Skip to content

Commit

Permalink
[Tidy] Use BaseChatModel for type hint (#550)
Browse files Browse the repository at this point in the history
  • Loading branch information
lingyielia authored Jun 27, 2024
1 parent ee767b3 commit 38353ca
Show file tree
Hide file tree
Showing 12 changed files with 78 additions and 26 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
<!--
A new scriv changelog fragment.
Uncomment the section that is right (remove the HTML comment wrapper).
-->

<!--
### Highlights ✨
- A bullet item for the Highlights ✨ category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Removed
- A bullet item for the Removed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Added
- A bullet item for the Added category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Changed
- A bullet item for the Changed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Deprecated
- A bullet item for the Deprecated category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Fixed
- A bullet item for the Fixed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Security
- A bullet item for the Security category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
5 changes: 2 additions & 3 deletions vizro-ai/src/vizro_ai/chains/_llm_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
from langchain.prompts import PromptTemplate
from langchain.schema import ChatGeneration, Generation
from langchain.schema.messages import AIMessage

from vizro_ai.chains._llm_models import LLM_MODELS
from langchain_core.language_models.chat_models import BaseChatModel

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -38,7 +37,7 @@ class FunctionCallChain(VizroBaseChain, ABC):

def __init__( # noqa: PLR0913
self,
llm: LLM_MODELS,
llm: BaseChatModel,
raw_prompt: str,
partial_vars_map: Optional[Dict[Any, Any]] = None,
llm_kwargs: Optional[Dict[str, Any]] = None,
Expand Down
8 changes: 3 additions & 5 deletions vizro-ai/src/vizro_ai/chains/_llm_models.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from typing import Dict, Optional, Union

from langchain_core.language_models.chat_models import BaseChatModel
from langchain_openai import ChatOpenAI

# TODO add new wrappers in if new model support is added
LLM_MODELS = Union[ChatOpenAI]

# TODO constant of model inventory, can be converted to yaml and link to docs
PREDEFINED_MODELS: Dict[str, Dict[str, Union[int, LLM_MODELS]]] = {
PREDEFINED_MODELS: Dict[str, Dict[str, Union[int, BaseChatModel]]] = {
"gpt-3.5-turbo-0613": {
"max_tokens": 4096,
"wrapper": ChatOpenAI,
Expand Down Expand Up @@ -41,7 +39,7 @@
DEFAULT_TEMPERATURE = 0


def _get_llm_model(model: Optional[Union[ChatOpenAI, str]] = None) -> LLM_MODELS:
def _get_llm_model(model: Optional[Union[ChatOpenAI, str]] = None) -> BaseChatModel:
"""Fetches and initializes an instance of the LLM.
Args:
Expand Down
5 changes: 3 additions & 2 deletions vizro-ai/src/vizro_ai/components/_base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from abc import ABC, abstractmethod
from typing import Dict, Tuple, Union

from langchain_core.language_models.chat_models import BaseChatModel

from vizro_ai.chains import FunctionCallChain
from vizro_ai.chains._llm_models import LLM_MODELS


class VizroAiComponentBase(ABC):
Expand All @@ -22,7 +23,7 @@ class VizroAiComponentBase(ABC):

prompt: str = "default prompt place holder"

def __init__(self, llm: LLM_MODELS):
def __init__(self, llm: BaseChatModel):
"""Initialize Vizro-AI base component.
Args:
Expand Down
5 changes: 3 additions & 2 deletions vizro-ai/src/vizro_ai/components/chart_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
except ImportError: # pragma: no cov
from pydantic import BaseModel, Field

from langchain_core.language_models.chat_models import BaseChatModel

from vizro_ai.chains._chain_utils import _log_time
from vizro_ai.chains._llm_models import LLM_MODELS
from vizro_ai.components import VizroAiComponentBase
from vizro_ai.schema_manager import SchemaManager

Expand Down Expand Up @@ -48,7 +49,7 @@ class GetChartSelection(VizroAiComponentBase):

prompt: str = chart_type_prompt

def __init__(self, llm: LLM_MODELS):
def __init__(self, llm: BaseChatModel):
"""Initialization of Chart Selection components.
Args:
Expand Down
5 changes: 3 additions & 2 deletions vizro-ai/src/vizro_ai/components/code_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
except ImportError: # pragma: no cov
from pydantic import BaseModel, Field

from langchain_core.language_models.chat_models import BaseChatModel

from vizro_ai.chains._chain_utils import _log_time
from vizro_ai.chains._llm_models import LLM_MODELS
from vizro_ai.components import VizroAiComponentBase
from vizro_ai.schema_manager import SchemaManager

Expand Down Expand Up @@ -44,7 +45,7 @@ class GetDebugger(VizroAiComponentBase):

prompt: str = debugging_prompt

def __init__(self, llm: LLM_MODELS):
def __init__(self, llm: BaseChatModel):
"""Initialization of Chart Selection components.
Args:
Expand Down
5 changes: 3 additions & 2 deletions vizro-ai/src/vizro_ai/components/custom_chart_wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
except ImportError: # pragma: no cov
from pydantic import BaseModel, Field

from langchain_core.language_models.chat_models import BaseChatModel

from vizro_ai.chains._chain_utils import _log_time
from vizro_ai.chains._llm_models import LLM_MODELS
from vizro_ai.components import VizroAiComponentBase
from vizro_ai.schema_manager import SchemaManager

Expand Down Expand Up @@ -48,7 +49,7 @@ class GetCustomChart(VizroAiComponentBase):

prompt: str = custom_chart_prompt

def __init__(self, llm: LLM_MODELS):
def __init__(self, llm: BaseChatModel):
"""Initialization of custom chart components.
Args:
Expand Down
5 changes: 3 additions & 2 deletions vizro-ai/src/vizro_ai/components/dataframe_craft.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
except ImportError: # pragma: no cov
from pydantic import BaseModel, Field

from langchain_core.language_models.chat_models import BaseChatModel

from vizro_ai.chains._chain_utils import _log_time
from vizro_ai.chains._llm_models import LLM_MODELS
from vizro_ai.components import VizroAiComponentBase
from vizro_ai.schema_manager import SchemaManager

Expand Down Expand Up @@ -51,7 +52,7 @@ class GetDataFrameCraft(VizroAiComponentBase):

prompt: str = dataframe_prompt

def __init__(self, llm: LLM_MODELS):
def __init__(self, llm: BaseChatModel):
"""Initialization of dataframe crafting components.
Args:
Expand Down
5 changes: 3 additions & 2 deletions vizro-ai/src/vizro_ai/components/explanation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
except ImportError: # pragma: no cov
from pydantic import BaseModel, Field

from langchain_core.language_models.chat_models import BaseChatModel

from vizro_ai.chains._chain_utils import _log_time
from vizro_ai.chains._llm_models import LLM_MODELS
from vizro_ai.components import VizroAiComponentBase
from vizro_ai.schema_manager import SchemaManager

Expand Down Expand Up @@ -44,7 +45,7 @@ class GetCodeExplanation(VizroAiComponentBase):

prompt: str = code_explanation_prompt

def __init__(self, llm: LLM_MODELS):
def __init__(self, llm: BaseChatModel):
"""Initialization of Code Explanation components.
Args:
Expand Down
5 changes: 3 additions & 2 deletions vizro-ai/src/vizro_ai/components/visual_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
except ImportError: # pragma: no cov
from pydantic import BaseModel, Field

from langchain_core.language_models.chat_models import BaseChatModel

from vizro_ai.chains._chain_utils import _log_time
from vizro_ai.chains._llm_models import LLM_MODELS
from vizro_ai.components import VizroAiComponentBase
from vizro_ai.schema_manager import SchemaManager

Expand Down Expand Up @@ -43,7 +44,7 @@ class GetVisualCode(VizroAiComponentBase):

prompt: str = visual_code_prompt

def __init__(self, llm: LLM_MODELS):
def __init__(self, llm: BaseChatModel):
"""Initialization of Chart Selection components.
Args:
Expand Down
4 changes: 2 additions & 2 deletions vizro-ai/src/vizro_ai/task_pipeline/_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

from typing import Any, Dict, List, Optional

from vizro_ai.chains._llm_models import LLM_MODELS
from langchain_core.language_models.chat_models import BaseChatModel


class Pipeline:
"""A pipeline to manage the flow of tasks in a sequence."""

def __init__(self, llm: LLM_MODELS):
def __init__(self, llm: BaseChatModel):
"""Initialize the Pipeline.
Args:
Expand Down
4 changes: 2 additions & 2 deletions vizro-ai/src/vizro_ai/task_pipeline/_pipeline_manager.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"""Pipeline Manager."""

from vizro_ai.chains._llm_models import LLM_MODELS
from langchain_core.language_models.chat_models import BaseChatModel
from vizro_ai.components import GetChartSelection, GetCustomChart, GetDataFrameCraft, GetVisualCode
from vizro_ai.task_pipeline._pipeline import Pipeline


class PipelineManager:
"""Task pipeline manager."""

def __init__(self, llm: LLM_MODELS = None):
def __init__(self, llm: BaseChatModel = None):
"""Initialize the Pipeline Manager.
Args:
Expand Down

0 comments on commit 38353ca

Please sign in to comment.