diff --git a/vizro-ai/src/vizro_ai/_vizro_ai.py b/vizro-ai/src/vizro_ai/_vizro_ai.py index fabc07f12..be9697ea6 100644 --- a/vizro-ai/src/vizro_ai/_vizro_ai.py +++ b/vizro-ai/src/vizro_ai/_vizro_ai.py @@ -58,7 +58,7 @@ def _lazy_get_component(self, component_class: Any) -> Any: # TODO configure co return self.components_instances[component_class] def _run_plot_tasks( - self, df: pd.DataFrame, user_input: str, max_debug_retry: int = 3, explain: bool = False + self, df: pd.DataFrame, user_input: str, max_debug_retry: int = 3, explain: bool = False, return_elements: bool = False ) -> PlotOutputs: """Task execution.""" chart_type_pipeline = self.pipeline_manager.chart_type_pipeline @@ -86,7 +86,7 @@ def _run_plot_tasks( ) fig_object = _exec_code_and_retrieve_fig( - code=code_string, local_args={"df": df}, show_fig=_is_jupyter(), is_notebook_env=_is_jupyter() + code=code_string, local_args={"df": df}, show_fig=_is_jupyter(), is_notebook_env=_is_jupyter(), return_elements=return_elements ) if explain: business_insights, code_explanation = self._lazy_get_component(GetCodeExplanation).run( @@ -137,7 +137,7 @@ def plot( # pylint: disable=too-many-arguments # noqa: PLR0913 """ vizro_plot = self._run_plot_tasks( - df=df, user_input=user_input, explain=explain, max_debug_retry=max_debug_retry + df=df, user_input=user_input, explain=explain, max_debug_retry=max_debug_retry, return_elements=return_elements ) if not explain: diff --git a/vizro-ai/src/vizro_ai/utils/helper.py b/vizro-ai/src/vizro_ai/utils/helper.py index 2de34b845..17a2d5cc0 100644 --- a/vizro-ai/src/vizro_ai/utils/helper.py +++ b/vizro-ai/src/vizro_ai/utils/helper.py @@ -60,7 +60,7 @@ def _debug_helper( def _exec_code_and_retrieve_fig( - code: str, local_args: Optional[Dict] = None, show_fig: bool = False, is_notebook_env: bool = True + code: str, local_args: Optional[Dict] = None, show_fig: bool = False, is_notebook_env: bool = True, return_elements: bool = False ) -> go.Figure: """Execute code in notebook with correct namespace and return fig object. @@ -76,10 +76,11 @@ def _exec_code_and_retrieve_fig( """ from IPython import get_ipython - if show_fig and "\nfig.show()" not in code: - code += "\nfig.show()" - elif not show_fig: - code = code.replace("fig.show()", "") + if return_elements: + if show_fig and "\nfig.show()" not in code: + code += "\nfig.show()" + elif not show_fig: + code = code.replace("fig.show()", "") namespace = get_ipython().user_ns if is_notebook_env else globals()