Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated for sagemaker endpoints compatibility #9

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Next Next commit
fastapi setup
aaravnavani committed Aug 14, 2024
commit 59e9d47cb71291d025087d888b011a7b72d68722
102 changes: 73 additions & 29 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,78 @@
import json
import torch
import nltk
from typing import Any, Dict, List


from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Union
from transformers import pipeline
import torch

class InferlessPythonModel:
app = FastAPI()

def initialize(self):
self._classifier = pipeline(
"zero-shot-classification",
model="facebook/bart-large-mnli",
device="cuda",
hypothesis_template="This example has to do with topic {}.",
multi_label=True,
)
#self._classifier.to("cuda")

def infer(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
result = self._classifier(inputs["text"], inputs["candidate_topics"])
topics = result["labels"]
scores = result["scores"]
found_topics = []
for topic, score in zip(topics, scores):
if score > inputs["zero_shot_threshold"]:
found_topics.append(topic)
if not found_topics:
return {"results": ["No valid topic found."]}
return {"results": found_topics}

def finalize(self):
pass
# Initialize the zero-shot classification pipeline
classifier = pipeline(
"zero-shot-classification",
model="facebook/bart-large-mnli",
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
hypothesis_template="This example has to do with topic {}.",
multi_label=True,
)

class InferenceData(BaseModel):
name: str
shape: List[int]
data: Union[List[str], List[float]]
datatype: str

class InputRequest(BaseModel):
inputs: List[InferenceData]

class OutputResponse(BaseModel):
modelname: str
modelversion: str
outputs: List[InferenceData]

@app.post("/validate", response_model=OutputResponse)
async def restrict_to_topic(input_request: InputRequest):
print('make request')
text = None
candidate_topics = None
zero_shot_threshold = 0.5

for inp in input_request.inputs:
if inp.name == "text":
text = inp.data[0]
elif inp.name == "candidate_topics":
candidate_topics = inp.data
elif inp.name == "zero_shot_threshold":
zero_shot_threshold = float(inp.data[0])

if text is None or candidate_topics is None:
raise HTTPException(status_code=400, detail="Invalid input format")

# Perform zero-shot classification
result = classifier(text, candidate_topics)
topics = result["labels"]
scores = result["scores"]
found_topics = [topic for topic, score in zip(topics, scores) if score > zero_shot_threshold]

if not found_topics:
found_topics = ["No valid topic found."]

output_data = OutputResponse(
modelname="RestrictToTopicModel",
modelversion="1",
outputs=[
InferenceData(
name="results",
datatype="BYTES",
shape=[len(found_topics)],
data=found_topics
)
]
)

print(f"Output data: {output_data}")
return output_data

# Run the app with uvicorn
# Save this script as app.py and run with: uvicorn app:app --reload