Skip to content

Commit

Permalink
Add overloads for chunk / partition functions.
Browse files Browse the repository at this point in the history
  • Loading branch information
omilyutin-tt committed Dec 19, 2024
1 parent 9a5c33e commit 115db1e
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions ttnn/cpp/ttnn/tensor/xtensor/partition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,13 @@ std::vector<Tensor> chunk_impl(const Tensor& tensor, const TensorLayout& layout,
std::vector<Tensor> chunk(const Tensor& tensor, int num_chunks, int dim) {
const auto& reference_layout = tensor.tensor_spec().tensor_layout();
switch (reference_layout.get_data_type()) {
case DataType::BFLOAT16: return adaptor::chunk_impl<float>(tensor, reference_layout, num_chunks, dim);
case DataType::BFLOAT4_B:
case DataType::BFLOAT8_B:
case DataType::BFLOAT16:
case DataType::FLOAT32: return adaptor::chunk_impl<float>(tensor, reference_layout, num_chunks, dim);
case DataType::INT32: return adaptor::chunk_impl<int32_t>(tensor, reference_layout, num_chunks, dim);
case DataType::UINT8: return adaptor::chunk_impl<uint8_t>(tensor, reference_layout, num_chunks, dim);
case DataType::UINT16: return adaptor::chunk_impl<uint16_t>(tensor, reference_layout, num_chunks, dim);
case DataType::UINT32: return adaptor::chunk_impl<uint32_t>(tensor, reference_layout, num_chunks, dim);
default: TT_THROW("Unsupported data type: {}", reference_layout.get_data_type());
}
Expand All @@ -163,9 +167,13 @@ Tensor concat(const std::vector<Tensor>& tensors, int dim) {
TT_FATAL(tensors.size() > 0, "Cannot concatenate an empty list of tensors");
const auto& reference_layout = tensors.front().tensor_spec().tensor_layout();
switch (reference_layout.get_data_type()) {
case DataType::BFLOAT16: return adaptor::concat_impl<float>(tensors, reference_layout, dim);
case DataType::BFLOAT4_B:
case DataType::BFLOAT8_B:
case DataType::BFLOAT16:
case DataType::FLOAT32: return adaptor::concat_impl<float>(tensors, reference_layout, dim);
case DataType::INT32: return adaptor::concat_impl<int32_t>(tensors, reference_layout, dim);
case DataType::UINT8: return adaptor::concat_impl<uint8_t>(tensors, reference_layout, dim);
case DataType::UINT16: return adaptor::concat_impl<uint16_t>(tensors, reference_layout, dim);
case DataType::UINT32: return adaptor::concat_impl<uint32_t>(tensors, reference_layout, dim);
default: TT_THROW("Unsupported data type: {}", reference_layout.get_data_type());
}
Expand Down

0 comments on commit 115db1e

Please sign in to comment.