diff --git a/litellm_example.py b/litellm_example.py index 4182cf1..9e3911a 100644 --- a/litellm_example.py +++ b/litellm_example.py @@ -2,4 +2,4 @@ model = LiteLLMModel() output = model.run("hey") -print(output) \ No newline at end of file +print(output) diff --git a/pyproject.toml b/pyproject.toml index d62251d..736964b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "swarm-models" -version = "0.1.5" +version = "0.1.6" description = "Swarm Models - Pytorch" license = "MIT" authors = ["Kye Gomez "] diff --git a/swarm_models/lite_llm_model.py b/swarm_models/lite_llm_model.py index 9f4a9c0..36c51b9 100644 --- a/swarm_models/lite_llm_model.py +++ b/swarm_models/lite_llm_model.py @@ -1,17 +1,20 @@ from litellm import completion, acompletion from loguru import logger + class LiteLLMModel: """ This class represents a LiteLLMModel. It is used to interact with the LLM model for various tasks. """ + def __init__( self, model_name: str = "gpt-4o", system_prompt: str = None, stream: bool = False, temperature: float = 0.5, + max_tokens: int = 4000, ): """ Initialize the LiteLLMModel with the given parameters. @@ -20,18 +23,21 @@ def __init__( self.system_prompt = system_prompt self.stream = stream self.temperature = temperature + self.max_tokens = max_tokens def _prepare_messages(self, task: str) -> list: """ Prepare the messages for the given task. """ messages = [] - + if self.system_prompt: # Check if system_prompt is not None - messages.append({"role": "system", "content": self.system_prompt}) - + messages.append( + {"role": "system", "content": self.system_prompt} + ) + messages.append({"role": "user", "content": task}) - + return messages def run(self, task: str, *args, **kwargs): @@ -39,16 +45,20 @@ def run(self, task: str, *args, **kwargs): Run the LLM model for the given task. """ messages = self._prepare_messages(task) - + response = completion( model=self.model_name, messages=messages, stream=self.stream, temperature=self.temperature, + max_completion_tokens=self.max_tokens, + max_tokens=self.max_tokens, *args, - **kwargs + **kwargs, ) - content = response.choices[0].message.content # Accessing the content + content = response.choices[ + 0 + ].message.content # Accessing the content return content def __call__(self, task: str, *args, **kwargs):