diff --git a/splink/internals/spark/database_api.py b/splink/internals/spark/database_api.py index 750825dfe2..5de36f8847 100644 --- a/splink/internals/spark/database_api.py +++ b/splink/internals/spark/database_api.py @@ -246,6 +246,14 @@ def _repartition_if_needed(self, spark_df, templated_name): num_partitions = math.ceil(num_partitions / 4) elif templated_name == "__splink__blocked_id_pairs": num_partitions = math.ceil(num_partitions / 6) + elif templated_name == "__splink__distinct_clusters_at_threshold": + num_partitions = 1 + elif templated_name == "__splink__nodes_in_play": + num_partitions = math.ceil(num_partitions / 10) + elif templated_name == "__splink__edges_in_play": + num_partitions = math.ceil(num_partitions / 10) + elif templated_name == "__splink__clusters_at_threshold": + num_partitions = math.ceil(num_partitions / 10) if re.fullmatch(r"|".join(names_to_repartition), templated_name): spark_df = spark_df.repartition(num_partitions) @@ -266,6 +274,10 @@ def _break_lineage_and_repartition(self, spark_df, templated_name, physical_name r"__splink__df_connected_components_df", r"__splink__blocked_id_pairs", r"__splink__marginal_exploded_ids_blocking_rule.*", + r"__splink__nodes_in_play", + r"__splink__edges_in_play", + r"__splink__clusters_at_threshold", + r"__splink__distinct_clusters_at_threshold", ] if re.fullmatch(r"|".join(regex_to_persist), templated_name):