Skip to content

Commit

Permalink
fix(coverage): adding unit tests (#1641)
Browse files Browse the repository at this point in the history
* fix(coverage): adding unit tests

* fix(coverage): format

* chore: fix typo

Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>

* chore: rename to test_get_group_by_columns

Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>

* chore: rename to test_column_must_be_defined_for_view

* Update tests/unit_tests/data_loader/test_loader.py

Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>

* fix(coverage): assertion missing

* fix(coverage): useless optional str

---------

Co-authored-by: Gabriele Venturi <[email protected]>
Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Feb 27, 2025
1 parent 82a5c8a commit ae431f8
Show file tree
Hide file tree
Showing 8 changed files with 134 additions and 11 deletions.
2 changes: 0 additions & 2 deletions pandasai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,6 @@ def create(
group_by=group_by,
transformations=parsed_transformations,
)
else:
raise InvalidConfigError("Unable to create schema with the provided params")

schema.description = description or schema.description

Expand Down
5 changes: 1 addition & 4 deletions pandasai/data_loader/semantic_layer_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,7 @@ def is_column_type_supported(cls, type: str) -> str:

@field_validator("expression")
@classmethod
def is_expression_valid(cls, expr: Optional[str]) -> Optional[str]:
if not expr:
return None

def is_expression_valid(cls, expr: str) -> str:
try:
parse_one(expr)
return expr
Expand Down
3 changes: 0 additions & 3 deletions pandasai/query_builders/view_query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,6 @@ def normalize_view_column_alias(name: str) -> str:

def _get_group_by_columns(self) -> list[str]:
"""Get the group by columns with proper view column aliasing."""
if not self.schema.group_by:
return []

group_by_cols = []
for col in self.schema.group_by:
group_by_cols.append(self.normalize_view_column_alias(col))
Expand Down
11 changes: 11 additions & 0 deletions tests/unit_tests/data_loader/test_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pandasai.data_loader.loader import DatasetLoader
from pandasai.data_loader.local_loader import LocalDatasetLoader
from pandasai.dataframe.base import DataFrame
from pandasai.exceptions import MaliciousQueryError
from pandasai.query_builders import LocalQueryBuilder


Expand Down Expand Up @@ -101,3 +102,13 @@ def test_build_dataset_csv_schema(self, sample_schema):

assert isinstance(result, DataFrame)
assert "email" in result.columns

def test_malicious_query(self, sample_schema):
loader = LocalDatasetLoader(sample_schema, "test/test")
with pytest.raises(MaliciousQueryError):
loader.execute_query("DROP TABLE")

def test_runtime_error(self, sample_schema):
loader = LocalDatasetLoader(sample_schema, "test/test")
with pytest.raises(RuntimeError):
loader.execute_query("SELECT * FROM nonexistent_table")
36 changes: 36 additions & 0 deletions tests/unit_tests/data_loader/test_transformation_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
from pydantic import ValidationError

from pandasai.data_loader.semantic_layer_schema import (
Column,
SemanticLayerSchema,
Source,
SQLConnectionConfig,
Transformation,
TransformationParams,
)
Expand Down Expand Up @@ -226,3 +229,36 @@ def test_rename_transformation_missing_params():
},
],
)


def test_column_expression_parse_error():
with pytest.raises(ValueError):
Column.is_expression_valid("invalid SELECT FROM sql")


def test_incompatible_source():
source1 = Source(type="csv", path="path")
source2 = Source(
type="postgres",
connection=SQLConnectionConfig(
**{
"host": "example.amazonaws.com",
"port": 5432,
"user": "user",
"password": "password",
"database": "db",
}
),
table="table",
)
assert not source1.is_compatible_source(source2)


def test_source_or_view_error():
with pytest.raises(ValidationError):
SemanticLayerSchema(name="ciao")


def test_column_must_be_defined_for_view():
with pytest.raises(ValidationError):
SemanticLayerSchema(name="ciao", view=True)
33 changes: 32 additions & 1 deletion tests/unit_tests/query_builders/test_query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
import pytest
import sqlglot

from pandasai.data_loader.semantic_layer_schema import SemanticLayerSchema
from pandasai.data_loader.semantic_layer_schema import (
SemanticLayerSchema,
Transformation,
)
from pandasai.query_builders import LocalQueryBuilder
from pandasai.query_builders.base_query_builder import BaseQueryBuilder
from pandasai.query_builders.sql_query_builder import SqlQueryBuilder
Expand Down Expand Up @@ -365,3 +368,31 @@ def test_order_by_injection(self, injection, mysql_schema):
query_builder = BaseQueryBuilder(mysql_schema)
with pytest.raises((sqlglot.errors.ParseError, sqlglot.errors.TokenError)):
query_builder.build_query()

def test_build_query_distinct(self, sample_schema):
base_query_builder = BaseQueryBuilder(sample_schema)
base_query_builder.schema.transformations = [
Transformation(type="remove_duplicates")
]
result = base_query_builder.build_query()
assert result.startswith("SELECT DISTINCT")

def test_build_query_distinct_head(self, sample_schema):
base_query_builder = BaseQueryBuilder(sample_schema)
base_query_builder.schema.transformations = [
Transformation(type="remove_duplicates")
]
result = base_query_builder.get_head_query()
assert result.startswith("SELECT DISTINCT")

def test_build_query_order_by(self, sample_schema):
base_query_builder = BaseQueryBuilder(sample_schema)
base_query_builder.schema.order_by = ["column"]
result = base_query_builder.build_query()
assert "ORDER BY\n column" in result

def test_get_group_by_columns(self, sample_schema):
base_query_builder = BaseQueryBuilder(sample_schema)
base_query_builder.schema.group_by = ["parents"]
result = base_query_builder.get_head_query()
assert "GROUP BY\n parents" in result
34 changes: 33 additions & 1 deletion tests/unit_tests/query_builders/test_view_query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

import pytest

from pandasai.data_loader.semantic_layer_schema import SemanticLayerSchema
from pandasai.data_loader.semantic_layer_schema import (
SemanticLayerSchema,
Transformation,
)
from pandasai.data_loader.sql_loader import SQLDatasetLoader
from pandasai.query_builders.sql_query_builder import SqlQueryBuilder
from pandasai.query_builders.view_query_builder import ViewQueryBuilder
Expand Down Expand Up @@ -76,13 +79,42 @@ def test_build_query(self, view_query_builder):
) AS parent_children"""
)

def test_build_query_distinct(self, view_query_builder):
view_query_builder.schema.transformations = [
Transformation(type="remove_duplicates")
]
result = view_query_builder.build_query()
assert result.startswith("SELECT DISTINCT")

def test_build_query_distinct_head(self, view_query_builder):
view_query_builder.schema.transformations = [
Transformation(type="remove_duplicates")
]
result = view_query_builder.get_head_query()
assert result.startswith("SELECT DISTINCT")

def test_build_query_order_by(self, view_query_builder):
view_query_builder.schema.order_by = ["column"]
result = view_query_builder.build_query()
assert "ORDER BY\n column" in result

def test_build_query_limit(self, view_query_builder):
view_query_builder.schema.limit = 10
result = view_query_builder.build_query()
assert "LIMIT 10" in result

def test_get_columns(self, view_query_builder):
assert view_query_builder._get_columns() == [
"parents_id AS parents_id",
"parents_name AS parents_name",
"children_name AS children_name",
]

def test_get__group_by_columns(self, view_query_builder):
view_query_builder.schema.group_by = ["parents.id"]
group_by_column = view_query_builder._get_group_by_columns()
assert group_by_column == ["parents_id"]

def test_get_table_expression(self, view_query_builder):
assert (
view_query_builder._get_table_expression()
Expand Down
21 changes: 21 additions & 0 deletions tests/unit_tests/test_pandasai_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,27 @@ def test_create_valid_dataset_no_params(
assert result.schema.description is None
assert mock_loader_instance.load.call_count == 1

def test_create_valid_dataset_group_by(
self, sample_df, mock_loader_instance, mock_file_manager
):
"""Test creating a dataset with valid inputs."""
with patch.object(sample_df, "to_parquet") as mock_to_parquet:
result = pandasai.create(
"test-org/test-dataset",
sample_df,
columns=[
{"name": "A"},
{"name": "B", "expression": "avg(B)", "alias": "average_b"},
],
group_by=["A"],
)
assert result.schema.group_by == ["A"]

def test_create_invalid(self, sample_df, mock_loader_instance, mock_file_manager):
"""Test creating a dataset with valid inputs."""
with pytest.raises(InvalidConfigError):
pandasai.create("test-org/test-dataset")

def test_create_invalid_path_format(self, sample_df):
"""Test creating a dataset with invalid path format."""
with pytest.raises(
Expand Down

0 comments on commit ae431f8

Please sign in to comment.