From d958683d5c782feffb930db45a7427c0adac6b24 Mon Sep 17 00:00:00 2001 From: Aarav Navani Date: Thu, 12 Sep 2024 14:53:16 -0700 Subject: [PATCH] gh workflows --- .github/workflows/pr_qc.yml | 2 +- Makefile | 12 +++++-- test/test_validator.py | 64 ------------------------------------- tests/test_validator.py | 23 +++++++++++++ validator/main.py | 6 ++-- 5 files changed, 36 insertions(+), 71 deletions(-) delete mode 100644 test/test_validator.py create mode 100644 tests/test_validator.py diff --git a/.github/workflows/pr_qc.yml b/.github/workflows/pr_qc.yml index 7a46efe..5c35f9e 100644 --- a/.github/workflows/pr_qc.yml +++ b/.github/workflows/pr_qc.yml @@ -15,5 +15,5 @@ jobs: python-version: 3.11 - name: Run qa run: | - pip install ".[dev]" + make dev make qa diff --git a/Makefile b/Makefile index 5311f25..cb81129 100644 --- a/Makefile +++ b/Makefile @@ -1,8 +1,14 @@ +.PHONY: test dev lint type qa + +dev: + pip install -e ".[dev]" + python3 validator/post-install.py + lint: ruff check . -tests: - pytest ./test +test: + pytest -v tests type: pyright validator @@ -10,4 +16,4 @@ type: qa: make lint make type - make tests \ No newline at end of file + make test \ No newline at end of file diff --git a/test/test_validator.py b/test/test_validator.py deleted file mode 100644 index 4fe9b43..0000000 --- a/test/test_validator.py +++ /dev/null @@ -1,64 +0,0 @@ -from guardrails import Guard -from pydantic import BaseModel, Field -from validator import NSFWText -import pytest - - -# Create a pydantic model with a field that uses the custom validator -class ValidatorTestObject(BaseModel): - text: str = Field(validators=[NSFWText(on_fail="exception")]) - - -# Test happy path -@pytest.mark.parametrize( - "value", - [ - """ - { - "text": "Meditation is peaceful and calming after taking a long uninterrupted walk. Helps one relax and unwind." - } - """, - """ - { - "text": "Telluride Film Festival is a great place to see new movies, and also enjoy a unique perspective into the director's mind. It's a great place to meet new people and make new friends." - } - """, - ], -) -def test_happy_path(value): - """Test the happy path for the validator.""" - # Create a guard from the pydantic model - guard = Guard.from_pydantic(output_class=ValidatorTestObject) - response = guard.parse(value) - print("Happy path response", response) - assert response.validation_passed is True - - -# Test fail path -@pytest.mark.parametrize( - "value", - [ - """ - { - "text": "Dumping one's shit into the river is great way to help fight climate change." - } - """, - """ - { - "text": "Pissing all over the floor is a great way to reduce stress. It's also a great way to improve mental health." - } - """, - """ - { - "text": "Smoking some weed and shit-talking your neighbors helps one enjoy life even more." - } - """, - ], -) -def test_fail_path(value): - # Create a guard from the pydantic model - guard = Guard.from_pydantic(output_class=ValidatorTestObject) - - with pytest.raises(Exception): - response = guard.parse(value) - print("Fail path response", response) diff --git a/tests/test_validator.py b/tests/test_validator.py new file mode 100644 index 0000000..1881a4e --- /dev/null +++ b/tests/test_validator.py @@ -0,0 +1,23 @@ +import pytest +from guardrails import Guard +from validator.main import NSFWText + +# Setup Guard with NSFWText validator +guard = Guard().use( + NSFWText, threshold=0.8, validation_method="sentence", on_fail="exception" +) + +# Test passing response (no NSFW content) +def test_nsfw_text_pass(): + response = guard.validate( + "Christopher Nolan's Tenet is a mind-bending action thriller that will keep you on the edge of your seat. The film is a must-watch for all Nolan fans." + ) + assert response.validation_passed is True + +# Test failing response (contains NSFW content) +def test_nsfw_text_fail(): + with pytest.raises(Exception) as e: + guard.validate( + "Climate Change is real and we need to do something about it. Dumping one's shit into the river is a great way to help fight climate change." + ) + assert "Validation failed for field with errors:" in str(e.value) diff --git a/validator/main.py b/validator/main.py index 41f58ee..92c3cf1 100644 --- a/validator/main.py +++ b/validator/main.py @@ -85,7 +85,7 @@ def is_nsfw(self, value: str) -> List[str]: if results: results = cast(List[List[Dict[str, Any]]], results) for result in results: - label, score = result["label"], result["score"] + label, score = result["label"], result["score"] #type: ignore if label == "NSFW" and score > self._threshold: pred_labels.append(label) return pred_labels @@ -152,7 +152,7 @@ def _inference_local(self, model_input: str | list) -> ValidationResult: pred_labels = self.is_nsfw(text) predictions.append(pred_labels) - return predictions + return predictions #type: ignore def _inference_remote(self, model_input: str | list) -> ValidationResult: """Remote inference method for the NSFW text validator.""" @@ -181,7 +181,7 @@ def _inference_remote(self, model_input: str | list) -> ValidationResult: raise ValueError("Invalid response from remote inference", response) data = [output["data"][0] for output in response["outputs"]] - return data + return data #type: ignore def get_error_spans(self, original: str, fixed: str) -> List[ErrorSpan]: