Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bioimageio Model Creation #227

Closed
wants to merge 24 commits into from
Closed
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
f2fa218
WIP Add bioimage.io model creation
anwai98 Oct 10, 2023
2073119
Update model building script
anwai98 Oct 11, 2023
fbfde8e
Update model predictor adaptor for bioimage models
anwai98 Oct 11, 2023
c62fe93
Refactor modelzoo functionality into submodule
constantinpape Oct 11, 2023
615ffc6
Add first working scripts for bioengine export
constantinpape Oct 11, 2023
09f5f76
Add input prompt transofrms to adaptor
anwai98 Oct 11, 2023
b869a33
Update numpy input saving
anwai98 Oct 11, 2023
96dc3f6
Update bioengine export script
constantinpape Oct 11, 2023
61b076e
Merge branch 'aa-modelzoo' of https://github.com/computational-cell-a…
constantinpape Oct 11, 2023
7d6bfca
Refactor modelzoo export
constantinpape Oct 12, 2023
b470257
Add tempfile for model conversion inputs
anwai98 Oct 12, 2023
de6b245
Add doc-strings to bioengine export functionality
constantinpape Oct 12, 2023
ebba719
Update modelzoo export script
constantinpape Oct 12, 2023
ee831ef
Update url in imjoy test
constantinpape Oct 13, 2023
b28c885
Merge branch 'dev' into aa-modelzoo
constantinpape Mar 14, 2024
390ce23
Update to bioimageio.spec v0.5 WIP
constantinpape Mar 14, 2024
6cccc06
Update example script
constantinpape Mar 15, 2024
b9df849
Update bioimageio export
constantinpape Mar 18, 2024
d911392
Minor fixes
constantinpape Mar 19, 2024
ec56035
Work on export
constantinpape Mar 19, 2024
2d9c88a
More modelzoo updtes
constantinpape Mar 20, 2024
a170511
Add all possible model inputs
constantinpape Mar 21, 2024
d2e4909
Merge branch 'dev' into aa-modelzoo
constantinpape Apr 9, 2024
c181e29
Bioimageio updates WIP
constantinpape Apr 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Update to bioimageio.spec v0.5 WIP
constantinpape committed Mar 14, 2024
commit 390ce234262e0c4d6517751bac6d267a90b42f05
4 changes: 2 additions & 2 deletions examples/modelzoo/export_model_for_bioimageio.py
Original file line number Diff line number Diff line change
@@ -3,11 +3,11 @@


def export_model_with_synthetic_data():
image, labels = synthetic_data(shape=(1024, 1024))
image, labels = synthetic_data(shape=(1024, 1022))

export_bioimageio_model(
image, labels,
model_type="vit_b", model_name="sam-test-vit-b",
model_type="vit_b", name="sam-test-vit-b",
output_path="./test_export.zip",
)

257 changes: 200 additions & 57 deletions micro_sam/modelzoo/bioimageio_export.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,30 @@
import os
from pathlib import Path
from tempfile import NamedTemporaryFile as tmp_file
import numpy as np
from typing import Optional, Union

import bioimageio.spec.model.v0_5 as spec
import numpy as np
import torch

from bioimageio.core.build_spec import build_model
from bioimageio.spec import save_bioimageio_package


from .. import util
from ..prompt_generators import PointAndBoxPromptGenerator
from .predictor_adaptor import PredictorAdaptor


def _get_model(image, model_type, checkpoint_path):
"Returns the model and predictor while initializing with the model checkpoints"
predictor, sam_model = util.get_sam_model(model_type=model_type, return_sam=True,
checkpoint_path=checkpoint_path) # type: ignore
image_embeddings = util.precompute_image_embeddings(predictor, image)
util.set_precomputed(predictor, image_embeddings)
return predictor, sam_model
# TODO extend the defaults
DEFAULTS = {
"authors": [
spec.Author(name="Anwai Archit", affiliation="University Goettingen", github_user="anwai98"),
spec.Author(name="Constantin Pape", affiliation="University Goettingen", github_user="constantinpape"),
],
"description": "Finetuned Segment Anything Model for Microscopy",
"cite": [
spec.CiteEntry(text="Archit et al. Segment Anything for Microscopy", doi=spec.Doi("10.1101/2023.08.21.554208")),
]
}


def _create_test_inputs_and_outputs(
@@ -36,34 +42,45 @@ def _create_test_inputs_and_outputs(
# For now we just generate a single box prompt here, but we could also generate more input prompts.
generator = PointAndBoxPromptGenerator(0, 0, 4, False, True)
centers, bounding_boxes = util.get_centers_and_bounding_boxes(labels)
masks = util.segmentation_to_one_hot(labels.astype("int64"), segmentation_ids=[1]) # type: ignore
_, _, box_prompts, _ = generator(masks, [bounding_boxes[1]])
box_prompts = box_prompts.numpy()

save_image_path = input_path.name
np.save(save_image_path, image[None, None])
masks = util.segmentation_to_one_hot(labels.astype("int64"), segmentation_ids=[1, 2]) # type: ignore
_, _, box_prompts, _ = generator(masks, [bounding_boxes[1], bounding_boxes[2]])
box_prompts = box_prompts.numpy()[None]

_, sam_model = _get_model(image, model_type, checkpoint_path)
predictor = PredictorAdaptor(sam_model=sam_model)
predictor = PredictorAdaptor(model_type=model_type)
predictor.load_state_dict(torch.load(checkpoint_path))

save_box_prompt_path = box_path.name
np.save(save_box_prompt_path, box_prompts)

input_ = util._to_image(image).transpose(2, 0, 1)
input_ = util._to_image(image).transpose(2, 0, 1)[None]
save_image_path = input_path.name
np.save(save_image_path, input_)

masks, scores, embeddings = predictor(
input_image=torch.from_numpy(input_)[None],
image_embeddings=None,
box_prompts=torch.from_numpy(box_prompts)[None]
image=torch.from_numpy(input_),
embeddings=None,
box_prompts=torch.from_numpy(box_prompts)
)

np.save(mask_path.name, masks.numpy())
np.save(score_path.name, scores.numpy())
np.save(embed_path.name, embeddings.numpy())

return [save_image_path, save_box_prompt_path], [mask_path.name, score_path.name, embed_path.name]
# TODO autogenerate the cover and return it too.

inputs = {
"image": save_image_path,
"box_prompts": save_box_prompt_path,
}
outputs = {
"mask": mask_path.name,
"score": score_path.name,
"embeddings": embed_path.name
}
return inputs, outputs


# TODO url with documentation for the modelzoo interface, and just add it to defaults
def _write_documentation(doc_path, doc):
with open(doc_path, "w") as f:
if doc is None:
@@ -75,15 +92,19 @@ def _write_documentation(doc_path, doc):
return doc_path


# TODO enable over-riding the authors and citation and tags from kwargs
# TODO support RGB sample inputs
def _get_checkpoint(model_type, checkpoint_path):
if checkpoint_path is None:
model_registry = util.models()
checkpoint_path = model_registry.fetch(model_type)
return checkpoint_path


def export_bioimageio_model(
image: np.ndarray,
label_image: np.ndarray,
model_type: str,
model_name: str,
name: str,
output_path: Union[str, os.PathLike],
doc: Optional[str] = None,
checkpoint_path: Optional[Union[str, os.PathLike]] = None,
**kwargs
) -> None:
@@ -97,12 +118,9 @@ def export_bioimageio_model(
label_image: The segmentation correspoding to `image`.
It is used to derive prompt inputs for the model.
model_type: The type of the SAM model.
model_name: The name of the exported model.
name: The name of the exported model.
output_path: Where the exported model is saved.
doc: Documentation for the model.
checkpoint_path: Optional checkpoint for loading the SAM model.
kwargs: optional keyword arguments for the 'build_model' function
that converts to the modelzoo format.
"""
with (
tmp_file(suffix=".md") as tmp_doc_path,
@@ -112,6 +130,7 @@ def export_bioimageio_model(
tmp_file(suffix=".npy") as tmp_score_path,
tmp_file(suffix=".npy") as tmp_embed_path,
):
checkpoint_path = _get_checkpoint(model_type, checkpoint_path=checkpoint_path)
input_paths, result_paths = _create_test_inputs_and_outputs(
image, label_image, model_type, checkpoint_path,
input_path=tmp_input_path,
@@ -120,35 +139,159 @@ def export_bioimageio_model(
score_path=tmp_score_path,
embed_path=tmp_embed_path,
)
checkpoint = util._get_checkpoint(model_type, checkpoint_path=checkpoint_path)
input_descriptions = [
# First input: the image data.
spec.InputTensorDescr(
id=spec.TensorId("image"),
axes=[
spec.BatchAxis(),
# NOTE: to support 1 and 3 channels we can add another preprocessing.
# Best solution: Have a pre-processing for this! (1C -> RGB)
spec.ChannelAxis(channel_names=[spec.Identifier(cname) for cname in "RGB"]),
spec.SpaceInputAxis(id=spec.AxisId("y"), size=spec.ARBITRARY_SIZE),
spec.SpaceInputAxis(id=spec.AxisId("x"), size=spec.ARBITRARY_SIZE),
],
test_tensor=spec.FileDescr(source=input_paths["image"]),
data=spec.IntervalOrRatioDataDescr(type="uint8")
),

# Second input: the box prompts (optional)
spec.InputTensorDescr(
id=spec.TensorId("box_prompts"),
optional=True,
axes=[
spec.BatchAxis(),
spec.IndexAxis(
id=spec.AxisId("object"),
size=spec.ARBITRARY_SIZE
),
# TODO double check the axis names
spec.ChannelAxis(channel_names=[spec.Identifier(bname) for bname in "hwxy"]),
],
test_tensor=spec.FileDescr(source=input_paths["box_prompts"]),
data=spec.IntervalOrRatioDataDescr(type="int64")
),

# TODO
# Third input: the point prompts (optional)

# TODO
# Fourth input: the mask prompts (optional)

# Fifth input: the image embeddings (optional)
spec.InputTensorDescr(
id=spec.TensorId("embeddings"),
optional=True,
axes=[
spec.BatchAxis(),
# NOTE: we currently have to specify all the channel names
# (It would be nice to also support size)
spec.ChannelAxis(channel_names=[spec.Identifier(f"c{i}") for i in range(256)]),
spec.SpaceInputAxis(id=spec.AxisId("y"), size=64),
spec.SpaceInputAxis(id=spec.AxisId("x"), size=64),
],
test_tensor=spec.FileDescr(source=result_paths["embeddings"]),
data=spec.IntervalOrRatioDataDescr(type="float32")
),

]

output_descriptions = [
# First output: The mask predictions.
spec.OutputTensorDescr(
id=spec.TensorId("masks"),
axes=[
spec.BatchAxis(),
spec.IndexAxis(
id=spec.AxisId("object"),
size=spec.SizeReference(
tensor_id=spec.TensorId("box_prompts"), axis_id=spec.AxisId("object")
)
),
# NOTE: this could be a 3 once we use multi-masking
spec.ChannelAxis(channel_names=[spec.Identifier("mask")]),
spec.SpaceOutputAxis(
id=spec.AxisId("y"),
size=spec.SizeReference(
tensor_id=spec.TensorId("image"), axis_id=spec.AxisId("y"),
)
),
spec.SpaceOutputAxis(
id=spec.AxisId("x"),
size=spec.SizeReference(
tensor_id=spec.TensorId("image"), axis_id=spec.AxisId("x"),
)
)
],
data=spec.IntervalOrRatioDataDescr(type="uint8"),
test_tensor=spec.FileDescr(source=result_paths["mask"])
),

# The score predictions
spec.OutputTensorDescr(
id=spec.TensorId("scores"),
axes=[
spec.BatchAxis(),
spec.IndexAxis(
id=spec.AxisId("object"),
size=spec.SizeReference(
tensor_id=spec.TensorId("box_prompts"), axis_id=spec.AxisId("object")
)
),
# NOTE: this could be a 3 once we use multi-masking
spec.ChannelAxis(channel_names=[spec.Identifier("mask")]),
],
data=spec.IntervalOrRatioDataDescr(type="float32"),
test_tensor=spec.FileDescr(source=result_paths["score"])
),

# The image embeddings
spec.OutputTensorDescr(
id=spec.TensorId("embeddings"),
axes=[
spec.BatchAxis(),
spec.ChannelAxis(channel_names=[spec.Identifier(f"c{i}") for i in range(256)]),
spec.SpaceOutputAxis(id=spec.AxisId("y"), size=64),
spec.SpaceOutputAxis(id=spec.AxisId("x"), size=64),
],
data=spec.IntervalOrRatioDataDescr(type="float32"),
test_tensor=spec.FileDescr(source=result_paths["embeddings"])
)
]

# TODO sha256
architecture_path = os.path.join(os.path.split(__file__)[0], "predictor_adaptor.py")
architecture = spec.ArchitectureFromFileDescr(
source=Path(architecture_path),
callable="PredictorAdaptor",
kwargs={"model_type": model_type}
)

weight_descriptions = spec.WeightsDescr(
pytorch_state_dict=spec.PytorchStateDictWeightsDescr(
source=Path(checkpoint_path),
architecture=architecture,
pytorch_version=spec.Version(torch.__version__),
)
)

doc_path = tmp_doc_path.name
_write_documentation(doc_path, doc)

build_model(
weight_uri=checkpoint, # type: ignore
test_inputs=input_paths,
test_outputs=result_paths,
input_axes=["bcyx", "bic"],
# FIXME this causes some error in build-model
# input_names=["image", "box-prompts"],
output_axes=["bcyx", "bic", "bcyx"],
# FIXME this causes some error in build-model
# output_names=["masks", "scores", "image_embeddings"],
name=model_name,
description="Finetuned Segment Anything models for Microscopy",
authors=[{"name": "Anwai Archit", "affiliation": "Uni Goettingen"},
{"name": "Constantin Pape", "affiliation": "Uni Goettingen"}],
tags=["instance-segmentation", "segment-anything"],
license="CC-BY-4.0",
documentation=doc_path, # type: ignore
cite=[{"text": "Archit, ..., Pape et al. Segment Anything for Microscopy",
"doi": "10.1101/2023.08.21.554208"}],
output_path=output_path, # type: ignore
architecture=architecture_path,
**kwargs,
_write_documentation(doc_path, kwargs.get("documentation", None))

# TODO tags, dependencies, other stuff ...
model_description = spec.ModelDescr(
name=name,
description=kwargs.get("description", DEFAULTS["description"]),
authors=kwargs.get("authors", DEFAULTS["authors"]),
cite=kwargs.get("cite", DEFAULTS["cite"]),
license=spec.LicenseId("MIT"),
documentation=Path(doc_path),
git_repo=spec.HttpUrl("https://github.com/computational-cell-analytics/micro-sam"),
inputs=input_descriptions,
outputs=output_descriptions,
weights=weight_descriptions,
)

# TODO actually test the model
# TODO test the model.

save_bioimageio_package(model_description, output_path=output_path)
99 changes: 69 additions & 30 deletions micro_sam/modelzoo/predictor_adaptor.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,94 @@
from typing import Optional
import warnings
from typing import Optional, Tuple

import torch
from torch import nn

from segment_anything.predictor import SamPredictor

try:
# Avoid import warnings from mobile_sam
with warnings.catch_warnings():
warnings.simplefilter("ignore")
from mobile_sam import sam_model_registry
except ImportError:
from segment_anything import sam_model_registry

class PredictorAdaptor(SamPredictor):
"""Wrapper around the SamPredictor to be used by BioImage.IO model format.

# TODO we need to accept and return an additional tensor for the image sizes to support embeddings
class PredictorAdaptor(nn.Module):
"""Wrapper around the SamPredictor.
This model supports the same functionality as SamPredictor and can provide mask segmentations
from box, point or mask input prompts.
Args:
model_type: The type of the model for the image encoder.
Can be one of 'vit_b', 'vit_l', 'vit_h' or 'vit_t'.
For 'vit_t' support the 'mobile_sam' package has to be installed.
"""
def __call__(
self,
input_image: torch.Tensor,
image_embeddings: Optional[torch.Tensor] = None,
box_prompts: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Expected inputs:
- input_image: torch inputs of dimensions B x C x H x W
- image_embeddings: precomputed image embeddings
- box_prompts: box prompts of dimensions C x 4
def __init__(self, model_type: str) -> None:
super().__init__()
sam_model = sam_model_registry[model_type]()
self.sam = SamPredictor(sam_model)

def load_state_dict(self, state):
self.sam.model.load_state_dict(state)

@torch.no_grad()
def forward(
self,
image: torch.Tensor,
box_prompts: Optional[torch.Tensor] = None,
# TODO add point and mask prompts
embeddings: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
image: torch inputs of dimensions B x C x H x W
box_prompts: box prompts of dimensions B x OBJECTS x 4
embeddings: precomputed image embeddings B x 256 x 64 x 64
Returns:
"""
batch_size = image.shape[0]
if batch_size != 1:
raise ValueError

# We have image embeddings set and image embeddings were not passed.
if self.is_image_set and image_embeddings is None:
if self.sam.is_image_set and embeddings is None:
pass # do nothing

# We have image embeddings set and image embeddings were passed.
elif self.is_image_set and image_embeddings is not None:
self.features = image_embeddings

# We don't have image embeddings set and image embeddings were passed.
elif image_embeddings is not None:
self.features = image_embeddings
# The embeddings are passed, so we set them.
elif embeddings is not None:
self.sam.features = embeddings
self.sam.orig_h, self.sam.orig_w = image.shape[2:]
self.sam.input_h, self.sam.input_w = self.sam.transform.apply_image_torch(image).shape[2:]
self.sam.is_image_set = True

# We don't have image embeddings set and they were not apassed
elif not self.is_image_set:
image = self.transform.apply_image_torch(input_image)
self.set_torch_image(image, original_image_size=input_image.numpy().shape[2:])
elif not self.sam.is_image_set:
image = self.sam.transform.apply_image_torch(image)
self.sam.set_torch_image(image, original_image_size=image.numpy().shape[2:])

boxes = self.transform.apply_boxes_torch(box_prompts, original_size=input_image.numpy().shape[2:])
boxes = self.sam.transform.apply_boxes_torch(box_prompts, original_size=image.numpy().shape[2:])

masks, scores, _ = self.predict_torch(
masks, scores, _ = self.sam.predict_torch(
point_coords=None,
point_labels=None,
boxes=boxes,
multimask_output=False
)

assert masks.shape[2:] == input_image.shape[2:],\
f"{masks.shape[2:]} is not as expected ({input_image.shape[2:]})"
assert masks.shape[2:] == image.shape[2:], \
f"{masks.shape[2:]} is not as expected ({image.shape[2:]})"

# Ensure batch axis.
if masks.ndim == 4:
masks = masks[None]
assert scores.ndim == 2
scores = scores[None]

image_embeddings = self.features
return masks, scores, image_embeddings
embeddings = self.sam.get_image_embedding()
return masks.to(dtype=torch.uint8), scores, embeddings