Skip to content

Commit

Permalink
#195 Add Model Type Attribute (#200)
Browse files Browse the repository at this point in the history
* add ModelType to address #167

* fix syntax issue
  • Loading branch information
mvanniasingheTT authored Feb 25, 2025
1 parent 5b23ffd commit 9ba2fcc
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 1 deletion.
4 changes: 3 additions & 1 deletion app/api/docker_control/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
perform_reset,
)
from shared_config.model_config import model_implmentations
from shared_config.model_type_config import ModelTypes
from .serializers import DeploymentSerializer, StopSerializer
from shared_config.logger_config import get_logger

Expand Down Expand Up @@ -114,7 +115,8 @@ def post(self, request, *args, **kwargs):
weights_id = request.data.get("weights_id")
impl = model_implmentations[impl_id]
response = run_container(impl, weights_id)
run_agent_container(response["container_name"], response["port_bindings"], impl) # run agent container that maps to appropriate LLM container
if impl.model_type == ModelTypes.CHAT:
run_agent_container(response["container_name"], response["port_bindings"], impl) # run agent container that maps to appropriate LLM container
return Response(response, status=status.HTTP_201_CREATED)
else:
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
Expand Down
11 changes: 11 additions & 0 deletions app/api/shared_config/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from shared_config.device_config import DeviceConfigurations
from shared_config.backend_config import backend_config
from shared_config.setup_config import SetupTypes
from shared_config.model_type_config import ModelTypes
from shared_config.logger_config import get_logger

logger = get_logger(__name__)
Expand Down Expand Up @@ -50,6 +51,7 @@ class ModelImpl:
docker_config: Dict[str, Any]
service_route: str
setup_type: SetupTypes
model_type: ModelTypes
hf_model_id: str = None
model_name: str = None # uses defaults based on hf_model_id
model_id: str = None # uses defaults based on hf_model_id
Expand Down Expand Up @@ -236,6 +238,7 @@ def base_docker_config():
service_port=7000,
service_route="/objdetection_v2",
setup_type=SetupTypes.NO_SETUP,
model_type=ModelTypes.OBJECT_DETECTION
),
ModelImpl(
hf_model_id="meta-llama/Llama-3.1-70B-Instruct",
Expand All @@ -249,6 +252,7 @@ def base_docker_config():
service_port=7000,
service_route="/v1/chat/completions",
setup_type=SetupTypes.MAKE_VOLUMES,
model_type=ModelTypes.MOCK
),
ModelImpl(
hf_model_id="meta-llama/Llama-3.1-70B-Instruct",
Expand All @@ -261,6 +265,7 @@ def base_docker_config():
service_route="/v1/chat/completions",
env_file=os.environ.get("VLLM_LLAMA31_ENV_FILE"),
setup_type=SetupTypes.TT_INFERENCE_SERVER,
model_type=ModelTypes.CHAT
),
ModelImpl(
hf_model_id="meta-llama/Llama-3.2-1B-Instruct",
Expand All @@ -270,6 +275,7 @@ def base_docker_config():
docker_config=base_docker_config(),
service_route="/v1/chat/completions",
setup_type=SetupTypes.TT_INFERENCE_SERVER,
model_type=ModelTypes.CHAT
),
ModelImpl(
hf_model_id="meta-llama/Llama-3.2-3B-Instruct",
Expand All @@ -279,6 +285,7 @@ def base_docker_config():
docker_config=base_docker_config(),
service_route="/v1/chat/completions",
setup_type=SetupTypes.TT_INFERENCE_SERVER,
model_type=ModelTypes.CHAT
),
ModelImpl(
hf_model_id="meta-llama/Llama-3.1-8B-Instruct",
Expand All @@ -288,6 +295,7 @@ def base_docker_config():
docker_config=base_docker_config(),
service_route="/v1/chat/completions",
setup_type=SetupTypes.TT_INFERENCE_SERVER,
model_type=ModelTypes.CHAT
),
ModelImpl(
hf_model_id="meta-llama/Llama-3.2-11B-Vision-Instruct",
Expand All @@ -297,6 +305,7 @@ def base_docker_config():
docker_config=base_docker_config(),
service_route="/v1/chat/completions",
setup_type=SetupTypes.TT_INFERENCE_SERVER,
model_type=ModelTypes.CHAT
),
ModelImpl(
hf_model_id="meta-llama/Llama-3.1-70B-Instruct",
Expand All @@ -306,6 +315,7 @@ def base_docker_config():
docker_config=base_docker_config(),
service_route="/v1/chat/completions",
setup_type=SetupTypes.TT_INFERENCE_SERVER,
model_type=ModelTypes.CHAT
),
ModelImpl(
hf_model_id="meta-llama/Llama-3.3-70B-Instruct",
Expand All @@ -315,6 +325,7 @@ def base_docker_config():
docker_config=base_docker_config(),
service_route="/v1/chat/completions",
setup_type=SetupTypes.TT_INFERENCE_SERVER,
model_type=ModelTypes.CHAT
),
#! Add new model vLLM model implementations here
]
Expand Down
7 changes: 7 additions & 0 deletions app/api/shared_config/model_type_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from enum import Enum

class ModelTypes(Enum):
MOCK = "mock"
CHAT = "chat"
OBJECT_DETECTION = "object_detection"
IMAGE_GENERATION = "image_generation"

0 comments on commit 9ba2fcc

Please sign in to comment.