From eb79ab73f412844d63f457d3220b794ccf3f88ec Mon Sep 17 00:00:00 2001 From: Wyatt Lansford <22553069+wylansford@users.noreply.github.com> Date: Wed, 29 May 2024 12:07:02 -0700 Subject: [PATCH 01/25] optimize to a single call instead of multiple calls per topic --- validator/main.py | 53 +++++++++++++++++++++++++---------------------- 1 file changed, 28 insertions(+), 25 deletions(-) diff --git a/validator/main.py b/validator/main.py index c992c9c..733022b 100644 --- a/validator/main.py +++ b/validator/main.py @@ -97,7 +97,7 @@ def __init__( else: self._invalid_topics = invalid_topics - self._device = to_int(device) + self._device = device if device == "mps" else to_int(device) self._model = model self._disable_classifier = disable_classifier self._disable_llm = disable_llm @@ -112,12 +112,16 @@ def __init__( def get_topic_ensemble( self, text: str, candidate_topics: List[str] ) -> ValidationResult: - topic, confidence = self.get_topic_zero_shot(text, candidate_topics) - - if confidence > self._model_threshold: - return self.verify_topic(topic) - else: - return self.get_topic_llm(text, candidate_topics) + topics, scores = self.get_topic_zero_shot(text, candidate_topics) + succesfully_on_topic = [] + for score, topic in zip(scores, topics): + if score > self._model_threshold and topic in self._valid_topics: + succesfully_on_topic.append(topic) + if score > self._model_threshold and topic in self._invalid_topics: + return FailResult(error_message=f"Invalid {topic} was found to be relevant.") + if not succesfully_on_topic: + return FailResult(error_message="No valid topic was found.") + return self.get_topic_llm(text, candidate_topics) def get_topic_llm(self, text: str, candidate_topics: List[str]) -> ValidationResult: response = self.call_llm(text, candidate_topics) @@ -207,13 +211,14 @@ def get_topic_zero_shot( classifier = pipeline( "zero-shot-classification", model=self._model, - device=self._device, + device="mps", hypothesis_template="This example has to do with topic {}.", + multi_label=True ) result = classifier(text, candidate_topics) - topic = result["labels"][0] - score = result["scores"][0] - return topic, score + topics = result["labels"] + scores = result["scores"] + return topics, scores def validate( self, value: str, metadata: Optional[Dict[str, Any]] = {} @@ -231,12 +236,6 @@ def validate( if bool(valid_topics.intersection(invalid_topics)): raise ValueError("A topic cannot be valid and invalid at the same time.") - # Add 'other' to the invalid topics list - if "other" not in invalid_topics: - invalid_topics.add("other") - - # Combine valid and invalid topics - candidate_topics = valid_topics.union(invalid_topics) # Check which model(s) to use if self._disable_classifier and self._disable_llm: # Error, no model set @@ -244,14 +243,18 @@ def validate( elif ( not self._disable_classifier and not self._disable_llm ): # Use ensemble (Zero-Shot + Ensemble) - return self.get_topic_ensemble(value, list(candidate_topics)) + return self.get_topic_ensemble(value, list(invalid_topics)) elif self._disable_classifier and not self._disable_llm: # Use only LLM - return self.get_topic_llm(value, list(candidate_topics)) + return self.get_topic_llm(value, list(invalid_topics)) # Use only Zero-Shot - topic, _score = self.get_topic_zero_shot(value, list(candidate_topics)) - - if _score > self._model_threshold: - return self.verify_topic(topic) - else: - return self.verify_topic("other") + topics, scores = self.get_topic_zero_shot(value, list(invalid_topics) + list(valid_topics)) + succesfully_on_topic = [] + for score, topic in zip(scores, topics): + if score > self._model_threshold and topic in self._valid_topics: + succesfully_on_topic.append(topic) + if score > self._model_threshold and topic in self._invalid_topics: + return FailResult(error_message=f"Invalid {topic} was found to be relevant.") + if not succesfully_on_topic: + return FailResult(error_message="No valid topic was found.") + return PassResult() \ No newline at end of file From 1fcc8be2fbdc2754525a867a72e004b74cffb594 Mon Sep 17 00:00:00 2001 From: Wyatt Lansford <22553069+wylansford@users.noreply.github.com> Date: Wed, 29 May 2024 12:08:56 -0700 Subject: [PATCH 02/25] change device --- validator/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/validator/main.py b/validator/main.py index 733022b..460daeb 100644 --- a/validator/main.py +++ b/validator/main.py @@ -97,7 +97,7 @@ def __init__( else: self._invalid_topics = invalid_topics - self._device = device if device == "mps" else to_int(device) + self._device = device if device in ["cpu", "mps"] else to_int(device) self._model = model self._disable_classifier = disable_classifier self._disable_llm = disable_llm @@ -211,7 +211,7 @@ def get_topic_zero_shot( classifier = pipeline( "zero-shot-classification", model=self._model, - device="mps", + device=self._device, hypothesis_template="This example has to do with topic {}.", multi_label=True ) From 3ff10e568949c39d4e937aa497a7bebff1260f90 Mon Sep 17 00:00:00 2001 From: Wyatt Lansford <22553069+wylansford@users.noreply.github.com> Date: Wed, 29 May 2024 12:48:14 -0700 Subject: [PATCH 03/25] updating with error fix --- validator/main.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/validator/main.py b/validator/main.py index 460daeb..4e3700c 100644 --- a/validator/main.py +++ b/validator/main.py @@ -97,7 +97,7 @@ def __init__( else: self._invalid_topics = invalid_topics - self._device = device if device in ["cpu", "mps"] else to_int(device) + self._device = device if device == "mps" else to_int(device) self._model = model self._disable_classifier = disable_classifier self._disable_llm = disable_llm @@ -118,7 +118,7 @@ def get_topic_ensemble( if score > self._model_threshold and topic in self._valid_topics: succesfully_on_topic.append(topic) if score > self._model_threshold and topic in self._invalid_topics: - return FailResult(error_message=f"Invalid {topic} was found to be relevant.") + return FailResult(error_message=f"Invalid topic {topic} was found to be relevant.") if not succesfully_on_topic: return FailResult(error_message="No valid topic was found.") return self.get_topic_llm(text, candidate_topics) @@ -211,7 +211,7 @@ def get_topic_zero_shot( classifier = pipeline( "zero-shot-classification", model=self._model, - device=self._device, + device="mps", hypothesis_template="This example has to do with topic {}.", multi_label=True ) From f283425687117a051fc25bf143fd49381d53ff22 Mon Sep 17 00:00:00 2001 From: Wyatt Lansford <22553069+wylansford@users.noreply.github.com> Date: Wed, 29 May 2024 12:49:30 -0700 Subject: [PATCH 04/25] revert device --- validator/main.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/validator/main.py b/validator/main.py index 4e3700c..3a48a50 100644 --- a/validator/main.py +++ b/validator/main.py @@ -97,7 +97,7 @@ def __init__( else: self._invalid_topics = invalid_topics - self._device = device if device == "mps" else to_int(device) + self._device = device if device in ["cpu", "mps"] else to_int(device) self._model = model self._disable_classifier = disable_classifier self._disable_llm = disable_llm @@ -118,7 +118,9 @@ def get_topic_ensemble( if score > self._model_threshold and topic in self._valid_topics: succesfully_on_topic.append(topic) if score > self._model_threshold and topic in self._invalid_topics: - return FailResult(error_message=f"Invalid topic {topic} was found to be relevant.") + return FailResult( + error_message=f"Invalid topic {topic} was found to be relevant." + ) if not succesfully_on_topic: return FailResult(error_message="No valid topic was found.") return self.get_topic_llm(text, candidate_topics) @@ -211,9 +213,9 @@ def get_topic_zero_shot( classifier = pipeline( "zero-shot-classification", model=self._model, - device="mps", + device=self._device, hypothesis_template="This example has to do with topic {}.", - multi_label=True + multi_label=True, ) result = classifier(text, candidate_topics) topics = result["labels"] @@ -236,7 +238,6 @@ def validate( if bool(valid_topics.intersection(invalid_topics)): raise ValueError("A topic cannot be valid and invalid at the same time.") - # Check which model(s) to use if self._disable_classifier and self._disable_llm: # Error, no model set raise ValueError("Either classifier or llm must be enabled.") @@ -248,13 +249,17 @@ def validate( return self.get_topic_llm(value, list(invalid_topics)) # Use only Zero-Shot - topics, scores = self.get_topic_zero_shot(value, list(invalid_topics) + list(valid_topics)) + topics, scores = self.get_topic_zero_shot( + value, list(invalid_topics) + list(valid_topics) + ) succesfully_on_topic = [] for score, topic in zip(scores, topics): if score > self._model_threshold and topic in self._valid_topics: succesfully_on_topic.append(topic) if score > self._model_threshold and topic in self._invalid_topics: - return FailResult(error_message=f"Invalid {topic} was found to be relevant.") + return FailResult( + error_message=f"Invalid {topic} was found to be relevant." + ) if not succesfully_on_topic: return FailResult(error_message="No valid topic was found.") - return PassResult() \ No newline at end of file + return PassResult() From 83459a5cc7bdf58902521c73fe25ac39958068bc Mon Sep 17 00:00:00 2001 From: Wyatt Lansford <22553069+wylansford@users.noreply.github.com> Date: Wed, 29 May 2024 14:43:22 -0700 Subject: [PATCH 05/25] moving model initialization to init instead of call --- validator/main.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/validator/main.py b/validator/main.py index 3a48a50..89f3828 100644 --- a/validator/main.py +++ b/validator/main.py @@ -108,6 +108,13 @@ def __init__( self._model_threshold = model_threshold self.set_callable(llm_callable) + self.classifier = pipeline( + "zero-shot-classification", + model=self._model, + device=self._device, + hypothesis_template="This example has to do with topic {}.", + multi_label=True, + ) def get_topic_ensemble( self, text: str, candidate_topics: List[str] @@ -210,14 +217,8 @@ def openai_callable(text: str, topics: List[str]) -> str: def get_topic_zero_shot( self, text: str, candidate_topics: List[str] ) -> Tuple[str, float]: - classifier = pipeline( - "zero-shot-classification", - model=self._model, - device=self._device, - hypothesis_template="This example has to do with topic {}.", - multi_label=True, - ) - result = classifier(text, candidate_topics) + + result = self.classifier(text, candidate_topics) topics = result["labels"] scores = result["scores"] return topics, scores From 1d053c5a66ee30c38a8df1a30d4c97da9b7f8862 Mon Sep 17 00:00:00 2001 From: Wyatt Lansford <22553069+wylansford@users.noreply.github.com> Date: Wed, 29 May 2024 15:13:18 -0700 Subject: [PATCH 06/25] fixing llm call --- validator/main.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/validator/main.py b/validator/main.py index 89f3828..1259d7f 100644 --- a/validator/main.py +++ b/validator/main.py @@ -97,7 +97,7 @@ def __init__( else: self._invalid_topics = invalid_topics - self._device = device if device in ["cpu", "mps"] else to_int(device) + self._device = device if device == "mps" else to_int(device) self._model = model self._disable_classifier = disable_classifier self._disable_llm = disable_llm @@ -128,8 +128,7 @@ def get_topic_ensemble( return FailResult( error_message=f"Invalid topic {topic} was found to be relevant." ) - if not succesfully_on_topic: - return FailResult(error_message="No valid topic was found.") + return self.get_topic_llm(text, candidate_topics) def get_topic_llm(self, text: str, candidate_topics: List[str]) -> ValidationResult: @@ -150,7 +149,11 @@ def get_client_args(self) -> Tuple[Optional[str], Optional[str]]: return (api_key, api_base) - @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(5)) + @retry( + wait=wait_random_exponential(min=1, max=60), + stop=stop_after_attempt(5), + reraise=True, + ) def call_llm(self, text: str, topics: List[str]) -> str: """Call the LLM with the given prompt. @@ -161,9 +164,6 @@ def call_llm(self, text: str, topics: List[str]) -> str: Returns: response (str): String representing the LLM response. """ - from dotenv import load_dotenv - - load_dotenv() return self._llm_callable(text, topics) def verify_topic(self, topic: str) -> ValidationResult: @@ -217,7 +217,6 @@ def openai_callable(text: str, topics: List[str]) -> str: def get_topic_zero_shot( self, text: str, candidate_topics: List[str] ) -> Tuple[str, float]: - result = self.classifier(text, candidate_topics) topics = result["labels"] scores = result["scores"] From 1503bef2ff9e0bc88b39ce18116bd71beb2d8f9d Mon Sep 17 00:00:00 2001 From: Wyatt Lansford <22553069+wylansford@users.noreply.github.com> Date: Wed, 29 May 2024 22:13:28 -0700 Subject: [PATCH 07/25] json mode for openai call plus cleanup & organization of validation --- validator/main.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/validator/main.py b/validator/main.py index 1259d7f..2412728 100644 --- a/validator/main.py +++ b/validator/main.py @@ -120,15 +120,15 @@ def get_topic_ensemble( self, text: str, candidate_topics: List[str] ) -> ValidationResult: topics, scores = self.get_topic_zero_shot(text, candidate_topics) - succesfully_on_topic = [] + failed = [] for score, topic in zip(scores, topics): - if score > self._model_threshold and topic in self._valid_topics: - succesfully_on_topic.append(topic) if score > self._model_threshold and topic in self._invalid_topics: - return FailResult( - error_message=f"Invalid topic {topic} was found to be relevant." - ) + failed.append(topic) + if failed: + return FailResult( + error_message=f"The following invalid topics were found to be relevant: {failed}", + ) return self.get_topic_llm(text, candidate_topics) def get_topic_llm(self, text: str, candidate_topics: List[str]) -> ValidationResult: @@ -195,6 +195,7 @@ def openai_callable(text: str, topics: List[str]) -> str: api_key, api_base = self.get_client_args() response = OpenAIClient(api_key, api_base).create_chat_completion( model=llm_callable, + response_format={"type": "json_object"}, messages=[ { "role": "user", From bc338175de7673f074da82c8d0db70240422252d Mon Sep 17 00:00:00 2001 From: Wyatt Lansford <22553069+wylansford@users.noreply.github.com> Date: Wed, 29 May 2024 22:21:26 -0700 Subject: [PATCH 08/25] add ignore --- validator/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/validator/main.py b/validator/main.py index 2412728..9f50cf9 100644 --- a/validator/main.py +++ b/validator/main.py @@ -16,7 +16,7 @@ @register_validator(name="tryolabs/restricttotopic", data_type="string") -class RestrictToTopic(Validator): +class RestrictToTopic(Validator): # type: ignore """Checks if text's main topic is specified within a list of valid topics and ensures that the text is not about any of the invalid topics. From a54fee40fc9f28839f9455d121d999fd92996ee8 Mon Sep 17 00:00:00 2001 From: Wyatt Lansford <22553069+wylansford@users.noreply.github.com> Date: Thu, 30 May 2024 00:10:18 -0700 Subject: [PATCH 09/25] throwing error if no restricted topics provided --- validator/main.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/validator/main.py b/validator/main.py index 9f50cf9..f83ec5e 100644 --- a/validator/main.py +++ b/validator/main.py @@ -16,7 +16,7 @@ @register_validator(name="tryolabs/restricttotopic", data_type="string") -class RestrictToTopic(Validator): # type: ignore +class RestrictToTopic(Validator): """Checks if text's main topic is specified within a list of valid topics and ensures that the text is not about any of the invalid topics. @@ -234,7 +234,10 @@ def validate( raise ValueError( "`valid_topics` must be set and contain at least one topic." ) - + if not invalid_topics: + raise ValueError( + "`invalid topics` must be set and contain at least one topic." + ) # throw if valid and invalid topics are not disjoint if bool(valid_topics.intersection(invalid_topics)): raise ValueError("A topic cannot be valid and invalid at the same time.") From 52a626af65ea7a565f6f0cfe2234b3e3593ae4a0 Mon Sep 17 00:00:00 2001 From: Wyatt Lansford <22553069+wylansford@users.noreply.github.com> Date: Thu, 30 May 2024 00:14:58 -0700 Subject: [PATCH 10/25] handling when there is no invalid topics --- validator/main.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/validator/main.py b/validator/main.py index f83ec5e..bdc6dbe 100644 --- a/validator/main.py +++ b/validator/main.py @@ -234,10 +234,7 @@ def validate( raise ValueError( "`valid_topics` must be set and contain at least one topic." ) - if not invalid_topics: - raise ValueError( - "`invalid topics` must be set and contain at least one topic." - ) + # throw if valid and invalid topics are not disjoint if bool(valid_topics.intersection(invalid_topics)): raise ValueError("A topic cannot be valid and invalid at the same time.") @@ -247,10 +244,12 @@ def validate( raise ValueError("Either classifier or llm must be enabled.") elif ( not self._disable_classifier and not self._disable_llm - ): # Use ensemble (Zero-Shot + Ensemble) - return self.get_topic_ensemble(value, list(invalid_topics)) + ): + if invalid_topics:# Use ensemble (Zero-Shot + Ensemble) + return self.get_topic_ensemble(value, list(invalid_topics)) elif self._disable_classifier and not self._disable_llm: # Use only LLM - return self.get_topic_llm(value, list(invalid_topics)) + if invalid_topics: + return self.get_topic_llm(value, list(invalid_topics)) # Use only Zero-Shot topics, scores = self.get_topic_zero_shot( From 3bc96d3c910f88ee10fc0733b8f7d10408f99aa8 Mon Sep 17 00:00:00 2001 From: Wyatt Lansford <22553069+wylansford@users.noreply.github.com> Date: Thu, 30 May 2024 16:41:55 -0700 Subject: [PATCH 11/25] cleanup and optimization --- validator/main.py | 209 ++++++++++++++++++++++++++++++---------------- 1 file changed, 139 insertions(+), 70 deletions(-) diff --git a/validator/main.py b/validator/main.py index bdc6dbe..8b9fe1b 100644 --- a/validator/main.py +++ b/validator/main.py @@ -75,9 +75,11 @@ def __init__( model: Optional[str] = "facebook/bart-large-mnli", llm_callable: Union[str, Callable, None] = None, disable_classifier: Optional[bool] = False, + classifier_api_endpoint: Optional[str] = None, disable_llm: Optional[bool] = False, on_fail: Optional[Callable[..., Any]] = None, - model_threshold: Optional[float] = 0.5, + zero_shot_threshold: Optional[float] = 0.5, + llm_theshold: Optional[int] = 3, ): super().__init__( valid_topics=valid_topics, @@ -85,10 +87,12 @@ def __init__( device=device, model=model, disable_classifier=disable_classifier, + classifier_api_endpoint=classifier_api_endpoint, disable_llm=disable_llm, llm_callable=llm_callable, on_fail=on_fail, - model_threshold=model_threshold, + zero_shot_threshold=zero_shot_threshold, + llm_theshold=llm_theshold, ) self._valid_topics = valid_topics @@ -101,40 +105,68 @@ def __init__( self._model = model self._disable_classifier = disable_classifier self._disable_llm = disable_llm + self._classifier_api_endpoint = classifier_api_endpoint - if not model_threshold: - model_threshold = 0.5 - else: - self._model_threshold = model_threshold + self._zero_shot_threshold = zero_shot_threshold + if self._zero_shot_threshold < 0 or self._zero_shot_threshold > 1: + raise ValueError("zero_shot_threshold must be a number between 0 and 1") + self._llm_threshold = llm_theshold + if self._llm_threshold < 0 or self._llm_threshold > 5: + raise ValueError("llm_threshold must be a number between 0 and 5") self.set_callable(llm_callable) - self.classifier = pipeline( - "zero-shot-classification", - model=self._model, - device=self._device, - hypothesis_template="This example has to do with topic {}.", - multi_label=True, - ) - def get_topic_ensemble( - self, text: str, candidate_topics: List[str] - ) -> ValidationResult: - topics, scores = self.get_topic_zero_shot(text, candidate_topics) - failed = [] - for score, topic in zip(scores, topics): - if score > self._model_threshold and topic in self._invalid_topics: - failed.append(topic) + if self._classifier_api_endpoint is None: + self._classifier = pipeline( + "zero-shot-classification", + model=self._model, + device=self._device, + hypothesis_template="This example has to do with topic {}.", + multi_label=True, + ) + else: + # TODO api endpoint + ... - if failed: - return FailResult( - error_message=f"The following invalid topics were found to be relevant: {failed}", + self._json_schema = self._create_json_schema( + self._valid_topics, self._invalid_topics + ) + + def _create_json_schema(self, valid_topics: list, invalid_topics: list): + json_schema = [] + for topic in set(valid_topics + invalid_topics): + json_schema.append( + {topic: {"present": "[bool]", "confidence": "[int, 1, 5]"}} ) - return self.get_topic_llm(text, candidate_topics) + return str(json_schema) + + def get_topic_ensemble(self, text: str, candidate_topics: List[str]) -> list[str]: + """Finds the topics in the input text based on if it is determined by the zero + shot model or the llm. + + Args: + text (str): The input text to find categories from + candidate_topics (List[str]): The topics to search for in the input text + + Returns: + list[str]: The found topics + """ + # Find topics based on zero shot model + zero_shot_topics = self.get_topic_zero_shot(text, candidate_topics) + + # Find topics based on llm + llm_topics = self.get_topic_llm(text, candidate_topics) + + return list(set(zero_shot_topics + llm_topics)) def get_topic_llm(self, text: str, candidate_topics: List[str]) -> ValidationResult: response = self.call_llm(text, candidate_topics) - topic = json.loads(response)["topic"] - return self.verify_topic(topic) + topics = json.loads(response) + found_topics = [] + for topic, data in topics.items(): + if data["present"] and data["confidence"] > self._llm_threshold: + found_topics.append(topic) + return found_topics def get_client_args(self) -> Tuple[Optional[str], Optional[str]]: kwargs = {} @@ -149,11 +181,11 @@ def get_client_args(self) -> Tuple[Optional[str], Optional[str]]: return (api_key, api_base) - @retry( - wait=wait_random_exponential(min=1, max=60), - stop=stop_after_attempt(5), - reraise=True, - ) + # @retry( + # wait=wait_random_exponential(min=1, max=60), + # stop=stop_after_attempt(5), + # reraise=True, + # ) def call_llm(self, text: str, topics: List[str]) -> str: """Call the LLM with the given prompt. @@ -164,13 +196,7 @@ def call_llm(self, text: str, topics: List[str]) -> str: Returns: response (str): String representing the LLM response. """ - return self._llm_callable(text, topics) - - def verify_topic(self, topic: str) -> ValidationResult: - if topic in self._valid_topics: - return PassResult() - else: - return FailResult(error_message=f"Most relevant topic is {topic}.") + return self._llm_callable(text) def set_callable(self, llm_callable: Union[str, Callable, None]) -> None: """Set the LLM callable. @@ -191,7 +217,7 @@ def set_callable(self, llm_callable: Union[str, Callable, None]) -> None: "Check out ProvenanceV1 documentation for an example." ) - def openai_callable(text: str, topics: List[str]) -> str: + def openai_callable(text: str) -> str: api_key, api_base = self.get_client_args() response = OpenAIClient(api_key, api_base).create_chat_completion( model=llm_callable, @@ -199,14 +225,23 @@ def openai_callable(text: str, topics: List[str]) -> str: messages=[ { "role": "user", - "content": f"""Classify the following text {text} - into one of these topics: {topics}. - Format the response as JSON with the following schema: - {{"topic": "topic_name"}}""", + "content": f"""Given a text, fill out the provided json schema with a confidence that the topic is relevant to the text. + + Text + ---- + {text} + + Schema + ------ + {self._json_schema} + + Complete Schema + --------------- + + """, }, ], ) - return response.output self._llm_callable = openai_callable @@ -215,19 +250,46 @@ def openai_callable(text: str, topics: List[str]) -> str: else: raise ValueError("llm_callable must be a string or a Callable") - def get_topic_zero_shot( - self, text: str, candidate_topics: List[str] - ) -> Tuple[str, float]: - result = self.classifier(text, candidate_topics) + def get_topic_zero_shot(self, text: str, candidate_topics: List[str]) -> list[str]: + """Gets the topics found through the zero shot classifier + + Args: + text (str): The text to classify + candidate_topics (List[str]): The potential topics to look for + + Returns: + list[str]: The resulting topics found that meet the given threshold + """ + result = self._classifier(text, candidate_topics) topics = result["labels"] scores = result["scores"] - return topics, scores + found_topics = [] + for topic, score in zip(topics, scores): + if score > self._zero_shot_threshold: + found_topics.append(topic) + return found_topics def validate( self, value: str, metadata: Optional[Dict[str, Any]] = {} ) -> ValidationResult: + """Validates that a string contains at least one valid topic and no invalid topics. + + Args: + value (str): The given string to classify + metadata (Optional[Dict[str, Any]], optional): _description_. Defaults to {}. + + Raises: + ValueError: If a topic is invalid and valid + ValueError: If no valid topics are set + ValueError: If there is no llm or zero shot classifier set + + Returns: + ValidationResult: PassResult if a topic is restricted and valid, + FailResult otherwise + """ valid_topics = set(self._valid_topics) invalid_topics = set(self._invalid_topics) + all_topics = list(set(valid_topics) | set(invalid_topics)) # throw if valid and invalid topics are empty if not valid_topics: @@ -242,27 +304,34 @@ def validate( # Check which model(s) to use if self._disable_classifier and self._disable_llm: # Error, no model set raise ValueError("Either classifier or llm must be enabled.") - elif ( - not self._disable_classifier and not self._disable_llm - ): - if invalid_topics:# Use ensemble (Zero-Shot + Ensemble) - return self.get_topic_ensemble(value, list(invalid_topics)) - elif self._disable_classifier and not self._disable_llm: # Use only LLM - if invalid_topics: - return self.get_topic_llm(value, list(invalid_topics)) + + # Use ensemble (Zero-Shot + Ensemble) + elif not self._disable_classifier and not self._disable_llm: + found_topics = self.get_topic_ensemble(value, all_topics) + + # Use only LLM + elif self._disable_classifier and not self._disable_llm: + found_topics = self.get_topic_llm(value, all_topics) # Use only Zero-Shot - topics, scores = self.get_topic_zero_shot( - value, list(invalid_topics) + list(valid_topics) - ) - succesfully_on_topic = [] - for score, topic in zip(scores, topics): - if score > self._model_threshold and topic in self._valid_topics: - succesfully_on_topic.append(topic) - if score > self._model_threshold and topic in self._invalid_topics: - return FailResult( - error_message=f"Invalid {topic} was found to be relevant." - ) - if not succesfully_on_topic: + elif not self._disable_classifier and self._disable_llm: + found_topics = self.get_topic_zero_shot(value, all_topics) + + # Determine if valid or invalid topics were found + invalid_topics_found = [] + valid_topics_found = [] + for topic in found_topics: + if topic in self._valid_topics: + valid_topics_found.append(topic) + elif topic in self._invalid_topics: + invalid_topics_found.append(topic) + + # Require at least one valid topic and no invalid topics + if invalid_topics_found: + return FailResult( + error_message=f"Invalid topics found: {invalid_topics_found}" + ) + if not valid_topics_found: return FailResult(error_message="No valid topic was found.") + return PassResult() From 25c550a6830bb8e678d30b3bc2561ffd08f97fa9 Mon Sep 17 00:00:00 2001 From: Wyatt Lansford <22553069+wylansford@users.noreply.github.com> Date: Thu, 30 May 2024 16:44:09 -0700 Subject: [PATCH 12/25] comments cleanup --- validator/main.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/validator/main.py b/validator/main.py index 8b9fe1b..aa9516b 100644 --- a/validator/main.py +++ b/validator/main.py @@ -181,11 +181,11 @@ def get_client_args(self) -> Tuple[Optional[str], Optional[str]]: return (api_key, api_base) - # @retry( - # wait=wait_random_exponential(min=1, max=60), - # stop=stop_after_attempt(5), - # reraise=True, - # ) + @retry( + wait=wait_random_exponential(min=1, max=60), + stop=stop_after_attempt(5), + reraise=True, + ) def call_llm(self, text: str, topics: List[str]) -> str: """Call the LLM with the given prompt. @@ -301,19 +301,19 @@ def validate( if bool(valid_topics.intersection(invalid_topics)): raise ValueError("A topic cannot be valid and invalid at the same time.") - # Check which model(s) to use + # Verify at least one is enabled if self._disable_classifier and self._disable_llm: # Error, no model set raise ValueError("Either classifier or llm must be enabled.") - # Use ensemble (Zero-Shot + Ensemble) + # Case: both enabled/ensemble (Zero-Shot + Ensemble) elif not self._disable_classifier and not self._disable_llm: found_topics = self.get_topic_ensemble(value, all_topics) - # Use only LLM + # Case: Only use LLM elif self._disable_classifier and not self._disable_llm: found_topics = self.get_topic_llm(value, all_topics) - # Use only Zero-Shot + # Case: Only use Zero-Shot elif not self._disable_classifier and self._disable_llm: found_topics = self.get_topic_zero_shot(value, all_topics) From 77c828239becbce756cd028525e9619f1f229fdb Mon Sep 17 00:00:00 2001 From: Wyatt Lansford <22553069+wylansford@users.noreply.github.com> Date: Thu, 30 May 2024 16:47:36 -0700 Subject: [PATCH 13/25] adding hallucination check for topics the llm returns --- validator/main.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/validator/main.py b/validator/main.py index aa9516b..99d5d4b 100644 --- a/validator/main.py +++ b/validator/main.py @@ -160,12 +160,14 @@ def get_topic_ensemble(self, text: str, candidate_topics: List[str]) -> list[str return list(set(zero_shot_topics + llm_topics)) def get_topic_llm(self, text: str, candidate_topics: List[str]) -> ValidationResult: - response = self.call_llm(text, candidate_topics) + response = self.call_llm(text) topics = json.loads(response) found_topics = [] for topic, data in topics.items(): if data["present"] and data["confidence"] > self._llm_threshold: - found_topics.append(topic) + # Verify the llm didn't hallucinate a topic. + if topic in candidate_topics: + found_topics.append(topic) return found_topics def get_client_args(self) -> Tuple[Optional[str], Optional[str]]: @@ -186,7 +188,7 @@ def get_client_args(self) -> Tuple[Optional[str], Optional[str]]: stop=stop_after_attempt(5), reraise=True, ) - def call_llm(self, text: str, topics: List[str]) -> str: + def call_llm(self, text: str) -> str: """Call the LLM with the given prompt. Expects a function that takes a string and returns a string. From cccebfb1118a249f53ef0eb521257b8ec0eecc19 Mon Sep 17 00:00:00 2001 From: Wyatt Lansford <22553069+wylansford@users.noreply.github.com> Date: Thu, 30 May 2024 16:52:27 -0700 Subject: [PATCH 14/25] more docstrings --- validator/main.py | 31 +++++++++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/validator/main.py b/validator/main.py index 99d5d4b..7f2b11b 100644 --- a/validator/main.py +++ b/validator/main.py @@ -132,7 +132,19 @@ def __init__( self._valid_topics, self._invalid_topics ) - def _create_json_schema(self, valid_topics: list, invalid_topics: list): + def _create_json_schema(self, valid_topics: list, invalid_topics: list) -> str: + """Creates a json schema that an LLM will fill out. The json schema contains + one of each of the provided topics, as well as a blank 'present' and 'confidence' + for the llm to fill in. + + Args: + valid_topics (list): The valid topics to provide as one of the json schema + invalid_topics (list): Invalid topics to provide as one of the json schema + + Returns: + str: The resulting json schema with unfilled data types + """ + json_schema = [] for topic in set(valid_topics + invalid_topics): json_schema.append( @@ -159,7 +171,17 @@ def get_topic_ensemble(self, text: str, candidate_topics: List[str]) -> list[str return list(set(zero_shot_topics + llm_topics)) - def get_topic_llm(self, text: str, candidate_topics: List[str]) -> ValidationResult: + def get_topic_llm(self, text: str, candidate_topics: List[str]) -> list[str]: + """Returns a list of the topics identified in the given text using an LLM + callable + + Args: + text (str): The input text to classify topics. + candidate_topics (List[str]): The topics to identify if present in the text. + + Returns: + list[str]: The topics found in the input text. + """ response = self.call_llm(text) topics = json.loads(response) found_topics = [] @@ -171,6 +193,11 @@ def get_topic_llm(self, text: str, candidate_topics: List[str]) -> ValidationRes return found_topics def get_client_args(self) -> Tuple[Optional[str], Optional[str]]: + """Returns neccessary data for api calls. + + Returns: + Tuple[Optional[str], Optional[str]]: api key and api base values + """ kwargs = {} context_copy = contextvars.copy_context() for key, context_var in context_copy.items(): From 457a1a7e4c77db13f6bd67503608cfdeca76f11d Mon Sep 17 00:00:00 2001 From: Wyatt Lansford <22553069+wylansford@users.noreply.github.com> Date: Thu, 30 May 2024 17:37:04 -0700 Subject: [PATCH 15/25] update docstring --- validator/main.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/validator/main.py b/validator/main.py index 7f2b11b..0e4f19c 100644 --- a/validator/main.py +++ b/validator/main.py @@ -59,12 +59,17 @@ class RestrictToTopic(Validator): disable_classifier (bool, Optional, defaults to False): controls whether to use the Zero-Shot model. At least one of disable_classifier and disable_llm must be False. + classifier_api_endpoint (str, Optional, defaults to None): An API endpoint + to recieve post requests that will be used when provided. If not provided, a + local model will be initialized. disable_llm (bool, Optional, defaults to False): controls whether to use the LLM fallback. At least one of disable_classifier and disable_llm must be False. - model_threshold (float, Optional, defaults to 0.5): The threshold used to + zero_shot_threshold (float, Optional, defaults to 0.5): The threshold used to determine whether to accept a topic from the Zero-Shot model. Must be a number between 0 and 1. + llm_threshold (int, Optional, defaults to 3): The threshold used to determine + if a topic exists based on the provided llm api. Must be between 0 and 5. """ def __init__( From 46bec4527a58efab18ca70558b11a15932afcdae Mon Sep 17 00:00:00 2001 From: Wyatt Lansford <22553069+wylansford@users.noreply.github.com> Date: Thu, 30 May 2024 17:43:33 -0700 Subject: [PATCH 16/25] cleanup if organization --- validator/main.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/validator/main.py b/validator/main.py index 0e4f19c..eee7765 100644 --- a/validator/main.py +++ b/validator/main.py @@ -335,21 +335,17 @@ def validate( if bool(valid_topics.intersection(invalid_topics)): raise ValueError("A topic cannot be valid and invalid at the same time.") - # Verify at least one is enabled - if self._disable_classifier and self._disable_llm: # Error, no model set - raise ValueError("Either classifier or llm must be enabled.") - - # Case: both enabled/ensemble (Zero-Shot + Ensemble) - elif not self._disable_classifier and not self._disable_llm: - found_topics = self.get_topic_ensemble(value, all_topics) - - # Case: Only use LLM + # Ensemble method + if not self._disable_classifier and not self._disable_llm: + found_topics = self.get_topics_ensemble(value, invalid_topics) + # LLM Classifier Only elif self._disable_classifier and not self._disable_llm: - found_topics = self.get_topic_llm(value, all_topics) - - # Case: Only use Zero-Shot + found_topics = self.get_topics_llm(value, invalid_topics) + # Zero Shot Classifier Only elif not self._disable_classifier and self._disable_llm: - found_topics = self.get_topic_zero_shot(value, all_topics) + found_topics, _ = self.get_topic_zero_shot(value, invalid_topics) + else: + raise ValueError("Either classifier or llm must be enabled.") # Determine if valid or invalid topics were found invalid_topics_found = [] From becb2314026b46a76183d25efbcb988b58555c92 Mon Sep 17 00:00:00 2001 From: Wyatt Lansford <22553069+wylansford@users.noreply.github.com> Date: Thu, 30 May 2024 17:51:12 -0700 Subject: [PATCH 17/25] fix bad function call and cleanup --- validator/main.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/validator/main.py b/validator/main.py index eee7765..9bae569 100644 --- a/validator/main.py +++ b/validator/main.py @@ -157,7 +157,7 @@ def _create_json_schema(self, valid_topics: list, invalid_topics: list) -> str: ) return str(json_schema) - def get_topic_ensemble(self, text: str, candidate_topics: List[str]) -> list[str]: + def get_topics_ensemble(self, text: str, candidate_topics: List[str]) -> list[str]: """Finds the topics in the input text based on if it is determined by the zero shot model or the llm. @@ -169,14 +169,14 @@ def get_topic_ensemble(self, text: str, candidate_topics: List[str]) -> list[str list[str]: The found topics """ # Find topics based on zero shot model - zero_shot_topics = self.get_topic_zero_shot(text, candidate_topics) + zero_shot_topics = self.get_topics_zero_shot(text, candidate_topics) # Find topics based on llm - llm_topics = self.get_topic_llm(text, candidate_topics) + llm_topics = self.get_topics_llm(text, candidate_topics) return list(set(zero_shot_topics + llm_topics)) - def get_topic_llm(self, text: str, candidate_topics: List[str]) -> list[str]: + def get_topics_llm(self, text: str, candidate_topics: List[str]) -> list[str]: """Returns a list of the topics identified in the given text using an LLM callable @@ -284,7 +284,7 @@ def openai_callable(text: str) -> str: else: raise ValueError("llm_callable must be a string or a Callable") - def get_topic_zero_shot(self, text: str, candidate_topics: List[str]) -> list[str]: + def get_topics_zero_shot(self, text: str, candidate_topics: List[str]) -> list[str]: """Gets the topics found through the zero shot classifier Args: @@ -337,13 +337,13 @@ def validate( # Ensemble method if not self._disable_classifier and not self._disable_llm: - found_topics = self.get_topics_ensemble(value, invalid_topics) + found_topics = self.get_topics_ensemble(value, all_topics) # LLM Classifier Only elif self._disable_classifier and not self._disable_llm: - found_topics = self.get_topics_llm(value, invalid_topics) + found_topics = self.get_topics_llm(value, all_topics) # Zero Shot Classifier Only elif not self._disable_classifier and self._disable_llm: - found_topics, _ = self.get_topic_zero_shot(value, invalid_topics) + found_topics = self.get_topics_zero_shot(value, all_topics) else: raise ValueError("Either classifier or llm must be enabled.") From 35757f716ab53770666db8bf65f950b6e7035086 Mon Sep 17 00:00:00 2001 From: Wyatt Lansford <22553069+wylansford@users.noreply.github.com> Date: Thu, 30 May 2024 18:31:16 -0700 Subject: [PATCH 18/25] improve device check slightly --- validator/main.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/validator/main.py b/validator/main.py index 9bae569..bf00a0e 100644 --- a/validator/main.py +++ b/validator/main.py @@ -76,7 +76,7 @@ def __init__( self, valid_topics: List[str], invalid_topics: Optional[List[str]] = [], - device: Optional[int] = -1, + device: Optional[Union[str, int]] = -1, model: Optional[str] = "facebook/bart-large-mnli", llm_callable: Union[str, Callable, None] = None, disable_classifier: Optional[bool] = False, @@ -106,7 +106,9 @@ def __init__( else: self._invalid_topics = invalid_topics - self._device = device if device == "mps" else to_int(device) + self._device = ( + device.lower() if device.lower() in ["cpu", "mps"] else int(device) + ) self._model = model self._disable_classifier = disable_classifier self._disable_llm = disable_llm From 42cf0968b1948af7915501838789d259b7b1cf46 Mon Sep 17 00:00:00 2001 From: Wyatt Lansford <22553069+wylansford@users.noreply.github.com> Date: Thu, 30 May 2024 21:07:46 -0700 Subject: [PATCH 19/25] switched llm call from prompting to function calling --- validator/main.py | 63 +++++++++++++++++++++++++++++++++++------------ 1 file changed, 47 insertions(+), 16 deletions(-) diff --git a/validator/main.py b/validator/main.py index bf00a0e..04905e6 100644 --- a/validator/main.py +++ b/validator/main.py @@ -3,7 +3,6 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union from guardrails.utils.casting_utils import to_int -from guardrails.utils.openai_utils import OpenAIClient from guardrails.validator_base import ( FailResult, PassResult, @@ -11,6 +10,7 @@ Validator, register_validator, ) +from openai import OpenAI from tenacity import retry, stop_after_attempt, wait_random_exponential from transformers import pipeline @@ -107,7 +107,9 @@ def __init__( self._invalid_topics = invalid_topics self._device = ( - device.lower() if device.lower() in ["cpu", "mps"] else int(device) + str(device).lower() + if str(device).lower() in ["cpu", "mps"] + else int(device) ) self._model = model self._disable_classifier = disable_classifier @@ -135,7 +137,7 @@ def __init__( # TODO api endpoint ... - self._json_schema = self._create_json_schema( + self._json_schema, self._tools = self._create_json_schema( self._valid_topics, self._invalid_topics ) @@ -151,13 +153,38 @@ def _create_json_schema(self, valid_topics: list, invalid_topics: list) -> str: Returns: str: The resulting json schema with unfilled data types """ + tools = [ + { + "type": "function", + "function": { + "name": "is_topic_relevant", + "description": "Determine if the provided text is about a topic, with a confidence score.", + "parameters": { + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "Simply the repeated name of the given topic.", + }, + "present": { + "type": "boolean", + "description": "If the given topic is discussed in the given text.", + }, + "confidence": { + "type": "integer", + "description": "The confidence level of the topic being present in the text, from 1-5", + }, + }, + "required": ["name", "present", "confidence"], + }, + }, + }, + ] json_schema = [] for topic in set(valid_topics + invalid_topics): - json_schema.append( - {topic: {"present": "[bool]", "confidence": "[int, 1, 5]"}} - ) - return str(json_schema) + json_schema.append({"topic": topic}) + return json_schema, tools def get_topics_ensemble(self, text: str, candidate_topics: List[str]) -> list[str]: """Finds the topics in the input text based on if it is determined by the zero @@ -189,14 +216,13 @@ def get_topics_llm(self, text: str, candidate_topics: List[str]) -> list[str]: Returns: list[str]: The topics found in the input text. """ - response = self.call_llm(text) - topics = json.loads(response) + topics = self.call_llm(text) found_topics = [] - for topic, data in topics.items(): - if data["present"] and data["confidence"] > self._llm_threshold: + for llm_result in topics: + if llm_result["present"] and llm_result["confidence"] > self._llm_threshold: # Verify the llm didn't hallucinate a topic. - if topic in candidate_topics: - found_topics.append(topic) + if llm_result["name"] in candidate_topics: + found_topics.append(llm_result["name"]) return found_topics def get_client_args(self) -> Tuple[Optional[str], Optional[str]]: @@ -255,13 +281,14 @@ def set_callable(self, llm_callable: Union[str, Callable, None]) -> None: def openai_callable(text: str) -> str: api_key, api_base = self.get_client_args() - response = OpenAIClient(api_key, api_base).create_chat_completion( + client = OpenAI() + response = client.chat.completions.create( model=llm_callable, response_format={"type": "json_object"}, messages=[ { "role": "user", - "content": f"""Given a text, fill out the provided json schema with a confidence that the topic is relevant to the text. + "content": f"""Given a series of topics, determine if the topic is present in the provided text. Return the result as json. Text ---- @@ -277,8 +304,12 @@ def openai_callable(text: str) -> str: """, }, ], + tools=self._tools, ) - return response.output + tool_calls = [] + for tool_call in response.choices[0].message.tool_calls: + tool_calls.append(json.loads(tool_call.function.arguments)) + return tool_calls self._llm_callable = openai_callable elif isinstance(llm_callable, Callable): From 3dc53cf7db1aea5f983a2c11caac1c2791d13a66 Mon Sep 17 00:00:00 2001 From: Wyatt Lansford <22553069+wylansford@users.noreply.github.com> Date: Fri, 31 May 2024 00:30:04 -0700 Subject: [PATCH 20/25] adding load dot env to args to correctly set api key --- validator/main.py | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/validator/main.py b/validator/main.py index 04905e6..dcbfac5 100644 --- a/validator/main.py +++ b/validator/main.py @@ -1,8 +1,9 @@ import contextvars import json +import os from typing import Any, Callable, Dict, List, Optional, Tuple, Union -from guardrails.utils.casting_utils import to_int +from dotenv import load_dotenv from guardrails.validator_base import ( FailResult, PassResult, @@ -225,23 +226,26 @@ def get_topics_llm(self, text: str, candidate_topics: List[str]) -> list[str]: found_topics.append(llm_result["name"]) return found_topics - def get_client_args(self) -> Tuple[Optional[str], Optional[str]]: + def get_client_args(self) -> str: """Returns neccessary data for api calls. Returns: - Tuple[Optional[str], Optional[str]]: api key and api base values + str: api key """ - kwargs = {} - context_copy = contextvars.copy_context() - for key, context_var in context_copy.items(): - if key.name == "kwargs" and isinstance(kwargs, dict): - kwargs = context_var - break - api_key = kwargs.get("api_key") - api_base = kwargs.get("api_base") + load_dotenv() + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + kwargs = {} + context_copy = contextvars.copy_context() + for key, context_var in context_copy.items(): + if key.name == "kwargs" and isinstance(kwargs, dict): + kwargs = context_var + break - return (api_key, api_base) + api_key = kwargs.get("api_key") + + return api_key @retry( wait=wait_random_exponential(min=1, max=60), @@ -280,8 +284,8 @@ def set_callable(self, llm_callable: Union[str, Callable, None]) -> None: ) def openai_callable(text: str) -> str: - api_key, api_base = self.get_client_args() - client = OpenAI() + api_key = self.get_client_args() + client = OpenAI(api_key=api_key) response = client.chat.completions.create( model=llm_callable, response_format={"type": "json_object"}, From 793537496aace90376a867b3f9270e420a8518d8 Mon Sep 17 00:00:00 2001 From: Wyatt Lansford <22553069+wylansford@users.noreply.github.com> Date: Fri, 31 May 2024 14:07:09 -0700 Subject: [PATCH 21/25] list[str] -> typing.List[str] --- validator/main.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/validator/main.py b/validator/main.py index dcbfac5..bbf1887 100644 --- a/validator/main.py +++ b/validator/main.py @@ -187,7 +187,7 @@ def _create_json_schema(self, valid_topics: list, invalid_topics: list) -> str: json_schema.append({"topic": topic}) return json_schema, tools - def get_topics_ensemble(self, text: str, candidate_topics: List[str]) -> list[str]: + def get_topics_ensemble(self, text: str, candidate_topics: List[str]) -> List[str]: """Finds the topics in the input text based on if it is determined by the zero shot model or the llm. @@ -196,7 +196,7 @@ def get_topics_ensemble(self, text: str, candidate_topics: List[str]) -> list[st candidate_topics (List[str]): The topics to search for in the input text Returns: - list[str]: The found topics + List[str]: The found topics """ # Find topics based on zero shot model zero_shot_topics = self.get_topics_zero_shot(text, candidate_topics) @@ -206,7 +206,7 @@ def get_topics_ensemble(self, text: str, candidate_topics: List[str]) -> list[st return list(set(zero_shot_topics + llm_topics)) - def get_topics_llm(self, text: str, candidate_topics: List[str]) -> list[str]: + def get_topics_llm(self, text: str, candidate_topics: List[str]) -> List[str]: """Returns a list of the topics identified in the given text using an LLM callable @@ -215,7 +215,7 @@ def get_topics_llm(self, text: str, candidate_topics: List[str]) -> list[str]: candidate_topics (List[str]): The topics to identify if present in the text. Returns: - list[str]: The topics found in the input text. + List[str]: The topics found in the input text. """ topics = self.call_llm(text) found_topics = [] @@ -321,7 +321,7 @@ def openai_callable(text: str) -> str: else: raise ValueError("llm_callable must be a string or a Callable") - def get_topics_zero_shot(self, text: str, candidate_topics: List[str]) -> list[str]: + def get_topics_zero_shot(self, text: str, candidate_topics: List[str]) -> List[str]: """Gets the topics found through the zero shot classifier Args: @@ -329,7 +329,7 @@ def get_topics_zero_shot(self, text: str, candidate_topics: List[str]) -> list[s candidate_topics (List[str]): The potential topics to look for Returns: - list[str]: The resulting topics found that meet the given threshold + List[str]: The resulting topics found that meet the given threshold """ result = self._classifier(text, candidate_topics) topics = result["labels"] From 165a2c116bc5942e2574f691be487e75ae638d79 Mon Sep 17 00:00:00 2001 From: Wyatt Lansford <22553069+wylansford@users.noreply.github.com> Date: Fri, 31 May 2024 14:25:15 -0700 Subject: [PATCH 22/25] updating to gpt-4o, no function calling --- validator/main.py | 98 ++++++++++++----------------------------------- 1 file changed, 25 insertions(+), 73 deletions(-) diff --git a/validator/main.py b/validator/main.py index bbf1887..6fcae55 100644 --- a/validator/main.py +++ b/validator/main.py @@ -142,51 +142,6 @@ def __init__( self._valid_topics, self._invalid_topics ) - def _create_json_schema(self, valid_topics: list, invalid_topics: list) -> str: - """Creates a json schema that an LLM will fill out. The json schema contains - one of each of the provided topics, as well as a blank 'present' and 'confidence' - for the llm to fill in. - - Args: - valid_topics (list): The valid topics to provide as one of the json schema - invalid_topics (list): Invalid topics to provide as one of the json schema - - Returns: - str: The resulting json schema with unfilled data types - """ - tools = [ - { - "type": "function", - "function": { - "name": "is_topic_relevant", - "description": "Determine if the provided text is about a topic, with a confidence score.", - "parameters": { - "type": "object", - "properties": { - "name": { - "type": "string", - "description": "Simply the repeated name of the given topic.", - }, - "present": { - "type": "boolean", - "description": "If the given topic is discussed in the given text.", - }, - "confidence": { - "type": "integer", - "description": "The confidence level of the topic being present in the text, from 1-5", - }, - }, - "required": ["name", "present", "confidence"], - }, - }, - }, - ] - - json_schema = [] - for topic in set(valid_topics + invalid_topics): - json_schema.append({"topic": topic}) - return json_schema, tools - def get_topics_ensemble(self, text: str, candidate_topics: List[str]) -> List[str]: """Finds the topics in the input text based on if it is determined by the zero shot model or the llm. @@ -217,13 +172,11 @@ def get_topics_llm(self, text: str, candidate_topics: List[str]) -> List[str]: Returns: List[str]: The topics found in the input text. """ - topics = self.call_llm(text) + llm_topics = self.call_llm(text, candidate_topics) found_topics = [] - for llm_result in topics: - if llm_result["present"] and llm_result["confidence"] > self._llm_threshold: - # Verify the llm didn't hallucinate a topic. - if llm_result["name"] in candidate_topics: - found_topics.append(llm_result["name"]) + for llm_topic in llm_topics: + if llm_topic in candidate_topics: + found_topics.append(llm_topic) return found_topics def get_client_args(self) -> str: @@ -252,7 +205,7 @@ def get_client_args(self) -> str: stop=stop_after_attempt(5), reraise=True, ) - def call_llm(self, text: str) -> str: + def call_llm(self, text: str, topics: List[str]) -> str: """Call the LLM with the given prompt. Expects a function that takes a string and returns a string. @@ -262,7 +215,7 @@ def call_llm(self, text: str) -> str: Returns: response (str): String representing the LLM response. """ - return self._llm_callable(text) + return self._llm_callable(text, topics) def set_callable(self, llm_callable: Union[str, Callable, None]) -> None: """Set the LLM callable. @@ -273,17 +226,17 @@ def set_callable(self, llm_callable: Union[str, Callable, None]) -> None: """ if llm_callable is None: - llm_callable = "gpt-3.5-turbo" + llm_callable = "gpt-4o" if isinstance(llm_callable, str): - if llm_callable not in ["gpt-3.5-turbo", "gpt-4"]: + if llm_callable not in ["gpt-3.5-turbo", "gpt-4", "gpt-4o"]: raise ValueError( - "llm_callable must be one of 'gpt-3.5-turbo' or 'gpt-4'." + "llm_callable must be one of 'gpt-3.5-turbo', 'gpt-4', or 'gpt-4o'" "If you want to use a custom LLM, please provide a callable." "Check out ProvenanceV1 documentation for an example." ) - def openai_callable(text: str) -> str: + def openai_callable(text: str, topics: List[str]) -> str: api_key = self.get_client_args() client = OpenAI(api_key=api_key) response = client.chat.completions.create( @@ -292,28 +245,27 @@ def openai_callable(text: str) -> str: messages=[ { "role": "user", - "content": f"""Given a series of topics, determine if the topic is present in the provided text. Return the result as json. - - Text - ---- - {text} + "content": f""" + Given a text and a list of topics, return a valid json list of which topics are present in the text. If none, just return an empty list. + + Output Format: + ------------- + "topics_present": [] - Schema - ------ - {self._json_schema} + Text: + ---- + "{text}" - Complete Schema - --------------- + Topics: + ------ + {topics} - """, + Result: + ------ """, }, ], - tools=self._tools, ) - tool_calls = [] - for tool_call in response.choices[0].message.tool_calls: - tool_calls.append(json.loads(tool_call.function.arguments)) - return tool_calls + return json.loads(response.choices[0].message.content)["topics_present"] self._llm_callable = openai_callable elif isinstance(llm_callable, Callable): From e19f39cdb9925ab986c4e4646f884bf82520369f Mon Sep 17 00:00:00 2001 From: Wyatt Lansford <22553069+wylansford@users.noreply.github.com> Date: Fri, 31 May 2024 14:30:35 -0700 Subject: [PATCH 23/25] updating to gpt-4o, no function calling --- validator/main.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/validator/main.py b/validator/main.py index 6fcae55..ae52a83 100644 --- a/validator/main.py +++ b/validator/main.py @@ -138,9 +138,6 @@ def __init__( # TODO api endpoint ... - self._json_schema, self._tools = self._create_json_schema( - self._valid_topics, self._invalid_topics - ) def get_topics_ensemble(self, text: str, candidate_topics: List[str]) -> List[str]: """Finds the topics in the input text based on if it is determined by the zero From f670aa966f61ee9c322a2c41a8a8932340c70251 Mon Sep 17 00:00:00 2001 From: Wyatt Lansford <22553069+wylansford@users.noreply.github.com> Date: Mon, 3 Jun 2024 11:43:16 -0700 Subject: [PATCH 24/25] fixing typing docs --- validator/main.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/validator/main.py b/validator/main.py index ae52a83..e552e85 100644 --- a/validator/main.py +++ b/validator/main.py @@ -46,8 +46,8 @@ class RestrictToTopic(Validator): (one or many). invalid_topics (List[str], Optional, defaults to []): topics that the text cannot be about. - device (int, Optional, defaults to -1): Device ordinal for CPU/GPU - supports for Zero-Shot classifier. Setting this to -1 will leverage + device (Optional[Union[str, int]], Optional, defaults to -1): Device ordinal for + CPU/GPU supports for Zero-Shot classifier. Setting this to -1 will leverage CPU, a positive will run the Zero-Shot model on the associated CUDA device id. model (str, Optional, defaults to 'facebook/bart-large-mnli'): The @@ -55,7 +55,7 @@ class RestrictToTopic(Validator): list of all models here: https://huggingface.co/models?pipeline_tag=zero-shot-classification llm_callable (Union[str, Callable, None], Optional, defaults to - 'gpt-3.5-turbo'): Either the name of the OpenAI model, or a callable + 'gpt-4o'): Either the name of the OpenAI model, or a callable that takes a prompt and returns a response. disable_classifier (bool, Optional, defaults to False): controls whether to use the Zero-Shot model. At least one of disable_classifier and @@ -138,7 +138,6 @@ def __init__( # TODO api endpoint ... - def get_topics_ensemble(self, text: str, candidate_topics: List[str]) -> List[str]: """Finds the topics in the input text based on if it is determined by the zero shot model or the llm. From 800f546e69e728775babad082b03500e91bda746 Mon Sep 17 00:00:00 2001 From: Wyatt Lansford <22553069+wylansford@users.noreply.github.com> Date: Mon, 3 Jun 2024 11:44:33 -0700 Subject: [PATCH 25/25] fixing typo theshold -> threshold --- validator/main.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/validator/main.py b/validator/main.py index e552e85..8a796cc 100644 --- a/validator/main.py +++ b/validator/main.py @@ -85,7 +85,7 @@ def __init__( disable_llm: Optional[bool] = False, on_fail: Optional[Callable[..., Any]] = None, zero_shot_threshold: Optional[float] = 0.5, - llm_theshold: Optional[int] = 3, + llm_threshold: Optional[int] = 3, ): super().__init__( valid_topics=valid_topics, @@ -98,7 +98,7 @@ def __init__( llm_callable=llm_callable, on_fail=on_fail, zero_shot_threshold=zero_shot_threshold, - llm_theshold=llm_theshold, + llm_threshold=llm_threshold, ) self._valid_topics = valid_topics @@ -121,7 +121,7 @@ def __init__( if self._zero_shot_threshold < 0 or self._zero_shot_threshold > 1: raise ValueError("zero_shot_threshold must be a number between 0 and 1") - self._llm_threshold = llm_theshold + self._llm_threshold = llm_threshold if self._llm_threshold < 0 or self._llm_threshold > 5: raise ValueError("llm_threshold must be a number between 0 and 5") self.set_callable(llm_callable)