Skip to content

Commit

Permalink
add dim validation
Browse files Browse the repository at this point in the history
  • Loading branch information
saitcakmak committed Jan 21, 2025
1 parent 0df1c55 commit 312b0f1
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 1 deletion.
7 changes: 6 additions & 1 deletion gpytorch/distributions/multivariate_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,13 @@ def unsqueeze(self, dim: int) -> MultivariateNormal:
If `dim = -1`, then the returned MultivariateNormal will have
`batch_shape = torch.Size([2, 3, 1])`.
"""
# If dim is negative, get the positive equivalent.
if dim > len(self.batch_shape) or dim < -len(self.batch_shape) - 1:
raise IndexError(
"Dimension out of range (expected to be in range of "
f"[{-len(self.batch_shape) - 1}, {len(self.batch_shape)}], but got {dim})."
)
if dim < 0:
# If dim is negative, get the positive equivalent.
dim = len(self.batch_shape) + dim + 1

new_loc = self.loc.unsqueeze(dim)
Expand Down
12 changes: 12 additions & 0 deletions test/distributions/test_multivariate_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,8 +370,20 @@ def test_multivariate_normal_unsqueeze(self, cuda=False):
self.assertEqual(new.batch_shape, expected_batch)
self.assertEqual(new.event_shape, mvn.event_shape)
self.assertTrue(torch.equal(new.mean, mean.unsqueeze(positive_dim)))
self.assertEqual(new.mean.shape, expected_batch + torch.Size([3]))
self.assertTrue(torch.allclose(new.covariance_matrix, covmat.unsqueeze(positive_dim)))
self.assertEqual(new.covariance_matrix.shape, expected_batch + torch.Size([3, 3]))
self.assertTrue(torch.allclose(new.scale_tril, mvn.scale_tril.unsqueeze(positive_dim)))
self.assertEqual(new.scale_tril.shape, expected_batch + torch.Size([3, 3]))

# Check for dim validation.
with self.assertRaisesRegex(IndexError, "Dimension out of range"):
mvn.unsqueeze(3)
with self.assertRaisesRegex(IndexError, "Dimension out of range"):
mvn.unsqueeze(-4)
# Should not raise error up to 2 or -3.
mvn.unsqueeze(2)
mvn.unsqueeze(-3)


if __name__ == "__main__":
Expand Down

0 comments on commit 312b0f1

Please sign in to comment.