Skip to content

Commit

Permalink
tidy
Browse files Browse the repository at this point in the history
  • Loading branch information
lingyielia committed Feb 11, 2025
1 parent 1d50560 commit dddd2dd
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 20 deletions.
21 changes: 15 additions & 6 deletions vizro-ai/examples/example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"llm = None\n",
"# uncomment below to try out different models\n",
"# llm = \"gpt-4o\"\n",
"# llm = \"o1\"\n",
"# llm = \"claude-3-5-sonnet-latest\"\n",
"# llm = \"mistral-large-latest\"\n",
"\n",
"# llm = \"grok-beta\" #xAI API is compatible with OpenAI. To use grok-beta,\n",
Expand All @@ -24,9 +24,7 @@
"\n",
"# from langchain_openai import ChatOpenAI\n",
"# llm = ChatOpenAI(\n",
"# model=llm,\n",
"# disabled_params={\"parallel_tool_calls\": None}\n",
"# )\n",
"# model=\"gpt-4o\")\n",
"\n",
"\n",
"# import os\n",
Expand Down Expand Up @@ -104,18 +102,29 @@
]
},
{
"cell_type": "raw",
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"vizro_ai.plot(df, \"describe the composition of gdp in continent, and add horizontal line for avg gdp\")"
]
},
{
"cell_type": "raw",
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"vizro_ai.plot(df, \"show me the geo distribution of life expectancy and set year as animation \")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
6 changes: 1 addition & 5 deletions vizro-ai/src/vizro_ai/_vizro_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,21 +89,17 @@ def plot(
response_model = ChartPlanFactory(data_frame=df) if validate_code else ChartPlan

_, df_sample = _get_df_info(df, n_sample=10)

response = _get_pydantic_model(
query=user_input,
llm_model=self.model,
response_model=response_model,
df_info=df_sample,
max_retry=max_debug_retry,
)

if return_elements:
return response
else:
# Time get_fig_object
fig = response.get_fig_object(data_frame=df)
return fig
return response.get_fig_object(data_frame=df)

def dashboard(
self,
Expand Down
6 changes: 1 addition & 5 deletions vizro-ai/src/vizro_ai/dashboard/_pydantic_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ def _get_pydantic_model(
"""Get the pydantic output from the LLM model with retry logic."""
for attempt in range(max_retry):
attempt_is_retry = attempt > 0

prompt = _create_prompt(retry=attempt_is_retry)
message_content = _create_message_content(
query, df_info, str(last_validation_error) if attempt_is_retry else None, retry=attempt_is_retry
Expand All @@ -108,10 +107,7 @@ def _get_pydantic_model(
return _handle_google_llm_response(llm_model, response_model, prompt, message_content)

pydantic_llm = prompt | llm_model.with_structured_output(response_model)

result = pydantic_llm.invoke(message_content)

return result
return pydantic_llm.invoke(message_content)

except ValidationError as validation_error:
last_validation_error = validation_error
Expand Down
5 changes: 1 addition & 4 deletions vizro-ai/src/vizro_ai/plot/_response_models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
"""Code powering the plot command."""

import logging

try:
from pydantic.v1 import BaseModel, Field, PrivateAttr, create_model, validator
except ImportError: # pragma: no cov
Expand Down Expand Up @@ -181,12 +179,11 @@ class ChartPlan(BaseChartPlan):


class ChartPlanFactory:
def __new__(cls, data_frame: pd.DataFrame) -> ChartPlan:
def __new__(cls, data_frame: pd.DataFrame) -> ChartPlan: # TODO: change to ChartPlanDynamic
def _test_execute_chart_code(v, values):
"""Test the execution of the chart code."""
imports = "\n".join(values.get("imports", []))
code_to_validate = imports + "\n\n" + v

try:
_safeguard_check(code_to_validate)
except Exception as e:
Expand Down

0 comments on commit dddd2dd

Please sign in to comment.