Skip to content

Commit

Permalink
Fix shape in outer (#17492)
Browse files Browse the repository at this point in the history
### Ticket
Link to Github Issue
#16882

### Problem description
 ttnn::outer fails after tilizing the inputs

### What's changed
outer op is checking the padded size of the inputs which is causing the
error. This PR changes the shape used in outer

### Checklist
- [x] Post commit CI passes
https://github.com/tenstorrent/tt-metal/actions/runs/13167635235
- [ ] Blackhole Post commit (if applicable)
- [ ] Model regression CI testing passes (if applicable)
- [ ] Device performance regression CI testing passes (if applicable)
- [ ] **(For models and ops writers)** Full [new
models](https://github.com/tenstorrent/tt-metal/actions/workflows/full-new-models-suite.yaml)
tests passes
- [ ] New/Existing tests provide coverage for changes
  • Loading branch information
nardoTT authored Feb 18, 2025
1 parent d277980 commit 2958cac
Showing 1 changed file with 6 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -476,8 +476,8 @@ Tensor _scatter(const Tensor& input_a, const Tensor& input_b, const std::optiona
* by running reshape.
*/
Tensor _outer(const Tensor& input_a, const Tensor& input_b, const std::optional<MemoryConfig>& output_mem_config) {
const ttnn::Shape s_a = input_a.padded_shape();
const ttnn::Shape s_b = input_b.padded_shape();
const ttnn::Shape s_a = input_a.get_logical_shape();
const ttnn::Shape s_b = input_b.get_logical_shape();
auto num_ones = [](const ttnn::Shape& s) -> uint32_t {
uint32_t num1s = 0;
for (uint32_t idx = 0; idx < 4; idx++) {
Expand All @@ -497,10 +497,12 @@ Tensor _outer(const Tensor& input_a, const Tensor& input_b, const std::optional<
Tensor b_slim = input_b;

if (!skip_reshape_a) {
a_slim = ttnn::reshape(input_a, ttnn::Shape{std::array<uint32_t, 4>{1, 1, input_a.volume(), 1}});
uint32_t a_volume = s_a[0] * s_a[1] * s_a[2] * s_a[3];
a_slim = ttnn::reshape(input_a, ttnn::Shape{std::array<uint32_t, 4>{1, 1, a_volume, 1}});
}
if (!skip_reshape_b) {
b_slim = ttnn::reshape(input_b, ttnn::Shape{std::array<uint32_t, 4>{1, 1, 1, input_b.volume()}});
uint32_t b_volume = s_b[0] * s_b[1] * s_b[2] * s_b[3];
b_slim = ttnn::reshape(input_b, ttnn::Shape{std::array<uint32_t, 4>{1, 1, 1, b_volume}});
}
a_slim = ttnn::to_layout(a_slim, ttnn::TILE_LAYOUT, std::nullopt, std::nullopt, (IDevice*)nullptr);
b_slim = ttnn::to_layout(b_slim, ttnn::TILE_LAYOUT, std::nullopt, std::nullopt, (IDevice*)nullptr);
Expand Down

0 comments on commit 2958cac

Please sign in to comment.