diff --git a/replay_trajectory_classification/decoder.py b/replay_trajectory_classification/decoder.py index 7701321..bb0b7c9 100644 --- a/replay_trajectory_classification/decoder.py +++ b/replay_trajectory_classification/decoder.py @@ -326,9 +326,7 @@ def convert_results_to_xarray( { key: ( dims, - mask(value, is_track_interior) - .squeeze(axis=-1) - .reshape(new_shape, order="F"), + mask(value, is_track_interior).reshape(new_shape, order="F"), ) for key, value in results.items() }, @@ -340,9 +338,7 @@ def convert_results_to_xarray( { key: ( dims, - mask(value, is_track_interior) - .squeeze(axis=-1) - .reshape(new_shape), + mask(value, is_track_interior).reshape(new_shape), ) for key, value in results.items() },