diff --git a/src/matchbox/client/models/linkers/splinklinker.py b/src/matchbox/client/models/linkers/splinklinker.py index 0996b2b0..798ed468 100644 --- a/src/matchbox/client/models/linkers/splinklinker.py +++ b/src/matchbox/client/models/linkers/splinklinker.py @@ -3,6 +3,7 @@ import logging from typing import Any, Dict, List, Optional, Type +import pyarrow as pa from pandas import DataFrame from pydantic import BaseModel, ConfigDict, Field, model_validator from splink import DuckDBAPI, SettingsCreator @@ -227,21 +228,19 @@ def link(self, left: DataFrame = None, right: DataFrame = None) -> DataFrame: threshold_match_probability=self.settings.threshold ) - return ( - res.as_pandas_dataframe() - .convert_dtypes(dtype_backend="pyarrow") - .rename( - columns={ - f"{self.settings.left_id}_l": "left_id", - f"{self.settings.right_id}_r": "right_id", - "match_probability": "probability", - } - ) - .assign( - left_id=lambda df: df.left_id.apply(self._id_dtype_l), - right_id=lambda df: df.right_id.apply(self._id_dtype_r), - ) - .filter(["left_id", "right_id", "probability"]) - .drop_duplicates() - .reset_index(drop=True) + df = res.as_pandas_dataframe().drop_duplicates() + + return pa.table( + [ + pa.array( + df[f"{self.settings.left_id}_l"].apply(self._id_dtype_l), + type=pa.uint64(), + ), + pa.array( + df[f"{self.settings.right_id}_r"].apply(self._id_dtype_r), + type=pa.uint64(), + ), + pa.array(df["match_probability"], type=pa.float32()), + ], + names=["left_id", "right_id", "probability"], ) diff --git a/src/matchbox/client/models/models.py b/src/matchbox/client/models/models.py index da05603f..961db4aa 100644 --- a/src/matchbox/client/models/models.py +++ b/src/matchbox/client/models/models.py @@ -6,14 +6,11 @@ from matchbox.client.models.dedupers.base import Deduper from matchbox.client.models.linkers.base import Linker from matchbox.client.results import ( - ClusterResults, ModelMetadata, ModelType, - ProbabilityResults, Results, ) from matchbox.common.exceptions import MatchboxResolutionError -from matchbox.common.transform import to_clusters from matchbox.server import MatchboxDBAdapter, inject_backend from matchbox.server.base import MatchboxModelAdapter @@ -74,18 +71,6 @@ def insert_model(self) -> None: ) self._connect() - @property - @ensure_connection - def probabilities(self) -> ProbabilityResults: - """Retrieve probabilities associated with the model from the database.""" - return self._model.probabilities - - @property - @ensure_connection - def clusters(self) -> ClusterResults: - """Retrieve clusters associated with the model from the database.""" - return self._model.clusters - @property @ensure_connection def results(self) -> Results: @@ -128,8 +113,8 @@ def ancestors_cache(self, ancestors_cache: dict[str, float]) -> None: """Set the ancestors cache of the model.""" self._model.ancestors_cache = ancestors_cache - def calculate_probabilities(self) -> ProbabilityResults: - """Calculate probabilities for the model.""" + def run(self) -> Results: + """Execute the model pipeline and return results.""" if self.metadata.type == ModelType.LINKER: if self.right_data is None: raise MatchboxResolutionError("Right dataset required for linking") @@ -140,23 +125,12 @@ def calculate_probabilities(self) -> ProbabilityResults: else: results = self.model_instance.dedupe(data=self.left_data) - return ProbabilityResults( - dataframe=results, + return Results( + probabilities=results, model=self, metadata=self.metadata, ) - def calculate_clusters(self, probabilities: ProbabilityResults) -> ClusterResults: - """Calculate clusters for the model based on probabilities.""" - return to_clusters(results=probabilities) - - def run(self) -> Results: - """Execute the model pipeline and return results.""" - probabilities = self.calculate_probabilities() - clusters = self.calculate_clusters(probabilities) - - return Results(model=self, probabilities=probabilities, clusters=clusters) - @inject_backend def make_model( diff --git a/src/matchbox/client/results.py b/src/matchbox/client/results.py index bdccb517..f178b22f 100644 --- a/src/matchbox/client/results.py +++ b/src/matchbox/client/results.py @@ -1,16 +1,16 @@ import logging -from abc import ABC, abstractmethod from enum import StrEnum -from typing import TYPE_CHECKING, Any, Hashable, TypeVar +from functools import wraps +from typing import TYPE_CHECKING, Any, Callable, Hashable, ParamSpec, TypeVar -import pandas as pd +import pyarrow as pa +import pyarrow.compute as pc from dotenv import find_dotenv, load_dotenv -from pandas import DataFrame -from pydantic import BaseModel, ConfigDict, field_validator, model_validator -from sqlalchemy import Table +from pandas import ArrowDtype, DataFrame +from pydantic import BaseModel, ConfigDict, field_validator -from matchbox.common.db import Cluster, Probability -from matchbox.common.hash import columns_to_value_ordered_hash +from matchbox.common.hash import IntMap +from matchbox.common.transform import to_clusters from matchbox.server.base import MatchboxDBAdapter, inject_backend if TYPE_CHECKING: @@ -19,6 +19,8 @@ Model = Any T = TypeVar("T", bound=Hashable) +P = ParamSpec("P") +R = TypeVar("R") logic_logger = logging.getLogger("mb_logic") @@ -43,263 +45,190 @@ class ModelMetadata(BaseModel): right_source: str | None = None # Only used for linker models -class ResultsBaseDataclass(BaseModel, ABC): - """Base class for results dataclasses. +def calculate_clusters(func: Callable[P, R]) -> Callable[P, R]: + """Decorator to calculate clusters if it hasn't been already.""" - Model is required during construction and calculation, but not when loading - from storage. - """ - - model_config = ConfigDict(arbitrary_types_allowed=True) - - dataframe: DataFrame - model: Model | None = None - metadata: ModelMetadata - - _expected_fields: list[str] - - @model_validator(mode="after") - def _check_dataframe(self) -> Table: - """Verifies the table contains the expected fields.""" - table_fields = set(self.dataframe.columns) - expected_fields = set(self._expected_fields) - - if table_fields != expected_fields: - raise ValueError(f"Expected {expected_fields}. \n" f"Found {table_fields}.") - - return self - - @abstractmethod - def inspect_with_source(self) -> DataFrame: - """Enriches the results with the source data.""" - return + @wraps(func) + def wrapper(self: "Results", *args: P.args, **kwargs: P.kwargs) -> R: + if not self.clusters: + im = IntMap() + self.clusters = to_clusters( + results=self.probabilities, dtype=pa.int64, hash_func=im.index + ) + return func(self, *args, **kwargs) - @abstractmethod - def to_df(self) -> DataFrame: - """Returns the results as a DataFrame.""" - return + return wrapper - @abstractmethod - def to_records(self) -> list[Probability | Cluster]: - """Returns the results as a list of records suitable for insertion.""" - return +class Results(BaseModel): + """Results of a model run. -class ProbabilityResults(ResultsBaseDataclass): - """Probabilistic matches produced by linkers and dedupers. + Contains: - There are pairs of records/clusters with a probability of being a match. - The hash is the hash of the sorted left and right ids. + * The probabilities of each pair being a match + * (Optional) The clusters of connected components at each threshold - _expected_fields enforces the shape of the dataframe. + Model is required during construction and calculation, but not when loading + from storage. - Args: - dataframe (DataFrame): the DataFrame holding the results - model (Model): x + Allows users to easily interrogate the outputs of models, explore decisions on + choosing thresholds for clustering, and upload the results to Matchbox. """ - _expected_fields: list[str] = [ - "id", - "left_id", - "right_id", - "probability", - ] - - @field_validator("dataframe", mode="before") - @classmethod - def results_to_hash(cls, dataframe: pd.DataFrame) -> pd.DataFrame: - """Adds an ID column to the dataframe if it doesn't already exist. - - * Reattaches hashes from the backend - * Uses them to create the new ID column - """ - id_exists = "id" in dataframe.columns - l_is_int = pd.api.types.is_integer_dtype(dataframe["left_id"]) - r_is_int = pd.api.types.is_integer_dtype(dataframe["right_id"]) - - if id_exists and l_is_int and r_is_int: - return dataframe - - @inject_backend - def _make_id_hasher(backend: MatchboxDBAdapter): - """Closure for converting int columns to hash using a lookup.""" - lookup: dict[int, bytes] = {} - - def _hash_column(df: pd.DataFrame, column_name: str) -> None: - hashed_column = f"{column_name}_hashed" - unique_ids = df[column_name].unique().tolist() - - lookup.update(backend.cluster_id_to_hash(ids=unique_ids)) + model_config = ConfigDict(arbitrary_types_allowed=True) - df[hashed_column] = ( - df[column_name].map(lookup).astype("binary[pyarrow]") - ) - df.drop(columns=[column_name], inplace=True) - df.rename(columns={hashed_column: column_name}, inplace=True) + probabilities: pa.Table + clusters: pa.Table | None = None + model: Model | None = None + metadata: ModelMetadata - return _hash_column + @field_validator("probabilities", mode="before") + @classmethod + def check_probabilities(cls, value: pa.Table | DataFrame) -> pa.Table: + """Verifies the probabilities table contains the expected fields.""" + if isinstance(value, DataFrame): + value = pa.Table.from_pandas(value) - hash_column = _make_id_hasher() + if not isinstance(value, pa.Table): + raise ValueError("Expected a pandas DataFrame or pyarrow Table.") - # Update lookup with left_id, then convert to hash - if l_is_int: - hash_column(df=dataframe, column_name="left_id") + table_fields = set(value.column_names) + expected_fields = {"left_id", "right_id", "probability"} + optional_fields = {"id"} - # Update lookup with right_id, then convert to hash - if r_is_int: - hash_column(df=dataframe, column_name="right_id") + if table_fields - optional_fields != expected_fields: + raise ValueError(f"Expected {expected_fields}. \n" f"Found {table_fields}.") - # Create ID column if it doesn't exist and hash the values - if not id_exists: - dataframe[["left_id", "right_id"]] = dataframe[ - ["left_id", "right_id"] - ].astype("binary[pyarrow]") - dataframe["id"] = columns_to_value_ordered_hash( - data=dataframe, columns=["left_id", "right_id"] + # If a process produces floats, we multiply by 100 and coerce to uint8 + if pa.types.is_floating(value["probability"].type): + probability_uint8 = pc.cast( + pc.multiply(value["probability"], 100), + options=pc.CastOptions( + target_type=pa.uint8(), allow_float_truncate=True + ), ) - dataframe["id"] = dataframe["id"].astype("binary[pyarrow]") - return dataframe + if pc.max(probability_uint8).as_py() > 100: + p_max = pc.max(value["probability"]).as_py() + p_min = pc.min(value["probability"]).as_py() + raise ValueError(f"Probability range misconfigured: [{p_min}, {p_max}]") - def inspect_with_source( - self, left_data: DataFrame, left_key: str, right_data: DataFrame, right_key: str - ) -> DataFrame: - """Enriches the results with the source data.""" - df = ( - self.to_df() - .filter(["left_id", "right_id", "probability"]) - .assign( - left_id=lambda d: d.left_id.apply(str), - right_id=lambda d: d.right_id.apply(str), + value = value.set_column( + i=value.schema.get_field_index("probability"), + field_="probability", + column=probability_uint8, ) - .merge( - left_data.assign(**{left_key: lambda d: d[left_key].apply(str)}), - how="left", - left_on="left_id", - right_on=left_key, + + if "id" in table_fields: + return value.cast( + pa.schema( + [ + ("id", pa.uint64()), + ("left_id", pa.uint64()), + ("right_id", pa.uint64()), + ("probability", pa.uint8()), + ] + ) ) - .drop(columns=[left_key]) - .merge( - right_data.assign(**{right_key: lambda d: d[right_key].apply(str)}), - how="left", - left_on="right_id", - right_on=right_key, + + return value.cast( + pa.schema( + [ + ("left_id", pa.uint64()), + ("right_id", pa.uint64()), + ("probability", pa.uint8()), + ] ) - .drop(columns=[right_key]) ) - return df - - def to_df(self) -> DataFrame: - """Returns the results as a DataFrame.""" - df = self.dataframe.assign( - left=self.model.metadata.left_source, - right=self.model.metadata.right_source, - model=self.metadata.name, - ).convert_dtypes(dtype_backend="pyarrow")[ - ["model", "left", "left_id", "right", "right_id", "probability"] - ] - - return df - - @inject_backend - def to_records(self, backend: MatchboxDBAdapter | None) -> set[Probability]: - """Returns the results as a list of records suitable for insertion. - - If given a backend, will validate the records against the database. - """ - # Optional validation - if backend: - backend.validate_hashes(hashes=self.dataframe.left_id.unique().tolist()) - backend.validate_hashes(hashes=self.dataframe.right_id.unique().tolist()) - - return { - Probability(hash=row[0], left=row[1], right=row[2], probability=row[3]) - for row in self.dataframe[ - ["id", "left_id", "right_id", "probability"] - ].to_numpy() - } - - -class ClusterResults(ResultsBaseDataclass): - """Cluster data produced by using to_clusters on ProbabilityResults. - - This is the connected components of the probabilistic matches at every - threshold of probabilitity. The parent is the hash of the sorted children. - - _expected_fields enforces the shape of the dataframe. - - Args: - dataframe (DataFrame): the DataFrame holding the results - model (Model): x - """ - - _expected_fields: list[str] = ["parent", "child", "threshold"] - - def inspect_with_source( + def _merge_with_source_data( self, + base_df: DataFrame, + base_df_cols: list[str], left_data: DataFrame, left_key: str, right_data: DataFrame, right_key: str, + left_merge_col: str, + right_merge_col: str, ) -> DataFrame: - """Enriches the results with the source data.""" + """Helper method to merge results with source data frames.""" return ( - self.to_df() - .filter(["parent", "child", "probability"]) - .map(str) + base_df.filter(base_df_cols) .merge( - left_data.assign(**{left_key: lambda d: d[left_key].apply(str)}), + left_data, how="left", - left_on="child", + left_on=left_merge_col, right_on=left_key, ) .drop(columns=[left_key]) .merge( - right_data.assign(**{right_key: lambda d: d[right_key].apply(str)}), + right_data, how="left", - left_on="child", + left_on=right_merge_col, right_on=right_key, ) .drop(columns=[right_key]) ) - def to_df(self) -> DataFrame: - """Returns the results as a DataFrame.""" - return self.dataframe.copy().convert_dtypes(dtype_backend="pyarrow") - - def to_records(self) -> set[Cluster]: - """Returns the results as a list of records suitable for insertion.""" - # Preprocess the dataframe - pre_prep_df = ( - self.dataframe[["parent", "child", "threshold"]] - .groupby(["parent", "threshold"], as_index=False)["child"] - .agg(list) - .copy() + def probabilities_to_pandas(self) -> DataFrame: + """Returns the probability results as a DataFrame.""" + df = ( + self.probabilities.to_pandas(types_mapper=ArrowDtype) + .assign( + left=self.model.metadata.left_source, + right=self.model.metadata.right_source, + model=self.metadata.name, + ) + .convert_dtypes(dtype_backend="pyarrow")[ + ["model", "left", "left_id", "right", "right_id", "probability"] + ] ) - return { - Cluster(parent=row[0], children=row[1], threshold=row[2]) - for row in pre_prep_df.to_numpy() - } - - -class Results(BaseModel): - """A container for the results of a model run. + return df - Contains all the information any backend will need to store the results. - """ + def inspect_probabilities( + self, left_data: DataFrame, left_key: str, right_data: DataFrame, right_key: str + ) -> DataFrame: + """Enriches the probability results with the source data.""" + return self._merge_with_source_data( + base_df=self.probabilities_to_pandas(), + base_df_cols=["left_id", "right_id", "probability"], + left_data=left_data, + left_key=left_key, + right_data=right_data, + right_key=right_key, + left_merge_col="left_id", + right_merge_col="right_id", + ) - model_config = ConfigDict(arbitrary_types_allowed=True) + @calculate_clusters + def clusters_to_pandas(self) -> DataFrame: + """Returns the cluster results as a DataFrame.""" + return self.clusters.to_pandas(types_mapper=ArrowDtype) - probabilities: ProbabilityResults - clusters: ClusterResults + @calculate_clusters + def inspect_clusters( + self, + left_data: DataFrame, + left_key: str, + right_data: DataFrame, + right_key: str, + ) -> DataFrame: + """Enriches the cluster results with the source data.""" + return self._merge_with_source_data( + base_df=self.clusters_to_pandas(), + base_df_cols=["parent", "child", "probability"], + left_data=left_data, + left_key=left_key, + right_data=right_data, + right_key=right_key, + left_merge_col="child", + right_merge_col="child", + ) @inject_backend def to_matchbox(self, backend: MatchboxDBAdapter) -> None: """Writes the results to the Matchbox database.""" - if self.probabilities.model != self.clusters.model: - raise ValueError("Probabilities and clusters must be from the same model.") - - self.clusters.model.insert_model() - self.clusters.model.results = self + self.model.insert_model() + self.model.results = self diff --git a/src/matchbox/common/hash.py b/src/matchbox/common/hash.py index 4101d37b..841b5f9d 100644 --- a/src/matchbox/common/hash.py +++ b/src/matchbox/common/hash.py @@ -70,13 +70,13 @@ def hash_data(data: str) -> bytes: return HASH_FUNC(prep_for_hash(data)).digest() -def list_to_value_ordered_hash(list_: list[T]) -> bytes: - """Returns a single hash of a list ordered by its values. +def hash_values(*values: tuple[T, ...]) -> bytes: + """Returns a single hash of a tuple of items ordered by its values. List must be sorted as the different orders of value must produce the same hash. """ try: - sorted_vals = sorted(list_) + sorted_vals = sorted(values) except TypeError as e: raise TypeError("Can only order lists or columns of the same datatype.") from e @@ -104,15 +104,18 @@ def columns_to_value_ordered_hash(data: DataFrame, columns: list[str]) -> Series hashed_records = [] for record in bytes_records: - hashed_vals = list_to_value_ordered_hash(record.values()) + hashed_vals = hash_values(*record.values()) hashed_records.append(hashed_vals) return Series(hashed_records) class IntMap: - """ - A data structure taking unordered sets of integers, and mapping them a to an ID that + """A data structure to map integers without collisions within a dedicated space. + + A stand-in for hashing integers within pa.int64. + + Takes unordered sets of integers, and maps them a to an ID that 1) is a negative integer; 2) does not collide with other IDs generated by other instances of this class, as long as they are initialised with a different salt. @@ -122,9 +125,12 @@ class IntMap: (which will be negative). The salt allows to work with a parallel execution model, where each worker maintains their separate ID space, as long as each worker operates on disjoint subsets of positive integers. + + Args: + salt (optional): A positive integer to salt the Cantor pairing function """ - def __init__(self, salt: int): + def __init__(self, salt: int = 42): self.mapping: dict[frozenset[int], int] = {} if salt < 0: raise ValueError("The salt must be a positive integer") diff --git a/src/matchbox/common/logging.py b/src/matchbox/common/logging.py new file mode 100644 index 00000000..724bb9a6 --- /dev/null +++ b/src/matchbox/common/logging.py @@ -0,0 +1,30 @@ +from rich.console import Console +from rich.progress import ( + BarColumn, + Progress, + SpinnerColumn, + TextColumn, + TimeElapsedColumn, + TimeRemainingColumn, +) + + +def get_console(): + """Get the console instance.""" + return Console() + + +def build_progress_bar(console: Console | None = None) -> Progress: + """Create a progress bar.""" + if console is None: + console = get_console() + + return Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), + TimeElapsedColumn(), + TimeRemainingColumn(), + console=console, + ) diff --git a/src/matchbox/common/transform.py b/src/matchbox/common/transform.py index 72a10e6f..fc32a7ca 100644 --- a/src/matchbox/common/transform.py +++ b/src/matchbox/common/transform.py @@ -9,22 +9,9 @@ import pyarrow.compute as pc import rustworkx as rx from dotenv import find_dotenv, load_dotenv -from pandas import DataFrame -from rich.console import Console -from rich.progress import ( - BarColumn, - Progress, - SpinnerColumn, - TextColumn, - TimeElapsedColumn, - TimeRemainingColumn, -) - -from matchbox.client.results import ClusterResults, ProbabilityResults -from matchbox.common.hash import ( - IntMap, - list_to_value_ordered_hash, -) + +from matchbox.common.hash import hash_values +from matchbox.common.logging import build_progress_bar T = TypeVar("T", bound=Hashable) @@ -34,37 +21,48 @@ load_dotenv(dotenv_path) -def to_clusters(results: ProbabilityResults) -> ClusterResults: +def to_clusters( + results: pa.Table, + dtype: pa.DataType = pa.binary, + hash_func: Callable[[*tuple[T, ...]], T] = hash_values, +) -> pa.Table: """ Converts probabilities into a list of connected components formed at each threshold. + Args: + results: Arrow table with columns ['left_id', 'right_id', 'probability'] + dtype: Arrow data type for the parent and child columns + hash_func: Function to hash the parent and child values + Returns: - ClusterResults sorted by threshold descending. + Arrow table of parent, child, threshold, sorted by probability descending. """ G = rx.PyGraph() added: dict[bytes, int] = {} components: dict[str, list] = {"parent": [], "child": [], "threshold": []} - # Sort probabilities descending and group by probability - edges_df = ( - results.dataframe.sort_values("probability", ascending=False) - .filter(["left_id", "right_id", "probability"]) - .astype( - {"left_id": "large_binary[pyarrow]", "right_id": "large_binary[pyarrow]"} - ) - ) + # Sort probabilities descending and select relevant columns + results = results.sort_by([("probability", "descending")]) + results = results.select(["left_id", "right_id", "probability"]) # Get unique probability thresholds, sorted - thresholds = sorted(edges_df["probability"].unique()) + thresholds = pc.unique(results.column("probability")).sort(order="descending") # Process edges grouped by probability threshold for prob in thresholds: - threshold_edges = edges_df[edges_df["probability"] == prob] + threshold_edges = results.filter(pc.equal(results.column("probability"), prob)) + # Get state before adding this batch of edges old_components = {frozenset(comp) for comp in rx.connected_components(G)} # Add all nodes and edges at this probability threshold - edge_values = threshold_edges[["left_id", "right_id"]].values + edge_values = list( + zip( + threshold_edges.column("left_id").to_pylist(), + threshold_edges.column("right_id").to_pylist(), + strict=True, + ) + ) for left, right in edge_values: for hash_val in (left, right): if hash_val not in added: @@ -78,17 +76,18 @@ def to_clusters(results: ProbabilityResults) -> ClusterResults: # For each changed component, add ALL members at current threshold for comp in changed_components: - children = sorted([G.get_node_data(n) for n in comp]) - parent = list_to_value_ordered_hash(children) + children = [G.get_node_data(n) for n in comp] + parent = hash_func(*children) components["parent"].extend([parent] * len(children)) components["child"].extend(children) components["threshold"].extend([prob] * len(children)) - return ClusterResults( - dataframe=DataFrame(components).convert_dtypes(dtype_backend="pyarrow"), - model=results.model, - metadata=results.metadata, + return pa.Table.from_pydict( + components, + schema=pa.schema( + [("parent", dtype()), ("child", dtype()), ("threshold", pa.uint8())] + ), ) @@ -108,7 +107,6 @@ def graph_results( - A list mapping the 'left' probabilities column to node indices in the graph - A list mapping the 'right' probabilities column to node indices in the graph """ - # Create index to use in graph unique = pc.unique( pa.concat_arrays( @@ -147,6 +145,11 @@ def attach_components_to_probabilities(probabilities: pa.Table) -> pa.Table: Returns a table with an additional column, component. """ + # Handle empty probabilities + if len(probabilities) == 0: + empty_components = pa.array([], type=pa.int64()) + return probabilities.append_column("component", empty_components) + graph, left_indices, _ = graph_results(probabilities) components = rx.connected_components(graph) @@ -212,7 +215,9 @@ def get_components(self) -> list[set[T]]: def component_to_hierarchy( - table: pa.Table, salt: int, dtype: pa.DataType = pa.uint64 + table: pa.Table, + dtype: pa.DataType = pa.binary, + hash_func: Callable[[*tuple[T, ...]], T] = hash_values, ) -> pa.Table: """ Convert pairwise probabilities into a hierarchical representation. @@ -221,6 +226,8 @@ def component_to_hierarchy( Args: table: Arrow Table with columns ['left', 'right', 'probability'] + dtype: Arrow data type for the parent and child columns + hash_func: Function to hash the parent and child values Returns: Arrow Table with columns ['parent', 'child', 'probability'] @@ -231,9 +238,9 @@ def component_to_hierarchy( probs = ascending_probs[::-1] djs = DisjointSet[int]() # implements connected components - im = IntMap(salt=salt) # generates IDs for new clusters current_roots: dict[int, set[int]] = defaultdict(set) # tracks ultimate parents hierarchy: list[tuple[int, int, float]] = [] # the output of this function + seen_components: set[frozenset[int]] = set() # track previously seen component sets for threshold in probs: # Get current probability rows @@ -247,7 +254,7 @@ def component_to_hierarchy( strict=True, ): djs.union(left, right) - parent = im.index(left, right) + parent = hash_func(left, right) hierarchy.extend([(parent, left, threshold), (parent, right, threshold)]) current_roots[left].add(parent) current_roots[right].add(parent) @@ -256,10 +263,14 @@ def component_to_hierarchy( if len(children) <= 2: continue # Skip pairs already handled by pairwise probabilities - if im.has_mapping(*children): - continue # Skip unchanged components from previous thresholds + # Skip if we've seen this exact component before + frozen_children = frozenset(children) + if frozen_children in seen_components: + continue + + seen_components.add(frozen_children) - parent = im.index(*children) + parent = hash_func(*children) prev_roots: set[int] = set() for child in children: prev_roots.update(current_roots[child]) @@ -281,7 +292,8 @@ def component_to_hierarchy( def to_hierarchical_clusters( probabilities: pa.Table, proc_func: Callable[[pa.Table, pa.DataType], pa.Table] = component_to_hierarchy, - dtype: pa.DataType = pa.int64, + hash_func: Callable[[*tuple[T, ...]], T] = hash_values, + dtype: pa.DataType = pa.binary, timeout: int = 300, ) -> pa.Table: """ @@ -291,20 +303,22 @@ def to_hierarchical_clusters( probabilities: Arrow table with columns ['component', 'left', 'right', 'probability'] proc_func: Function to process each component + hash_func: Function to hash the parent and child values + dtype: Arrow data type for the parent and child columns timeout: Maximum seconds to wait for each component to process Returns: Arrow table with columns ['parent', 'child', 'probability'] """ - console = Console() - progress_columns = [ - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - BarColumn(), - TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), - TimeElapsedColumn(), - TimeRemainingColumn(), - ] + # Handle empty probabilities + if len(probabilities) == 0: + return pa.table( + { + "parent": pa.array([], type=dtype()), + "child": pa.array([], type=dtype()), + "probability": pa.array([], type=pa.uint8()), + } + ) probabilities = probabilities.sort_by([("component", "ascending")]) components = pc.unique(probabilities["component"]) @@ -318,7 +332,7 @@ def to_hierarchical_clusters( indices = [] start_idx = 0 - with Progress(*progress_columns, console=console) as progress: + with build_progress_bar() as progress: split_task = progress.add_task( "[cyan]Splitting tables...", total=len(component_col) ) @@ -339,15 +353,17 @@ def to_hierarchical_clusters( # Process components in parallel results = [] - with Progress(*progress_columns, console=console) as progress: + with build_progress_bar() as progress: process_task = progress.add_task( "[green]Processing components...", total=len(component_tables) ) with ProcessPoolExecutor(max_workers=n_cores) as executor: futures = [ - executor.submit(proc_func, component_table, salt=salt, dtype=dtype) - for salt, component_table in enumerate(component_tables) + executor.submit( + proc_func, component_table, hash_func=hash_func, dtype=dtype + ) + for component_table in component_tables ] for future in futures: diff --git a/src/matchbox/server/base.py b/src/matchbox/server/base.py index db1747ec..b5144389 100644 --- a/src/matchbox/server/base.py +++ b/src/matchbox/server/base.py @@ -192,14 +192,6 @@ class MatchboxModelAdapter(ABC): hash: bytes name: str - @property - @abstractmethod - def probabilities(self) -> ProbabilityResults: ... - - @property - @abstractmethod - def clusters(self) -> ClusterResults: ... - @property @abstractmethod def results(self) -> Results: ... diff --git a/src/matchbox/server/postgresql/adapter.py b/src/matchbox/server/postgresql/adapter.py index 412ac34c..05fd879e 100644 --- a/src/matchbox/server/postgresql/adapter.py +++ b/src/matchbox/server/postgresql/adapter.py @@ -5,7 +5,7 @@ from sqlalchemy import Engine, and_, bindparam, delete, func, or_, select from sqlalchemy.orm import Session -from matchbox.client.results import ClusterResults, ProbabilityResults, Results +from matchbox.client.results import Results from matchbox.common.db import Match, Source, SourceWarehouse from matchbox.common.exceptions import ( MatchboxDataError, @@ -30,10 +30,7 @@ insert_results, ) from matchbox.server.postgresql.utils.query import match, query -from matchbox.server.postgresql.utils.results import ( - get_model_clusters, - get_model_probabilities, -) +from matchbox.server.postgresql.utils.results import get_model_results if TYPE_CHECKING: from pandas import DataFrame as PandasDataFrame @@ -132,22 +129,10 @@ def name(self) -> str: session.add(self.resolution) return self.resolution.name - @property - def probabilities(self) -> ProbabilityResults: - """Retrieve probabilities for this model.""" - return get_model_probabilities( - engine=MBDB.get_engine(), resolution=self.resolution - ) - - @property - def clusters(self) -> ClusterResults: - """Retrieve clusters for this model.""" - return get_model_clusters(engine=MBDB.get_engine(), resolution=self.resolution) - @property def results(self) -> Results: """Retrieve results for this model.""" - return Results(probabilities=self.probabilities, clusters=self.clusters) + return get_model_results(engine=MBDB.get_engine(), resolution=self.resolution) @results.setter def results(self, results: Results) -> None: diff --git a/src/matchbox/server/postgresql/benchmark/generate_tables.py b/src/matchbox/server/postgresql/benchmark/generate_tables.py index c18dadb1..56c86faa 100644 --- a/src/matchbox/server/postgresql/benchmark/generate_tables.py +++ b/src/matchbox/server/postgresql/benchmark/generate_tables.py @@ -4,68 +4,22 @@ import click import pyarrow as pa +import pyarrow.compute as pc import pyarrow.parquet as pq from matchbox.common.factories import generate_dummy_probabilities -from matchbox.common.hash import HASH_FUNC +from matchbox.common.hash import HASH_FUNC, hash_data, hash_values from matchbox.common.transform import ( attach_components_to_probabilities, to_hierarchical_clusters, ) - - -class IDCreator: - """ - A generator of incremental integer IDs from positive and negative integers. - - Positive integers will be returned as they are, while a new ID will be generated - for each negative integer. - """ - - def __init__(self, start: int): - self.id_map = dict() - self._next_int = start - - def create(self, temp_ids: list[int]) -> list[int]: - results = [] - for ti in temp_ids: - if ti >= 0: - results.append(ti) - elif ti in self.id_map: - results.append(self.id_map[ti]) - else: - self.id_map[ti] = self._next_int - results.append(self._next_int) - self._next_int += 1 - - return results - - def reset_mapping(self): - self.__init__(self._next_int) - - return self +from matchbox.server.postgresql.utils.insert import HashIDMap def _hash_list_int(li: list[int]) -> list[bytes]: return [HASH_FUNC(str(i).encode("utf-8")).digest() for i in li] -def _unique_clusters( - all_parents: Iterable[int], all_probabilities: Iterable[int] -) -> tuple[list[int], list[float]]: - ll = set() - clusters = [] - probabilities = [] - for parent, prob in zip(all_parents, all_probabilities, strict=True): - if parent in ll: - continue - else: - ll.add(parent) - clusters.append(parent) - probabilities.append(prob / 100) - return clusters, probabilities - - def generate_sources() -> pa.Table: """ Generate sources table. @@ -181,12 +135,12 @@ def generate_result_tables( left_ids: Iterable[int], right_ids: Iterable[int] | None, resolution_id: int, - id_creator: IDCreator, + next_id: int, n_components: int, n_probs: int, prob_min: float = 0.6, prob_max: float = 1, -) -> tuple[list[int], pa.Table, pa.Table, pa.Table]: +) -> tuple[list[int], pa.Table, pa.Table, pa.Table, int]: """ Generate probabilities, contains and clusters tables. @@ -194,62 +148,105 @@ def generate_result_tables( left_ids: list of IDs for rows to dedupe, or for left rows to link right_ids: list of IDs for right rows to link resolution_id: ID of resolution for this dedupe or link model - id_creator: an IDCreator instance + next_id: the next ID to use when generating IDs n_components: number of implied connected components n_probs: total number of probability edges to be generated prob_min: minimum value for probabilities to be generated prob_max: maximum value for probabilities to be generated Returns: - Tuple with 1 list of top-level clusters and 3 PyArrow tables, for probabilities, - contains and clusters + Tuple with 1 list of top-level clusters, 3 PyArrow tables, for probabilities, + contains and clusters, and the next ID to use for future calls """ probs = generate_dummy_probabilities( left_ids, right_ids, [prob_min, prob_max], n_components, n_probs ) - clusters = to_hierarchical_clusters(attach_components_to_probabilities(probs)) + # Create a lookup table for hashes + all_probs = pa.concat_arrays( + [probs["left"].combine_chunks(), probs["right"].combine_chunks()] + ) + lookup = pa.table( + { + "id": all_probs, + "hash": pa.array( + [hash_data(p) for p in all_probs.to_pylist()], type=pa.binary() + ), + } + ) + + hm = HashIDMap(start=next_id, lookup=lookup) - indexed_parents = id_creator.create(clusters["parent"].to_pylist()) - indexed_children = id_creator.create(clusters["child"].to_pylist()) + # Join hashes, probabilities and components + probs_with_ccs = attach_components_to_probabilities( + pa.table( + { + "left": hm.get_hashes(probs["left"]), + "right": hm.get_hashes(probs["right"]), + "probability": probs["probability"], + } + ) + ) - final_clusters, final_probs = _unique_clusters( - indexed_parents, clusters["probability"].to_numpy() + # Calculate hierarchies + hierarchy = to_hierarchical_clusters( + probabilities=probs_with_ccs, + hash_func=hash_values, + dtype=pa.binary, ) - source_entries = left_ids if right_ids is None else left_ids + right_ids - set_children = set(indexed_children) - top_clusters = [c for c in final_clusters + source_entries if c not in set_children] + # Shape into tables + parent_ids = hm.get_ids(hierarchy["parent"]) + child_ids = hm.get_ids(hierarchy["child"]) + unique_parent_ids = pc.unique(parent_ids) + unique_child_ids = pc.unique(child_ids) probabilities_table = pa.table( { "resolution": pa.array( - [resolution_id] * len(final_clusters), type=pa.uint64() + [resolution_id] * hierarchy.shape[0], type=pa.uint64() ), - "cluster": pa.array(final_clusters, type=pa.uint64()), - "probability": pa.array(final_probs, type=pa.float64()), + "cluster": parent_ids, + "probability": hierarchy["probability"], } ) contains_table = pa.table( { - "parent": pa.array(indexed_parents, type=pa.uint64()), - "child": pa.array(indexed_children, type=pa.uint64()), + "parent": parent_ids, + "child": child_ids, } ) clusters_table = pa.table( { - "cluster_id": pa.array(final_clusters, type=pa.uint64()), - "cluster_hash": pa.array(_hash_list_int(final_clusters), type=pa.binary()), - "dataset": pa.array([None] * len(final_clusters), type=pa.uint64()), + "cluster_id": unique_parent_ids, + "cluster_hash": hm.get_hashes(unique_parent_ids), + "dataset": pa.array([None] * len(unique_parent_ids), type=pa.uint64()), "source_pk": pa.array( - [None] * len(final_clusters), type=pa.list_(pa.string()) + [None] * len(unique_parent_ids), type=pa.list_(pa.string()) ), } ) - return (top_clusters, probabilities_table, contains_table, clusters_table) + # Compute top clusters + parents_not_children = pc.filter( + unique_parent_ids, pc.invert(pc.is_in(unique_parent_ids, unique_child_ids)) + ) + sources_not_children = pc.filter( + all_probs, pc.invert(pc.is_in(all_probs, unique_child_ids)) + ) + top_clusters = pc.unique( + pa.concat_arrays([parents_not_children, sources_not_children]) + ) + + return ( + top_clusters, + probabilities_table, + contains_table, + clusters_table, + hm.next_int, + ) def generate_all_tables( @@ -279,36 +276,43 @@ def generate_all_tables( clusters_source1 = generate_cluster_source(0, source_len) clusters_source2 = generate_cluster_source(source_len, source_len * 2) - id_creator = IDCreator(source_len * 2) - top_clusters1, probabilities_dedupe1, contains_dedupe1, clusters_dedupe1 = ( - generate_result_tables( - clusters_source1["cluster_id"].to_pylist(), - None, - 3, - id_creator, - dedupe_components, - dedupe_len, - ) + ( + top_clusters1, + probabilities_dedupe1, + contains_dedupe1, + clusters_dedupe1, + next_id1, + ) = generate_result_tables( + left_ids=clusters_source1["cluster_id"].to_pylist(), + right_ids=None, + resolution_id=3, + next_id=source_len * 2, + n_components=dedupe_components, + n_probs=dedupe_len, ) - top_clusters2, probabilities_dedupe2, contains_dedupe2, clusters_dedupe2 = ( - generate_result_tables( - clusters_source2["cluster_id"].to_pylist(), - None, - 4, - id_creator.reset_mapping(), - dedupe_components, - dedupe_len, - ) + ( + top_clusters2, + probabilities_dedupe2, + contains_dedupe2, + clusters_dedupe2, + next_id2, + ) = generate_result_tables( + left_ids=clusters_source2["cluster_id"].to_pylist(), + right_ids=None, + resolution_id=4, + next_id=next_id1, + n_components=dedupe_components, + n_probs=dedupe_len, ) - _, probabilities_link, contains_link, clusters_link = generate_result_tables( - top_clusters1, - top_clusters2, - 5, - id_creator.reset_mapping(), - link_components, - link_len, + _, probabilities_link, contains_link, clusters_link, _ = generate_result_tables( + left_ids=top_clusters1, + right_ids=top_clusters2, + resolution_id=5, + next_id=next_id2, + n_components=link_components, + n_probs=link_len, ) probabilities = pa.concat_tables( diff --git a/src/matchbox/server/postgresql/orm.py b/src/matchbox/server/postgresql/orm.py index 69514185..a7725e0b 100644 --- a/src/matchbox/server/postgresql/orm.py +++ b/src/matchbox/server/postgresql/orm.py @@ -2,6 +2,7 @@ BIGINT, FLOAT, INTEGER, + SMALLINT, VARCHAR, CheckConstraint, Column, @@ -276,7 +277,7 @@ class Probabilities(CountMixin, MBDB.MatchboxBase): cluster = Column( BIGINT, ForeignKey("clusters.cluster_id", ondelete="CASCADE"), primary_key=True ) - probability = Column(FLOAT, nullable=False) + probability = Column(SMALLINT, nullable=False) # Relationships proposed_by = relationship("Resolutions", back_populates="probabilities") @@ -284,5 +285,5 @@ class Probabilities(CountMixin, MBDB.MatchboxBase): # Constraints __table_args__ = ( - CheckConstraint("probability BETWEEN 0 AND 1", name="valid_probability"), + CheckConstraint("probability BETWEEN 0 AND 100", name="valid_probability"), ) diff --git a/src/matchbox/server/postgresql/utils/insert.py b/src/matchbox/server/postgresql/utils/insert.py index 1192e956..9e9273e9 100644 --- a/src/matchbox/server/postgresql/utils/insert.py +++ b/src/matchbox/server/postgresql/utils/insert.py @@ -1,18 +1,21 @@ import logging -from collections import defaultdict -from itertools import count -import numpy as np -import pandas as pd -from sqlalchemy import Engine, bindparam, delete, select +import pyarrow as pa +import pyarrow.compute as pc +from sqlalchemy import Engine, delete, exists, select, union from sqlalchemy.dialects.postgresql import insert from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import Session +from sqlalchemy.sql.selectable import Select -from matchbox.client.results import ClusterResults, ProbabilityResults, Results -from matchbox.common.db import Source +from matchbox.client.results import Results +from matchbox.common.db import Source, sql_to_df from matchbox.common.graph import ResolutionNodeType -from matchbox.common.hash import dataset_to_hashlist, list_to_value_ordered_hash +from matchbox.common.hash import dataset_to_hashlist, hash_values +from matchbox.common.transform import ( + attach_components_to_probabilities, + to_hierarchical_clusters, +) from matchbox.server.postgresql.orm import ( Clusters, Contains, @@ -26,6 +29,79 @@ logic_logger = logging.getLogger("mb_logic") +class HashIDMap: + """An object to help map between IDs and hashes. + + When given a set of IDs, returns their hashes. If any ID doesn't have a hash, + it will error. + + When given a set of hashes, it will return their IDs. If any don't have IDs, it + will create one and return it as part of the set. + + Args: + start: The first integer to use for new IDs + lookup (optional): A lookup table to use for existing hashes + """ + + def __init__(self, start: int, lookup: pa.Table = None): + self.next_int = start + if not lookup: + self.lookup = pa.Table.from_arrays( + [ + pa.array([], type=pa.uint64()), + pa.array([], type=pa.binary()), + pa.array([], type=pa.bool_()), + ], + names=["id", "hash", "new"], + ) + else: + new_column = pa.array([False] * lookup.shape[0], type=pa.bool_()) + self.lookup = pa.Table.from_arrays( + [lookup["id"], lookup["hash"], new_column], names=["id", "hash", "new"] + ) + + def get_hashes(self, ids: pa.UInt64Array) -> pa.LargeBinaryArray: + """Returns the hashes of the given IDs.""" + indices = pc.index_in(ids, self.lookup["id"]) + + if pc.any(pc.is_null(indices)).as_py(): + m_mask = pc.is_null(indices) + m_ids = pc.filter(ids, m_mask) + + raise ValueError( + f"The following IDs were not found in lookup table: {m_ids.to_pylist()}" + ) + + return pc.take(self.lookup["hash"], indices) + + def get_ids(self, hashes: pa.LargeBinaryArray) -> pa.UInt64Array: + """Returns the IDs of the given hashes, assigning new IDs for unknown hashes.""" + indices = pc.index_in(hashes, self.lookup["hash"]) + new_hashes = pc.unique(pc.filter(hashes, pc.is_null(indices))) + + if len(new_hashes) > 0: + new_ids = pa.array( + range(self.next_int, self.next_int + len(new_hashes)), + type=pa.uint64(), + ) + + new_entries = pa.Table.from_arrays( + [ + new_ids, + new_hashes, + pa.array([True] * len(new_hashes), type=pa.bool_()), + ], + names=["id", "hash", "new"], + ) + + self.next_int += len(new_hashes) + self.lookup = pa.concat_tables([self.lookup, new_entries]) + + indices = pc.index_in(hashes, self.lookup["hash"]) + + return pc.take(self.lookup["id"], indices) + + def insert_dataset(dataset: Source, engine: Engine, batch_size: int) -> None: """Indexes a dataset from your data warehouse within Matchbox.""" @@ -142,12 +218,10 @@ def insert_model( """ logic_logger.info(f"[{model}] Registering model") with Session(engine) as session: - resolution_hash = list_to_value_ordered_hash( - [ - left.resolution_hash, - right.resolution_hash, - bytes(model, encoding="utf-8"), - ] + resolution_hash = hash_values( + left.resolution_hash, + right.resolution_hash, + bytes(model, encoding="utf-8"), ) # Check if resolution exists @@ -222,76 +296,176 @@ def _create_closure_entries(parent_resolution: Resolutions) -> None: logic_logger.info(f"[{model}] Done!") -def _cluster_results_to_hierarchical( - probabilities: ProbabilityResults, - clusters: ClusterResults, -) -> pd.DataFrame: - """ - Converts results to a hierarchical structure by building up from base components. +def _map_ids( + array: pa.Array, + lookup: pa.Table, + source: str, + target: str, +) -> pa.Array: + """Maps values in an array via a lookup. Args: - probabilities: Original pairwise probabilities containing base components - clusters: Connected components at each threshold + array: Array of values to map + lookup: Table of values to map to + source: Name of the column to map from + target: Name of the column to map to Returns: - Pandas DataFrame of (parent, child, threshold) tuples representing the hierarchy + Array of mapped values """ - prob_df = probabilities.dataframe - cluster_df = clusters.dataframe + indices = pc.index_in(array, lookup[source]) + return pc.take(lookup[target], indices) - # Sort thresholds in descending order - thresholds = sorted(cluster_df["threshold"].unique(), reverse=True) - hierarchy: list[tuple[int, int, float]] = [] - ultimate_parents: dict[int, set[int]] = defaultdict(set) +def _get_resolution_related_clusters(resolution_id: int) -> Select: + """ + Get cluster hashes and IDs for a resolution, its parents, and siblings. - # Process each threshold level - for threshold in thresholds: - threshold_float = float(threshold) + * When a parent is a dataset, retrieves the data via the Sources table. + * When a parent is a model, retrieves the data via the Probabilities table. - # Add new pairwise relationships at this threshold - current_probs = prob_df[prob_df["probability"] == threshold_float] + This corresponds to all possible existing clusters that a resolution might ever be + able to link together, or propose. - for _, row in current_probs.iterrows(): - parent = row["id"] - left_id = row["left_id"] - right_id = row["right_id"] + Args: + resolution_id: The ID of the resolution to query - hierarchy.extend( - [ - (parent, left_id, threshold_float), - (parent, right_id, threshold_float), - ] + Returns: + List of tuples containing (cluster_hash, cluster_id) + """ + direct_resolution = select(Resolutions.resolution_id).where( + Resolutions.resolution_id == resolution_id + ) + + parent_resolutions = select(ResolutionFrom.parent).where( + ResolutionFrom.child == resolution_id + ) + + sibling_resolutions = ( + select(ResolutionFrom.child) + .where( + ResolutionFrom.parent.in_( + select(ResolutionFrom.parent).where( + ResolutionFrom.child == resolution_id + ) ) + ) + .where(ResolutionFrom.child != resolution_id) + ) + + resolution_set = union( + direct_resolution, parent_resolutions, sibling_resolutions + ).cte("resolution_set") - ultimate_parents[left_id].add(parent) - ultimate_parents[right_id].add(parent) + # Main query + base_query = ( + select(Clusters.cluster_hash.label("hash"), Clusters.cluster_id.label("id")) + .distinct() + .select_from(Clusters) + .join(Probabilities, Probabilities.cluster == Clusters.cluster_id, isouter=True) + ) + + # Subquery for source datasets + source_datasets = select(resolution_set.c.resolution_id).join( + Sources, Sources.resolution_id == resolution_set.c.resolution_id + ) + + # Subquery for model resolutions + model_resolutions = select(resolution_set.c.resolution_id).where( + ~exists() + .select_from(Sources) + .where(Sources.resolution_id == resolution_set.c.resolution_id) + ) + + # Combine conditions + final_query = base_query.where( + (Clusters.dataset.in_(source_datasets)) + | (Probabilities.resolution.in_(model_resolutions)) + ) + + return final_query + + +def _results_to_insert_tables( + resolution: Resolutions, results: Results, engine: Engine +) -> tuple[pa.Table, pa.Table, pa.Table]: + """Takes Results and returns three Arrow tables that can be inserted exactly. - # Process clusters at this threshold - current_clusters = cluster_df[cluster_df["threshold"] == threshold_float] + Returns: + A tuple containing: + * A Clusters update Arrow table + * A Contains update Arrow table + * A Probabilities update Arrow table + """ + logic_logger.info(f"[{resolution.name}] Wrangling data to insert tables") - # Group by parent to process components together - for parent, group in current_clusters.groupby("parent"): - children = set(group["child"]) - if len(children) <= 2: - continue # Skip pairs already handled by pairwise probabilities + # Create ID-Hash lookup for existing probabilities + lookup = sql_to_df( + stmt=_get_resolution_related_clusters(resolution.resolution_id), + engine=engine, + return_type="arrow", + ) + lookup = lookup.cast(pa.schema([("hash", pa.binary()), ("id", pa.uint64())])) + + hm = HashIDMap(start=Clusters.next_id(), lookup=lookup) + + # Join hashes, probabilities and components + probs_with_ccs = attach_components_to_probabilities( + pa.table( + { + "left": hm.get_hashes(results.probabilities["left_id"]), + "right": hm.get_hashes(results.probabilities["right_id"]), + "probability": results.probabilities["probability"], + } + ) + ) - current_ultimate_parents: set[int] = set() - for child in children: - current_ultimate_parents.update(ultimate_parents[child]) + # Calculate hierarchies + hierarchy = to_hierarchical_clusters( + probabilities=probs_with_ccs, + hash_func=hash_values, + dtype=pa.binary, + ) - for up in current_ultimate_parents: - hierarchy.append((parent, up, threshold_float)) + # Create Probabilities Arrow table to insert, containing all generated probabilities + probabilities = pa.table( + { + "resolution": pa.array( + [resolution.resolution_id] * hierarchy.shape[0], + type=pa.uint64(), + ), + "cluster": hm.get_ids(hierarchy["parent"]), + "probability": hierarchy["probability"], + } + ) - for child in children: - ultimate_parents[child] = {parent} + # Create Clusters Arrow table to insert, containing only new clusters + new_hashes = pc.filter(hm.lookup["hash"], hm.lookup["new"]) + clusters = pa.table( + { + "cluster_id": pc.filter(hm.lookup["id"], hm.lookup["new"]), + "cluster_hash": new_hashes, + "dataset": pa.nulls(len(new_hashes), type=pa.uint64()), + "source_pk": pa.nulls(len(new_hashes), type=pa.list_(pa.string())), + } + ) - # Sort hierarchy by threshold (descending), then parent, then child - return ( - pd.DataFrame(hierarchy, columns=["parent", "child", "threshold"]) - .sort_values(["threshold", "parent", "child"], ascending=[False, True, True]) - .reset_index(drop=True) + # Create Contains Arrow table to insert, containing only new contains edges + # Recall that clusters are defined by their parents, so all existing clusters + # already have the same parent-child relationships as were calculated here + hierarchy_new = hierarchy.filter( + pa.compute.is_in(hierarchy["parent"], value_set=new_hashes) ) + contains = pa.table( + { + "parent": hm.get_ids(hierarchy_new["parent"]), + "child": hm.get_ids(hierarchy_new["child"]), + } + ) + + logic_logger.info(f"[{resolution.name}] Wrangling complete!") + + return clusters, contains, probabilities def insert_results( @@ -324,46 +498,10 @@ def insert_results( f"[{resolution.name}] Writing results data with batch size {batch_size:,}" ) - # Get the lookup of existing database values and generate new ones - hierarchy = _cluster_results_to_hierarchical( - probabilities=results.probabilities, clusters=results.clusters - ) - hashes = np.unique( - np.concatenate([hierarchy["parent"].unique(), hierarchy["child"].unique()]) - ).tolist() - lookup: dict[bytes, int | None] = {hash: None for hash in hashes} - - with Session(engine) as session: - data_inner_join = ( - session.query(Clusters) - .filter( - Clusters.cluster_hash.in_( - bindparam( - "ins_ids", - hashes, - expanding=True, - ) - ) - ) - .all() - ) - - gen_cluster_id = count(Clusters.next_id()) - - lookup.update({item.cluster_hash: item.cluster_id for item in data_inner_join}) - lookup = {k: next(gen_cluster_id) if v is None else v for k, v in lookup.items()} - - hierarchy["parent_id"] = ( - hierarchy["parent"].apply(lambda i: lookup[i]).astype("int32[pyarrow]") - ) - hierarchy["child_id"] = ( - hierarchy["child"].apply(lambda i: lookup[i]).astype("int32[pyarrow]") + clusters, contains, probabilities = _results_to_insert_tables( + resolution=resolution, results=results, engine=engine ) - hierarchy_unique_parents = hierarchy[ - ["parent_id", "parent", "threshold"] - ].drop_duplicates() - with Session(engine) as session: try: # Clear existing probabilities for this resolution @@ -385,70 +523,45 @@ def insert_results( with engine.connect() as conn: try: - total_records = results.clusters.dataframe.shape[0] logic_logger.info( - f"[{resolution.name}] Inserting {total_records:,} results objects" - ) - - cluster_records: list[tuple[int, bytes, None, None]] = list( - zip( - hierarchy_unique_parents["parent_id"], - hierarchy_unique_parents["parent"], - [None] * hierarchy_unique_parents.shape[0], - [None] * hierarchy_unique_parents.shape[0], - strict=True, - ) - ) - contains_records: list[tuple[int, int]] = list( - zip( - hierarchy["parent_id"], - hierarchy["child_id"], - strict=True, - ) - ) - probability_records: list[tuple[int, int, float]] = list( - zip( - [resolution.resolution_id] * hierarchy_unique_parents.shape[0], - hierarchy_unique_parents["parent_id"], - hierarchy_unique_parents["threshold"], - strict=True, - ) + f"[{resolution.name}] Inserting {clusters.shape[0]:,} results " + "objects" ) batch_ingest( - records=cluster_records, + records=[tuple(c.values()) for c in clusters.to_pylist()], table=Clusters, conn=conn, batch_size=batch_size, ) logic_logger.info( - f"[{resolution.name}] Successfully inserted {len(cluster_records)} " + f"[{resolution.name}] Successfully inserted {clusters.shape[0]} " "objects into Clusters table" ) batch_ingest( - records=contains_records, + records=[tuple(c.values()) for c in contains.to_pylist()], table=Contains, conn=conn, batch_size=batch_size, ) logic_logger.info( - f"[{resolution.name}] Successfully inserted {len(contains_records)} " + f"[{resolution.name}] Successfully inserted {contains.shape[0]} " "objects into Contains table" ) batch_ingest( - records=probability_records, + records=[tuple(c.values()) for c in probabilities.to_pylist()], table=Probabilities, conn=conn, batch_size=batch_size, ) logic_logger.info( - f"[{resolution.name}] Successfully inserted {len(probability_records)} " - "objects into Probabilities table" + f"[{resolution.name}] Successfully inserted " + f"{probabilities.shape[0]} objects into Probabilities table" ) except SQLAlchemyError as e: diff --git a/src/matchbox/server/postgresql/utils/results.py b/src/matchbox/server/postgresql/utils/results.py index 11a0528e..542bba18 100644 --- a/src/matchbox/server/postgresql/utils/results.py +++ b/src/matchbox/server/postgresql/utils/results.py @@ -1,20 +1,12 @@ from typing import NamedTuple -import pandas as pd -import pyarrow as pa -import rustworkx as rx from sqlalchemy import Engine, and_, case, func, select from sqlalchemy.orm import Session -from matchbox.client.results import ( - ClusterResults, - ModelMetadata, - ModelType, - ProbabilityResults, -) +from matchbox.client.results import ModelMetadata, ModelType, Results +from matchbox.common.db import sql_to_df from matchbox.common.graph import ResolutionNodeType from matchbox.server.postgresql.orm import ( - Clusters, Contains, Probabilities, ResolutionFrom, @@ -84,11 +76,9 @@ def _get_source_info(engine: Engine, resolution_id: int) -> SourceInfo: ) -def get_model_probabilities( - engine: Engine, resolution: Resolutions -) -> ProbabilityResults: +def get_model_results(engine: Engine, resolution: Resolutions) -> Results: """ - Recover the model's ProbabilityResults. + Recover the model's pairwise probabilities and return as Results. For each probability this model assigned: - Get its two immediate children @@ -100,7 +90,7 @@ def get_model_probabilities( resolution: Resolution of type model to query Returns: - ProbabilityResults containing the original pairwise probabilities + Results containing the original pairwise probabilities """ if resolution.type != ResolutionNodeType.MODEL: raise ValueError("Expected resolution of type model") @@ -195,162 +185,6 @@ def get_model_probabilities( pairs.c.probability, ) - results = session.execute(final_select).fetchall() - - df = pd.DataFrame( - results, columns=["id", "left_id", "right_id", "probability"] - ).astype( - { - "id": pd.ArrowDtype(pa.uint64()), - "left_id": pd.ArrowDtype(pa.uint64()), - "right_id": pd.ArrowDtype(pa.uint64()), - "probability": pd.ArrowDtype(pa.float32()), - } - ) - - return ProbabilityResults(dataframe=df, metadata=metadata) - - -def _get_all_leaf_descendants(graph: rx.PyDiGraph, node_id: int) -> set[int]: - """Get all leaf descendant node IDs of a given node in the graph.""" - descendants = set() - to_process = [node_id] - - while to_process: - current = to_process.pop() - children = [edge[1] for edge in graph.out_edges(current)] - - if not children: - descendants.add(current) - else: - to_process.extend(children) - - return descendants - - -def get_model_clusters(engine: Engine, resolution: Resolutions) -> ClusterResults: - """ - Recover the model's Clusters. - - Clusters are the connected components of the model at every threshold. - - While they're stored in a hierarchical structure, we need to recover the - original components, where all child IDs are leaf Clusters. - - Args: - engine: SQLAlchemy engine - model: Resolution of type model to query - - Returns: - A ClusterResults object containing connected components and model metadata - """ - if resolution.type != ResolutionNodeType.MODEL: - raise ValueError("Expected resolution of type model") - - source_info: SourceInfo = _get_source_info( - engine=engine, resolution_id=resolution.resolution_id - ) - - with Session(engine) as session: - # Build metadata - left = session.get(Resolutions, source_info.left) - right = ( - session.get(Resolutions, source_info.right) if source_info.right else None - ) - - metadata = ModelMetadata( - name=resolution.name, - description=resolution.description or "", - type=ModelType.DEDUPER if source_info.right is None else ModelType.LINKER, - left_source=left.name, - right_source=right.name if source_info.right else None, - ) - - # Get all clusters and their relationships for this resolution - hierarchy_query = ( - select(Contains.parent, Contains.child, Probabilities.probability) - .join( - Probabilities, - and_( - Probabilities.cluster == Contains.parent, - Probabilities.resolution == resolution.resolution_id, - ), - ) - .order_by(Probabilities.probability.desc()) - ) - - hierarchy = session.execute(hierarchy_query).fetchall() + df = sql_to_df(stmt=final_select, engine=engine, return_type="arrow") - # Get all leaf nodes (clusters with no children) and their IDs - leaf_query = select(Clusters.cluster_id, Clusters.source_pk).where( - ~Clusters.cluster_id.in_(select(Contains.parent).distinct()) - ) - leaf_nodes = { - row.cluster_id: row.source_pk[0] if row.source_pk else None - for row in session.execute(leaf_query) - } - - # Get unique thresholds and components at each threshold - threshold_query = ( - select(Probabilities.cluster, Probabilities.probability) - .where(Probabilities.resolution == resolution.resolution_id) - .order_by(Probabilities.probability.desc()) - ) - threshold_components = session.execute(threshold_query).fetchall() - - # Build directed graph of the full hierarchy - graph = rx.PyDiGraph() - nodes: dict[bytes, int] = {} # node_hash -> node_id - - def get_node_id(id: bytes) -> int: - if id not in nodes: - nodes[id] = graph.add_node(id) - return nodes[id] - - for parent, child, prob in hierarchy: - parent_id = get_node_id(parent) - child_id = get_node_id(child) - graph.add_edge(parent_id, child_id, prob) - - # Process each threshold level - components: list[tuple[bytes, bytes, float]] = [] - seen_combinations = set() - - threshold_groups = {} - for comp, thresh in threshold_components: - if thresh not in threshold_groups: - threshold_groups[thresh] = [] - threshold_groups[thresh].append(comp) - - # Process thresholds in descending order - for threshold in sorted(threshold_groups.keys(), reverse=True): - for component in threshold_groups[threshold]: - component_id = get_node_id(component) - - leaf_ids = _get_all_leaf_descendants(graph, component_id) - - leaf_hashes = { - graph.get_node_data(leaf_id) - for leaf_id in leaf_ids - if graph.get_node_data(leaf_id) in leaf_nodes - } - - for leaf in leaf_hashes: - if leaf_nodes[leaf] is not None: - relation = (component, leaf, threshold) - if relation not in seen_combinations: - components.append(relation) - seen_combinations.add(relation) - - df = pd.DataFrame(components, columns=["parent", "child", "threshold"]).astype( - { - "parent": pd.ArrowDtype(pa.uint64()), - "child": pd.ArrowDtype(pa.uint64()), - "threshold": pd.ArrowDtype(pa.float32()), - } - ) - - return ClusterResults( - dataframe=df, - metadata=metadata, - ) + return Results(probabilities=df, metadata=metadata) diff --git a/test/client/test_dedupers.py b/test/client/test_dedupers.py index b96f00d7..590e5b78 100644 --- a/test/client/test_dedupers.py +++ b/test/client/test_dedupers.py @@ -1,3 +1,5 @@ +import pyarrow as pa +import pyarrow.compute as pc import pytest from pandas import DataFrame @@ -78,34 +80,30 @@ def test_dedupers( results = model.run() - deduped_df = results.probabilities.to_df() - deduped_df_with_source = results.probabilities.inspect_with_source( + result_with_source = results.inspect_probabilities( left_data=df, left_key="id", right_data=df, right_key="id" ) - assert isinstance(deduped_df, DataFrame) - assert deduped_df.shape[0] == fx_data.tgt_prob_n + assert isinstance(results.probabilities, pa.Table) + assert results.probabilities.shape[0] == fx_data.tgt_prob_n - assert isinstance(deduped_df_with_source, DataFrame) + assert isinstance(result_with_source, DataFrame) for field in fields: - assert deduped_df_with_source[field + "_x"].equals( - deduped_df_with_source[field + "_y"] - ) + assert result_with_source[field + "_x"].equals(result_with_source[field + "_y"]) # 3. Correct number of clusters are resolved - clusters_dupes_df = results.clusters.to_df() - clusters_dupes_df_with_source = results.clusters.inspect_with_source( + clusters_with_source = results.inspect_clusters( left_data=df, left_key="id", right_data=df, right_key="id" ) - assert isinstance(clusters_dupes_df, DataFrame) - assert clusters_dupes_df.parent.nunique() == fx_data.tgt_clus_n + assert isinstance(results.clusters, pa.Table) + assert pc.count_distinct(results.clusters["parent"]).as_py() == fx_data.tgt_clus_n - assert isinstance(clusters_dupes_df_with_source, DataFrame) + assert isinstance(clusters_with_source, DataFrame) for field in fields: - assert clusters_dupes_df_with_source[field + "_x"].equals( - clusters_dupes_df_with_source[field + "_y"] + assert clusters_with_source[field + "_x"].equals( + clusters_with_source[field + "_y"] ) # 4. Probabilities and clusters are inserted correctly @@ -113,7 +111,7 @@ def test_dedupers( results.to_matchbox(backend=matchbox_postgres) model = matchbox_postgres.get_model(model=deduper_name) - assert model.probabilities.dataframe.shape[0] == fx_data.tgt_prob_n + assert model.results.probabilities.shape[0] == fx_data.tgt_prob_n model.truth = 0.0 diff --git a/test/client/test_linkers.py b/test/client/test_linkers.py index 822182a5..b9839635 100644 --- a/test/client/test_linkers.py +++ b/test/client/test_linkers.py @@ -1,3 +1,5 @@ +import pyarrow as pa +import pyarrow.compute as pc import pytest from pandas import DataFrame from splink import SettingsCreator @@ -114,35 +116,33 @@ def test_linkers( results = model.run() - linked_df = results.probabilities.to_df() - linked_df_with_source = results.probabilities.inspect_with_source( + result_with_source = results.inspect_probabilities( left_data=df_l, left_key="id", right_data=df_r, right_key="id", ) - assert isinstance(linked_df, DataFrame) - assert linked_df.shape[0] == fx_data.tgt_prob_n + assert isinstance(results.probabilities, pa.Table) + assert results.probabilities.shape[0] == fx_data.tgt_prob_n - assert isinstance(linked_df_with_source, DataFrame) + assert isinstance(result_with_source, DataFrame) for field_l, field_r in zip(fields_l, fields_r, strict=True): - assert linked_df_with_source[field_l].equals(linked_df_with_source[field_r]) + assert result_with_source[field_l].equals(result_with_source[field_r]) # 3. Correct number of clusters are resolved - clusters_links_df = results.clusters.to_df() - clusters_links_df_with_source = results.clusters.inspect_with_source( + clusters_with_source = results.inspect_clusters( left_data=df_l, left_key="id", right_data=df_r, right_key="id", ) - assert isinstance(clusters_links_df, DataFrame) - assert clusters_links_df.parent.nunique() == fx_data.tgt_clus_n + assert isinstance(results.clusters, pa.Table) + assert pc.count_distinct(results.clusters["parent"]).as_py() == fx_data.tgt_clus_n - assert isinstance(clusters_links_df_with_source, DataFrame) + assert isinstance(clusters_with_source, DataFrame) for field_l, field_r in zip(fields_l, fields_r, strict=True): # When we enrich the ClusterResults in a deduplication job, every child # id will match something in the source data, because we're only using @@ -157,7 +157,7 @@ def unique_non_null(s): return s.dropna().unique() cluster_vals = ( - clusters_links_df_with_source.filter(["parent", field_l, field_r]) + clusters_with_source.filter(["parent", field_l, field_r]) .groupby("parent") .agg( { @@ -178,7 +178,7 @@ def unique_non_null(s): results.to_matchbox(backend=matchbox_postgres) model = matchbox_postgres.get_model(model=linker_name) - assert model.probabilities.dataframe.shape[0] == fx_data.tgt_prob_n + assert model.results.probabilities.shape[0] == fx_data.tgt_prob_n model.truth = 0.0 diff --git a/test/common/test_transform.py b/test/common/test_transform.py index 8e7ef3d9..88b46273 100644 --- a/test/common/test_transform.py +++ b/test/common/test_transform.py @@ -1,15 +1,12 @@ -from concurrent.futures import ThreadPoolExecutor -from contextlib import contextmanager +from functools import lru_cache from itertools import chain -from typing import Any, Iterator -from unittest.mock import patch +from typing import Any import pyarrow as pa import pyarrow.compute as pc import pytest from matchbox.common.factories import generate_dummy_probabilities -from matchbox.common.hash import IntMap from matchbox.common.transform import ( attach_components_to_probabilities, component_to_hierarchy, @@ -17,7 +14,8 @@ ) -def _combine_strings(self, *n: str) -> str: +@lru_cache(maxsize=None) +def _combine_strings(*n: str) -> str: """ Combine n strings into a single string, with a cache. Meant to replace `matchbox.common.hash.IntMap.index` @@ -28,32 +26,8 @@ def _combine_strings(self, *n: str) -> str: Returns: A single string """ - value_set = frozenset(n) - if value_set in self.mapping: - return self.mapping[value_set] - letters = set(chain.from_iterable(n)) - - new_id = "".join(sorted(letters)) - self.mapping[value_set] = new_id - return new_id - - -@contextmanager -def parallel_pool_for_tests( - max_workers: int = 2, timeout: int = 30 -) -> Iterator[ThreadPoolExecutor]: - """Context manager for safe parallel execution in tests using threads. - - Args: - max_workers: Maximum number of worker threads - timeout: Maximum seconds to wait for each task - """ - with ThreadPoolExecutor(max_workers=max_workers) as executor: - try: - yield executor - finally: - executor.shutdown(wait=False, cancel_futures=True) + return "".join(sorted(letters)) @pytest.mark.parametrize( @@ -86,6 +60,20 @@ def test_attach_components_to_probabilities(parameters: dict[str, Any]): assert len(pc.unique(with_components["component"])) == parameters["num_components"] +def test_empty_attach_components_to_probabilities(): + probabilities = pa.table( + { + "left": [], + "right": [], + "probability": [], + } + ) + + with_components = attach_components_to_probabilities(probabilities=probabilities) + + assert len(with_components) == 0 + + @pytest.mark.parametrize( ("probabilities", "hierarchy"), [ @@ -165,57 +153,54 @@ def test_attach_components_to_probabilities(parameters: dict[str, Any]): def test_component_to_hierarchy( probabilities: dict[str, list[str | float]], hierarchy: set[tuple[str, str, int]] ): - with patch.object(IntMap, "index", _combine_strings): - probabilities_table = ( - pa.Table.from_pydict(probabilities) - .cast( - pa.schema( - [ - ("left", pa.string()), - ("right", pa.string()), - ("probability", pa.uint8()), - ] - ) + probabilities_table = ( + pa.Table.from_pydict(probabilities) + .cast( + pa.schema( + [ + ("left", pa.string()), + ("right", pa.string()), + ("probability", pa.uint8()), + ] ) - .sort_by([("probability", "descending")]) ) + .sort_by([("probability", "descending")]) + ) - parents, children, probs = zip(*hierarchy, strict=False) + parents, children, probs = zip(*hierarchy, strict=False) - hierarchy_true = ( - pa.table( - [parents, children, probs], names=["parent", "child", "probability"] - ) - .cast( - pa.schema( - [ - ("parent", pa.string()), - ("child", pa.string()), - ("probability", pa.uint8()), - ] - ) - ) - .sort_by( + hierarchy_true = ( + pa.table([parents, children, probs], names=["parent", "child", "probability"]) + .cast( + pa.schema( [ - ("probability", "descending"), - ("parent", "ascending"), - ("child", "ascending"), + ("parent", pa.string()), + ("child", pa.string()), + ("probability", pa.uint8()), ] ) - .filter(pc.is_valid(pc.field("parent"))) ) - - hierarchy = component_to_hierarchy( - probabilities_table, salt=1, dtype=pa.string - ).sort_by( + .sort_by( [ ("probability", "descending"), ("parent", "ascending"), ("child", "ascending"), ] ) + .filter(pc.is_valid(pc.field("parent"))) + ) - assert hierarchy.equals(hierarchy_true) + hierarchy = component_to_hierarchy( + table=probabilities_table, dtype=pa.string, hash_func=_combine_strings + ).sort_by( + [ + ("probability", "descending"), + ("parent", "ascending"), + ("child", "ascending"), + ] + ) + + assert hierarchy.equals(hierarchy_true) @pytest.mark.parametrize( @@ -305,16 +290,12 @@ def test_hierarchical_clusters(input_data, expected_hierarchy): ) # Run and compare - with ( - patch( - "matchbox.common.transform.ProcessPoolExecutor", - lambda *args, **kwargs: parallel_pool_for_tests(timeout=30), - ), - patch.object(IntMap, "index", _combine_strings), - ): - result = to_hierarchical_clusters( - probabilities, dtype=pa.string, proc_func=component_to_hierarchy - ) + result = to_hierarchical_clusters( + probabilities, + dtype=pa.string, + proc_func=component_to_hierarchy, + hash_func=_combine_strings, + ) result = result.sort_by( [ diff --git a/test/conftest.py b/test/conftest.py index 7c82675e..ae48e4c4 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,5 +1,58 @@ +from concurrent.futures import ThreadPoolExecutor +from contextlib import contextmanager +from typing import Iterator +from unittest.mock import patch + +import pytest +from rich.console import Console +from rich.progress import Progress + pytest_plugins = [ "test.fixtures.data", "test.fixtures.db", "test.fixtures.graph", ] + + +@contextmanager +def parallel_pool_for_tests( + max_workers: int = 2, timeout: int = 30 +) -> Iterator[ThreadPoolExecutor]: + """Context manager for safe parallel execution in tests using threads. + + Args: + max_workers: Maximum number of worker threads + timeout: Maximum seconds to wait for each task + """ + with ThreadPoolExecutor(max_workers=max_workers) as executor: + try: + yield executor + finally: + executor.shutdown(wait=False, cancel_futures=True) + + +@pytest.fixture(scope="session", autouse=True) +def patch_multiprocessing() -> Iterator[None]: + """Patch ProcessPoolExecutor to use ThreadPoolExecutor in tests.""" + with patch( + "matchbox.common.transform.ProcessPoolExecutor", + lambda *args, **kwargs: parallel_pool_for_tests(timeout=30), + ): + yield + + +@pytest.fixture(scope="session", autouse=True) +def patch_rich_console() -> Iterator[None]: + """Patch Rich console for quiet output in tests.""" + quiet_console = Console(quiet=True) + + console_patch = patch( + "matchbox.common.logging.get_console", return_value=quiet_console + ) + progress_patch = patch( + "matchbox.common.logging.build_progress_bar", + return_value=Progress(console=quiet_console), + ) + + with console_patch, progress_patch: + yield diff --git a/test/server/test_adapter.py b/test/server/test_adapter.py index ff6f2f95..339597fc 100644 --- a/test/server/test_adapter.py +++ b/test/server/test_adapter.py @@ -1,16 +1,13 @@ from collections import defaultdict from typing import Callable +import pyarrow.compute as pc import pytest from dotenv import find_dotenv, load_dotenv from pandas import DataFrame from matchbox.client.helpers.selector import match, query, selector, selectors -from matchbox.client.results import ( - ClusterResults, - ProbabilityResults, - Results, -) +from matchbox.client.results import Results from matchbox.common.db import Match, Source, SourceColumn from matchbox.common.exceptions import ( MatchboxDataError, @@ -19,7 +16,6 @@ ) from matchbox.common.graph import ResolutionGraph from matchbox.common.hash import HASH_FUNC -from matchbox.common.transform import to_clusters from matchbox.server.base import MatchboxDBAdapter, MatchboxModelAdapter from ..fixtures.db import SetupDatabaseCallable @@ -80,8 +76,6 @@ def test_model_properties(self): assert naive_crn.id assert naive_crn.hash assert naive_crn.name - assert naive_crn.probabilities - assert naive_crn.clusters assert naive_crn.results assert isinstance(naive_crn.truth, float) # otherwise we assert 0.0 assert naive_crn.ancestors @@ -299,32 +293,59 @@ def test_insert_model(self): assert self.backend.models.count() == models_count + 3 - def test_model_get_probabilities(self): - """Test that a model's ProbabilityResults can be retrieved.""" + def test_model_results(self): + """Test that a model's Results can be set and retrieved.""" self.setup_database("dedupe") naive_crn = self.backend.get_model(model="naive_test.crn") - assert isinstance(naive_crn.probabilities, ProbabilityResults) - assert len(naive_crn.probabilities.dataframe) > 0 - assert naive_crn.probabilities.metadata.name == "naive_test.crn" - self.backend.validate_ids(ids=naive_crn.probabilities.dataframe["id"].to_list()) - self.backend.validate_ids( - ids=naive_crn.probabilities.dataframe["left_id"].to_list() + # Retrieve + pre_results = naive_crn.results + + assert isinstance(pre_results, Results) + assert len(pre_results.probabilities) > 0 + assert pre_results.metadata.name == "naive_test.crn" + + self.backend.validate_ids(ids=pre_results.probabilities["id"].to_pylist()) + self.backend.validate_ids(ids=pre_results.probabilities["left_id"].to_pylist()) + self.backend.validate_ids(ids=pre_results.probabilities["right_id"].to_pylist()) + + # Set + target_row = pre_results.probabilities.to_pylist()[0] + target_id = target_row["id"] + target_left_id = target_row["left_id"] + target_right_id = target_row["right_id"] + + matches_id_mask = pc.not_equal(pre_results.probabilities["id"], target_id) + matches_left_mask = pc.not_equal( + pre_results.probabilities["left_id"], target_left_id ) - self.backend.validate_ids( - ids=naive_crn.probabilities.dataframe["right_id"].to_list() + matches_right_mask = pc.not_equal( + pre_results.probabilities["right_id"], target_right_id ) - def test_model_get_clusters(self): - """Test that a model's ClusterResults can be retrieved.""" - self.setup_database("dedupe") - naive_crn = self.backend.get_model(model="naive_test.crn") - assert isinstance(naive_crn.clusters, ClusterResults) - assert len(naive_crn.clusters.dataframe) > 0 - assert naive_crn.clusters.metadata.name == "naive_test.crn" + combined_mask = pc.and_( + pc.and_(matches_id_mask, matches_left_mask), matches_right_mask + ) + df_probabilities_truncated = pre_results.probabilities.filter(combined_mask) - self.backend.validate_ids(ids=naive_crn.clusters.dataframe["parent"].to_list()) - self.backend.validate_ids(ids=naive_crn.clusters.dataframe["child"].to_list()) + results = Results( + probabilities=df_probabilities_truncated.select( + ["left_id", "right_id", "probability"] + ), + model=pre_results.model, + metadata=pre_results.metadata, + ) + + naive_crn.results = results + + # Retrieve again + post_results = naive_crn.results + + # Check difference + assert len(pre_results.probabilities) != len(post_results.probabilities) + + # Check similarity + assert pre_results.metadata.name == post_results.metadata.name def test_model_truth(self): """Test that a model's truth can be set and retrieved.""" @@ -360,73 +381,6 @@ def test_model_ancestors(self): assert truth_found - def test_model_results(self): - """Test that a model's Results can be set and retrieved.""" - self.setup_database("dedupe") - naive_crn = self.backend.get_model(model="naive_test.crn") - - # Retrieve - pre_results = naive_crn.results - - assert isinstance(pre_results, Results) - assert len(pre_results.probabilities.dataframe) > 0 - assert pre_results.probabilities.metadata.name == "naive_test.crn" - assert len(pre_results.clusters.dataframe) > 0 - assert pre_results.clusters.metadata.name == "naive_test.crn" - - # Set - target_row = pre_results.probabilities.dataframe.iloc[0] - target_id = target_row["id"] - target_left_id = target_row["left_id"] - target_right_id = target_row["right_id"] - - matches_id_mask = pre_results.probabilities.dataframe["id"] != target_id - matches_left_mask = ( - pre_results.probabilities.dataframe["left_id"] != target_left_id - ) - matches_right_mask = ( - pre_results.probabilities.dataframe["right_id"] != target_right_id - ) - - df_probabilities_truncated = pre_results.probabilities.dataframe[ - matches_id_mask & matches_left_mask & matches_right_mask - ].copy() - - probabilities_truncated = ProbabilityResults( - dataframe=df_probabilities_truncated[ - ["left_id", "right_id", "probability"] - ].reset_index( - drop=True - ), # Reset so adding ID doesn't try to match old index - model=pre_results.probabilities.model, - metadata=pre_results.probabilities.metadata, - ) - - results = Results( - probabilities=probabilities_truncated, - clusters=to_clusters(results=probabilities_truncated), - ) - - naive_crn.results = results - - # Retrieve again - post_results = naive_crn.results - - # Check difference - assert len(pre_results.probabilities.dataframe) != len( - post_results.probabilities.dataframe - ) - assert len(pre_results.clusters.dataframe) != len( - post_results.clusters.dataframe - ) - - # Check similarity - assert ( - pre_results.probabilities.metadata.name - == post_results.probabilities.metadata.name - ) - assert pre_results.clusters.metadata.name == post_results.clusters.metadata.name - def test_model_ancestors_cache(self): """Test that a model's ancestors cache can be set and retrieved.""" self.setup_database("link") diff --git a/test/server/test_postgresql.py b/test/server/test_postgresql.py index b8c6992b..f20463ad 100644 --- a/test/server/test_postgresql.py +++ b/test/server/test_postgresql.py @@ -1,215 +1,13 @@ from typing import Iterable +import pyarrow as pa import pytest -import rustworkx as rx -from pandas import DataFrame from sqlalchemy import text -from matchbox.client.results import ( - ClusterResults, - ModelMetadata, - ModelType, - ProbabilityResults, -) from matchbox.server.postgresql.benchmark.generate_tables import generate_all_tables from matchbox.server.postgresql.benchmark.init_schema import create_tables, empty_schema from matchbox.server.postgresql.db import MBDB -from matchbox.server.postgresql.utils.insert import _cluster_results_to_hierarchical - - -@pytest.fixture -def model_metadata(): - return ModelMetadata( - name="test_model", - description="Test model metadata", - type=ModelType.DEDUPER, - left_source="left", - ) - - -def create_results(prob_data: dict, cluster_data: dict, metadata: ModelMetadata): - """Helper to create ProbabilityResults and ClusterResults from test data.""" - prob_df = DataFrame(prob_data) - cluster_df = DataFrame(cluster_data) - - return ( - ProbabilityResults( - dataframe=prob_df.convert_dtypes(dtype_backend="pyarrow"), - metadata=metadata, - ), - ClusterResults( - dataframe=cluster_df.convert_dtypes(dtype_backend="pyarrow"), - metadata=metadata, - ), - ) - - -def verify_hierarchy(hierarchy: list[tuple[bytes, bytes, float]]) -> None: - """ - Verify each item has exactly one ultimate parent at each relevant threshold. - - Args: - hierarchy: List of (parent, child, threshold) relationships - """ - # Group relationships by threshold - thresholds = sorted({t for _, _, t in hierarchy}, reverse=True) - - for threshold in thresholds: - # Build graph of relationships at this threshold - graph = rx.PyDiGraph() - nodes = {} # hash -> node_id - - # Add all nodes first - edges = [(p, c) for p, c, t in hierarchy if t >= threshold] - items = set() # Track individual items (leaves) - - for parent, child in edges: - if parent not in nodes: - nodes[parent] = graph.add_node(parent) - if child not in nodes: - nodes[child] = graph.add_node(child) - # If this child never appears as a parent, it's an item - if child not in {p for p, _ in edges}: - items.add(child) - - # Add edges - for parent, child in edges: - graph.add_edge(nodes[parent], nodes[child], None) - - # For each item, find its ultimate parents - for item in items: - item_node = nodes[item] - ancestors = set() - - # Find all ancestors that have no parents themselves - for node in graph.node_indices(): - node_hash = graph.get_node_data(node) - if ( - rx.has_path(graph, node, item_node) or node == item_node - ) and graph.in_degree(node) == 0: - ancestors.add(node_hash) - - assert len(ancestors) == 1, ( - f"Item {item} has {len(ancestors)} ultimate parents at " - f"threshold {threshold}: {ancestors}" - ) - - -@pytest.mark.parametrize( - ("prob_data", "cluster_data", "expected_relations"), - [ - # Test case 1: Equal probability components - ( - { - "id": ["ab", "bc", "cd"], - "left_id": ["a", "b", "c"], - "right_id": ["b", "c", "d"], - "probability": [1.0, 1.0, 1.0], - }, - { - "parent": ["abcd", "abcd", "abcd", "abcd"], - "child": ["a", "b", "c", "d"], - "threshold": [1.0, 1.0, 1.0, 1.0], - }, - { - ("ab", "a", 1.0), - ("ab", "b", 1.0), - ("bc", "b", 1.0), - ("bc", "c", 1.0), - ("cd", "c", 1.0), - ("cd", "d", 1.0), - ("abcd", "ab", 1.0), - ("abcd", "bc", 1.0), - ("abcd", "cd", 1.0), - }, - ), - # Test case 2: Asymmetric probability components - ( - { - "id": ["wx", "xy", "yz"], - "left_id": ["w", "x", "y"], - "right_id": ["x", "y", "z"], - "probability": [0.9, 0.85, 0.8], - }, - { - "parent": [ - "wx", - "wx", - "wxy", - "wxy", - "wxy", - "wxyz", - "wxyz", - "wxyz", - "wxyz", - ], - "child": ["w", "x", "w", "x", "y", "w", "x", "y", "z"], - "threshold": [0.9, 0.9, 0.85, 0.85, 0.85, 0.8, 0.8, 0.8, 0.8], - }, - { - ("wx", "w", 0.9), - ("wx", "x", 0.9), - ("xy", "x", 0.85), - ("xy", "y", 0.85), - ("wxy", "wx", 0.85), - ("wxy", "xy", 0.85), - ("yz", "y", 0.8), - ("yz", "z", 0.8), - ("wxyz", "wxy", 0.8), - ("wxyz", "yz", 0.8), - }, - ), - # Test case 3: Empty input - ( - { - "id": [], - "left_id": [], - "right_id": [], - "probability": [], - }, - { - "parent": [], - "child": [], - "threshold": [], - }, - set(), - ), - # Test case 4: Single two-item component - ( - { - "id": ["xy"], - "left_id": ["x"], - "right_id": ["y"], - "probability": [0.9], - }, - { - "parent": ["xy", "xy"], - "child": ["x", "y"], - "threshold": [0.9, 0.9], - }, - { - ("xy", "x", 0.9), - ("xy", "y", 0.9), - }, - ), - ], - ids=["equal_prob", "asymmetric_prob", "empty", "single_component"], -) -def test_cluster_results_to_hierarchical( - prob_data, cluster_data, expected_relations, model_metadata -): - """Test hierarchical clustering with various input scenarios.""" - prob_results, cluster_results = create_results( - prob_data, cluster_data, model_metadata - ) - - hierarchy = _cluster_results_to_hierarchical(prob_results, cluster_results) - actual_relations = set((p, c, t) for p, c, t in hierarchy.itertuples(index=False)) - - assert actual_relations == expected_relations - - if actual_relations: # Skip verification for empty case - verify_hierarchy(hierarchy.itertuples(index=False)) +from matchbox.server.postgresql.utils.insert import HashIDMap def test_benchmark_init_schema(): @@ -260,4 +58,41 @@ def array_encode(array: Iterable[str]): for c in df.columns: if df[c].dtype == "uint64": df[c] = df[c].astype("int64") - df.to_sql(table_name, con, schema) + df.to_sql(name=table_name, con=con, schema=schema) + + +def test_hash_id_map(): + """Test HashIDMap core functionality including basic operations.""" + # Initialize with some existing mappings + lookup = pa.Table.from_arrays( + [ + pa.array([1, 2], type=pa.uint64()), + pa.array([b"hash1", b"hash2"], type=pa.binary()), + ], + names=["id", "hash"], + ) + hash_map = HashIDMap(start=100, lookup=lookup) + + # Test getting existing hashes + ids = pa.array([2, 1], type=pa.uint64()) + hashes = hash_map.get_hashes(ids) + assert hashes.to_pylist() == [b"hash2", b"hash1"] + + # Test getting mix of existing and new hashes + input_hashes = pa.array([b"hash1", b"new_hash", b"hash2"], type=pa.binary()) + returned_ids = hash_map.get_ids(input_hashes) + + # Verify results + id_list = returned_ids.to_pylist() + assert id_list[0] == 1 # Existing hash1 + assert id_list[2] == 2 # Existing hash2 + assert id_list[1] == 100 # New hash got next available ID + + # Verify lookup table was updated correctly + assert hash_map.lookup.shape == (3, 3) + assert hash_map.next_int == 101 + + # Test error handling for missing IDs + with pytest.raises(ValueError) as exc_info: + hash_map.get_hashes(pa.array([999], type=pa.uint64())) + assert "not found in lookup table" in str(exc_info.value)