Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
lingyielia committed Aug 2, 2024
1 parent 4b22839 commit dfd3a85
Show file tree
Hide file tree
Showing 9 changed files with 48 additions and 56 deletions.
18 changes: 9 additions & 9 deletions vizro-ai/src/vizro_ai/dashboard/_graph/dashboard_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
from langgraph.constants import END, Send
from langgraph.graph import StateGraph
from tqdm.auto import tqdm
from vizro_ai.dashboard._pydantic_output import _get_pydantic_output
from vizro_ai.dashboard._response_models.dashboard import DashboardPlanner
from vizro_ai.dashboard._pydantic_output import _get_pydantic_model
from vizro_ai.dashboard._response_models.dashboard import DashboardPlan
from vizro_ai.dashboard._response_models.df_info import DfInfo, _create_df_info_content, _get_df_info
from vizro_ai.dashboard._response_models.page import PagePlanner
from vizro_ai.dashboard._response_models.page import PagePlan
from vizro_ai.dashboard.utils import AllDfMetadata, DfMetadata, _execute_step
from vizro_ai.utils.helper import DebugFailure

Expand Down Expand Up @@ -47,7 +47,7 @@ class GraphState(BaseModel):
messages: List[BaseMessage]
dfs: List[pd.DataFrame]
all_df_metadata: AllDfMetadata
dashboard_plan: Optional[DashboardPlanner] = None
dashboard_plan: Optional[DashboardPlan] = None
pages: Annotated[List, operator.add]
dashboard: Optional[vm.Dashboard] = None

Expand All @@ -72,7 +72,7 @@ def _store_df_info(state: GraphState, config: RunnableConfig) -> Dict[str, AllDf

llm = config["configurable"].get("model", None)
try:
df_name = _get_pydantic_output(
df_name = _get_pydantic_model(
query=query,
llm_model=llm,
response_model=DfInfo,
Expand All @@ -91,7 +91,7 @@ def _store_df_info(state: GraphState, config: RunnableConfig) -> Dict[str, AllDf
return {"all_df_metadata": all_df_metadata}


def _dashboard_plan(state: GraphState, config: RunnableConfig) -> Dict[str, DashboardPlanner]:
def _dashboard_plan(state: GraphState, config: RunnableConfig) -> Dict[str, DashboardPlan]:
"""Generate a dashboard plan."""
node_desc = "Generate dashboard plan"
pbar = tqdm(total=2, desc=node_desc)
Expand All @@ -106,10 +106,10 @@ def _dashboard_plan(state: GraphState, config: RunnableConfig) -> Dict[str, Dash
None,
)
try:
dashboard_plan = _get_pydantic_output(
dashboard_plan = _get_pydantic_model(
query=query,
llm_model=llm,
response_model=DashboardPlanner,
response_model=DashboardPlan,
df_info=all_df_metadata.get_schemas_and_samples(),
)
except (DebugFailure, ValidationError) as e:
Expand Down Expand Up @@ -137,7 +137,7 @@ class BuildPageState(BaseModel):
"""

all_df_metadata: AllDfMetadata
page_plan: Optional[PagePlanner] = None
page_plan: Optional[PagePlan] = None


def _build_page(state: BuildPageState, config: RunnableConfig) -> Dict[str, List[vm.Page]]:
Expand Down
6 changes: 3 additions & 3 deletions vizro-ai/src/vizro_ai/dashboard/_pydantic_output.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Contains the _get_pydantic_output for the Vizro AI dashboard."""
"""Contains the _get_pydantic_model for the Vizro AI dashboard."""

# ruff: noqa: F821

Expand Down Expand Up @@ -63,7 +63,7 @@ def _create_message_content(
return message_content


def _get_pydantic_output(
def _get_pydantic_model(
query: str,
llm_model: BaseChatModel,
response_model: BaseModel,
Expand Down Expand Up @@ -94,5 +94,5 @@ def _get_pydantic_output(

model = _get_llm_model()
component_description = "Create a card with the following content: 'Hello, world!'"
res = _get_pydantic_output(query=component_description, llm_model=model, response_model=vm.Card)
res = _get_pydantic_model(query=component_description, llm_model=model, response_model=vm.Card)
print(res) # noqa: T201
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
except ImportError: # pragma: no cov
from pydantic import BaseModel, Field
from vizro.tables import dash_ag_grid
from vizro_ai.dashboard._pydantic_output import _get_pydantic_output
from vizro_ai.dashboard._response_models.types import CompType
from vizro_ai.dashboard._pydantic_output import _get_pydantic_model
from vizro_ai.dashboard._response_models.types import ComponentType
from vizro_ai.utils.helper import DebugFailure

logger = logging.getLogger(__name__)
Expand All @@ -20,7 +20,7 @@
class ComponentPlan(BaseModel):
"""Component plan model."""

component_type: CompType
component_type: ComponentType
component_description: str = Field(
...,
description="""
Expand Down Expand Up @@ -62,7 +62,7 @@ def create(self, model, all_df_metadata) -> Union[vm.Card, vm.AgGrid, vm.Figure]
The Card uses the dcc.Markdown component from Dash as its underlying text component.
Create a card based on the card description: {self.component_description}.
"""
result_proxy = _get_pydantic_output(query=card_prompt, llm_model=model, response_model=vm.Card)
result_proxy = _get_pydantic_model(query=card_prompt, llm_model=model, response_model=vm.Card)
proxy_dict = result_proxy.dict()
proxy_dict["id"] = self.component_id
return vm.Card.parse_obj(proxy_dict)
Expand Down Expand Up @@ -92,7 +92,6 @@ def create(self, model, all_df_metadata) -> Union[vm.Card, vm.AgGrid, vm.Figure]
component_type="Card",
component_description="Create a card says 'this is worldwide GDP'.",
component_id="gdp_card",
page_id="page1",
df_name="N/A",
)
component = component_plan.create(model, all_df_metadata)
Expand Down
18 changes: 8 additions & 10 deletions vizro-ai/src/vizro_ai/dashboard/_response_models/controls.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Controls plan model."""

import logging
from typing import List, Union
from typing import List, Optional

import pandas as pd
import vizro.models as vm
Expand All @@ -10,8 +10,8 @@
from pydantic.v1 import BaseModel, Field, ValidationError, create_model, root_validator, validator
except ImportError: # pragma: no cov
from pydantic import BaseModel, Field, ValidationError, create_model, root_validator, validator
from vizro_ai.dashboard._pydantic_output import _get_pydantic_output
from vizro_ai.dashboard._response_models.types import CtrlType
from vizro_ai.dashboard._pydantic_output import _get_pydantic_model
from vizro_ai.dashboard._response_models.types import ControlType

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -84,16 +84,14 @@ def _create_filter(filter_prompt, model, df_cols, df_schema, controllable_compon
result_proxy = _create_filter_proxy(
df_cols=df_cols, df_schema=df_schema, controllable_components=controllable_components
)
proxy = _get_pydantic_output(query=filter_prompt, llm_model=model, response_model=result_proxy, df_info=df_schema)
return vm.Filter.parse_obj(
proxy.dict(exclude={"selector": {"id": True, "actions": True, "_add_key": True}, "id": True, "type": True})
)
proxy = _get_pydantic_model(query=filter_prompt, llm_model=model, response_model=result_proxy, df_info=df_schema)
return vm.Filter.parse_obj(proxy.dict(exclude_unset=True))


class ControlPlan(BaseModel):
"""Control plan model."""

control_type: CtrlType
control_type: ControlType
control_description: str = Field(
...,
description="""
Expand All @@ -105,12 +103,12 @@ class ControlPlan(BaseModel):
df_name: str = Field(
...,
description="""
The name of the dataframe that this component will use.
The name of the dataframe that the target component will use.
If the dataframe is not used, please specify that.
""",
)

def create(self, model, controllable_components, all_df_metadata) -> Union[vm.Filter, None]:
def create(self, model, controllable_components, all_df_metadata) -> Optional[vm.Filter]:
"""Create the control."""
filter_prompt = f"""
Create a filter from the following instructions: <{self.control_description}>. Do not make up
Expand Down
6 changes: 3 additions & 3 deletions vizro-ai/src/vizro_ai/dashboard/_response_models/dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
from pydantic.v1 import BaseModel, Field
except ImportError: # pragma: no cov
from pydantic import BaseModel, Field
from vizro_ai.dashboard._response_models.page import PagePlanner
from vizro_ai.dashboard._response_models.page import PagePlan

logger = logging.getLogger(__name__)


class DashboardPlanner(BaseModel):
class DashboardPlan(BaseModel):
"""Dashboard plan model."""

title: str = Field(
Expand All @@ -22,4 +22,4 @@ class DashboardPlanner(BaseModel):
make a short and concise title from the content of the pages.
""",
)
pages: List[PagePlanner]
pages: List[PagePlan]
6 changes: 3 additions & 3 deletions vizro-ai/src/vizro_ai/dashboard/_response_models/df_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pydantic import BaseModel, Field


DF_SUM_PROMPT = """
DF_SUMMARY_PROMPT = """
Inspect the provided data and give a short unique name to the dataset. \n
dataframe sample: \n ------- \n {df_sample} \n ------- \n
Here is the data schema: \n ------- \n {df_schema} \n ------- \n
Expand All @@ -27,15 +27,15 @@ class DfInfo(BaseModel):


def _get_df_info(df: pd.DataFrame) -> Tuple[Dict[str, str], pd.DataFrame]:
"""Get the dataframe schema and head info as strings."""
"""Get the dataframe schema and sample."""
formatted_pairs = dict(df.dtypes.astype(str))
df_sample = df.sample(5, replace=True, random_state=19)
return formatted_pairs, df_sample


def _create_df_info_content(df_schema: Dict[str, str], df_sample: pd.DataFrame, current_df_names: List[str]) -> dict:
"""Create the message content for the dataframe summarization."""
return DF_SUM_PROMPT.format(df_sample=df_sample, df_schema=df_schema, current_df_names=current_df_names)
return DF_SUMMARY_PROMPT.format(df_sample=df_sample, df_schema=df_schema, current_df_names=current_df_names)


if __name__ == "__main__":
Expand Down
29 changes: 13 additions & 16 deletions vizro-ai/src/vizro_ai/dashboard/_response_models/layout.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Layout plan model."""

import logging
from typing import List
from typing import List, Optional

import vizro.models as vm

Expand All @@ -13,7 +13,7 @@
logger = logging.getLogger(__name__)


def _convert_to_grid(layout_grid_template_areas, component_ids) -> List[List[int]]:
def _convert_to_grid(layout_grid_template_areas: List[str], component_ids: List[str]) -> List[List[int]]:
component_map = {component: index for index, component in enumerate(component_ids)}
grid = []

Expand All @@ -26,8 +26,13 @@ def _convert_to_grid(layout_grid_template_areas, component_ids) -> List[List[int
try:
grid_row.append(component_map[cell])
except KeyError:
logger.warning(f"Component {cell} not found in component_ids: {component_ids}")
grid_row.append(-1)
logger.warning(
f"""
[FALLBACK] Component {cell} not found in component_ids: {component_ids}.
Returning default values.
"""
)
return []
grid.append(grid_row)

return grid
Expand All @@ -36,13 +41,6 @@ def _convert_to_grid(layout_grid_template_areas, component_ids) -> List[List[int
class LayoutPlan(BaseModel):
"""Layout plan model, which only applies to Vizro Components(Graph, AgGrid, Card)."""

layout_description: str = Field(
...,
description="""
Description of the layout of Vizro Components(Graph, AgGrid, Card).
Include everything that seems to relate to this layout. If layout not specified, describe layout as N/A.
""",
)
layout_grid_template_areas: List[str] = Field(
[],
description="""
Expand All @@ -57,9 +55,9 @@ class LayoutPlan(BaseModel):
""",
)

def create(self, component_ids: List[str]):
def create(self, component_ids: List[str]) -> Optional[vm.Layout]:
"""Create the layout."""
if self.layout_description == "N/A":
if not self.layout_grid_template_areas:
return None

try:
Expand All @@ -72,7 +70,7 @@ def create(self, component_ids: List[str]):
f"""
[FALLBACK] Build failed for `Layout`, returning default values. Try rephrase the prompt or select a different model.
Error details: {e}
Relevant prompt: {self.layout_description}, which was parsed as layout_grid_template_areas:
Relevant layout_grid_template_areas:
{self.layout_grid_template_areas}
"""
)
Expand All @@ -88,8 +86,7 @@ def create(self, component_ids: List[str]):

model = _get_llm_model()
layout_plan = LayoutPlan(
layout_description="Create a layout with a graph on the left and a card on the right.",
layout_grid_template_areas=["graph1 card2 card2", "graph1 . card1"],
)
layout = layout_plan.create(model, component_ids=["graph1", "card1", "card2"])
layout = layout_plan.create(component_ids=["graph1", "card1", "card2"])
print(layout) # noqa: T201
8 changes: 3 additions & 5 deletions vizro-ai/src/vizro_ai/dashboard/_response_models/page.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
logger = logging.getLogger(__name__)


class PagePlanner(BaseModel):
class PagePlan(BaseModel):
"""Page plan model."""

title: str = Field(
Expand All @@ -28,7 +28,6 @@ class PagePlanner(BaseModel):
make a concise and descriptive title from the components.
""",
)
page_id: str = Field(..., description="Unique identifier for the page being planned.")
components_plan: List[ComponentPlan] = Field(
..., description="List of components. Must contain at least one component."
)
Expand Down Expand Up @@ -162,6 +161,7 @@ def create(self, model, all_df_metadata) -> Union[vm.Page, None]:
try:
page = vm.Page(title=title, components=components, controls=controls, layout=layout)
except Exception as e:
# TODO: This Exception might be redundant. Check if it can be removed.
if any("Number of page and grid components need to be the same" in error["msg"] for error in e.errors()):
logger.warning(
"""
Expand Down Expand Up @@ -197,9 +197,8 @@ def create(self, model, all_df_metadata) -> Union[vm.Page, None]:
)
}
)
page_plan = PagePlanner(
page_plan = PagePlan(
title="Worldwide GDP",
page_id="page1",
components_plan=[
ComponentPlan(
component_type="Card",
Expand All @@ -216,7 +215,6 @@ def create(self, model, all_df_metadata) -> Union[vm.Page, None]:
)
],
layout_plan=LayoutPlan(
layout_description="N/A",
layout_grid_template_areas=[],
),
unsupported_specs=[],
Expand Down
4 changes: 2 additions & 2 deletions vizro-ai/src/vizro_ai/dashboard/_response_models/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
# TODO make available in documentation

# Complete list: ["AgGrid", "Button", "Card", "Container", "Graph", "Table", "Tabs"]
CompType = Literal["AgGrid", "Card", "Graph"]
ComponentType = Literal["AgGrid", "Card", "Graph"]
"""Component types currently supported by Vizro-AI."""

# Complete list: ["Filter", "Parameter"]
CtrlType = Literal["Filter"]
ControlType = Literal["Filter"]
"""Control types currently supported by Vizro-AI."""

0 comments on commit dfd3a85

Please sign in to comment.