Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for scatter(torch.scatter) #1444

Closed
wants to merge 2 commits into from

Conversation

mokeyish
Copy link
Contributor

@mokeyish
Copy link
Contributor Author

@LaurentMazare do you have time to review this PR.

@LaurentMazare
Copy link
Collaborator

Couldn't we use scatter_add with some zeros tensors to achieve this? It would be less efficient but the idea is to keep the number of basic op as limited as possible, especially as we introduce new backends having more ops result in additional work and maintenance so would prefer avoiding it unless it's really necessary.

@mokeyish
Copy link
Contributor Author

I understand your concerns. It does work by adding a negative number first using scatter_add

But when the index is non-unique, it cannot be completely equivalent to torch.scatter.
image

I tried it and found that scatter can be implemented in candle-ext.

Just calling copy_strided_src_ (which is pub crate) requires a twist.

image

@mokeyish mokeyish closed this Dec 23, 2023
@mokeyish mokeyish deleted the scatter branch December 23, 2023 04:22
@mokeyish mokeyish restored the scatter branch December 23, 2023 04:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants