Skip to content

Commit

Permalink
added initial validator
Browse files Browse the repository at this point in the history
  • Loading branch information
AlejandroEsquivel committed Aug 18, 2024
1 parent 9422670 commit 8c31d98
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 33 deletions.
63 changes: 51 additions & 12 deletions inference/serving-non-optimized-fastapi.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import time
from typing import Optional

import modal

Expand Down Expand Up @@ -101,11 +100,27 @@ def tgi_app():

from typing import List
from pydantic import BaseModel
import logging

TOKEN = os.getenv("TOKEN")
if TOKEN is None:
raise ValueError("Please set the TOKEN environment variable")

# Create a logger
logger = logging.getLogger(MODEL_ALIAS)
logger.setLevel(logging.DEBUG)

# Create a handler for logging to stdout
stdout_handler = logging.StreamHandler()
stdout_handler.setLevel(logging.DEBUG)

# Create a formatter for the log messages
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
stdout_handler.setFormatter(formatter)

# Add the handler to the logger
logger.addHandler(stdout_handler)

volume.reload() # ensure we have the latest version of the weights

app = fastapi.FastAPI()
Expand All @@ -129,6 +144,24 @@ async def is_authenticated(api_key: str = fastapi.Security(http_bearer)):
detail="Invalid authentication credentials",
)
return {"username": "authenticated_user"}

@app.exception_handler(Exception)
def error_handler(request, exc):
status_code = 500
detail = "Internal Server Error"
logger.exception(exc)
if isinstance(exc, fastapi.HTTPException):
status_code = exc.status_code
detail = exc.detail
return fastapi.responses.JSONResponse(
status_code=status_code,
content={
"status": status_code,
"response": {
"detail": detail,
}
},
)

router = fastapi.APIRouter(dependencies=[fastapi.Depends(is_authenticated)])

Expand All @@ -137,11 +170,13 @@ class ChatMessages(BaseModel):
content: str

class ChatClassificationRequestBody(BaseModel):
policies: Optional[List[str]] = None
chat: List[ChatMessages]


@router.post("/v1/chat/classification")
async def chat_classification_response(body: ChatClassificationRequestBody):
policies = body.policies
chat = body.model_dump().get("chat",[])

print("Serving request for chat classification...")
Expand All @@ -154,24 +189,28 @@ async def chat_classification_response(body: ChatClassificationRequestBody):
cleaned_response = response.lower().strip()

if "unsafe" in cleaned_response:
is_unsafe = True
split_cln_response = response.strip().split(os.linesep)
subclass = split_cln_response[1] if len(split_cln_response) > 1 else None

if policies and subclass in policies:
is_unsafe = True
elif policies and subclass not in policies:
is_unsafe = False
else:
is_unsafe = True
else:
is_unsafe = False

return {
"class": "unsafe" if is_unsafe else "safe",
"subclass": subclass,
"response": response
"status": 200,
"response": {
"class": "unsafe" if is_unsafe else "safe",
"subclass": subclass,
"applied_policies": policies,
"raw_output": response
}
}


app.include_router(router)
return app


# @app.local_entrypoint()
# def main():
# model = Model()
# model.generate.remote()
112 changes: 91 additions & 21 deletions validator/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import Any, Callable, Dict, Optional
import json
from typing import Any, Callable, Dict, List, Optional
from enum import Enum
from guardrails.validator_base import ErrorSpan

from guardrails.validator_base import (
FailResult,
Expand All @@ -7,42 +10,109 @@
Validator,
register_validator,
)
from guardrails.logger import logger

class Policies(str, Enum):
NO_VIOLENCE_HATE = "O1"
NO_SEXUAL_CONTENT = "O2"
NO_CRIMINAL_PLANNING = "O3"
NO_GUNS_AND_ILLEGAL_WEAPONS = "O4"
NO_ILLEGAL_DRUGS = "O5"
NO_ENOURAGE_SELF_HARM = "O6"

@register_validator(name="guardrails/validator_template", data_type="string")
class ValidatorTemplate(Validator):
"""Validates that {fill in how you validator interacts with the passed value}.
@register_validator(name="guardrails/llamaguard_7b", data_type="string")
class LlamaGuard7B(Validator):
"""
Classifies model inputs or outputs as "safe" or "unsafe" based on certain policies defined by the LlamaGuard-7B model.
**Key Properties**
| Property | Description |
| ----------------------------- | --------------------------------- |
| Name for `format` attribute | `guardrails/validator_template` |
| Name for `format` attribute | `guardrails/llamaguard_7b` |
| Supported data types | `string` |
| Programmatic fix | {If you support programmatic fixes, explain it here. Otherwise `None`} |
| Programmatic fix | None |
Args:
arg_1 (string): {Description of the argument here}
arg_2 (string): {Description of the argument here}
policies (List[Policies]): List of LlamaGuard7B.Policies enum values to enforce.
score_threshold (float): Threshold score for the classification. If the score is above this threshold, the input is considered unsafe.
""" # noqa

# If you don't have any init args, you can omit the __init__ method.
Policies = Policies

def __init__(
self,
arg_1: str,
arg_2: str,
policies: Optional[List[Policies]] = None,
validation_method: Optional[str] = "full",
on_fail: Optional[Callable] = None,
):
super().__init__(on_fail=on_fail, arg_1=arg_1, arg_2=arg_2)
self._arg_1 = arg_1
self._arg_2 = arg_2

super().__init__(
on_fail=on_fail,
validation_method=validation_method,
)

try:
self._policies = [policy.value for policy in policies] if policies else []
except AttributeError as e:
raise ValueError("Invalid policies provided. Please provide a list of LlamaGuard7B.Policies enum values.") from e


def validate(self, value: Any, metadata: Dict = {}) -> ValidationResult:
"""Validates that {fill in how you validator interacts with the passed value}."""
# Add your custom validator logic here and return a PassResult or FailResult accordingly.
if value != "pass": # FIXME

if not value:
raise ValueError("Value cannot be empty.")

(classification, subclass) = self._inference(value)

is_unsafe = classification == "unsafe"

if is_unsafe:
error_span = ErrorSpan(
start=0,
end=len(value),
reason=f"Unsafe content: {value}",
)

find_policy_violated = next(
(policy for policy in self.Policies if policy.value == subclass),
None
)
return FailResult(
error_message="{A descriptive but concise error message about why validation failed}",
fix_value="{The programmtic fix if applicable, otherwise remove this kwarg.}",
error_message=(
f"Prompt contains unsafe content. Classification: {classification}, Violated Policy: {find_policy_violated}"
),
error_spans=[error_span],
)
return PassResult()
else:
return PassResult()


def _inference_local(self, value: str):
raise NotImplementedError("Local inference is not supported for LlamaGuard7B validator.")

def _inference_remote(self, value: str) -> ValidationResult:
"""Remote inference method for this validator."""
request_body = {
"policies": self._policies,
"chat": [
{
"role": "user",
"content": value
}
]
}

response = self._hub_inference_request(json.dumps(request_body), self.validation_endpoint)

status = response.get("status")
if status != 200:
detail = response.get("response",{}).get("detail", "Unknown error")
raise ValueError(f"Failed to get valid response from Llamaguard-7B model. Status: {status}. Detail: {detail}")

response_data = response.get("response")

classification = response_data.get("class")
subclass = response_data.get("subclass")

return (classification, subclass)

0 comments on commit 8c31d98

Please sign in to comment.