Skip to content

Commit

Permalink
[tt-train] Add RMSNorm module (#16991)
Browse files Browse the repository at this point in the history
### Problem description
We need RMSNorm to train Llama 3 and some other exciting open source
models.

### What's changed
- Added RMS op
- Added RMS module

### Checklist
- [ ] Post commit CI passes
- [ ] Blackhole Post commit (if applicable)
- [ ] Model regression CI testing passes (if applicable)
- [ ] Device performance regression CI testing passes (if applicable)
- [ ] **(For models and ops writers)** Full [new
models](https://github.com/tenstorrent/tt-metal/actions/workflows/full-new-models-suite.yaml)
tests passes
- [ ] New/Existing tests provide coverage for changes
  • Loading branch information
jaykru-tt authored Feb 19, 2025
1 parent 686a4f0 commit c17e35a
Show file tree
Hide file tree
Showing 6 changed files with 333 additions and 0 deletions.
1 change: 1 addition & 0 deletions tt-train/sources/ttml/core/ttnn_all_includes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include <ttnn/operations/data_movement/untilize/untilize.hpp> // NOLINT
#include <ttnn/operations/eltwise/binary/binary.hpp> // NOLINT
#include <ttnn/operations/eltwise/binary_backward/binary_backward.hpp> // NOLINT
#include <ttnn/operations/eltwise/binary_ng/binary_ng.hpp> // NOLINT
#include <ttnn/operations/eltwise/unary/unary.hpp> // NOLINT
#include <ttnn/operations/eltwise/unary/unary_composite.hpp> // NOLINT
#include <ttnn/operations/eltwise/unary_backward/unary_backward.hpp> // NOLINT
Expand Down
28 changes: 28 additions & 0 deletions tt-train/sources/ttml/modules/rms_norm_module.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include "rms_norm_module.hpp"

#include "core/tt_tensor_utils.hpp"
#include "ops/rmsnorm_op.hpp"

namespace ttml::modules {

void RMSNormLayer::initialize_tensors(uint32_t features) {
m_gamma =
autograd::create_tensor(core::ones(core::create_shape({1, 1, 1, features}), &autograd::ctx().get_device()));
}

RMSNormLayer::RMSNormLayer(uint32_t features, float epsilon) : m_epsilon(epsilon) {
initialize_tensors(features);

create_name("rmsnorm");
register_tensor(m_gamma, "gamma");
}

autograd::TensorPtr RMSNormLayer::operator()(const autograd::TensorPtr& tensor) {
return ops::rmsnorm(tensor, m_gamma, m_epsilon);
}

} // namespace ttml::modules
27 changes: 27 additions & 0 deletions tt-train/sources/ttml/modules/rms_norm_module.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "autograd/auto_context.hpp"
#include "autograd/graph.hpp"
#include "autograd/module_base.hpp"
#include "autograd/tensor.hpp"
#include "ops/rmsnorm_op.hpp"

namespace ttml::modules {

class RMSNormLayer : public autograd::ModuleBase {
private:
float m_epsilon = 1e-5F;
autograd::TensorPtr m_gamma = nullptr;

public:
void initialize_tensors(uint32_t features);
explicit RMSNormLayer(uint32_t features, float epsilon = 1e-5F);

[[nodiscard]] autograd::TensorPtr operator()(const autograd::TensorPtr& tensor);
};

} // namespace ttml::modules
116 changes: 116 additions & 0 deletions tt-train/sources/ttml/ops/rmsnorm_op.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
// SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include "rmsnorm_op.hpp"

#include <cassert>
#include <core/ttnn_all_includes.hpp>
#include <cstdint>
#include <optional>
#include <stdexcept>

#include "autograd/auto_context.hpp"
#include "autograd/graph.hpp"
#include "autograd/graph_utils.hpp"
#include "autograd/tensor.hpp"
#include "core/compute_kernel_config.hpp"
#include "ttnn_fixed/trivial_ttnn_ops.hpp"

namespace ttml::ops {

autograd::TensorPtr rmsnorm(const autograd::TensorPtr &tensor, const autograd::TensorPtr &gamma, float epsilon) {
auto a_shape = tensor->get_value().logical_shape();
if (a_shape.rank() != 4) {
throw std::runtime_error("rmsnorm only supports rank-4 input tensors.");
}

auto ashape_arr = a_shape.to_array_4D();
auto [B, N, S, C] = ashape_arr;
assert((N == 1)); // one sequence per batch

// one gain parameter per channel
assert((gamma->get_value().logical_shape().to_array_4D() == std::array<uint32_t, 4>{1, 1, 1, C}));

auto device = &autograd::ctx().get_device();

ttnn::Tensor squares = ttnn::square(tensor->get_value()); // [B,1,S,C] -> [B,1,S,C]

ttnn::Tensor seq_means_of_squares = ttnn::mean(squares, /*dim_arg=*/-1, /*keep_dim=*/true); // [B,1,S,1]

ttnn::Tensor seq_means_of_squares_plus_epsilon =
ttnn::experimental::add(seq_means_of_squares, epsilon); // [B,1,S,1] x. [1] -> [B,1,S,1] (bcast)

ttnn::Tensor rms_a = ttnn::sqrt(seq_means_of_squares_plus_epsilon); // [B,1,S,1] -> [B,1,S,1]

ttnn::Tensor gamma_times_activations =
ttnn::experimental::mul(gamma->get_value(), tensor->get_value()); // [1,1,1,C] x [B,1,S,C] -> [B,1,S,C] (bcast)

ttnn::Tensor out_tensor =
ttnn::experimental::div(gamma_times_activations, rms_a); // [B,1,S,C] x [B,1,S,C] -> [B,1,S,C]

auto out = autograd::create_tensor(out_tensor);

autograd::GradFunction grad = [B, S, C, tensor, gamma, out, rms_a, device]() {
auto a = tensor->get_value(); // [B,1,S,C]
auto g = gamma->get_value(); // [1,1,1,C]

// c is the number of activations; in the RMS1orm paper they call this
// "n". it is renamed here to avoid confusion with 1.
auto c = static_cast<float>(a.logical_shape()[-1]);

auto dL_dout = out->get_grad(); // Grad w.r.t normalized arctivations, hence [B,1,S,C]

auto scaled_gain = ttnn::experimental::div(g, rms_a); // [1,1,1,C] x [B,1,S,1] -> [B,1,S,C] (bcast)
auto gained_dL_dout = ttnn::experimental::mul(scaled_gain, dL_dout); // [B,1,S,C] x [B,1,S,C] -> [B,1,S,C]

// notation:
// _ · _ <- usual dot product
// _ @ _ <- matrix multiplication
// _ *. _ <- Hadamard product/eltwise multiplication with broadcasting
// _ /. _ <- eltwise division with broadcasting

// have a : [B,1,S,C]

// want to obtain scaled_outer = gained_dL_dout @ ((a@a^T)/n*rms(a)^2)

// to avoid computing the large outer product matrix explicitly, we
// instead compute
// scale = (a^T · gained_dL_dout) : [B,1,S,C] x [B,1,S,C] -> [1]
// scaled_outer = scale *. a : [1] x [B,1,S,C] -> [B,1,S,C]

auto scale = ttml::ttnn_fixed::sum_over_dim(
ttnn::experimental::mul(a, gained_dL_dout), 3); // [B,1,S,C] x [B,1,S,C] -> [B,1,S,C] -> [B,1,S,1]

auto scaled_outer = ttnn::experimental::mul(scale, a); // [B,1,S,1] x [B,1,S,C] -> [B,1,S,C] (bcast)

auto ms_a = ttnn::square(rms_a); // [B,1,S,1] -> [B,1,S,1]

auto c_by_ms_a = ttnn::experimental::mul(ms_a, c); // [B,1,S,1] x [1] -> [B,1,S,1] (bcast)

auto rhs = ttnn::experimental::div(scaled_outer, c_by_ms_a); // [B,1,S,C] x [B,1,S,1] -> [B,1,S,C] (bcast)

auto dL_da =
ttnn::experimental::sub(gained_dL_dout, rhs); // [B,1,S,C] x [B,1,S,C] -> [B,1,S,C]; checked by add_grad
tensor->add_grad(dL_da);

// dL_dgamma = (a / rms(a)) * dL_dout -> requires sum over batch due to broadcasting
auto dL_dg_components = ttnn::experimental::mul(
dL_dout,
ttnn::experimental::div(a, rms_a)); // [B,1,S,C] x [B,1,S,1] -> [B,1,S,C] (bcast); checked by add_grad
auto dL_dg = ttnn::sum(
dL_dg_components,
/* dim_arg */ ttnn::SmallVector<int>{0, 1, 2},
/* keep_dim */ true,
/* output_mem_config */ std::nullopt,
/*compute_kernel_config */ core::ComputeKernelConfig::precise()); // [B,1,S,C] -> [1,1,1,C]
gamma->add_grad(dL_dg);
};

auto links = autograd::get_links(tensor, gamma);
out->set_node(autograd::ctx().add_backward_node(std::move(grad), links));

return out;
}

} // namespace ttml::ops
12 changes: 12 additions & 0 deletions tt-train/sources/ttml/ops/rmsnorm_op.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#pragma once
#include "autograd/tensor.hpp"

namespace ttml::ops {

autograd::TensorPtr rmsnorm(const autograd::TensorPtr& tensor, const autograd::TensorPtr& gamma, float epsilon);

} // namespace ttml::ops
149 changes: 149 additions & 0 deletions tt-train/tests/ops/rmsnorm_op_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
// SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include "ops/rmsnorm_op.hpp"

#include <gtest/gtest.h>

#include <cassert>
#include <core/ttnn_all_includes.hpp>

#include "autograd/auto_context.hpp"
#include "autograd/tensor.hpp"
#include "core/tt_tensor_utils.hpp"
#include "ops/losses.hpp"

class RMSNormOpTest : public ::testing::Test {
protected:
void SetUp() override {
ttml::autograd::ctx().open_device();
}

void TearDown() override {
ttml::autograd::ctx().close_device();
}
};

// Forward and backward tests are given by comparing with results from PyTorch:
// For test tensor `x` of shape [N,C,H,W] we set x.requires_grad = True
// and compute the RMSNorm as `x_norm_sum = torch.nn.functional.rms_norm(x).sum()`
// and compute its gradient with respect to `x` as `x_grad = torch.autograd.grad(x_norm_sum, x)[0]`
// We then compare the results of the RMSNorm and its gradient with the results of the RMSNorm and its gradient
// computed by the RMSNorm op in TTML.
TEST_F(RMSNormOpTest, RMSNorm_Small_Forward) {
using namespace ttml;
float eps = 0.0078125F; // default in PyTorch for bf16

uint32_t N = 1, C = 1, H = 1, W = 8;

xt::xarray<float> example_xtensor = {{{{1.F, 2.F, 3.F, 4.F, 1.F, 2.F, 3.F, 4.F}}}};
auto example_tensor = autograd::create_tensor(core::from_xtensor(example_xtensor, &autograd::ctx().get_device()));
auto gamma = autograd::create_tensor(core::ones(core::create_shape({1, 1, 1, W}), &autograd::ctx().get_device()));

auto result = ops::rmsnorm(example_tensor, gamma, 0.0078125F);
auto result_xtensor = core::to_xtensor(result->get_value());
xt::xarray<float> expected_result = {{0.3652F, 0.7305F, 1.0938F, 1.4609F, 0.3652F, 0.7305F, 1.0938F, 1.4609F}};
EXPECT_TRUE(xt::allclose(result_xtensor, expected_result, 1e-2F));
}

TEST_F(RMSNormOpTest, RMSNorm_Small_Backward) {
using namespace ttml;
float eps = 0.0078125F; // default in PyTorch for bf16

uint32_t N = 1, C = 1, H = 1, W = 8;

xt::xarray<float> example_xtensor = {{{{1.F, 2.F, 3.F, 4.F, 1.F, 2.F, 3.F, 4.F}}}};
auto example_tensor = autograd::create_tensor(core::from_xtensor(example_xtensor, &autograd::ctx().get_device()));
auto gamma = autograd::create_tensor(core::ones(core::create_shape({1, 1, 1, W}), &autograd::ctx().get_device()));

auto result = ops::rmsnorm(example_tensor, gamma, 0.0078125F);
auto result_xtensor = core::to_xtensor(result->get_value());

auto target = autograd::create_tensor(core::zeros_like(result->get_value()));
auto mse_result = ttml::ops::mse_loss(result, target);
mse_result->backward();
auto example_tensor_grad = core::to_xtensor(example_tensor->get_grad());
auto expected_example_tensor_grad = xt::xarray<float>(
{{{{5.2452e-05F,
1.0490e-04F,
-2.0742e-05F,
2.0981e-04F,
5.2452e-05F,
1.0490e-04F,
-2.0742e-05F,
2.0981e-04F}}}});
EXPECT_TRUE(xt::allclose(example_tensor_grad, expected_example_tensor_grad, 1.0e-3F, 1e-2F));

auto gamma_grad = core::to_xtensor(gamma->get_grad());
auto expected_gamma_grad =
xt::xarray<float>({{{{0.0334F, 0.1338F, 0.2988F, 0.5352F, 0.0334F, 0.1338F, 0.2988F, 0.5352F}}}});
EXPECT_TRUE(xt::allclose(gamma_grad, expected_gamma_grad, 1.0e-3F, 1e-2F));
}

TEST_F(RMSNormOpTest, RMSNorm_Forward_Batch) {
using namespace ttml;
float eps = 0.0078125F; // default in PyTorch for bf16

// 2 batches, 1 sequence, 20 tokens, 5-dim'l embedding space.
std::array<uint32_t, 4> a_shape = {2, 1, 20, 5};
xt::xarray<float> a_xarray = xt::xarray<float>::from_shape(a_shape);
std::generate(a_xarray.begin(), a_xarray.end(), [cur = 0.0F]() mutable { return (cur++); });

auto example_tensor = autograd::create_tensor(core::from_xtensor(a_xarray, &autograd::ctx().get_device()));
auto gamma = autograd::create_tensor(core::ones(core::create_shape({1, 1, 1, 5}), &autograd::ctx().get_device()));

auto result = ops::rmsnorm(example_tensor, gamma, 0.0078125F);
auto result_xtensor = core::to_xtensor(result->get_value());
xt::xarray<float> expected_result = {
{{{0.00000F, 0.40820F, 0.81641F, 1.22656F, 1.63281F}, {0.69922F, 0.83984F, 0.98047F, 1.11719F, 1.25781F},
{0.82812F, 0.91016F, 0.99219F, 1.07812F, 1.15625F}, {0.87891F, 0.93750F, 0.99609F, 1.05469F, 1.11719F},
{0.90625F, 0.95312F, 0.99609F, 1.04688F, 1.08594F}, {0.92578F, 0.96094F, 1.00000F, 1.03906F, 1.07031F},
{0.93750F, 0.96875F, 1.00000F, 1.03125F, 1.06250F}, {0.94531F, 0.97266F, 1.00000F, 1.02344F, 1.05469F},
{0.95312F, 0.97656F, 1.00000F, 1.02344F, 1.04688F}, {0.95703F, 0.97656F, 1.00000F, 1.02344F, 1.03906F},
{0.96094F, 0.98047F, 1.00000F, 1.01562F, 1.03906F}, {0.96484F, 0.98047F, 1.00000F, 1.01562F, 1.03125F},
{0.96875F, 0.98438F, 1.00000F, 1.01562F, 1.03125F}, {0.96875F, 0.98438F, 1.00000F, 1.01562F, 1.03125F},
{0.97266F, 0.98438F, 1.00000F, 1.01562F, 1.03125F}, {0.97266F, 0.98828F, 1.00000F, 1.01562F, 1.02344F},
{0.97656F, 0.98828F, 1.00000F, 1.01562F, 1.02344F}, {0.97656F, 0.98828F, 1.00000F, 1.00781F, 1.02344F},
{0.97656F, 0.98828F, 1.00000F, 1.00781F, 1.02344F}, {0.98047F, 0.98828F, 1.00000F, 1.00781F, 1.02344F}}},
{{{0.98047F, 0.98828F, 1.00000F, 1.00781F, 1.01562F}, {0.98047F, 0.99219F, 1.00000F, 1.00781F, 1.01562F},
{0.98047F, 0.99219F, 1.00000F, 1.00781F, 1.01562F}, {0.98438F, 0.99219F, 1.00000F, 1.00781F, 1.01562F},
{0.98438F, 0.99219F, 1.00000F, 1.00781F, 1.01562F}, {0.98438F, 0.99219F, 1.00000F, 1.00781F, 1.01562F},
{0.98438F, 0.99219F, 1.00000F, 1.00781F, 1.01562F}, {0.98438F, 0.99219F, 1.00000F, 1.00781F, 1.01562F},
{0.98438F, 0.99219F, 1.00000F, 1.00781F, 1.01562F}, {0.98828F, 0.99219F, 1.00000F, 1.00781F, 1.01562F},
{0.98828F, 0.99219F, 1.00000F, 1.00781F, 1.01562F}, {0.98828F, 0.99219F, 1.00000F, 1.00781F, 1.01562F},
{0.98828F, 0.99219F, 1.00000F, 1.00781F, 1.01562F}, {0.98828F, 0.99219F, 1.00000F, 1.00781F, 1.01562F},
{0.98828F, 0.99609F, 1.00000F, 1.00781F, 1.00781F}, {0.98828F, 0.99609F, 1.00000F, 1.00781F, 1.00781F},
{0.98828F, 0.99609F, 1.00000F, 1.00781F, 1.00781F}, {0.98828F, 0.99609F, 1.00000F, 1.00781F, 1.00781F},
{0.98828F, 0.99609F, 1.00000F, 1.00781F, 1.00781F}, {0.98828F, 0.99609F, 1.00000F, 1.00781F, 1.00781F}}}};
assert((expected_result.shape() == result_xtensor.shape()));
EXPECT_TRUE(xt::allclose(result_xtensor, expected_result, 6e-2F, 1e-8F));
}

TEST_F(RMSNormOpTest, RMSNorm_Backward_Batch) {
using namespace ttml;
float eps = 0.0078125F; // default in PyTorch for bf16

// 2 batches, 1 sequence, 20 tokens, 5-dim'l embedding space.
std::array<uint32_t, 4> a_shape = {2, 1, 20, 5};
xt::xarray<float> a_xarray = xt::xarray<float>::from_shape(a_shape);
std::generate(a_xarray.begin(), a_xarray.end(), [cur = 0.0F]() mutable { return (cur++); });

auto example_tensor = autograd::create_tensor(core::from_xtensor(a_xarray, &autograd::ctx().get_device()));
auto gamma = autograd::create_tensor(core::ones(core::create_shape({1, 1, 1, 5}), &autograd::ctx().get_device()));

auto result = ops::rmsnorm(example_tensor, gamma, 0.0078125F);
auto result_xtensor = core::to_xtensor(result->get_value());

auto target = autograd::create_tensor(core::zeros_like(result->get_value()));
auto mse_result = ttml::ops::mse_loss(result, target);
mse_result->backward();

auto example_tensor_grad = core::to_xtensor(example_tensor->get_grad());
xt::xarray<float> expected_example_tensor_grad = xt::zeros_like(a_xarray);
EXPECT_TRUE(xt::allclose(example_tensor_grad, expected_example_tensor_grad, 5e-2F, 1e-3F));

auto gamma_grad = core::to_xtensor(gamma->get_grad());
xt::xarray<float> expected_gamma_grad = {{{{0.36111F, 0.37644F, 0.39589F, 0.41945F, 0.44712F}}}};
EXPECT_TRUE(xt::allclose(gamma_grad, expected_gamma_grad, 5e-2F));
}

0 comments on commit c17e35a

Please sign in to comment.