From e6ec9e9ba4336168ce7874c09d07157be8bbff5a Mon Sep 17 00:00:00 2001 From: Sanjiv Das Date: Thu, 26 Sep 2024 13:15:54 -0700 Subject: [PATCH] Migrate to `ChatOllama` base class in Ollama provider (#1015) * Added separate `ollama` provider Created a separate file `ollama.py` as a unique provider. Refactored other code accordingly. Also changed the `Ollama` class to `ChatOllama` so that it can support binding tools to the LLM. Updated the imports to come from `langchain_ollama` instead of `langchain_community` Tested on several Ollama models, both LLMs and embedding models: `mxbai-embed-large`, `nomic-embed-text`, `ima/deepseek-math`, `mathstral`, `qwen2-math`, `snowflake-arctic-embed`, `mistral`, `llama3.1`, `starcoder2:15b-instruct` * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../jupyter_ai_magics/__init__.py | 2 -- .../jupyter_ai_magics/embedding_providers.py | 14 -------- .../partner_providers/ollama.py | 32 +++++++++++++++++++ .../jupyter_ai_magics/providers.py | 23 +------------ packages/jupyter-ai-magics/pyproject.toml | 5 +-- 5 files changed, 36 insertions(+), 40 deletions(-) create mode 100644 packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/ollama.py diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py b/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py index f43dad09d..6239ecd59 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py @@ -5,7 +5,6 @@ BaseEmbeddingsProvider, GPT4AllEmbeddingsProvider, HfHubEmbeddingsProvider, - OllamaEmbeddingsProvider, QianfanEmbeddingsEndpointProvider, ) from .exception import store_exception @@ -21,7 +20,6 @@ BaseProvider, GPT4AllProvider, HfHubProvider, - OllamaProvider, QianfanProvider, TogetherAIProvider, ) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py index f1abc7ed1..695465488 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py @@ -10,7 +10,6 @@ from langchain_community.embeddings import ( GPT4AllEmbeddings, HuggingFaceHubEmbeddings, - OllamaEmbeddings, QianfanEmbeddingsEndpoint, ) @@ -65,19 +64,6 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs, **model_kwargs) -class OllamaEmbeddingsProvider(BaseEmbeddingsProvider, OllamaEmbeddings): - id = "ollama" - name = "Ollama" - # source: https://ollama.com/library - models = [ - "nomic-embed-text", - "mxbai-embed-large", - "all-minilm", - "snowflake-arctic-embed", - ] - model_id_key = "model" - - class HfHubEmbeddingsProvider(BaseEmbeddingsProvider, HuggingFaceHubEmbeddings): id = "huggingface_hub" name = "Hugging Face Hub" diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/ollama.py b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/ollama.py new file mode 100644 index 000000000..5babc5adb --- /dev/null +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/ollama.py @@ -0,0 +1,32 @@ +from langchain_ollama import ChatOllama, OllamaEmbeddings + +from ..embedding_providers import BaseEmbeddingsProvider +from ..providers import BaseProvider, EnvAuthStrategy, TextField + + +class OllamaProvider(BaseProvider, ChatOllama): + id = "ollama" + name = "Ollama" + model_id_key = "model" + help = ( + "See [https://www.ollama.com/library](https://www.ollama.com/library) for a list of models. " + "Pass a model's name; for example, `deepseek-coder-v2`." + ) + models = ["*"] + registry = True + fields = [ + TextField(key="base_url", label="Base API URL (optional)", format="text"), + ] + + +class OllamaEmbeddingsProvider(BaseEmbeddingsProvider, OllamaEmbeddings): + id = "ollama" + name = "Ollama" + # source: https://ollama.com/library + models = [ + "nomic-embed-text", + "mxbai-embed-large", + "all-minilm", + "snowflake-arctic-embed", + ] + model_id_key = "model" diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py index 023e51a62..8a28c5251 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py @@ -29,13 +29,7 @@ from langchain.schema.output_parser import StrOutputParser from langchain.schema.runnable import Runnable from langchain_community.chat_models import QianfanChatEndpoint -from langchain_community.llms import ( - AI21, - GPT4All, - HuggingFaceEndpoint, - Ollama, - Together, -) +from langchain_community.llms import AI21, GPT4All, HuggingFaceEndpoint, Together from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.language_models.llms import BaseLLM @@ -707,21 +701,6 @@ async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]: return await self._call_in_executor(*args, **kwargs) -class OllamaProvider(BaseProvider, Ollama): - id = "ollama" - name = "Ollama" - model_id_key = "model" - help = ( - "See [https://www.ollama.com/library](https://www.ollama.com/library) for a list of models. " - "Pass a model's name; for example, `deepseek-coder-v2`." - ) - models = ["*"] - registry = True - fields = [ - TextField(key="base_url", label="Base API URL (optional)", format="text"), - ] - - class TogetherAIProvider(BaseProvider, Together): id = "togetherai" name = "Together AI" diff --git a/packages/jupyter-ai-magics/pyproject.toml b/packages/jupyter-ai-magics/pyproject.toml index 9273c2d9a..91ef5699f 100644 --- a/packages/jupyter-ai-magics/pyproject.toml +++ b/packages/jupyter-ai-magics/pyproject.toml @@ -48,6 +48,7 @@ all = [ "langchain_mistralai", "langchain_nvidia_ai_endpoints", "langchain_openai", + "langchain_ollama", "pillow", "boto3", "qianfan", @@ -61,7 +62,7 @@ anthropic-chat = "jupyter_ai_magics.partner_providers.anthropic:ChatAnthropicPro cohere = "jupyter_ai_magics.partner_providers.cohere:CohereProvider" gpt4all = "jupyter_ai_magics:GPT4AllProvider" huggingface_hub = "jupyter_ai_magics:HfHubProvider" -ollama = "jupyter_ai_magics:OllamaProvider" +ollama = "jupyter_ai_magics.partner_providers.ollama:OllamaProvider" openai = "jupyter_ai_magics.partner_providers.openai:OpenAIProvider" openai-chat = "jupyter_ai_magics.partner_providers.openai:ChatOpenAIProvider" azure-chat-openai = "jupyter_ai_magics.partner_providers.openai:AzureChatOpenAIProvider" @@ -83,7 +84,7 @@ cohere = "jupyter_ai_magics.partner_providers.cohere:CohereEmbeddingsProvider" mistralai = "jupyter_ai_magics.partner_providers.mistralai:MistralAIEmbeddingsProvider" gpt4all = "jupyter_ai_magics:GPT4AllEmbeddingsProvider" huggingface_hub = "jupyter_ai_magics:HfHubEmbeddingsProvider" -ollama = "jupyter_ai_magics:OllamaEmbeddingsProvider" +ollama = "jupyter_ai_magics.partner_providers.ollama:OllamaEmbeddingsProvider" openai = "jupyter_ai_magics.partner_providers.openai:OpenAIEmbeddingsProvider" qianfan = "jupyter_ai_magics:QianfanEmbeddingsEndpointProvider"