Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Try lowering aten.nll_loss_forward to ttnn.moreh_nll_loss #676

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

jdh8
Copy link
Contributor

@jdh8 jdh8 commented Dec 24, 2024

Ticket

Problem description

I forgot how to dig into this kind of error:

FAILED tests/lowering/misc/test_nll_loss.py::test_nll_loss[input_shape0-True-mean--100] - torch._dynamo.exc.BackendCompilerFailed: backend='ttnn_backend' raised:
FAILED tests/lowering/misc/test_nll_loss.py::test_nll_loss[input_shape1-False-mean--100] - torch._dynamo.exc.BackendCompilerFailed: backend='ttnn_backend' raised:

What's changed

Describe the approach used to solve the problem.
Summarize the changes made and their impact

@jdh8 jdh8 self-assigned this Dec 24, 2024
@jdh8
Copy link
Contributor Author

jdh8 commented Dec 26, 2024

I followed the troubleshooting steps in https://github.com/tenstorrent/pytorch2.0_ttnn/blob/main/docs/ProblemSolving.md, but I don't know what it implies.

(python_env) jdh8@tt-loudbox:~/pytorch2.0_ttnn$ pytest tests/lowering/misc/test_nll_loss.py --trace
=========================================================== test session starts ============================================================
platform linux -- Python 3.8.10, pytest-7.2.2, pluggy-1.5.0
rootdir: /home/jdh8/pytorch2.0_ttnn/tests, configfile: pytest.ini
plugins: timeout-2.2.0, split-0.8.2, dash-2.15.0, xdist-3.6.1, anyio-4.5.2
collected 2 items                                                                                                                          

tests/lowering/misc/test_nll_loss.py 
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> PDB runcall (IO-capturing turned off) >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
> /home/jdh8/pytorch2.0_ttnn/tests/lowering/misc/test_nll_loss.py(25)test_nll_loss()
-> module = NllLossModule()
(Pdb) b ../tt-metal/python_env/lib/python3.8/site-packages/torch/fx/passes/infra/pass_manager.py:296
Breakpoint 1 at /home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/torch/fx/passes/infra/pass_manager.py:296
(Pdb) r
> /home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/torch/fx/passes/infra/pass_manager.py(296)__call__()
-> raise Exception(msg) from e
(Pdb) interact
*interactive*
>>> b ../tt-metal/python_env/lib/python3.8/site-packages/torch/fx/passes/infra/pass_manager.py:296
KeyboardInterrupt
>>> import pdb
>>> pdb.post_mortem(e.__traceback__)
> /home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/torch/fx/graph.py(1259)_target_to_str()
-> op = target.__name__
(Pdb)

@kevinwuTT
Copy link
Contributor

kevinwuTT commented Dec 26, 2024

This issue look identical to tenstorrent/tt-metal#9681, even down to the same line in Pytorch. This was fixed in tt-metal. We might need the same for moreh ops.

jdh8 added 2 commits December 27, 2024 03:02
I forgot how to dig into this kind of error:
```
FAILED tests/lowering/misc/test_nll_loss.py::test_nll_loss[input_shape0-True-mean--100] - torch._dynamo.exc.BackendCompilerFailed: backend='ttnn_backend' raised:
FAILED tests/lowering/misc/test_nll_loss.py::test_nll_loss[input_shape1-False-mean--100] - torch._dynamo.exc.BackendCompilerFailed: backend='ttnn_backend' raised:
```
1. Isn't the default divisor 1?
2. We arrive at a segfault
```
tests/lowering/misc/test_nll_loss.py Fatal Python error: Segmentation fault

Thread 0x00007f1a8a7e4700 (most recent call first):
  File "/usr/lib/python3.8/threading.py", line 306 in wait
  File "/usr/lib/python3.8/threading.py", line 558 in wait
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/tqdm/_monitor.py", line 60 in run
  File "/usr/lib/python3.8/threading.py", line 932 in _bootstrap_inner
  File "/usr/lib/python3.8/threading.py", line 890 in _bootstrap

Current thread 0x00007f1ba3a19740 (most recent call first):
  File "/home/jdh8/tt-metal/ttnn/ttnn/decorators.py", line 329 in __call__
  File "/home/jdh8/tt-metal/ttnn/ttnn/operations/core.py", line 233 in from_torch
  File "/home/jdh8/tt-metal/ttnn/ttnn/decorators.py", line 329 in __call__
  File "<eval_with_key>.15", line 8 in forward
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520 in _call_impl
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511 in _wrapped_call_impl
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/torch/fx/graph_module.py", line 304 in __call__
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/torch/fx/graph_module.py", line 738 in call_wrapped
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/torch/_functorch/_aot_autograd/utils.py", line 81 in g
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 118 in rng_functionalization_wrapper
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/torch/_functorch/_aot_autograd/utils.py", line 105 in call_func_at_runtime_with_args
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 94 in runtime_wrapper
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/torch/_functorch/_aot_autograd/utils.py", line 81 in g
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 901 in forward
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/torch/_dynamo/external_utils.py", line 17 in inner
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 489 in _fn
  File "/home/jdh8/pytorch2.0_ttnn/tests/lowering/misc/test_nll_loss.py", line 13 in forward
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520 in _call_impl
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511 in _wrapped_call_impl
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 489 in _fn
  File "/home/jdh8/pytorch2.0_ttnn/tests/lowering/misc/test_nll_loss.py", line 39 in test_nll_loss
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/_pytest/python.py", line 195 in pytest_pyfunc_call
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/pluggy/_callers.py", line 103 in _multicall
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/pluggy/_manager.py", line 120 in _hookexec
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/pluggy/_hooks.py", line 513 in __call__
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/_pytest/python.py", line 1789 in runtest
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/_pytest/runner.py", line 167 in pytest_runtest_call
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/pluggy/_callers.py", line 103 in _multicall
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/pluggy/_manager.py", line 120 in _hookexec
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/pluggy/_hooks.py", line 513 in __call__
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/_pytest/runner.py", line 260 in <lambda>
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/_pytest/runner.py", line 339 in from_call
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/_pytest/runner.py", line 259 in call_runtest_hook
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/_pytest/runner.py", line 220 in call_and_report
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/_pytest/runner.py", line 131 in runtestprotocol
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/_pytest/runner.py", line 112 in pytest_runtest_protocol
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/pluggy/_callers.py", line 103 in _multicall
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/pluggy/_manager.py", line 120 in _hookexec
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/pluggy/_hooks.py", line 513 in __call__
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/_pytest/main.py", line 349 in pytest_runtestloop
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/pluggy/_callers.py", line 103 in _multicall
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/pluggy/_manager.py", line 120 in _hookexec
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/pluggy/_hooks.py", line 513 in __call__
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/_pytest/main.py", line 324 in _main
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/_pytest/main.py", line 270 in wrap_session
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/_pytest/main.py", line 317 in pytest_cmdline_main
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/pluggy/_callers.py", line 103 in _multicall
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/pluggy/_manager.py", line 120 in _hookexec
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/pluggy/_hooks.py", line 513 in __call__
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/_pytest/config/__init__.py", line 167 in main
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/_pytest/config/__init__.py", line 190 in console_main
  File "/home/jdh8/tt-metal/python_env/bin/pytest", line 8 in <module>
Segmentation fault (core dumped)
```
@@ -192,12 +192,12 @@ def is_tt_compute(node) -> bool:
ttnn.zeros_like,
ttnn.mean,
ttnn.moreh_cumsum,
ttnn.moreh_nll_loss,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Using ttnn.moreh_* wrappers work!

input, target, weight, reduction, ignore_index = args
args = input, target, ("none", "mean", "sum")[reduction]
kwargs = {
"divisor_tensor": torch.tensor([0], dtype=get_dtype(input)),
Copy link
Contributor Author

@jdh8 jdh8 Dec 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Isn't the default divisor 1?
  2. We arrive at a segfault (even after I tried 1 instead of 0)
tests/lowering/misc/test_nll_loss.py Fatal Python error: Segmentation fault

Thread 0x00007f1a8a7e4700 (most recent call first):
  File "/usr/lib/python3.8/threading.py", line 306 in wait
  File "/usr/lib/python3.8/threading.py", line 558 in wait
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/tqdm/_monitor.py", line 60 in run
  File "/usr/lib/python3.8/threading.py", line 932 in _bootstrap_inner
  File "/usr/lib/python3.8/threading.py", line 890 in _bootstrap

Current thread 0x00007f1ba3a19740 (most recent call first):
  File "/home/jdh8/tt-metal/ttnn/ttnn/decorators.py", line 329 in __call__
  File "/home/jdh8/tt-metal/ttnn/ttnn/operations/core.py", line 233 in from_torch
  File "/home/jdh8/tt-metal/ttnn/ttnn/decorators.py", line 329 in __call__
  File "<eval_with_key>.15", line 8 in forward
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520 in _call_impl
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511 in _wrapped_call_impl
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/torch/fx/graph_module.py", line 304 in __call__
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/torch/fx/graph_module.py", line 738 in call_wrapped
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/torch/_functorch/_aot_autograd/utils.py", line 81 in g
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 118 in rng_functionalization_wrapper
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/torch/_functorch/_aot_autograd/utils.py", line 105 in call_func_at_runtime_with_args
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 94 in runtime_wrapper
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/torch/_functorch/_aot_autograd/utils.py", line 81 in g
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 901 in forward
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/torch/_dynamo/external_utils.py", line 17 in inner
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 489 in _fn
  File "/home/jdh8/pytorch2.0_ttnn/tests/lowering/misc/test_nll_loss.py", line 13 in forward
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520 in _call_impl
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511 in _wrapped_call_impl
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 489 in _fn
  File "/home/jdh8/pytorch2.0_ttnn/tests/lowering/misc/test_nll_loss.py", line 39 in test_nll_loss
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/_pytest/python.py", line 195 in pytest_pyfunc_call
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/pluggy/_callers.py", line 103 in _multicall
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/pluggy/_manager.py", line 120 in _hookexec
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/pluggy/_hooks.py", line 513 in __call__
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/_pytest/python.py", line 1789 in runtest
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/_pytest/runner.py", line 167 in pytest_runtest_call
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/pluggy/_callers.py", line 103 in _multicall
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/pluggy/_manager.py", line 120 in _hookexec
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/pluggy/_hooks.py", line 513 in __call__
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/_pytest/runner.py", line 260 in <lambda>
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/_pytest/runner.py", line 339 in from_call
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/_pytest/runner.py", line 259 in call_runtest_hook
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/_pytest/runner.py", line 220 in call_and_report
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/_pytest/runner.py", line 131 in runtestprotocol
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/_pytest/runner.py", line 112 in pytest_runtest_protocol
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/pluggy/_callers.py", line 103 in _multicall
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/pluggy/_manager.py", line 120 in _hookexec
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/pluggy/_hooks.py", line 513 in __call__
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/_pytest/main.py", line 349 in pytest_runtestloop
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/pluggy/_callers.py", line 103 in _multicall
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/pluggy/_manager.py", line 120 in _hookexec
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/pluggy/_hooks.py", line 513 in __call__
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/_pytest/main.py", line 324 in _main
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/_pytest/main.py", line 270 in wrap_session
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/_pytest/main.py", line 317 in pytest_cmdline_main
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/pluggy/_callers.py", line 103 in _multicall
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/pluggy/_manager.py", line 120 in _hookexec
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/pluggy/_hooks.py", line 513 in __call__
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/_pytest/config/__init__.py", line 167 in main
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/_pytest/config/__init__.py", line 190 in console_main
  File "/home/jdh8/tt-metal/python_env/bin/pytest", line 8 in <module>
Segmentation fault (core dumped)

@jdh8 jdh8 changed the title Try lowering aten.nll_loss_forward to ttnn.operations.moreh.nll_loss Try lowering aten.nll_loss_forward to ttnn.moreh_nll_loss Dec 30, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

aten.nll_loss_forward.default
3 participants