Skip to content

Commit

Permalink
Merge pull request #10 from uktrade/feature/splink4
Browse files Browse the repository at this point in the history
Upgrade to Splink 4
  • Loading branch information
leo-mazzone authored Nov 21, 2024
2 parents 29d88d1 + 7bb8e6d commit 5fb6536
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 85 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ dependencies = [
"pydantic>=2.9.2",
"python-dotenv>=1.0.1",
"rustworkx>=0.15.1",
"splink<4",
"splink>=4.0.5",
"sqlalchemy>=2.0.35",
"tomli>=2.0.1",
]
Expand Down
110 changes: 50 additions & 60 deletions src/matchbox/models/linkers/splinklinker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,30 @@
from typing import Any, Dict, List, Optional, Type

from pandas import DataFrame
from pydantic import BaseModel, Field, model_validator
from splink.duckdb.linker import DuckDBLinker
from splink.linker import Linker as SplinkLibLinkerClass
from pydantic import BaseModel, ConfigDict, Field, model_validator
from splink import DuckDBAPI, SettingsCreator
from splink import Linker as SplinkLibLinkerClass
from splink.internals.linker_components.training import LinkerTraining

from matchbox.models.linkers.base import Linker, LinkerSettings

logic_logger = logging.getLogger("mb_logic")


class SplinkLinkerFunction(BaseModel):
"""A method of splink.linker.Linker used to train the linker."""
"""A method of splink.Linker.training used to train the linker."""

function: str
arguments: Dict[str, Any]

@model_validator(mode="after")
def validate_function_and_arguments(self) -> "SplinkLinkerFunction":
if not hasattr(SplinkLibLinkerClass, self.function):
if not hasattr(LinkerTraining, self.function):
raise ValueError(
f"Function {self.function} not found as method of Splink Linker class"
)

splink_linker_func = getattr(SplinkLibLinkerClass, self.function)
splink_linker_func = getattr(LinkerTraining, self.function)
splink_linker_func_param_set = set(
inspect.signature(splink_linker_func).parameters.keys()
)
Expand All @@ -48,18 +49,20 @@ class SplinkSettings(LinkerSettings):
A data class to enforce the Splink linker's settings dictionary shape.
"""

linker_class: Type[SplinkLibLinkerClass] = Field(
default=DuckDBLinker,
model_config = ConfigDict(arbitrary_types_allowed=True)

database_api: Type[DuckDBAPI] = Field(
default=DuckDBAPI,
description="""
A Splink Linker class. Defaults to DuckDBLinker, and has only been tested
with this class.
The Splink DB API, to choose between DuckDB (default) and Spark (untested)
""",
validate_default=True,
)

linker_training_functions: List[SplinkLinkerFunction] = Field(
description="""
A list of dictionaries keyed to functions, with values of the function's
argument dictionary, to be run against the Linker in the order supplied.
A list of dictionaries where keys are the names of methods for
splink.Linker.training and values are dictionaries encoding the arguments of
those methods. Each function will be run in the order supplied.
Example:
Expand All @@ -81,54 +84,39 @@ class SplinkSettings(LinkerSettings):
"""
)
linker_settings: Dict = Field(
linker_settings: SettingsCreator = Field(
description="""
A valid settings dictionary for a Splink linker.
A valid Splink SettingsCreator.
See Splink's documentation for a full description of available settings.
https://moj-analytical-services.github.io/splink/settings_dict_guide.html
The following settings are enforced by the Company Matching Framework:
https://moj-analytical-services.github.io/splink/api_docs/settings_dict_guide.html
* link_type is set to "link_only"
* unique_id_column_name is set to the value of left_id and right_id, which
must match
* link_type must be set to "link_only"
* unique_id_column_name is overridden to the value of left_id and right_id,
which must match
Example:
>>> from splink.duckdb.blocking_rule_library import block_on
... import splink.duckdb.comparison_library as cl
... import splink.duckdb.comparison_template_library as ctl
>>> from splink import SettingsCreator, block_on
... import splink.comparison_library as cl
... import splink.comparison_template_library as ctl
...
... splink_settings={
... "retain_matching_columns": False,
... "retain_intermediate_calculation_columns": False,
... "blocking_rules_to_generate_predictions": [
... \"""
... (l.company_name = r.company_name)
... and (
... l.name_unusual_tokens <> ''
... and r.name_unusual_tokens <> ''
... )
... \""",
... \"""
... (l.postcode = r.postcode)
... and (
... l.postcode <> ''
... and r.postcode <> ''
... )
... \""",
... splink_settings = SettingsCreator(
... retain_matching_columns=False,
... retain_intermediate_calculation_columns=False,
... blocking_rules_to_generate_predictions=[
... block_on("company_name"),
... block_on("postcode"),
... ],
... "comparisons": [
... comparisons=[
... cl.jaro_winkler_at_thresholds(
... "company_name",
... [0.9, 0.6],
... term_frequency_adjustments=True
... ),
... ctl.postcode_comparison("postcode"),
... ],
... }
... ctl.postcode_comparison("postcode"),
... ]
... )
"""
)
threshold: Optional[float] = Field(
Expand Down Expand Up @@ -156,14 +144,15 @@ def check_ids_match(self) -> "SplinkSettings":
)
return self

@model_validator(mode="after")
def check_link_only(self) -> "SplinkSettings":
if self.linker_settings.link_type != "link_only":
raise ValueError('link_type must be set to "link_only"')
return self

@model_validator(mode="after")
def add_enforced_settings(self) -> "SplinkSettings":
enforced_settings = {
"link_type": "link_only",
"unique_id_column_name": self.left_id,
}
for k, v in enforced_settings.items():
self.linker_settings[k] = v
self.linker_settings.unique_id_column_name = self.left_id
return self


Expand All @@ -179,15 +168,13 @@ def from_settings(
cls,
left_id: str,
right_id: str,
linker_class: SplinkLibLinkerClass,
linker_training_functions: List[Dict[str, Any]],
linker_settings: Dict[str, Any],
linker_settings: SettingsCreator,
threshold: float,
) -> "SplinkLinker":
settings = SplinkSettings(
left_id=left_id,
right_id=right_id,
linker_class=linker_class,
linker_training_functions=[
SplinkLinkerFunction(**func) for func in linker_training_functions
],
Expand Down Expand Up @@ -218,14 +205,15 @@ def prepare(self, left: DataFrame, right: DataFrame) -> None:
left[self.settings.left_id] = left[self.settings.left_id].apply(str)
right[self.settings.right_id] = right[self.settings.right_id].apply(str)

self._linker = self.settings.linker_class(
self._linker = SplinkLibLinkerClass(
input_table_or_tables=[left, right],
input_table_aliases=["l", "r"],
settings_dict=self.settings.linker_settings,
settings=self.settings.linker_settings,
db_api=self.settings.database_api(),
)

for func in self.settings.linker_training_functions:
proc_func = getattr(self._linker, func.function)
proc_func = getattr(self._linker.training, func.function)
proc_func(**func.arguments)

def link(self, left: DataFrame = None, right: DataFrame = None) -> DataFrame:
Expand All @@ -235,7 +223,9 @@ def link(self, left: DataFrame = None, right: DataFrame = None) -> DataFrame:
"These values will be ignored"
)

res = self._linker.predict(threshold_match_probability=self.settings.threshold)
res = self._linker.inference.predict(
threshold_match_probability=self.settings.threshold
)

return (
res.as_pandas_dataframe()
Expand Down
47 changes: 47 additions & 0 deletions test/client/test_linkers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import pytest
from matchbox import make_model, query
from matchbox.helpers import selectors
from matchbox.models.linkers.splinklinker import SplinkLinkerFunction, SplinkSettings
from matchbox.server.models import Source, SourceWarehouse
from matchbox.server.postgresql import MatchboxPostgres
from pandas import DataFrame
from splink import SettingsCreator

from ..fixtures.db import AddDedupeModelsAndDataCallable, AddIndexedDataCallable
from ..fixtures.models import (
Expand Down Expand Up @@ -187,3 +189,48 @@ def unique_non_null(s):

assert isinstance(clusters, DataFrame)
assert clusters.hash.nunique() == fx_data.unique_n


def test_splink_training_functions():
# You can create a valid SplinkLinkerFunction
SplinkLinkerFunction(
function="estimate_u_using_random_sampling",
arguments={"max_pairs": 1e4},
)
# You can't reference a function that doesn't exist
with pytest.raises(ValueError):
SplinkLinkerFunction(function="made_up_funcname", arguments=dict())
# You can't pass arguments that don't exist
with pytest.raises(ValueError):
SplinkLinkerFunction(
function="estimate_u_using_random_sampling", arguments={"foo": "bar"}
)


def test_splink_settings():
valid_settings = SplinkSettings(
left_id="hash",
right_id="hash",
linker_training_functions=[],
linker_settings=SettingsCreator(link_type="link_only"),
threshold=None,
)
assert valid_settings.linker_settings.unique_id_column_name == "hash"
# Can only use "link_only"
with pytest.raises(ValueError):
valid_settings = SplinkSettings(
left_id="hash",
right_id="hash",
linker_training_functions=[],
linker_settings=SettingsCreator(link_type="dedupe_only"),
threshold=None,
)
# Left and right ID must coincide
with pytest.raises(ValueError):
valid_settings = SplinkSettings(
left_id="hash",
right_id="hash2",
linker_training_functions=[],
linker_settings=SettingsCreator(link_type="link_only"),
threshold=None,
)
22 changes: 11 additions & 11 deletions test/fixtures/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Callable

import splink.duckdb.comparison_library as cl
import splink.comparison_library as cl
from matchbox.models.dedupers import NaiveDeduper
from matchbox.models.dedupers.base import Deduper
from matchbox.models.linkers import (
Expand All @@ -10,8 +10,8 @@
)
from matchbox.models.linkers.base import Linker
from pydantic import BaseModel, Field
from splink.duckdb import blocking_rule_library as brl
from splink.duckdb.linker import DuckDBLinker
from splink import SettingsCreator
from splink import blocking_rule_library as brl


class DedupeTestParams(BaseModel):
Expand Down Expand Up @@ -236,21 +236,21 @@ def make_splink_li_settings(data: LinkTestParams) -> dict[str, Any]:
# The m parameter is 1 because we're testing in a deterministic system, and
# many of these tests only have one field, so we can't use expectation
# maximisation to estimate. For testing raw functionality, fine to use 1
linker_settings = {
"retain_matching_columns": False,
"retain_intermediate_calculation_columns": False,
"blocking_rules_to_generate_predictions": [
linker_settings = SettingsCreator(
link_type="link_only",
retain_matching_columns=False,
retain_intermediate_calculation_columns=False,
blocking_rules_to_generate_predictions=[
brl.block_on(field) for field in fields
],
"comparisons": [
cl.exact_match(field, m_probability_exact_match=1) for field in fields
comparisons=[
cl.ExactMatch(field).configure(m_probabilities=[1, 0]) for field in fields
],
}
)

return {
"left_id": "hash",
"right_id": "hash",
"linker_class": DuckDBLinker,
"linker_training_functions": linker_training_functions,
"linker_settings": linker_settings,
"threshold": None,
Expand Down
18 changes: 5 additions & 13 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 5fb6536

Please sign in to comment.