Skip to content

Commit

Permalink
Fix scaling for visualizer
Browse files Browse the repository at this point in the history
  • Loading branch information
gitttt-1234 committed Dec 18, 2024
1 parent d47c98a commit 28ccb70
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 16 deletions.
30 changes: 16 additions & 14 deletions sleap/nn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1690,7 +1690,7 @@ def __init__(
keras_model: tf.keras.Model,
crop_size: int,
input_scale: float = 1.0,
precrop_resize: Optional[float] = 1.0,
precrop_resize: float = 1.0,
pad_to_stride: int = 1,
output_stride: Optional[int] = None,
peak_threshold: float = 0.2,
Expand Down Expand Up @@ -1830,6 +1830,13 @@ def call(self, inputs):
# See: https://github.com/tensorflow/tensorflow/issues/6720
centroid_points = (centroid_points / self.input_scale) + 0.5

# resize full images
if self.precrop_resize != 1.0:
full_imgs = sleap.nn.data.resizing.resize_image(
full_imgs, self.precrop_resize
)
centroid_points *= self.precrop_resize

# Store crop offsets.
crop_offsets = centroid_points - (self.crop_size / 2)

Expand Down Expand Up @@ -1906,13 +1913,6 @@ def call(self, inputs):

crop_offsets = centroid_points - (self.crop_size / 2)

# resize full images
if self.precrop_resize:
full_imgs = sleap.nn.data.resizing.resize_image(
full_imgs, self.precrop_resize
)
centroid_points *= self.precrop_resize

# Crop instances around centroids.
bboxes = sleap.nn.data.instance_cropping.make_centered_bboxes(
centroid_points, self.crop_size, self.crop_size
Expand Down Expand Up @@ -1978,10 +1978,10 @@ class FindInstancePeaks(InferenceLayer):
input_scale: Float indicating if the images should be resized before being
passed to the model.
resize_input_image: Bool indicating if the crops should be resized. If
`CentroidCropGroundTruth` is used along with `FindInstancePeaks`, then the
images are resized in the `CentroidCropGroundTruth` and this is set to `False`.
However, the output keypoints are adjusted to the actual scale with the
`input_scaling` argument.
`CentroidCropGroundTruth` or `CentroidCrop` is used along with `FindInstancePeaks`,
then the images are resized in the `CentroidCropGroundTruth` or `CentroidCrop`
before cropping and this is set to `False`. However, the output keypoints
are adjusted to the actual scale with the `input_scaling` argument.
output_stride: Output stride of the model, denoting the scale of the output
confidence maps relative to the images (after input scaling). This is used
for adjusting the peak coordinates to the image grid. This will be inferred
Expand Down Expand Up @@ -2170,7 +2170,9 @@ def call(
if "crop_offsets" in inputs:
# Flatten (samples, ?, 2) -> (n_peaks, 2).
crop_offsets = inputs["crop_offsets"].merge_dims(0, 1)
peak_points = peak_points + tf.expand_dims(crop_offsets, axis=1)
peak_points = peak_points + (
tf.expand_dims(crop_offsets, axis=1) / self.input_scale
)

# Group peaks by sample (samples, ?, nodes, 2).
peaks = tf.RaggedTensor.from_value_rowids(
Expand Down Expand Up @@ -2384,7 +2386,7 @@ def _initialize_inference_model(self):
keras_model=self.centroid_model.keras_model,
crop_size=crop_size,
input_scale=self.centroid_config.data.preprocessing.input_scaling,
precrop_resize=None,
precrop_resize=1.0,
pad_to_stride=self.centroid_config.data.preprocessing.pad_to_stride,
output_stride=self.centroid_config.model.heads.centroid.output_stride,
peak_threshold=self.peak_threshold,
Expand Down
4 changes: 2 additions & 2 deletions sleap/nn/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1315,7 +1315,7 @@ def _setup_visualization(self):
# Create an instance peak finding layer.
find_peaks = FindInstancePeaks(
keras_model=self.keras_model,
input_scale=self.config.data.preprocessing.input_scaling,
input_scale=1.0,
peak_threshold=0.2,
refinement="local",
return_confmaps=True,
Expand Down Expand Up @@ -1756,7 +1756,7 @@ def _setup_visualization(self):
# Create an instance peak finding layer.
find_peaks = FindInstancePeaks(
keras_model=self.keras_model,
input_scale=self.config.data.preprocessing.input_scaling,
input_scale=1.0,
peak_threshold=0.2,
refinement="local",
return_confmaps=True,
Expand Down

0 comments on commit 28ccb70

Please sign in to comment.