-
Notifications
You must be signed in to change notification settings - Fork 161
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
additional features for NPSE #1370
base: main
Are you sure you want to change the base?
additional features for NPSE #1370
Conversation
53bd4f9
to
0b49bf3
Compare
9f24294
to
bcea468
Compare
specified |
I've requested review now. While batched sampling for score-based posteriors is now possible and tested for, IID sampling is still not possible, but talking to @manuelgloeckler about this, maybe this can be done in a new PR. Other than that, I've also noticed while testing that sampling from the posterior with ode can be much less accurate than via diffusion. So the test |
e59d6d0
to
0d29c8a
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1370 +/- ##
===========================================
- Coverage 89.31% 78.24% -11.07%
===========================================
Files 119 119
Lines 8779 8850 +71
===========================================
- Hits 7841 6925 -916
- Misses 938 1925 +987
Flags with carried forward coverage won't be shown. Click here to find out more.
|
I started integrating the IID stuff into the current version of this branch and created a new PR for it (#1381). So, lets first get this merged the IID PR still requires some work from my side. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great! Thanks a lot for addressing all these issues with the current NPSE. 🚀
I left a couple of suggestions and questions for my understanding. Happy to discuss in person if needed.
@@ -136,28 +139,47 @@ def sample( | |||
|
|||
x = self._x_else_default_x(x) | |||
x = reshape_to_batch_event(x, self.score_estimator.condition_shape) | |||
self.potential_fn.set_x(x) | |||
self.potential_fn.set_x(x, x_is_iid=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if iid is not working yet, why is this set to True
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are in the sample
function, as opposed to sample_batched
. When x
has a batch size of 1, the x_is_iid
currently has no effect. Specifying this flag here makes sure that if x
has batch_size >1, the potential raises an error that IID sampling is not yet implemented, as opposed to trying to sample on a batch of conditions (in which case, the user should call posterior.sample_batched
instead)
max_sampling_batch_size=max_sampling_batch_size, | ||
proposal_sampling_kwargs={"x": x}, | ||
)[0] | ||
samples = samples.reshape(sample_shape + self.score_estimator.input_shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move this below into the return statement to reduce repetition?
@@ -222,12 +244,12 @@ def _sample_via_diffusion( | |||
) | |||
samples = torch.cat(samples, dim=0)[:num_samples] | |||
|
|||
return samples.reshape(sample_shape + self.score_estimator.input_shape) | |||
return samples |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this not needed anymore? Cause it's handled after accept_and_reject sampling in the public sample
method?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed the sampling shape in Diffuser
, see
sbi/sbi/samplers/score/diffuser.py
Line 103 in 0d29c8a
init_shape = (num_samples, num_batch) + self.input_shape |
I did not see a good reason why the sample and batch dimensions were treated differently there, so this was changed to match our usual shape conventions, and so we don't need to do this additional reshaping here.
|
||
return samples | ||
|
||
def _sample_via_diffusion( | ||
self, | ||
sample_shape: Shape = torch.Size(), | ||
x: Optional[Tensor] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this needed if it's deprecated as stated in the docstring?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm, I think that this is a leftover, so it indeed should be removed.I will also remove the dependence on x
from _sample_with_zuko
as it should also use the default x
.
def map( | ||
self, | ||
x: Optional[Tensor] = None, | ||
num_iter: int = 1000, | ||
num_to_optimize: int = 1000, | ||
learning_rate: float = 1e-5, | ||
learning_rate: float = 0.01, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is specific for score posteriors?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, learning_rate = 0.01
was what we use in all the other posterior.map()
methods, I am guessing 1e-5
was a leftover from some previous debugging attempts. There is no reason to have it so drastically different for score posteriors as far as I can tell.
val_loss_sum += val_losses.sum().item() | ||
|
||
# Take mean over all validation samples. | ||
val_loss = val_loss_sum / ( | ||
len(val_loader) * val_loader.batch_size # type: ignore | ||
len(val_loader) * val_loader.batch_size * times_batch # type: ignore | ||
) | ||
|
||
# NOTE: Due to the inherently noisy nature we do instead log a exponential |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this comment is difficult to follow. can we give more details here? or move it to the docstring of train?
best model. We noticed that this improves performance. Deleting this method | ||
will make C2ST tests fail. This is because the loss is very stochastic, so | ||
resetting might reset to an underfitted model. Ideally, we would write a | ||
custom `._converged()` method which checks whether the loss is still going | ||
down **for all t**. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is resolved now with the new ema check?
@@ -219,6 +219,8 @@ def accept_reject_sample( | |||
rejected. Must take a batch of parameters and return a boolean tensor which |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the docstring above is not correct, should be proposal
no? Can you please correct this and add that this is a callable now that takes a sample shape as arg and additional kwargs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catch, yes I will update this!
num_batch = self.batch_shape.numel() | ||
init_shape = (num_samples, num_batch) + self.input_shape | ||
# init_shape = ( | ||
# num_samples, | ||
# ) + self.input_shape # just use num_samples, not num_batch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does this change implicate? aren't we still raising an error for the iid setting and fixing it in the other PR by Manuel?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As far as I can tell, the former just matches our normal shape conventions, where as the (now commented out) version does not. But if there is some reason to keep the lines that are now commented out, we can revert this. @manuelgloeckler do you have any thoughts on this?
@@ -234,4 +230,4 @@ def test_npse_map(): | |||
|
|||
map_ = posterior.map(show_progress_bars=True) | |||
|
|||
assert torch.allclose(map_, gt_posterior.mean, atol=0.2), "MAP is not close to GT." | |||
assert torch.allclose(map_, gt_posterior.mean, atol=0.4), "MAP is not close to GT." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this increase in tolerance is still reasonable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure. I ran this test in a notebook to see why it was failing with atol=0.2
, and saw that the MAP of the approximate posterior was estimated correctly (by looking at it on the pairplot of samples from the approximate posterior), but the posterior itself we estimated was a bit off - I guess not enough to fail our posterior tests, but enough to shift the MAP enough for this test to fail with atol=0.2
.
Thanks for the comments @janfb! I've tried to answer some of your questions, and will make the appropriate changes soon! |
What does this implement/fix? Explain your changes
This introduces some additional features for score estimation named in #1226, namely:
enable_transform = True
for score-based potentialsconverged()
method for NPSEDoes this close any currently open issues?
#1226
Any relevant code examples, logs, error output, etc?
Any other comments?
score_based_posterior.map()
is still quite slow. We get the gradient of the log probs with respect totheta
by using the score estimator, but still computing the log-probs explicitly ingradient_ascent
, which is more expensive. To get around this, we save a low-accuracy ode_flow to calculate the log-probs more quickly. Ideally, we might want to write a customgradient_ascent
function for calculating the MAP for score estimators to avoid doing this altogether.linearGaussian_npse_test.py::test_npse_map
- as far as I can tell, the reason this failed with the lower tolerance is not because of MAP calculation, but because score-based posteriors are currently slightly less accurate (at least for our test tasks).