-
Notifications
You must be signed in to change notification settings - Fork 110
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DRAFT] [WIP] Path exploration for TT-NN x TT-Mesh Integration #18067
base: main
Are you sure you want to change the base?
Conversation
@@ -614,7 +606,7 @@ Tensor to_host_mesh_tensor(const Tensor& tensor, bool blocking) { | |||
|
|||
mesh_cq.enqueue_read_shards(shard_data_transfers, mesh_buffer, /*blocking=*/true); | |||
|
|||
MultiDeviceHostStorage host_storage(storage.strategy, std::move(buffers), std::move(specs)); | |||
MultiDeviceHostStorage host_storage(AllGatherTensor{}, std::move(buffers), std::move(specs)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fyi, on another branch I am attempting to get rid of "strategy". I think we only need a shape to track which devices the tensor was uploaded. It can be a full mesh shape or a submesh.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that would be great!
std::vector<T> host_buffer; | ||
const auto& shard_tensor_spec = storage.specs.at(id); | ||
const auto& shard_tensor_spec = tensor.get_tensor_spec(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We will need the per-shard information, won't we? What is your plan here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pragmatically for now. I'll add support for it if it's really needed. For now, let's see how far I can get.
7441d1b
to
560ad32
Compare
@@ -108,6 +80,7 @@ class Tensor { | |||
explicit Tensor( | |||
uint32_t num_buffers, std::optional<DistributedTensorConfig> distributed_tensor_config = std::nullopt); | |||
explicit Tensor(const std::vector<IDevice*>& workers); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can be removed?
71a3899
to
bae493e
Compare
This reverts commit bae493e.
### Ticket #18360 ### Problem description Recently we disabled async mode for single device, by ignoring enable_async call for it, assuming multi-device customers make a call to MeshDevice enable_async. However in some places including our test we actually iterate over each individual device in the mesh and call enable_async on it, which is being ignored ### What's changed Make a single call to MeshDevice::enable_async instead of iterating over individual devices and calling Device::enable_async for each one of them ### Checklist - [ ] [All post commit CI passes](https://github.com/tenstorrent/tt-metal/actions/runs/13553947437) - [x] [T3K demo tests CI passes](https://github.com/tenstorrent/tt-metal/actions/runs/13553950838) - [x] New/Existing tests provide coverage for changes (cherry picked from commit 69a36b8)
lots of fun TMP
…-with-mesh-rebase
Ticket
Link to Github Issue
Problem description
Provide context for the problem.
What's changed
Checklist