Skip to content

Commit

Permalink
fixes for downstream test failures and MR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
amorrisonTT committed Feb 25, 2025
1 parent 787a129 commit 5215684
Show file tree
Hide file tree
Showing 9 changed files with 70 additions and 83 deletions.
30 changes: 3 additions & 27 deletions tests/ttnn/unit_tests/operations/test_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,35 +746,11 @@ def test_slice_adversarial_fixed(input_shape, dim, start, end, step, layout, dev
assert_with_pcc(torch_output_tensor, ttnn_output_tensor, 0.999)


@pytest.mark.skip("#15796 1D tiled support is not available as of yet")
@pytest.mark.parametrize(
"input_shape, dim, start, end, step, layout",
(([3], 0, 0, -1, 1, ttnn.TILE_LAYOUT),), # Difference in expected shape as it's a 1D tensor
)
def test_slice_adversarial(input_shape, dim, start, end, step, layout, device):
torch_input = torch.randn(input_shape, dtype=torch.bfloat16)

slice_obj = slice(start, end, step)

# Prepare indices for slicing in the specified dimension
indices = [slice(None)] * len(input_shape) # By default, select all elements along every dimension
indices[dim] = slice_obj # Apply slicing to the target dimension
indices = tuple(indices)

# Apply slicing to the input_tensor
torch_output_tensor = torch_input[indices]

ttnn_tensor = ttnn.from_torch(torch_input, device=device, layout=layout, dtype=ttnn.bfloat16)
ttnn_output = ttnn_tensor[indices]

ttnn_output_tensor = ttnn.from_device(ttnn_output).to_torch_with_logical_shape()

assert_with_pcc(torch_output_tensor, ttnn_output_tensor, 0.999)


@pytest.mark.parametrize(
"input_shape, input_start, input_ends, input_stride",
(
([3], [0, 0], [-1, -1], [1, 1]),
([1, 7], [0, 0], [-1, -1], [1, 1]),
([3234, 4], [0, 2], [3234, 4], [1, 1]),
([196, 196, 2], [0, 0, 1], [196, 196, 2], [1, 1, 1]),
([1, 23, 40], [0, 0, 39], [1, 23, 40], [1, 1, 1]),
Expand All @@ -788,7 +764,7 @@ def test_slice_adversarial(input_shape, dim, start, end, step, layout, device):
"input_memory_config",
(ttnn.L1_MEMORY_CONFIG, ttnn.DRAM_MEMORY_CONFIG),
)
def test_slice_adversarial_flexible(
def test_slice_pytorch2_former_failures(
input_shape, input_start, input_ends, input_stride, layout, input_memory_config, device
):
if layout == ttnn.TILE_LAYOUT:
Expand Down
10 changes: 0 additions & 10 deletions tests/ttnn/unit_tests/test_getitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,16 +76,6 @@ def test_getitem_2d(device, height, width, input_layout, on_device):
assert torch.allclose(torch_output_tensor, output_tensor)


def test_getitem_scalar_output():
torch_input_tensor = torch.rand((16, 32), dtype=torch.bfloat16)

input_tensor = ttnn.from_torch(torch_input_tensor)

with pytest.raises(RuntimeError) as e:
input_tensor[0, 0]
assert "Host tensor slice cannot return a scalar or empty tensor" in str(e.value)


@pytest.mark.parametrize("batch_sizes", [(), (1, 1)])
@pytest.mark.parametrize("height", [32, 64])
@pytest.mark.parametrize("width", [32, 96])
Expand Down
4 changes: 2 additions & 2 deletions ttnn/cpp/pybind11/pytensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1212,8 +1212,8 @@ void pytensor_module(py::module& m_tensor) {
.def(
"unpad",
[](const Tensor& self,
const std::array<uint32_t, 4>& output_tensor_start,
const std::array<uint32_t, 4>& output_tensor_end) {
const ttnn::SmallVector<uint32_t>& output_tensor_start,
const ttnn::SmallVector<uint32_t>& output_tensor_end) {
return self.unpad(ttnn::Shape(output_tensor_start), ttnn::Shape(output_tensor_end));
},
R"doc(
Expand Down
2 changes: 0 additions & 2 deletions ttnn/cpp/ttnn/operations/data_movement/concat/concat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
#include "cpp/ttnn/operations/data_movement/common/common.hpp"
#include "cpp/ttnn/operations/data_movement/transpose/transpose.hpp"
#include "cpp/ttnn/operations/data_movement/tilize_with_val_padding/tilize_with_val_padding.hpp"
// #include "cpp/ttnn/operations/data_movement/slice/slice.hpp"
// #include "cpp/ttnn/operations/data_movement/slice/device/slice_op.hpp"
#include "cpp/ttnn/operations/data_movement/untilize_with_unpadding/untilize_with_unpadding.hpp"

#include <ranges>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ void SliceDeviceOperation::validate_with_output_tensors(
TT_FATAL(this->slice_start[i] <= this->slice_end[i], "Error");
}
if (!output_tensors.empty() && output_tensors[0].has_value()) {
const auto output_shape_required = compute_output_specs(input_tensors)[0].logical_shape();
const auto output_shape_required = compute_output_specs(input_tensors)[0].padded_shape();
const auto& out_tensor = output_tensors[0].value();
TT_FATAL(
out_tensor.get_padded_shape() == output_shape_required,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,28 +35,28 @@ inline std::vector<std::pair<std::vector<uint32_t>, std::vector<uint32_t>>> get_
auto input_shape = input_tensor.get_logical_shape();
auto output_shape = output_tensor.get_logical_shape();

uint32_t padded_row_size_bytes = input_shape[-1] * input_tensor.element_size();
uint32_t unpadded_row_size_bytes = output_shape[-1] * input_tensor.element_size();
uint32_t input_row_size_bytes = input_shape[-1] * input_tensor.element_size();
uint32_t output_row_size_bytes = output_shape[-1] * input_tensor.element_size();

std::uint32_t num_dims = static_cast<std::uint32_t>(input_shape.rank());
std::vector<uint32_t> num_unpadded_sticks_per_dim(num_dims);
std::vector<uint32_t> num_padded_sticks_per_dim(num_dims);
std::vector<uint32_t> num_output_sticks_per_dim(num_dims);
std::vector<uint32_t> num_input_sticks_per_dim(num_dims);
std::vector<uint32_t> id_per_dim(num_dims);

std::vector<uint32_t> accumulated_total_per_dim(num_dims);

// TODO: Remove first element of these arrays and update kernel accordingly
// This currently just matches tile version where we iterate over the row as well
num_unpadded_sticks_per_dim[0] = 1;
num_padded_sticks_per_dim[0] = 0;
num_output_sticks_per_dim[0] = 1;
num_input_sticks_per_dim[0] = 0;
accumulated_total_per_dim[0] = 1;

for (int32_t i = 1; i < num_dims; i++) {
uint32_t num_unpadded_dim = output_shape[-(i + 1)];
uint32_t num_output_dim = output_shape[-(i + 1)];
uint32_t num_total_dim = input_shape[-(i + 1)];
uint32_t num_padded_dim = (num_total_dim - num_unpadded_dim) * accumulated_total_per_dim[i - 1];
num_unpadded_sticks_per_dim[i] = num_unpadded_dim;
num_padded_sticks_per_dim[i] = num_padded_dim;
uint32_t num_input_dim = (num_total_dim - num_output_dim) * accumulated_total_per_dim[i - 1];
num_output_sticks_per_dim[i] = num_output_dim;
num_input_sticks_per_dim[i] = num_input_dim;
accumulated_total_per_dim[i] = num_total_dim * accumulated_total_per_dim[i - 1];
}

Expand All @@ -71,22 +71,22 @@ inline std::vector<std::pair<std::vector<uint32_t>, std::vector<uint32_t>>> get_
uint32_t begins_bytes = output_tensor_start[-1] * input_tensor.element_size();
uint32_t misalignment = begins_bytes % SRC_BUFFER_ALIGNMENT;

uint32_t unpadded_row_size_bytes_offset = tt::round_up(unpadded_row_size_bytes, ALIGNMENT);
uint32_t output_row_size_bytes_offset = tt::round_up(output_row_size_bytes, ALIGNMENT);
uint32_t start_addr = input_tensor.buffer()->address();
std::vector<uint32_t> common_reader_kernel_args = {
start_addr + begins_bytes - misalignment, // read from nearest aligned address
padded_row_size_bytes,
unpadded_row_size_bytes,
unpadded_row_size_bytes_offset,
input_row_size_bytes,
output_row_size_bytes,
output_row_size_bytes_offset,
num_dims,
0,
0,
0,
0};
common_reader_kernel_args.insert(
common_reader_kernel_args.end(), num_unpadded_sticks_per_dim.begin(), num_unpadded_sticks_per_dim.end());
common_reader_kernel_args.end(), num_output_sticks_per_dim.begin(), num_output_sticks_per_dim.end());
common_reader_kernel_args.insert(
common_reader_kernel_args.end(), num_padded_sticks_per_dim.begin(), num_padded_sticks_per_dim.end());
common_reader_kernel_args.end(), num_input_sticks_per_dim.begin(), num_input_sticks_per_dim.end());

std::vector<std::pair<std::vector<uint32_t>, std::vector<uint32_t>>> ret_val(num_cores_total);

Expand All @@ -108,22 +108,22 @@ inline std::vector<std::pair<std::vector<uint32_t>, std::vector<uint32_t>>> get_
if (num_sticks_per_core != 0) {
auto num_sticks_per_core_pad32 = num_sticks_per_core + (32 - num_sticks_per_core % 32) % 32;
num_sticks_per_core_read = tt::tt_metal::merge_num_sticks_to_read(
num_sticks_per_core_pad32, unpadded_row_size_bytes_offset, max_read_size);
num_sticks_per_core_pad32, output_row_size_bytes_offset, max_read_size);
num_read_per_barrier = num_sticks_per_core_pad32 / num_sticks_per_core_read;
}

id_per_dim[0] = num_sticks_written % num_unpadded_sticks_per_dim[0];
uint32_t unpadded_written = num_sticks_written / num_unpadded_sticks_per_dim[0];
id_per_dim[0] = num_sticks_written % num_output_sticks_per_dim[0];
uint32_t output_written = num_sticks_written / num_output_sticks_per_dim[0];
uint32_t start_id = id_per_dim[0] + start_offset;

for (uint32_t j = 1; j < num_dims; j++) {
id_per_dim[j] = unpadded_written % num_unpadded_sticks_per_dim[j];
unpadded_written = unpadded_written / num_unpadded_sticks_per_dim[j];
id_per_dim[j] = output_written % num_output_sticks_per_dim[j];
output_written = output_written / num_output_sticks_per_dim[j];
start_id += id_per_dim[j] * accumulated_total_per_dim[j - 1];
}
std::vector<uint32_t> reader_kernel_args = common_reader_kernel_args;
//
uint32_t addr_offset = 5; // input buffer addr, padded_row_size_bytes, unpadded_row_size_bytes, num_dims
uint32_t addr_offset = 5; // input buffer addr, input_row_size_bytes, output_row_size_bytes, num_dims
reader_kernel_args[addr_offset++] = start_id;
reader_kernel_args[addr_offset++] = num_sticks_per_core;
reader_kernel_args[addr_offset++] = num_sticks_per_core_read;
Expand All @@ -132,8 +132,8 @@ inline std::vector<std::pair<std::vector<uint32_t>, std::vector<uint32_t>>> get_

std::vector<uint32_t> writer_kernel_args = {
output_buffer->address(),
unpadded_row_size_bytes,
unpadded_row_size_bytes_offset,
output_row_size_bytes,
output_row_size_bytes_offset,
num_sticks_per_core,
num_sticks_per_core_read,
num_read_per_barrier,
Expand All @@ -155,7 +155,7 @@ operation::ProgramWithCallbacks slice_rm_multi_core(
// This should allocate a DRAM buffer on the device
tt::tt_metal::IDevice* device = a.device();

uint32_t num_unpadded_sticks = output.volume() / output.get_logical_shape()[-1];
uint32_t num_output_sticks = output.volume() / output.get_logical_shape()[-1];

auto compute_with_storage_grid_size = device->compute_with_storage_grid_size();
uint32_t num_cores_x = compute_with_storage_grid_size.x;
Expand All @@ -164,23 +164,23 @@ operation::ProgramWithCallbacks slice_rm_multi_core(
CoreRange total_cores({0, 0}, {num_cores_x - 1, num_cores_y - 1});
uint32_t num_cores_total = num_cores_x * num_cores_y;
auto [num_cores, all_cores, core_group_1, core_group_2, num_sticks_per_core_group_1, num_sticks_per_core_group_2] =
tt::tt_metal::split_work_to_cores(compute_with_storage_grid_size, num_unpadded_sticks);
tt::tt_metal::split_work_to_cores(compute_with_storage_grid_size, num_output_sticks);

tt::tt_metal::Buffer* src0_buffer = a.buffer();

tt::DataFormat cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(a.get_dtype());

uint32_t padded_row_size_bytes = a.get_logical_shape()[-1] * a.element_size();
uint32_t unpadded_row_size_bytes = output_shape[-1] * a.element_size();
uint32_t input_row_size_bytes = a.get_logical_shape()[-1] * a.element_size();
uint32_t output_row_size_bytes = output_shape[-1] * a.element_size();

tt::tt_metal::Buffer* dst_buffer = output.buffer();
TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!");

bool src0_is_dram = src0_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0;
bool dst_is_dram = dst_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0;

uint32_t src_stick_size = padded_row_size_bytes;
uint32_t dst_stick_size = unpadded_row_size_bytes;
uint32_t src_stick_size = input_row_size_bytes;
uint32_t dst_stick_size = output_row_size_bytes;

uint32_t src0_cb_index = 0;
uint32_t max_read_size = 4096;
Expand All @@ -200,7 +200,7 @@ operation::ProgramWithCallbacks slice_rm_multi_core(
if (misalignment != 0) {
ALIGNMENT *= 2;
}
uint32_t cb_page_size = tt::round_up(unpadded_row_size_bytes, ALIGNMENT);
uint32_t cb_page_size = tt::round_up(output_row_size_bytes, ALIGNMENT);

uint32_t num_input_pages = num_sticks_per_core_group_1 > num_sticks_per_core_group_2 ? num_sticks_per_core_group_1
: num_sticks_per_core_group_2;
Expand Down Expand Up @@ -265,15 +265,15 @@ operation::ProgramWithCallbacks slice_rm_multi_core(
uint32_t num_cores_x = compute_with_storage_grid_size.x;
uint32_t num_cores_y = compute_with_storage_grid_size.y;
uint32_t num_cores_total = num_cores_x * num_cores_y;
uint32_t num_unpadded_sticks = dst_tensor.volume() / dst_tensor.get_logical_shape()[-1];
uint32_t num_output_sticks = dst_tensor.volume() / dst_tensor.get_logical_shape()[-1];
auto
[num_cores,
all_cores,
core_group_1,
core_group_2,
num_sticks_per_core_group_1,
num_sticks_per_core_group_2] =
tt::tt_metal::split_work_to_cores(compute_with_storage_grid_size, num_unpadded_sticks);
tt::tt_metal::split_work_to_cores(compute_with_storage_grid_size, num_output_sticks);

const auto tensor_start =
static_cast<const ttnn::operations::data_movement::SliceDeviceOperation*>(operation)->slice_start;
Expand Down
14 changes: 9 additions & 5 deletions ttnn/cpp/ttnn/operations/data_movement/slice/slice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
//
// SPDX-License-Identifier: Apache-2.0

#include "ttnn/common/constants.hpp"
#include "slice.hpp"
#include "device/slice_op.hpp"
#include "ttnn/common/queue_id.hpp"
#include "ttnn/run_operation.hpp"
#include "ttnn/operations/core/core.hpp"
#include "ttnn/common/queue_id.hpp"
#include "cpp/ttnn/operations/creation.hpp"
#include "ttnn/common/constants.hpp"
#include "cpp/ttnn/operations/data_movement/copy/copy.hpp"
#include "cpp/ttnn/operations/data_movement/unsqueeze/unsqueeze.hpp"
#include "cpp/ttnn/operations/data_movement/common/common.hpp"
Expand Down Expand Up @@ -107,7 +109,10 @@ ttnn::Tensor SliceOperation::invoke(
}
rm_only = !no_step || !aligned_begins || !aligned_ends || one_dimensional;
if (rm_only) {
TT_FATAL(input.get_dtype() == DataType::BFLOAT16, "Strided slice is not supported for BFLOAT8 tensors");
if (!no_step) {
TT_FATAL(
input.get_dtype() != DataType::BFLOAT8_B, "Strided slice is not supported for BFLOAT8 tensors");
}
input = ttnn::to_layout(input, Layout::ROW_MAJOR, std::nullopt, memory_config, (IDevice*)nullptr);
}
}
Expand Down Expand Up @@ -173,7 +178,6 @@ ttnn::Tensor SliceOperation::invoke(
return SliceOperation::invoke<T>(ttnn::DefaultQueueId, input_tensor, begins, ends, step, memory_config_arg, optional_output_tensor);
}


template <typename T, std::size_t N>
ttnn::Tensor SliceOperation::invoke(
QueueId queue_id,
Expand Down Expand Up @@ -243,7 +247,7 @@ template ttnn::Tensor SliceOperation::invoke<uint32_t>(
const std::optional<Tensor>& optional_output_tensor);

template ttnn::Tensor SliceOperation::invoke<uint32_t, 4>(
uint8_t queue_id,
QueueId queue_id,
const ttnn::Tensor& input_tensor,
const std::array<uint32_t, 4>& output_tensor_start,
const std::array<uint32_t, 4>& output_tensor_end,
Expand All @@ -260,7 +264,7 @@ template ttnn::Tensor SliceOperation::invoke<uint32_t, 4>(
const std::optional<Tensor>& optional_output_tensor);

template ttnn::Tensor SliceOperation::invoke<uint32_t, 3>(
uint8_t queue_id,
QueueId queue_id,
const ttnn::Tensor& input_tensor,
const std::array<uint32_t, 3>& output_tensor_start,
const std::array<uint32_t, 3>& output_tensor_end,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,10 @@ std::vector<std::optional<Tensor>> ExecuteBackwardConcat::invoke(
if (are_required_outputs[0]) {
ttnn::SmallVector<uint32_t> start_index = {0, 0, 0, 0};
ttnn::SmallVector<uint32_t> end_index = {
input.padded_shape()[0], input.padded_shape()[1], input.padded_shape()[2], input.padded_shape()[3]};
input.get_logical_shape()[0],
input.get_logical_shape()[1],
input.get_logical_shape()[2],
input.get_logical_shape()[3]};
ttnn::SmallVector<uint32_t> step = {1, 1, 1, 1};
ttnn::slice(queue_id, grad, start_index, end_index, step, std::nullopt, input_grad);
grad_tensor[0] = input_grad;
Expand All @@ -525,7 +528,10 @@ std::vector<std::optional<Tensor>> ExecuteBackwardConcat::invoke(
start_index_2 = {0, 0, 0, input.padded_shape()[3]};
}
ttnn::SmallVector<uint32_t> end_index_2 = {
grad.padded_shape()[0], grad.padded_shape()[1], grad.padded_shape()[2], grad.padded_shape()[3]};
grad.get_logical_shape()[0],
grad.get_logical_shape()[1],
grad.get_logical_shape()[2],
grad.get_logical_shape()[3]};
ttnn::SmallVector<uint32_t> step_2 = {1, 1, 1, 1};
ttnn::slice(queue_id, grad, start_index_2, end_index_2, step_2, std::nullopt, other_grad);
grad_tensor[1] = other_grad;
Expand Down
15 changes: 14 additions & 1 deletion ttnn/ttnn/operations/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ def _golden_function(input_tensor: ttnn.Tensor, slices):
return output_tensor


def _host_slice_with_unpad(input_tensor: ttnn.Tensor, begins, ends) -> ttnn.Tensor:
"""Hacky fallback to old `unpad` methods for host based accessing"""

working_tensor = ttnn.to_layout(input_tensor, ttnn.ROW_MAJOR_LAYOUT).unpad(begins, ends)
working_tensor = ttnn.view(working_tensor, [e - b for e, b in zip(ends, begins)])
return ttnn.to_layout(working_tensor, input_tensor.get_layout())


@ttnn.register_python_operation(
name="ttnn.Tensor.__getitem__",
is_method=True,
Expand Down Expand Up @@ -132,7 +140,12 @@ def __getitem__(input_tensor: ttnn.Tensor, slices) -> ttnn.Tensor:
slice_step.append(stp)

# 5) Perform the slicing
output = ttnn.slice(input_tensor, slice_start, slice_end, slice_step)
if ttnn.is_tensor_storage_on_device(input_tensor):
output = ttnn.slice(input_tensor, slice_start, slice_end, slice_step)
else:
if not all([s == 1 for s in slice_step]):
raise RuntimeError("Host tensors cannot be accessed with non-unit stride")
output = _host_slice_with_unpad(input_tensor, slice_start, slice_end)

# 6) Squeeze out all dimensions that were indexed by an integer.
# We do this from left to right, adjusting each subsequent dimension index
Expand Down

0 comments on commit 5215684

Please sign in to comment.