Skip to content

Commit

Permalink
Merge pull request #2 from guardrails-ai/main
Browse files Browse the repository at this point in the history
Clean-up + Polish
  • Loading branch information
pazcuturi authored Mar 5, 2024
2 parents a1c09f7 + 0e0332d commit c4a7f82
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 44 deletions.
12 changes: 12 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
dev:
pip install -e ".[dev]"

lint:
ruff check .

type:
pyright validator

qa:
make lint
make type
11 changes: 5 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
# Overview
## Overview

| Developed by | Tryolabs |
| --- | --- |
| Date of development | Feb 15, 2024 |
| Validator type | Format |
| Blog | |
| Blog | - |
| License | Apache 2 |
| Input/Output | Output |

# Description

## Intended Use
## Description
This validator checks if a text is related with a topic.

## Requirements
Expand All @@ -26,7 +25,7 @@ This validator checks if a text is related with a topic.
# Installation

```bash
$ guardrails hub install hub://tryolabs/restricttotopic
guardrails hub install hub://tryolabs/restricttotopic
```

# Usage Examples
Expand Down
13 changes: 11 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,21 @@ authors = [
]
license = {file = "LICENSE"}
readme = "README.md"
requires-python = ">= 3.8"
requires-python = ">= 3.8.1"
dependencies = [
"guardrails-ai>=0.4.0",
"pydantic>=2.4.2",
"tenacity>=8.1.0",
"transformers>=4.11.3",
"torch>=2.1.1",
"python-dotenv"
]
]

[project.optional-dependencies]
dev = [
"pyright",
"ruff"
]

[tool.pyright]
include = ["validator"]
42 changes: 21 additions & 21 deletions test/test-validator.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
import os
from dotenv import load_dotenv

load_dotenv()

from dotenv import load_dotenv
from guardrails import Guard
from pydantic import BaseModel, Field
from validator import RestrictToTopic

load_dotenv()


class ValidatorTestObject(BaseModel):
test_val: str = Field(
validators=[
RestrictToTopic(
valid_topics=["sports"],
invalid_topics=["music"],
disable_classifier=True,
disable_llm=False,
on_fail="exception"
)
],
api_key=os.getenv("OPENAI_API_KEY")
validators=[
RestrictToTopic(
valid_topics=["sports"],
invalid_topics=["music"],
disable_classifier=True,
disable_llm=False,
on_fail="exception",
)
],
api_key=os.getenv("OPENAI_API_KEY"),
)


Expand All @@ -33,10 +33,10 @@ class ValidatorTestObject(BaseModel):
guard = Guard.from_pydantic(output_class=ValidatorTestObject)

try:
guard.parse(TEST_OUTPUT)
print ("Successfully passed validation when it was supposed to.")
except (Exception):
print ("Failed to pass validation when it was supposed to.")
guard.parse(TEST_OUTPUT)
print("Successfully passed validation when it was supposed to.")
except Exception:
print("Failed to pass validation when it was supposed to.")


TEST_FAIL_OUTPUT = """
Expand All @@ -46,7 +46,7 @@ class ValidatorTestObject(BaseModel):
"""

try:
guard.parse(TEST_FAIL_OUTPUT)
print ("Failed to fail validation when it was supposed to")
except (Exception):
print ("Successfully failed validation when it was supposed to.")
guard.parse(TEST_FAIL_OUTPUT)
print("Failed to fail validation when it was supposed to")
except Exception:
print("Successfully failed validation when it was supposed to.")
19 changes: 4 additions & 15 deletions validator/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
import json
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

from tenacity import retry, stop_after_attempt, wait_random_exponential

from guardrails.utils.casting_utils import to_int
from guardrails.utils.openai_utils import OpenAIClient
from guardrails.validator_base import (
Expand All @@ -13,11 +11,8 @@
Validator,
register_validator,
)

try:
from transformers import pipeline
except ImportError:
pipeline = None
from tenacity import retry, stop_after_attempt, wait_random_exponential
from transformers import pipeline


@register_validator(name="tryolabs/restricttotopic", data_type="string")
Expand All @@ -41,7 +36,7 @@ class RestrictToTopic(Validator):
| Property | Description |
| ----------------------------- | ---------------------------------------- |
| Name for `format` attribute | `restrict_to_topic` |
| Name for `format` attribute | `tryolabs/restricttotopic` |
| Supported data types | `string` |
| Programmatic fix | Removes lines with off-topic information |
Expand Down Expand Up @@ -97,13 +92,6 @@ def __init__(
)
self._valid_topics = valid_topics

if pipeline is None:
raise ValueError(
"You must install transformers in order to "
"use the RestrictToTopic validator."
"Install it using `pip install transformers`."
)

if invalid_topics is None:
self._invalid_topics = []
else:
Expand Down Expand Up @@ -161,6 +149,7 @@ def call_llm(self, text: str, topics: List[str]) -> str:
response (str): String representing the LLM response.
"""
from dotenv import load_dotenv

load_dotenv()
return self._llm_callable(text, topics)

Expand Down

0 comments on commit c4a7f82

Please sign in to comment.