diff --git a/vizro-ai/src/vizro_ai/dashboard/graph/code_generation.py b/vizro-ai/src/vizro_ai/dashboard/graph/code_generation.py index 9a2fd3cd0..fb6ff457d 100644 --- a/vizro-ai/src/vizro_ai/dashboard/graph/code_generation.py +++ b/vizro-ai/src/vizro_ai/dashboard/graph/code_generation.py @@ -5,6 +5,7 @@ from typing import Any, Dict, List import pandas as pd +from langchain.globals import set_debug from langgraph.graph import END, StateGraph from vizro.models import Dashboard from vizro_ai.chains._llm_models import _get_llm_model @@ -49,11 +50,13 @@ class Config: arbitrary_types_allowed = True + @validator("dfs") @validator("dfs") def check_dataframes(cls, v): """Check if the dataframes are valid.""" if not isinstance(v, list): raise ValueError("dfs must be a list") + raise ValueError("dfs must be a list") for df in v: if not isinstance(df, pd.DataFrame): raise ValueError("Each element in dfs must be a Pandas DataFrame") @@ -75,6 +78,7 @@ def _store_df_info(state: GraphState): for _, df in enumerate(dfs): df_schema, df_sample = _get_df_info(df) data_sum_chain = df_sum_prompt | _get_llm_model(model=model_default).with_structured_output(DfInfo) + data_sum_chain = df_sum_prompt | _get_llm_model(model=model_default).with_structured_output(DfInfo) df_name = data_sum_chain.invoke( {"df_schema": df_schema, "df_sample": df_sample, "messages": messages, "current_df_names": current_df_names} @@ -165,6 +169,7 @@ def _generate_dashboard_code(state: GraphState): Args: state (dict): The current graph state + """ logger.info("*** _generate_dashboard_code ***") messages = state.messages diff --git a/vizro-ai/src/vizro_ai/dashboard/nodes/core_builder/build.py b/vizro-ai/src/vizro_ai/dashboard/nodes/core_builder/build.py index fa5a8c0a4..9a0e67901 100644 --- a/vizro-ai/src/vizro_ai/dashboard/nodes/core_builder/build.py +++ b/vizro-ai/src/vizro_ai/dashboard/nodes/core_builder/build.py @@ -38,6 +38,9 @@ def _build_components(self): components.append( self._page_plan.components.components[i].create(df_metadata=self._df_metadata, model=self._model) ) + components.append( + self._page_plan.components.components[i].create(df_metadata=self._df_metadata, model=self._model) + ) except DebugFailure as e: components.append( vm.Card( @@ -46,6 +49,7 @@ def _build_components(self): ) return components + @property def layout(self): """Property to get layout.""" @@ -53,6 +57,7 @@ def layout(self): self._layout = self._build_layout() return self._layout + def _build_layout(self): logger.info(f"Building layout: {self._page_plan}") return self._page_plan.layout.create(model=self._model, df_metadata=self._df_metadata) @@ -76,10 +81,13 @@ def _build_controls(self): ): control = self._page_plan.controls.controls[i].create( model=self._model, available_components=self.available_components, df_metadata=self._df_metadata + ) + model=self._model, available_components=self.available_components, df_metadata=self._df_metadata ) if control: controls.append(control) + return controls @property diff --git a/vizro-ai/src/vizro_ai/dashboard/nodes/core_builder/model.py b/vizro-ai/src/vizro_ai/dashboard/nodes/core_builder/model.py index ad358a34f..01ae2b8a8 100644 --- a/vizro-ai/src/vizro_ai/dashboard/nodes/core_builder/model.py +++ b/vizro-ai/src/vizro_ai/dashboard/nodes/core_builder/model.py @@ -9,6 +9,7 @@ from langchain_core.prompts import ChatPromptTemplate + class ProxyVizroBaseModel(BaseModel): """Proxy model for VizroBaseModel.""" diff --git a/vizro-ai/src/vizro_ai/dashboard/nodes/core_builder/vizro_ai_db.py b/vizro-ai/src/vizro_ai/dashboard/nodes/core_builder/vizro_ai_db.py new file mode 100644 index 000000000..6bba180df --- /dev/null +++ b/vizro-ai/src/vizro_ai/dashboard/nodes/core_builder/vizro_ai_db.py @@ -0,0 +1,20 @@ +from typing import Dict + +from .build import DashboardBuilder +from .plan import _get_dashboard_plan, _print_dashboard_plan + + +class VizroAIDashboard: + def __init__(self, model): + self.model = model + self.dashboard_plan = None + + def _build_dashboard(self, query: str, df_metadata: Dict[str, Dict[str, str]]): + self.dashboard_plan = _get_dashboard_plan(query=query, model=self.model, df_metadata=df_metadata) + _print_dashboard_plan(self.dashboard_plan) + dashboard = DashboardBuilder( + model=self.model, + df_metadata=df_metadata, + dashboard_plan=self.dashboard_plan, + ).dashboard + return dashboard