Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#0: Port all Misc ops to use TensorSpec #15509

Merged
merged 8 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ std::vector<Tensor> run_operation(
static_assert(
operation::detail::is_device_operation<OpConfig>(), "ttnn::run_operation can only dispatch Device Operations!");
// Create output tensor vector by examining the number of output shapes created by the device operation
auto output_shapes = operation::DeviceOperation<operation::Tensors>(devop).compute_output_shapes(input_tensors);
auto output_shapes = operation::DeviceOperation<operation::Tensors>(devop).compute_output_shapes(input_tensors, {});
size_t output_shapes_size = 0;
if (std::holds_alternative<std::vector<ttnn::SimpleShape>>(output_shapes)) {
output_shapes_size = std::get<std::vector<ttnn::SimpleShape>>(output_shapes).size();
Expand Down
64 changes: 39 additions & 25 deletions ttnn/cpp/ttnn/operation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,9 +298,15 @@ constexpr bool implements_compute_output_shapes() {
template <class T, class... Args>
using has_compute_output_specs_t = decltype(std::declval<T>().compute_output_specs(std::declval<Args>()...));

template <class T>
constexpr bool implements_compute_output_specs_with_optional_output_tensors() {
return std::experimental::is_detected_v<has_compute_output_specs_t, T, const Tensors&, const OptionalTensors&>;
}

template <class T>
constexpr bool implements_compute_output_specs() {
return std::experimental::is_detected_v<has_compute_output_specs_t, T, const Tensors&>;
return std::experimental::is_detected_v<has_compute_output_specs_t, T, const Tensors&> ||
implements_compute_output_specs_with_optional_output_tensors<T>();
}

template <class T, class... Args>
Expand Down Expand Up @@ -441,8 +447,9 @@ struct DeviceOperation final {
}

// TODO: Rename into compute_output_specs in later PR
inline const ComputedShapes compute_output_shapes(const Tensors& input_tensors) const {
return this->compute_output_shapes_impl_(this->type_erased_storage, input_tensors);
inline const ComputedShapes compute_output_shapes(
const Tensors& input_tensors, const OptionalTensors& output_tensors) const {
return this->compute_output_shapes_impl_(this->type_erased_storage, input_tensors, output_tensors);
}

inline const OutputTensors create_output_tensors(
Expand Down Expand Up @@ -581,37 +588,44 @@ struct DeviceOperation final {
"Operation must implement either validate or validate_with_output_tensors");
}
}},
compute_output_shapes_impl_{[](const storage_t& storage, const Tensors& input_tensors) -> const ComputedShapes {
const auto& operation = *reinterpret_cast<const std::decay_t<T>*>(&storage);
if constexpr (
detail::implements_compute_output_shapes<T>() and detail::implements_compute_output_specs<T>()) {
static_assert(
tt::stl::concepts::always_false_v<T>,
"Operation cannot implement both compute_output_shapes and compute_output_specs");
} else if constexpr (detail::implements_compute_output_shapes<T>()) {
return operation.compute_output_shapes(input_tensors);
} else if constexpr (detail::implements_compute_output_specs<T>()) {
return operation.compute_output_specs(input_tensors);
} else {
static_assert(
tt::stl::concepts::always_false_v<T>,
"Operation must implement either compute_output_shapes or compute_output_specs");
}
}},
compute_output_shapes_impl_{
[](const storage_t& storage,
const Tensors& input_tensors,
const OptionalTensors& output_tensors) -> const ComputedShapes {
const auto& operation = *reinterpret_cast<const std::decay_t<T>*>(&storage);
if constexpr (
detail::implements_compute_output_shapes<T>() and detail::implements_compute_output_specs<T>()) {
static_assert(
tt::stl::concepts::always_false_v<T>,
"Operation cannot implement both compute_output_shapes and compute_output_specs");
} else if constexpr (detail::implements_compute_output_shapes<T>()) {
return operation.compute_output_shapes(input_tensors);
} else if constexpr (detail::implements_compute_output_specs_with_optional_output_tensors<T>()) {
return operation.compute_output_specs(input_tensors, output_tensors);
} else if constexpr (detail::implements_compute_output_specs<T>()) {
return operation.compute_output_specs(input_tensors);
} else {
static_assert(
tt::stl::concepts::always_false_v<T>,
"Operation must implement either compute_output_shapes or compute_output_specs");
}
}},
create_output_tensors_impl_{
[](const storage_t& storage,
const Tensors& input_tensors,
const OptionalTensors& output_tensors) -> const OutputTensors {
const auto& operation = *reinterpret_cast<const std::decay_t<T>*>(&storage);
if constexpr (detail::implements_create_output_tensors_with_optional_output_tensors<T>()) {
static_assert(
detail::implements_compute_output_shapes<T>(),
"Operation must implement compute_output_shapes if it implements create_output_tensors");
detail::implements_compute_output_shapes<T>() || detail::implements_compute_output_specs<T>(),
"Operation must implement compute_output_shapes or compute_output_specs if it implements "
"create_output_tensors");
return operation.create_output_tensors(input_tensors, output_tensors);
} else if constexpr (detail::implements_create_output_tensors<T>()) {
static_assert(
detail::implements_compute_output_shapes<T>(),
"Operation must implement compute_output_shapes if it implements create_output_tensors");
detail::implements_compute_output_shapes<T>() || detail::implements_compute_output_specs<T>(),
"Operation must implement compute_output_shapes or compute_output_specs if it implements "
"create_output_tensors");
return operation.create_output_tensors(input_tensors);
} else if constexpr (detail::implements_compute_output_specs<T>()) {
return detail::default_create_output_tensors(operation, input_tensors, output_tensors);
Expand Down Expand Up @@ -811,7 +825,7 @@ struct DeviceOperation final {
const Tensors&,
const std::vector<std::optional<const Tensor>>&,
const OptionalTensors&);
const ComputedShapes (*compute_output_shapes_impl_)(const storage_t& value, const Tensors&);
const ComputedShapes (*compute_output_shapes_impl_)(const storage_t& value, const Tensors&, const OptionalTensors&);
const OutputTensors (*create_output_tensors_impl_)(const storage_t& value, const Tensors&, const OptionalTensors&);

CacheableProgram<OutputTensors> (*create_program_impl_)(
Expand Down
18 changes: 8 additions & 10 deletions ttnn/cpp/ttnn/operations/full/device/full_device_operation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,20 +50,18 @@ void FullOperation::validate_on_program_cache_hit(
validate_inputs(operation_attributes, tensor_args);
};

FullOperation::shape_return_value_t FullOperation::compute_output_shapes(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
return SimpleShape(operation_attributes.shape);
FullOperation::spec_return_value_t FullOperation::compute_output_specs(
const operation_attributes_t& operation_attributes, const tensor_args_t&) {
return TensorSpec(
SimpleShape(operation_attributes.shape),
TensorLayout(
operation_attributes.dtype, PageConfig(operation_attributes.layout), operation_attributes.memory_config));
};

FullOperation::tensor_return_value_t FullOperation::create_output_tensors(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
auto output_shape = compute_output_shapes(operation_attributes, tensor_args);
return create_device_tensor(
output_shape,
operation_attributes.dtype,
operation_attributes.layout,
tensor_args.any.device(),
operation_attributes.memory_config);
auto output_spec = compute_output_specs(operation_attributes, tensor_args);
return create_device_tensor(output_spec, tensor_args.any.device());
}

std::tuple<FullOperation::operation_attributes_t, FullOperation::tensor_args_t> FullOperation::invoke(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ struct FullOperation {
const Tensor& any;
};

using shape_return_value_t = SimpleShape;
using spec_return_value_t = TensorSpec;
using tensor_return_value_t = Tensor;

struct ProgramFactory {
Expand Down Expand Up @@ -54,7 +54,7 @@ struct FullOperation {
static program_factory_t select_program_factory(const operation_attributes_t&, const tensor_args_t&);
static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&);
static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&);
static shape_return_value_t compute_output_shapes(const operation_attributes_t&, const tensor_args_t&);
static spec_return_value_t compute_output_specs(const operation_attributes_t&, const tensor_args_t&);
static tensor_return_value_t create_output_tensors(const operation_attributes_t&, const tensor_args_t&);

static std::tuple<operation_attributes_t, tensor_args_t> invoke(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,21 +42,18 @@ void FullLikeOperation::validate_on_program_cache_hit(
validate(operation_attributes, tensor_args);
}

FullLikeOperation::shape_return_value_t FullLikeOperation::compute_output_shapes(
FullLikeOperation::spec_return_value_t FullLikeOperation::compute_output_specs(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
return tensor_args.input.get_logical_shape();
return TensorSpec(
tensor_args.input.get_logical_shape(),
TensorLayout(
operation_attributes.dtype, PageConfig(operation_attributes.layout), operation_attributes.memory_config));
}

FullLikeOperation::tensor_return_value_t FullLikeOperation::create_output_tensors(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
const auto output_shape = compute_output_shapes(operation_attributes, tensor_args);
const auto& input = tensor_args.input;
return create_device_tensor(
output_shape,
operation_attributes.dtype,
operation_attributes.layout,
input.device(),
operation_attributes.memory_config);
const auto output_spec = compute_output_specs(operation_attributes, tensor_args);
return create_device_tensor(output_spec, tensor_args.input.device());
}

std::tuple<FullLikeOperation::operation_attributes_t, FullLikeOperation::tensor_args_t> FullLikeOperation::invoke(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ struct FullLikeOperation {
const Tensor& input;
};

using shape_return_value_t = SimpleShape;
using spec_return_value_t = TensorSpec;
using tensor_return_value_t = Tensor;

struct ProgramFactory {
Expand Down Expand Up @@ -54,7 +54,7 @@ struct FullLikeOperation {
static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&);
static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&);
static void validate(const operation_attributes_t&, const tensor_args_t&);
static shape_return_value_t compute_output_shapes(const operation_attributes_t&, const tensor_args_t&);
static spec_return_value_t compute_output_specs(const operation_attributes_t&, const tensor_args_t&);

static tensor_return_value_t create_output_tensors(const operation_attributes_t&, const tensor_args_t&);
static std::tuple<operation_attributes_t, tensor_args_t> invoke(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,16 @@ void IndexFillOperation::validate_on_program_cache_hit(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
validate(operation_attributes, tensor_args);
}
IndexFillOperation::shape_return_value_t IndexFillOperation::compute_output_shapes(
IndexFillOperation::spec_return_value_t IndexFillOperation::compute_output_specs(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
return tensor_args.input.get_logical_shape();
return TensorSpec(
tensor_args.input.get_logical_shape(),
tensor_args.input.get_tensor_spec().tensor_layout().with_memory_config(operation_attributes.memory_config));
}
IndexFillOperation::tensor_return_value_t IndexFillOperation::create_output_tensors(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
const auto output_shape = compute_output_shapes(operation_attributes, tensor_args);
const auto& input = tensor_args.input;
return create_device_tensor(
output_shape, input.dtype(), input.layout(), input.device(), operation_attributes.memory_config);
const auto output_spec = compute_output_specs(operation_attributes, tensor_args);
return create_device_tensor(output_spec, tensor_args.input.device());
}
std::tuple<IndexFillOperation::operation_attributes_t, IndexFillOperation::tensor_args_t> IndexFillOperation::invoke(
const Tensor& input,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ struct IndexFillOperation {
const Tensor& input;
const Tensor& index;
};
using shape_return_value_t = SimpleShape;
using spec_return_value_t = TensorSpec;
using tensor_return_value_t = Tensor;
struct MultiCore {
struct shared_variables_t {
Expand All @@ -46,7 +46,7 @@ struct IndexFillOperation {
static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&);
static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&);
static void validate(const operation_attributes_t&, const tensor_args_t&);
static shape_return_value_t compute_output_shapes(const operation_attributes_t&, const tensor_args_t&);
static spec_return_value_t compute_output_specs(const operation_attributes_t&, const tensor_args_t&);
static tensor_return_value_t create_output_tensors(const operation_attributes_t&, const tensor_args_t&);
static std::tuple<operation_attributes_t, tensor_args_t> invoke(
const Tensor& input,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,7 @@ void UpdateCache::validate(const std::vector<Tensor>& input_tensors) const {
}
}

std::vector<tt::tt_metal::LegacyShape> UpdateCache::compute_output_shapes(
const std::vector<Tensor>& input_tensors) const {
std::vector<TensorSpec> UpdateCache::compute_output_specs(const std::vector<Tensor>&) const {
// Do nothing because it's an in-place operation
return {};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ struct UpdateCache {
UpdateCacheOpParallelizationStrategy get_parallelization_strategy(const std::vector<Tensor>& input_tensors) const;

void validate(const std::vector<Tensor>& input_tensors) const;
std::vector<tt::tt_metal::LegacyShape> compute_output_shapes(const std::vector<Tensor>& input_tensors) const;
std::vector<TensorSpec> compute_output_specs(const std::vector<Tensor>& input_tensors) const;
std::vector<Tensor> create_output_tensors(const std::vector<Tensor>& input_tensors) const;

operation::ProgramWithCallbacks create_program(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,23 +79,23 @@ void GroupNorm::validate(
TT_FATAL(input_mask.value().get_legacy_shape()[3] % TILE_WIDTH == 0, "Error");
}
}
std::vector<ttnn::SimpleShape> GroupNorm::compute_output_shapes(const std::vector<Tensor>& input_tensors) const {
return {input_tensors.at(0).get_logical_shape()};
std::vector<TensorSpec> GroupNorm::compute_output_specs(const std::vector<Tensor>& input_tensors) const {
const auto& input_tensor = input_tensors.at(0);
if (this->program_config.inplace) {
return {input_tensor.get_tensor_spec()};
}
auto mem_config = this->output_mem_config;
mem_config.shard_spec = input_tensor.shard_spec();
return {TensorSpec(
input_tensor.get_logical_shape(),
TensorLayout(program_config.out_data_format, PageConfig(program_config.output_layout), mem_config))};
}
std::vector<Tensor> GroupNorm::create_output_tensors(const std::vector<Tensor>& input_tensors) const {
const auto& input_tensor = input_tensors.at(0);
if (this->program_config.inplace) {
return {input_tensors.at(0)};
} else {
auto mem_config = this->output_mem_config;
mem_config.shard_spec = input_tensor.shard_spec();
return {create_device_tensor(
this->compute_output_shapes(input_tensors).at(0),
program_config.out_data_format,
this->program_config.output_layout,
input_tensor.device(),
mem_config)};
return {input_tensor};
}
return {create_device_tensor(this->compute_output_specs(input_tensors).at(0), input_tensor.device())};
}
operation::ProgramWithCallbacks GroupNorm::create_program(
const std::vector<Tensor>& input_tensors,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ struct GroupNorm {
void validate(
const std::vector<Tensor>& input_tensors,
const std::vector<std::optional<const Tensor>>& optional_input_tensors) const;
std::vector<ttnn::SimpleShape> compute_output_shapes(const std::vector<Tensor>& input_tensors) const;
std::vector<TensorSpec> compute_output_specs(const std::vector<Tensor>& input_tensors) const;
std::vector<Tensor> create_output_tensors(const std::vector<Tensor>& input_tensors) const;
operation::ProgramWithCallbacks create_program(
const std::vector<Tensor>& input_tensors,
Expand Down
Loading
Loading