Skip to content

Commit

Permalink
Minor update to support peft kwargs in qualitative comparsion scripts (
Browse files Browse the repository at this point in the history
…#845)

Make qualitative comparison scripts flexible!
  • Loading branch information
anwai98 authored Jan 29, 2025
1 parent d930618 commit 031d9fb
Showing 1 changed file with 24 additions and 14 deletions.
38 changes: 24 additions & 14 deletions micro_sam/evaluation/model_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from tqdm import tqdm
from pathlib import Path
from functools import partial
from typing import Optional, Union
from typing import Optional, Union, Dict, Any

import h5py
import numpy as np
Expand Down Expand Up @@ -124,6 +124,9 @@ def generate_data_for_model_comparison(
checkpoint1: Optional[Union[str, os.PathLike]] = None,
checkpoint2: Optional[Union[str, os.PathLike]] = None,
checkpoint3: Optional[Union[str, os.PathLike]] = None,
peft_kwargs1: Optional[Dict[str, Any]] = None,
peft_kwargs2: Optional[Dict[str, Any]] = None,
peft_kwargs3: Optional[Dict[str, Any]] = None,
) -> None:
"""Generate samples for qualitative model comparison.
Expand All @@ -149,11 +152,11 @@ def generate_data_for_model_comparison(
get_box_prompts=True,
)

predictor1 = util.get_sam_model(model_type=model_type1, checkpoint_path=checkpoint1)
predictor2 = util.get_sam_model(model_type=model_type2, checkpoint_path=checkpoint2)
predictor1 = util.get_sam_model(model_type=model_type1, checkpoint_path=checkpoint1, peft_kwargs=peft_kwargs1)
predictor2 = util.get_sam_model(model_type=model_type2, checkpoint_path=checkpoint2, peft_kwargs=peft_kwargs2)

if model_type3 is not None:
predictor3 = util.get_sam_model(model_type=model_type3, checkpoint_path=checkpoint3)
predictor3 = util.get_sam_model(model_type=model_type3, checkpoint_path=checkpoint3, peft_kwargs=peft_kwargs3)
else:
predictor3 = None

Expand Down Expand Up @@ -262,8 +265,8 @@ def _overlay_points(im, prompt, radius):


def _compare_eval(
f, eval_result, advantage_column, n_images_per_sample, prefix,
sample_name, plot_folder, point_radius, outline_dilation, have_model3,
f, eval_result, advantage_column, n_images_per_sample, prefix, sample_name,
plot_folder, point_radius, outline_dilation, have_model3, enhance_image,
):
result = eval_result.sort_values(advantage_column, ascending=False).iloc[:n_images_per_sample]
n_rows = result.shape[0]
Expand Down Expand Up @@ -313,7 +316,11 @@ def plot_ax(axis, i, row):
else:
prompt = (g.attrs["point_coords"] - offset, g.attrs["point_labels"])

im = _enhance_image(image[bb])
if enhance_image:
im = _enhance_image(image[bb])
else:
im = image[bb]

gt, mask1, mask2 = gt[bb], mask1[bb], mask2[bb]

if have_model3:
Expand Down Expand Up @@ -364,7 +371,7 @@ def plot_ax(axis, i, row):

def _compare_prompts(
f, prefix, n_images_per_sample, min_size, sample_name, plot_folder,
point_radius, outline_dilation, have_model3,
point_radius, outline_dilation, have_model3, enhance_image,
):
box_eval = _evaluate_samples(f, prefix, min_size)
if plot_folder is None:
Expand All @@ -376,16 +383,16 @@ def _compare_prompts(
os.makedirs(plot_folder2, exist_ok=True)
_compare_eval(
f, box_eval, "advantage1", n_images_per_sample, prefix, sample_name, plot_folder1,
point_radius, outline_dilation, have_model3,
point_radius, outline_dilation, have_model3, enhance_image,
)
_compare_eval(
f, box_eval, "advantage2", n_images_per_sample, prefix, sample_name, plot_folder2,
point_radius, outline_dilation, have_model3,
point_radius, outline_dilation, have_model3, enhance_image,
)


def _compare_models(
path, n_images_per_sample, min_size, plot_folder, point_radius, outline_dilation, have_model3,
path, n_images_per_sample, min_size, plot_folder, point_radius, outline_dilation, have_model3, enhance_image,
):
sample_name = Path(path).stem
with h5py.File(path, "r") as f:
Expand All @@ -396,11 +403,11 @@ def _compare_models(
plot_folder_box = os.path.join(plot_folder, "box")
_compare_prompts(
f, "points", n_images_per_sample, min_size, sample_name, plot_folder_points,
point_radius, outline_dilation, have_model3,
point_radius, outline_dilation, have_model3, enhance_image,
)
_compare_prompts(
f, "box", n_images_per_sample, min_size, sample_name, plot_folder_box,
point_radius, outline_dilation, have_model3,
point_radius, outline_dilation, have_model3, enhance_image,
)


Expand All @@ -412,6 +419,7 @@ def model_comparison(
point_radius: int = 4,
outline_dilation: int = 0,
have_model3=False,
enhance_image=True,
) -> None:
"""Create images for a qualitative model comparision.
Expand All @@ -422,11 +430,13 @@ def model_comparison(
plot_folder: The folder where to save the plots. If not given the plots will be displayed.
point_radius: The radius of the point overlay.
outline_dilation: The dilation factor of the outline overlay.
enhance_image: Whether to enhance the input image.
"""
files = glob(os.path.join(output_folder, "*.h5"))
for path in tqdm(files):
_compare_models(
path, n_images_per_sample, min_size, plot_folder, point_radius, outline_dilation, have_model3,
path, n_images_per_sample, min_size, plot_folder, point_radius,
outline_dilation, have_model3, enhance_image,
)


Expand Down

0 comments on commit 031d9fb

Please sign in to comment.