diff --git a/pandasai/agent/base.py b/pandasai/agent/base.py index 00613ae89..d0974defc 100644 --- a/pandasai/agent/base.py +++ b/pandasai/agent/base.py @@ -29,6 +29,7 @@ from ..config import Config from ..constants import LOCAL_SOURCE_TYPES from ..data_loader.duck_db_connection_manager import DuckDBConnectionManager +from ..data_loader.semantic_layer_schema import Source from ..query_builders.base_query_builder import BaseQueryBuilder from ..query_builders.sql_parser import SQLParser from .state import AgentState @@ -69,7 +70,7 @@ def __init__( ) if isinstance(dfs, list): - sources = [df.schema.source for df in dfs] + sources = [df.schema.source or df._loader.source for df in dfs] if not BaseQueryBuilder.check_compatible_sources(sources): raise ValueError( f"The sources of these datasets: {dfs} are not compatibles" @@ -120,23 +121,6 @@ def execute_code(self, code: str) -> dict: return code_executor.execute_and_return_result(code) - @staticmethod - def _parse_correct_table_name(query: str, dfs: List[VirtualDataFrame]) -> str: - table_mapping = { - df.schema.name: df.query_builder._get_table_expression() for df in dfs - } - - return SQLParser.replace_table_and_column_names(query, table_mapping) - - def _execute_local_sql_query(self, query: str) -> pd.DataFrame: - try: - db_manager = DuckDBConnectionManager() - for df in self._state.dfs: - db_manager.register(df.schema.name, df) - return db_manager.sql(query).df() - except duckdb.Error as e: - raise RuntimeError(f"SQL execution failed: {e}") from e - def _execute_sql_query(self, query: str) -> pd.DataFrame: """ Executes an SQL query on registered DataFrames. @@ -150,14 +134,26 @@ def _execute_sql_query(self, query: str) -> pd.DataFrame: if not self._state.dfs: raise ValueError("No DataFrames available to register for query execution.") - df0 = self._state.dfs[0] - source = df0.schema.source or None + db_manager = DuckDBConnectionManager() + + table_mapping = {} + df_executor = None + + for df in self._state.dfs: + if hasattr(df, "query_builder"): + # df is a valid dataset with query builder, loader and execute_sql_query method + table_mapping[df.schema.name] = df.query_builder._get_table_expression() + df_executor = df.execute_sql_query + else: + # dataset created from loading a csv, no query builder available + db_manager.register(df.schema.name, df) + + final_query = SQLParser.replace_table_and_column_names(query, table_mapping) - if source and source.type in LOCAL_SOURCE_TYPES: - return self._execute_local_sql_query(query) + if not df_executor: + return db_manager.sql(final_query).df() else: - query = self._parse_correct_table_name(query, self._state.dfs) - return df0.execute_sql_query(query) + return df_executor(final_query) def execute_with_retries(self, code: str) -> Any: """Execute the code with retry logic.""" diff --git a/tests/unit_tests/agent/test_agent.py b/tests/unit_tests/agent/test_agent.py index de11199e6..c1dd21447 100644 --- a/tests/unit_tests/agent/test_agent.py +++ b/tests/unit_tests/agent/test_agent.py @@ -378,16 +378,6 @@ def test_train_method_with_code_but_no_queries(self, agent): with pytest.raises(ValueError): agent.train(codes) - def test_execute_local_sql_query_success(self, agent, sample_df): - query = f'SELECT count(*) as total from "{sample_df.schema.name}";' - expected_result = pd.DataFrame({"total": [3]}) - result = agent._execute_local_sql_query(query) - pd.testing.assert_frame_equal(result, expected_result) - - def test_execute_local_sql_query_failure(self, agent): - with pytest.raises(RuntimeError, match="SQL execution failed"): - agent._execute_local_sql_query("wrong query;") - def test_execute_sql_query_success_local(self, agent, sample_df): query = f'SELECT count(*) as total from "{sample_df.schema.name}";' expected_result = pd.DataFrame({"total": [3]})