Skip to content

Commit

Permalink
feat: bitwise-ops-for-tensors (#2498)
Browse files Browse the repository at this point in the history
* feat: bitwise-ops-for-tensors

* add bitwise ops for jit

* patch: address-requested-changes

* feat: jit-binary-int-ops

* cargo lock

* feat: jit-backend bitwise not unary op

* feat: bitwise left shift and right shift ops

* patch: resolve review request changes

* patch: remove-dtype-int-op-desc

* refactor requested changes

* Add bitwise int ops to book + remove dead code

---------

Co-authored-by: Guillaume Lagrange <[email protected]>
  • Loading branch information
quinton11 and laggui authored Jan 24, 2025
1 parent e40c69b commit e73c2d9
Show file tree
Hide file tree
Showing 22 changed files with 1,836 additions and 70 deletions.
145 changes: 78 additions & 67 deletions burn-book/src/building-blocks/tensor.md

Large diffs are not rendered by default.

44 changes: 44 additions & 0 deletions crates/burn-autodiff/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -348,4 +348,48 @@ impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
fn int_argsort(tensor: IntTensor<Self>, dim: usize, descending: bool) -> IntTensor<Self> {
B::int_argsort(tensor, dim, descending)
}

fn bitwise_and(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
B::bitwise_and(lhs, rhs)
}

fn bitwise_and_scalar(lhs: IntTensor<Self>, rhs: B::IntElem) -> IntTensor<Self> {
B::bitwise_and_scalar(lhs, rhs)
}

fn bitwise_or(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
B::bitwise_or(lhs, rhs)
}

fn bitwise_or_scalar(lhs: IntTensor<Self>, rhs: B::IntElem) -> IntTensor<Self> {
B::bitwise_or_scalar(lhs, rhs)
}

fn bitwise_xor(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
B::bitwise_xor(lhs, rhs)
}

fn bitwise_xor_scalar(lhs: IntTensor<Self>, rhs: B::IntElem) -> IntTensor<Self> {
B::bitwise_xor_scalar(lhs, rhs)
}

fn bitwise_not(tensor: IntTensor<Self>) -> IntTensor<Self> {
B::bitwise_not(tensor)
}

fn bitwise_left_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
B::bitwise_left_shift(lhs, rhs)
}

fn bitwise_left_shift_scalar(lhs: IntTensor<Self>, rhs: B::IntElem) -> IntTensor<Self> {
B::bitwise_left_shift_scalar(lhs, rhs)
}

fn bitwise_right_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
B::bitwise_right_shift(lhs, rhs)
}

fn bitwise_right_shift_scalar(lhs: IntTensor<Self>, rhs: B::IntElem) -> IntTensor<Self> {
B::bitwise_right_shift_scalar(lhs, rhs)
}
}
43 changes: 43 additions & 0 deletions crates/burn-candle/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -372,4 +372,47 @@ impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F
fn int_sign(tensor: IntTensor<Self>) -> IntTensor<Self> {
sign(tensor)
}
fn bitwise_and(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
unimplemented!("bitwise_and is not implemented for Candle IntTensor");
}

fn bitwise_and_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
unimplemented!("bitwise_and_scalar is not implemented for Candle IntTensor");
}

fn bitwise_or(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
unimplemented!("bitwise_or is not implemented for Candle IntTensor");
}

fn bitwise_or_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
unimplemented!("bitwise_or_scalar is not implemented for Candle IntTensor");
}

fn bitwise_xor(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
unimplemented!("bitwise_xor is not implemented for Candle IntTensor");
}

fn bitwise_xor_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
unimplemented!("bitwise_xor_scalar is not implemented for Candle IntTensor");
}

fn bitwise_not(tensor: IntTensor<Self>) -> IntTensor<Self> {
unimplemented!("bitwise_not is not implemented for Candle IntTensor");
}

fn bitwise_left_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
unimplemented!("bitwise_left_shift is not implemented for Candle IntTensor");
}

fn bitwise_right_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
unimplemented!("bitwise_right_shift is not implemented for Candle IntTensor");
}

fn bitwise_left_shift_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
unimplemented!("bitwise_left_shift_scalar is not implemented for Candle IntTensor");
}

fn bitwise_right_shift_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
unimplemented!("bitwise_right_shift_scalar is not implemented for Candle IntTensor");
}
}
263 changes: 263 additions & 0 deletions crates/burn-fusion/src/ops/int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1819,4 +1819,267 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {

out
}

fn bitwise_and(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
binary_int_ops!(BitwiseAndOps, B::bitwise_and);

let stream_1 = lhs.stream;
let stream_2 = rhs.stream;
let out = lhs.client.tensor_uninitialized(
binary_ops_shape(&lhs.shape, &rhs.shape),
B::IntElem::dtype(),
);

let desc = BinaryOperationDescription {
lhs: lhs.into_description(),
rhs: rhs.into_description(),
out: out.to_description_out(),
};
out.client.register(
vec![stream_1, stream_2],
repr::OperationDescription::Int(IntOperationDescription::BitwiseAnd(desc.clone())),
BitwiseAndOps::<B>::new(desc),
);

out
}

fn bitwise_and_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
scalar_int_ops!(BitwiseAndOps, B::bitwise_and_scalar);

let stream = lhs.stream;
let out = lhs
.client
.tensor_uninitialized(lhs.shape.clone(), B::IntElem::dtype());

let desc = ScalarOperationDescription {
lhs: lhs.into_description(),
rhs: rhs.elem(),
out: out.to_description_out(),
};
out.client.register(
vec![stream],
repr::OperationDescription::Int(IntOperationDescription::BitwiseAndScalar(
desc.clone(),
)),
BitwiseAndOps::<B>::new(desc),
);

out
}

fn bitwise_or(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
binary_int_ops!(BitwiseOrOps, B::bitwise_or);

let stream_1 = lhs.stream;
let stream_2 = rhs.stream;
let out = lhs.client.tensor_uninitialized(
binary_ops_shape(&lhs.shape, &rhs.shape),
B::IntElem::dtype(),
);

let desc = BinaryOperationDescription {
lhs: lhs.into_description(),
rhs: rhs.into_description(),
out: out.to_description_out(),
};
out.client.register(
vec![stream_1, stream_2],
repr::OperationDescription::Int(IntOperationDescription::BitwiseOr(desc.clone())),
BitwiseOrOps::<B>::new(desc),
);

out
}

fn bitwise_or_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
scalar_int_ops!(BitwiseOrOps, B::bitwise_or_scalar);

let stream = lhs.stream;
let out = lhs
.client
.tensor_uninitialized(lhs.shape.clone(), B::IntElem::dtype());

let desc = ScalarOperationDescription {
lhs: lhs.into_description(),
rhs: rhs.elem(),
out: out.to_description_out(),
};
out.client.register(
vec![stream],
repr::OperationDescription::Int(IntOperationDescription::BitwiseOrScalar(desc.clone())),
BitwiseOrOps::<B>::new(desc),
);

out
}

fn bitwise_xor(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
binary_int_ops!(BitwiseXorOps, B::bitwise_xor);

let stream_1 = lhs.stream;
let stream_2 = rhs.stream;
let out = lhs.client.tensor_uninitialized(
binary_ops_shape(&lhs.shape, &rhs.shape),
B::IntElem::dtype(),
);

let desc = BinaryOperationDescription {
lhs: lhs.into_description(),
rhs: rhs.into_description(),
out: out.to_description_out(),
};
out.client.register(
vec![stream_1, stream_2],
repr::OperationDescription::Int(IntOperationDescription::BitwiseXor(desc.clone())),
BitwiseXorOps::<B>::new(desc),
);

out
}

fn bitwise_xor_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
scalar_int_ops!(BitwiseXorOps, B::bitwise_xor_scalar);

let stream = lhs.stream;
let out = lhs
.client
.tensor_uninitialized(lhs.shape.clone(), B::IntElem::dtype());

let desc = ScalarOperationDescription {
lhs: lhs.into_description(),
rhs: rhs.elem(),
out: out.to_description_out(),
};
out.client.register(
vec![stream],
repr::OperationDescription::Int(IntOperationDescription::BitwiseXorScalar(
desc.clone(),
)),
BitwiseXorOps::<B>::new(desc),
);

out
}

fn bitwise_not(tensor: IntTensor<Self>) -> IntTensor<Self> {
unary_int_ops!(BitwiseNotOps, B::bitwise_not);

let stream = tensor.stream;
let out = tensor
.client
.tensor_uninitialized(tensor.shape.clone(), B::IntElem::dtype());

let desc = UnaryOperationDescription {
input: tensor.into_description(),
out: out.to_description_out(),
};
out.client.register(
vec![stream],
repr::OperationDescription::Int(IntOperationDescription::BitwiseNot(desc.clone())),
BitwiseNotOps::<B>::new(desc),
);

out
}

fn bitwise_left_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
binary_int_ops!(BitwiseLeftShiftOps, B::bitwise_left_shift);

let stream_1 = lhs.stream;
let stream_2 = rhs.stream;
let out = lhs.client.tensor_uninitialized(
binary_ops_shape(&lhs.shape, &rhs.shape),
B::IntElem::dtype(),
);

let desc = BinaryOperationDescription {
lhs: lhs.into_description(),
rhs: rhs.into_description(),
out: out.to_description_out(),
};
out.client.register(
vec![stream_1, stream_2],
repr::OperationDescription::Int(IntOperationDescription::BitwiseLeftShift(
desc.clone(),
)),
BitwiseLeftShiftOps::<B>::new(desc),
);

out
}

fn bitwise_left_shift_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
scalar_int_ops!(BitwiseLeftShiftOps, B::bitwise_left_shift_scalar);

let stream = lhs.stream;
let out = lhs
.client
.tensor_uninitialized(lhs.shape.clone(), B::IntElem::dtype());

let desc = ScalarOperationDescription {
lhs: lhs.into_description(),
rhs: rhs.elem(),
out: out.to_description_out(),
};
out.client.register(
vec![stream],
repr::OperationDescription::Int(IntOperationDescription::BitwiseLeftShiftScalar(
desc.clone(),
)),
BitwiseLeftShiftOps::<B>::new(desc),
);

out
}

fn bitwise_right_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
binary_int_ops!(BitwiseRightShiftOps, B::bitwise_right_shift);

let stream_1 = lhs.stream;
let stream_2 = rhs.stream;
let out = lhs.client.tensor_uninitialized(
binary_ops_shape(&lhs.shape, &rhs.shape),
B::IntElem::dtype(),
);

let desc = BinaryOperationDescription {
lhs: lhs.into_description(),
rhs: rhs.into_description(),
out: out.to_description_out(),
};
out.client.register(
vec![stream_1, stream_2],
repr::OperationDescription::Int(IntOperationDescription::BitwiseRightShift(
desc.clone(),
)),
BitwiseRightShiftOps::<B>::new(desc),
);

out
}

fn bitwise_right_shift_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
scalar_int_ops!(BitwiseRightShiftOps, B::bitwise_right_shift_scalar);

let stream = lhs.stream;
let out = lhs
.client
.tensor_uninitialized(lhs.shape.clone(), B::IntElem::dtype());

let desc = ScalarOperationDescription {
lhs: lhs.into_description(),
rhs: rhs.elem(),
out: out.to_description_out(),
};
out.client.register(
vec![stream],
repr::OperationDescription::Int(IntOperationDescription::BitwiseRightShiftScalar(
desc.clone(),
)),
BitwiseRightShiftOps::<B>::new(desc),
);

out
}
}
Loading

0 comments on commit e73c2d9

Please sign in to comment.