diff --git a/micro_sam/evaluation/inference.py b/micro_sam/evaluation/inference.py index 5bbf9ca0..f8c64090 100644 --- a/micro_sam/evaluation/inference.py +++ b/micro_sam/evaluation/inference.py @@ -170,6 +170,7 @@ def get_predictor( return predictor +# TODO batch this computation def precompute_all_embeddings( predictor: SamPredictor, image_paths: List[Union[str, os.PathLike]], diff --git a/micro_sam/evaluation/instance_segmentation.py b/micro_sam/evaluation/instance_segmentation.py index 623bce8b..6050a946 100644 --- a/micro_sam/evaluation/instance_segmentation.py +++ b/micro_sam/evaluation/instance_segmentation.py @@ -12,6 +12,7 @@ import pandas as pd from elf.evaluation import mean_segmentation_accuracy +from elf.io import open_file from tqdm import tqdm from ..instance_segmentation import AMGBase, InstanceSegmentationWithDecoder, mask_data_to_segmentation @@ -63,7 +64,6 @@ def default_grid_search_values_instance_segmentation_with_decoder( boundary_distance_threshold_values: Optional[List[float]] = None, distance_smoothing_values: Optional[List[float]] = None, min_size_values: Optional[List[float]] = None, - ) -> Dict[str, List[float]]: if center_distance_threshold_values is None: @@ -89,9 +89,18 @@ def default_grid_search_values_instance_segmentation_with_decoder( } -def _grid_search( - segmenter, gs_combinations, gt, image_name, result_path, fixed_generate_kwargs, verbose, -): +# TODO update all arguments and description +def grid_search_iteration( + segmenter, + gs_combinations, + gt: np.ndarray, + image_name: str, + fixed_generate_kwargs, + result_path: Optional[Union[str, os.PathLike]], + verbose: bool = False, +) -> pd.DataFrame: + """ + """ net_list = [] for gs_kwargs in tqdm(gs_combinations, disable=not verbose): generate_kwargs = gs_kwargs | fixed_generate_kwargs @@ -111,16 +120,32 @@ def _grid_search( img_gs_df = pd.concat(net_list) img_gs_df.to_csv(result_path, index=False) + return img_gs_df + + +def _load_image(path, key, roi): + if key is None: + im = imageio.imread(path) + if roi is not None: + im = im[roi] + return im + with open_file(path, "r") as f: + im = f[key][:] if roi is None else f[key][roi] + return im + def run_instance_segmentation_grid_search( segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], grid_search_values: Dict[str, List], image_paths: List[Union[str, os.PathLike]], gt_paths: List[Union[str, os.PathLike]], - embedding_dir: Union[str, os.PathLike], result_dir: Union[str, os.PathLike], + embedding_dir: Optional[Union[str, os.PathLike]], fixed_generate_kwargs: Optional[Dict[str, Any]] = None, verbose_gs: bool = False, + image_key: Optional[str] = None, + gt_key: Optional[str] = None, + rois: Optional[Tuple[slice, ...]] = None, ) -> None: """Run grid search for automatic mask generation. @@ -144,10 +169,15 @@ def run_instance_segmentation_grid_search( grid_search_values: The grid search values for parameters of the `generate` function. image_paths: The input images for the grid search. gt_paths: The ground-truth segmentation for the grid search. - embedding_dir: Folder to cache the image embeddings. result_dir: Folder to cache the evaluation results per image. + embedding_dir: Folder to cache the image embeddings. fixed_generate_kwargs: Fixed keyword arguments for the `generate` method of the segmenter. verbose_gs: Whether to run the gridsearch for individual images in a verbose mode. + image_key: Key for loading the image data from a more complex file format like HDF5. + If not given a simple image format like tif is assumed. + gt_key: Key for loading the ground-truth data from a more complex file format like HDF5. + If not given a simple image format like tif is assumed. + rois: Region of interests to resetrict the evaluation to. """ assert len(image_paths) == len(gt_paths) fixed_generate_kwargs = {} if fixed_generate_kwargs is None else fixed_generate_kwargs @@ -167,10 +197,10 @@ def run_instance_segmentation_grid_search( ] os.makedirs(result_dir, exist_ok=True) - predictor = segmenter._predictor + predictor = getattr(segmenter, "_predictor", None) - for image_path, gt_path in tqdm( - zip(image_paths, gt_paths), desc="Run instance segmentation grid-search", total=len(image_paths) + for i, (image_path, gt_path) in tqdm( + enumerate(zip(image_paths, gt_paths)), desc="Run instance segmentation grid-search", total=len(image_paths) ): image_name = Path(image_path).stem result_path = os.path.join(result_dir, f"{image_name}.csv") @@ -182,16 +212,20 @@ def run_instance_segmentation_grid_search( assert os.path.exists(image_path), image_path assert os.path.exists(gt_path), gt_path - image = imageio.imread(image_path) - gt = imageio.imread(gt_path) + image = _load_image(image_path, image_key, roi=None if rois is None else rois[i]) + gt = _load_image(gt_path, gt_key, roi=None if rois is None else rois[i]) - embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr") - image_embeddings = util.precompute_image_embeddings(predictor, image, embedding_path, ndim=2) - segmenter.initialize(image, image_embeddings) + if embedding_dir is None: + segmenter.initialize(image) + else: + assert predictor is not None + embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr") + image_embeddings = util.precompute_image_embeddings(predictor, image, embedding_path, ndim=2) + segmenter.initialize(image, image_embeddings) - _grid_search( + grid_search_iteration( segmenter, gs_combinations, gt, image_name, - result_path=result_path, fixed_generate_kwargs=fixed_generate_kwargs, verbose=verbose_gs, + fixed_generate_kwargs=fixed_generate_kwargs, result_path=result_path, verbose=verbose_gs, ) diff --git a/micro_sam/instance_segmentation.py b/micro_sam/instance_segmentation.py index fe9e1bae..bb5e22ef 100644 --- a/micro_sam/instance_segmentation.py +++ b/micro_sam/instance_segmentation.py @@ -47,6 +47,7 @@ def __getitem__(self, index): return np.zeros(block_shape, dtype="float32") +# FIXME remove the shape argument def mask_data_to_segmentation( masks: List[Dict[str, Any]], shape: Tuple[int, ...], @@ -69,7 +70,9 @@ def mask_data_to_segmentation( """ masks = sorted(masks, key=(lambda x: x["area"]), reverse=True) - segmentation = np.zeros(shape[:2], dtype="uint32") + # we could also get the shape from the crop box + shape = next(iter(masks))["segmentation"].shape + segmentation = np.zeros(shape, dtype="uint32") def require_numpy(mask): return mask.cpu().numpy() if torch.is_tensor(mask) else mask @@ -872,16 +875,32 @@ def initialize( def _to_masks(self, segmentation, output_mode): if output_mode != "binary_mask": raise NotImplementedError + props = regionprops(segmentation) - crop_box = [0, segmentation.shape[1], 0, segmentation.shape[0]] + ndim = segmentation.ndim + assert ndim in (2, 3) + + shape = segmentation.shape + if ndim == 2: + crop_box = [0, shape[1], 0, shape[0]] + else: + crop_box = [0, shape[2], 0, shape[1], 0, shape[0]] # go from skimage bbox in format [y0, x0, y1, x1] to SAM format [x0, w, y0, h] - def to_bbox(bbox): + def to_bbox_2d(bbox): y0, x0 = bbox[0], bbox[1] w = bbox[3] - x0 h = bbox[2] - y0 return [x0, w, y0, h] + def to_bbox_3d(bbox): + z0, y0, x0 = bbox[0], bbox[1], bbox[2] + w = bbox[5] - x0 + h = bbox[4] - y0 + d = bbox[3] - y0 + return [x0, w, y0, h, z0, d] + + to_bbox = to_bbox_2d if ndim == 2 else to_bbox_3d masks = [ { "segmentation": segmentation == prop.label,