Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch 2.6 #209

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open

torch 2.6 #209

wants to merge 3 commits into from

Conversation

samsja
Copy link
Member

@samsja samsja commented Feb 1, 2025

No description provided.

@samsja
Copy link
Member Author

samsja commented Feb 1, 2025

I have some weird error when running the 1b and 10B models. This error does not exist with the 150M model

uv  run torchrun --nproc_per_node=8 src/zeroband/train.py @ configs/1B/H100.toml

More info :

  • bug is happening both on 2x3090 and 8xH100
  • torch compile is not the problem uv run torchrun --nproc_per_node=2 src/zeroband/train.py @ configs/1B/H100.toml --no-train.torch_compile has the same problem
  • Does not fail when using sdpa instead of flex attention
22:36:40 [INFO] [Rank 0] Caught an exception, terminating children
22:36:40 [INFO] [Rank 0] NoValidChoicesError: No choices to select, please consider adding ATEN into max_autotune_gemm_backends config (defined in torch/_inductor/config.py) to allow at least one choice.
  target: flex_attention_backward
  args[0]: TensorBox(StorageBox(
    InputBuffer(name='primals_1', layout=FixedLayout('cuda:0', torch.bfloat16, size=[32, 16, 1024, 128], stride=[2097152, 128, 2048, 1]))
  ))
  args[1]: TensorBox(StorageBox(
    InputBuffer(name='primals_2', layout=FixedLayout('cuda:0', torch.bfloat16, size=[32, 16, 1024, 128], stride=[2097152, 128, 2048, 1]))
  ))
  args[2]: TensorBox(StorageBox(
    InputBuffer(name='primals_3', layout=FixedLayout('cuda:0', torch.bfloat16, size=[32, 16, 1024, 128], stride=[2097152, 128, 2048, 1]))
  ))
  args[3]: TensorBox(StorageBox(
    InputBuffer(name='getitem', layout=FixedLayout('cuda:0', torch.bfloat16, size=[32, 16, 1024, 128], stride=[2097152, 128, 2048, 1]))
  ))
  args[4]: TensorBox(StorageBox(
    DonatedBuffer(name='getitem_1', layout=FixedLayout('cuda:0', torch.float32, size=[32, 16, 1024], stride=[16384, 1024, 1]))
  ))
  args[5]: TensorBox(StorageBox(
    InputBuffer(name='tangents_1', layout=FixedLayout('cuda:0', torch.bfloat16, size=[32, 16, 1024, 128], stride=[2097152, 131072, 128, 1]))
  ))
  args[6]: TensorBox(StorageBox(
    Pointwise(
      'cuda',
      torch.float32,
      def inner_fn(index):
          i0, i1, i2 = index
          tmp0 = ops.constant(0, torch.float32)
          return tmp0
      ,
      ranges=[32, 16, 1024],
      origin_node=full_default,
      origins=OrderedSet([full_default])
    )
  ))
  args[7]: Subgraph(name='fw_graph0', graph_module=<lambda>(), graph=None)
  args[8]: Subgraph(name='joint_graph0', graph_module=<lambda>(), graph=None)
  args[9]: (1024, 1024, TensorBox(StorageBox(
    InputBuffer(name='primals_5', layout=FixedLayout('cuda:0', torch.int32, size=[32, 1, 8], stride=[8, 8, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='primals_4', layout=FixedLayout('cuda:0', torch.int32, size=[32, 1, 8, 8], stride=[64, 64, 8, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='primals_7', layout=FixedLayout('cuda:0', torch.int32, size=[32, 1, 8], stride=[8, 8, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='primals_8', layout=FixedLayout('cuda:0', torch.int32, size=[32, 1, 8, 8], stride=[64, 64, 8, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='primals_9', layout=FixedLayout('cuda:0', torch.int32, size=[32, 1, 8], stride=[8, 8, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='primals_10', layout=FixedLayout('cuda:0', torch.int32, size=[32, 1, 8, 8], stride=[64, 64, 8, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='primals_11', layout=FixedLayout('cuda:0', torch.int32, size=[32, 1, 8], stride=[8, 8, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='primals_12', layout=FixedLayout('cuda:0', torch.int32, size=[32, 1, 8, 8], stride=[64, 64, 8, 1]))
  )), 128, 128, Subgraph(name='mask_graph0', graph_module=<lambda>(), graph=None))
  args[10]: 0.08838834764831843
  args[11]: {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'OUTPUT_LOGSUMEXP': True}
  args[12]: ()
  args[13]: (TensorBox(StorageBox(
    InputBuffer(name='primals_6', layout=FixedLayout('cuda:0', torch.int64, size=[32, 1024], stride=[1024, 1]))
  )),)
Traceback (most recent call last):
  File "/root/prime/src/zeroband/train.py", line 589, in <module>
    raise e
  File "/root/prime/src/zeroband/train.py", line 581, in <module>
    train(config)
  File "/root/prime/src/zeroband/train.py", line 352, in train
    loss.backward()
  File "/root/prime/.venv/lib/python3.10/site-packages/torch/_tensor.py", line 626, in backward
    torch.autograd.backward(
  File "/root/prime/.venv/lib/python3.10/site-packages/torch/autograd/__init__.py", line 347, in backward
    _engine_run_backward(
  File "/root/prime/.venv/lib/python3.10/site-packages/torch/autograd/graph.py", line 823, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/root/prime/.venv/lib/python3.10/site-packages/torch/autograd/function.py", line 307, in apply
    return user_fn(self, *args)
  File "/root/prime/.venv/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 1710, in backward
    return impl_fn()
  File "/root/prime/.venv/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 1700, in impl_fn
    out = CompiledFunction._backward_impl(ctx, all_args)
  File "/root/prime/.venv/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 2037, in _backward_impl
    CompiledFunction.compiled_bw = aot_config.bw_compiler(
  File "/root/prime/.venv/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 489, in __call__
    return self.compiler_fn(gm, example_inputs)
  File "/root/prime/.venv/lib/python3.10/site-packages/torch/_dynamo/backends/common.py", line 54, in _wrapped_bw_compiler
    return disable(disable(bw_compiler_fn)(*args, **kwargs))
  File "/root/prime/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 745, in _fn
    return fn(*args, **kwargs)
  File "/root/prime/.venv/lib/python3.10/site-packages/torch/_utils_internal.py", line 95, in wrapper_function
    return function(*args, **kwargs)
  File "/root/prime/.venv/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1808, in bw_compiler
    return inner_compile(
  File "/root/prime/.venv/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 569, in compile_fx_inner
    return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
  File "/root/prime/.venv/lib/python3.10/site-packages/torch/_dynamo/repro/after_aot.py", line 102, in debug_wrapper
    inner_compiled_fn = compiler_fn(gm, example_inputs)
  File "/root/prime/.venv/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 675, in _compile_fx_inner
    mb_compiled_graph = fx_codegen_and_compile(
  File "/root/prime/.venv/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1129, in fx_codegen_and_compile
    return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
  File "/root/prime/.venv/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 979, in codegen_and_compile
    graph.run(*example_inputs)
  File "/root/prime/.venv/lib/python3.10/site-packages/torch/_inductor/graph.py", line 855, in run
    return super().run(*args)
  File "/root/prime/.venv/lib/python3.10/site-packages/torch/fx/interpreter.py", line 167, in run
    self.env[node] = self.run_node(node)
  File "/root/prime/.venv/lib/python3.10/site-packages/torch/_inductor/graph.py", line 1496, in run_node
    result = super().run_node(n)
  File "/root/prime/.venv/lib/python3.10/site-packages/torch/fx/interpreter.py", line 230, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/root/prime/.venv/lib/python3.10/site-packages/torch/_inductor/graph.py", line 1143, in call_function
    raise LoweringException(e, target, args, kwargs).with_traceback(
  File "/root/prime/.venv/lib/python3.10/site-packages/torch/_inductor/graph.py", line 1133, in call_function
    out = lowerings[target](*args, **kwargs)  # type: ignore[index]
  File "/root/prime/.venv/lib/python3.10/site-packages/torch/_inductor/lowering.py", line 409, in wrapped
    out = decomp_fn(*args, **kwargs)
  File "/root/prime/.venv/lib/python3.10/site-packages/torch/_inductor/kernel/flex_attention.py", line 2361, in flex_attention_backward
    broadcasted_grad_key = autotune_select_algorithm(
  File "/root/prime/.venv/lib/python3.10/site-packages/torch/_inductor/select_algorithm.py", line 1909, in autotune_select_algorithm
    return _ALGORITHM_SELECTOR_CACHE(*args, **kwargs)
  File "/root/prime/.venv/lib/python3.10/site-packages/torch/_inductor/select_algorithm.py", line 1379, in __call__
    raise NoValidChoicesError(
torch._inductor.exc.LoweringException: NoValidChoicesError: No choices to select, please consider adding ATEN into max_autotune_gemm_backends config (defined in torch/_inductor/config.py) to allow at least one choice.
  target: flex_attention_backward
  args[0]: TensorBox(StorageBox(
    InputBuffer(name='primals_1', layout=FixedLayout('cuda:0', torch.bfloat16, size=[32, 16, 1024, 128], stride=[2097152, 128, 2048, 1]))
  ))
  args[1]: TensorBox(StorageBox(
    InputBuffer(name='primals_2', layout=FixedLayout('cuda:0', torch.bfloat16, size=[32, 16, 1024, 128], stride=[2097152, 128, 2048, 1]))
  ))
  args[2]: TensorBox(StorageBox(
    InputBuffer(name='primals_3', layout=FixedLayout('cuda:0', torch.bfloat16, size=[32, 16, 1024, 128], stride=[2097152, 128, 2048, 1]))
  ))
  args[3]: TensorBox(StorageBox(
    InputBuffer(name='getitem', layout=FixedLayout('cuda:0', torch.bfloat16, size=[32, 16, 1024, 128], stride=[2097152, 128, 2048, 1]))
  ))
  args[4]: TensorBox(StorageBox(
    DonatedBuffer(name='getitem_1', layout=FixedLayout('cuda:0', torch.float32, size=[32, 16, 1024], stride=[16384, 1024, 1]))
  ))
  args[5]: TensorBox(StorageBox(
    InputBuffer(name='tangents_1', layout=FixedLayout('cuda:0', torch.bfloat16, size=[32, 16, 1024, 128], stride=[2097152, 131072, 128, 1]))
  ))
  args[6]: TensorBox(StorageBox(
    Pointwise(
      'cuda',
      torch.float32,
      def inner_fn(index):
          i0, i1, i2 = index
          tmp0 = ops.constant(0, torch.float32)
          return tmp0
      ,
      ranges=[32, 16, 1024],
      origin_node=full_default,
      origins=OrderedSet([full_default])
    )
  ))
  args[7]: Subgraph(name='fw_graph0', graph_module=<lambda>(), graph=None)
  args[8]: Subgraph(name='joint_graph0', graph_module=<lambda>(), graph=None)
  args[9]: (1024, 1024, TensorBox(StorageBox(
    InputBuffer(name='primals_5', layout=FixedLayout('cuda:0', torch.int32, size=[32, 1, 8], stride=[8, 8, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='primals_4', layout=FixedLayout('cuda:0', torch.int32, size=[32, 1, 8, 8], stride=[64, 64, 8, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='primals_7', layout=FixedLayout('cuda:0', torch.int32, size=[32, 1, 8], stride=[8, 8, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='primals_8', layout=FixedLayout('cuda:0', torch.int32, size=[32, 1, 8, 8], stride=[64, 64, 8, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='primals_9', layout=FixedLayout('cuda:0', torch.int32, size=[32, 1, 8], stride=[8, 8, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='primals_10', layout=FixedLayout('cuda:0', torch.int32, size=[32, 1, 8, 8], stride=[64, 64, 8, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='primals_11', layout=FixedLayout('cuda:0', torch.int32, size=[32, 1, 8], stride=[8, 8, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='primals_12', layout=FixedLayout('cuda:0', torch.int32, size=[32, 1, 8, 8], stride=[64, 64, 8, 1]))
  )), 128, 128, Subgraph(name='mask_graph0', graph_module=<lambda>(), graph=None))
  args[10]: 0.08838834764831843
  args[11]: {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'OUTPUT_LOGSUMEXP': True}
  args[12]: ()
  args[13]: (TensorBox(StorageBox(
    InputBuffer(name='primals_6', layout=FixedLayout('cuda:0', torch.int64, size=[32, 1024], stride=[1024, 1]))
  )),)

Signed-off-by: sami jaghouar <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant