From 718c03d11e7984829b1d9ac8c86d6404823ce42f Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Thu, 21 Nov 2024 11:38:22 +0000 Subject: [PATCH] fixed logic to remove FSDP sharding Signed-off-by: Alp Dener --- transformer_engine/jax/cpp_extensions/gemm.py | 27 ++++++++++++++++--- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index bf80941f85..d54009e60b 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -49,25 +49,44 @@ def mirror_dim(dim, ndims): def remove_fsdp_specs(pspecs): fsdp_resource = global_mesh_resource().fsdp_resource + if fsdp_resource is None: + return list(pspecs).copy() + new_pspecs = [] for spec in pspecs: if spec is None: new_pspecs.append(None) - elif fsdp_resource not in spec: - new_pspecs.append(spec) + elif isinstance(spec, Iterable) and not isinstance(spec, str): new_spec = [] for s in spec: - if s != fsdp_resource: + if s == fsdp_resource: + new_spec.append(None) + else: new_spec.append(s) + if len(new_spec) > 1: new_pspecs.append(new_spec) elif len(new_spec) == 1: new_pspecs.append(new_spec[0]) else: new_pspecs.append(None) + + elif isinstance(spec, str): + if spec == fsdp_resource: + new_pspecs.append(None) + else: + new_pspecs.append(spec) + else: - new_pspecs.append(None) + new_pspecs.append(spec) + + assert len(new_pspecs) == len(pspecs), ( + "Length of partition specs changed when removing FSDP sharding!\n" + + f"Original: {pspecs}\n" + + f"Filtered: {new_pspecs}\n" + ) + return new_pspecs