Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for scatter(torch.scatter) #1444

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions candle-core/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,15 @@ pub trait BackendStorage: Sized {
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self>;

fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self>;
fn scatter(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
) -> Result<Self>;
fn scatter_add(
&self,
_: &Layout,
Expand Down
9 changes: 9 additions & 0 deletions candle-core/src/backprop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ impl Tensor {
} else if let Some(op) = node.op() {
match op {
Op::IndexAdd(t1, t2, t3, _)
| Op::Scatter(t1, t2, t3, _)
| Op::ScatterAdd(t1, t2, t3, _)
| Op::CustomOp3(t1, t2, t3, _)
| Op::WhereCond(t1, t2, t3) => {
Expand Down Expand Up @@ -384,6 +385,14 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;
}
Op::Scatter(init, indexes, src, dim) => {
let init_sum_grad = grads.or_insert(init)?;
*init_sum_grad = init_sum_grad.add(&grad)?;

let src_grad = grad.gather(indexes, *dim)?;
let src_sum_grad = grads.or_insert(src)?;
*src_sum_grad = src_sum_grad.add(&src_grad)?;
}
Op::ScatterAdd(init, indexes, src, dim) => {
let init_sum_grad = grads.or_insert(init)?;
*init_sum_grad = init_sum_grad.add(&grad)?;
Expand Down
74 changes: 74 additions & 0 deletions candle-core/src/cpu_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -900,6 +900,63 @@ impl<'a, I: IntDType> Map1 for IndexSelect<'a, I> {
}
}

struct Scatter<'a, I: IntDType> {
ids: &'a [I],
ids_l: &'a Layout,
dim: usize,
}

impl<'a, I: IntDType> Map2 for Scatter<'a, I> {
const OP: &'static str = "scatter";
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
let dst_len = l1.shape().elem_count();
let mut dst = vec![T::zero(); dst_len];
copy_strided_src_(v1, &mut dst, 0, l1);
let src = match src_l.contiguous_offsets() {
None => Err(Error::RequiresContiguous { op: "scatter" }.bt())?,
Some((o1, o2)) => &src[o1..o2],
};

let dim = self.dim;
let ids_dims = self.ids_l.dims();
let dst_dims = l1.dims();
let dst_dim_len = dst_dims[dim];
let dst_right_len: usize = dst_dims[dim + 1..].iter().product();

let ids_left_len: usize = ids_dims[..dim].iter().product();
let ids_dim_len = ids_dims[dim];
let ids_right_len: usize = ids_dims[dim + 1..].iter().product();

let ids = match self.ids_l.contiguous_offsets() {
Some((a, b)) => &self.ids[a..b],
None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
};
for left_i in 0..ids_left_len {
let start_ids_idx = left_i * ids_right_len * ids_dim_len;
let start_dst_idx = left_i * dst_right_len * dst_dim_len;
for i in 0..ids_dim_len {
let start_ids_idx = start_ids_idx + i * ids_right_len;
for right_i in 0..dst_right_len {
let ids_idx = start_ids_idx + right_i;
let index = ids[ids_idx].as_usize();
if index >= dst_dim_len {
Err(Error::InvalidIndex {
index,
size: dst_dim_len,
op: "gather",
}
.bt())?
}
let dst_idx = start_dst_idx + index * dst_right_len + right_i;
dst[dst_idx] = src[ids_idx]
}
}
}

Ok(dst)
}
}

struct ScatterAdd<'a, I: IntDType> {
ids: &'a [I],
ids_l: &'a Layout,
Expand Down Expand Up @@ -2587,6 +2644,23 @@ impl BackendStorage for CpuStorage {
}
}

fn scatter(
&self,
l: &Layout,
ids: &Self,
ids_l: &Layout,
src: &Self,
src_l: &Layout,
dim: usize,
) -> Result<Self> {
match ids {
Self::U8(ids) => Scatter { ids, ids_l, dim }.map(self, l, src, src_l),
Self::U32(ids) => Scatter { ids, ids_l, dim }.map(self, l, src, src_l),
Self::I64(ids) => Scatter { ids, ids_l, dim }.map(self, l, src, src_l),
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter")),
}
}

fn scatter_add(
&self,
l: &Layout,
Expand Down
67 changes: 67 additions & 0 deletions candle-core/src/cuda_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1023,6 +1023,58 @@ impl<'a> Map2InPlace for IndexAdd<'a> {
}
}

struct Scatter<'a>(&'a CudaStorage, &'a Layout, usize);

impl<'a> Map2InPlace for Scatter<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
dst: &mut CudaSlice<T>,
dst_shape: &Shape,
src: &CudaSlice<T>,
src_l: &Layout,
dev: &CudaDevice,
) -> Result<()> {
let ids = &self.0;
let ids_l = &self.1;
let dim = self.2;
let (ids_o1, ids_o2) = match ids_l.contiguous_offsets() {
Some(o12) => o12,
None => Err(crate::Error::RequiresContiguous { op: "scatter" }.bt())?,
};
let (name, ids) = match &ids.slice {
CudaStorageSlice::U32(slice) => {
("scatter_u32", *slice.slice(ids_o1..ids_o2).device_ptr())
}
CudaStorageSlice::I64(slice) => {
("scatter_i64", *slice.slice(ids_o1..ids_o2).device_ptr())
}
CudaStorageSlice::U8(slice) => {
("scatter_u8", *slice.slice(ids_o1..ids_o2).device_ptr())
}
_ => Err(CudaError::UnexpectedDType {
msg: "scatter ids should be u8/u32/i64",
expected: DType::U32,
got: ids.dtype(),
})?,
};
let src = match src_l.contiguous_offsets() {
Some((o1, o2)) => src.slice(o1..o2),
None => Err(crate::Error::RequiresContiguous { op: "scatter" }.bt())?,
};
let left_sz: usize = src_l.dims()[..dim].iter().product();
let right_sz: usize = src_l.dims()[dim + 1..].iter().product();
let src_dim_sz = src_l.dims()[dim];
let dst_dim_sz = dst_shape.dims()[dim];
let cfg = LaunchConfig::for_num_elems((left_sz * right_sz) as u32);
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::INDEXING)?;
// SAFETY: Set later by running the kernel.
let params = (ids, &src, dst, left_sz, src_dim_sz, dst_dim_sz, right_sz);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(())
}
}

struct ScatterAdd<'a>(&'a CudaStorage, &'a Layout, usize);
impl<'a> Map2InPlace for ScatterAdd<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
Expand Down Expand Up @@ -1994,6 +2046,21 @@ impl BackendStorage for CudaStorage {
let slice = Gather(ids, ids_l, dim).map(&self.slice, &device, l)?;
Ok(Self { slice, device })
}
fn scatter(
&self,
l: &Layout,
ids: &Self,
ids_l: &Layout,
src: &Self,
src_l: &Layout,
dim: usize,
) -> Result<Self> {
let device = self.device().clone();
let mut acc = device.zeros_impl(l.shape(), self.dtype())?;
self.copy_strided_src(&mut acc, 0, l)?;
Scatter(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?;
Ok(acc)
}
fn scatter_add(
&self,
l: &Layout,
Expand Down
12 changes: 12 additions & 0 deletions candle-core/src/dummy_cuda_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,18 @@ impl crate::backend::BackendStorage for CudaStorage {
Err(Error::NotCompiledWithCudaSupport)
}

fn scatter(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}

fn scatter_add(
&self,
_: &Layout,
Expand Down
12 changes: 12 additions & 0 deletions candle-core/src/dummy_metal_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,18 @@ impl crate::backend::BackendStorage for MetalStorage {
Err(Error::NotCompiledWithMetalSupport)
}

fn scatter(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}

fn scatter_add(
&self,
_: &Layout,
Expand Down
1 change: 1 addition & 0 deletions candle-core/src/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ pub enum Op {
Reduce(Tensor, ReduceOp, Vec<usize>),
Matmul(Tensor, Tensor),
Gather(Tensor, Tensor, usize),
Scatter(Tensor, Tensor, Tensor, usize),
ScatterAdd(Tensor, Tensor, Tensor, usize),
IndexSelect(Tensor, Tensor, usize),
IndexAdd(Tensor, Tensor, Tensor, usize),
Expand Down
28 changes: 28 additions & 0 deletions candle-core/src/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,34 @@ impl Storage {
}
}

pub(crate) fn scatter(
&self,
l: &Layout,
indexes: &Self,
indexes_l: &Layout,
source: &Self,
source_l: &Layout,
d: usize,
) -> Result<Self> {
self.same_device(indexes, "scatter")?;
self.same_device(source, "scatter")?;
match (self, indexes, source) {
(Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => {
let storage = s.scatter(l, indexes, indexes_l, source, source_l, d)?;
Ok(Self::Cpu(storage))
}
(Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => {
let storage = s.scatter(l, indexes, indexes_l, source, source_l, d)?;
Ok(Self::Cuda(storage))
}
(Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
Ok(Self::Metal(storage))
}
_ => unreachable!(),
}
}

pub(crate) fn scatter_add(
&self,
l: &Layout,
Expand Down
50 changes: 50 additions & 0 deletions candle-core/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1200,6 +1200,56 @@ impl Tensor {
self.index_select(ids, 0)
}

/// Writes all values from the tensor src into self at the indices specified in the
/// index tensor. For each value in src, its output index is specified by its index
/// in src for dimension != dim and by the corresponding value in index for
/// dimension = dim.
pub fn scatter<D: Dim>(&self, indexes: &Self, src: &Self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "scatter")?;
let src_dims = src.dims();
let self_dims = self.dims();
let mismatch = if src_dims.len() != self_dims.len() {
true
} else {
let mut mismatch = false;
for (i, (&d1, &d2)) in self_dims.iter().zip(src_dims.iter()).enumerate() {
if i != dim && d1 != d2 {
mismatch = true;
break;
}
}
mismatch
};
if mismatch {
Err(Error::ShapeMismatchBinaryOp {
op: "scatter (self, src)",
lhs: self.shape().clone(),
rhs: src.shape().clone(),
}
.bt())?
}
if indexes.dims() != src.dims() {
Err(Error::ShapeMismatchBinaryOp {
op: "scatter (indexes, src)",
lhs: indexes.shape().clone(),
rhs: src.shape().clone(),
}
.bt())?
}
let storage = self.storage().scatter(
self.layout(),
&indexes.storage(),
indexes.layout(),
&src.storage(),
src.layout(),
dim,
)?;
let op = BackpropOp::new3(self, indexes, src, |t1, t2, t3| {
Op::Scatter(t1, t2, t3, dim)
});
Ok(from_storage(storage, self.shape(), op, false))
}

pub fn scatter_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "scatter-add")?;
let source_dims = source.dims();
Expand Down
Loading
Loading