Skip to content

Commit

Permalink
Download models with pooch (#276)
Browse files Browse the repository at this point in the history
Use pooch to download model weights and simplify some other functionality
---------

Co-authored-by: Constantin Pape <constantin.pape@informatik.uni-goettingen.de>
Co-authored-by: Constantin Pape <c.pape@gmx.net>
3 people authored Nov 23, 2023
1 parent 23e974f commit c2a4e54
Showing 8 changed files with 123 additions and 114 deletions.
2 changes: 1 addition & 1 deletion development/benchmark.py
Original file line number Diff line number Diff line change
@@ -180,7 +180,7 @@ def main():
args = parser.parse_args()

model_type = args.model_type
device = util._get_device(args.device)
device = util.get_device(args.device)
print("Running benchmarks for", model_type)
print("with device:", device)

2 changes: 1 addition & 1 deletion micro_sam/evaluation/model_comparison.py
Original file line number Diff line number Diff line change
@@ -109,7 +109,7 @@ def generate_data_for_model_comparison(
output_folder: The folder where the samples will be saved.
model_type1: The first model to use for comparison.
The value needs to be a valid model_type for `micro_sam.util.get_sam_model`.
model_type1: The second model to use for comparison.
model_type2: The second model to use for comparison.
The value needs to be a valid model_type for `micro_sam.util.get_sam_model`.
n_samples: The number of samples to draw from the dataloader.
"""
2 changes: 1 addition & 1 deletion micro_sam/precompute_state.py
Original file line number Diff line number Diff line change
@@ -158,7 +158,7 @@ def main():
parser = argparse.ArgumentParser(description="Compute the embeddings for an image.")
parser.add_argument("-i", "--input_path", required=True)
parser.add_argument("-o", "--output_path", required=True)
parser.add_argument("-m", "--model_type", default="vit_h")
parser.add_argument("-m", "--model_type", default=util._DEFAULT_MODEL)
parser.add_argument("-c", "--checkpoint_path", default=None)
parser.add_argument("-k", "--key")
parser.add_argument(
34 changes: 16 additions & 18 deletions micro_sam/sam_annotator/_widgets.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from enum import Enum
import os
from pathlib import Path
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Literal

from magicgui import magic_factory, widgets
from napari.qt.threading import thread_worker
import pooch
import zarr
from zarr.errors import PathNotFoundError

@@ -14,7 +12,7 @@
ImageEmbeddings,
get_sam_model,
precompute_image_embeddings,
_MODEL_URLS,
models,
_DEFAULT_MODEL,
_available_devices,
get_cache_directory,
@@ -23,21 +21,17 @@
if TYPE_CHECKING:
import napari

Model = Enum("Model", _MODEL_URLS)
available_devices_list = ["auto"] + _available_devices()


@magic_factory(
pbar={'visible': False, 'max': 0, 'value': 0, 'label': 'working...'},
call_button="Compute image embeddings",
device = {"choices": available_devices_list},
save_path={"mode": "d"}, # choose a directory
)
def embedding_widget(
pbar: widgets.ProgressBar,
image: "napari.layers.Image",
model: Model = Model.__getitem__(_DEFAULT_MODEL),
device = "auto",
model: Literal[tuple(models().urls.keys())] = _DEFAULT_MODEL,
device: Literal[tuple(["auto"] + _available_devices())] = "auto",
save_path: Optional[Path] = None, # where embeddings for this image are cached (optional)
optional_custom_weights: Optional[Path] = None, # A filepath or URL to custom model weights.
) -> ImageEmbeddings:
@@ -54,8 +48,9 @@ def embedding_widget(

@thread_worker(connect={'started': pbar.show, 'finished': pbar.hide})
def _compute_image_embedding(state, image_data, save_path, ndim=None,
device="auto", model=Model.__getitem__(_DEFAULT_MODEL),
optional_custom_weights=None):
device="auto", model=_DEFAULT_MODEL,
optional_custom_weights=None,
):
# Make sure save directory exists and is an empty directory
if save_path is not None:
os.makedirs(save_path, exist_ok=True)
@@ -71,19 +66,22 @@ def _compute_image_embedding(state, image_data, save_path, ndim=None,
"The user selected 'save_path' is not a zarr array "
f"or empty directory: {save_path}"
)

# Initialize the model
state.predictor = get_sam_model(device=device, model_type=model.name,
checkpoint_path=optional_custom_weights)
state.predictor = get_sam_model(device=device, model_type=model, checkpoint_path=optional_custom_weights)
# Compute the image embeddings
state.image_embeddings = precompute_image_embeddings(
predictor = state.predictor,
input_ = image_data,
save_path = str(save_path),
predictor=state.predictor,
input_=image_data,
save_path=save_path,
ndim=ndim,
)
return state # returns napari._qt.qthreading.FunctionWorker

return _compute_image_embedding(state, image.data, save_path, ndim=ndim, device=device, model=model, optional_custom_weights=optional_custom_weights)
return _compute_image_embedding(
state, image.data, save_path, ndim=ndim, device=device, model=model,
optional_custom_weights=optional_custom_weights
)


@magic_factory(
4 changes: 3 additions & 1 deletion micro_sam/training/util.py
Original file line number Diff line number Diff line change
@@ -18,7 +18,9 @@ def get_trainable_sam_model(
"""Get the trainable sam model.
Args:
model_type: The type of the segment anything model.
model_type: The segment anything model that should be finetuned.
The weights of this model will be used for initialization, unless a
custom weight file is passed via `checkpoint_path`.
device: The device to use for training.
checkpoint_path: Path to a custom checkpoint from which to load the model weights.
freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder
187 changes: 98 additions & 89 deletions micro_sam/util.py
Original file line number Diff line number Diff line change
@@ -8,13 +8,11 @@
import pickle
import warnings
from collections import OrderedDict
from shutil import copyfileobj
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union

import imageio.v3 as imageio
import numpy as np
import pooch
import requests
import torch
import vigra
import zarr
@@ -36,45 +34,15 @@
except ImportError:
from tqdm import tqdm

_MODEL_URLS = {
# the default segment anything models
"vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
"vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
"vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
# the model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM
"vit_t": "https://owncloud.gwdg.de/index.php/s/TuDzuwVDHd1ZDnQ/download",
# first version of finetuned models on zenodo
"vit_h_lm": "https://zenodo.org/record/8250299/files/vit_h_lm.pth?download=1",
"vit_b_lm": "https://zenodo.org/record/8250281/files/vit_b_lm.pth?download=1",
"vit_h_em": "https://zenodo.org/record/8250291/files/vit_h_em.pth?download=1",
"vit_b_em": "https://zenodo.org/record/8250260/files/vit_b_em.pth?download=1",
}

_CHECKSUMS = {
# the default segment anything models
"vit_h": "a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e",
"vit_l": "3adcc4315b642a4d2101128f611684e8734c41232a17c648ed1693702a49a622",
"vit_b": "ec2df62732614e57411cdcf32a23ffdf28910380d03139ee0f4fcbe91eb8c912",
# the model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM
"vit_t": "6dbb90523a35330fedd7f1d3dfc66f995213d81b29a5ca8108dbcdd4e37d6c2f",
# first version of finetuned models on zenodo
"vit_h_lm": "9a65ee0cddc05a98d60469a12a058859c89dc3ea3ba39fed9b90d786253fbf26",
"vit_b_lm": "5a59cc4064092d54cd4d92cd967e39168f3760905431e868e474d60fe5464ecd",
"vit_h_em": "ae3798a0646c8df1d4db147998a2d37e402ff57d3aa4e571792fbb911d8a979c",
"vit_b_em": "c04a714a4e14a110f0eec055a65f7409d54e6bf733164d2933a0ce556f7d6f81",
}
# this is required so that the downloaded file is not called 'download'
_DOWNLOAD_NAMES = {
"vit_t": "vit_t_mobile_sam.pth",
"vit_h_lm": "vit_h_lm.pth",
"vit_b_lm": "vit_b_lm.pth",
"vit_h_em": "vit_h_em.pth",
"vit_b_em": "vit_b_em.pth",
}

# this is the default model used in micro_sam
# currently set to the default vit_h
_DEFAULT_MODEL = "vit_h"

# The valid model types. Each type corresponds to the architecture of the
# vision transformer used within SAM.
_MODEL_TYPES = ("vit_h", "vit_b", "vit_l", "vit_t")


# TODO define the proper type for image embeddings
ImageEmbeddings = Dict[str, Any]
@@ -90,53 +58,62 @@ def get_cache_directory() -> None:
cache_directory = Path(os.environ.get('MICROSAM_CACHEDIR', default_cache_directory))
return cache_directory


#
# Functionality for model download and export
#

def microsam_cachedir():
"""Return the micro-sam cache directory.
def _download(url, path, model_type):
with requests.get(url, stream=True, verify=True) as r:
if r.status_code != 200:
r.raise_for_status()
raise RuntimeError(f"Request to {url} returned status code {r.status_code}")
file_size = int(r.headers.get("Content-Length", 0))
desc = f"Download {url} to {path}"
if file_size == 0:
desc += " (unknown file size)"
with tqdm.wrapattr(r.raw, "read", total=file_size, desc=desc) as r_raw, open(path, "wb") as f:
copyfileobj(r_raw, f)

# validate the checksum
expected_checksum = _CHECKSUMS[model_type]
if expected_checksum is None:
return
with open(path, "rb") as f:
file_ = f.read()
checksum = hashlib.sha256(file_).hexdigest()
if checksum != expected_checksum:
raise RuntimeError(
"The checksum of the download does not match the expected checksum."
f"Expected: {expected_checksum}, got: {checksum}"
)
print("Download successful and checksums agree.")
Returns the top level cache directory for micro-sam models and sample data.
Every time this function is called, we check for any user updates made to
the MICROSAM_CACHEDIR os environment variable since the last time.
"""
cache_directory = os.environ.get('MICROSAM_CACHEDIR') or pooch.os_cache('micro_sam')
return cache_directory

def _get_checkpoint(model_type, checkpoint_path=None):
if checkpoint_path is None:
checkpoint_url = _MODEL_URLS[model_type]
checkpoint_name = _DOWNLOAD_NAMES.get(model_type, checkpoint_url.split("/")[-1])
checkpoint_folder = os.path.join(get_cache_directory(), "models")
checkpoint_path = os.path.join(checkpoint_folder, checkpoint_name)

# download the checkpoint if necessary
if not os.path.exists(checkpoint_path):
os.makedirs(checkpoint_folder, exist_ok=True)
_download(checkpoint_url, checkpoint_path, model_type)
elif not os.path.exists(checkpoint_path):
raise ValueError(f"The checkpoint path {checkpoint_path} that was passed does not exist.")
def models():
"""Return the segmentation models registry.
return checkpoint_path
We recreate the model registry every time this function is called,
so any user changes to the default micro-sam cache directory location
are respected.
"""
models = pooch.create(
path=os.path.join(microsam_cachedir(), "models"),
base_url="",
registry={
# the default segment anything models
"vit_h": "a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e",
"vit_l": "3adcc4315b642a4d2101128f611684e8734c41232a17c648ed1693702a49a622",
"vit_b": "ec2df62732614e57411cdcf32a23ffdf28910380d03139ee0f4fcbe91eb8c912",
# the model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM
"vit_t": "6dbb90523a35330fedd7f1d3dfc66f995213d81b29a5ca8108dbcdd4e37d6c2f",
# first version of finetuned models on zenodo
"vit_h_lm": "9a65ee0cddc05a98d60469a12a058859c89dc3ea3ba39fed9b90d786253fbf26",
"vit_b_lm": "5a59cc4064092d54cd4d92cd967e39168f3760905431e868e474d60fe5464ecd",
"vit_h_em": "ae3798a0646c8df1d4db147998a2d37e402ff57d3aa4e571792fbb911d8a979c",
"vit_b_em": "c04a714a4e14a110f0eec055a65f7409d54e6bf733164d2933a0ce556f7d6f81",
},
# Now specify custom URLs for some of the files in the registry.
urls={
# the default segment anything models
"vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
"vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
"vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
# the model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM
"vit_t": "https://owncloud.gwdg.de/index.php/s/TuDzuwVDHd1ZDnQ/download",
# first version of finetuned models on zenodo
"vit_h_lm": "https://zenodo.org/record/8250299/files/vit_h_lm.pth?download=1",
"vit_b_lm": "https://zenodo.org/record/8250281/files/vit_b_lm.pth?download=1",
"vit_h_em": "https://zenodo.org/record/8250291/files/vit_h_em.pth?download=1",
"vit_b_em": "https://zenodo.org/record/8250260/files/vit_b_em.pth?download=1",
},
)
return models


def _get_default_device():
@@ -208,9 +185,16 @@ def get_sam_model(
) -> SamPredictor:
r"""Get the SegmentAnything Predictor.
This function will download the required model checkpoint or load it from file if it
was already downloaded.
This location can be changed by setting the environment variable: MICROSAM_CACHEDIR.
This function will download the required model or load it from the cached weight file.
This location of the cache can be changed by setting the environment variable: MICROSAM_CACHEDIR.
The name of the requested model can be set via `model_type`.
See https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models
for an overview of the available models
Alternatively this function can also load a model from weights stored in a local filepath.
The corresponding file path is given via `checkpoint_path`. In this case `model_type`
must be given as the matching encoder architecture, e.g. "vit_b" if the weights are for
a SAM model with vit_b encoder.
By default the models are downloaded to a folder named 'micro_sam/models'
inside your default cache directory, eg:
@@ -222,30 +206,52 @@ def get_sam_model(
Args:
model_type: The SegmentAnything model to use. Will use the standard vit_h model by default.
To get a list of all available model names you can call `get_model_names`.
device: The device for the model. If none is given will use GPU if available.
checkpoint_path: The path to the corresponding checkpoint if not in the default model folder.
checkpoint_path: The path to a file with weights that should be used instead of using the
weights corresponding to `model_type`. If given, `model_type` must match the architecture
corresponding to the weight file. E.g. if you use weights for SAM with vit_b encoder
then `model_type` must be given as "vit_b".
return_sam: Return the sam model object as well as the predictor.
Returns:
The segment anything predictor.
"""
checkpoint = _get_checkpoint(model_type, checkpoint_path)
device = get_device(device)

# Our custom model types have a suffix "_...". This suffix needs to be stripped
# We support passing a local filepath to a checkpoint.
# In this case we do not download any weights but just use the local weight file,
# as it is, without copying it over anywhere or checking it's hashes.

# checkpoint_path has not been passed, we download a known model and derive the correct
# URL from the model_type. If the model_type is invalid pooch will raise an error.
if checkpoint_path is None:
model_registry = models()
checkpoint = model_registry.fetch(model_type)
# checkpoint_path has been passed, we use it instead of downloading a model.
else:
# Check if the file exists and raise an error otherwise.
# We can't check any hashes here, and we don't check if the file is actually a valid weight file.
# (If it isn't the model creation will fail below.)
if not os.path.exists(checkpoint_path):
raise ValueError(f"Checkpoint at {checkpoint_path} could not be found.")
checkpoint = checkpoint_path

# Our fine-tuned model types have a suffix "_...". This suffix needs to be stripped
# before calling sam_model_registry.
model_type_ = model_type[:5]
assert model_type_ in ("vit_h", "vit_b", "vit_l", "vit_t")
if model_type == "vit_t" and not VIT_T_SUPPORT:
abbreviated_model_type = model_type[:5]
if abbreviated_model_type not in _MODEL_TYPES:
raise ValueError(f"Invalid model_type: {abbreviated_model_type}. Expect one of {_MODEL_TYPES}")
if abbreviated_model_type == "vit_t" and not VIT_T_SUPPORT:
raise RuntimeError(
"mobile_sam is required for the vit-tiny."
"You can install it via 'pip install git+https://github.com/ChaoningZhang/MobileSAM.git'"
)

sam = sam_model_registry[model_type_](checkpoint=checkpoint)
sam = sam_model_registry[abbreviated_model_type](checkpoint=checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)
predictor.model_type = model_type
predictor.model_type = abbreviated_model_type
if return_sam:
return predictor, sam
return predictor
@@ -278,7 +284,7 @@ def get_custom_sam_model(
Args:
checkpoint_path: The path to the corresponding checkpoint if not in the default model folder.
model_type: The SegmentAnything model to use.
model_type: The SegmentAnything model_type for the given checkpoint.
device: The device for the model. If none is given will use GPU if available.
return_sam: Return the sam model object as well as the predictor.
return_state: Return the full state of the checkpoint in addition to the predictor.
@@ -328,7 +334,7 @@ def export_custom_sam_model(
Args:
checkpoint_path: The path to the corresponding checkpoint if not in the default model folder.
model_type: The SegmentAnything model type to use (vit_h, vit_b or vit_l).
model_type: The SegmentAnything model type corresponding to the checkpoint (vit_h, vit_b, vit_l or vit_t).
save_path: Where to save the exported model.
"""
_, state = get_custom_sam_model(
@@ -343,7 +349,9 @@ def export_custom_sam_model(


def get_model_names() -> Iterable:
return _MODEL_URLS.keys()
model_registry = models()
model_names = model_registry.registry.keys()
return model_names


#
@@ -600,6 +608,7 @@ def precompute_image_embeddings(
assert save_path is not None, "Tiled prediction is only supported when the embeddings are saved to file."

if save_path is not None:
save_path = str(save_path)
data_signature = _compute_data_signature(input_)

f = zarr.open(save_path, "a")
4 changes: 2 additions & 2 deletions test/test_sam_annotator/test_widgets.py
Original file line number Diff line number Diff line change
@@ -7,7 +7,7 @@
import zarr

from micro_sam.sam_annotator._state import AnnotatorState
from micro_sam.sam_annotator._widgets import embedding_widget, Model
from micro_sam.sam_annotator._widgets import embedding_widget
from micro_sam.util import _compute_data_signature


@@ -22,7 +22,7 @@ def test_embedding_widget(make_napari_viewer, tmp_path):
layer = viewer.open_sample('napari', 'camera')[0]
my_widget = embedding_widget()
# run image embedding widget
worker = my_widget(image=layer, model=Model.vit_t, device="cpu", save_path=tmp_path)
worker = my_widget(image=layer, model="vit_t", device="cpu", save_path=tmp_path)
worker.await_workers() # blocks until thread worker is finished the embedding
# Check in-memory state - predictor
assert isinstance(AnnotatorState().predictor, (SamPredictor, MobileSamPredictor))
2 changes: 1 addition & 1 deletion test/test_util.py
Original file line number Diff line number Diff line change
@@ -34,7 +34,7 @@ def check_predictor(predictor):

# check predictor with checkpoint path (using the cached model)
checkpoint_path = os.path.join(
get_cache_directory(), "models", "vit_t_mobile_sam.pth" if VIT_T_SUPPORT else "sam_vit_b_01ec64.pth"
get_cache_directory(), "models", "vit_t" if VIT_T_SUPPORT else "vit_b"
)
predictor = get_sam_model(model_type=self.model_type, checkpoint_path=checkpoint_path)
check_predictor(predictor)

0 comments on commit c2a4e54

Please sign in to comment.