Skip to content

Commit

Permalink
feat(uav): SAM based segmentation to fix too large boxes on some loca…
Browse files Browse the repository at this point in the history
…lizations
  • Loading branch information
danellecline committed Jan 22, 2025
1 parent effd596 commit 69e03f0
Show file tree
Hide file tree
Showing 2 changed files with 221 additions and 0 deletions.
36 changes: 36 additions & 0 deletions aipipeline/projects/uav/load_sam_loc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Utility to load the bounding box data from a CSV file and update the bounding box data in Tator
import os
from pathlib import Path
from aipipeline.db_utils import init_api_project
import pandas as pd

out_csv = Path("/Users/dcline/aidata/datasets/Baseline/crops/BirdOut") / "birdbox.csv"
host = "mantis.shore.mbari.org"
TATOR_TOKEN=os.environ["TATOR_TOKEN"]
project_name="901902-uavs"
api, project_id = init_api_project(host=host, token=TATOR_TOKEN, project=project_name)

df = pd.read_csv(out_csv)

image_width = 7952
image_height = 5304

# Iterate over the rows and fix the bounding box data
for i, row in df.iterrows():
try:
localization = api.get_localization(row["id"])
except Exception as e:
print(f"Error getting localization for {row['image']}: {e}")
continue
x = row["x"]/image_width + localization.x
y = row["y"]/image_height + localization.y
w = row["width"] / image_width
h = row["height"] / image_height
print(f"Updating {row['image']} to {x},{y},{w},{h}")
update = {'width': w, 'height': h, 'x': x, 'y': y}
try:
api.update_localization(row['id'], localization_update=update)
print(f"Updated {row['image']}")
except Exception as e:
print(f"Error updating {row['image']}: {e}")
continue
185 changes: 185 additions & 0 deletions aipipeline/projects/uav/run_sam_loc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
# Utility script to run SAM on images and save the bounding box to a CSV file

from pathlib import Path

import cv2
import numpy as np
from ultralytics import SAM, FastSAM

# Load the SAM model
model_sam21l = SAM(model="sam2.1_l.pt")
model_sam21t = SAM(model="sam2.1_t.pt")
model_saml = SAM(model="sam_l.pt")
model_sam2t = SAM(model="sam2_t.pt")
model_samm = SAM(model="mobile_sam.pt")
model_fast = FastSAM(model="FastSAM-x.pt")

display = False # Set to True to display the images - usefull for debugging
sift = cv2.SIFT_create()

# image_path = Path("/Users/dcline/Dropbox/data/UAV/crops/BirdSelect/")
# image_path = Path("/Users/dcline/Dropbox/data/UAV/crops/BirdHard/")
image_path = Path("/Users/dcline/aidata/datasets/Baseline/crops/Bird")
out_path = Path("/Users/dcline/aidata/datasets/Baseline/crops/BirdOut")
# out_path = Path("/Users/dcline/Dropbox/data/UAV/crops/BirdHardOut/")
out_path.mkdir(exist_ok=True)

out_csv = out_path / "birdbox.csv"

# The padding for the bounding box
padding = 10

with out_csv.open("w") as f:
f.write("id,image,x,y,width,height\n")
for im in image_path.glob("*.jpg"):
image = cv2.imread(im.as_posix())
db_id = int(im.stem)

# Get the mean color of the image and skip if there is too much color variation as the segmentation may not be accurate
mean_color = np.mean(image, axis=(0, 1))
std_color = np.mean(np.std(image, axis=(0, 1)))

if std_color > 30:
continue

# Skip if the image is too small
if image.shape[0] < 100 or image.shape[1] < 100:
continue

# Threshold in HSV color space
hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
# # Calculate the mean and standard deviation of the saturation channel
sat = hsv[:, :, 1]
val = hsv[:, :, 2]
hue = hsv[:, :, 0]
block_size = 13
# Threshold the saliency map using gradient thresholding
binary_mask = cv2.adaptiveThreshold(
sat.astype(np.uint8),
255, # Max pixel value
cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
cv2.THRESH_BINARY,
block_size, # Block size (size of the local neighborhood)
9 # Constant subtracted from the mean
)

# Set the mask to 0 if the saturation is below a threshold
binary_mask[sat < 40] = 0

# Invert the mask
binary_mask = 255 - binary_mask

contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

# Remove all small contours
for contour in contours:
area = cv2.contourArea(contour)
if area < 300:
cv2.drawContours(binary_mask, [contour], -1, 0, -1)

# Remove contours close in hue to the background
for contour in contours:
x, y, w, h = cv2.boundingRect(contour)
roi = hue[y:y+h, x:x+w]
mean_hue = np.mean(roi)
if mean_hue < 60 or mean_hue > 100:
cv2.drawContours(binary_mask, [contour], -1, 0, -1)

# Remove contours too close to the edge
for contour in contours:
x, y, w, h = cv2.boundingRect(contour)
if x < 10 or y < 10 or x + w > image.shape[1] - 10 or y + h > image.shape[0] - 10:
cv2.drawContours(binary_mask, [contour], -1, 0, -1)

# Invert the mask
output_mask = 255 - binary_mask

# Create a new image setting the masked region to black
image_masked = image.copy()
image_masked[output_mask == 255] = 0
keypoints, _ = sift.detectAndCompute(image_masked, None)

# Filter keypoints based on the binary mask
filtered_keypoints = []

for kp in keypoints:
x, y = int(kp.pt[0]), int(kp.pt[1]) # Keypoint coordinates
if binary_mask[y, x] > 0: # Check if the keypoint is in the mask region
filtered_keypoints.append(kp)

# Convert cv2 keypoints to numpy array
keypoints = np.array([[int(kp.pt[0]),int(kp.pt[1])] for kp in keypoints])

# Display keypoints on the image at the chosen indices
image_kp = image.copy()
for kp in keypoints:
cv2.circle(image_kp, kp, 2, (0, 0, 255), -1)

if len(keypoints) == 0:
print("No keypoints detected.")
continue

# Gaussian blur the image
image_blur = cv2.GaussianBlur(image, (9, 9), 0)

# Run SAM segmentation
results1 = model_sam2t.predict(image_blur, points=keypoints, labels=[1] * len(keypoints), device="cpu")
results2 = model_sam21l.predict(image_blur, points=keypoints, labels=[1] * len(keypoints), device="cpu")
results = results1

# Get the largest bounding box that has at least 10% coverage in the masked region
bbox_best = None
largest = 0
for result in results:
bboxes = result.boxes.xywh
for bbox in bboxes:
bbox = bbox.tolist()
x, y, w, h = bbox
w = int(w)
h = int(h)
x = int(x)
y = int(y)

image_crop = image[y:y+h, x:x+w] # Crop the image
mask_crop = output_mask[y:y+h, x:x+w] # Crop the mask
mask_area = np.sum(mask_crop) / 255
bbox_area = w * h
coverage = mask_area / bbox_area
print(f"{im} {x},{y},{w}x{h} {mask_area} {bbox_area} {coverage}")

# If the width or height is within 1 pixel of the image size skip
# this is either a well cropped image or the background segment
if w >= image.shape[1] - 1 or h >= image.shape[0] - 1 or coverage < 0.3:
print(f'Skipping {w}x{h} bbox')
continue

x = int(x) - w // 2
y = int(y) - h // 2
if w * h > largest and w > 10 and h > 10:
largest = w * h
bbox_best = (x, y, w, h)
# add padding to the bounding box
bbox_best = (bbox_best[0] - padding, bbox_best[1] - padding, bbox_best[2] + 2 * padding, bbox_best[3] + 2 * padding)
# clip the bounding box to the image size
bbox_best = (max(0, bbox_best[0]), max(0, bbox_best[1]), min(image.shape[1], bbox_best[2]), min(image.shape[0], bbox_best[3]))

if bbox_best:
x = int(bbox_best[0])
y = int(bbox_best[1])
w = int(bbox_best[2])
h = int(bbox_best[3])
cv2.rectangle(image, (x, y), (x + w, y + h), (0, 255, 0), 1)
title = f"image_box {im.stem} {x},{y},{w}x{h} {image.shape[0]}x{image.shape[1]}"
else:
title = f"image_kp {im.stem} {image.shape[0]}x{image.shape[1]}"

im_show = np.concatenate((cv2.cvtColor(output_mask, cv2.COLOR_GRAY2BGR), image, image_kp,), axis=1)
if display:
cv2.imshow(title, im_show)
cv2.waitKey(0)

out_file = out_path / im.name

if bbox_best:
cv2.imwrite(out_file.as_posix(), im_show)
f.write(f"{db_id},{im},{x},{y},{w},{h}\n")

0 comments on commit 69e03f0

Please sign in to comment.