diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 49ea86e6..ca6c2441 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -165,7 +165,7 @@ def __init__( bias=False, async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER, ) - self.split_silu_mul = torch.compile(GLUActivation(config.hidden_act)) + self.split_silu_mul = GLUActivation(config.hidden_act) def forward(self, hidden_states): # [seq_length, batch_size, hidden_dim] merged_states = self.gate_up_proj(hidden_states)