From a9748b41825c33587a5bcab943d8c80df58ead91 Mon Sep 17 00:00:00 2001
From: Aarav Navani <arav.navani@gmail.com>
Date: Thu, 5 Sep 2024 13:12:50 -0700
Subject: [PATCH] changes

---
 NSFWText.py       | 217 ++++++++++++++++++++++++++++++++++++++++++++++
 test.py           |  21 +++++
 validator/main.py |   6 +-
 3 files changed, 241 insertions(+), 3 deletions(-)
 create mode 100644 NSFWText.py
 create mode 100644 test.py

diff --git a/NSFWText.py b/NSFWText.py
new file mode 100644
index 0000000..41f58ee
--- /dev/null
+++ b/NSFWText.py
@@ -0,0 +1,217 @@
+import difflib
+import json
+from typing import Any, Callable, Dict, List, Optional, Union, cast
+
+import nltk
+from guardrails.validator_base import (
+    ErrorSpan,
+    FailResult,
+    PassResult,
+    ValidationResult,
+    Validator,
+    register_validator,
+)
+from transformers import pipeline
+
+
+@register_validator(
+    name="guardrails/nsfw_text", data_type="string", has_guardrails_endpoint=True
+)
+class NSFWText(Validator):
+    """Validates that the generated text is safe for work (SFW).
+
+    **Key Properties**
+    | Property                      | Description            |
+    | ----------------------------- | ---------------------- |
+    | Name for `format` attribute   | `guardrails/nsfw_text` |
+    | Supported data types          | `string`               |
+    | Programmatic fix              | N/A                    |
+
+    Args:
+        threshold: The confidence threshold over which model inferences are considered.
+            Must be a float between 0 and 1. Defaults to 0.8
+        validation_method: Whether to validate at the sentence level or
+            over the full text. Must be one of `sentence` or `full`.
+            Defaults to `sentence`
+
+    This validator uses the pre-trained multi-class model from HuggingFace -
+    `michellejieli/NSFW_text_classifier` to check whether the generated text is
+    safe for work. If the model predicts the text to be "NSFW" with a confidence
+    higher than the threshold, the validator fails. Otherwise, it passes.
+
+    If validation_method is `sentence`, the validator will remove the sentences
+    that are predicted to be NSFW and return the remaining sentences. If
+    validation_method is `full`, the validator will remove the entire text if
+    the prediction is deemed NSFW it will return an empty string.
+    """
+
+    def __init__(
+        self,
+        threshold: float = 0.8,
+        validation_method: str = "sentence",
+        device: Optional[Union[str, int]] = "cpu",
+        model_name: Optional[str] = "michellejieli/NSFW_text_classifier",
+        on_fail: Optional[Callable[..., Any]] = None,
+        **kwargs,
+    ):
+        super().__init__(
+            on_fail=on_fail,
+            threshold=threshold,
+            validation_method=validation_method,
+            **kwargs,
+        )
+        self._threshold = float(threshold)
+        if validation_method not in ["sentence", "full"]:
+            raise ValueError("validation_method must be 'sentence' or 'full'.")
+        self._validation_method = validation_method
+        if self.use_local:
+            self._model = pipeline(
+                "text-classification",
+                model=model_name,
+            )
+
+    def is_nsfw(self, value: str) -> List[str]:
+        """Determines if the generated text is NSFW.
+
+        Args:
+            value (str): The generated text.
+
+        Returns:
+            list[str]: Labels predicted by the model
+        """
+        pred_labels = []
+        if value:
+            results = self._model(value)
+            if results:
+                results = cast(List[List[Dict[str, Any]]], results)
+                for result in results:
+                    label, score = result["label"], result["score"]
+                    if label == "NSFW" and score > self._threshold:
+                        pred_labels.append(label)
+        return pred_labels
+
+    def validate_each_sentence(
+        self, value: str, metadata: Dict[str, Any]
+    ) -> ValidationResult:
+        """Validate that each sentence in the generated text is SFW."""
+
+        # Split the value into sentences using nltk sentence tokenizer.
+        sentences = nltk.sent_tokenize(value)
+
+        clean_sentences, nsfw_sentences = [], []
+        error_spans: List[ErrorSpan] = []
+        char_index = 0
+
+        sentence_predictions = self._inference(sentences)
+
+        for idx, sentence in enumerate(sentences):
+            pred_labels = sentence_predictions[idx]
+
+            if pred_labels:
+                nsfw_sentences.append(sentence)
+                error_spans.append(
+                    ErrorSpan(
+                        start=char_index,
+                        end=char_index + len(sentence),
+                        reason=f"NSFW content detected: {', '.join(pred_labels)}",
+                    )
+                )
+            else:
+                clean_sentences.append(sentence)
+            char_index += len(sentence) + 1  
+
+        if nsfw_sentences:
+            nsfw_sentences_text = "- " + "\n- ".join(nsfw_sentences)
+
+            return FailResult(
+                metadata=metadata,
+                error_message=(
+                    f"The following sentences in your response "
+                    "were found to be NSFW:\n"
+                    f"\n{nsfw_sentences_text}"
+                ),
+                fix_value="\n".join(clean_sentences),
+                error_spans=error_spans,
+            )
+        return PassResult(metadata=metadata)
+
+    def validate(self, value: str, metadata: Dict[str, Any]) -> ValidationResult:
+        """Validation method for the NSFW text validator."""
+        if not value:
+            raise ValueError("Value cannot be empty.")
+
+        return self.validate_each_sentence(value, metadata)
+
+    def _inference_local(self, model_input: str | list) -> ValidationResult:
+        """Local inference method for the NSFW text validator."""
+        
+        if isinstance(model_input, str):
+            model_input = [model_input]
+        predictions = []
+        for text in model_input:
+            pred_labels = self.is_nsfw(text)
+            predictions.append(pred_labels)
+        
+        return predictions
+
+    def _inference_remote(self, model_input: str | list) -> ValidationResult:
+        """Remote inference method for the NSFW text validator."""
+        
+        if isinstance(model_input, str):
+            model_input = [model_input]
+        
+        request_body = {
+            "inputs": [
+                {
+                    "name": "text",
+                    "shape": [len(model_input)],
+                    "data": model_input,
+                    "datatype": "BYTES"
+                },
+                {
+                    "name": "threshold",
+                    "shape": [1],
+                    "data": [self._threshold],
+                    "datatype": "FP32"
+                }
+            ]
+        }
+        response = self._hub_inference_request(json.dumps(request_body), self.validation_endpoint)
+        if not response or "outputs" not in response:
+            raise ValueError("Invalid response from remote inference", response)
+        
+        data = [output["data"][0] for output in response["outputs"]]
+        return data
+
+
+    def get_error_spans(self, original: str, fixed: str) -> List[ErrorSpan]:
+        """Generate error spans to display in failresult (if they exist). Error
+        spans show the character-level range of text that has failed validation.
+
+        Args:
+            original (str): The input string
+            fixed (str): The 'validated' output string
+
+        Returns:
+            List[ErrorSpan]: A list of ErrorSpans to represent validation failures
+            over the character sequence.
+        """
+        differ = difflib.Differ()
+        diffs = list(differ.compare(original, fixed))
+        error_spans = []
+        start = None
+        for i, diff in enumerate(diffs):
+            if diff.startswith("- "):
+                if start is None:
+                    start = i
+            else:
+                if start is not None:
+                    error_spans.append(
+                        ErrorSpan(
+                            start=start,
+                            end=i,
+                            reason="NSFW content detected",
+                        )
+                    )
+                    start = None
+        return error_spans
\ No newline at end of file
diff --git a/test.py b/test.py
new file mode 100644
index 0000000..1a10a20
--- /dev/null
+++ b/test.py
@@ -0,0 +1,21 @@
+# Import Guard and Validator
+from NSFWText import NSFWText
+from guardrails import Guard
+
+# Setup Guard with the validator
+guard = Guard().use(
+    NSFWText, threshold=0.8, validation_method="sentence", on_fail="exception", use_local=False, validation_endpoint="http://127.0.01:8000/validate"
+)
+
+# Test passing response
+guard.validate(
+    "Christopher Nolan's Tenet is a mind-bending action thriller that will keep you on the edge of your seat. The film is a must-watch for all Nolan fans."
+)
+
+try:
+    # Test failing response
+    guard.validate(
+        "Climate Change is real and we need to do something about it. Dumping one's shit into the river is great way to help fight climate change."
+    )
+except Exception as e:
+    print(e)
\ No newline at end of file
diff --git a/validator/main.py b/validator/main.py
index 221102f..41f58ee 100644
--- a/validator/main.py
+++ b/validator/main.py
@@ -128,9 +128,9 @@ def validate_each_sentence(
                 error_message=(
                     f"The following sentences in your response "
                     "were found to be NSFW:\n"
-                    f"\n{unsupported_sentences_text}"
+                    f"\n{nsfw_sentences_text}"
                 ),
-                fix_value="\n".join(supported_sentences),
+                fix_value="\n".join(clean_sentences),
                 error_spans=error_spans,
             )
         return PassResult(metadata=metadata)
@@ -214,4 +214,4 @@ def get_error_spans(self, original: str, fixed: str) -> List[ErrorSpan]:
                         )
                     )
                     start = None
-        return error_spans
+        return error_spans
\ No newline at end of file