Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/vista3d' into vista3d-export
Browse files Browse the repository at this point in the history
Signed-off-by: Boris Fomitchev <[email protected]>
  • Loading branch information
borisfom committed Jul 31, 2024
2 parents 3e4a84b + 24567b9 commit a45deed
Show file tree
Hide file tree
Showing 8 changed files with 42 additions and 29 deletions.
34 changes: 32 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,45 @@ The **VISTA3D** is a foundation model trained systematically on 11,454 volumes e
### Out-of box automatic segmentation
For supported 127 classes, the model can perform highly accurate out-of-box segmentation. The fully automated process adopts a patch-based sliding-window inference and only requires a class prompt.
Compared to supervised segmentation models trained on each dataset separately, VISTA3D showed comparable out-of-box performances and strong generalizability ('VISTA3D auto' in Table.1).
<!-- <div align="center"> <img src="" width="800"/> </div> -->
<!-- <div align="center"> <img src="assets/imgs/everything.gif" width="800"/> </div> -->
<div align="center">
<figure>
<img
src="assets/imgs/everything.gif">
<figcaption> NIM Demo supports "Segment Everything" </figcaption>
</figure>
</div>



### Interactive editing
The interactive segmentation is based on user-provided clicks. Each click point will impact a local 3D patch. User can either effectively refine the automatic results with clicks ('VISTA3D auto+point' in Table.1) or simply provide a click without specifying the target class ('VISTA3D point' in Table.1) .
<!-- <div align="center"> <img src="" width="800"/> </div> -->
<div align="center">
<figure>
<img
src="assets/imgs/liver.gif">
<figcaption> Specify a supported class and edit the automatic results </figcaption>
</figure>
</div>
<div align="center">
<figure>
<img
src="assets/imgs/unspecified.gif">
<figcaption> Interactive supported class segmentation without specifying class </figcaption>
</figure>
</div>

### Zero-shot interactive segmentation
VISTA3D is built to produce visually plausible segmentations on previously unseen classes.
This capability makes the model even more flexible and accelerates practical segmentation data curation processes.
<div align="center">
<figure>
<img
src="assets/imgs/zeroshot.gif">
<figcaption> Add a new unseen class and do annotation </figcaption>
</figure>
</div>

### Fine-tuning
VISTA3D checkpoint showed improvements when finetuning in few-shot settings. Once a few annotated examples are provided, user can start finetune with the VISTA3D checkpoint.
Expand Down Expand Up @@ -98,7 +128,7 @@ Ask and answer questions on [MONAI VISTA's GitHub discussions tab](https://githu

## License

The codebase is under Apache 2.0 Licence.
The codebase is under Apache 2.0 Licence. The model weight is under special NVIDIA license.

## Reference

Expand Down
Binary file added assets/imgs/everything.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/imgs/liver.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/imgs/unspecified.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/imgs/zeroshot.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
5 changes: 4 additions & 1 deletion scripts/debugger.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,16 @@ def on_button_click(event, ax=ax):
print("-- segmenting ---")
self.generate_mask()
print("-- done ---")
print("-- Note: Point only prompts will only do 128 cubic segmentation, a cropping artefact will be observed. ---")
print("-- Note: Point without class will be treated as supported class, which has worse zero-shot ability. Try class > 132 to perform better zeroshot. ---")
print("-- Note: CTRL + Right Click will be adding negative points. ---")
print(
"-- Note: Click points on different foreground class will cause segmentation conflicts. Clear first. ---"
)
print(
"-- Note: Click points not matching class prompts will also cause confusion. ---"
)
print("-- Note: CTRL + Right Click will be adding negative points. ---")

self.update_slice(ax)
# self.point_start = len(self.clicked_points)

Expand Down
5 changes: 3 additions & 2 deletions vista3d/modeling/segresnetds.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def _forward(self, x: torch.Tensor) -> list[torch.Tensor]:

if self.head_module is not None:
outputs = self.head_module(outputs)

return outputs

def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
Expand Down Expand Up @@ -463,7 +464,7 @@ def is_valid_shape(self, x):

def _forward(
self, x: torch.Tensor, with_point, with_label
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
) -> Union[None, torch.Tensor, list[torch.Tensor]]:
if self.preprocess is not None:
x = self.preprocess(x)

Expand Down Expand Up @@ -522,7 +523,7 @@ def _forward(
return outputs, outputs_auto

def forward(
self, x: torch.Tensor, with_point=True, with_label=True, # **kwargs
self, x: torch.Tensor, with_point=True, with_label=True, **kwargs
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
return self._forward(x, with_point, with_label)

Expand Down
27 changes: 3 additions & 24 deletions vista3d/modeling/vista3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import torch
import torch.nn as nn
from monai.utils import optional_import
import time

from scripts.utils.trans_utils import convert_points_to_disc
from scripts.utils.trans_utils import get_largest_connected_component_mask as lcc
Expand All @@ -42,8 +41,7 @@ def __init__(self, image_encoder, class_head, point_head, feature_size):
)
self.auto_freeze = False
self.point_freeze = False
self.engine = None


def precompute_embedding(self, input_images):
"""precompute image embedding, require sliding window inference"""
raise NotImplementedError
Expand Down Expand Up @@ -205,8 +203,6 @@ def set_auto_grad(self, auto_freeze=False, point_freeze=False):
param.requires_grad = not point_freeze
self.point_freeze = point_freeze



def forward(
self,
input_images,
Expand Down Expand Up @@ -249,8 +245,6 @@ def forward(
val_point_sampler: function used to sample points from labels. This is only used for point-only evaluation.
"""
time00 = time.time()

image_size = input_images.shape[-3:]
device = input_images.device
if point_coords is None and class_vector is None:
Expand Down Expand Up @@ -313,16 +307,13 @@ def forward(
enable_all_tactics=True
)

time0 = time.time()
out, out_auto = self.image_encoder(
x=input_images,
with_point=point_coords is not None,
with_label=class_vector is not None,
)
# torch.cuda.synchronize()
# time1 = time.time()
# print(f"Encoder Time: {time.time() - time0}, shape : {input_images.shape}, point: {point_coords is not None}")
input_images = None
input_images = None

# force releasing memories that set to None
torch.cuda.empty_cache()
if class_vector is not None:
Expand All @@ -333,19 +324,12 @@ def forward(
dynamo=False,
verbose=False,
)
# time2 = time.time()
logits, _ = self.class_head(src=out_auto, class_vector=class_vector)
# torch.cuda.synchronize()
# print(f"Class Head Time: {time.time() - time2}")

if point_coords is not None:
# time3 = time.time()
point_logits = self.point_head(
out, point_coords, point_labels, class_vector=prompt_class
)
# torch.cuda.synchronize()
# print(f"Point Head Time: {time.time() - time3}")
# time4 = time.time()
if patch_coords is None:
logits = self.gaussian_combine(
logits,
Expand All @@ -360,8 +344,6 @@ def forward(
logits = self.connected_components_combine(
logits, point_logits, point_coords, point_labels, mapping_index
)
# torch.cuda.synchronize()
# print(f"Combine Time: {time.time() - time4}")
else:
logits = NINF_VALUE + torch.zeros(
[bs, 1, *image_size], device=device, dtype=out.dtype
Expand All @@ -378,9 +360,6 @@ def forward(
mapping_index,
)

# torch.cuda.synchronize()
# print(f"Total time : {time.time() - time00} shape : {logits.shape}")

if kwargs.get("keep_cache", False) and class_vector is None:
self.image_embeddings = out.detach()
return logits

0 comments on commit a45deed

Please sign in to comment.