From 0f805e688d27a15740d8e71dd7ef3fe793613f22 Mon Sep 17 00:00:00 2001 From: Fernando Cervantes Sanchez Date: Wed, 8 Jan 2025 11:03:20 -0500 Subject: [PATCH 01/10] Implementation of models regstration --- notebooks/README.md | 2 + notebooks/microsam_activelearning.py | 61 ++++++++++++++++++ src/napari_activelearning/__init__.py | 4 +- src/napari_activelearning/_acquisition.py | 34 ++++++---- src/napari_activelearning/_interface.py | 62 +++++++++++------- .../{_models.py => _models_impl.py} | 59 ++++++++--------- ...interface.py => _models_impl_interface.py} | 4 +- src/napari_activelearning/_tests/conftest.py | 14 +---- .../_tests/test_acquisition.py | 63 ++++++++++++++----- src/napari_activelearning/_widgets.py | 41 ++++++++---- 10 files changed, 229 insertions(+), 115 deletions(-) create mode 100644 notebooks/README.md create mode 100644 notebooks/microsam_activelearning.py rename src/napari_activelearning/{_models.py => _models_impl.py} (94%) rename src/napari_activelearning/{_models_interface.py => _models_impl_interface.py} (99%) diff --git a/notebooks/README.md b/notebooks/README.md new file mode 100644 index 0000000..8aa741d --- /dev/null +++ b/notebooks/README.md @@ -0,0 +1,2 @@ +# Example notebooks for integrating Active Learning to custom deep learning models + diff --git a/notebooks/microsam_activelearning.py b/notebooks/microsam_activelearning.py new file mode 100644 index 0000000..45166b1 --- /dev/null +++ b/notebooks/microsam_activelearning.py @@ -0,0 +1,61 @@ +from micro_sam import util +from micro_sam import automatic_segmentation as msas + +import napari_activelearning as al + + +class TunableMicroSAM(al.TunableMethodWidget): + def __init__(self): + super(TunableMicroSAM, self).__init__() + self._predictor = None + self._amg = None + + def _model_init(self): + if self._amg is not None: + return + + (self._sam_predictor, + self._sam_instance_segmenter) = msas.get_predictor_and_segmenter( + model_type='vit_t', + device=util.get_device("cpu"), + amg=True, + checkpoint=None, + stability_score_offset=1.0 + ) + + def _get_transform(self): + return lambda x: x + + def _run_pred(self, img, *args, **kwargs): + self._model_init() + + segmentation_mask = msas.automatic_instance_segmentation( + predictor=self._sam_predictor, + segmenter=self._sam_instance_segmenter, + input_path=img, + ndim=2, + verbose=False + ) + + return segmentation_mask + + def _run_eval(self, img, *args, **kwargs): + self._model_init() + + segmentation_mask = msas.automatic_instance_segmentation( + predictor=self._sam_predictor, + segmenter=self._sam_instance_segmenter, + input_path=img, + ndim=2, + verbose=False + ) + + return segmentation_mask + + def _fine_tune(self, train_data, train_labels, test_data, test_labels): + self._model_init() + return None + + +def register_microsam(): + al.register_model("micro-sam", TunableMicroSAM) diff --git a/src/napari_activelearning/__init__.py b/src/napari_activelearning/__init__.py index d14a75c..a30574a 100644 --- a/src/napari_activelearning/__init__.py +++ b/src/napari_activelearning/__init__.py @@ -1,7 +1,7 @@ from ._utils import * from ._layers import * from ._acquisition import * -from ._models import * +from ._models_impl import * from ._interface import * -from ._models_interface import * +from ._models_impl_interface import * from ._widgets import * diff --git a/src/napari_activelearning/_acquisition.py b/src/napari_activelearning/_acquisition.py index d797d6b..1bcf959 100644 --- a/src/napari_activelearning/_acquisition.py +++ b/src/napari_activelearning/_acquisition.py @@ -1,5 +1,4 @@ from typing import Optional, Iterable, Tuple, Callable, Union -from functools import partial import random from pathlib import Path import numpy as np @@ -20,8 +19,7 @@ from ._layers import ImageGroupsManager, ImageGroup, LayersGroup from ._labels import LabelsManager, LabelItem -from ._utils import (get_dataloader, save_zarr, downsample_image, - StaticPatchSampler) +from ._utils import get_dataloader, save_zarr, downsample_image def compute_BALD(probs): @@ -232,7 +230,7 @@ def segment(self, img, *args, **kwargs): return out -class FineTuningMethod: +class TunableMethod(SegmentationMethod): def __init__(self): self._num_workers = 0 super().__init__() @@ -328,15 +326,10 @@ def fine_tune(self, dataset_metadata_list: Iterable[dict], return train_data, train_labels, test_data, test_labels -class TunableMethod(SegmentationMethod, FineTuningMethod): - def __init__(self): - super().__init__() - - class AcquisitionFunction: def __init__(self, image_groups_manager: ImageGroupsManager, labels_manager: LabelsManager, - tunable_segmentation_method: TunableMethod): + tunable_segmentation_methods: dict): self._patch_sizes = {} self._max_samples = 1 self._MC_repetitions = 3 @@ -347,7 +340,8 @@ def __init__(self, image_groups_manager: ImageGroupsManager, self.image_groups_manager = image_groups_manager self.labels_manager = labels_manager - self.tunable_segmentation_method = tunable_segmentation_method + self.tunable_segmentation_method = None + self._tunable_segmentation_methods = tunable_segmentation_methods super().__init__() @@ -441,10 +435,23 @@ def _prepare_datasets_metadata( return dataset_metadata + def set_model(self, selected_model): + tunable_segmentation_method_cls =\ + self._tunable_segmentation_methods.get(selected_model, None) + + if tunable_segmentation_method_cls is not None: + self.tunable_segmentation_method =\ + tunable_segmentation_method_cls() + else: + self.tunable_segmentation_method = None + def compute_acquisition(self, dataset_metadata, acquisition_fun, segmentation_out, sampled_mask=None, segmentation_only=False): + if self.tunable_segmentation_method is None: + return + model_spatial_axes = [ ax for ax in self.model_axes @@ -563,6 +570,9 @@ def compute_acquisition_layers( segmentation_group_name: Optional[str] = "segmentation", segmentation_only: bool = False, ): + if self.tunable_segmentation_method is None: + return + if run_all: for idx in range(self.image_groups_manager.groups_root.childCount() ): @@ -769,6 +779,8 @@ def fine_tune(self): self.image_groups_manager.groups_root.child(idx), range(self.image_groups_manager.groups_root.childCount())) )) + if self.tunable_segmentation_method is None: + return if not image_groups: return False diff --git a/src/napari_activelearning/_interface.py b/src/napari_activelearning/_interface.py index ea1c915..9303559 100644 --- a/src/napari_activelearning/_interface.py +++ b/src/napari_activelearning/_interface.py @@ -834,10 +834,10 @@ def __init__(self): class AcquisitionFunctionWidget(AcquisitionFunction, QWidget): def __init__(self, image_groups_manager: ImageGroupsManagerWidget, labels_manager: LabelsManagerWidget, - tunable_segmentation_method: TunableMethodWidget): + tunable_segmentation_methods: dict): super().__init__(image_groups_manager, labels_manager, - tunable_segmentation_method) + tunable_segmentation_methods) self.patch_sizes_mspn = MultiSpinBox() @@ -855,7 +855,6 @@ def __init__(self, image_groups_manager: ImageGroupsManagerWidget, patch_sizes_chk = QCheckBox("Edit patch sizes") patch_sizes_chk.setChecked(False) - patch_sizes_chk.toggled.connect(self._show_patch_sizes) spatial_input_axes = self.input_axes if "C" in spatial_input_axes: @@ -864,41 +863,33 @@ def __init__(self, image_groups_manager: ImageGroupsManagerWidget, spatial_input_axes = "".join(spatial_input_axes) self.patch_sizes_mspn.axes = spatial_input_axes - self.patch_sizes_mspn.sizesChanged.connect(self._set_patch_size) self.max_samples_spn = QSpinBox(minimum=1, maximum=10000, value=self._max_samples, singleStep=10) - self.max_samples_spn.valueChanged.connect(self._set_max_samples) self.MC_repetitions_spn = QSpinBox(minimum=2, maximum=100, value=self._MC_repetitions, singleStep=10) - self.MC_repetitions_spn.valueChanged.connect(self._set_MC_repetitions) self.input_axes_le = QLineEdit() - self.input_axes_le.textChanged.connect(self._set_input_axes) self.input_axes_le.setText(self.input_axes) self.model_axes_le = QLineEdit() - self.model_axes_le.textChanged.connect(self._set_model_axes) self.model_axes_le.setText(self.model_axes) - self.execute_selected_btn = QPushButton("Run on selected image groups") - self.execute_selected_btn.clicked.connect( - partial(self.compute_acquisition_layers, run_all=False) - ) + self.methods_cmb = QComboBox() + self.methods_cmb.setEditable(False) + for method_name in tunable_segmentation_methods.keys(): + self.methods_cmb.addItem(method_name) + self.execute_selected_btn = QPushButton("Run on selected image groups") self.execute_all_btn = QPushButton("Run on all image groups") - self.execute_all_btn.clicked.connect( - partial(self.compute_acquisition_layers, run_all=True) - ) self.image_pb = QProgressBar() self.patch_pb = QProgressBar() self.finetuning_btn = QPushButton("Fine tune model") - self.finetuning_btn.clicked.connect(self.fine_tune) acquisition_lyt = QGridLayout() acquisition_lyt.addWidget(patch_sizes_chk, 0, 0) @@ -911,18 +902,33 @@ def __init__(self, image_groups_manager: ImageGroupsManagerWidget, acquisition_lyt.addWidget(self.input_axes_le, 4, 1) acquisition_lyt.addWidget(QLabel("Model axes"), 4, 2) acquisition_lyt.addWidget(self.model_axes_le, 4, 3) - acquisition_lyt.addWidget(self.tunable_segmentation_method, 5, 0, 1, 4) - acquisition_lyt.addWidget(self.execute_selected_btn, 6, 0) - acquisition_lyt.addWidget(self.execute_all_btn, 6, 1) - acquisition_lyt.addWidget(self.finetuning_btn, 7, 1) - acquisition_lyt.addWidget(QLabel("Image queue:"), 8, 0, 1, 1) - acquisition_lyt.addWidget(self.image_pb, 8, 1, 1, 3) - acquisition_lyt.addWidget(QLabel("Patch queue:"), 9, 0, 1, 1) - acquisition_lyt.addWidget(self.patch_pb, 9, 1, 1, 3) + acquisition_lyt.addWidget(self.methods_cmb, 5, 0, 1, 4) + acquisition_lyt.addWidget(self.execute_selected_btn, 7, 0) + acquisition_lyt.addWidget(self.execute_all_btn, 7, 1) + acquisition_lyt.addWidget(self.finetuning_btn, 8, 1) + acquisition_lyt.addWidget(QLabel("Image queue:"), 9, 0, 1, 1) + acquisition_lyt.addWidget(self.image_pb, 9, 1, 1, 3) + acquisition_lyt.addWidget(QLabel("Patch queue:"), 10, 0, 1, 1) + acquisition_lyt.addWidget(self.patch_pb, 10, 1, 1, 3) self.setLayout(acquisition_lyt) self.patch_sizes_widget.setVisible(False) + patch_sizes_chk.toggled.connect(self._show_patch_sizes) + self.patch_sizes_mspn.sizesChanged.connect(self._set_patch_size) + self.max_samples_spn.valueChanged.connect(self._set_max_samples) + self.MC_repetitions_spn.valueChanged.connect(self._set_MC_repetitions) + self.input_axes_le.textChanged.connect(self._set_input_axes) + self.model_axes_le.textChanged.connect(self._set_model_axes) + self.methods_cmb.currentIndexChanged.connect(self._set_model) + self.execute_selected_btn.clicked.connect( + partial(self.compute_acquisition_layers, run_all=False) + ) + self.execute_all_btn.clicked.connect( + partial(self.compute_acquisition_layers, run_all=True) + ) + self.finetuning_btn.clicked.connect(self.fine_tune) + def _show_patch_sizes(self, show: bool): self.patch_sizes_widget.setVisible(show) @@ -955,3 +961,11 @@ def _set_input_axes(self): def _set_model_axes(self): self.model_axes = self.model_axes_le.text() + + def _set_model(self, selected_model_index: int): + if self.tunable_segmentation_method is not None: + self.layout().removeWidget(self.tunable_segmentation_method) + self.tunable_segmentation_method.deleteLater() + + self.set_model(self.methods_cmb.itemText(selected_model_index)) + self.layout().addWidget(self.tunable_segmentation_method, 6, 0, 1, 4) diff --git a/src/napari_activelearning/_models.py b/src/napari_activelearning/_models_impl.py similarity index 94% rename from src/napari_activelearning/_models.py rename to src/napari_activelearning/_models_impl.py index 94c5345..04a01d8 100644 --- a/src/napari_activelearning/_models.py +++ b/src/napari_activelearning/_models_impl.py @@ -4,8 +4,7 @@ import zarrdataset as zds -from ._acquisition import SegmentationMethod, FineTuningMethod, add_dropout - +from ._acquisition import TunableMethod, add_dropout try: import cellpose @@ -33,7 +32,7 @@ def _compute_transform(self, image: np.ndarray) -> np.ndarray: axis=self._channel_axis) return img_t - class CellposeSegmentation(SegmentationMethod): + class CellposeTunable(TunableMethod): def __init__(self): super().__init__() @@ -41,7 +40,6 @@ def __init__(self): self._model_dropout = None self.refresh_model = True - self._transform = None self._pretrained_model = None @@ -50,6 +48,25 @@ def __init__(self): self._channel_axis = 2 self._channels = [0, 0] + self._batch_size = 8 + self._learning_rate = 0.005 + self._n_epochs = 20 + self._weight_decay = 1e-5 + self._momentum = 0.9 + self._SGD = False + self._rgb = False + self._normalize = True + self._compute_flows = False + self._save_path = None + self._save_every = 100 + self._nimg_per_epoch = None + self._nimg_test_per_epoch = None + self._rescale = True + self._scale_range = None + self._bsize = 224 + self._min_train_masks = 5 + self._model_name = None + def _model_init(self): gpu = torch.cuda.is_available() and self._gpu if self._pretrained_model is None: @@ -70,7 +87,9 @@ def _model_init(self): self._model_dropout = models.CellposeModel( gpu=gpu, model_type=model_type, - pretrained_model=str(self._pretrained_model) + pretrained_model=(str(self._pretrained_model) + if self._pretrained_model is not None + else None) ) self._model_dropout.mkldnn = False self._model_dropout.net.mkldnn = False @@ -112,29 +131,6 @@ def _run_eval(self, img, *args, **kwargs): channels=self._channels) return seg - class CellposeTunable(CellposeSegmentation, FineTuningMethod): - def __init__(self): - super().__init__() - - self._batch_size = 8 - self._learning_rate = 0.005 - self._n_epochs = 20 - self._weight_decay = 1e-5 - self._momentum = 0.9 - self._SGD = False - self._rgb = False - self._normalize = True - self._compute_flows = False - self._save_path = None - self._save_every = 100 - self._nimg_per_epoch = None - self._nimg_test_per_epoch = None - self._rescale = True - self._scale_range = None - self._bsize = 224 - self._min_train_masks = 5 - self._model_name = None - def _get_transform(self): if self._transform is None: self._transform = CellposeTransform(self._channels, @@ -208,7 +204,7 @@ def _compute_transform(self, image: np.ndarray) -> np.ndarray: return image_t -class SimpleSegmentation(SegmentationMethod): +class SimpleTunable(TunableMethod): def __init__(self): super().__init__() self._channel_axis = 2 @@ -236,11 +232,6 @@ def _run_eval(self, img, *args, **kwargs): labels = skimage.measure.label(img_g > self._threshold) return labels - -class SimpleTunable(SimpleSegmentation, FineTuningMethod): - def __init__(self): - super().__init__() - def _get_transform(self): if self._transform is None: self._model_init() diff --git a/src/napari_activelearning/_models_interface.py b/src/napari_activelearning/_models_impl_interface.py similarity index 99% rename from src/napari_activelearning/_models_interface.py rename to src/napari_activelearning/_models_impl_interface.py index 43ac774..0a2f573 100644 --- a/src/napari_activelearning/_models_interface.py +++ b/src/napari_activelearning/_models_impl_interface.py @@ -5,10 +5,10 @@ from qtpy.QtWidgets import QWidget, QGridLayout, QScrollArea, QCheckBox from functools import partial -from ._models import USING_CELLPOSE, SimpleTunable +from ._models_impl import USING_CELLPOSE, SimpleTunable if USING_CELLPOSE: - from ._models import CellposeTunable + from ._models_impl import CellposeTunable def cellpose_segmentation_parameters_widget(): @magicgui(auto_call=True) diff --git a/src/napari_activelearning/_tests/conftest.py b/src/napari_activelearning/_tests/conftest.py index 378f35f..fc59b7c 100644 --- a/src/napari_activelearning/_tests/conftest.py +++ b/src/napari_activelearning/_tests/conftest.py @@ -1,5 +1,5 @@ import pytest -from unittest.mock import MagicMock, patch +from unittest.mock import patch from pathlib import Path, PureWindowsPath import numpy as np import zarr @@ -7,24 +7,12 @@ from napari.layers import Image from napari.layers._source import Source -from napari.layers._multiscale_data import MultiScaleData from napari_activelearning._layers import (LayerChannel, LayersGroup, ImageGroup, ImageGroupsManager) from napari_activelearning._labels import LabelsManager, LabelGroup, LabelItem -from napari_activelearning._acquisition import AcquisitionFunction -from napari_activelearning._models import SimpleTunable - - -@pytest.fixture(scope="package") -def tunable_segmentation_method(): - method = SimpleTunable() - method._run_pred = MagicMock(return_value=np.random.random((10, 10))) - method._run_eval = MagicMock(return_value=np.random.randint(0, 2, (10, 10))) - method._fine_tune = MagicMock(return_value=None) - return method @pytest.fixture(scope="package") diff --git a/src/napari_activelearning/_tests/test_acquisition.py b/src/napari_activelearning/_tests/test_acquisition.py index 4e0372e..b66630e 100644 --- a/src/napari_activelearning/_tests/test_acquisition.py +++ b/src/napari_activelearning/_tests/test_acquisition.py @@ -1,10 +1,11 @@ -from unittest.mock import patch +from unittest.mock import patch, MagicMock import numpy as np from napari_activelearning._acquisition import (AcquisitionFunction, compute_acquisition_fun, compute_segmentation, - add_multiscale_output_layer) + add_multiscale_output_layer, + TunableMethod) from napari_activelearning._layers import LayerChannel try: @@ -14,10 +15,31 @@ USING_PYTORCH = False -def test_compute_acquisition_fun(tunable_segmentation_method): +class TestTunableMethod(TunableMethod): + def __init__(self): + super(TestTunableMethod, self).__init__() + + def _get_transform(self): + return lambda x: x + + def _run_pred(self, img, *args, **kwargs): + return np.random.random((10, 10)) + + def _run_eval(self, img, *args, **kwargs): + return np.random.randint(0, 2, (10, 10)) + + def _fine_tune(self, train_data, train_labels, test_data, test_labels): + return None + + +def test_compute_acquisition_fun(): img = np.random.random((10, 10, 3)) img_sp = np.random.random((10, 10)) MC_repetitions = 3 + tunable_segmentation_method = TunableMethod() + tunable_segmentation_method._run_pred = MagicMock( + return_value=np.random.random((10, 10)) + ) result = compute_acquisition_fun(tunable_segmentation_method, img, img_sp, MC_repetitions) @@ -25,9 +47,14 @@ def test_compute_acquisition_fun(tunable_segmentation_method): assert tunable_segmentation_method._run_pred.call_count == MC_repetitions -def test_compute_segmentation(tunable_segmentation_method): +def test_compute_segmentation(): img = np.random.random((1, 1, 1, 10, 10, 3)) labels_offset = 1 + tunable_segmentation_method = TunableMethod() + tunable_segmentation_method._run_eval = MagicMock( + return_value=np.random.randint(0, 2, (10, 10)) + ) + result = compute_segmentation(tunable_segmentation_method, img, labels_offset) expected_segmentation = tunable_segmentation_method.segment(img) @@ -39,14 +66,15 @@ def test_compute_segmentation(tunable_segmentation_method): def test_compute_acquisition(image_groups_manager, labels_manager, - tunable_segmentation_method, make_napari_viewer): viewer = make_napari_viewer() viewer.dims.axis_labels = ['t', 'z', 'y', 'x'] - acquisition_function = AcquisitionFunction(image_groups_manager, - labels_manager, - tunable_segmentation_method) + acquisition_function = AcquisitionFunction( + image_groups_manager, + labels_manager, + {"test": TestTunableMethod} + ) dataset_metadata = { "images": {"source_axes": "TCZYX", "axes": "TZYXC"}, @@ -56,6 +84,7 @@ def test_compute_acquisition(image_groups_manager, labels_manager, segmentation_out = np.zeros((1, 1, 1, 10, 10)) segmentation_only = False + acquisition_function.set_model("test") acquisition_function.input_axes = "TZYX" acquisition_function.model_axes = "YXC" acquisition_function.patch_sizes = {"T": 1, "Z": 1, "Y": 10, "X": 10} @@ -126,7 +155,6 @@ def test_add_multiscale_output_layer(single_scale_type_variant_array, def test_prepare_datasets_metadata(image_groups_manager, labels_manager, - tunable_segmentation_method, simple_image_group, make_napari_viewer): image_group, _, _ = simple_image_group @@ -135,10 +163,12 @@ def test_prepare_datasets_metadata(image_groups_manager, labels_manager, viewer = make_napari_viewer() viewer.dims.axis_labels = ['t', 'z', 'y', 'x'] - acquisition_function = AcquisitionFunction(image_groups_manager, - labels_manager, - tunable_segmentation_method) + acquisition_function = AcquisitionFunction( + image_groups_manager, + labels_manager, + {"test": TestTunableMethod}) + acquisition_function.set_model("test") acquisition_function._patch_sizes = {"T": 1, "X": 5, "Y": 5, "Z": 1} acquisition_function.input_axes = "TZYX" acquisition_function.model_axes = "YXC" @@ -187,7 +217,6 @@ def test_prepare_datasets_metadata(image_groups_manager, labels_manager, def test_compute_acquisition_layers(image_groups_manager, labels_manager, - tunable_segmentation_method, make_napari_viewer, simple_image_group, labels_group, @@ -224,8 +253,9 @@ def test_compute_acquisition_layers(image_groups_manager, labels_manager, acquisition_function = AcquisitionFunction( image_groups_manager, labels_manager, - tunable_segmentation_method) + {"test": TestTunableMethod}) + acquisition_function.set_model("test") acquisition_function._patch_sizes = {"T": 1, "X": 5, "Y": 5, "Z": 1} acquisition_function.input_axes = "TZYX" acquisition_function.model_axes = "YXC" @@ -249,7 +279,6 @@ def test_compute_acquisition_layers(image_groups_manager, labels_manager, def test_fine_tune(image_groups_manager, simple_image_group, labels_manager, - tunable_segmentation_method, multiscale_layer_channel, multiscale_layers_group, labels_group, @@ -289,15 +318,15 @@ def test_fine_tune(image_groups_manager, simple_image_group, acquisition_function = AcquisitionFunction( image_groups_manager, labels_manager, - tunable_segmentation_method + {"test": TestTunableMethod} ) + acquisition_function.set_model("test") acquisition_function._patch_sizes = {"T": 1, "X": 10, "Y": 10, "Z": 1} acquisition_function.input_axes = "TZYX" acquisition_function.model_axes = "YXC" assert acquisition_function.fine_tune() - assert tunable_segmentation_method._fine_tune.called image_group.removeChild(multiscale_layers_group) image_group.labels_group = None diff --git a/src/napari_activelearning/_widgets.py b/src/napari_activelearning/_widgets.py index 6591369..3780f1f 100644 --- a/src/napari_activelearning/_widgets.py +++ b/src/napari_activelearning/_widgets.py @@ -1,23 +1,42 @@ +from ._acquisition import TunableMethod from ._interface import (ImageGroupsManagerWidget, LabelsManagerWidget, AcquisitionFunctionWidget) -from ._models import USING_CELLPOSE -from ._models_interface import SimpleTunableWidget +from ._models_impl import USING_CELLPOSE +from ._models_impl_interface import SimpleTunableWidget -if USING_CELLPOSE: - from ._models_interface import CellposeTunableWidget - - SEGMENTATION_METHOD_CLASS = CellposeTunableWidget - -else: - SEGMENTATION_METHOD_CLASS = SimpleTunableWidget CURRENT_IMAGE_GROUPS_MANAGER = None CURRENT_LABEL_GROUPS_MANAGER = None CURRENT_SEGMENTATION_METHOD = None CURRENT_ACQUISITION_FUNCTION = None +models_registry: dict[str] = { + "None selected": None +} + + +def register_model(model_name: str, model: TunableMethod): + global CURRENT_ACQUISITION_FUNCTION + if model_name in models_registry: + return + + models_registry[model_name] = model + + if CURRENT_ACQUISITION_FUNCTION is not None: + CURRENT_ACQUISITION_FUNCTION = AcquisitionFunctionWidget( + image_groups_manager=get_image_groups_manager_widget(), + labels_manager=get_label_groups_manager_widget(), + tunable_segmentation_methods=models_registry, + ) + + +register_model("simple", SimpleTunableWidget) +if USING_CELLPOSE: + from ._models_impl_interface import CellposeTunableWidget + register_model("cellpose", CellposeTunableWidget) + def get_image_groups_manager_widget(): global CURRENT_IMAGE_GROUPS_MANAGER @@ -43,12 +62,10 @@ def get_acquisition_function_widget(): global CURRENT_ACQUISITION_FUNCTION if CURRENT_ACQUISITION_FUNCTION is None: - segmentation_method = SEGMENTATION_METHOD_CLASS() - CURRENT_ACQUISITION_FUNCTION = AcquisitionFunctionWidget( image_groups_manager=get_image_groups_manager_widget(), labels_manager=get_label_groups_manager_widget(), - tunable_segmentation_method=segmentation_method, + tunable_segmentation_methods=models_registry, ) return CURRENT_ACQUISITION_FUNCTION From 166b274787d0e223bcaf189debef4ea4e46a396e Mon Sep 17 00:00:00 2001 From: Fernando Cervantes Sanchez Date: Wed, 8 Jan 2025 13:33:44 -0500 Subject: [PATCH 02/10] Updated drop out insertion method --- notebooks/microsam_activelearning.py | 16 ++++++++++-- src/napari_activelearning/_acquisition.py | 30 ++++++++++++++++++++++- 2 files changed, 43 insertions(+), 3 deletions(-) diff --git a/notebooks/microsam_activelearning.py b/notebooks/microsam_activelearning.py index 45166b1..23ce20c 100644 --- a/notebooks/microsam_activelearning.py +++ b/notebooks/microsam_activelearning.py @@ -23,6 +23,18 @@ def _model_init(self): stability_score_offset=1.0 ) + (self._sam_predictor_dropout, + self._sam_instance_segmenter_dropout) =\ + msas.get_predictor_and_segmenter( + model_type='vit_t', + device=util.get_device("cpu"), + amg=True, + checkpoint=None, + stability_score_offset=1.0) + + al.add_dropout(self._sam_predictor_dropout.model) + + def _get_transform(self): return lambda x: x @@ -43,8 +55,8 @@ def _run_eval(self, img, *args, **kwargs): self._model_init() segmentation_mask = msas.automatic_instance_segmentation( - predictor=self._sam_predictor, - segmenter=self._sam_instance_segmenter, + predictor=self._sam_predictor_dropout, + segmenter=self._sam_instance_segmenter_dropout, input_path=img, ndim=2, verbose=False diff --git a/src/napari_activelearning/_acquisition.py b/src/napari_activelearning/_acquisition.py index 1bcf959..3237b74 100644 --- a/src/napari_activelearning/_acquisition.py +++ b/src/napari_activelearning/_acquisition.py @@ -197,7 +197,35 @@ def add_dropout(net, p=0.05): for module in net.modules(): if isinstance(module, torch.nn.Sequential): for l_idx, layer in enumerate(module): - if isinstance(layer, torch.nn.ReLU): + if isinstance(layer, (torch.nn.Threshold, + torch.nn.ReLU, + torch.nn.RReLU, + torch.nn.Hardtanh, + torch.nn.ReLU6, + torch.nn.Sigmoid, + torch.nn.Hardsigmoid, + torch.nn.Tanh, + torch.nn.SiLU, + torch.nn.Mish, + torch.nn.Hardswish, + torch.nn.ELU, + torch.nn.CELU, + torch.nn.SELU, + torch.nn.GLU, + torch.nn.GELU, + torch.nn.Hardshrink, + torch.nn.LeakyReLU, + torch.nn.LogSigmoid, + torch.nn.Softplus, + torch.nn.Softshrink, + torch.nn.MultiheadAttention, + torch.nn.PReLU, + torch.nn.Softsign, + torch.nn.Tanhshrink, + torch.nn.Softmin, + torch.nn.Softmax, + torch.nn.Softmax2d, + torch.nn.LogSoftmax)): break else: continue From c5aa2e713c338d8026ecc7d012aa1c2e7d8b08f0 Mon Sep 17 00:00:00 2001 From: Fernando Cervantes Sanchez Date: Thu, 9 Jan 2025 14:06:38 -0500 Subject: [PATCH 03/10] Updated probability computation for micro-sam --- notebooks/microsam_activelearning.py | 58 ++++++++++++++++++----- src/napari_activelearning/_acquisition.py | 4 ++ 2 files changed, 51 insertions(+), 11 deletions(-) diff --git a/notebooks/microsam_activelearning.py b/notebooks/microsam_activelearning.py index 23ce20c..32d924d 100644 --- a/notebooks/microsam_activelearning.py +++ b/notebooks/microsam_activelearning.py @@ -1,6 +1,9 @@ +import numpy as np +import torch +import time + from micro_sam import util from micro_sam import automatic_segmentation as msas - import napari_activelearning as al @@ -17,7 +20,9 @@ def _model_init(self): (self._sam_predictor, self._sam_instance_segmenter) = msas.get_predictor_and_segmenter( model_type='vit_t', - device=util.get_device("cpu"), + device=util.get_device("cuda" + if torch.cuda.is_available() + else "cpu"), amg=True, checkpoint=None, stability_score_offset=1.0 @@ -27,7 +32,9 @@ def _model_init(self): self._sam_instance_segmenter_dropout) =\ msas.get_predictor_and_segmenter( model_type='vit_t', - device=util.get_device("cpu"), + device=util.get_device("cuda" + if torch.cuda.is_available() + else "cpu"), amg=True, checkpoint=None, stability_score_offset=1.0) @@ -41,26 +48,55 @@ def _get_transform(self): def _run_pred(self, img, *args, **kwargs): self._model_init() - segmentation_mask = msas.automatic_instance_segmentation( - predictor=self._sam_predictor, - segmenter=self._sam_instance_segmenter, - input_path=img, + e_time = time.perf_counter() + img_embeddings = util.precompute_image_embeddings( + predictor=self._sam_predictor_dropout, + input_=img, + save_path=None, ndim=2, - verbose=False + tile_shape=None, + halo=None, + verbose=False, ) + e_time = time.perf_counter() - e_time - return segmentation_mask + e_time = time.perf_counter() + self._sam_instance_segmenter_dropout.initialize( + image=img, + image_embeddings=img_embeddings + ) + e_time = time.perf_counter() - e_time + + e_time = time.perf_counter() + masks = self._sam_instance_segmenter_dropout.generate() + e_time = time.perf_counter() - e_time + + e_time = time.perf_counter() + probs = np.zeros(img.shape[:2], dtype=np.float32) + for mask in masks: + probs = np.where( + mask["segmentation"], + mask["predicted_iou"], + probs + ) + e_time = time.perf_counter() - e_time + + probs = torch.from_numpy(probs).sigmoid().numpy() + + return probs def _run_eval(self, img, *args, **kwargs): self._model_init() + e_time = time.perf_counter() segmentation_mask = msas.automatic_instance_segmentation( - predictor=self._sam_predictor_dropout, - segmenter=self._sam_instance_segmenter_dropout, + predictor=self._sam_predictor, + segmenter=self._sam_instance_segmenter, input_path=img, ndim=2, verbose=False ) + e_time = time.perf_counter() - e_time return segmentation_mask diff --git a/src/napari_activelearning/_acquisition.py b/src/napari_activelearning/_acquisition.py index 3237b74..61a7abe 100644 --- a/src/napari_activelearning/_acquisition.py +++ b/src/napari_activelearning/_acquisition.py @@ -46,6 +46,9 @@ def compute_acquisition_superpixel(probs, super_pixel_labels): for sp_l in super_pixel_indices: mask = super_pixel_labels == sp_l u_val = np.sum(mutual_info[mask]) / np.sum(mask) + if np.isnan(u_val): + u_val = 0.0 + u_sp_lab = np.where(mask, u_val, u_sp_lab) return u_sp_lab @@ -232,6 +235,7 @@ def add_dropout(net, p=0.05): dropout_layer = torch.nn.Dropout(p=p, inplace=True) module.insert(l_idx + 1, DropoutEvalOverrider(dropout_layer)) + else: def add_dropout(net, p=0.05): pass From 69b7255cd4892f3d7d3fbbaa2db937f7e0f67be4 Mon Sep 17 00:00:00 2001 From: Fernando Cervantes Sanchez Date: Fri, 10 Jan 2025 16:52:13 -0500 Subject: [PATCH 04/10] Moved to data loaders for fine tuning --- notebooks/microsam_activelearning.py | 46 +++- src/napari_activelearning/_acquisition.py | 203 ++++++++++++------ src/napari_activelearning/_interface.py | 2 + src/napari_activelearning/_models_impl.py | 47 +++- .../_models_impl_interface.py | 7 +- .../_tests/test_acquisition.py | 9 +- 6 files changed, 230 insertions(+), 84 deletions(-) diff --git a/notebooks/microsam_activelearning.py b/notebooks/microsam_activelearning.py index 32d924d..d27fc53 100644 --- a/notebooks/microsam_activelearning.py +++ b/notebooks/microsam_activelearning.py @@ -1,3 +1,5 @@ +from typing import Iterable, Union + import numpy as np import torch import time @@ -43,7 +45,7 @@ def _model_init(self): def _get_transform(self): - return lambda x: x + return lambda x: x, None def _run_pred(self, img, *args, **kwargs): self._model_init() @@ -100,9 +102,47 @@ def _run_eval(self, img, *args, **kwargs): return segmentation_mask - def _fine_tune(self, train_data, train_labels, test_data, test_labels): + # def _fine_tune(self, data_loader, + # train_data_proportion: float = 0.8) -> bool: + def _fine_tune(self, train_dataloader, val_dataloader) -> bool: self._model_init() - return None + + # self._pretrained_model = train.train_seg( + # self._model.net, + # train_data=train_data, + # train_labels=train_labels, + # train_probs=None, + # test_data=test_data, + # test_labels=test_labels, + # test_probs=None, + # load_files=False, + # batch_size=self._batch_size, + # learning_rate=self._learning_rate, + # n_epochs=self._n_epochs, + # weight_decay=self._weight_decay, + # momentum=self._momentum, + # SGD=self._SGD, + # channels=self._channels, + # channel_axis=self._channel_axis, + # rgb=self._rgb, + # normalize=self._normalize, + # compute_flows=self._compute_flows, + # save_path=self._save_path, + # save_every=self._save_every, + # nimg_per_epoch=self._nimg_per_epoch, + # nimg_test_per_epoch=self._nimg_test_per_epoch, + # rescale=self._rescale, + # scale_range=self._scale_range, + # bsize=self._bsize, + # min_train_masks=self._min_train_masks, + # model_name=self._model_name + # ) + + # if isinstance(self._pretrained_model, tuple): + # self._pretrained_model = self._pretrained_model[0] + + # self.refresh_model = True + return True def register_microsam(): diff --git a/src/napari_activelearning/_acquisition.py b/src/napari_activelearning/_acquisition.py index 61a7abe..751433c 100644 --- a/src/napari_activelearning/_acquisition.py +++ b/src/napari_activelearning/_acquisition.py @@ -1,15 +1,16 @@ from typing import Optional, Iterable, Tuple, Callable, Union -import random + from pathlib import Path import numpy as np +import random import math +import dask.array as da import zarrdataset as zds -import dask.array as da try: import torch - from torch.utils.data import DataLoader + from torch.utils.data import DataLoader, ChainDataset, random_split USING_PYTORCH = True except ModuleNotFoundError: USING_PYTORCH = False @@ -88,6 +89,7 @@ def add_multiscale_output_layer( contrast_limits: Optional[Iterable[float]] = None, colormap: Optional[str] = None, use_as_input_labels: bool = False, + use_as_sampling_mask: bool = False, add_func: Optional[Callable] = napari.Viewer.add_image ): if output_filename: @@ -117,7 +119,7 @@ def add_multiscale_output_layer( opacity=0.8, scale=list(scale.values()), translate=tuple( - reference_scale.get(ax, 1) / 2.0 + (reference_scale.get(ax, 1) - 1) / 2.0 if reference_scale.get(ax, 1) > 1 else 0 for ax in reference_source_axes ), @@ -145,7 +147,7 @@ def add_multiscale_output_layer( source_axes=axes, use_as_input_image=False, use_as_input_labels=use_as_input_labels, - use_as_sampling_mask=False + use_as_sampling_mask=use_as_sampling_mask ) output_channel = output_layers_group.add_layer( @@ -271,91 +273,141 @@ def _get_transform(self): raise NotImplementedError("This method requies to be overriden by a " "derived class.") - def _fine_tune(self, train_data, train_labels, test_data, test_labels): + # def _fine_tune(self, dataset_metadata_list: Iterable[dict], + # train_data_proportion: float = 0.8) -> bool: + def _fine_tune(self, train_dataloader, val_dataloader) -> bool: raise NotImplementedError("This method requies to be overriden by a " "derived class.") def fine_tune(self, dataset_metadata_list: Iterable[dict], train_data_proportion: float = 0.8, - patch_sizes: Union[dict, int] = 256, - model_axes="YXC"): - train_data = [] - test_data = [] - train_labels = [] - test_labels = [] + patch_sizes: Union[dict, int] = 256): + + transform, labels_transform = self._get_transform() - transform = self._get_transform() + worker_init_fn = None + + if len(dataset_metadata_list) == 1: + sampling_mask = np.copy( + dataset_metadata_list[0]["masks"]["filenames"] + ) + + sampling_locations = np.nonzero(sampling_mask) + sampling_locations = np.ravel_multi_index(sampling_locations, + sampling_mask.shape) + sampling_locations = np.random.choice( + sampling_locations, + size=int(train_data_proportion * len(sampling_locations)), + replace=False + ) + sampling_locations = np.unravel_index(sampling_locations, + sampling_mask.shape) + + train_mask = np.zeros_like(sampling_mask) + train_mask[sampling_locations] = True + val_mask = np.bitwise_xor(train_mask, sampling_mask) - for dataset_metadata in dataset_metadata_list: patch_sampler = zds.PatchSampler( patch_size=patch_sizes, - spatial_axes=dataset_metadata["labels"]["axes"], + spatial_axes=dataset_metadata_list[0]["labels"]["axes"], min_area=0.01 ) - dataset = zds.ZarrDataset( - list(dataset_metadata.values()), + dataset_metadata_list[0]["masks"]["filenames"] = train_mask + + train_datasets = zds.ZarrDataset( + list(dataset_metadata_list[0].values()), return_positions=False, draw_same_chunk=False, patch_sampler=patch_sampler, shuffle=True, ) - dataset.add_transform("images", zds.ToDtype(np.float32)) - dataset.add_transform("labels", zds.ToDtype(np.int32)) + dataset_metadata_list[0]["masks"]["filenames"] = val_mask - if USING_PYTORCH: - dataloader = DataLoader( - dataset, - num_workers=self._num_workers, - worker_init_fn=zds.zarrdataset_worker_init_fn - ) - else: - dataloader = dataset - - drop_axis = tuple( - ax_idx - for ax_idx, ax in enumerate( - dataset_metadata["images"]["axes"]) - if ax != "C" and ax not in model_axes + val_datasets = zds.ZarrDataset( + list(dataset_metadata_list[0].values()), + return_positions=False, + draw_same_chunk=False, + patch_sampler=patch_sampler, + shuffle=True, ) - drop_label_axis = tuple( - ax_idx - for ax_idx, ax in enumerate( - dataset_metadata["labels"]["axes"]) - if ax != "C" and ax not in model_axes - ) + train_datasets.add_transform("images", zds.ToDtype(np.float32)) + if transform: + train_datasets.add_transform("images", transform) + + train_datasets.add_transform("labels", zds.ToDtype(np.int32)) + if labels_transform: + train_datasets.add_transform("labels", labels_transform) + + val_datasets.add_transform("images", zds.ToDtype(np.float32)) + if transform: + val_datasets.add_transform("images", transform) - for img, lab in dataloader: - if USING_PYTORCH: - img = img[0].numpy() - lab = lab[0].numpy() + val_datasets.add_transform("labels", zds.ToDtype(np.int32)) + if labels_transform: + val_datasets.add_transform("labels", labels_transform) - if len(drop_axis): - img = img.squeeze(drop_axis) + worker_init_fn = zds.zarrdataset_worker_init_fn - if len(drop_label_axis): - lab = lab.squeeze(drop_label_axis) + else: + train_datasets = [] + val_datasets = [] + + training_indices = np.random.choice( + len(dataset_metadata_list), + int(train_data_proportion * len(dataset_metadata_list)) + ).tolist() + + for idx, dataset_metadata in enumerate(dataset_metadata_list): + patch_sampler = zds.PatchSampler( + patch_size=patch_sizes, + spatial_axes=dataset_metadata["labels"]["axes"], + min_area=0.01 + ) + + dataset = zds.ZarrDataset( + list(dataset_metadata.values()), + return_positions=False, + draw_same_chunk=False, + patch_sampler=patch_sampler, + shuffle=True, + ) + + dataset.add_transform("images", zds.ToDtype(np.float32)) + if transform: + dataset.add_transform("images", transform) - img = transform(img) + dataset.add_transform("labels", zds.ToDtype(np.int32)) + if labels_transform: + dataset.add_transform("labels", labels_transform) - if random.random() <= train_data_proportion: - train_data.append(img) - train_labels.append(lab) + if idx in training_indices: + train_datasets.append(dataset) else: - test_data.append(img) - test_labels.append(lab) + val_datasets.append(dataset) - if not test_data: - # Take at least one sample at random from the train dataset - test_data_idx = random.randrange(0, len(train_data)) - test_data = [train_data.pop(test_data_idx)] - test_labels = [train_labels.pop(test_data_idx)] + train_datasets = ChainDataset(train_datasets) + val_datasets = ChainDataset(val_datasets) + worker_init_fn = zds.chained_zarrdataset_worker_init_fn - self._fine_tune(train_data, train_labels, test_data, test_labels) + if USING_PYTORCH: + train_dataloader = DataLoader( + train_datasets, + num_workers=self._num_workers, + worker_init_fn=worker_init_fn + ) + val_dataloader = DataLoader( + val_datasets, + num_workers=self._num_workers, + worker_init_fn=worker_init_fn + ) + else: + train_dataloader = train_datasets + val_dataloader = val_datasets - return train_data, train_labels, test_data, test_labels + return self._fine_tune(train_dataloader, val_dataloader) class AcquisitionFunction: @@ -773,6 +825,7 @@ def compute_acquisition_layers( reference_source_axes=output_axes, reference_scale=sampling_output_scale, output_filename=output_filename, + use_as_sampling_mask=True, add_func=viewer.add_labels ) @@ -843,10 +896,6 @@ def fine_tune(self): (label_layers_group, "labels") ] - if (sampling_mask_layers_group is not None - and image_group.labels_group is None): - layer_types.append((sampling_mask_layers_group, "masks")) - displayed_source_axes = input_layers_group.source_axes displayed_shape = { ax: ax_s @@ -860,6 +909,25 @@ def fine_tune(self): output_axes.remove("C") output_axes = "".join(output_axes) + if sampling_mask_layers_group is not None: + layer_types.append((sampling_mask_layers_group, "masks")) + else: + self.image_groups_manager.mask_generator.active_image_group =\ + image_group + self.image_groups_manager.mask_generator.set_patch_size( + [self._patch_sizes.get(ax, 1) + for ax in output_axes] + ) + + self.image_groups_manager.mask_generator.generate_mask_layer() + + sampling_mask_layers_group = image_group.child( + image_group.sampling_mask_layers_group + ) + sampling_mask_layers_group.child(0).layer.data[:] = 1 + + layer_types.append((sampling_mask_layers_group, "masks")) + dataset_metadata = self._prepare_datasets_metadata( displayed_shape, layer_types, @@ -867,10 +935,9 @@ def fine_tune(self): dataset_metadata_list.append(dataset_metadata) - self.tunable_segmentation_method.fine_tune( + success = self.tunable_segmentation_method.fine_tune( dataset_metadata_list, - patch_sizes=self._patch_sizes, - model_axes=self.model_axes + patch_sizes=self._patch_sizes ) self.compute_acquisition_layers( @@ -879,4 +946,4 @@ def fine_tune(self): segmentation_only=True ) - return True + return success diff --git a/src/napari_activelearning/_interface.py b/src/napari_activelearning/_interface.py index 9303559..6d210ce 100644 --- a/src/napari_activelearning/_interface.py +++ b/src/napari_activelearning/_interface.py @@ -929,6 +929,8 @@ def __init__(self, image_groups_manager: ImageGroupsManagerWidget, ) self.finetuning_btn.clicked.connect(self.fine_tune) + self.patch_sizes_mspn.update_spin_boxes() + def _show_patch_sizes(self, show: bool): self.patch_sizes_widget.setVisible(show) diff --git a/src/napari_activelearning/_models_impl.py b/src/napari_activelearning/_models_impl.py index 04a01d8..0460097 100644 --- a/src/napari_activelearning/_models_impl.py +++ b/src/napari_activelearning/_models_impl.py @@ -1,10 +1,10 @@ -import os import numpy as np +import random import skimage import zarrdataset as zds -from ._acquisition import TunableMethod, add_dropout +from ._acquisition import TunableMethod, add_dropout, USING_PYTORCH try: import cellpose @@ -135,11 +135,38 @@ def _get_transform(self): if self._transform is None: self._transform = CellposeTransform(self._channels, self._channel_axis) - return self._transform - - def _fine_tune(self, train_data, train_labels, test_data, test_labels): + return self._transform, None + + # def _preload_data(self, data_loader, + # train_data_proportion: float = 0.8): + def _preload_data(self, dataloader): + raw_data = [] + label_data = [] + for img, lab in dataloader: + if USING_PYTORCH: + img = img[0].numpy() + lab = lab[0].numpy() + + raw_data.append(img) + label_data.append(lab) + + return raw_data, label_data + + # def _fine_tune(self, data_loader, + # train_data_proportion: float = 0.8) -> bool: + def _fine_tune(self, train_dataloader, val_dataloader) -> bool: self._model_init() + (train_data, + train_labels) = self._preload_data( + train_dataloader + ) + + (test_data, + test_labels) = self._preload_data( + val_dataloader + ) + self._pretrained_model = train.train_seg( self._model.net, train_data=train_data, @@ -176,6 +203,8 @@ def _fine_tune(self, train_data, train_labels, test_data, test_labels): self.refresh_model = True + return True + USING_CELLPOSE = True except ModuleNotFoundError: @@ -236,8 +265,12 @@ def _get_transform(self): if self._transform is None: self._model_init() - return self._transform + return self._transform, None - def _fine_tune(self, train_data, train_labels, test_data, test_labels): + # def _fine_tune(self, data_loader, + # train_data_proportion: float = 0.8) -> bool: + def _fine_tune(self, train_dataloader, val_dataloader) -> bool: if self._transform is None: self._model_init() + + return True diff --git a/src/napari_activelearning/_models_impl_interface.py b/src/napari_activelearning/_models_impl_interface.py index 0a2f573..76e9c1f 100644 --- a/src/napari_activelearning/_models_impl_interface.py +++ b/src/napari_activelearning/_models_impl_interface.py @@ -259,9 +259,10 @@ def _show_segmentation_parameters(self, show: bool): def _show_finetuning_parameters(self, show: bool): self._finetuning_parameters_scr.setVisible(show) - def _fine_tune(self, train_data, train_labels, test_data, test_labels): - super()._fine_tune(train_data, train_labels, test_data, - test_labels) + # def _fine_tune(self, data_loader, train_data_proportion: float = 0.8): + # super()._fine_tune(data_loader, train_data_proportion) + def _fine_tune(self, train_dataloader, val_dataloader): + super()._fine_tune(train_dataloader, val_dataloader) self._segmentation_parameters.pretrained_model.value =\ self._pretrained_model diff --git a/src/napari_activelearning/_tests/test_acquisition.py b/src/napari_activelearning/_tests/test_acquisition.py index b66630e..0adf4ec 100644 --- a/src/napari_activelearning/_tests/test_acquisition.py +++ b/src/napari_activelearning/_tests/test_acquisition.py @@ -20,7 +20,7 @@ def __init__(self): super(TestTunableMethod, self).__init__() def _get_transform(self): - return lambda x: x + return lambda x: x, None def _run_pred(self, img, *args, **kwargs): return np.random.random((10, 10)) @@ -28,8 +28,9 @@ def _run_pred(self, img, *args, **kwargs): def _run_eval(self, img, *args, **kwargs): return np.random.randint(0, 2, (10, 10)) - def _fine_tune(self, train_data, train_labels, test_data, test_labels): - return None + # def _fine_tune(self, train_data, train_labels, test_data, test_labels): + def _fine_tune(self, train_dataloader, val_dataloader): + return True def test_compute_acquisition_fun(): @@ -131,6 +132,7 @@ def test_add_multiscale_output_layer(single_scale_type_variant_array, contrast_limits = [0, 1] colormap = "gray" use_as_input_labels = False + use_as_sampling_mask = False viewer = make_napari_viewer() add_func = viewer.add_image @@ -148,6 +150,7 @@ def test_add_multiscale_output_layer(single_scale_type_variant_array, contrast_limits, colormap, use_as_input_labels, + use_as_sampling_mask, add_func ) From 5b5644a6b704565e2d3e03cc048935da764b1976 Mon Sep 17 00:00:00 2001 From: Fernando Cervantes Sanchez Date: Mon, 13 Jan 2025 11:23:55 -0500 Subject: [PATCH 05/10] Changed to take number of dimensions from layers data --- src/napari_activelearning/_layers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/napari_activelearning/_layers.py b/src/napari_activelearning/_layers.py index 2e3fa07..e45de19 100644 --- a/src/napari_activelearning/_layers.py +++ b/src/napari_activelearning/_layers.py @@ -30,7 +30,7 @@ def __init__(self, layer: Layer, channel: int = 0, self.channel = channel source_axes = list(source_axes) - source_axes = source_axes[-self.layer.ndim:] + source_axes = source_axes[-self.layer.data.ndim:] self.source_axes = "".join(source_axes) def _update_name(self, event): @@ -88,7 +88,7 @@ def source_axes(self): @source_axes.setter def source_axes(self, source_axes: str): - if "C" in source_axes and self.layer.ndim != len(source_axes): + if "C" in source_axes and self.layer.data.ndim != len(source_axes): source_axes = list(source_axes) source_axes.remove("C") source_axes = "".join(source_axes) @@ -110,7 +110,7 @@ def shape(self): @property def ndim(self): - return self.layer.ndim + return self.layer.data.ndim @property def scale(self): From 0cf9a93d71487bd40102c87e0dcd243461d454f0 Mon Sep 17 00:00:00 2001 From: Fernando Cervantes Sanchez Date: Tue, 14 Jan 2025 10:36:28 -0500 Subject: [PATCH 06/10] Testing micro-sam fine tuning --- notebooks/microsam_activelearning.py | 67 +++++++++-------------- src/napari_activelearning/_acquisition.py | 19 +++++-- 2 files changed, 39 insertions(+), 47 deletions(-) diff --git a/notebooks/microsam_activelearning.py b/notebooks/microsam_activelearning.py index d27fc53..1122d2b 100644 --- a/notebooks/microsam_activelearning.py +++ b/notebooks/microsam_activelearning.py @@ -1,11 +1,12 @@ -from typing import Iterable, Union - import numpy as np import torch import time +from torch_em.transform.label import PerObjectDistanceTransform + from micro_sam import util from micro_sam import automatic_segmentation as msas +import micro_sam.training as sam_training import napari_activelearning as al @@ -43,9 +44,13 @@ def _model_init(self): al.add_dropout(self._sam_predictor_dropout.model) - def _get_transform(self): - return lambda x: x, None + label_transform = PerObjectDistanceTransform( + distances=True, boundary_distances=True, directed_distances=False, + foreground=True, instances=True, min_size=25 + ) + + return lambda x: (255.0 * x).astype(np.uint8), label_transform def _run_pred(self, img, *args, **kwargs): self._model_init() @@ -102,46 +107,26 @@ def _run_eval(self, img, *args, **kwargs): return segmentation_mask - # def _fine_tune(self, data_loader, - # train_data_proportion: float = 0.8) -> bool: def _fine_tune(self, train_dataloader, val_dataloader) -> bool: self._model_init() - # self._pretrained_model = train.train_seg( - # self._model.net, - # train_data=train_data, - # train_labels=train_labels, - # train_probs=None, - # test_data=test_data, - # test_labels=test_labels, - # test_probs=None, - # load_files=False, - # batch_size=self._batch_size, - # learning_rate=self._learning_rate, - # n_epochs=self._n_epochs, - # weight_decay=self._weight_decay, - # momentum=self._momentum, - # SGD=self._SGD, - # channels=self._channels, - # channel_axis=self._channel_axis, - # rgb=self._rgb, - # normalize=self._normalize, - # compute_flows=self._compute_flows, - # save_path=self._save_path, - # save_every=self._save_every, - # nimg_per_epoch=self._nimg_per_epoch, - # nimg_test_per_epoch=self._nimg_test_per_epoch, - # rescale=self._rescale, - # scale_range=self._scale_range, - # bsize=self._bsize, - # min_train_masks=self._min_train_masks, - # model_name=self._model_name - # ) - - # if isinstance(self._pretrained_model, tuple): - # self._pretrained_model = self._pretrained_model[0] - - # self.refresh_model = True + train_dataloader.shuffle = True + val_dataloader.shuffle = False + + # Run training. + sam_training.train_sam( + name="microsam_activelearning", + model_type="vit_t", + train_loader=train_dataloader, + val_loader=val_dataloader, + n_epochs=2, + n_objects_per_batch=25, + with_segmentation_decoder=True, + device=util.get_device("cuda" + if torch.cuda.is_available() + else "cpu"), + ) + return True diff --git a/src/napari_activelearning/_acquisition.py b/src/napari_activelearning/_acquisition.py index 751433c..f098223 100644 --- a/src/napari_activelearning/_acquisition.py +++ b/src/napari_activelearning/_acquisition.py @@ -264,6 +264,14 @@ def segment(self, img, *args, **kwargs): return out +class MyZarrDataset(zds.ZarrDataset): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def __len__(self): + return len(self._toplefts) + + class TunableMethod(SegmentationMethod): def __init__(self): self._num_workers = 0 @@ -315,7 +323,7 @@ def fine_tune(self, dataset_metadata_list: Iterable[dict], dataset_metadata_list[0]["masks"]["filenames"] = train_mask - train_datasets = zds.ZarrDataset( + train_datasets = MyZarrDataset( list(dataset_metadata_list[0].values()), return_positions=False, draw_same_chunk=False, @@ -325,7 +333,7 @@ def fine_tune(self, dataset_metadata_list: Iterable[dict], dataset_metadata_list[0]["masks"]["filenames"] = val_mask - val_datasets = zds.ZarrDataset( + val_datasets = MyZarrDataset( list(dataset_metadata_list[0].values()), return_positions=False, draw_same_chunk=False, @@ -361,13 +369,13 @@ def fine_tune(self, dataset_metadata_list: Iterable[dict], ).tolist() for idx, dataset_metadata in enumerate(dataset_metadata_list): - patch_sampler = zds.PatchSampler( + patch_sampler = MyZarrDataset( patch_size=patch_sizes, spatial_axes=dataset_metadata["labels"]["axes"], min_area=0.01 ) - dataset = zds.ZarrDataset( + dataset = MyZarrDataset( list(dataset_metadata.values()), return_positions=False, draw_same_chunk=False, @@ -726,8 +734,7 @@ def compute_acquisition_layers( is_multiscale=True ) - acquisition_fun_grp = acquisition_root["" - "acquisition_fun/0"] + acquisition_fun_grp = acquisition_root["acquisition_fun/0"] else: acquisition_fun_grp = None From 3fe88e910292667bfe607e4472a1251b9b91c614 Mon Sep 17 00:00:00 2001 From: Fernando Cervantes Sanchez Date: Fri, 17 Jan 2025 16:59:24 -0500 Subject: [PATCH 07/10] Working on tutorial --- docs/tutorials/cellpose_tutorial.qmd | 637 +++++++++++++++++++++++++++ 1 file changed, 637 insertions(+) diff --git a/docs/tutorials/cellpose_tutorial.qmd b/docs/tutorials/cellpose_tutorial.qmd index 4a4d7bd..d5561d9 100644 --- a/docs/tutorials/cellpose_tutorial.qmd +++ b/docs/tutorials/cellpose_tutorial.qmd @@ -18,3 +18,640 @@ execute: jupyter: python3 --- +# 1 Image groups management + +## 1.1 Load a sample image + +You can use the cells 3D image sample from the napari's built-in samples. + +``` +File > Open Sample > napari builtins > Cells (3D+2Ch) +``` + +``` {python} +#| echo: false +from PIL import Image, ImageDraw, ImageFont +import napari +from napari.utils import nbscreenshot +import napari_activelearning as al +``` +``` {python} +#| echo: false +viewer = napari.Viewer() +_ = viewer.open_sample(plugin="napari", sample="cells3d") +``` +``` {python} +#| echo: false +nbscreenshot(viewer) +``` + +## 1.2 Add the _Image Groups Manager_ widget to napari's window + +You can find the _Image group manager_ under the _Active Learning_ plugin in the napari's plugins menu. + +``` +Plugins > Active Learning > Image groups manager +``` + +``` {python} +#| echo: false +image_groups_mgr, acquisition_fun_cfg, labels_mgr = al.get_active_learning_widget() +``` +``` {python} +#| echo: false +image_groups_mgr_dw = viewer.window.add_dock_widget(image_groups_mgr) +``` +``` {python} +#| echo: false +nbscreenshot(viewer) +``` + +## 1.3 Create an _Image Group_ containing _nuclei_ and _membrane_ layers + +- Select the _nuclei_ and _membrane_ layer and click the _New Image Group_ button on the _Image Groups Manager_ widget. + +``` {python} +#| echo: false +viewer.layers.selection.clear() +viewer.layers.selection.add(viewer.layers["nuclei"]) +viewer.layers.selection.add(viewer.layers["membrane"]) +image_groups_mgr.create_group() +``` +``` {python} +#| echo: false +nbscreenshot(viewer) +``` + +## 1.4 Edit the image group properties + +:::: {.columns} + +::: {.column width=0.3} +- Select the newly created image group, it will appear as "images" in the _Image groups manager_ widget. + +- Click the _Edit group properties_ checkbox. + +- Make sure that _Axes order_ is "CZYX", otherwise, you can edit it and press _Enter_ to update the axes names. +::: + +::: {.column width=0.7} +``` {python} +#| echo: false +image_groups_mgr.image_groups_editor._show_editor(True) +image_groups_mgr._active_image_group.child(0).setSelected(True) +image_groups_mgr.image_groups_editor.edit_axes_le.setText("CZYX") +image_groups_mgr.image_groups_editor.update_source_axes() +``` +``` {python} +#| echo: false +screenshot = viewer.screenshot(canvas_only=False, flash=False) + +image = Image.fromarray(screenshot) + +org_height, org_width = screenshot.shape[:2] + +roi = (org_width * 0.7, 0, org_width, org_height) + +# Crop the image +cropped_image = image.crop(roi) + +draw = ImageDraw.Draw(cropped_image) + +# Draw a red rectangle +draw.rectangle([70, 250, 310, 280], outline="white", width=5) +draw.rectangle([70, 250, 310, 280], outline="green", width=2) + +cropped_image +``` +::: + +:::: + +# 2 Segment the managed image groups + +## 2.1 Add the _Acquisition function configuration_ widget to napari's window + +The _Acquisition function configuration_ is under the _Active Learning_ plugin in the napari's plugins menu. + +``` +Plugins > Active Learning > Acquisition function configuration +``` + +``` {python} +#| echo: false +acquisition_fun_cfg_dw = viewer.window.add_dock_widget(acquisition_fun_cfg) +viewer.window._qt_window.tabifyDockWidget(image_groups_mgr_dw, acquisition_fun_cfg_dw) + +image_groups_mgr_dw.setWindowTitle("Image groups manager") +acquisition_fun_cfg_dw.setWindowTitle("Acquisition function configuration") +``` +``` {python} +#| echo: false +acquisition_fun_cfg_dw.raise_() +``` +``` {python} +#| echo: false +nbscreenshot(viewer) +``` + +## 2.2 Define sampling configuration + +:::: {.columns} + +::: {.column width=0.3} +### 2.2.1 Set the axes of the sampling space + +1. Make sure "Input axes" are set to "ZYX" + +::: {.callout-note} +This specifies that the samples will be taken from those axes. +::: + +2. Change the "Model axes" to "YXC" + +::: {.callout-note} +The model axes are the axes on which the segmentation model is implemented. `Cellpose` (for 2D images) expects the input image to have "Y" and "X" spatial axes, and a trailing channel axes "C". +::: + +::: + +::: {.column width=0.3} + +``` {python} +#| echo: false +acquisition_fun_cfg.input_axes_le.setText("ZYX") +acquisition_fun_cfg._set_input_axes() +acquisition_fun_cfg.model_axes_le.setText("YXC") +acquisition_fun_cfg._set_model_axes() +``` +``` {python} +#| echo: false +screenshot = viewer.screenshot(canvas_only=False, flash=False) + +image = Image.fromarray(screenshot) + +org_height, org_width = screenshot.shape[:2] + +roi = (org_width * 0.7, 0, org_width, org_height) + +# Crop the image +cropped_image = image.crop(roi) + +draw = ImageDraw.Draw(cropped_image) + +draw.rectangle([60, 435, 405, 475], outline="white", width=5) +draw.rectangle([60, 435, 405, 475], outline="green", width=2) + +position = (250, 395) +font = ImageFont.truetype("arialbd.ttf", size=36) +draw.text(position, "1", fill="green", font=font) +font = ImageFont.truetype("arial.ttf", size=36) +draw.text(position, "1", fill="white", font=font) + +draw.rectangle([405, 435, 570, 475], outline="white", width=5) +draw.rectangle([405, 435, 570, 475], outline="green", width=2) + +position = (490, 395) +font = ImageFont.truetype("arialbd.ttf", size=36) +draw.text(position, "2", fill="green", font=font) +font = ImageFont.truetype("arial.ttf", size=36) +draw.text(position, "2", fill="white", font=font) + +cropped_image +``` +::: + +:::: + +--- + +### 2.2.2 Set the size of the sampling patch + +:::: {.columns} + +::: {.column width=0.70} + +``` {python} +#| echo: false +acquisition_fun_cfg._show_patch_sizes(True) + +acquisition_fun_cfg.patch_sizes_mspn.sizes = {"Z": 1, "Y": 256, "X": 256} +``` +``` {python} +#| echo: false +screenshot = viewer.screenshot(canvas_only=False, flash=False) + +image = Image.fromarray(screenshot) + +org_height, org_width = screenshot.shape[:2] + +roi = (org_width * 0.7, 0, org_width, org_height) + +# Crop the image +cropped_image = image.crop(roi) + +draw = ImageDraw.Draw(cropped_image) + +# Draw a red rectangle +draw.rectangle([70, 270, 470, 700], outline="white", width=5) +draw.rectangle([70, 270, 470, 700], outline="green", width=2) + +cropped_image +``` + +::: + +::: {.column width=0.30} + +- Click the "Edit patch size" checkbox +- Change the patch size of "X" and "Y" to 256, and the "Z" axis to 1. + +:::{.callout-note} +This directs the Active Learning plugin to sample at random patches of size $256\times256$ pixels, and $1$ slice of depth. +::: + +::: + +:::: + +## 2.3 Define the maximum number of samples to extract + +:::: {.columns} + +::: {.column width=0.7} + +``` {python} +#| echo: false +acquisition_fun_cfg.max_samples_spn.setValue(4) +``` +``` {python} +#| echo: false +screenshot = viewer.screenshot(canvas_only=False, flash=False) + +image = Image.fromarray(screenshot) + +org_height, org_width = screenshot.shape[:2] + +roi = (org_width * 0.7, 0, org_width, org_height) + +# Crop the image +cropped_image = image.crop(roi) + +draw = ImageDraw.Draw(cropped_image) + +# Draw a red rectangle +draw.rectangle([60, 855, 405, 890], outline="white", width=5) +draw.rectangle([60, 855, 405, 890], outline="green", width=2) + +cropped_image.resize((int(0.15 * org_width), int(org_height * 0.5)), Image.Resampling.LANCZOS) +``` + +::: + +::: {.column width=0.30} +- Set the "Maximum samples" to $4$ and press _Enter_ + +:::{.callout-note} +This tells the Active Learning plugin to process at most _four_ samples at random from the whole image. +::: + +::: + +:::: + + +## 2.4 Configure the segmentation method + +:::: {.columns} + +::: {.column width=0.3} + +1. Use the dropdown with label "None selected" to select the `cellpose` method + +2. Click the "Advanced segmentation parameters" checkbox + +3. Change the second channel to 1 (the right spin box in the "channels" row) + +::: {.callout-note} +This tells `cellpose` to segment the first channel ($0$) and use the second channel ($1$) as help channel. +::: + +::: + +::: {.column width=0.7} +``` {python} +#| echo: false +acquisition_fun_cfg.methods_cmb.setCurrentIndex(2) +acquisition_fun_cfg.tunable_segmentation_method.advanced_segmentation_options_chk.setChecked(True) +acquisition_fun_cfg.tunable_segmentation_method._segmentation_parameters.channels.value = (0, 1) +``` +``` {python} +#| echo: false +screenshot = viewer.screenshot(canvas_only=False, flash=False) + +image = Image.fromarray(screenshot) + +org_height, org_width = screenshot.shape[:2] + +roi = (org_width * 0.7, 0, org_width, org_height) + +# Crop the image +cropped_image = image.crop(roi) + +draw = ImageDraw.Draw(cropped_image) + +# Draw a red rectangle +draw.rectangle([55, 550, 570, 580], outline="white", width=5) +draw.rectangle([55, 550, 570, 580], outline="green", width=2) + +position = (550, 580) +font = ImageFont.truetype("arialbd.ttf", size=36) +draw.text(position, "1", fill="green", font=font) +font = ImageFont.truetype("arial.ttf", size=36) +draw.text(position, "1", fill="white", font=font) + +draw.rectangle([65, 585, 310, 620], outline="white", width=5) +draw.rectangle([65, 585, 310, 620], outline="green", width=2) + +position = (320, 580) +font = ImageFont.truetype("arialbd.ttf", size=36) +draw.text(position, "2", fill="green", font=font) +font = ImageFont.truetype("arial.ttf", size=36) +draw.text(position, "2", fill="white", font=font) + +draw.rectangle([390, 750, 550, 780], outline="white", width=5) +draw.rectangle([390, 750, 550, 780], outline="green", width=2) + +position = (450, 710) +font = ImageFont.truetype("arialbd.ttf", size=36) +draw.text(position, "3", fill="green", font=font) +font = ImageFont.truetype("arial.ttf", size=36) +draw.text(position, "3", fill="white", font=font) + +cropped_image.resize((int(0.15 * org_width), int(org_height * 0.5)), Image.Resampling.LANCZOS) +``` +::: + +:::: + +## 2.5 Execute the segmentation method on all image groups + +:::: {.columns} + +::: {.column width=0.3} +- Click the "Run on all image groups" + +::: {.callout-note} +To execute the segmentation only on specific image groups, select the desired image groups in the _Image groups manager_ widget and use the "Run on selected image groups" button instead. +::: + +::: + +::: {.column width=0.7} +``` {python} +#| echo: false +screenshot = viewer.screenshot(canvas_only=False, flash=False) + +image = Image.fromarray(screenshot) + +org_height, org_width = screenshot.shape[:2] + +roi = (org_width * 0.7, 0, org_width, org_height) + +# Crop the image +cropped_image = image.crop(roi) + +draw = ImageDraw.Draw(cropped_image) + +# Draw a red rectangle +draw.rectangle([225, 970, 385, 1005], outline="white", width=5) +draw.rectangle([225, 970, 385, 1005], outline="green", width=2) + +cropped_image.resize((int(0.15 * org_width), int(org_height * 0.5)), Image.Resampling.LANCZOS) +``` +``` {python} +#| echo: false +_ = acquisition_fun_cfg.compute_acquisition_layers(run_all=True) +``` +::: + +:::: + +## 2.6 Inspect the segmentation layer + +::: {.callout-note} +Because the input image is 3D, you might have to slide the Z index on the bottom of napari's window to look at the samples that have been segmented. +::: + +``` {python} +#| echo: false +labels_mgr.focus_region( + labels_mgr.labels_group_root.child(0).child(0) +) +``` +``` {python} +#| echo: false +screenshot = viewer.screenshot(canvas_only=False, flash=False) + +image = Image.fromarray(screenshot) + +draw = ImageDraw.Draw(image) + +# Draw a red rectangle +draw.rectangle([295, 1080, 1380, 1110], outline="white", width=5) +draw.rectangle([295, 1080, 1380, 1115], outline="green", width=2) + +image +``` + + +# 3 Segment masked regions only + +## 3.1 Create a mask to restrict the sampling space + +### 3.1.1 Add a mask layer to the image group + +:::: {.columns} +::: {.column width=0.3} +- Switch to the "Image groups manager" tab +- Click the "Edit mask properties" checkbox +::: + +::: {.column width=0.7} +``` {python} +#| echo: false +image_groups_mgr_dw.raise_() +``` +``` {python} +#| echo: false +screenshot = viewer.screenshot(canvas_only=False, flash=False) + +image = Image.fromarray(screenshot) + +org_height, org_width = screenshot.shape[:2] + +roi = (org_width * 0.7, 0, org_width, org_height) + +# Crop the image +cropped_image = image.crop(roi) + +draw = ImageDraw.Draw(cropped_image) + +# Draw a red rectangle +draw.rectangle([45, 375, 220, 410], outline="white", width=5) +draw.rectangle([45, 375, 220, 410], outline="green", width=2) + +cropped_image +``` +``` {python} +image_groups_mgr.mask_generator._show_editor(True) +``` + +::: + +:::: + +### 3.1.2 Create a low resolution mask for the associated image + +:::: {.columns} +:::{.column width=0.3} +1. Set the mask scale to $256$ for the "X" and "Y" axes, and a scale of $1$ for the "Z" axis + +2. Click the "Create mask" button + +:::{.callout-note} +This creates a low-resolution mask where each pixel corresponds to a $256\times256$ pixels region in the input image. +Because the mask is low-resolution, it uses less space (in memory RAM and disk). +::: + +::: + +:::{.column width=0.7} +``` {python} +#| echo: false +image_groups_mgr.mask_generator.patch_sizes_mspn.sizes = {"Z": 1, "Y": 256, "X": 256} +image_groups_mgr.mask_generator.generate_mask_layer() +``` +``` {python} +#| echo: false +screenshot = viewer.screenshot(canvas_only=False, flash=False) + +image = Image.fromarray(screenshot) + +org_height, org_width = screenshot.shape[:2] + +roi = (org_width * 0.7, 0, org_width, org_height) + +# Crop the image +cropped_image = image.crop(roi) + +draw = ImageDraw.Draw(cropped_image) + +# Draw a red rectangle +draw.rectangle([60, 450, 550, 640], outline="white", width=5) +draw.rectangle([60, 450, 550, 640], outline="green", width=2) + +position = (500, 400) +font = ImageFont.truetype("arialbd.ttf", size=36) +draw.text(position, "1", fill="green", font=font) +font = ImageFont.truetype("arial.ttf", size=36) +draw.text(position, "1", fill="white", font=font) + +draw.rectangle([50, 675, 560, 710], outline="white", width=5) +draw.rectangle([50, 675, 560, 710], outline="green", width=2) + +position = (530, 640) +font = ImageFont.truetype("arialbd.ttf", size=36) +draw.text(position, "2", fill="green", font=font) +font = ImageFont.truetype("arial.ttf", size=36) +draw.text(position, "2", fill="white", font=font) + +cropped_image +``` +::: +:::: + + +## 3.1.3 Specify the samplable regions + +- Draw a mask on slices $27$ to $30$ in the "Z" axis. + +:::{.callout-note} +You can move the slider at the bottom of napari's window to navigate between slices in the "Z" axis. +::: + +``` {python} +#| echo: false +viewer.camera.center = (27, 128, 128) +viewer.dims.current_step = (27, 128, 128) + +viewer.layers["images mask"].data[0, 0, 27:31, 0, 0] = 1 +viewer.layers["images mask"].refresh() +``` +``` {python} +#| echo: false +screenshot = viewer.screenshot(canvas_only=False, flash=False) + +image = Image.fromarray(screenshot) + +image +``` + + +## 3.2 Execute the segmentation process on the masked regions + +:::: {.columns} +::: {.column width=0.3} +- Go back to the "Acquisition function configuration" widget +- Click the "Run on all image groups" button again + +::: {.callout-note} +Because the image group has a defined mask, samples will be extracted at random inside those defined regions only. +::: + +::: + +::: {.column width=0.7} +``` {python} +#| echo: false +acquisition_fun_cfg_dw.raise_() +``` +``` {python} +#| echo: false +screenshot = viewer.screenshot(canvas_only=False, flash=False) + +image = Image.fromarray(screenshot) + +org_height, org_width = screenshot.shape[:2] + +roi = (org_width * 0.7, 0, org_width, org_height) + +# Crop the image +cropped_image = image.crop(roi) + +draw = ImageDraw.Draw(cropped_image) + +# Draw a red rectangle +draw.rectangle([225, 970, 385, 1005], outline="white", width=5) +draw.rectangle([225, 970, 385, 1005], outline="green", width=2) + +cropped_image.resize((int(0.15 * org_width), int(org_height * 0.5)), Image.Resampling.LANCZOS) +``` +``` {python} +#| echo: false +_ = acquisition_fun_cfg.compute_acquisition_layers(run_all=True) +``` + +::: + +:::: + +## 3.3 Inspect the masked segmentation output + +``` {python} +#| echo: false +screenshot = viewer.screenshot(canvas_only=False, flash=False) + +image = Image.fromarray(screenshot) +image +``` From e40e18e910cb6eb9c7723962e449f3c02798899a Mon Sep 17 00:00:00 2001 From: Fernando Cervantes Sanchez Date: Fri, 17 Jan 2025 16:59:44 -0500 Subject: [PATCH 08/10] Fixing multiple sampling masks active --- src/napari_activelearning/_layers.py | 33 +++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/src/napari_activelearning/_layers.py b/src/napari_activelearning/_layers.py index e45de19..17c40f8 100644 --- a/src/napari_activelearning/_layers.py +++ b/src/napari_activelearning/_layers.py @@ -1,7 +1,8 @@ from typing import Iterable, Union, Optional import operator +from functools import partial from pathlib import Path -from qtpy.QtCore import Qt +from qtpy.QtCore import Qt, QObject, Signal, Slot from qtpy.QtWidgets import QTreeWidgetItem import numpy as np @@ -151,6 +152,10 @@ def selected(self, is_selected: bool): viewer.layers.selection.remove(self.layer) +class LayersGroupSignals(QObject): + updated_usage = Signal((int,)) + + class LayersGroup(QTreeWidgetItem): def __init__(self, layers_group_name: str, source_axes: Optional[str] = None, @@ -171,6 +176,9 @@ def __init__(self, layers_group_name: str, super().__init__() + self.layers_group_signals = LayersGroupSignals() + self._signal_emited = False + self.layers_group_name = layers_group_name self.use_as_input_image = use_as_input_image self.use_as_input_labels = use_as_input_labels @@ -389,6 +397,12 @@ def _set_usage(self): if self._use_as_sampling_mask: use_as.append("Sampling mask") + if not self._signal_emited and self.parent() is not None: + sampling_mask_idx = self.parent().indexOfChild(self) + self._signal_emited = True + self.layers_group_signals.updated_usage.emit(sampling_mask_idx) + self._signal_emited = False + self.setText(1, "/".join(use_as)) @property @@ -700,6 +714,10 @@ def sampling_mask_layers_group(self, sampling_mask_idx: Union[int, None]): if sampling_mask_idx is not None: self.child(sampling_mask_idx).use_as_sampling_mask = True + @Slot(int) + def _update_sampling_mask_layers_group(self, sampling_mask_idx): + self.sampling_mask_layers_group = sampling_mask_idx + @property def labels_group(self): return self._labels_group @@ -725,6 +743,9 @@ def getLayersGroup(self, layers_group_name: str): def takeChild(self, index: int): child = super(ImageGroup, self).takeChild(index) if isinstance(child, LayersGroup): + child.layers_group_signals.updated_usage.disconnect( + partial(self._update_sampling_mask_layers_group, self) + ) child.takeChildren() return child @@ -733,12 +754,18 @@ def takeChildren(self): children = super(ImageGroup, self).takeChildren() for child in children: if isinstance(child, LayersGroup): + child.layers_group_signals.updated_usage.disconnect( + partial(self._update_sampling_mask_layers_group, self) + ) child.takeChildren() return children def removeChild(self, child: QTreeWidgetItem): if isinstance(child, LayersGroup): + child.layers_group_signals.updated_usage.disconnect( + partial(self._update_sampling_mask_layers_group, self) + ) child.takeChildren() super(ImageGroup, self).removeChild(child) @@ -770,6 +797,10 @@ def add_layers_group(self, layers_group_name: Optional[str] = None, self.addChild(new_layers_group) + new_layers_group.layers_group_signals.updated_usage.connect( + partial(self._update_sampling_mask_layers_group, self) + ) + new_layers_group.setExpanded(True) return new_layers_group From 8c57654a7743800f1e3549ddeb609a8eeaa6a131 Mon Sep 17 00:00:00 2001 From: Fernando Cervantes Sanchez Date: Fri, 24 Jan 2025 14:13:31 -0500 Subject: [PATCH 09/10] Major improvements to interface --- docs/tutorials/cellpose_tutorial.qmd | 518 +++++- docs/tutorials/cellpose_tutorial.quarto_ipynb | 1623 +++++++++++++++++ src/napari_activelearning/_acquisition.py | 25 +- src/napari_activelearning/_interface.py | 3 +- src/napari_activelearning/_labels.py | 2 +- src/napari_activelearning/_layers.py | 63 +- src/napari_activelearning/_models_impl.py | 13 +- .../_models_impl_interface.py | 5 +- src/napari_activelearning/_tests/conftest.py | 13 +- .../_tests/test_layers.py | 20 +- 10 files changed, 2151 insertions(+), 134 deletions(-) create mode 100644 docs/tutorials/cellpose_tutorial.quarto_ipynb diff --git a/docs/tutorials/cellpose_tutorial.qmd b/docs/tutorials/cellpose_tutorial.qmd index d5561d9..2ca5a05 100644 --- a/docs/tutorials/cellpose_tutorial.qmd +++ b/docs/tutorials/cellpose_tutorial.qmd @@ -8,13 +8,6 @@ format: controls: true fontsize: 22pt - ipynb: - roc: true - -execute: - echo: - true - jupyter: python3 --- @@ -30,10 +23,11 @@ File > Open Sample > napari builtins > Cells (3D+2Ch) ``` {python} #| echo: false -from PIL import Image, ImageDraw, ImageFont import napari from napari.utils import nbscreenshot +from PIL import Image, ImageDraw, ImageFont import napari_activelearning as al +import zarr ``` ``` {python} #| echo: false @@ -167,11 +161,7 @@ nbscreenshot(viewer) This specifies that the samples will be taken from those axes. ::: -2. Change the "Model axes" to "YXC" - -::: {.callout-note} -The model axes are the axes on which the segmentation model is implemented. `Cellpose` (for 2D images) expects the input image to have "Y" and "X" spatial axes, and a trailing channel axes "C". -::: +2. Change the "Model axes" to "CYX" ::: @@ -181,7 +171,7 @@ The model axes are the axes on which the segmentation model is implemented. `Cel #| echo: false acquisition_fun_cfg.input_axes_le.setText("ZYX") acquisition_fun_cfg._set_input_axes() -acquisition_fun_cfg.model_axes_le.setText("YXC") +acquisition_fun_cfg.model_axes_le.setText("CYX") acquisition_fun_cfg._set_model_axes() ``` ``` {python} @@ -199,19 +189,19 @@ cropped_image = image.crop(roi) draw = ImageDraw.Draw(cropped_image) -draw.rectangle([60, 435, 405, 475], outline="white", width=5) -draw.rectangle([60, 435, 405, 475], outline="green", width=2) +draw.rectangle([60, 390, 405, 425], outline="white", width=5) +draw.rectangle([60, 390, 405, 425], outline="green", width=2) -position = (250, 395) +position = (250, 345) font = ImageFont.truetype("arialbd.ttf", size=36) draw.text(position, "1", fill="green", font=font) font = ImageFont.truetype("arial.ttf", size=36) draw.text(position, "1", fill="white", font=font) -draw.rectangle([405, 435, 570, 475], outline="white", width=5) -draw.rectangle([405, 435, 570, 475], outline="green", width=2) +draw.rectangle([405, 390, 570, 425], outline="white", width=5) +draw.rectangle([405, 390, 570, 425], outline="green", width=2) -position = (490, 395) +position = (490, 345) font = ImageFont.truetype("arialbd.ttf", size=36) draw.text(position, "2", fill="green", font=font) font = ImageFont.truetype("arial.ttf", size=36) @@ -253,8 +243,8 @@ cropped_image = image.crop(roi) draw = ImageDraw.Draw(cropped_image) # Draw a red rectangle -draw.rectangle([70, 270, 470, 700], outline="white", width=5) -draw.rectangle([70, 270, 470, 700], outline="green", width=2) +draw.rectangle([70, 240, 470, 600], outline="white", width=5) +draw.rectangle([70, 240, 470, 600], outline="green", width=2) cropped_image ``` @@ -300,8 +290,8 @@ cropped_image = image.crop(roi) draw = ImageDraw.Draw(cropped_image) # Draw a red rectangle -draw.rectangle([60, 855, 405, 890], outline="white", width=5) -draw.rectangle([60, 855, 405, 890], outline="green", width=2) +draw.rectangle([60, 740, 405, 770], outline="white", width=5) +draw.rectangle([60, 740, 405, 770], outline="green", width=2) cropped_image.resize((int(0.15 * org_width), int(org_height * 0.5)), Image.Resampling.LANCZOS) ``` @@ -326,16 +316,24 @@ This tells the Active Learning plugin to process at most _four_ samples at rando ::: {.column width=0.3} -1. Use the dropdown with label "None selected" to select the `cellpose` method +1. Use the "Model" dropdown to select the `cellpose` method 2. Click the "Advanced segmentation parameters" checkbox -3. Change the second channel to 1 (the right spin box in the "channels" row) +3. Change the "Channel axis" to $0$ +::: {.callout-note} +This makes `cellpose` to use the first axis as "Color" channel. +::: + +4. Change the second channel to $1$ (the right spin box in the "channels" row) ::: {.callout-note} This tells `cellpose` to segment the first channel ($0$) and use the second channel ($1$) as help channel. ::: +5. Choose the "nuclei" model from the dropdown + + ::: ::: {.column width=0.7} @@ -343,7 +341,9 @@ This tells `cellpose` to segment the first channel ($0$) and use the second chan #| echo: false acquisition_fun_cfg.methods_cmb.setCurrentIndex(2) acquisition_fun_cfg.tunable_segmentation_method.advanced_segmentation_options_chk.setChecked(True) +acquisition_fun_cfg.tunable_segmentation_method._segmentation_parameters.channel_axis.value = 0 acquisition_fun_cfg.tunable_segmentation_method._segmentation_parameters.channels.value = (0, 1) +acquisition_fun_cfg.tunable_segmentation_method._segmentation_parameters.model_type.value = "nuclei" ``` ``` {python} #| echo: false @@ -361,33 +361,51 @@ cropped_image = image.crop(roi) draw = ImageDraw.Draw(cropped_image) # Draw a red rectangle -draw.rectangle([55, 550, 570, 580], outline="white", width=5) -draw.rectangle([55, 550, 570, 580], outline="green", width=2) +draw.rectangle([55, 490, 570, 520], outline="white", width=5) +draw.rectangle([55, 490, 570, 520], outline="green", width=2) -position = (550, 580) +position = (550, 450) font = ImageFont.truetype("arialbd.ttf", size=36) draw.text(position, "1", fill="green", font=font) font = ImageFont.truetype("arial.ttf", size=36) draw.text(position, "1", fill="white", font=font) -draw.rectangle([65, 585, 310, 620], outline="white", width=5) -draw.rectangle([65, 585, 310, 620], outline="green", width=2) +draw.rectangle([65, 525, 310, 560], outline="white", width=5) +draw.rectangle([65, 525, 310, 560], outline="green", width=2) -position = (320, 580) +position = (320, 525) font = ImageFont.truetype("arialbd.ttf", size=36) draw.text(position, "2", fill="green", font=font) font = ImageFont.truetype("arial.ttf", size=36) draw.text(position, "2", fill="white", font=font) -draw.rectangle([390, 750, 550, 780], outline="white", width=5) -draw.rectangle([390, 750, 550, 780], outline="green", width=2) +draw.rectangle([390, 655, 550, 690], outline="white", width=5) +draw.rectangle([390, 655, 550, 690], outline="green", width=2) -position = (450, 710) +position = (400, 545) font = ImageFont.truetype("arialbd.ttf", size=36) draw.text(position, "3", fill="green", font=font) font = ImageFont.truetype("arial.ttf", size=36) draw.text(position, "3", fill="white", font=font) +draw.rectangle([235, 580, 550, 615], outline="white", width=5) +draw.rectangle([235, 580, 550, 615], outline="green", width=2) + +position = (450, 615) +font = ImageFont.truetype("arialbd.ttf", size=36) +draw.text(position, "4", fill="green", font=font) +font = ImageFont.truetype("arial.ttf", size=36) +draw.text(position, "4", fill="white", font=font) + +draw.rectangle([70, 725, 550, 765], outline="white", width=5) +draw.rectangle([70, 725, 550, 765], outline="green", width=2) + +position = (450, 770) +font = ImageFont.truetype("arialbd.ttf", size=36) +draw.text(position, "5", fill="green", font=font) +font = ImageFont.truetype("arial.ttf", size=36) +draw.text(position, "5", fill="white", font=font) + cropped_image.resize((int(0.15 * org_width), int(org_height * 0.5)), Image.Resampling.LANCZOS) ``` ::: @@ -424,8 +442,8 @@ cropped_image = image.crop(roi) draw = ImageDraw.Draw(cropped_image) # Draw a red rectangle -draw.rectangle([225, 970, 385, 1005], outline="white", width=5) -draw.rectangle([225, 970, 385, 1005], outline="green", width=2) +draw.rectangle([245, 850, 405, 880], outline="white", width=5) +draw.rectangle([245, 850, 405, 880], outline="green", width=2) cropped_image.resize((int(0.15 * org_width), int(org_height * 0.5)), Image.Resampling.LANCZOS) ``` @@ -498,8 +516,8 @@ cropped_image = image.crop(roi) draw = ImageDraw.Draw(cropped_image) # Draw a red rectangle -draw.rectangle([45, 375, 220, 410], outline="white", width=5) -draw.rectangle([45, 375, 220, 410], outline="green", width=2) +draw.rectangle([65, 380, 225, 415], outline="white", width=5) +draw.rectangle([65, 380, 225, 415], outline="green", width=2) cropped_image ``` @@ -511,7 +529,9 @@ image_groups_mgr.mask_generator._show_editor(True) :::: -### 3.1.2 Create a low resolution mask for the associated image +--- + +### 3.1.2 Create a low resolution mask for its corresponding image :::: {.columns} :::{.column width=0.3} @@ -548,8 +568,8 @@ cropped_image = image.crop(roi) draw = ImageDraw.Draw(cropped_image) # Draw a red rectangle -draw.rectangle([60, 450, 550, 640], outline="white", width=5) -draw.rectangle([60, 450, 550, 640], outline="green", width=2) +draw.rectangle([60, 450, 550, 600], outline="white", width=5) +draw.rectangle([60, 450, 550, 600], outline="green", width=2) position = (500, 400) font = ImageFont.truetype("arialbd.ttf", size=36) @@ -557,10 +577,10 @@ draw.text(position, "1", fill="green", font=font) font = ImageFont.truetype("arial.ttf", size=36) draw.text(position, "1", fill="white", font=font) -draw.rectangle([50, 675, 560, 710], outline="white", width=5) -draw.rectangle([50, 675, 560, 710], outline="green", width=2) +draw.rectangle([70, 620, 560, 650], outline="white", width=5) +draw.rectangle([70, 620, 560, 650], outline="green", width=2) -position = (530, 640) +position = (530, 650) font = ImageFont.truetype("arialbd.ttf", size=36) draw.text(position, "2", fill="green", font=font) font = ImageFont.truetype("arial.ttf", size=36) @@ -632,8 +652,8 @@ cropped_image = image.crop(roi) draw = ImageDraw.Draw(cropped_image) # Draw a red rectangle -draw.rectangle([225, 970, 385, 1005], outline="white", width=5) -draw.rectangle([225, 970, 385, 1005], outline="green", width=2) +draw.rectangle([245, 850, 400, 890], outline="white", width=5) +draw.rectangle([245, 850, 400, 890], outline="green", width=2) cropped_image.resize((int(0.15 * org_width), int(org_height * 0.5)), Image.Resampling.LANCZOS) ``` @@ -655,3 +675,407 @@ screenshot = viewer.screenshot(canvas_only=False, flash=False) image = Image.fromarray(screenshot) image ``` + +# 4. Fine tune the segmentation model + +## 4.1 Add the _Label groups manager_ widget to napari's window + +You can find the _Label groups manager_ under the _Active Learning_ plugin in the napari's plugins menu. + +``` +Plugins > Active Learning > Label groups manager +``` + +``` {python} +#| echo: false +labels_mgr_dw = viewer.window.add_dock_widget(labels_mgr) +viewer.window._qt_window.tabifyDockWidget(acquisition_fun_cfg_dw, labels_mgr_dw) + +labels_mgr_dw.setWindowTitle("Label groups manager") +``` +``` {python} +#| echo: false +labels_mgr_dw.raise_() +``` +``` {python} +#| echo: false +screenshot = viewer.screenshot(canvas_only=False, flash=False) + +image = Image.fromarray(screenshot) +image +``` + + +## 4.2 Edit segmented patches + +### 4.2.1. Select a segmented patch to edit + +- You can double click on any segmented patch in the viewer +(e.g. on slice $27$) + +``` {python} +#| echo: false +labels_mgr.labels_table_tw.setCurrentItem(labels_mgr.labels_group_root.child(1).child(3), 0) +labels_mgr.focus_region( + labels_mgr.labels_group_root.child(1).child(3), + edit_focused_label=True +) +``` +``` {python} +#| echo: false +screenshot = viewer.screenshot(canvas_only=False, flash=False) + +image = Image.fromarray(screenshot) +image +``` + +--- + +### 4.2.2. Use napari's layer controls to make changes on the objects of the current patch + +``` {python} +#| echo: false +screenshot = viewer.screenshot(canvas_only=False, flash=False) + +image = Image.fromarray(screenshot) +image + + +draw = ImageDraw.Draw(image) + +# Draw a red rectangle +draw.rectangle([5, 20, 280, 450], outline="white", width=5) +draw.rectangle([5, 20, 280, 450], outline="green", width=2) + +image +``` + +--- + +### 4.2.3. Commit changes to the labels layer + +:::: {.columns} +::: {.column width=0.3} + +- Once you have finished editing the labels, click the "Commit changes" button on the _Label groups manager_ +::: + +::: {.column width=0.7} +``` {python} +#| echo: false +z_tmp = zarr.open(r"C:\Users\cervaf\Documents\Logging\activelearning_logs\membrane.zarr\segmentation\0", mode="r") +``` +``` {python} +#| echo: false +labels_mgr.labels_group_root.child(1).child(3)._position +viewer.layers["Labels edit"].data[:] = z_tmp[labels_mgr.labels_group_root.child(1).child(3)._position] +labels_mgr.commit() +``` +``` {python} +#| echo: false +screenshot = viewer.screenshot(canvas_only=False, flash=False) + +image = Image.fromarray(screenshot) + +org_height, org_width = screenshot.shape[:2] + +roi = (org_width * 0.7, 0, org_width, org_height) + +# Crop the image +cropped_image = image.crop(roi) + +draw = ImageDraw.Draw(cropped_image) + +# Draw a red rectangle +draw.rectangle([305, 900, 570, 940], outline="white", width=5) +draw.rectangle([305, 900, 570, 940], outline="green", width=2) + +cropped_image.resize((int(0.15 * org_width), int(org_height * 0.5)), Image.Resampling.LANCZOS) +``` +::: +:::: + +## 4.3 Navigate between segmented patches + +:::: {.columns} +::: {.column width=0.3} + +1. Expand the second group of labels + +2. Double-click on any of the nested items to open it for editing + +3. Use the navigation buttons to move between segmented patches + +4. Continue editing the segmentation in the current patch and commit the changes when finish +::: + +::: {.column width=0.7} + +``` {python} +#| echo: false +labels_mgr.labels_table_tw.expandItem(labels_mgr.labels_group_root.child(1)) +labels_mgr.labels_table_tw.setCurrentItem(labels_mgr.labels_group_root.child(1).child(0), 0) +labels_mgr.focus_region( + labels_mgr.labels_group_root.child(1).child(0), + edit_focused_label=True +) +``` +``` {python} +#| echo: false +screenshot = viewer.screenshot(canvas_only=False, flash=False) + +image = Image.fromarray(screenshot) + +org_height, org_width = image.height, image.width + +roi = (org_width * 0.7, 0, org_width, org_height) + +# Crop the image +cropped_image = image.crop(roi) + +draw = ImageDraw.Draw(cropped_image) + +# Draw a red rectangle +draw.rectangle([70, 115, 270, 140], outline="white", width=5) +draw.rectangle([70, 115, 270, 140], outline="green", width=2) + +position = (50, 105) +font = ImageFont.truetype("arialbd.ttf", size=36) +draw.text(position, "1", fill="green", font=font) +font = ImageFont.truetype("arial.ttf", size=36) +draw.text(position, "1", fill="white", font=font) + +draw.rectangle([110, 135, 560, 215], outline="white", width=5) +draw.rectangle([110, 135, 560, 215], outline="green", width=2) + +position = (80, 150) +font = ImageFont.truetype("arialbd.ttf", size=36) +draw.text(position, "2", fill="green", font=font) +font = ImageFont.truetype("arial.ttf", size=36) +draw.text(position, "2", fill="white", font=font) + +draw.rectangle([60, 870, 570, 905], outline="white", width=5) +draw.rectangle([60, 870, 570, 905], outline="green", width=2) + +position = (290, 820) +font = ImageFont.truetype("arialbd.ttf", size=36) +draw.text(position, "3", fill="green", font=font) +font = ImageFont.truetype("arial.ttf", size=36) +draw.text(position, "3", fill="white", font=font) + +# Draw a red rectangle +draw.rectangle([305, 900, 570, 940], outline="white", width=5) +draw.rectangle([305, 900, 570, 940], outline="green", width=2) + +position = (270, 905) +font = ImageFont.truetype("arialbd.ttf", size=36) +draw.text(position, "4", fill="green", font=font) +font = ImageFont.truetype("arial.ttf", size=36) +draw.text(position, "4", fill="white", font=font) + +cropped_image.resize((int(0.15 * org_width), int(org_height * 0.5)), Image.Resampling.LANCZOS) +``` +``` {python} +#| echo: false +for c_idx in range(labels_mgr.labels_group_root.child(1).childCount()): + labels_mgr.focus_region( + labels_mgr.labels_group_root.child(1).child(c_idx), + edit_focused_label=True + ) + labels_mgr.labels_group_root.child(1).child(c_idx)._position + viewer.layers["Labels edit"].data[:] = z_tmp[labels_mgr.labels_group_root.child(1).child(c_idx)._position] + labels_mgr.commit() + +``` + +::: + +:::: + +## 4.4 Setup fine tuning configuration + +### 4.4.1 Use the _Acquisition function configuration_ widget to set the configuration for executing the fine tuning process + +1. Go to the "Acquisition function configuration" widget + +2. Click the "Advanced fine tuning parameters" checkbox + +3. Change the "save path" to a location where you want to store the fine tuned model + +``` {python} +#| echo: false +acquisition_fun_cfg_dw.raise_() +``` +``` {python} +#| echo: false +acquisition_fun_cfg.tunable_segmentation_method.advanced_finetuning_options_chk.setChecked(True) +acquisition_fun_cfg.tunable_segmentation_method._finetuning_parameters.save_path.value = "C:/Users/Public/Documents/models" +``` +``` {python} +#| echo: false +screenshot = viewer.screenshot(canvas_only=False, flash=False) + +image = Image.fromarray(screenshot) + +org_height, org_width = image.height, image.width + +roi = (org_width * 0.7, 0, org_width, org_height) + +# Crop the image +cropped_image = image.crop(roi) + +draw = ImageDraw.Draw(cropped_image) + +# Draw a red rectangle +draw.rectangle([185, 965, 385, 1000], outline="white", width=5) +draw.rectangle([185, 965, 385, 1000], outline="green", width=2) + +position = (200, 920) +font = ImageFont.truetype("arialbd.ttf", size=36) +draw.text(position, "1", fill="green", font=font) +font = ImageFont.truetype("arial.ttf", size=36) +draw.text(position, "1", fill="white", font=font) + +draw.rectangle([75, 690, 540, 725], outline="white", width=5) +draw.rectangle([75, 690, 540, 725], outline="green", width=2) + +position = (40, 490) +font = ImageFont.truetype("arialbd.ttf", size=36) +draw.text(position, "2", fill="green", font=font) +font = ImageFont.truetype("arial.ttf", size=36) +draw.text(position, "2", fill="white", font=font) + +draw.rectangle([65, 495, 300, 535], outline="white", width=5) +draw.rectangle([65, 495, 300, 535], outline="green", width=2) + +position = (50, 690) +font = ImageFont.truetype("arialbd.ttf", size=36) +draw.text(position, "3", fill="green", font=font) +font = ImageFont.truetype("arial.ttf", size=36) +draw.text(position, "3", fill="white", font=font) + +cropped_image.resize((int(0.15 * org_width), int(org_height * 0.5)), Image.Resampling.LANCZOS) +``` + +--- + +### 4.4.2 Set the *learning rate* and *batch size* + +1. Scroll the _Advanced fine tuning parameters_ widget down to show more parameters + +3. Change the "model name" to "nuclei_ft" + +3. Set the "batch size" to $3$ + +4. Change the "learning rate" to $0.0001$ + +::: {.callout-note} +You can modify other parameters for the training process here, such as the number of training epochs. +::: + +``` {python} +#| echo: false +vertical_scroll_bar = acquisition_fun_cfg.tunable_segmentation_method._finetuning_parameters_scr.verticalScrollBar() +vertical_scroll_bar.setValue(vertical_scroll_bar.maximum()) + +acquisition_fun_cfg.tunable_segmentation_method._finetuning_parameters.model_name.value = "nuclei_ft" +acquisition_fun_cfg.tunable_segmentation_method._finetuning_parameters.batch_size.value = 3 +acquisition_fun_cfg.tunable_segmentation_method._finetuning_parameters.learning_rate.value = 0.0001 +``` +``` {python} +#| echo: false +screenshot = viewer.screenshot(canvas_only=False, flash=False) + +image = Image.fromarray(screenshot) + +org_height, org_width = image.height, image.width + +roi = (org_width * 0.7, 0, org_width, org_height) + +# Crop the image +cropped_image = image.crop(roi) + +draw = ImageDraw.Draw(cropped_image) + +# Draw a red rectangle +draw.rectangle([535, 525, 565, 845], outline="white", width=5) +draw.rectangle([535, 525, 565, 845], outline="green", width=2) + +position = (515, 480) +font = ImageFont.truetype("arialbd.ttf", size=36) +draw.text(position, "1", fill="green", font=font) +font = ImageFont.truetype("arial.ttf", size=36) +draw.text(position, "1", fill="white", font=font) + +draw.rectangle([50, 725, 540, 755], outline="white", width=5) +draw.rectangle([50, 725, 540, 755], outline="green", width=2) + +position = (20, 720) +font = ImageFont.truetype("arialbd.ttf", size=34) +draw.text(position, "2", fill="green", font=font) +font = ImageFont.truetype("arial.ttf", size=34) +draw.text(position, "2", fill="white", font=font) + +draw.rectangle([50, 755, 540, 780], outline="white", width=5) +draw.rectangle([50, 755, 540, 780], outline="green", width=2) + +position = (20, 750) +font = ImageFont.truetype("arialbd.ttf", size=34) +draw.text(position, "3", fill="green", font=font) +font = ImageFont.truetype("arial.ttf", size=34) +draw.text(position, "3", fill="white", font=font) + +draw.rectangle([50, 780, 540, 810], outline="white", width=5) +draw.rectangle([50, 780, 540, 810], outline="green", width=2) + +position = (20, 780) +font = ImageFont.truetype("arialbd.ttf", size=34) +draw.text(position, "4", fill="green", font=font) +font = ImageFont.truetype("arial.ttf", size=34) +draw.text(position, "4", fill="white", font=font) + +cropped_image.resize((int(0.15 * org_width), int(org_height * 0.5)), Image.Resampling.LANCZOS) +``` + +## 4.5 Execute the fine tuning process + +- Click the "Fine tune model" button to run the training process. + +::: {.callout-note} +Depending on your computer resources (RAM, CPU), this process might take some minutes to complete. If you have a dedicated GPU device, this can take a couple of seconds instead. +::: + +``` {python} +#| echo: false +screenshot = viewer.screenshot(canvas_only=False, flash=False) + +image = Image.fromarray(screenshot) + +org_height, org_width = image.height, image.width + +roi = (org_width * 0.7, 0, org_width, org_height) + +# Crop the image +cropped_image = image.crop(roi) + +draw = ImageDraw.Draw(cropped_image) + +# Draw a red rectangle +draw.rectangle([245, 880, 405, 920], outline="white", width=5) +draw.rectangle([245, 880, 405, 920], outline="green", width=2) + +cropped_image.resize((int(0.15 * org_width), int(org_height * 0.5)), Image.Resampling.LANCZOS) +``` +``` {python} +#| echo: false +acquisition_fun_cfg.fine_tune() +``` + +## 4.6 Review the fine tuned segmentation + +``` {python} +#| echo: false +screenshot = viewer.screenshot(canvas_only=False, flash=False) +image = Image.fromarray(screenshot) +image +``` \ No newline at end of file diff --git a/docs/tutorials/cellpose_tutorial.quarto_ipynb b/docs/tutorials/cellpose_tutorial.quarto_ipynb new file mode 100644 index 0000000..77e2d32 --- /dev/null +++ b/docs/tutorials/cellpose_tutorial.quarto_ipynb @@ -0,0 +1,1623 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "title: \"Tutorial: How to fine tune a Cellpose model\"\n", + "author: Fernando Cervantes (fernando.cervantes@jax.org)\n", + "format:\n", + " revealjs:\n", + " code-fold: false\n", + " progress: true\n", + " controls: true\n", + " fontsize: 22pt\n", + "\n", + "jupyter: python3\n", + "---\n", + "\n", + "\n", + "# 1 Image groups management\n", + "\n", + "## 1.1 Load a sample image\n", + "\n", + "You can use the cells 3D image sample from the napari's built-in samples.\n", + "\n", + "```\n", + "File > Open Sample > napari builtins > Cells (3D+2Ch)\n", + "```\n" + ], + "id": "e97a4294" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "#| echo: false\n", + "import napari\n", + "from napari.utils import nbscreenshot\n", + "from PIL import Image, ImageDraw, ImageFont\n", + "import napari_activelearning as al\n", + "import zarr" + ], + "id": "ef726e2f", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "#| echo: false\n", + "viewer = napari.Viewer()\n", + "_ = viewer.open_sample(plugin=\"napari\", sample=\"cells3d\")" + ], + "id": "7fd45b30", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "#| echo: false\n", + "nbscreenshot(viewer)" + ], + "id": "3a5ba3c3", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1.2 Add the _Image Groups Manager_ widget to napari's window\n", + "\n", + "You can find the _Image group manager_ under the _Active Learning_ plugin in the napari's plugins menu.\n", + "\n", + "```\n", + "Plugins > Active Learning > Image groups manager\n", + "```\n" + ], + "id": "4fb57689" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "#| echo: false\n", + "image_groups_mgr, acquisition_fun_cfg, labels_mgr = al.get_active_learning_widget()" + ], + "id": "f15fa69a", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "#| echo: false\n", + "image_groups_mgr_dw = viewer.window.add_dock_widget(image_groups_mgr)" + ], + "id": "c3362079", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "#| echo: false\n", + "nbscreenshot(viewer)" + ], + "id": "d84dd0e2", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1.3 Create an _Image Group_ containing _nuclei_ and _membrane_ layers\n", + "\n", + "- Select the _nuclei_ and _membrane_ layer and click the _New Image Group_ button on the _Image Groups Manager_ widget.\n" + ], + "id": "1c63b9ef" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "#| echo: false\n", + "viewer.layers.selection.clear()\n", + "viewer.layers.selection.add(viewer.layers[\"nuclei\"])\n", + "viewer.layers.selection.add(viewer.layers[\"membrane\"])\n", + "image_groups_mgr.create_group()" + ], + "id": "51287fc3", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "#| echo: false\n", + "nbscreenshot(viewer)" + ], + "id": "7c3c2dcb", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1.4 Edit the image group properties\n", + "\n", + ":::: {.columns}\n", + "\n", + "::: {.column width=0.3}\n", + "- Select the newly created image group, it will appear as \"images\" in the _Image groups manager_ widget.\n", + "\n", + "- Click the _Edit group properties_ checkbox.\n", + "\n", + "- Make sure that _Axes order_ is \"CZYX\", otherwise, you can edit it and press _Enter_ to update the axes names.\n", + ":::\n", + "\n", + "::: {.column width=0.7}" + ], + "id": "6f7a3c84" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "#| echo: false\n", + "image_groups_mgr.image_groups_editor._show_editor(True)\n", + "image_groups_mgr._active_image_group.child(0).setSelected(True)\n", + "image_groups_mgr.image_groups_editor.edit_axes_le.setText(\"CZYX\")\n", + "image_groups_mgr.image_groups_editor.update_source_axes()" + ], + "id": "ee5b13b6", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "#| echo: false\n", + "screenshot = viewer.screenshot(canvas_only=False, flash=False)\n", + "\n", + "image = Image.fromarray(screenshot)\n", + "\n", + "org_height, org_width = screenshot.shape[:2]\n", + "\n", + "roi = (org_width * 0.7, 0, org_width, org_height)\n", + "\n", + "# Crop the image\n", + "cropped_image = image.crop(roi)\n", + "\n", + "draw = ImageDraw.Draw(cropped_image)\n", + "\n", + "# Draw a red rectangle\n", + "draw.rectangle([70, 250, 310, 280], outline=\"white\", width=5)\n", + "draw.rectangle([70, 250, 310, 280], outline=\"green\", width=2)\n", + "\n", + "cropped_image" + ], + "id": "23ff59fc", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + ":::\n", + "\n", + "::::\n", + "\n", + "# 2 Segment the managed image groups\n", + "\n", + "## 2.1 Add the _Acquisition function configuration_ widget to napari's window\n", + "\n", + "The _Acquisition function configuration_ is under the _Active Learning_ plugin in the napari's plugins menu.\n", + "\n", + "```\n", + "Plugins > Active Learning > Acquisition function configuration\n", + "```\n" + ], + "id": "534a1f40" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "#| echo: false\n", + "acquisition_fun_cfg_dw = viewer.window.add_dock_widget(acquisition_fun_cfg)\n", + "viewer.window._qt_window.tabifyDockWidget(image_groups_mgr_dw, acquisition_fun_cfg_dw)\n", + "\n", + "image_groups_mgr_dw.setWindowTitle(\"Image groups manager\")\n", + "acquisition_fun_cfg_dw.setWindowTitle(\"Acquisition function configuration\")" + ], + "id": "8c565d45", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "#| echo: false\n", + "acquisition_fun_cfg_dw.raise_()" + ], + "id": "199e2302", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "#| echo: false\n", + "nbscreenshot(viewer)" + ], + "id": "af143943", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2.2 Define sampling configuration\n", + "\n", + ":::: {.columns}\n", + "\n", + "::: {.column width=0.3}\n", + "### 2.2.1 Set the axes of the sampling space\n", + "\n", + "1. Make sure \"Input axes\" are set to \"ZYX\"\n", + "\n", + "::: {.callout-note}\n", + "This specifies that the samples will be taken from those axes.\n", + ":::\n", + "\n", + "2. Change the \"Model axes\" to \"CYX\"\n", + "\n", + ":::\n", + "\n", + "::: {.column width=0.3}\n" + ], + "id": "22e7af5b" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "#| echo: false\n", + "acquisition_fun_cfg.input_axes_le.setText(\"ZYX\")\n", + "acquisition_fun_cfg._set_input_axes()\n", + "acquisition_fun_cfg.model_axes_le.setText(\"CYX\")\n", + "acquisition_fun_cfg._set_model_axes()" + ], + "id": "3bcb37d4", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "#| echo: false\n", + "screenshot = viewer.screenshot(canvas_only=False, flash=False)\n", + "\n", + "image = Image.fromarray(screenshot)\n", + "\n", + "org_height, org_width = screenshot.shape[:2]\n", + "\n", + "roi = (org_width * 0.7, 0, org_width, org_height)\n", + "\n", + "# Crop the image\n", + "cropped_image = image.crop(roi)\n", + "\n", + "draw = ImageDraw.Draw(cropped_image)\n", + "\n", + "draw.rectangle([60, 390, 405, 425], outline=\"white\", width=5)\n", + "draw.rectangle([60, 390, 405, 425], outline=\"green\", width=2)\n", + "\n", + "position = (250, 345)\n", + "font = ImageFont.truetype(\"arialbd.ttf\", size=36)\n", + "draw.text(position, \"1\", fill=\"green\", font=font)\n", + "font = ImageFont.truetype(\"arial.ttf\", size=36)\n", + "draw.text(position, \"1\", fill=\"white\", font=font)\n", + "\n", + "draw.rectangle([405, 390, 570, 425], outline=\"white\", width=5)\n", + "draw.rectangle([405, 390, 570, 425], outline=\"green\", width=2)\n", + "\n", + "position = (490, 345)\n", + "font = ImageFont.truetype(\"arialbd.ttf\", size=36)\n", + "draw.text(position, \"2\", fill=\"green\", font=font)\n", + "font = ImageFont.truetype(\"arial.ttf\", size=36)\n", + "draw.text(position, \"2\", fill=\"white\", font=font)\n", + "\n", + "cropped_image" + ], + "id": "a12be33c", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + ":::\n", + "\n", + "::::\n", + "\n", + "---\n", + "\n", + "### 2.2.2 Set the size of the sampling patch\n", + "\n", + ":::: {.columns}\n", + "\n", + "::: {.column width=0.70}\n" + ], + "id": "99bf0c97" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "#| echo: false\n", + "acquisition_fun_cfg._show_patch_sizes(True)\n", + "\n", + "acquisition_fun_cfg.patch_sizes_mspn.sizes = {\"Z\": 1, \"Y\": 256, \"X\": 256}" + ], + "id": "d7a1aa80", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "#| echo: false\n", + "screenshot = viewer.screenshot(canvas_only=False, flash=False)\n", + "\n", + "image = Image.fromarray(screenshot)\n", + "\n", + "org_height, org_width = screenshot.shape[:2]\n", + "\n", + "roi = (org_width * 0.7, 0, org_width, org_height)\n", + "\n", + "# Crop the image\n", + "cropped_image = image.crop(roi)\n", + "\n", + "draw = ImageDraw.Draw(cropped_image)\n", + "\n", + "# Draw a red rectangle\n", + "draw.rectangle([70, 240, 470, 600], outline=\"white\", width=5)\n", + "draw.rectangle([70, 240, 470, 600], outline=\"green\", width=2)\n", + "\n", + "cropped_image" + ], + "id": "1c6c4883", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + ":::\n", + "\n", + "::: {.column width=0.30}\n", + "\n", + "- Click the \"Edit patch size\" checkbox\n", + "- Change the patch size of \"X\" and \"Y\" to 256, and the \"Z\" axis to 1.\n", + "\n", + ":::{.callout-note}\n", + "This directs the Active Learning plugin to sample at random patches of size $256\\times256$ pixels, and $1$ slice of depth.\n", + ":::\n", + "\n", + ":::\n", + "\n", + "::::\n", + "\n", + "## 2.3 Define the maximum number of samples to extract\n", + "\n", + ":::: {.columns}\n", + "\n", + "::: {.column width=0.7}\n" + ], + "id": "070d6da4" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "#| echo: false\n", + "acquisition_fun_cfg.max_samples_spn.setValue(4)" + ], + "id": "ca2eec70", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "#| echo: false\n", + "screenshot = viewer.screenshot(canvas_only=False, flash=False)\n", + "\n", + "image = Image.fromarray(screenshot)\n", + "\n", + "org_height, org_width = screenshot.shape[:2]\n", + "\n", + "roi = (org_width * 0.7, 0, org_width, org_height)\n", + "\n", + "# Crop the image\n", + "cropped_image = image.crop(roi)\n", + "\n", + "draw = ImageDraw.Draw(cropped_image)\n", + "\n", + "# Draw a red rectangle\n", + "draw.rectangle([60, 740, 405, 770], outline=\"white\", width=5)\n", + "draw.rectangle([60, 740, 405, 770], outline=\"green\", width=2)\n", + "\n", + "cropped_image.resize((int(0.15 * org_width), int(org_height * 0.5)), Image.Resampling.LANCZOS)" + ], + "id": "4a5ddfdf", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + ":::\n", + "\n", + "::: {.column width=0.30}\n", + "- Set the \"Maximum samples\" to $4$ and press _Enter_\n", + "\n", + ":::{.callout-note}\n", + "This tells the Active Learning plugin to process at most _four_ samples at random from the whole image.\n", + ":::\n", + "\n", + ":::\n", + "\n", + "::::\n", + "\n", + "\n", + "## 2.4 Configure the segmentation method\n", + "\n", + ":::: {.columns}\n", + "\n", + "::: {.column width=0.3}\n", + "\n", + "1. Use the \"Model\" dropdown to select the `cellpose` method\n", + "\n", + "2. Click the \"Advanced segmentation parameters\" checkbox\n", + "\n", + "3. Change the \"Channel axis\" to $0$\n", + "::: {.callout-note}\n", + "This makes `cellpose` to use the first axis as \"Color\" channel.\n", + ":::\n", + "\n", + "4. Change the second channel to $1$ (the right spin box in the \"channels\" row)\n", + "\n", + "::: {.callout-note}\n", + "This tells `cellpose` to segment the first channel ($0$) and use the second channel ($1$) as help channel.\n", + ":::\n", + "\n", + "5. Choose the \"nuclei\" model from the dropdown\n", + "\n", + "\n", + ":::\n", + "\n", + "::: {.column width=0.7}" + ], + "id": "44e9ee07" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "#| echo: false\n", + "acquisition_fun_cfg.methods_cmb.setCurrentIndex(2)\n", + "acquisition_fun_cfg.tunable_segmentation_method.advanced_segmentation_options_chk.setChecked(True)\n", + "acquisition_fun_cfg.tunable_segmentation_method._segmentation_parameters.channel_axis.value = 0\n", + "acquisition_fun_cfg.tunable_segmentation_method._segmentation_parameters.channels.value = (0, 1)\n", + "acquisition_fun_cfg.tunable_segmentation_method._segmentation_parameters.model_type.value = \"nuclei\"" + ], + "id": "9855435d", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "#| echo: false\n", + "screenshot = viewer.screenshot(canvas_only=False, flash=False)\n", + "\n", + "image = Image.fromarray(screenshot)\n", + "\n", + "org_height, org_width = screenshot.shape[:2]\n", + "\n", + "roi = (org_width * 0.7, 0, org_width, org_height)\n", + "\n", + "# Crop the image\n", + "cropped_image = image.crop(roi)\n", + "\n", + "draw = ImageDraw.Draw(cropped_image)\n", + "\n", + "# Draw a red rectangle\n", + "draw.rectangle([55, 490, 570, 520], outline=\"white\", width=5)\n", + "draw.rectangle([55, 490, 570, 520], outline=\"green\", width=2)\n", + "\n", + "position = (550, 450)\n", + "font = ImageFont.truetype(\"arialbd.ttf\", size=36)\n", + "draw.text(position, \"1\", fill=\"green\", font=font)\n", + "font = ImageFont.truetype(\"arial.ttf\", size=36)\n", + "draw.text(position, \"1\", fill=\"white\", font=font)\n", + "\n", + "draw.rectangle([65, 525, 310, 560], outline=\"white\", width=5)\n", + "draw.rectangle([65, 525, 310, 560], outline=\"green\", width=2)\n", + "\n", + "position = (320, 525)\n", + "font = ImageFont.truetype(\"arialbd.ttf\", size=36)\n", + "draw.text(position, \"2\", fill=\"green\", font=font)\n", + "font = ImageFont.truetype(\"arial.ttf\", size=36)\n", + "draw.text(position, \"2\", fill=\"white\", font=font)\n", + "\n", + "draw.rectangle([390, 655, 550, 690], outline=\"white\", width=5)\n", + "draw.rectangle([390, 655, 550, 690], outline=\"green\", width=2)\n", + "\n", + "position = (400, 545)\n", + "font = ImageFont.truetype(\"arialbd.ttf\", size=36)\n", + "draw.text(position, \"3\", fill=\"green\", font=font)\n", + "font = ImageFont.truetype(\"arial.ttf\", size=36)\n", + "draw.text(position, \"3\", fill=\"white\", font=font)\n", + "\n", + "draw.rectangle([235, 580, 550, 615], outline=\"white\", width=5)\n", + "draw.rectangle([235, 580, 550, 615], outline=\"green\", width=2)\n", + "\n", + "position = (450, 615)\n", + "font = ImageFont.truetype(\"arialbd.ttf\", size=36)\n", + "draw.text(position, \"4\", fill=\"green\", font=font)\n", + "font = ImageFont.truetype(\"arial.ttf\", size=36)\n", + "draw.text(position, \"4\", fill=\"white\", font=font)\n", + "\n", + "draw.rectangle([70, 725, 550, 765], outline=\"white\", width=5)\n", + "draw.rectangle([70, 725, 550, 765], outline=\"green\", width=2)\n", + "\n", + "position = (450, 770)\n", + "font = ImageFont.truetype(\"arialbd.ttf\", size=36)\n", + "draw.text(position, \"5\", fill=\"green\", font=font)\n", + "font = ImageFont.truetype(\"arial.ttf\", size=36)\n", + "draw.text(position, \"5\", fill=\"white\", font=font)\n", + "\n", + "cropped_image.resize((int(0.15 * org_width), int(org_height * 0.5)), Image.Resampling.LANCZOS)" + ], + "id": "604e121d", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + ":::\n", + "\n", + "::::\n", + "\n", + "## 2.5 Execute the segmentation method on all image groups\n", + "\n", + ":::: {.columns}\n", + "\n", + "::: {.column width=0.3}\n", + "- Click the \"Run on all image groups\"\n", + "\n", + "::: {.callout-note}\n", + "To execute the segmentation only on specific image groups, select the desired image groups in the _Image groups manager_ widget and use the \"Run on selected image groups\" button instead.\n", + ":::\n", + "\n", + ":::\n", + "\n", + "::: {.column width=0.7}" + ], + "id": "3ae56237" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "#| echo: false\n", + "screenshot = viewer.screenshot(canvas_only=False, flash=False)\n", + "\n", + "image = Image.fromarray(screenshot)\n", + "\n", + "org_height, org_width = screenshot.shape[:2]\n", + "\n", + "roi = (org_width * 0.7, 0, org_width, org_height)\n", + "\n", + "# Crop the image\n", + "cropped_image = image.crop(roi)\n", + "\n", + "draw = ImageDraw.Draw(cropped_image)\n", + "\n", + "# Draw a red rectangle\n", + "draw.rectangle([245, 850, 405, 880], outline=\"white\", width=5)\n", + "draw.rectangle([245, 850, 405, 880], outline=\"green\", width=2)\n", + "\n", + "cropped_image.resize((int(0.15 * org_width), int(org_height * 0.5)), Image.Resampling.LANCZOS)" + ], + "id": "604ff3e7", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "#| echo: false\n", + "_ = acquisition_fun_cfg.compute_acquisition_layers(run_all=True)" + ], + "id": "d35b902b", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + ":::\n", + "\n", + "::::\n", + "\n", + "## 2.6 Inspect the segmentation layer\n", + "\n", + "::: {.callout-note}\n", + "Because the input image is 3D, you might have to slide the Z index on the bottom of napari's window to look at the samples that have been segmented.\n", + ":::\n" + ], + "id": "098f3c9a" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "#| echo: false\n", + "labels_mgr.focus_region(\n", + " labels_mgr.labels_group_root.child(0).child(0)\n", + ")" + ], + "id": "5c70be38", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "#| echo: false\n", + "screenshot = viewer.screenshot(canvas_only=False, flash=False)\n", + "\n", + "image = Image.fromarray(screenshot)\n", + "\n", + "draw = ImageDraw.Draw(image)\n", + "\n", + "# Draw a red rectangle\n", + "draw.rectangle([295, 1080, 1380, 1110], outline=\"white\", width=5)\n", + "draw.rectangle([295, 1080, 1380, 1115], outline=\"green\", width=2)\n", + "\n", + "image" + ], + "id": "885ec293", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 3 Segment masked regions only\n", + "\n", + "## 3.1 Create a mask to restrict the sampling space\n", + "\n", + "### 3.1.1 Add a mask layer to the image group\n", + "\n", + ":::: {.columns}\n", + "::: {.column width=0.3}\n", + "- Switch to the \"Image groups manager\" tab\n", + "- Click the \"Edit mask properties\" checkbox\n", + ":::\n", + "\n", + "::: {.column width=0.7}" + ], + "id": "60c55ddf" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "#| echo: false\n", + "image_groups_mgr_dw.raise_()" + ], + "id": "9e60c1f4", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "#| echo: false\n", + "screenshot = viewer.screenshot(canvas_only=False, flash=False)\n", + "\n", + "image = Image.fromarray(screenshot)\n", + "\n", + "org_height, org_width = screenshot.shape[:2]\n", + "\n", + "roi = (org_width * 0.7, 0, org_width, org_height)\n", + "\n", + "# Crop the image\n", + "cropped_image = image.crop(roi)\n", + "\n", + "draw = ImageDraw.Draw(cropped_image)\n", + "\n", + "# Draw a red rectangle\n", + "draw.rectangle([65, 380, 225, 415], outline=\"white\", width=5)\n", + "draw.rectangle([65, 380, 225, 415], outline=\"green\", width=2)\n", + "\n", + "cropped_image" + ], + "id": "673031c2", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "image_groups_mgr.mask_generator._show_editor(True)" + ], + "id": "ac88e6e6", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + ":::\n", + "\n", + "::::\n", + "\n", + "---\n", + "\n", + "### 3.1.2 Create a low resolution mask for its corresponding image\n", + "\n", + ":::: {.columns}\n", + ":::{.column width=0.3}\n", + "1. Set the mask scale to $256$ for the \"X\" and \"Y\" axes, and a scale of $1$ for the \"Z\" axis\n", + "\n", + "2. Click the \"Create mask\" button\n", + "\n", + ":::{.callout-note}\n", + "This creates a low-resolution mask where each pixel corresponds to a $256\\times256$ pixels region in the input image.\n", + "Because the mask is low-resolution, it uses less space (in memory RAM and disk).\n", + ":::\n", + "\n", + ":::\n", + "\n", + ":::{.column width=0.7}" + ], + "id": "9bfb162c" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "#| echo: false\n", + "image_groups_mgr.mask_generator.patch_sizes_mspn.sizes = {\"Z\": 1, \"Y\": 256, \"X\": 256}\n", + "image_groups_mgr.mask_generator.generate_mask_layer()" + ], + "id": "4ca4b9af", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "#| echo: false\n", + "screenshot = viewer.screenshot(canvas_only=False, flash=False)\n", + "\n", + "image = Image.fromarray(screenshot)\n", + "\n", + "org_height, org_width = screenshot.shape[:2]\n", + "\n", + "roi = (org_width * 0.7, 0, org_width, org_height)\n", + "\n", + "# Crop the image\n", + "cropped_image = image.crop(roi)\n", + "\n", + "draw = ImageDraw.Draw(cropped_image)\n", + "\n", + "# Draw a red rectangle\n", + "draw.rectangle([60, 450, 550, 600], outline=\"white\", width=5)\n", + "draw.rectangle([60, 450, 550, 600], outline=\"green\", width=2)\n", + "\n", + "position = (500, 400)\n", + "font = ImageFont.truetype(\"arialbd.ttf\", size=36)\n", + "draw.text(position, \"1\", fill=\"green\", font=font)\n", + "font = ImageFont.truetype(\"arial.ttf\", size=36)\n", + "draw.text(position, \"1\", fill=\"white\", font=font)\n", + "\n", + "draw.rectangle([70, 620, 560, 650], outline=\"white\", width=5)\n", + "draw.rectangle([70, 620, 560, 650], outline=\"green\", width=2)\n", + "\n", + "position = (530, 650)\n", + "font = ImageFont.truetype(\"arialbd.ttf\", size=36)\n", + "draw.text(position, \"2\", fill=\"green\", font=font)\n", + "font = ImageFont.truetype(\"arial.ttf\", size=36)\n", + "draw.text(position, \"2\", fill=\"white\", font=font)\n", + "\n", + "cropped_image" + ], + "id": "dec004a3", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + ":::\n", + "::::\n", + "\n", + "\n", + "## 3.1.3 Specify the samplable regions\n", + "\n", + "- Draw a mask on slices $27$ to $30$ in the \"Z\" axis.\n", + "\n", + ":::{.callout-note}\n", + "You can move the slider at the bottom of napari's window to navigate between slices in the \"Z\" axis.\n", + ":::\n" + ], + "id": "b5bdf978" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "#| echo: false\n", + "viewer.camera.center = (27, 128, 128)\n", + "viewer.dims.current_step = (27, 128, 128)\n", + "\n", + "viewer.layers[\"images mask\"].data[0, 0, 27:31, 0, 0] = 1\n", + "viewer.layers[\"images mask\"].refresh()" + ], + "id": "7fb3ba4f", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "#| echo: false\n", + "screenshot = viewer.screenshot(canvas_only=False, flash=False)\n", + "\n", + "image = Image.fromarray(screenshot)\n", + "\n", + "image" + ], + "id": "6ff66d54", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3.2 Execute the segmentation process on the masked regions\n", + "\n", + ":::: {.columns}\n", + "::: {.column width=0.3}\n", + "- Go back to the \"Acquisition function configuration\" widget\n", + "- Click the \"Run on all image groups\" button again\n", + "\n", + "::: {.callout-note}\n", + "Because the image group has a defined mask, samples will be extracted at random inside those defined regions only.\n", + ":::\n", + "\n", + ":::\n", + "\n", + "::: {.column width=0.7}" + ], + "id": "6b9fdbe0" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "#| echo: false\n", + "acquisition_fun_cfg_dw.raise_()" + ], + "id": "dd5158a2", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "#| echo: false\n", + "screenshot = viewer.screenshot(canvas_only=False, flash=False)\n", + "\n", + "image = Image.fromarray(screenshot)\n", + "\n", + "org_height, org_width = screenshot.shape[:2]\n", + "\n", + "roi = (org_width * 0.7, 0, org_width, org_height)\n", + "\n", + "# Crop the image\n", + "cropped_image = image.crop(roi)\n", + "\n", + "draw = ImageDraw.Draw(cropped_image)\n", + "\n", + "# Draw a red rectangle\n", + "draw.rectangle([245, 850, 400, 890], outline=\"white\", width=5)\n", + "draw.rectangle([245, 850, 400, 890], outline=\"green\", width=2)\n", + "\n", + "cropped_image.resize((int(0.15 * org_width), int(org_height * 0.5)), Image.Resampling.LANCZOS)" + ], + "id": "c0fa998d", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "#| echo: false\n", + "_ = acquisition_fun_cfg.compute_acquisition_layers(run_all=True)" + ], + "id": "671da31b", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + ":::\n", + "\n", + "::::\n", + "\n", + "## 3.3 Inspect the masked segmentation output\n" + ], + "id": "25f528ed" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "#| echo: false\n", + "screenshot = viewer.screenshot(canvas_only=False, flash=False)\n", + "\n", + "image = Image.fromarray(screenshot)\n", + "image" + ], + "id": "0042b4aa", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 4. Fine tune the segmentation model\n", + "\n", + "## 4.1 Add the _Label groups manager_ widget to napari's window\n", + "\n", + "You can find the _Label groups manager_ under the _Active Learning_ plugin in the napari's plugins menu.\n", + "\n", + "```\n", + "Plugins > Active Learning > Label groups manager\n", + "```\n" + ], + "id": "783acbf6" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "#| echo: false\n", + "labels_mgr_dw = viewer.window.add_dock_widget(labels_mgr)\n", + "viewer.window._qt_window.tabifyDockWidget(acquisition_fun_cfg_dw, labels_mgr_dw)\n", + "\n", + "labels_mgr_dw.setWindowTitle(\"Label groups manager\")" + ], + "id": "e8332b1b", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "#| echo: false\n", + "labels_mgr_dw.raise_()" + ], + "id": "1b0536a5", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "#| echo: false\n", + "screenshot = viewer.screenshot(canvas_only=False, flash=False)\n", + "\n", + "image = Image.fromarray(screenshot)\n", + "image" + ], + "id": "a5f25f04", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4.2 Edit segmented patches\n", + "\n", + "### 4.2.1. Select a segmented patch to edit\n", + "\n", + "- You can double click on any segmented patch in the viewer \n", + "(e.g. on slice $27$)\n" + ], + "id": "80bac7a0" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "#| echo: false\n", + "labels_mgr.labels_table_tw.setCurrentItem(labels_mgr.labels_group_root.child(1).child(3), 0)\n", + "labels_mgr.focus_region(\n", + " labels_mgr.labels_group_root.child(1).child(3),\n", + " edit_focused_label=True\n", + ")" + ], + "id": "28f84458", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "#| echo: false\n", + "screenshot = viewer.screenshot(canvas_only=False, flash=False)\n", + "\n", + "image = Image.fromarray(screenshot)\n", + "image" + ], + "id": "f1c0e905", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "\n", + "### 4.2.2. Use napari's layer controls to make changes on the objects of the current patch\n" + ], + "id": "42f97e59" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "#| echo: false\n", + "screenshot = viewer.screenshot(canvas_only=False, flash=False)\n", + "\n", + "image = Image.fromarray(screenshot)\n", + "image\n", + "\n", + "\n", + "draw = ImageDraw.Draw(image)\n", + "\n", + "# Draw a red rectangle\n", + "draw.rectangle([5, 20, 280, 450], outline=\"white\", width=5)\n", + "draw.rectangle([5, 20, 280, 450], outline=\"green\", width=2)\n", + "\n", + "image" + ], + "id": "43532cb7", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "\n", + "### 4.2.3. Commit changes to the labels layer\n", + "\n", + ":::: {.columns}\n", + "::: {.column width=0.3}\n", + "\n", + "- Once you have finished editing the labels, click the \"Commit changes\" button on the _Label groups manager_\n", + ":::\n", + "\n", + "::: {.column width=0.7}" + ], + "id": "336a7cea" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "#| echo: false\n", + "z_tmp = zarr.open(r\"C:\\Users\\cervaf\\Documents\\Logging\\activelearning_logs\\membrane.zarr\\segmentation\\0\", mode=\"r\")" + ], + "id": "8a16faa5", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "#| echo: false\n", + "labels_mgr.labels_group_root.child(1).child(3)._position\n", + "viewer.layers[\"Labels edit\"].data[:] = z_tmp[labels_mgr.labels_group_root.child(1).child(3)._position]\n", + "labels_mgr.commit()" + ], + "id": "6934c785", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "#| echo: false\n", + "screenshot = viewer.screenshot(canvas_only=False, flash=False)\n", + "\n", + "image = Image.fromarray(screenshot)\n", + "\n", + "org_height, org_width = screenshot.shape[:2]\n", + "\n", + "roi = (org_width * 0.7, 0, org_width, org_height)\n", + "\n", + "# Crop the image\n", + "cropped_image = image.crop(roi)\n", + "\n", + "draw = ImageDraw.Draw(cropped_image)\n", + "\n", + "# Draw a red rectangle\n", + "draw.rectangle([305, 900, 570, 940], outline=\"white\", width=5)\n", + "draw.rectangle([305, 900, 570, 940], outline=\"green\", width=2)\n", + "\n", + "cropped_image.resize((int(0.15 * org_width), int(org_height * 0.5)), Image.Resampling.LANCZOS)" + ], + "id": "b55d1c5a", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + ":::\n", + "::::\n", + "\n", + "## 4.3 Navigate between segmented patches\n", + "\n", + ":::: {.columns}\n", + "::: {.column width=0.3}\n", + "\n", + "1. Expand the second group of labels \n", + "\n", + "2. Double-click on any of the nested items to open it for editing\n", + "\n", + "3. Use the navigation buttons to move between segmented patches\n", + "\n", + "4. Continue editing the segmentation in the current patch and commit the changes when finish\n", + ":::\n", + "\n", + "::: {.column width=0.7}\n" + ], + "id": "d13279b2" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "#| echo: false\n", + "labels_mgr.labels_table_tw.expandItem(labels_mgr.labels_group_root.child(1))\n", + "labels_mgr.labels_table_tw.setCurrentItem(labels_mgr.labels_group_root.child(1).child(0), 0)\n", + "labels_mgr.focus_region(\n", + " labels_mgr.labels_group_root.child(1).child(0),\n", + " edit_focused_label=True\n", + ")" + ], + "id": "98650afb", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "#| echo: false\n", + "screenshot = viewer.screenshot(canvas_only=False, flash=False)\n", + "\n", + "image = Image.fromarray(screenshot)\n", + "\n", + "org_height, org_width = image.height, image.width\n", + "\n", + "roi = (org_width * 0.7, 0, org_width, org_height)\n", + "\n", + "# Crop the image\n", + "cropped_image = image.crop(roi)\n", + "\n", + "draw = ImageDraw.Draw(cropped_image)\n", + "\n", + "# Draw a red rectangle\n", + "draw.rectangle([70, 115, 270, 140], outline=\"white\", width=5)\n", + "draw.rectangle([70, 115, 270, 140], outline=\"green\", width=2)\n", + "\n", + "position = (50, 105)\n", + "font = ImageFont.truetype(\"arialbd.ttf\", size=36)\n", + "draw.text(position, \"1\", fill=\"green\", font=font)\n", + "font = ImageFont.truetype(\"arial.ttf\", size=36)\n", + "draw.text(position, \"1\", fill=\"white\", font=font)\n", + "\n", + "draw.rectangle([110, 135, 560, 215], outline=\"white\", width=5)\n", + "draw.rectangle([110, 135, 560, 215], outline=\"green\", width=2)\n", + "\n", + "position = (80, 150)\n", + "font = ImageFont.truetype(\"arialbd.ttf\", size=36)\n", + "draw.text(position, \"2\", fill=\"green\", font=font)\n", + "font = ImageFont.truetype(\"arial.ttf\", size=36)\n", + "draw.text(position, \"2\", fill=\"white\", font=font)\n", + "\n", + "draw.rectangle([60, 870, 570, 905], outline=\"white\", width=5)\n", + "draw.rectangle([60, 870, 570, 905], outline=\"green\", width=2)\n", + "\n", + "position = (290, 820)\n", + "font = ImageFont.truetype(\"arialbd.ttf\", size=36)\n", + "draw.text(position, \"3\", fill=\"green\", font=font)\n", + "font = ImageFont.truetype(\"arial.ttf\", size=36)\n", + "draw.text(position, \"3\", fill=\"white\", font=font)\n", + "\n", + "# Draw a red rectangle\n", + "draw.rectangle([305, 900, 570, 940], outline=\"white\", width=5)\n", + "draw.rectangle([305, 900, 570, 940], outline=\"green\", width=2)\n", + "\n", + "position = (270, 905)\n", + "font = ImageFont.truetype(\"arialbd.ttf\", size=36)\n", + "draw.text(position, \"4\", fill=\"green\", font=font)\n", + "font = ImageFont.truetype(\"arial.ttf\", size=36)\n", + "draw.text(position, \"4\", fill=\"white\", font=font)\n", + "\n", + "cropped_image.resize((int(0.15 * org_width), int(org_height * 0.5)), Image.Resampling.LANCZOS)" + ], + "id": "d53dfb99", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "#| echo: false\n", + "for c_idx in range(labels_mgr.labels_group_root.child(1).childCount()):\n", + " labels_mgr.focus_region(\n", + " labels_mgr.labels_group_root.child(1).child(c_idx),\n", + " edit_focused_label=True\n", + " )\n", + " labels_mgr.labels_group_root.child(1).child(c_idx)._position\n", + " viewer.layers[\"Labels edit\"].data[:] = z_tmp[labels_mgr.labels_group_root.child(1).child(c_idx)._position]\n", + " labels_mgr.commit()" + ], + "id": "dd3f78e2", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + ":::\n", + "\n", + "::::\n", + "\n", + "## 4.4 Setup fine tuning configuration\n", + "\n", + "### 4.4.1 Use the _Acquisition function configuration_ widget to set the configuration for executing the fine tuning process\n", + "\n", + "1. Go to the \"Acquisition function configuration\" widget\n", + "\n", + "2. Click the \"Advanced fine tuning parameters\" checkbox\n", + "\n", + "3. Change the \"save path\" to a location where you want to store the fine tuned model\n" + ], + "id": "572025f2" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "#| echo: false\n", + "acquisition_fun_cfg_dw.raise_()" + ], + "id": "00c7dbdd", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "#| echo: false\n", + "acquisition_fun_cfg.tunable_segmentation_method.advanced_finetuning_options_chk.setChecked(True)\n", + "acquisition_fun_cfg.tunable_segmentation_method._finetuning_parameters.save_path.value = \"C:/Users/Public/Documents/models\"" + ], + "id": "4bd127a4", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "#| echo: false\n", + "screenshot = viewer.screenshot(canvas_only=False, flash=False)\n", + "\n", + "image = Image.fromarray(screenshot)\n", + "\n", + "org_height, org_width = image.height, image.width\n", + "\n", + "roi = (org_width * 0.7, 0, org_width, org_height)\n", + "\n", + "# Crop the image\n", + "cropped_image = image.crop(roi)\n", + "\n", + "draw = ImageDraw.Draw(cropped_image)\n", + "\n", + "# Draw a red rectangle\n", + "draw.rectangle([185, 965, 385, 1000], outline=\"white\", width=5)\n", + "draw.rectangle([185, 965, 385, 1000], outline=\"green\", width=2)\n", + "\n", + "position = (200, 920)\n", + "font = ImageFont.truetype(\"arialbd.ttf\", size=36)\n", + "draw.text(position, \"1\", fill=\"green\", font=font)\n", + "font = ImageFont.truetype(\"arial.ttf\", size=36)\n", + "draw.text(position, \"1\", fill=\"white\", font=font)\n", + "\n", + "draw.rectangle([75, 690, 540, 725], outline=\"white\", width=5)\n", + "draw.rectangle([75, 690, 540, 725], outline=\"green\", width=2)\n", + "\n", + "position = (40, 490)\n", + "font = ImageFont.truetype(\"arialbd.ttf\", size=36)\n", + "draw.text(position, \"2\", fill=\"green\", font=font)\n", + "font = ImageFont.truetype(\"arial.ttf\", size=36)\n", + "draw.text(position, \"2\", fill=\"white\", font=font)\n", + "\n", + "draw.rectangle([65, 495, 300, 535], outline=\"white\", width=5)\n", + "draw.rectangle([65, 495, 300, 535], outline=\"green\", width=2)\n", + "\n", + "position = (50, 690)\n", + "font = ImageFont.truetype(\"arialbd.ttf\", size=36)\n", + "draw.text(position, \"3\", fill=\"green\", font=font)\n", + "font = ImageFont.truetype(\"arial.ttf\", size=36)\n", + "draw.text(position, \"3\", fill=\"white\", font=font)\n", + "\n", + "cropped_image.resize((int(0.15 * org_width), int(org_height * 0.5)), Image.Resampling.LANCZOS)" + ], + "id": "b937b746", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "\n", + "### 4.4.2 Set the *learning rate* and *batch size*\n", + "\n", + "1. Scroll the _Advanced fine tuning parameters_ widget down to show more parameters\n", + "\n", + "3. Change the \"model name\" to \"nuclei_ft\"\n", + "\n", + "3. Set the \"batch size\" to $3$\n", + "\n", + "4. Change the \"learning rate\" to $0.0001$\n", + "\n", + "::: {.callout-note}\n", + "You can modify other parameters for the training process here, such as the number of training epochs.\n", + ":::\n" + ], + "id": "8bb63596" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "#| echo: false\n", + "vertical_scroll_bar = acquisition_fun_cfg.tunable_segmentation_method._finetuning_parameters_scr.verticalScrollBar()\n", + "vertical_scroll_bar.setValue(vertical_scroll_bar.maximum())\n", + " \n", + "acquisition_fun_cfg.tunable_segmentation_method._finetuning_parameters.model_name.value = \"nuclei_ft\"\n", + "acquisition_fun_cfg.tunable_segmentation_method._finetuning_parameters.batch_size.value = 3\n", + "acquisition_fun_cfg.tunable_segmentation_method._finetuning_parameters.learning_rate.value = 0.0001" + ], + "id": "b7a06529", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "#| echo: false\n", + "screenshot = viewer.screenshot(canvas_only=False, flash=False)\n", + "\n", + "image = Image.fromarray(screenshot)\n", + "\n", + "org_height, org_width = image.height, image.width\n", + "\n", + "roi = (org_width * 0.7, 0, org_width, org_height)\n", + "\n", + "# Crop the image\n", + "cropped_image = image.crop(roi)\n", + "\n", + "draw = ImageDraw.Draw(cropped_image)\n", + "\n", + "# Draw a red rectangle\n", + "draw.rectangle([535, 525, 565, 845], outline=\"white\", width=5)\n", + "draw.rectangle([535, 525, 565, 845], outline=\"green\", width=2)\n", + "\n", + "position = (515, 480)\n", + "font = ImageFont.truetype(\"arialbd.ttf\", size=36)\n", + "draw.text(position, \"1\", fill=\"green\", font=font)\n", + "font = ImageFont.truetype(\"arial.ttf\", size=36)\n", + "draw.text(position, \"1\", fill=\"white\", font=font)\n", + "\n", + "draw.rectangle([50, 725, 540, 755], outline=\"white\", width=5)\n", + "draw.rectangle([50, 725, 540, 755], outline=\"green\", width=2)\n", + "\n", + "position = (20, 720)\n", + "font = ImageFont.truetype(\"arialbd.ttf\", size=34)\n", + "draw.text(position, \"2\", fill=\"green\", font=font)\n", + "font = ImageFont.truetype(\"arial.ttf\", size=34)\n", + "draw.text(position, \"2\", fill=\"white\", font=font)\n", + "\n", + "draw.rectangle([50, 755, 540, 780], outline=\"white\", width=5)\n", + "draw.rectangle([50, 755, 540, 780], outline=\"green\", width=2)\n", + "\n", + "position = (20, 750)\n", + "font = ImageFont.truetype(\"arialbd.ttf\", size=34)\n", + "draw.text(position, \"3\", fill=\"green\", font=font)\n", + "font = ImageFont.truetype(\"arial.ttf\", size=34)\n", + "draw.text(position, \"3\", fill=\"white\", font=font)\n", + "\n", + "draw.rectangle([50, 780, 540, 810], outline=\"white\", width=5)\n", + "draw.rectangle([50, 780, 540, 810], outline=\"green\", width=2)\n", + "\n", + "position = (20, 780)\n", + "font = ImageFont.truetype(\"arialbd.ttf\", size=34)\n", + "draw.text(position, \"4\", fill=\"green\", font=font)\n", + "font = ImageFont.truetype(\"arial.ttf\", size=34)\n", + "draw.text(position, \"4\", fill=\"white\", font=font)\n", + "\n", + "cropped_image.resize((int(0.15 * org_width), int(org_height * 0.5)), Image.Resampling.LANCZOS)" + ], + "id": "3d266f79", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4.5 Execute the fine tuning process\n", + "\n", + "- Click the \"Fine tune model\" button to run the training process.\n", + "\n", + "::: {.callout-note}\n", + "Depending on your computer resources (RAM, CPU), this process might take some minutes to complete. If you have a dedicated GPU device, this can take a couple of seconds instead.\n", + ":::\n" + ], + "id": "1f2c277e" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "#| echo: false\n", + "screenshot = viewer.screenshot(canvas_only=False, flash=False)\n", + "\n", + "image = Image.fromarray(screenshot)\n", + "\n", + "org_height, org_width = image.height, image.width\n", + "\n", + "roi = (org_width * 0.7, 0, org_width, org_height)\n", + "\n", + "# Crop the image\n", + "cropped_image = image.crop(roi)\n", + "\n", + "draw = ImageDraw.Draw(cropped_image)\n", + "\n", + "# Draw a red rectangle\n", + "draw.rectangle([245, 880, 405, 920], outline=\"white\", width=5)\n", + "draw.rectangle([245, 880, 405, 920], outline=\"green\", width=2)\n", + "\n", + "cropped_image.resize((int(0.15 * org_width), int(org_height * 0.5)), Image.Resampling.LANCZOS)" + ], + "id": "b6826668", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "#| echo: false\n", + "acquisition_fun_cfg.fine_tune()" + ], + "id": "213cb5cb", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4.6 Review the fine tuned segmentation\n" + ], + "id": "3ee08673" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "#| echo: false\n", + "screenshot = viewer.screenshot(canvas_only=False, flash=False)\n", + "image = Image.fromarray(screenshot)\n", + "image" + ], + "id": "aec994c5", + "execution_count": null, + "outputs": [] + } + ], + "metadata": { + "kernelspec": { + "name": "python3", + "language": "python", + "display_name": "Python 3 (ipykernel)", + "path": "C:\\Users\\cervaf\\AppData\\Local\\miniforge3\\envs\\activelearning\\share\\jupyter\\kernels\\python3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file diff --git a/src/napari_activelearning/_acquisition.py b/src/napari_activelearning/_acquisition.py index f098223..2b7f5be 100644 --- a/src/napari_activelearning/_acquisition.py +++ b/src/napari_activelearning/_acquisition.py @@ -137,19 +137,14 @@ def add_multiscale_output_layer( if isinstance(new_output_layer, list): new_output_layer = new_output_layer[0] - output_layers_group = image_group.getLayersGroup( - layers_group_name + output_layers_group = image_group.add_layers_group( + layers_group_name, + source_axes=axes, + use_as_input_image=False, + use_as_input_labels=use_as_input_labels, + use_as_sampling_mask=use_as_sampling_mask ) - if output_layers_group is None: - output_layers_group = image_group.add_layers_group( - layers_group_name, - source_axes=axes, - use_as_input_image=False, - use_as_input_labels=use_as_input_labels, - use_as_sampling_mask=use_as_sampling_mask - ) - output_channel = output_layers_group.add_layer( new_output_layer ) @@ -281,8 +276,6 @@ def _get_transform(self): raise NotImplementedError("This method requies to be overriden by a " "derived class.") - # def _fine_tune(self, dataset_metadata_list: Iterable[dict], - # train_data_proportion: float = 0.8) -> bool: def _fine_tune(self, train_dataloader, val_dataloader) -> bool: raise NotImplementedError("This method requies to be overriden by a " "derived class.") @@ -364,7 +357,7 @@ def fine_tune(self, dataset_metadata_list: Iterable[dict], val_datasets = [] training_indices = np.random.choice( - len(dataset_metadata_list), + len(dataset_metadata_list), int(train_data_proportion * len(dataset_metadata_list)) ).tolist() @@ -852,8 +845,8 @@ def compute_acquisition_layers( ) if (not segmentation_only - and image_group is not None - and image_group.labels_group is None): + and image_group is not None): + # and image_group.labels_group is None): new_label_group = self.labels_manager.add_labels( segmentation_channel, img_sampling_positions diff --git a/src/napari_activelearning/_interface.py b/src/napari_activelearning/_interface.py index 6d210ce..33e7f83 100644 --- a/src/napari_activelearning/_interface.py +++ b/src/napari_activelearning/_interface.py @@ -902,7 +902,8 @@ def __init__(self, image_groups_manager: ImageGroupsManagerWidget, acquisition_lyt.addWidget(self.input_axes_le, 4, 1) acquisition_lyt.addWidget(QLabel("Model axes"), 4, 2) acquisition_lyt.addWidget(self.model_axes_le, 4, 3) - acquisition_lyt.addWidget(self.methods_cmb, 5, 0, 1, 4) + acquisition_lyt.addWidget(QLabel("Model"), 5, 0) + acquisition_lyt.addWidget(self.methods_cmb, 5, 1, 1, 3) acquisition_lyt.addWidget(self.execute_selected_btn, 7, 0) acquisition_lyt.addWidget(self.execute_all_btn, 7, 1) acquisition_lyt.addWidget(self.finetuning_btn, 8, 1) diff --git a/src/napari_activelearning/_labels.py b/src/napari_activelearning/_labels.py index 8328dba..97c9a45 100644 --- a/src/napari_activelearning/_labels.py +++ b/src/napari_activelearning/_labels.py @@ -384,7 +384,7 @@ def focus_and_edit_region(self, layer, event): []): for label in map(lambda idx: label_group.child(idx), range(label_group.childCount())): - if all(ax_pos.start <= ax_coord <= ax_pos.stop + if all(ax_pos.start <= ax_coord < ax_pos.stop for ax_pos, ax_coord in zip(label.position, curr_pos)): clicked_label = label break diff --git a/src/napari_activelearning/_layers.py b/src/napari_activelearning/_layers.py index 17c40f8..e7f2168 100644 --- a/src/napari_activelearning/_layers.py +++ b/src/napari_activelearning/_layers.py @@ -152,17 +152,8 @@ def selected(self, is_selected: bool): viewer.layers.selection.remove(self.layer) -class LayersGroupSignals(QObject): - updated_usage = Signal((int,)) - - class LayersGroup(QTreeWidgetItem): - def __init__(self, layers_group_name: str, - source_axes: Optional[str] = None, - use_as_input_image: Optional[bool] = False, - use_as_input_labels: Optional[bool] = False, - use_as_sampling_mask: Optional[bool] = False - ): + def __init__(self): self._layers_group_name = None self._use_as_input_image = False @@ -176,15 +167,8 @@ def __init__(self, layers_group_name: str, super().__init__() - self.layers_group_signals = LayersGroupSignals() self._signal_emited = False - self.layers_group_name = layers_group_name - self.use_as_input_image = use_as_input_image - self.use_as_input_labels = use_as_input_labels - self.use_as_sampling_mask = use_as_sampling_mask - self.source_axes = source_axes - self.updated = True def _update_source_axes(self): @@ -390,17 +374,27 @@ def _set_usage(self): use_as = [] if self._use_as_input_image: use_as.append("Input") + if not self._signal_emited and self.parent() is not None: + current_idx = self.parent().indexOfChild(self) + self._signal_emited = True + self.parent().input_layers_group = current_idx + self._signal_emited = False if self._use_as_input_labels: use_as.append("Labels") + if not self._signal_emited and self.parent() is not None: + current_idx = self.parent().indexOfChild(self) + self._signal_emited = True + self.parent().labels_layers_group = current_idx + self._signal_emited = False if self._use_as_sampling_mask: use_as.append("Sampling mask") if not self._signal_emited and self.parent() is not None: - sampling_mask_idx = self.parent().indexOfChild(self) + current_idx = self.parent().indexOfChild(self) self._signal_emited = True - self.layers_group_signals.updated_usage.emit(sampling_mask_idx) + self.parent().sampling_mask_layers_group = current_idx self._signal_emited = False self.setText(1, "/".join(use_as)) @@ -714,10 +708,6 @@ def sampling_mask_layers_group(self, sampling_mask_idx: Union[int, None]): if sampling_mask_idx is not None: self.child(sampling_mask_idx).use_as_sampling_mask = True - @Slot(int) - def _update_sampling_mask_layers_group(self, sampling_mask_idx): - self.sampling_mask_layers_group = sampling_mask_idx - @property def labels_group(self): return self._labels_group @@ -743,9 +733,6 @@ def getLayersGroup(self, layers_group_name: str): def takeChild(self, index: int): child = super(ImageGroup, self).takeChild(index) if isinstance(child, LayersGroup): - child.layers_group_signals.updated_usage.disconnect( - partial(self._update_sampling_mask_layers_group, self) - ) child.takeChildren() return child @@ -754,18 +741,12 @@ def takeChildren(self): children = super(ImageGroup, self).takeChildren() for child in children: if isinstance(child, LayersGroup): - child.layers_group_signals.updated_usage.disconnect( - partial(self._update_sampling_mask_layers_group, self) - ) child.takeChildren() return children def removeChild(self, child: QTreeWidgetItem): if isinstance(child, LayersGroup): - child.layers_group_signals.updated_usage.disconnect( - partial(self._update_sampling_mask_layers_group, self) - ) child.takeChildren() super(ImageGroup, self).removeChild(child) @@ -787,19 +768,15 @@ def add_layers_group(self, layers_group_name: Optional[str] = None, elif use_as_sampling_mask: layers_group_name = "masks" - new_layers_group = LayersGroup( - layers_group_name, - source_axes=source_axes, - use_as_input_image=use_as_input_image, - use_as_input_labels=use_as_input_labels, - use_as_sampling_mask=use_as_sampling_mask - ) + new_layers_group = LayersGroup() self.addChild(new_layers_group) + new_layers_group.layers_group_name = layers_group_name + new_layers_group.source_axes = source_axes - new_layers_group.layers_group_signals.updated_usage.connect( - partial(self._update_sampling_mask_layers_group, self) - ) + new_layers_group.use_as_input_image = use_as_input_image + new_layers_group.use_as_input_labels = use_as_input_labels + new_layers_group.use_as_sampling_mask = use_as_sampling_mask new_layers_group.setExpanded(True) @@ -1037,7 +1014,7 @@ def update_channels(self, channel: Optional[int] = None): if not self._active_layer_channel or not self._active_layers_group: return - if channel: + if channel is not None: self._edit_channel = channel prev_channel = self._active_layer_channel.channel diff --git a/src/napari_activelearning/_models_impl.py b/src/napari_activelearning/_models_impl.py index 0460097..95704df 100644 --- a/src/napari_activelearning/_models_impl.py +++ b/src/napari_activelearning/_models_impl.py @@ -132,28 +132,21 @@ def _run_eval(self, img, *args, **kwargs): return seg def _get_transform(self): - if self._transform is None: - self._transform = CellposeTransform(self._channels, - self._channel_axis) - return self._transform, None + return lambda x: x, None - # def _preload_data(self, data_loader, - # train_data_proportion: float = 0.8): def _preload_data(self, dataloader): raw_data = [] label_data = [] for img, lab in dataloader: if USING_PYTORCH: - img = img[0].numpy() - lab = lab[0].numpy() + img = img[0].numpy().squeeze() + lab = lab[0].numpy().squeeze() raw_data.append(img) label_data.append(lab) return raw_data, label_data - # def _fine_tune(self, data_loader, - # train_data_proportion: float = 0.8) -> bool: def _fine_tune(self, train_dataloader, val_dataloader) -> bool: self._model_init() diff --git a/src/napari_activelearning/_models_impl_interface.py b/src/napari_activelearning/_models_impl_interface.py index 76e9c1f..ba24c30 100644 --- a/src/napari_activelearning/_models_impl_interface.py +++ b/src/napari_activelearning/_models_impl_interface.py @@ -105,7 +105,7 @@ def cellpose_finetuning_parameters( learning_rate: Annotated[float, {"widget_type": "FloatSpinBox", "min": 1e-10, "max": 1.0, - "step": 1e-3}] = 0.005, + "step": 1e-10}] = 0.005, n_epochs: Annotated[int, {"widget_type": "SpinBox", "min": 1, "max": 10000}] = 20): @@ -259,10 +259,9 @@ def _show_segmentation_parameters(self, show: bool): def _show_finetuning_parameters(self, show: bool): self._finetuning_parameters_scr.setVisible(show) - # def _fine_tune(self, data_loader, train_data_proportion: float = 0.8): - # super()._fine_tune(data_loader, train_data_proportion) def _fine_tune(self, train_dataloader, val_dataloader): super()._fine_tune(train_dataloader, val_dataloader) + self._segmentation_parameters.model_type.value = "custom" self._segmentation_parameters.pretrained_model.value =\ self._pretrained_model diff --git a/src/napari_activelearning/_tests/conftest.py b/src/napari_activelearning/_tests/conftest.py index fc59b7c..98968f2 100644 --- a/src/napari_activelearning/_tests/conftest.py +++ b/src/napari_activelearning/_tests/conftest.py @@ -241,12 +241,11 @@ def multiscale_layer_channel(multiscale_layer): @pytest.fixture(scope="function") def multiscale_layers_group(multiscale_layer_channel): - layers_group_mock = LayersGroup("segmentation", - source_axes="TZYX", - use_as_input_image=False, - use_as_input_labels=True, - use_as_sampling_mask=False) + layers_group_mock = LayersGroup() + layers_group_mock.layers_group_name = "segmentation" + layers_group_mock.source_axes = "TZYX" layers_group_mock.addChild(multiscale_layer_channel) + layers_group_mock.use_as_input_labels = True layers_group_mock.source_axes = "TZYX" return layers_group_mock @@ -266,9 +265,11 @@ def simple_image_group(single_scale_array): image_group = ImageGroup("simple_group") - layers_group = LayersGroup("simple_layers_group") + layers_group = LayersGroup() image_group.addChild(layers_group) + layers_group.layers_group_name = "simple_layers_group" + layer_channel = layers_group.add_layer(layer, 0, "TCZYX") image_group.input_layers_group = 0 diff --git a/src/napari_activelearning/_tests/test_layers.py b/src/napari_activelearning/_tests/test_layers.py index e89d322..a71fb21 100644 --- a/src/napari_activelearning/_tests/test_layers.py +++ b/src/napari_activelearning/_tests/test_layers.py @@ -165,8 +165,8 @@ def test_update_source_data(single_scale_layer): def test_layers_group_default_initialization(): - group = LayersGroup(layers_group_name="default_group") - assert group.layers_group_name == "default_group" + group = LayersGroup() + assert group.layers_group_name is None assert group.use_as_input_image is False assert group.use_as_sampling_mask is False assert group._source_axes_no_channels is None @@ -181,12 +181,14 @@ def test_layers_group_properties(single_scale_layer, make_napari_viewer): viewer = make_napari_viewer() viewer.layers.append(layer) - layers_group = LayersGroup("sample_layers_group") + layers_group = LayersGroup() layers_group.add_layer(layer, channel=0, source_axes="TCZYX") image_group = ImageGroup() image_group.addChild(layers_group) + layers_group.layers_group_name = "sample_layers_group" + assert layers_group.layers_group_name == "sample_layers_group" layers_group.layers_group_name = "new_sample_layers_group" assert layers_group.layers_group_name == "new_sample_layers_group" @@ -244,7 +246,8 @@ def test_update_layers_group_source_data(single_scale_memory_layer, viewer = make_napari_viewer() viewer.layers.append(layer) - layers_group = LayersGroup("sample_layers_group") + layers_group = LayersGroup() + layers_group.layers_group_name = "sample_layers_group" layers_group.add_layer(layer, 0, "TCZYX") layers_group.add_layer(layer, 1, "TCZYX") @@ -263,7 +266,8 @@ def test_update_layers_group_channels(single_scale_memory_layer, viewer = make_napari_viewer() viewer.layers.append(layer) - layers_group = LayersGroup("sample_layers_group") + layers_group = LayersGroup() + layers_group.layers_group_name = "sample_layers_group" layer_channel_1 = layers_group.add_layer(layer, 0, "TCZYX") layer_channel_2 = layers_group.add_layer(layer, 1, "TCZYX") @@ -322,11 +326,13 @@ def test_managed_layers_image_group_root(single_scale_memory_layer, image_group = ImageGroup("image_group") group_root.addChild(image_group) - layers_group_1 = LayersGroup("layers_group_1") + layers_group_1 = LayersGroup() image_group.addChild(layers_group_1) + layers_group_1.layers_group_name = "layers_group_1" - layers_group_2 = LayersGroup("layers_group_2") + layers_group_2 = LayersGroup() image_group.addChild(layers_group_2) + layers_group_2.layers_group_name = "layers_group_2" layer_channel_1 = layers_group_1.add_layer(layer, 0, "TCZYX") layer_channel_2 = layers_group_2.add_layer(layer, 0, "TCZYX") From e33f85bdd5ade9cbd9a69b25c92ae982485e60c4 Mon Sep 17 00:00:00 2001 From: Fernando Cervantes Sanchez Date: Fri, 24 Jan 2025 17:05:37 -0500 Subject: [PATCH 10/10] Finishing details for tutorial --- docs/tutorials/cellpose_tutorial.qmd | 193 +++++++++++------- .../_tests/test_layers.py | 3 +- src/napari_activelearning/napari.yaml | 5 + 3 files changed, 128 insertions(+), 73 deletions(-) diff --git a/docs/tutorials/cellpose_tutorial.qmd b/docs/tutorials/cellpose_tutorial.qmd index 2ca5a05..10c9c2f 100644 --- a/docs/tutorials/cellpose_tutorial.qmd +++ b/docs/tutorials/cellpose_tutorial.qmd @@ -1,5 +1,5 @@ --- -title: "Tutorial: How to fine tune a Cellpose model" +title: "Tutorial: How to fine tune a Cellpose model with the Active Learning plugin for Napari" author: Fernando Cervantes (fernando.cervantes@jax.org) format: revealjs: @@ -11,11 +11,11 @@ format: jupyter: python3 --- -# 1 Image groups management +# 1. Image groups management -## 1.1 Load a sample image +## 1.1. Load a sample image -You can use the cells 3D image sample from the napari's built-in samples. +You can use the cells 3D image sample from napari's built-in samples. ``` File > Open Sample > napari builtins > Cells (3D+2Ch) @@ -32,6 +32,7 @@ import zarr ``` {python} #| echo: false viewer = napari.Viewer() +viewer.window._qt_window.showMaximized() _ = viewer.open_sample(plugin="napari", sample="cells3d") ``` ``` {python} @@ -39,9 +40,9 @@ _ = viewer.open_sample(plugin="napari", sample="cells3d") nbscreenshot(viewer) ``` -## 1.2 Add the _Image Groups Manager_ widget to napari's window +## 1.2. Add the _Image Groups Manager_ widget to napari's window -You can find the _Image group manager_ under the _Active Learning_ plugin in the napari's plugins menu. +You can find the _Image groups manager_ under the _Active Learning_ plugin in napari's plugins menu. ``` Plugins > Active Learning > Image groups manager @@ -60,7 +61,7 @@ image_groups_mgr_dw = viewer.window.add_dock_widget(image_groups_mgr) nbscreenshot(viewer) ``` -## 1.3 Create an _Image Group_ containing _nuclei_ and _membrane_ layers +## 1.3. Create an _Image Group_ containing _nuclei_ and _membrane_ layers - Select the _nuclei_ and _membrane_ layer and click the _New Image Group_ button on the _Image Groups Manager_ widget. @@ -76,7 +77,7 @@ image_groups_mgr.create_group() nbscreenshot(viewer) ``` -## 1.4 Edit the image group properties +## 1.4. Edit the image group properties :::: {.columns} @@ -121,16 +122,22 @@ cropped_image :::: -# 2 Segment the managed image groups +# 2. Segmentation on image groups -## 2.1 Add the _Acquisition function configuration_ widget to napari's window +## 2.1. Add the _Acquisition function configuration_ widget to napari's window -The _Acquisition function configuration_ is under the _Active Learning_ plugin in the napari's plugins menu. +The _Acquisition function configuration_ is under the _Active Learning_ plugin in napari's plugins menu. ``` Plugins > Active Learning > Acquisition function configuration ``` +::: {.callout-caution} +The _Acquisition function configuration_ widget will be docked under the _Image groups manager_ widget, which might reduce the space to show the content of both widgets properly. +However, these widgets can be _un-docked_ from their current place and docked into another more convenient location within napari's window. +Also, these can be _re-docked_ as tabs as is illustrated in this tutorial. +::: + ``` {python} #| echo: false acquisition_fun_cfg_dw = viewer.window.add_dock_widget(acquisition_fun_cfg) @@ -148,12 +155,13 @@ acquisition_fun_cfg_dw.raise_() nbscreenshot(viewer) ``` -## 2.2 Define sampling configuration +## 2.2. Define sampling configuration :::: {.columns} ::: {.column width=0.3} -### 2.2.1 Set the axes of the sampling space + +### 2.2.1. Set the axes of the sampling space 1. Make sure "Input axes" are set to "ZYX" @@ -163,6 +171,10 @@ This specifies that the samples will be taken from those axes. 2. Change the "Model axes" to "CYX" +::: {.callout-note} +And this specifies that sample's axes will be permuted to match the "Model axes" order. +::: + ::: ::: {.column width=0.3} @@ -215,7 +227,7 @@ cropped_image --- -### 2.2.2 Set the size of the sampling patch +### 2.2.2. Set the size of the sampling patch :::: {.columns} @@ -264,11 +276,11 @@ This directs the Active Learning plugin to sample at random patches of size $256 :::: -## 2.3 Define the maximum number of samples to extract +## 2.3. Define the maximum number of samples to extract :::: {.columns} -::: {.column width=0.7} +::: {.column width=0.4} ``` {python} #| echo: false @@ -293,12 +305,12 @@ draw = ImageDraw.Draw(cropped_image) draw.rectangle([60, 740, 405, 770], outline="white", width=5) draw.rectangle([60, 740, 405, 770], outline="green", width=2) -cropped_image.resize((int(0.15 * org_width), int(org_height * 0.5)), Image.Resampling.LANCZOS) +cropped_image.resize((int(0.7*0.3 * org_width), int(org_height * 0.7)), Image.Resampling.LANCZOS) ``` ::: -::: {.column width=0.30} +::: {.column width=0.6} - Set the "Maximum samples" to $4$ and press _Enter_ :::{.callout-note} @@ -310,17 +322,18 @@ This tells the Active Learning plugin to process at most _four_ samples at rando :::: -## 2.4 Configure the segmentation method +## 2.4. Configure the segmentation method :::: {.columns} -::: {.column width=0.3} +::: {.column width=0.6} -1. Use the "Model" dropdown to select the `cellpose` method +1. Use the "Model" dropdown list to select the `cellpose` method 2. Click the "Advanced segmentation parameters" checkbox 3. Change the "Channel axis" to $0$ + ::: {.callout-note} This makes `cellpose` to use the first axis as "Color" channel. ::: @@ -331,12 +344,12 @@ This makes `cellpose` to use the first axis as "Color" channel. This tells `cellpose` to segment the first channel ($0$) and use the second channel ($1$) as help channel. ::: -5. Choose the "nuclei" model from the dropdown +5. Choose the "nuclei" model from the dropdown list ::: -::: {.column width=0.7} +::: {.column width=0.4} ``` {python} #| echo: false acquisition_fun_cfg.methods_cmb.setCurrentIndex(2) @@ -406,17 +419,17 @@ draw.text(position, "5", fill="green", font=font) font = ImageFont.truetype("arial.ttf", size=36) draw.text(position, "5", fill="white", font=font) -cropped_image.resize((int(0.15 * org_width), int(org_height * 0.5)), Image.Resampling.LANCZOS) +cropped_image.resize((int(0.7 * 0.3 * org_width), int(org_height * 0.7)), Image.Resampling.LANCZOS) ``` ::: :::: -## 2.5 Execute the segmentation method on all image groups +## 2.5. Execute the segmentation method on all image groups :::: {.columns} -::: {.column width=0.3} +::: {.column width=0.6} - Click the "Run on all image groups" ::: {.callout-note} @@ -425,7 +438,7 @@ To execute the segmentation only on specific image groups, select the desired im ::: -::: {.column width=0.7} +::: {.column width=0.4} ``` {python} #| echo: false screenshot = viewer.screenshot(canvas_only=False, flash=False) @@ -445,7 +458,7 @@ draw = ImageDraw.Draw(cropped_image) draw.rectangle([245, 850, 405, 880], outline="white", width=5) draw.rectangle([245, 850, 405, 880], outline="green", width=2) -cropped_image.resize((int(0.15 * org_width), int(org_height * 0.5)), Image.Resampling.LANCZOS) +cropped_image.resize((int(0.7 * 0.3 * org_width), int(org_height * 0.7)), Image.Resampling.LANCZOS) ``` ``` {python} #| echo: false @@ -455,7 +468,7 @@ _ = acquisition_fun_cfg.compute_acquisition_layers(run_all=True) :::: -## 2.6 Inspect the segmentation layer +## 2.6. Inspect the segmentation layer ::: {.callout-note} Because the input image is 3D, you might have to slide the Z index on the bottom of napari's window to look at the samples that have been segmented. @@ -483,19 +496,19 @@ image ``` -# 3 Segment masked regions only +# 3. Segment masked regions only -## 3.1 Create a mask to restrict the sampling space +## 3.1. Create a mask to restrict the sampling space -### 3.1.1 Add a mask layer to the image group +### 3.1.1. Add a mask layer to the image group :::: {.columns} -::: {.column width=0.3} +::: {.column width=0.6} - Switch to the "Image groups manager" tab - Click the "Edit mask properties" checkbox ::: -::: {.column width=0.7} +::: {.column width=0.4} ``` {python} #| echo: false image_groups_mgr_dw.raise_() @@ -531,22 +544,22 @@ image_groups_mgr.mask_generator._show_editor(True) --- -### 3.1.2 Create a low resolution mask for its corresponding image +### 3.1.2. Create a low resolution mask for its corresponding image :::: {.columns} -:::{.column width=0.3} +:::{.column width=0.6} 1. Set the mask scale to $256$ for the "X" and "Y" axes, and a scale of $1$ for the "Z" axis 2. Click the "Create mask" button :::{.callout-note} -This creates a low-resolution mask where each pixel corresponds to a $256\times256$ pixels region in the input image. -Because the mask is low-resolution, it uses less space (in memory RAM and disk). +This creates a low-resolution mask where each of its pixels corresponds to a $256\times256$ pixels region in the input image. +Because the mask is low-resolution, it uses less space in memory RAM and/or disk. ::: ::: -:::{.column width=0.7} +:::{.column width=0.4} ``` {python} #| echo: false image_groups_mgr.mask_generator.patch_sizes_mspn.sizes = {"Z": 1, "Y": 256, "X": 256} @@ -592,9 +605,9 @@ cropped_image :::: -## 3.1.3 Specify the samplable regions +## 3.1.3. Specify the samplable regions -- Draw a mask on slices $27$ to $30$ in the "Z" axis. +- Draw a mask that covers slices $27$ to $30$ in the "Z" axis. :::{.callout-note} You can move the slider at the bottom of napari's window to navigate between slices in the "Z" axis. @@ -618,10 +631,10 @@ image ``` -## 3.2 Execute the segmentation process on the masked regions +## 3.2. Execute the segmentation process on the masked regions :::: {.columns} -::: {.column width=0.3} +::: {.column width=0.6} - Go back to the "Acquisition function configuration" widget - Click the "Run on all image groups" button again @@ -631,7 +644,7 @@ Because the image group has a defined mask, samples will be extracted at random ::: -::: {.column width=0.7} +::: {.column width=0.4} ``` {python} #| echo: false acquisition_fun_cfg_dw.raise_() @@ -655,7 +668,7 @@ draw = ImageDraw.Draw(cropped_image) draw.rectangle([245, 850, 400, 890], outline="white", width=5) draw.rectangle([245, 850, 400, 890], outline="green", width=2) -cropped_image.resize((int(0.15 * org_width), int(org_height * 0.5)), Image.Resampling.LANCZOS) +cropped_image.resize((int(0.7 * 0.3 * org_width), int(org_height * 0.7)), Image.Resampling.LANCZOS) ``` ``` {python} #| echo: false @@ -666,7 +679,7 @@ _ = acquisition_fun_cfg.compute_acquisition_layers(run_all=True) :::: -## 3.3 Inspect the masked segmentation output +## 3.3. Inspect the masked segmentation output ``` {python} #| echo: false @@ -678,9 +691,9 @@ image # 4. Fine tune the segmentation model -## 4.1 Add the _Label groups manager_ widget to napari's window +## 4.1. Add the _Label groups manager_ widget to napari's window -You can find the _Label groups manager_ under the _Active Learning_ plugin in the napari's plugins menu. +You can find the _Label groups manager_ under the _Active Learning_ plugin in napari's plugins menu. ``` Plugins > Active Learning > Label groups manager @@ -706,7 +719,7 @@ image ``` -## 4.2 Edit segmented patches +## 4.2. Edit segmented patches ### 4.2.1. Select a segmented patch to edit @@ -744,8 +757,8 @@ image draw = ImageDraw.Draw(image) # Draw a red rectangle -draw.rectangle([5, 20, 280, 450], outline="white", width=5) -draw.rectangle([5, 20, 280, 450], outline="green", width=2) +draw.rectangle([5, 20, 280, 430], outline="white", width=5) +draw.rectangle([5, 20, 280, 430], outline="green", width=2) image ``` @@ -755,24 +768,18 @@ image ### 4.2.3. Commit changes to the labels layer :::: {.columns} -::: {.column width=0.3} +::: {.column width=0.6} - Once you have finished editing the labels, click the "Commit changes" button on the _Label groups manager_ ::: -::: {.column width=0.7} +::: {.column width=0.4} ``` {python} #| echo: false z_tmp = zarr.open(r"C:\Users\cervaf\Documents\Logging\activelearning_logs\membrane.zarr\segmentation\0", mode="r") ``` ``` {python} #| echo: false -labels_mgr.labels_group_root.child(1).child(3)._position -viewer.layers["Labels edit"].data[:] = z_tmp[labels_mgr.labels_group_root.child(1).child(3)._position] -labels_mgr.commit() -``` -``` {python} -#| echo: false screenshot = viewer.screenshot(canvas_only=False, flash=False) image = Image.fromarray(screenshot) @@ -790,15 +797,22 @@ draw = ImageDraw.Draw(cropped_image) draw.rectangle([305, 900, 570, 940], outline="white", width=5) draw.rectangle([305, 900, 570, 940], outline="green", width=2) -cropped_image.resize((int(0.15 * org_width), int(org_height * 0.5)), Image.Resampling.LANCZOS) +cropped_image.resize((int(0.7 * 0.3 * org_width), int(org_height * 0.7)), Image.Resampling.LANCZOS) +``` +``` {python} +#| echo: false +labels_mgr.labels_group_root.child(1).child(3)._position +viewer.layers["Labels edit"].data[:] = z_tmp[labels_mgr.labels_group_root.child(1).child(3)._position] +labels_mgr.commit() ``` + ::: :::: -## 4.3 Navigate between segmented patches +## 4.3. Navigate between segmented patches :::: {.columns} -::: {.column width=0.3} +::: {.column width=0.6} 1. Expand the second group of labels @@ -809,7 +823,7 @@ cropped_image.resize((int(0.15 * org_width), int(org_height * 0.5)), Image.Resam 4. Continue editing the segmentation in the current patch and commit the changes when finish ::: -::: {.column width=0.7} +::: {.column width=0.4} ``` {python} #| echo: false @@ -873,7 +887,7 @@ draw.text(position, "4", fill="green", font=font) font = ImageFont.truetype("arial.ttf", size=36) draw.text(position, "4", fill="white", font=font) -cropped_image.resize((int(0.15 * org_width), int(org_height * 0.5)), Image.Resampling.LANCZOS) +cropped_image.resize((int(0.7 * 0.3 * org_width), int(org_height * 0.7)), Image.Resampling.LANCZOS) ``` ``` {python} #| echo: false @@ -892,9 +906,13 @@ for c_idx in range(labels_mgr.labels_group_root.child(1).childCount()): :::: -## 4.4 Setup fine tuning configuration +## 4.4. Setup fine tuning configuration + +### 4.4.1. Use the _Acquisition function configuration_ widget to configure the parameters for the fine tuning process + +:::: {.columns} -### 4.4.1 Use the _Acquisition function configuration_ widget to set the configuration for executing the fine tuning process +::: {.column width=0.6} 1. Go to the "Acquisition function configuration" widget @@ -902,6 +920,10 @@ for c_idx in range(labels_mgr.labels_group_root.child(1).childCount()): 3. Change the "save path" to a location where you want to store the fine tuned model +::: + +::: {.column width=0.4} + ``` {python} #| echo: false acquisition_fun_cfg_dw.raise_() @@ -954,12 +976,20 @@ draw.text(position, "3", fill="green", font=font) font = ImageFont.truetype("arial.ttf", size=36) draw.text(position, "3", fill="white", font=font) -cropped_image.resize((int(0.15 * org_width), int(org_height * 0.5)), Image.Resampling.LANCZOS) +cropped_image.resize((int(0.7 * 0.3 * org_width), int(org_height * 0.7)), Image.Resampling.LANCZOS) ``` +::: + +:::: + --- -### 4.4.2 Set the *learning rate* and *batch size* +### 4.4.2. Set the *learning rate* and *batch size* + +:::: {.columns} + +::: {.column width=0.6} 1. Scroll the _Advanced fine tuning parameters_ widget down to show more parameters @@ -973,6 +1003,10 @@ cropped_image.resize((int(0.15 * org_width), int(org_height * 0.5)), Image.Resam You can modify other parameters for the training process here, such as the number of training epochs. ::: +::: + +::: {.column width=0.4} + ``` {python} #| echo: false vertical_scroll_bar = acquisition_fun_cfg.tunable_segmentation_method._finetuning_parameters_scr.verticalScrollBar() @@ -1034,17 +1068,28 @@ draw.text(position, "4", fill="green", font=font) font = ImageFont.truetype("arial.ttf", size=34) draw.text(position, "4", fill="white", font=font) -cropped_image.resize((int(0.15 * org_width), int(org_height * 0.5)), Image.Resampling.LANCZOS) +cropped_image.resize((int(0.7 * 0.3 * org_width), int(org_height * 0.7)), Image.Resampling.LANCZOS) ``` -## 4.5 Execute the fine tuning process +::: + +:::: + +## 4.5. Execute the fine tuning process + +:::: {.columns} +::: {.column width=0.6} - Click the "Fine tune model" button to run the training process. ::: {.callout-note} Depending on your computer resources (RAM, CPU), this process might take some minutes to complete. If you have a dedicated GPU device, this can take a couple of seconds instead. ::: +::: + +::: {.column width=0.4} + ``` {python} #| echo: false screenshot = viewer.screenshot(canvas_only=False, flash=False) @@ -1064,14 +1109,18 @@ draw = ImageDraw.Draw(cropped_image) draw.rectangle([245, 880, 405, 920], outline="white", width=5) draw.rectangle([245, 880, 405, 920], outline="green", width=2) -cropped_image.resize((int(0.15 * org_width), int(org_height * 0.5)), Image.Resampling.LANCZOS) +cropped_image.resize((int(0.7 * 0.3 * org_width), int(org_height * 0.7)), Image.Resampling.LANCZOS) ``` ``` {python} #| echo: false acquisition_fun_cfg.fine_tune() ``` -## 4.6 Review the fine tuned segmentation +::: + +:::: + +## 4.6. Review the fine tuned segmentation ``` {python} #| echo: false diff --git a/src/napari_activelearning/_tests/test_layers.py b/src/napari_activelearning/_tests/test_layers.py index a71fb21..833f459 100644 --- a/src/napari_activelearning/_tests/test_layers.py +++ b/src/napari_activelearning/_tests/test_layers.py @@ -190,9 +190,10 @@ def test_layers_group_properties(single_scale_layer, make_napari_viewer): layers_group.layers_group_name = "sample_layers_group" assert layers_group.layers_group_name == "sample_layers_group" + assert image_group.group_name == "sample_layers_group" + layers_group.layers_group_name = "new_sample_layers_group" assert layers_group.layers_group_name == "new_sample_layers_group" - assert image_group.group_name == "new_sample_layers_group" assert not layers_group.use_as_input_image assert not layers_group.use_as_sampling_mask diff --git a/src/napari_activelearning/napari.yaml b/src/napari_activelearning/napari.yaml index 01a6bd7..784f020 100644 --- a/src/napari_activelearning/napari.yaml +++ b/src/napari_activelearning/napari.yaml @@ -15,6 +15,9 @@ contributions: - id: napari-activelearning.make_label_groups_manager_widget python_name: napari_activelearning:get_label_groups_manager_widget title: Label groups manager + - id: napari-activelearning.make_activelearning_widget + python_name: napari_activelearning:get_active_learning_widget + title: Active learning widget widgets: - command: napari-activelearning.make_image_groups_manager_widget display_name: Image groups manager @@ -22,3 +25,5 @@ contributions: display_name: Acquisition function configuration - command: napari-activelearning.make_label_groups_manager_widget display_name: Label groups manager + - command: napari-activelearning.make_activelearning_widget + display_name: Active learning widget