Skip to content

Commit

Permalink
Revert "WIP"
Browse files Browse the repository at this point in the history
This reverts commit ad4ea5e.
  • Loading branch information
dcherian committed Nov 29, 2023
1 parent ad4ea5e commit 72c7e16
Showing 1 changed file with 7 additions and 19 deletions.
26 changes: 7 additions & 19 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def slices_from_chunks(chunks):
return product(*slices)


# @memoize
@memoize
def find_group_cohorts(labels, chunks, merge: bool = True) -> dict:
"""
Finds groups labels that occur together aka "cohorts"
Expand Down Expand Up @@ -312,13 +312,10 @@ def invert(x) -> tuple[np.ndarray, ...]:

# precompute needed metrics for the quadratic loop below.
items = tuple((k, len(k), set(k), v) for k, v in sorted_chunks_cohorts.items() if k)
import copy
items2 = list(copy.deepcopy(items[1:]))

merged_cohorts = {}
merged_keys: set[tuple] = set()

# import ipdb; ipdb.set_trace()
# Now we iterate starting with the longest number of chunks,
# and then merge in cohorts that are present in a subset of those chunks
# I think this is suboptimal and must fail at some point.
Expand All @@ -330,20 +327,12 @@ def invert(x) -> tuple[np.ndarray, ...]:
new_value = v1
# iterate in reverse since we expect small cohorts
# to be most likely merged in to larger ones
to_delete = []
for idx2, (k2, len_k2, set_k2, v2) in enumerate(reversed(items2)):
if v1 == v2 or k1 == k2:
continue
if (len(set_k2 & new_key) / len_k2) > 0.75:
new_key |= set_k2
new_value += v2
assert items2[len(items2)-idx2 - 1][0] == k2
to_delete.append(len(items2)-idx2 - 1)
merged_keys.update((k2,))
# print(to_delete)
for delete in reversed(sorted(to_delete)):
del items2[delete]

for k2, len_k2, set_k2, v2 in reversed(items[idx + 1 :]):
if k2 not in merged_keys:
if (len(set_k2 & new_key) / len_k2) > 0.75:
new_key |= set_k2
new_value += v2
merged_keys.update((k2,))
sorted_ = sorted(new_value)
merged_cohorts[tuple(sorted(new_key))] = sorted_
if idx == 0 and (len(sorted_) == nlabels) and (sorted_ == ilabels).all():
Expand All @@ -352,7 +341,6 @@ def invert(x) -> tuple[np.ndarray, ...]:
# sort by first label in cohort
# This will help when sort=True (default)
# and we have to resort the dask array
print(merged_cohorts)
return dict(sorted(merged_cohorts.items(), key=lambda kv: kv[1][0]))

else:
Expand Down

0 comments on commit 72c7e16

Please sign in to comment.