-
Notifications
You must be signed in to change notification settings - Fork 111
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[tt-train] Add RMSNorm module (#16991)
### Problem description We need RMSNorm to train Llama 3 and some other exciting open source models. ### What's changed - Added RMS op - Added RMS module ### Checklist - [ ] Post commit CI passes - [ ] Blackhole Post commit (if applicable) - [ ] Model regression CI testing passes (if applicable) - [ ] Device performance regression CI testing passes (if applicable) - [ ] **(For models and ops writers)** Full [new models](https://github.com/tenstorrent/tt-metal/actions/workflows/full-new-models-suite.yaml) tests passes - [ ] New/Existing tests provide coverage for changes
- Loading branch information
Showing
6 changed files
with
333 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
// SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include "rms_norm_module.hpp" | ||
|
||
#include "core/tt_tensor_utils.hpp" | ||
#include "ops/rmsnorm_op.hpp" | ||
|
||
namespace ttml::modules { | ||
|
||
void RMSNormLayer::initialize_tensors(uint32_t features) { | ||
m_gamma = | ||
autograd::create_tensor(core::ones(core::create_shape({1, 1, 1, features}), &autograd::ctx().get_device())); | ||
} | ||
|
||
RMSNormLayer::RMSNormLayer(uint32_t features, float epsilon) : m_epsilon(epsilon) { | ||
initialize_tensors(features); | ||
|
||
create_name("rmsnorm"); | ||
register_tensor(m_gamma, "gamma"); | ||
} | ||
|
||
autograd::TensorPtr RMSNormLayer::operator()(const autograd::TensorPtr& tensor) { | ||
return ops::rmsnorm(tensor, m_gamma, m_epsilon); | ||
} | ||
|
||
} // namespace ttml::modules |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
// SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#pragma once | ||
|
||
#include "autograd/auto_context.hpp" | ||
#include "autograd/graph.hpp" | ||
#include "autograd/module_base.hpp" | ||
#include "autograd/tensor.hpp" | ||
#include "ops/rmsnorm_op.hpp" | ||
|
||
namespace ttml::modules { | ||
|
||
class RMSNormLayer : public autograd::ModuleBase { | ||
private: | ||
float m_epsilon = 1e-5F; | ||
autograd::TensorPtr m_gamma = nullptr; | ||
|
||
public: | ||
void initialize_tensors(uint32_t features); | ||
explicit RMSNormLayer(uint32_t features, float epsilon = 1e-5F); | ||
|
||
[[nodiscard]] autograd::TensorPtr operator()(const autograd::TensorPtr& tensor); | ||
}; | ||
|
||
} // namespace ttml::modules |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
// SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include "rmsnorm_op.hpp" | ||
|
||
#include <cassert> | ||
#include <core/ttnn_all_includes.hpp> | ||
#include <cstdint> | ||
#include <optional> | ||
#include <stdexcept> | ||
|
||
#include "autograd/auto_context.hpp" | ||
#include "autograd/graph.hpp" | ||
#include "autograd/graph_utils.hpp" | ||
#include "autograd/tensor.hpp" | ||
#include "core/compute_kernel_config.hpp" | ||
#include "ttnn_fixed/trivial_ttnn_ops.hpp" | ||
|
||
namespace ttml::ops { | ||
|
||
autograd::TensorPtr rmsnorm(const autograd::TensorPtr &tensor, const autograd::TensorPtr &gamma, float epsilon) { | ||
auto a_shape = tensor->get_value().logical_shape(); | ||
if (a_shape.rank() != 4) { | ||
throw std::runtime_error("rmsnorm only supports rank-4 input tensors."); | ||
} | ||
|
||
auto ashape_arr = a_shape.to_array_4D(); | ||
auto [B, N, S, C] = ashape_arr; | ||
assert((N == 1)); // one sequence per batch | ||
|
||
// one gain parameter per channel | ||
assert((gamma->get_value().logical_shape().to_array_4D() == std::array<uint32_t, 4>{1, 1, 1, C})); | ||
|
||
auto device = &autograd::ctx().get_device(); | ||
|
||
ttnn::Tensor squares = ttnn::square(tensor->get_value()); // [B,1,S,C] -> [B,1,S,C] | ||
|
||
ttnn::Tensor seq_means_of_squares = ttnn::mean(squares, /*dim_arg=*/-1, /*keep_dim=*/true); // [B,1,S,1] | ||
|
||
ttnn::Tensor seq_means_of_squares_plus_epsilon = | ||
ttnn::experimental::add(seq_means_of_squares, epsilon); // [B,1,S,1] x. [1] -> [B,1,S,1] (bcast) | ||
|
||
ttnn::Tensor rms_a = ttnn::sqrt(seq_means_of_squares_plus_epsilon); // [B,1,S,1] -> [B,1,S,1] | ||
|
||
ttnn::Tensor gamma_times_activations = | ||
ttnn::experimental::mul(gamma->get_value(), tensor->get_value()); // [1,1,1,C] x [B,1,S,C] -> [B,1,S,C] (bcast) | ||
|
||
ttnn::Tensor out_tensor = | ||
ttnn::experimental::div(gamma_times_activations, rms_a); // [B,1,S,C] x [B,1,S,C] -> [B,1,S,C] | ||
|
||
auto out = autograd::create_tensor(out_tensor); | ||
|
||
autograd::GradFunction grad = [B, S, C, tensor, gamma, out, rms_a, device]() { | ||
auto a = tensor->get_value(); // [B,1,S,C] | ||
auto g = gamma->get_value(); // [1,1,1,C] | ||
|
||
// c is the number of activations; in the RMS1orm paper they call this | ||
// "n". it is renamed here to avoid confusion with 1. | ||
auto c = static_cast<float>(a.logical_shape()[-1]); | ||
|
||
auto dL_dout = out->get_grad(); // Grad w.r.t normalized arctivations, hence [B,1,S,C] | ||
|
||
auto scaled_gain = ttnn::experimental::div(g, rms_a); // [1,1,1,C] x [B,1,S,1] -> [B,1,S,C] (bcast) | ||
auto gained_dL_dout = ttnn::experimental::mul(scaled_gain, dL_dout); // [B,1,S,C] x [B,1,S,C] -> [B,1,S,C] | ||
|
||
// notation: | ||
// _ · _ <- usual dot product | ||
// _ @ _ <- matrix multiplication | ||
// _ *. _ <- Hadamard product/eltwise multiplication with broadcasting | ||
// _ /. _ <- eltwise division with broadcasting | ||
|
||
// have a : [B,1,S,C] | ||
|
||
// want to obtain scaled_outer = gained_dL_dout @ ((a@a^T)/n*rms(a)^2) | ||
|
||
// to avoid computing the large outer product matrix explicitly, we | ||
// instead compute | ||
// scale = (a^T · gained_dL_dout) : [B,1,S,C] x [B,1,S,C] -> [1] | ||
// scaled_outer = scale *. a : [1] x [B,1,S,C] -> [B,1,S,C] | ||
|
||
auto scale = ttml::ttnn_fixed::sum_over_dim( | ||
ttnn::experimental::mul(a, gained_dL_dout), 3); // [B,1,S,C] x [B,1,S,C] -> [B,1,S,C] -> [B,1,S,1] | ||
|
||
auto scaled_outer = ttnn::experimental::mul(scale, a); // [B,1,S,1] x [B,1,S,C] -> [B,1,S,C] (bcast) | ||
|
||
auto ms_a = ttnn::square(rms_a); // [B,1,S,1] -> [B,1,S,1] | ||
|
||
auto c_by_ms_a = ttnn::experimental::mul(ms_a, c); // [B,1,S,1] x [1] -> [B,1,S,1] (bcast) | ||
|
||
auto rhs = ttnn::experimental::div(scaled_outer, c_by_ms_a); // [B,1,S,C] x [B,1,S,1] -> [B,1,S,C] (bcast) | ||
|
||
auto dL_da = | ||
ttnn::experimental::sub(gained_dL_dout, rhs); // [B,1,S,C] x [B,1,S,C] -> [B,1,S,C]; checked by add_grad | ||
tensor->add_grad(dL_da); | ||
|
||
// dL_dgamma = (a / rms(a)) * dL_dout -> requires sum over batch due to broadcasting | ||
auto dL_dg_components = ttnn::experimental::mul( | ||
dL_dout, | ||
ttnn::experimental::div(a, rms_a)); // [B,1,S,C] x [B,1,S,1] -> [B,1,S,C] (bcast); checked by add_grad | ||
auto dL_dg = ttnn::sum( | ||
dL_dg_components, | ||
/* dim_arg */ ttnn::SmallVector<int>{0, 1, 2}, | ||
/* keep_dim */ true, | ||
/* output_mem_config */ std::nullopt, | ||
/*compute_kernel_config */ core::ComputeKernelConfig::precise()); // [B,1,S,C] -> [1,1,1,C] | ||
gamma->add_grad(dL_dg); | ||
}; | ||
|
||
auto links = autograd::get_links(tensor, gamma); | ||
out->set_node(autograd::ctx().add_backward_node(std::move(grad), links)); | ||
|
||
return out; | ||
} | ||
|
||
} // namespace ttml::ops |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
// SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#pragma once | ||
#include "autograd/tensor.hpp" | ||
|
||
namespace ttml::ops { | ||
|
||
autograd::TensorPtr rmsnorm(const autograd::TensorPtr& tensor, const autograd::TensorPtr& gamma, float epsilon); | ||
|
||
} // namespace ttml::ops |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
// SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include "ops/rmsnorm_op.hpp" | ||
|
||
#include <gtest/gtest.h> | ||
|
||
#include <cassert> | ||
#include <core/ttnn_all_includes.hpp> | ||
|
||
#include "autograd/auto_context.hpp" | ||
#include "autograd/tensor.hpp" | ||
#include "core/tt_tensor_utils.hpp" | ||
#include "ops/losses.hpp" | ||
|
||
class RMSNormOpTest : public ::testing::Test { | ||
protected: | ||
void SetUp() override { | ||
ttml::autograd::ctx().open_device(); | ||
} | ||
|
||
void TearDown() override { | ||
ttml::autograd::ctx().close_device(); | ||
} | ||
}; | ||
|
||
// Forward and backward tests are given by comparing with results from PyTorch: | ||
// For test tensor `x` of shape [N,C,H,W] we set x.requires_grad = True | ||
// and compute the RMSNorm as `x_norm_sum = torch.nn.functional.rms_norm(x).sum()` | ||
// and compute its gradient with respect to `x` as `x_grad = torch.autograd.grad(x_norm_sum, x)[0]` | ||
// We then compare the results of the RMSNorm and its gradient with the results of the RMSNorm and its gradient | ||
// computed by the RMSNorm op in TTML. | ||
TEST_F(RMSNormOpTest, RMSNorm_Small_Forward) { | ||
using namespace ttml; | ||
float eps = 0.0078125F; // default in PyTorch for bf16 | ||
|
||
uint32_t N = 1, C = 1, H = 1, W = 8; | ||
|
||
xt::xarray<float> example_xtensor = {{{{1.F, 2.F, 3.F, 4.F, 1.F, 2.F, 3.F, 4.F}}}}; | ||
auto example_tensor = autograd::create_tensor(core::from_xtensor(example_xtensor, &autograd::ctx().get_device())); | ||
auto gamma = autograd::create_tensor(core::ones(core::create_shape({1, 1, 1, W}), &autograd::ctx().get_device())); | ||
|
||
auto result = ops::rmsnorm(example_tensor, gamma, 0.0078125F); | ||
auto result_xtensor = core::to_xtensor(result->get_value()); | ||
xt::xarray<float> expected_result = {{0.3652F, 0.7305F, 1.0938F, 1.4609F, 0.3652F, 0.7305F, 1.0938F, 1.4609F}}; | ||
EXPECT_TRUE(xt::allclose(result_xtensor, expected_result, 1e-2F)); | ||
} | ||
|
||
TEST_F(RMSNormOpTest, RMSNorm_Small_Backward) { | ||
using namespace ttml; | ||
float eps = 0.0078125F; // default in PyTorch for bf16 | ||
|
||
uint32_t N = 1, C = 1, H = 1, W = 8; | ||
|
||
xt::xarray<float> example_xtensor = {{{{1.F, 2.F, 3.F, 4.F, 1.F, 2.F, 3.F, 4.F}}}}; | ||
auto example_tensor = autograd::create_tensor(core::from_xtensor(example_xtensor, &autograd::ctx().get_device())); | ||
auto gamma = autograd::create_tensor(core::ones(core::create_shape({1, 1, 1, W}), &autograd::ctx().get_device())); | ||
|
||
auto result = ops::rmsnorm(example_tensor, gamma, 0.0078125F); | ||
auto result_xtensor = core::to_xtensor(result->get_value()); | ||
|
||
auto target = autograd::create_tensor(core::zeros_like(result->get_value())); | ||
auto mse_result = ttml::ops::mse_loss(result, target); | ||
mse_result->backward(); | ||
auto example_tensor_grad = core::to_xtensor(example_tensor->get_grad()); | ||
auto expected_example_tensor_grad = xt::xarray<float>( | ||
{{{{5.2452e-05F, | ||
1.0490e-04F, | ||
-2.0742e-05F, | ||
2.0981e-04F, | ||
5.2452e-05F, | ||
1.0490e-04F, | ||
-2.0742e-05F, | ||
2.0981e-04F}}}}); | ||
EXPECT_TRUE(xt::allclose(example_tensor_grad, expected_example_tensor_grad, 1.0e-3F, 1e-2F)); | ||
|
||
auto gamma_grad = core::to_xtensor(gamma->get_grad()); | ||
auto expected_gamma_grad = | ||
xt::xarray<float>({{{{0.0334F, 0.1338F, 0.2988F, 0.5352F, 0.0334F, 0.1338F, 0.2988F, 0.5352F}}}}); | ||
EXPECT_TRUE(xt::allclose(gamma_grad, expected_gamma_grad, 1.0e-3F, 1e-2F)); | ||
} | ||
|
||
TEST_F(RMSNormOpTest, RMSNorm_Forward_Batch) { | ||
using namespace ttml; | ||
float eps = 0.0078125F; // default in PyTorch for bf16 | ||
|
||
// 2 batches, 1 sequence, 20 tokens, 5-dim'l embedding space. | ||
std::array<uint32_t, 4> a_shape = {2, 1, 20, 5}; | ||
xt::xarray<float> a_xarray = xt::xarray<float>::from_shape(a_shape); | ||
std::generate(a_xarray.begin(), a_xarray.end(), [cur = 0.0F]() mutable { return (cur++); }); | ||
|
||
auto example_tensor = autograd::create_tensor(core::from_xtensor(a_xarray, &autograd::ctx().get_device())); | ||
auto gamma = autograd::create_tensor(core::ones(core::create_shape({1, 1, 1, 5}), &autograd::ctx().get_device())); | ||
|
||
auto result = ops::rmsnorm(example_tensor, gamma, 0.0078125F); | ||
auto result_xtensor = core::to_xtensor(result->get_value()); | ||
xt::xarray<float> expected_result = { | ||
{{{0.00000F, 0.40820F, 0.81641F, 1.22656F, 1.63281F}, {0.69922F, 0.83984F, 0.98047F, 1.11719F, 1.25781F}, | ||
{0.82812F, 0.91016F, 0.99219F, 1.07812F, 1.15625F}, {0.87891F, 0.93750F, 0.99609F, 1.05469F, 1.11719F}, | ||
{0.90625F, 0.95312F, 0.99609F, 1.04688F, 1.08594F}, {0.92578F, 0.96094F, 1.00000F, 1.03906F, 1.07031F}, | ||
{0.93750F, 0.96875F, 1.00000F, 1.03125F, 1.06250F}, {0.94531F, 0.97266F, 1.00000F, 1.02344F, 1.05469F}, | ||
{0.95312F, 0.97656F, 1.00000F, 1.02344F, 1.04688F}, {0.95703F, 0.97656F, 1.00000F, 1.02344F, 1.03906F}, | ||
{0.96094F, 0.98047F, 1.00000F, 1.01562F, 1.03906F}, {0.96484F, 0.98047F, 1.00000F, 1.01562F, 1.03125F}, | ||
{0.96875F, 0.98438F, 1.00000F, 1.01562F, 1.03125F}, {0.96875F, 0.98438F, 1.00000F, 1.01562F, 1.03125F}, | ||
{0.97266F, 0.98438F, 1.00000F, 1.01562F, 1.03125F}, {0.97266F, 0.98828F, 1.00000F, 1.01562F, 1.02344F}, | ||
{0.97656F, 0.98828F, 1.00000F, 1.01562F, 1.02344F}, {0.97656F, 0.98828F, 1.00000F, 1.00781F, 1.02344F}, | ||
{0.97656F, 0.98828F, 1.00000F, 1.00781F, 1.02344F}, {0.98047F, 0.98828F, 1.00000F, 1.00781F, 1.02344F}}}, | ||
{{{0.98047F, 0.98828F, 1.00000F, 1.00781F, 1.01562F}, {0.98047F, 0.99219F, 1.00000F, 1.00781F, 1.01562F}, | ||
{0.98047F, 0.99219F, 1.00000F, 1.00781F, 1.01562F}, {0.98438F, 0.99219F, 1.00000F, 1.00781F, 1.01562F}, | ||
{0.98438F, 0.99219F, 1.00000F, 1.00781F, 1.01562F}, {0.98438F, 0.99219F, 1.00000F, 1.00781F, 1.01562F}, | ||
{0.98438F, 0.99219F, 1.00000F, 1.00781F, 1.01562F}, {0.98438F, 0.99219F, 1.00000F, 1.00781F, 1.01562F}, | ||
{0.98438F, 0.99219F, 1.00000F, 1.00781F, 1.01562F}, {0.98828F, 0.99219F, 1.00000F, 1.00781F, 1.01562F}, | ||
{0.98828F, 0.99219F, 1.00000F, 1.00781F, 1.01562F}, {0.98828F, 0.99219F, 1.00000F, 1.00781F, 1.01562F}, | ||
{0.98828F, 0.99219F, 1.00000F, 1.00781F, 1.01562F}, {0.98828F, 0.99219F, 1.00000F, 1.00781F, 1.01562F}, | ||
{0.98828F, 0.99609F, 1.00000F, 1.00781F, 1.00781F}, {0.98828F, 0.99609F, 1.00000F, 1.00781F, 1.00781F}, | ||
{0.98828F, 0.99609F, 1.00000F, 1.00781F, 1.00781F}, {0.98828F, 0.99609F, 1.00000F, 1.00781F, 1.00781F}, | ||
{0.98828F, 0.99609F, 1.00000F, 1.00781F, 1.00781F}, {0.98828F, 0.99609F, 1.00000F, 1.00781F, 1.00781F}}}}; | ||
assert((expected_result.shape() == result_xtensor.shape())); | ||
EXPECT_TRUE(xt::allclose(result_xtensor, expected_result, 6e-2F, 1e-8F)); | ||
} | ||
|
||
TEST_F(RMSNormOpTest, RMSNorm_Backward_Batch) { | ||
using namespace ttml; | ||
float eps = 0.0078125F; // default in PyTorch for bf16 | ||
|
||
// 2 batches, 1 sequence, 20 tokens, 5-dim'l embedding space. | ||
std::array<uint32_t, 4> a_shape = {2, 1, 20, 5}; | ||
xt::xarray<float> a_xarray = xt::xarray<float>::from_shape(a_shape); | ||
std::generate(a_xarray.begin(), a_xarray.end(), [cur = 0.0F]() mutable { return (cur++); }); | ||
|
||
auto example_tensor = autograd::create_tensor(core::from_xtensor(a_xarray, &autograd::ctx().get_device())); | ||
auto gamma = autograd::create_tensor(core::ones(core::create_shape({1, 1, 1, 5}), &autograd::ctx().get_device())); | ||
|
||
auto result = ops::rmsnorm(example_tensor, gamma, 0.0078125F); | ||
auto result_xtensor = core::to_xtensor(result->get_value()); | ||
|
||
auto target = autograd::create_tensor(core::zeros_like(result->get_value())); | ||
auto mse_result = ttml::ops::mse_loss(result, target); | ||
mse_result->backward(); | ||
|
||
auto example_tensor_grad = core::to_xtensor(example_tensor->get_grad()); | ||
xt::xarray<float> expected_example_tensor_grad = xt::zeros_like(a_xarray); | ||
EXPECT_TRUE(xt::allclose(example_tensor_grad, expected_example_tensor_grad, 5e-2F, 1e-3F)); | ||
|
||
auto gamma_grad = core::to_xtensor(gamma->get_grad()); | ||
xt::xarray<float> expected_gamma_grad = {{{{0.36111F, 0.37644F, 0.39589F, 0.41945F, 0.44712F}}}}; | ||
EXPECT_TRUE(xt::allclose(gamma_grad, expected_gamma_grad, 5e-2F)); | ||
} |