Skip to content

Commit

Permalink
Refactored results and adapter but no unit tests run yet
Browse files Browse the repository at this point in the history
  • Loading branch information
Will Langdale committed Oct 30, 2024
1 parent 4efd84c commit 0826fa9
Show file tree
Hide file tree
Showing 10 changed files with 709 additions and 614 deletions.
195 changes: 80 additions & 115 deletions src/matchbox/common/results.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from abc import ABC, abstractmethod
from enum import StrEnum
from typing import TYPE_CHECKING, Any, List

import rustworkx as rx
Expand All @@ -11,11 +12,11 @@
from matchbox.server.base import MatchboxDBAdapter, inject_backend
from matchbox.server.models import Cluster, Probability
from pandas import DataFrame
from pydantic import BaseModel, ConfigDict, model_validator
from pydantic import BaseModel, ConfigDict, field_validator, model_validator
from sqlalchemy import Table

if TYPE_CHECKING:
from matchbox.models.models import Model
from matchbox.models.models import Model, ModelMetadata
else:
Model = Any

Expand All @@ -25,19 +26,43 @@
load_dotenv(dotenv_path)


class ModelType(StrEnum):
"""Enumeration of supported model types."""

LINKER = "linker"
DEDUPER = "deduper"


class ModelMetadata(BaseModel):
"""Metadata for a model."""

name: str
description: str
type: ModelType
left_source: str
right_source: str | None = None # Only used for linker models


class ResultsBaseDataclass(BaseModel, ABC):
"""Base class for results dataclasses.
Model is required during construction and calculation, but not when loading
from storage.
"""

model_config = ConfigDict(arbitrary_types_allowed=True)

dataframe: DataFrame
model: Model
model: Model | None
metadata: ModelMetadata

_expected_fields: list[str]

@model_validator(mode="after")
def _check_dataframe(self) -> Table:
"""Verifies the table contains the expected fields."""
table_fields = sorted(self.dataframe.columns)
expected_fields = sorted(self._expected_fields)
table_fields = set(self.dataframe.columns)
expected_fields = set(self._expected_fields)

if table_fields != expected_fields:
raise ValueError(f"Expected {expected_fields}. \n" f"Found {table_fields}.")
Expand All @@ -63,7 +88,8 @@ def to_records(self) -> list[Probability | Cluster]:
class ProbabilityResults(ResultsBaseDataclass):
"""Probabilistic matches produced by linkers and dedupers.
Inherits the following attributes from ResultsBaseDataclass.
There are pairs of records/clusters with a probability of being a match.
The hash is the hash of the sorted left and right ids.
_expected_fields enforces the shape of the dataframe.
Expand All @@ -73,11 +99,26 @@ class ProbabilityResults(ResultsBaseDataclass):
"""

_expected_fields: list[str] = [
"hash",
"left_id",
"right_id",
"probability",
]

@field_validator("dataframe")
@classmethod
def add_hash(cls, dataframe: DataFrame) -> DataFrame:
"""Adds a hash column to the dataframe if it doesn't already exist."""
if "hash" not in dataframe.columns:
dataframe[["left_id", "right_id"]] = dataframe[
["left_id", "right_id"]
].astype("binary[pyarrow]")
dataframe["hash"] = columns_to_value_ordered_hash(
data=dataframe, columns=["left_id", "right_id"]
)
dataframe["hash"] = dataframe["hash"].astype("binary[pyarrow]")
return dataframe[["hash", "left_id", "right_id", "probability"]]

def inspect_with_source(
self, left_data: DataFrame, left_key: str, right_data: DataFrame, right_key: str
) -> DataFrame:
Expand Down Expand Up @@ -130,28 +171,19 @@ def to_records(self, backend: MatchboxDBAdapter | None) -> set[Probability]:
backend.validate_hashes(hashes=self.dataframe.left_id.unique().tolist())
backend.validate_hashes(hashes=self.dataframe.right_id.unique().tolist())

# Preprocess the dataframe
pre_prep_df = self.dataframe[["left_id", "right_id", "probability"]].copy()
pre_prep_df[["left_id", "right_id"]] = pre_prep_df[
["left_id", "right_id"]
].astype("binary[pyarrow]")
pre_prep_df["sha1"] = columns_to_value_ordered_hash(
data=pre_prep_df, columns=["left_id", "right_id"]
)
pre_prep_df["sha1"] = pre_prep_df["sha1"].astype("binary[pyarrow]")

return {
Probability(hash=row[0], left=row[1], right=row[2], probability=row[3])
for row in pre_prep_df[
["sha1", "left_id", "right_id", "probability"]
for row in self.dataframe[
["hash", "left_id", "right_id", "probability"]
].to_numpy()
}


class ClusterResults(ResultsBaseDataclass):
"""Cluster data produced by using to_clusters on ProbabilityResults.
Inherits the following attributes from ResultsBaseDataclass.
This is the connected components of the probabilistic matches at every
threshold of probabilitity. The parent is the hash of the sorted children.
_expected_fields enforces the shape of the dataframe.
Expand Down Expand Up @@ -211,134 +243,67 @@ def to_records(self) -> set[Cluster]:


class Results(BaseModel):
"""A container for the results of a model run.
Contains all the information any backend will need to store the results.
"""

model_config = ConfigDict(arbitrary_types_allowed=True)

model: Model
probabilities: ProbabilityResults
clusters: ClusterResults

@inject_backend
def to_matchbox(self, backend: MatchboxDBAdapter) -> None:
"""Writes the results to the Matchbox database."""
self.model.insert_model()
self.model.clusters = self


def process_components(
G: rx.PyGraph,
current_pairs: set[bytes],
added: dict[bytes, int],
pair_children: dict[bytes, list[bytes]],
) -> list[tuple[bytes, list[bytes], int]]:
"""
Process connected components in the current graph state.
Identifies which 2-item components have merged into larger components.
if self.probabilities.model != self.clusters.model:
raise ValueError("Probabilities and clusters must be from the same model.")

Returns:
List of (parent_hash, children, size) for components > 2 items
where children includes both individual items and parent hashes
"""
new_components = []
component_with_size = [
(component, len(component)) for component in rx.connected_components(G)
]

for component, size in component_with_size:
if size <= 2:
continue

# Get all node hashes in component
node_hashes = [G.get_node_data(node) for node in component]

# Find which 2-item parents are part of this component
component_pairs = {
pair
for pair in current_pairs
if any(G.has_node(added[h]) for h in node_hashes)
}

# Children are individual nodes not in pairs, plus the pair parents
children = component_pairs | {
h
for h in node_hashes
if not any(h in pair_children[p] for p in component_pairs)
}

parent_hash = list_to_value_ordered_hash(sorted(children))
new_components.append((parent_hash, list(children)))

return new_components
self.clusters.model.insert_model()
self.clusters.model.results = self


def to_clusters(results: ProbabilityResults) -> ClusterResults:
"""
Takes a models probabilistic outputs and turns them into clusters.
Performs connected components at decreasing thresholds from 1.0 to return every
possible component in a hierarchical tree.
* Stores all two-item components with their original probabilities
* For larger components, stores the individual items and two-item parent hashes
as children, with a new parent hash
Args:
results: ProbabilityResults object
Converts probabilities into a list of connected components formed at each threshold.
Returns:
ClusterResults object
ClusterResults sorted by threshold descending.
"""
G = rx.PyGraph()
added: dict[bytes, int] = {}
pair_children: dict[bytes, list[bytes]] = {}
current_pairs: set[bytes] = set()
seen_larger: set[bytes] = set()

clusters = {"parent": [], "child": [], "threshold": []}
components: dict[str, list] = {"parent": [], "child": [], "threshold": []}

# 1. Create all 2-item components with original probabilities
initial_edges = (
results.dataframe.filter(["left_id", "right_id", "probability"])
# Sort probabilities descending and process in order of decreasing probability
edges = (
results.dataframe.sort_values("probability", ascending=False)
.filter(["left_id", "right_id", "probability"])
.astype({"left_id": "binary[pyarrow]", "right_id": "binary[pyarrow]"})
.itertuples(index=False, name=None)
)

for left, right, prob in initial_edges:
for left, right, prob in edges:
# Add nodes if not seen before
for hash_val in (left, right):
if hash_val not in added:
idx = G.add_node(hash_val)
added[hash_val] = idx

children = sorted([left, right])
parent_hash = list_to_value_ordered_hash(children)

pair_children[parent_hash] = children
current_pairs.add(parent_hash)

clusters["parent"].extend([parent_hash] * 2)
clusters["child"].extend(children)
clusters["threshold"].extend([prob] * 2)

# Get state, add edge, get new state, add anything new to results
old_components = {frozenset(comp) for comp in rx.connected_components(G)}
G.add_edge(added[left], added[right], None)
new_components = {frozenset(comp) for comp in rx.connected_components(G)}

# 2. Process at each probability threshold
sorted_probabilities = sorted(
results.dataframe["probability"].unique(), reverse=True
)

for threshold in sorted_probabilities:
# Find new larger components at this threshold
new_components = process_components(G, current_pairs)
for comp in new_components - old_components:
children = sorted([G.get_node_data(n) for n in comp])
parent = list_to_value_ordered_hash(children)

# Add new components to results
for parent_hash, children in new_components:
if parent_hash not in seen_larger:
seen_larger.add(parent_hash)
clusters["parent"].extend([parent_hash] * len(children))
clusters["child"].extend(children)
clusters["threshold"].extend([threshold] * len(children))
components["parent"].extend([parent] * len(children))
components["child"].extend(children)
components["threshold"].extend([prob] * len(children))

return ClusterResults(
dataframe=DataFrame(clusters).convert_dtypes(dtype_backend="pyarrow"),
dataframe=DataFrame(components).convert_dtypes(dtype_backend="pyarrow"),
model=results.model,
metadata=results.metadata,
)
30 changes: 0 additions & 30 deletions src/matchbox/models/linkers/base.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import warnings
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict

from pandas import DataFrame
from pydantic import BaseModel, Field, ValidationInfo, field_validator

from matchbox.common.results import ProbabilityResults


class LinkerSettings(BaseModel):
"""
Expand Down Expand Up @@ -51,30 +48,3 @@ def prepare(self, left: DataFrame, right: DataFrame) -> None:
@abstractmethod
def link(self, left: DataFrame, right: DataFrame) -> DataFrame:
return


def make_linker(
link_run_name: str,
description: str,
linker: Linker,
linker_settings: Dict[str, Any],
left_data: DataFrame,
left_source: str,
right_data: DataFrame,
right_source: str,
) -> Callable[[DataFrame, DataFrame], ProbabilityResults]:
linker_instance = linker.from_settings(**linker_settings)
linker_instance.prepare(left=left_data, right=right_data)

def linker(
left_data: DataFrame = left_data, right_data: DataFrame = right_data
) -> ProbabilityResults:
return ProbabilityResults(
dataframe=linker_instance.link(left=left_data, right=right_data),
run_name=link_run_name,
description=description,
left=left_source,
right=right_source,
)

return linker
Loading

0 comments on commit 0826fa9

Please sign in to comment.