Skip to content

Commit

Permalink
Merge pull request #2444 from moj-analytical-services/bug/cluster-lin…
Browse files Browse the repository at this point in the history
…k-only

Fix clustering in linky jobs with source dataset column on Postgres
  • Loading branch information
ADBond authored Oct 3, 2024
2 parents 2936d77 + d05ff82 commit 0166096
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 10 deletions.
10 changes: 10 additions & 0 deletions splink/internals/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
)
from splink.internals.vertically_concatenate import (
compute_df_concat_with_tf,
concat_table_column_names,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -252,6 +253,15 @@ def _source_dataset_column_already_exists(self):
in input_cols
)

@property
def _concat_table_column_names(self) -> list[str]:
"""
Returns the columns actually present in __splink__df_concat table.
Includes source dataset name if it's been created, and logic of additional
columns already taken care of
"""
return concat_table_column_names(self)

@property
def _cache_uid(self):
return self._settings_obj._cache_uid
Expand Down
18 changes: 8 additions & 10 deletions splink/internals/linker_components/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
_composite_unique_id_from_edges_sql,
_composite_unique_id_from_nodes_sql,
)
from splink.internals.vertically_concatenate import enqueue_df_concat
from splink.internals.vertically_concatenate import (
concat_table_column_names,
enqueue_df_concat,
)

if TYPE_CHECKING:
from splink.internals.linker import Linker
Expand Down Expand Up @@ -135,16 +138,11 @@ def cluster_pairwise_predictions_at_threshold(

enqueue_df_concat(linker, pipeline)

df_obj = next(iter(linker._input_tables_dict.values()))
columns = df_obj.columns_escaped

if linker._settings_obj._get_source_dataset_column_name_is_required():
columns.insert(
1,
linker._settings_obj.column_info_settings.source_dataset_input_column.name,
)
columns = concat_table_column_names(self._linker)
# don't want to include salting column in output if present
columns_without_salt = filter(lambda x: x != "__splink_salt", columns)

select_columns_sql = ", ".join(columns)
select_columns_sql = ", ".join(columns_without_salt)

sql = f"""
select
Expand Down
29 changes: 29 additions & 0 deletions splink/internals/vertically_concatenate.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,35 @@ def compute_df_concat(linker: Linker, pipeline: CTEPipeline) -> SplinkDataFrame:
return nodes_with_tf


def concat_table_column_names(linker: Linker) -> list[str]:
"""
Returns list of column names of the table __splink__df_concat,
without needing to instantiate the table.
"""
source_dataset_input_column = (
linker._settings_obj.column_info_settings.source_dataset_input_column
)

input_tables = linker._input_tables_dict
salting_required = linker._settings_obj.salting_required

df_obj = next(iter(input_tables.values()))
columns = df_obj.columns_escaped
if salting_required:
columns.append("__splink_salt")

if len(input_tables) > 1:
source_dataset_column_already_exists = False
if source_dataset_input_column:
source_dataset_column_already_exists = (
source_dataset_input_column.unquote().name
in [c.unquote().name for c in df_obj.columns]
)
if not source_dataset_column_already_exists:
columns.append("source_dataset")
return columns


def split_df_concat_with_tf_into_two_tables_sqls(
input_tablename: str, source_dataset_col: str, sample_switch: bool = False
) -> list[dict[str, str]]:
Expand Down
65 changes: 65 additions & 0 deletions tests/test_clustering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import pandas as pd
from pytest import mark

import splink.comparison_library as cl
from splink import Linker, SettingsCreator, block_on

from .decorator import mark_with_dialects_excluding

df = pd.read_csv("./tests/datasets/fake_1000_from_splink_demos.csv")
# we just want to check it runs, so use a small slice of the data
df = df[0:25]
df_l = df.copy()
df_r = df.copy()
df_m = df.copy()
df_l["source_dataset"] = "my_left_ds"
df_r["source_dataset"] = "my_right_ds"
df_m["source_dataset"] = "my_middle_ds"
df_combined = pd.concat([df_l, df_r])


@mark_with_dialects_excluding()
@mark.parametrize(
["link_type", "input_pd_tables"],
[
["dedupe_only", [df]],
["link_only", [df, df]], # no source dataset
["link_only", [df_l, df_r]], # source dataset column
["link_only", [df_combined]], # concatenated frame
["link_only", [df_l, df_m, df_r]],
["link_and_dedupe", [df, df]], # no source dataset
["link_and_dedupe", [df_l, df_r]], # source dataset column
["link_and_dedupe", [df_combined]], # concatenated frame
],
ids=[
"dedupe",
"link_only_no_source_dataset",
"link_only_with_source_dataset",
"link_only_concat",
"link_only_three_tables",
"link_and_dedupe_no_source_dataset",
"link_and_dedupe_with_source_dataset",
"link_and_dedupe_concat",
],
)
def test_clustering(test_helpers, dialect, link_type, input_pd_tables):
helper = test_helpers[dialect]

settings = SettingsCreator(
link_type=link_type,
comparisons=[
cl.ExactMatch("first_name"),
cl.ExactMatch("surname"),
cl.ExactMatch("dob"),
cl.ExactMatch("city"),
],
blocking_rules_to_generate_predictions=[
block_on("surname"),
block_on("dob"),
],
)
linker_input = list(map(helper.convert_frame, input_pd_tables))
linker = Linker(linker_input, settings, **helper.extra_linker_args())

df_predict = linker.inference.predict()
linker.clustering.cluster_pairwise_predictions_at_threshold(df_predict, 0.95)

0 comments on commit 0166096

Please sign in to comment.