Skip to content

Commit

Permalink
Merge pull request #3 from guardrails-ai/host_setup
Browse files Browse the repository at this point in the history
Host NSFWText validator
  • Loading branch information
dtam authored Sep 5, 2024
2 parents 1734503 + 06f5233 commit adbb456
Show file tree
Hide file tree
Showing 2 changed files with 220 additions and 53 deletions.
83 changes: 83 additions & 0 deletions app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List
from transformers import pipeline
import os
import torch

app = FastAPI()

device = os.environ.get("GUARDRAILS_DEVICE", "cpu")

if device == "cuda" and torch.cuda.is_available():
torch_device = "cuda"
elif device == "cuda" and not torch.cuda.is_available():
print("Warning: CUDA is not available. Falling back to CPU.")
torch_device = "cpu"
else:
torch_device = "cpu"

class InferenceData(BaseModel):
name: str
shape: List[int]
data: List
datatype: str

class InputRequest(BaseModel):
inputs: List[InferenceData]

class OutputResponse(BaseModel):
modelname: str
modelversion: str
outputs: List[InferenceData]

@app.get("/")
async def hello_world():
return "nsfw_text"

@app.post("/validate", response_model=OutputResponse)
async def check_nsfw(input_request: InputRequest):
threshold = None
for inp in input_request.inputs:
if inp.name == "text":
text_vals = inp.data
elif inp.name == "threshold":
threshold = float(inp.data[0])

if text_vals is None or threshold is None:
raise HTTPException(status_code=400, detail="Invalid input format")

return NSFWText.infer(text_vals, threshold)

class NSFWText:
model_name = "michellejieli/NSFW_text_classifier"
pipe = pipeline(
"text-classification",
model=model_name,
device=torch_device
)

def infer(text_vals, threshold) -> OutputResponse:
outputs = []
for idx, text in enumerate(text_vals):
results = NSFWText.pipe(text)
pred_labels = [
result['label'] for result in results if result['label'] == 'NSFW' and result['score'] > threshold
]
outputs.append(
InferenceData(
name=f"result{idx}",
datatype="BYTES",
shape=[len(pred_labels)],
data=[pred_labels],
)
)

output_data = OutputResponse(
modelname=NSFWText.model_name, modelversion="1", outputs=outputs
)

return output_data

# Run the app with uvicorn
# Save this script as app.py and run with: uvicorn app:app --reload
190 changes: 137 additions & 53 deletions validator/main.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
from typing import Any, Callable, Dict, Optional
import difflib
import json
from typing import Any, Callable, Dict, List, Optional, Union, cast

import nltk
from guardrails.validator_base import (
ErrorSpan,
FailResult,
PassResult,
ValidationResult,
Validator,
register_validator,
)

import nltk
from transformers import pipeline


@register_validator(name="guardrails/nsfw_text", data_type="string")
@register_validator(
name="guardrails/nsfw_text", data_type="string", has_guardrails_endpoint=True
)
class NSFWText(Validator):
"""Validates that the generated text is safe for work (SFW).
Expand Down Expand Up @@ -45,42 +49,46 @@ def __init__(
self,
threshold: float = 0.8,
validation_method: str = "sentence",
on_fail: Optional[Callable] = None,
device: Optional[Union[str, int]] = "cpu",
model_name: Optional[str] = "michellejieli/NSFW_text_classifier",
on_fail: Optional[Callable[..., Any]] = None,
**kwargs,
):
super().__init__(
on_fail, threshold=threshold, validation_method=validation_method, **kwargs
on_fail=on_fail,
threshold=threshold,
validation_method=validation_method,
**kwargs,
)
self._threshold = float(threshold)
if validation_method not in ["sentence", "full"]:
raise ValueError("validation_method must be 'sentence' or 'full'.")
self._validation_method = validation_method
if self.use_local:
self._model = pipeline(
"text-classification",
model=model_name,
)

# Define the model, pipeline and labels
self._model_name = "michellejieli/NSFW_text_classifier"
self._pipe = pipeline(
"text-classification",
model=self._model_name,
)
print("Pipeline setup successfully.")

def is_nsfw(self, value: str) -> bool:
def is_nsfw(self, value: str) -> List[str]:
"""Determines if the generated text is NSFW.
Args:
value (str): The generated text.
Returns:
bool: Whether the generated text is NSFW.
list[str]: Labels predicted by the model
"""
result = self._pipe(value)
if not result:
raise RuntimeError("Failed to get model prediction.")

pred_label, confidence = result[0]["label"], result[0]["score"] # type: ignore
if pred_label == "NSFW" and confidence > self._threshold:
return True
return False
pred_labels = []
if value:
results = self._model(value)
if results:
results = cast(List[List[Dict[str, Any]]], results)
for result in results:
label, score = result["label"], result["score"]
if label == "NSFW" and score > self._threshold:
pred_labels.append(label)
return pred_labels

def validate_each_sentence(
self, value: str, metadata: Dict[str, Any]
Expand All @@ -90,44 +98,120 @@ def validate_each_sentence(
# Split the value into sentences using nltk sentence tokenizer.
sentences = nltk.sent_tokenize(value)

unsupported_sentences, supported_sentences = [], []
for sentence in sentences:
is_nsfw = self.is_nsfw(sentence)
if is_nsfw:
unsupported_sentences.append(sentence)
clean_sentences, nsfw_sentences = [], []
error_spans: List[ErrorSpan] = []
char_index = 0

sentence_predictions = self._inference(sentences)

for idx, sentence in enumerate(sentences):
pred_labels = sentence_predictions[idx]

if pred_labels:
nsfw_sentences.append(sentence)
error_spans.append(
ErrorSpan(
start=char_index,
end=char_index + len(sentence),
reason=f"NSFW content detected: {', '.join(pred_labels)}",
)
)
else:
supported_sentences.append(sentence)
clean_sentences.append(sentence)
char_index += len(sentence) + 1

if nsfw_sentences:
nsfw_sentences_text = "- " + "\n- ".join(nsfw_sentences)

if unsupported_sentences:
unsupported_sentences = "- " + "\n- ".join(unsupported_sentences)
return FailResult(
metadata=metadata,
error_message=(
f"The following sentences in your response "
"were found to be NSFW:\n"
f"\n{unsupported_sentences}"
f"\n{nsfw_sentences_text}"
),
fix_value="\n".join(supported_sentences),
fix_value="\n".join(clean_sentences),
error_spans=error_spans,
)
return PassResult()

def validate_full_text(
self, value: str, metadata: Dict[str, Any]
) -> ValidationResult:
"""Validate that the entire generated text is SFW."""

is_nsfw = self.is_nsfw(value)
if is_nsfw:
return FailResult(
metadata=metadata,
error_message="The generated text was found to be NSFW.",
)
return PassResult()
return PassResult(metadata=metadata)

def validate(self, value: str, metadata: Dict[str, Any]) -> ValidationResult:
"""Validation method of the NSFWText validator."""
if self._validation_method == "sentence":
return self.validate_each_sentence(value, metadata)
"""Validation method for the NSFW text validator."""
if not value:
raise ValueError("Value cannot be empty.")

return self.validate_each_sentence(value, metadata)

def _inference_local(self, model_input: str | list) -> ValidationResult:
"""Local inference method for the NSFW text validator."""

if isinstance(model_input, str):
model_input = [model_input]
predictions = []
for text in model_input:
pred_labels = self.is_nsfw(text)
predictions.append(pred_labels)

return predictions

def _inference_remote(self, model_input: str | list) -> ValidationResult:
"""Remote inference method for the NSFW text validator."""

if isinstance(model_input, str):
model_input = [model_input]

request_body = {
"inputs": [
{
"name": "text",
"shape": [len(model_input)],
"data": model_input,
"datatype": "BYTES"
},
{
"name": "threshold",
"shape": [1],
"data": [self._threshold],
"datatype": "FP32"
}
]
}
response = self._hub_inference_request(json.dumps(request_body), self.validation_endpoint)
if not response or "outputs" not in response:
raise ValueError("Invalid response from remote inference", response)

data = [output["data"][0] for output in response["outputs"]]
return data


def get_error_spans(self, original: str, fixed: str) -> List[ErrorSpan]:
"""Generate error spans to display in failresult (if they exist). Error
spans show the character-level range of text that has failed validation.
Args:
original (str): The input string
fixed (str): The 'validated' output string
if self._validation_method == "full":
return self.validate_full_text(value, metadata)
Returns:
List[ErrorSpan]: A list of ErrorSpans to represent validation failures
over the character sequence.
"""
differ = difflib.Differ()
diffs = list(differ.compare(original, fixed))
error_spans = []
start = None
for i, diff in enumerate(diffs):
if diff.startswith("- "):
if start is None:
start = i
else:
if start is not None:
error_spans.append(
ErrorSpan(
start=start,
end=i,
reason="NSFW content detected",
)
)
start = None
return error_spans

0 comments on commit adbb456

Please sign in to comment.