Skip to content

Commit

Permalink
Fix reference info path for zero-shot learning (#3354)
Browse files Browse the repository at this point in the history
* Apply #3265

* Updates for current develop

* Fix types, docstring, and comments

* Fix unit tests
  • Loading branch information
sungchul2 authored Apr 22, 2024
1 parent 313e537 commit 0388471
Show file tree
Hide file tree
Showing 11 changed files with 364 additions and 156 deletions.
50 changes: 34 additions & 16 deletions src/otx/algo/visual_prompting/zero_shot_segment_anything.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from __future__ import annotations

import logging as log
import os
import pickle
from collections import defaultdict
from copy import deepcopy
from itertools import product
Expand Down Expand Up @@ -625,15 +625,16 @@ def _decide_cascade_results(
class OTXZeroShotSegmentAnything(OTXZeroShotVisualPromptingModel):
"""Zero-Shot Visual Prompting model."""

def __init__(
def __init__( # noqa: PLR0913
self,
backbone: Literal["tiny_vit", "vit_b"],
label_info: LabelInfoTypes = NullLabelInfo(),
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = VisualPromptingMetricCallable,
torch_compile: bool = False,
root_reference_info: Path | str = "vpm_zsl_reference_infos",
reference_info_dir: Path | str = "reference_infos",
infer_reference_info_root: Path | str = "../.latest/train",
save_outputs: bool = True,
pixel_mean: list[float] | None = [123.675, 116.28, 103.53], # noqa: B006
pixel_std: list[float] | None = [58.395, 57.12, 57.375], # noqa: B006
Expand Down Expand Up @@ -669,7 +670,8 @@ def __init__(
)

self.save_outputs = save_outputs
self.root_reference_info: Path = Path(root_reference_info)
self.reference_info_dir: Path = Path(reference_info_dir)
self.infer_reference_info_root: Path = Path(infer_reference_info_root)

self.register_buffer("pixel_mean", Tensor(pixel_mean).view(-1, 1, 1), False)
self.register_buffer("pixel_std", Tensor(pixel_std).view(-1, 1, 1), False)
Expand Down Expand Up @@ -877,21 +879,37 @@ def initialize_reference_info(self) -> None:
self.register_buffer("reference_feats", torch.zeros(0, 1, self.model.embed_dim), False)
self.register_buffer("used_indices", torch.tensor([], dtype=torch.int64), False)

def _find_latest_reference_info(self, root: Path) -> str | None:
"""Find latest reference info to be used."""
if not Path.is_dir(root):
return None
if len(stamps := sorted(os.listdir(root), reverse=True)) > 0:
return stamps[0]
return None
def save_reference_info(self, default_root_dir: Path | str) -> None:
"""Save reference info."""
reference_info = {
"reference_feats": self.reference_feats,
"used_indices": self.used_indices,
}
# save reference info
path_reference_info: Path = Path(default_root_dir) / self.reference_info_dir / "reference_info.pt"
path_reference_info.parent.mkdir(parents=True, exist_ok=True)
# TODO (sungchul): ticket no. 139210
torch.save(reference_info, path_reference_info)
pickle.dump(
{k: v.numpy() for k, v in reference_info.items()},
path_reference_info.with_suffix(".pickle").open("wb"),
)
log.info(f"Saved reference info at {path_reference_info}.")

def load_latest_reference_info(self, device: str | torch.device = "cpu") -> bool:
def load_reference_info(self, default_root_dir: Path | str, device: str | torch.device = "cpu") -> bool:
"""Load latest reference info to be used."""
if (latest_stamp := self._find_latest_reference_info(self.root_reference_info)) is not None:
latest_reference_info = self.root_reference_info / latest_stamp / "reference_info.pt"
reference_info = torch.load(latest_reference_info)
_infer_reference_info_root: Path = (
self.infer_reference_info_root
if self.infer_reference_info_root == self.infer_reference_info_root.absolute()
else Path(default_root_dir) / self.infer_reference_info_root
)

if (
path_reference_info := _infer_reference_info_root / self.reference_info_dir / "reference_info.pt"
).is_file():
reference_info = torch.load(path_reference_info)
retval = True
log.info(f"reference info saved at {latest_reference_info} was successfully loaded.")
log.info(f"reference info saved at {path_reference_info} was successfully loaded.")
else:
reference_info = {}
retval = False
Expand Down
Loading

0 comments on commit 0388471

Please sign in to comment.