Skip to content

Commit

Permalink
#0: Port all Misc ops to use TensorSpec (#15509)
Browse files Browse the repository at this point in the history
### Ticket

### Problem description
We need to migrate all ops to use `compute_output_specs` with
TensorSpec, instead of older `compute_output_shapes`

### What's changed
Migrated all misc ops to TensorSpec
Minor infra upgrades to support the migration.

### Checklist
- [x] [Post commit CI
passes](https://github.com/tenstorrent/tt-metal/actions/runs/12057375995)
- [x] [(Single-card) Demo
tests](https://github.com/tenstorrent/tt-metal/actions/runs/12133160610)
- [x] [(Single-card) Model perf
tests](https://github.com/tenstorrent/tt-metal/actions/runs/12133156506)
- [x] [(Single-card) Device perf
regressions](https://github.com/tenstorrent/tt-metal/actions/runs/12146782009)
- [x] New/Existing tests provide coverage for changes
  • Loading branch information
sminakov-tt authored and yieldthought committed Dec 6, 2024
1 parent ad533c4 commit 37c0bb0
Show file tree
Hide file tree
Showing 41 changed files with 265 additions and 273 deletions.
2 changes: 1 addition & 1 deletion tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ std::vector<Tensor> run_operation(
static_assert(
operation::detail::is_device_operation<OpConfig>(), "ttnn::run_operation can only dispatch Device Operations!");
// Create output tensor vector by examining the number of output shapes created by the device operation
auto output_shapes = operation::DeviceOperation<operation::Tensors>(devop).compute_output_shapes(input_tensors);
auto output_shapes = operation::DeviceOperation<operation::Tensors>(devop).compute_output_shapes(input_tensors, {});
size_t output_shapes_size = 0;
if (std::holds_alternative<std::vector<ttnn::SimpleShape>>(output_shapes)) {
output_shapes_size = std::get<std::vector<ttnn::SimpleShape>>(output_shapes).size();
Expand Down
64 changes: 39 additions & 25 deletions ttnn/cpp/ttnn/operation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,9 +298,15 @@ constexpr bool implements_compute_output_shapes() {
template <class T, class... Args>
using has_compute_output_specs_t = decltype(std::declval<T>().compute_output_specs(std::declval<Args>()...));

template <class T>
constexpr bool implements_compute_output_specs_with_optional_output_tensors() {
return std::experimental::is_detected_v<has_compute_output_specs_t, T, const Tensors&, const OptionalTensors&>;
}

template <class T>
constexpr bool implements_compute_output_specs() {
return std::experimental::is_detected_v<has_compute_output_specs_t, T, const Tensors&>;
return std::experimental::is_detected_v<has_compute_output_specs_t, T, const Tensors&> ||
implements_compute_output_specs_with_optional_output_tensors<T>();
}

template <class T, class... Args>
Expand Down Expand Up @@ -441,8 +447,9 @@ struct DeviceOperation final {
}

// TODO: Rename into compute_output_specs in later PR
inline const ComputedShapes compute_output_shapes(const Tensors& input_tensors) const {
return this->compute_output_shapes_impl_(this->type_erased_storage, input_tensors);
inline const ComputedShapes compute_output_shapes(
const Tensors& input_tensors, const OptionalTensors& output_tensors) const {
return this->compute_output_shapes_impl_(this->type_erased_storage, input_tensors, output_tensors);
}

inline const OutputTensors create_output_tensors(
Expand Down Expand Up @@ -581,37 +588,44 @@ struct DeviceOperation final {
"Operation must implement either validate or validate_with_output_tensors");
}
}},
compute_output_shapes_impl_{[](const storage_t& storage, const Tensors& input_tensors) -> const ComputedShapes {
const auto& operation = *reinterpret_cast<const std::decay_t<T>*>(&storage);
if constexpr (
detail::implements_compute_output_shapes<T>() and detail::implements_compute_output_specs<T>()) {
static_assert(
tt::stl::concepts::always_false_v<T>,
"Operation cannot implement both compute_output_shapes and compute_output_specs");
} else if constexpr (detail::implements_compute_output_shapes<T>()) {
return operation.compute_output_shapes(input_tensors);
} else if constexpr (detail::implements_compute_output_specs<T>()) {
return operation.compute_output_specs(input_tensors);
} else {
static_assert(
tt::stl::concepts::always_false_v<T>,
"Operation must implement either compute_output_shapes or compute_output_specs");
}
}},
compute_output_shapes_impl_{
[](const storage_t& storage,
const Tensors& input_tensors,
const OptionalTensors& output_tensors) -> const ComputedShapes {
const auto& operation = *reinterpret_cast<const std::decay_t<T>*>(&storage);
if constexpr (
detail::implements_compute_output_shapes<T>() and detail::implements_compute_output_specs<T>()) {
static_assert(
tt::stl::concepts::always_false_v<T>,
"Operation cannot implement both compute_output_shapes and compute_output_specs");
} else if constexpr (detail::implements_compute_output_shapes<T>()) {
return operation.compute_output_shapes(input_tensors);
} else if constexpr (detail::implements_compute_output_specs_with_optional_output_tensors<T>()) {
return operation.compute_output_specs(input_tensors, output_tensors);
} else if constexpr (detail::implements_compute_output_specs<T>()) {
return operation.compute_output_specs(input_tensors);
} else {
static_assert(
tt::stl::concepts::always_false_v<T>,
"Operation must implement either compute_output_shapes or compute_output_specs");
}
}},
create_output_tensors_impl_{
[](const storage_t& storage,
const Tensors& input_tensors,
const OptionalTensors& output_tensors) -> const OutputTensors {
const auto& operation = *reinterpret_cast<const std::decay_t<T>*>(&storage);
if constexpr (detail::implements_create_output_tensors_with_optional_output_tensors<T>()) {
static_assert(
detail::implements_compute_output_shapes<T>(),
"Operation must implement compute_output_shapes if it implements create_output_tensors");
detail::implements_compute_output_shapes<T>() || detail::implements_compute_output_specs<T>(),
"Operation must implement compute_output_shapes or compute_output_specs if it implements "
"create_output_tensors");
return operation.create_output_tensors(input_tensors, output_tensors);
} else if constexpr (detail::implements_create_output_tensors<T>()) {
static_assert(
detail::implements_compute_output_shapes<T>(),
"Operation must implement compute_output_shapes if it implements create_output_tensors");
detail::implements_compute_output_shapes<T>() || detail::implements_compute_output_specs<T>(),
"Operation must implement compute_output_shapes or compute_output_specs if it implements "
"create_output_tensors");
return operation.create_output_tensors(input_tensors);
} else if constexpr (detail::implements_compute_output_specs<T>()) {
return detail::default_create_output_tensors(operation, input_tensors, output_tensors);
Expand Down Expand Up @@ -811,7 +825,7 @@ struct DeviceOperation final {
const Tensors&,
const std::vector<std::optional<const Tensor>>&,
const OptionalTensors&);
const ComputedShapes (*compute_output_shapes_impl_)(const storage_t& value, const Tensors&);
const ComputedShapes (*compute_output_shapes_impl_)(const storage_t& value, const Tensors&, const OptionalTensors&);
const OutputTensors (*create_output_tensors_impl_)(const storage_t& value, const Tensors&, const OptionalTensors&);

CacheableProgram<OutputTensors> (*create_program_impl_)(
Expand Down
18 changes: 8 additions & 10 deletions ttnn/cpp/ttnn/operations/full/device/full_device_operation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,20 +50,18 @@ void FullOperation::validate_on_program_cache_hit(
validate_inputs(operation_attributes, tensor_args);
};

FullOperation::shape_return_value_t FullOperation::compute_output_shapes(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
return SimpleShape(operation_attributes.shape);
FullOperation::spec_return_value_t FullOperation::compute_output_specs(
const operation_attributes_t& operation_attributes, const tensor_args_t&) {
return TensorSpec(
SimpleShape(operation_attributes.shape),
TensorLayout(
operation_attributes.dtype, PageConfig(operation_attributes.layout), operation_attributes.memory_config));
};

FullOperation::tensor_return_value_t FullOperation::create_output_tensors(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
auto output_shape = compute_output_shapes(operation_attributes, tensor_args);
return create_device_tensor(
output_shape,
operation_attributes.dtype,
operation_attributes.layout,
tensor_args.any.device(),
operation_attributes.memory_config);
auto output_spec = compute_output_specs(operation_attributes, tensor_args);
return create_device_tensor(output_spec, tensor_args.any.device());
}

std::tuple<FullOperation::operation_attributes_t, FullOperation::tensor_args_t> FullOperation::invoke(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ struct FullOperation {
const Tensor& any;
};

using shape_return_value_t = SimpleShape;
using spec_return_value_t = TensorSpec;
using tensor_return_value_t = Tensor;

struct ProgramFactory {
Expand Down Expand Up @@ -54,7 +54,7 @@ struct FullOperation {
static program_factory_t select_program_factory(const operation_attributes_t&, const tensor_args_t&);
static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&);
static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&);
static shape_return_value_t compute_output_shapes(const operation_attributes_t&, const tensor_args_t&);
static spec_return_value_t compute_output_specs(const operation_attributes_t&, const tensor_args_t&);
static tensor_return_value_t create_output_tensors(const operation_attributes_t&, const tensor_args_t&);

static std::tuple<operation_attributes_t, tensor_args_t> invoke(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,21 +42,18 @@ void FullLikeOperation::validate_on_program_cache_hit(
validate(operation_attributes, tensor_args);
}

FullLikeOperation::shape_return_value_t FullLikeOperation::compute_output_shapes(
FullLikeOperation::spec_return_value_t FullLikeOperation::compute_output_specs(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
return tensor_args.input.get_logical_shape();
return TensorSpec(
tensor_args.input.get_logical_shape(),
TensorLayout(
operation_attributes.dtype, PageConfig(operation_attributes.layout), operation_attributes.memory_config));
}

FullLikeOperation::tensor_return_value_t FullLikeOperation::create_output_tensors(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
const auto output_shape = compute_output_shapes(operation_attributes, tensor_args);
const auto& input = tensor_args.input;
return create_device_tensor(
output_shape,
operation_attributes.dtype,
operation_attributes.layout,
input.device(),
operation_attributes.memory_config);
const auto output_spec = compute_output_specs(operation_attributes, tensor_args);
return create_device_tensor(output_spec, tensor_args.input.device());
}

std::tuple<FullLikeOperation::operation_attributes_t, FullLikeOperation::tensor_args_t> FullLikeOperation::invoke(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ struct FullLikeOperation {
const Tensor& input;
};

using shape_return_value_t = SimpleShape;
using spec_return_value_t = TensorSpec;
using tensor_return_value_t = Tensor;

struct ProgramFactory {
Expand Down Expand Up @@ -54,7 +54,7 @@ struct FullLikeOperation {
static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&);
static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&);
static void validate(const operation_attributes_t&, const tensor_args_t&);
static shape_return_value_t compute_output_shapes(const operation_attributes_t&, const tensor_args_t&);
static spec_return_value_t compute_output_specs(const operation_attributes_t&, const tensor_args_t&);

static tensor_return_value_t create_output_tensors(const operation_attributes_t&, const tensor_args_t&);
static std::tuple<operation_attributes_t, tensor_args_t> invoke(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,16 @@ void IndexFillOperation::validate_on_program_cache_hit(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
validate(operation_attributes, tensor_args);
}
IndexFillOperation::shape_return_value_t IndexFillOperation::compute_output_shapes(
IndexFillOperation::spec_return_value_t IndexFillOperation::compute_output_specs(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
return tensor_args.input.get_logical_shape();
return TensorSpec(
tensor_args.input.get_logical_shape(),
tensor_args.input.get_tensor_spec().tensor_layout().with_memory_config(operation_attributes.memory_config));
}
IndexFillOperation::tensor_return_value_t IndexFillOperation::create_output_tensors(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
const auto output_shape = compute_output_shapes(operation_attributes, tensor_args);
const auto& input = tensor_args.input;
return create_device_tensor(
output_shape, input.dtype(), input.layout(), input.device(), operation_attributes.memory_config);
const auto output_spec = compute_output_specs(operation_attributes, tensor_args);
return create_device_tensor(output_spec, tensor_args.input.device());
}
std::tuple<IndexFillOperation::operation_attributes_t, IndexFillOperation::tensor_args_t> IndexFillOperation::invoke(
const Tensor& input,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ struct IndexFillOperation {
const Tensor& input;
const Tensor& index;
};
using shape_return_value_t = SimpleShape;
using spec_return_value_t = TensorSpec;
using tensor_return_value_t = Tensor;
struct MultiCore {
struct shared_variables_t {
Expand All @@ -46,7 +46,7 @@ struct IndexFillOperation {
static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&);
static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&);
static void validate(const operation_attributes_t&, const tensor_args_t&);
static shape_return_value_t compute_output_shapes(const operation_attributes_t&, const tensor_args_t&);
static spec_return_value_t compute_output_specs(const operation_attributes_t&, const tensor_args_t&);
static tensor_return_value_t create_output_tensors(const operation_attributes_t&, const tensor_args_t&);
static std::tuple<operation_attributes_t, tensor_args_t> invoke(
const Tensor& input,
Expand Down
3 changes: 1 addition & 2 deletions ttnn/cpp/ttnn/operations/kv_cache/device/update_cache_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,7 @@ void UpdateCache::validate(const std::vector<Tensor>& input_tensors) const {
}
}

std::vector<tt::tt_metal::LegacyShape> UpdateCache::compute_output_shapes(
const std::vector<Tensor>& input_tensors) const {
std::vector<TensorSpec> UpdateCache::compute_output_specs(const std::vector<Tensor>&) const {
// Do nothing because it's an in-place operation
return {};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ struct UpdateCache {
UpdateCacheOpParallelizationStrategy get_parallelization_strategy(const std::vector<Tensor>& input_tensors) const;

void validate(const std::vector<Tensor>& input_tensors) const;
std::vector<tt::tt_metal::LegacyShape> compute_output_shapes(const std::vector<Tensor>& input_tensors) const;
std::vector<TensorSpec> compute_output_specs(const std::vector<Tensor>& input_tensors) const;
std::vector<Tensor> create_output_tensors(const std::vector<Tensor>& input_tensors) const;

operation::ProgramWithCallbacks create_program(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,23 +79,23 @@ void GroupNorm::validate(
TT_FATAL(input_mask.value().get_legacy_shape()[3] % TILE_WIDTH == 0, "Error");
}
}
std::vector<ttnn::SimpleShape> GroupNorm::compute_output_shapes(const std::vector<Tensor>& input_tensors) const {
return {input_tensors.at(0).get_logical_shape()};
std::vector<TensorSpec> GroupNorm::compute_output_specs(const std::vector<Tensor>& input_tensors) const {
const auto& input_tensor = input_tensors.at(0);
if (this->program_config.inplace) {
return {input_tensor.get_tensor_spec()};
}
auto mem_config = this->output_mem_config;
mem_config.shard_spec = input_tensor.shard_spec();
return {TensorSpec(
input_tensor.get_logical_shape(),
TensorLayout(program_config.out_data_format, PageConfig(program_config.output_layout), mem_config))};
}
std::vector<Tensor> GroupNorm::create_output_tensors(const std::vector<Tensor>& input_tensors) const {
const auto& input_tensor = input_tensors.at(0);
if (this->program_config.inplace) {
return {input_tensors.at(0)};
} else {
auto mem_config = this->output_mem_config;
mem_config.shard_spec = input_tensor.shard_spec();
return {create_device_tensor(
this->compute_output_shapes(input_tensors).at(0),
program_config.out_data_format,
this->program_config.output_layout,
input_tensor.device(),
mem_config)};
return {input_tensor};
}
return {create_device_tensor(this->compute_output_specs(input_tensors).at(0), input_tensor.device())};
}
operation::ProgramWithCallbacks GroupNorm::create_program(
const std::vector<Tensor>& input_tensors,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ struct GroupNorm {
void validate(
const std::vector<Tensor>& input_tensors,
const std::vector<std::optional<const Tensor>>& optional_input_tensors) const;
std::vector<ttnn::SimpleShape> compute_output_shapes(const std::vector<Tensor>& input_tensors) const;
std::vector<TensorSpec> compute_output_specs(const std::vector<Tensor>& input_tensors) const;
std::vector<Tensor> create_output_tensors(const std::vector<Tensor>& input_tensors) const;
operation::ProgramWithCallbacks create_program(
const std::vector<Tensor>& input_tensors,
Expand Down
Loading

0 comments on commit 37c0bb0

Please sign in to comment.