Skip to content

Commit

Permalink
[CCL] Fix padding issues (#16347)
Browse files Browse the repository at this point in the history
### Ticket
Link to Github Issue

### Problem description
Provide context for the problem.

### What's changed
Describe the approach used to solve the problem.
Summarize the changes made and its impact.

### Checklist
- [ ] Post commit CI passes
- [ ] Blackhole Post commit (if applicable)
- [ ] Model regression CI testing passes (if applicable)
- [ ] Device performance regression CI testing passes (if applicable)
- [ ] **(For models and ops writers)** Full [new
models](https://github.com/tenstorrent/tt-metal/actions/workflows/full-new-models-suite.yaml)
tests passes
- [ ] New/Existing tests provide coverage for changes
  • Loading branch information
Aswinmcw authored Jan 8, 2025
1 parent ccb55da commit 204ed99
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 16 deletions.
23 changes: 10 additions & 13 deletions tests/ttnn/unit_tests/operations/ccl/test_all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def is_unsupported_case(input_shape, dim, mem_config, num_devices, num_links, in
if layout == ttnn.ROW_MAJOR_LAYOUT and input_dtype == ttnn.bfloat8_b:
return True, "Invalid combination"

if input_shape[dim] % num_devices != 0 or (dim == 3 and input_shape[dim] // num_devices % 32 != 0):
if input_shape[dim] % num_devices != 0:
return True, "Unsupported test case"
if tile != (32, 32) and input_dtype != ttnn.bfloat16:
return True, "Tiny tile only supports bfloat16"
Expand All @@ -36,13 +36,7 @@ def is_unsupported_case(input_shape, dim, mem_config, num_devices, num_links, in
return True, "L1 buffer can't support large tensor sizes"

# Check that each chip has a non-zero amount of data available
min_sized_chunks_on_dim = input_shape[dim]
if dim == 3:
min_sized_chunks_on_dim //= 32
if dim == 2:
if layout == ttnn.TILE_LAYOUT:
min_sized_chunks_on_dim //= 32
if min_sized_chunks_on_dim < num_devices:
if input_shape[dim] < num_devices:
return (
True,
f"Input shape {input_shape} incompatible with {num_devices} on dim {dim} because some chips will have no tensor",
Expand Down Expand Up @@ -159,9 +153,8 @@ def run_all_gather_impl(
input_tensors = torch.chunk(input_tensor, num_devices, dim)
tt_input_tensors = []
for i, t in enumerate(input_tensors):
tt_input_tensors.append(
ttnn.Tensor(t, input_dtype, {}, ttnn.Tile(tile)).to(layout).to(mesh_device.get_devices()[i], mem_config)
)
t = ttnn.from_torch(t, input_dtype, layout=layout, tile=ttnn.Tile(tile))
tt_input_tensors.append(t.to(mesh_device.get_devices()[i], mem_config))

input_tensor_mesh = ttnn.aggregate_as_tensor(tt_input_tensors)
if trace_mode:
Expand All @@ -184,7 +177,7 @@ def run_all_gather_impl(
logger.info(f"Done iteration {i}")

for i, t in enumerate(ttnn.get_device_tensors(tt_out_tensor)):
tt_output_tensor = t.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch()
tt_output_tensor = ttnn.to_torch(t)
if input_dtype == ttnn.bfloat16:
eq, output = comp_equal(tt_output_tensor, input_tensor)
else:
Expand Down Expand Up @@ -328,6 +321,10 @@ def run_all_gather_on_t3000_impl_tight_loop(
# (4, 2, [4, 1, 256, 32], 0, ttnn.TILE_LAYOUT), # https://github.com/tenstorrent/tt-metal/issues/9686
# (8, 1, [8, 1, 256, 32], 0, ttnn.TILE_LAYOUT), # https://github.com/tenstorrent/tt-metal/issues/9686
(8, 1, [1, 1, 32, 16384], 3, ttnn.TILE_LAYOUT),
(8, 1, [1, 1, 8, 1], 2, ttnn.TILE_LAYOUT),
(8, 1, [1, 1, 2, 8], 3, ttnn.TILE_LAYOUT),
(8, 1, [1, 1, 2, 33 * 8], 3, ttnn.TILE_LAYOUT),
(8, 1, [1, 1, 67 * 8, 35], 2, ttnn.TILE_LAYOUT),
# (4, 2, [1, 1, 32, 32768], 3, ttnn.TILE_LAYOUT), # https://github.com/tenstorrent/tt-metal/issues/9686
# (4, 2, [4, 1, 256, 32], 0, ttnn.ROW_MAJOR_LAYOUT), # https://github.com/tenstorrent/tt-metal/issues/9686
# (8, 1, [8, 1, 256, 32], 0, ttnn.ROW_MAJOR_LAYOUT), # https://github.com/tenstorrent/tt-metal/issues/9686
Expand Down Expand Up @@ -1053,7 +1050,7 @@ def test_all_gather_on_t3000_nightly(
ttnn.MemoryConfig(buffer_type=ttnn.BufferType.L1),
],
)
def test_all_gather_on_t3000_nightly(
def test_all_gather_on_t3000_nightly_pcie(
pcie_mesh_device,
num_devices,
input_shape,
Expand Down
37 changes: 34 additions & 3 deletions ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@

#include "ttnn/tensor/tensor_utils.hpp"

#include "ttnn/cpp/ttnn/operations/data_movement/pad/pad.hpp"
#include "eth_l1_address_map.h"
#include "ttnn/cpp/ttnn/operations/copy.hpp"

namespace ttnn {
namespace ccl {
Expand Down Expand Up @@ -187,7 +189,7 @@ void AllGather::validate(const std::vector<Tensor>& input_tensors) const {
}

std::vector<ttnn::TensorSpec> AllGather::compute_output_specs(const std::vector<Tensor>& input_tensors) const {
auto output_shape = input_tensors[0].get_padded_shape(); // TODO: Replace with get_logical_shape()
auto output_shape = input_tensors[0].get_logical_shape();
output_shape[this->dim] *= this->ring_size;

const auto& input_tensor = input_tensors[0];
Expand Down Expand Up @@ -259,6 +261,8 @@ Tensor all_gather(
operation::launch_op(
[gather_dim,
num_links,
dim,
num_devices,
memory_config,
user_defined_num_workers,
user_defined_num_buffers_per_channel,
Expand All @@ -267,9 +271,29 @@ Tensor all_gather(
const std::vector<Tensor>& input_tensors,
const std::vector<std::optional<const Tensor>>& optional_input_tensors,
const std::vector<std::optional<Tensor>>& optional_output_tensors) mutable -> std::vector<Tensor> {
const auto& input_tensor = input_tensors.at(0);
auto input_tensor = input_tensors.at(0);

ttnn::SmallVector<uint32_t> unpad_elements = {
input_tensor.get_logical_shape()[-4],
input_tensor.get_logical_shape()[-3],
input_tensor.get_logical_shape()[-2],
input_tensor.get_logical_shape()[-1]};
bool needs_padding = input_tensor.get_layout() == Layout::TILE &&
(input_tensor.get_logical_shape()[-2] % tt::constants::TILE_HEIGHT != 0 ||
input_tensor.get_logical_shape()[-1] % tt::constants::TILE_WIDTH != 0);
if (needs_padding) {
ttnn::SmallVector<std::pair<uint32_t, uint32_t>> padding = {{0, 0}, {0, 0}, {0, 0}, {0, 0}};
DataType original_dtype = input_tensor.get_dtype();
if (input_tensor.get_dtype() != DataType::BFLOAT16 && input_tensor.get_dtype() != DataType::FLOAT32) {
input_tensor = ttnn::typecast(input_tensor, DataType::BFLOAT16);
}
input_tensor = ttnn::pad(0, input_tensor, padding, 0, false, std::nullopt);
if (original_dtype != input_tensor.get_dtype()) {
input_tensor = ttnn::typecast(input_tensor, original_dtype);
}
}

return operation::run(
auto output_tensor = operation::run(
ttnn::ccl::all_gather_detail::create_all_gather_struct(
input_tensor,
gather_dim,
Expand All @@ -280,9 +304,16 @@ Tensor all_gather(
devices,
ccl_topology),
{input_tensor});

if (needs_padding) {
return ttnn::ccl::unpad_output_tensor(output_tensor, num_devices, unpad_elements, dim);
} else {
return output_tensor;
}
},
{input_tensor},
output_tensors);

return output_tensors.at(0);
}

Expand Down
26 changes: 26 additions & 0 deletions ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

#include "ccl_host_datastructures.hpp"
#include "ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.hpp"
#include "ttnn/operations/data_movement/slice/slice.hpp"
#include "ttnn/operations/data_movement/concat/concat.hpp"

namespace ttnn {
namespace ccl {
Expand Down Expand Up @@ -81,6 +83,30 @@ std::tuple<uint32_t, std::optional<chip_id_t>, std::optional<chip_id_t>> get_dev
return {device_index, std::nullopt, std::nullopt}; // Return null if the device is not found
}

std::vector<ttnn::Tensor> unpad_output_tensor(
const std::vector<ttnn::Tensor>& output_tensor,
const uint32_t num_devices,
const ttnn::SmallVector<uint32_t>& unpad_elements,
const int dim){
std::vector<ttnn::Tensor> combined_tensors;

ttnn::SmallVector<uint32_t> begins = {0, 0, 0, 0};
ttnn::SmallVector<uint32_t> ends = {1, 1, 1, 1};
ttnn::SmallVector<uint32_t> step = {1, 1, 1, 1};
ends = unpad_elements;

for (int i = 0; i < num_devices; ++i) {
begins[dim] = i * output_tensor.at(0).get_logical_shape()[dim] / num_devices;
ends[dim] = begins[dim] + unpad_elements[dim];

ttnn::Tensor sliced_tensor = ttnn::slice(output_tensor.at(0), begins, ends, step);

combined_tensors.push_back(sliced_tensor);
}
ttnn::Tensor concat_tensor = ttnn::concat(combined_tensors, dim);
return {concat_tensor};
}

RingTopology::RingTopology(
Device const* device,
Topology topology,
Expand Down
5 changes: 5 additions & 0 deletions ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ std::tuple<uint32_t, std::optional<chip_id_t>, std::optional<chip_id_t>> get_dev
const std::vector<Device*>& devices,
const ttnn::ccl::Topology& topology);

std::vector<ttnn::Tensor> unpad_output_tensor(
const std::vector<ttnn::Tensor>& output_tensor,
const uint32_t num_devices,
const ttnn::SmallVector<uint32_t>& unpad_elements,
const int dim);

class LineTopology {
public:
Expand Down

0 comments on commit 204ed99

Please sign in to comment.