diff --git a/pyproject.toml b/pyproject.toml index f4d5d094..65f6556c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ dependencies = [ "pydantic>=2.9.2", "python-dotenv>=1.0.1", "rustworkx>=0.15.1", - "splink<4", + "splink>=4.0.5", "sqlalchemy>=2.0.35", "tomli>=2.0.1", ] diff --git a/src/matchbox/models/linkers/splinklinker.py b/src/matchbox/models/linkers/splinklinker.py index b6d74d22..69116158 100644 --- a/src/matchbox/models/linkers/splinklinker.py +++ b/src/matchbox/models/linkers/splinklinker.py @@ -4,9 +4,10 @@ from typing import Any, Dict, List, Optional, Type from pandas import DataFrame -from pydantic import BaseModel, Field, model_validator -from splink.duckdb.linker import DuckDBLinker -from splink.linker import Linker as SplinkLibLinkerClass +from pydantic import BaseModel, ConfigDict, Field, model_validator +from splink import DuckDBAPI, SettingsCreator +from splink import Linker as SplinkLibLinkerClass +from splink.internals.linker_components.training import LinkerTraining from matchbox.models.linkers.base import Linker, LinkerSettings @@ -14,19 +15,19 @@ class SplinkLinkerFunction(BaseModel): - """A method of splink.linker.Linker used to train the linker.""" + """A method of splink.Linker.training used to train the linker.""" function: str arguments: Dict[str, Any] @model_validator(mode="after") def validate_function_and_arguments(self) -> "SplinkLinkerFunction": - if not hasattr(SplinkLibLinkerClass, self.function): + if not hasattr(LinkerTraining, self.function): raise ValueError( f"Function {self.function} not found as method of Splink Linker class" ) - splink_linker_func = getattr(SplinkLibLinkerClass, self.function) + splink_linker_func = getattr(LinkerTraining, self.function) splink_linker_func_param_set = set( inspect.signature(splink_linker_func).parameters.keys() ) @@ -48,18 +49,20 @@ class SplinkSettings(LinkerSettings): A data class to enforce the Splink linker's settings dictionary shape. """ - linker_class: Type[SplinkLibLinkerClass] = Field( - default=DuckDBLinker, + model_config = ConfigDict(arbitrary_types_allowed=True) + + database_api: Type[DuckDBAPI] = Field( + default=DuckDBAPI, description=""" - A Splink Linker class. Defaults to DuckDBLinker, and has only been tested - with this class. + The Splink DB API, to choose between DuckDB (default) and Spark (untested) """, - validate_default=True, ) + linker_training_functions: List[SplinkLinkerFunction] = Field( description=""" - A list of dictionaries keyed to functions, with values of the function's - argument dictionary, to be run against the Linker in the order supplied. + A list of dictionaries where keys are the names of methods for + splink.Linker.training and values are dictionaries encoding the arguments of + those methods. Each function will be run in the order supplied. Example: @@ -81,54 +84,39 @@ class SplinkSettings(LinkerSettings): """ ) - linker_settings: Dict = Field( + linker_settings: SettingsCreator = Field( description=""" - A valid settings dictionary for a Splink linker. + A valid Splink SettingsCreator. See Splink's documentation for a full description of available settings. - https://moj-analytical-services.github.io/splink/settings_dict_guide.html - - The following settings are enforced by the Company Matching Framework: + https://moj-analytical-services.github.io/splink/api_docs/settings_dict_guide.html - * link_type is set to "link_only" - * unique_id_column_name is set to the value of left_id and right_id, which - must match + * link_type must be set to "link_only" + * unique_id_column_name is overridden to the value of left_id and right_id, + which must match Example: - >>> from splink.duckdb.blocking_rule_library import block_on - ... import splink.duckdb.comparison_library as cl - ... import splink.duckdb.comparison_template_library as ctl + >>> from splink import SettingsCreator, block_on + ... import splink.comparison_library as cl + ... import splink.comparison_template_library as ctl ... - ... splink_settings={ - ... "retain_matching_columns": False, - ... "retain_intermediate_calculation_columns": False, - ... "blocking_rules_to_generate_predictions": [ - ... \""" - ... (l.company_name = r.company_name) - ... and ( - ... l.name_unusual_tokens <> '' - ... and r.name_unusual_tokens <> '' - ... ) - ... \""", - ... \""" - ... (l.postcode = r.postcode) - ... and ( - ... l.postcode <> '' - ... and r.postcode <> '' - ... ) - ... \""", + ... splink_settings = SettingsCreator( + ... retain_matching_columns=False, + ... retain_intermediate_calculation_columns=False, + ... blocking_rules_to_generate_predictions=[ + ... block_on("company_name"), + ... block_on("postcode"), ... ], - ... "comparisons": [ + ... comparisons=[ ... cl.jaro_winkler_at_thresholds( ... "company_name", ... [0.9, 0.6], ... term_frequency_adjustments=True ... ), - ... ctl.postcode_comparison("postcode"), - ... ], - ... } - + ... ctl.postcode_comparison("postcode"), + ... ] + ... ) """ ) threshold: Optional[float] = Field( @@ -156,14 +144,15 @@ def check_ids_match(self) -> "SplinkSettings": ) return self + @model_validator(mode="after") + def check_link_only(self) -> "SplinkSettings": + if self.linker_settings.link_type != "link_only": + raise ValueError('link_type must be set to "link_only"') + return self + @model_validator(mode="after") def add_enforced_settings(self) -> "SplinkSettings": - enforced_settings = { - "link_type": "link_only", - "unique_id_column_name": self.left_id, - } - for k, v in enforced_settings.items(): - self.linker_settings[k] = v + self.linker_settings.unique_id_column_name = self.left_id return self @@ -179,15 +168,13 @@ def from_settings( cls, left_id: str, right_id: str, - linker_class: SplinkLibLinkerClass, linker_training_functions: List[Dict[str, Any]], - linker_settings: Dict[str, Any], + linker_settings: SettingsCreator, threshold: float, ) -> "SplinkLinker": settings = SplinkSettings( left_id=left_id, right_id=right_id, - linker_class=linker_class, linker_training_functions=[ SplinkLinkerFunction(**func) for func in linker_training_functions ], @@ -218,14 +205,15 @@ def prepare(self, left: DataFrame, right: DataFrame) -> None: left[self.settings.left_id] = left[self.settings.left_id].apply(str) right[self.settings.right_id] = right[self.settings.right_id].apply(str) - self._linker = self.settings.linker_class( + self._linker = SplinkLibLinkerClass( input_table_or_tables=[left, right], input_table_aliases=["l", "r"], - settings_dict=self.settings.linker_settings, + settings=self.settings.linker_settings, + db_api=self.settings.database_api(), ) for func in self.settings.linker_training_functions: - proc_func = getattr(self._linker, func.function) + proc_func = getattr(self._linker.training, func.function) proc_func(**func.arguments) def link(self, left: DataFrame = None, right: DataFrame = None) -> DataFrame: @@ -235,7 +223,9 @@ def link(self, left: DataFrame = None, right: DataFrame = None) -> DataFrame: "These values will be ignored" ) - res = self._linker.predict(threshold_match_probability=self.settings.threshold) + res = self._linker.inference.predict( + threshold_match_probability=self.settings.threshold + ) return ( res.as_pandas_dataframe() diff --git a/test/client/test_linkers.py b/test/client/test_linkers.py index 5b75f6f6..1664d6aa 100644 --- a/test/client/test_linkers.py +++ b/test/client/test_linkers.py @@ -1,9 +1,11 @@ import pytest from matchbox import make_model, query from matchbox.helpers import selectors +from matchbox.models.linkers.splinklinker import SplinkLinkerFunction, SplinkSettings from matchbox.server.models import Source, SourceWarehouse from matchbox.server.postgresql import MatchboxPostgres from pandas import DataFrame +from splink import SettingsCreator from ..fixtures.db import AddDedupeModelsAndDataCallable, AddIndexedDataCallable from ..fixtures.models import ( @@ -187,3 +189,48 @@ def unique_non_null(s): assert isinstance(clusters, DataFrame) assert clusters.hash.nunique() == fx_data.unique_n + + +def test_splink_training_functions(): + # You can create a valid SplinkLinkerFunction + SplinkLinkerFunction( + function="estimate_u_using_random_sampling", + arguments={"max_pairs": 1e4}, + ) + # You can't reference a function that doesn't exist + with pytest.raises(ValueError): + SplinkLinkerFunction(function="made_up_funcname", arguments=dict()) + # You can't pass arguments that don't exist + with pytest.raises(ValueError): + SplinkLinkerFunction( + function="estimate_u_using_random_sampling", arguments={"foo": "bar"} + ) + + +def test_splink_settings(): + valid_settings = SplinkSettings( + left_id="hash", + right_id="hash", + linker_training_functions=[], + linker_settings=SettingsCreator(link_type="link_only"), + threshold=None, + ) + assert valid_settings.linker_settings.unique_id_column_name == "hash" + # Can only use "link_only" + with pytest.raises(ValueError): + valid_settings = SplinkSettings( + left_id="hash", + right_id="hash", + linker_training_functions=[], + linker_settings=SettingsCreator(link_type="dedupe_only"), + threshold=None, + ) + # Left and right ID must coincide + with pytest.raises(ValueError): + valid_settings = SplinkSettings( + left_id="hash", + right_id="hash2", + linker_training_functions=[], + linker_settings=SettingsCreator(link_type="link_only"), + threshold=None, + ) diff --git a/test/fixtures/models.py b/test/fixtures/models.py index 06b1f796..d2747d91 100644 --- a/test/fixtures/models.py +++ b/test/fixtures/models.py @@ -1,6 +1,6 @@ from typing import Any, Callable -import splink.duckdb.comparison_library as cl +import splink.comparison_library as cl from matchbox.models.dedupers import NaiveDeduper from matchbox.models.dedupers.base import Deduper from matchbox.models.linkers import ( @@ -10,8 +10,8 @@ ) from matchbox.models.linkers.base import Linker from pydantic import BaseModel, Field -from splink.duckdb import blocking_rule_library as brl -from splink.duckdb.linker import DuckDBLinker +from splink import SettingsCreator +from splink import blocking_rule_library as brl class DedupeTestParams(BaseModel): @@ -236,21 +236,21 @@ def make_splink_li_settings(data: LinkTestParams) -> dict[str, Any]: # The m parameter is 1 because we're testing in a deterministic system, and # many of these tests only have one field, so we can't use expectation # maximisation to estimate. For testing raw functionality, fine to use 1 - linker_settings = { - "retain_matching_columns": False, - "retain_intermediate_calculation_columns": False, - "blocking_rules_to_generate_predictions": [ + linker_settings = SettingsCreator( + link_type="link_only", + retain_matching_columns=False, + retain_intermediate_calculation_columns=False, + blocking_rules_to_generate_predictions=[ brl.block_on(field) for field in fields ], - "comparisons": [ - cl.exact_match(field, m_probability_exact_match=1) for field in fields + comparisons=[ + cl.ExactMatch(field).configure(m_probabilities=[1, 0]) for field in fields ], - } + ) return { "left_id": "hash", "right_id": "hash", - "linker_class": DuckDBLinker, "linker_training_functions": linker_training_functions, "linker_settings": linker_settings, "threshold": None, diff --git a/uv.lock b/uv.lock index 14505b5c..2aa0fe28 100644 --- a/uv.lock +++ b/uv.lock @@ -807,7 +807,7 @@ requires-dist = [ { name = "pydantic-settings", specifier = ">=2.5.2" }, { name = "python-dotenv", specifier = ">=1.0.1" }, { name = "rustworkx", specifier = ">=0.15.1" }, - { name = "splink", specifier = "<4" }, + { name = "splink", specifier = ">=4.0.5" }, { name = "sqlalchemy", specifier = ">=2.0.35" }, { name = "tomli", specifier = ">=2.0.1" }, ] @@ -1047,12 +1047,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3c/c3/a21d017f41c4d7603c0aa895ad781ea24fa7c9cc412056aee119eb326883/pg_force_execute-0.0.11-py3-none-any.whl", hash = "sha256:250587c0f4c51a2997454442a0f39c2ab4113dc70ebae2015f1556f080595e4a", size = 4492 }, ] -[[package]] -name = "phonetics" -version = "1.0.5" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/67/a5/d1b6dbcbb05477aa5f0c5e73a7d68c6d23ab098af4461072f00999ed573a/phonetics-1.0.5.tar.gz", hash = "sha256:16263948c82fce1e257964b2ab4adc953f995e0fa7e2e60e6ba336d77a7235ba", size = 8848 } - [[package]] name = "pillow" version = "10.4.0" @@ -1159,8 +1153,6 @@ version = "6.0.0" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/18/c7/8c6872f7372eb6a6b2e4708b88419fb46b857f7a2e1892966b851cc79fc9/psutil-6.0.0.tar.gz", hash = "sha256:8faae4f310b6d969fa26ca0545338b21f73c6b15db7c4a8d934a5482faa818f2", size = 508067 } wheels = [ - { url = "https://files.pythonhosted.org/packages/c5/66/78c9c3020f573c58101dc43a44f6855d01bbbd747e24da2f0c4491200ea3/psutil-6.0.0-cp27-none-win32.whl", hash = "sha256:02b69001f44cc73c1c5279d02b30a817e339ceb258ad75997325e0e6169d8b35", size = 249766 }, - { url = "https://files.pythonhosted.org/packages/e1/3f/2403aa9558bea4d3854b0e5e567bc3dd8e9fbc1fc4453c0aa9aafeb75467/psutil-6.0.0-cp27-none-win_amd64.whl", hash = "sha256:21f1fb635deccd510f69f485b87433460a603919b45e2a324ad65b0cc74f8fb1", size = 253024 }, { url = "https://files.pythonhosted.org/packages/0b/37/f8da2fbd29690b3557cca414c1949f92162981920699cd62095a984983bf/psutil-6.0.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:c588a7e9b1173b6e866756dde596fd4cad94f9399daf99ad8c3258b3cb2b47a0", size = 250961 }, { url = "https://files.pythonhosted.org/packages/35/56/72f86175e81c656a01c4401cd3b1c923f891b31fbcebe98985894176d7c9/psutil-6.0.0-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ed2440ada7ef7d0d608f20ad89a04ec47d2d3ab7190896cd62ca5fc4fe08bf0", size = 287478 }, { url = "https://files.pythonhosted.org/packages/19/74/f59e7e0d392bc1070e9a70e2f9190d652487ac115bb16e2eff6b22ad1d24/psutil-6.0.0-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5fd9a97c8e94059b0ef54a7d4baf13b405011176c3b6ff257c247cae0d560ecd", size = 290455 }, @@ -1633,7 +1625,7 @@ wheels = [ [[package]] name = "splink" -version = "3.9.15" +version = "4.0.5" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "altair" }, @@ -1641,13 +1633,13 @@ dependencies = [ { name = "igraph" }, { name = "jinja2" }, { name = "jsonschema" }, + { name = "numpy" }, { name = "pandas" }, - { name = "phonetics" }, { name = "sqlglot" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c9/fe/457c18d9f54e6b34ddd4a908ee90e61b8287006560d36952afec7cae45d9/splink-3.9.15.tar.gz", hash = "sha256:d52a4f2e48567b502621924cbd909f88c6cb88b32442d575deb4c16bbbb2ccad", size = 3655727 } +sdist = { url = "https://files.pythonhosted.org/packages/9e/51/66dd1871f1ed6edaad43dc1121dd1e59d4ef0c5d3cd993b23b5c751ab94e/splink-4.0.5.tar.gz", hash = "sha256:72dbdaa7a1211733018d01a80b87f3bfecd32216a1693b1c67fe31db9034f356", size = 3654992 } wheels = [ - { url = "https://files.pythonhosted.org/packages/96/8a/99cf732fb1a6aac4535e0c4a641c1159a5f1d2fa9bce6452c52f078c7ba5/splink-3.9.15-py3-none-any.whl", hash = "sha256:1b8f557743e633c785fa6da4030821d0cd1ccf03336d662db59f809956f4ec87", size = 3713845 }, + { url = "https://files.pythonhosted.org/packages/9b/68/bb9108f4341e41b95d203c9c8f47d7f52a7d6e96348b83dc3ba1f075e91d/splink-4.0.5-py3-none-any.whl", hash = "sha256:0afc28e12fc863030ad1add89dffa54c91a35b50f14fee64ac78bfa43f5d8866", size = 3717815 }, ] [[package]]