Skip to content

Commit

Permalink
add relabel_cc post-proc option for two/three_class predictions
Browse files Browse the repository at this point in the history
  • Loading branch information
yajivunev committed Aug 24, 2023
1 parent ca770d3 commit 6e5c99b
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 54 deletions.
68 changes: 43 additions & 25 deletions exercise.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1621,7 +1621,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "18fb5f75",
"id": "17e7dd63",
"metadata": {
"tags": []
},
Expand Down Expand Up @@ -1702,7 +1702,7 @@
"source": [
"## Task 3.0: Post-processing / further improvements\n",
"\n",
"- Before we were just thresholding and relabeling connected components. Now we will see a more advanced post-processing strategy called watershed.\n",
"- Before we were just thresholding and relabeling connected components. Now we will also see a more advanced post-processing strategy called watershed which works with affs.\n",
"\n",
"\n",
"- We also want to gauge our model performance so we will introduce some evaluation methods for instance segmentation.\n",
Expand Down Expand Up @@ -1767,14 +1767,17 @@
" # get boundary mask\n",
" boundary_mask = get_boundary_mask(pred, prediction_type, thresh=thresh)\n",
" \n",
" # get boundary distances\n",
" boundary_distances = distance_transform_edt(boundary_mask)\n",
" \n",
" # get segmentation\n",
" seg = watershed_from_boundary_distance(\n",
" boundary_distances,\n",
" boundary_mask\n",
" )\n",
" # get segmentation \n",
" if prediction_type == \"affs\":\n",
" # get boundary distances\n",
" boundary_distances = distance_transform_edt(boundary_mask)\n",
" \n",
" seg = watershed_from_boundary_distance(\n",
" boundary_distances,\n",
" boundary_mask\n",
" )\n",
" else:\n",
" seg = relabel_cc(boundary_mask)\n",
" \n",
" gt_labels = imread(test_loader.dataset.mask_files[idx])\n",
" \n",
Expand Down Expand Up @@ -1852,12 +1855,17 @@
" thresh = np.mean(pred)\n",
" \n",
" boundary_mask = get_boundary_mask(pred, prediction_type, thresh=thresh)\n",
" boundary_distances = distance_transform_edt(boundary_mask)\n",
" \n",
" pred_labels = watershed_from_boundary_distance(\n",
" boundary_distances,\n",
" boundary_mask\n",
" )\n",
" if prediction_type == \"affs\":\n",
" boundary_distances = distance_transform_edt(boundary_mask)\n",
"\n",
" pred_labels = watershed_from_boundary_distance(\n",
" boundary_distances,\n",
" boundary_mask\n",
" )\n",
" \n",
" else:\n",
" pred_labels = relabel_cc(boundary_mask)\n",
" \n",
" gt_labels = imread(test_loader.dataset.mask_files[idx])\n",
" \n",
Expand Down Expand Up @@ -1926,12 +1934,17 @@
" thresh = np.mean(thresh)\n",
" \n",
" boundary_mask = get_boundary_mask(pred, prediction_type, thresh)\n",
" boundary_distances = distance_transform_edt(boundary_mask)\n",
" \n",
" pred_labels = watershed_from_boundary_distance(\n",
" boundary_distances,\n",
" boundary_mask\n",
" )\n",
" if prediction_type == \"affs\":\n",
" boundary_distances = distance_transform_edt(boundary_mask)\n",
"\n",
" pred_labels = watershed_from_boundary_distance(\n",
" boundary_distances,\n",
" boundary_mask\n",
" )\n",
" \n",
" else:\n",
" pred_labels = relabel_cc(boundary_mask)\n",
" \n",
" gt_labels = imread(test_loader.dataset.mask_files[idx])\n",
" \n",
Expand Down Expand Up @@ -2279,12 +2292,17 @@
" thresh = np.mean(pred)\n",
" \n",
" boundary_mask = get_boundary_mask(pred, prediction_type, thresh)\n",
" boundary_distances = distance_transform_edt(boundary_mask)\n",
" \n",
" pred_labels = watershed_from_boundary_distance(\n",
" boundary_distances,\n",
" boundary_mask\n",
" )\n",
" if prediction_type == \"affs\":\n",
" boundary_distances = distance_transform_edt(boundary_mask)\n",
"\n",
" pred_labels = watershed_from_boundary_distance(\n",
" boundary_distances,\n",
" boundary_mask\n",
" )\n",
" \n",
" else:\n",
" pred_labels = relabel_cc(boundary_mask)\n",
" \n",
" gt_labels = imread(test_loader.dataset.mask_files[idx])\n",
" \n",
Expand Down
66 changes: 42 additions & 24 deletions solutions.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2106,7 +2106,7 @@
"source": [
"## Task 3.0: Post-processing / further improvements\n",
"\n",
"- Before we were just thresholding and relabeling connected components. Now we will see a more advanced post-processing strategy called watershed.\n",
"- Before we were just thresholding and relabeling connected components. Now we will also see a more advanced post-processing strategy called watershed which works with affs.\n",
"\n",
"\n",
"- We also want to gauge our model performance so we will introduce some evaluation methods for instance segmentation.\n",
Expand Down Expand Up @@ -2171,14 +2171,17 @@
" # get boundary mask\n",
" boundary_mask = get_boundary_mask(pred, prediction_type, thresh=thresh)\n",
" \n",
" # get boundary distances\n",
" boundary_distances = distance_transform_edt(boundary_mask)\n",
" \n",
" # get segmentation\n",
" seg = watershed_from_boundary_distance(\n",
" boundary_distances,\n",
" boundary_mask\n",
" )\n",
" # get segmentation \n",
" if prediction_type == \"affs\":\n",
" # get boundary distances\n",
" boundary_distances = distance_transform_edt(boundary_mask)\n",
" \n",
" seg = watershed_from_boundary_distance(\n",
" boundary_distances,\n",
" boundary_mask\n",
" )\n",
" else:\n",
" seg = relabel_cc(boundary_mask)\n",
" \n",
" gt_labels = imread(test_loader.dataset.mask_files[idx])\n",
" \n",
Expand Down Expand Up @@ -2256,12 +2259,17 @@
" thresh = np.mean(pred)\n",
" \n",
" boundary_mask = get_boundary_mask(pred, prediction_type, thresh=thresh)\n",
" boundary_distances = distance_transform_edt(boundary_mask)\n",
" \n",
" pred_labels = watershed_from_boundary_distance(\n",
" boundary_distances,\n",
" boundary_mask\n",
" )\n",
" if prediction_type == \"affs\":\n",
" boundary_distances = distance_transform_edt(boundary_mask)\n",
"\n",
" pred_labels = watershed_from_boundary_distance(\n",
" boundary_distances,\n",
" boundary_mask\n",
" )\n",
" \n",
" else:\n",
" pred_labels = relabel_cc(boundary_mask)\n",
" \n",
" gt_labels = imread(test_loader.dataset.mask_files[idx])\n",
" \n",
Expand Down Expand Up @@ -2330,12 +2338,17 @@
" thresh = np.mean(thresh)\n",
" \n",
" boundary_mask = get_boundary_mask(pred, prediction_type, thresh)\n",
" boundary_distances = distance_transform_edt(boundary_mask)\n",
" \n",
" pred_labels = watershed_from_boundary_distance(\n",
" boundary_distances,\n",
" boundary_mask\n",
" )\n",
" if prediction_type == \"affs\":\n",
" boundary_distances = distance_transform_edt(boundary_mask)\n",
"\n",
" pred_labels = watershed_from_boundary_distance(\n",
" boundary_distances,\n",
" boundary_mask\n",
" )\n",
" \n",
" else:\n",
" pred_labels = relabel_cc(boundary_mask)\n",
" \n",
" gt_labels = imread(test_loader.dataset.mask_files[idx])\n",
" \n",
Expand Down Expand Up @@ -2826,12 +2839,17 @@
" thresh = np.mean(pred)\n",
" \n",
" boundary_mask = get_boundary_mask(pred, prediction_type, thresh)\n",
" boundary_distances = distance_transform_edt(boundary_mask)\n",
" \n",
" pred_labels = watershed_from_boundary_distance(\n",
" boundary_distances,\n",
" boundary_mask\n",
" )\n",
" if prediction_type == \"affs\":\n",
" boundary_distances = distance_transform_edt(boundary_mask)\n",
"\n",
" pred_labels = watershed_from_boundary_distance(\n",
" boundary_distances,\n",
" boundary_mask\n",
" )\n",
" \n",
" else:\n",
" pred_labels = relabel_cc(boundary_mask)\n",
" \n",
" gt_labels = imread(test_loader.dataset.mask_files[idx])\n",
" \n",
Expand Down
9 changes: 4 additions & 5 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,14 +184,13 @@ def watershed_from_boundary_distance(

def get_boundary_mask(pred, prediction_type, thresh=None):

if prediction_type == 'two_class' or prediction_type == 'sdt':
if prediction_type == 'sdt' or prediction_type == 'two_class':
# simple threshold
boundary_mask = pred > thresh

elif prediction_type == 'three_class':
# Return the indices of the maximum values along channel axis, then set mask to cell interior (1)
boundary_mask = np.argmax(pred, axis=0)
boundary_mask = boundary_mask == 1
# fg = prediction greater than / equal to threshold
boundary_mask = pred[1] > thresh

elif prediction_type == 'affs':
# take mean of combined affs then threshold
Expand Down Expand Up @@ -262,4 +261,4 @@ def evaluate(gt_labels: np.ndarray, pred_labels: np.ndarray, th: float = 0.5):
precision = tp / max(1, tp + fp)
recall = tp / max(1, tp + fn)

return ap, precision, recall, tp, fp, fn
return ap, precision, recall, tp, fp, fn

0 comments on commit 6e5c99b

Please sign in to comment.