Skip to content

Commit

Permalink
removed ray serve code, added back original fastapi script for backwa…
Browse files Browse the repository at this point in the history
…rds compatibility with inference box
  • Loading branch information
AlejandroEsquivel committed Sep 16, 2024
1 parent 38ad66e commit 20f343c
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 115 deletions.
94 changes: 69 additions & 25 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,86 @@
# Note: this will be deprecated in the future in favour of the shared "model-hosts" package which contains this code.
from fastapi import FastAPI, HTTPException
from contextlib import contextmanager, asynccontextmanager
from pydantic import BaseModel
from typing import List
from app_inference_spec import InferenceSpec, InputRequest, OutputResponse
import detoxify
import torch
import os

app = FastAPI()
# Initialize the detoxify model once
env = os.environ.get("env", "dev")
torch_device = "cuda" if env == "prod" else "cpu"
model = detoxify.Detoxify("unbiased-small", device=torch.device(torch_device))

####################################################################
# FastAPI Setup & Endpoints
####################################################################

app = FastAPI()
class InferenceData(BaseModel):
name: str
shape: List[int]
data: List
datatype: str


class InputRequest(BaseModel):
inputs: List[InferenceData]

inference_spec = InferenceSpec()

# Load the model once before the app starts
# Not using lifespan events as they don't support sync functions.
@app.on_event("startup")
def startup_event():
inference_spec.load()
class OutputResponse(BaseModel):
modelname: str
modelversion: str
outputs: List[InferenceData]


@app.post("/validate", response_model=OutputResponse)
def validate(input_request: InputRequest):
args, kwargs = inference_spec.process_request(input_request)
return inference_spec.infer(*args, **kwargs)
async def check_toxicity(input_request: InputRequest):
threshold = None
for inp in input_request.inputs:
if inp.name == "text":
text_vals = inp.data
elif inp.name == "threshold":
threshold = float(inp.data[0])

if text_vals is None or threshold is None:
raise HTTPException(status_code=400, detail="Invalid input format")

return ToxicLanguage.infer(text_vals, threshold)


class ToxicLanguage:
model_name = "unbiased-small"
validation_method = "sentence"
device = torch.device(torch_device)
model = detoxify.Detoxify(model_name, device=device)
labels = [
"toxicity",
"severe_toxicity",
"obscene",
"threat",
"insult",
"identity_attack",
"sexual_explicit",
]

####################################################################
# Sagemaker Specific Endpoints
####################################################################
def infer(text_vals, threshold) -> OutputResponse:
outputs = []
for idx, text in enumerate(text_vals):
results = ToxicLanguage.model.predict(text)
pred_labels = [
label for label, score in results.items() if score > threshold
]
outputs.append(
InferenceData(
name=f"result{idx}",
datatype="BYTES",
shape=[len(pred_labels)],
data=[pred_labels],
)
)

@app.get("/ping")
async def healtchcheck():
return {"status": "ok"}
output_data = OutputResponse(
modelname="unbiased-small", modelversion="1", outputs=outputs
)

@app.post("/invocations", response_model=OutputResponse)
def validate_sagemaker(input_request: InputRequest):
args, kwargs = inference_spec.process_request(input_request)
return inference_spec.infer(*args, **kwargs)
return output_data


# Run the app with uvicorn
Expand Down
10 changes: 2 additions & 8 deletions app_inference_spec.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from models_host.base_inference_spec import BaseInferenceSpec
from typing import List
import detoxify
import torch
Expand All @@ -22,9 +23,8 @@ class OutputResponse(BaseModel):
outputs: List[InferenceData]

# Using same nomencalture as in Sagemaker classes
class InferenceSpec:
class InferenceSpec(BaseInferenceSpec):
model = None
_instance = None

model_name = "unbiased-small"
validation_method = "sentence"
Expand All @@ -43,12 +43,6 @@ def torch_device(self):
env = os.environ.get("env", "dev")
torch_device = "cuda" if env == "prod" else "cpu"
return torch_device

# Singleton pattern
def __new__(cls):
if cls._instance is None:
cls._instance = super(InferenceSpec, cls).__new__(cls)
return cls._instance

def load(self):
model_name = self.model_name
Expand Down
44 changes: 0 additions & 44 deletions app_ray_serve.py

This file was deleted.

38 changes: 0 additions & 38 deletions app_ray_serve_config.yaml

This file was deleted.

0 comments on commit 20f343c

Please sign in to comment.