diff --git a/validator/main.py b/validator/main.py index 8a796cc..c150063 100644 --- a/validator/main.py +++ b/validator/main.py @@ -16,7 +16,9 @@ from transformers import pipeline -@register_validator(name="tryolabs/restricttotopic", data_type="string") +@register_validator( + name="tryolabs/restricttotopic", data_type="string", has_guardrails_endpoint=True +) class RestrictToTopic(Validator): """Checks if text's main topic is specified within a list of valid topics and ensures that the text is not about any of the invalid topics. @@ -86,6 +88,7 @@ def __init__( on_fail: Optional[Callable[..., Any]] = None, zero_shot_threshold: Optional[float] = 0.5, llm_threshold: Optional[int] = 3, + **kwargs, ): super().__init__( valid_topics=valid_topics, @@ -99,9 +102,9 @@ def __init__( on_fail=on_fail, zero_shot_threshold=zero_shot_threshold, llm_threshold=llm_threshold, + **kwargs, ) self._valid_topics = valid_topics - if invalid_topics is None: self._invalid_topics = [] else: @@ -126,7 +129,7 @@ def __init__( raise ValueError("llm_threshold must be a number between 0 and 5") self.set_callable(llm_callable) - if self._classifier_api_endpoint is None: + if self._classifier_api_endpoint is None and self.use_local: self._classifier = pipeline( "zero-shot-classification", model=self._model, @@ -150,7 +153,7 @@ def get_topics_ensemble(self, text: str, candidate_topics: List[str]) -> List[st List[str]: The found topics """ # Find topics based on zero shot model - zero_shot_topics = self.get_topics_zero_shot(text, candidate_topics) + zero_shot_topics = self._inference({"text": text, "valid_topics": candidate_topics, "invalid_topics": []}) # Find topics based on llm llm_topics = self.get_topics_llm(text, candidate_topics) @@ -269,25 +272,6 @@ def openai_callable(text: str, topics: List[str]) -> str: else: raise ValueError("llm_callable must be a string or a Callable") - def get_topics_zero_shot(self, text: str, candidate_topics: List[str]) -> List[str]: - """Gets the topics found through the zero shot classifier - - Args: - text (str): The text to classify - candidate_topics (List[str]): The potential topics to look for - - Returns: - List[str]: The resulting topics found that meet the given threshold - """ - result = self._classifier(text, candidate_topics) - topics = result["labels"] - scores = result["scores"] - found_topics = [] - for topic, score in zip(topics, scores): - if score > self._zero_shot_threshold: - found_topics.append(topic) - return found_topics - def validate( self, value: str, metadata: Optional[Dict[str, Any]] = {} ) -> ValidationResult: @@ -316,10 +300,17 @@ def validate( "`valid_topics` must be set and contain at least one topic." ) + # throw if valid and invalid topics are not disjoint if bool(valid_topics.intersection(invalid_topics)): raise ValueError("A topic cannot be valid and invalid at the same time.") + model_input = { + "text": value, + "valid_topics": self._valid_topics, + "invalid_topics": self._invalid_topics + } + # Ensemble method if not self._disable_classifier and not self._disable_llm: found_topics = self.get_topics_ensemble(value, all_topics) @@ -328,10 +319,10 @@ def validate( found_topics = self.get_topics_llm(value, all_topics) # Zero Shot Classifier Only elif not self._disable_classifier and self._disable_llm: - found_topics = self.get_topics_zero_shot(value, all_topics) + found_topics = self._inference(model_input) else: raise ValueError("Either classifier or llm must be enabled.") - + # Determine if valid or invalid topics were found invalid_topics_found = [] valid_topics_found = [] @@ -350,3 +341,50 @@ def validate( return FailResult(error_message="No valid topic was found.") return PassResult() + + def _inference_local(self, model_input: Any) -> Any: + """Local inference method for the restrict-to-topic validator.""" + text = model_input["text"] + candidate_topics = model_input["valid_topics"] + model_input["invalid_topics"] + + result = self._classifier(text, candidate_topics) + topics = result["labels"] + scores = result["scores"] + found_topics = [] + for topic, score in zip(topics, scores): + if score > self._zero_shot_threshold: + found_topics.append(topic) + return found_topics + + + def _inference_remote(self, model_input: Any) -> Any: + """Remote inference method for the restrict-to-topic validator.""" + request_body = { + "inputs": [ + { + "name": "text", + "shape": [1], + "data": [model_input["text"]], + "datatype": "BYTES" + }, + { + "name": "candidate_topics", + "shape": [len(model_input["valid_topics"]) + len(model_input["invalid_topics"])], + "data": model_input["valid_topics"] + model_input["invalid_topics"], + "datatype": "BYTES" + }, + { + "name": "zero_shot_threshold", + "shape": [1], + "data": [self._zero_shot_threshold], + "datatype": "FP32" + } + ] + } + + response = self._hub_inference_request(json.dumps(request_body), self.validation_endpoint) + + if not response or "outputs" not in response: + raise ValueError("Invalid response from remote inference", response) + + return response["outputs"][0]["data"] \ No newline at end of file