Skip to content

Commit

Permalink
Merge pull request #12 from guardrails-ai/ml_setup
Browse files Browse the repository at this point in the history
ML Endpoint Setup
  • Loading branch information
CalebCourier authored Aug 7, 2024
2 parents aecf97f + 9718545 commit 8714b46
Showing 1 changed file with 63 additions and 25 deletions.
88 changes: 63 additions & 25 deletions validator/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
from transformers import pipeline


@register_validator(name="tryolabs/restricttotopic", data_type="string")
@register_validator(
name="tryolabs/restricttotopic", data_type="string", has_guardrails_endpoint=True
)
class RestrictToTopic(Validator):
"""Checks if text's main topic is specified within a list of valid topics
and ensures that the text is not about any of the invalid topics.
Expand Down Expand Up @@ -86,6 +88,7 @@ def __init__(
on_fail: Optional[Callable[..., Any]] = None,
zero_shot_threshold: Optional[float] = 0.5,
llm_threshold: Optional[int] = 3,
**kwargs,
):
super().__init__(
valid_topics=valid_topics,
Expand All @@ -99,9 +102,9 @@ def __init__(
on_fail=on_fail,
zero_shot_threshold=zero_shot_threshold,
llm_threshold=llm_threshold,
**kwargs,
)
self._valid_topics = valid_topics

if invalid_topics is None:
self._invalid_topics = []
else:
Expand All @@ -126,7 +129,7 @@ def __init__(
raise ValueError("llm_threshold must be a number between 0 and 5")
self.set_callable(llm_callable)

if self._classifier_api_endpoint is None:
if self._classifier_api_endpoint is None and self.use_local:
self._classifier = pipeline(
"zero-shot-classification",
model=self._model,
Expand All @@ -150,7 +153,7 @@ def get_topics_ensemble(self, text: str, candidate_topics: List[str]) -> List[st
List[str]: The found topics
"""
# Find topics based on zero shot model
zero_shot_topics = self.get_topics_zero_shot(text, candidate_topics)
zero_shot_topics = self._inference({"text": text, "valid_topics": candidate_topics, "invalid_topics": []})

# Find topics based on llm
llm_topics = self.get_topics_llm(text, candidate_topics)
Expand Down Expand Up @@ -269,25 +272,6 @@ def openai_callable(text: str, topics: List[str]) -> str:
else:
raise ValueError("llm_callable must be a string or a Callable")

def get_topics_zero_shot(self, text: str, candidate_topics: List[str]) -> List[str]:
"""Gets the topics found through the zero shot classifier
Args:
text (str): The text to classify
candidate_topics (List[str]): The potential topics to look for
Returns:
List[str]: The resulting topics found that meet the given threshold
"""
result = self._classifier(text, candidate_topics)
topics = result["labels"]
scores = result["scores"]
found_topics = []
for topic, score in zip(topics, scores):
if score > self._zero_shot_threshold:
found_topics.append(topic)
return found_topics

def validate(
self, value: str, metadata: Optional[Dict[str, Any]] = {}
) -> ValidationResult:
Expand Down Expand Up @@ -316,10 +300,17 @@ def validate(
"`valid_topics` must be set and contain at least one topic."
)


# throw if valid and invalid topics are not disjoint
if bool(valid_topics.intersection(invalid_topics)):
raise ValueError("A topic cannot be valid and invalid at the same time.")

model_input = {
"text": value,
"valid_topics": self._valid_topics,
"invalid_topics": self._invalid_topics
}

# Ensemble method
if not self._disable_classifier and not self._disable_llm:
found_topics = self.get_topics_ensemble(value, all_topics)
Expand All @@ -328,10 +319,10 @@ def validate(
found_topics = self.get_topics_llm(value, all_topics)
# Zero Shot Classifier Only
elif not self._disable_classifier and self._disable_llm:
found_topics = self.get_topics_zero_shot(value, all_topics)
found_topics = self._inference(model_input)
else:
raise ValueError("Either classifier or llm must be enabled.")

# Determine if valid or invalid topics were found
invalid_topics_found = []
valid_topics_found = []
Expand All @@ -350,3 +341,50 @@ def validate(
return FailResult(error_message="No valid topic was found.")

return PassResult()

def _inference_local(self, model_input: Any) -> Any:
"""Local inference method for the restrict-to-topic validator."""
text = model_input["text"]
candidate_topics = model_input["valid_topics"] + model_input["invalid_topics"]

result = self._classifier(text, candidate_topics)
topics = result["labels"]
scores = result["scores"]
found_topics = []
for topic, score in zip(topics, scores):
if score > self._zero_shot_threshold:
found_topics.append(topic)
return found_topics


def _inference_remote(self, model_input: Any) -> Any:
"""Remote inference method for the restrict-to-topic validator."""
request_body = {
"inputs": [
{
"name": "text",
"shape": [1],
"data": [model_input["text"]],
"datatype": "BYTES"
},
{
"name": "candidate_topics",
"shape": [len(model_input["valid_topics"]) + len(model_input["invalid_topics"])],
"data": model_input["valid_topics"] + model_input["invalid_topics"],
"datatype": "BYTES"
},
{
"name": "zero_shot_threshold",
"shape": [1],
"data": [self._zero_shot_threshold],
"datatype": "FP32"
}
]
}

response = self._hub_inference_request(json.dumps(request_body), self.validation_endpoint)

if not response or "outputs" not in response:
raise ValueError("Invalid response from remote inference", response)

return response["outputs"][0]["data"]

0 comments on commit 8714b46

Please sign in to comment.