Skip to content

Commit

Permalink
Training Utils, Clean up (#3)
Browse files Browse the repository at this point in the history
* 🐛

Signed-off-by: Peter Szemraj <[email protected]>

* 🎨

Signed-off-by: Peter Szemraj <[email protected]>

* suppress fla FutureWarning

Signed-off-by: Peter Szemraj <[email protected]>

* 🔥 remove fused rotary emb code

Signed-off-by: Peter Szemraj <[email protected]>

* ✨ model summary, auto tf32

Signed-off-by: Peter Szemraj <[email protected]>

* 📝 document config params

Signed-off-by: Peter Szemraj <[email protected]>

* 🔊 more detailed model summary

Signed-off-by: Peter Szemraj <[email protected]>

* more closely match samba421m cfg

Signed-off-by: Peter Szemraj <[email protected]>

---------

Signed-off-by: Peter Szemraj <[email protected]>
Co-authored-by: Peter Szemraj <[email protected]>
  • Loading branch information
pszemraj and Peter Szemraj authored Nov 23, 2024
1 parent 19a0ab6 commit 58adcdc
Show file tree
Hide file tree
Showing 7 changed files with 187 additions and 130 deletions.
55 changes: 54 additions & 1 deletion samba_pytorch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,60 @@

@dataclass
class Config:
org: str = "Lightning-AI"
"""Configuration class for SAMBA (Simple Hybrid State Space Models) architecture.
The SAMBA architecture combines Mamba (selective state space model) with
Sliding Window Attention (SWA) and Multi-Layer Perceptrons (MLP) in a layer-wise fashion.
Attributes:
org (str): Organization name, defaults to "samba-pytorch"
name (str): Model name, defaults to "lit-GPT"
block_size (int): Maximum sequence length for the model, defaults to 4096
vocab_size (int): Size of the vocabulary, defaults to 50254
padding_multiple (int): Padding factor for vocab size optimization, defaults to 512
padded_vocab_size (Optional[int]): Actual padded vocabulary size after adjustment
n_layer (int): Number of transformer layers, defaults to 16
n_head (int): Number of attention heads, defaults to 32
n_embd (int): Embedding dimension / hidden state size, defaults to 4096
rotary_percentage (float): Fraction of dimensions to apply rotary embeddings to, defaults to 0.25
parallel_residual (bool): Whether to use parallel residual connections, defaults to True
bias (bool): Whether to include bias terms in linear layers, defaults to True
# SAMBA-specific parameters
local_window (int): Size of sliding window for attention, -1 means full attention
mlp (bool): Whether to include MLP layers, defaults to True
full_per_layer (int): Number of tokens for full attention per layer
mb_per_layer (int): Number of Mamba layers per block
ret_per_layer (int): Number of RetNet layers per block
gla_per_layer (int): Number of GLA (Gated Linear Attention) layers per block
nope (bool): Skip certain layers if True
mamba (bool): Whether to use Mamba layers, defaults to False
sc_attn (bool): Whether to use short convolution in attention, defaults to False
rms_norm (bool): Use RMSNorm instead of LayerNorm, defaults to True
# Performance optimizations
residual_in_fp32 (bool): Keep residual connections in fp32, defaults to True
fused_add_norm (bool): Use fused add+norm operations, defaults to True
mamba_init (bool): Use specialized Mamba initialization, defaults to False
attn_layer_pos (str): Position of attention layers in architecture
n_query_groups (Optional[int]): Number of query groups for grouped-query attention
shared_attention_norm (bool): Share normalization across attention heads, defaults to False
_norm_class (str): Normalization layer class to use ("LayerNorm" or "RMSNorm")
norm_eps (float): Epsilon for normalization layers, defaults to 1e-5
_mlp_class (str): MLP implementation class ("GptNeoxMLP" or "LLaMAMLP")
intermediate_size (Optional[int]): Size of intermediate MLP layers
condense_ratio (int): Ratio for condensing layers, defaults to 1
Key Implementation Details from Paper:
- SAMBA combines Mamba, SWA and MLP through layer-wise interleaving
- Default sliding window size is 2048 tokens
- Uses PreNorm and skip connections for each intermediate layer
- Mamba layers capture time-dependent semantics and provide efficient decoding
- SWA handles complex non-Markovian dependencies
- MLPs handle factual knowledge recall
"""
org: str = "samba-pytorch"
name: str = "lit-GPT"
block_size: int = 4096
vocab_size: int = 50254
Expand Down
6 changes: 0 additions & 6 deletions samba_pytorch/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,12 @@
"""Core model component modules."""

from samba_pytorch.modules.fused_rotary_embedding import (
ApplyRotaryEmb,
apply_rotary_emb_func,
)
from samba_pytorch.modules.gla import GatedLinearAttention
from samba_pytorch.modules.mamba_simple import Mamba
from samba_pytorch.modules.multiscale_retention import MultiScaleRetention
from samba_pytorch.modules.rmsnorm import RMSNorm, rms_norm
from samba_pytorch.modules.rotary import RotaryEmbedding, apply_rotary_emb

__all__ = [
"apply_rotary_emb_func",
"ApplyRotaryEmb",
"GatedLinearAttention",
"Mamba",
"MultiScaleRetention",
Expand Down
99 changes: 0 additions & 99 deletions samba_pytorch/modules/fused_rotary_embedding.py

This file was deleted.

5 changes: 3 additions & 2 deletions samba_pytorch/modules/rmsnorm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Optional, Tuple, Union

import torch
from torch import nn
from einops import rearrange
from typing import Optional, Tuple, Union
from torch import nn


def maybe_align(x: torch.Tensor, alignment_in_bytes: int = 16) -> torch.Tensor:
Expand Down
19 changes: 9 additions & 10 deletions samba_pytorch/samba.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@
# see LICENSE file at https://github.com/Lightning-AI/litgpt/blob/main/LICENSE

import math
import warnings
from functools import partial
from typing import Any, List, Optional, Tuple

import torch
import torch.nn as nn
from rotary_embedding_torch import RotaryEmbedding
from torch import Tensor
from typing_extensions import Self
from xformers.ops import SwiGLU
from rotary_embedding_torch import RotaryEmbedding

try:
from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
Expand All @@ -22,12 +23,12 @@
from causal_conv1d import causal_conv1d_fn
from einops import rearrange

from samba_pytorch.config import Config

from samba_pytorch.modules.gla import GatedLinearAttention
from samba_pytorch.modules.mamba_simple import Mamba
from samba_pytorch.modules.multiscale_retention import MultiScaleRetention
warnings.filterwarnings("ignore", category=FutureWarning, module="fla.ops")

from samba_pytorch.config import Config # noqa
from samba_pytorch.modules.gla import GatedLinearAttention # noqa
from samba_pytorch.modules.mamba_simple import Mamba # noqa
from samba_pytorch.modules.multiscale_retention import MultiScaleRetention # noqa

RoPECache = Tuple[torch.Tensor, torch.Tensor]
KVCache = Tuple[torch.Tensor, torch.Tensor]
Expand Down Expand Up @@ -74,7 +75,7 @@ def __init__(self, config: Config) -> None:
self.config = config

self.rotary_emb = RotaryEmbedding(
dim=int(config.rotary_percentage * config.head_size), # TODO: validate
dim=int(config.rotary_percentage * config.head_size), # TODO: validate
use_xpos=getattr(config, "use_xpos", False),
interpolate_factor=getattr(config, "interpolate_factor", 1.0),
)
Expand Down Expand Up @@ -243,7 +244,7 @@ def forward(

# Initialize rotary embedding variables
if self.config.nope:
rope = None # Set rope to None if config.nope
rope = None # Set rope to None if config.nope
else:
# Using rotary_emb to rotate queries and keys in attention modules
rope = self.rotary_emb
Expand Down Expand Up @@ -668,5 +669,3 @@ def __init__(
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.swiglu(x)
return x


108 changes: 107 additions & 1 deletion samba_pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from functools import partial
from io import BytesIO
from pathlib import Path
from typing import Any, Dict, List, Mapping, Optional, TypeVar, Union
from typing import Any, Dict, List, Mapping, Optional, Tuple, TypeVar, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -669,3 +669,109 @@ def get_default_supported_precision(training: bool, tpu: bool = False) -> str:
if not torch.cuda.is_available() or torch.cuda.is_bf16_supported():
return "bf16-mixed" if training else "bf16-true"
return "16-mixed" if training else "16-true"


def activate_tf32_if_available():
"""
Check if the GPU supports NVIDIA Ampere or later and enable FP32 in PyTorch if it does.
"""
# Check if CUDA is available
if not torch.cuda.is_available():
warnings.warn("No GPU detected, running on CPU.")
return

try:
device = torch.cuda.current_device()
capability = torch.cuda.get_device_capability(device)
major, minor = capability

# Check if the GPU is Ampere or newer
if major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
gpu_name = torch.cuda.get_device_name(device)
print(
f"{gpu_name} (compute capability {major}.{minor}) supports NVIDIA Ampere or later, enabled TF32 in PyTorch."
)
else:
gpu_name = torch.cuda.get_device_name(device)
print(
f"{gpu_name} (compute capability {major}.{minor}) does not support NVIDIA Ampere or later."
)

except Exception as e:
warnings.warn(f"Error occurred while checking GPU: {e}")


def model_summary(
model: nn.Module, max_depth: int = 4, show_input_size: bool = False
) -> None:
"""
Prints an accurate summary of the model, avoiding double-counting of parameters.
:param PreTrainedModel model: torch model to summarize
:param int max_depth: maximum depth of the model to print, defaults to 4
:param bool show_input_size: whether to show input size for each layer, defaults to False
"""

def format_params(num_params: int) -> str:
return f"{num_params:,}" if num_params > 0 else "--"

def format_size(size: Optional[List[int]]) -> str:
return "x".join(str(x) for x in size) if size else "N/A"

def count_parameters(module: nn.Module) -> Tuple[int, int]:
total_params = sum(p.numel() for p in module.parameters())
trainable_params = sum(
p.numel() for p in module.parameters() if p.requires_grad
)
return total_params, trainable_params

def recursive_summarize(
module: nn.Module, depth: int, idx: List[int], prefix: str = ""
) -> List[Tuple[str, int, int, int, Optional[List[int]], nn.Module]]:
summary = []

total_params, trainable_params = count_parameters(module)

if depth <= max_depth:
layer_name = f"{prefix}{type(module).__name__}"
layer_index = ".".join(map(str, idx))
param_shape = next(
(p.shape for p in module.parameters(recurse=False) if p.requires_grad),
None,
)
summary.append(
(layer_name, depth, total_params, trainable_params, param_shape, module)
)

for i, (name, child) in enumerate(module.named_children(), 1):
child_summary = recursive_summarize(
child, depth + 1, idx + [i], prefix + " "
)
summary.extend(child_summary)

return summary

summary = recursive_summarize(model, 1, [1])

max_name_length = max(len(name) for name, _, _, _, _, _ in summary)
max_shape_length = max(len(format_size(shape)) for _, _, _, _, shape, _ in summary)

print("=" * (max_name_length + 50))
header = f"{'Layer (type:depth-idx)':<{max_name_length}} {'Output Shape':>{max_shape_length}} {'Param #':>12} {'Trainable':>10}"
print(header)
print("=" * (max_name_length + 50))

for name, depth, num_params, trainable_params, shape, _ in summary:
shape_str = format_size(shape) if show_input_size else ""
print(
f"{name:<{max_name_length}} {shape_str:>{max_shape_length}} {format_params(num_params):>12} {str(trainable_params > 0):>10}"
)

total_params, trainable_params = count_parameters(model)
print("=" * (max_name_length + 50))
print(f"Total params: {format_params(total_params)}")
print(f"Trainable params: {format_params(trainable_params)}")
print(f"Non-trainable params: {format_params(total_params - trainable_params)}")
print("=" * (max_name_length + 50))
Loading

0 comments on commit 58adcdc

Please sign in to comment.