Skip to content

Commit

Permalink
[Feat] Refactor of vizro_ai.plot (breaking) (#646)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Lingyi Zhang <[email protected]>
  • Loading branch information
3 people authored Sep 20, 2024
1 parent d8e51c4 commit f1e68ae
Show file tree
Hide file tree
Showing 44 changed files with 671 additions and 2,155 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
<!--
A new scriv changelog fragment.
Uncomment the section that is right (remove the HTML comment wrapper).
-->

<!--
### Highlights ✨
- A bullet item for the Highlights ✨ category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Removed
- A bullet item for the Removed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Added
- A bullet item for the Added category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Changed
- A bullet item for the Changed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Deprecated
- A bullet item for the Deprecated category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Fixed
- A bullet item for the Fixed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Security
- A bullet item for the Security category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
4 changes: 2 additions & 2 deletions vizro-ai/examples/dashboard_ui/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def get_vizro_ai_plot(user_prompt, df, model, api_key, api_base, vendor_input):
vendor = SUPPORTED_VENDORS[vendor_input]
llm = vendor(model_name=model, openai_api_key=api_key, openai_api_base=api_base)
vizro_ai = VizroAI(model=llm)
ai_outputs = vizro_ai.plot(df, user_prompt, explain=False, return_elements=True)
ai_outputs = vizro_ai.plot(df, user_prompt, return_elements=True)

return ai_outputs

Expand Down Expand Up @@ -76,7 +76,7 @@ def create_response(ai_response, figure, user_prompt, filename):
vendor_input=vendor_input,
)
ai_code = ai_outputs.code
figure = ai_outputs.figure
figure = ai_outputs.get_fig_object(data_frame=df)
formatted_code = black.format_str(ai_code, mode=black.Mode(line_length=100))

ai_response = "\n".join(["```python", formatted_code, "```"])
Expand Down
44 changes: 29 additions & 15 deletions vizro-ai/examples/example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,40 @@
"source": [
"from vizro_ai import VizroAI\n",
"\n",
"vizro_ai = VizroAI()\n",
"llm = None\n",
"# uncomment below to update model to use other models\n",
"# llm = \"gpt-4o\"\n",
"\n",
"# uncomment below to update model to gpt4 if you would like to create more complex chart\n",
"# vizro_ai = VizroAI(model=\"gpt-4-0613\")"
"\n",
"# import os\n",
"# from langchain_anthropic import ChatAnthropic\n",
"# llm = ChatAnthropic(\n",
"# model=\"claude-3-5-sonnet-20240620\",\n",
"# api_key = os.environ.get(\"ANTHROPIC_API_KEY\"),\n",
"# base_url= os.environ.get(\"ANTHROPIC_API_BASE\")\n",
"# )\n",
"\n",
"# import os\n",
"# from langchain_openai import AzureChatOpenAI\n",
"# llm = AzureChatOpenAI(\n",
"# azure_deployment=\"gpt-4-1106-preview\", # or your deployment\n",
"# api_version=\"2024-04-01-preview\", # or your api version\n",
"# temperature=0,\n",
"# max_tokens=None,\n",
"# timeout=None,\n",
"# max_retries=2,\n",
"# azure_endpoint=os.environ[\"AZURE_OPENAI_ENDPOINT\"],\n",
"# api_key=os.environ[\"AZURE_OPENAI_API_KEY\"]\n",
"# )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"metadata": {},
"outputs": [],
"source": [
"from dotenv import load_dotenv\n",
"\n",
"load_dotenv()"
"vizro_ai = VizroAI(model=llm)"
]
},
{
Expand Down Expand Up @@ -67,7 +81,7 @@
"metadata": {},
"outputs": [],
"source": [
"vizro_ai.plot(df, \"describe the composition of gdp in continent, and add horizontal line for avg gdp\", explain=True)"
"vizro_ai.plot(df, \"describe the composition of gdp in continent, and add horizontal line for avg gdp\")"
]
},
{
Expand All @@ -76,7 +90,7 @@
"metadata": {},
"outputs": [],
"source": [
"vizro_ai.plot(df, \"show me the geo distribution of life expectancy and set year as animation \", explain=True)"
"vizro_ai.plot(df, \"show me the geo distribution of life expectancy and set year as animation \")"
]
}
],
Expand All @@ -96,7 +110,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
"version": "3.12.5"
}
},
"nbformat": 4,
Expand Down
2 changes: 1 addition & 1 deletion vizro-ai/examples/example_dashboard.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
"version": "3.12.4"
}
},
"nbformat": 4,
Expand Down
7 changes: 5 additions & 2 deletions vizro-ai/hatch.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,13 @@ dependencies = [
"nbformat>=4.2.0",
"pyhamcrest",
"jupyter",
"langchain_community",
"dash[testing]",
"chromedriver-autoinstaller>=0.6.4",
"urllib3<2.0.0" # helps to resolve bug with urllib3 timeout from vizro-ai integration tests: https://bugs.launchpad.net/python-jenkins/+bug/2018567
# "urllib3<2.0.0", # helps to resolve bug with urllib3 timeout from vizro-ai integration tests: https://bugs.launchpad.net/python-jenkins/+bug/2018567
# Below dependencies useful to test different models
"langchain-community",
"langchain_mistralai",
"langchain-anthropic"
]
installer = "uv"

Expand Down
9 changes: 5 additions & 4 deletions vizro-ai/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@ dependencies = [
"pandas",
"tabulate",
"openai>=1.0.0",
"langchain>=0.1.0, <0.3.0", # TODO update all LLMChain class, update to pydantic v2 and remove upper bound
"langchain-openai",
"langgraph>=0.1.2, <0.2.17", # latest version break pydantic v1 compatibility
"langchain>=0.1.0, <0.3.0", # TODO update to pydantic v2 and remove upper bound
"langgraph>=0.1.2, <0.2.17", # TODO update to pydantic v2 and remove upper bound (latest version break pydantic v1 compatibility)
"python-dotenv>=1.0.0", # TODO decide env var management to see if we need this
"vizro>=0.1.20"
"vizro>=0.1.20",
"langchain-openai", # Base dependency, ie minimum model working, but ideally we could get rid of this
"ruff"
]
description = "Vizro-AI is a tool for generating data visualizations"
dynamic = ["version"]
Expand Down
2 changes: 1 addition & 1 deletion vizro-ai/src/vizro_ai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

__all__ = ["VizroAI"]

__version__ = "0.2.4.dev0"
__version__ = "0.3.0.dev0"

# TODO: I think this collides with the VIZRO_LOG_LEVEL setting, as basicConfig can only be set once
logging.basicConfig(level=os.getenv("VIZRO_AI_LOG_LEVEL", "INFO"))
15 changes: 12 additions & 3 deletions vizro-ai/src/vizro_ai/_llm_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_openai import ChatOpenAI

try:
from langchain_anthropic import ChatAnthropic
except ImportError:
ChatAnthropic = None

SUPPORTED_MODELS = {
"OpenAI": [
"gpt-4-0613",
Expand All @@ -21,16 +26,20 @@
"gpt-4o-mini",
"gpt-4o-mini-2024-07-18",
],
"Anthropic": [
"claude-3-sonnet-20240229",
],
}

DEFAULT_WRAPPER_MAP: Dict[str, BaseChatModel] = {"OpenAI": ChatOpenAI}
DEFAULT_MODEL = "gpt-3.5-turbo"
DEFAULT_WRAPPER_MAP: Dict[str, BaseChatModel] = {"OpenAI": ChatOpenAI, "Anthropic": ChatAnthropic}

DEFAULT_MODEL = "gpt-4o-mini"
DEFAULT_TEMPERATURE = 0

model_to_vendor = {model: key for key, models in SUPPORTED_MODELS.items() for model in models}


def _get_llm_model(model: Optional[Union[ChatOpenAI, str]] = None) -> BaseChatModel:
def _get_llm_model(model: Optional[Union[BaseChatModel, str]] = None) -> BaseChatModel:
"""Fetches and initializes an instance of the LLM.
Args:
Expand Down
Loading

0 comments on commit f1e68ae

Please sign in to comment.