-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Derive associative scan algorithm for factorization #28
Comments
Hi @dfm, sorry to be plaguing you 😅. I'm working on a JAX project with GPU acceleration and I'd like to use celerite2. If I use it out of the box, I get a warning that says:
which brought me here. Is this still on the to-do list? |
There is no GPU support planned for celerite2. It's possible to parallelize some of the algorithms but it's slower than the CPU version for all the tests I've done and scales badly with J (J^3 instead of J^2). |
Thanks for the quick response! If you have any pointers on alternatives I'd be grateful. |
I don't know of any good JAX libraries for GPs, but it's not too hard to implement the math yourself to try it out. If the GP is your bottleneck, I think it's unlikely that you'll get any benefit from using a GPU, but if your computation is dominated by other parts of the model that are improved by GPU acceleration and not too many data points then it might be worth it. Here's an example implementation of naive GP computations using JAX + GPU acceleration that could get you started: https://github.com/dfm/tinygp/blob/main/src/tinygp/gp.py |
Thanks so much, as always! |
I've derived the algorithms for matrix multiplication and solves, but I haven't been able to work out the factorization algorithm yet. There don't seem to be numerical issues for the ops that I've derived so far, but I haven't extensively tested it. This would be interesting because it would allow parallel implementation on a GPU.
The text was updated successfully, but these errors were encountered: