Skip to content

Commit

Permalink
cosmetic fixes
Browse files Browse the repository at this point in the history
Signed-off-by: Amruth Sandhupatla <[email protected]>
  • Loading branch information
asandhupatlaTT committed Dec 19, 2024
1 parent 6825a3d commit 9edd1f9
Showing 1 changed file with 22 additions and 12 deletions.
34 changes: 22 additions & 12 deletions ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1028,25 +1028,29 @@ Matmul create_matmul_struct(

const bool is_optional_output_tensor =
!optional_output_tensors.empty() && optional_output_tensors.at(0).has_value();
auto output_dtype = parameters.output_dtype;
auto output_mem_config = parameters.output_mem_config;
std::optional<DataType> output_dtype = parameters.output_dtype;
MemoryConfig output_mem_config = parameters.output_mem_config;

if (is_optional_output_tensor) {
const auto& optional_output_tensor_c = optional_output_tensors.at(0);
const auto& optional_output_tensor = optional_output_tensors.at(0);
if (output_mem_config == operation::DEFAULT_OUTPUT_MEMORY_CONFIG) {
output_mem_config = optional_output_tensor_c->memory_config();
output_mem_config = optional_output_tensor->memory_config();
} else {
TT_FATAL(
optional_output_tensor_c->memory_config() == output_mem_config,
"Memory config mismatch between optional output tensor & output tensor");
optional_output_tensor->memory_config() == output_mem_config,
"Memory config mismatch between optional output tensor {} & output tensor {}",
optional_output_tensor->memory_config(),
output_mem_config);
}

if (output_dtype.has_value()) {
TT_FATAL(
optional_output_tensor_c->get_dtype() == output_dtype.value(),
"Type mismatch between optional output tensor & output tensor");
optional_output_tensor->get_dtype() == output_dtype.value(),
"Type mismatch between optional output tensor {} & output tensor {}",
optional_output_tensor->get_dtype(),
output_dtype.value());
} else {
output_dtype = optional_output_tensor_c->get_dtype();
output_dtype = optional_output_tensor->get_dtype();
}
} else {
if (!output_dtype.has_value()) {
Expand Down Expand Up @@ -1155,13 +1159,19 @@ void Matmul::validate(
const auto output_tensor_spec = this->compute_output_specs(input_tensors, {}).at(0);
TT_FATAL(
optional_output_tensor_shape == output_tensor_spec.logical_shape(),
"Shape of Optional Output Tensor should match Output Tensor");
"Shape of Optional Output Tensor {} doesnt match Output Tensor {}",
optional_output_tensor_shape,
output_tensor_spec.logical_shape());
TT_FATAL(
optional_output_tensor_c->get_dtype() == this->output_dtype.value(),
"Type mismatch between optional output tensor & output tensor");
"Type mismatch between optional output tensor {} & output tensor {}",
optional_output_tensor_c->get_dtype(),
this->output_dtype.value());
TT_FATAL(
optional_output_tensor_c->memory_config() == this->output_mem_config,
"Memory config mismatch between optional output tensor & output tensor");
"Memory config mismatch between optional output tensor {} & output tensor {}",
optional_output_tensor_c->memory_config(),
this->output_mem_config);
}

TT_FATAL(this->bcast_batch.has_value(), "Error: bcast_batch field should have been automatically populated");
Expand Down

0 comments on commit 9edd1f9

Please sign in to comment.