Skip to content

Commit

Permalink
Merge pull request #6 from guardrails-ai/fix_gpt4o_call
Browse files Browse the repository at this point in the history
Fix typos and docstrings
  • Loading branch information
wylansford authored Jun 3, 2024
2 parents 5049bd3 + 800f546 commit 3096256
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions validator/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,16 @@ class RestrictToTopic(Validator):
(one or many).
invalid_topics (List[str], Optional, defaults to []): topics that the
text cannot be about.
device (int, Optional, defaults to -1): Device ordinal for CPU/GPU
supports for Zero-Shot classifier. Setting this to -1 will leverage
device (Optional[Union[str, int]], Optional, defaults to -1): Device ordinal for
CPU/GPU supports for Zero-Shot classifier. Setting this to -1 will leverage
CPU, a positive will run the Zero-Shot model on the associated CUDA
device id.
model (str, Optional, defaults to 'facebook/bart-large-mnli'): The
Zero-Shot model that will be used to classify the topic. See a
list of all models here:
https://huggingface.co/models?pipeline_tag=zero-shot-classification
llm_callable (Union[str, Callable, None], Optional, defaults to
'gpt-3.5-turbo'): Either the name of the OpenAI model, or a callable
'gpt-4o'): Either the name of the OpenAI model, or a callable
that takes a prompt and returns a response.
disable_classifier (bool, Optional, defaults to False): controls whether
to use the Zero-Shot model. At least one of disable_classifier and
Expand Down Expand Up @@ -85,7 +85,7 @@ def __init__(
disable_llm: Optional[bool] = False,
on_fail: Optional[Callable[..., Any]] = None,
zero_shot_threshold: Optional[float] = 0.5,
llm_theshold: Optional[int] = 3,
llm_threshold: Optional[int] = 3,
):
super().__init__(
valid_topics=valid_topics,
Expand All @@ -98,7 +98,7 @@ def __init__(
llm_callable=llm_callable,
on_fail=on_fail,
zero_shot_threshold=zero_shot_threshold,
llm_theshold=llm_theshold,
llm_threshold=llm_threshold,
)
self._valid_topics = valid_topics

Expand All @@ -121,7 +121,7 @@ def __init__(
if self._zero_shot_threshold < 0 or self._zero_shot_threshold > 1:
raise ValueError("zero_shot_threshold must be a number between 0 and 1")

self._llm_threshold = llm_theshold
self._llm_threshold = llm_threshold
if self._llm_threshold < 0 or self._llm_threshold > 5:
raise ValueError("llm_threshold must be a number between 0 and 5")
self.set_callable(llm_callable)
Expand All @@ -138,7 +138,6 @@ def __init__(
# TODO api endpoint
...


def get_topics_ensemble(self, text: str, candidate_topics: List[str]) -> List[str]:
"""Finds the topics in the input text based on if it is determined by the zero
shot model or the llm.
Expand Down

0 comments on commit 3096256

Please sign in to comment.