Skip to content

Commit

Permalink
Feature add new one hot function meeting multi-dimensions (ranks) (#2613
Browse files Browse the repository at this point in the history
)

* add one hot with axis and values function

* update one hot multidimentional function

* implementing on numeric.rs

* update one hot method in numeric

* update one hot function to deal with additional dims
add one hot test

* added tests for one hot

* modify function name
modify format
add tests

* modify to respond to difference between Tensor type and values type

* fix clippy point out and doc test

* do refactoring
modify comments

* update burn book to publish one hot plus method

* modify one_hot_plus to one_hot_fill and args names

* modify one_hot function in int impl and float impl
modify one_hot tests

* modify numeric to clear logic

* modify miscs due to validation, linnter and formatter

* modify documents for tensor api

* modify codes to follow review comments

* modify codes to follow reviews

* modify tests to follow reviews comments

* Improve check message

---------

Co-authored-by: Guillaume Lagrange <[email protected]>
  • Loading branch information
tiruka and laggui authored Jan 15, 2025
1 parent 93cafc4 commit ad81344
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 118 deletions.
4 changes: 2 additions & 2 deletions burn-book/src/building-blocks/tensor.md
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`.
| `tensor.neg()` or `-tensor` | `-tensor` |
| `tensor.not_equal_elem(scalar)` | `tensor.ne(scalar)` |
| `tensor.ones_like()` | `torch.ones_like(tensor)` |
| `tensor.one_hot(num_classes)` | `torch.nn.functional.one_hot` |
| `tensor.one_hot_fill(num_classes, on_value, off_value, axis)` | N/A |
| `tensor.pad(pads, value)` | `torch.nn.functional.pad(input, pad, value)` |
| `tensor.powf(other)` or `tensor.powi(intother)` | `tensor.pow(other)` |
| `tensor.powf_scalar(scalar)` or `tensor.powi_scalar(intscalar)` | `tensor.pow(scalar)` |
Expand Down Expand Up @@ -258,7 +260,6 @@ Those operations are only available for `Float` tensors.

| Burn API | PyTorch Equivalent |
| --------------------------------------------- | ---------------------------------- |
| `Tensor::one_hot(index, num_classes, device)` | N/A |
| `tensor.cast(dtype)` | `tensor.to(dtype)` |
| `tensor.ceil()` | `tensor.ceil()` |
| `tensor.cos()` | `tensor.cos()` |
Expand Down Expand Up @@ -296,7 +297,6 @@ Those operations are only available for `Int` tensors.
| `tensor.from_ints(ints)` | N/A |
| `tensor.int_random(shape, distribution, device)` | N/A |
| `tensor.cartesian_grid(shape, device)` | N/A |
| `tensor.one_hot(num_classes)` | N/A |

### Bool Operations

Expand Down
34 changes: 17 additions & 17 deletions crates/burn-tensor/src/tensor/api/check.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{backend::Backend, BasicOps, Int, Shape, Tensor};
use crate::{backend::Backend, BasicOps, Numeric, Shape, Tensor};
use alloc::format;
use alloc::string::{String, ToString};
use alloc::vec;
Expand Down Expand Up @@ -447,22 +447,8 @@ impl TensorCheck {
check
}

pub(crate) fn one_hot_index(index: usize, num_classes: usize) -> Self {
let mut check = Self::Ok;
if index >= num_classes {
check = check.register(
"One Hot",
TensorError::new(format!(
"Can't create a one hot tensor with index ({index}) greater or equal to the number of classes ({num_classes})",
)),
);
}

check
}

pub(crate) fn one_hot_tensor<B: Backend>(
index_tensor: Tensor<B, 1, Int>,
pub(crate) fn one_hot_tensor<B: Backend, const D: usize, K: Numeric<B>>(
index_tensor: Tensor<B, D, K>,
num_classes: usize,
) -> Self {
let mut check = Self::Ok;
Expand All @@ -487,6 +473,20 @@ impl TensorCheck {
check
}

pub(crate) fn one_hot_tensor_rank<const D: usize, const D2: usize>() -> Self {
let mut check = Self::Ok;
if D + 1 != D2 {
check = check.register(
"One Hot",
TensorError::new(
"The one-hot tensor rank must correspond to the rank of the tensor + 1",
)
.details(format!("Expected D2={}, got {D2}", D + 1)),
);
}
check
}

pub(crate) fn swap_dims<const D: usize>(dim1: usize, dim2: usize) -> Self {
let mut check = Self::Ok;

Expand Down
34 changes: 1 addition & 33 deletions crates/burn-tensor/src/tensor/api/float.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
use alloc::vec::Vec;
use core::convert::TryInto;

use crate::check::TensorCheck;
use crate::quantization::{QuantizationParameters, QuantizationScheme};
use crate::tensor::backend::Backend;
use crate::tensor::stats;
use crate::tensor::{Distribution, Shape, TensorData};
use crate::tensor::{Distribution, TensorData};
use crate::Tensor;
use crate::{check, FloatDType};
use crate::{Int, TensorPrimitive};
Expand Down Expand Up @@ -174,35 +171,6 @@ where
)))
}

/// Create a one hot tensor.
///
/// # Example
///
/// ```rust
/// use burn_tensor::backend::Backend;
/// use burn_tensor::Tensor;
///
/// fn example<B: Backend>() {
/// let device = Default::default();
/// let one_hot = Tensor::<B, 1>::one_hot(2, 10, &device);
/// println!("{}", one_hot.to_data());
/// // [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
/// }
/// ```
pub fn one_hot(index: usize, num_classes: usize, device: &B::Device) -> Self {
check!(TensorCheck::one_hot_index(index, num_classes));

let mut dims = [1; D];
dims[D - 1] = num_classes;
let shape = Shape::new(dims);
let ranges: Vec<_> = shape.dims.iter().map(|dim| 0..*dim).collect();
let tensor = Tensor::zeros(shape, device);
let mut ranges: [core::ops::Range<usize>; D] = ranges.try_into().unwrap();
ranges[D - 1] = index..index + 1;

tensor.slice_assign(ranges, Tensor::ones(Shape::new([1; D]), device))
}

/// Applies the matrix multiplication operation.
///
/// `C = AB`
Expand Down
30 changes: 0 additions & 30 deletions crates/burn-tensor/src/tensor/api/int.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use crate::check;
use crate::check::TensorCheck;
use crate::{
backend::Backend, cartesian_grid, Float, Int, Shape, Tensor, TensorData, TensorPrimitive,
};
Expand Down Expand Up @@ -29,34 +27,6 @@ where
pub fn arange_step(range: Range<i64>, step: usize, device: &B::Device) -> Self {
Tensor::new(B::int_arange_step(range, step, device))
}

/// Create a one hot tensor from an index tensor.
///
/// # Arguments
///
/// * `num_classes` - The number of classes to use in encoding.
///
/// # Example
///
/// ```rust
/// use burn_tensor::backend::Backend;
/// use burn_tensor::{Tensor, Int};
///
/// fn example<B: Backend>() {
/// let device = B::Device::default();
/// let indices: Tensor<B, 1, Int> = Tensor::from_ints([0, 1, 2, 3], &device);
/// let one_hot = indices.one_hot(4);
/// println!("{}", one_hot.to_data());
/// // [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]
/// }
/// ```
pub fn one_hot(self, num_classes: usize) -> Tensor<B, 2, Int> {
check!(TensorCheck::one_hot_tensor(self.clone(), num_classes));
let [num_samples] = self.dims();
let indices = self.unsqueeze_dim(1);
let values = indices.ones_like();
Tensor::zeros([num_samples, num_classes], &indices.device()).scatter(1, indices, values)
}
}

impl<const D: usize, B> Tensor<B, D, Int>
Expand Down
97 changes: 97 additions & 0 deletions crates/burn-tensor/src/tensor/api/numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2034,6 +2034,103 @@ where
// Assign the original tensor data to the appropriate slice of the padded tensor
padded_tensor.slice_assign(ranges, self)
}
/// Create a one hot tensor.
///
/// # Example
///
/// ```rust
/// use burn_tensor::backend::Backend;
/// use burn_tensor::Tensor;
///
/// fn example<B: Backend>(){
/// let device = Default::default();
/// let indices: Tensor<B, 1> = Tensor::from_floats([0.0, 1.0, 2.0, 3.0], &device);
/// let one_hot: Tensor<B, 2> = indices.one_hot(4);
/// println!("{}", one_hot.to_data());
/// // [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]
/// }
/// ```
pub fn one_hot<const D2: usize>(self, num_classes: usize) -> Tensor<B, D2, K> {
check!(TensorCheck::one_hot_tensor(self.clone(), num_classes));
self.one_hot_fill(num_classes, 1.0, 0.0, -1)
}

/// Create a one-hot encoded tensor with configurable `num_classes`, `on_value`, `off_value`, and `axis` including high-ranked tensors.
///
/// # Arguments
///
/// * `num_classes`: The number of classes for the one-hot encoding, which defines the size of the one-hot dimension.
/// * `on_value`: The value to assign for active positions (corresponding to indices).
/// * `off_value`: The value to assign for inactive positions.
/// * `axis`: The axis along which the one-hot dimension is added. Supports negative indexing.
///
/// # Returns
///
/// A tensor with one additional dimension for the one-hot encoding, where active positions are filled with `on_value` and others with `off_value`.
///
/// # Example
/// ```rust
/// use burn_tensor::backend::Backend;
/// use burn_tensor::{Tensor, Float};
/// fn example<B: Backend<FloatElem: From<f32>>>() {
/// let device = B::Device::default();
/// let indices: Tensor<B, 2, Float> = Tensor::from_floats([[0., 2.], [1., -1.]], &device);
/// // One-hot encoding
/// let tensor:Tensor<B, 3, Float> = indices.one_hot_fill(3, 5.0.into(), 0.0.into(), -1);
/// println!("{tensor}");
/// // [[[5.0, 0.0, 0.0],
/// // [0.0, 0.0, 5.0]],
/// // [[0.0, 5.0, 0.0],
/// // [0.0, 0.0, 5.0]]]
/// }
/// ```
pub fn one_hot_fill<const D2: usize>(
self,
num_classes: usize,
on_value: f32,
off_value: f32,
axis: i64,
) -> Tensor<B, D2, K> {
check!(TensorCheck::one_hot_tensor_rank::<D, D2>());
// Initialize shape from the current tensor dimensions and prepare for modification
let mut shape = self.shape().dims::<D>().to_vec();
let device = self.device();
let rank = self.dims().len();

// Adjust negative axis to a positive index
let axis = if axis < 0 {
axis + rank as i64 + 1
} else {
axis
};

// Ensure axis is within valid range
if axis < 0 || axis > rank as i64 {
panic!("Axis out of range. Accepted range is [-r-1, r] where r = rank(indices).");
}
// Convert the input tensor to integer indices
let indices: Tensor<B, D, Int> =
Tensor::from_data(self.to_data().convert::<i64>(), &device);
// Insert the new dimension for the one-hot representation
shape.insert(axis as usize, num_classes);
// Adjust indices to valid range and handle invalid indices
let adjusted_indices = indices
.clone()
.mask_fill(self.clone().lower_elem(0), num_classes as i64) // Handle negative indices
.add(indices.clone().mask_fill(self.clone().greater_elem(0), 0)); // Handle positive indices
// Unsqueeze the indices tensor along the specified axis
let indices_unsqueezed: Tensor<B, D2, Int> = adjusted_indices.unsqueeze_dim(axis as usize);

// Initialize the output tensor with the off_value
let output = Tensor::full(shape.clone(), off_value, &device);

// Prepare scatter tensor for on_value and off_value adjustments
let scatter_on_values = Tensor::full(indices_unsqueezed.shape(), on_value, &device)
- Tensor::full(indices_unsqueezed.shape(), off_value, &self.device());

// Scatter on_value at the appropriate indices to create the one-hot representation
output.scatter(axis as usize, indices_unsqueezed, scatter_on_values)
}

/// Returns a new tensor with boolean elements indicating whether each element of the input is NaN.
///
Expand Down
Loading

0 comments on commit ad81344

Please sign in to comment.