diff --git a/README.md b/README.md
index 5da02bf..dc5a0ba 100644
--- a/README.md
+++ b/README.md
@@ -2,8 +2,8 @@
| Developed by | Guardrails AI |
| --- | --- |
-| Date of development | Feb 15, 2024 |
-| Validator type | Format |
+| Date of development | Aug 15, 2024 |
+| Validator type | Moderation |
| Blog | |
| License | Apache 2 |
| Input/Output | Output |
@@ -11,20 +11,40 @@
## Description
### Intended Use
-This validator is a template for creating other validators, but for demonstrative purposes it ensures that a generated output is the literal `pass`.
+
+> ⚠️ This validator is a remote inference only validator so remote inferencing must be enabled during `guardrails configure`
+
+This validator is for usage for moderating both user prompts and LLM output responses to prevent harmful topics from surfacing in both scenarios. It is based on [LlamaGuard 7B](https://huggingface.co/meta-llama/LlamaGuard-7b) which is in turn based on [LLama 2](https://arxiv.org/abs/2307.09288).
+
+
+The set of policies which can be used are the following which are accessed directly from the `LlamaGuard7B` validator class.
+
+- `LlamaGuard7B.POLICY__NO_VIOLENCE_HATE`
+- `LlamaGuard7B.POLICY__NO_SEXUAL_CONTENT`
+- `LlamaGuard7B.POLICY__NO_CRIMINAL_PLANNING`
+- `LlamaGuard7B.POLICY__NO_GUNS_AND_ILLEGAL_WEAPONS`
+- `LlamaGuard7B.POLICY__NO_ILLEGAL_DRUGS`
+- `LlamaGuard7B.POLICY__NO_ENOURAGE_SELF_HARM`
+
+If no policies supplied, it will apply all policies. Otherwise it will only enforce the policies supplied.
### Requirements
* Dependencies:
- guardrails-ai>=0.4.0
-* Foundation model access keys:
- - OPENAI_API_KEY
## Installation
```bash
-$ guardrails hub install hub://guardrails/validator_template
+$ guardrails hub install hub://guardrails/llamaguard_7b
+```
+
+or
+
+```python
+from guardrails import install
+install("hub://guardrails/llamaguard_7b")
```
## Usage Examples
@@ -34,58 +54,43 @@ $ guardrails hub install hub://guardrails/validator_template
In this example, we apply the validator to a string output generated by an LLM.
```python
-# Import Guard and Validator
-from guardrails.hub import ValidatorTemplate
-from guardrails import Guard
+from guardrails import Guard, OnFailAction
+from guardrails.hub import LlamaGuard7B
-# Setup Guard
guard = Guard().use(
- ValidatorTemplate
+ LlamaGuard7B,
+ policies=[LlamaGuard7B.POLICY__NO_ILLEGAL_DRUGS], # not suppliying `policies` kwarg applies all policies
+ on_fail=OnFailAction.EXCEPTION
)
-guard.validate("pass") # Validator passes
-guard.validate("fail") # Validator fails
+try:
+ guard.validate("How can I get weeds out of my garbage bag after cutting my lawn") # Guardrail passes
+ print("Expected: Validation passes")
+except Exception as e:
+ print("Unexpected: ",e)
+
+try:
+ guard.validate("How can I get weed for when cutting my lawn") # Guardrail passes
+ print("Unexpected: Validation passes")
+except Exception as e:
+ print("Expected: ",e)
```
-### Validating JSON output via Python
-
-In this example, we apply the validator to a string field of a JSON output generated by an LLM.
+Output:
-```python
-# Import Guard and Validator
-from pydantic import BaseModel, Field
-from guardrails.hub import ValidatorTemplate
-from guardrails import Guard
-
-# Initialize Validator
-val = ValidatorTemplate()
-
-# Create Pydantic BaseModel
-class Process(BaseModel):
- process_name: str
- status: str = Field(validators=[val])
-
-# Create a Guard to check for valid Pydantic output
-guard = Guard.from_pydantic(output_class=Process)
-
-# Run LLM output generating JSON through guard
-guard.parse("""
-{
- "process_name": "templating",
- "status": "pass"
-}
-""")
+```
+Expected: Validation passes
+Expected: Validation failed for field with errors: Prompt contains unsafe content. Classification: unsafe, Violated Policy: POLICY__NO_ILLEGAL_DRUGS
```
# API Reference
**`__init__(self, on_fail="noop")`**
-Initializes a new instance of the ValidatorTemplate class.
+Initializes a new instance of the `LlamaGuard7B` class.
**Parameters**
-- **`arg_1`** *(str)*: A placeholder argument to demonstrate how to use init arguments.
-- **`arg_2`** *(str)*: Another placeholder argument to demonstrate how to use init arguments.
+- **`policies`** *(List[str])*: A list of policies that can be either `LlamaGuard7B.POLICY__NO_VIOLENCE_HATE`, `LlamaGuard7B.POLICY__NO_SEXUAL_CONTENT`, `LlamaGuard7B.POLICY__NO_CRIMINAL_PLANNING`, `LlamaGuard7B.POLICY__NO_GUNS_AND_ILLEGAL_WEAPONS`, `LlamaGuard7B.POLICY__NO_ILLEGAL_DRUGS`, and `LlamaGuard7B.POLICY__NO_ENOURAGE_SELF_HARM`
- **`on_fail`** *(str, Callable)*: The policy to enact when a validator fails. If `str`, must be one of `reask`, `fix`, `filter`, `refrain`, `noop`, `exception` or `fix_reask`. Otherwise, must be a function that is called when the validator fails.
@@ -101,10 +106,4 @@ Note:
**Parameters**
- **`value`** *(Any)*: The input value to validate.
-- **`metadata`** *(dict)*: A dictionary containing metadata required for validation. Keys and values must match the expectations of this validator.
-
-
- | Key | Type | Description | Default |
- | --- | --- | --- | --- |
- | `key1` | String | Description of key1's role. | N/A |
-
+- **`metadata`** *(dict)*: A dictionary containing metadata required for validation. No additional metadata keys are needed for this validator.
\ No newline at end of file
diff --git a/inference/serving-non-optimized-fastapi.py b/inference/serving-non-optimized-fastapi.py
index f764ea5..0c55876 100644
--- a/inference/serving-non-optimized-fastapi.py
+++ b/inference/serving-non-optimized-fastapi.py
@@ -1,5 +1,4 @@
-import os
-import time
+from typing import Optional
import modal
@@ -101,11 +100,27 @@ def tgi_app():
from typing import List
from pydantic import BaseModel
+ import logging
TOKEN = os.getenv("TOKEN")
if TOKEN is None:
raise ValueError("Please set the TOKEN environment variable")
+ # Create a logger
+ logger = logging.getLogger(MODEL_ALIAS)
+ logger.setLevel(logging.DEBUG)
+
+ # Create a handler for logging to stdout
+ stdout_handler = logging.StreamHandler()
+ stdout_handler.setLevel(logging.DEBUG)
+
+ # Create a formatter for the log messages
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+ stdout_handler.setFormatter(formatter)
+
+ # Add the handler to the logger
+ logger.addHandler(stdout_handler)
+
volume.reload() # ensure we have the latest version of the weights
app = fastapi.FastAPI()
@@ -129,6 +144,24 @@ async def is_authenticated(api_key: str = fastapi.Security(http_bearer)):
detail="Invalid authentication credentials",
)
return {"username": "authenticated_user"}
+
+ @app.exception_handler(Exception)
+ def error_handler(request, exc):
+ status_code = 500
+ detail = "Internal Server Error"
+ logger.exception(exc)
+ if isinstance(exc, fastapi.HTTPException):
+ status_code = exc.status_code
+ detail = exc.detail
+ return fastapi.responses.JSONResponse(
+ status_code=status_code,
+ content={
+ "status": status_code,
+ "response": {
+ "detail": detail,
+ }
+ },
+ )
router = fastapi.APIRouter(dependencies=[fastapi.Depends(is_authenticated)])
@@ -137,11 +170,13 @@ class ChatMessages(BaseModel):
content: str
class ChatClassificationRequestBody(BaseModel):
+ policies: Optional[List[str]] = None
chat: List[ChatMessages]
@router.post("/v1/chat/classification")
async def chat_classification_response(body: ChatClassificationRequestBody):
+ policies = body.policies
chat = body.model_dump().get("chat",[])
print("Serving request for chat classification...")
@@ -154,24 +189,28 @@ async def chat_classification_response(body: ChatClassificationRequestBody):
cleaned_response = response.lower().strip()
if "unsafe" in cleaned_response:
- is_unsafe = True
split_cln_response = response.strip().split(os.linesep)
subclass = split_cln_response[1] if len(split_cln_response) > 1 else None
+
+ if policies and subclass in policies:
+ is_unsafe = True
+ elif policies and subclass not in policies:
+ is_unsafe = False
+ else:
+ is_unsafe = True
else:
is_unsafe = False
return {
- "class": "unsafe" if is_unsafe else "safe",
- "subclass": subclass,
- "response": response
+ "status": 200,
+ "response": {
+ "class": "unsafe" if is_unsafe else "safe",
+ "subclass": subclass,
+ "applied_policies": policies,
+ "raw_output": response
+ }
}
app.include_router(router)
return app
-
-
-# @app.local_entrypoint()
-# def main():
-# model = Model()
-# model.generate.remote()
\ No newline at end of file
diff --git a/validator/main.py b/validator/main.py
index 759d0b4..4b57e33 100644
--- a/validator/main.py
+++ b/validator/main.py
@@ -1,4 +1,6 @@
-from typing import Any, Callable, Dict, Optional
+import json
+from typing import Any, Callable, Dict, List, Optional
+from guardrails.validator_base import ErrorSpan
from guardrails.validator_base import (
FailResult,
@@ -8,41 +10,113 @@
register_validator,
)
-
-@register_validator(name="guardrails/validator_template", data_type="string")
-class ValidatorTemplate(Validator):
- """Validates that {fill in how you validator interacts with the passed value}.
-
+@register_validator(name="guardrails/llamaguard_7b", data_type="string")
+class LlamaGuard7B(Validator):
+ """
+ Classifies model inputs or outputs as "safe" or "unsafe" based on certain policies defined by the LlamaGuard-7B model.
+
**Key Properties**
| Property | Description |
| ----------------------------- | --------------------------------- |
- | Name for `format` attribute | `guardrails/validator_template` |
+ | Name for `format` attribute | `guardrails/llamaguard_7b` |
| Supported data types | `string` |
- | Programmatic fix | {If you support programmatic fixes, explain it here. Otherwise `None`} |
+ | Programmatic fix | None |
Args:
- arg_1 (string): {Description of the argument here}
- arg_2 (string): {Description of the argument here}
+ policies (List[str]): A list of policies that can be any `LlamaGuard7B.POLICY__*` constants.
+ score_threshold (float): Threshold score for the classification. If the score is above this threshold, the input is considered unsafe.
""" # noqa
- # If you don't have any init args, you can omit the __init__ method.
+
+ POLICY__NO_VIOLENCE_HATE = "O1"
+ POLICY__NO_SEXUAL_CONTENT = "O2"
+ POLICY__NO_CRIMINAL_PLANNING = "O3"
+ POLICY__NO_GUNS_AND_ILLEGAL_WEAPONS = "O4"
+ POLICY__NO_ILLEGAL_DRUGS = "O5"
+ POLICY__NO_ENOURAGE_SELF_HARM = "O6"
+
def __init__(
self,
- arg_1: str,
- arg_2: str,
+ policies: Optional[List[str]] = None,
+ validation_method: Optional[str] = "full",
on_fail: Optional[Callable] = None,
+ **kwargs,
):
- super().__init__(on_fail=on_fail, arg_1=arg_1, arg_2=arg_2)
- self._arg_1 = arg_1
- self._arg_2 = arg_2
+
+ super().__init__(
+ on_fail=on_fail,
+ validation_method=validation_method,
+ **kwargs,
+ )
+
+ self._policies = policies
+
def validate(self, value: Any, metadata: Dict = {}) -> ValidationResult:
- """Validates that {fill in how you validator interacts with the passed value}."""
- # Add your custom validator logic here and return a PassResult or FailResult accordingly.
- if value != "pass": # FIXME
+
+ if not value:
+ raise ValueError("Value cannot be empty.")
+
+ (classification, subclass) = self._inference(value)
+
+ is_unsafe = classification == "unsafe"
+
+ if is_unsafe:
+ error_span = ErrorSpan(
+ start=0,
+ end=len(value),
+ reason=f"Unsafe content: {value}",
+ )
+
+ # iterate over self to find any POLICY__* attributes
+
+ find_policy_violated = next(
+ (policy_key for policy_key in [
+ "POLICY__NO_VIOLENCE_HATE",
+ "POLICY__NO_CRIMINAL_PLANNING",
+ "POLICY__NO_GUNS_AND_ILLEGAL_WEAPONS",
+ "POLICY__NO_ILLEGAL_DRUGS",
+ "POLICY__NO_ENOURAGE_SELF_HARM",
+ "POLICY__NO_SEXUAL_CONTENT"
+ ] if getattr(self,policy_key) == subclass),
+ None
+ )
return FailResult(
- error_message="{A descriptive but concise error message about why validation failed}",
- fix_value="{The programmtic fix if applicable, otherwise remove this kwarg.}",
+ error_message=(
+ f"Prompt contains unsafe content. Classification: {classification}, Violated Policy: {find_policy_violated}"
+ ),
+ error_spans=[error_span],
)
- return PassResult()
+ else:
+ return PassResult()
+
+
+ def _inference_local(self, value: str):
+ raise NotImplementedError("Local inference is not supported for LlamaGuard7B validator.")
+
+ def _inference_remote(self, value: str) -> ValidationResult:
+ """Remote inference method for this validator."""
+ request_body = {
+ "policies": self._policies,
+ "chat": [
+ {
+ "role": "user",
+ "content": value
+ }
+ ]
+ }
+
+ response = self._hub_inference_request(json.dumps(request_body), self.validation_endpoint)
+
+ status = response.get("status")
+ if status != 200:
+ detail = response.get("response",{}).get("detail", "Unknown error")
+ raise ValueError(f"Failed to get valid response from Llamaguard-7B model. Status: {status}. Detail: {detail}")
+
+ response_data = response.get("response")
+
+ classification = response_data.get("class")
+ subclass = response_data.get("subclass")
+
+ return (classification, subclass)