-
Notifications
You must be signed in to change notification settings - Fork 116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TTNN pad
does not support last dimension
#12896
Comments
@esmalTT Can you give an example test? I have been working on this for a bit and I'm actually now confused how you obtained a sharded tensor of that shape in the first place :) |
I’m away from my computer until next week (so sadly I can’t run this) but this is what I need for UNet import pytest
import torch
import ttnn
def test_unet_pad(device, use_program_cache):
x = ttnn.from_torch(torch.rand([1, 1, 337920, 4]), dtype=ttnn.bfloat16)
sharded_memory_config = ttnn.create_sharded_memory_config(
[1, 1, 337920, 4], ttnn.CoreGrid(x=8, y=8), ttnn.ShardStrategy.HEIGHT
)
x = ttnn.to_device(x, device, sharded_memory_config)
x = ttnn.pad(x, ((0,0),(0,0),(0,0),(0,12))) # pad up to 16 |
This is perfect, thanks @esmalTT :) |
hey @esmalTT , @bbradelTT ran into this problem too. @jaykru-tt is actively looking at this and should have something for next week. |
@ntarafdar @bbradelTT I was never able to come up with a workaround, we ended up avoiding the need for this entirely by making moving to channels-first instead of channels-last. |
Thanks @esmalTT that is a workaround. We can just transpose back afterwards:
outputs
|
Update: turns out the last dimension needs to be a multiple of 2, therefore the workaround does not work in many cases. |
A conversation with @yugaoTT yesterday which revealed that we will need new kernels to support last dim padding. I'm starting on those today and hope to be done by end of day on Friday. |
Update on progress from yesterday: I've got the new stick-wise sharded reader written out on paper. I just need to transcribe it to code this morning. I'll also be adding a new program factory for these cases. Fortunately the writer I've already written should work just fine. After that, the rest of the week will hopefully be devoted to debugging. Thanks all for your patience! :) |
Update on current progress: kernels and the new op path are complete, tested working on a few cases. I need to add a few more tests, do a few low hanging optimization, and will start off the PR process tomorrow. cc @bbradelTT |
### Tickets - #15511 - #15603 (90% resolved with these changes and to be fully resolved in a future PR) - #12896 ### Problem description ttnn.pad's RM sharded implementation only has support for padding along the non-width dimensions. The row major implementation additionally is not fully general with respect to the width dimension, so until now there are no great options for padding along width. In a future PR coming tomorrow, I'll add input massaging code to convert to row-major and shard as needed for input configurations that aren't currently supported by pad. ### What's changed - Adds new kernels to support padding along the width dimension. - For pad operations requiring both NCH and width padding, we use a fused op using the original height-padding kernels and the new width kernels. - The previous point required extensive refactoring to the host code. I would like eyes on pad.cpp please @yugaoTT @sminakov-tt. - Also adds a bunch of common utility functions for working with sharded tensors: - A function for easily creating sharded memory configs from C++ (analogous to the Python `create_sharded_memory_config` utility function created by @ntarafdar) - A function for locating elements of a shard by their coordinates within the tensor. I've tested this one in the context of this PR, but it didn't end up being necessary in the final implementation. ### Checklist - [~] [Post commit CI passes](https://github.com/tenstorrent/tt-metal/actions/runs/12327681570) - [x] [Model regression CI testing passes](https://github.com/tenstorrent/tt-metal/actions/runs/12308045581) - [x] [Device performance regression CI testing passes](https://github.com/tenstorrent/tt-metal/actions/runs/12308046347) - [ ] Blackhole Post commit (if applicable) - [ ] **(For models and ops writers)** Full [new models](https://github.com/tenstorrent/tt-metal/actions/workflows/full-new-models-suite.yaml) tests passes - [ ] New/Existing tests provide coverage for changes --------- Co-authored-by: tarafdarTT <[email protected]>
@esmalTT this is fixed with tests for a reduced version of your input in main. Going to close this now :) |
Summary
Pad on device is required to improve end-to-end performance of UNet Shallow. The sharded input tensor needs to be padded from 4 -> 16 channels.
Running a pad on a device tensor of shape
{1, 1, 2 * 1056 * 160, 4}
to{1, 1, 2 * 1056 * 160, 16}
throws the following error:The text was updated successfully, but these errors were encountered: