Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(agent): chat with view and other dataset #1660

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 20 additions & 24 deletions pandasai/agent/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from ..config import Config
from ..constants import LOCAL_SOURCE_TYPES
from ..data_loader.duck_db_connection_manager import DuckDBConnectionManager
from ..data_loader.semantic_layer_schema import Source
from ..query_builders.base_query_builder import BaseQueryBuilder
from ..query_builders.sql_parser import SQLParser
from .state import AgentState
Expand Down Expand Up @@ -69,7 +70,7 @@ def __init__(
)

if isinstance(dfs, list):
sources = [df.schema.source for df in dfs]
sources = [df.schema.source or df._loader.source for df in dfs]
if not BaseQueryBuilder.check_compatible_sources(sources):
raise ValueError(
f"The sources of these datasets: {dfs} are not compatibles"
Expand Down Expand Up @@ -120,23 +121,6 @@ def execute_code(self, code: str) -> dict:

return code_executor.execute_and_return_result(code)

@staticmethod
def _parse_correct_table_name(query: str, dfs: List[VirtualDataFrame]) -> str:
table_mapping = {
df.schema.name: df.query_builder._get_table_expression() for df in dfs
}

return SQLParser.replace_table_and_column_names(query, table_mapping)

def _execute_local_sql_query(self, query: str) -> pd.DataFrame:
try:
db_manager = DuckDBConnectionManager()
for df in self._state.dfs:
db_manager.register(df.schema.name, df)
return db_manager.sql(query).df()
except duckdb.Error as e:
raise RuntimeError(f"SQL execution failed: {e}") from e

def _execute_sql_query(self, query: str) -> pd.DataFrame:
"""
Executes an SQL query on registered DataFrames.
Expand All @@ -150,14 +134,26 @@ def _execute_sql_query(self, query: str) -> pd.DataFrame:
if not self._state.dfs:
raise ValueError("No DataFrames available to register for query execution.")

df0 = self._state.dfs[0]
source = df0.schema.source or None
db_manager = DuckDBConnectionManager()

table_mapping = {}
df_executor = None

for df in self._state.dfs:
if hasattr(df, "query_builder"):
# df is a valid dataset with query builder, loader and execute_sql_query method
table_mapping[df.schema.name] = df.query_builder._get_table_expression()
df_executor = df.execute_sql_query
else:
# dataset created from loading a csv, no query builder available
db_manager.register(df.schema.name, df)

final_query = SQLParser.replace_table_and_column_names(query, table_mapping)

if source and source.type in LOCAL_SOURCE_TYPES:
return self._execute_local_sql_query(query)
if not df_executor:
return db_manager.sql(final_query).df()
else:
query = self._parse_correct_table_name(query, self._state.dfs)
return df0.execute_sql_query(query)
return df_executor(final_query)

def execute_with_retries(self, code: str) -> Any:
"""Execute the code with retry logic."""
Expand Down
10 changes: 0 additions & 10 deletions tests/unit_tests/agent/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,16 +378,6 @@ def test_train_method_with_code_but_no_queries(self, agent):
with pytest.raises(ValueError):
agent.train(codes)

def test_execute_local_sql_query_success(self, agent, sample_df):
query = f'SELECT count(*) as total from "{sample_df.schema.name}";'
expected_result = pd.DataFrame({"total": [3]})
result = agent._execute_local_sql_query(query)
pd.testing.assert_frame_equal(result, expected_result)

def test_execute_local_sql_query_failure(self, agent):
with pytest.raises(RuntimeError, match="SQL execution failed"):
agent._execute_local_sql_query("wrong query;")

def test_execute_sql_query_success_local(self, agent, sample_df):
query = f'SELECT count(*) as total from "{sample_df.schema.name}";'
expected_result = pd.DataFrame({"total": [3]})
Expand Down
Loading