From a9748b41825c33587a5bcab943d8c80df58ead91 Mon Sep 17 00:00:00 2001 From: Aarav Navani <arav.navani@gmail.com> Date: Thu, 5 Sep 2024 13:12:50 -0700 Subject: [PATCH] changes --- NSFWText.py | 217 ++++++++++++++++++++++++++++++++++++++++++++++ test.py | 21 +++++ validator/main.py | 6 +- 3 files changed, 241 insertions(+), 3 deletions(-) create mode 100644 NSFWText.py create mode 100644 test.py diff --git a/NSFWText.py b/NSFWText.py new file mode 100644 index 0000000..41f58ee --- /dev/null +++ b/NSFWText.py @@ -0,0 +1,217 @@ +import difflib +import json +from typing import Any, Callable, Dict, List, Optional, Union, cast + +import nltk +from guardrails.validator_base import ( + ErrorSpan, + FailResult, + PassResult, + ValidationResult, + Validator, + register_validator, +) +from transformers import pipeline + + +@register_validator( + name="guardrails/nsfw_text", data_type="string", has_guardrails_endpoint=True +) +class NSFWText(Validator): + """Validates that the generated text is safe for work (SFW). + + **Key Properties** + | Property | Description | + | ----------------------------- | ---------------------- | + | Name for `format` attribute | `guardrails/nsfw_text` | + | Supported data types | `string` | + | Programmatic fix | N/A | + + Args: + threshold: The confidence threshold over which model inferences are considered. + Must be a float between 0 and 1. Defaults to 0.8 + validation_method: Whether to validate at the sentence level or + over the full text. Must be one of `sentence` or `full`. + Defaults to `sentence` + + This validator uses the pre-trained multi-class model from HuggingFace - + `michellejieli/NSFW_text_classifier` to check whether the generated text is + safe for work. If the model predicts the text to be "NSFW" with a confidence + higher than the threshold, the validator fails. Otherwise, it passes. + + If validation_method is `sentence`, the validator will remove the sentences + that are predicted to be NSFW and return the remaining sentences. If + validation_method is `full`, the validator will remove the entire text if + the prediction is deemed NSFW it will return an empty string. + """ + + def __init__( + self, + threshold: float = 0.8, + validation_method: str = "sentence", + device: Optional[Union[str, int]] = "cpu", + model_name: Optional[str] = "michellejieli/NSFW_text_classifier", + on_fail: Optional[Callable[..., Any]] = None, + **kwargs, + ): + super().__init__( + on_fail=on_fail, + threshold=threshold, + validation_method=validation_method, + **kwargs, + ) + self._threshold = float(threshold) + if validation_method not in ["sentence", "full"]: + raise ValueError("validation_method must be 'sentence' or 'full'.") + self._validation_method = validation_method + if self.use_local: + self._model = pipeline( + "text-classification", + model=model_name, + ) + + def is_nsfw(self, value: str) -> List[str]: + """Determines if the generated text is NSFW. + + Args: + value (str): The generated text. + + Returns: + list[str]: Labels predicted by the model + """ + pred_labels = [] + if value: + results = self._model(value) + if results: + results = cast(List[List[Dict[str, Any]]], results) + for result in results: + label, score = result["label"], result["score"] + if label == "NSFW" and score > self._threshold: + pred_labels.append(label) + return pred_labels + + def validate_each_sentence( + self, value: str, metadata: Dict[str, Any] + ) -> ValidationResult: + """Validate that each sentence in the generated text is SFW.""" + + # Split the value into sentences using nltk sentence tokenizer. + sentences = nltk.sent_tokenize(value) + + clean_sentences, nsfw_sentences = [], [] + error_spans: List[ErrorSpan] = [] + char_index = 0 + + sentence_predictions = self._inference(sentences) + + for idx, sentence in enumerate(sentences): + pred_labels = sentence_predictions[idx] + + if pred_labels: + nsfw_sentences.append(sentence) + error_spans.append( + ErrorSpan( + start=char_index, + end=char_index + len(sentence), + reason=f"NSFW content detected: {', '.join(pred_labels)}", + ) + ) + else: + clean_sentences.append(sentence) + char_index += len(sentence) + 1 + + if nsfw_sentences: + nsfw_sentences_text = "- " + "\n- ".join(nsfw_sentences) + + return FailResult( + metadata=metadata, + error_message=( + f"The following sentences in your response " + "were found to be NSFW:\n" + f"\n{nsfw_sentences_text}" + ), + fix_value="\n".join(clean_sentences), + error_spans=error_spans, + ) + return PassResult(metadata=metadata) + + def validate(self, value: str, metadata: Dict[str, Any]) -> ValidationResult: + """Validation method for the NSFW text validator.""" + if not value: + raise ValueError("Value cannot be empty.") + + return self.validate_each_sentence(value, metadata) + + def _inference_local(self, model_input: str | list) -> ValidationResult: + """Local inference method for the NSFW text validator.""" + + if isinstance(model_input, str): + model_input = [model_input] + predictions = [] + for text in model_input: + pred_labels = self.is_nsfw(text) + predictions.append(pred_labels) + + return predictions + + def _inference_remote(self, model_input: str | list) -> ValidationResult: + """Remote inference method for the NSFW text validator.""" + + if isinstance(model_input, str): + model_input = [model_input] + + request_body = { + "inputs": [ + { + "name": "text", + "shape": [len(model_input)], + "data": model_input, + "datatype": "BYTES" + }, + { + "name": "threshold", + "shape": [1], + "data": [self._threshold], + "datatype": "FP32" + } + ] + } + response = self._hub_inference_request(json.dumps(request_body), self.validation_endpoint) + if not response or "outputs" not in response: + raise ValueError("Invalid response from remote inference", response) + + data = [output["data"][0] for output in response["outputs"]] + return data + + + def get_error_spans(self, original: str, fixed: str) -> List[ErrorSpan]: + """Generate error spans to display in failresult (if they exist). Error + spans show the character-level range of text that has failed validation. + + Args: + original (str): The input string + fixed (str): The 'validated' output string + + Returns: + List[ErrorSpan]: A list of ErrorSpans to represent validation failures + over the character sequence. + """ + differ = difflib.Differ() + diffs = list(differ.compare(original, fixed)) + error_spans = [] + start = None + for i, diff in enumerate(diffs): + if diff.startswith("- "): + if start is None: + start = i + else: + if start is not None: + error_spans.append( + ErrorSpan( + start=start, + end=i, + reason="NSFW content detected", + ) + ) + start = None + return error_spans \ No newline at end of file diff --git a/test.py b/test.py new file mode 100644 index 0000000..1a10a20 --- /dev/null +++ b/test.py @@ -0,0 +1,21 @@ +# Import Guard and Validator +from NSFWText import NSFWText +from guardrails import Guard + +# Setup Guard with the validator +guard = Guard().use( + NSFWText, threshold=0.8, validation_method="sentence", on_fail="exception", use_local=False, validation_endpoint="http://127.0.01:8000/validate" +) + +# Test passing response +guard.validate( + "Christopher Nolan's Tenet is a mind-bending action thriller that will keep you on the edge of your seat. The film is a must-watch for all Nolan fans." +) + +try: + # Test failing response + guard.validate( + "Climate Change is real and we need to do something about it. Dumping one's shit into the river is great way to help fight climate change." + ) +except Exception as e: + print(e) \ No newline at end of file diff --git a/validator/main.py b/validator/main.py index 221102f..41f58ee 100644 --- a/validator/main.py +++ b/validator/main.py @@ -128,9 +128,9 @@ def validate_each_sentence( error_message=( f"The following sentences in your response " "were found to be NSFW:\n" - f"\n{unsupported_sentences_text}" + f"\n{nsfw_sentences_text}" ), - fix_value="\n".join(supported_sentences), + fix_value="\n".join(clean_sentences), error_spans=error_spans, ) return PassResult(metadata=metadata) @@ -214,4 +214,4 @@ def get_error_spans(self, original: str, fixed: str) -> List[ErrorSpan]: ) ) start = None - return error_spans + return error_spans \ No newline at end of file