Skip to content

Commit

Permalink
remove import builder
Browse files Browse the repository at this point in the history
  • Loading branch information
lingyielia committed Jul 10, 2024
1 parent 5532bce commit e6cf358
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 153 deletions.
1 change: 1 addition & 0 deletions vizro-ai/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ dependencies = [
"langchain>=0.1.0, <0.3.0", # TODO update all LLMChain class, update to pydantic v2 and remove upper bound
"langchain-openai",
"langgraph",
"black",
"python-dotenv>=1.0.0", # TODO decide env var management to see if we need this
"vizro>=0.1.4", # TODO set upper bound later
"ipython>=8.10.0", # not directly required, pinned by Snyk to avoid a vulnerability: https://app.snyk.io/vuln/SNYK-PYTHON-IPYTHON-3318382
Expand Down
8 changes: 4 additions & 4 deletions vizro-ai/src/vizro_ai/_vizro_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from vizro_ai.chains._llm_models import _get_llm_model, _get_model_name
from vizro_ai.components import GetCodeExplanation, GetDebugger
from vizro_ai.dashboard.graph.code_generation import _create_and_compile_graph
from vizro_ai.dashboard.utils import DashboardOutputs
from vizro_ai.dashboard.utils import DashboardOutputs, _dashboard_code
from vizro_ai.task_pipeline._pipeline_manager import PipelineManager
from vizro_ai.utils.helper import (
DebugFailure,
Expand Down Expand Up @@ -189,8 +189,8 @@ def dashboard(
},
config=config,
)
dashboard_output = DashboardOutputs(
dashboard=message_res["dashboard"], code=message_res["messages"][-1].content
)
dashboard = message_res["dashboard"]
code = _dashboard_code(dashboard) # TODO: `_dashboard_code` to be implemented
dashboard_output = DashboardOutputs(dashboard=dashboard, code=code)

return dashboard_output if return_elements else dashboard_output.dashboard
70 changes: 6 additions & 64 deletions vizro-ai/src/vizro_ai/dashboard/graph/code_generation.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,18 @@
"""Code generation graph for dashboard generation."""

import inspect
import logging
import operator
import re
from typing import Annotated, Any, Dict, List, Union

import pandas as pd
import vizro.models as vm
from langchain_core.messages import BaseMessage, FunctionMessage, HumanMessage
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.constants import END, Send
from langgraph.graph import StateGraph
from vizro_ai.dashboard.nodes.build import PageBuilder
from vizro_ai.dashboard.nodes.data_summary import DfInfo, _get_df_info, df_sum_prompt
from vizro_ai.dashboard.nodes.imports_builder import ModelSummary, _generate_import_statement, model_sum_prompt
from vizro_ai.dashboard.nodes.plan import (
DashboardPlanner,
PagePlanner,
Expand All @@ -31,14 +29,6 @@
logger.setLevel(logging.INFO)


DASHBOARD_CODE_TEMPLATE = """
{import_statement}
dashboard={dashboard_code_str}
Vizro().build(dashboard).run()
"""


DfMetadata = Dict[str, Dict[str, Union[Dict[str, str], pd.DataFrame]]]
"""Cleaned dataframe names and their metadata."""

Expand Down Expand Up @@ -110,22 +100,6 @@ def _store_df_info(state: GraphState, config: RunnableConfig) -> Dict[str, DfMet
return {"df_metadata": df_metadata}


def _compose_imports_code(state: GraphState, config: RunnableConfig) -> Dict[str, Messages]:
"""Generate code snippet for imports."""
logger.info("*** _compose_imports_code ***")
messages = state.messages

llm = config["configurable"].get("model", None)
model_sum_chain = model_sum_prompt | llm.with_structured_output(ModelSummary)

vizro_model_summary = model_sum_chain.invoke({"messages": messages})

import_statement = _generate_import_statement(vizro_model_summary)

messages.append(FunctionMessage(content=import_statement, name=inspect.currentframe().f_code.co_name))
return {"messages": messages}


def _dashboard_plan(state: GraphState, config: RunnableConfig) -> Dict[str, DashboardPlanner]:
"""Generate a dashboard plan."""
logger.info("*** _dashboard_plan ***")
Expand All @@ -139,28 +113,6 @@ def _dashboard_plan(state: GraphState, config: RunnableConfig) -> Dict[str, Dash
return {"dashboard_plan": dashboard_plan}


def _generate_dashboard_code(state: GraphState) -> Dict[str, Messages]:
"""Generate a dashboard code snippet."""
logger.info("*** _generate_dashboard_code ***")
messages = state.messages
import_statement = messages[-1].content
dashboard = state.dashboard

# TODO: the code to string should come from vizro_core
# Currently, the output code string is a string representation of the dashboard object
dashboard_code_str = repr(dashboard)

messages.append(
FunctionMessage(
content=DASHBOARD_CODE_TEMPLATE.format(
import_statement=import_statement, dashboard_code_str=dashboard_code_str
),
name=inspect.currentframe().f_code.co_name,
)
)
return {"messages": messages}


class BuildPageState(BaseModel):
"""Represents the state of building the page.
Expand Down Expand Up @@ -189,7 +141,7 @@ def _build_page(state: BuildPageState, config: RunnableConfig) -> Dict[str, List
return {"pages": [page]}


def continue_to_pages(state: GraphState) -> List[Send]:
def _continue_to_pages(state: GraphState) -> List[Send]:
"""Continue to build pages."""
logger.info("*** build_page ***")
df_metadata = state.df_metadata
Expand All @@ -213,20 +165,15 @@ def _create_and_compile_graph():
graph = StateGraph(GraphState)

graph.add_node("_store_df_info", _store_df_info)
graph.add_node("_compose_imports_code", _compose_imports_code)
graph.add_node("_dashboard_plan", _dashboard_plan)
graph.add_node("_build_page", _build_page)
graph.add_node("_build_dashboard", _build_dashboard)

graph.add_node("_generate_dashboard_code", _generate_dashboard_code)

graph.add_edge("_store_df_info", "_compose_imports_code")
graph.add_edge("_compose_imports_code", "_dashboard_plan")
graph.add_conditional_edges("_dashboard_plan", continue_to_pages)
graph.add_edge("_store_df_info", "_dashboard_plan")
graph.add_conditional_edges("_dashboard_plan", _continue_to_pages)
graph.add_edge("_build_page", "_build_dashboard")

graph.add_edge("_build_dashboard", "_generate_dashboard_code")
graph.add_edge("_generate_dashboard_code", END)
graph.add_edge("_build_dashboard", END)

graph.set_entry_point("_store_df_info")

Expand All @@ -244,14 +191,9 @@ def _create_and_compile_graph():
the dashboard should be: `My wonderful \njolly dashboard showing a lot of info`.\n
The layout of this page should use `grid=[[0,1]]`
"""
previous_fn_res = """"
from vizro import Vizro\nfrom vizro.models import AgGrid, Card, Dashboard, Page\nfrom "
"vizro.tables import dash_ag_grid\nimport pandas as pd\n"
"""
test_state = {
"messages": [
HumanMessage(content=user_input),
FunctionMessage(content=previous_fn_res, name="_compose_imports_code"),
],
"dfs": [
pd.DataFrame(),
Expand Down Expand Up @@ -284,5 +226,5 @@ def _create_and_compile_graph():
},
}
sample_state = GraphState(**test_state)
message = _generate_dashboard_code(sample_state)
message = _dashboard_plan(sample_state)
print(message) # noqa: T201
85 changes: 0 additions & 85 deletions vizro-ai/src/vizro_ai/dashboard/nodes/imports_builder.py

This file was deleted.

19 changes: 19 additions & 0 deletions vizro-ai/src/vizro_ai/dashboard/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
"""Helper Functions For Vizro AI dashboard."""

from dataclasses import dataclass

# import black
from typing import Any

import tqdm.std
import vizro.models as vm

IMPORT_STATEMENTS = (
"import vizro.plotly.express as px\n"
"from vizro.models.types import capture\n"
"import plotly.graph_objects as go\n"
"from vizro.tables import dash_ag_grid\n"
"import vizro.models as vm\n"
)


@dataclass
class DashboardOutputs:
Expand Down Expand Up @@ -46,3 +56,12 @@ def _get_tqdm():
else:
from tqdm import tqdm, trange
return tqdm, trange


def _dashboard_code(dashboard: vm.Dashboard) -> str:
"""Generate dashboard code from dashboard object."""
dashboard_code_str = IMPORT_STATEMENTS + repr(dashboard)

# TODO: use black or ruff to format the code
# formatted_code = black.format_str(dashboard_code_str, mode=black.Mode())
return dashboard_code_str

0 comments on commit e6cf358

Please sign in to comment.