Skip to content

Commit

Permalink
Check null count too in sum aggregation (#17964)
Browse files Browse the repository at this point in the history
Closes #17963. We should also check the null count when doing a sum aggregation to match polars.

Also adds `size` and `null_count` directly to cudf.polars Column class.

Authors:
  - Matthew Murray (https://github.com/Matt711)

Approvers:
  - Vyas Ramasubramani (https://github.com/vyasr)
  - Matthew Roeschke (https://github.com/mroeschke)

URL: #17964
  • Loading branch information
Matt711 authored Feb 13, 2025
1 parent 359d936 commit ee74e2d
Show file tree
Hide file tree
Showing 11 changed files with 42 additions and 33 deletions.
22 changes: 15 additions & 7 deletions python/cudf_polars/cudf_polars/containers/column.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0

"""A column, with some properties."""
Expand Down Expand Up @@ -51,7 +51,7 @@ def __init__(
name: str | None = None,
):
self.obj = column
self.is_scalar = self.obj.size() == 1
self.is_scalar = self.size == 1
self.name = name
self.set_sorted(is_sorted=is_sorted, order=order, null_order=null_order)

Expand All @@ -70,9 +70,7 @@ def obj_scalar(self) -> plc.Scalar:
If the column is not length-1.
"""
if not self.is_scalar:
raise ValueError(
f"Cannot convert a column of length {self.obj.size()} to scalar"
)
raise ValueError(f"Cannot convert a column of length {self.size} to scalar")
return plc.copying.get_element(self.obj, 0)

def rename(self, name: str | None, /) -> Self:
Expand Down Expand Up @@ -242,7 +240,7 @@ def set_sorted(
-------
Self with metadata set.
"""
if self.obj.size() <= 1:
if self.size <= 1:
is_sorted = plc.types.Sorted.YES
self.is_sorted = is_sorted
self.order = order
Expand All @@ -268,7 +266,7 @@ def copy(self) -> Self:
def mask_nans(self) -> Self:
"""Return a shallow copy of self with nans masked out."""
if plc.traits.is_floating_point(self.obj.type()):
old_count = self.obj.null_count()
old_count = self.null_count
mask, new_count = plc.transform.nans_to_nulls(self.obj)
result = type(self)(self.obj.with_mask(mask, new_count))
if old_count == new_count:
Expand All @@ -288,3 +286,13 @@ def nan_count(self) -> int:
)
).as_py()
return 0

@property
def size(self) -> int:
"""Return the size of the column."""
return self.obj.size()

@property
def null_count(self) -> int:
"""Return the number of Null values in the column."""
return self.obj.null_count()
6 changes: 3 additions & 3 deletions python/cudf_polars/cudf_polars/dsl/expressions/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def _count(self, column: Column) -> Column:
plc.Column.from_scalar(
plc.interop.from_arrow(
pa.scalar(
column.obj.size() - column.obj.null_count(),
column.size - column.null_count,
type=plc.interop.to_arrow(self.dtype),
),
),
Expand All @@ -181,7 +181,7 @@ def _count(self, column: Column) -> Column:
)

def _sum(self, column: Column) -> Column:
if column.obj.size() == 0:
if column.size == 0 or column.null_count == column.size:
return Column(
plc.Column.from_scalar(
plc.interop.from_arrow(
Expand Down Expand Up @@ -224,7 +224,7 @@ def _first(self, column: Column) -> Column:
return Column(plc.copying.slice(column.obj, [0, 1])[0])

def _last(self, column: Column) -> Column:
n = column.obj.size()
n = column.size
return Column(plc.copying.slice(column.obj, [n - 1, n])[0])

def do_evaluate(
Expand Down
4 changes: 2 additions & 2 deletions python/cudf_polars/cudf_polars/dsl/expressions/binaryop.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0
# TODO: remove need for this
# ruff: noqa: D101
Expand Down Expand Up @@ -98,7 +98,7 @@ def do_evaluate(
)
lop = left.obj
rop = right.obj
if left.obj.size() != right.obj.size():
if left.size != right.size:
if left.is_scalar:
lop = left.obj_scalar
elif right.is_scalar:
Expand Down
8 changes: 4 additions & 4 deletions python/cudf_polars/cudf_polars/dsl/expressions/boolean.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0
# TODO: remove need for this
# ruff: noqa: D101
Expand Down Expand Up @@ -191,7 +191,7 @@ def do_evaluate(
is_any = self.name is BooleanFunction.Name.Any
agg = plc.aggregation.any() if is_any else plc.aggregation.all()
result = plc.reduce.reduce(column.obj, agg, self.dtype)
if not ignore_nulls and column.obj.null_count() > 0:
if not ignore_nulls and column.null_count > 0:
# Truth tables
# Any All
# | F U T | F U T
Expand All @@ -218,14 +218,14 @@ def do_evaluate(
(column,) = columns
return Column(
plc.unary.is_nan(column.obj).with_mask(
column.obj.null_mask(), column.obj.null_count()
column.obj.null_mask(), column.null_count
)
)
elif self.name is BooleanFunction.Name.IsNotNan:
(column,) = columns
return Column(
plc.unary.is_not_nan(column.obj).with_mask(
column.obj.null_mask(), column.obj.null_count()
column.obj.null_mask(), column.null_count
)
)
elif self.name is BooleanFunction.Name.IsFirstDistinct:
Expand Down
4 changes: 2 additions & 2 deletions python/cudf_polars/cudf_polars/dsl/expressions/selection.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0
# TODO: remove need for this
# ruff: noqa: D101
Expand Down Expand Up @@ -50,7 +50,7 @@ def do_evaluate(
n = df.num_rows
if hi >= n or lo < -n:
raise ValueError("gather indices are out of bounds")
if indices.obj.null_count():
if indices.null_count:
bounds_policy = plc.copying.OutOfBoundsPolicy.NULLIFY
obj = plc.replace.replace_nulls(
indices.obj,
Expand Down
8 changes: 4 additions & 4 deletions python/cudf_polars/cudf_polars/dsl/expressions/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def do_evaluate(
(child,) = self.children
column = child.evaluate(df, context=context, mapping=mapping)
delimiter, ignore_nulls = self.options
if column.obj.null_count() > 0 and not ignore_nulls:
if column.null_count > 0 and not ignore_nulls:
return Column(plc.Column.all_null_like(column.obj, 1))
return Column(
plc.strings.combine.join_strings(
Expand All @@ -228,7 +228,7 @@ def do_evaluate(
pat = arg.evaluate(df, context=context, mapping=mapping)
pattern = (
pat.obj_scalar
if pat.is_scalar and pat.obj.size() != column.obj.size()
if pat.is_scalar and pat.size != column.size
else pat.obj
)
return Column(plc.strings.find.contains(column.obj, pattern))
Expand Down Expand Up @@ -298,7 +298,7 @@ def do_evaluate(
plc.strings.find.ends_with(
column.obj,
suffix.obj_scalar
if column.obj.size() != suffix.obj.size() and suffix.is_scalar
if column.size != suffix.size and suffix.is_scalar
else suffix.obj,
)
)
Expand All @@ -308,7 +308,7 @@ def do_evaluate(
plc.strings.find.starts_with(
column.obj,
prefix.obj_scalar
if column.obj.size() != prefix.obj.size() and prefix.is_scalar
if column.size != prefix.size and prefix.is_scalar
else prefix.obj,
)
)
Expand Down
2 changes: 1 addition & 1 deletion python/cudf_polars/cudf_polars/dsl/expressions/unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def do_evaluate(
else plc.types.Order.DESCENDING
)
null_order = plc.types.NullOrder.BEFORE
if column.obj.null_count() > 0 and (n := column.obj.size()) > 1:
if column.null_count > 0 and (n := column.size) > 1:
# PERF: This invokes four stream synchronisations!
has_nulls_first = not plc.copying.get_element(column.obj, 0).is_valid()
has_nulls_last = not plc.copying.get_element(
Expand Down
6 changes: 3 additions & 3 deletions python/cudf_polars/cudf_polars/dsl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def broadcast(*columns: Column, target_length: int | None = None) -> list[Column
"""
if len(columns) == 0:
return []
lengths: set[int] = {column.obj.size() for column in columns}
lengths: set[int] = {column.size for column in columns}
if lengths == {1}:
if target_length is None:
return list(columns)
Expand All @@ -116,7 +116,7 @@ def broadcast(*columns: Column, target_length: int | None = None) -> list[Column
)
return [
column
if column.obj.size() != 1
if column.size != 1
else Column(
plc.Column.from_scalar(column.obj_scalar, nrows),
is_sorted=plc.types.Sorted.YES,
Expand Down Expand Up @@ -820,7 +820,7 @@ def do_evaluate(
) -> DataFrame: # pragma: no cover; not exposed by polars yet
"""Evaluate and return a dataframe."""
columns = broadcast(*(e.evaluate(df) for e in exprs))
assert all(column.obj.size() == 1 for column in columns)
assert all(column.size == 1 for column in columns)
return DataFrame(columns)


Expand Down
4 changes: 2 additions & 2 deletions python/cudf_polars/tests/containers/test_column.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations
Expand Down Expand Up @@ -59,7 +59,7 @@ def test_mask_nans(typeid):
values = pyarrow.array([0, 0, 0], type=plc.interop.to_arrow(dtype))
column = Column(plc.interop.from_arrow(values))
masked = column.mask_nans()
assert column.obj.null_count() == masked.obj.null_count()
assert column.null_count == masked.null_count


def test_mask_nans_float():
Expand Down
5 changes: 3 additions & 2 deletions python/cudf_polars/tests/expressions/test_agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,8 @@ def test_agg_singleton(op):
assert_gpu_result_equal(q)


def test_sum_empty_zero():
df = pl.LazyFrame({"a": pl.Series(values=[], dtype=pl.Int32())})
@pytest.mark.parametrize("data", [[], [None], [None, 2, 3, None]])
def test_sum_empty_zero(data):
df = pl.LazyFrame({"a": pl.Series(values=data, dtype=pl.Int32())})
q = df.select(pl.col("a").sum())
assert_gpu_result_equal(q)
6 changes: 3 additions & 3 deletions python/cudf_polars/tests/utils/test_broadcast.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations
Expand Down Expand Up @@ -26,7 +26,7 @@ def test_broadcast_all_scalar(target):
expected = 1 if target is None else target

assert [c.name for c in result] == [f"col{i}" for i in range(3)]
assert all(column.obj.size() == expected for column in result)
assert all(column.size == expected for column in result)


def test_invalid_target_length():
Expand Down Expand Up @@ -73,4 +73,4 @@ def test_broadcast_with_scalars(nrows):

result = broadcast(*columns)
assert [c.name for c in result] == [f"col{i}" for i in range(3)]
assert all(column.obj.size() == nrows for column in result)
assert all(column.size == nrows for column in result)

0 comments on commit ee74e2d

Please sign in to comment.