Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TTNN pad does not support last dimension #12896

Closed
esmalTT opened this issue Sep 19, 2024 · 12 comments
Closed

TTNN pad does not support last dimension #12896

esmalTT opened this issue Sep 19, 2024 · 12 comments

Comments

@esmalTT
Copy link
Contributor

esmalTT commented Sep 19, 2024

Summary

Pad on device is required to improve end-to-end performance of UNet Shallow. The sharded input tensor needs to be padded from 4 -> 16 channels.

Running a pad on a device tensor of shape {1, 1, 2 * 1056 * 160, 4} to {1, 1, 2 * 1056 * 160, 16} throws the following error:

sharded pad does not support pad on last dim currently as that will cause perf degradation
@jaykru-tt
Copy link
Contributor

@esmalTT Can you give an example test? I have been working on this for a bit and I'm actually now confused how you obtained a sharded tensor of that shape in the first place :)

@esmalTT
Copy link
Contributor Author

esmalTT commented Sep 26, 2024

@esmalTT Can you give an example test? I have been working on this for a bit and I'm actually now confused how you obtained a sharded tensor of that shape in the first place :)

I’m away from my computer until next week (so sadly I can’t run this) but this is what I need for UNet

import pytest
import torch
import ttnn

def test_unet_pad(device, use_program_cache):
    x = ttnn.from_torch(torch.rand([1, 1, 337920, 4]), dtype=ttnn.bfloat16)

    sharded_memory_config = ttnn.create_sharded_memory_config(
        [1, 1, 337920, 4], ttnn.CoreGrid(x=8, y=8), ttnn.ShardStrategy.HEIGHT
    )

    x = ttnn.to_device(x, device, sharded_memory_config)

    x = ttnn.pad(x, ((0,0),(0,0),(0,0),(0,12))) # pad up to 16

@jaykru-tt
Copy link
Contributor

This is perfect, thanks @esmalTT :)

@ntarafdar
Copy link
Contributor

hey @esmalTT , @bbradelTT ran into this problem too. @jaykru-tt is actively looking at this and should have something for next week.
Do you have a workaround for this? if so could you comment on it so @bbradelTT can see it too.

@esmalTT
Copy link
Contributor Author

esmalTT commented Nov 28, 2024

hey @esmalTT , @bbradelTT ran into this problem too. @jaykru-tt is actively looking at this and should have something for next week. Do you have a workaround for this? if so could you comment on it so @bbradelTT can see it too.

@ntarafdar @bbradelTT I was never able to come up with a workaround, we ended up avoiding the need for this entirely by making moving to channels-first instead of channels-last.

@bbradelTT
Copy link
Contributor

Thanks @esmalTT that is a workaround. We can just transpose back afterwards:

def test_pad_transpose(device):
    input_shape = (1, 1, 12, 8)
    a = torch.ones(input_shape) 
    b = ttnn.from_torch(a, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
    b2 = ttnn.to_layout(b, layout=ttnn.ROW_MAJOR_LAYOUT)
    b3 = ttnn.transpose(b2, 1, 3)
    c = ttnn.pad(b3, [1, 32, 32, 32], [0, 0, 0, 0], 3)
    c2 = ttnn.transpose(c, 1, 3)
    c3 = ttnn.slice(c2, [0, 0, 0, 0], [1, 1, 32, 32])
    c4 = ttnn.to_layout(c3, layout=ttnn.TILE_LAYOUT)
    print(f'b2 {b2}\nb3 {b3}\n c {c}\nc2 {c2}\nc3 {c3}\nc4 {c4}')

outputs

...
c4 ttnn.Tensor([[[[ 1.00000,  1.00000,  ...,  3.00391,  3.00391],
               [ 1.00000,  1.00000,  ...,  3.00391,  3.00391],
               ...,
               [ 3.00391,  3.00391,  ...,  3.00391,  3.00391],
               [ 3.00391,  3.00391,  ...,  3.00391,  3.00391]]]], shape=Shape([1, 1, 32, 32]), dtype=DataType::FLOAT32, layout=Layout::TILE)

@bbradelTT
Copy link
Contributor

Update: turns out the last dimension needs to be a multiple of 2, therefore the workaround does not work in many cases.

@bbradelTT
Copy link
Contributor

Since there's no workaround and this is blocking P0s (#13621 and #14530), along with #15511 moving this to a P0 as well.

@bbradelTT bbradelTT added P0 and removed P1 labels Dec 2, 2024
@jaykru-tt
Copy link
Contributor

A conversation with @yugaoTT yesterday which revealed that we will need new kernels to support last dim padding. I'm starting on those today and hope to be done by end of day on Friday.

@jaykru-tt
Copy link
Contributor

Update on progress from yesterday: I've got the new stick-wise sharded reader written out on paper. I just need to transcribe it to code this morning. I'll also be adding a new program factory for these cases. Fortunately the writer I've already written should work just fine. After that, the rest of the week will hopefully be devoted to debugging. Thanks all for your patience! :)

@jaykru-tt
Copy link
Contributor

Update on current progress: kernels and the new op path are complete, tested working on a few cases. I need to add a few more tests, do a few low hanging optimization, and will start off the PR process tomorrow. cc @bbradelTT

jaykru-tt added a commit that referenced this issue Dec 14, 2024
### Tickets
- #15511 
- #15603 (90% resolved
with these changes and to be fully resolved in a future PR)
- #12896

### Problem description
ttnn.pad's RM sharded implementation only has support for padding along
the non-width dimensions. The row major implementation additionally is
not fully general with respect to the width dimension, so until now
there are no great options for padding along width. In a future PR
coming tomorrow, I'll add input massaging code to convert to row-major
and shard as needed for input configurations that aren't currently
supported by pad.

### What's changed
- Adds new kernels to support padding along the width dimension.
- For pad operations requiring both NCH and width padding, we use a
fused op using the original height-padding kernels and the new width
kernels.
- The previous point required extensive refactoring to the host code. I
would like eyes on pad.cpp please @yugaoTT @sminakov-tt.
- Also adds a bunch of common utility functions for working with sharded
tensors:
- A function for easily creating sharded memory configs from C++
(analogous to the Python `create_sharded_memory_config` utility function
created by @ntarafdar)
- A function for locating elements of a shard by their coordinates
within the tensor. I've tested this one in the context of this PR, but
it didn't end up being necessary in the final implementation.

### Checklist
- [~] [Post commit CI
passes](https://github.com/tenstorrent/tt-metal/actions/runs/12327681570)
- [x] [Model regression CI testing
passes](https://github.com/tenstorrent/tt-metal/actions/runs/12308045581)
- [x] [Device performance regression CI testing
passes](https://github.com/tenstorrent/tt-metal/actions/runs/12308046347)
- [ ] Blackhole Post commit (if applicable)
- [ ] **(For models and ops writers)** Full [new
models](https://github.com/tenstorrent/tt-metal/actions/workflows/full-new-models-suite.yaml)
tests passes
- [ ] New/Existing tests provide coverage for changes

---------

Co-authored-by: tarafdarTT <[email protected]>
@jaykru-tt
Copy link
Contributor

@esmalTT this is fixed with tests for a reduced version of your input in main. Going to close this now :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

6 participants