Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#2065: Updated all reduce code to handle 0 or 1 mesh cluster axis #2215

Merged
merged 1 commit into from
Feb 27, 2025

Conversation

tapspatel
Copy link
Collaborator

@tapspatel tapspatel commented Feb 19, 2025

#2065: Updated all reduce code to handle 0 or 1 cluster axis and cleaned up dialect representations of all reduce in ttir and ttnn. Update algorithms for calculating gather and scatter dimensions. Migrated all workaround code into TTNN workaround pass such that we don't clog up ttir or ttnn definitions of all_reduce.

@wooseokTT
Copy link
Contributor

@tapspatel Can you wait for this PR to be merged? #2149 I had to fix some of the code for importing all_reduce and it's conflicting now. I believe you can apply this change on top of the pr.

Copy link
Contributor

@kmabeeTT kmabeeTT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thumbs up for runtime

Copy link
Contributor

@jnie-TT jnie-TT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment in line, otherwise runtime changes look good.

Copy link
Contributor

@sdjordjevicTT sdjordjevicTT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couple of comments inline.

Copy link
Contributor

@azecevicTT azecevicTT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great follow-up! Few comments, mostly nits, but there are a few that might be bugs, so make sure to check them before merging.

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Clang-Tidy found issue(s) with the introduced code (1/1)

@nsmithtt
Copy link
Contributor

Perhaps the commit title should become the commit message and the title could be something a bit more concise?

@tapspatel tapspatel changed the title #2065: Updated all reduce code to handle 0 or 1 cluster axis and cleaned up dialect representations of all reduce in ttir and ttnn. Update algorithms for calculating gather and scatter dimensions #2065: Updated all reduce code to handle 0 or 1 mesh cluster axis Feb 21, 2025
Copy link
Contributor

@nsmithtt nsmithtt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TTNN interfaces I think need to change, some other comments inline, but otherwise looks good!

auto firstElementIt = replicaGroups.begin();
auto secondElementIt = firstElementIt + 1;

clusterAxis = (((*firstElementIt) + 1) == *secondElementIt);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this assume that device IDs are consecutive and in a particular topology?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From docs: https://docs.jax.dev/en/latest/_autosummary/jax.make_mesh.html

Essentially, the mesh device ordering is determined by the TPU structure internally. It will optimize the ordering based on the mesh shape. However, if you simulate CPUs, it orders by consecutive numbers. To simplify this, it is an assumption that the mesh provided by tt-xla will be monotonically increasing, because we don't have a way to propagate an efficient mesh mapping based on hardware structure yet in compiler (but this can be done).

I can add this assumption as a comment for now.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For topologies, we are only supporting 2d grid topologies in compiler from conversation with Wooseok.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have observed the ascending device ids from 0 to N-1 so far as well, so I guess it's fine. In the future, we may leverage the device IDs to best distribute data for our hardware config, but didn't get there yet.

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Clang-Tidy found issue(s) with the introduced code (1/1)

@tapspatel tapspatel requested a review from nsmithtt February 21, 2025 21:49
Copy link
Contributor

@nsmithtt nsmithtt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thank you!

Copy link
Contributor

@wooseokTT wooseokTT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you address some of my comments/requests?

auto firstElementIt = replicaGroups.begin();
auto secondElementIt = firstElementIt + 1;

clusterAxis = (((*firstElementIt) + 1) == *secondElementIt);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have observed the ascending device ids from 0 to N-1 so far as well, so I guess it's fine. In the future, we may leverage the device IDs to best distribute data for our hardware config, but didn't get there yet.

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Clang-Tidy found issue(s) with the introduced code (1/1)

@tapspatel
Copy link
Collaborator Author

@wooseokTT let me know if the responses to all the comments are sufficient.

…ned up dialect representations of all reduce in ttir and ttnn. Update algorithms for calculating gather and scatter dimensions
@tapspatel tapspatel merged commit 834fc78 into main Feb 27, 2025
31 checks passed
@tapspatel tapspatel deleted the tpatel/issue-2065 branch February 27, 2025 19:29
jserbedzijaTT pushed a commit that referenced this pull request Mar 1, 2025
)

Updated all reduce code to handle 0 or 1 cluster axis and cleaned
up dialect representations of all reduce in ttir and ttnn. Update
algorithms for calculating gather and scatter dimensions. Migrated all
workaround code into TTNN workaround pass such that we don't clog up
ttir or ttnn definitions of all_reduce.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Update all_reduce shlo conversion to ttnn with dynamic cluster_axis
7 participants