Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check and update return sharding #2298

Merged
merged 1 commit into from
Feb 28, 2025

Conversation

wooseokTT
Copy link
Contributor

Ticket

#2297

Problem description

PJRT expects the return values to be sharded instead of being merged into a tensor with a single buffer if there is sharding attribute in return value.

What's changed

This PR provides special care for the return values just like we do for input tensors. It checks the return value attribute and if the sharding attribute is there, mark the returning mesh_shard op as manual such that runtime will not perform any concat operation and returns multiple buffers.

Checklist

  • [V] New/Existing tests provide coverage for changes

@wooseokTT wooseokTT requested a review from nsmithtt February 26, 2025 15:47
@wooseokTT wooseokTT force-pushed the wooseok/fix_mesh_shard_return_duplicate branch from 4220f21 to 0d487fa Compare February 26, 2025 16:00
@wooseokTT wooseokTT linked an issue Feb 26, 2025 that may be closed by this pull request
@wooseokTT wooseokTT force-pushed the wooseok/fix_mesh_shard_return_duplicate branch 2 times, most recently from 1e7e4c9 to a4b8507 Compare February 27, 2025 17:25
@wooseokTT wooseokTT force-pushed the wooseok/fix_mesh_shard_return_duplicate branch from a4b8507 to 83a6bc0 Compare February 28, 2025 15:04
Copy link
Contributor

@nsmithtt nsmithtt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think just make sure to fix the ifdef because that can lead to confusing compiler errors. Lmk what you think about the enum naming otherwise lgtm! Thanks

@wooseokTT wooseokTT force-pushed the wooseok/fix_mesh_shard_return_duplicate branch from 83a6bc0 to 77a210e Compare February 28, 2025 20:08
@wooseokTT wooseokTT merged commit a8c29e0 into main Feb 28, 2025
31 checks passed
@wooseokTT wooseokTT deleted the wooseok/fix_mesh_shard_return_duplicate branch February 28, 2025 22:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Multi-Device] Handle automatic input/result sharding in JAX
3 participants