diff --git a/src/matchbox/common/transform.py b/src/matchbox/common/transform.py index 1903bfbb..be120ce2 100644 --- a/src/matchbox/common/transform.py +++ b/src/matchbox/common/transform.py @@ -2,8 +2,7 @@ import multiprocessing from collections import defaultdict from concurrent.futures import ProcessPoolExecutor -from typing import Callable, Generic, Hashable, Iterable, Literal, TypeVar -from uuid import uuid4 +from typing import Callable, Generic, Hashable, Iterable, TypeVar import numpy as np import pyarrow as pa @@ -386,48 +385,3 @@ def to_hierarchical_clusters( ) return pa.concat_tables(results) - - -def drop_duplicates( - table: pa.Table, - on: list[str] | None = None, - keep: Literal["first", "last"] = "first", -) -> pa.Table: - """ - Remove duplicate rows from a PyArrow table based on specified columns. - - This function efficiently removes duplicate rows from a PyArrow table, - keeping either the first or last occurrence of each unique combination - of values in the specified columns. - - Lifted with love from this gist: - https://gist.github.com/nmehran/57f264bd951b2f77af08f760eafea40e - - An alternative: - https://github.com/TomScheffers/pyarrow_ops/ - """ - if not isinstance(table, pa.Table): - raise TypeError("Parameter 'table' must be a PyArrow Table") - - if keep not in ["first", "last"]: - raise ValueError("Parameter 'keep' must be either 'first' or 'last'") - - if not on: - on = table.column_names - - # Generate a unique column name for row index - index_column = f"index_{uuid4().hex}" - index_aggregate_column = f"{index_column}_{keep}" - - # Create row numbers - num_rows = table.num_rows - row_numbers = pa.array(np.arange(num_rows, dtype=np.int64)) - - # Append row numbers, group by specified columns, and aggregate - unique_indices = ( - table.append_column(index_column, row_numbers) - .group_by(on, use_threads=False) - .aggregate([(index_column, keep)]) - ).column(index_aggregate_column) - - return pc.take(table, unique_indices, boundscheck=False)