From 93dd4884d23286320b0dc214d7136396c767adfc Mon Sep 17 00:00:00 2001 From: Avi-Robusta <97387909+Avi-Robusta@users.noreply.github.com> Date: Tue, 8 Oct 2024 19:45:30 +0300 Subject: [PATCH] Bugfix - backwards compatibility with model names for openai (#153) --model="openai/gpt-4o" would throw an exception due to no longer including the prefix openai in the model_cost table https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json Screen Shot 2024-10-08 at 19 20 18 --- holmes/core/tool_calling_llm.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/holmes/core/tool_calling_llm.py b/holmes/core/tool_calling_llm.py index f03fd45..42135fb 100644 --- a/holmes/core/tool_calling_llm.py +++ b/holmes/core/tool_calling_llm.py @@ -82,18 +82,32 @@ def check_llm(self, model, api_key): if not model_requirements["keys_in_environment"]: raise Exception(f"model {model} requires the following environment variables: {model_requirements['missing_keys']}") + def _strip_model_prefix(self) -> str: + """ + Helper function to strip 'openai/' prefix from model name if it exists. + model cost is taken from here which does not have the openai prefix + https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json + """ + model_name = self.model + if model_name.startswith('openai/'): + model_name = model_name[len('openai/'):] # Strip the 'openai/' prefix + return model_name + + # this unfortunately does not seem to work for azure if the deployment name is not a well-known model name #if not litellm.supports_function_calling(model=model): # raise Exception(f"model {model} does not support function calling. You must use HolmesGPT with a model that supports function calling.") def get_context_window_size(self) -> int: - return litellm.model_cost[self.model]['max_input_tokens'] + model_name = self._strip_model_prefix() + return litellm.model_cost[model_name]['max_input_tokens'] def count_tokens_for_message(self, messages: list[dict]) -> int: return litellm.token_counter(model=self.model, messages=messages) def get_maximum_output_token(self) -> int: - return litellm.model_cost[self.model]['max_output_tokens'] + model_name = self._strip_model_prefix() + return litellm.model_cost[model_name]['max_output_tokens'] def call(self, system_prompt, user_prompt, post_process_prompt: Optional[str] = None, response_format: dict = None) -> LLMResult: messages = [