diff --git a/vizro-ai/examples/example_dashboard.ipynb b/vizro-ai/examples/example_dashboard.ipynb new file mode 100644 index 000000000..abf43cd05 --- /dev/null +++ b/vizro-ai/examples/example_dashboard.ipynb @@ -0,0 +1,285 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "53e857ce-22bc-49de-9adc-9a2e7c9829cf", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from dotenv import load_dotenv\n", + "load_dotenv()" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "2a25acdd-20c3-4762-b97f-254de1586aeb", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import vizro.plotly.express as px\n", + "\n", + "from vizro import Vizro\n", + "from vizro_ai import VizroAI\n", + "\n", + "vizro_ai = VizroAI(model=\"gpt-4-turbo\")\n", + "# vizro_ai = VizroAI()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "b5e24f1b-e698-40e5-be00-c3a59c53ec65", + "metadata": {}, + "outputs": [], + "source": [ + "df1 = px.data.gapminder()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "449da2ee-c754-420a-ba2e-c9b0ef62d934", + "metadata": {}, + "outputs": [], + "source": [ + "df2 = px.data.stocks()" + ] + }, + { + "cell_type": "markdown", + "id": "ec46d4d1-d20b-4351-831d-d3d8ddc5cb70", + "metadata": {}, + "source": [ + "# Example: Simple dashboard request" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6df428c5-28a3-49b6-bac9-f72ade4a34a4", + "metadata": {}, + "outputs": [], + "source": [ + "user_question_2_data = \"\"\"\n", + "I need a page with 1 table.\n", + "The table shows the tech companies stock data.\n", + "\n", + "I need a second page showing 2 cards and one chart.\n", + "The first card says 'The Gapminder dataset provides historical data on countries' development indicators.'\n", + "The chart is a scatter plot showing life expectancy vs. GDP per capita by country. Life expectancy on the y axis, GDP per capita on the x axis, and colored by continent.\n", + "The second card says 'Data spans from 1952 to 2007 across various countries'\n", + "The layout uses a grid of 3 columns and 2 rows.\n", + "\n", + "Row 1: The first row has three columns:\n", + "The first column is occupied by the first card.\n", + "The second and third columns are spanned by the chart.\n", + "\n", + "Row 2: The second row mirrors the layout of the first row with respect to chart, but the first column is occupied by the second card.\n", + "\n", + "Add a filter to filter the scatter plot by continent.\n", + "Add a second filter to filter the chart by year.\n", + "\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "31c002d9-8103-43ef-9da5-78d9180075de", + "metadata": {}, + "outputs": [], + "source": [ + "res = vizro_ai.dashboard([df1, df2], user_question_2_data, return_elements=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9996286e-46b0-4a63-8273-5dd4b2a91cd0", + "metadata": {}, + "outputs": [], + "source": [ + "vizro_ai.run_dashboard(res.dashboard, res.metadata)" + ] + }, + { + "cell_type": "markdown", + "id": "747964b9-fd05-4c5a-a73a-79dae82320b3", + "metadata": {}, + "source": [ + "# Example: 5-page dashboard request" + ] + }, + { + "cell_type": "raw", + "id": "067983a0-65c6-48a2-85f3-a4d3db31977b", + "metadata": {}, + "source": [ + "df3 = px.data.tips()" + ] + }, + { + "cell_type": "raw", + "id": "50ccf0bc-0afe-433d-b541-1bbb41b0e7e7", + "metadata": {}, + "source": [ + "user_question_3_data = \"\"\"\n", + "\n", + "I need a page with 1 table, showing the stock prices of GOOG, AAPL, and AMZN over time.\n", + "\n", + "\n", + "I need a second page showing 1 card and 1 chart.\n", + "The card says 'The Gapminder dataset provides historical data on countries' development indicators.'\n", + "The chart is a scatter plot showing GDP per capita vs. life expectancy. GDP per capita on the x axis, life expectancy on the y axis, and colored by continent.\n", + "Layout the card on the left and the chart on the right. The card takes 1/3 of the whole space on the left.\n", + "The chart takes 2/3 of the whole space and is on the right.\n", + "Add a filter to filter the scatter plot by continent.\n", + "Add a second filter to filter the chart by year.\n", + "\n", + "\n", + "This page displays the tips dataset. use two different charts to show data\n", + "distributions. one chart should be a bar chart and the other should be a scatter plot.\n", + "first chart is on the left and the second chart is on the right.\n", + "Add a filter to filter data in the scatter plot by smoker.\n", + "\n", + "\n", + "Create 3 cards on this page:\n", + "1. The first card on top says \"This page combines data from various sources including tips, stock prices, and global indicators.\"\n", + "2. The second card says \"Insights from Gapminder dataset.\"\n", + "3. The third card says \"Stock price trends over time.\"\n", + "\n", + "Layout these 3 cards in this way:\n", + "create a grid with 3 columns and 2 rows.\n", + "Row 1: The first row has three columns:\n", + "- The first column is empty.\n", + "- The second and third columns span the area for card 1.\n", + "\n", + "Row 2: The second row also has three columns:\n", + "- The first column is empty.\n", + "- The second column is occupied by the area for card 2.\n", + "- The third column is occupied by the area for card 3.\n", + "\"\"\"" + ] + }, + { + "cell_type": "raw", + "id": "a240b98d-229b-4ed2-8e4d-93de004c6283", + "metadata": {}, + "source": [ + "Vizro._reset()\n", + "res = vizro_ai.dashboard([df1, df2, df3], user_question_3_data, return_elements=True)" + ] + }, + { + "cell_type": "raw", + "id": "7930a7ee-7dff-467a-b888-21b4c502c585", + "metadata": {}, + "source": [ + "vizro_ai.run_dashboard(res.dashboard, res.metadata)" + ] + }, + { + "cell_type": "markdown", + "id": "bbf5c920-0432-4415-996f-1acb9d7b6b8a", + "metadata": {}, + "source": [ + "# Example: Request with unsupported features" + ] + }, + { + "cell_type": "raw", + "id": "c8c278d5-eb8b-4244-843b-92cf63e8866f", + "metadata": {}, + "source": [ + "user_question_2_data = \"\"\"\n", + "\n", + "I need a page showing 2 cards, one chart, and 1 button.\n", + "The first card says 'The Tips dataset provides insights into customer tipping behavior.'\n", + "The chart is a bar chart showing the total bill amount by day. Day on the x axis, total bill amount on the y axis, and colored by time of day.\n", + "The second card says 'Data collected from various days and times.'\n", + "Layout the two cards on the left and the chart on the right. Two cards take 1/3 of the whole space on the left in total.\n", + "The first card is on top of the second card vertically.\n", + "The chart takes 2/3 of the whole space and is on the right.\n", + "The button would trigger a download action to download the Tips dataset.\n", + "Add a filter to filter the bar chart by `size`.\n", + "Make another tab on this page,\n", + "In this tab, create a card saying \"Tipping patterns and trends.\"\n", + "Group all the above content into the first NavLink.\n", + "\n", + "\n", + "Create two pages:\n", + "1. The first page has a card saying \"Analyzing global development trends.\"\n", + "2. The second page has a scatter plot showing GDP per capita vs. life expectancy. GDP per capita on the x axis, life expectancy on the y axis, and colored by continent.\n", + "Add a parameter to control the title of the scatter plot, with title options \"Economic Growth vs. Health\" and \"Development Indicators.\"\n", + "Also create a button and a spinning circle on the right-hand side of the page.\n", + "\n", + "\n", + "Create one page:\n", + "1. The first page has a card saying \"Stock price trends over time.\"\n", + "Create a button and a spinning circle on the right-hand side of the page.\n", + "\n", + "For hosting the dashboard on AWS, which service should I use?\n", + "\"\"\"" + ] + }, + { + "cell_type": "raw", + "id": "f338d294-ec84-49f1-b729-646b0b419108", + "metadata": {}, + "source": [ + "Vizro._reset()\n", + "res = vizro_ai.dashboard([df3, df2], user_question_2_data, return_elements=True)" + ] + }, + { + "cell_type": "raw", + "id": "4dbe8aae-f7ae-4a94-8bce-6a394faabcd4", + "metadata": {}, + "source": [ + "vizro_ai.run_dashboard(res.dashboard, res.metadata)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9408a882-f0bf-4eb6-bd74-ca415f5d83a8", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/vizro-ai/examples/example_dashboard.py b/vizro-ai/examples/example_dashboard.py index d24a912f8..8dab8e7ec 100644 --- a/vizro-ai/examples/example_dashboard.py +++ b/vizro-ai/examples/example_dashboard.py @@ -1,9 +1,9 @@ """Example of creating a dashboard using VizroAI.""" - import vizro.plotly.express as px from vizro_ai import VizroAI vizro_ai = VizroAI(model="gpt-4-turbo") +# vizro_ai = VizroAI() gapminder_data = px.data.gapminder() tips_data = px.data.tips() @@ -19,7 +19,7 @@ "page2 displays the tips dataset. use two different charts to help me understand the data " "distributions. one chart should be a bar chart and the other should be a scatter plot. " "first chart is on the left and the second chart is on the right. " - "add a filter to filter data in the scatter plot." + "add a filter to filter data in the scatter plot by smoker." ) res = vizro_ai.dashboard(dfs=dfs, user_input=input_text, return_elements=True) diff --git a/vizro-ai/examples/example_dashboard_code.ipynb b/vizro-ai/examples/example_dashboard_code.ipynb deleted file mode 100644 index 9c914db20..000000000 --- a/vizro-ai/examples/example_dashboard_code.ipynb +++ /dev/null @@ -1,352 +0,0 @@ -{ - "cells": [ - { - "cell_type": "raw", - "id": "19271619-b3ab-4581-9469-4ae4778e5654", - "metadata": {}, - "source": [ - "from langsmith import traceable" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c81ca231-69c0-455e-91d6-252c845b8619", - "metadata": {}, - "outputs": [], - "source": [ - "from dotenv import load_dotenv\n", - "load_dotenv()" - ] - }, - { - "cell_type": "raw", - "id": "53c5da4c-3fac-4aae-ac56-454212b302f4", - "metadata": {}, - "source": [ - "from langchain.globals import set_debug, set_verbose\n", - "set_verbose(True)\n", - "set_debug(True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "632a9013-6872-447d-936f-1dbcd1cc0f12", - "metadata": {}, - "outputs": [], - "source": [ - "from vizro_ai import VizroAI\n", - "vizro_ai = VizroAI()" - ] - }, - { - "cell_type": "raw", - "id": "d23ceb07-7cee-4fd3-ab5a-64fd0966be07", - "metadata": {}, - "source": [ - "from vizro_ai.dashboard.graph.code_generation import _create_and_compile_graph\n", - "from IPython.display import Image\n", - "\n", - "graph = _create_and_compile_graph()\n", - "\n", - "Image(graph.get_graph().draw_mermaid_png())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "86cd816b-e043-45af-9f7d-6b99706d6a47", - "metadata": {}, - "outputs": [], - "source": [ - "import vizro.plotly.express as px\n", - "df = px.data.gapminder()\n", - "df.head()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "07c18cfb-9172-45b1-be7e-0085ee1fd66a", - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "df_stocks = px.data.stocks()\n", - "df_stocks.head()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "88a055e7-c08b-4eeb-a0e5-a3d69519011d", - "metadata": {}, - "outputs": [], - "source": [ - "user_question_2_data = \"\"\"\n", - "I need a page with 2 tables.\n", - "One table shows the population per continent.\n", - "The second table shows the stock price of some major companies.\n", - "add a filter to filter the world demographic table by continent.\n", - "I also want a second page with two \n", - "cards on it. One should be a card saying: `This was the jolly data dashboard, it was created in Vizro which is amazing` \n", - ", and the second card says `To learn more, visit Vizro docs` and should link to `https://vizro.readthedocs.io/`. The title of the dashboard should be: `My wonderful \n", - "jolly dashboard showing a lot of info`.\n", - "The layout of this page should use `grid=[[0,1]]`\n", - "\"\"\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c7579821-02f5-4be2-9e5c-a533d15c6067", - "metadata": {}, - "outputs": [], - "source": [ - "res = vizro_ai.dashboard([df, df_stocks], user_question_2_data, return_elements=True)\n", - "res" - ] - }, - { - "cell_type": "raw", - "id": "eea8eb34-a023-488e-a66a-99b6e3c7089f", - "metadata": {}, - "source": [ - "user_question_not_implemented = \"\"\"\n", - "I need a page with 1 table and 1 tab area.\n", - "the table shows the population per continent.\n", - "In the tab, I need two \n", - "cards in it. One should be a card saying: `This was the jolly data dashboard, it was created in Vizro which is amazing` \n", - ", and the second card says `To learn more, visit Vizro docs`.\n", - "\"\"\"" - ] - }, - { - "cell_type": "raw", - "id": "6f5a1669-3d21-43b8-a2ef-d058d094e07d", - "metadata": {}, - "source": [ - "res = vizro_ai.dashboard([df], user_question_not_implemented)\n", - "res" - ] - }, - { - "cell_type": "raw", - "id": "90622878-5ff5-45b6-9d94-e52979df35ce", - "metadata": {}, - "source": [ - "from vizro import Vizro\n", - "Vizro._reset()" - ] - }, - { - "cell_type": "raw", - "id": "14382903-6959-4113-b41b-f319783b8845", - "metadata": {}, - "source": [ - "user_question_2_pages = \"\"\"\n", - "Create 2 pages.\n", - "I need a page with one table showing worldwide population and GDP.\n", - "second page has a card saying `this is the second page`. No filter please.\n", - "\"\"\"" - ] - }, - { - "cell_type": "raw", - "id": "f3180a2f-43e1-4e01-88b8-6f7f7789e472", - "metadata": {}, - "source": [ - "res = vizro_ai.dashboard([df], user_question_2_pages)" - ] - }, - { - "cell_type": "markdown", - "id": "600dbaf0-a9ac-451a-bc26-f7bf0317ad54", - "metadata": {}, - "source": [ - "# use langSmith" - ] - }, - { - "cell_type": "raw", - "id": "d3dba222-3089-4b2a-bbf7-89ee2ccca81d", - "metadata": {}, - "source": [ - "user_question_2_data = \"\"\"\n", - "I need a page with 2 tables.\n", - "One table shows the population per continent.\n", - "The second table shows the stock price of some major companies.\n", - "add a filter to filter the world demographic table by continent.\n", - "I also want a second page with two \n", - "cards on it. One should be a card saying: `This was the jolly data dashboard, it was created in Vizro which is amazing` \n", - ", and the second card says `To learn more, visit Vizro docs` and should link to `https://vizro.readthedocs.io/`. The title of the dashboard should be: `My wonderful \n", - "jolly dashboard showing a lot of info`.\n", - "The layout of this page should use `grid=[[0,1]]`\n", - "\"\"\"" - ] - }, - { - "cell_type": "raw", - "id": "c3fccacc-10d4-4f80-85fd-1aa30af4d1fe", - "metadata": {}, - "source": [ - "@traceable\n", - "def traced_dashboard():\n", - " res = vizro_ai.dashboard([df, df_stocks], user_question_2_data)\n", - " return res" - ] - }, - { - "cell_type": "raw", - "id": "b032128a-1605-4a1d-bc0b-caa4f7f03b4d", - "metadata": {}, - "source": [ - "traced_dashboard()" - ] - }, - { - "cell_type": "raw", - "id": "275957de-0da5-4cb0-a5b6-1912f0121d62", - "metadata": {}, - "source": [ - "user_question_gaint = \"\"\"\n", - "I need a dashboard of 10 pages.\n", - "page1: One table shows the population per continent.\n", - "The second table shows the stock price of some major companies.\n", - "add a filter to filter the world demographic table by continent.\n", - "page2: I also want a second page with two \n", - "cards on it. One should be a card saying: `This was the jolly data dashboard, it was created in Vizro which is amazing` \n", - ", and the second card says `To learn more, visit Vizro docs` and should link to `https://vizro.readthedocs.io/`. The title of the dashboard should be: `My wonderful \n", - "jolly dashboard showing a lot of info`.\n", - "page3: One table shows the population per continent. add a filter to filter by year\n", - "page4: One table shows the population per continent. add a filter to filter by continent\n", - "page5: One table shows the population per continent. add a filter to filter by pop\n", - "page6: One table shows the population per continent. add a filter to filter by lifeExp\n", - "page7: one card saying \"this is page 7\"\n", - "page8: one card saying \"this is page 8\"\n", - "page9: table shows the stock price of some major companies. also add a card saying \"Recent trends\"\n", - "page10: one card saying \"thanks for reading\"\n", - "\"\"\"\n", - "\n", - "@traceable\n", - "def giant_dashboard():\n", - " res = vizro_ai.dashboard([df, df_stocks], user_question_gaint)\n", - " return res\n", - "\n", - "giant_dashboard()" - ] - }, - { - "cell_type": "markdown", - "id": "b825c427-7f65-4a04-a430-e72ea93f8fd8", - "metadata": {}, - "source": [ - "# Check cost" - ] - }, - { - "cell_type": "raw", - "id": "a20215e3-882d-4a96-84b0-b298c3bf5e56", - "metadata": {}, - "source": [ - "from langchain_community.callbacks import get_openai_callback\n", - "\n", - "with get_openai_callback() as cb:\n", - " vizro_ai.dashboard([df, df_stocks], user_question_2_data_1filter)\n", - " print(cb)" - ] - }, - { - "cell_type": "raw", - "id": "5b6a3e7e-51db-4038-a2f3-a6869e13a0e9", - "metadata": {}, - "source": [ - "from langchain_community.callbacks import get_openai_callback\n", - "\n", - "with get_openai_callback() as cb:\n", - " vizro_ai.dashboard([df, df_stocks], user_question_2_data)\n", - " print(cb)" - ] - }, - { - "cell_type": "markdown", - "id": "82d8f97f-1cba-4b0e-ba63-2b894858e49a", - "metadata": {}, - "source": [ - "# Test the generated code is working" - ] - }, - { - "cell_type": "raw", - "id": "c244add0-184a-42c1-a591-e48f0df57b68", - "metadata": {}, - "source": [ - "from vizro import Vizro\n", - "Vizro._reset()" - ] - }, - { - "cell_type": "raw", - "id": "1bf72465-432f-482f-9b98-81454a6e9b30", - "metadata": {}, - "source": [ - "from vizro import Vizro\n", - "from vizro.models import AgGrid, Card, Dashboard, Filter, Layout, Page\n", - "from vizro.tables import dash_ag_grid\n", - "import pandas as pd" - ] - }, - { - "cell_type": "raw", - "id": "c9e1fccd-bd99-4ed3-9db2-e24de4f66967", - "metadata": {}, - "source": [ - "# example data manager code to run before running the generated dashboard code\n", - "from vizro.managers import data_manager\n", - "\n", - "def population_data():\n", - " df = px.data.gapminder()\n", - " return df\n", - "\n", - "def stock_data():\n", - " df_stocks = px.data.stocks()\n", - " return df_stocks\n", - " \n", - "\n", - "data_manager[\"world_indicators\"] = population_data\n", - "data_manager[\"stock_prices\"] = stock_data" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "637ca5de-c6cf-4309-b57b-70f3abb7c3f1", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.7" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/vizro-ai/hatch.toml b/vizro-ai/hatch.toml index 66b69e519..258e6c479 100644 --- a/vizro-ai/hatch.toml +++ b/vizro-ai/hatch.toml @@ -25,6 +25,7 @@ VIZRO_AI_LOG_LEVEL = "DEBUG" [envs.default.scripts] example = "cd examples; python example.py" +example-create-dashboard = "cd examples; python example_dashboard.py" lint = "hatch run lint:lint {args:--all-files}" prep-release = [ "hatch version release", diff --git a/vizro-ai/src/vizro_ai/chains/_llm_models.py b/vizro-ai/src/vizro_ai/chains/_llm_models.py index 973f17d06..6d28f8afc 100644 --- a/vizro-ai/src/vizro_ai/chains/_llm_models.py +++ b/vizro-ai/src/vizro_ai/chains/_llm_models.py @@ -70,7 +70,7 @@ def _get_llm_model(model: Optional[Union[ChatOpenAI, str]] = None) -> BaseChatMo ) -def _get_model_name(model): +def _get_model_name(model: BaseChatModel) -> str: methods = [ lambda: model.model_name, # OpenAI models lambda: model.model, # Anthropic models diff --git a/vizro-ai/src/vizro_ai/dashboard/response_models/components.py b/vizro-ai/src/vizro_ai/dashboard/response_models/components.py index 4bb15f9aa..68e328dbb 100644 --- a/vizro-ai/src/vizro_ai/dashboard/response_models/components.py +++ b/vizro-ai/src/vizro_ai/dashboard/response_models/components.py @@ -11,10 +11,9 @@ from pydantic import BaseModel, Field from vizro.tables import dash_ag_grid from vizro_ai.dashboard._pydantic_output import _get_pydantic_output +from vizro_ai.dashboard.response_models.types import component_type from vizro_ai.utils.helper import DebugFailure -from .types import component_type - logger = logging.getLogger(__name__) @@ -55,7 +54,13 @@ def create(self, model, df_metadata) -> ComponentType: id=self.component_id + "_" + self.page_id, figure=dash_ag_grid(data_frame=self.df_name) ) elif self.component_type == "Card": - return _get_pydantic_output(query=self.component_description, llm_model=model, response_model=vm.Card) + result_proxy = _get_pydantic_output( + query=self.component_description, llm_model=model, response_model=vm.Card + ) + proxy_dict = result_proxy.dict() + proxy_dict["id"] = self.component_id + "_" + self.page_id # id to be used by layout + return vm.Card.parse_obj(proxy_dict) + except DebugFailure as e: logger.warning( f"Failed to build component: {self.component_id}.\n ------- \n " @@ -75,7 +80,7 @@ def create(self, model, df_metadata) -> ComponentType: component_type="Card", component_description="Create a card says 'this is worldwide GDP'.", component_id="gdp_card", - page_id="1", + page_id="page1", df_name="N/A", ) component = component_plan.create(model, df_metadata) diff --git a/vizro-ai/src/vizro_ai/dashboard/response_models/controls.py b/vizro-ai/src/vizro_ai/dashboard/response_models/controls.py index d5d9f4177..448108728 100644 --- a/vizro-ai/src/vizro_ai/dashboard/response_models/controls.py +++ b/vizro-ai/src/vizro_ai/dashboard/response_models/controls.py @@ -10,8 +10,7 @@ except ImportError: # pragma: no cov from pydantic import BaseModel, Field, ValidationError, create_model, validator from vizro_ai.dashboard._pydantic_output import _get_pydantic_output - -from .types import control_type +from vizro_ai.dashboard.response_models.types import control_type logger = logging.getLogger(__name__) @@ -109,6 +108,7 @@ def create(self, model, available_components, df_metadata): "returning default values." ) return None + return res else: logger.warning(f"Control type {self.control_type} not recognized.") return None @@ -137,7 +137,9 @@ def create(self, model, available_components, df_metadata): ) control_plan = ControlPlan( control_type="Filter", - control_description="Create a filter that filters the data based on the column 'a'.", + control_description="Create a filter that filters the data by column 'a'.", df_name="gdp_chart", ) - control = control_plan.create(model, ["gdp_chart"], df_metadata) + control = control_plan.create( + model, ["gdp_chart"], df_metadata + ) # error: Target gdp_chart not found in model_manager. diff --git a/vizro-ai/src/vizro_ai/dashboard/response_models/layout.py b/vizro-ai/src/vizro_ai/dashboard/response_models/layout.py index 782ba394b..ae686fc1b 100644 --- a/vizro-ai/src/vizro_ai/dashboard/response_models/layout.py +++ b/vizro-ai/src/vizro_ai/dashboard/response_models/layout.py @@ -4,40 +4,44 @@ from typing import List, Union import vizro.models as vm +from langchain_core.language_models.chat_models import BaseChatModel try: - from pydantic.v1 import BaseModel, Field, validator + from pydantic.v1 import BaseModel, Field, create_model except ImportError: # pragma: no cov - from pydantic import BaseModel, Field, validator -import numpy as np -from vizro.models._layout import _get_grid_lines, _get_unique_grid_component_ids, _validate_grid_areas + from pydantic import BaseModel, Field, create_model from vizro_ai.dashboard._pydantic_output import _get_pydantic_output from vizro_ai.utils.helper import DebugFailure logger = logging.getLogger(__name__) -# TODO: try switch to inherit from Layout directly, like FilterProxy -class LayoutProxyModel(BaseModel): - """Proxy model for Layout.""" +def _convert_layout_to_grid(layout_grid_template_areas): + # TODO: Programmatically convert layout_grid_template_areas to grid + pass - grid: List[List[int]] = Field(..., description="Grid specification to arrange components on screen.") - @validator("grid") - def validate_grid(cls, grid): - """Validate the grid.""" - if len({len(row) for row in grid}) > 1: - raise ValueError("All rows must be of same length.") - - # Validate grid type and values - unique_grid_idx = _get_unique_grid_component_ids(grid) - if 0 not in unique_grid_idx or not np.array_equal(unique_grid_idx, np.arange((unique_grid_idx.max() + 1))): - raise ValueError("Grid must contain consecutive integers starting from 0.") +def _create_layout_proxy(component_ids, layout_grid_template_areas) -> BaseModel: + """Create a layout proxy model.""" - # Validates grid areas spanned by components and spaces - component_grid_lines, space_grid_lines = _get_grid_lines(grid) - _validate_grid_areas(component_grid_lines + space_grid_lines) - return grid + def validate_grid(v): + """Validate the grid.""" + expected_grid = _convert_layout_to_grid(layout_grid_template_areas) + if v != expected_grid: + logger.warning(f"Calculated grid: {expected_grid}, got: {v}") + return v + + return create_model( + "LayoutProxyModel", + grid=( + List[List[int]], + Field(None, description="Grid specification to arrange components on screen."), + ), + __validators__={ + # "validator1": validator("grid", pre=True, allow_reuse=True)(validate_grid), + }, + __base__=vm.Layout, + ) class LayoutPlan(BaseModel): @@ -52,24 +56,30 @@ class LayoutPlan(BaseModel): layout_grid_template_areas: List[str] = Field( [], description="Grid template areas for the layout, which adhere to the grid-template-areas CSS property syntax." - "Each unique string should be used to represent a unique component. If no grid template areas are provided, " - "leave this as an empty list.", + "Each unique string ('component_id' and 'page_id' concated by '_') should be used to " + "represent a unique component. If no grid template areas are provided, leave this as an empty list.", ) - def create(self, model) -> Union[vm.Layout, None]: + def create(self, model: BaseChatModel, component_ids: List[str]) -> Union[vm.Layout, None]: """Create the layout.""" layout_prompt = ( f"Create a layout from the following instructions: {self.layout_description}. Do not make up " f"a layout if not requested. If a layout_grid_template_areas is provided, translate it into " - f"a matrix of integers where each integer represents a unique component (starting from 0). replace " - f"'.' with -1 to represent empty spaces. Here is the grid template areas: {self.layout_grid_template_areas}" + f"a matrix of integers where each integer represents a unique component (starting from 0). " + f"When translating, match the layout_grid_template_areas element string to the same name in " + f"{component_ids} and use the index of {component_ids} to replace the element string. " + f"replace '.' with -1 to represent empty spaces. Here is the grid template areas: \n ------- \n" + f" {self.layout_grid_template_areas}\n ------- \n" ) if self.layout_description == "N/A": return None try: - proxy = _get_pydantic_output(query=layout_prompt, llm_model=model, response_model=LayoutProxyModel) - actual = vm.Layout.parse_obj(proxy.dict(exclude={})) + result_proxy = _create_layout_proxy( + component_ids=component_ids, layout_grid_template_areas=self.layout_grid_template_areas + ) + proxy = _get_pydantic_output(query=layout_prompt, llm_model=model, response_model=result_proxy) + actual = vm.Layout.parse_obj(proxy.dict(exclude={"id": True})) except DebugFailure as e: logger.warning( f"Build failed for `Layout`, returning default values. Try rephrase the prompt or " @@ -87,8 +97,7 @@ def create(self, model) -> Union[vm.Layout, None]: model = _get_llm_model() layout_plan = LayoutPlan( layout_description="Create a layout with a graph on the left and a card on the right.", - layout_grid_template_areas=["graph card"], + layout_grid_template_areas=["graph1 card2 card2", "graph1 . card1"], ) - layout = layout_plan.create(model) + layout = layout_plan.create(model, component_ids=["graph1", "card1", "card2"]) print(layout) # noqa: T201 - print(layout.dict()) # noqa: T201 diff --git a/vizro-ai/src/vizro_ai/dashboard/response_models/page.py b/vizro-ai/src/vizro_ai/dashboard/response_models/page.py index a9336d1a0..58c96b686 100644 --- a/vizro-ai/src/vizro_ai/dashboard/response_models/page.py +++ b/vizro-ai/src/vizro_ai/dashboard/response_models/page.py @@ -29,7 +29,7 @@ class PagePlanner(BaseModel): ..., description="List of components. Must contain at least one component." ) controls_plan: List[ControlPlan] = Field([], description="Controls of the page.") - layout_plan: LayoutPlan = Field(None, description="Layout of the page.") + layout_plan: LayoutPlan = Field(None, description="Layout of components on the page.") unsupported_specs: List[str] = Field( [], description="List of unsupported specs. If there are any unsupported specs, " @@ -80,15 +80,18 @@ def _build_components(self, model, df_metadata): component_log.close() return components - def _get_layout(self, model): + def _get_layout(self, model, df_metadata): if self._layout is None: - self._layout = self._build_layout(model) + self._layout = self._build_layout(model, df_metadata) return self._layout - def _build_layout(self, model): + def _build_layout(self, model, df_metadata): if self.layout_plan is None: return None - return self.layout_plan.create(model) + return self.layout_plan.create( + model=model, + component_ids=self._get_component_ids(model=model, df_metadata=df_metadata), + ) def _get_controls(self, model, df_metadata): if self._controls is None: @@ -102,6 +105,9 @@ def _available_components(self, model, df_metadata): if isinstance(comp, (vm.Graph, vm.AgGrid)) ] + def _get_component_ids(self, model, df_metadata): + return [comp.id for comp in self._get_components(model=model, df_metadata=df_metadata)] + def _build_controls(self, model, df_metadata): controls = [] with tqdm( @@ -134,9 +140,19 @@ def create(self, model, df_metadata): controls = _execute_step( pbar, page_desc + " --> add controls", self._get_controls(model=model, df_metadata=df_metadata) ) - layout = _execute_step(pbar, page_desc + " --> add layout", self._get_layout(model)) + layout = _execute_step( + pbar, page_desc + " --> add layout", self._get_layout(model=model, df_metadata=df_metadata) + ) - page = vm.Page(title=title, components=components, controls=controls, layout=layout) + try: + page = vm.Page(title=title, components=components, controls=controls, layout=layout) + except Exception as e: + if any("Number of page and grid components need to be the same" in error["msg"] for error in e.errors()): + logger.warning( + "Number of page and grid components need to be the same. " + "Please check the layout and the components." + ) + page = vm.Page(title=title, components=components, controls=controls, layout=None) _execute_step(pbar, page_desc + " --> done", None) pbar.close() return page