From 06f5233969b24f6f73b0333c66c9eb7d58d7bf6e Mon Sep 17 00:00:00 2001 From: Aarav Navani Date: Thu, 5 Sep 2024 13:13:25 -0700 Subject: [PATCH] delete --- NSFWText.py | 217 ---------------------------------------------------- test.py | 21 ----- 2 files changed, 238 deletions(-) delete mode 100644 NSFWText.py delete mode 100644 test.py diff --git a/NSFWText.py b/NSFWText.py deleted file mode 100644 index 41f58ee..0000000 --- a/NSFWText.py +++ /dev/null @@ -1,217 +0,0 @@ -import difflib -import json -from typing import Any, Callable, Dict, List, Optional, Union, cast - -import nltk -from guardrails.validator_base import ( - ErrorSpan, - FailResult, - PassResult, - ValidationResult, - Validator, - register_validator, -) -from transformers import pipeline - - -@register_validator( - name="guardrails/nsfw_text", data_type="string", has_guardrails_endpoint=True -) -class NSFWText(Validator): - """Validates that the generated text is safe for work (SFW). - - **Key Properties** - | Property | Description | - | ----------------------------- | ---------------------- | - | Name for `format` attribute | `guardrails/nsfw_text` | - | Supported data types | `string` | - | Programmatic fix | N/A | - - Args: - threshold: The confidence threshold over which model inferences are considered. - Must be a float between 0 and 1. Defaults to 0.8 - validation_method: Whether to validate at the sentence level or - over the full text. Must be one of `sentence` or `full`. - Defaults to `sentence` - - This validator uses the pre-trained multi-class model from HuggingFace - - `michellejieli/NSFW_text_classifier` to check whether the generated text is - safe for work. If the model predicts the text to be "NSFW" with a confidence - higher than the threshold, the validator fails. Otherwise, it passes. - - If validation_method is `sentence`, the validator will remove the sentences - that are predicted to be NSFW and return the remaining sentences. If - validation_method is `full`, the validator will remove the entire text if - the prediction is deemed NSFW it will return an empty string. - """ - - def __init__( - self, - threshold: float = 0.8, - validation_method: str = "sentence", - device: Optional[Union[str, int]] = "cpu", - model_name: Optional[str] = "michellejieli/NSFW_text_classifier", - on_fail: Optional[Callable[..., Any]] = None, - **kwargs, - ): - super().__init__( - on_fail=on_fail, - threshold=threshold, - validation_method=validation_method, - **kwargs, - ) - self._threshold = float(threshold) - if validation_method not in ["sentence", "full"]: - raise ValueError("validation_method must be 'sentence' or 'full'.") - self._validation_method = validation_method - if self.use_local: - self._model = pipeline( - "text-classification", - model=model_name, - ) - - def is_nsfw(self, value: str) -> List[str]: - """Determines if the generated text is NSFW. - - Args: - value (str): The generated text. - - Returns: - list[str]: Labels predicted by the model - """ - pred_labels = [] - if value: - results = self._model(value) - if results: - results = cast(List[List[Dict[str, Any]]], results) - for result in results: - label, score = result["label"], result["score"] - if label == "NSFW" and score > self._threshold: - pred_labels.append(label) - return pred_labels - - def validate_each_sentence( - self, value: str, metadata: Dict[str, Any] - ) -> ValidationResult: - """Validate that each sentence in the generated text is SFW.""" - - # Split the value into sentences using nltk sentence tokenizer. - sentences = nltk.sent_tokenize(value) - - clean_sentences, nsfw_sentences = [], [] - error_spans: List[ErrorSpan] = [] - char_index = 0 - - sentence_predictions = self._inference(sentences) - - for idx, sentence in enumerate(sentences): - pred_labels = sentence_predictions[idx] - - if pred_labels: - nsfw_sentences.append(sentence) - error_spans.append( - ErrorSpan( - start=char_index, - end=char_index + len(sentence), - reason=f"NSFW content detected: {', '.join(pred_labels)}", - ) - ) - else: - clean_sentences.append(sentence) - char_index += len(sentence) + 1 - - if nsfw_sentences: - nsfw_sentences_text = "- " + "\n- ".join(nsfw_sentences) - - return FailResult( - metadata=metadata, - error_message=( - f"The following sentences in your response " - "were found to be NSFW:\n" - f"\n{nsfw_sentences_text}" - ), - fix_value="\n".join(clean_sentences), - error_spans=error_spans, - ) - return PassResult(metadata=metadata) - - def validate(self, value: str, metadata: Dict[str, Any]) -> ValidationResult: - """Validation method for the NSFW text validator.""" - if not value: - raise ValueError("Value cannot be empty.") - - return self.validate_each_sentence(value, metadata) - - def _inference_local(self, model_input: str | list) -> ValidationResult: - """Local inference method for the NSFW text validator.""" - - if isinstance(model_input, str): - model_input = [model_input] - predictions = [] - for text in model_input: - pred_labels = self.is_nsfw(text) - predictions.append(pred_labels) - - return predictions - - def _inference_remote(self, model_input: str | list) -> ValidationResult: - """Remote inference method for the NSFW text validator.""" - - if isinstance(model_input, str): - model_input = [model_input] - - request_body = { - "inputs": [ - { - "name": "text", - "shape": [len(model_input)], - "data": model_input, - "datatype": "BYTES" - }, - { - "name": "threshold", - "shape": [1], - "data": [self._threshold], - "datatype": "FP32" - } - ] - } - response = self._hub_inference_request(json.dumps(request_body), self.validation_endpoint) - if not response or "outputs" not in response: - raise ValueError("Invalid response from remote inference", response) - - data = [output["data"][0] for output in response["outputs"]] - return data - - - def get_error_spans(self, original: str, fixed: str) -> List[ErrorSpan]: - """Generate error spans to display in failresult (if they exist). Error - spans show the character-level range of text that has failed validation. - - Args: - original (str): The input string - fixed (str): The 'validated' output string - - Returns: - List[ErrorSpan]: A list of ErrorSpans to represent validation failures - over the character sequence. - """ - differ = difflib.Differ() - diffs = list(differ.compare(original, fixed)) - error_spans = [] - start = None - for i, diff in enumerate(diffs): - if diff.startswith("- "): - if start is None: - start = i - else: - if start is not None: - error_spans.append( - ErrorSpan( - start=start, - end=i, - reason="NSFW content detected", - ) - ) - start = None - return error_spans \ No newline at end of file diff --git a/test.py b/test.py deleted file mode 100644 index 1a10a20..0000000 --- a/test.py +++ /dev/null @@ -1,21 +0,0 @@ -# Import Guard and Validator -from NSFWText import NSFWText -from guardrails import Guard - -# Setup Guard with the validator -guard = Guard().use( - NSFWText, threshold=0.8, validation_method="sentence", on_fail="exception", use_local=False, validation_endpoint="http://127.0.01:8000/validate" -) - -# Test passing response -guard.validate( - "Christopher Nolan's Tenet is a mind-bending action thriller that will keep you on the edge of your seat. The film is a must-watch for all Nolan fans." -) - -try: - # Test failing response - guard.validate( - "Climate Change is real and we need to do something about it. Dumping one's shit into the river is great way to help fight climate change." - ) -except Exception as e: - print(e) \ No newline at end of file