From 28ccb70747b040a3b7a302f02b4d97f00a3b5802 Mon Sep 17 00:00:00 2001 From: gitttt-1234 Date: Wed, 18 Dec 2024 11:20:48 -0800 Subject: [PATCH] Fix scaling for visualizer --- sleap/nn/inference.py | 30 ++++++++++++++++-------------- sleap/nn/training.py | 4 ++-- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index 4f00419a5..df061a289 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -1690,7 +1690,7 @@ def __init__( keras_model: tf.keras.Model, crop_size: int, input_scale: float = 1.0, - precrop_resize: Optional[float] = 1.0, + precrop_resize: float = 1.0, pad_to_stride: int = 1, output_stride: Optional[int] = None, peak_threshold: float = 0.2, @@ -1830,6 +1830,13 @@ def call(self, inputs): # See: https://github.com/tensorflow/tensorflow/issues/6720 centroid_points = (centroid_points / self.input_scale) + 0.5 + # resize full images + if self.precrop_resize != 1.0: + full_imgs = sleap.nn.data.resizing.resize_image( + full_imgs, self.precrop_resize + ) + centroid_points *= self.precrop_resize + # Store crop offsets. crop_offsets = centroid_points - (self.crop_size / 2) @@ -1906,13 +1913,6 @@ def call(self, inputs): crop_offsets = centroid_points - (self.crop_size / 2) - # resize full images - if self.precrop_resize: - full_imgs = sleap.nn.data.resizing.resize_image( - full_imgs, self.precrop_resize - ) - centroid_points *= self.precrop_resize - # Crop instances around centroids. bboxes = sleap.nn.data.instance_cropping.make_centered_bboxes( centroid_points, self.crop_size, self.crop_size @@ -1978,10 +1978,10 @@ class FindInstancePeaks(InferenceLayer): input_scale: Float indicating if the images should be resized before being passed to the model. resize_input_image: Bool indicating if the crops should be resized. If - `CentroidCropGroundTruth` is used along with `FindInstancePeaks`, then the - images are resized in the `CentroidCropGroundTruth` and this is set to `False`. - However, the output keypoints are adjusted to the actual scale with the - `input_scaling` argument. + `CentroidCropGroundTruth` or `CentroidCrop` is used along with `FindInstancePeaks`, + then the images are resized in the `CentroidCropGroundTruth` or `CentroidCrop` + before cropping and this is set to `False`. However, the output keypoints + are adjusted to the actual scale with the `input_scaling` argument. output_stride: Output stride of the model, denoting the scale of the output confidence maps relative to the images (after input scaling). This is used for adjusting the peak coordinates to the image grid. This will be inferred @@ -2170,7 +2170,9 @@ def call( if "crop_offsets" in inputs: # Flatten (samples, ?, 2) -> (n_peaks, 2). crop_offsets = inputs["crop_offsets"].merge_dims(0, 1) - peak_points = peak_points + tf.expand_dims(crop_offsets, axis=1) + peak_points = peak_points + ( + tf.expand_dims(crop_offsets, axis=1) / self.input_scale + ) # Group peaks by sample (samples, ?, nodes, 2). peaks = tf.RaggedTensor.from_value_rowids( @@ -2384,7 +2386,7 @@ def _initialize_inference_model(self): keras_model=self.centroid_model.keras_model, crop_size=crop_size, input_scale=self.centroid_config.data.preprocessing.input_scaling, - precrop_resize=None, + precrop_resize=1.0, pad_to_stride=self.centroid_config.data.preprocessing.pad_to_stride, output_stride=self.centroid_config.model.heads.centroid.output_stride, peak_threshold=self.peak_threshold, diff --git a/sleap/nn/training.py b/sleap/nn/training.py index f56d8cf46..7d32dd797 100644 --- a/sleap/nn/training.py +++ b/sleap/nn/training.py @@ -1315,7 +1315,7 @@ def _setup_visualization(self): # Create an instance peak finding layer. find_peaks = FindInstancePeaks( keras_model=self.keras_model, - input_scale=self.config.data.preprocessing.input_scaling, + input_scale=1.0, peak_threshold=0.2, refinement="local", return_confmaps=True, @@ -1756,7 +1756,7 @@ def _setup_visualization(self): # Create an instance peak finding layer. find_peaks = FindInstancePeaks( keras_model=self.keras_model, - input_scale=self.config.data.preprocessing.input_scaling, + input_scale=1.0, peak_threshold=0.2, refinement="local", return_confmaps=True,