diff --git a/src/matchbox/models/linkers/splinklinker.py b/src/matchbox/models/linkers/splinklinker.py index c32da61a..ce1e0806 100644 --- a/src/matchbox/models/linkers/splinklinker.py +++ b/src/matchbox/models/linkers/splinklinker.py @@ -143,12 +143,15 @@ def check_ids_match(self) -> "SplinkSettings": "left_id and right_id must match in a Splink linker." ) return self - + @model_validator(mode="after") - def add_enforced_settings(self) -> "SplinkSettings": + def check_link_only(self) -> "SplinkSettings": if self.linker_settings.link_type != "link_only": raise ValueError('link_type must be set to "link_only"') - self.linker_settings.link_type = "link_only" + return self + + @model_validator(mode="after") + def add_enforced_settings(self) -> "SplinkSettings": self.linker_settings.unique_id_column_name = self.left_id return self diff --git a/test/client/test_linkers.py b/test/client/test_linkers.py index 5b75f6f6..a2609fd9 100644 --- a/test/client/test_linkers.py +++ b/test/client/test_linkers.py @@ -1,9 +1,11 @@ import pytest from matchbox import make_model, query from matchbox.helpers import selectors +from matchbox.models.linkers.splinklinker import SplinkLinkerFunction, SplinkSettings from matchbox.server.models import Source, SourceWarehouse from matchbox.server.postgresql import MatchboxPostgres from pandas import DataFrame +from splink import SettingsCreator from ..fixtures.db import AddDedupeModelsAndDataCallable, AddIndexedDataCallable from ..fixtures.models import ( @@ -187,3 +189,48 @@ def unique_non_null(s): assert isinstance(clusters, DataFrame) assert clusters.hash.nunique() == fx_data.unique_n + + +def test_splink_training_functions(): + # You can create a valid SplinkLinkerFunction + SplinkLinkerFunction( + function="estimate_u_using_random_sampling", + arguments={"max_pairs": 1e4}, + ) + # You can't reference a function that doesn't exist + with pytest.raises(ValueError): + SplinkLinkerFunction(function="made_up_funcname", arguments=dict()) + # You can't pass arguments that don't exist + with pytest.raises(ValueError): + SplinkLinkerFunction( + function="estimate_u_using_random_sampling", arguments={"foo": "bar"} + ) + +def test_splink_settings(): + valid_settings = SplinkSettings( + left_id="hash", + right_id="hash", + linker_training_functions=[], + linker_settings=SettingsCreator(link_type="link_only"), + threshold=None, + ) + assert valid_settings.linker_settings.unique_id_column_name == "hash" + # Can only use "link_only" + with pytest.raises(ValueError): + valid_settings = SplinkSettings( + left_id="hash", + right_id="hash", + linker_training_functions=[], + linker_settings=SettingsCreator(link_type="dedupe_only"), + threshold=None, + ) + # Left and right ID must coincide + with pytest.raises(ValueError): + valid_settings = SplinkSettings( + left_id="hash", + right_id="hash2", + linker_training_functions=[], + linker_settings=SettingsCreator(link_type="link_only"), + threshold=None, + ) +