From e44d5cfe96e9aa97c082ed1142ccd99a882a3c51 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Thu, 21 Nov 2024 20:38:19 +0000 Subject: [PATCH] fixed static args Signed-off-by: Alp Dener --- transformer_engine/jax/cpp_extensions/gemm.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index cb62189d5a..0a198ccdf2 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -3,9 +3,8 @@ # See LICENSE for license information. import warnings import operator -from functools import reduce +from functools import reduce, partial from typing import Optional, Tuple -from collections.abc import Iterable import jax import jax.numpy as jnp @@ -69,7 +68,7 @@ class CollectiveGemmPrimitive(BasePrimitive): """ name = "te_gemm" - impl_static_args = (8, 9, 10, 11, 12, 13, 14, 15) + impl_static_args = (8, 9, 10, 11, 12, 13, 14, 15, 16) multiple_results = True inner_primitive = None outer_primitive = None