diff --git a/exercise.ipynb b/exercise.ipynb index c9d0365..3e8dffa 100644 --- a/exercise.ipynb +++ b/exercise.ipynb @@ -1621,7 +1621,7 @@ { "cell_type": "code", "execution_count": null, - "id": "18fb5f75", + "id": "17e7dd63", "metadata": { "tags": [] }, @@ -1702,7 +1702,7 @@ "source": [ "## Task 3.0: Post-processing / further improvements\n", "\n", - "- Before we were just thresholding and relabeling connected components. Now we will see a more advanced post-processing strategy called watershed.\n", + "- Before we were just thresholding and relabeling connected components. Now we will also see a more advanced post-processing strategy called watershed which works with affs.\n", "\n", "\n", "- We also want to gauge our model performance so we will introduce some evaluation methods for instance segmentation.\n", @@ -1767,14 +1767,17 @@ " # get boundary mask\n", " boundary_mask = get_boundary_mask(pred, prediction_type, thresh=thresh)\n", " \n", - " # get boundary distances\n", - " boundary_distances = distance_transform_edt(boundary_mask)\n", - " \n", - " # get segmentation\n", - " seg = watershed_from_boundary_distance(\n", - " boundary_distances,\n", - " boundary_mask\n", - " )\n", + " # get segmentation \n", + " if prediction_type == \"affs\":\n", + " # get boundary distances\n", + " boundary_distances = distance_transform_edt(boundary_mask)\n", + " \n", + " seg = watershed_from_boundary_distance(\n", + " boundary_distances,\n", + " boundary_mask\n", + " )\n", + " else:\n", + " seg = relabel_cc(boundary_mask)\n", " \n", " gt_labels = imread(test_loader.dataset.mask_files[idx])\n", " \n", @@ -1852,12 +1855,17 @@ " thresh = np.mean(pred)\n", " \n", " boundary_mask = get_boundary_mask(pred, prediction_type, thresh=thresh)\n", - " boundary_distances = distance_transform_edt(boundary_mask)\n", " \n", - " pred_labels = watershed_from_boundary_distance(\n", - " boundary_distances,\n", - " boundary_mask\n", - " )\n", + " if prediction_type == \"affs\":\n", + " boundary_distances = distance_transform_edt(boundary_mask)\n", + "\n", + " pred_labels = watershed_from_boundary_distance(\n", + " boundary_distances,\n", + " boundary_mask\n", + " )\n", + " \n", + " else:\n", + " pred_labels = relabel_cc(boundary_mask)\n", " \n", " gt_labels = imread(test_loader.dataset.mask_files[idx])\n", " \n", @@ -1926,12 +1934,17 @@ " thresh = np.mean(thresh)\n", " \n", " boundary_mask = get_boundary_mask(pred, prediction_type, thresh)\n", - " boundary_distances = distance_transform_edt(boundary_mask)\n", " \n", - " pred_labels = watershed_from_boundary_distance(\n", - " boundary_distances,\n", - " boundary_mask\n", - " )\n", + " if prediction_type == \"affs\":\n", + " boundary_distances = distance_transform_edt(boundary_mask)\n", + "\n", + " pred_labels = watershed_from_boundary_distance(\n", + " boundary_distances,\n", + " boundary_mask\n", + " )\n", + " \n", + " else:\n", + " pred_labels = relabel_cc(boundary_mask)\n", " \n", " gt_labels = imread(test_loader.dataset.mask_files[idx])\n", " \n", @@ -2279,12 +2292,17 @@ " thresh = np.mean(pred)\n", " \n", " boundary_mask = get_boundary_mask(pred, prediction_type, thresh)\n", - " boundary_distances = distance_transform_edt(boundary_mask)\n", " \n", - " pred_labels = watershed_from_boundary_distance(\n", - " boundary_distances,\n", - " boundary_mask\n", - " )\n", + " if prediction_type == \"affs\":\n", + " boundary_distances = distance_transform_edt(boundary_mask)\n", + "\n", + " pred_labels = watershed_from_boundary_distance(\n", + " boundary_distances,\n", + " boundary_mask\n", + " )\n", + " \n", + " else:\n", + " pred_labels = relabel_cc(boundary_mask)\n", " \n", " gt_labels = imread(test_loader.dataset.mask_files[idx])\n", " \n", diff --git a/solutions.ipynb b/solutions.ipynb index f52aace..bb7b306 100644 --- a/solutions.ipynb +++ b/solutions.ipynb @@ -2106,7 +2106,7 @@ "source": [ "## Task 3.0: Post-processing / further improvements\n", "\n", - "- Before we were just thresholding and relabeling connected components. Now we will see a more advanced post-processing strategy called watershed.\n", + "- Before we were just thresholding and relabeling connected components. Now we will also see a more advanced post-processing strategy called watershed which works with affs.\n", "\n", "\n", "- We also want to gauge our model performance so we will introduce some evaluation methods for instance segmentation.\n", @@ -2171,14 +2171,17 @@ " # get boundary mask\n", " boundary_mask = get_boundary_mask(pred, prediction_type, thresh=thresh)\n", " \n", - " # get boundary distances\n", - " boundary_distances = distance_transform_edt(boundary_mask)\n", - " \n", - " # get segmentation\n", - " seg = watershed_from_boundary_distance(\n", - " boundary_distances,\n", - " boundary_mask\n", - " )\n", + " # get segmentation \n", + " if prediction_type == \"affs\":\n", + " # get boundary distances\n", + " boundary_distances = distance_transform_edt(boundary_mask)\n", + " \n", + " seg = watershed_from_boundary_distance(\n", + " boundary_distances,\n", + " boundary_mask\n", + " )\n", + " else:\n", + " seg = relabel_cc(boundary_mask)\n", " \n", " gt_labels = imread(test_loader.dataset.mask_files[idx])\n", " \n", @@ -2256,12 +2259,17 @@ " thresh = np.mean(pred)\n", " \n", " boundary_mask = get_boundary_mask(pred, prediction_type, thresh=thresh)\n", - " boundary_distances = distance_transform_edt(boundary_mask)\n", " \n", - " pred_labels = watershed_from_boundary_distance(\n", - " boundary_distances,\n", - " boundary_mask\n", - " )\n", + " if prediction_type == \"affs\":\n", + " boundary_distances = distance_transform_edt(boundary_mask)\n", + "\n", + " pred_labels = watershed_from_boundary_distance(\n", + " boundary_distances,\n", + " boundary_mask\n", + " )\n", + " \n", + " else:\n", + " pred_labels = relabel_cc(boundary_mask)\n", " \n", " gt_labels = imread(test_loader.dataset.mask_files[idx])\n", " \n", @@ -2330,12 +2338,17 @@ " thresh = np.mean(thresh)\n", " \n", " boundary_mask = get_boundary_mask(pred, prediction_type, thresh)\n", - " boundary_distances = distance_transform_edt(boundary_mask)\n", " \n", - " pred_labels = watershed_from_boundary_distance(\n", - " boundary_distances,\n", - " boundary_mask\n", - " )\n", + " if prediction_type == \"affs\":\n", + " boundary_distances = distance_transform_edt(boundary_mask)\n", + "\n", + " pred_labels = watershed_from_boundary_distance(\n", + " boundary_distances,\n", + " boundary_mask\n", + " )\n", + " \n", + " else:\n", + " pred_labels = relabel_cc(boundary_mask)\n", " \n", " gt_labels = imread(test_loader.dataset.mask_files[idx])\n", " \n", @@ -2826,12 +2839,17 @@ " thresh = np.mean(pred)\n", " \n", " boundary_mask = get_boundary_mask(pred, prediction_type, thresh)\n", - " boundary_distances = distance_transform_edt(boundary_mask)\n", " \n", - " pred_labels = watershed_from_boundary_distance(\n", - " boundary_distances,\n", - " boundary_mask\n", - " )\n", + " if prediction_type == \"affs\":\n", + " boundary_distances = distance_transform_edt(boundary_mask)\n", + "\n", + " pred_labels = watershed_from_boundary_distance(\n", + " boundary_distances,\n", + " boundary_mask\n", + " )\n", + " \n", + " else:\n", + " pred_labels = relabel_cc(boundary_mask)\n", " \n", " gt_labels = imread(test_loader.dataset.mask_files[idx])\n", " \n", diff --git a/utils.py b/utils.py index a9ef195..abfa291 100644 --- a/utils.py +++ b/utils.py @@ -184,14 +184,13 @@ def watershed_from_boundary_distance( def get_boundary_mask(pred, prediction_type, thresh=None): - if prediction_type == 'two_class' or prediction_type == 'sdt': + if prediction_type == 'sdt' or prediction_type == 'two_class': # simple threshold boundary_mask = pred > thresh elif prediction_type == 'three_class': - # Return the indices of the maximum values along channel axis, then set mask to cell interior (1) - boundary_mask = np.argmax(pred, axis=0) - boundary_mask = boundary_mask == 1 + # fg = prediction greater than / equal to threshold + boundary_mask = pred[1] > thresh elif prediction_type == 'affs': # take mean of combined affs then threshold @@ -262,4 +261,4 @@ def evaluate(gt_labels: np.ndarray, pred_labels: np.ndarray, th: float = 0.5): precision = tp / max(1, tp + fp) recall = tp / max(1, tp + fn) - return ap, precision, recall, tp, fp, fn \ No newline at end of file + return ap, precision, recall, tp, fp, fn