Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert aten.split to ttnn.split #195

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open

Convert aten.split to ttnn.split #195

wants to merge 2 commits into from

Conversation

jdh8
Copy link
Contributor

@jdh8 jdh8 commented Sep 13, 2024

Ticket

None

Problem description

Convert torch.split to ttnn.split

What's changed

  • Convert aten.split to ttnn.split
  • Test the conversion

@jdh8 jdh8 self-assigned this Sep 13, 2024
Even a symmetric split fails with FPE
```
Current thread 0x00007fda41e6a740 (most recent call first):
  File "/home/jdh8/venv/lib/python3.8/site-packages/ttnn/decorators.py", line 326 in __call__
  File "<eval_with_key>.13", line 6 in forward
  File "/home/jdh8/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520 in _call_impl
  File "/home/jdh8/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511 in _wrapped_call_impl
  File "/home/jdh8/venv/lib/python3.8/site-packages/torch/fx/graph_module.py", line 304 in __call__
  File "/home/jdh8/venv/lib/python3.8/site-packages/torch/fx/graph_module.py", line 738 in call_wrapped
  File "/home/jdh8/venv/lib/python3.8/site-packages/torch/_functorch/_aot_autograd/utils.py", line 81 in g
  File "/home/jdh8/venv/lib/python3.8/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 118 in rng_functionalization_wrapper
```
@jdh8
Copy link
Contributor Author

jdh8 commented Sep 13, 2024

Even a symmetric split fails with FPE 🧐

tests/lowering/tensor_manipulation/test_split.py Fatal Python error: Floating point exception

Thread 0x00007fd9d9ffb700 (most recent call first):
  File "/usr/lib/python3.8/threading.py", line 306 in wait
  File "/usr/lib/python3.8/threading.py", line 558 in wait
  File "/home/jdh8/venv/lib/python3.8/site-packages/tqdm/_monitor.py", line 60 in run
  File "/usr/lib/python3.8/threading.py", line 932 in _bootstrap_inner
  File "/usr/lib/python3.8/threading.py", line 890 in _bootstrap

Current thread 0x00007fda41e6a740 (most recent call first):
  File "/home/jdh8/venv/lib/python3.8/site-packages/ttnn/decorators.py", line 326 in __call__
  File "<eval_with_key>.13", line 6 in forward
  File "/home/jdh8/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520 in _call_impl
  File "/home/jdh8/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511 in _wrapped_call_impl
  File "/home/jdh8/venv/lib/python3.8/site-packages/torch/fx/graph_module.py", line 304 in __call__
  File "/home/jdh8/venv/lib/python3.8/site-packages/torch/fx/graph_module.py", line 738 in call_wrapped
  File "/home/jdh8/venv/lib/python3.8/site-packages/torch/_functorch/_aot_autograd/utils.py", line 81 in g
  File "/home/jdh8/venv/lib/python3.8/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 118 in rng_functionalization_wrapper
  File "/home/jdh8/venv/lib/python3.8/site-packages/torch/_functorch/_aot_autograd/utils.py", line 105 in call_func_at_runtime_with_args
  File "/home/jdh8/venv/lib/python3.8/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 94 in runtime_wrapper
  File "/home/jdh8/venv/lib/python3.8/site-packages/torch/_functorch/_aot_autograd/utils.py", line 81 in g
  File "/home/jdh8/venv/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 901 in forward
  File "/home/jdh8/venv/lib/python3.8/site-packages/torch/_dynamo/external_utils.py", line 17 in inner
  File "/home/jdh8/venv/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 489 in _fn
  File "/home/jdh8/pytorch2.0_ttnn/tests/lowering/tensor_manipulation/test_split.py", line 13 in forward
  File "/home/jdh8/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520 in _call_impl
  File "/home/jdh8/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511 in _wrapped_call_impl
  File "/home/jdh8/venv/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 489 in _fn
  File "/home/jdh8/pytorch2.0_ttnn/tests/lowering/tensor_manipulation/test_split.py", line 38 in test_split
  File "/home/jdh8/venv/lib/python3.8/site-packages/_pytest/python.py", line 195 in pytest_pyfunc_call
  File "/home/jdh8/venv/lib/python3.8/site-packages/pluggy/_callers.py", line 103 in _multicall
  File "/home/jdh8/venv/lib/python3.8/site-packages/pluggy/_manager.py", line 120 in _hookexec
  File "/home/jdh8/venv/lib/python3.8/site-packages/pluggy/_hooks.py", line 513 in __call__
  File "/home/jdh8/venv/lib/python3.8/site-packages/_pytest/python.py", line 1789 in runtest
  File "/home/jdh8/venv/lib/python3.8/site-packages/_pytest/runner.py", line 167 in pytest_runtest_call
  File "/home/jdh8/venv/lib/python3.8/site-packages/pluggy/_callers.py", line 103 in _multicall
  File "/home/jdh8/venv/lib/python3.8/site-packages/pluggy/_manager.py", line 120 in _hookexec
  File "/home/jdh8/venv/lib/python3.8/site-packages/pluggy/_hooks.py", line 513 in __call__
  File "/home/jdh8/venv/lib/python3.8/site-packages/_pytest/runner.py", line 260 in <lambda>
  File "/home/jdh8/venv/lib/python3.8/site-packages/_pytest/runner.py", line 339 in from_call
  File "/home/jdh8/venv/lib/python3.8/site-packages/_pytest/runner.py", line 259 in call_runtest_hook
  File "/home/jdh8/venv/lib/python3.8/site-packages/_pytest/runner.py", line 220 in call_and_report
  File "/home/jdh8/venv/lib/python3.8/site-packages/_pytest/runner.py", line 131 in runtestprotocol
  File "/home/jdh8/venv/lib/python3.8/site-packages/_pytest/runner.py", line 112 in pytest_runtest_protocol
  File "/home/jdh8/venv/lib/python3.8/site-packages/pluggy/_callers.py", line 103 in _multicall
  File "/home/jdh8/venv/lib/python3.8/site-packages/pluggy/_manager.py", line 120 in _hookexec
  File "/home/jdh8/venv/lib/python3.8/site-packages/pluggy/_hooks.py", line 513 in __call__
  File "/home/jdh8/venv/lib/python3.8/site-packages/_pytest/main.py", line 349 in pytest_runtestloop
  File "/home/jdh8/venv/lib/python3.8/site-packages/pluggy/_callers.py", line 103 in _multicall
  File "/home/jdh8/venv/lib/python3.8/site-packages/pluggy/_manager.py", line 120 in _hookexec
  File "/home/jdh8/venv/lib/python3.8/site-packages/pluggy/_hooks.py", line 513 in __call__
  File "/home/jdh8/venv/lib/python3.8/site-packages/_pytest/main.py", line 324 in _main
  File "/home/jdh8/venv/lib/python3.8/site-packages/_pytest/main.py", line 270 in wrap_session
  File "/home/jdh8/venv/lib/python3.8/site-packages/_pytest/main.py", line 317 in pytest_cmdline_main
  File "/home/jdh8/venv/lib/python3.8/site-packages/pluggy/_callers.py", line 103 in _multicall
  File "/home/jdh8/venv/lib/python3.8/site-packages/pluggy/_manager.py", line 120 in _hookexec
  File "/home/jdh8/venv/lib/python3.8/site-packages/pluggy/_hooks.py", line 513 in __call__
  File "/home/jdh8/venv/lib/python3.8/site-packages/_pytest/config/__init__.py", line 167 in main
  File "/home/jdh8/venv/lib/python3.8/site-packages/_pytest/config/__init__.py", line 190 in console_main
  File "/home/jdh8/venv/bin/pytest", line 8 in <module>
Floating point exception (core dumped)

@jdh8 jdh8 marked this pull request as ready for review September 13, 2024 18:59
@jdh8 jdh8 added the blocked label Sep 13, 2024
@jdh8
Copy link
Contributor Author

jdh8 commented Sep 13, 2024

For now, ttnn.split only supports splitting to two equally sized tensors. I'm thinking of converting unsupported cases to ttnn.slice:

  • aten.split.Tensor that does not split in halves
  • Any aten.split_with_sizes.default

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants