Skip to content

Commit

Permalink
improve layout creation
Browse files Browse the repository at this point in the history
  • Loading branch information
lingyielia committed Jul 23, 2024
1 parent 4cac939 commit fe48708
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 57 deletions.
11 changes: 5 additions & 6 deletions vizro-ai/src/vizro_ai/dashboard/response_models/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,31 +42,30 @@ def create(self, model, df_metadata) -> ComponentType:
from vizro_ai import VizroAI

vizro_ai = VizroAI(model=model)
component_id_unique = self.component_id + "_" + self.page_id # id to be referenced by layout

try:
if self.component_type == "Graph":
return vm.Graph(
id=self.component_id + "_" + self.page_id,
id=component_id_unique,
figure=vizro_ai.plot(df=df_metadata.get_df(self.df_name), user_input=self.component_description),
)
elif self.component_type == "AgGrid":
return vm.AgGrid(
id=self.component_id + "_" + self.page_id, figure=dash_ag_grid(data_frame=self.df_name)
)
return vm.AgGrid(id=component_id_unique, figure=dash_ag_grid(data_frame=self.df_name))
elif self.component_type == "Card":
result_proxy = _get_pydantic_output(
query=self.component_description, llm_model=model, response_model=vm.Card
)
proxy_dict = result_proxy.dict()
proxy_dict["id"] = self.component_id + "_" + self.page_id # id to be used by layout
proxy_dict["id"] = component_id_unique
return vm.Card.parse_obj(proxy_dict)

except DebugFailure as e:
logger.warning(
f"Failed to build component: {self.component_id}.\n ------- \n "
f"Reason: {e} \n ------- \n Relevant prompt: `{self.component_description}`"
)
return vm.Card(id=self.component_id, text=f"Failed to build component: {self.component_id}")
return vm.Card(id=component_id_unique, text=f"Failed to build component: {self.component_id}")


if __name__ == "__main__":
Expand Down
18 changes: 10 additions & 8 deletions vizro-ai/src/vizro_ai/dashboard/response_models/controls.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,15 @@ def validate_targets(v):
raise ValueError(f"targets must be one of {available_components}")
return v

def validate_targets_not_empty(v):
"""Validate the targets not empty."""
if available_components == []:
raise ValueError(
"This might be due to the filter target is not found in the available components. "
"returning default values."
)
return v

def validate_column(v):
"""Validate the column."""
if v not in df_cols:
Expand All @@ -46,6 +55,7 @@ def validate_column(v):
__validators__={
"validator1": validator("targets", pre=True, each_item=True, allow_reuse=True)(validate_targets),
"validator2": validator("column", allow_reuse=True)(validate_column),
"validator3": validator("targets", pre=True, allow_reuse=True)(validate_targets_not_empty),
},
__base__=vm.Filter,
)
Expand Down Expand Up @@ -100,14 +110,6 @@ def create(self, model, available_components, df_metadata):
df_schema=_df_schema,
available_components=available_components,
)
if res.targets == []:
logger.warning(
f"Filter control failed to create, "
f"related user input: `{self.control_description}`."
f"This might be due to the filter target is not found in the available components. "
"returning default values."
)
return None
return res
else:
logger.warning(f"Control type {self.control_type} not recognized.")
Expand Down
69 changes: 26 additions & 43 deletions vizro-ai/src/vizro_ai/dashboard/response_models/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,41 +7,31 @@
from langchain_core.language_models.chat_models import BaseChatModel

try:
from pydantic.v1 import BaseModel, Field, create_model
from pydantic.v1 import BaseModel, Field, ValidationError
except ImportError: # pragma: no cov
from pydantic import BaseModel, Field, create_model
from vizro_ai.dashboard._pydantic_output import _get_pydantic_output
from vizro_ai.utils.helper import DebugFailure
from pydantic import BaseModel, Field, ValidationError

logger = logging.getLogger(__name__)


def _convert_layout_to_grid(layout_grid_template_areas):
# TODO: Programmatically convert layout_grid_template_areas to grid
pass
def _convert_to_grid(layout_grid_template_areas, component_ids):
component_map = {component: index for index, component in enumerate(component_ids)}
grid = []

for row in layout_grid_template_areas:
grid_row = []
for cell in row.split():
if cell == ".":
grid_row.append(-1)
else:
try:
grid_row.append(component_map[cell])
except KeyError:
logger.warning(f"Component {cell} not found in component_ids: {component_ids}")
grid_row.append(-1)
grid.append(grid_row)

def _create_layout_proxy(component_ids, layout_grid_template_areas) -> BaseModel:
"""Create a layout proxy model."""

def validate_grid(v):
"""Validate the grid."""
expected_grid = _convert_layout_to_grid(layout_grid_template_areas)
if v != expected_grid:
logger.warning(f"Calculated grid: {expected_grid}, got: {v}")
return v

return create_model(
"LayoutProxyModel",
grid=(
List[List[int]],
Field(None, description="Grid specification to arrange components on screen."),
),
__validators__={
# "validator1": validator("grid", pre=True, allow_reuse=True)(validate_grid),
},
__base__=vm.Layout,
)
return grid


class LayoutPlan(BaseModel):
Expand All @@ -62,30 +52,23 @@ class LayoutPlan(BaseModel):

def create(self, model: BaseChatModel, component_ids: List[str]) -> Union[vm.Layout, None]:
"""Create the layout."""
layout_prompt = (
f"Create a layout from the following instructions: {self.layout_description}. Do not make up "
f"a layout if not requested. If a layout_grid_template_areas is provided, translate it into "
f"a matrix of integers where each integer represents a unique component (starting from 0). "
f"When translating, match the layout_grid_template_areas element string to the same name in "
f"{component_ids} and use the index of {component_ids} to replace the element string. "
f"replace '.' with -1 to represent empty spaces. Here is the grid template areas: \n ------- \n"
f" {self.layout_grid_template_areas}\n ------- \n"
)
if self.layout_description == "N/A":
return None

try:
result_proxy = _create_layout_proxy(
component_ids=component_ids, layout_grid_template_areas=self.layout_grid_template_areas
grid = _convert_to_grid(
layout_grid_template_areas=self.layout_grid_template_areas, component_ids=component_ids
)
proxy = _get_pydantic_output(query=layout_prompt, llm_model=model, response_model=result_proxy)
actual = vm.Layout.parse_obj(proxy.dict(exclude={"id": True}))
except DebugFailure as e:
actual = vm.Layout(grid=grid)
except ValidationError as e:
logger.warning(
f"Build failed for `Layout`, returning default values. Try rephrase the prompt or "
f"select a different model. \n ------- \n Error details: {e} \n ------- \n "
f"Relevant prompt: `{self.layout_description}`"
f"Relevant prompt: `{self.layout_description}`, which was parsed as layout_grid_template_areas:"
f" {self.layout_grid_template_areas}"
)
if grid:
logger.warning(f"Calculated grid which caused the error: {grid}")
actual = None

return actual
Expand Down

0 comments on commit fe48708

Please sign in to comment.