diff --git a/exercise.ipynb b/exercise.ipynb index c921a1c..c9d0365 100644 --- a/exercise.ipynb +++ b/exercise.ipynb @@ -1,11 +1,9 @@ { "cells": [ { - "attachments": {}, "cell_type": "markdown", "id": "95227378-32db-418d-aa85-bf29eb8ad145", "metadata": { - "jp-MarkdownHeadingCollapsed": true, "tags": [] }, "source": [ @@ -25,14 +23,12 @@ "- If you are running this in jupyter lab, every markdown header is collapsible and all cells are collapsed by default. Just click on the left of a cell to expand it, and just make sure to expand until the code cells show. The headings unfortunately do not collapse in jupyter notebook, but will still give you an idea of breaks between exercises. \n", "\n", "\n", - "- Most TODOs build off the previous TODOs and require copying over classes/functions before adding more to them. It could be annoying to constantly scroll back and forth so each TODO and Task header can link to the previous and next section to make it easier to move around freely. **Note:** This only works for uncollapsed cells. If you collapse a cell containing TODO 1 and then try to return to it using the TODO 2 link it won't work.\n", - "\n", + "- Most TODOs build off the previous TODOs and require copying over classes/functions before adding more to them. It could be annoying to constantly scroll back and forth so you can click the table of contents on the left hand side to easily navigate between sections.\n", "\n", "- If you have questions please let us know. Have fun!!" ] }, { - "attachments": {}, "cell_type": "markdown", "id": "230f4cec", "metadata": { @@ -46,7 +42,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "id": "8812ec06-d766-4809-8d5f-aade999bbbc4", "metadata": { @@ -91,7 +86,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "id": "3c7bedd4-985d-41a9-b4c6-7c5669b716f0", "metadata": { @@ -101,13 +95,10 @@ "source": [ "## Task 1.0: Creating a simple model\n", "\n", - "- Let's start by creating a simple model similar to the one in the semantic segmentation exercise. We will then improve it in the subsequent checkpoints\n", - "\n", - "Click [here](#task-10-creating-a-simple-model) to go to the next task\n" + "- Let's start by creating a simple model similar to the one in the semantic segmentation exercise. We will then improve it in the subsequent checkpoints" ] }, { - "attachments": {}, "cell_type": "markdown", "id": "172f39f0-5270-4d48-a42e-61029d99ea72", "metadata": { @@ -177,11 +168,27 @@ "metadata": {}, "outputs": [], "source": [ - "# Let's define the decive we'll be using throughout the notebook\n", + "# Let's define the device we'll be using throughout the notebook\n", "\n", "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "9cfdebc1-b8d1-4ed9-9e71-cb3ec35305cf", + "metadata": {}, + "outputs": [], + "source": [ + "# convenience functions for viewing labels as rgb, and reading files into numpy arrays\n", + "from skimage import color\n", + "from skimage.io import imread\n", + "\n", + "# utility function to view labels as rgb lut with matplotlib\n", + "# eg plt.imshow(create_lut(labels))\n", + "from utils import create_lut" + ] + }, { "cell_type": "code", "execution_count": null, @@ -225,7 +232,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "id": "00ea412d-b780-4028-a963-3012a29b6127", "metadata": { @@ -256,13 +262,15 @@ "metadata": {}, "outputs": [], "source": [ - "# Let's define simple augmentation pipeline\n", + "# import albumentations library\n", + "import albumentations as A\n", "\n", "file = random.choice(train_nuclei)\n", "\n", "full_mask_nuclei = imread(file)\n", "full_raw_nuclei = imread(file.replace('_nuclei_masks', ''))[0]\n", "\n", + "# Define simple augmentation pipeline\n", "transform = A.Compose([\n", " A.RandomCrop(width=64, height=64),\n", " A.HorizontalFlip(p=0.5),\n", @@ -285,10 +293,12 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "id": "2aec8caf", - "metadata": {}, + "metadata": { + "jp-MarkdownHeadingCollapsed": true, + "tags": [] + }, "source": [ "

Task 1.3: Create fg/bg representation

\n", " \n", @@ -303,10 +313,70 @@ ] }, { - "attachments": {}, + "cell_type": "code", + "execution_count": null, + "id": "ab58aa51-348c-4802-87aa-7f40a57c9eb6", + "metadata": {}, + "outputs": [], + "source": [ + "# import utility function to erode borders\n", + "from utils import erode\n", + "\n", + "# visualize the representation (repeatedly run cell)\n", + "\n", + "file = random.choice(train_nuclei)\n", + "\n", + "full_mask_nuclei = imread(file)\n", + "full_raw_nuclei = imread(file.replace('_nuclei_masks', ''))[0]\n", + "\n", + "# Define simple augmentation pipeline\n", + "transform = A.Compose([\n", + " A.RandomCrop(width=64, height=64),\n", + " A.HorizontalFlip(p=0.5),\n", + " A.VerticalFlip(p=0.5)\n", + " ])\n", + "\n", + "transformed = transform(image=full_raw_nuclei, mask=full_mask_nuclei)\n", + " \n", + "aug_raw, aug_mask = transformed['image'], transformed['mask']\n", + "\n", + "# erode label borders\n", + "eroded_labels = erode(\n", + " aug_mask,\n", + " iterations=1,\n", + " border_value=1)\n", + "\n", + "# create fg/bg mask\n", + "labels_two_class = (eroded_labels != 0).astype(np.float32)\n", + "\n", + "# check num classes and pixel counts\n", + "print(np.unique(labels_two_class, return_counts=True))\n", + "\n", + "fig, axes = plt.subplots(1,3,figsize=(15, 15),sharex=True,sharey=True,squeeze=False)\n", + "\n", + "axes[0][0].imshow(aug_raw, cmap='gray')\n", + "axes[0][0].title.set_text('Raw')\n", + "\n", + "axes[0][0].imshow(create_lut(aug_mask), alpha=0.5)\n", + "axes[0][0].title.set_text('Labels')\n", + "\n", + "axes[0][1].imshow(aug_raw, cmap='gray')\n", + "axes[0][1].title.set_text('Raw')\n", + "\n", + "axes[0][1].imshow(create_lut(eroded_labels), alpha=0.5)\n", + "axes[0][1].title.set_text('Eroded labels')\n", + "\n", + "axes[0][2].imshow(labels_two_class)\n", + "axes[0][2].title.set_text('Foreground / background')" + ] + }, + { "cell_type": "markdown", "id": "02c3ec11", - "metadata": {}, + "metadata": { + "jp-MarkdownHeadingCollapsed": true, + "tags": [] + }, "source": [ "

Task 1.4: Create simple dataset

\n", "\n", @@ -316,10 +386,10 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "id": "f61492b1", "metadata": { + "jp-MarkdownHeadingCollapsed": true, "tags": [] }, "source": [ @@ -334,6 +404,7 @@ "* The `augment_data` method should take in a raw and mask array and return the augmented raw and mask arrays.\n", "* In the `__getitem__` method, we should only augment our data if our split is `train`. So you will need to also add split as an attribute in the `__init__` method\n", "* After augmenting your data, create a fg / bg representation\n", + "* Erode your fg / bg representation by a pixel to introduce a seperation or explicit boundary between touching labels\n", "* Make sure to return your fg/bg as float32 for training\n", "* Make sure to add a dummy channel dimension to your arrays for training (Pytorch assumes we have tensor shape batch, channel, height, width). We will add a batch dimension later once we create a data loader.\n", "* For now we will just use the nuclei channel for training, so make sure to slice the correct channel of the raw data before returning\n", @@ -363,7 +434,7 @@ " # make sure to add your crop size\n", " self.crop_size = ...\n", "\n", - " # using the root dir, split and mask create a path to files and sort it \n", + " # using root_dir, split and mask create a path to files and sort it \n", " # Hint: natsorted glob and os libraries could come in handy\n", " self.mask_files = ... # load mask files into sorted list\n", " self.raw_files = ... # load image files into sorted list\n", @@ -399,7 +470,7 @@ " \n", " # augment your data if split mode is train\n", " \n", - " mask = ... # erode your labels, cast to float32. Hint: use function that returns just the mask\n", + " mask = ... # erode your labels, cast to float32.\n", "\n", " raw = ... # add channel dimension to comply with pytorch standard (C, H, W)\n", " mask = ... # add channel dimension\n", @@ -430,19 +501,14 @@ "\n", "raw, mask = train_dataset[random.randrange(len(train_dataset))]\n", "\n", - "labels = erode(\n", - " mask,\n", - " iterations=1,\n", - " border_value=1)\n", - "\n", - "labels_two_class = (labels != 0).astype(np.float32)\n", + "labels_two_class = (mask != 0).astype(np.float32)\n", "\n", "fig, axes = plt.subplots(1,2,figsize=(10, 10),sharex=True,sharey=True,squeeze=False)\n", "\n", "axes[0][0].imshow(raw.squeeze(), cmap='gray')\n", "axes[0][0].title.set_text('Raw')\n", "\n", - "axes[0][0].imshow(create_lut(mask.squeeze().astype(int)), alpha=0.5)\n", + "axes[0][0].imshow(create_lut(relabel_cc(mask.squeeze().astype(int))), alpha=0.5)\n", "axes[0][0].title.set_text('Segmentation')\n", "\n", "axes[0][1].imshow(labels_two_class.squeeze())\n", @@ -470,7 +536,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "id": "594c20e0-4cc6-419f-a1d8-8290ba39aa9e", "metadata": { @@ -480,10 +545,10 @@ "source": [ "

Task 1.5: Create shallow network, visualize receptive field

\n", " \n", - "- Let's create a shallow two level U-Net and visualize the receptive field. We will see later how this receptive field changes as we add more layers and change our input image size\n", + "- Let's create a shallow two level U-Net and visualize the receptive field. We will see later how this receptive field changes as we add more levels and change our input image size\n", " \n", " \n", - "- The receptive field tells us how much of the image the network is looking at in each layer -- this is the amount of spatial context that the network can use to create predictions.\n", + "- The receptive field tells us how much of the image the network is looking at in each level -- this is the amount of spatial context that the network can use to create predictions.\n", " \n", " \n", "- Run the following cell to see the networks receptive field. Try changing the downsampling factors to see how it affects the receptive field (eg try combinations of [1,1], [3,3], [4,4], etc)\n", @@ -500,12 +565,16 @@ "metadata": {}, "outputs": [], "source": [ + "# import a UNet class\n", + "from unet import UNet\n", + "\n", "raw, mask = train_dataset[random.randrange(len(train_dataset))]\n", "\n", "net_t = raw\n", "fovs = []\n", "d_factors = [[2,2],[2,2]]\n", "\n", + "# create unet\n", "net = UNet(in_channels=1,\n", " num_fmaps=6,\n", " fmap_inc_factors=2,\n", @@ -513,6 +582,7 @@ " padding='same'\n", " )\n", "\n", + "# get unet fovs\n", "for level in range(len(d_factors)+1):\n", " fov_tmp, _ = net.rec_fov(level , (1, 1), 1)\n", " fovs.append(fov_tmp[0])\n", @@ -522,6 +592,7 @@ "\n", "plt.imshow(np.squeeze(raw), cmap='gray')\n", "\n", + "# visualize receptive field\n", "for idx, fov_t in enumerate(fovs):\n", " print(\"Field of view at depth {}: {:3d} (color: {})\".format(idx+1, fov_t, colors[idx]))\n", " xmin = raw.shape[1]/2 - fov_t/2\n", @@ -536,7 +607,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "id": "55352c82-130c-4452-badc-5be3e7ec6a7f", "metadata": { @@ -552,30 +622,29 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "id": "02950213", - "metadata": {}, + "metadata": { + "jp-MarkdownHeadingCollapsed": true, + "tags": [] + }, "source": [ "\n", "\n", "##### **TODO (2)**\n", " \n", "- Decide how many output channels to have, remember we are starting with a binary task\n", - "- What loss function and final layer activation should we use? Think back to the semantic segmentation exercise.\n", + "- What loss function and final level activation should we use? Think back to the semantic segmentation exercise.\n", "- What type should we ensure our tensors to be? You can see see a list of tensor types [here](https://pytorch.org/docs/stable/tensors.html) - maybe the equivalent of 32-bit floating point :)\n", "- For our model, we will create a two level U-Net with the following parameters: (to learn more about the torch layers click [here](https://pytorch.org/docs/stable/nn.html))\n", - " - downsample by a factor of 2 in each layer\n", + " - downsample by a factor of 2 in each level\n", " - single input channel\n", " - 32 input feature maps \n", - " - multiply by a factor of 2 between layers\n", + " - multiply by a factor of 2 between levels\n", " - `same` padding (this gives us the same input and output shapes) \n", " - Since our Unet will have the same number of output features as input features, we need to add a final convolution to get to our desired output feature maps. We should use a final convolution with kernel size of 1 \n", " - To see parameter defs you can run `UNet?`\n", - "- How many trainable parameters does our network have? \n", - "\n", - "* Click [here](#first-todo) if you need to go back to the previous **TODO**\n", - "* Click [here](#third-todo) if you need to go to the next **TODO**\n" + "- How many trainable parameters does our network have? \n" ] }, { @@ -623,7 +692,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "id": "3b62b36d-00f0-4812-bf59-a94a6b9eebed", "metadata": { @@ -656,23 +724,24 @@ " if train_step:\n", " optimizer.zero_grad()\n", " \n", - " # forward\n", + " # forward - pass data through model to get logits\n", " logits = model(feature)\n", " \n", " if prediction_type == \"three_class\":\n", " label=torch.squeeze(label,1)\n", " \n", - " # final activation\n", + " # pass logits through final activation to get predictions\n", " predicted = activation(logits)\n", "\n", - " # pass through loss\n", + " # pass predictions through loss, compare to ground truth\n", " loss_value = loss_fn(input=predicted, target=label)\n", " \n", - " # backward if training mode\n", + " # if training mode, backprop and optimizer step\n", " if train_step:\n", " loss_value.backward()\n", " optimizer.step()\n", "\n", + " # return outputs and loss\n", " outputs = {\n", " 'pred': predicted,\n", " 'logits': logits,\n", @@ -725,10 +794,12 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "id": "c734f964", - "metadata": {}, + "metadata": { + "jp-MarkdownHeadingCollapsed": true, + "tags": [] + }, "source": [ "\n", "\n", @@ -739,9 +810,7 @@ "- The train data loader should use a batch size of 4 (shape = 4, c, h, w) and our val/test data loaders should use a batch size of 1. Set shuffle and pin memory to `True` in the train loader. (for more info on dataloaders see [here](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader))\n", "- For now lets just train for 1000 steps\n", "- Use a learning rate of 1e-4 and an Adam optimizer\n", - "- Use the `train` function with all the required parameters to train the model\n", - "- Click [here](#second-todo) if you need to go back to the previous **TODO**\n", - "- Click [here](#fourth-todo) if you need to go to the next **TODO**" + "- Use the `train` function with all the required parameters to train the model" ] }, { @@ -783,17 +852,6 @@ "# run training loop... (eg call train)" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "540fe7f0", - "metadata": {}, - "outputs": [], - "source": [ - "# run training loop\n", - "train(train_loader, val_loader, net, loss_fn, activation, optimizer, dtype)" - ] - }, { "cell_type": "code", "execution_count": null, @@ -817,7 +875,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "id": "bba0107c-5a21-421f-8c33-fda6da798048", "metadata": { @@ -843,6 +900,12 @@ "metadata": {}, "outputs": [], "source": [ + "# utility function to relabel connected components\n", + "from utils import relabel_cc\n", + "\n", + "# function to perform otsu thresholding on predictions\n", + "from skimage.filters import threshold_otsu\n", + "\n", "# make sure net is in eval mode so we don't backprop\n", "net.eval()\n", "\n", @@ -906,7 +969,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "id": "e027aa06", "metadata": { @@ -920,7 +982,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "id": "de5f9fbe-33b7-4993-8060-91f2c44e7445", "metadata": { @@ -942,15 +1003,10 @@ " 3. Increase the input size to our network\n", " 4. Use a bigger network (eg increase layers, number of feature maps)\n", " 5. Train for longer\n", - " 6. Use a better post-processing strategy (e.g. seeded watershed)\n", - "\n", - "\n", - "- Click [here](#task-10-creating-a-simple-model) to go back to the previous task\n", - "- Click [here](#task-20-improving-the-model) to go back to the next task" + " 6. Use a better post-processing strategy (e.g. seeded watershed)" ] }, { - "attachments": {}, "cell_type": "markdown", "id": "02aa001e-797d-4d6f-9735-b81ad1af2ddc", "metadata": { @@ -1007,7 +1063,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "id": "340d8bc9-825f-4c44-91a7-c08a0206bca9", "metadata": { @@ -1041,6 +1096,9 @@ "metadata": {}, "outputs": [], "source": [ + "# utility functions to compute signed distance transform and edge affinities\n", + "from utils import compute_sdt, compute_affinities\n", + "\n", "# compute each representation and visualize\n", "\n", "file = random.choice(train_nuclei)\n", @@ -1056,17 +1114,26 @@ "transformed = transform(image=full_raw_nuclei, mask=full_mask_nuclei)\n", " \n", "aug_raw, aug_mask = transformed['image'], transformed['mask']\n", - " \n", + "\n", + "# get eroded labels\n", "labels, border = erode_border(\n", " aug_mask,\n", " iterations=1,\n", " border_value=1)\n", "\n", + "# get fg/bg classes\n", "labels_two_class = (labels != 0)\n", + "\n", + "# get border pixels\n", "border[border!=0] = 2\n", "\n", + "# combine to get three class \n", "labels_three_class = (labels_two_class + border)\n", + "\n", + "# compute signed distance transform\n", "sdt = compute_sdt(labels)\n", + "\n", + "# compute edge affinities\n", "affs = compute_affinities(labels, nhood=[[0,1],[1,0]])\n", "\n", "fig, axes = plt.subplots(1,5,figsize=(20, 10),sharex=True,sharey=True,squeeze=False)\n", @@ -1086,10 +1153,12 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "id": "ec48dc5b", - "metadata": {}, + "metadata": { + "jp-MarkdownHeadingCollapsed": true, + "tags": [] + }, "source": [ "\n", "\n", @@ -1159,6 +1228,14 @@ " def get_padding(self, crop_size, padding_size):\n", " \n", " padding = ... # calculate your padding\n", + " \n", + " # hint:\n", + " \n", + " # 1. get padding quotient (crop size / padding size)\n", + " # 2. if crop size IS NOT evenly divisible by padding size, our padding\n", + " # is equal to (padding size * (padding quotient + 1))\n", + " # 3. if crop size IS evenly divisible by padding size, out padding is\n", + " # is equal to our crop size\n", " \n", " return padding\n", "\n", @@ -1206,11 +1283,13 @@ "cell_type": "code", "execution_count": null, "id": "aeced379-98f3-4d32-b4ba-c59cf579a748", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "prediction_type = ... # try each of prediction types we defined in create_target function\n", - "crop_size = ...\n", + "crop_size = ... # try some different crop sizes, how does the crop size affect the padding? \n", "\n", "train_dataset = TissueNetDataset(\n", " root_dir='woodshole',\n", @@ -1235,7 +1314,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "id": "e33002a1-0937-4f89-a0c4-33de721960ae", "metadata": { @@ -1251,10 +1329,10 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "id": "357d9786", "metadata": { + "jp-MarkdownHeadingCollapsed": true, "tags": [] }, "source": [ @@ -1288,9 +1366,7 @@ " - We are doing regression again, what should our loss be?\n", " - Our outputs will be between 0 and 1, what activation should we use?\n", " - We can use the same dtype as Sdt\n", - " - We will have both x and y affinities, so how many channels should we have?\n", - "- Click [here](#fourth-todo) if you need to go back to the previous **TODO**\n", - "- Click [here](#sixth-todo) if you need to go to the next **TODO**" + " - We will have both x and y affinities, so how many channels should we have?" ] }, { @@ -1318,7 +1394,7 @@ " ... # get params\n", " \n", " else:\n", - " raise RuntimeError(\"invalid prediction type\")\n", + " raise ValueError('Choose from one of the following prediction types: two_class, three_class, sdt, affs')\n", " \n", " params = ... # get dict\n", " \n", @@ -1340,7 +1416,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "id": "3533b935-dea0-4115-91d6-d273362db9ad", "metadata": { @@ -1386,7 +1461,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "id": "7f795ce0-440f-4c27-ba5a-3b0638b3abca", "metadata": { @@ -1449,10 +1523,12 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "id": "a306ac4d", - "metadata": {}, + "metadata": { + "jp-MarkdownHeadingCollapsed": true, + "tags": [] + }, "source": [ "\n", "\n", @@ -1465,9 +1541,7 @@ "- How many trainable parameters do we have now? How does this compare to when we used two layers and a mult factor of 2 instead of 3?\n", "- Create your datasets and loaders as before. Increase crop patch size (eg 64 -> 128) \n", "- Use the same learning rate and optimizer as before \n", - "- Train for longer (eg 1000 -> 3000 steps)\n", - "- Click [here](#fifth-todo) if you need to go back to the previous **TODO**\n", - "- Click [here](#seventh-todo) if you need to go to the next **TODO**" + "- Train for longer (eg 1000 -> 3000 steps)" ] }, { @@ -1477,7 +1551,6 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "prediction_type = ... \n", "params = get_hyperparams(prediction_type)\n", "\n", @@ -1549,18 +1622,9 @@ "cell_type": "code", "execution_count": null, "id": "18fb5f75", - "metadata": {}, - "outputs": [], - "source": [ - "# run training loop\n", - "train(train_loader, val_loader, net, loss_fn, activation, optimizer, dtype, prediction_type)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "17e7dd63", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "# run training loop\n", @@ -1616,7 +1680,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "id": "51aad2aa", "metadata": { @@ -1630,7 +1693,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "id": "a3ffca7a-7dce-476e-99f2-c351b02aaade", "metadata": { @@ -1646,14 +1708,10 @@ "- We also want to gauge our model performance so we will introduce some evaluation methods for instance segmentation.\n", "\n", "\n", - "- Finally, up until now we were just using a single input channel to our network - but since we have multiple channels in our raw data we should leverage them. You will get the opportunity to put everything together to try to improve your model.\n", - "\n", - "- Click [here](#task-20-improving-the-model) to go to the previous task\n", - "- Click [here](#task-30-post-processing--further-improvements) to go to the next task" + "- Finally, up until now we were just using a single input channel to our network - but since we have multiple channels in our raw data we should leverage them. You will get the opportunity to put everything together to try to improve your model." ] }, { - "attachments": {}, "cell_type": "markdown", "id": "9e6665f2-1a3b-42d5-ac8c-27183cb5abce", "metadata": { @@ -1661,7 +1719,7 @@ "tags": [] }, "source": [ - "

Task 3.1: Introduce watershed

\n", + "

Task 3.1: Watershed

\n", " \n", "- Before we were just thresholding our predictions and then relabeling connected components. This is a totally fine approach in the cases where we don't have touching objects. Now we will use a better approach commonly used for instance segmentation called seeded watershed. See here for a nice overview: https://scikit-image.org/docs/stable/auto_examples/segmentation/plot_watershed.html\n", " \n", @@ -1672,7 +1730,7 @@ "- Because of this, it is often not sufficient to use watershed alone on complex datasets. In most cases the resulting objects are referred to as fragments (or supervoxels), which can then be stitched together using the underlying predictions as edge weights through a process called agglomeration.\n", " \n", " \n", - "- Agglomeration is out of the scope of this exercise, but you can find a nice overview here: https://scikit-image.org/docs/stable/auto_examples/segmentation/plot_boundary_merge.html\n", + "- Agglomeration is out of the scope of this exercise, but you can find a nice overview here: https://scikit-image.org/docs/stable/auto_examples/segmentation/plot_boundary_merge.html. A good challenge for future learning is to implement your own agglomeration using predictions and supervoxels. \n", "
" ] }, @@ -1683,6 +1741,12 @@ "metadata": {}, "outputs": [], "source": [ + "# utility functions to get boundary mask, and seg using watershed\n", + "from utils import get_boundary_mask, watershed_from_boundary_distance\n", + "\n", + "# function to compute exact euclidean distance transform\n", + "from scipy.ndimage import distance_transform_edt\n", + "\n", "# get segmentations\n", "\n", "net.eval()\n", @@ -1699,10 +1763,14 @@ " \n", " # feel free to try different thresholds\n", " thresh = np.mean(pred)\n", - " \n", + " \n", + " # get boundary mask\n", " boundary_mask = get_boundary_mask(pred, prediction_type, thresh=thresh)\n", + " \n", + " # get boundary distances\n", " boundary_distances = distance_transform_edt(boundary_mask)\n", " \n", + " # get segmentation\n", " seg = watershed_from_boundary_distance(\n", " boundary_distances,\n", " boundary_mask\n", @@ -1726,7 +1794,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "id": "f8e17356-b7df-4cb3-98ad-5231f0b915f8", "metadata": { @@ -1765,6 +1832,9 @@ "metadata": {}, "outputs": [], "source": [ + "# import utility function to evaluate several metrics\n", + "from utils import evaluate\n", + "\n", "# Evaluate on a single batch\n", "\n", "net.eval()\n", @@ -1818,7 +1888,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "id": "6ff4f94a-c301-4007-864a-8bc1b494b7cd", "metadata": { @@ -1898,7 +1967,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "id": "5f4389b9", "metadata": { @@ -1914,10 +1982,12 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "id": "8c2c19d7", - "metadata": {}, + "metadata": { + "jp-MarkdownHeadingCollapsed": true, + "tags": [] + }, "source": [ "\n", "\n", @@ -1928,9 +1998,7 @@ "- By default our raw data is (C, H, W) but albumentations expects (H,W,C) for rgb data\n", "- Our mask data is H,W by default so we need to add a channel dimension\n", "- Following augmentation, we then need to get both our raw and mask back back to (C,H,W) for training\n", - "- Make sure when creating your target representation that you handle the mask channel correctly (eg pass mask[0] instead of mask)\n", - "- Click [here](#sixth_todo) if you need to go back to the previous **TODO**\n", - "- Click [here](#eighth_todo) if you need to go to the next **TODO**" + "- Make sure when creating your target representation that you handle the mask channel correctly (eg pass mask[0] instead of mask)" ] }, { @@ -2073,18 +2141,18 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "id": "c8dd0bee", - "metadata": {}, + "metadata": { + "jp-MarkdownHeadingCollapsed": true, + "tags": [] + }, "source": [ "\n", "\n", "##### **TODO (8)** \n", " \n", - "- Create your network, hyperparameters and data loaders as last time (eg 3 levels, 3 mult factor, 128 crop size, 3k iterations, etc). Make sure to use the correct number of input channels to your network!!!\n", - "- Click [here](#seventh_todo) if you need to go back to the previous **TODO**\n", - "- Click [here](#final_todo) if you need to go to the next **TODO**" + "- Create your network, hyperparameters and data loaders as last time (eg 3 levels, 3 mult factor, 128 crop size, 3k iterations, etc). Make sure to use the correct number of input channels to your network!!!" ] }, { @@ -2173,10 +2241,10 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "id": "6e630227", "metadata": { + "jp-MarkdownHeadingCollapsed": true, "tags": [] }, "source": [ @@ -2255,7 +2323,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "id": "97ea658f-230c-4554-94ea-a719368cc1da", "metadata": { @@ -2276,7 +2343,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "id": "4b1d5765", "metadata": { @@ -2290,704 +2356,36 @@ ] }, { - "attachments": {}, - "cell_type": "markdown", - "id": "4b6ace09-2b95-4ad8-bde0-483486c16818", - "metadata": { - "jp-MarkdownHeadingCollapsed": true, - "tags": [] - }, - "source": [ - "## Task 4.0: Auxiliary learning\n", - "\n", - "- Auxiliary learning is a powerful technique that can help to improve the results of our main objective by providing a helper task. Up until now, we have only shown our model representations of the data that are boundary specific. But the data is a lot richer than that - these objects have distinct shapes that could be leveraged in order to better learn the boundaries.\n", - "\n", - "- Click [here](#task-30-post-processing--further-improvements) to go to the previous task.\n", - "- Click [here](#task-50-bonus-exercises-and-further-learning) to go to the bonus task." - ] - }, - { - "attachments": {}, "cell_type": "markdown", - "id": "0b593025-1e9a-4cf1-96ef-c177a7c450a6", + "id": "6f63a89b", "metadata": { - "jp-MarkdownHeadingCollapsed": true, "tags": [] }, "source": [ - "

Task 4.1: Cellpose

\n", - " \n", - "- In [**Cellpose**](https://cellpose.readthedocs.io/en/latest/), cells are turned into flow representations. We create these flow representations by simulating diffusion from the center of the cell to get the spatial gradients for each pixel that point towards the center of the cell. During test time, we use the flows as a dynamical system and all pixels that converge to the same point are defined as the pixels in a given cell. The flows shown below are represented by an HSV colormap used in the optic flow literature.\n", - " \n", - " \n", - "- We also predict the foregroud / background -- the two classes you predicted in exercise 1. In Cellpose we call this the cell probability. We threshold this to decide which pixels are in cells -- we only use these pixels to run the dynamical system.\n", - " \n", - " \n", - "- The flow representation allows the learning of non-convex shapes, because pixels can flow around corners. It also prevents merging, as flows for two cells that are touching are opposite.\n", - "
\n", - "\n", - "![cellpose_flows](static/cellpose_flows.png)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "1e6891e8", - "metadata": {}, - "source": [ - "\n", - "\n", - "##### **TODO (9)**\n", - "\n", - "- Use relevant pretrained model for prediction:\n", - " - get the Model class from `cellpose.models`. \n", - " **Hint** Load the weights for TissueNet from Cellpose, refer to the [documentation](https://cellpose.readthedocs.io/en/latest/models.html#other-built-in-models)\n", - " - call `model.eval` method with correct parameters\n", - " - for the nuclei model, also set the correct channels \n", - " **Hint** refer to the [documentation](https://cellpose.readthedocs.io/en/latest/settings.html#channels)\n", - "- Click [here](#eighth_todo) if you need to go back to the previous **TODO**" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8f7b37aa", - "metadata": {}, - "outputs": [], - "source": [ - "from cellpose import models\n", - "\n", - "# create a cellpose model on the gpu\n", - "# use a built-in model trained on tissuenet\n", - "# (the first time you run this cell the model will download)\n", - "\n", - "model = # get Cellpose model\n", - "\n", - "test_dataset = TissueNetDataset(root_dir='woodshole', split='test', prediction_type=prediction_type)\n", - "test_loader = DataLoader(test_dataset, batch_size=1)\n", - "\n", - "### IMPORTANT: these are the channels used for the segmentation\n", - "# the first one is the channel to segment, and the second one is the optional nuclear channel\n", - "# red = 1\n", - "# green = 2\n", - "# blue = 3\n", + "### Further learning\n", "\n", - "channels = [2, 1]\n", - "# Diameter parameter for tissuenet dataset\n", - "diameter = 25\n", + "* Instance segmentation can be challenging and this exercise just scratches the surface of what is possible.\n", "\n", - "masks_cp = []\n", - "for idx, (image, mask) in enumerate(test_loader):\n", - " image = image.cpu().detach().numpy()\n", - " mask_cp, flows, styles = # call model in evaluation mode with image, diameter and channels parameters\n", - " masks_cp.append(mask_cp)\n", - " \n", - " fig, axes = plt.subplots(1,4,figsize=(20, 20),sharex=True,sharey=True,squeeze=False)\n", - " \n", - " image = np.squeeze(image)\n", - " image = np.vstack((image, np.zeros_like(image)[:1]))\n", - " image = image.transpose(1,2,0)\n", - " \n", - " axes[0][0].imshow(image)\n", - " axes[0][0].title.set_text('Raw')\n", - " \n", - " axes[0][1].imshow(create_lut(mask_cp))\n", - " axes[0][1].title.set_text('Predicted Labels')\n", + "* There are bonus exercises (without todos) in the `bonus_exercises.ipynb` notebook. There are examples for CellPose and Local Shape Descriptors\n", "\n", - " axes[0][2].imshow(flows[0])\n", - " axes[0][2].title.set_text('Predicted cellpose')\n", - " \n", - " axes[0][3].imshow(flows[2])\n", - " axes[0][3].title.set_text('Predicted cell probability')\n", + "* This notebook assumes images that fit into memory but often times this is not the case (especially in biology). \n", + " 1. To see an example for predicting over an image in chunks and stitching the results together, see this [notebook](https://github.com/dlmbl/instance_segmentation/blob/2021-solutions/3_tile_and_stitch.ipynb)\n", + " 2. For a more advanced library that makes it easier to do machine learning on massive datasets, see gunpowder (navigate to the tutorials, or browse the API): https://funkelab.github.io/gunpowder\n", " \n", " \n", - " if idx == 2:\n", - " break" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "22bb2f3e", - "metadata": {}, - "outputs": [], - "source": [ - "# we could also use the nuclear channel ONLY and run a nuclear model in cellpose\n", + "* We did not cover more complex loss functions. Here are some nice explanations / implementations of other loss functions that are useful for instance segmentation: https://www.kaggle.com/code/bigironsphere/loss-function-library-keras-pytorch/notebook\n", "\n", - "channels = ... # set correct channels\n", - "diameter = 20\n", "\n", - "# initialize nuclei model (can also try \"cyto\" model if this doesn't work)\n", - "# the \"nuclei\" model in cellpose has been trained on lots of nuclear data (but not the tissuenet dataset)\n", - "# the \"cyto\" model in cellpose has been trained on many cellular images (but not the tissuenet dataset)\n", - "model = ... # Get the model for nuclei\n", + "* A more complex (but powerful) approach is called metric learning. This can be seen in last years [exercise](https://github.com/dlmbl/instance_segmentation/blob/2021-solutions/2_instance_segmentation.ipynb)\n", "\n", - "masks_cp = []\n", - "for idx, (image, mask) in enumerate(test_loader):\n", - " image = image.cpu().detach().numpy()\n", - " mask_cp, flows, styles = # call model in evaluation mode with image, diameter and channels parameters\n", - " masks_cp.append(mask_cp)\n", - " \n", - " fig, axes = plt.subplots(1,4,figsize=(20, 20),sharex=True,sharey=True,squeeze=False)\n", - " \n", - " image = np.squeeze(image)\n", - " image = np.vstack((image, np.zeros_like(image)[:1]))\n", - " image = image.transpose(1,2,0)\n", - " \n", - " axes[0][0].imshow(image)\n", - " axes[0][0].title.set_text('Raw')\n", - " \n", - " axes[0][1].imshow(create_lut(mask_cp))\n", - " axes[0][1].title.set_text('Predicted Labels')\n", "\n", - " axes[0][2].imshow(flows[0])\n", - " axes[0][2].title.set_text('Predicted cellpose')\n", - " \n", - " axes[0][3].imshow(flows[2])\n", - " axes[0][3].title.set_text('Predicted cell probability')\n", + "* We did not cover stardist in this tutorial, and barely scratched the surface on cellpose and lsds. For more tutorials on:\n", + " 1. Stardist: https://github.com/maweigert/tutorials/tree/main/stardist\n", + " 2. CellPose: https://github.com/MouseLand/cellpose#run-cellpose-10-without-local-python-installation\n", + " 3. LSDs: https://github.com/funkelab/lsd#notebooks\n", " \n", - " if idx == 2:\n", - " break" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "d0b50f2a", - "metadata": {}, - "source": [ - "

Checkpoint 4

\n", - "\n", - "
" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "411bab38", - "metadata": {}, - "source": [ - "## Task 5.0: Bonus exercises and further learning\n", - "\n", - "- These exercises do not have todos, feel free to run them to get a sense of a few more tricks you can use for instance segmentation. " + "### Good luck on your instance segmentation endeavors!!" ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "4593600a", - "metadata": { - "jp-MarkdownHeadingCollapsed": true, - "tags": [] - }, - "source": [ - "##### **Local Shape Descriptors**\n", - "\n", - "- Another example of auxiliary learning is [**LSDs**](https://localshapedescriptors.github.io/). This embedding encodes object shape similarly but is computed in a defined gaussian constrained to each label. This allows for consistent gradients regardless of object shapes which makes it a good candidate for segmentation of complex objects such as neurons in large electron microscopy datasets. \n", - "\n", - "\n", - "- The LSDs are combined with nearest neighbor affinities to improve the boundary representations. The improved affinities then produce nice segmentations when using a hierarchical agglomeration approach and can be easily parallelized to allow for scaling to massive volumes. \n", - "\n", - "- For this exercise we'd need to add the dependencies to the env. Please follow the readme in the [**github repo**](https://github.com/funkelab/lsd)\n", - "\n", - "![example_image](static/lsd_schematic.png)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bf570532-444e-411b-b452-7576f873d2d6", - "metadata": {}, - "outputs": [], - "source": [ - "# import lsds, calculate on a small patch and visualize the descriptor components\n", - "\n", - "from lsd.train import local_shape_descriptor\n", - "\n", - "file = random.choice(train_nuclei)\n", - "\n", - "nuclei = imread(file)[0:64, 0:64]\n", - "raw = imread(file.replace('_nuclei_masks', ''))[:, 0:64, 0:64]\n", - "\n", - "#just to visualize\n", - "raw = np.vstack((raw, np.zeros_like(raw)[:1]))\n", - "raw = raw.transpose(1,2,0)\n", - "\n", - "lsds = local_shape_descriptor.get_local_shape_descriptors(\n", - " segmentation=nuclei,\n", - " sigma=(5,)*2,\n", - " voxel_size=(1,)*2)\n", - "\n", - "fig, axes = plt.subplots(\n", - " 1,\n", - " 6,\n", - " figsize=(20, 20),\n", - " sharex=False,\n", - " sharey=True,\n", - " squeeze=False)\n", - " \n", - "axes[0][0].imshow(np.squeeze(lsds[0]), cmap='jet')\n", - "axes[0][0].title.set_text('Mean offset Y')\n", - "\n", - "axes[0][1].imshow(np.squeeze(lsds[1]), cmap='jet')\n", - "axes[0][1].title.set_text('Mean offset X')\n", - "\n", - "axes[0][2].imshow(np.squeeze(lsds[2]), cmap='jet')\n", - "axes[0][2].title.set_text('Covariance Y-Y')\n", - "\n", - "axes[0][3].imshow(np.squeeze(lsds[3]), cmap='jet')\n", - "axes[0][3].title.set_text('Covariance X-X')\n", - "\n", - "axes[0][4].imshow(np.squeeze(lsds[4]), cmap='jet')\n", - "axes[0][4].title.set_text('Covariance Y-X')\n", - "\n", - "axes[0][5].imshow(np.squeeze(lsds[5]), cmap='jet')\n", - "axes[0][5].title.set_text('Size')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "795f27af-1f0d-4d1e-a519-1ea4666793e9", - "metadata": {}, - "outputs": [], - "source": [ - "# slightly modify our dataset just for simplicity\n", - "\n", - "class TissueNetDataset(Dataset):\n", - " def __init__(self,\n", - " root_dir,\n", - " split='train',\n", - " mask='nuclei',\n", - " crop_size=None,\n", - " val_split=False,\n", - " padding_size=8\n", - " ):\n", - " \n", - " self.split = split\n", - " self.mask_files = natsorted(glob(os.path.join(root_dir, split, f'*{mask}*')))\n", - " self.raw_files = [i for i in natsorted(glob(os.path.join(root_dir, split, '*.tif'))) if 's.tif' not in i]\n", - " self.crop_size = crop_size\n", - " self.padding_size = padding_size\n", - " \n", - " if split == 'test':\n", - " if val_split:\n", - " self.mask_files = self.mask_files[:10]\n", - " self.raw_files = self.raw_files[:10]\n", - " else:\n", - " self.mask_files = self.mask_files[10:]\n", - " self.raw_files = self.raw_files[10:]\n", - "\n", - " def __len__(self):\n", - " return len(self.raw_files)\n", - " \n", - " def get_padding(self, crop_size, padding_size):\n", - " \n", - " # quotient\n", - " q = int(crop_size / padding_size)\n", - " \n", - " if crop_size % padding_size != 0:\n", - " padding = (padding_size * (q + 1))\n", - " else:\n", - " padding = crop_size\n", - " \n", - " return padding\n", - " \n", - " def augment_data(self, raw, mask, padding):\n", - " \n", - " transform = A.Compose([\n", - " A.RandomCrop(\n", - " width=self.crop_size,\n", - " height=self.crop_size),\n", - " A.PadIfNeeded(\n", - " min_height=padding,\n", - " min_width=padding,\n", - " p=1,\n", - " border_mode=0),\n", - " A.HorizontalFlip(p=0.3),\n", - " A.VerticalFlip(p=0.3),\n", - " A.RandomRotate90(p=0.3),\n", - " A.Transpose(p=0.3),\n", - " A.RandomBrightnessContrast(p=0.3)\n", - " ])\n", - "\n", - " transformed = transform(image=raw, mask=mask)\n", - "\n", - " raw, mask = transformed['image'], transformed['mask']\n", - " \n", - " return raw, mask\n", - "\n", - " def __getitem__(self, idx):\n", - " raw_file = self.raw_files[idx]\n", - " mask_file = self.mask_files[idx]\n", - " \n", - " raw = imread(raw_file)\n", - " mask = imread(mask_file)\n", - "\n", - " raw = raw.transpose([1,2,0])\n", - " \n", - " mask = np.expand_dims(mask, axis=0)\n", - " mask = mask.transpose([1,2,0])\n", - " \n", - " # just do this regardless of split to make val/test faster for demo purposes\n", - " padding = self.get_padding(self.crop_size, self.padding_size)\n", - " raw, mask = self.augment_data(raw, mask, padding)\n", - " \n", - " raw = raw.transpose([2,0,1])\n", - " mask = mask.transpose([2,0,1])\n", - " \n", - " mask, border = erode_border(\n", - " mask[0],\n", - " iterations=1,\n", - " border_value=1)\n", - "\n", - " affs = compute_affinities(mask, nhood=[[0,1],[1,0]])\n", - " \n", - " lsds = local_shape_descriptor.get_local_shape_descriptors(\n", - " segmentation=mask,\n", - " sigma=(5,)*2,\n", - " voxel_size=(1,)*2)\n", - "\n", - " lsds = lsds.astype(np.float32)\n", - " affs = affs.astype(np.float32)\n", - " \n", - " return raw, lsds, affs" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e0d178b0-b3a6-438c-ac67-5ba357a45a9e", - "metadata": {}, - "outputs": [], - "source": [ - "# visualize batch\n", - "\n", - "train_dataset = TissueNetDataset(root_dir='woodshole', split='train', crop_size=64)\n", - "\n", - "raw, lsds, affs = train_dataset[random.randrange(len(train_dataset))]\n", - "\n", - "raw = np.vstack((raw, np.zeros_like(raw)[:1]))\n", - "raw = raw.transpose(1,2,0)\n", - "\n", - "fig, axes = plt.subplots(\n", - " 1,\n", - " 7,\n", - " figsize=(20, 20),\n", - " sharex=False,\n", - " sharey=True,\n", - " squeeze=False)\n", - " \n", - "axes[0][0].imshow(np.squeeze(lsds[0]), cmap='jet')\n", - "axes[0][0].title.set_text('Mean offset Y')\n", - "\n", - "axes[0][1].imshow(np.squeeze(lsds[1]), cmap='jet')\n", - "axes[0][1].title.set_text('Mean offset X')\n", - "\n", - "axes[0][2].imshow(np.squeeze(lsds[2]), cmap='jet')\n", - "axes[0][2].title.set_text('Covariance Y-Y')\n", - "\n", - "axes[0][3].imshow(np.squeeze(lsds[3]), cmap='jet')\n", - "axes[0][3].title.set_text('Covariance X-X')\n", - "\n", - "axes[0][4].imshow(np.squeeze(lsds[4]), cmap='jet')\n", - "axes[0][4].title.set_text('Covariance Y-X')\n", - "\n", - "axes[0][5].imshow(np.squeeze(lsds[5]), cmap='jet')\n", - "axes[0][5].title.set_text('Size')\n", - "\n", - "axes[0][6].imshow(np.squeeze(affs[0]+affs[1]), cmap='jet')\n", - "axes[0][6].title.set_text('Affs')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "855fe179-8ed2-45d3-b0b6-d03f761bffd4", - "metadata": {}, - "outputs": [], - "source": [ - "# we need two output heads for our network, one for lsds and one for affinities\n", - "# to do this we will subclass torch.nn.Module and create our UNet inside\n", - "# before we had a single final convolution. Now we have one for each head.\n", - "# then in the forward pass we pass our image through our unet and then the output through each head\n", - "\n", - "class MtlsdModel(torch.nn.Module):\n", - "\n", - " def __init__(\n", - " self,\n", - " in_channels,\n", - " num_fmaps,\n", - " fmap_inc_factors,\n", - " downsample_factors,\n", - " padding='same'\n", - " ):\n", - " super().__init__()\n", - "\n", - " self.unet = UNet(\n", - " in_channels=in_channels,\n", - " num_fmaps=num_fmaps,\n", - " fmap_inc_factors=fmap_inc_factors,\n", - " downsample_factors=downsample_factors,\n", - " padding=padding)\n", - "\n", - " self.lsd_head = torch.nn.Conv2d(in_channels=num_fmaps,out_channels=6, kernel_size=1)\n", - " self.aff_head = torch.nn.Conv2d(in_channels=num_fmaps,out_channels=2, kernel_size=1)\n", - "\n", - " def forward(self, input):\n", - "\n", - " z = self.unet(input)\n", - " lsds = self.lsd_head(z)\n", - " affs = self.aff_head(z)\n", - "\n", - " return lsds, affs\n", - "\n", - "# We want to combine the lsds and affs losses and minimize the sum\n", - "# we can do this by subclassing our loss function (torch.nn.MSELoss) and overriding the forward method\n", - "\n", - "class CombinedLoss(torch.nn.MSELoss):\n", - "\n", - " def __init__(self):\n", - " super(CombinedLoss, self).__init__()\n", - "\n", - " def forward(self, lsds_prediction, lsds_target, affs_prediction, affs_target):\n", - "\n", - " loss1 = super(CombinedLoss, self).forward(lsds_prediction,lsds_target)\n", - " loss2 = super(CombinedLoss, self).forward(affs_prediction, affs_target)\n", - " \n", - " return loss1 + loss2" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "52eeeabc-09f4-4570-90fb-c4784d9c31bf", - "metadata": {}, - "outputs": [], - "source": [ - "torch.manual_seed(42)\n", - "\n", - "d_factors = [[2,2],[2,2],[2,2]]\n", - "\n", - "in_channels=2\n", - "num_fmaps=32\n", - "fmap_inc_factors=4\n", - "\n", - "net = MtlsdModel(in_channels,num_fmaps,fmap_inc_factors,d_factors)\n", - "\n", - "loss_fn = CombinedLoss().to(device)\n", - "\n", - "net = net.to(device)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "542a39b6-5c07-42bb-acde-10fcdea0a8c5", - "metadata": {}, - "outputs": [], - "source": [ - "training_steps = 3000\n", - "logdir = os.path.join(\"logs\", datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\"))\n", - "writer = SummaryWriter(logdir)\n", - "\n", - "net = net.to(device)\n", - "dtype = torch.FloatTensor\n", - "\n", - "# set optimizer\n", - "learning_rate = 1e-4\n", - "optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)\n", - "\n", - "# set activation\n", - "activation = torch.nn.Sigmoid()\n", - "\n", - "### create datasets\n", - "\n", - "train_dataset = TissueNetDataset(root_dir='woodshole', split='train', crop_size=64)\n", - "test_dataset = TissueNetDataset(root_dir='woodshole', split='test', crop_size=128)\n", - "val_dataset = TissueNetDataset(root_dir='woodshole', split='test', val_split=True, crop_size=64)\n", - "\n", - "batch_size = 4\n", - "\n", - "# make dataloaders\n", - "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)\n", - "test_loader = DataLoader(test_dataset, batch_size=1)\n", - "val_loader = DataLoader(val_dataset, batch_size=1)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c191c382-6301-46ae-84ec-cdaa54d11c19", - "metadata": {}, - "outputs": [], - "source": [ - "# update our training step to have two logits and two predictions\n", - "\n", - "def model_step(model, loss_fn, optimizer, feature, gt_lsds, gt_affs, activation, train_step=True):\n", - " \n", - " # zero gradients if training\n", - " if train_step:\n", - " optimizer.zero_grad()\n", - " \n", - " # forward\n", - " lsd_logits, affs_logits = model(feature)\n", - "\n", - " loss_value = loss_fn(lsd_logits, gt_lsds, affs_logits, gt_affs)\n", - " \n", - " # backward if training mode\n", - " if train_step:\n", - " loss_value.backward()\n", - " optimizer.step()\n", - " \n", - " lsd_output = activation(lsd_logits)\n", - " affs_output = activation(affs_logits)\n", - " \n", - " outputs = {\n", - " 'pred_lsds': lsd_output,\n", - " 'pred_affs': affs_output,\n", - " 'lsds_logits': lsd_logits,\n", - " 'affs_logits': affs_logits,\n", - " }\n", - " \n", - " return loss_value, outputs" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d9f0fb35-41e6-4ee9-a550-45ba27fb143e", - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "# update our training loop to do both lsds and affs\n", - "\n", - "# set flags\n", - "net.train() \n", - "loss_fn.train()\n", - "step = 0\n", - "\n", - "with tqdm(total=training_steps) as pbar:\n", - " while step < training_steps:\n", - " # reset data loader to get random augmentations\n", - " np.random.seed()\n", - " tmp_loader = iter(train_loader)\n", - " for feature, gt_lsds, gt_affs in tmp_loader:\n", - " gt_lsds = gt_lsds.to(device)\n", - " gt_affs = gt_affs.to(device)\n", - " feature = feature.to(device)\n", - " \n", - " #print(label.shape, feature.shape)\n", - " \n", - " loss_value, pred = model_step(net, loss_fn, optimizer, feature, gt_lsds, gt_affs, activation)\n", - " writer.add_scalar('loss',loss_value.cpu().detach().numpy(),step)\n", - " step += 1\n", - " pbar.update(1)\n", - " \n", - " if step % 100 == 0:\n", - " net.eval()\n", - " tmp_val_loader = iter(test_loader)\n", - " acc_loss = []\n", - " for feature, gt_lsds, gt_affs in tmp_val_loader: \n", - " gt_lsds = gt_lsds.to(device)\n", - " gt_affs = gt_affs.to(device)\n", - " feature = feature.to(device)\n", - " loss_value, _ = model_step(net, loss_fn, optimizer, feature, gt_lsds, gt_affs, activation, train_step=False)\n", - " acc_loss.append(loss_value.cpu().detach().numpy())\n", - " writer.add_scalar('val_loss',np.mean(acc_loss),step) \n", - " net.train()\n", - " print(np.mean(acc_loss))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7636afc0-7bed-4808-9c66-5dc52ba95796", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# visualize a few predictions - have the lsds helped to improve the affinities?\n", - "# For a future challenge you could try using a weighted combined loss and watershed + agglomeration to get strong segmentations\n", - "\n", - "net.eval()\n", - "\n", - "activation = torch.nn.Sigmoid()\n", - "\n", - "for idx, (image, gt_lsds, gt_affs) in enumerate(test_loader):\n", - " image = image.to(device)\n", - " lsds_logits, affs_logits = net(image)\n", - " pred_lsds = activation(lsds_logits)\n", - " pred_affs = activation(affs_logits)\n", - " \n", - " image = np.squeeze(image.cpu())\n", - " gt_lsds = np.squeeze(gt_lsds.cpu().numpy())\n", - " gt_affs = np.squeeze(gt_affs.cpu().numpy())\n", - " \n", - " pred_lsds = np.squeeze(pred_lsds.cpu().detach().numpy())\n", - " pred_affs = np.squeeze(pred_affs.cpu().detach().numpy())\n", - " \n", - " fig, axes = plt.subplots(1,3,figsize=(20, 20),sharex=True,sharey=True,squeeze=False)\n", - " \n", - " image = np.vstack((image, np.zeros_like(image)[:1]))\n", - " image = image.transpose(1,2,0)\n", - " \n", - " axes[0][0].imshow(image)\n", - " axes[0][0].title.set_text('Raw')\n", - " \n", - " axes[0][1].imshow(np.squeeze(pred_lsds[0]), cmap='jet')\n", - " axes[0][1].imshow(np.squeeze(pred_lsds[1]), cmap='jet', alpha=0.5)\n", - " axes[0][1].title.set_text('Mean offsets')\n", - "\n", - " axes[0][2].imshow(np.squeeze(pred_affs[0]+pred_affs[1]), cmap='jet')\n", - " axes[0][2].title.set_text('Affs')\n", - " \n", - " if idx == 2:\n", - " break" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "6f63a89b", - "metadata": { - "jp-MarkdownHeadingCollapsed": true, - "tags": [] - }, - "source": [ - "### Further learning\n", - "\n", - "* Instance segmentation can be challenging and this exercise just scratches the surface of what is possible.\n", - "\n", - "\n", - "* This notebook assumes images that fit into memory but often times this is not the case (especially in biology). \n", - " 1. To see an example for predicting over an image in chunks and stitching the results together, see this [notebook](https://github.com/dlmbl/instance_segmentation/blob/2021-solutions/3_tile_and_stitch.ipynb)\n", - " 2. For a more advanced library that makes it easier to do machine learning on massive datasets, see gunpowder (navigate to the tutorials, or browse the API): https://funkelab.github.io/gunpowder\n", - " \n", - " \n", - "* We did not cover more complex loss functions. Here are some nice explanations / implementations of other loss functions that are useful for instance segmentation: https://www.kaggle.com/code/bigironsphere/loss-function-library-keras-pytorch/notebook\n", - "\n", - "\n", - "* A more complex (but powerful) approach is called metric learning. This can be seen in last years [exercise](https://github.com/dlmbl/instance_segmentation/blob/2021-solutions/2_instance_segmentation.ipynb)\n", - "\n", - "\n", - "* We did not cover stardist in this tutorial, and barely scratched the surface on cellpose and lsds. For more tutorials on:\n", - " 1. Stardist: https://github.com/maweigert/tutorials/tree/main/stardist\n", - " 2. CellPose: https://github.com/MouseLand/cellpose#run-cellpose-10-without-local-python-installation\n", - " 3. LSDs: https://github.com/funkelab/lsd#notebooks\n", - " \n", - "### Good luck on your instance segmentation endeavors!!" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "76cf7bf2", - "metadata": {}, - "source": [] } ], "metadata": { diff --git a/solutions.ipynb b/solutions.ipynb index 2e1c399..f52aace 100644 --- a/solutions.ipynb +++ b/solutions.ipynb @@ -404,6 +404,7 @@ "* The `augment_data` method should take in a raw and mask array and return the augmented raw and mask arrays.\n", "* In the `__getitem__` method, we should only augment our data if our split is `train`. So you will need to also add split as an attribute in the `__init__` method\n", "* After augmenting your data, create a fg / bg representation\n", + "* Erode your fg / bg representation by a pixel to introduce a seperation or explicit boundary between touching labels\n", "* Make sure to return your fg/bg as float32 for training\n", "* Make sure to add a dummy channel dimension to your arrays for training (Pytorch assumes we have tensor shape batch, channel, height, width). We will add a batch dimension later once we create a data loader.\n", "* For now we will just use the nuclei channel for training, so make sure to slice the correct channel of the raw data before returning\n", @@ -433,7 +434,7 @@ " # make sure to add your crop size\n", " self.crop_size = ...\n", "\n", - " # using the root dir, split and mask create a path to files and sort it \n", + " # using root_dir, split and mask create a path to files and sort it \n", " # Hint: natsorted glob and os libraries could come in handy\n", " self.mask_files = ... # load mask files into sorted list\n", " self.raw_files = ... # load image files into sorted list\n", @@ -575,19 +576,14 @@ "\n", "raw, mask = train_dataset[random.randrange(len(train_dataset))]\n", "\n", - "labels = erode(\n", - " mask,\n", - " iterations=1,\n", - " border_value=1)\n", - "\n", - "labels_two_class = (labels != 0).astype(np.float32)\n", + "labels_two_class = (mask != 0).astype(np.float32)\n", "\n", "fig, axes = plt.subplots(1,2,figsize=(10, 10),sharex=True,sharey=True,squeeze=False)\n", "\n", "axes[0][0].imshow(raw.squeeze(), cmap='gray')\n", "axes[0][0].title.set_text('Raw')\n", "\n", - "axes[0][0].imshow(create_lut(mask.squeeze().astype(int)), alpha=0.5)\n", + "axes[0][0].imshow(create_lut(relabel_cc(mask.squeeze().astype(int))), alpha=0.5)\n", "axes[0][0].title.set_text('Segmentation')\n", "\n", "axes[0][1].imshow(labels_two_class.squeeze())\n", @@ -624,10 +620,10 @@ "source": [ "

Task 1.5: Create shallow network, visualize receptive field

\n", " \n", - "- Let's create a shallow two level U-Net and visualize the receptive field. We will see later how this receptive field changes as we add more layers and change our input image size\n", + "- Let's create a shallow two level U-Net and visualize the receptive field. We will see later how this receptive field changes as we add more levels and change our input image size\n", " \n", " \n", - "- The receptive field tells us how much of the image the network is looking at in each layer -- this is the amount of spatial context that the network can use to create predictions.\n", + "- The receptive field tells us how much of the image the network is looking at in each level -- this is the amount of spatial context that the network can use to create predictions.\n", " \n", " \n", "- Run the following cell to see the networks receptive field. Try changing the downsampling factors to see how it affects the receptive field (eg try combinations of [1,1], [3,3], [4,4], etc)\n", @@ -713,13 +709,13 @@ "##### **TODO (2)**\n", " \n", "- Decide how many output channels to have, remember we are starting with a binary task\n", - "- What loss function and final layer activation should we use? Think back to the semantic segmentation exercise.\n", + "- What loss function and final level activation should we use? Think back to the semantic segmentation exercise.\n", "- What type should we ensure our tensors to be? You can see see a list of tensor types [here](https://pytorch.org/docs/stable/tensors.html) - maybe the equivalent of 32-bit floating point :)\n", "- For our model, we will create a two level U-Net with the following parameters: (to learn more about the torch layers click [here](https://pytorch.org/docs/stable/nn.html))\n", - " - downsample by a factor of 2 in each layer\n", + " - downsample by a factor of 2 in each level\n", " - single input channel\n", " - 32 input feature maps \n", - " - multiply by a factor of 2 between layers\n", + " - multiply by a factor of 2 between levels\n", " - `same` padding (this gives us the same input and output shapes) \n", " - Since our Unet will have the same number of output features as input features, we need to add a final convolution to get to our desired output feature maps. We should use a final convolution with kernel size of 1 \n", " - To see parameter defs you can run `UNet?`\n",