Skip to content

Commit

Permalink
Merge branch 'main' into dev/make-charts-transparent
Browse files Browse the repository at this point in the history
  • Loading branch information
huong-li-nguyen committed Feb 12, 2025
2 parents f4966ef + 42034f5 commit b4fbb6a
Show file tree
Hide file tree
Showing 4 changed files with 216 additions and 42 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
<!--
A new scriv changelog fragment.
Uncomment the section that is right (remove the HTML comment wrapper).
-->

<!--
### Highlights ✨
- A bullet item for the Highlights ✨ category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX. ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Removed
- A bullet item for the Removed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX. ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Added
- A bullet item for the Added category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX. ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Changed
- A bullet item for the Changed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX. ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Deprecated
- A bullet item for the Deprecated category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX. ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Fixed
- A bullet item for the Fixed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX. ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Security
- A bullet item for the Security category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX. ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
9 changes: 7 additions & 2 deletions vizro-ai/src/vizro_ai/_vizro_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from vizro_ai.dashboard._graph.dashboard_creation import _create_and_compile_graph
from vizro_ai.dashboard._pydantic_output import _get_pydantic_model # TODO: make general, ie remove from dashboard
from vizro_ai.dashboard.utils import DashboardOutputs, _extract_overall_imports_and_code, _register_data
from vizro_ai.plot._response_models import ChartPlan, ChartPlanFactory
from vizro_ai.plot._response_models import BaseChartPlan, ChartPlan, ChartPlanFactory
from vizro_ai.utils.helper import _get_df_info

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -65,6 +65,7 @@ def plot(
max_debug_retry: int = 1,
return_elements: bool = False,
validate_code: bool = True,
_minimal_output: bool = False,
) -> Union[go.Figure, ChartPlan]:
"""Plot visuals using vizro via english descriptions, english to chart translation.
Expand All @@ -75,12 +76,16 @@ def plot(
return_elements: Flag to return ChartPlan pydantic model that includes all
possible elements generated. Defaults to `False`.
validate_code: Flag if produced code should be executed to validate it. Defaults to `True`.
_minimal_output: Internal flag to exclude chart insights and code explanations and
skip validation. Defaults to `False`.
Returns:
go.Figure or ChartPlan pydantic model
"""
response_model = ChartPlanFactory(data_frame=df) if validate_code else ChartPlan
chart_plan = BaseChartPlan if _minimal_output else ChartPlan
response_model = ChartPlanFactory(data_frame=df, chart_plan=chart_plan) if validate_code else chart_plan

_, df_sample = _get_df_info(df, n_sample=10)
response = _get_pydantic_model(
query=user_input,
Expand Down
97 changes: 58 additions & 39 deletions vizro-ai/src/vizro_ai/plot/_response_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,37 @@ def _exec_code(code: str, namespace: dict) -> dict:
return namespace


class ChartPlan(BaseModel):
"""Chart plan model."""
def _test_execute_chart_code(data_frame: pd.DataFrame):
def validator(v, values):
"""Test the execution of the chart code."""
imports = "\n".join(values.get("imports", []))
code_to_validate = imports + "\n\n" + v
try:
_safeguard_check(code_to_validate)
except Exception as e:
raise ValueError(
f"Produced code failed the safeguard validation: <{e}>. Please check the code and try again."
)
try:
namespace = globals()
namespace = _exec_code(code_to_validate, namespace)
custom_chart = namespace[f"{CUSTOM_CHART_NAME}"]
fig = custom_chart(data_frame.sample(10, replace=True))
except Exception as e:
raise ValueError(
f"Produced code execution failed the following error: <{e}>. Please check the code and try again, "
f"alternatively try with a more powerful model."
)
assert isinstance(fig, go.Figure), (
f"Expected chart code to return a plotly go.Figure object, but got {type(fig)}"
)
return v

return validator


class BaseChartPlan(BaseModel):
"""Base chart plan model with core fields."""

chart_type: str = Field(
...,
Expand Down Expand Up @@ -90,16 +119,6 @@ class ChartPlan(BaseModel):
must be done within the function.
""",
)
chart_insights: str = Field(
...,
description="""
Insights to what the chart explains or tries to show. Ideally concise and between 30 and 60 words.""",
)
code_explanation: str = Field(
...,
description="""
Explanation of the code steps used for `chart_code` field.""",
)

_additional_vizro_imports: list[str] = PrivateAttr(ADDITIONAL_IMPORTS)

Expand Down Expand Up @@ -177,37 +196,37 @@ def code_vizro(self):
return self._get_complete_code(vizro=True)


class ChartPlan(BaseChartPlan):
"""Extended chart plan model with additional explanatory fields."""

chart_insights: str = Field(
...,
description="""
Insights to what the chart explains or tries to show.
Ideally concise and between 30 and 60 words.""",
)
code_explanation: str = Field(
...,
description="""
Explanation of the code steps used for `chart_code` field.""",
)


class ChartPlanFactory:
def __new__(cls, data_frame: pd.DataFrame) -> ChartPlan: # TODO: change to ChartPlanDynamic
def _test_execute_chart_code(v, values):
"""Test the execution of the chart code."""
imports = "\n".join(values.get("imports", []))
code_to_validate = imports + "\n\n" + v
try:
_safeguard_check(code_to_validate)
except Exception as e:
raise ValueError(
f"Produced code failed the safeguard validation: <{e}>. Please check the code and try again."
)
try:
namespace = globals()
namespace = _exec_code(code_to_validate, namespace)
custom_chart = namespace[f"{CUSTOM_CHART_NAME}"]
fig = custom_chart(data_frame.sample(10, replace=True))
except Exception as e:
raise ValueError(
f"Produced code execution failed the following error: <{e}>. Please check the code and try again, "
f"alternatively try with a more powerful model."
)
assert isinstance(fig, go.Figure), (
f"Expected chart code to return a plotly go.Figure object, but got {type(fig)}"
)
return v
def __new__(cls, data_frame: pd.DataFrame, chart_plan: type[BaseChartPlan] = ChartPlan) -> type[BaseChartPlan]:
"""Creates a chart plan model with additional validation.
Args:
data_frame: DataFrame to use for validation
chart_plan: Chart plan model to run extended validation against. Defaults to ChartPlan.
Returns:
Chart plan model with additional validation
"""
return create_model(
"ChartPlanDynamic",
__base__=chart_plan,
__validators__={
"validator1": validator("chart_code", allow_reuse=True)(_test_execute_chart_code),
"validator1": validator("chart_code", allow_reuse=True)(_test_execute_chart_code(data_frame)),
},
__base__=ChartPlan,
)
104 changes: 103 additions & 1 deletion vizro-ai/tests/unit/vizro-ai/plot/test_response_model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import pandas as pd
import plotly.express as ppx
import pytest
import vizro.plotly.express as px

from vizro_ai.plot._response_models import ChartPlan, ChartPlanFactory
from vizro_ai.plot._response_models import BaseChartPlan, ChartPlan, ChartPlanFactory

df = px.data.iris()

Expand All @@ -21,6 +22,18 @@ def chart_plan():
)


@pytest.fixture
def sample_df():
return pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]})


@pytest.fixture
def valid_chart_code():
return """def custom_chart(data_frame):
import plotly.express as px
return px.scatter(data_frame, x='x', y='y')"""


class TestChartPlanInstantiation:
"""Tests for the ChartPlan class instantiation."""

Expand Down Expand Up @@ -227,3 +240,92 @@ def {expected_chart_name}(data_frame):
def test_get_fig_object(self, chart_plan, vizro, expected_fig):
fig = chart_plan.get_fig_object(data_frame=df, vizro=vizro)
assert fig == expected_fig


def test_chart_plan_factory_with_base_chart_plan(sample_df, valid_chart_code):
"""Test factory creates validated BaseChartPlan."""
ValidatedModel = ChartPlanFactory(data_frame=sample_df, chart_plan=BaseChartPlan)

assert issubclass(ValidatedModel, BaseChartPlan)
assert ValidatedModel is not BaseChartPlan

instance = ValidatedModel(
chart_type="scatter", imports=["import plotly.express as px"], chart_code=valid_chart_code
)
assert isinstance(instance, BaseChartPlan)


def test_chart_plan_factory_with_chart_plan(sample_df, valid_chart_code):
"""Test factory creates validated ChartPlan."""
ValidatedModel = ChartPlanFactory(data_frame=sample_df, chart_plan=ChartPlan)

assert issubclass(ValidatedModel, ChartPlan)
assert ValidatedModel is not ChartPlan

instance = ValidatedModel(
chart_type="scatter",
imports=["import plotly.express as px"],
chart_code=valid_chart_code,
chart_insights="Test insights",
code_explanation="Test explanation",
)
assert isinstance(instance, ChartPlan)


def test_chart_plan_factory_validation_failure(sample_df):
"""Test factory validation fails with invalid code."""
ValidatedModel = ChartPlanFactory(data_frame=sample_df, chart_plan=BaseChartPlan)

with pytest.raises(ValueError, match="The chart code must be wrapped in a function named"):
ValidatedModel(chart_type="scatter", imports=["import plotly.express as px"], chart_code="invalid_code")


def test_base_chart_plan_without_validation(valid_chart_code):
"""Test BaseChartPlan without validation."""
instance = BaseChartPlan(chart_type="scatter", imports=["import plotly.express as px"], chart_code=valid_chart_code)
assert isinstance(instance, BaseChartPlan)


def test_chart_plan_without_validation(valid_chart_code):
"""Test ChartPlan without validation."""
instance = ChartPlan(
chart_type="scatter",
imports=["import plotly.express as px"],
chart_code=valid_chart_code,
chart_insights="Test insights",
code_explanation="Test explanation",
)
assert isinstance(instance, ChartPlan)
assert instance.chart_insights == "Test insights"
assert instance.code_explanation == "Test explanation"


def test_chart_plan_factory_preserves_fields(sample_df, valid_chart_code):
"""Test factory preserves all fields from base class."""
ValidatedModel = ChartPlanFactory(data_frame=sample_df, chart_plan=ChartPlan)

instance = ValidatedModel(
chart_type="scatter",
imports=["import plotly.express as px"],
chart_code=valid_chart_code,
chart_insights="Test insights",
code_explanation="Test explanation",
)

# Check all fields are preserved
assert instance.chart_type == "scatter"
assert instance.imports == ["import plotly.express as px"]
assert instance.chart_code == valid_chart_code
assert instance.chart_insights == "Test insights"
assert instance.code_explanation == "Test explanation"


def test_base_chart_plan_no_explanatory_fields(valid_chart_code):
"""Test BaseChartPlan doesn't have explanatory fields."""
instance = BaseChartPlan(chart_type="scatter", imports=["import plotly.express as px"], chart_code=valid_chart_code)

with pytest.raises(AttributeError):
_ = instance.chart_insights

with pytest.raises(AttributeError):
_ = instance.code_explanation

0 comments on commit b4fbb6a

Please sign in to comment.