diff --git a/flox/core.py b/flox/core.py index b32e3b043..c9f15059a 100644 --- a/flox/core.py +++ b/flox/core.py @@ -214,7 +214,7 @@ def slices_from_chunks(chunks): return product(*slices) -# @memoize +@memoize def find_group_cohorts(labels, chunks, merge: bool = True) -> dict: """ Finds groups labels that occur together aka "cohorts" @@ -312,13 +312,10 @@ def invert(x) -> tuple[np.ndarray, ...]: # precompute needed metrics for the quadratic loop below. items = tuple((k, len(k), set(k), v) for k, v in sorted_chunks_cohorts.items() if k) - import copy - items2 = list(copy.deepcopy(items[1:])) merged_cohorts = {} merged_keys: set[tuple] = set() - # import ipdb; ipdb.set_trace() # Now we iterate starting with the longest number of chunks, # and then merge in cohorts that are present in a subset of those chunks # I think this is suboptimal and must fail at some point. @@ -330,20 +327,12 @@ def invert(x) -> tuple[np.ndarray, ...]: new_value = v1 # iterate in reverse since we expect small cohorts # to be most likely merged in to larger ones - to_delete = [] - for idx2, (k2, len_k2, set_k2, v2) in enumerate(reversed(items2)): - if v1 == v2 or k1 == k2: - continue - if (len(set_k2 & new_key) / len_k2) > 0.75: - new_key |= set_k2 - new_value += v2 - assert items2[len(items2)-idx2 - 1][0] == k2 - to_delete.append(len(items2)-idx2 - 1) - merged_keys.update((k2,)) - # print(to_delete) - for delete in reversed(sorted(to_delete)): - del items2[delete] - + for k2, len_k2, set_k2, v2 in reversed(items[idx + 1 :]): + if k2 not in merged_keys: + if (len(set_k2 & new_key) / len_k2) > 0.75: + new_key |= set_k2 + new_value += v2 + merged_keys.update((k2,)) sorted_ = sorted(new_value) merged_cohorts[tuple(sorted(new_key))] = sorted_ if idx == 0 and (len(sorted_) == nlabels) and (sorted_ == ilabels).all(): @@ -352,7 +341,6 @@ def invert(x) -> tuple[np.ndarray, ...]: # sort by first label in cohort # This will help when sort=True (default) # and we have to resort the dask array - print(merged_cohorts) return dict(sorted(merged_cohorts.items(), key=lambda kv: kv[1][0])) else: