Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#17477: Introduce ND coordinate system for TT-distributed #17745

Merged
merged 10 commits into from
Feb 16, 2025
Merged

Conversation

omilyutin-tt
Copy link
Contributor

@omilyutin-tt omilyutin-tt commented Feb 7, 2025

Ticket

#17477

Problem description

Existing mesh infra assumes 2D. This assumption won't hold in the future.

What's changed

Introduce a new SimpleMeshShape that will gradually replace the existing MeshShape, after which it will be renamed to MeshShape.

Introduce MeshCoordinate, MeshCoordinateRange, and MeshContainer - primitives designed to work with the new ND coordinate system.

MeshContainer allows efficient flat representation of various metadata that matches the mesh shape. Iterators are available to make it easy to use. MeshCoordinate along with strides that are precomputed on SimpleMeshShape allows for an easy point access. The integration with MeshBuffer demonstrates the use case.

Next steps:

  • Replace the existing MeshShape, MeshOffset, and the related aliases with the new SimpleMeshShape, and MeshCoordinate.
  • No plans to generalize with CoreCoord, for now. Cores are fundamentally in 2D, so a more specialized system can be used for efficiency. Also it is not desired to make CoreCoord to interop with MeshCoordinate - the 2 sets of coordinates mean entirely different concepts.
  • More functionality might be added, as we continue working on TT-distributed.

Checklist

bool eq_spans(const ArrayType& a, const ArrayType& b) {
return std::equal(a.begin(), a.end(), b.begin(), b.end());
}
bool eq_spans(const auto a, const auto b) { return std::equal(a.begin(), a.end(), b.begin(), b.end()); }
Copy link
Contributor Author

@omilyutin-tt omilyutin-tt Feb 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to do this because a is now tt::stl::Span while b is std::span. Annoying, but this keeps Metal at cpp17.

Copy link
Contributor

@tt-asaigal tt-asaigal left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great overall. One thing to keep in mind is that our existing 2D coordinate system in Metal (exposed through CoreCoord, CoreRange and CoreRangeSet) provides a bunch of utility functions, allowing users to compute set/range intersections, adjacency, etc.
It would be very useful for us to expose similar APIs for this new ND coordinate system as well. Especially as we start introducing more heterogeneity in our workloads.

mesh_device_->num_cols());
return buffers_[device_coord.row][device_coord.col];
return get_device_buffer(MeshCoordinate(device_coord.row, device_coord.col));
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add a comment saying that this overload is kept around to be compatible with existing infra and will be removed once everything is migrated to use MeshCoordinate.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed offline that the overload itself is useful to keep as a shorthand (in case user just wants to use 2D cluster). But we will drop the "num rows" "num cols" terminology, and the infra will be generic.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for MeshCoordinate overload - the std::shared_ptr<Buffer> MeshBuffer::get_device_buffer(const Coordinate& device_coord) will go away entirely.

@@ -218,7 +218,11 @@ std::vector<IDevice*> MeshDevice::get_devices() const { return view_->get_device

// TODO: Remove this function once we have a proper view interface
IDevice* MeshDevice::get_device(size_t row_idx, size_t col_idx) const {
return this->get_device_index(row_idx * num_cols() + col_idx);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, I don't see why we need this overload to expose physical devices using [row, col] in the long term. This doesn't make sense for an ND mesh anyway

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep all of this will go away.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, we should remove this in another PR otherwise it'll be confusing with (y, x) vs. (x,y). Actually - is this a bug? Should it be MeshCoordinate{col_idx, row_idx}?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got rid of rows/cols terminology, this logic should be consistent with itself -- we end up computing:

linear_index += coord[dim] * shape.get_stride(dim);

Over all dims. Since I initialize the shape with [num_rows, num_cols] and coord with [row_idx, col_idx], I am effectively getting the same thing.

Unless I am misreading something :)

}

MeshCoordinate::MeshCoordinate(uint32_t coord) : value_({coord}) {}
MeshCoordinate::MeshCoordinate(uint32_t row, uint32_t col) : value_({row, col}) {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The use of [row, col] should be removed eventually. I expect that users will be extremely confused by two different notations being exposed by the same data structure.
Fundamentally, removing this requires MeshDevice to start using a Cartesian scheme.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to use x y as discussed offline.

}

MeshCoordinateRange::MeshCoordinateRange(const SimpleMeshShape& shape) :
MeshCoordinateRange(zero_coordinate(shape.dims()), shape_back(shape)) {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

Copy link
Contributor

@jvegaTT jvegaTT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving on pad file changes

Copy link
Contributor

@cfjchu cfjchu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great -- left a few comments to resolve before you merge

@@ -218,7 +218,11 @@ std::vector<IDevice*> MeshDevice::get_devices() const { return view_->get_device

// TODO: Remove this function once we have a proper view interface
IDevice* MeshDevice::get_device(size_t row_idx, size_t col_idx) const {
return this->get_device_index(row_idx * num_cols() + col_idx);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, we should remove this in another PR otherwise it'll be confusing with (y, x) vs. (x,y). Actually - is this a bug? Should it be MeshCoordinate{col_idx, row_idx}?

size_t get_stride(size_t dim) const;

// Returns the total number of elements in the mesh.
size_t mesh_size() const;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe just size()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was my thinking initially. There is a size() member from ShapeBase that refers to the dimensionality of the shape, so I wanted to be explicit here.

// Throws if `coord` is out of bounds of `shape`.
size_t to_linear_index(const SimpleMeshShape& shape, const MeshCoordinate& coord);

// Represents a range of MeshCoordinates. Requires that mesh coordinates have the same dimensionality.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be good just to add a comment saying it's inclusive start/end

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have it further down, where I mention specific "start" and "end" (without these a bit hard to describe what exactly does range "include").

Also added end_coord() documentation.

@omilyutin-tt omilyutin-tt merged commit 52c53d5 into main Feb 16, 2025
236 of 247 checks passed
@omilyutin-tt omilyutin-tt deleted the omilyutin/nd branch February 16, 2025 01:55
hschoi4448 pushed a commit that referenced this pull request Feb 20, 2025
### Ticket
#17477 

### Problem description
Existing mesh infra assumes 2D. This assumption won't hold in the
future.

### What's changed
Introduce a new `SimpleMeshShape` that will gradually replace the
existing `MeshShape`, after which it will be renamed to `MeshShape`.

Introduce `MeshCoordinate`, `MeshCoordinateRange`, and `MeshContainer` -
primitives designed to work with the new ND coordinate system.

`MeshContainer` allows efficient flat representation of various metadata
that matches the mesh shape. Iterators are available to make it easy to
use. `MeshCoordinate` along with strides that are precomputed on
`SimpleMeshShape` allows for an easy point access. The integration with
`MeshBuffer` demonstrates the use case.

Next steps:
* Replace the existing `MeshShape`, `MeshOffset`, and the related
aliases with the new `SimpleMeshShape`, and `MeshCoordinate`.
* No plans to generalize with `CoreCoord`, for now. Cores are
fundamentally in 2D, so a more specialized system can be used for
efficiency. Also it is not desired to make `CoreCoord` to interop with
`MeshCoordinate` - the 2 sets of coordinates mean entirely different
concepts.
* More functionality might be added, as we continue working on
TT-distributed.

### Checklist
- [X] [All post
commit](https://github.com/tenstorrent/tt-metal/actions/runs/13347753550)
- [X] New/Existing tests provide coverage for changes
TT-billteng pushed a commit that referenced this pull request Feb 21, 2025
### Ticket
#17477 

### Problem description
Existing mesh infra assumes 2D. This assumption won't hold in the
future.

### What's changed
Introduce a new `SimpleMeshShape` that will gradually replace the
existing `MeshShape`, after which it will be renamed to `MeshShape`.

Introduce `MeshCoordinate`, `MeshCoordinateRange`, and `MeshContainer` -
primitives designed to work with the new ND coordinate system.

`MeshContainer` allows efficient flat representation of various metadata
that matches the mesh shape. Iterators are available to make it easy to
use. `MeshCoordinate` along with strides that are precomputed on
`SimpleMeshShape` allows for an easy point access. The integration with
`MeshBuffer` demonstrates the use case.

Next steps:
* Replace the existing `MeshShape`, `MeshOffset`, and the related
aliases with the new `SimpleMeshShape`, and `MeshCoordinate`.
* No plans to generalize with `CoreCoord`, for now. Cores are
fundamentally in 2D, so a more specialized system can be used for
efficiency. Also it is not desired to make `CoreCoord` to interop with
`MeshCoordinate` - the 2 sets of coordinates mean entirely different
concepts.
* More functionality might be added, as we continue working on
TT-distributed.

### Checklist
- [X] [All post
commit](https://github.com/tenstorrent/tt-metal/actions/runs/13347753550)
- [X] New/Existing tests provide coverage for changes
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants