Skip to content

Commit

Permalink
Make grid search functionality more flexible WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
constantinpape committed Jan 2, 2024
1 parent ae46a20 commit be4a296
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 19 deletions.
1 change: 1 addition & 0 deletions micro_sam/evaluation/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def get_predictor(
return predictor


# TODO batch this computation
def precompute_all_embeddings(
predictor: SamPredictor,
image_paths: List[Union[str, os.PathLike]],
Expand Down
66 changes: 50 additions & 16 deletions micro_sam/evaluation/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import pandas as pd

from elf.evaluation import mean_segmentation_accuracy
from elf.io import open_file
from tqdm import tqdm

from ..instance_segmentation import AMGBase, InstanceSegmentationWithDecoder, mask_data_to_segmentation
Expand Down Expand Up @@ -63,7 +64,6 @@ def default_grid_search_values_instance_segmentation_with_decoder(
boundary_distance_threshold_values: Optional[List[float]] = None,
distance_smoothing_values: Optional[List[float]] = None,
min_size_values: Optional[List[float]] = None,

) -> Dict[str, List[float]]:

if center_distance_threshold_values is None:
Expand All @@ -89,9 +89,18 @@ def default_grid_search_values_instance_segmentation_with_decoder(
}


def _grid_search(
segmenter, gs_combinations, gt, image_name, result_path, fixed_generate_kwargs, verbose,
):
# TODO update all arguments and description
def grid_search_iteration(
segmenter,
gs_combinations,
gt: np.ndarray,
image_name: str,
fixed_generate_kwargs,
result_path: Optional[Union[str, os.PathLike]],
verbose: bool = False,
) -> pd.DataFrame:
"""
"""
net_list = []
for gs_kwargs in tqdm(gs_combinations, disable=not verbose):
generate_kwargs = gs_kwargs | fixed_generate_kwargs
Expand All @@ -111,16 +120,32 @@ def _grid_search(
img_gs_df = pd.concat(net_list)
img_gs_df.to_csv(result_path, index=False)

return img_gs_df


def _load_image(path, key, roi):
if key is None:
im = imageio.imread(path)
if roi is not None:
im = im[roi]
return im
with open_file(path, "r") as f:
im = f[key][:] if roi is None else f[key][roi]
return im


def run_instance_segmentation_grid_search(
segmenter: Union[AMGBase, InstanceSegmentationWithDecoder],
grid_search_values: Dict[str, List],
image_paths: List[Union[str, os.PathLike]],
gt_paths: List[Union[str, os.PathLike]],
embedding_dir: Union[str, os.PathLike],
result_dir: Union[str, os.PathLike],
embedding_dir: Optional[Union[str, os.PathLike]],
fixed_generate_kwargs: Optional[Dict[str, Any]] = None,
verbose_gs: bool = False,
image_key: Optional[str] = None,
gt_key: Optional[str] = None,
rois: Optional[Tuple[slice, ...]] = None,
) -> None:
"""Run grid search for automatic mask generation.
Expand All @@ -144,10 +169,15 @@ def run_instance_segmentation_grid_search(
grid_search_values: The grid search values for parameters of the `generate` function.
image_paths: The input images for the grid search.
gt_paths: The ground-truth segmentation for the grid search.
embedding_dir: Folder to cache the image embeddings.
result_dir: Folder to cache the evaluation results per image.
embedding_dir: Folder to cache the image embeddings.
fixed_generate_kwargs: Fixed keyword arguments for the `generate` method of the segmenter.
verbose_gs: Whether to run the gridsearch for individual images in a verbose mode.
image_key: Key for loading the image data from a more complex file format like HDF5.
If not given a simple image format like tif is assumed.
gt_key: Key for loading the ground-truth data from a more complex file format like HDF5.
If not given a simple image format like tif is assumed.
rois: Region of interests to resetrict the evaluation to.
"""
assert len(image_paths) == len(gt_paths)
fixed_generate_kwargs = {} if fixed_generate_kwargs is None else fixed_generate_kwargs
Expand All @@ -167,10 +197,10 @@ def run_instance_segmentation_grid_search(
]

os.makedirs(result_dir, exist_ok=True)
predictor = segmenter._predictor
predictor = getattr(segmenter, "_predictor", None)

for image_path, gt_path in tqdm(
zip(image_paths, gt_paths), desc="Run instance segmentation grid-search", total=len(image_paths)
for i, (image_path, gt_path) in tqdm(
enumerate(zip(image_paths, gt_paths)), desc="Run instance segmentation grid-search", total=len(image_paths)
):
image_name = Path(image_path).stem
result_path = os.path.join(result_dir, f"{image_name}.csv")
Expand All @@ -182,16 +212,20 @@ def run_instance_segmentation_grid_search(
assert os.path.exists(image_path), image_path
assert os.path.exists(gt_path), gt_path

image = imageio.imread(image_path)
gt = imageio.imread(gt_path)
image = _load_image(image_path, image_key, roi=None if rois is None else rois[i])
gt = _load_image(gt_path, gt_key, roi=None if rois is None else rois[i])

embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr")
image_embeddings = util.precompute_image_embeddings(predictor, image, embedding_path, ndim=2)
segmenter.initialize(image, image_embeddings)
if embedding_dir is None:
segmenter.initialize(image)
else:
assert predictor is not None
embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr")
image_embeddings = util.precompute_image_embeddings(predictor, image, embedding_path, ndim=2)
segmenter.initialize(image, image_embeddings)

_grid_search(
grid_search_iteration(
segmenter, gs_combinations, gt, image_name,
result_path=result_path, fixed_generate_kwargs=fixed_generate_kwargs, verbose=verbose_gs,
fixed_generate_kwargs=fixed_generate_kwargs, result_path=result_path, verbose=verbose_gs,
)


Expand Down
25 changes: 22 additions & 3 deletions micro_sam/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __getitem__(self, index):
return np.zeros(block_shape, dtype="float32")


# FIXME remove the shape argument
def mask_data_to_segmentation(
masks: List[Dict[str, Any]],
shape: Tuple[int, ...],
Expand All @@ -69,7 +70,9 @@ def mask_data_to_segmentation(
"""

masks = sorted(masks, key=(lambda x: x["area"]), reverse=True)
segmentation = np.zeros(shape[:2], dtype="uint32")
# we could also get the shape from the crop box
shape = next(iter(masks))["segmentation"].shape
segmentation = np.zeros(shape, dtype="uint32")

def require_numpy(mask):
return mask.cpu().numpy() if torch.is_tensor(mask) else mask
Expand Down Expand Up @@ -872,16 +875,32 @@ def initialize(
def _to_masks(self, segmentation, output_mode):
if output_mode != "binary_mask":
raise NotImplementedError

props = regionprops(segmentation)
crop_box = [0, segmentation.shape[1], 0, segmentation.shape[0]]
ndim = segmentation.ndim
assert ndim in (2, 3)

shape = segmentation.shape
if ndim == 2:
crop_box = [0, shape[1], 0, shape[0]]
else:
crop_box = [0, shape[2], 0, shape[1], 0, shape[0]]

# go from skimage bbox in format [y0, x0, y1, x1] to SAM format [x0, w, y0, h]
def to_bbox(bbox):
def to_bbox_2d(bbox):
y0, x0 = bbox[0], bbox[1]
w = bbox[3] - x0
h = bbox[2] - y0
return [x0, w, y0, h]

def to_bbox_3d(bbox):
z0, y0, x0 = bbox[0], bbox[1], bbox[2]
w = bbox[5] - x0
h = bbox[4] - y0
d = bbox[3] - y0
return [x0, w, y0, h, z0, d]

to_bbox = to_bbox_2d if ndim == 2 else to_bbox_3d
masks = [
{
"segmentation": segmentation == prop.label,
Expand Down

0 comments on commit be4a296

Please sign in to comment.