From f670aa966f61ee9c322a2c41a8a8932340c70251 Mon Sep 17 00:00:00 2001 From: Wyatt Lansford <22553069+wylansford@users.noreply.github.com> Date: Mon, 3 Jun 2024 11:43:16 -0700 Subject: [PATCH 1/2] fixing typing docs --- validator/main.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/validator/main.py b/validator/main.py index ae52a83..e552e85 100644 --- a/validator/main.py +++ b/validator/main.py @@ -46,8 +46,8 @@ class RestrictToTopic(Validator): (one or many). invalid_topics (List[str], Optional, defaults to []): topics that the text cannot be about. - device (int, Optional, defaults to -1): Device ordinal for CPU/GPU - supports for Zero-Shot classifier. Setting this to -1 will leverage + device (Optional[Union[str, int]], Optional, defaults to -1): Device ordinal for + CPU/GPU supports for Zero-Shot classifier. Setting this to -1 will leverage CPU, a positive will run the Zero-Shot model on the associated CUDA device id. model (str, Optional, defaults to 'facebook/bart-large-mnli'): The @@ -55,7 +55,7 @@ class RestrictToTopic(Validator): list of all models here: https://huggingface.co/models?pipeline_tag=zero-shot-classification llm_callable (Union[str, Callable, None], Optional, defaults to - 'gpt-3.5-turbo'): Either the name of the OpenAI model, or a callable + 'gpt-4o'): Either the name of the OpenAI model, or a callable that takes a prompt and returns a response. disable_classifier (bool, Optional, defaults to False): controls whether to use the Zero-Shot model. At least one of disable_classifier and @@ -138,7 +138,6 @@ def __init__( # TODO api endpoint ... - def get_topics_ensemble(self, text: str, candidate_topics: List[str]) -> List[str]: """Finds the topics in the input text based on if it is determined by the zero shot model or the llm. From 800f546e69e728775babad082b03500e91bda746 Mon Sep 17 00:00:00 2001 From: Wyatt Lansford <22553069+wylansford@users.noreply.github.com> Date: Mon, 3 Jun 2024 11:44:33 -0700 Subject: [PATCH 2/2] fixing typo theshold -> threshold --- validator/main.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/validator/main.py b/validator/main.py index e552e85..8a796cc 100644 --- a/validator/main.py +++ b/validator/main.py @@ -85,7 +85,7 @@ def __init__( disable_llm: Optional[bool] = False, on_fail: Optional[Callable[..., Any]] = None, zero_shot_threshold: Optional[float] = 0.5, - llm_theshold: Optional[int] = 3, + llm_threshold: Optional[int] = 3, ): super().__init__( valid_topics=valid_topics, @@ -98,7 +98,7 @@ def __init__( llm_callable=llm_callable, on_fail=on_fail, zero_shot_threshold=zero_shot_threshold, - llm_theshold=llm_theshold, + llm_threshold=llm_threshold, ) self._valid_topics = valid_topics @@ -121,7 +121,7 @@ def __init__( if self._zero_shot_threshold < 0 or self._zero_shot_threshold > 1: raise ValueError("zero_shot_threshold must be a number between 0 and 1") - self._llm_threshold = llm_theshold + self._llm_threshold = llm_threshold if self._llm_threshold < 0 or self._llm_threshold > 5: raise ValueError("llm_threshold must be a number between 0 and 5") self.set_callable(llm_callable)