diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index cb62189d5a..0a198ccdf2 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -3,9 +3,8 @@ # See LICENSE for license information. import warnings import operator -from functools import reduce +from functools import reduce, partial from typing import Optional, Tuple -from collections.abc import Iterable import jax import jax.numpy as jnp @@ -69,7 +68,7 @@ class CollectiveGemmPrimitive(BasePrimitive): """ name = "te_gemm" - impl_static_args = (8, 9, 10, 11, 12, 13, 14, 15) + impl_static_args = (8, 9, 10, 11, 12, 13, 14, 15, 16) multiple_results = True inner_primitive = None outer_primitive = None