Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvement of user interface #31

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,124 changes: 1,117 additions & 7 deletions docs/tutorials/cellpose_tutorial.qmd

Large diffs are not rendered by default.

1,623 changes: 1,623 additions & 0 deletions docs/tutorials/cellpose_tutorial.quarto_ipynb

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions notebooks/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Example notebooks for integrating Active Learning to custom deep learning models

134 changes: 134 additions & 0 deletions notebooks/microsam_activelearning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import numpy as np
import torch
import time

from torch_em.transform.label import PerObjectDistanceTransform

from micro_sam import util
from micro_sam import automatic_segmentation as msas
import micro_sam.training as sam_training
import napari_activelearning as al


class TunableMicroSAM(al.TunableMethodWidget):
def __init__(self):
super(TunableMicroSAM, self).__init__()
self._predictor = None
self._amg = None

def _model_init(self):
if self._amg is not None:
return

(self._sam_predictor,
self._sam_instance_segmenter) = msas.get_predictor_and_segmenter(
model_type='vit_t',
device=util.get_device("cuda"
if torch.cuda.is_available()
else "cpu"),
amg=True,
checkpoint=None,
stability_score_offset=1.0
)

(self._sam_predictor_dropout,
self._sam_instance_segmenter_dropout) =\
msas.get_predictor_and_segmenter(
model_type='vit_t',
device=util.get_device("cuda"
if torch.cuda.is_available()
else "cpu"),
amg=True,
checkpoint=None,
stability_score_offset=1.0)

al.add_dropout(self._sam_predictor_dropout.model)

def _get_transform(self):
label_transform = PerObjectDistanceTransform(
distances=True, boundary_distances=True, directed_distances=False,
foreground=True, instances=True, min_size=25
)

return lambda x: (255.0 * x).astype(np.uint8), label_transform

def _run_pred(self, img, *args, **kwargs):
self._model_init()

e_time = time.perf_counter()
img_embeddings = util.precompute_image_embeddings(
predictor=self._sam_predictor_dropout,
input_=img,
save_path=None,
ndim=2,
tile_shape=None,
halo=None,
verbose=False,
)
e_time = time.perf_counter() - e_time

e_time = time.perf_counter()
self._sam_instance_segmenter_dropout.initialize(
image=img,
image_embeddings=img_embeddings
)
e_time = time.perf_counter() - e_time

e_time = time.perf_counter()
masks = self._sam_instance_segmenter_dropout.generate()
e_time = time.perf_counter() - e_time

e_time = time.perf_counter()
probs = np.zeros(img.shape[:2], dtype=np.float32)
for mask in masks:
probs = np.where(
mask["segmentation"],
mask["predicted_iou"],
probs
)
e_time = time.perf_counter() - e_time

probs = torch.from_numpy(probs).sigmoid().numpy()

return probs

def _run_eval(self, img, *args, **kwargs):
self._model_init()

e_time = time.perf_counter()
segmentation_mask = msas.automatic_instance_segmentation(
predictor=self._sam_predictor,
segmenter=self._sam_instance_segmenter,
input_path=img,
ndim=2,
verbose=False
)
e_time = time.perf_counter() - e_time

return segmentation_mask

def _fine_tune(self, train_dataloader, val_dataloader) -> bool:
self._model_init()

train_dataloader.shuffle = True
val_dataloader.shuffle = False

# Run training.
sam_training.train_sam(
name="microsam_activelearning",
model_type="vit_t",
train_loader=train_dataloader,
val_loader=val_dataloader,
n_epochs=2,
n_objects_per_batch=25,
with_segmentation_decoder=True,
device=util.get_device("cuda"
if torch.cuda.is_available()
else "cpu"),
)

return True


def register_microsam():
al.register_model("micro-sam", TunableMicroSAM)
4 changes: 2 additions & 2 deletions src/napari_activelearning/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from ._utils import *
from ._layers import *
from ._acquisition import *
from ._models import *
from ._models_impl import *
from ._interface import *
from ._models_interface import *
from ._models_impl_interface import *
from ._widgets import *
Loading
Loading