Skip to content

Commit

Permalink
implement MVN.unsqueeze
Browse files Browse the repository at this point in the history
  • Loading branch information
saitcakmak committed Jan 17, 2025
1 parent 9eaecdc commit 6287b0b
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 0 deletions.
32 changes: 32 additions & 0 deletions gpytorch/distributions/multivariate_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,38 @@ def expand(self, batch_size: torch.Size) -> MultivariateNormal:
new.covariance_matrix = self.covariance_matrix
return new

def unsqueeze(self, dim: int) -> MultivariateNormal:
r"""
Constructs a new MultivariateNormal with the batch shape unsqueezed
by the given dimension.
For example, if `self.batch_shape = torch.Size([2, 3])` and `dim = 0`, then
the returned MultivariateNormal will have `batch_shape = torch.Size([1, 2, 3])`.
If `dim = -1`, then the returned MultivariateNormal will have
`batch_shape = torch.Size([2, 3, 1])`.
"""
# If dim is negative, get the positive equivalent.
if dim < 0:
dim = len(self.batch_shape) + dim + 1

new_loc = self.loc.unsqueeze(dim)
if self.islazy:
new_covar = self._covar.unsqueeze(dim)
new = self.__class__(mean=new_loc, covariance_matrix=new_covar)
if self.__unbroadcasted_scale_tril is not None:
# Reuse the scale tril if available.
new.__unbroadcasted_scale_tril = self.__unbroadcasted_scale_tril.unsqueeze(dim)
else:
# Non-lazy MVN is represented using scale_tril in PyTorch.
# Constructing it from scale_tril will avoid unnecessary computation.
# Initialize using __new__, so that we can skip __init__ and use scale_tril.
new = self.__new__(type(self))
new._islazy = False
new_scale_tril = self.__unbroadcasted_scale_tril.unsqueeze(dim)
super(MultivariateNormal, new).__init__(loc=new_loc, scale_tril=new_scale_tril)
# Set the covar matrix, since it is always available for GPyTorch MVN.
new.covariance_matrix = self.covariance_matrix.unsqueeze(dim)
return new

def get_base_samples(self, sample_shape: torch.Size = torch.Size()) -> Tensor:
r"""
Returns i.i.d. standard Normal samples to be used with
Expand Down
24 changes: 24 additions & 0 deletions test/distributions/test_multivariate_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,30 @@ def test_multivariate_normal_expand(self, cuda=False):
self.assertTrue(torch.allclose(expanded.covariance_matrix, covmat.expand(2, -1, -1)))
self.assertTrue(torch.allclose(expanded.scale_tril, mvn.scale_tril.expand(2, -1, -1)))

def test_multivariate_normal_unsqueeze(self, cuda=False):
device = torch.device("cuda") if cuda else torch.device("cpu")
for dtype, lazy in product((torch.float, torch.double), (True, False)):
batch_shape = torch.Size([2, 3])
mean = torch.tensor([0, 1, 2], device=device, dtype=dtype).expand(*batch_shape, -1)
covmat = torch.diag(torch.tensor([1, 0.75, 1.5], device=device, dtype=dtype)).expand(*batch_shape, -1, -1)
if lazy:
mvn = MultivariateNormal(mean=mean, covariance_matrix=DenseLinearOperator(covmat), validate_args=True)
# Initialize scale tril so we can test that it was unsqueezed.
mvn.scale_tril
else:
mvn = MultivariateNormal(mean=mean, covariance_matrix=covmat, validate_args=True)
self.assertEqual(mvn.batch_shape, batch_shape)
self.assertEqual(mvn.islazy, lazy)
for dim, positive_dim, expected_batch in ((1, 1, torch.Size([2, 1, 3])), (-1, 2, torch.Size([2, 3, 1]))):
new = mvn.unsqueeze(dim)
self.assertIsInstance(new, MultivariateNormal)
self.assertEqual(new.islazy, lazy)
self.assertEqual(new.batch_shape, expected_batch)
self.assertEqual(new.event_shape, mvn.event_shape)
self.assertTrue(torch.equal(new.mean, mean.unsqueeze(positive_dim)))
self.assertTrue(torch.allclose(new.covariance_matrix, covmat.unsqueeze(positive_dim)))
self.assertTrue(torch.allclose(new.scale_tril, mvn.scale_tril.unsqueeze(positive_dim)))


if __name__ == "__main__":
unittest.main()

0 comments on commit 6287b0b

Please sign in to comment.