diff --git a/project/src/data_models.py b/project/src/data_models.py index 2061531..c0a2c98 100644 --- a/project/src/data_models.py +++ b/project/src/data_models.py @@ -5,6 +5,7 @@ # NOTE: If you're using a different model ensure that you add in the Results and ModelResponse # Pydantic models below! + class SimpleModelRequest(BaseModel): review: str @@ -20,7 +21,7 @@ def process_labels(cls, data: dict[int, float]) -> dict[str, float]: return {LABEL_CLASS_TO_NAME[key]: value for key, value in data.items()} -class SimpleModelRespone(BaseModel): +class SimpleModelResponse(BaseModel): label: SentimentLabel score: float diff --git a/project/src/server.py b/project/src/server.py index 63c37a7..36c9761 100644 --- a/project/src/server.py +++ b/project/src/server.py @@ -2,7 +2,7 @@ from ray import serve from ray.serve.handle import DeploymentHandle -from src.data_models import SimpleModelRequest, SimpleModelRespone, SimpleModelResults +from src.data_models import SimpleModelRequest, SimpleModelResponse, SimpleModelResults from src.model import Model app = FastAPI( @@ -28,7 +28,7 @@ def __init__(self, simple_model_handle: DeploymentHandle) -> None: async def predict(self, request: SimpleModelRequest): # TODO: Use the handle.predict which is a remote function # to get the result - return SimpleModelRespone.model_validate(result.model_dump()) + return SimpleModelResponse.model_validate(result.model_dump()) @serve.deployment(