Skip to content

Commit

Permalink
Solve bugs with benchmark probability generation
Browse files Browse the repository at this point in the history
  • Loading branch information
leo-mazzone committed Jan 7, 2025
1 parent 9173884 commit 77ea3d0
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 42 deletions.
35 changes: 13 additions & 22 deletions src/matchbox/common/factories.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,33 @@
from collections import Counter
from textwrap import dedent
from typing import Any

import numpy as np
import pyarrow as pa
import rustworkx as rx

from matchbox.common.transform import graph_results

def verify_components(table: pa.Table) -> dict:

def verify_components(all_nodes: list[Any], table: pa.Table) -> dict:
"""
Fast verification of connected components using rustworkx.
Args:
all_nodes: list of identities of inputs being matched
table: PyArrow table with 'left', 'right' columns
Returns:
dictionary containing basic component statistics
"""
graph = rx.PyGraph()

unique_nodes = set(table["left"].to_numpy()) | set(table["right"].to_numpy())
graph.add_nodes_from(range(len(unique_nodes)))

node_to_idx = {node: idx for idx, node in enumerate(unique_nodes)}
edges = [
(node_to_idx[left], node_to_idx[right])
for left, right in zip(
table["left"].to_numpy(),
table["right"].to_numpy(),
strict=True,
)
]

graph.add_edges_from_no_data(edges)

graph, _, _ = graph_results(table, all_nodes)
components = rx.connected_components(graph)
component_sizes = Counter(len(component) for component in components)

return {
"num_components": len(components),
"total_nodes": len(unique_nodes),
"total_edges": len(edges),
"total_nodes": graph.num_nodes(),
"total_edges": graph.num_edges(),
"component_sizes": component_sizes,
"min_component_size": min(component_sizes.keys()),
"max_component_size": max(component_sizes.keys()),
Expand Down Expand Up @@ -122,7 +110,7 @@ def calculate_min_max_edges(
max_edges += _max_edges_component(
left_after_min_mod, right_after_min_mod, deduplicate
) * (max_mod - min_mod)
# components where both side have maximum nodes
# components where both side have minimum nodes
min_edges += _min_edges_component(left_div, right_div, deduplicate) * (
num_components - max_mod
)
Expand Down Expand Up @@ -173,6 +161,9 @@ def generate_dummy_probabilities(
)

mode = "dedupe" if deduplicate else "link"

if total_rows == 0:
raise ValueError("At least one edge must be generated")
if total_rows < min_possible_edges:
raise ValueError(
dedent(f"""
Expand Down Expand Up @@ -204,7 +195,7 @@ def generate_dummy_probabilities(
left_components = np.array_split(np.array(left_values), num_components)
right_components = np.array_split(np.array(right_values), num_components)
# For each left-right component pair, the right equals the left rotated by one
right_components = [np.roll(c, 1) for c in right_components]
right_components = [np.roll(c, -1) for c in right_components]

all_edges = []

Expand Down
48 changes: 37 additions & 11 deletions src/matchbox/common/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import multiprocessing
from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor
from typing import Callable, Generic, Hashable, TypeVar
from typing import Callable, Generic, Hashable, Iterable, TypeVar

import numpy as np
import pyarrow as pa
Expand Down Expand Up @@ -92,14 +92,23 @@ def to_clusters(results: ProbabilityResults) -> ClusterResults:
)


def attach_components_to_probabilities(probabilities: pa.Table) -> pa.Table:
def graph_results(
probabilities: pa.Table, all_nodes: Iterable[int] | None = None
) -> tuple[rx.PyDiGraph, np.ndarray, np.ndarray]:
"""
Takes an Arrow table of probabilities and adds a component column.
Convert probability table to graph representation.
Expects an Arrow table of column, left, right, probability.
Returns a table with an additional column, component.
Args:
probabilities: PyArrow table with 'left', 'right' columns
all_nodes: superset of node identities figuring in probabilities table.
Used to optionally add isolated nodes to the graph.
Returns:
A tuple containing:
- Rustwork directed graph
- A list mapping the 'left' probabilities column to node indices in the graph
- A list mapping the 'right' probabilities column to node indices in the graph
"""

# Create index to use in graph
unique = pc.unique(
pa.concat_arrays(
Expand All @@ -109,8 +118,9 @@ def attach_components_to_probabilities(probabilities: pa.Table) -> pa.Table:
]
)
)
left_indices = pc.index_in(probabilities["left"], unique)
right_indices = pc.index_in(probabilities["right"], unique)

left_indices = pc.index_in(probabilities["left"], unique).to_numpy()
right_indices = pc.index_in(probabilities["right"], unique).to_numpy()

# Create and process graph
n_nodes = len(unique)
Expand All @@ -119,9 +129,25 @@ def attach_components_to_probabilities(probabilities: pa.Table) -> pa.Table:
graph = rx.PyGraph(node_count_hint=n_nodes, edge_count_hint=n_edges)
graph.add_nodes_from(range(n_nodes))

edges = tuple(zip(left_indices.to_numpy(), right_indices.to_numpy(), strict=True))
if all_nodes is not None:
isolated_nodes = len(set(all_nodes) - set(unique.to_pylist()))
graph.add_nodes_from(range(isolated_nodes))

edges = tuple(zip(left_indices, right_indices, strict=True))
graph.add_edges_from_no_data(edges)

return graph, left_indices, right_indices


def attach_components_to_probabilities(probabilities: pa.Table) -> pa.Table:
"""
Takes an Arrow table of probabilities and adds a component column.
Expects an Arrow table of column, left, right, probability.
Returns a table with an additional column, component.
"""
graph, left_indices, _ = graph_results(probabilities)
components = rx.connected_components(graph)

# Convert components to arrays, map back to input to join, and reattach
Expand All @@ -130,10 +156,10 @@ def attach_components_to_probabilities(probabilities: pa.Table) -> pa.Table:
np.arange(len(components)), [len(c) for c in components]
)

node_to_component = np.zeros(len(unique), dtype=np.int64)
node_to_component = np.zeros(graph.num_nodes(), dtype=np.int64)
node_to_component[component_indices] = component_labels

edge_components = pa.array(node_to_component[left_indices.to_numpy()])
edge_components = pa.array(node_to_component[left_indices])

return probabilities.append_column("component", edge_components).sort_by(
[("component", "ascending"), ("probability", "descending")]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,8 @@ def generate_result_tables(
indexed_parents, clusters["probability"].to_numpy()
)

set_children = set(indexed_children)
source_entries = left_ids if right_ids is None else left_ids + right_ids
set_children = set(indexed_children)
top_clusters = [c for c in final_clusters + source_entries if c not in set_children]

probabilities_table = pa.table(
Expand Down
25 changes: 19 additions & 6 deletions test/common/test_factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@ def test_calculate_min_max_edges(
@pytest.mark.parametrize(
("parameters"),
[
{
"left_count": 5,
"right_count": None,
"prob_range": (0.6, 0.8),
"num_components": 3,
"total_rows": 2,
},
{
"left_count": 1000,
"right_count": None,
Expand Down Expand Up @@ -77,7 +84,13 @@ def test_calculate_min_max_edges(
"total_rows": calculate_min_max_edges(1000, 1000, 10, False)[1],
},
],
ids=["dedupe_min", "dedupe_max", "link_min", "link_max"],
ids=[
"dedupe_no_edges",
"dedupe_min",
"dedupe_max",
"link_min",
"link_max",
],
)
def test_generate_dummy_probabilities(parameters: dict[str, Any]):
len_left = parameters["left_count"]
Expand All @@ -103,19 +116,19 @@ def test_generate_dummy_probabilities(parameters: dict[str, Any]):
num_components=n_components,
total_rows=total_rows,
)
report = verify_components(table=probabilities)
report = verify_components(table=probabilities, all_nodes=rand_vals)
p_left = probabilities["left"].to_pylist()
p_right = probabilities["right"].to_pylist()

assert report["num_components"] == n_components

# Link
# Link job
if right_values:
assert set(p_left) == set(left_values)
assert set(p_right) == set(right_values)
assert set(p_left) <= set(left_values)
assert set(p_right) <= set(right_values)
# Dedupe
else:
assert set(p_left) | set(p_right) == set(left_values)
assert set(p_left) | set(p_right) <= set(left_values)

assert (
pc.max(probabilities["probability"]).as_py() / 100
Expand Down
4 changes: 2 additions & 2 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 77ea3d0

Please sign in to comment.