Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable dps ops for matmul #15285

Merged
merged 28 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
3910d7a
enable dps ops for matmul
asandhupatlaTT Nov 21, 2024
9e8bac0
send opt op tensors as i/p for op launch
asandhupatlaTT Nov 21, 2024
3aed6d1
fix failing test case
asandhupatlaTT Nov 25, 2024
5121894
pass opt op arg for entire call flow
asandhupatlaTT Nov 27, 2024
8059099
refactor
asandhupatlaTT Nov 27, 2024
ed4d0f7
Update ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp
asandhupatlaTT Nov 27, 2024
b5e9089
add const to hpp file
asandhupatlaTT Nov 28, 2024
3f18bd0
adding missing &. idk why this happenned
asandhupatlaTT Nov 28, 2024
f5854af
add missing &. part 2
asandhupatlaTT Nov 28, 2024
2066ef0
support validate for opt op tensoirs
asandhupatlaTT Dec 2, 2024
b26cfe4
fix bugs
asandhupatlaTT Dec 3, 2024
267276c
fix typo
asandhupatlaTT Dec 3, 2024
8557a10
undo gather matmul changes
asandhupatlaTT Dec 4, 2024
4d7d2af
undo all gather part 2
asandhupatlaTT Dec 4, 2024
0121d7a
fix typo
asandhupatlaTT Dec 4, 2024
38a9876
final patch before rebase
asandhupatlaTT Dec 4, 2024
87db5fa
add as comment instead od removing it
asandhupatlaTT Dec 5, 2024
194698b
enable compute_output_spec
asandhupatlaTT Dec 6, 2024
0d8b170
enable the checking now
asandhupatlaTT Dec 6, 2024
e8d0e83
my test case passes now
asandhupatlaTT Dec 6, 2024
1ff828c
add similar checks for linear too
asandhupatlaTT Dec 6, 2024
5600b7d
address eyon comments
asandhupatlaTT Dec 11, 2024
d76594a
address reviewer comments
asandhupatlaTT Dec 13, 2024
d73bc62
fix compiler errors
asandhupatlaTT Dec 13, 2024
501d164
fix merge conflict typo
asandhupatlaTT Dec 13, 2024
8e64289
address brian comments
asandhupatlaTT Dec 18, 2024
6825a3d
fix typo
asandhupatlaTT Dec 18, 2024
9edd1f9
cosmetic fixes
asandhupatlaTT Dec 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions tests/ttnn/unit_tests/operations/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,3 +262,51 @@ def test_bloom_ff2_linear(device):
)

assert ttnn.pearson_correlation_coefficient(torch_output, output) >= 0.9992


@pytest.mark.parametrize("batch_size", [1, 8])
@pytest.mark.parametrize("m_size", [32, 64])
@pytest.mark.parametrize("k_size", [1024, 2048])
@pytest.mark.parametrize("n_size", [1024, 2048])
@pytest.mark.parametrize("activation", [None, "relu"])
def test_linear_by_passing_in_1D_systolic_array_program_config_and_optional_outout_tensor(
device, batch_size, m_size, k_size, n_size, activation
):
torch.manual_seed(0)

torch_input_tensor_a = torch.randn((batch_size, m_size, k_size), dtype=torch.bfloat16)
torch_input_tensor_b = torch.randn((k_size, n_size), dtype=torch.bfloat16)
torch_output_tensor = torch_input_tensor_a @ torch_input_tensor_b
if activation == "relu":
torch_output_tensor = torch.relu(torch_output_tensor)

input_tensor_a = ttnn.from_torch(torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device)
input_tensor_b = ttnn.from_torch(torch_input_tensor_b, layout=ttnn.TILE_LAYOUT, device=device)

torch_opt_output_tensor = torch.zeros_like(torch_output_tensor)
optional_output_tensor = ttnn.from_torch(torch_opt_output_tensor, layout=ttnn.TILE_LAYOUT, device=device)

output_tensor = ttnn.linear(
input_tensor_a,
input_tensor_b,
activation=activation,
core_grid=device.core_grid,
)

output_tensor = ttnn.to_torch(output_tensor)

ttnn.linear(
input_tensor_a,
input_tensor_b,
activation=activation,
optional_output_tensor=optional_output_tensor,
core_grid=device.core_grid,
)

optional_output_tensor = ttnn.to_torch(optional_output_tensor)

assert len(output_tensor.shape) == len(torch_output_tensor.shape) == len(optional_output_tensor.shape)
assert output_tensor.shape == torch_output_tensor.shape == optional_output_tensor.shape
assert_with_pcc(torch_output_tensor, output_tensor, 0.997)
assert_with_pcc(torch_output_tensor, optional_output_tensor, 0.997)
assert_with_pcc(optional_output_tensor, output_tensor, 0.997)
31 changes: 31 additions & 0 deletions tests/ttnn/unit_tests/operations/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -2079,3 +2079,34 @@ def test_interleaved_input_sharded_output_matmul(device):
output3 = ttnn.matmul(input_tensor_a, input_tensor_b, memory_config=out_mem_config)
output_tensor = ttnn.to_torch(output3)
assert_with_pcc(torch_output_tensor, output_tensor, pcc=pcc)


@pytest.mark.parametrize(
"n_size, c, m, k, n",
[
(1, 1, 1024, 64, 512),
],
)
def test_optional_output_argument(device, n_size, c, m, k, n):
torch.manual_seed(0)

torch_input_tensor_a = torch.rand((n_size, c, m, k), dtype=torch.bfloat16)
torch_input_tensor_b = torch.rand((n_size, c, k, n), dtype=torch.bfloat16)
torch_output_tensor = torch.matmul(torch_input_tensor_a, torch_input_tensor_b)
torch_opt_output_tensor = torch.zeros_like(torch_output_tensor)

input_tensor_a = ttnn.from_torch(torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device)
input_tensor_b = ttnn.from_torch(torch_input_tensor_b, layout=ttnn.TILE_LAYOUT, device=device)
optional_output_tensor = ttnn.from_torch(torch_opt_output_tensor, layout=ttnn.TILE_LAYOUT, device=device)

output = ttnn.matmul(input_tensor_a, input_tensor_b)
output = ttnn.to_torch(output)

ttnn.matmul(input_tensor_a, input_tensor_b, optional_output_tensor=optional_output_tensor)
optional_output_tensor = ttnn.to_torch(optional_output_tensor)

assert len(output.shape) == len(torch_output_tensor.shape) == len(optional_output_tensor.shape)
assert output.shape == torch_output_tensor.shape == optional_output_tensor.shape
assert_with_pcc(torch_output_tensor, output, 0.999)
assert_with_pcc(torch_output_tensor, optional_output_tensor, 0.999)
assert_with_pcc(output, optional_output_tensor, 0.999)
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ namespace experimental {

void AllGatherMatmul::validate(
const std::vector<Tensor>& input_tensors,
const std::vector<std::optional<const ttnn::Tensor>>& optional_input_tensors) const {
const std::vector<std::optional<const ttnn::Tensor>>& optional_input_tensors,
const std::vector<std::optional<Tensor>>& optional_output_tensors) const {
TT_ASSERT(
input_tensors.size() == 4,
"AllGatherMatmul requires 4 input tensors: [input, weight, all_gather_output, datacopy_output]");
Expand All @@ -33,7 +34,7 @@ void AllGatherMatmul::validate(
this->all_gather_struct.validate({input_tensor});

// Matmul validate.
this->matmul_struct.validate({all_gather_output_tensor, weight_tensor}, optional_input_tensors);
this->matmul_struct.validate({all_gather_output_tensor, weight_tensor}, optional_input_tensors, {});

// All Gather Matmul validate
TT_FATAL(this->all_gather_struct.dim == 3, "AllGatherMatmul requires dim=3 for the AllGather operaitons.");
Expand Down Expand Up @@ -73,7 +74,7 @@ std::vector<ttnn::TensorSpec> AllGatherMatmul::compute_output_specs(const std::v

// Matmul shape
ttnn::TensorSpec matmul_output_specs =
this->matmul_struct.compute_output_specs({input_tensors[1], input_tensors[2]})[0];
this->matmul_struct.compute_output_specs({input_tensors[1], input_tensors[2]}, {})[0];

return {all_gather_output_shape, matmul_output_specs, datacopy_output_shape};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ struct AllGatherMatmul {
/* General */
void validate(
const std::vector<Tensor>& input_tensors,
const std::vector<std::optional<const Tensor>>& optional_input_tensors) const;
const std::vector<std::optional<const Tensor>>& optional_input_tensors,
const std::vector<std::optional<Tensor>>& optional_output_tensors = {std::nullopt}) const;
std::vector<ttnn::TensorSpec> compute_output_specs(const std::vector<Tensor>& input_tensors) const;
std::vector<Tensor> create_output_tensors(const std::vector<Tensor>& input_tensors) const;
operation::ProgramWithCallbacks create_program(
Expand Down
100 changes: 87 additions & 13 deletions ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1003,7 +1003,10 @@ namespace operations {
namespace matmul {

Matmul create_matmul_struct(
const Tensor& input_tensor_a, const Tensor& input_tensor_b, const struct Matmul& parameters) {
const Tensor& input_tensor_a,
const Tensor& input_tensor_b,
const struct Matmul& parameters,
const std::vector<std::optional<Tensor>>& optional_output_tensors) {
auto arch = input_tensor_a.device()->arch();
const bool has_user_grid = parameters.user_core_coord.has_value();
const bool has_program_config = parameters.program_config.has_value();
Expand All @@ -1022,16 +1025,48 @@ Matmul create_matmul_struct(
bool broadcast_batch =
parameters.bcast_batch.value_or(get_broadcast_batch(input_tensor_a, input_tensor_b, parameters.program_config));
TT_FATAL(!(has_user_grid && has_program_config), "Cannot use both user core grid/coordinates and a program config");

const bool is_optional_output_tensor =
!optional_output_tensors.empty() && optional_output_tensors.at(0).has_value();
std::optional<DataType> output_dtype = parameters.output_dtype;
MemoryConfig output_mem_config = parameters.output_mem_config;

if (is_optional_output_tensor) {
const auto& optional_output_tensor = optional_output_tensors.at(0);
if (output_mem_config == operation::DEFAULT_OUTPUT_MEMORY_CONFIG) {
output_mem_config = optional_output_tensor->memory_config();
} else {
TT_FATAL(
optional_output_tensor->memory_config() == output_mem_config,
"Memory config mismatch between optional output tensor {} & output tensor {}",
optional_output_tensor->memory_config(),
output_mem_config);
}

if (output_dtype.has_value()) {
TT_FATAL(
optional_output_tensor->get_dtype() == output_dtype.value(),
"Type mismatch between optional output tensor {} & output tensor {}",
optional_output_tensor->get_dtype(),
output_dtype.value());
} else {
output_dtype = optional_output_tensor->get_dtype();
}
} else {
if (!output_dtype.has_value()) {
output_dtype = input_tensor_a.get_dtype();
}
}

auto in0_tile = input_tensor_a.get_tensor_spec().tile();
auto in1_tile = input_tensor_b.get_tensor_spec().tile();
tt::tt_metal::Tile output_tile =
get_output_tile(parameters.output_mem_config, in0_tile, in1_tile, parameters.output_tile);
tt::tt_metal::Tile output_tile = get_output_tile(output_mem_config, in0_tile, in1_tile, parameters.output_tile);

return Matmul{
parameters.program_config,
broadcast_batch,
parameters.output_mem_config,
parameters.output_dtype.value_or(input_tensor_a.get_dtype()),
output_mem_config,
output_dtype,
kernel_config_val,
parameters.untilize_out,
parameters.user_core_coord,
Expand All @@ -1047,9 +1082,11 @@ Tensor matmul(
const Tensor& input_tensor_b,
const std::optional<const Tensor>& bias,
const struct Matmul& parameters,
const uint8_t queue_id) {
const uint8_t queue_id,
const std::optional<Tensor>& optional_output_tensor) {
std::vector<std::optional<const Tensor>> optional_input_tensors = {};
std::vector<Tensor> output_tensors;

if (bias.has_value()) {
optional_input_tensors.push_back(bias.value());
output_tensors = {
Expand All @@ -1068,21 +1105,23 @@ Tensor matmul(
const auto& input_tensor_b = input_tensors.at(1);

return operation::run(
create_matmul_struct(input_tensor_a, input_tensor_b, parameters),
create_matmul_struct(input_tensor_a, input_tensor_b, parameters, optional_output_tensors),
{input_tensor_a, input_tensor_b},
optional_input_tensors,
{},
optional_output_tensors,
queue_id);
},
{input_tensor_a, input_tensor_b},
output_tensors,
optional_input_tensors);
optional_input_tensors,
{optional_output_tensor});
return output_tensors.at(0);
}

void Matmul::validate(
const std::vector<Tensor>& input_tensors,
const std::vector<std::optional<const Tensor>>& optional_input_tensors) const {
const std::vector<std::optional<const Tensor>>& optional_input_tensors,
const std::vector<std::optional<Tensor>>& optional_output_tensors) const {
TT_FATAL(input_tensors.size() == 2, "Error");
const auto& input_tensor_a = input_tensors.at(0);
const auto& input_tensor_b = input_tensors.at(1);
Expand Down Expand Up @@ -1113,6 +1152,28 @@ void Matmul::validate(
a_shape[-1],
b_shape[-2]);

const bool is_optional_output_tensor = !optional_output_tensors.empty() && optional_output_tensors.at(0).has_value();
if (is_optional_output_tensor) {
const auto& optional_output_tensor_c = optional_output_tensors.at(0);
const auto& optional_output_tensor_shape = optional_output_tensor_c->get_logical_shape();
const auto output_tensor_spec = this->compute_output_specs(input_tensors, {}).at(0);
TT_FATAL(
optional_output_tensor_shape == output_tensor_spec.logical_shape(),
"Shape of Optional Output Tensor {} doesnt match Output Tensor {}",
optional_output_tensor_shape,
output_tensor_spec.logical_shape());
TT_FATAL(
optional_output_tensor_c->get_dtype() == this->output_dtype.value(),
"Type mismatch between optional output tensor {} & output tensor {}",
optional_output_tensor_c->get_dtype(),
this->output_dtype.value());
TT_FATAL(
optional_output_tensor_c->memory_config() == this->output_mem_config,
"Memory config mismatch between optional output tensor {} & output tensor {}",
optional_output_tensor_c->memory_config(),
this->output_mem_config);
}

TT_FATAL(this->bcast_batch.has_value(), "Error: bcast_batch field should have been automatically populated");
TT_FATAL(this->output_tile.has_value(), "Error: output_tile field should have been automatically populated");
if (this->bcast_batch.value()) {
Expand Down Expand Up @@ -1562,7 +1623,18 @@ void Matmul::validate(
chosen_program_config);
}

std::vector<ttnn::TensorSpec> Matmul::compute_output_specs(const std::vector<Tensor>& input_tensors) const {
std::vector<ttnn::TensorSpec> Matmul::compute_output_specs(
const std::vector<Tensor>& input_tensors, const std::vector<std::optional<Tensor>>& optional_output_tensors) const {
TT_FATAL(
optional_output_tensors.size() <= 1,
"None or One Optional output tensor can be passed when accessing it for computing Matmul's output specs");

const bool is_optional_output_tensor = !optional_output_tensors.empty() && optional_output_tensors.at(0).has_value();

if (is_optional_output_tensor) {
return {optional_output_tensors.at(0)->get_tensor_spec()};
}

const auto& input_tensor_a = input_tensors.at(0);
const auto& input_tensor_b = input_tensors.at(1);
const ttnn::SimpleShape input_shape_a = input_tensor_a.get_logical_shape();
Expand All @@ -1587,6 +1659,7 @@ std::vector<ttnn::TensorSpec> Matmul::compute_output_specs(const std::vector<Ten
auto output_tile = this->output_tile.value();
auto tile_width_ratio = output_tile.get_tile_shape()[1] / in1_tile_shape[1];
auto output_layout = this->untilize_out ? Layout::ROW_MAJOR : Layout::TILE;

TT_FATAL(this->output_dtype.has_value(), "Error");
if (this->output_mem_config.is_sharded()) {
MatmulProgramConfig chosen_program_config = get_program_config(input_tensor_a, input_tensor_b, this);
Expand Down Expand Up @@ -1733,8 +1806,9 @@ std::vector<ttnn::TensorSpec> Matmul::compute_output_specs(const std::vector<Ten
output_shape, TensorLayout(output_dtype.value(), PageConfig(Layout::TILE, output_tile), output_mem_config))};
}

std::vector<Tensor> Matmul::create_output_tensors(const std::vector<Tensor>& input_tensors) const {
return operation::default_create_output_tensors(*this, input_tensors, {});
std::vector<Tensor> Matmul::create_output_tensors(
const std::vector<Tensor>& input_tensors, const std::vector<std::optional<Tensor>>& optional_output_tensors) const {
return operation::default_create_output_tensors(*this, input_tensors, optional_output_tensors);
}

operation::ProgramWithCallbacks Matmul::create_program(
Expand Down
21 changes: 15 additions & 6 deletions ttnn/cpp/ttnn/operations/matmul/device/matmul_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,15 @@ struct Matmul {

void validate(
const std::vector<Tensor>& input_tensors,
const std::vector<std::optional<const Tensor>>& optional_input_tensors) const;
std::vector<ttnn::TensorSpec> compute_output_specs(const std::vector<Tensor>& input_tensors) const;
std::vector<Tensor> create_output_tensors(const std::vector<Tensor>& input_tensors) const;
tt::tt_metal::operation::ProgramWithCallbacks create_program(
const std::vector<std::optional<const Tensor>>& optional_input_tensors,
const std::vector<std::optional<Tensor>>& optional_output_tensors = {std::nullopt}) const;
std::vector<ttnn::TensorSpec> compute_output_specs(
const std::vector<Tensor>& input_tensors,
const std::vector<std::optional<Tensor>>& optional_output_tensors = {std::nullopt}) const;
std::vector<Tensor> create_output_tensors(
const std::vector<Tensor>& input_tensors,
const std::vector<std::optional<Tensor>>& optional_output_tensors = {std::nullopt}) const;
operation::ProgramWithCallbacks create_program(
const std::vector<Tensor>& input_tensors,
const std::vector<std::optional<const Tensor>>& optional_input_tensors,
std::vector<Tensor>& output_tensors) const;
Expand All @@ -191,7 +196,10 @@ struct Matmul {
};

Matmul create_matmul_struct(
const Tensor& input_tensor_a, const Tensor& input_tensor_b, const struct Matmul& parameters);
const Tensor& input_tensor_a,
const Tensor& input_tensor_b,
const struct Matmul& parameters,
const std::vector<std::optional<Tensor>>& optional_output_tensors = {std::nullopt});

operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_1d_optimized_helper(
tt::tt_metal::Program& program,
Expand Down Expand Up @@ -221,7 +229,8 @@ Tensor matmul(
const Tensor& input_tensor_b,
const std::optional<const Tensor>& bias = std::nullopt,
const struct Matmul& parameters = Matmul{},
const uint8_t queue_id = 0);
const uint8_t queue_id = 0,
const std::optional<Tensor>& optional_output_tensor = std::nullopt);

} // namespace matmul

Expand Down
Loading
Loading