Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
lingyielia committed Jun 13, 2024
1 parent 9f4eda1 commit bc5e501
Show file tree
Hide file tree
Showing 10 changed files with 251 additions and 150 deletions.
30 changes: 14 additions & 16 deletions vizro-ai/examples/example_dashboard_code.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -81,22 +81,18 @@
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "af9ffbfc-0efd-411f-9038-4eca24b510ee",
"cell_type": "raw",
"id": "059aa349-f803-4b7f-a01f-45649b40c84d",
"metadata": {},
"outputs": [],
"source": [
"res = vizro_ai.dashboard([df, df_stocks], user_question_2_data)\n",
"res"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d74ace77-1754-4cad-a8fb-29afad44fc43",
"cell_type": "raw",
"id": "6bce7b14-cea3-4a49-85bc-ca91ce718310",
"metadata": {},
"outputs": [],
"source": [
"from vizro import Vizro\n",
"Vizro._reset()"
Expand Down Expand Up @@ -132,14 +128,16 @@
]
},
{
"cell_type": "raw",
"id": "239aabcd-5fc9-4cfe-b2d0-a7c62e8c5c17",
"cell_type": "code",
"execution_count": null,
"id": "04171030-72ae-4c7d-bd88-cc0e0c5b363e",
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.callbacks import get_openai_callback\n",
"\n",
"with get_openai_callback() as cb:\n",
" vizro_ai.dashboard([df,], user_question_1_data_1filter)\n",
" vizro_ai.dashboard([df, df_stocks], user_question_2_data)\n",
" print(cb)"
]
},
Expand Down Expand Up @@ -182,28 +180,28 @@
" \n",
"\n",
"data_manager[\"world_indicators\"] = population_data\n",
"data_manager[\"tech_stocks_performance\"] = stock_data"
"data_manager[\"stock_price_data\"] = stock_data"
]
},
{
"cell_type": "raw",
"id": "753b8b5a-80b0-43fa-ab0a-483e77edf0f6",
"id": "3ca43805-c738-4f28-a3ca-3ab2a05bc7d5",
"metadata": {},
"source": [
"from vizro import Vizro\n",
"from vizro.models import Page, AgGrid, Dashboard, Card\n",
"from vizro.models import AgGrid, Card, Dashboard, Filter, Layout, Page\n",
"from vizro.tables import dash_ag_grid\n",
"import pandas as pd\n",
"\n",
"dashboard=Dashboard(pages=[Page(id='Data Tables', components=[AgGrid(id='population_by_continent', figure=dash_ag_grid(data_frame='country_demographics_overview')), AgGrid(id='stock_prices', figure=dash_ag_grid(data_frame='tech_stocks_performance'))], title='Data Tables', controls=[]), Page(id='Information Cards', components=[Card(type='card', text=\"**User Inquiry:**\\n\\nThe user expressed enthusiasm about a data dashboard referred to as the 'jolly data dashboard,' which was created using Vizro.\", href=''), Card(type='card', text='To learn more, visit Vizro docs at [Vizro Documentation](https://vizro.readthedocs.io/)', href='')], title='Information Cards', controls=[])], title='My wonderful jolly dashboard showing a lot of info')\n",
"dashboard=Dashboard(pages=[Page(id='Page 1', components=[AgGrid(id='population_per_continent', figure=dash_ag_grid(data_frame='world_indicators')), AgGrid(id='stock_price', figure=dash_ag_grid(data_frame='stock_price_data'))], title='Page 1', layout=Layout(grid=[[0], [1]]), controls=[Filter(column='continent', targets=[], selector=None)]), Page(id='Page 2', components=[Card(type='card', text='The user found the data dashboard created in Vizro to be jolly and amazing.', href=''), Card(type='card', text='Visit the Vizro documentation for more information.', href='https://vizro.com/docs')], title='Page 2', layout=Layout(grid=[[0, 1]]), controls=[])], title='My wonderful jolly dashboard showing a lot of info')\n",
"\n",
"Vizro().build(dashboard).run()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f32aad5b-d62e-44d4-a9a9-ab7cefb4a04f",
"id": "2e10a41b-eade-43e6-9b3e-da7941b5e2db",
"metadata": {},
"outputs": [],
"source": []
Expand Down
8 changes: 4 additions & 4 deletions vizro-ai/src/vizro_ai/_vizro_ai.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from dataclasses import asdict
from typing import Any, Dict, List, Optional, Union
from typing import Any, List, Optional, Union

import pandas as pd
import plotly.graph_objects as go
Expand Down Expand Up @@ -181,11 +181,11 @@ def dashboard(

message_res = runnable.invoke(
{
"dfs": dfs,
"dfs": dfs,
"df_metadata": {},
"dashboard_plan": None,
"dashboard": None,
"messages": [("user", user_input)]
"messages": [("user", user_input)],
}
)
return message_res
return message_res
1 change: 1 addition & 0 deletions vizro-ai/src/vizro_ai/chains/_llm_models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Dict, Optional, Union

from langchain_openai import ChatOpenAI

# from langchain_anthropic import ChatAnthropic

# TODO add new wrappers in if new model support is added
Expand Down
127 changes: 79 additions & 48 deletions vizro-ai/src/vizro_ai/dashboard/graph/code_generation.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,26 @@
from typing import Any, List, Dict
"""Code generation graph for dashboard generation."""

import logging
import re
from typing import Any, Dict, List

import pandas as pd
from langgraph.graph import END, StateGraph
from vizro.models import Dashboard
from vizro_ai.chains._llm_models import _get_llm_model
from vizro_ai.dashboard.nodes.data_summary import DfInfo, df_sum_prompt, _get_df_info
from vizro_ai.dashboard.nodes.imports_builder import ModelSummary, model_sum_prompt, _generate_import_statement
from vizro_ai.dashboard.nodes.core_builder.plan import _get_dashboard_plan, _print_dashboard_plan, DashboardPlanner
from vizro_ai.dashboard.nodes.core_builder.build import DashboardBuilder
from vizro.models import Dashboard

from langchain.globals import set_debug
from vizro_ai.dashboard.nodes.core_builder.plan import DashboardPlanner, _get_dashboard_plan, _print_dashboard_plan
from vizro_ai.dashboard.nodes.data_summary import DfInfo, _get_df_info, df_sum_prompt
from vizro_ai.dashboard.nodes.imports_builder import ModelSummary, _generate_import_statement, model_sum_prompt

try:
from pydantic.v1 import BaseModel, validator
except ImportError: # pragma: no cov
from pydantic import BaseModel, validator


logger = logging.getLogger(__name__)

model_default = "gpt-3.5-turbo"
# model_default = "gpt-4-turbo"
# set_debug(True)
Expand All @@ -41,15 +45,18 @@ class GraphState(BaseModel):
dashboard: Dashboard = None

class Config:
"""Pydantic configuration."""

arbitrary_types_allowed = True

@validator('dfs')
@validator("dfs")
def check_dataframes(cls, v):
"""Check if the dataframes are valid."""
if not isinstance(v, list):
raise ValueError('dfs must be a list')
raise ValueError("dfs must be a list")
for df in v:
if not isinstance(df, pd.DataFrame):
raise ValueError('Each element in dfs must be a Pandas DataFrame')
raise ValueError("Each element in dfs must be a Pandas DataFrame")
return v


Expand All @@ -58,29 +65,27 @@ def _store_df_info(state: GraphState):
Args:
state (dict): The current graph state.
"""
print("*** _store_df_info ***")
logger.info("*** _store_df_info ***")
dfs = state.dfs
messages = state.messages
df_metadata = state.df_metadata
current_df_names = []
for _, df in enumerate(dfs):
df_schema, df_sample = _get_df_info(df)
data_sum_chain = df_sum_prompt | _get_llm_model(model=model_default).with_structured_output(
DfInfo
)
data_sum_chain = df_sum_prompt | _get_llm_model(model=model_default).with_structured_output(DfInfo)

df_name = data_sum_chain.invoke(
{"df_schema": df_schema, "df_sample": df_sample, "messages": messages, "current_df_names": current_df_names}
)

print(f"df_name: {df_name}")
current_df_names.append(df_name)

cleaned_df_name = df_name.dataset_name.lower()
cleaned_df_name = re.sub(r'\W+', '_', cleaned_df_name)
df_id = cleaned_df_name.strip('_')
print(f"df_id: {df_id}")
cleaned_df_name = re.sub(r"\W+", "_", cleaned_df_name)
df_id = cleaned_df_name.strip("_")
logger.info(f"df_name: {df_name} --> df_id: {df_id}")
df_metadata[df_id] = {"df_schema": df_schema, "df_sample": df_sample}

return {"df_metadata": df_metadata}
Expand All @@ -96,7 +101,7 @@ def _compose_imports_code(state: GraphState):
state (dict): New key added to state, generation
"""
print("*** _compose_imports_code ***")
logger.info("*** _compose_imports_code ***")
messages = state.messages
model_sum_chain = model_sum_prompt | _get_llm_model(model=model_default).with_structured_output(ModelSummary)

Expand All @@ -118,8 +123,9 @@ def _dashboard_plan(state: GraphState):
Args:
state (dict): The current graph state
"""
print("*** _dashboard_plan ***")
logger.info("*** _dashboard_plan ***")
messages = state.messages
_, query = messages[0]
df_metadata = state.df_metadata
Expand All @@ -137,17 +143,18 @@ def _build_dashboard(state: GraphState):
Args:
state (dict): The current graph state
"""
print("*** _build_dashboard ***")
logger.info("*** _build_dashboard ***")
df_metadata = state.df_metadata
dashboard_plan = state.dashboard_plan

model = _get_llm_model(model=model_default)
dashboard = DashboardBuilder(
model=model,
df_metadata=df_metadata,
dashboard_plan=dashboard_plan,
).dashboard
model=model,
df_metadata=df_metadata,
dashboard_plan=dashboard_plan,
).dashboard

return {"dashboard": dashboard}

Expand All @@ -157,15 +164,16 @@ def _generate_dashboard_code(state: GraphState):
Args:
state (dict): The current graph state
"""
print("*** _generate_dashboard_code ***")
logger.info("*** _generate_dashboard_code ***")
messages = state.messages
_, import_statement = messages[-1]
dashboard = state.dashboard

dashboard_code_string = dashboard.dict_obj(exclude_unset=True)
full_code_string = f"\n{import_statement}\ndashboard={dashboard_code_string}\n\nVizro().build(dashboard).run()\n"
print(f"full_code_string: \n ------- \n{full_code_string}\n ------- \n")
logger.info(f"full_code_string: \n ------- \n{full_code_string}\n ------- \n")

messages += [
(
Expand Down Expand Up @@ -200,27 +208,50 @@ def _create_and_compile_graph():

if __name__ == "__main__":
test_state = {
'messages': [
('user',
'\nI need a page with a table showing the population per continent \n'
'I also want a page with two \ncards on it. One should be a card saying: '
'`This was the jolly data dashboard, it was created in Vizro which is amazing` \n, '
'and the second card should link to `https://vizro.readthedocs.io/`. The title of '
'the dashboard should be: `My wonderful \njolly dashboard showing a lot of info`.\n'
'The layout of this page should use `grid=[[0,1]]`'),
('assistant',
'from vizro import Vizro\nfrom vizro.models import AgGrid, Card, Dashboard, Page\nfrom vizro.tables import dash_ag_grid\nimport pandas as pd\n'),],
'dfs': [pd.DataFrame(),],
'df_metadata': {'globaldemographics': {'df_schema': {'country': 'object',
'continent': 'object',
'year': 'int64',
'lifeExp': 'float64',
'pop': 'int64',
'gdpPercap': 'float64',
'iso_alpha': 'object',
'iso_num': 'int64'},
'df_sample': '| | country | continent | year | lifeExp | pop | gdpPercap | iso_alpha | iso_num |\n|-----:|:----------|:------------|-------:|----------:|---------:|------------:|:------------|----------:|\n| 215 | Burundi | Africa | 2007 | 49.58 | 8390505 | 430.071 | BDI | 108 |\n| 1545 | Togo | Africa | 1997 | 58.39 | 4320890 | 982.287 | TGO | 768 |\n| 772 | Italy | Europe | 1972 | 72.19 | 54365564 | 12269.3 | ITA | 380 |\n| 1322 | Senegal | Africa | 1962 | 41.454 | 3430243 | 1654.99 | SEN | 686 |\n| 732 | Iraq | Asia | 1952 | 45.32 | 5441766 | 4129.77 | IRQ | 368 |'},},
"messages": [
(
"user",
"\nI need a page with a table showing the population per continent \n"
"I also want a page with two \ncards on it. One should be a card saying: "
"`This was the jolly data dashboard, it was created in Vizro which is amazing` \n, "
"and the second card should link to `https://vizro.readthedocs.io/`. The title of "
"the dashboard should be: `My wonderful \njolly dashboard showing a lot of info`.\n"
"The layout of this page should use `grid=[[0,1]]`",
),
(
"assistant",
"from vizro import Vizro\nfrom vizro.models import AgGrid, Card, Dashboard, Page\nfrom "
"vizro.tables import dash_ag_grid\nimport pandas as pd\n",
),
],
"dfs": [
pd.DataFrame(),
],
"df_metadata": {
"globaldemographics": {
"df_schema": {
"country": "object",
"continent": "object",
"year": "int64",
"lifeExp": "float64",
"pop": "int64",
"gdpPercap": "float64",
"iso_alpha": "object",
"iso_num": "int64",
},
"df_sample": "| | country | continent | year | lifeExp | pop | "
"gdpPercap | iso_alpha | iso_num |\n|-----:|:----------|:------------|-------:"
"|----------:|---------:|------------:|:------------|----------:|\n| 215 | "
"Burundi | Africa | 2007 | 49.58 | 8390505 | 430.071 | BDI"
" | 108 |\n| 1545 | Togo | Africa | 1997 | 58.39 |"
" 4320890 | 982.287 | TGO | 768 |\n| 772 | Italy | Europe"
" | 1972 | 72.19 | 54365564 | 12269.3 | ITA | 380 |\n|"
" 1322 | Senegal | Africa | 1962 | 41.454 | 3430243 | 1654.99 | SEN"
" | 686 |\n| 732 | Iraq | Asia | 1952 | 45.32 | 5441766"
" | 4129.77 | IRQ | 368 |",
},
},
}
sample_state = GraphState(**test_state)
message = _generate_dashboard_code(sample_state)
print(message)
print(message) # noqa: T201
Loading

0 comments on commit bc5e501

Please sign in to comment.