generated from guardrails-ai/validator-template
-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #15 from guardrails-ai/main
manifest and endpoint updates
- Loading branch information
Showing
5 changed files
with
168 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
from fastapi import FastAPI, HTTPException | ||
from pydantic import BaseModel | ||
from typing import List, Union | ||
from transformers import pipeline | ||
import torch | ||
import os | ||
|
||
app = FastAPI() | ||
|
||
env = os.environ.get("env", "dev") | ||
torch_device = "cuda" if env == "prod" else "cpu" | ||
|
||
# Initialize the zero-shot classification pipeline | ||
classifier = pipeline( | ||
"zero-shot-classification", | ||
model="facebook/bart-large-mnli", | ||
device=torch.device(torch_device), | ||
hypothesis_template="This example has to do with topic {}.", | ||
multi_label=True, | ||
) | ||
|
||
class InferenceData(BaseModel): | ||
name: str | ||
shape: List[int] | ||
data: Union[List[str], List[float]] | ||
datatype: str | ||
|
||
class InputRequest(BaseModel): | ||
text: str | ||
candidate_topics: List[str] | ||
zero_shot_threshold: float = 0.5 | ||
|
||
class OutputResponse(BaseModel): | ||
modelname: str | ||
modelversion: str | ||
outputs: List[InferenceData] | ||
|
||
@app.post("/validate", response_model=OutputResponse) | ||
def restrict_to_topic(input_request: InputRequest): | ||
print('make request') | ||
|
||
text = input_request.text | ||
candidate_topics = input_request.candidate_topics | ||
zero_shot_threshold = input_request.zero_shot_threshold | ||
|
||
|
||
if text is None or candidate_topics is None: | ||
raise HTTPException(status_code=400, detail="Invalid input format") | ||
|
||
# Perform zero-shot classification | ||
result = classifier(text, candidate_topics) | ||
print("result: ", result) | ||
topics = result["labels"] | ||
scores = result["scores"] | ||
found_topics = [topic for topic, score in zip(topics, scores) if score > zero_shot_threshold] | ||
|
||
if not found_topics: | ||
found_topics = ["No valid topic found."] | ||
|
||
output_data = OutputResponse( | ||
modelname="RestrictToTopicModel", | ||
modelversion="1", | ||
outputs=[ | ||
InferenceData( | ||
name="results", | ||
datatype="BYTES", | ||
shape=[len(found_topics)], | ||
data=found_topics | ||
) | ||
] | ||
) | ||
|
||
return output_data | ||
|
||
# Run the app with uvicorn | ||
# Save this script as app.py and run with: uvicorn app:app --reload |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# Inferless config file (version: 1.0.0) | ||
version: 1.0.0 | ||
|
||
name: Restrict-To-Topic | ||
import_source: GIT | ||
|
||
# you can choose the options between ONNX, TENSORFLOW, PYTORCH | ||
source_framework_type: PYTORCH | ||
|
||
configuration: | ||
# if you want to use a custom runtime, add the runtime id below. | ||
# you can find it by running `inferless r6897f8untime list` or create one with `inferless runtime upload` and update this file it by running `inferless runtime select --id <RUNTIME_ID>`. | ||
custom_runtime_id: 035c210b-3425-43d7-be4f-f341fae13842 | ||
custom_runtime_version: '0' | ||
|
||
# if you want to use a custom volume, add the volume id and name below, | ||
# you can find it by running `inferless volume list` or create one with `inferless volume create -n {VOLUME_NAME}` | ||
custom_volume_id: '' | ||
custom_volume_name: '' | ||
|
||
gpu_type: T4 | ||
inference_time: '180' | ||
is_dedicated: false | ||
is_serverless: false | ||
max_replica: '1' | ||
min_replica: '0' | ||
scale_down_delay: '600' | ||
region: region-1 | ||
vcpu: '1.5' | ||
ram: '7' | ||
env: | ||
# Add your environment variables here | ||
# ENV: 'PROD' | ||
secrets: | ||
# Add your secret ids here you can find it by running `inferless secrets list` | ||
# - 65723205-ce21-4392-a10b-3tf00c58988c | ||
io_schema: true | ||
model_url: https://github.com/guardrails-ai/restricttotopic | ||
provider: GITHUB |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
INPUT_SCHEMA = { | ||
"text": { | ||
"example": ["Let's talk about Iphones made by Apple"], | ||
"shape": [1], | ||
"datatype": "STRING", | ||
"required": True, | ||
}, | ||
"candidate_topics": { | ||
"example": ["Apple Iphone", "Samsung Galaxy"], | ||
"shape": [-1], | ||
"datatype": "STRING", | ||
"required": True, | ||
}, | ||
"zero_shot_threshold": { | ||
"example": [0.5], | ||
"shape": [1], | ||
"datatype": "FP32", | ||
"required": True, | ||
}, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters