Skip to content

Commit

Permalink
fixed static args
Browse files Browse the repository at this point in the history
Signed-off-by: Alp Dener <[email protected]>
  • Loading branch information
denera committed Nov 21, 2024
1 parent 8a610d0 commit e44d5cf
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions transformer_engine/jax/cpp_extensions/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
# See LICENSE for license information.
import warnings
import operator
from functools import reduce
from functools import reduce, partial
from typing import Optional, Tuple
from collections.abc import Iterable

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -69,7 +68,7 @@ class CollectiveGemmPrimitive(BasePrimitive):
"""

name = "te_gemm"
impl_static_args = (8, 9, 10, 11, 12, 13, 14, 15)
impl_static_args = (8, 9, 10, 11, 12, 13, 14, 15, 16)
multiple_results = True
inner_primitive = None
outer_primitive = None
Expand Down

0 comments on commit e44d5cf

Please sign in to comment.