diff --git a/pyproject.toml b/pyproject.toml index db09dcd2c1..0c6c05686f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "splink" -version = "0.2.9" +version = "0.3.0" description = "Implementation in Apache Spark of the EM algorithm to estimate parameters of Fellegi-Sunter's canonical model of record linkage." authors = ["Robin Linacre ", "Sam Lindsay", "Theodore Manassis"] license = "MIT" diff --git a/splink/truth.py b/splink/truth.py index 92d964c7e0..36126b595e 100644 --- a/splink/truth.py +++ b/splink/truth.py @@ -2,6 +2,8 @@ from functools import reduce from pyspark.sql import DataFrame +import pyspark.sql.functions as f +from pyspark.sql import Window altair_installed = True try: @@ -17,10 +19,16 @@ SparkSession = None -def _sql_gen_unique_id_keygen(table:str, uid_col1:str, uid_col2:str): +def _sql_gen_unique_id_keygen( + table: str, + uid_col1: str, + uid_col2: str, + source_dataset1: str = None, + source_dataset2: str = None, +): """Create a composite unique id for a pairwise comparisons This is a concatenation of the unique id of each record - of the pairwise comparison. + of the pairwise comparisons. The composite unique id is agnostic to the ordering i.e. it treats: @@ -36,138 +44,255 @@ def _sql_gen_unique_id_keygen(table:str, uid_col1:str, uid_col2:str): table (str): name of table uid_col1 (str): name of unique id col 1 uid_col2 (str): name of unique id col 2 + source_dataset1 (str, optional): Name of source dataset column if exists. Defaults to None. + source_dataset2 (str, optional): [description]. Defaults to None. Returns: str: sql case expression that outputs the composite unique_id """ + if source_dataset1: + concat_1 = f"concat({table}.{source_dataset1}, '_',{table}.{uid_col1})" + concat_2 = f"concat({table}.{source_dataset2}, '_',{table}.{uid_col2})" + else: + concat_1 = f"{table}.{uid_col1}" + concat_2 = f"{table}.{uid_col2}" + return f""" case - when {table}.{uid_col1} > {table}.{uid_col2} then concat({table}.{uid_col2}, '-', {table}.{uid_col1}) - else concat({table}.{uid_col1}, '-', {table}.{uid_col2}) + when {concat_1} > {concat_2} + then concat({concat_2}, '-', {concat_1}) + else concat({concat_1}, '-', {concat_2}) end """ -def _check_df_labels(df_labels, settings): - """Check the df_labels provided contains the expected columns +def _get_score_colname(df_e, score_colname=None): + if score_colname: + return score_colname + elif "tf_adjusted_match_prob" in df_e.columns: + return "tf_adjusted_match_prob" + elif "match_probability" in df_e.columns: + return "match_probability" + else: + raise ValueError("There doesn't appear to be a score column in df_e") + + +def dedupe_splink_scores( + df_e_with_dupes: DataFrame, + unique_id_colname: str, + score_colname: str = None, + selection_fn: str = "abs_val", +): + """Sometimes, multiple Splink jobs with different blocking rules are combined + into a single dataset of edges. Sometimes,the same pair of nodes will be + scored multiple times, once by each job. We need to deduplicate this dataset + so each pair of nodes appears only once + + Args: + df_e_with_dupes (DataFrame): Dataframe with dupes + unique_id_colname (str): Unique id column name e.g. unique_id + score_colname (str, optional): Which column contains scores? If none, inferred from + df_e_with_dupes.columns. Defaults to None. + selection_fn (str, optional): Where we have several different scores for a given + pair of records, how do we decide the final score? + Options are 'abs_val' and 'mean'. + abs_val: Take the value furthest from 0.5 i.e. the value that expresses most certainty + mean: Take the mean of all values + Defaults to 'abs_val'. """ - cols = df_labels.columns - colname = settings["unique_id_column_name"] + # Looking in blocking.py, the position of unique ids + # (whether they appear in _l or _r) is guaranteed + # in blocking outputs so we don't need to worry about + # inversions + + # This is not the case for labelled data - hence the need + # _sql_gen_unique_id_keygen to join labels to df_e - assert f"{colname}_l" in cols, f"{colname}_l should be a column in df_labels" - assert f"{colname}_r" in cols, f"{colname}_l should be a column in df_labels" - assert ( - "clerical_match_score" in cols - ), f"clerical_match_score should be a column in df_labels" + possible_vals = ["abs_val", "mean"] + if selection_fn not in possible_vals: + raise ValueError( + f"selection function should be in {possible_vals}, you passed {selection_fn}" + ) + + score_colname = _get_score_colname(df_e_with_dupes, score_colname) + + if selection_fn == "abs_val": + df_e_with_dupes = df_e_with_dupes.withColumn( + "absval", f.expr(f"0.5 - abs({score_colname})") + ) + + win_spec = Window.partitionBy( + [f"{unique_id_colname}_l", f"{unique_id_colname}_r"] + ).orderBy(f.col("absval").desc()) + df_e_with_dupes = df_e_with_dupes.withColumn( + "ranking", f.row_number().over(win_spec) + ) + df_e = df_e_with_dupes.filter(f.col("ranking") == 1) + df_e = df_e.drop("absval") + df_e = df_e.drop("ranking") + + if selection_fn == "mean": + win_spec = Window.partitionBy( + [f"{unique_id_colname}_l", f"{unique_id_colname}_r"] + ).orderBy(f.col(score_colname).desc()) -def _get_score_colname(settings): - score_colname = "match_probability" - for c in settings["comparison_columns"]: - if c["term_frequency_adjustments"] is True: - score_colname = "tf_adjusted_match_prob" - return score_colname + df_e_with_dupes = df_e_with_dupes.withColumn( + "ranking", f.row_number().over(win_spec) + ) + + df_e_with_dupes = df_e_with_dupes.withColumn( + score_colname, + f.avg(score_colname).over( + win_spec.rowsBetween( + Window.unboundedPreceding, Window.unboundedFollowing + ) + ), + ) + df_e = df_e_with_dupes.filter(f.col("ranking") == 1) + df_e = df_e.drop("ranking") -def _join_labels_to_results(df_labels, df_e, settings, spark): + return df_e - # df_labels is a dataframe like: - # | unique_id_l | unique_id_r | clerical_match_score | - # |:------------|:------------|---------------------:| - # | id1 | id2 | 0.9 | - # | id1 | id3 | 0.1 | - # df_e is a dataframe like - # | unique_id_l| unique_id_r| tf_adjusted_match_prob | - # |:-----------|:-----------|-----------------------:| - # | id1 | id2 | 0.85 | - # | id1 | id3 | 0.2 | - # | id2 | id3 | 0.1 | - settings = complete_settings_dict(settings, None) +def labels_with_splink_scores( + df_labels, + df_e, + unique_id_colname, + spark, + score_colname=None, + join_on_source_dataset=False, + retain_all_cols=False, +): + """Create a dataframe with clerical labels set against splink scores - _check_df_labels(df_labels, settings) + Assumes uniqueness of pairs of identifiers in both datasets - e.g. + if you have duplicate clerical labels or splink scores, you should + deduplicate them first - uid_colname = settings["unique_id_column_name"] + Args: + df_labels: a dataframe like: + | unique_id_l | unique_id_r | clerical_match_score | + |:------------|:------------|---------------------:| + | id1 | id2 | 0.9 | + | id1 | id3 | 0.1 | + df_e: a dataframe like + | unique_id_l| unique_id_r| tf_adjusted_match_prob | + |:-----------|:-----------|-----------------------:| + | id1 | id2 | 0.85 | + | id1 | id3 | 0.2 | + | id2 | id3 | 0.1 | + unique_id_colname (str): Unique id column name e.g. unique_id + spark : SparkSession + score_colname (float, optional): Allows user to explicitly state the column name + in the Splink dataset containing the Splink score. If none will be inferred + join_on_source_dataset (bool, optional): In certain scenarios (e.g. linking two tables), the IDs may be unique only within the input table + Where this is the case, you should include columns 'source_dataset_l' and 'source_dataset_r' + and set join_on_source_dataset=True, which will include the source dataset in the join key Defaults to False. + retain_all_cols (bool, optional): Retain all columns in input datasets. Defaults to False. - # If settings has tf_afjustments, use tf_adjusted_match_prob else use match_probability - score_colname = _get_score_colname(settings) + Returns: + DataFrame: Like: + | unique_id_l | unique_id_r | clerical_match_score | tf_adjusted_match_prob | found_by_blocking | + |--------------:|--------------:|-----------------------:|-------------------------:|:--------------------| + | 0 | 1 | 1 | 0.999566 | True | + | 0 | 2 | 1 | 0.999566 | True | + | 0 | 3 | 1 | 0.999989 | True | - # The join is trickier than it looks because there's no guarantee of which way around the two ids are - # it could be id1, id2 in df_labels and id2,id1 in df_e + """ + score_colname = _get_score_colname(df_e, score_colname) - uid_col_l = f"{uid_colname}_l" - uid_col_r = f"{uid_colname}_r" + uid_col_l = f"{unique_id_colname}_l" + uid_col_r = f"{unique_id_colname}_r" df_labels.createOrReplaceTempView("df_labels") df_e.createOrReplaceTempView("df_e") - sql = f""" - select + if join_on_source_dataset: + labels_key = _sql_gen_unique_id_keygen( + "df_labels", uid_col_l, uid_col_r, "source_dataset_l", "source_dataset_r" + ) + df_e_key = _sql_gen_unique_id_keygen( + "df_e", uid_col_l, uid_col_r, "source_dataset_l", "source_dataset_r" + ) + select_cols = f""" + df_e.source_dataset_l, + df_e.{uid_col_l}, + df_e.source_dataset_r, + df_e.{uid_col_r} + """ + else: + labels_key = _sql_gen_unique_id_keygen("df_labels", uid_col_l, uid_col_r) + df_e_key = _sql_gen_unique_id_keygen("df_e", uid_col_l, uid_col_r) + select_cols = f""" - df_labels.{uid_col_l}, - df_labels.{uid_col_r}, - clerical_match_score, + df_e.{uid_col_l}, - case - when {score_colname} is null then 0 - else {score_colname} - end as {score_colname}, + df_e.{uid_col_r} + """ - case - when {score_colname} is null then false - else true - end as found_by_blocking + if retain_all_cols: + cols1 = [f"df_e.{c} as df_e__{c}" for c in df_e.columns if c != score_colname] + cols2 = [ + f"df_labels.{c} as df_labels__{c}" + for c in df_labels.columns + if c != "clerical_match_score" + ] + cols1.extend(cols2) + select_smt = ", ".join(cols1) - from df_labels - left join df_e - on {_sql_gen_unique_id_keygen('df_labels', uid_col_l, uid_col_r)} - = {_sql_gen_unique_id_keygen('df_e', uid_col_l, uid_col_r)} + sql = f""" + select - """ + {select_smt}, + clerical_match_score, - return spark.sql(sql) + case + when {score_colname} is null then 0 + else {score_colname} + end as {score_colname}, + case + when {score_colname} is null then false + else true + end as found_by_blocking -def _categorise_scores_into_truth_cats( - df_e_with_labels, threshold_pred, settings, spark, threshold_actual=0.5 -): - """Take a dataframe with clerical labels and splink predictions and - label each row with truth categories (true positive, true negative etc) - """ - # df_e_with_labels is a dataframe like - # | unique_id_l | unique_id_r | clerical_match_score | tf_adjusted_match_prob | - # |:------------------|:--------------|-----------------------:|-------------------------:| - # | id1 | id2 | 0.9 | 0.85 | - # | id1 | id3 | 0.1 | 0.2 | + from df_labels + left join df_e + on {labels_key} + = {df_e_key} + """ - df_e_with_labels.createOrReplaceTempView("df_e_with_labels") + else: + sql = f""" + select - score_colname = _get_score_colname(settings) + {select_cols}, + clerical_match_score, - pred = f"({score_colname} >= {threshold_pred})" + case + when {score_colname} is null then 0 + else {score_colname} + end as {score_colname}, - actual = f"(clerical_match_score >= {threshold_actual})" + case + when {score_colname} is null then false + else true + end as found_by_blocking - sql = f""" - select - *, - cast ({threshold_pred} as float) as truth_threshold, - {actual} = 1.0 as P, - {actual} = 0.0 as N, - {pred} = 1.0 and {actual} = 1.0 as TP, - {pred} = 0.0 and {actual} = 0.0 as TN, - {pred} = 1.0 and {actual} = 0.0 as FP, - {pred} = 0.0 and {actual} = 1.0 as FN - from - df_e_with_labels + from df_labels + left join df_e + on {labels_key} + = {df_e_key} - """ + """ return spark.sql(sql) @@ -215,12 +340,11 @@ def _summarise_truth_cats(df_truth_cats, spark): def df_e_with_truth_categories( - df_labels: DataFrame, - df_e: DataFrame, - settings: dict, - threshold_pred: float, + df_labels_with_splink_scores, + threshold_pred, spark: SparkSession, threshold_actual: float = 0.5, + score_colname: str = None, ): """Join Splink's predictions to clerically labelled data and categorise rows by truth category (false positive, true positive etc.) @@ -228,90 +352,85 @@ def df_e_with_truth_categories( Note that df_labels Args: - df_labels (DataFrame): A dataframe of clerically labelled data - with ids that match the unique_id_column sepcified in the - splink settings object. If the column is called unique_id - df_labels should look like: - | unique_id_l | unique_id_r | clerical_match_score | - |:------------|:------------|---------------------:| - | id1 | id2 | 0.9 | - | id1 | id3 | 0.1 | - df_e (DataFrame): Splink output of scored pairwise record comparisons - | unique_id_l| unique_id_r| tf_adjusted_match_prob | - |:-----------|:-----------|-----------------------:| - | id1 | id2 | 0.85 | - | id1 | id3 | 0.2 | - | id2 | id3 | 0.1 | - settings (dict): splink settings dictionary + df_labels_with_splink_scores (DataFrame): A dataframe of labels and associated splink scores + usually the output of the truth.labels_with_splink_scores function threshold_pred (float): Threshold to use in categorising Splink predictions into match or no match spark (SparkSession): SparkSession object threshold_actual (float, optional): Threshold to use in categorising clerical match scores into match or no match. Defaults to 0.5. + score_colname (float, optional): Allows user to explicitly state the column name + in the Splink dataset containing the Splink score. If none will be inferred Returns: DataFrame: Dataframe of labels associated with truth category """ - df_labels = _join_labels_to_results(df_labels, df_e, settings, spark) - df_e_t = _categorise_scores_into_truth_cats( - df_labels, threshold_pred, settings, spark, threshold_actual - ) - return df_e_t + + df_labels_with_splink_scores.createOrReplaceTempView("df_labels_with_splink_scores") + + score_colname = _get_score_colname(df_labels_with_splink_scores) + + pred = f"({score_colname} >= {threshold_pred})" + + actual = f"(clerical_match_score >= {threshold_actual})" + + sql = f""" + select + *, + cast ({threshold_pred} as float) as truth_threshold, + {actual} = 1.0 as P, + {actual} = 0.0 as N, + {pred} = 1.0 and {actual} = 1.0 as TP, + {pred} = 0.0 and {actual} = 0.0 as TN, + {pred} = 1.0 and {actual} = 0.0 as FP, + {pred} = 0.0 and {actual} = 1.0 as FN + + from + df_labels_with_splink_scores + + """ + + return spark.sql(sql) def truth_space_table( - df_labels: DataFrame, - df_e: DataFrame, - settings: dict, + df_labels_with_splink_scores: DataFrame, spark: SparkSession, threshold_actual: float = 0.5, + score_colname: str = None, ): """Create a table of the ROC space i.e. truth table statistics for each discrimination threshold Args: - df_labels (DataFrame): A dataframe of clerically labelled data - with ids that match the unique_id_column sepcified in the - splink settings object. If the column is called unique_id - df_labels should look like: - | unique_id_l | unique_id_r | clerical_match_score | - |:------------|:------------|---------------------:| - | id1 | id2 | 0.9 | - | id1 | id3 | 0.1 | - df_e (DataFrame): Splink output of scored pairwise record comparisons - | unique_id_l| unique_id_r| tf_adjusted_match_prob | - |:-----------|:-----------|-----------------------:| - | id1 | id2 | 0.85 | - | id1 | id3 | 0.2 | - | id2 | id3 | 0.1 | - settings (dict): splink settings dictionary - spark (SparkSession): SparkSession object + df_labels_with_splink_scores (DataFrame): A dataframe of labels and associated splink scores + usually the output of the truth.labels_with_splink_scores function threshold_actual (float, optional): Threshold to use in categorising clerical match scores into match or no match. Defaults to 0.5. + score_colname (float, optional): Allows user to explicitly state the column name + in the Splink dataset containing the Splink score. If none will be inferred Returns: DataFrame: Table of 'truth space' i.e. truth categories for each threshold level """ - df_labels_results = _join_labels_to_results(df_labels, df_e, settings, spark) - # This is used repeatedly to generate the roc curve - df_labels_results.persist() + df_labels_with_splink_scores.persist() # We want percentiles of score to compute - score_colname = _get_score_colname(settings) + score_colname = _get_score_colname(df_labels_with_splink_scores, score_colname) percentiles = [x / 100 for x in range(0, 101)] - values_distinct = df_labels_results.select(score_colname).distinct() + values_distinct = df_labels_with_splink_scores.select(score_colname).distinct() thresholds = values_distinct.stat.approxQuantile(score_colname, percentiles, 0.0) thresholds.append(1.01) thresholds = sorted(set(thresholds)) roc_dfs = [] for thres in thresholds: - df_e_t = _categorise_scores_into_truth_cats( - df_labels_results, thres, settings, spark, threshold_actual + df_e_t = df_e_with_truth_categories( + df_labels_with_splink_scores, thres, spark, threshold_actual, score_colname ) df_roc_row = _summarise_truth_cats(df_e_t, spark) roc_dfs.append(df_roc_row) @@ -321,9 +440,7 @@ def truth_space_table( def roc_chart( - df_labels: DataFrame, - df_e: DataFrame, - settings: dict, + df_labels_with_splink_scores: DataFrame, spark: SparkSession, threshold_actual: float = 0.5, x_domain: list = None, @@ -333,21 +450,8 @@ def roc_chart( """Create a ROC chart from labelled data Args: - df_labels (DataFrame): A dataframe of clerically labelled data - with ids that match the unique_id_column sepcified in the - splink settings object. If the column is called unique_id - df_labels should look like: - | unique_id_l | unique_id_r | clerical_match_score | - |:------------|:------------|---------------------:| - | id1 | id2 | 0.9 | - | id1 | id3 | 0.1 | - df_e (DataFrame): Splink output of scored pairwise record comparisons - | unique_id_l| unique_id_r| tf_adjusted_match_prob | - |:-----------|:-----------|-----------------------:| - | id1 | id2 | 0.85 | - | id1 | id3 | 0.2 | - | id2 | id3 | 0.1 | - settings (dict): splink settings dictionary + df_labels_with_splink_scores (DataFrame): A dataframe of labels and associated splink scores + usually the output of the truth.labels_with_splink_scores function spark (SparkSession): SparkSession object threshold_actual (float, optional): Threshold to use in categorising clerical match scores into match or no match. Defaults to 0.5. @@ -393,7 +497,7 @@ def roc_chart( } data = truth_space_table( - df_labels, df_e, settings, spark, threshold_actual=threshold_actual + df_labels_with_splink_scores, spark, threshold_actual=threshold_actual ).toPandas() if not x_domain: @@ -406,7 +510,7 @@ def roc_chart( roc_chart_def["encoding"]["x"]["scale"] = {"domain": x_domain} - data = data.to_dict(orient="rows") + data = data.to_dict(orient="records") roc_chart_def["data"]["values"] = data @@ -417,9 +521,7 @@ def roc_chart( def precision_recall_chart( - df_labels, - df_e, - settings, + df_labels_with_splink_scores, spark, threshold_actual=0.5, domain=None, @@ -429,21 +531,8 @@ def precision_recall_chart( """Create a precision recall chart from labelled data Args: - df_labels (DataFrame): A dataframe of clerically labelled data - with ids that match the unique_id_column sepcified in the - splink settings object. If the column is called unique_id - df_labels should look like: - | unique_id_l | unique_id_r | clerical_match_score | - |:------------|:------------|---------------------:| - | id1 | id2 | 0.9 | - | id1 | id3 | 0.1 | - df_e (DataFrame): Splink output of scored pairwise record comparisons - | unique_id_l| unique_id_r| tf_adjusted_match_prob | - |:-----------|:-----------|-----------------------:| - | id1 | id2 | 0.85 | - | id1 | id3 | 0.2 | - | id2 | id3 | 0.1 | - settings (dict): splink settings dictionary + df_labels_with_splink_scores (DataFrame): A dataframe of labels and associated splink scores + usually the output of the truth.labels_with_splink_scores function spark (SparkSession): SparkSession object threshold_actual (float, optional): Threshold to use in categorising clerical match scores into match or no match. Defaults to 0.5. @@ -492,10 +581,10 @@ def precision_recall_chart( pr_chart_def["encoding"]["x"]["scale"]["domain"] = domain data = truth_space_table( - df_labels, df_e, settings, spark, threshold_actual=threshold_actual + df_labels_with_splink_scores, spark, threshold_actual=threshold_actual ).toPandas() - data = data.to_dict(orient="rows") + data = data.to_dict(orient="records") pr_chart_def["data"]["values"] = data diff --git a/tests/test_truth.py b/tests/test_truth.py index 2462f55ab4..bda53ab51e 100644 --- a/tests/test_truth.py +++ b/tests/test_truth.py @@ -4,22 +4,18 @@ from pyspark.sql import Row from splink.truth import ( + labels_with_splink_scores, df_e_with_truth_categories, truth_space_table, roc_chart, precision_recall_chart, + dedupe_splink_scores, ) import pytest -def test_roc(spark): - - settings = { - "link_type": "dedupe_only", - "unique_id_column_name": "person_id", - "comparison_columns": [{"col_name": "col1"}], - } +def test_roc_1(spark): df_e = [ {"person_id_l": 1, "person_id_r": 11, "match_probability": 0.0}, @@ -36,14 +32,18 @@ def test_roc(spark): ] df_labels = spark.createDataFrame(Row(**x) for x in df_labels) - df_truth = df_e_with_truth_categories(df_labels, df_e, settings, 0.5, spark) + df_labels_with_splink_scores = labels_with_splink_scores( + df_labels, df_e, "person_id", spark + ) + + df_truth = df_e_with_truth_categories(df_labels_with_splink_scores, 0.5, spark) df_truth = df_truth.toPandas() f1 = df_truth["person_id_l"] == 1 f2 = df_truth["person_id_r"] == 11 - row = df_truth[f1 & f2].to_dict(orient="rows")[0] + row = df_truth[f1 & f2].to_dict(orient="records")[0] assert row["P"] is False assert row["N"] is True @@ -52,14 +52,14 @@ def test_roc(spark): f1 = df_truth["person_id_l"] == 3 f2 = df_truth["person_id_r"] == 13 - row = df_truth[f1 & f2].to_dict(orient="rows")[0] + row = df_truth[f1 & f2].to_dict(orient="records")[0] assert row["P"] is False assert row["N"] is True assert row["TP"] is False assert row["FP"] is True - df_roc = truth_space_table(df_labels, df_e, settings, spark) + df_roc = truth_space_table(df_labels_with_splink_scores, spark) df_roc = df_roc.toPandas() # Note that our critiera are great than or equal to @@ -67,7 +67,7 @@ def test_roc(spark): f1 = df_roc["truth_threshold"] > 0.39 f2 = df_roc["truth_threshold"] < 0.41 - row = df_roc[f1 & f2].to_dict(orient="rows")[0] + row = df_roc[f1 & f2].to_dict(orient="records")[0] # FPR = FP/(FP+TN) = FP/N # At 0.4 we have @@ -82,6 +82,146 @@ def test_roc(spark): # Check no errors from charting functions - df_roc = truth_space_table(df_labels, df_e, settings, spark) - roc_chart(df_labels, df_e, settings, spark) - precision_recall_chart(df_labels, df_e, settings, spark) \ No newline at end of file + roc_chart(df_labels_with_splink_scores, spark) + precision_recall_chart(df_labels_with_splink_scores, spark) + + +def test_join(spark): + + df_e = [ + { + "source_dataset_l": "t1", + "person_id_l": 1, + "source_dataset_r": "t2", + "person_id_r": 1, + "tf_adjusted_match_prob": 0.0, + }, + { + "source_dataset_l": "t1", + "person_id_l": 1, + "source_dataset_r": "t2", + "person_id_r": 2, + "tf_adjusted_match_prob": 0.4, + }, + { + "source_dataset_l": "t2", + "person_id_l": 1, + "source_dataset_r": "t1", + "person_id_r": 2, + "tf_adjusted_match_prob": 1.0, + }, + ] + df_e = spark.createDataFrame(Row(**x) for x in df_e) + + df_labels = [ + { + "source_dataset_l": "t2", + "person_id_l": 1, + "source_dataset_r": "t1", + "person_id_r": 1, + "clerical_match_score": 0.1, + }, + { + "source_dataset_l": "t2", + "person_id_l": 1, + "source_dataset_r": "t1", + "person_id_r": 2, + "clerical_match_score": 0.45, + }, + { + "source_dataset_l": "t1", + "person_id_l": 1, + "source_dataset_r": "t2", + "person_id_r": 2, + "clerical_match_score": 0.01, + }, + ] + df_labels = spark.createDataFrame(Row(**x) for x in df_labels) + + df_labels_with_splink_scores = labels_with_splink_scores( + df_labels, df_e, "person_id", spark, join_on_source_dataset=True + ) + + df_pd = df_labels_with_splink_scores.toPandas() + + f1 = df_pd["source_dataset_l"] == "t1" + f2 = df_pd["person_id_l"] == 1 + f3 = df_pd["source_dataset_r"] == "t2" + f4 = df_pd["person_id_r"] == 1 + + row = df_pd[f1 & f2 & f3 & f4].to_dict(orient="records")[0] + + assert pytest.approx(row["tf_adjusted_match_prob"]) == 0.0 + assert pytest.approx(row["clerical_match_score"]) == 0.1 + + f1 = df_pd["source_dataset_l"] == "t2" + f2 = df_pd["person_id_l"] == 1 + f3 = df_pd["source_dataset_r"] == "t1" + f4 = df_pd["person_id_r"] == 2 + + row = df_pd[f1 & f2 & f3 & f4].to_dict(orient="records")[0] + + assert pytest.approx(row["tf_adjusted_match_prob"]) == 1.0 + assert pytest.approx(row["clerical_match_score"]) == 0.45 + + +def test_dedupe(spark): + + df_e = [ + {"person_id_l": 1, "person_id_r": 2, "match_probability": 0.1}, + {"person_id_l": 1, "person_id_r": 2, "match_probability": 0.8}, + {"person_id_l": 1, "person_id_r": 2, "match_probability": 0.3}, + {"person_id_l": 1, "person_id_r": 4, "match_probability": 1.0}, + {"person_id_l": 5, "person_id_r": 6, "match_probability": 0.0}, + ] + df_e = spark.createDataFrame(Row(**x) for x in df_e) + + df = dedupe_splink_scores(df_e, "person_id", selection_fn="abs_val") + + df_pd = df.toPandas() + + f1 = df_pd["person_id_l"] == 1 + f2 = df_pd["person_id_r"] == 2 + + row = df_pd[f1 & f2].to_dict(orient="records")[0] + + assert row["match_probability"] == 0.1 + + f1 = df_pd["person_id_l"] == 1 + f2 = df_pd["person_id_r"] == 4 + + row = df_pd[f1 & f2].to_dict(orient="records")[0] + + assert row["match_probability"] == 1.0 + + f1 = df_pd["person_id_l"] == 5 + f2 = df_pd["person_id_r"] == 6 + + row = df_pd[f1 & f2].to_dict(orient="records")[0] + + assert row["match_probability"] == 0.0 + + df = dedupe_splink_scores(df_e, "person_id", selection_fn="mean") + + df_pd = df.toPandas() + + f1 = df_pd["person_id_l"] == 1 + f2 = df_pd["person_id_r"] == 2 + + row = df_pd[f1 & f2].to_dict(orient="records")[0] + + assert pytest.approx(row["match_probability"]) == 0.4 + + f1 = df_pd["person_id_l"] == 1 + f2 = df_pd["person_id_r"] == 4 + + row = df_pd[f1 & f2].to_dict(orient="records")[0] + + assert row["match_probability"] == 1.0 + + f1 = df_pd["person_id_l"] == 5 + f2 = df_pd["person_id_r"] == 6 + + row = df_pd[f1 & f2].to_dict(orient="records")[0] + + assert row["match_probability"] == 0.0