From ad81344821bbdca2fde1b5debda4d83d96370876 Mon Sep 17 00:00:00 2001 From: tiruka <33803972+tiruka@users.noreply.github.com> Date: Thu, 16 Jan 2025 01:44:50 +0900 Subject: [PATCH] Feature add new one hot function meeting multi-dimensions (ranks) (#2613) * add one hot with axis and values function * update one hot multidimentional function * implementing on numeric.rs * update one hot method in numeric * update one hot function to deal with additional dims add one hot test * added tests for one hot * modify function name modify format add tests * modify to respond to difference between Tensor type and values type * fix clippy point out and doc test * do refactoring modify comments * update burn book to publish one hot plus method * modify one_hot_plus to one_hot_fill and args names * modify one_hot function in int impl and float impl modify one_hot tests * modify numeric to clear logic * modify miscs due to validation, linnter and formatter * modify documents for tensor api * modify codes to follow review comments * modify codes to follow reviews * modify tests to follow reviews comments * Improve check message --------- Co-authored-by: Guillaume Lagrange --- burn-book/src/building-blocks/tensor.md | 4 +- crates/burn-tensor/src/tensor/api/check.rs | 34 +++--- crates/burn-tensor/src/tensor/api/float.rs | 34 +----- crates/burn-tensor/src/tensor/api/int.rs | 30 ----- crates/burn-tensor/src/tensor/api/numeric.rs | 97 ++++++++++++++++ crates/burn-tensor/src/tests/ops/one_hot.rs | 112 +++++++++++++------ 6 files changed, 193 insertions(+), 118 deletions(-) diff --git a/burn-book/src/building-blocks/tensor.md b/burn-book/src/building-blocks/tensor.md index fb429ffd0f..8a7c01bbc9 100644 --- a/burn-book/src/building-blocks/tensor.md +++ b/burn-book/src/building-blocks/tensor.md @@ -228,6 +228,8 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`. | `tensor.neg()` or `-tensor` | `-tensor` | | `tensor.not_equal_elem(scalar)` | `tensor.ne(scalar)` | | `tensor.ones_like()` | `torch.ones_like(tensor)` | +| `tensor.one_hot(num_classes)` | `torch.nn.functional.one_hot` | +| `tensor.one_hot_fill(num_classes, on_value, off_value, axis)` | N/A | | `tensor.pad(pads, value)` | `torch.nn.functional.pad(input, pad, value)` | | `tensor.powf(other)` or `tensor.powi(intother)` | `tensor.pow(other)` | | `tensor.powf_scalar(scalar)` or `tensor.powi_scalar(intscalar)` | `tensor.pow(scalar)` | @@ -258,7 +260,6 @@ Those operations are only available for `Float` tensors. | Burn API | PyTorch Equivalent | | --------------------------------------------- | ---------------------------------- | -| `Tensor::one_hot(index, num_classes, device)` | N/A | | `tensor.cast(dtype)` | `tensor.to(dtype)` | | `tensor.ceil()` | `tensor.ceil()` | | `tensor.cos()` | `tensor.cos()` | @@ -296,7 +297,6 @@ Those operations are only available for `Int` tensors. | `tensor.from_ints(ints)` | N/A | | `tensor.int_random(shape, distribution, device)` | N/A | | `tensor.cartesian_grid(shape, device)` | N/A | -| `tensor.one_hot(num_classes)` | N/A | ### Bool Operations diff --git a/crates/burn-tensor/src/tensor/api/check.rs b/crates/burn-tensor/src/tensor/api/check.rs index d4ab13faf4..8a6fb2ad78 100644 --- a/crates/burn-tensor/src/tensor/api/check.rs +++ b/crates/burn-tensor/src/tensor/api/check.rs @@ -1,4 +1,4 @@ -use crate::{backend::Backend, BasicOps, Int, Shape, Tensor}; +use crate::{backend::Backend, BasicOps, Numeric, Shape, Tensor}; use alloc::format; use alloc::string::{String, ToString}; use alloc::vec; @@ -447,22 +447,8 @@ impl TensorCheck { check } - pub(crate) fn one_hot_index(index: usize, num_classes: usize) -> Self { - let mut check = Self::Ok; - if index >= num_classes { - check = check.register( - "One Hot", - TensorError::new(format!( - "Can't create a one hot tensor with index ({index}) greater or equal to the number of classes ({num_classes})", - )), - ); - } - - check - } - - pub(crate) fn one_hot_tensor( - index_tensor: Tensor, + pub(crate) fn one_hot_tensor>( + index_tensor: Tensor, num_classes: usize, ) -> Self { let mut check = Self::Ok; @@ -487,6 +473,20 @@ impl TensorCheck { check } + pub(crate) fn one_hot_tensor_rank() -> Self { + let mut check = Self::Ok; + if D + 1 != D2 { + check = check.register( + "One Hot", + TensorError::new( + "The one-hot tensor rank must correspond to the rank of the tensor + 1", + ) + .details(format!("Expected D2={}, got {D2}", D + 1)), + ); + } + check + } + pub(crate) fn swap_dims(dim1: usize, dim2: usize) -> Self { let mut check = Self::Ok; diff --git a/crates/burn-tensor/src/tensor/api/float.rs b/crates/burn-tensor/src/tensor/api/float.rs index a6f59f6e88..b50d0d0596 100644 --- a/crates/burn-tensor/src/tensor/api/float.rs +++ b/crates/burn-tensor/src/tensor/api/float.rs @@ -1,11 +1,8 @@ -use alloc::vec::Vec; -use core::convert::TryInto; - use crate::check::TensorCheck; use crate::quantization::{QuantizationParameters, QuantizationScheme}; use crate::tensor::backend::Backend; use crate::tensor::stats; -use crate::tensor::{Distribution, Shape, TensorData}; +use crate::tensor::{Distribution, TensorData}; use crate::Tensor; use crate::{check, FloatDType}; use crate::{Int, TensorPrimitive}; @@ -174,35 +171,6 @@ where ))) } - /// Create a one hot tensor. - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::Tensor; - /// - /// fn example() { - /// let device = Default::default(); - /// let one_hot = Tensor::::one_hot(2, 10, &device); - /// println!("{}", one_hot.to_data()); - /// // [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] - /// } - /// ``` - pub fn one_hot(index: usize, num_classes: usize, device: &B::Device) -> Self { - check!(TensorCheck::one_hot_index(index, num_classes)); - - let mut dims = [1; D]; - dims[D - 1] = num_classes; - let shape = Shape::new(dims); - let ranges: Vec<_> = shape.dims.iter().map(|dim| 0..*dim).collect(); - let tensor = Tensor::zeros(shape, device); - let mut ranges: [core::ops::Range; D] = ranges.try_into().unwrap(); - ranges[D - 1] = index..index + 1; - - tensor.slice_assign(ranges, Tensor::ones(Shape::new([1; D]), device)) - } - /// Applies the matrix multiplication operation. /// /// `C = AB` diff --git a/crates/burn-tensor/src/tensor/api/int.rs b/crates/burn-tensor/src/tensor/api/int.rs index 08bdab0fe7..e882a107c7 100644 --- a/crates/burn-tensor/src/tensor/api/int.rs +++ b/crates/burn-tensor/src/tensor/api/int.rs @@ -1,5 +1,3 @@ -use crate::check; -use crate::check::TensorCheck; use crate::{ backend::Backend, cartesian_grid, Float, Int, Shape, Tensor, TensorData, TensorPrimitive, }; @@ -29,34 +27,6 @@ where pub fn arange_step(range: Range, step: usize, device: &B::Device) -> Self { Tensor::new(B::int_arange_step(range, step, device)) } - - /// Create a one hot tensor from an index tensor. - /// - /// # Arguments - /// - /// * `num_classes` - The number of classes to use in encoding. - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::{Tensor, Int}; - /// - /// fn example() { - /// let device = B::Device::default(); - /// let indices: Tensor = Tensor::from_ints([0, 1, 2, 3], &device); - /// let one_hot = indices.one_hot(4); - /// println!("{}", one_hot.to_data()); - /// // [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]] - /// } - /// ``` - pub fn one_hot(self, num_classes: usize) -> Tensor { - check!(TensorCheck::one_hot_tensor(self.clone(), num_classes)); - let [num_samples] = self.dims(); - let indices = self.unsqueeze_dim(1); - let values = indices.ones_like(); - Tensor::zeros([num_samples, num_classes], &indices.device()).scatter(1, indices, values) - } } impl Tensor diff --git a/crates/burn-tensor/src/tensor/api/numeric.rs b/crates/burn-tensor/src/tensor/api/numeric.rs index 59dc44b7e6..b82175c3fe 100644 --- a/crates/burn-tensor/src/tensor/api/numeric.rs +++ b/crates/burn-tensor/src/tensor/api/numeric.rs @@ -2034,6 +2034,103 @@ where // Assign the original tensor data to the appropriate slice of the padded tensor padded_tensor.slice_assign(ranges, self) } + /// Create a one hot tensor. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::Tensor; + /// + /// fn example(){ + /// let device = Default::default(); + /// let indices: Tensor = Tensor::from_floats([0.0, 1.0, 2.0, 3.0], &device); + /// let one_hot: Tensor = indices.one_hot(4); + /// println!("{}", one_hot.to_data()); + /// // [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]] + /// } + /// ``` + pub fn one_hot(self, num_classes: usize) -> Tensor { + check!(TensorCheck::one_hot_tensor(self.clone(), num_classes)); + self.one_hot_fill(num_classes, 1.0, 0.0, -1) + } + + /// Create a one-hot encoded tensor with configurable `num_classes`, `on_value`, `off_value`, and `axis` including high-ranked tensors. + /// + /// # Arguments + /// + /// * `num_classes`: The number of classes for the one-hot encoding, which defines the size of the one-hot dimension. + /// * `on_value`: The value to assign for active positions (corresponding to indices). + /// * `off_value`: The value to assign for inactive positions. + /// * `axis`: The axis along which the one-hot dimension is added. Supports negative indexing. + /// + /// # Returns + /// + /// A tensor with one additional dimension for the one-hot encoding, where active positions are filled with `on_value` and others with `off_value`. + /// + /// # Example + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Float}; + /// fn example>>() { + /// let device = B::Device::default(); + /// let indices: Tensor = Tensor::from_floats([[0., 2.], [1., -1.]], &device); + /// // One-hot encoding + /// let tensor:Tensor = indices.one_hot_fill(3, 5.0.into(), 0.0.into(), -1); + /// println!("{tensor}"); + /// // [[[5.0, 0.0, 0.0], + /// // [0.0, 0.0, 5.0]], + /// // [[0.0, 5.0, 0.0], + /// // [0.0, 0.0, 5.0]]] + /// } + /// ``` + pub fn one_hot_fill( + self, + num_classes: usize, + on_value: f32, + off_value: f32, + axis: i64, + ) -> Tensor { + check!(TensorCheck::one_hot_tensor_rank::()); + // Initialize shape from the current tensor dimensions and prepare for modification + let mut shape = self.shape().dims::().to_vec(); + let device = self.device(); + let rank = self.dims().len(); + + // Adjust negative axis to a positive index + let axis = if axis < 0 { + axis + rank as i64 + 1 + } else { + axis + }; + + // Ensure axis is within valid range + if axis < 0 || axis > rank as i64 { + panic!("Axis out of range. Accepted range is [-r-1, r] where r = rank(indices)."); + } + // Convert the input tensor to integer indices + let indices: Tensor = + Tensor::from_data(self.to_data().convert::(), &device); + // Insert the new dimension for the one-hot representation + shape.insert(axis as usize, num_classes); + // Adjust indices to valid range and handle invalid indices + let adjusted_indices = indices + .clone() + .mask_fill(self.clone().lower_elem(0), num_classes as i64) // Handle negative indices + .add(indices.clone().mask_fill(self.clone().greater_elem(0), 0)); // Handle positive indices + // Unsqueeze the indices tensor along the specified axis + let indices_unsqueezed: Tensor = adjusted_indices.unsqueeze_dim(axis as usize); + + // Initialize the output tensor with the off_value + let output = Tensor::full(shape.clone(), off_value, &device); + + // Prepare scatter tensor for on_value and off_value adjustments + let scatter_on_values = Tensor::full(indices_unsqueezed.shape(), on_value, &device) + - Tensor::full(indices_unsqueezed.shape(), off_value, &self.device()); + + // Scatter on_value at the appropriate indices to create the one-hot representation + output.scatter(axis as usize, indices_unsqueezed, scatter_on_values) + } /// Returns a new tensor with boolean elements indicating whether each element of the input is NaN. /// diff --git a/crates/burn-tensor/src/tests/ops/one_hot.rs b/crates/burn-tensor/src/tests/ops/one_hot.rs index 310399119f..24e8f24b38 100644 --- a/crates/burn-tensor/src/tests/ops/one_hot.rs +++ b/crates/burn-tensor/src/tests/ops/one_hot.rs @@ -1,74 +1,114 @@ #[burn_tensor_testgen::testgen(one_hot)] mod tests { use super::*; - use burn_tensor::{Int, TensorData}; + use burn_tensor::{ + as_type, + backend::Backend, + tests::{Float as _, Int as _}, + Float, Int, Numeric, Shape, Tensor, TensorData, + }; #[test] fn float_should_support_one_hot() { - let device = Default::default(); - - let tensor = TestTensor::<1>::one_hot(0, 5, &device); - let expected = TensorData::from([1., 0., 0., 0., 0.]); - tensor.into_data().assert_eq(&expected, false); - - let tensor = TestTensor::<1>::one_hot(1, 5, &device); - let expected = TensorData::from([0., 1., 0., 0., 0.]); - tensor.into_data().assert_eq(&expected, false); - - let tensor = TestTensor::<1>::one_hot(4, 5, &device); - let expected = TensorData::from([0., 0., 0., 0., 1.]); - tensor.into_data().assert_eq(&expected, false); + let tensor = TestTensor::<1>::from([0.0, 1.0, 4.0]); + let one_hot_tensor: Tensor = tensor.one_hot(5); + let expected = TensorData::from([ + [1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + ]); + one_hot_tensor.into_data().assert_eq(&expected, false); + } - let tensor = TestTensor::<1>::one_hot(1, 2, &device); - let expected = TensorData::from([0., 1.]); - tensor.into_data().assert_eq(&expected, false); + #[test] + fn float_should_support_one_hot_index() { + let tensor = TestTensor::<1>::from([2.0]); + let one_hot_tensor: Tensor = tensor.one_hot::<2>(10); + let expected = TensorData::from([[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]); + one_hot_tensor.into_data().assert_eq(&expected, false); } #[test] #[should_panic] fn float_one_hot_should_panic_when_index_exceeds_number_of_classes() { - let device = Default::default(); - let tensor = TestTensor::<1>::one_hot(1, 1, &device); + let tensor = TestTensor::<1>::from([5.0]); + let result: Tensor = tensor.one_hot(5); } #[test] #[should_panic] fn float_one_hot_should_panic_when_number_of_classes_is_zero() { - let device = Default::default(); - let tensor = TestTensor::<1>::one_hot(0, 0, &device); + let tensor = TestTensor::<1>::from([0.0]); + let result: Tensor = tensor.one_hot(0); } #[test] fn int_should_support_one_hot() { - let device = Default::default(); - - let index_tensor = TestTensorInt::<1>::arange(0..5, &device); - let one_hot_tensor = index_tensor.one_hot(5); - let expected = TestTensorInt::eye(5, &device).into_data(); + let tensor = TestTensorInt::<1>::from([0, 1, 4]); + let one_hot_tensor: Tensor = tensor.one_hot(5); + let expected = TensorData::from([[1, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 1]]); one_hot_tensor.into_data().assert_eq(&expected, false); } #[test] #[should_panic] fn int_one_hot_should_panic_when_index_exceeds_number_of_classes() { - let device = Default::default(); - let index_tensor = TestTensorInt::<1>::arange(0..6, &device); - let one_hot_tensor = index_tensor.one_hot(5); + let tensor = TestTensorInt::<1>::from([5]); + let result: Tensor = tensor.one_hot(5); } #[test] #[should_panic] fn int_one_hot_should_panic_when_number_of_classes_is_zero() { - let device = Default::default(); - let index_tensor = TestTensorInt::<1>::arange(0..3, &device); - let one_hot_tensor = index_tensor.one_hot(0); + let tensor = TestTensorInt::<1>::from([2]); + let result: Tensor = tensor.one_hot(0); + } + + #[test] + fn one_hot_fill_with_positive_axis_and_indices() { + let tensor = TestTensorInt::<2>::from([[1, 9], [2, 4]]); + let expected = TensorData::from(as_type!(IntType: [ + [[1, 1], [3, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 3]], + [[1, 1], [1, 1], [3, 1], [1, 1], [1, 3], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1]] + ])); + + let one_hot_tensor: Tensor = tensor.one_hot_fill(10, 3.0, 1.0, 1); + + one_hot_tensor.into_data().assert_eq(&expected, true); + } + + #[test] + fn one_hot_fill_with_negative_axis_and_indices() { + let tensor = TestTensor::<2>::from([[0, 2], [1, -1]]); + let expected = TensorData::from(as_type!(FloatType: [ + [[5.0, 0.0, 0.0], [0.0, 0.0, 5.0]], + [[0.0, 5.0, 0.0], [0.0, 0.0, 5.0]] + ])); + + let one_hot_tensor: Tensor = tensor.one_hot_fill(3, 5.0, 0.0, -1); + + one_hot_tensor.into_data().assert_eq(&expected, true); } #[test] + fn one_hot_fill_with_negative_indices() { + let tensor = TestTensor::<1>::from([0.0, -7.0, -8.0]); + let expected = TensorData::from(as_type!(FloatType: [ + [3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 1.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] + ])); + + let one_hot_tensor: Tensor = tensor.one_hot_fill(10, 3.0, 1.0, 1); + + one_hot_tensor.into_data().assert_eq(&expected, true); + } + #[should_panic] - fn int_one_hot_should_panic_when_number_of_classes_is_1() { - let device = Default::default(); - let index_tensor = TestTensorInt::<1>::arange(0..3, &device); - let one_hot_tensor = index_tensor.one_hot(1); + #[test] + fn one_hot_fill_should_panic_when_axis_out_range_of_rank() { + let tensor = TestTensor::<2>::from([[0.0, 2.0], [1.0, -1.0]]); + + let one_hot_tensor: Tensor = tensor.one_hot_fill(2, 5.0, 0.0, 3); } }