Skip to content

Commit

Permalink
Removed extra drop duplicates function
Browse files Browse the repository at this point in the history
  • Loading branch information
wpfl-dbt committed Jan 7, 2025
1 parent d31b3e0 commit 8394b35
Showing 1 changed file with 1 addition and 47 deletions.
48 changes: 1 addition & 47 deletions src/matchbox/common/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
import multiprocessing
from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor
from typing import Callable, Generic, Hashable, Iterable, Literal, TypeVar
from uuid import uuid4
from typing import Callable, Generic, Hashable, Iterable, TypeVar

import numpy as np
import pyarrow as pa
Expand Down Expand Up @@ -386,48 +385,3 @@ def to_hierarchical_clusters(
)

return pa.concat_tables(results)


def drop_duplicates(
table: pa.Table,
on: list[str] | None = None,
keep: Literal["first", "last"] = "first",
) -> pa.Table:
"""
Remove duplicate rows from a PyArrow table based on specified columns.
This function efficiently removes duplicate rows from a PyArrow table,
keeping either the first or last occurrence of each unique combination
of values in the specified columns.
Lifted with love from this gist:
https://gist.github.com/nmehran/57f264bd951b2f77af08f760eafea40e
An alternative:
https://github.com/TomScheffers/pyarrow_ops/
"""
if not isinstance(table, pa.Table):
raise TypeError("Parameter 'table' must be a PyArrow Table")

if keep not in ["first", "last"]:
raise ValueError("Parameter 'keep' must be either 'first' or 'last'")

if not on:
on = table.column_names

# Generate a unique column name for row index
index_column = f"index_{uuid4().hex}"
index_aggregate_column = f"{index_column}_{keep}"

# Create row numbers
num_rows = table.num_rows
row_numbers = pa.array(np.arange(num_rows, dtype=np.int64))

# Append row numbers, group by specified columns, and aggregate
unique_indices = (
table.append_column(index_column, row_numbers)
.group_by(on, use_threads=False)
.aggregate([(index_column, keep)])
).column(index_aggregate_column)

return pc.take(table, unique_indices, boundscheck=False)

0 comments on commit 8394b35

Please sign in to comment.