Skip to content

Commit

Permalink
Merge pull request #288 from leggedrobotics/dev/nature_hiking/feat/ad…
Browse files Browse the repository at this point in the history
…aptive_anomaly

Adding dynamic thresholding with GMM model working
  • Loading branch information
RobinSchmid7 authored Jan 25, 2024
2 parents dfa3106 + 77da852 commit 53ec625
Showing 1 changed file with 46 additions and 12 deletions.
58 changes: 46 additions & 12 deletions wild_visual_navigation_ros/scripts/wvn_feature_extractor_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from wild_visual_navigation.utils import ConfidenceGenerator
from wild_visual_navigation.learning.utils import AnomalyLoss
from wild_visual_navigation_msgs.msg import ChannelInfo
from sklearn.mixture import GaussianMixture

import rospy
from sensor_msgs.msg import Image, CameraInfo, CompressedImage
Expand All @@ -26,6 +27,7 @@
from termcolor import colored
import signal
import sys
import cv2


class WvnFeatureExtractor:
Expand All @@ -47,8 +49,6 @@ def __init__(self):

self.model = get_model(self.exp_cfg["model"]).to(self.device)
self.model.eval()



if not self.anomaly_detection:
self.confidence_generator = ConfidenceGenerator(
Expand Down Expand Up @@ -290,6 +290,7 @@ def image_callback(self, image_msg: Image, cam: str): # info_msg: CameraInfo
out_trav = prediction.reshape(H, W, -1)[:, :, 0]

# Publish traversability
# TODO: edit this part for dynamic thresholding, make this available optional
if self.scale_traversability:
# Apply piecewise linear scaling 0->0; threshold->0.5; 1->1
traversability = out_trav.clone()
Expand All @@ -301,8 +302,41 @@ def image_callback(self, image_msg: Image, cam: str): # info_msg: CameraInfo
traversability[~m] *= 0.5 / (1 - self.traversability_threshold)
traversability[~m] += 0.5
traversability = traversability.clip(0, 1)
# TODO Check if this was a bug
out_trav = traversability

# # TODO: make this optional
if False:
loss_reco = F.mse_loss(prediction[:, 1:], data.x, reduction="none").mean(dim=1)

loss_reco_flat = loss_reco.flatten().detach().cpu().numpy().reshape(-1, 1)

# Fit a 1D GMM with k=2
gmm_k2 = GaussianMixture(n_components=2, random_state=0)
gmm_k2.fit(loss_reco_flat)

gmm_labels = gmm_k2.predict(loss_reco_flat)
# Assume the cluster with the larger mean loss is the unconfident one
unconfident_cluster = gmm_k2.means_.argmax()

# Create a mask from GMM predictions
unconf_mask = (gmm_labels == unconfident_cluster).reshape(H, W)
unconf_mask = torch.from_numpy(unconf_mask).to(torch_image.device)
conf_mask = ~unconf_mask

conf_mask_np = conf_mask.cpu().numpy()

conf_mask_np = conf_mask_np.astype(np.uint8) * 255

# Save the confident mask as an image
# cv2.imwrite("/home/rschmid/conf_mask.jpg", conf_mask_np.astype(np.uint8) * 255)

# Publish the confident mask
conf_mask_msg = rc.numpy_to_ros_image(conf_mask_np, "passthrough")
conf_mask_msg.header = image_msg.header
conf_mask_msg.width = conf_mask_np.shape[0]
conf_mask_msg.height = conf_mask_np.shape[1]
self.camera_handler[cam]["conf_pub"].publish(conf_mask_msg)

else:
loss, loss_aux, trav = self.traversability_loss(None, prediction)

Expand Down Expand Up @@ -336,15 +370,15 @@ def image_callback(self, image_msg: Image, cam: str): # info_msg: CameraInfo
self.camera_handler[cam]["input_pub"].publish(msg)

# Publish confidence
if self.camera_topics[cam]["publish_confidence"]:
loss_reco = F.mse_loss(prediction[:, 1:], data.x, reduction="none").mean(dim=1)
confidence = self.confidence_generator.inference_without_update(x=loss_reco)
out_confidence = confidence.reshape(H, W)
msg = rc.numpy_to_ros_image(out_confidence.cpu().numpy(), "passthrough")
msg.header = image_msg.header
msg.width = out_confidence.shape[0]
msg.height = out_confidence.shape[1]
self.camera_handler[cam]["conf_pub"].publish(msg)
# if self.camera_topics[cam]["publish_confidence"]:
# loss_reco = F.mse_loss(prediction[:, 1:], data.x, reduction="none").mean(dim=1)
# confidence = self.confidence_generator.inference_without_update(x=loss_reco)
# out_confidence = confidence.reshape(H, W)
# msg = rc.numpy_to_ros_image(out_confidence.cpu().numpy(), "passthrough")
# msg.header = image_msg.header
# msg.width = out_confidence.shape[0]
# msg.height = out_confidence.shape[1]
# self.camera_handler[cam]["conf_pub"].publish(msg)

# Publish features and feature_segments
if self.camera_topics[cam]["use_for_training"]:
Expand Down

0 comments on commit 53ec625

Please sign in to comment.