-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
#2065: Updated all reduce code to handle 0 or 1 mesh cluster axis #2215
Conversation
@tapspatel Can you wait for this PR to be merged? #2149 I had to fix some of the code for importing all_reduce and it's conflicting now. I believe you can apply this change on top of the pr. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thumbs up for runtime
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment in line, otherwise runtime changes look good.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Couple of comments inline.
3002451
to
a846738
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great follow-up! Few comments, mostly nits, but there are a few that might be bugs, so make sure to check them before merging.
a846738
to
26fddac
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clang-Tidy
found issue(s) with the introduced code (1/1)
26fddac
to
ac11675
Compare
Perhaps the commit title should become the commit message and the title could be something a bit more concise? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The TTNN interfaces I think need to change, some other comments inline, but otherwise looks good!
auto firstElementIt = replicaGroups.begin(); | ||
auto secondElementIt = firstElementIt + 1; | ||
|
||
clusterAxis = (((*firstElementIt) + 1) == *secondElementIt); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't this assume that device IDs are consecutive and in a particular topology?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From docs: https://docs.jax.dev/en/latest/_autosummary/jax.make_mesh.html
Essentially, the mesh device ordering is determined by the TPU structure internally. It will optimize the ordering based on the mesh shape. However, if you simulate CPUs, it orders by consecutive numbers. To simplify this, it is an assumption that the mesh provided by tt-xla will be monotonically increasing, because we don't have a way to propagate an efficient mesh mapping based on hardware structure yet in compiler (but this can be done).
I can add this assumption as a comment for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For topologies, we are only supporting 2d grid topologies in compiler from conversation with Wooseok.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have observed the ascending device ids from 0 to N-1 so far as well, so I guess it's fine. In the future, we may leverage the device IDs to best distribute data for our hardware config, but didn't get there yet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clang-Tidy
found issue(s) with the introduced code (1/1)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, thank you!
ac11675
to
d5d9ebd
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you address some of my comments/requests?
auto firstElementIt = replicaGroups.begin(); | ||
auto secondElementIt = firstElementIt + 1; | ||
|
||
clusterAxis = (((*firstElementIt) + 1) == *secondElementIt); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have observed the ascending device ids from 0 to N-1 so far as well, so I guess it's fine. In the future, we may leverage the device IDs to best distribute data for our hardware config, but didn't get there yet.
d5d9ebd
to
959ef8e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clang-Tidy
found issue(s) with the introduced code (1/1)
959ef8e
to
ca39af3
Compare
@wooseokTT let me know if the responses to all the comments are sufficient. |
ca39af3
to
0cec5f6
Compare
…ned up dialect representations of all reduce in ttir and ttnn. Update algorithms for calculating gather and scatter dimensions
0cec5f6
to
f02b2fa
Compare
) Updated all reduce code to handle 0 or 1 cluster axis and cleaned up dialect representations of all reduce in ttir and ttnn. Update algorithms for calculating gather and scatter dimensions. Migrated all workaround code into TTNN workaround pass such that we don't clog up ttir or ttnn definitions of all_reduce.
#2065: Updated all reduce code to handle 0 or 1 cluster axis and cleaned up dialect representations of all reduce in ttir and ttnn. Update algorithms for calculating gather and scatter dimensions. Migrated all workaround code into TTNN workaround pass such that we don't clog up ttir or ttnn definitions of all_reduce.