diff --git a/ibis_omniscidb/operations.py b/ibis_omniscidb/operations.py index 0b220b3..d22c551 100644 --- a/ibis_omniscidb/operations.py +++ b/ibis_omniscidb/operations.py @@ -1101,6 +1101,7 @@ def formatter(translator, expr): ops.Lag: _shift_like('lag'), ops.Lead: _shift_like('lead', 1), ops.MinRank: lambda *args: 'rank()', + ops.NTile: _window_op_one_param('ntile'), # cume_dist vs percent_rank # https://github.com/ibis-project/ibis/issues/1975 ops.PercentRank: lambda *args: 'cume_dist()', @@ -1118,7 +1119,6 @@ def formatter(translator, expr): ops.CumulativeAny, ops.CumulativeAll, ops.IdenticalTo, - ops.NTile, ops.NthValue, ops.GroupConcat, ops.IsInf, diff --git a/ibis_omniscidb/tests/conftest.py b/ibis_omniscidb/tests/conftest.py index e62b30d..d24fe00 100644 --- a/ibis_omniscidb/tests/conftest.py +++ b/ibis_omniscidb/tests/conftest.py @@ -7,6 +7,8 @@ import pandas import pytest +import ibis_omniscidb + OMNISCIDB_HOST = os.environ.get('IBIS_TEST_OMNISCIDB_HOST', 'localhost') OMNISCIDB_PORT = int(os.environ.get('IBIS_TEST_OMNISCIDB_PORT', 6274)) OMNISCIDB_USER = os.environ.get('IBIS_TEST_OMNISCIDB_USER', 'admin') @@ -25,7 +27,7 @@ def con(): ------- ibis.omniscidb.OmniSciDBClient """ - return ibis.omniscidb.connect( + return ibis_omniscidb.connect( protocol=OMNISCIDB_PROTOCOL, host=OMNISCIDB_HOST, port=OMNISCIDB_PORT, @@ -61,7 +63,7 @@ def test_table(con): def session_con(): """Define a session connection fixture.""" # TODO: fix return issue - return ibis.omniscidb.connect( + return ibis_omniscidb.connect( protocol=OMNISCIDB_PROTOCOL, host=OMNISCIDB_HOST, port=OMNISCIDB_PORT, diff --git a/ibis_omniscidb/tests/test_window.py b/ibis_omniscidb/tests/test_window.py new file mode 100644 index 0000000..aa1fd71 --- /dev/null +++ b/ibis_omniscidb/tests/test_window.py @@ -0,0 +1,86 @@ +from typing import List, Union + +import ibis +import pandas as pd +import pytest + + +def _ntile( + data: Union[pd.Series, pd.core.groupby.generic.SeriesGroupBy], bucket: int +): + """ + NTILE divides given data set into a number of buckets. + + It divides an ordered and grouped data set into a number of buckets + and assigns the appropriate bucket number to each row. + Return an integer ranging from 0 to `bucket - 1`, dividing the + partition as equally as possible. + Adapted from: + https://gist.github.com/xmnlab/2c1f93df1a6c6bde4e32c8579117e9cc + + Parameters + ---------- + data : pandas.core.groupby.generic.SeriesGroupBy or pandas.Series + bucket: int + + Returns + ------- + pandas.Series + + Notes + ----- + This function would be used to test the result from the OmniSci backend. + """ + if isinstance(data, pd.core.groupby.generic.SeriesGroupBy): + return pd.concat([_ntile(group, bucket) for name, group in data]) + + n = data.shape[0] + sub_n = n // bucket + diff = n - (sub_n * bucket) + + result = [] + for i in range(bucket): + sub_result = [i] * (sub_n + (1 if diff else 0)) + result.extend(sub_result) + if diff > 0: + diff -= 1 + return pd.Series(result, index=data.index) + + +@pytest.mark.parametrize( + 'column_name,group_by,order_by,buckets', + [ + ('string_col', ['string_col'], 'id', 7), + ], +) +def test_ntile( + con: ibis.omniscidb.OmniSciDBClient, + alltypes: ibis.expr.types.TableExpr, + df_alltypes: pd.DataFrame, + column_name: str, + group_by: List[str], + order_by: List[str], + buckets: int, +): + result_pd = df_alltypes.copy() + result_pd_grouped = result_pd.sort_values(order_by).groupby(group_by) + result_pd['val'] = _ntile(result_pd_grouped[column_name], buckets) + + expr = alltypes.mutate( + val=( + alltypes[column_name] + .ntile(buckets=buckets) + .over( + ibis.window( + following=0, + group_by=group_by, + order_by=order_by, + ) + ) + ) + ) + + result_pd = result_pd.sort_values(order_by).reset_index(drop=True) + result_expr = expr.execute().sort_values(order_by).reset_index(drop=True) + + pd.testing.assert_series_equal(result_pd.val, result_expr.val)