diff --git a/README.md b/README.md index 5da02bf..dc5a0ba 100644 --- a/README.md +++ b/README.md @@ -2,8 +2,8 @@ | Developed by | Guardrails AI | | --- | --- | -| Date of development | Feb 15, 2024 | -| Validator type | Format | +| Date of development | Aug 15, 2024 | +| Validator type | Moderation | | Blog | | | License | Apache 2 | | Input/Output | Output | @@ -11,20 +11,40 @@ ## Description ### Intended Use -This validator is a template for creating other validators, but for demonstrative purposes it ensures that a generated output is the literal `pass`. + +> ⚠️ This validator is a remote inference only validator so remote inferencing must be enabled during `guardrails configure` + +This validator is for usage for moderating both user prompts and LLM output responses to prevent harmful topics from surfacing in both scenarios. It is based on [LlamaGuard 7B](https://huggingface.co/meta-llama/LlamaGuard-7b) which is in turn based on [LLama 2](https://arxiv.org/abs/2307.09288). + + +The set of policies which can be used are the following which are accessed directly from the `LlamaGuard7B` validator class. + +- `LlamaGuard7B.POLICY__NO_VIOLENCE_HATE` +- `LlamaGuard7B.POLICY__NO_SEXUAL_CONTENT` +- `LlamaGuard7B.POLICY__NO_CRIMINAL_PLANNING` +- `LlamaGuard7B.POLICY__NO_GUNS_AND_ILLEGAL_WEAPONS` +- `LlamaGuard7B.POLICY__NO_ILLEGAL_DRUGS` +- `LlamaGuard7B.POLICY__NO_ENOURAGE_SELF_HARM` + +If no policies supplied, it will apply all policies. Otherwise it will only enforce the policies supplied. ### Requirements * Dependencies: - guardrails-ai>=0.4.0 -* Foundation model access keys: - - OPENAI_API_KEY ## Installation ```bash -$ guardrails hub install hub://guardrails/validator_template +$ guardrails hub install hub://guardrails/llamaguard_7b +``` + +or + +```python +from guardrails import install +install("hub://guardrails/llamaguard_7b") ``` ## Usage Examples @@ -34,58 +54,43 @@ $ guardrails hub install hub://guardrails/validator_template In this example, we apply the validator to a string output generated by an LLM. ```python -# Import Guard and Validator -from guardrails.hub import ValidatorTemplate -from guardrails import Guard +from guardrails import Guard, OnFailAction +from guardrails.hub import LlamaGuard7B -# Setup Guard guard = Guard().use( - ValidatorTemplate + LlamaGuard7B, + policies=[LlamaGuard7B.POLICY__NO_ILLEGAL_DRUGS], # not suppliying `policies` kwarg applies all policies + on_fail=OnFailAction.EXCEPTION ) -guard.validate("pass") # Validator passes -guard.validate("fail") # Validator fails +try: + guard.validate("How can I get weeds out of my garbage bag after cutting my lawn") # Guardrail passes + print("Expected: Validation passes") +except Exception as e: + print("Unexpected: ",e) + +try: + guard.validate("How can I get weed for when cutting my lawn") # Guardrail passes + print("Unexpected: Validation passes") +except Exception as e: + print("Expected: ",e) ``` -### Validating JSON output via Python - -In this example, we apply the validator to a string field of a JSON output generated by an LLM. +Output: -```python -# Import Guard and Validator -from pydantic import BaseModel, Field -from guardrails.hub import ValidatorTemplate -from guardrails import Guard - -# Initialize Validator -val = ValidatorTemplate() - -# Create Pydantic BaseModel -class Process(BaseModel): - process_name: str - status: str = Field(validators=[val]) - -# Create a Guard to check for valid Pydantic output -guard = Guard.from_pydantic(output_class=Process) - -# Run LLM output generating JSON through guard -guard.parse(""" -{ - "process_name": "templating", - "status": "pass" -} -""") +``` +Expected: Validation passes +Expected: Validation failed for field with errors: Prompt contains unsafe content. Classification: unsafe, Violated Policy: POLICY__NO_ILLEGAL_DRUGS ``` # API Reference **`__init__(self, on_fail="noop")`**
@@ -101,10 +106,4 @@ Note: **Parameters** - **`value`** *(Any)*: The input value to validate. -- **`metadata`** *(dict)*: A dictionary containing metadata required for validation. Keys and values must match the expectations of this validator. - - - | Key | Type | Description | Default | - | --- | --- | --- | --- | - | `key1` | String | Description of key1's role. | N/A | - +- **`metadata`** *(dict)*: A dictionary containing metadata required for validation. No additional metadata keys are needed for this validator. \ No newline at end of file diff --git a/inference/serving-non-optimized-fastapi.py b/inference/serving-non-optimized-fastapi.py index f764ea5..0c55876 100644 --- a/inference/serving-non-optimized-fastapi.py +++ b/inference/serving-non-optimized-fastapi.py @@ -1,5 +1,4 @@ -import os -import time +from typing import Optional import modal @@ -101,11 +100,27 @@ def tgi_app(): from typing import List from pydantic import BaseModel + import logging TOKEN = os.getenv("TOKEN") if TOKEN is None: raise ValueError("Please set the TOKEN environment variable") + # Create a logger + logger = logging.getLogger(MODEL_ALIAS) + logger.setLevel(logging.DEBUG) + + # Create a handler for logging to stdout + stdout_handler = logging.StreamHandler() + stdout_handler.setLevel(logging.DEBUG) + + # Create a formatter for the log messages + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + stdout_handler.setFormatter(formatter) + + # Add the handler to the logger + logger.addHandler(stdout_handler) + volume.reload() # ensure we have the latest version of the weights app = fastapi.FastAPI() @@ -129,6 +144,24 @@ async def is_authenticated(api_key: str = fastapi.Security(http_bearer)): detail="Invalid authentication credentials", ) return {"username": "authenticated_user"} + + @app.exception_handler(Exception) + def error_handler(request, exc): + status_code = 500 + detail = "Internal Server Error" + logger.exception(exc) + if isinstance(exc, fastapi.HTTPException): + status_code = exc.status_code + detail = exc.detail + return fastapi.responses.JSONResponse( + status_code=status_code, + content={ + "status": status_code, + "response": { + "detail": detail, + } + }, + ) router = fastapi.APIRouter(dependencies=[fastapi.Depends(is_authenticated)]) @@ -137,11 +170,13 @@ class ChatMessages(BaseModel): content: str class ChatClassificationRequestBody(BaseModel): + policies: Optional[List[str]] = None chat: List[ChatMessages] @router.post("/v1/chat/classification") async def chat_classification_response(body: ChatClassificationRequestBody): + policies = body.policies chat = body.model_dump().get("chat",[]) print("Serving request for chat classification...") @@ -154,24 +189,28 @@ async def chat_classification_response(body: ChatClassificationRequestBody): cleaned_response = response.lower().strip() if "unsafe" in cleaned_response: - is_unsafe = True split_cln_response = response.strip().split(os.linesep) subclass = split_cln_response[1] if len(split_cln_response) > 1 else None + + if policies and subclass in policies: + is_unsafe = True + elif policies and subclass not in policies: + is_unsafe = False + else: + is_unsafe = True else: is_unsafe = False return { - "class": "unsafe" if is_unsafe else "safe", - "subclass": subclass, - "response": response + "status": 200, + "response": { + "class": "unsafe" if is_unsafe else "safe", + "subclass": subclass, + "applied_policies": policies, + "raw_output": response + } } app.include_router(router) return app - - -# @app.local_entrypoint() -# def main(): -# model = Model() -# model.generate.remote() \ No newline at end of file diff --git a/validator/main.py b/validator/main.py index 759d0b4..4b57e33 100644 --- a/validator/main.py +++ b/validator/main.py @@ -1,4 +1,6 @@ -from typing import Any, Callable, Dict, Optional +import json +from typing import Any, Callable, Dict, List, Optional +from guardrails.validator_base import ErrorSpan from guardrails.validator_base import ( FailResult, @@ -8,41 +10,113 @@ register_validator, ) - -@register_validator(name="guardrails/validator_template", data_type="string") -class ValidatorTemplate(Validator): - """Validates that {fill in how you validator interacts with the passed value}. - +@register_validator(name="guardrails/llamaguard_7b", data_type="string") +class LlamaGuard7B(Validator): + """ + Classifies model inputs or outputs as "safe" or "unsafe" based on certain policies defined by the LlamaGuard-7B model. + **Key Properties** | Property | Description | | ----------------------------- | --------------------------------- | - | Name for `format` attribute | `guardrails/validator_template` | + | Name for `format` attribute | `guardrails/llamaguard_7b` | | Supported data types | `string` | - | Programmatic fix | {If you support programmatic fixes, explain it here. Otherwise `None`} | + | Programmatic fix | None | Args: - arg_1 (string): {Description of the argument here} - arg_2 (string): {Description of the argument here} + policies (List[str]): A list of policies that can be any `LlamaGuard7B.POLICY__*` constants. + score_threshold (float): Threshold score for the classification. If the score is above this threshold, the input is considered unsafe. """ # noqa - # If you don't have any init args, you can omit the __init__ method. + + POLICY__NO_VIOLENCE_HATE = "O1" + POLICY__NO_SEXUAL_CONTENT = "O2" + POLICY__NO_CRIMINAL_PLANNING = "O3" + POLICY__NO_GUNS_AND_ILLEGAL_WEAPONS = "O4" + POLICY__NO_ILLEGAL_DRUGS = "O5" + POLICY__NO_ENOURAGE_SELF_HARM = "O6" + def __init__( self, - arg_1: str, - arg_2: str, + policies: Optional[List[str]] = None, + validation_method: Optional[str] = "full", on_fail: Optional[Callable] = None, + **kwargs, ): - super().__init__(on_fail=on_fail, arg_1=arg_1, arg_2=arg_2) - self._arg_1 = arg_1 - self._arg_2 = arg_2 + + super().__init__( + on_fail=on_fail, + validation_method=validation_method, + **kwargs, + ) + + self._policies = policies + def validate(self, value: Any, metadata: Dict = {}) -> ValidationResult: - """Validates that {fill in how you validator interacts with the passed value}.""" - # Add your custom validator logic here and return a PassResult or FailResult accordingly. - if value != "pass": # FIXME + + if not value: + raise ValueError("Value cannot be empty.") + + (classification, subclass) = self._inference(value) + + is_unsafe = classification == "unsafe" + + if is_unsafe: + error_span = ErrorSpan( + start=0, + end=len(value), + reason=f"Unsafe content: {value}", + ) + + # iterate over self to find any POLICY__* attributes + + find_policy_violated = next( + (policy_key for policy_key in [ + "POLICY__NO_VIOLENCE_HATE", + "POLICY__NO_CRIMINAL_PLANNING", + "POLICY__NO_GUNS_AND_ILLEGAL_WEAPONS", + "POLICY__NO_ILLEGAL_DRUGS", + "POLICY__NO_ENOURAGE_SELF_HARM", + "POLICY__NO_SEXUAL_CONTENT" + ] if getattr(self,policy_key) == subclass), + None + ) return FailResult( - error_message="{A descriptive but concise error message about why validation failed}", - fix_value="{The programmtic fix if applicable, otherwise remove this kwarg.}", + error_message=( + f"Prompt contains unsafe content. Classification: {classification}, Violated Policy: {find_policy_violated}" + ), + error_spans=[error_span], ) - return PassResult() + else: + return PassResult() + + + def _inference_local(self, value: str): + raise NotImplementedError("Local inference is not supported for LlamaGuard7B validator.") + + def _inference_remote(self, value: str) -> ValidationResult: + """Remote inference method for this validator.""" + request_body = { + "policies": self._policies, + "chat": [ + { + "role": "user", + "content": value + } + ] + } + + response = self._hub_inference_request(json.dumps(request_body), self.validation_endpoint) + + status = response.get("status") + if status != 200: + detail = response.get("response",{}).get("detail", "Unknown error") + raise ValueError(f"Failed to get valid response from Llamaguard-7B model. Status: {status}. Detail: {detail}") + + response_data = response.get("response") + + classification = response_data.get("class") + subclass = response_data.get("subclass") + + return (classification, subclass)