Skip to content

Commit

Permalink
update to work with new compiler. Also upgraded to sam2.1
Browse files Browse the repository at this point in the history
  • Loading branch information
Alex Overbeek committed Feb 6, 2025
1 parent a518e58 commit a94ed7e
Show file tree
Hide file tree
Showing 61 changed files with 9,809 additions and 262 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ _C.*
output/*
checkpoints/*.pt
GroundingDINO/weights/*.pth
.whl
*.whl
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,18 @@ echo $CUDA_HOME
3. Install the Nakama Pyzed Wrapper as a package or clone it into your project:
[Nakama Pyzed Wrapper](https://bitbucket.org/ctw-bw/nakama_pyzed_wrapper/src/master/)


**NOTE:** Neural depth can give some errors when first optimizing during runtime. This is mostly the case when working in virtual environments. In that situation your best option is to force manually installation of the neural mode.

follow this link for the manual installation commands:
https://support.stereolabs.com/hc/en-us/articles/9747407795223-How-can-I-optimize-the-ZED-SDK-AI-models-manually

Additionaly there can be some old models that need to be removed in :
/usr/local/zed/resources --> remove those and try optimizing the models
---



## ZED Camera Prediction

The main pipeline code is located at:
Expand Down
60 changes: 44 additions & 16 deletions checkpoints/download_ckpts.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,54 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

# Use either wget or curl to download the checkpoints
if command -v wget &> /dev/null; then
CMD="wget"
elif command -v curl &> /dev/null; then
CMD="curl -L -O"
else
echo "Please install wget or curl to download the checkpoints."
exit 1
fi

# Define the URLs for SAM 2 checkpoints
# SAM2_BASE_URL="https://dl.fbaipublicfiles.com/segment_anything_2/072824"
# sam2_hiera_t_url="${SAM2_BASE_URL}/sam2_hiera_tiny.pt"
# sam2_hiera_s_url="${SAM2_BASE_URL}/sam2_hiera_small.pt"
# sam2_hiera_b_plus_url="${SAM2_BASE_URL}/sam2_hiera_base_plus.pt"
# sam2_hiera_l_url="${SAM2_BASE_URL}/sam2_hiera_large.pt"

# Define the URLs for the checkpoints
BASE_URL="https://dl.fbaipublicfiles.com/segment_anything_2/072824/"
sam2_hiera_t_url="${BASE_URL}sam2_hiera_tiny.pt"
sam2_hiera_s_url="${BASE_URL}sam2_hiera_small.pt"
sam2_hiera_b_plus_url="${BASE_URL}sam2_hiera_base_plus.pt"
sam2_hiera_l_url="${BASE_URL}sam2_hiera_large.pt"
# Download each of the four checkpoints using wget
# echo "Downloading sam2_hiera_tiny.pt checkpoint..."
# $CMD $sam2_hiera_t_url || { echo "Failed to download checkpoint from $sam2_hiera_t_url"; exit 1; }

# echo "Downloading sam2_hiera_small.pt checkpoint..."
# $CMD $sam2_hiera_s_url || { echo "Failed to download checkpoint from $sam2_hiera_s_url"; exit 1; }

# Download each of the four checkpoints using wget
echo "Downloading sam2_hiera_tiny.pt checkpoint..."
wget $sam2_hiera_t_url || { echo "Failed to download checkpoint from $sam2_hiera_t_url"; exit 1; }
# echo "Downloading sam2_hiera_base_plus.pt checkpoint..."
# $CMD $sam2_hiera_b_plus_url || { echo "Failed to download checkpoint from $sam2_hiera_b_plus_url"; exit 1; }

# echo "Downloading sam2_hiera_large.pt checkpoint..."
# $CMD $sam2_hiera_l_url || { echo "Failed to download checkpoint from $sam2_hiera_l_url"; exit 1; }

# Define the URLs for SAM 2.1 checkpoints
SAM2p1_BASE_URL="https://dl.fbaipublicfiles.com/segment_anything_2/092824"
sam2p1_hiera_t_url="${SAM2p1_BASE_URL}/sam2.1_hiera_tiny.pt"
sam2p1_hiera_s_url="${SAM2p1_BASE_URL}/sam2.1_hiera_small.pt"
sam2p1_hiera_b_plus_url="${SAM2p1_BASE_URL}/sam2.1_hiera_base_plus.pt"
sam2p1_hiera_l_url="${SAM2p1_BASE_URL}/sam2.1_hiera_large.pt"

# SAM 2.1 checkpoints
echo "Downloading sam2.1_hiera_tiny.pt checkpoint..."
$CMD $sam2p1_hiera_t_url || { echo "Failed to download checkpoint from $sam2p1_hiera_t_url"; exit 1; }

echo "Downloading sam2_hiera_small.pt checkpoint..."
wget $sam2_hiera_s_url || { echo "Failed to download checkpoint from $sam2_hiera_s_url"; exit 1; }
echo "Downloading sam2.1_hiera_small.pt checkpoint..."
$CMD $sam2p1_hiera_s_url || { echo "Failed to download checkpoint from $sam2p1_hiera_s_url"; exit 1; }

echo "Downloading sam2_hiera_base_plus.pt checkpoint..."
wget $sam2_hiera_b_plus_url || { echo "Failed to download checkpoint from $sam2_hiera_b_plus_url"; exit 1; }
echo "Downloading sam2.1_hiera_base_plus.pt checkpoint..."
$CMD $sam2p1_hiera_b_plus_url || { echo "Failed to download checkpoint from $sam2p1_hiera_b_plus_url"; exit 1; }

echo "Downloading sam2_hiera_large.pt checkpoint..."
wget $sam2_hiera_l_url || { echo "Failed to download checkpoint from $sam2_hiera_l_url"; exit 1; }
echo "Downloading sam2.1_hiera_large.pt checkpoint..."
$CMD $sam2p1_hiera_l_url || { echo "Failed to download checkpoint from $sam2p1_hiera_l_url"; exit 1; }

echo "All checkpoints are downloaded successfully."
echo "All checkpoints are downloaded successfully."
23 changes: 13 additions & 10 deletions configurations/sam2_configs/sam2_hiera_s.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ model:
use_mask_input_as_output_without_sam: true
# Memory
directly_add_no_mem_embed: true
no_obj_embed_spatial: true
# use high-resolution feature map in the SAM mask decoder
use_high_res_features_in_sam: true
# output 3 masks on the first click on initial conditioning frames
Expand All @@ -100,7 +101,9 @@ model:
iou_prediction_use_sigmoid: True
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
use_obj_ptrs_in_encoder: true
add_tpos_enc_to_obj_ptrs: false
add_tpos_enc_to_obj_ptrs: true
proj_tpos_enc_in_obj_ptrs: true
use_signed_tpos_enc_to_obj_ptrs: true
only_obj_ptrs_in_the_past_for_eval: true
# object occlusion prediction
pred_obj_scores: true
Expand All @@ -112,13 +115,13 @@ model:
multimask_min_pt_num: 0
multimask_max_pt_num: 1
use_mlp_for_obj_ptr_proj: true

# # Compilation flag
# compile_image_encoder: false
# Compilation flag
# compile_image_encoder: False

# compilation flags
compile_image_encoder: True
compile_memory_encoder: True
compile_memory_attention: True
compile_prompt_encoder: False
compile_mask_decoder: False

# compilation settings
compile_image_encoder: true
compile_memory_encoder: false
compile_memory_attention: false
compile_prompt_encoder: false
compile_mask_decoder: false
7 changes: 4 additions & 3 deletions configurations/sam2_zed_small.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@ grounding_dino:
checkpoint_path: "./GroundingDINO/weights/groundingdino_swint_ogc.pth"

sam2:
checkpoint: "./checkpoints/sam2_hiera_small.pt"
# vos_optimized: true
checkpoint: "./checkpoints/sam2.1_hiera_small.pt"
model_cfg: "sam2_configs/sam2_hiera_s.yaml"

camera:
connection_type: "svo"
connection_type: "id"
serial_number: 0
svo_input_filename: "./output/output.svo2"
sender_ip: "127.0.0.1"
Expand All @@ -29,7 +30,7 @@ camera:


depth:
refine_depth: true
refine_depth: false
max_occlusion_percentage: 0.6

results:
Expand Down
21 changes: 0 additions & 21 deletions configurations/sam2_zed_tiny.yaml

This file was deleted.

Binary file modified output/norm_depth.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified output/refined_depth.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified output/test.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 3 additions & 1 deletion sam2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,7 @@
# LICENSE file in the root directory of this source tree.

from hydra import initialize_config_module
from hydra.core.global_hydra import GlobalHydra

initialize_config_module("sam2_configs", version_base="1.2")
if not GlobalHydra.instance().is_initialized():
initialize_config_module("sam2", version_base="1.2")
183 changes: 183 additions & 0 deletions sam2/benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import os
import time

import numpy as np
import torch
from tqdm import tqdm
import cv2

from sam2.build_sam import build_sam2_camera_predictor

# Only cuda supported
assert torch.cuda.is_available()
device = torch.device("cuda")

# import torch._dynamo
# torch._dynamo.config.suppress_errors = True

torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

# Config and checkpoint
sam2_checkpoint = "checkpoints/sam2.1_hiera_tiny.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_t_512"

# Build video predictor with vos_optimized=True setting
predictor = build_sam2_camera_predictor(
model_cfg, sam2_checkpoint, device=device, vos_optimized=True
)

#################################
cap = cv2.VideoCapture("notebooks/videos/aquarium/aquarium.mp4")
num_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT)
ret, frame = cap.read()
width, height = frame.shape[:2][::-1]

predictor.load_first_frame(frame)
if_init = True

using_point = False # if True, we use point prompt
using_box = True # if True, we use point prompt
using_mask= False # if True, we use mask prompt

ann_frame_idx = 0 # the frame index we interact with
ann_obj_id = (
1 # give a unique id to each object we interact with (it can be any integers)
)
# Let's add a positive click at (x, y) = (210, 350) to get started

# using point prompt
points = np.array([[670, 247]], dtype=np.float32)
# for labels, `1` means positive click and `0` means negative click
labels = np.array([1], dtype=np.int32)
bbox = np.array([[600, 214], [765, 286]], dtype=np.float32)

frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)


if using_point:
_, out_obj_ids, out_mask_logits = predictor.add_new_prompt(
frame_idx=ann_frame_idx,
obj_id=ann_obj_id,
points=points,
labels=labels,
)

elif using_box:
_, out_obj_ids, out_mask_logits = predictor.add_new_prompt(
frame_idx=ann_frame_idx, obj_id=ann_obj_id, bbox=bbox
)

elif using_mask:
mask_img_path="masks/aquarium/aquarium_mask.png"
mask = cv2.imread(mask_img_path, cv2.IMREAD_GRAYSCALE)
mask = mask / 255

_, out_obj_ids, out_mask_logits = predictor.add_new_mask(
frame_idx=ann_frame_idx, obj_id=ann_obj_id, mask=mask
)

vis_gap = 30

# Number of runs, warmup etc
warm_up, runs = 5, 25
verbose = True
total, count = 0, 0
torch.cuda.empty_cache()

with torch.autocast("cuda", torch.bfloat16):
with torch.inference_mode():
for i in tqdm(range(runs), disable=not verbose, desc="Benchmarking"):
cap = cv2.VideoCapture("notebooks/videos/aquarium/aquarium.mp4")
start = time.time()
for _ in tqdm(range(int(num_frames)), desc="Tracking", leave=False, total=num_frames):

ret, frame = cap.read()
ann_frame_idx += 1
if not ret:
break
width, height = frame.shape[:2][::-1]

out_obj_ids, out_mask_logits = predictor.track(frame)

cap.release()
end = time.time()
total += end - start
count += 1
if i == warm_up - 1:
print("Warmup FPS: ", count * num_frames / total)
total = 0
count = 0

print("FPS: ", count * num_frames / total)



##########################################

# # Initialize with video
# video_dir = "notebooks/videos/bedroom"
# # scan all the JPEG frame names in this directory
# frame_names = [
# p
# for p in os.listdir(video_dir)
# if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
# ]
# frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
# inference_state = predictor.init_state(video_path=video_dir)


# # Number of runs, warmup etc
# warm_up, runs = 5, 25
# verbose = True
# num_frames = len(frame_names)
# total, count = 0, 0
# torch.cuda.empty_cache()

# # We will select an object with a click.
# # See video_predictor_example.ipynb for more detailed explanation
# ann_frame_idx, ann_obj_id = 0, 1
# # Add a positive click at (x, y) = (210, 350)
# # For labels, `1` means positive click
# points = np.array([[210, 350]], dtype=np.float32)
# labels = np.array([1], np.int32)

# _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
# inference_state=inference_state,
# frame_idx=ann_frame_idx,
# obj_id=ann_obj_id,
# points=points,
# labels=labels,
# )

# # Warmup and then average FPS over several runs
# with torch.autocast("cuda", torch.bfloat16):
# with torch.inference_mode():
# for i in tqdm(range(runs), disable=not verbose, desc="Benchmarking"):
# start = time.time()
# # Start tracking
# for (
# out_frame_idx,
# out_obj_ids,
# out_mask_logits,
# ) in predictor.propagate_in_video(inference_state):
# pass

# end = time.time()
# total += end - start
# count += 1
# if i == warm_up - 1:
# print("Warmup FPS: ", count * num_frames / total)
# total = 0
# count = 0

# print("FPS: ", count * num_frames / total)
Loading

0 comments on commit a94ed7e

Please sign in to comment.