From f05457a1eb5b45fddbd60b35835439994e06ae9d Mon Sep 17 00:00:00 2001 From: Denys Makoviichuk Date: Fri, 28 Feb 2025 14:43:53 -0800 Subject: [PATCH] [TT-Train]Training infra update (#18167) ### Problem description It is not easy to customize different neural networks. ### What's changed Now all created NNs are immutable by desing. Added operator () to the base class. All should derive. Id it is possible you should always store std::shared_ptr instead of your class. ### Checklist - [x] [All post commit](https://github.com/tenstorrent/tt-metal/actions/workflows/all-post-commit-workflows.yaml) CI passes - [x] New/Existing tests provide coverage for changes --- .../sources/ttml/autograd/module_base.cpp | 8 +++ .../sources/ttml/autograd/module_base.hpp | 5 ++ .../sources/ttml/models/distributed/gpt2.cpp | 3 +- .../sources/ttml/models/distributed/gpt2.hpp | 10 ++-- tt-train/sources/ttml/models/gpt2.cpp | 20 ++++--- tt-train/sources/ttml/models/gpt2.hpp | 15 ++--- .../sources/ttml/modules/dropout_module.hpp | 2 +- .../sources/ttml/modules/embedding_module.cpp | 11 ++-- .../sources/ttml/modules/embedding_module.hpp | 4 +- tt-train/sources/ttml/modules/gpt_block.hpp | 5 +- .../ttml/modules/layer_norm_module.hpp | 2 +- .../sources/ttml/modules/linear_module.cpp | 56 +++++++++++++------ .../sources/ttml/modules/linear_module.hpp | 13 ++--- .../ttml/modules/multi_head_attention.hpp | 3 +- .../ttml/modules/multi_layer_perceptron.cpp | 9 +-- .../ttml/modules/multi_layer_perceptron.hpp | 2 +- .../ttml/modules/positional_embeddings.hpp | 9 +-- .../sources/ttml/modules/rms_norm_module.hpp | 2 +- .../ttml/modules/single_head_attention.hpp | 3 +- tt-train/tests/model/weight_tying_test.cpp | 9 +-- .../serialization/model_serializer_test.cpp | 2 +- 21 files changed, 111 insertions(+), 82 deletions(-) diff --git a/tt-train/sources/ttml/autograd/module_base.cpp b/tt-train/sources/ttml/autograd/module_base.cpp index 569e4294f6b..b05bb89e06d 100644 --- a/tt-train/sources/ttml/autograd/module_base.cpp +++ b/tt-train/sources/ttml/autograd/module_base.cpp @@ -106,5 +106,13 @@ void ModuleBase::train() { void ModuleBase::eval() { set_run_mode(RunMode::EVAL); } +autograd::TensorPtr ModuleBase::operator()(const autograd::TensorPtr& tensor) { + throw std::logic_error("ModuleBase::operator()(const autograd::TensorPtr& tensor) is Not implemented"); +} +autograd::TensorPtr ModuleBase::operator()(const autograd::TensorPtr& tensor, const autograd::TensorPtr& other) { + throw std::logic_error( + "ModuleBase::operator()(const autograd::TensorPtr& tensor, const autograd::TensorPtr& other) is Not " + "implemented"); +} } // namespace ttml::autograd diff --git a/tt-train/sources/ttml/autograd/module_base.hpp b/tt-train/sources/ttml/autograd/module_base.hpp index 5cf53f5334f..f6176e39824 100644 --- a/tt-train/sources/ttml/autograd/module_base.hpp +++ b/tt-train/sources/ttml/autograd/module_base.hpp @@ -51,6 +51,11 @@ class ModuleBase { void eval(); void set_run_mode(RunMode mode); [[nodiscard]] RunMode get_run_mode() const; + + // Forward pass for the module. All posible overloads + [[nodiscard]] virtual autograd::TensorPtr operator()(const autograd::TensorPtr& tensor); + [[nodiscard]] virtual autograd::TensorPtr operator()( + const autograd::TensorPtr& tensor, const autograd::TensorPtr& other); }; } // namespace ttml::autograd diff --git a/tt-train/sources/ttml/models/distributed/gpt2.cpp b/tt-train/sources/ttml/models/distributed/gpt2.cpp index 50530fcd96c..5207cf730d6 100644 --- a/tt-train/sources/ttml/models/distributed/gpt2.cpp +++ b/tt-train/sources/ttml/models/distributed/gpt2.cpp @@ -5,6 +5,7 @@ #include "gpt2.hpp" #include "autograd/graph_utils.hpp" +#include "autograd/module_base.hpp" #include "autograd/tensor.hpp" #include "core/distributed_mapping.hpp" #include "core/scoped.hpp" @@ -129,7 +130,7 @@ DistributedTransformer::DistributedTransformer(const TransformerConfig& config) auto create_positional_embedding = [position_embedding_type, max_sequence_length, embedding_dim, - dropout_prob]() -> std::shared_ptr { + dropout_prob]() -> std::shared_ptr { if (position_embedding_type == PositionalEmbeddingType::Trainable) { return std::make_shared( ttml::modules::PositionalEmbeddingConfig{ diff --git a/tt-train/sources/ttml/models/distributed/gpt2.hpp b/tt-train/sources/ttml/models/distributed/gpt2.hpp index 3d1695342b5..c0bdcc30ac1 100644 --- a/tt-train/sources/ttml/models/distributed/gpt2.hpp +++ b/tt-train/sources/ttml/models/distributed/gpt2.hpp @@ -24,11 +24,11 @@ using models::gpt2::WeightTyingType; class DistributedTransformer : public ttml::autograd::ModuleBase { private: RunnerType runner_type = RunnerType::Default; - std::shared_ptr tok_emb; - std::shared_ptr pos_emb; - std::vector> blocks; - std::shared_ptr ln_fc; - std::shared_ptr fc; + std::shared_ptr tok_emb; + std::shared_ptr pos_emb; + std::vector> blocks; + std::shared_ptr ln_fc; + std::shared_ptr fc; public: explicit DistributedTransformer(const TransformerConfig& config); diff --git a/tt-train/sources/ttml/models/gpt2.cpp b/tt-train/sources/ttml/models/gpt2.cpp index 8ce2537d257..12cea49100b 100644 --- a/tt-train/sources/ttml/models/gpt2.cpp +++ b/tt-train/sources/ttml/models/gpt2.cpp @@ -8,10 +8,12 @@ #include "autograd/tensor.hpp" #include "core/scoped.hpp" #include "init/tensor_initializers.hpp" +#include "modules/embedding_module.hpp" +#include "modules/gpt_block.hpp" +#include "modules/layer_norm_module.hpp" #include "modules/positional_embeddings.hpp" #include "ops/binary_ops.hpp" #include "ops/unary_ops.hpp" - namespace ttml::models::gpt2 { namespace { @@ -112,12 +114,17 @@ Transformer::Transformer(const TransformerConfig& config) { "embedding_dim={}", embedding_dim)); } - tok_emb = std::make_shared(vocab_size_divisible_by_32, embedding_dim); + auto last_fc = std::make_shared(embedding_dim, vocab_size, /* bias */ false); + if (config.weight_tying == WeightTyingType::Enabled) { + tok_emb = std::make_shared(last_fc->get_weight()); + } else { + tok_emb = std::make_shared(vocab_size_divisible_by_32, embedding_dim); + } auto create_positional_embedding = [position_embedding_type, max_sequence_length, embedding_dim, - dropout_prob]() -> std::shared_ptr { + dropout_prob]() -> std::shared_ptr { if (position_embedding_type == PositionalEmbeddingType::Trainable) { return std::make_shared( ttml::modules::PositionalEmbeddingConfig{ @@ -140,7 +147,7 @@ Transformer::Transformer(const TransformerConfig& config) { std::make_shared(embedding_dim, num_heads, dropout_prob, use_composite_layernorm)); } ln_fc = std::make_shared(embedding_dim, use_composite_layernorm); - fc = std::make_shared(embedding_dim, vocab_size, /* bias */ false); + fc = last_fc; create_name("transformer"); register_module(tok_emb, "tok_emb"); @@ -151,11 +158,6 @@ Transformer::Transformer(const TransformerConfig& config) { register_module(ln_fc, "ln_fc"); register_module(fc, "fc"); - if (config.weight_tying == WeightTyingType::Enabled) { - // tie weights between embedding and fc - tok_emb->set_weight(fc->get_weight()); - } - weights_initialization(*this); } diff --git a/tt-train/sources/ttml/models/gpt2.hpp b/tt-train/sources/ttml/models/gpt2.hpp index 2c555888c8c..a41ba3f4f0b 100644 --- a/tt-train/sources/ttml/models/gpt2.hpp +++ b/tt-train/sources/ttml/models/gpt2.hpp @@ -6,10 +6,7 @@ #include -#include "modules/embedding_module.hpp" -#include "modules/gpt_block.hpp" -#include "modules/layer_norm_module.hpp" -#include "modules/positional_embeddings.hpp" +#include "autograd/module_base.hpp" namespace ttml::models::gpt2 { @@ -48,11 +45,11 @@ struct TransformerConfig { class Transformer : public ttml::autograd::ModuleBase { private: RunnerType runner_type = RunnerType::Default; - std::shared_ptr tok_emb; - std::shared_ptr pos_emb; - std::vector> blocks; - std::shared_ptr ln_fc; - std::shared_ptr fc; + std::shared_ptr tok_emb; + std::shared_ptr pos_emb; + std::vector> blocks; + std::shared_ptr ln_fc; + std::shared_ptr fc; public: explicit Transformer(const TransformerConfig& config); diff --git a/tt-train/sources/ttml/modules/dropout_module.hpp b/tt-train/sources/ttml/modules/dropout_module.hpp index e5a4f768e7f..fc1d5e8f7dc 100644 --- a/tt-train/sources/ttml/modules/dropout_module.hpp +++ b/tt-train/sources/ttml/modules/dropout_module.hpp @@ -17,7 +17,7 @@ class DropoutLayer : public autograd::ModuleBase { public: explicit DropoutLayer(float probability, bool use_per_device_seed = true); - [[nodiscard]] autograd::TensorPtr operator()(const autograd::TensorPtr& tensor); + [[nodiscard]] autograd::TensorPtr operator()(const autograd::TensorPtr& tensor) override; }; } // namespace ttml::modules diff --git a/tt-train/sources/ttml/modules/embedding_module.cpp b/tt-train/sources/ttml/modules/embedding_module.cpp index 9f669b2f471..35bf699c8a8 100644 --- a/tt-train/sources/ttml/modules/embedding_module.cpp +++ b/tt-train/sources/ttml/modules/embedding_module.cpp @@ -45,13 +45,14 @@ autograd::TensorPtr Embedding::operator()(const autograd::TensorPtr& tensor) { return ops::embedding_op(tensor, m_weight); } -void Embedding::set_weight(const autograd::TensorPtr& weight) { - m_weight = weight; - override_tensor(m_weight, "weight"); -} - autograd::TensorPtr Embedding::get_weight() const { return m_weight; } +Embedding::Embedding(const autograd::TensorPtr& weight) { + m_weight = weight; + create_name("embedding"); + register_tensor(m_weight, "weight"); +} + } // namespace ttml::modules diff --git a/tt-train/sources/ttml/modules/embedding_module.hpp b/tt-train/sources/ttml/modules/embedding_module.hpp index c3359f74b11..7f4df543826 100644 --- a/tt-train/sources/ttml/modules/embedding_module.hpp +++ b/tt-train/sources/ttml/modules/embedding_module.hpp @@ -16,10 +16,10 @@ class Embedding : public autograd::ModuleBase { public: Embedding(uint32_t num_embeddings, uint32_t embedding_dim); - void set_weight(const autograd::TensorPtr& weight); + Embedding(const autograd::TensorPtr& weight); [[nodiscard]] autograd::TensorPtr get_weight() const; - [[nodiscard]] autograd::TensorPtr operator()(const autograd::TensorPtr& tensor); + [[nodiscard]] autograd::TensorPtr operator()(const autograd::TensorPtr& tensor) override; }; } // namespace ttml::modules diff --git a/tt-train/sources/ttml/modules/gpt_block.hpp b/tt-train/sources/ttml/modules/gpt_block.hpp index 37806d61757..79d7528dbd0 100644 --- a/tt-train/sources/ttml/modules/gpt_block.hpp +++ b/tt-train/sources/ttml/modules/gpt_block.hpp @@ -21,7 +21,7 @@ class GPTMLP : public autograd::ModuleBase { public: GPTMLP(uint32_t embedding_size, float dropout_prob); - autograd::TensorPtr operator()(const autograd::TensorPtr& input); + [[nodiscard]] autograd::TensorPtr operator()(const autograd::TensorPtr& input) override; }; class GPTBlock : public autograd::ModuleBase { @@ -34,7 +34,8 @@ class GPTBlock : public autograd::ModuleBase { explicit GPTBlock( uint32_t embedding_size, uint32_t num_heads, float dropout_prob, bool use_composite_layernorm = false); - autograd::TensorPtr operator()(const autograd::TensorPtr& input, const autograd::TensorPtr& mask); + [[nodiscard]] autograd::TensorPtr operator()( + const autograd::TensorPtr& input, const autograd::TensorPtr& mask) override; }; } // namespace ttml::modules diff --git a/tt-train/sources/ttml/modules/layer_norm_module.hpp b/tt-train/sources/ttml/modules/layer_norm_module.hpp index 2274ee81d2c..c2aff910db3 100644 --- a/tt-train/sources/ttml/modules/layer_norm_module.hpp +++ b/tt-train/sources/ttml/modules/layer_norm_module.hpp @@ -22,7 +22,7 @@ class LayerNormLayer : public autograd::ModuleBase { void initialize_tensors(uint32_t features); explicit LayerNormLayer(uint32_t features, bool use_composite_op = false); - [[nodiscard]] autograd::TensorPtr operator()(const autograd::TensorPtr& tensor); + [[nodiscard]] autograd::TensorPtr operator()(const autograd::TensorPtr& tensor) override; }; } // namespace ttml::modules diff --git a/tt-train/sources/ttml/modules/linear_module.cpp b/tt-train/sources/ttml/modules/linear_module.cpp index e7e8412288d..27de8bcb791 100644 --- a/tt-train/sources/ttml/modules/linear_module.cpp +++ b/tt-train/sources/ttml/modules/linear_module.cpp @@ -6,28 +6,35 @@ #include +#include "autograd/auto_context.hpp" +#include "autograd/tensor.hpp" #include "core/tt_tensor_utils.hpp" #include "init/cpu_initializers.hpp" #include "init/tensor_initializers.hpp" +#include "ops/linear_op.hpp" namespace ttml::modules { -void LinearLayer::initialize_tensors(uint32_t in_features, uint32_t out_features, bool has_bias) { +namespace { +ttml::autograd::TensorPtr create_weight(uint32_t in_features, uint32_t out_features) { auto* device = &autograd::ctx().get_device(); auto weight_shape = core::create_shape({1, 1, out_features, in_features}); - m_weight = ttml::autograd::create_tensor(); + auto weight = ttml::autograd::create_tensor(); const float init_k = std::sqrtf(1.F / static_cast(in_features)); - init::uniform_init(m_weight, weight_shape, init::UniformRange{-init_k, init_k}); - if (has_bias) { - auto bias_shape = core::create_shape({1, 1, 1, out_features}); - m_bias = ttml::autograd::create_tensor(); - init::uniform_init(m_bias, bias_shape, init::UniformRange{-init_k, init_k}); - } + init::uniform_init(weight, weight_shape, init::UniformRange{-init_k, init_k}); + return weight; } +ttml::autograd::TensorPtr create_bias(uint32_t in_features, uint32_t out_features) { + const float init_k = std::sqrtf(1.F / static_cast(in_features)); + auto* device = &ttml::autograd::ctx().get_device(); + auto bias_shape = ttml::core::create_shape({1, 1, 1, out_features}); + auto bias = ttml::autograd::create_tensor(); + ttml::init::uniform_init(bias, bias_shape, ttml::init::UniformRange{-init_k, init_k}); + return bias; +} +} // namespace -LinearLayer::LinearLayer(uint32_t in_features, uint32_t out_features, bool has_bias) { - initialize_tensors(in_features, out_features, has_bias); - +void LinearLayer::register_tensors() { create_name("linear"); register_tensor(m_weight, "weight"); if (m_bias != nullptr) { @@ -35,13 +42,30 @@ LinearLayer::LinearLayer(uint32_t in_features, uint32_t out_features, bool has_b } } -autograd::TensorPtr LinearLayer::get_weight() const { - return m_weight; +LinearLayer::LinearLayer(uint32_t in_features, uint32_t out_features, bool has_bias) { + m_weight = create_weight(in_features, out_features); + if (has_bias) { + m_bias = create_bias(in_features, out_features); + } + register_tensors(); +} + +LinearLayer::LinearLayer(const autograd::TensorPtr& weight, bool has_bias) : m_weight(weight) { + if (has_bias) { + int in_features = m_weight->get_value().get_logical_shape()[3]; + int out_features = m_weight->get_value().get_logical_shape()[2]; + m_bias = create_bias(in_features, out_features); + } + register_tensors(); } -void LinearLayer::set_weight(const autograd::TensorPtr& weight) { - m_weight = weight; - override_tensor(m_weight, "weight"); +LinearLayer::LinearLayer(const autograd::TensorPtr& weight, const autograd::TensorPtr& bias) : + m_weight(weight), m_bias(bias) { + register_tensors(); +} + +autograd::TensorPtr LinearLayer::get_weight() const { + return m_weight; } autograd::TensorPtr LinearLayer::operator()(const autograd::TensorPtr& tensor) { diff --git a/tt-train/sources/ttml/modules/linear_module.hpp b/tt-train/sources/ttml/modules/linear_module.hpp index 09c92361bc1..ed6646a15f6 100644 --- a/tt-train/sources/ttml/modules/linear_module.hpp +++ b/tt-train/sources/ttml/modules/linear_module.hpp @@ -3,14 +3,8 @@ // SPDX-License-Identifier: Apache-2.0 #pragma once - -#include - -#include "autograd/auto_context.hpp" -#include "autograd/graph.hpp" #include "autograd/module_base.hpp" #include "autograd/tensor.hpp" -#include "ops/linear_op.hpp" namespace ttml::modules { @@ -20,14 +14,15 @@ class LinearLayer : public autograd::ModuleBase { autograd::TensorPtr m_bias; void initialize_tensors(uint32_t in_features, uint32_t out_features, bool has_bias = true); + void register_tensors(); public: LinearLayer(uint32_t in_features, uint32_t out_features, bool has_bias = true); - + LinearLayer(const autograd::TensorPtr& weight, const autograd::TensorPtr& bias); + LinearLayer(const autograd::TensorPtr& weight, bool has_bias = true); autograd::TensorPtr get_weight() const; - void set_weight(const autograd::TensorPtr& weight); - [[nodiscard]] autograd::TensorPtr operator()(const autograd::TensorPtr& tensor); + [[nodiscard]] autograd::TensorPtr operator()(const autograd::TensorPtr& tensor) override; }; } // namespace ttml::modules diff --git a/tt-train/sources/ttml/modules/multi_head_attention.hpp b/tt-train/sources/ttml/modules/multi_head_attention.hpp index 29aa47dce69..bcef1b0c66a 100644 --- a/tt-train/sources/ttml/modules/multi_head_attention.hpp +++ b/tt-train/sources/ttml/modules/multi_head_attention.hpp @@ -24,7 +24,8 @@ class MultiHeadAttention : public ttml::autograd::ModuleBase { public: explicit MultiHeadAttention(uint32_t embedding_dim, uint32_t num_heads, float dropout_prob); - autograd::TensorPtr operator()(const autograd::TensorPtr& x, const autograd::TensorPtr& mask); + [[nodiscard]] autograd::TensorPtr operator()( + const autograd::TensorPtr& x, const autograd::TensorPtr& mask) override; }; } // namespace ttml::modules diff --git a/tt-train/sources/ttml/modules/multi_layer_perceptron.cpp b/tt-train/sources/ttml/modules/multi_layer_perceptron.cpp index 72d3ca8e092..b57976467aa 100644 --- a/tt-train/sources/ttml/modules/multi_layer_perceptron.cpp +++ b/tt-train/sources/ttml/modules/multi_layer_perceptron.cpp @@ -28,15 +28,16 @@ MultiLayerPerceptron::MultiLayerPerceptron(const MultiLayerPerceptronParameters& register_module(m_layers[idx], "layer_" + std::to_string(idx)); } } -autograd::TensorPtr MultiLayerPerceptron::operator()(autograd::TensorPtr tensor) { +autograd::TensorPtr MultiLayerPerceptron::operator()(const autograd::TensorPtr& tensor) { + auto x = tensor; for (size_t index = 0; index < m_layers.size(); ++index) { - tensor = (*m_layers[index])(tensor); + x = (*m_layers[index])(x); if (index + 1 != m_layers.size()) { - tensor = ops::relu(tensor); + x = ops::relu(x); } } - return tensor; + return x; } } // namespace ttml::modules diff --git a/tt-train/sources/ttml/modules/multi_layer_perceptron.hpp b/tt-train/sources/ttml/modules/multi_layer_perceptron.hpp index 27a3b301696..6442ebd11f0 100644 --- a/tt-train/sources/ttml/modules/multi_layer_perceptron.hpp +++ b/tt-train/sources/ttml/modules/multi_layer_perceptron.hpp @@ -25,7 +25,7 @@ class MultiLayerPerceptron : public autograd::ModuleBase { public: explicit MultiLayerPerceptron(const MultiLayerPerceptronParameters& params); - [[nodiscard]] autograd::TensorPtr operator()(autograd::TensorPtr tensor); + [[nodiscard]] autograd::TensorPtr operator()(const autograd::TensorPtr& tensor) override; }; } // namespace ttml::modules diff --git a/tt-train/sources/ttml/modules/positional_embeddings.hpp b/tt-train/sources/ttml/modules/positional_embeddings.hpp index 58097d5f07e..3b45d02f526 100644 --- a/tt-train/sources/ttml/modules/positional_embeddings.hpp +++ b/tt-train/sources/ttml/modules/positional_embeddings.hpp @@ -19,12 +19,7 @@ struct PositionalEmbeddingConfig { bool use_dropout_seed_per_device{true}; }; -class PositionalEmbeddingBase : public autograd::ModuleBase { -public: - virtual autograd::TensorPtr operator()(const autograd::TensorPtr& input) = 0; -}; - -class PositionalEmbedding : public PositionalEmbeddingBase { +class PositionalEmbedding : public autograd::ModuleBase { private: uint32_t m_sequence_length{}; std::shared_ptr m_dropout; @@ -35,7 +30,7 @@ class PositionalEmbedding : public PositionalEmbeddingBase { [[nodiscard]] autograd::TensorPtr operator()(const autograd::TensorPtr& input) override; }; -class TrainablePositionalEmbedding : public PositionalEmbeddingBase { +class TrainablePositionalEmbedding : public autograd::ModuleBase { uint32_t m_sequence_length{}; autograd::TensorPtr m_weight; std::shared_ptr m_dropout; diff --git a/tt-train/sources/ttml/modules/rms_norm_module.hpp b/tt-train/sources/ttml/modules/rms_norm_module.hpp index 721b3658c07..c7b1eeda909 100644 --- a/tt-train/sources/ttml/modules/rms_norm_module.hpp +++ b/tt-train/sources/ttml/modules/rms_norm_module.hpp @@ -21,7 +21,7 @@ class RMSNormLayer : public autograd::ModuleBase { void initialize_tensors(uint32_t features); explicit RMSNormLayer(uint32_t features, float epsilon = 1e-5F); - [[nodiscard]] autograd::TensorPtr operator()(const autograd::TensorPtr& tensor); + [[nodiscard]] autograd::TensorPtr operator()(const autograd::TensorPtr& tensor) override; }; } // namespace ttml::modules diff --git a/tt-train/sources/ttml/modules/single_head_attention.hpp b/tt-train/sources/ttml/modules/single_head_attention.hpp index cddb8df1863..a4cdde7937c 100644 --- a/tt-train/sources/ttml/modules/single_head_attention.hpp +++ b/tt-train/sources/ttml/modules/single_head_attention.hpp @@ -19,7 +19,8 @@ class SingleHeadAttention : public ttml::autograd::ModuleBase { public: explicit SingleHeadAttention(uint32_t embedding_dim, float dropout_prob); - autograd::TensorPtr operator()(const autograd::TensorPtr& x, const autograd::TensorPtr& mask); + [[nodiscard]] autograd::TensorPtr operator()( + const autograd::TensorPtr& x, const autograd::TensorPtr& mask) override; }; } // namespace ttml::modules diff --git a/tt-train/tests/model/weight_tying_test.cpp b/tt-train/tests/model/weight_tying_test.cpp index 955defd7b6f..2f44da84522 100644 --- a/tt-train/tests/model/weight_tying_test.cpp +++ b/tt-train/tests/model/weight_tying_test.cpp @@ -21,14 +21,12 @@ class ModelFC : public ttml::autograd::ModuleBase { public: ModelFC() { - m_fc1 = std::make_shared(64, 64); m_fc2 = std::make_shared(64, 64); + m_fc1 = std::make_shared(m_fc2->get_weight(), /* has_bias*/ true); create_name("ModelFC"); register_module(m_fc1, "fc1"); register_module(m_fc2, "fc2"); - - m_fc1->set_weight(m_fc2->get_weight()); } ttml::autograd::TensorPtr operator()(const ttml::autograd::TensorPtr& x) { @@ -53,14 +51,13 @@ class LanguageModel : public ttml::autograd::ModuleBase { public: LanguageModel() { - m_fc1 = std::make_shared(128, 64); m_emb = std::make_shared(64, 128); + m_fc1 = std::make_shared(m_emb->get_weight(), /* has_bias*/ true); + create_name("LanguageModel"); register_module(m_fc1, "fc1"); register_module(m_emb, "emb"); - - m_fc1->set_weight(m_emb->get_weight()); } }; diff --git a/tt-train/tests/serialization/model_serializer_test.cpp b/tt-train/tests/serialization/model_serializer_test.cpp index e831788dbf6..ff7353d97cd 100644 --- a/tt-train/tests/serialization/model_serializer_test.cpp +++ b/tt-train/tests/serialization/model_serializer_test.cpp @@ -6,9 +6,9 @@ #include +#include "autograd/auto_context.hpp" #include "models/gpt2.hpp" #include "models/mlp.hpp" - class MultiLayerPerceptronParametersTest : public ::testing::Test { protected: void SetUp() override {