Skip to content

Commit

Permalink
changes
Browse files Browse the repository at this point in the history
  • Loading branch information
aaravnavani committed Sep 5, 2024
1 parent b8d78df commit a9748b4
Show file tree
Hide file tree
Showing 3 changed files with 241 additions and 3 deletions.
217 changes: 217 additions & 0 deletions NSFWText.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
import difflib
import json
from typing import Any, Callable, Dict, List, Optional, Union, cast

import nltk
from guardrails.validator_base import (
ErrorSpan,
FailResult,
PassResult,
ValidationResult,
Validator,
register_validator,
)
from transformers import pipeline


@register_validator(
name="guardrails/nsfw_text", data_type="string", has_guardrails_endpoint=True
)
class NSFWText(Validator):
"""Validates that the generated text is safe for work (SFW).
**Key Properties**
| Property | Description |
| ----------------------------- | ---------------------- |
| Name for `format` attribute | `guardrails/nsfw_text` |
| Supported data types | `string` |
| Programmatic fix | N/A |
Args:
threshold: The confidence threshold over which model inferences are considered.
Must be a float between 0 and 1. Defaults to 0.8
validation_method: Whether to validate at the sentence level or
over the full text. Must be one of `sentence` or `full`.
Defaults to `sentence`
This validator uses the pre-trained multi-class model from HuggingFace -
`michellejieli/NSFW_text_classifier` to check whether the generated text is
safe for work. If the model predicts the text to be "NSFW" with a confidence
higher than the threshold, the validator fails. Otherwise, it passes.
If validation_method is `sentence`, the validator will remove the sentences
that are predicted to be NSFW and return the remaining sentences. If
validation_method is `full`, the validator will remove the entire text if
the prediction is deemed NSFW it will return an empty string.
"""

def __init__(
self,
threshold: float = 0.8,
validation_method: str = "sentence",
device: Optional[Union[str, int]] = "cpu",
model_name: Optional[str] = "michellejieli/NSFW_text_classifier",
on_fail: Optional[Callable[..., Any]] = None,
**kwargs,
):
super().__init__(
on_fail=on_fail,
threshold=threshold,
validation_method=validation_method,
**kwargs,
)
self._threshold = float(threshold)
if validation_method not in ["sentence", "full"]:
raise ValueError("validation_method must be 'sentence' or 'full'.")
self._validation_method = validation_method
if self.use_local:
self._model = pipeline(
"text-classification",
model=model_name,
)

def is_nsfw(self, value: str) -> List[str]:
"""Determines if the generated text is NSFW.
Args:
value (str): The generated text.
Returns:
list[str]: Labels predicted by the model
"""
pred_labels = []
if value:
results = self._model(value)
if results:
results = cast(List[List[Dict[str, Any]]], results)
for result in results:
label, score = result["label"], result["score"]
if label == "NSFW" and score > self._threshold:
pred_labels.append(label)
return pred_labels

def validate_each_sentence(
self, value: str, metadata: Dict[str, Any]
) -> ValidationResult:
"""Validate that each sentence in the generated text is SFW."""

# Split the value into sentences using nltk sentence tokenizer.
sentences = nltk.sent_tokenize(value)

clean_sentences, nsfw_sentences = [], []
error_spans: List[ErrorSpan] = []
char_index = 0

sentence_predictions = self._inference(sentences)

for idx, sentence in enumerate(sentences):
pred_labels = sentence_predictions[idx]

if pred_labels:
nsfw_sentences.append(sentence)
error_spans.append(
ErrorSpan(
start=char_index,
end=char_index + len(sentence),
reason=f"NSFW content detected: {', '.join(pred_labels)}",
)
)
else:
clean_sentences.append(sentence)
char_index += len(sentence) + 1

if nsfw_sentences:
nsfw_sentences_text = "- " + "\n- ".join(nsfw_sentences)

return FailResult(
metadata=metadata,
error_message=(
f"The following sentences in your response "
"were found to be NSFW:\n"
f"\n{nsfw_sentences_text}"
),
fix_value="\n".join(clean_sentences),
error_spans=error_spans,
)
return PassResult(metadata=metadata)

def validate(self, value: str, metadata: Dict[str, Any]) -> ValidationResult:
"""Validation method for the NSFW text validator."""
if not value:
raise ValueError("Value cannot be empty.")

return self.validate_each_sentence(value, metadata)

def _inference_local(self, model_input: str | list) -> ValidationResult:
"""Local inference method for the NSFW text validator."""

if isinstance(model_input, str):
model_input = [model_input]
predictions = []
for text in model_input:
pred_labels = self.is_nsfw(text)
predictions.append(pred_labels)

return predictions

def _inference_remote(self, model_input: str | list) -> ValidationResult:
"""Remote inference method for the NSFW text validator."""

if isinstance(model_input, str):
model_input = [model_input]

request_body = {
"inputs": [
{
"name": "text",
"shape": [len(model_input)],
"data": model_input,
"datatype": "BYTES"
},
{
"name": "threshold",
"shape": [1],
"data": [self._threshold],
"datatype": "FP32"
}
]
}
response = self._hub_inference_request(json.dumps(request_body), self.validation_endpoint)
if not response or "outputs" not in response:
raise ValueError("Invalid response from remote inference", response)

data = [output["data"][0] for output in response["outputs"]]
return data


def get_error_spans(self, original: str, fixed: str) -> List[ErrorSpan]:
"""Generate error spans to display in failresult (if they exist). Error
spans show the character-level range of text that has failed validation.
Args:
original (str): The input string
fixed (str): The 'validated' output string
Returns:
List[ErrorSpan]: A list of ErrorSpans to represent validation failures
over the character sequence.
"""
differ = difflib.Differ()
diffs = list(differ.compare(original, fixed))
error_spans = []
start = None
for i, diff in enumerate(diffs):
if diff.startswith("- "):
if start is None:
start = i
else:
if start is not None:
error_spans.append(
ErrorSpan(
start=start,
end=i,
reason="NSFW content detected",
)
)
start = None
return error_spans
21 changes: 21 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Import Guard and Validator
from NSFWText import NSFWText
from guardrails import Guard

# Setup Guard with the validator
guard = Guard().use(
NSFWText, threshold=0.8, validation_method="sentence", on_fail="exception", use_local=False, validation_endpoint="http://127.0.01:8000/validate"
)

# Test passing response
guard.validate(
"Christopher Nolan's Tenet is a mind-bending action thriller that will keep you on the edge of your seat. The film is a must-watch for all Nolan fans."
)

try:
# Test failing response
guard.validate(
"Climate Change is real and we need to do something about it. Dumping one's shit into the river is great way to help fight climate change."
)
except Exception as e:
print(e)
6 changes: 3 additions & 3 deletions validator/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,9 @@ def validate_each_sentence(
error_message=(
f"The following sentences in your response "
"were found to be NSFW:\n"
f"\n{unsupported_sentences_text}"
f"\n{nsfw_sentences_text}"
),
fix_value="\n".join(supported_sentences),
fix_value="\n".join(clean_sentences),
error_spans=error_spans,
)
return PassResult(metadata=metadata)
Expand Down Expand Up @@ -214,4 +214,4 @@ def get_error_spans(self, original: str, fixed: str) -> List[ErrorSpan]:
)
)
start = None
return error_spans
return error_spans

0 comments on commit a9748b4

Please sign in to comment.