Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] Enable dashboard code generation #641

Merged
merged 16 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
<!--
A new scriv changelog fragment.

Uncomment the section that is right (remove the HTML comment wrapper).
-->

<!--
### Highlights ✨

- A bullet item for the Highlights ✨ category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Removed

- A bullet item for the Removed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Added

- A bullet item for the Added category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Changed

- A bullet item for the Changed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Deprecated

- A bullet item for the Deprecated category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Fixed

- A bullet item for the Fixed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Security

- A bullet item for the Security category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
81 changes: 55 additions & 26 deletions vizro-ai/examples/example_dashboard.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@
"from vizro_ai import VizroAI\n",
"\n",
"# vizro_ai = VizroAI(model=\"gpt-4-turbo\")\n",
"vizro_ai = VizroAI(model=\"gpt-4o\")\n",
"# vizro_ai = VizroAI()"
"# vizro_ai = VizroAI(model=\"gpt-4o-mini\")\n",
"# vizro_ai = VizroAI(model=\"gpt-4o\")\n",
"vizro_ai = VizroAI()"
]
},
{
Expand All @@ -49,6 +50,16 @@
"df2 = px.data.stocks()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "94885b88-eb85-4060-bbcb-5fa07100892c",
"metadata": {},
"outputs": [],
"source": [
"df3 = px.data.tips()"
]
},
{
"cell_type": "markdown",
"id": "ec46d4d1-d20b-4351-831d-d3d8ddc5cb70",
Expand All @@ -60,7 +71,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "820a5d0f-a31e-4bbd-a924-9629631cc291",
"id": "119ea726-e426-47a6-b5ce-e209ac32c9b9",
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -88,45 +99,49 @@
{
"cell_type": "code",
"execution_count": null,
"id": "0d71e089-8c94-4d12-87bd-d803552acb32",
"id": "eb83882c-148a-460c-983a-56554ef5fc3b",
"metadata": {},
"outputs": [],
"source": [
"dashboard = vizro_ai.dashboard([df1, df2], user_question_2_data)"
"result = vizro_ai.dashboard([df1, df2], user_question_2_data, return_elements=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "14477c56-54e9-43a5-9136-25bc950fdf3a",
"metadata": {},
"id": "fa592f43-5966-4832-a4d7-4e0bb5593fcb",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"Vizro().build(dashboard).run()"
"print(result.code)"
]
},
{
"cell_type": "markdown",
"id": "747964b9-fd05-4c5a-a73a-79dae82320b3",
"cell_type": "code",
"execution_count": null,
"id": "7ad22c55-22a6-4c3e-902d-54adba1084f8",
"metadata": {},
"outputs": [],
"source": [
"# Example: 5-page dashboard request"
"Vizro().build(result.dashboard).run()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "967ff6a4-f138-4643-b993-a72e5cc26de2",
"cell_type": "markdown",
"id": "747964b9-fd05-4c5a-a73a-79dae82320b3",
"metadata": {},
"outputs": [],
"source": [
"df3 = px.data.tips()"
"# Example: 4-page dashboard request\n",
"\n",
"In most cases, using more advanced models produces more stable and accurate output."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cb9347f8",
"id": "46f13f78-665b-4729-a415-8d4ff5b8e6f1",
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -173,36 +188,50 @@
{
"cell_type": "code",
"execution_count": null,
"id": "f0a0cdfa",
"id": "867f8e68-ab78-401c-a6df-c1a0101fb131",
"metadata": {},
"outputs": [],
"source": [
"Vizro._reset()\n",
"dashboard = vizro_ai.dashboard([df1, df2, df3], user_question_3_data)"
"result = vizro_ai.dashboard([df1, df2, df3], user_question_3_data, return_elements=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3167e996",
"id": "85fc6189-a2ba-435d-b499-79855ffc1833",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"print(result.code)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a9125393-9971-4d3d-86cf-89f574eaca0f",
"metadata": {},
"outputs": [],
"source": [
"Vizro().build(dashboard).run()"
"Vizro().build(result.dashboard).run()"
]
},
{
"cell_type": "markdown",
"id": "bbf5c920-0432-4415-996f-1acb9d7b6b8a",
"metadata": {},
"source": [
"# Example: Request with unsupported features"
"# Example: Request with unsupported features\n",
"\n",
"You may encounter warnings in the logs indicating that some features requested are currently unsupported by Vizro-AI. Additionally, validation errors might appear in the logs if the specifications are not supported."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "12d5976e",
"id": "b6f2d50d-5044-49bd-8db9-9c611ac8ed9c",
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -240,7 +269,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "6b4838d1",
"id": "99de2df2-b239-4a6d-b2cf-7131f2ba1559",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -251,7 +280,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "f055bec1",
"id": "84f2683c-cf7f-4547-82b3-a702a98cc7bd",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -275,7 +304,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.4"
"version": "3.10.14"
}
},
"nbformat": 4,
Expand Down
2 changes: 1 addition & 1 deletion vizro-ai/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ dependencies = [
"langchain-openai",
"langgraph>=0.1.2",
"python-dotenv>=1.0.0", # TODO decide env var management to see if we need this
"vizro>=0.1.4", # TODO set upper bound later
"vizro>=0.1.20",
"ipython>=8.10.0", # not directly required, pinned by Snyk to avoid a vulnerability: https://app.snyk.io/vuln/SNYK-PYTHON-IPYTHON-3318382
"aiohttp>=3.9.2", # not directly required, pinned by Snyk to avoid a vulnerability: https://security.snyk.io/vuln/SNYK-PYTHON-AIOHTTP-6209407
"langchain-core>=0.1.31" # not directly required, pinned by Snyk to avoid a vulnerability: https://security.snyk.io/vuln/SNYK-PYTHON-LANGCHAINCORE-6370598
Expand Down
2 changes: 1 addition & 1 deletion vizro-ai/snyk/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ langchain>=0.1.0, <0.3.0
langchain-openai
langgraph>=0.1.2
python-dotenv>=1.0.0
vizro>=0.1.4
vizro>=0.1.20
ipython>=8.10.0
aiohttp>=3.9.2
langchain-core>=0.1.31
13 changes: 9 additions & 4 deletions vizro-ai/src/vizro_ai/_vizro_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from vizro_ai._llm_models import _get_llm_model, _get_model_name
from vizro_ai.dashboard._graph.dashboard_creation import _create_and_compile_graph
from vizro_ai.dashboard.utils import DashboardOutputs, _register_data
from vizro_ai.dashboard.utils import DashboardOutputs, _extract_custom_functions_and_imports, _register_data
from vizro_ai.plot.components import GetCodeExplanation, GetDebugger
from vizro_ai.plot.task_pipeline._pipeline_manager import PipelineManager
from vizro_ai.utils.helper import (
Expand Down Expand Up @@ -66,15 +66,18 @@ def _run_plot_tasks(
user_input: str,
max_debug_retry: int = 3,
explain: bool = False,
chart_name: Optional[str] = None,
) -> PlotOutputs:
"""Task execution."""
chart_type_pipeline = self.pipeline_manager.chart_type_pipeline
chart_type = chart_type_pipeline.run(initial_args={"chain_input": user_input, "df": df})

# TODO update to loop through charts for multiple charts creation
if chart_name is None:
lingyielia marked this conversation as resolved.
Show resolved Hide resolved
chart_name = "custom_chart"
plot_pipeline = self.pipeline_manager.plot_pipeline
custom_chart_code = plot_pipeline.run(
initial_args={"chain_input": user_input, "df": df, "chart_type": chart_type}
initial_args={"chain_input": user_input, "df": df, "chart_type": chart_type, "chart_name": chart_name}
)

# TODO add debug in pipeline after getting _debug_helper logic in component
Expand Down Expand Up @@ -188,15 +191,17 @@ def dashboard(
"pages": [],
"dashboard": None,
"messages": [HumanMessage(content=user_input)],
"custom_charts_code": [],
},
config=config,
)
dashboard = message_res["dashboard"]
_register_data(all_df_metadata=message_res["all_df_metadata"])

if return_elements:
# code = _dashboard_code(dashboard) # TODO: `_dashboard_code` to be implemented
dashboard_output = DashboardOutputs(dashboard=dashboard)
chart_code, imports = _extract_custom_functions_and_imports(message_res["custom_charts_code"])
code = dashboard._to_python(extra_callable_defs=chart_code, extra_imports=imports)
dashboard_output = DashboardOutputs(dashboard=dashboard, code=code)
return dashboard_output
else:
return dashboard
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class GraphState(BaseModel):
dashboard_plan: Plan for the dashboard
pages: Vizro pages
dashboard: Vizro dashboard
custom_charts_code: Custom charts code

"""

Expand All @@ -50,6 +51,7 @@ class GraphState(BaseModel):
dashboard_plan: Optional[DashboardPlan] = None
pages: Annotated[List, operator.add]
dashboard: Optional[vm.Dashboard] = None
custom_charts_code: Annotated[List, operator.add]

class Config:
"""Pydantic configuration."""
Expand Down Expand Up @@ -150,9 +152,9 @@ def _build_page(state: BuildPageState, config: RunnableConfig) -> Dict[str, List
page_plan = state["page_plan"]

llm = config["configurable"].get("model", None)
page = page_plan.create(model=llm, all_df_metadata=all_df_metadata)
page, custom_chart_code = page_plan.create(model=llm, all_df_metadata=all_df_metadata)

return {"pages": [page]}
return {"pages": [page], "custom_charts_code": [custom_chart_code]}


def _continue_to_pages(state: GraphState) -> List[Send]:
Expand Down
Loading
Loading