Skip to content

Commit

Permalink
fix(query builders): quoting identifiers by default
Browse files Browse the repository at this point in the history
  • Loading branch information
scaliseraoul committed Mar 5, 2025
1 parent b67caf8 commit c461eaa
Show file tree
Hide file tree
Showing 7 changed files with 402 additions and 415 deletions.
5 changes: 3 additions & 2 deletions pandasai/query_builders/base_query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sqlglot
from sqlglot import select
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
from sqlglot.optimizer.qualify_columns import quote_identifiers

from pandasai.data_loader.semantic_layer_schema import SemanticLayerSchema, Source
from pandasai.query_builders.sql_transformation_manager import SQLTransformationManager
Expand Down Expand Up @@ -38,7 +39,7 @@ def build_query(self) -> str:
if self.schema.limit:
query = query.limit(self.schema.limit)

return query.sql(pretty=True)
return query.transform(quote_identifiers).sql(pretty=True)

def get_head_query(self, n=5):
query = select(*self._get_columns()).from_(self._get_table_expression())
Expand All @@ -55,7 +56,7 @@ def get_head_query(self, n=5):
# Add LIMIT
query = query.limit(n)

return query.sql(pretty=True)
return query.transform(quote_identifiers).sql(pretty=True)

def get_row_count(self):
return select("COUNT(*)").from_(self._get_table_expression()).sql(pretty=True)
Expand Down
5 changes: 3 additions & 2 deletions pandasai/query_builders/view_query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from sqlglot import exp, expressions, parse_one, select
from sqlglot.expressions import Subquery
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
from sqlglot.optimizer.qualify_columns import quote_identifiers

from ..data_loader.loader import DatasetLoader
from ..data_loader.semantic_layer_schema import SemanticLayerSchema, Transformation
Expand Down Expand Up @@ -79,7 +80,7 @@ def build_query(self) -> str:
query = query.order_by(*self.schema.order_by)
if self.schema.limit:
query = query.limit(self.schema.limit)
return query.sql(pretty=True)
return query.transform(quote_identifiers).sql(pretty=True)

def get_head_query(self, n=5):
"""Get the head query with proper group by column aliasing."""
Expand All @@ -89,7 +90,7 @@ def get_head_query(self, n=5):
query = query.distinct()

query = query.limit(n)
return query.sql(pretty=True)
return query.transform(quote_identifiers).sql(pretty=True)

def _get_sub_query_from_loader(self, loader: DatasetLoader) -> Subquery:
sub_query = parse_one(loader.query_builder.build_query())
Expand Down
7 changes: 1 addition & 6 deletions tests/unit_tests/data_loader/test_sql_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,7 @@ def test_load_mysql_source(self, mysql_schema):

# Verify the SQL query was executed correctly
mock_execute_query.assert_called_once_with(
"""SELECT
email,
first_name,
timestamp
FROM users
LIMIT 5"""
'SELECT\n "email",\n "first_name",\n "timestamp"\nFROM "users"\nLIMIT 5'
)

# Test executing a custom query
Expand Down
42 changes: 21 additions & 21 deletions tests/unit_tests/query_builders/test_group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,14 @@ def test_base_query_builder(self):

expected = (
"SELECT\n"
" category,\n"
" region,\n"
" SUM(amount) AS total_sales,\n"
" AVG(quantity) AS avg_quantity\n"
"FROM sales\n"
' "category",\n'
' "region",\n'
' SUM("amount") AS "total_sales",\n'
' AVG("quantity") AS "avg_quantity"\n'
'FROM "sales"\n'
"GROUP BY\n"
" category,\n"
" region"
' "category",\n'
' "region"'
)
self.assertEqual(query.strip(), expected.strip())

Expand All @@ -104,14 +104,14 @@ def test_local_query_builder(self):

expected = (
"SELECT\n"
" category,\n"
" region,\n"
" SUM(amount) AS total_sales,\n"
" AVG(quantity) AS avg_quantity\n"
' "category",\n'
' "region",\n'
' SUM("amount") AS "total_sales",\n'
' AVG("quantity") AS "avg_quantity"\n'
"FROM READ_CSV('/mocked/absolute/path')\n"
"GROUP BY\n"
" category,\n"
" region"
' "category",\n'
' "region"'
)
self.assertEqual(query.strip(), expected.strip())

Expand All @@ -121,14 +121,14 @@ def test_sql_query_builder(self):

expected = (
"SELECT\n"
" category,\n"
" region,\n"
" SUM(amount) AS total_sales,\n"
" AVG(quantity) AS avg_quantity\n"
"FROM sales\n"
' "category",\n'
' "region",\n'
' SUM("amount") AS "total_sales",\n'
' AVG("quantity") AS "avg_quantity"\n'
'FROM "sales"\n'
"GROUP BY\n"
" category,\n"
" region"
' "category",\n'
' "region"'
)
self.assertEqual(query.strip(), expected.strip())

Expand Down Expand Up @@ -179,5 +179,5 @@ def test_no_group_by(self):
builder = BaseQueryBuilder(schema)
query = builder.build_query()

expected = "SELECT\n" " category,\n" " amount\n" "FROM sales"
expected = 'SELECT\n "category",\n "amount"\nFROM "sales"'
self.assertEqual(query.strip(), expected.strip())
Loading

0 comments on commit c461eaa

Please sign in to comment.