Skip to content

Commit

Permalink
Add ntile operation
Browse files Browse the repository at this point in the history
  • Loading branch information
xmnlab committed Jan 12, 2021
1 parent b4b2b8b commit c2bace1
Showing 1 changed file with 17 additions and 23 deletions.
40 changes: 17 additions & 23 deletions ibis_omniscidb/tests/test_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,8 @@
import ibis
import pandas as pd
import pytest
from ibis.expr import types as ir

# from ibis.expr import window as ww


# ntile function version using pandas
def _ntile(
data: Union[pd.Series, pd.core.groupby.generic.SeriesGroupBy], bucket: int
):
Expand All @@ -30,6 +26,10 @@ def _ntile(
Returns
-------
pandas.Series
Notes
-----
This function would be used to test the result from the OmniSci backend.
"""
if isinstance(data, pd.core.groupby.generic.SeriesGroupBy):
return pd.concat([_ntile(group, bucket) for name, group in data])
Expand All @@ -49,7 +49,9 @@ def _ntile(

@pytest.mark.parametrize(
'column_name,group_by,order_by,buckets',
[('string_col', ['string_col'], 'id', 7)],
[
('string_col', ['string_col'], 'id', 7),
],
)
def test_ntile(
con: ibis.omniscidb.OmniSciDBClient,
Expand All @@ -60,15 +62,13 @@ def test_ntile(
order_by: List[str],
buckets: int,
):
def prepare_expr(
t: ir.TableExpr,
column_name: str,
buckets: int,
group_by: List[str],
order_by: List[str],
):
return t.mutate(
val=t[column_name]
result_pd = df_alltypes.copy()
result_pd_grouped = result_pd.sort_values(order_by).groupby(group_by)
result_pd['val'] = _ntile(result_pd_grouped[column_name], buckets)

expr = alltypes.mutate(
val=(
alltypes[column_name]
.ntile(buckets=buckets)
.over(
ibis.window(
Expand All @@ -78,15 +78,9 @@ def prepare_expr(
)
)
)
)

df_grouped = df_alltypes.sort_values(order_by).groupby(group_by)
result_pd = _ntile(df_grouped[column_name], buckets)

expr = prepare_expr(alltypes, column_name, buckets, group_by, order_by)

# result_pd = result_pd.sort_values(order_by).reset_index(drop=True)
result_pd = result_pd.sort_values(order_by).reset_index(drop=True)
result_expr = expr.execute().sort_values(order_by).reset_index(drop=True)

pd.testing.assert_series_equal(
result_pd.astype('int64'), result_expr.val.astype('int64')
)
pd.testing.assert_series_equal(result_pd.val, result_expr.val)

0 comments on commit c2bace1

Please sign in to comment.