Skip to content

Commit

Permalink
gh workflows
Browse files Browse the repository at this point in the history
  • Loading branch information
aaravnavani committed Sep 12, 2024
1 parent adbb456 commit d958683
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 71 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pr_qc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ jobs:
python-version: 3.11
- name: Run qa
run: |
pip install ".[dev]"
make dev
make qa
12 changes: 9 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
.PHONY: test dev lint type qa

dev:
pip install -e ".[dev]"
python3 validator/post-install.py

lint:
ruff check .

tests:
pytest ./test
test:
pytest -v tests

type:
pyright validator

qa:
make lint
make type
make tests
make test
64 changes: 0 additions & 64 deletions test/test_validator.py

This file was deleted.

23 changes: 23 additions & 0 deletions tests/test_validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import pytest
from guardrails import Guard
from validator.main import NSFWText

# Setup Guard with NSFWText validator
guard = Guard().use(
NSFWText, threshold=0.8, validation_method="sentence", on_fail="exception"
)

# Test passing response (no NSFW content)
def test_nsfw_text_pass():
response = guard.validate(
"Christopher Nolan's Tenet is a mind-bending action thriller that will keep you on the edge of your seat. The film is a must-watch for all Nolan fans."
)
assert response.validation_passed is True

# Test failing response (contains NSFW content)
def test_nsfw_text_fail():
with pytest.raises(Exception) as e:
guard.validate(
"Climate Change is real and we need to do something about it. Dumping one's shit into the river is a great way to help fight climate change."
)
assert "Validation failed for field with errors:" in str(e.value)
6 changes: 3 additions & 3 deletions validator/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def is_nsfw(self, value: str) -> List[str]:
if results:
results = cast(List[List[Dict[str, Any]]], results)
for result in results:
label, score = result["label"], result["score"]
label, score = result["label"], result["score"] #type: ignore
if label == "NSFW" and score > self._threshold:
pred_labels.append(label)
return pred_labels
Expand Down Expand Up @@ -152,7 +152,7 @@ def _inference_local(self, model_input: str | list) -> ValidationResult:
pred_labels = self.is_nsfw(text)
predictions.append(pred_labels)

return predictions
return predictions #type: ignore

def _inference_remote(self, model_input: str | list) -> ValidationResult:
"""Remote inference method for the NSFW text validator."""
Expand Down Expand Up @@ -181,7 +181,7 @@ def _inference_remote(self, model_input: str | list) -> ValidationResult:
raise ValueError("Invalid response from remote inference", response)

data = [output["data"][0] for output in response["outputs"]]
return data
return data #type: ignore


def get_error_spans(self, original: str, fixed: str) -> List[ErrorSpan]:
Expand Down

0 comments on commit d958683

Please sign in to comment.