diff --git a/src/matchbox/common/factories.py b/src/matchbox/common/factories.py index 6e295d32..71be3427 100644 --- a/src/matchbox/common/factories.py +++ b/src/matchbox/common/factories.py @@ -1,45 +1,33 @@ from collections import Counter from textwrap import dedent +from typing import Any import numpy as np import pyarrow as pa import rustworkx as rx +from matchbox.common.transform import graph_results -def verify_components(table: pa.Table) -> dict: + +def verify_components(all_nodes: list[Any], table: pa.Table) -> dict: """ Fast verification of connected components using rustworkx. Args: + all_nodes: list of identities of inputs being matched table: PyArrow table with 'left', 'right' columns Returns: dictionary containing basic component statistics """ - graph = rx.PyGraph() - - unique_nodes = set(table["left"].to_numpy()) | set(table["right"].to_numpy()) - graph.add_nodes_from(range(len(unique_nodes))) - - node_to_idx = {node: idx for idx, node in enumerate(unique_nodes)} - edges = [ - (node_to_idx[left], node_to_idx[right]) - for left, right in zip( - table["left"].to_numpy(), - table["right"].to_numpy(), - strict=True, - ) - ] - - graph.add_edges_from_no_data(edges) - + graph, _, _ = graph_results(table, all_nodes) components = rx.connected_components(graph) component_sizes = Counter(len(component) for component in components) return { "num_components": len(components), - "total_nodes": len(unique_nodes), - "total_edges": len(edges), + "total_nodes": graph.num_nodes(), + "total_edges": graph.num_edges(), "component_sizes": component_sizes, "min_component_size": min(component_sizes.keys()), "max_component_size": max(component_sizes.keys()), @@ -122,7 +110,7 @@ def calculate_min_max_edges( max_edges += _max_edges_component( left_after_min_mod, right_after_min_mod, deduplicate ) * (max_mod - min_mod) - # components where both side have maximum nodes + # components where both side have minimum nodes min_edges += _min_edges_component(left_div, right_div, deduplicate) * ( num_components - max_mod ) @@ -173,6 +161,9 @@ def generate_dummy_probabilities( ) mode = "dedupe" if deduplicate else "link" + + if total_rows == 0: + raise ValueError("At least one edge must be generated") if total_rows < min_possible_edges: raise ValueError( dedent(f""" @@ -204,7 +195,7 @@ def generate_dummy_probabilities( left_components = np.array_split(np.array(left_values), num_components) right_components = np.array_split(np.array(right_values), num_components) # For each left-right component pair, the right equals the left rotated by one - right_components = [np.roll(c, 1) for c in right_components] + right_components = [np.roll(c, -1) for c in right_components] all_edges = [] diff --git a/src/matchbox/common/transform.py b/src/matchbox/common/transform.py index 1ebcc36b..72a10e6f 100644 --- a/src/matchbox/common/transform.py +++ b/src/matchbox/common/transform.py @@ -2,7 +2,7 @@ import multiprocessing from collections import defaultdict from concurrent.futures import ProcessPoolExecutor -from typing import Callable, Generic, Hashable, TypeVar +from typing import Callable, Generic, Hashable, Iterable, TypeVar import numpy as np import pyarrow as pa @@ -92,14 +92,23 @@ def to_clusters(results: ProbabilityResults) -> ClusterResults: ) -def attach_components_to_probabilities(probabilities: pa.Table) -> pa.Table: +def graph_results( + probabilities: pa.Table, all_nodes: Iterable[int] | None = None +) -> tuple[rx.PyDiGraph, np.ndarray, np.ndarray]: """ - Takes an Arrow table of probabilities and adds a component column. + Convert probability table to graph representation. - Expects an Arrow table of column, left, right, probability. - - Returns a table with an additional column, component. + Args: + probabilities: PyArrow table with 'left', 'right' columns + all_nodes: superset of node identities figuring in probabilities table. + Used to optionally add isolated nodes to the graph. + Returns: + A tuple containing: + - Rustwork directed graph + - A list mapping the 'left' probabilities column to node indices in the graph + - A list mapping the 'right' probabilities column to node indices in the graph """ + # Create index to use in graph unique = pc.unique( pa.concat_arrays( @@ -109,8 +118,9 @@ def attach_components_to_probabilities(probabilities: pa.Table) -> pa.Table: ] ) ) - left_indices = pc.index_in(probabilities["left"], unique) - right_indices = pc.index_in(probabilities["right"], unique) + + left_indices = pc.index_in(probabilities["left"], unique).to_numpy() + right_indices = pc.index_in(probabilities["right"], unique).to_numpy() # Create and process graph n_nodes = len(unique) @@ -119,9 +129,25 @@ def attach_components_to_probabilities(probabilities: pa.Table) -> pa.Table: graph = rx.PyGraph(node_count_hint=n_nodes, edge_count_hint=n_edges) graph.add_nodes_from(range(n_nodes)) - edges = tuple(zip(left_indices.to_numpy(), right_indices.to_numpy(), strict=True)) + if all_nodes is not None: + isolated_nodes = len(set(all_nodes) - set(unique.to_pylist())) + graph.add_nodes_from(range(isolated_nodes)) + + edges = tuple(zip(left_indices, right_indices, strict=True)) graph.add_edges_from_no_data(edges) + return graph, left_indices, right_indices + + +def attach_components_to_probabilities(probabilities: pa.Table) -> pa.Table: + """ + Takes an Arrow table of probabilities and adds a component column. + + Expects an Arrow table of column, left, right, probability. + + Returns a table with an additional column, component. + """ + graph, left_indices, _ = graph_results(probabilities) components = rx.connected_components(graph) # Convert components to arrays, map back to input to join, and reattach @@ -130,10 +156,10 @@ def attach_components_to_probabilities(probabilities: pa.Table) -> pa.Table: np.arange(len(components)), [len(c) for c in components] ) - node_to_component = np.zeros(len(unique), dtype=np.int64) + node_to_component = np.zeros(graph.num_nodes(), dtype=np.int64) node_to_component[component_indices] = component_labels - edge_components = pa.array(node_to_component[left_indices.to_numpy()]) + edge_components = pa.array(node_to_component[left_indices]) return probabilities.append_column("component", edge_components).sort_by( [("component", "ascending"), ("probability", "descending")] diff --git a/src/matchbox/server/postgresql/benchmark/generate_tables.py b/src/matchbox/server/postgresql/benchmark/generate_tables.py index f24b594e..c18dadb1 100644 --- a/src/matchbox/server/postgresql/benchmark/generate_tables.py +++ b/src/matchbox/server/postgresql/benchmark/generate_tables.py @@ -217,8 +217,8 @@ def generate_result_tables( indexed_parents, clusters["probability"].to_numpy() ) - set_children = set(indexed_children) source_entries = left_ids if right_ids is None else left_ids + right_ids + set_children = set(indexed_children) top_clusters = [c for c in final_clusters + source_entries if c not in set_children] probabilities_table = pa.table( diff --git a/test/common/test_factories.py b/test/common/test_factories.py index 967f3851..e48aa3e3 100644 --- a/test/common/test_factories.py +++ b/test/common/test_factories.py @@ -48,6 +48,13 @@ def test_calculate_min_max_edges( @pytest.mark.parametrize( ("parameters"), [ + { + "left_count": 5, + "right_count": None, + "prob_range": (0.6, 0.8), + "num_components": 3, + "total_rows": 2, + }, { "left_count": 1000, "right_count": None, @@ -77,7 +84,13 @@ def test_calculate_min_max_edges( "total_rows": calculate_min_max_edges(1000, 1000, 10, False)[1], }, ], - ids=["dedupe_min", "dedupe_max", "link_min", "link_max"], + ids=[ + "dedupe_no_edges", + "dedupe_min", + "dedupe_max", + "link_min", + "link_max", + ], ) def test_generate_dummy_probabilities(parameters: dict[str, Any]): len_left = parameters["left_count"] @@ -103,19 +116,19 @@ def test_generate_dummy_probabilities(parameters: dict[str, Any]): num_components=n_components, total_rows=total_rows, ) - report = verify_components(table=probabilities) + report = verify_components(table=probabilities, all_nodes=rand_vals) p_left = probabilities["left"].to_pylist() p_right = probabilities["right"].to_pylist() assert report["num_components"] == n_components - # Link + # Link job if right_values: - assert set(p_left) == set(left_values) - assert set(p_right) == set(right_values) + assert set(p_left) <= set(left_values) + assert set(p_right) <= set(right_values) # Dedupe else: - assert set(p_left) | set(p_right) == set(left_values) + assert set(p_left) | set(p_right) <= set(left_values) assert ( pc.max(probabilities["probability"]).as_py() / 100 diff --git a/uv.lock b/uv.lock index 6f2472e6..a3c3f534 100644 --- a/uv.lock +++ b/uv.lock @@ -945,7 +945,6 @@ dependencies = [ { name = "connectorx" }, { name = "duckdb" }, { name = "httpx" }, - { name = "ipywidgets" }, { name = "matplotlib" }, { name = "pandas" }, { name = "psycopg2" }, @@ -970,6 +969,7 @@ server = [ dev = [ { name = "docker" }, { name = "ipykernel" }, + { name = "ipywidgets" }, { name = "pre-commit" }, { name = "pytest" }, { name = "pytest-cov" }, @@ -989,7 +989,6 @@ requires-dist = [ { name = "duckdb", specifier = ">=1.1.1" }, { name = "fastapi", extras = ["standard"], marker = "extra == 'server'", specifier = ">=0.115.0,<0.116.0" }, { name = "httpx", specifier = ">=0.28.0" }, - { name = "ipywidgets", specifier = ">=8.1.5" }, { name = "matplotlib", specifier = ">=3.9.2" }, { name = "pandas", specifier = ">=2.2.3" }, { name = "pg-bulk-ingest", marker = "extra == 'server'", specifier = ">=0.0.54" }, @@ -1009,6 +1008,7 @@ requires-dist = [ dev = [ { name = "docker", specifier = ">=7.1.0" }, { name = "ipykernel", specifier = ">=6.29.5" }, + { name = "ipywidgets", specifier = ">=8.1.5" }, { name = "pre-commit", specifier = ">=3.8.0" }, { name = "pytest", specifier = ">=8.3.3" }, { name = "pytest-cov", specifier = ">=5.0.0" },