Skip to content

Commit

Permalink
Merge pull request #15 from guardrails-ai/main
Browse files Browse the repository at this point in the history
manifest and endpoint updates
  • Loading branch information
zsimjee authored Aug 26, 2024
2 parents 98f081a + 895e77c commit baf839e
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 14 deletions.
23 changes: 12 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
## Overview
# Overview

| Developed by | Tryolabs |
| --- | --- |
| Date of development | Feb 15, 2024 |
| Validator type | Format |
| Blog | - |
| Blog | |
| License | Apache 2 |
| Input/Output | Output |

## Description

### Intended Use
This validator checks if a text is related with a topic.

## Requirements
### Requirements

* Dependencies:
- guardrails-ai>=0.4.0
Expand All @@ -22,15 +23,15 @@ This validator checks if a text is related with a topic.
* Foundation model access keys:
- OPENAI_API_KEY

# Installation
## Installation

```bash
guardrails hub install hub://tryolabs/restricttotopic
$ guardrails hub install hub://tryolabs/restricttotopic
```

# Usage Examples
## Usage Examples

## Validating string output via Python
### Validating string output via Python

In this example, we apply the validator to a string output generated by an LLM.

Expand Down Expand Up @@ -59,7 +60,7 @@ The Beatles were a charismatic English pop-rock band of the 1960s.
""") # Validator fails
```

## Validating JSON output via Python
### Validating JSON output via Python

In this example, we apply the validator to a string field of a JSON output generated by an LLM.

Expand Down Expand Up @@ -123,6 +124,6 @@ Note:
2. When invoking `guard.parse(...)`, ensure to pass the appropriate `metadata` dictionary that includes keys and values required by this validator. If `guard` is associated with multiple validators, combine all necessary metadata into a single dictionary.

**Parameters**
- **`value`** *(Any):* The input value to validate.
- **`metadata`** *(dict):* A dictionary containing metadata required for validation. No additional metadata keys are needed for this validator.
- **`value`** *(Any)*: The input value to validate.
- **`metadata`** *(dict)*: A dictionary containing metadata required for validation. No additional metadata keys are needed for this validator.
</ul>
76 changes: 76 additions & 0 deletions app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Union
from transformers import pipeline
import torch
import os

app = FastAPI()

env = os.environ.get("env", "dev")
torch_device = "cuda" if env == "prod" else "cpu"

# Initialize the zero-shot classification pipeline
classifier = pipeline(
"zero-shot-classification",
model="facebook/bart-large-mnli",
device=torch.device(torch_device),
hypothesis_template="This example has to do with topic {}.",
multi_label=True,
)

class InferenceData(BaseModel):
name: str
shape: List[int]
data: Union[List[str], List[float]]
datatype: str

class InputRequest(BaseModel):
text: str
candidate_topics: List[str]
zero_shot_threshold: float = 0.5

class OutputResponse(BaseModel):
modelname: str
modelversion: str
outputs: List[InferenceData]

@app.post("/validate", response_model=OutputResponse)
def restrict_to_topic(input_request: InputRequest):
print('make request')

text = input_request.text
candidate_topics = input_request.candidate_topics
zero_shot_threshold = input_request.zero_shot_threshold


if text is None or candidate_topics is None:
raise HTTPException(status_code=400, detail="Invalid input format")

# Perform zero-shot classification
result = classifier(text, candidate_topics)
print("result: ", result)
topics = result["labels"]
scores = result["scores"]
found_topics = [topic for topic, score in zip(topics, scores) if score > zero_shot_threshold]

if not found_topics:
found_topics = ["No valid topic found."]

output_data = OutputResponse(
modelname="RestrictToTopicModel",
modelversion="1",
outputs=[
InferenceData(
name="results",
datatype="BYTES",
shape=[len(found_topics)],
data=found_topics
)
]
)

return output_data

# Run the app with uvicorn
# Save this script as app.py and run with: uvicorn app:app --reload
39 changes: 39 additions & 0 deletions inferless.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Inferless config file (version: 1.0.0)
version: 1.0.0

name: Restrict-To-Topic
import_source: GIT

# you can choose the options between ONNX, TENSORFLOW, PYTORCH
source_framework_type: PYTORCH

configuration:
# if you want to use a custom runtime, add the runtime id below.
# you can find it by running `inferless r6897f8untime list` or create one with `inferless runtime upload` and update this file it by running `inferless runtime select --id <RUNTIME_ID>`.
custom_runtime_id: 035c210b-3425-43d7-be4f-f341fae13842
custom_runtime_version: '0'

# if you want to use a custom volume, add the volume id and name below,
# you can find it by running `inferless volume list` or create one with `inferless volume create -n {VOLUME_NAME}`
custom_volume_id: ''
custom_volume_name: ''

gpu_type: T4
inference_time: '180'
is_dedicated: false
is_serverless: false
max_replica: '1'
min_replica: '0'
scale_down_delay: '600'
region: region-1
vcpu: '1.5'
ram: '7'
env:
# Add your environment variables here
# ENV: 'PROD'
secrets:
# Add your secret ids here you can find it by running `inferless secrets list`
# - 65723205-ce21-4392-a10b-3tf00c58988c
io_schema: true
model_url: https://github.com/guardrails-ai/restricttotopic
provider: GITHUB
20 changes: 20 additions & 0 deletions input_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
INPUT_SCHEMA = {
"text": {
"example": ["Let's talk about Iphones made by Apple"],
"shape": [1],
"datatype": "STRING",
"required": True,
},
"candidate_topics": {
"example": ["Apple Iphone", "Samsung Galaxy"],
"shape": [-1],
"datatype": "STRING",
"required": True,
},
"zero_shot_threshold": {
"example": [0.5],
"shape": [1],
"datatype": "FP32",
"required": True,
},
}
24 changes: 21 additions & 3 deletions validator/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from dotenv import load_dotenv
from guardrails.validator_base import (
ErrorSpan,
FailResult,
PassResult,
ValidationResult,
Expand Down Expand Up @@ -332,14 +333,31 @@ def validate(
elif topic in self._invalid_topics:
invalid_topics_found.append(topic)

error_spans = []

# Require at least one valid topic and no invalid topics
if invalid_topics_found:
for topic in invalid_topics_found:
error_spans.append(
ErrorSpan(
start=value.find(topic),
end=value.find(topic) + len(topic),
reason=f"Text contains invalid topic: {topic}",
)
)
return FailResult(
error_message=f"Invalid topics found: {invalid_topics_found}"
error_message=f"Invalid topics found: {invalid_topics_found}",
error_spans=error_spans
)
if not valid_topics_found:
return FailResult(error_message="No valid topic was found.")

return FailResult(
error_message="No valid topic was found.",
error_spans=[ErrorSpan(
start=0,
end=len(value),
reason="No valid topic was found."
)]
)
return PassResult()

def _inference_local(self, model_input: Any) -> Any:
Expand Down

0 comments on commit baf839e

Please sign in to comment.