diff --git a/gpytorch/distributions/multivariate_normal.py b/gpytorch/distributions/multivariate_normal.py index 4bf91ebeb..f5a631a1f 100644 --- a/gpytorch/distributions/multivariate_normal.py +++ b/gpytorch/distributions/multivariate_normal.py @@ -161,6 +161,38 @@ def expand(self, batch_size: torch.Size) -> MultivariateNormal: new.covariance_matrix = self.covariance_matrix return new + def unsqueeze(self, dim: int) -> MultivariateNormal: + r""" + Constructs a new MultivariateNormal with the batch shape unsqueezed + by the given dimension. + For example, if `self.batch_shape = torch.Size([2, 3])` and `dim = 0`, then + the returned MultivariateNormal will have `batch_shape = torch.Size([1, 2, 3])`. + If `dim = -1`, then the returned MultivariateNormal will have + `batch_shape = torch.Size([2, 3, 1])`. + """ + # If dim is negative, get the positive equivalent. + if dim < 0: + dim = len(self.batch_shape) + dim + 1 + + new_loc = self.loc.unsqueeze(dim) + if self.islazy: + new_covar = self._covar.unsqueeze(dim) + new = self.__class__(mean=new_loc, covariance_matrix=new_covar) + if self.__unbroadcasted_scale_tril is not None: + # Reuse the scale tril if available. + new.__unbroadcasted_scale_tril = self.__unbroadcasted_scale_tril.unsqueeze(dim) + else: + # Non-lazy MVN is represented using scale_tril in PyTorch. + # Constructing it from scale_tril will avoid unnecessary computation. + # Initialize using __new__, so that we can skip __init__ and use scale_tril. + new = self.__new__(type(self)) + new._islazy = False + new_scale_tril = self.__unbroadcasted_scale_tril.unsqueeze(dim) + super(MultivariateNormal, new).__init__(loc=new_loc, scale_tril=new_scale_tril) + # Set the covar matrix, since it is always available for GPyTorch MVN. + new.covariance_matrix = self.covariance_matrix.unsqueeze(dim) + return new + def get_base_samples(self, sample_shape: torch.Size = torch.Size()) -> Tensor: r""" Returns i.i.d. standard Normal samples to be used with diff --git a/test/distributions/test_multivariate_normal.py b/test/distributions/test_multivariate_normal.py index ed54b50ec..f0de1ebdf 100644 --- a/test/distributions/test_multivariate_normal.py +++ b/test/distributions/test_multivariate_normal.py @@ -346,6 +346,30 @@ def test_multivariate_normal_expand(self, cuda=False): self.assertTrue(torch.allclose(expanded.covariance_matrix, covmat.expand(2, -1, -1))) self.assertTrue(torch.allclose(expanded.scale_tril, mvn.scale_tril.expand(2, -1, -1))) + def test_multivariate_normal_unsqueeze(self, cuda=False): + device = torch.device("cuda") if cuda else torch.device("cpu") + for dtype, lazy in product((torch.float, torch.double), (True, False)): + batch_shape = torch.Size([2, 3]) + mean = torch.tensor([0, 1, 2], device=device, dtype=dtype).expand(*batch_shape, -1) + covmat = torch.diag(torch.tensor([1, 0.75, 1.5], device=device, dtype=dtype)).expand(*batch_shape, -1, -1) + if lazy: + mvn = MultivariateNormal(mean=mean, covariance_matrix=DenseLinearOperator(covmat), validate_args=True) + # Initialize scale tril so we can test that it was unsqueezed. + mvn.scale_tril + else: + mvn = MultivariateNormal(mean=mean, covariance_matrix=covmat, validate_args=True) + self.assertEqual(mvn.batch_shape, batch_shape) + self.assertEqual(mvn.islazy, lazy) + for dim, positive_dim, expected_batch in ((1, 1, torch.Size([2, 1, 3])), (-1, 2, torch.Size([2, 3, 1]))): + new = mvn.unsqueeze(dim) + self.assertIsInstance(new, MultivariateNormal) + self.assertEqual(new.islazy, lazy) + self.assertEqual(new.batch_shape, expected_batch) + self.assertEqual(new.event_shape, mvn.event_shape) + self.assertTrue(torch.equal(new.mean, mean.unsqueeze(positive_dim))) + self.assertTrue(torch.allclose(new.covariance_matrix, covmat.unsqueeze(positive_dim))) + self.assertTrue(torch.allclose(new.scale_tril, mvn.scale_tril.unsqueeze(positive_dim))) + if __name__ == "__main__": unittest.main()