From 031d9fb1967de4738ffd7ce8ff35e766e93a6bfc Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Wed, 29 Jan 2025 14:56:08 +0100 Subject: [PATCH] Minor update to support peft kwargs in qualitative comparsion scripts (#845) Make qualitative comparison scripts flexible! --- micro_sam/evaluation/model_comparison.py | 38 +++++++++++++++--------- 1 file changed, 24 insertions(+), 14 deletions(-) diff --git a/micro_sam/evaluation/model_comparison.py b/micro_sam/evaluation/model_comparison.py index fafe7a6f..d401f79d 100644 --- a/micro_sam/evaluation/model_comparison.py +++ b/micro_sam/evaluation/model_comparison.py @@ -6,7 +6,7 @@ from tqdm import tqdm from pathlib import Path from functools import partial -from typing import Optional, Union +from typing import Optional, Union, Dict, Any import h5py import numpy as np @@ -124,6 +124,9 @@ def generate_data_for_model_comparison( checkpoint1: Optional[Union[str, os.PathLike]] = None, checkpoint2: Optional[Union[str, os.PathLike]] = None, checkpoint3: Optional[Union[str, os.PathLike]] = None, + peft_kwargs1: Optional[Dict[str, Any]] = None, + peft_kwargs2: Optional[Dict[str, Any]] = None, + peft_kwargs3: Optional[Dict[str, Any]] = None, ) -> None: """Generate samples for qualitative model comparison. @@ -149,11 +152,11 @@ def generate_data_for_model_comparison( get_box_prompts=True, ) - predictor1 = util.get_sam_model(model_type=model_type1, checkpoint_path=checkpoint1) - predictor2 = util.get_sam_model(model_type=model_type2, checkpoint_path=checkpoint2) + predictor1 = util.get_sam_model(model_type=model_type1, checkpoint_path=checkpoint1, peft_kwargs=peft_kwargs1) + predictor2 = util.get_sam_model(model_type=model_type2, checkpoint_path=checkpoint2, peft_kwargs=peft_kwargs2) if model_type3 is not None: - predictor3 = util.get_sam_model(model_type=model_type3, checkpoint_path=checkpoint3) + predictor3 = util.get_sam_model(model_type=model_type3, checkpoint_path=checkpoint3, peft_kwargs=peft_kwargs3) else: predictor3 = None @@ -262,8 +265,8 @@ def _overlay_points(im, prompt, radius): def _compare_eval( - f, eval_result, advantage_column, n_images_per_sample, prefix, - sample_name, plot_folder, point_radius, outline_dilation, have_model3, + f, eval_result, advantage_column, n_images_per_sample, prefix, sample_name, + plot_folder, point_radius, outline_dilation, have_model3, enhance_image, ): result = eval_result.sort_values(advantage_column, ascending=False).iloc[:n_images_per_sample] n_rows = result.shape[0] @@ -313,7 +316,11 @@ def plot_ax(axis, i, row): else: prompt = (g.attrs["point_coords"] - offset, g.attrs["point_labels"]) - im = _enhance_image(image[bb]) + if enhance_image: + im = _enhance_image(image[bb]) + else: + im = image[bb] + gt, mask1, mask2 = gt[bb], mask1[bb], mask2[bb] if have_model3: @@ -364,7 +371,7 @@ def plot_ax(axis, i, row): def _compare_prompts( f, prefix, n_images_per_sample, min_size, sample_name, plot_folder, - point_radius, outline_dilation, have_model3, + point_radius, outline_dilation, have_model3, enhance_image, ): box_eval = _evaluate_samples(f, prefix, min_size) if plot_folder is None: @@ -376,16 +383,16 @@ def _compare_prompts( os.makedirs(plot_folder2, exist_ok=True) _compare_eval( f, box_eval, "advantage1", n_images_per_sample, prefix, sample_name, plot_folder1, - point_radius, outline_dilation, have_model3, + point_radius, outline_dilation, have_model3, enhance_image, ) _compare_eval( f, box_eval, "advantage2", n_images_per_sample, prefix, sample_name, plot_folder2, - point_radius, outline_dilation, have_model3, + point_radius, outline_dilation, have_model3, enhance_image, ) def _compare_models( - path, n_images_per_sample, min_size, plot_folder, point_radius, outline_dilation, have_model3, + path, n_images_per_sample, min_size, plot_folder, point_radius, outline_dilation, have_model3, enhance_image, ): sample_name = Path(path).stem with h5py.File(path, "r") as f: @@ -396,11 +403,11 @@ def _compare_models( plot_folder_box = os.path.join(plot_folder, "box") _compare_prompts( f, "points", n_images_per_sample, min_size, sample_name, plot_folder_points, - point_radius, outline_dilation, have_model3, + point_radius, outline_dilation, have_model3, enhance_image, ) _compare_prompts( f, "box", n_images_per_sample, min_size, sample_name, plot_folder_box, - point_radius, outline_dilation, have_model3, + point_radius, outline_dilation, have_model3, enhance_image, ) @@ -412,6 +419,7 @@ def model_comparison( point_radius: int = 4, outline_dilation: int = 0, have_model3=False, + enhance_image=True, ) -> None: """Create images for a qualitative model comparision. @@ -422,11 +430,13 @@ def model_comparison( plot_folder: The folder where to save the plots. If not given the plots will be displayed. point_radius: The radius of the point overlay. outline_dilation: The dilation factor of the outline overlay. + enhance_image: Whether to enhance the input image. """ files = glob(os.path.join(output_folder, "*.h5")) for path in tqdm(files): _compare_models( - path, n_images_per_sample, min_size, plot_folder, point_radius, outline_dilation, have_model3, + path, n_images_per_sample, min_size, plot_folder, point_radius, + outline_dilation, have_model3, enhance_image, )