diff --git a/src/matchbox/common/results.py b/src/matchbox/common/results.py index 5ac2a784..1c0aa7ff 100644 --- a/src/matchbox/common/results.py +++ b/src/matchbox/common/results.py @@ -1,5 +1,6 @@ import logging from abc import ABC, abstractmethod +from enum import StrEnum from typing import TYPE_CHECKING, Any, List import rustworkx as rx @@ -11,11 +12,11 @@ from matchbox.server.base import MatchboxDBAdapter, inject_backend from matchbox.server.models import Cluster, Probability from pandas import DataFrame -from pydantic import BaseModel, ConfigDict, model_validator +from pydantic import BaseModel, ConfigDict, field_validator, model_validator from sqlalchemy import Table if TYPE_CHECKING: - from matchbox.models.models import Model + from matchbox.models.models import Model, ModelMetadata else: Model = Any @@ -25,19 +26,43 @@ load_dotenv(dotenv_path) +class ModelType(StrEnum): + """Enumeration of supported model types.""" + + LINKER = "linker" + DEDUPER = "deduper" + + +class ModelMetadata(BaseModel): + """Metadata for a model.""" + + name: str + description: str + type: ModelType + left_source: str + right_source: str | None = None # Only used for linker models + + class ResultsBaseDataclass(BaseModel, ABC): + """Base class for results dataclasses. + + Model is required during construction and calculation, but not when loading + from storage. + """ + model_config = ConfigDict(arbitrary_types_allowed=True) dataframe: DataFrame - model: Model + model: Model | None + metadata: ModelMetadata _expected_fields: list[str] @model_validator(mode="after") def _check_dataframe(self) -> Table: """Verifies the table contains the expected fields.""" - table_fields = sorted(self.dataframe.columns) - expected_fields = sorted(self._expected_fields) + table_fields = set(self.dataframe.columns) + expected_fields = set(self._expected_fields) if table_fields != expected_fields: raise ValueError(f"Expected {expected_fields}. \n" f"Found {table_fields}.") @@ -63,7 +88,8 @@ def to_records(self) -> list[Probability | Cluster]: class ProbabilityResults(ResultsBaseDataclass): """Probabilistic matches produced by linkers and dedupers. - Inherits the following attributes from ResultsBaseDataclass. + There are pairs of records/clusters with a probability of being a match. + The hash is the hash of the sorted left and right ids. _expected_fields enforces the shape of the dataframe. @@ -73,11 +99,26 @@ class ProbabilityResults(ResultsBaseDataclass): """ _expected_fields: list[str] = [ + "hash", "left_id", "right_id", "probability", ] + @field_validator("dataframe") + @classmethod + def add_hash(cls, dataframe: DataFrame) -> DataFrame: + """Adds a hash column to the dataframe if it doesn't already exist.""" + if "hash" not in dataframe.columns: + dataframe[["left_id", "right_id"]] = dataframe[ + ["left_id", "right_id"] + ].astype("binary[pyarrow]") + dataframe["hash"] = columns_to_value_ordered_hash( + data=dataframe, columns=["left_id", "right_id"] + ) + dataframe["hash"] = dataframe["hash"].astype("binary[pyarrow]") + return dataframe[["hash", "left_id", "right_id", "probability"]] + def inspect_with_source( self, left_data: DataFrame, left_key: str, right_data: DataFrame, right_key: str ) -> DataFrame: @@ -130,20 +171,10 @@ def to_records(self, backend: MatchboxDBAdapter | None) -> set[Probability]: backend.validate_hashes(hashes=self.dataframe.left_id.unique().tolist()) backend.validate_hashes(hashes=self.dataframe.right_id.unique().tolist()) - # Preprocess the dataframe - pre_prep_df = self.dataframe[["left_id", "right_id", "probability"]].copy() - pre_prep_df[["left_id", "right_id"]] = pre_prep_df[ - ["left_id", "right_id"] - ].astype("binary[pyarrow]") - pre_prep_df["sha1"] = columns_to_value_ordered_hash( - data=pre_prep_df, columns=["left_id", "right_id"] - ) - pre_prep_df["sha1"] = pre_prep_df["sha1"].astype("binary[pyarrow]") - return { Probability(hash=row[0], left=row[1], right=row[2], probability=row[3]) - for row in pre_prep_df[ - ["sha1", "left_id", "right_id", "probability"] + for row in self.dataframe[ + ["hash", "left_id", "right_id", "probability"] ].to_numpy() } @@ -151,7 +182,8 @@ def to_records(self, backend: MatchboxDBAdapter | None) -> set[Probability]: class ClusterResults(ResultsBaseDataclass): """Cluster data produced by using to_clusters on ProbabilityResults. - Inherits the following attributes from ResultsBaseDataclass. + This is the connected components of the probabilistic matches at every + threshold of probabilitity. The parent is the hash of the sorted children. _expected_fields enforces the shape of the dataframe. @@ -211,134 +243,67 @@ def to_records(self) -> set[Cluster]: class Results(BaseModel): + """A container for the results of a model run. + + Contains all the information any backend will need to store the results. + """ + model_config = ConfigDict(arbitrary_types_allowed=True) - model: Model probabilities: ProbabilityResults clusters: ClusterResults @inject_backend def to_matchbox(self, backend: MatchboxDBAdapter) -> None: """Writes the results to the Matchbox database.""" - self.model.insert_model() - self.model.clusters = self - - -def process_components( - G: rx.PyGraph, - current_pairs: set[bytes], - added: dict[bytes, int], - pair_children: dict[bytes, list[bytes]], -) -> list[tuple[bytes, list[bytes], int]]: - """ - Process connected components in the current graph state. - - Identifies which 2-item components have merged into larger components. + if self.probabilities.model != self.clusters.model: + raise ValueError("Probabilities and clusters must be from the same model.") - Returns: - List of (parent_hash, children, size) for components > 2 items - where children includes both individual items and parent hashes - """ - new_components = [] - component_with_size = [ - (component, len(component)) for component in rx.connected_components(G) - ] - - for component, size in component_with_size: - if size <= 2: - continue - - # Get all node hashes in component - node_hashes = [G.get_node_data(node) for node in component] - - # Find which 2-item parents are part of this component - component_pairs = { - pair - for pair in current_pairs - if any(G.has_node(added[h]) for h in node_hashes) - } - - # Children are individual nodes not in pairs, plus the pair parents - children = component_pairs | { - h - for h in node_hashes - if not any(h in pair_children[p] for p in component_pairs) - } - - parent_hash = list_to_value_ordered_hash(sorted(children)) - new_components.append((parent_hash, list(children))) - - return new_components + self.clusters.model.insert_model() + self.clusters.model.results = self def to_clusters(results: ProbabilityResults) -> ClusterResults: """ - Takes a models probabilistic outputs and turns them into clusters. - - Performs connected components at decreasing thresholds from 1.0 to return every - possible component in a hierarchical tree. - - * Stores all two-item components with their original probabilities - * For larger components, stores the individual items and two-item parent hashes - as children, with a new parent hash - - Args: - results: ProbabilityResults object + Converts probabilities into a list of connected components formed at each threshold. Returns: - ClusterResults object + ClusterResults sorted by threshold descending. """ G = rx.PyGraph() added: dict[bytes, int] = {} - pair_children: dict[bytes, list[bytes]] = {} - current_pairs: set[bytes] = set() - seen_larger: set[bytes] = set() - - clusters = {"parent": [], "child": [], "threshold": []} + components: dict[str, list] = {"parent": [], "child": [], "threshold": []} - # 1. Create all 2-item components with original probabilities - initial_edges = ( - results.dataframe.filter(["left_id", "right_id", "probability"]) + # Sort probabilities descending and process in order of decreasing probability + edges = ( + results.dataframe.sort_values("probability", ascending=False) + .filter(["left_id", "right_id", "probability"]) .astype({"left_id": "binary[pyarrow]", "right_id": "binary[pyarrow]"}) .itertuples(index=False, name=None) ) - for left, right, prob in initial_edges: + for left, right, prob in edges: + # Add nodes if not seen before for hash_val in (left, right): if hash_val not in added: idx = G.add_node(hash_val) added[hash_val] = idx - children = sorted([left, right]) - parent_hash = list_to_value_ordered_hash(children) - - pair_children[parent_hash] = children - current_pairs.add(parent_hash) - - clusters["parent"].extend([parent_hash] * 2) - clusters["child"].extend(children) - clusters["threshold"].extend([prob] * 2) - + # Get state, add edge, get new state, add anything new to results + old_components = {frozenset(comp) for comp in rx.connected_components(G)} G.add_edge(added[left], added[right], None) + new_components = {frozenset(comp) for comp in rx.connected_components(G)} - # 2. Process at each probability threshold - sorted_probabilities = sorted( - results.dataframe["probability"].unique(), reverse=True - ) - - for threshold in sorted_probabilities: - # Find new larger components at this threshold - new_components = process_components(G, current_pairs) + for comp in new_components - old_components: + children = sorted([G.get_node_data(n) for n in comp]) + parent = list_to_value_ordered_hash(children) - # Add new components to results - for parent_hash, children in new_components: - if parent_hash not in seen_larger: - seen_larger.add(parent_hash) - clusters["parent"].extend([parent_hash] * len(children)) - clusters["child"].extend(children) - clusters["threshold"].extend([threshold] * len(children)) + components["parent"].extend([parent] * len(children)) + components["child"].extend(children) + components["threshold"].extend([prob] * len(children)) return ClusterResults( - dataframe=DataFrame(clusters).convert_dtypes(dtype_backend="pyarrow"), + dataframe=DataFrame(components).convert_dtypes(dtype_backend="pyarrow"), model=results.model, + metadata=results.metadata, ) diff --git a/src/matchbox/models/linkers/base.py b/src/matchbox/models/linkers/base.py index 4a60fdbe..0b939177 100644 --- a/src/matchbox/models/linkers/base.py +++ b/src/matchbox/models/linkers/base.py @@ -1,12 +1,9 @@ import warnings from abc import ABC, abstractmethod -from typing import Any, Callable, Dict from pandas import DataFrame from pydantic import BaseModel, Field, ValidationInfo, field_validator -from matchbox.common.results import ProbabilityResults - class LinkerSettings(BaseModel): """ @@ -51,30 +48,3 @@ def prepare(self, left: DataFrame, right: DataFrame) -> None: @abstractmethod def link(self, left: DataFrame, right: DataFrame) -> DataFrame: return - - -def make_linker( - link_run_name: str, - description: str, - linker: Linker, - linker_settings: Dict[str, Any], - left_data: DataFrame, - left_source: str, - right_data: DataFrame, - right_source: str, -) -> Callable[[DataFrame, DataFrame], ProbabilityResults]: - linker_instance = linker.from_settings(**linker_settings) - linker_instance.prepare(left=left_data, right=right_data) - - def linker( - left_data: DataFrame = left_data, right_data: DataFrame = right_data - ) -> ProbabilityResults: - return ProbabilityResults( - dataframe=linker_instance.link(left=left_data, right=right_data), - run_name=link_run_name, - description=description, - left=left_source, - right=right_source, - ) - - return linker diff --git a/src/matchbox/models/models.py b/src/matchbox/models/models.py index afa99002..c67b0dda 100644 --- a/src/matchbox/models/models.py +++ b/src/matchbox/models/models.py @@ -1,14 +1,13 @@ -from enum import StrEnum from functools import wraps from typing import Any, Callable, ParamSpec, TypeVar -import numpy as np -from pandas import DataFrame, Series -from pydantic import BaseModel +from pandas import DataFrame from matchbox.common.exceptions import MatchboxModelError from matchbox.common.results import ( ClusterResults, + ModelMetadata, + ModelType, ProbabilityResults, Results, to_clusters, @@ -22,23 +21,6 @@ R = TypeVar("R") -class ModelType(StrEnum): - """Enumeration of supported model types.""" - - LINKER = "linker" - DEDUPER = "deduper" - - -class ModelMetadata(BaseModel): - """Metadata for a model.""" - - name: str - description: str - type: ModelType - left_source: str - right_source: str | None = None # Only used for linker models - - def ensure_connection(func: Callable[P, R]) -> Callable[P, R]: """Decorator to ensure model connection before method execution.""" @@ -99,61 +81,25 @@ def insert_model(self) -> None: @ensure_connection def probabilities(self) -> ProbabilityResults: """Retrieve probabilities associated with the model from the database.""" - n = len(self._model.probabilities) - - left_arr = np.empty(n, dtype="object") - right_arr = np.empty(n, dtype="object") - prob_arr = np.empty(n, dtype="float64") - - for i, prob in enumerate(self._model.probabilities): - left_arr[i] = prob.left - right_arr[i] = prob.right - prob_arr[i] = prob.probability - - df = DataFrame( - { - "left_id": Series(left_arr, dtype="binary[pyarrow]"), - "right_id": Series(right_arr, dtype="binary[pyarrow]"), - "probability": Series(prob_arr, dtype="float64[pyarrow]"), - } - ) - return ProbabilityResults(dataframe=df, model=self) + return self._model.probabilities @property @ensure_connection def clusters(self) -> ClusterResults: """Retrieve clusters associated with the model from the database.""" - total_rows = sum(len(cluster.children) for cluster in self._model.clusters) - - parent_arr = np.empty(total_rows, dtype="object") - child_arr = np.empty(total_rows, dtype="object") - threshold_arr = np.empty(total_rows, dtype="float64") - - idx = 0 - for cluster in self._model.clusters: - n_children = len(cluster.children) - # Set parent, repeated for each child - parent_arr[idx : idx + n_children] = cluster.parent - # Set children - child_arr[idx : idx + n_children] = cluster.children - # Set threshold, repeated for each child) - threshold_arr[idx : idx + n_children] = cluster.threshold - idx += n_children - - df = DataFrame( - { - "parent": Series(parent_arr, dtype="binary[pyarrow]"), - "child": Series(child_arr, dtype="binary[pyarrow]"), - "threshold": Series(threshold_arr, dtype="float64[pyarrow]"), - } - ) - return ClusterResults(dataframe=df, model=self) + return self._model.clusters + + @property + @ensure_connection + def results(self) -> Results: + """Retrieve results associated with the model from the database.""" + return self._model.results - @clusters.setter + @results.setter @ensure_connection - def clusters(self, clusters: ClusterResults) -> None: - """Insert clusters associated with the model into the backend database.""" - self._model.clusters = clusters.to_records(backend=self._backend) + def results(self, results: Results) -> None: + """Write results associated with the model to the database.""" + self._model.results = results @property @ensure_connection @@ -200,9 +146,7 @@ def calculate_probabilities(self) -> ProbabilityResults: return ProbabilityResults( dataframe=results, model=self, - description=self.metadata.description, - left=self.metadata.left_source, - right=self.metadata.right_source or self.metadata.left_source, + metadata=self.metadata, ) def calculate_clusters(self, probabilities: ProbabilityResults) -> ClusterResults: diff --git a/src/matchbox/server/base.py b/src/matchbox/server/base.py index 0c1ef4da..34e68827 100644 --- a/src/matchbox/server/base.py +++ b/src/matchbox/server/base.py @@ -19,7 +19,8 @@ from rustworkx import PyDiGraph from sqlalchemy import Engine -from matchbox.server.models import Cluster, Probability, Source +from matchbox.common.results import ClusterResults, ProbabilityResults, Results +from matchbox.server.models import Source if TYPE_CHECKING: from pandas import DataFrame as PandasDataFrame @@ -170,22 +171,33 @@ class ListableAndCountable(Countable, Listable): class MatchboxModelAdapter(ABC): - """An abstract base class for Matchbox model adapters.""" + """An abstract base class for Matchbox model adapters. + + Must be able to recover probabilities and clusters from the database, + but ultimately doesn't care how they're stored. + + Creates these with the pairwise probabilities and the connected components + of those pairs calculated at every threshold. + """ hash: bytes name: str @property @abstractmethod - def probabilities(self) -> set[Probability]: ... + def probabilities(self) -> ProbabilityResults: ... + + @property + @abstractmethod + def clusters(self) -> ClusterResults: ... @property @abstractmethod - def clusters(self) -> set[Cluster]: ... + def results(self) -> Results: ... - @clusters.setter + @results.setter @abstractmethod - def clusters(self, clusters: set[Cluster]) -> None: ... + def results(self, results: Results) -> None: ... @property @abstractmethod diff --git a/src/matchbox/server/postgresql/adapter.py b/src/matchbox/server/postgresql/adapter.py index 8fe295b4..f3c7405b 100644 --- a/src/matchbox/server/postgresql/adapter.py +++ b/src/matchbox/server/postgresql/adapter.py @@ -2,7 +2,7 @@ from pydantic import BaseModel, ConfigDict from rustworkx import PyDiGraph -from sqlalchemy import Engine, and_, bindparam, func +from sqlalchemy import Engine, and_, bindparam, func, select from sqlalchemy.orm import Session from matchbox.common.exceptions import ( @@ -10,12 +10,15 @@ MatchboxDatasetError, MatchboxModelError, ) +from matchbox.common.results import ClusterResults, ProbabilityResults, Results from matchbox.server.base import MatchboxDBAdapter, MatchboxModelAdapter -from matchbox.server.models import Cluster, Probability, Source, SourceWarehouse +from matchbox.server.models import Source, SourceWarehouse from matchbox.server.postgresql.db import MBDB, MatchboxPostgresSettings from matchbox.server.postgresql.orm import ( Clusters, + Contains, Models, + ModelsFrom, Probabilities, Sources, ) @@ -23,8 +26,13 @@ from matchbox.server.postgresql.utils.insert import ( insert_dataset, insert_model, + insert_results, +) +from matchbox.server.postgresql.utils.query import query +from matchbox.server.postgresql.utils.results import ( + get_model_clusters, + get_model_probabilities, ) -from matchbox.server.postgresql.utils.query import get_model_probabilities, query if TYPE_CHECKING: from pandas import DataFrame as PandasDataFrame @@ -36,24 +44,6 @@ ArrowTable = Any -# TODO: Implement cluster getter/setter -# As part of this need to implement insert_clusters - -# TODO: Filtered classes will no longer work, rethink 'em - -# TODO: Redo ancestor cache attribute -- now works different in the ORM -# Double check I updated this in insert_model, pretty sure I did - -# TODO: At last, can rewrite the query function to use the new structures -# 1. For each dataset in the selector -# a. Get the model tree (now very easy) -# b. Resolve any threshold discrepancies -# c. Filter Clusters and Contains by the model tree and thresholds -# d. Recurse down the Clusters and Contains to get the ultimate hash per record -# e. Join this to the actual dataset -# 2. Stack 'em and return - - class FilteredClusters(BaseModel): """Wrapper class for filtered cluster queries""" @@ -76,15 +66,11 @@ class FilteredProbabilities(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) over_truth: bool = False - mb_model: Models | None = None def count(self) -> int: with MBDB.get_session() as session: query = session.query(func.count()).select_from(Probabilities) - if self.mb_model is not None: - query = query.filter(Probabilities.model == self.mb_model.hash) - if self.over_truth: query = query.join(Models, Probabilities.model == Models.hash).filter( and_( @@ -111,19 +97,24 @@ def name(self) -> str: return self.model.name @property - def probabilities(self) -> set[Probability]: + def probabilities(self) -> ProbabilityResults: """Retrieve probabilities for this model.""" - return get_model_probabilities(engine=MBDB.get_engine(), model=self.model.hash) + return get_model_probabilities(engine=MBDB.get_engine(), model=self.model) @property - def clusters(self) -> set[Cluster]: + def clusters(self) -> ClusterResults: """Retrieve clusters for this model.""" - pass + return get_model_clusters(engine=MBDB.get_engine(), model=self.model) + + @property + def results(self) -> Results: + """Retrieve results for this model.""" + return Results(probabilities=self.probabilities, clusters=self.clusters) - @clusters.setter - def clusters(self, clusters: set[Cluster]) -> None: - """Insert clusters for this model, which will also insert probabilities.""" - pass + @results.setter + def results(self, results: Results) -> None: + """Inserts results for this model.""" + insert_results(results=results, model=self.model) @property def truth(self) -> float: @@ -139,7 +130,7 @@ def truth(self, truth: float) -> None: session.commit() @property - def ancestors(self) -> dict[str, float]: + def ancestors(self) -> dict[str, float | None]: """ Gets the current truth values of all ancestors. Returns a dict mapping model names to their current truth thresholds. @@ -147,70 +138,59 @@ def ancestors(self) -> dict[str, float]: Unlike ancestors_cache which returns cached values, this property returns the current truth values of all ancestor models. """ - with Session(MBDB.get_engine()) as session: - if not self.model.ancestors: - return {} - - ancestor_models = ( - session.query(Models) - .filter(Models.hash.in_(self.model.ancestors)) - .all() - ) - - return { - model.name: float(model.truth) - for model in ancestor_models - if model.truth is not None - } + return {model.name: model.truth for model in self.model.ancestors} @property def ancestors_cache(self) -> dict[str, float]: """ Gets the cached ancestor thresholds, converting hashes to model names. - Returns a dict mapping model names to their truth thresholds. + + Returns a dictionary mapping model names to their truth thresholds. This is required because each point of truth needs to be stable, so we choose when to update it, caching the ancestor's values in the model itself. """ with Session(MBDB.get_engine()) as session: - stored_ancestors = self.model.ancestors_cache or {} - ancestor_models = ( - session.query(Models.hash, Models.name) - .filter( - Models.hash.in_([bytes.fromhex(h) for h in stored_ancestors.keys()]) - ) - .all() + query = ( + select(Models.name, ModelsFrom.truth_cache) + .join(Models, Models.hash == ModelsFrom.parent) + .where(ModelsFrom.child == self.model.hash) + .where(ModelsFrom.truth_cache.isnot(None)) ) - hash_to_name = {m.hash.hex(): m.name for m in ancestor_models} return { - hash_to_name[hash_str]: float(threshold) - for hash_str, threshold in stored_ancestors.items() + name: truth_cache for name, truth_cache in session.execute(query).all() } @ancestors_cache.setter def ancestors_cache(self, new_values: dict[str, float]) -> None: """ Updates the cached ancestor thresholds. - Takes a dict mapping model names to their truth thresholds. - Only updates the float values, preserving the existing hash structure. + + Args: + new_values: Dictionary mapping model names to their truth thresholds """ + with Session(MBDB.get_engine()) as session: - name_to_hash = { - m.name: m.hash.hex() - for m in session.query(Models) - .filter(Models.name.in_(new_values.keys())) + model_names = list(new_values.keys()) + name_to_hash = dict( + session.query(Models.name, Models.hash) + .filter(Models.name.in_(model_names)) .all() - } + ) - # Update only the values in the existing JSON structure - current = self.model.ancestors_cache or {} - for name, threshold in new_values.items(): - if hash_str := name_to_hash.get(name): - current[hash_str] = threshold + for model_name, truth_value in new_values.items(): + parent_hash = name_to_hash.get(model_name) + if parent_hash is None: + raise ValueError(f"Model '{model_name}' not found in database") + + session.execute( + ModelsFrom.__table__.update() + .where(ModelsFrom.parent == parent_hash) + .where(ModelsFrom.child == self.model.hash) + .values(truth_cache=truth_value) + ) - self.model.ancestors_cache = current - session.add(self.model) session.commit() @classmethod @@ -234,8 +214,8 @@ def __init__(self, settings: MatchboxPostgresSettings): self.models = Models self.data = FilteredClusters(has_dataset=True) self.clusters = FilteredClusters(has_dataset=False) + self.merges = Contains self.creates = FilteredProbabilities(over_truth=True) - self.merges = FilteredProbabilities() self.proposes = FilteredProbabilities() def query( @@ -399,10 +379,22 @@ def insert_model( MatchboxDataError if, for a linker, the source models weren't found in the database """ + with Session(MBDB.get_engine()) as session: + left_model = session.query(Models).filter(Models.name == left).first() + if not left_model: + raise MatchboxModelError(model_name=left) + + # Overwritten with actual right model if in a link job + right_model = left_model + if right: + right_model = session.query(Models).filter(Models.name == right).first() + if not right_model: + raise MatchboxModelError(model_name=right) + insert_model( model=model, - left=left, - right=right, + left=left_model, + right=right_model, description=description, engine=MBDB.get_engine(), ) diff --git a/src/matchbox/server/postgresql/utils/db.py b/src/matchbox/server/postgresql/utils/db.py index 98fb78e4..9cf683a6 100644 --- a/src/matchbox/server/postgresql/utils/db.py +++ b/src/matchbox/server/postgresql/utils/db.py @@ -100,7 +100,7 @@ def _batches( def batch_ingest( - records: list[tuple], + records: list[tuple[Any]], table: DeclarativeMeta, conn: Connection, batch_size: int, diff --git a/src/matchbox/server/postgresql/utils/insert.py b/src/matchbox/server/postgresql/utils/insert.py index 4063b6c8..fb93de01 100644 --- a/src/matchbox/server/postgresql/utils/insert.py +++ b/src/matchbox/server/postgresql/utils/insert.py @@ -8,9 +8,9 @@ from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import Session -from matchbox.common.exceptions import MatchboxModelError from matchbox.common.hash import dataset_to_hashlist, list_to_value_ordered_hash -from matchbox.server.models import Probability, Source +from matchbox.common.results import ClusterResults, Results +from matchbox.server.models import Source from matchbox.server.postgresql.orm import ( Clusters, Contains, @@ -102,16 +102,20 @@ def insert_dataset(dataset: Source, engine: Engine, batch_size: int) -> None: def insert_model( - model: str, left: str, description: str, engine: Engine, right: str | None = None + model: str, + left: Models, + right: Models, + description: str, + engine: Engine, ) -> None: """Writes a model to Matchbox with a default truth value of 1.0. Args: model: Name of the new model left: Name of the left parent model + right: Name of the left parent model. Same as left in a link job description: Model description engine: SQLAlchemy engine instance - right: Optional name of the right parent model Raises: MatchboxModelError if the specified parent models don't exist. @@ -121,18 +125,7 @@ def insert_model( """ logic_logger.info(f"[{model}] Registering model") with Session(engine) as session: - left_model = session.query(Models).filter(Models.name == left).first() - if not left_model: - raise MatchboxModelError(model_name=left) - - # Overwritten with actual right model if in a link job - right_model = left_model - if right: - right_model = session.query(Models).filter(Models.name == right).first() - if not right_model: - raise MatchboxModelError(model_name=right) - - model_hash = list_to_value_ordered_hash([left_model.hash, right_model.hash]) + model_hash = list_to_value_ordered_hash([left.hash, right.hash]) # Create new model new_model = Models( @@ -145,7 +138,7 @@ def insert_model( session.add(new_model) session.flush() - def create_closure_entries(parent_model: Models) -> None: + def _create_closure_entries(parent_model: Models) -> None: """Create closure entries for the new model.""" session.add( ModelsFrom( @@ -173,63 +166,152 @@ def create_closure_entries(parent_model: Models) -> None: ) # Create model lineage entries - create_closure_entries(left_model) + _create_closure_entries(parent_model=left) - if right_model != left_model: - create_closure_entries(right_model) + if right != left: + _create_closure_entries(parent_model=right) session.commit() logic_logger.info(f"[{model}] Done!") -def insert_probabilities( - model: str, +def _cluster_results_to_hierarchical( + clusters: ClusterResults, +) -> list[tuple[bytes, bytes, float]]: + """ + Converts a Results object to a more efficient hierarchical structure for PostgreSQL. + + * Two-item components are given a threshold of their original pairwise probability + * Larger components are stored in a hierarchical structure, where if their children + are a known component at a higher threshold, they reference that component + + This allows all results to be recovered from the database, albeit inefficiently, + but allows simple and efficient querying of clusters at any threshold. + + This function requires that: + + * ClusterResults are sorted by threshold descending + * Two-item components thresholds are the original pairwise probabilities + + Args: + components_df: DataFrame with parent, child, threshold from to_components() + original_df: Original DataFrame with left_id, right_id, probability + + Returns: + A tuple of (parent, child, threshold) ready for insertion + """ + parents = [] + children = [] + thresholds = [] + + # hash -> (threshold, is_component) + component_info: dict[bytes, tuple[float, bool]] = {} + + # Process components in descending threshold order + for threshold, group in clusters.dataframe.groupby("threshold", sort=True): + current_components = set() + + # Process all parents at this threshold at once + for parent, parent_children in group.groupby("parent")["child"]: + child_hashes = frozenset(parent_children) + + # Partition children into original and subcomponents + original = [] + subcomponents = [] + + for child in child_hashes: + if child in component_info: + prev_threshold, is_comp = component_info[child] + if prev_threshold >= threshold and is_comp: + subcomponents.append(child) + continue + original.append(child) + + parents.extend([parent] * len(original)) + children.extend(original) + thresholds.extend([threshold] * len(original)) + + parents.extend([parent] * len(subcomponents)) + children.extend(subcomponents) + thresholds.extend([threshold] * len(subcomponents)) + + # Mark this parent as a component + component_info[parent] = (threshold, True) + current_components.add(parent) + + # Mark original children as non-components at this threshold + for child in original: + if child not in component_info: + component_info[child] = (threshold, False) + + return list(zip(parents, children, thresholds, strict=True)) + + +def insert_results( + model: Models, engine: Engine, - probabilities: list[Probability], + results: Results, batch_size: int, ) -> None: """ - Writes probabilities and their associated clusters to Matchbox. + Writes a Results object to Matchbox. + + The PostgreSQL backend stores clusters in a hierarchical structure, where + each component references its parent component at a higher threshold. + + This means two-item components are synonymous with their original pairwise + probabilities. + + This allows easy querying of clusters at any threshold. Args: - model: Name of the model to associate probabilities with + model: Model object to associate results with engine: SQLAlchemy engine instance - probabilities: List of Probability objects to insert + results: A results object batch_size: Number of records to insert in each batch Raises: MatchboxModelError if the specified model doesn't exist. """ - logic_logger.info(f"{model} Writing probability data with batch size {batch_size}") + logic_logger.info( + f"[{model.name}] Writing results data with batch size {batch_size}" + ) with Session(engine) as session: - db_model = session.query(Models).filter_by(name=model).first() - if db_model is None: - raise MatchboxModelError(model_name=model) - - model_hash = db_model.hash - try: # Clear existing probabilities for this model session.execute( - delete(Probabilities).where(Probabilities.model == model_hash) + delete(Probabilities).where(Probabilities.model == model.hash) ) session.commit() - logic_logger.info(f"{model} Removed old probabilities") + logic_logger.info(f"[{model.name}] Removed old probabilities") except SQLAlchemyError as e: session.rollback() - logic_logger.error(f"{model} Failed to clear old probabilities: {str(e)}") + logic_logger.error( + f"[{model.name}] Failed to clear old probabilities: {str(e)}" + ) raise with engine.connect() as conn: try: - total_records = len(probabilities) - logic_logger.info(f"{model} Inserting {total_records} probability objects") + total_records = results.clusters.dataframe.shape[0] + logic_logger.info( + f"[{model.name}] Inserting {total_records} probability objects" + ) + + cluster_records: list[tuple[bytes, None, None]] = [] + contains_records: list[tuple[bytes, bytes]] = [] + probability_records: list[tuple[bytes, bytes, float]] = [] - cluster_records = [(prob.hash, None, None) for prob in probabilities] + for parent, child, threshold in _cluster_results_to_hierarchical( + clusters=results.clusters + ): + cluster_records.append((parent, None, None)) + contains_records.append((parent, child)) + probability_records.append((model.hash, parent, threshold)) batch_ingest( records=cluster_records, @@ -238,14 +320,10 @@ def insert_probabilities( batch_size=batch_size, ) - contains_records = [] - for prob in probabilities: - contains_records.extend( - [ - (prob.hash, prob.left), - (prob.hash, prob.right), - ] - ) + logic_logger.info( + f"[{model.name}] Successfully inserted {len(cluster_records)} " + "objects into Clusters table" + ) batch_ingest( records=contains_records, @@ -254,9 +332,10 @@ def insert_probabilities( batch_size=batch_size, ) - probability_records = [ - (model_hash, prob.hash, prob.probability) for prob in probabilities - ] + logic_logger.info( + f"[{model.name}] Successfully inserted {len(contains_records)} " + "objects into Contains table" + ) batch_ingest( records=probability_records, @@ -266,12 +345,12 @@ def insert_probabilities( ) logic_logger.info( - f"{model} Successfully inserted {total_records} " - "probability objects and their associated clusters" + f"[{model.name}] Successfully inserted {len(probability_records)} " + "objects into Probabilities table" ) except SQLAlchemyError as e: - logic_logger.error(f"{model} Failed to insert data: {str(e)}") + logic_logger.error(f"[{model.name}] Failed to insert data: {str(e)}") raise - logic_logger.info(f"{model} Insert operation complete!") + logic_logger.info(f"[{model.name}] Insert operation complete!") diff --git a/src/matchbox/server/postgresql/utils/query.py b/src/matchbox/server/postgresql/utils/query.py index 30cd137c..8727a23a 100644 --- a/src/matchbox/server/postgresql/utils/query.py +++ b/src/matchbox/server/postgresql/utils/query.py @@ -4,7 +4,7 @@ import pyarrow as pa from pandas import ArrowDtype, DataFrame from sqlalchemy import Engine, and_, func, select -from sqlalchemy.orm import Session, aliased +from sqlalchemy.orm import Session from sqlalchemy.sql.selectable import Select from matchbox.common.db import sql_to_df @@ -12,7 +12,7 @@ MatchboxDatasetError, MatchboxModelError, ) -from matchbox.server.models import Probability, Source +from matchbox.server.models import Source from matchbox.server.postgresql.orm import ( Clusters, Contains, @@ -42,6 +42,16 @@ def key_to_sqlalchemy_label(key: str, source: Source) -> str: return f"{source.db_schema}_{source.db_table}_{key}" +# TODO: At last, can rewrite the query function to use the new structures +# 1. For each dataset in the selector +# a. Get the model tree (now very easy) +# b. Resolve any threshold discrepancies +# c. Filter Clusters and Contains by the model tree and thresholds +# d. Recurse down the Clusters and Contains to get the ultimate hash per record +# e. Join this to the actual dataset +# 2. Stack 'em and return + + def _get_threshold_for_model( model: Models, model_hash: bytes, @@ -323,97 +333,3 @@ def query( ) else: ValueError(f"return_type of {return_type} not valid") - - -def get_model_probabilities(engine: Engine, model_hash: bytes) -> set[Probability]: - """ - Get all original probabilities proposed by a model. - - These are identified by: - - * Exactly two children - * Both children are leaf nodes (not parents in Contains table) - - Orders children consistently based on their origin (dataset or proposing model). - - Args: - model_hash: Hash of the model to query - - Returns: - Set of Probability objects - """ - with Session(engine) as session: - Child = aliased(Clusters) - ChildProb = aliased(Probabilities) - - # Subquery to find clusters that are parents in Contains - parent_clusters = (select(Contains.parent).distinct()).scalar_subquery() - - # Main query - query = ( - select( - Clusters.hash, - Child.hash.label("child_hash"), - Child.dataset.label("source_dataset"), - ChildProb.model.label("source_model"), - Probabilities.probability, - ) - .join( - Probabilities, - and_( - Probabilities.cluster == Clusters.hash, - Probabilities.model == model_hash, - ), - ) - # Join to get children - .join(Contains, Contains.parent == Clusters.hash) - .join(Child, Child.hash == Contains.child) - # Left join to get potential source model - .outerjoin( - ChildProb, - and_(ChildProb.cluster == Child.hash), - ) - # Ensure children are leaf nodes - .where(~Child.hash.in_(parent_clusters)) - # Only get clusters with exactly two children - .having(func.count(Contains.child) == 2) - .group_by( - Clusters.hash, - Child.hash, - Child.dataset, - ChildProb.model, - Probabilities.probability, - ) - ) - - rows = session.execute(query).all() - probabilities: set[Probability] = set() - - # Group by parent hash to pair children - grouped = {} - for row in rows: - if row.hash not in grouped: - grouped[row.hash] = [] - source_hash = ( - row.source_dataset - if row.source_dataset is not None - else row.source_model - ) - grouped[row.hash].append((row.child_hash, source_hash)) - - for parent_hash, children in grouped.items(): - if len(children) == 2: - raise ValueError("Expected exactly two children") - - (left, _), (right, _) = sorted(children, key=lambda x: x[1]) - - probabilities.add( - Probability( - hash=parent_hash, - left=left, - right=right, - probability=rows[0].probability, # Same for both rows - ) - ) - - return probabilities diff --git a/src/matchbox/server/postgresql/utils/results.py b/src/matchbox/server/postgresql/utils/results.py new file mode 100644 index 00000000..800c597c --- /dev/null +++ b/src/matchbox/server/postgresql/utils/results.py @@ -0,0 +1,273 @@ +from typing import NamedTuple + +import pandas as pd +import pyarrow as pa +from sqlalchemy import Engine, and_, exists, func, select +from sqlalchemy import text as sqltext + +from matchbox.common.db import sql_to_df +from matchbox.common.results import ( + ClusterResults, + ModelMetadata, + ModelType, + ProbabilityResults, +) +from matchbox.server.postgresql.orm import ( + Clusters, + Contains, + Models, + ModelsFrom, + Probabilities, +) + + +class SourceInfo(NamedTuple): + """Information about a model's sources.""" + + left: Models + right: Models | None + left_ancestors: set[bytes] + right_ancestors: set[bytes] | None + + +def _get_model_parents(engine: Engine, model: Models) -> tuple[Models, Models | None]: + """Get the model's immediate parent models.""" + parent_query = ( + select(Models) + .join(ModelsFrom, Models.hash == ModelsFrom.parent) + .where(ModelsFrom.child == model.hash) + .where(ModelsFrom.level == 1) + ) + + with engine.connect() as conn: + parents = conn.execute(parent_query).fetchall() + + if len(parents) == 1: + return parents[0], None + elif len(parents) == 2: + p1, p2 = parents + # Put dataset first if it exists + if p1.type == "dataset": + return p1, p2 + elif p2.type == "dataset": + return p2, p1 + # Both models, maintain original order + return p1, p2 + else: + raise ValueError( + f"Model {model.name} has unexpected number of parents: {len(parents)}" + ) + + +def _get_source_info(engine: Engine, model: Models) -> SourceInfo: + """Get source models and their ancestry information.""" + left_source, right_source = _get_model_parents(engine, model) + + # Get ancestor sets including the sources themselves + left_ancestors = {m.hash for m in left_source.ancestors} + left_ancestors.add(left_source.hash) + + right_ancestors = None + if right_source: + right_ancestors = {m.hash for m in right_source.ancestors} + right_ancestors.add(right_source.hash) + + return SourceInfo(left_source, right_source, left_ancestors, right_ancestors) + + +def _get_leaf_pair_clusters(engine: Engine, model: Models) -> list[tuple]: + """Get all clusters with exactly two leaf children.""" + # Subquery to identify leaf nodes + leaf_nodes = ~exists().where(Contains.parent == Clusters.hash) + + query = ( + select( + Clusters.hash.label("parent_hash"), + Probabilities.probability, + func.array_agg(Clusters.hash).label("child_hashes"), + func.array_agg(Clusters.dataset).label("child_datasets"), + func.array_agg(Clusters.id).label("child_ids"), + ) + .join( + Probabilities, + and_( + Probabilities.cluster == Clusters.hash, + Probabilities.model == model.hash, + ), + ) + .join(Contains, Contains.parent == Clusters.hash) + .join(Clusters.children) + .where(leaf_nodes) + .group_by(Clusters.hash, Probabilities.probability) + .having(func.count() == 2) + ) + + with engine.connect() as conn: + return conn.execute(query).fetchall() + + +def _determine_hash_order( + engine: Engine, + hashes: list[bytes], + datasets: list[bytes], + left_source: Models, + left_ancestors: set[bytes], +) -> tuple[int, int]: + """Determine which child corresponds to left/right source.""" + # Check dataset assignment first + if datasets[0] == left_source.hash: + return 0, 1 + elif datasets[1] == left_source.hash: + return 1, 0 + + # Check probability ancestry + left_prob_query = ( + select(Probabilities) + .where(Probabilities.cluster == hashes[0]) + .where(Probabilities.model.in_(left_ancestors)) + ) + with engine.connect() as conn: + has_left_prob = conn.execute(left_prob_query).fetchone() is not None + + return (0, 1) if has_left_prob else (1, 0) + + +def get_model_probabilities(engine: Engine, model: Models) -> ProbabilityResults: + """ + Recover the model's ProbabilityResults. + + Probabilities are the model's Clusters identified by: + + * Exactly two children + * Both children are leaf nodes (not parents in Contains table) + + Args: + engine: SQLAlchemy engine + model: Model instance to query + + Returns: + A ProbabilityResults object containing pairwise probabilities and model metadata + """ + source_info: SourceInfo = _get_source_info(engine=engine, model=model) + + metadata = ModelMetadata( + name=model.name, + description=model.description or "", + type=ModelType.DEDUPER if source_info.right is None else ModelType.LINKER, + left_source=source_info.left.name, + right_source=source_info.right.name if source_info.right else None, + ) + + results = _get_leaf_pair_clusters(engine=engine, model=model) + + # Process results into pairs + rows: dict[str, list] = { + "hash": [], + "left_id": [], + "right_id": [], + "probability": [], + } + for parent_hash, prob, child_hashes, child_datasets, child_ids in results: + if metadata.type == ModelType.LINKER: + left_idx, right_idx = _determine_hash_order( + engine=engine, + hashes=child_hashes, + datasets=child_datasets, + left_source=source_info.left, + left_ancestors=source_info.left_ancestors, + ) + else: + # For dedupers, order doesn't matter + left_idx, right_idx = 0, 1 + + rows["hash"].append(parent_hash) + rows["left_id"].append(child_ids[left_idx][0]) + rows["right_id"].append(child_ids[right_idx][0]) + rows["probability"].append(prob) + + return ProbabilityResults( + dataframe=pd.DataFrame( + { + "hash": pd.Series(rows["hash"], dtype=pd.ArrowDtype(pa.binary())), + "left_id": pd.Series(rows["left_id"], dtype=pd.ArrowDtype(pa.binary())), + "right_id": pd.Series( + rows["right_id"], dtype=pd.ArrowDtype(pa.binary()) + ), + "probability": pd.Series( + rows["probability"], dtype=pd.ArrowDtype(pa.float32()) + ), + } + ), + metadata=metadata, + ) + + +def get_model_clusters(engine: Engine, model: Models) -> ClusterResults: + """ + Recover the model's Clusters. + + Clusters are the connected components of the model at every threshold. + + While they're stored in a hierarchical structure, we need to recover the + original components, where all child hashes are leaf Clusters. + + Args: + engine: SQLAlchemy engine + model: Model instance to query + + Returns: + A ClusterResults object containing connected components and model metadata + """ + # Get model metadata + source_info = _get_source_info(engine=engine, model=model) + metadata = ModelMetadata( + name=model.name, + description=model.description or "", + type=ModelType.DEDUPER if source_info.right is None else ModelType.LINKER, + left_source=source_info.left.name, + right_source=source_info.right.name if source_info.right else None, + ) + + # Subquery to identify leaf nodes (clusters with no children) + leaf_nodes = ~exists().where(Contains.parent == Clusters.hash) + + # Recursive CTE to get all descendants + descendants = select( + Contains.parent.label("component"), + Contains.child.label("descendant"), + sqltext("1").label("depth"), + ).cte(recursive=True) + + descendants_recursive = descendants.union_all( + select( + descendants.c.component, + Contains.child.label("descendant"), + descendants.c.depth + 1, + ).join(Contains, Contains.parent == descendants.c.descendant) + ) + + # Final query to get all components with their leaf descendants + components_query = ( + select( + Clusters.hash.label("parent"), + descendants_recursive.c.descendant.label("child"), + Probabilities.probability.label("threshold"), + ) + .join( + Probabilities, + and_( + Probabilities.cluster == Clusters.hash, + Probabilities.model == model.hash, + ), + ) + .join(descendants_recursive, descendants_recursive.c.component == Clusters.hash) + .join(Clusters.children) + .where(leaf_nodes) + .order_by(Probabilities.probability.desc()) + .distinct() + ) + + return ClusterResults( + dataframe=sql_to_df(stmt=components_query, engine=engine, return_type="pandas"), + metadata=metadata, + ) diff --git a/test/server/test_adapter.py b/test/server/test_adapter.py index 7cee7c67..31cfa256 100644 --- a/test/server/test_adapter.py +++ b/test/server/test_adapter.py @@ -1,4 +1,3 @@ -import random from typing import Callable import pytest @@ -10,9 +9,10 @@ MatchboxModelError, ) from matchbox.common.hash import HASH_FUNC +from matchbox.common.results import ClusterResults, ProbabilityResults, Results from matchbox.helpers.selector import query, selector, selectors from matchbox.server.base import MatchboxModelAdapter -from matchbox.server.models import Cluster, Probability, Source +from matchbox.server.models import Source from pandas import DataFrame from ..fixtures.models import ( @@ -71,8 +71,12 @@ def test_model_properties(self): naive_crn = self.backend.get_model(model="naive_test.crn") assert naive_crn.hash assert naive_crn.name - assert isinstance(naive_crn.probabilities.count(), int) - assert isinstance(naive_crn.clusters.count(), int) + assert naive_crn.probabilities + assert naive_crn.clusters + assert naive_crn.results + assert naive_crn.truth + assert naive_crn.ancestors + assert naive_crn.ancestors_cache def test_validate_hashes(self): """Test validating data hashes.""" @@ -210,191 +214,131 @@ def test_insert_model(self): assert self.backend.models.count() == model_count + 3 - def test_model_insert_probabilities(self): - """Test that model insert probabilities are correct.""" + def test_model_get_probabilities(self): + """Test that a model's ProbabilityResults can be retrieved.""" self.setup_database("dedupe") + naive_crn = self.backend.get_model(model="naive_test.crn") + assert isinstance(naive_crn.probabilities, ProbabilityResults) + assert len(naive_crn.probabilities.dataframe) > 0 + assert naive_crn.probabilities.metadata.name == "naive_test.crn" - crn = self.warehouse_data[0] - select_crn = selector( - table=str(crn), - fields=["crn"], - engine=crn.database.engine, - ) - df_crn = query( - selector=select_crn, - backend=self.backend, - model=None, - return_type="pandas", - ) + def test_model_get_clusters(self): + """Test that a model's ClusterResults can be retrieved.""" + self.setup_database("dedupe") + naive_crn = self.backend.get_model(model="naive_test.crn") + assert isinstance(naive_crn.clusters, ClusterResults) + assert len(naive_crn.clusters.dataframe) > 0 + assert naive_crn.clusters.metadata.name == "naive_test.crn" - duns = self.warehouse_data[1] - select_duns = selector( - table=str(duns), - fields=["id", "duns"], - engine=duns.database.engine, - ) - df_crn_deduped = query( - selector=select_crn, - backend=self.backend, - model="naive_test.crn", - return_type="pandas", - ) - df_duns_deduped = query( - selector=select_duns, - backend=self.backend, - model="naive_test.duns", - return_type="pandas", - ) + def test_model_truth(self): + """Test that a model's truth can be set and retrieved.""" + self.setup_database("dedupe") + naive_crn = self.backend.get_model(model="naive_test.crn") + # Retrieve + pre_truth = naive_crn.truth - self.backend.insert_model( - "dedupe_1", left=str(crn), description="Test deduper 1" - ) - self.backend.insert_model( - "dedupe_2", left=str(duns), description="Test deduper 1" - ) - self.backend.insert_model( - "link_1", left="dedupe_1", right="dedupe_2", description="Test linker 1" - ) + # Set + naive_crn.truth = 0.5 - # Test dedupe probabilities - dedupe_probabilities = [ - Probability( - hash=HASH_FUNC(random.randbytes(32)).digest(), - left=crn_prob_1, - right=crn_prob_2, - probability=1.0, - ) - for crn_prob_1, crn_prob_2 in zip( - df_crn["hash"].to_list()[:10], - reversed(df_crn["hash"].to_list()[:10]), - strict=True, - ) - ] - dedupe_1 = self.backend.get_model(model="dedupe_1") + # Retrieve again + post_truth = naive_crn.truth - assert dedupe_1.probabilities.count() == 0 + # Check difference + assert pre_truth != post_truth - dedupe_1.insert_probabilities( - probabilities=dedupe_probabilities, - probability_type="deduplications", - batch_size=10, - ) + def test_model_ancestors(self): + """Test that a model's ancestors can be retrieved.""" + self.setup_database("link") + linker_name = "deterministic_naive_test.crn_naive_test.duns" + linker = self.backend.get_model(model=linker_name) - assert dedupe_1.probabilities.count() == 10 + assert isinstance(linker.ancestors, dict) - # Test link probabilities - link_probabilities = [ - Probability( - hash=HASH_FUNC(random.randbytes(32)).digest(), - left=crn_prob, - right=duns_prob, - probability=1.0, - ) - for crn_prob, duns_prob in zip( - df_crn_deduped["hash"].to_list()[:10], - df_duns_deduped["hash"].to_list()[:10], - strict=True, - ) - ] - link_1 = self.backend.get_model(model="link_1") + truth_found = False + for model, truth in linker.ancestors.items(): + if isinstance(truth, float): + # Not all ancestors have truth values, but one must + truth_found = True + assert isinstance(model, str) + assert isinstance(truth, float or None) - assert link_1.probabilities.count() == 0 + assert truth_found - link_1.insert_probabilities( - probabilities=link_probabilities, - probability_type="links", - batch_size=10, - ) + def test_model_results(self): + """Test that a model's Results can be set and retrieved.""" + self.setup_database("dedupe") + naive_crn = self.backend.get_model(model="naive_test.crn") - assert link_1.probabilities.count() == 10 + # Retrieve + pre_results = naive_crn.results - def test_model_insert_clusters(self): - """Test that model insert clusters are correct.""" - self.setup_database("dedupe") + assert isinstance(pre_results, Results) + assert len(pre_results.probabilities.dataframe) > 0 + assert pre_results.probabilities.metadata.name == "naive_test.crn" + assert len(pre_results.clusters.dataframe) > 0 + assert pre_results.clusters.metadata.name == "naive_test.crn" - crn = self.warehouse_data[0] - select_crn = selector( - table=str(crn), - fields=["crn"], - engine=crn.database.engine, - ) - df_crn = query( - selector=select_crn, - backend=self.backend, - model=None, - return_type="pandas", - ) + # Set + hash_to_remove = pre_results.probabilities.dataframe["hash"].iloc[0] + df_probabilities_truncated = pre_results.probabilities.dataframe[ + pre_results.probabilities.dataframe["hash"] != hash_to_remove + ] + df_clusters_truncated = pre_results.clusters.dataframe[ + pre_results.clusters.dataframe["parent"] != hash_to_remove + ] - duns = self.warehouse_data[1] - select_duns = selector( - table=str(duns), - fields=["id", "duns"], - engine=duns.database.engine, - ) - df_crn_deduped = query( - selector=select_crn, - backend=self.backend, - model="naive_test.crn", - return_type="pandas", - ) - df_duns_deduped = query( - selector=select_duns, - backend=self.backend, - model="naive_test.duns", - return_type="pandas", + results = Results( + probabilities=ProbabilityResults( + dataframe=df_probabilities_truncated, + model=pre_results.probabilities.model, + metadata=pre_results.probabilities.metadata, + ), + clusters=ClusterResults( + dataframe=df_clusters_truncated, + model=pre_results.clusters.model, + metadata=pre_results.clusters.metadata, + ), ) - self.backend.insert_model( - "dedupe_1", left=str(crn), description="Test deduper 1" - ) - self.backend.insert_model( - "dedupe_2", left=str(duns), description="Test deduper 1" - ) - self.backend.insert_model( - "link_1", left="dedupe_1", right="dedupe_2", description="Test linker 1" - ) + naive_crn.results = results - # Test dedupe clusters - dedupe_clusters = [ - Cluster(parent=crn_prob_1, child=crn_prob_2) - for crn_prob_1, crn_prob_2 in zip( - df_crn["hash"].to_list()[:10], - reversed(df_crn["hash"].to_list()[:10]), - strict=True, - ) - ] - dedupe_1 = self.backend.get_model(model="dedupe_1") + # Retrieve again + post_results = naive_crn.clusters - assert dedupe_1.clusters.count() == 0 + # Check difference + assert len(pre_results.probabilities.dataframe) != len( + post_results.probabilities.dataframe + ) + assert len(pre_results.clusters.dataframe) != len( + post_results.clusters.dataframe + ) - dedupe_1.insert_clusters( - clusters=dedupe_clusters, - cluster_type="deduplications", - batch_size=10, + # Check similarity + assert ( + pre_results.probabilities.metadata.name + == post_results.probabilities.metadata.name ) + assert pre_results.clusters.metadata.name == post_results.clusters.metadata.name - assert dedupe_1.clusters.count() == 10 + def test_model_ancestors_cache(self): + """Test that a model's ancestors cache can be set and retrieved.""" + self.setup_database("link") + linker_name = "deterministic_naive_test.crn_naive_test.duns" + linker = self.backend.get_model(model=linker_name) - # Test link clusters - link_clusters = [ - Cluster(parent=crn_prob, child=duns_prob) - for crn_prob, duns_prob in zip( - df_crn_deduped["hash"].to_list()[:10], - df_duns_deduped["hash"].to_list()[:10], - strict=True, - ) - ] - link_1 = self.backend.get_model(model="link_1") + # Retrieve + pre_ancestors_cache = linker.ancestors_cache - assert link_1.clusters.count() == 0 + # Set + updated_ancestors_cache = {k: 0.5 for k in pre_ancestors_cache.keys()} + linker.ancessors_cache = updated_ancestors_cache - link_1.insert_clusters( - clusters=link_clusters, - cluster_type="links", - batch_size=10, - ) + # Retrieve again + post_ancestors_cache = linker.ancestors_cache - assert link_1.clusters.count() == 10 + # Check difference + assert pre_ancestors_cache != post_ancestors_cache + assert post_ancestors_cache == updated_ancestors_cache def test_index( self,