Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve device handling #251

Merged
merged 5 commits into from
Nov 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions micro_sam/evaluation/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def _run_inference_with_prompts_for_image(
def get_predictor(
checkpoint_path: Union[str, os.PathLike],
model_type: str,
device: Optional[str] = None,
device: Optional[Union[str, torch.device]] = None,
return_state: bool = False,
is_custom_model: Optional[bool] = None,
) -> SamPredictor:
Expand All @@ -146,12 +146,13 @@ def get_predictor(
Args:
checkpoint_path: The checkpoint filepath.
model_type: The type of the model, either vit_h, vit_b or vit_l.
device: The device to use.
return_state: Whether to return the complete state of the checkpoint in addtion to the predictor.
is_custom_model: Whether this is a custom model or not.
Returns:
The segment anything predictor.
"""
device = util._get_device(device)
device = util.get_device(device)

# By default we check if the model follows the torch_em checkpint naming scheme to check whether it is a
# custom model or not. This can be over-ridden by passing True or False for is_custom_model.
Expand Down
9 changes: 5 additions & 4 deletions micro_sam/training/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,33 @@
from typing import List, Optional, Union

import numpy as np
import torch

from ..prompt_generators import PointAndBoxPromptGenerator
from ..util import get_centers_and_bounding_boxes, get_sam_model, segmentation_to_one_hot, _get_device
from ..util import get_centers_and_bounding_boxes, get_sam_model, segmentation_to_one_hot, get_device
from .trainable_sam import TrainableSAM


def get_trainable_sam_model(
model_type: str = "vit_h",
device: Optional[str] = None,
device: Optional[Union[str, torch.device]] = None,
checkpoint_path: Optional[Union[str, os.PathLike]] = None,
freeze: Optional[List[str]] = None,
) -> TrainableSAM:
"""Get the trainable sam model.

Args:
model_type: The type of the segment anything model.
device: The device to use for training.
checkpoint_path: Path to a custom checkpoint from which to load the model weights.
freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder
By default nothing is frozen and the full model is updated.
device: The device to use for training.

Returns:
The trainable segment anything model.
"""
# set the device here so that the correct one is passed to TrainableSAM below
device = _get_device(device)
device = get_device(device)
_, sam = get_sam_model(model_type=model_type, device=device, checkpoint_path=checkpoint_path, return_sam=True)

# freeze components of the model if freeze was passed
Expand Down
34 changes: 23 additions & 11 deletions micro_sam/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,17 +158,29 @@ def _get_default_device():
return device


def _get_device(device=None):
def get_device(device=None) -> Union[str, torch.device]:
"""Get the torch device.

If no device is passed the default device for your system is used.
Else it will be checked if the device you have passed is supported.

Args:
device: The input device.

Returns:
The device.
"""
if device is None or device == "auto":
device = _get_default_device()
else:
if device.lower() == "cuda":
device_type = device if isinstance(device, str) else device.type
if device_type.lower() == "cuda":
if not torch.cuda.is_available():
raise RuntimeError("PyTorch CUDA backend is not available.")
elif device.lower() == "mps":
elif device_type.lower() == "mps":
if not (torch.backends.mps.is_available() and torch.backends.mps.is_built()):
raise RuntimeError("PyTorch MPS backend is not available or is not built correctly.")
elif device.lower() == "cpu":
elif device_type.lower() == "cpu":
pass # cpu is always available
else:
raise RuntimeError(f"Unsupported device: {device}\n"
Expand All @@ -180,7 +192,7 @@ def _available_devices():
available_devices = []
for i in ["cuda", "mps", "cpu"]:
try:
device = _get_device(i)
device = get_device(i)
except RuntimeError:
pass
else:
Expand All @@ -190,7 +202,7 @@ def _available_devices():

def get_sam_model(
model_type: str = _DEFAULT_MODEL,
device: Optional[str] = None,
device: Optional[Union[str, torch.device]] = None,
checkpoint_path: Optional[Union[str, os.PathLike]] = None,
return_sam: bool = False,
) -> SamPredictor:
Expand All @@ -209,16 +221,16 @@ def get_sam_model(
https://www.fatiando.org/pooch/latest/api/generated/pooch.os_cache.html

Args:
device: The device for the model. If none is given will use GPU if available.
model_type: The SegmentAnything model to use. Will use the standard vit_h model by default.
device: The device for the model. If none is given will use GPU if available.
checkpoint_path: The path to the corresponding checkpoint if not in the default model folder.
return_sam: Return the sam model object as well as the predictor.

Returns:
The segment anything predictor.
"""
checkpoint = _get_checkpoint(model_type, checkpoint_path)
device = _get_device(device)
device = get_device(device)

# Our custom model types have a suffix "_...". This suffix needs to be stripped
# before calling sam_model_registry.
Expand Down Expand Up @@ -255,7 +267,7 @@ def find_class(self, module, name):
def get_custom_sam_model(
checkpoint_path: Union[str, os.PathLike],
model_type: str = "vit_h",
device: Optional[str] = None,
device: Optional[Union[str, torch.device]] = None,
return_sam: bool = False,
return_state: bool = False,
) -> SamPredictor:
Expand All @@ -266,8 +278,8 @@ def get_custom_sam_model(

Args:
checkpoint_path: The path to the corresponding checkpoint if not in the default model folder.
device: The device for the model. If none is given will use GPU if available.
model_type: The SegmentAnything model to use.
device: The device for the model. If none is given will use GPU if available.
return_sam: Return the sam model object as well as the predictor.
return_state: Return the full state of the checkpoint in addition to the predictor.

Expand All @@ -280,7 +292,7 @@ def get_custom_sam_model(
custom_pickle = pickle
custom_pickle.Unpickler = _CustomUnpickler

device = _get_device(device)
device = get_device(device)
sam = sam_model_registry[model_type]()

# load the model state, ignoring any attributes that can't be found by pickle
Expand Down
6 changes: 3 additions & 3 deletions test/test_instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def write_object(center, radius):

@staticmethod
def _get_model(image, model_type):
predictor = util.get_sam_model(model_type=model_type, device=util._get_device(None))
predictor = util.get_sam_model(model_type=model_type, device=util.get_device(None))
image_embeddings = util.precompute_image_embeddings(predictor, image)
return predictor, image_embeddings

Expand Down Expand Up @@ -91,7 +91,7 @@ def test_tiled_automatic_mask_generator(self):
from micro_sam.instance_segmentation import TiledAutomaticMaskGenerator, mask_data_to_segmentation

# Release all unoccupied cached memory, tiling requires a lot of memory
device = util._get_device(None)
device = util.get_device(None)
if device == "cuda":
import torch.cuda
torch.cuda.empty_cache()
Expand Down Expand Up @@ -158,7 +158,7 @@ def test_tiled_embedding_mask_generator(self):
from micro_sam.instance_segmentation import _TiledEmbeddingMaskGenerator

# Release all unoccupied cached memory, tiling requires a lot of memory
device = util._get_device(None)
device = util.get_device(None)
if device == "cuda":
import torch.cuda
torch.cuda.empty_cache()
Expand Down
2 changes: 1 addition & 1 deletion test/test_prompt_based_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def _get_input(shape=(256, 256)):

@staticmethod
def _get_model(image, model_type):
predictor = util.get_sam_model(model_type=model_type, device=util._get_device(None))
predictor = util.get_sam_model(model_type=model_type, device=util.get_device(None))
image_embeddings = util.precompute_image_embeddings(predictor, image)
util.set_precomputed(predictor, image_embeddings)
return predictor
Expand Down
16 changes: 16 additions & 0 deletions test/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from shutil import rmtree

import numpy as np
import torch
import zarr

from skimage.data import binary_blobs
Expand Down Expand Up @@ -63,6 +64,21 @@ def test_segmentation_to_one_hot(self):

self.assertTrue(np.allclose(mask, expected_mask))

def test_get_device(self):
from micro_sam.util import get_device

# check that device without argument works
get_device()

# check passing device as string
device = get_device("cpu")
self.assertEqual(device, "cpu")

# check passing device as torch.device works
device = get_device(torch.device("cpu"))
self.assertTrue(isinstance(device, torch.device))
self.assertEqual(device.type, "cpu")


if __name__ == "__main__":
unittest.main()