Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JAX] Bug Fix: WeightInit with field #1361

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions transformer_engine/jax/praxis/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""
from functools import partial
from typing import Callable, Iterable, Sequence, Tuple, Union
from dataclasses import field, dataclass

from praxis import pax_fiddle
from praxis.base_layer import init_var
Expand All @@ -27,6 +28,7 @@ def _generate_ln_scale_init(scale_init):
return scale_init


@dataclass
class TransformerEngineBaseLayer(BaseLayer):
"""TransformerEngineBaseLayer"""

Expand Down Expand Up @@ -66,6 +68,7 @@ def create_layer(self, name, flax_module_cls):
self.create_child(name, flax_module_p.clone())


@dataclass
class LayerNorm(TransformerEngineBaseLayer):
"""LayerNorm"""

Expand All @@ -74,7 +77,7 @@ class LayerNorm(TransformerEngineBaseLayer):
zero_centered_gamma: bool = False
scale_init: WeightInit = None
scale_axes: Tuple[str, ...] = ()
bias_init: WeightInit = WeightInit.Constant(0.0)
bias_init: WeightInit = field(default_factory=WeightInit.Constant(0.0))
bias_axes: Tuple[str, ...] = ()
transpose_batch_sequence: bool = False

Expand Down Expand Up @@ -102,6 +105,7 @@ def __call__(self, x: JTensor) -> JTensor:
return self.layer_norm(x)


@dataclass
class FusedSoftmax(TransformerEngineBaseLayer):
"""FusedSoftmax"""

Expand All @@ -123,13 +127,14 @@ def __call__(self, x: JTensor, mask: JTensor = None, bias: JTensor = None) -> JT
return self.fused_softmax(x, mask, bias)


@dataclass
class Linear(TransformerEngineBaseLayer):
"""Linear"""

out_features: int = 512
kernel_axes: Tuple[str, ...] = ()
use_bias: bool = True
bias_init: WeightInit = WeightInit.Constant(0.0)
bias_init: WeightInit = field(default_factory=WeightInit.Constant(0.0))
bias_axes: Tuple[str, ...] = ()
enable_low_rank_adaptation: bool = False
low_rank_adaptation_dim: int = 32
Expand Down Expand Up @@ -164,6 +169,7 @@ def __call__(self, x: JTensor) -> JTensor:
return self.linear(x)


@dataclass
class LayerNormLinear(TransformerEngineBaseLayer):
"""LayerNormLinear"""

Expand All @@ -174,11 +180,11 @@ class LayerNormLinear(TransformerEngineBaseLayer):
zero_centered_gamma: bool = False
scale_init: WeightInit = None
scale_axes: Tuple[str, ...] = ()
ln_bias_init: WeightInit = WeightInit.Constant(1.0)
ln_bias_init: WeightInit = field(default_factory=WeightInit.Constant(1.0))
ln_bias_axes: Tuple[str, ...] = ()
kernel_axes: Tuple[str, ...] = ()
use_bias: bool = False
bias_init: WeightInit = WeightInit.Constant(0.0)
bias_init: WeightInit = field(default_factory=WeightInit.Constant(0.0))
bias_axes: Tuple[str, ...] = ()
enable_low_rank_adaptation: bool = False
low_rank_adaptation_dim: int = 32
Expand Down Expand Up @@ -227,6 +233,7 @@ def __call__(self, x: JTensor) -> JTensor:
return self.ln_linear(x)


@dataclass
class LayerNormMLP(TransformerEngineBaseLayer):
"""LayerNormMLP"""

Expand All @@ -237,12 +244,12 @@ class LayerNormMLP(TransformerEngineBaseLayer):
zero_centered_gamma: bool = False
scale_init: WeightInit = None
scale_axes: Tuple[str, ...] = ()
ln_bias_init: WeightInit = WeightInit.Constant(1.0)
ln_bias_init: WeightInit = field(default_factory=WeightInit.Constant(1.0))
ln_bias_axes: Tuple[str, ...] = ()
kernel_axes_1: Tuple[str, ...] = ()
kernel_axes_2: Tuple[str, ...] = ()
use_bias: bool = False
bias_init: WeightInit = WeightInit.Constant(0.0)
bias_init: WeightInit = field(default_factory=WeightInit.Constant(0.0))
bias_axes_1: Tuple[str, ...] = ()
bias_axes_2: Tuple[str, ...] = ()
enable_low_rank_adaptation: bool = False
Expand Down
11 changes: 8 additions & 3 deletions transformer_engine/jax/praxis/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""
from functools import partial
from typing import Optional, Sequence, Tuple
from dataclasses import field, dataclass
import warnings

from praxis import pax_fiddle
Expand All @@ -21,6 +22,7 @@
from ..attention import AttnBiasType, AttnMaskType


@dataclass
class RelativePositionBiases(TransformerEngineBaseLayer):
"""RelativePositionBiases"""

Expand All @@ -36,7 +38,7 @@ def generate_embedding_init(init, num_attention_heads, num_buckets):
embedding_init = init
if embedding_init is None:
rb_stddev = (num_attention_heads * num_buckets) ** -0.5
embedding_init = WeightInit.Gaussian(rb_stddev)
embedding_init = field(default_factory=WeightInit.Gaussian(rb_stddev))
return embedding_init

def setup(self) -> None:
Expand Down Expand Up @@ -66,6 +68,7 @@ def __call__(self, q_seqlen: JTensor, k_seqlen: JTensor, bidirectional: bool = T
return self.relative_position_bias(q_seqlen, k_seqlen, bidirectional)


@dataclass
class DotProductAttention(TransformerEngineBaseLayer):
"""DotProductAttention"""

Expand Down Expand Up @@ -124,6 +127,7 @@ def __call__(
)


@dataclass
class MultiHeadAttention(TransformerEngineBaseLayer):
"""MultiHeadAttention"""

Expand All @@ -138,7 +142,7 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
zero_centered_gamma: bool = False
return_layernorm_output: bool = False
use_bias: bool = False
bias_init: WeightInit = WeightInit.Constant(0.0)
bias_init: WeightInit = field(default_factory=WeightInit.Constant(0.0))
attn_mask_type: str = "causal"
attn_bias_type: Optional[str] = None
enable_rotary_pos_emb: bool = False
Expand Down Expand Up @@ -257,6 +261,7 @@ def __call__(
)


@dataclass
class TransformerLayer(TransformerEngineBaseLayer):
"""TransformerLayer"""

Expand All @@ -275,7 +280,7 @@ class TransformerLayer(TransformerEngineBaseLayer):
dropout_rng_name: str = "dropout"
mlp_activations: Sequence[str] = ("relu",)
use_bias: bool = False
bias_init: WeightInit = WeightInit.Constant(0.0)
bias_init: WeightInit = field(default_factory=WeightInit.Constant(0.0))
apply_residual_connection_post_layernorm: bool = False
output_layernorm: bool = False
float32_attention_logits: bool = False
Expand Down
Loading