Skip to content

Commit

Permalink
Bring back LoRA changes, with slight refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
rmccorm4 committed May 24, 2024
1 parent cdd7e77 commit a4e6162
Showing 1 changed file with 59 additions and 18 deletions.
77 changes: 59 additions & 18 deletions src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,31 @@ def auto_complete_config(auto_complete_model_config):
return auto_complete_model_config

def initialize(self, args):
self.args = args
self.logger = pb_utils.Logger
self.model_config = json.loads(args["model_config"])
output_config = pb_utils.get_output_config_by_name(
self.model_config, "text_output"
)
self.output_dtype = pb_utils.triton_string_to_numpy(output_config["data_type"])

# Prepare vLLM engine
self.init_engine()

# Counter to keep track of ongoing request counts
self.ongoing_request_count = 0

# Starting asyncio event loop to process the received requests asynchronously.
self._loop = asyncio.get_event_loop()
self._loop_thread = threading.Thread(
target=self.engine_loop, args=(self._loop,)
)
self._shutdown_event = asyncio.Event()
self._loop_thread.start()

def init_engine(self):
# Currently, Triton needs to use decoupled policy for asynchronously
# forwarding requests to vLLM engine, so assert it.

# assert are in decoupled mode. Currently, Triton needs to use
# decoupled policy for asynchronously forwarding requests to
Expand All @@ -118,17 +141,25 @@ def initialize(self, args):
engine_args_filepath
), f"'{_VLLM_ENGINE_ARGS_FILENAME}' containing vllm engine args must be provided in '{pb_utils.get_model_dir()}'"
with open(engine_args_filepath) as file:
vllm_engine_config = json.load(file)
self.vllm_engine_config = json.load(file)

# Validate device and multi-processing settings are currently set based on model/configs.
self.validate_device_config()

# Check for LoRA config and set it up if enabled
self.setup_lora()

# Create an AsyncLLMEngine from the config from JSON
self.llm_engine = AsyncLLMEngine.from_engine_args(
AsyncEngineArgs(**vllm_engine_config)
AsyncEngineArgs(**self.vllm_engine_config)
)

def setup_lora(self):
self.enable_lora = False

if (
"enable_lora" in vllm_engine_config.keys()
and vllm_engine_config["enable_lora"].lower() == "true"
"enable_lora" in self.vllm_engine_config.keys()
and self.vllm_engine_config["enable_lora"].lower() == "true"
):
# create Triton LoRA weights repository
multi_lora_args_filepath = os.path.join(
Expand All @@ -146,21 +177,31 @@ def initialize(self, args):
f"Triton backend cannot find {multi_lora_args_filepath}."
)

output_config = pb_utils.get_output_config_by_name(
self.model_config, "text_output"
)
self.output_dtype = pb_utils.triton_string_to_numpy(output_config["data_type"])

# Counter to keep track of ongoing request counts
self.ongoing_request_count = 0
def validate_device_config(self):
triton_kind = self.args["model_instance_kind"]
triton_device_id = self.args["model_instance_device_id"]
triton_instance = f"{self.args['model_name']}_{triton_device_id}"

# Triton's current definition of KIND_GPU makes assumptions that
# models only use a single GPU. For multi-GPU models, the recommendation
# is to specify KIND_MODEL to acknowledge that the model will take control
# of the devices made available to it.
# NOTE: Consider other parameters that would indicate multi-GPU in the future.
tp_size = int(self.vllm_engine_config.get("tensor_parallel_size", 1))
if tp_size > 1 and triton_kind == "GPU":
raise ValueError(
"KIND_GPU is for single-GPU models, please specify KIND_MODEL in the model's config.pbtxt for multi-GPU models"
)

# Starting asyncio event loop to process the received requests asynchronously.
self._loop = asyncio.get_event_loop()
self._loop_thread = threading.Thread(
target=self.engine_loop, args=(self._loop,)
)
self._shutdown_event = asyncio.Event()
self._loop_thread.start()
# If KIND_GPU is specified, isolate the selected GPU device to ensure that multiple model instances do not all use the same default
# device (usually device 0) when KIND_GPU is used.
if triton_kind == "GPU" and int(triton_device_id) >= 0:
self.logger.log_info(
f"Detected KIND_GPU model instance, explicitly setting GPU device={triton_device_id} for {triton_instance}"
)
# NOTE: this only affects this process and it's subprocesses, not other processes.
# vLLM doesn't currently seem to expose selecting a specific device in the APIs.
os.environ["CUDA_VISIBLE_DEVICES"] = triton_device_id

def create_task(self, coro):
"""
Expand Down

0 comments on commit a4e6162

Please sign in to comment.