diff --git a/.gitignore b/.gitignore index 831b7c8a..d8b0e0a4 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,9 @@ datasets.toml scratch/ *.ipynb_checkpoints .ruff_cache +notebooks/ +*.parquet +.DS_Store # DuckDB *.duckdb diff --git a/src/matchbox/server/postgresql/utils/insert.py b/src/matchbox/server/postgresql/utils/insert.py index c58abd84..c68ede77 100644 --- a/src/matchbox/server/postgresql/utils/insert.py +++ b/src/matchbox/server/postgresql/utils/insert.py @@ -207,67 +207,59 @@ def _cluster_results_to_hierarchical( clusters: Connected components at each threshold Returns: - List of (parent, child, threshold) tuples representing the hierarchy + list of (parent, child, threshold) tuples representing the hierarchy """ - # Create initial hierarchy from base components prob_df = probabilities.dataframe cluster_df = clusters.dataframe + # Sort thresholds in descending order thresholds = sorted(cluster_df["threshold"].unique(), reverse=True) - # Add all clusters corresponding to a simple two-item probability edge - hierarchy = [] - for _, row in prob_df.iterrows(): - parent, left_id, right_id, prob = row[ - ["hash", "left_id", "right_id", "probability"] - ] - hierarchy.extend( - [(parent, left_id, float(prob)), (parent, right_id, float(prob))] - ) - - # Create adjacency structure for quick lookups - adj_dict: dict[bytes, set[tuple[bytes, float]]] = defaultdict(set) - for parent, child, prob in hierarchy: - adj_dict[child].add((parent, prob)) + hierarchy: list[tuple[bytes, bytes, float]] = [] + ultimate_parents: dict[bytes, set[bytes]] = defaultdict(set) - # Process each threshold level, getting clusters at each threshold + # Process each threshold level for threshold in thresholds: threshold_float = float(threshold) - current_clusters = cluster_df[cluster_df["threshold"] == threshold] + # Add new pairwise relationships at this threshold + current_probs = prob_df[prob_df["probability"] == threshold_float] + + for _, row in current_probs.iterrows(): + parent = row["hash"] + left_id = row["left_id"] + right_id = row["right_id"] + + hierarchy.extend( + [ + (parent, left_id, threshold_float), + (parent, right_id, threshold_float), + ] + ) + + ultimate_parents[left_id].add(parent) + ultimate_parents[right_id].add(parent) + + # Process clusters at this threshold + current_clusters = cluster_df[cluster_df["threshold"] == threshold_float] - # Group by parent to process each component + # Group by parent to process components together for parent, group in current_clusters.groupby("parent"): - members = set(group["child"]) - if len(members) <= 2: - continue - - seen = set(members) - current = set(members) - ultimate_parents = set() - - # Keep traversing until we've explored all paths - while current: - next_level = set() - # If any current nodes have no parents above threshold, - # they are ultimate parents for this threshold - for node in current: - parents = { - p for p, prob in adj_dict[node] if prob >= threshold_float - } - next_parents = parents - seen - if not parents: # No parents = ultimate parent - ultimate_parents.add(node) - - next_level.update(next_parents) - seen.update(parents) - - current = next_level - - for up in ultimate_parents: + children = set(group["child"]) + if len(children) <= 2: + continue # Skip pairs already handled by pairwise probabilities + + current_ultimate_parents: set[bytes] = set() + for child in children: + current_ultimate_parents.update(ultimate_parents[child]) + + for up in current_ultimate_parents: hierarchy.append((parent, up, threshold_float)) - adj_dict[up].add((parent, threshold_float)) + for child in children: + ultimate_parents[child] = {parent} + + # Sort hierarchy by threshold (descending), then parent, then child return sorted(hierarchy, key=lambda x: (x[2], x[0], x[1]), reverse=True)