Skip to content

Commit

Permalink
[TT-Train]Training infra update (#18167)
Browse files Browse the repository at this point in the history
### Problem description
It is not easy to customize different neural networks.

### What's changed
Now all created NNs are immutable by desing.
Added operator () to the base class. All should derive.
Id it is possible you should always store std::shared_ptr<ModuleBase>
instead of your class.

### Checklist
- [x] [All post
commit](https://github.com/tenstorrent/tt-metal/actions/workflows/all-post-commit-workflows.yaml)
CI passes
- [x] New/Existing tests provide coverage for changes
  • Loading branch information
dmakoviichuk-tt authored Feb 28, 2025
1 parent e05b927 commit f05457a
Show file tree
Hide file tree
Showing 21 changed files with 111 additions and 82 deletions.
8 changes: 8 additions & 0 deletions tt-train/sources/ttml/autograd/module_base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,5 +106,13 @@ void ModuleBase::train() {
void ModuleBase::eval() {
set_run_mode(RunMode::EVAL);
}
autograd::TensorPtr ModuleBase::operator()(const autograd::TensorPtr& tensor) {
throw std::logic_error("ModuleBase::operator()(const autograd::TensorPtr& tensor) is Not implemented");
}
autograd::TensorPtr ModuleBase::operator()(const autograd::TensorPtr& tensor, const autograd::TensorPtr& other) {
throw std::logic_error(
"ModuleBase::operator()(const autograd::TensorPtr& tensor, const autograd::TensorPtr& other) is Not "
"implemented");
}

} // namespace ttml::autograd
5 changes: 5 additions & 0 deletions tt-train/sources/ttml/autograd/module_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ class ModuleBase {
void eval();
void set_run_mode(RunMode mode);
[[nodiscard]] RunMode get_run_mode() const;

// Forward pass for the module. All posible overloads
[[nodiscard]] virtual autograd::TensorPtr operator()(const autograd::TensorPtr& tensor);
[[nodiscard]] virtual autograd::TensorPtr operator()(
const autograd::TensorPtr& tensor, const autograd::TensorPtr& other);
};

} // namespace ttml::autograd
3 changes: 2 additions & 1 deletion tt-train/sources/ttml/models/distributed/gpt2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "gpt2.hpp"

#include "autograd/graph_utils.hpp"
#include "autograd/module_base.hpp"
#include "autograd/tensor.hpp"
#include "core/distributed_mapping.hpp"
#include "core/scoped.hpp"
Expand Down Expand Up @@ -129,7 +130,7 @@ DistributedTransformer::DistributedTransformer(const TransformerConfig& config)
auto create_positional_embedding = [position_embedding_type,
max_sequence_length,
embedding_dim,
dropout_prob]() -> std::shared_ptr<modules::PositionalEmbeddingBase> {
dropout_prob]() -> std::shared_ptr<autograd::ModuleBase> {
if (position_embedding_type == PositionalEmbeddingType::Trainable) {
return std::make_shared<ttml::modules::TrainablePositionalEmbedding>(
ttml::modules::PositionalEmbeddingConfig{
Expand Down
10 changes: 5 additions & 5 deletions tt-train/sources/ttml/models/distributed/gpt2.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ using models::gpt2::WeightTyingType;
class DistributedTransformer : public ttml::autograd::ModuleBase {
private:
RunnerType runner_type = RunnerType::Default;
std::shared_ptr<ttml::modules::Embedding> tok_emb;
std::shared_ptr<ttml::modules::PositionalEmbeddingBase> pos_emb;
std::vector<std::shared_ptr<ttml::modules::distributed::DistributedGPTBlock>> blocks;
std::shared_ptr<ttml::modules::LayerNormLayer> ln_fc;
std::shared_ptr<ttml::modules::distributed::ColumnParallelLinear> fc;
std::shared_ptr<ttml::autograd::ModuleBase> tok_emb;
std::shared_ptr<ttml::autograd::ModuleBase> pos_emb;
std::vector<std::shared_ptr<ttml::autograd::ModuleBase>> blocks;
std::shared_ptr<ttml::autograd::ModuleBase> ln_fc;
std::shared_ptr<ttml::autograd::ModuleBase> fc;

public:
explicit DistributedTransformer(const TransformerConfig& config);
Expand Down
20 changes: 11 additions & 9 deletions tt-train/sources/ttml/models/gpt2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
#include "autograd/tensor.hpp"
#include "core/scoped.hpp"
#include "init/tensor_initializers.hpp"
#include "modules/embedding_module.hpp"
#include "modules/gpt_block.hpp"
#include "modules/layer_norm_module.hpp"
#include "modules/positional_embeddings.hpp"
#include "ops/binary_ops.hpp"
#include "ops/unary_ops.hpp"

namespace ttml::models::gpt2 {

namespace {
Expand Down Expand Up @@ -112,12 +114,17 @@ Transformer::Transformer(const TransformerConfig& config) {
"embedding_dim={}",
embedding_dim));
}
tok_emb = std::make_shared<ttml::modules::Embedding>(vocab_size_divisible_by_32, embedding_dim);
auto last_fc = std::make_shared<ttml::modules::LinearLayer>(embedding_dim, vocab_size, /* bias */ false);
if (config.weight_tying == WeightTyingType::Enabled) {
tok_emb = std::make_shared<ttml::modules::Embedding>(last_fc->get_weight());
} else {
tok_emb = std::make_shared<ttml::modules::Embedding>(vocab_size_divisible_by_32, embedding_dim);
}

auto create_positional_embedding = [position_embedding_type,
max_sequence_length,
embedding_dim,
dropout_prob]() -> std::shared_ptr<modules::PositionalEmbeddingBase> {
dropout_prob]() -> std::shared_ptr<autograd::ModuleBase> {
if (position_embedding_type == PositionalEmbeddingType::Trainable) {
return std::make_shared<ttml::modules::TrainablePositionalEmbedding>(
ttml::modules::PositionalEmbeddingConfig{
Expand All @@ -140,7 +147,7 @@ Transformer::Transformer(const TransformerConfig& config) {
std::make_shared<ttml::modules::GPTBlock>(embedding_dim, num_heads, dropout_prob, use_composite_layernorm));
}
ln_fc = std::make_shared<ttml::modules::LayerNormLayer>(embedding_dim, use_composite_layernorm);
fc = std::make_shared<ttml::modules::LinearLayer>(embedding_dim, vocab_size, /* bias */ false);
fc = last_fc;

create_name("transformer");
register_module(tok_emb, "tok_emb");
Expand All @@ -151,11 +158,6 @@ Transformer::Transformer(const TransformerConfig& config) {
register_module(ln_fc, "ln_fc");
register_module(fc, "fc");

if (config.weight_tying == WeightTyingType::Enabled) {
// tie weights between embedding and fc
tok_emb->set_weight(fc->get_weight());
}

weights_initialization(*this);
}

Expand Down
15 changes: 6 additions & 9 deletions tt-train/sources/ttml/models/gpt2.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,7 @@

#include <yaml-cpp/yaml.h>

#include "modules/embedding_module.hpp"
#include "modules/gpt_block.hpp"
#include "modules/layer_norm_module.hpp"
#include "modules/positional_embeddings.hpp"
#include "autograd/module_base.hpp"

namespace ttml::models::gpt2 {

Expand Down Expand Up @@ -48,11 +45,11 @@ struct TransformerConfig {
class Transformer : public ttml::autograd::ModuleBase {
private:
RunnerType runner_type = RunnerType::Default;
std::shared_ptr<ttml::modules::Embedding> tok_emb;
std::shared_ptr<ttml::modules::PositionalEmbeddingBase> pos_emb;
std::vector<std::shared_ptr<ttml::modules::GPTBlock>> blocks;
std::shared_ptr<ttml::modules::LayerNormLayer> ln_fc;
std::shared_ptr<ttml::modules::LinearLayer> fc;
std::shared_ptr<ttml::autograd::ModuleBase> tok_emb;
std::shared_ptr<ttml::autograd::ModuleBase> pos_emb;
std::vector<std::shared_ptr<ttml::autograd::ModuleBase>> blocks;
std::shared_ptr<ttml::autograd::ModuleBase> ln_fc;
std::shared_ptr<ttml::autograd::ModuleBase> fc;

public:
explicit Transformer(const TransformerConfig& config);
Expand Down
2 changes: 1 addition & 1 deletion tt-train/sources/ttml/modules/dropout_module.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class DropoutLayer : public autograd::ModuleBase {
public:
explicit DropoutLayer(float probability, bool use_per_device_seed = true);

[[nodiscard]] autograd::TensorPtr operator()(const autograd::TensorPtr& tensor);
[[nodiscard]] autograd::TensorPtr operator()(const autograd::TensorPtr& tensor) override;
};

} // namespace ttml::modules
11 changes: 6 additions & 5 deletions tt-train/sources/ttml/modules/embedding_module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,14 @@ autograd::TensorPtr Embedding::operator()(const autograd::TensorPtr& tensor) {
return ops::embedding_op(tensor, m_weight);
}

void Embedding::set_weight(const autograd::TensorPtr& weight) {
m_weight = weight;
override_tensor(m_weight, "weight");
}

autograd::TensorPtr Embedding::get_weight() const {
return m_weight;
}

Embedding::Embedding(const autograd::TensorPtr& weight) {
m_weight = weight;
create_name("embedding");
register_tensor(m_weight, "weight");
}

} // namespace ttml::modules
4 changes: 2 additions & 2 deletions tt-train/sources/ttml/modules/embedding_module.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ class Embedding : public autograd::ModuleBase {

public:
Embedding(uint32_t num_embeddings, uint32_t embedding_dim);
void set_weight(const autograd::TensorPtr& weight);
Embedding(const autograd::TensorPtr& weight);
[[nodiscard]] autograd::TensorPtr get_weight() const;

[[nodiscard]] autograd::TensorPtr operator()(const autograd::TensorPtr& tensor);
[[nodiscard]] autograd::TensorPtr operator()(const autograd::TensorPtr& tensor) override;
};

} // namespace ttml::modules
5 changes: 3 additions & 2 deletions tt-train/sources/ttml/modules/gpt_block.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class GPTMLP : public autograd::ModuleBase {
public:
GPTMLP(uint32_t embedding_size, float dropout_prob);

autograd::TensorPtr operator()(const autograd::TensorPtr& input);
[[nodiscard]] autograd::TensorPtr operator()(const autograd::TensorPtr& input) override;
};

class GPTBlock : public autograd::ModuleBase {
Expand All @@ -34,7 +34,8 @@ class GPTBlock : public autograd::ModuleBase {
explicit GPTBlock(
uint32_t embedding_size, uint32_t num_heads, float dropout_prob, bool use_composite_layernorm = false);

autograd::TensorPtr operator()(const autograd::TensorPtr& input, const autograd::TensorPtr& mask);
[[nodiscard]] autograd::TensorPtr operator()(
const autograd::TensorPtr& input, const autograd::TensorPtr& mask) override;
};

} // namespace ttml::modules
2 changes: 1 addition & 1 deletion tt-train/sources/ttml/modules/layer_norm_module.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class LayerNormLayer : public autograd::ModuleBase {
void initialize_tensors(uint32_t features);
explicit LayerNormLayer(uint32_t features, bool use_composite_op = false);

[[nodiscard]] autograd::TensorPtr operator()(const autograd::TensorPtr& tensor);
[[nodiscard]] autograd::TensorPtr operator()(const autograd::TensorPtr& tensor) override;
};

} // namespace ttml::modules
56 changes: 40 additions & 16 deletions tt-train/sources/ttml/modules/linear_module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,42 +6,66 @@

#include <core/ttnn_all_includes.hpp>

#include "autograd/auto_context.hpp"
#include "autograd/tensor.hpp"
#include "core/tt_tensor_utils.hpp"
#include "init/cpu_initializers.hpp"
#include "init/tensor_initializers.hpp"
#include "ops/linear_op.hpp"

namespace ttml::modules {

void LinearLayer::initialize_tensors(uint32_t in_features, uint32_t out_features, bool has_bias) {
namespace {
ttml::autograd::TensorPtr create_weight(uint32_t in_features, uint32_t out_features) {
auto* device = &autograd::ctx().get_device();
auto weight_shape = core::create_shape({1, 1, out_features, in_features});
m_weight = ttml::autograd::create_tensor();
auto weight = ttml::autograd::create_tensor();
const float init_k = std::sqrtf(1.F / static_cast<float>(in_features));
init::uniform_init(m_weight, weight_shape, init::UniformRange{-init_k, init_k});
if (has_bias) {
auto bias_shape = core::create_shape({1, 1, 1, out_features});
m_bias = ttml::autograd::create_tensor();
init::uniform_init(m_bias, bias_shape, init::UniformRange{-init_k, init_k});
}
init::uniform_init(weight, weight_shape, init::UniformRange{-init_k, init_k});
return weight;
}
ttml::autograd::TensorPtr create_bias(uint32_t in_features, uint32_t out_features) {
const float init_k = std::sqrtf(1.F / static_cast<float>(in_features));
auto* device = &ttml::autograd::ctx().get_device();
auto bias_shape = ttml::core::create_shape({1, 1, 1, out_features});
auto bias = ttml::autograd::create_tensor();
ttml::init::uniform_init(bias, bias_shape, ttml::init::UniformRange{-init_k, init_k});
return bias;
}
} // namespace

LinearLayer::LinearLayer(uint32_t in_features, uint32_t out_features, bool has_bias) {
initialize_tensors(in_features, out_features, has_bias);

void LinearLayer::register_tensors() {
create_name("linear");
register_tensor(m_weight, "weight");
if (m_bias != nullptr) {
register_tensor(m_bias, "bias");
}
}

autograd::TensorPtr LinearLayer::get_weight() const {
return m_weight;
LinearLayer::LinearLayer(uint32_t in_features, uint32_t out_features, bool has_bias) {
m_weight = create_weight(in_features, out_features);
if (has_bias) {
m_bias = create_bias(in_features, out_features);
}
register_tensors();
}

LinearLayer::LinearLayer(const autograd::TensorPtr& weight, bool has_bias) : m_weight(weight) {
if (has_bias) {
int in_features = m_weight->get_value().get_logical_shape()[3];
int out_features = m_weight->get_value().get_logical_shape()[2];
m_bias = create_bias(in_features, out_features);
}
register_tensors();
}

void LinearLayer::set_weight(const autograd::TensorPtr& weight) {
m_weight = weight;
override_tensor(m_weight, "weight");
LinearLayer::LinearLayer(const autograd::TensorPtr& weight, const autograd::TensorPtr& bias) :
m_weight(weight), m_bias(bias) {
register_tensors();
}

autograd::TensorPtr LinearLayer::get_weight() const {
return m_weight;
}

autograd::TensorPtr LinearLayer::operator()(const autograd::TensorPtr& tensor) {
Expand Down
13 changes: 4 additions & 9 deletions tt-train/sources/ttml/modules/linear_module.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,8 @@
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include <memory>

#include "autograd/auto_context.hpp"
#include "autograd/graph.hpp"
#include "autograd/module_base.hpp"
#include "autograd/tensor.hpp"
#include "ops/linear_op.hpp"

namespace ttml::modules {

Expand All @@ -20,14 +14,15 @@ class LinearLayer : public autograd::ModuleBase {
autograd::TensorPtr m_bias;

void initialize_tensors(uint32_t in_features, uint32_t out_features, bool has_bias = true);
void register_tensors();

public:
LinearLayer(uint32_t in_features, uint32_t out_features, bool has_bias = true);

LinearLayer(const autograd::TensorPtr& weight, const autograd::TensorPtr& bias);
LinearLayer(const autograd::TensorPtr& weight, bool has_bias = true);
autograd::TensorPtr get_weight() const;
void set_weight(const autograd::TensorPtr& weight);

[[nodiscard]] autograd::TensorPtr operator()(const autograd::TensorPtr& tensor);
[[nodiscard]] autograd::TensorPtr operator()(const autograd::TensorPtr& tensor) override;
};

} // namespace ttml::modules
3 changes: 2 additions & 1 deletion tt-train/sources/ttml/modules/multi_head_attention.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ class MultiHeadAttention : public ttml::autograd::ModuleBase {
public:
explicit MultiHeadAttention(uint32_t embedding_dim, uint32_t num_heads, float dropout_prob);

autograd::TensorPtr operator()(const autograd::TensorPtr& x, const autograd::TensorPtr& mask);
[[nodiscard]] autograd::TensorPtr operator()(
const autograd::TensorPtr& x, const autograd::TensorPtr& mask) override;
};

} // namespace ttml::modules
9 changes: 5 additions & 4 deletions tt-train/sources/ttml/modules/multi_layer_perceptron.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,16 @@ MultiLayerPerceptron::MultiLayerPerceptron(const MultiLayerPerceptronParameters&
register_module(m_layers[idx], "layer_" + std::to_string(idx));
}
}
autograd::TensorPtr MultiLayerPerceptron::operator()(autograd::TensorPtr tensor) {
autograd::TensorPtr MultiLayerPerceptron::operator()(const autograd::TensorPtr& tensor) {
auto x = tensor;
for (size_t index = 0; index < m_layers.size(); ++index) {
tensor = (*m_layers[index])(tensor);
x = (*m_layers[index])(x);
if (index + 1 != m_layers.size()) {
tensor = ops::relu(tensor);
x = ops::relu(x);
}
}

return tensor;
return x;
}

} // namespace ttml::modules
2 changes: 1 addition & 1 deletion tt-train/sources/ttml/modules/multi_layer_perceptron.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class MultiLayerPerceptron : public autograd::ModuleBase {
public:
explicit MultiLayerPerceptron(const MultiLayerPerceptronParameters& params);

[[nodiscard]] autograd::TensorPtr operator()(autograd::TensorPtr tensor);
[[nodiscard]] autograd::TensorPtr operator()(const autograd::TensorPtr& tensor) override;
};

} // namespace ttml::modules
9 changes: 2 additions & 7 deletions tt-train/sources/ttml/modules/positional_embeddings.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,7 @@ struct PositionalEmbeddingConfig {
bool use_dropout_seed_per_device{true};
};

class PositionalEmbeddingBase : public autograd::ModuleBase {
public:
virtual autograd::TensorPtr operator()(const autograd::TensorPtr& input) = 0;
};

class PositionalEmbedding : public PositionalEmbeddingBase {
class PositionalEmbedding : public autograd::ModuleBase {
private:
uint32_t m_sequence_length{};
std::shared_ptr<DropoutLayer> m_dropout;
Expand All @@ -35,7 +30,7 @@ class PositionalEmbedding : public PositionalEmbeddingBase {
[[nodiscard]] autograd::TensorPtr operator()(const autograd::TensorPtr& input) override;
};

class TrainablePositionalEmbedding : public PositionalEmbeddingBase {
class TrainablePositionalEmbedding : public autograd::ModuleBase {
uint32_t m_sequence_length{};
autograd::TensorPtr m_weight;
std::shared_ptr<DropoutLayer> m_dropout;
Expand Down
Loading

0 comments on commit f05457a

Please sign in to comment.