-
Notifications
You must be signed in to change notification settings - Fork 190
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MinSR regularisation #1696
Comments
I was discussing it this morning as well. I think their algorithm is resumed well in this figure The major difference is line 8 of the algorithm, which requires replacing the local energies with something that depends on the previous gradient. I think the regularisation you talk about is at line 9 where they also add a full matrix of Then, they add momentum and gradient clipping which is easy to do with optax. |
Yeah, I was mostly talking about this - it should improve MinSR too, and seems like a quick change there to someone who's worked on that code EDIT. I've just noticed the second half of your comment. I think it could be an improvement - this is guaranteed to be a zero eigenvector, so if you eliminate it, MinSR might be stable with lower EDIT 2. I also suppose we'd recycle the MinSR code if we ever implement the new algorithm, so it makes sense to add it anyway. |
I think it's as simple as simply replacing this line matrix = matrix + diag_shift * jnp.eye(
matrix_side
) + jnp.full(matrix.shape, 1/N_mc) |
Probably the easiest thing is to implement it in another driver starting from |
Is this something (especially momentum) that we can encapsulate into a new driver? The updates would be wrong without adding the momentum, so it wouldn't make sense to leave it to the user to add it with an external momentum driver in optax (especially that conventions differ, so it would be an invitation for errors). Personally, I also don't see the point for using a separate library for this particular line of the algorithm, which is basically a single FMA. Gradient clipping is a different story, I think it can be left up to users whether they like it and do it themselves in optax. |
The addition of the projector P = (1/N) one_vec * one_vec^T should be scaled since OO^T -> OO^T + cP makes one_vec go from an eigen value 0 -> c , but for c=1 if the other eigenvalues of OO^T are all much less than 1, then it becomes the dominate eigen value and can leave the matrix still ill conditioned (going from one extreme to the other). I instead rescaled by c=tr( OO^T) /N =avg(eig(OO^T)) which ensures that it becomes neither the largest nor smallest eigen value. This worked well for me, also the Kaczmarz/momentum parts did make a noticeably improvement over minSR for 1D Heisenberg with RBMs. |
@attila-i-szabo this is why it's better to regularize MinSR by ignoring the modes where the eigenvalues are below some threshold rather than adding a diag_shift. Just throwing away this zero mode works completely fine. Here's some pseudocode
I find that |
I've just seen this paper. The algorithm they describe is interesting too, and we might want to consider supporting it, but there's a more immediate point. In Sec. 3.2, they explain that the MinSR matrix is guaranteed to have (1,1,...,1) as a zero vector, so regularising it (a lot) is a good idea to protect from numerical instability (in exact arithmetic, adding any multiple of the all-ones matrix makes no difference).
Does NetKet already do something similar?
The text was updated successfully, but these errors were encountered: