diff --git a/transformer_engine/jax/praxis/module.py b/transformer_engine/jax/praxis/module.py index b82c0915e4..e33ff39385 100644 --- a/transformer_engine/jax/praxis/module.py +++ b/transformer_engine/jax/praxis/module.py @@ -6,6 +6,7 @@ """ from functools import partial from typing import Callable, Iterable, Sequence, Tuple, Union +from dataclasses import field, dataclass from praxis import pax_fiddle from praxis.base_layer import init_var @@ -27,6 +28,7 @@ def _generate_ln_scale_init(scale_init): return scale_init +@dataclass class TransformerEngineBaseLayer(BaseLayer): """TransformerEngineBaseLayer""" @@ -66,6 +68,7 @@ def create_layer(self, name, flax_module_cls): self.create_child(name, flax_module_p.clone()) +@dataclass class LayerNorm(TransformerEngineBaseLayer): """LayerNorm""" @@ -74,7 +77,7 @@ class LayerNorm(TransformerEngineBaseLayer): zero_centered_gamma: bool = False scale_init: WeightInit = None scale_axes: Tuple[str, ...] = () - bias_init: WeightInit = WeightInit.Constant(0.0) + bias_init: WeightInit = field(default_factory=WeightInit.Constant(0.0)) bias_axes: Tuple[str, ...] = () transpose_batch_sequence: bool = False @@ -102,6 +105,7 @@ def __call__(self, x: JTensor) -> JTensor: return self.layer_norm(x) +@dataclass class FusedSoftmax(TransformerEngineBaseLayer): """FusedSoftmax""" @@ -123,13 +127,14 @@ def __call__(self, x: JTensor, mask: JTensor = None, bias: JTensor = None) -> JT return self.fused_softmax(x, mask, bias) +@dataclass class Linear(TransformerEngineBaseLayer): """Linear""" out_features: int = 512 kernel_axes: Tuple[str, ...] = () use_bias: bool = True - bias_init: WeightInit = WeightInit.Constant(0.0) + bias_init: WeightInit = field(default_factory=WeightInit.Constant(0.0)) bias_axes: Tuple[str, ...] = () enable_low_rank_adaptation: bool = False low_rank_adaptation_dim: int = 32 @@ -164,6 +169,7 @@ def __call__(self, x: JTensor) -> JTensor: return self.linear(x) +@dataclass class LayerNormLinear(TransformerEngineBaseLayer): """LayerNormLinear""" @@ -174,11 +180,11 @@ class LayerNormLinear(TransformerEngineBaseLayer): zero_centered_gamma: bool = False scale_init: WeightInit = None scale_axes: Tuple[str, ...] = () - ln_bias_init: WeightInit = WeightInit.Constant(1.0) + ln_bias_init: WeightInit = field(default_factory=WeightInit.Constant(1.0)) ln_bias_axes: Tuple[str, ...] = () kernel_axes: Tuple[str, ...] = () use_bias: bool = False - bias_init: WeightInit = WeightInit.Constant(0.0) + bias_init: WeightInit = field(default_factory=WeightInit.Constant(0.0)) bias_axes: Tuple[str, ...] = () enable_low_rank_adaptation: bool = False low_rank_adaptation_dim: int = 32 @@ -227,6 +233,7 @@ def __call__(self, x: JTensor) -> JTensor: return self.ln_linear(x) +@dataclass class LayerNormMLP(TransformerEngineBaseLayer): """LayerNormMLP""" @@ -237,12 +244,12 @@ class LayerNormMLP(TransformerEngineBaseLayer): zero_centered_gamma: bool = False scale_init: WeightInit = None scale_axes: Tuple[str, ...] = () - ln_bias_init: WeightInit = WeightInit.Constant(1.0) + ln_bias_init: WeightInit = field(default_factory=WeightInit.Constant(1.0)) ln_bias_axes: Tuple[str, ...] = () kernel_axes_1: Tuple[str, ...] = () kernel_axes_2: Tuple[str, ...] = () use_bias: bool = False - bias_init: WeightInit = WeightInit.Constant(0.0) + bias_init: WeightInit = field(default_factory=WeightInit.Constant(0.0)) bias_axes_1: Tuple[str, ...] = () bias_axes_2: Tuple[str, ...] = () enable_low_rank_adaptation: bool = False diff --git a/transformer_engine/jax/praxis/transformer.py b/transformer_engine/jax/praxis/transformer.py index f2ac802f10..405362f4f8 100644 --- a/transformer_engine/jax/praxis/transformer.py +++ b/transformer_engine/jax/praxis/transformer.py @@ -6,6 +6,7 @@ """ from functools import partial from typing import Optional, Sequence, Tuple +from dataclasses import field, dataclass import warnings from praxis import pax_fiddle @@ -21,6 +22,7 @@ from ..attention import AttnBiasType, AttnMaskType +@dataclass class RelativePositionBiases(TransformerEngineBaseLayer): """RelativePositionBiases""" @@ -36,7 +38,7 @@ def generate_embedding_init(init, num_attention_heads, num_buckets): embedding_init = init if embedding_init is None: rb_stddev = (num_attention_heads * num_buckets) ** -0.5 - embedding_init = WeightInit.Gaussian(rb_stddev) + embedding_init = field(default_factory=WeightInit.Gaussian(rb_stddev)) return embedding_init def setup(self) -> None: @@ -66,6 +68,7 @@ def __call__(self, q_seqlen: JTensor, k_seqlen: JTensor, bidirectional: bool = T return self.relative_position_bias(q_seqlen, k_seqlen, bidirectional) +@dataclass class DotProductAttention(TransformerEngineBaseLayer): """DotProductAttention""" @@ -124,6 +127,7 @@ def __call__( ) +@dataclass class MultiHeadAttention(TransformerEngineBaseLayer): """MultiHeadAttention""" @@ -138,7 +142,7 @@ class MultiHeadAttention(TransformerEngineBaseLayer): zero_centered_gamma: bool = False return_layernorm_output: bool = False use_bias: bool = False - bias_init: WeightInit = WeightInit.Constant(0.0) + bias_init: WeightInit = field(default_factory=WeightInit.Constant(0.0)) attn_mask_type: str = "causal" attn_bias_type: Optional[str] = None enable_rotary_pos_emb: bool = False @@ -257,6 +261,7 @@ def __call__( ) +@dataclass class TransformerLayer(TransformerEngineBaseLayer): """TransformerLayer""" @@ -275,7 +280,7 @@ class TransformerLayer(TransformerEngineBaseLayer): dropout_rng_name: str = "dropout" mlp_activations: Sequence[str] = ("relu",) use_bias: bool = False - bias_init: WeightInit = WeightInit.Constant(0.0) + bias_init: WeightInit = field(default_factory=WeightInit.Constant(0.0)) apply_residual_connection_post_layernorm: bool = False output_layernorm: bool = False float32_attention_logits: bool = False