Skip to content

Commit

Permalink
Hungarian assignment for fix frames and segment stitching
Browse files Browse the repository at this point in the history
done, now the default assignment mode.
  • Loading branch information
isaacrobinson2000 committed Jul 13, 2023
1 parent 7ed8720 commit 56ff46d
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 12 deletions.
13 changes: 3 additions & 10 deletions diplomat/predictors/fpe/frame_passes/fix_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,6 @@ def create_fix_frame(
)

if(algorithm == "greedy"):
print("Old Algorithm...")
select_mask = np.zeros((num_outputs, fb_data.num_bodyparts), bool)
for __ in range(fb_data.num_bodyparts):
# Compute the shortest node paths for every skeleton...
Expand Down Expand Up @@ -290,7 +289,6 @@ def create_fix_frame(
del score_graph[other_part][best_part]
del score_graph[best_part][other_part]
else:
print("New algorithm...")
select_mask = np.zeros(fb_data.num_bodyparts // num_outputs, dtype=bool)
select_mask[fixed_group] = True

Expand All @@ -306,22 +304,17 @@ def create_fix_frame(

grouped_skel_scores = skel_scores.reshape((num_outputs, -1, num_outputs))
net_part_type_error = np.nanmin(grouped_skel_scores, axis=2).sum(axis=0)
print(grouped_skel_scores)
print(net_part_type_error)

min_group = cls._masked_argmin(net_part_type_error, ~select_mask)[0]

select_mask[min_group] = True
print(min_group)
print(grouped_skel_scores[:, min_group, :].reshape(num_outputs, num_outputs))
opt_rows, opt_cols = linear_sum_assignment(
grouped_skel_scores[:, min_group, :].reshape(num_outputs, num_outputs)
)
print(opt_rows, opt_cols)

for row_idx, col_idx in zip(opt_rows, opt_cols):
new_i = min_group + row_idx
best_part = min_group + col_idx
new_i = min_group * num_outputs + row_idx
best_part = min_group * num_outputs + col_idx
fixed_frame[new_i] = fb_data.frames[frame_idx][best_part]
fixed_frame[new_i].disable_occluded = True

Expand Down Expand Up @@ -584,7 +577,7 @@ def get_config_options(cls) -> ConfigSpec:
"Specify the fixed frame manually by setting to an integer index."
),
"skeleton_assignment_algorithm": (
"greedy",
"hungarian",
tc.Literal("greedy", "hungarian"),
"The algorithm to use for assigning body parts to skeletons when creating the fix frame."
)
Expand Down
4 changes: 2 additions & 2 deletions diplomat/predictors/sfpe/segmented_frame_pass_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1340,10 +1340,10 @@ def get_settings(cls) -> ConfigSpec:
"The mode to utilize during sparsification."
),
"assignment_algorithm": (
"greedy",
"hungarian",
type_casters.Literal("greedy", "hungarian"),
"The algorithm to use for assigning parts to bodies and stitching parts/bodies across segments."
"Greedy is faster, hungarian provides better results."
"Greedy is faster/simpler, hungarian provides better results."
)
}

Expand Down

0 comments on commit 56ff46d

Please sign in to comment.