Skip to content

Commit

Permalink
enforce num_classes to be an int (#1728)
Browse files Browse the repository at this point in the history
Co-authored-by: Max Balandat <[email protected]>
  • Loading branch information
wjmaddox and Balandat authored Aug 31, 2021
1 parent e4579ed commit 41de1f4
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion gpytorch/likelihoods/gaussian_likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ class DirichletClassificationLikelihood(FixedNoiseGaussianLikelihood):
"""

def _prepare_targets(self, targets, alpha_epsilon=0.01, dtype=torch.float):
num_classes = targets.max() + 1
num_classes = int(targets.max() + 1)
# set alpha = \alpha_\epsilon
alpha = alpha_epsilon * torch.ones(targets.shape[-1], num_classes, device=targets.device, dtype=dtype)

Expand Down

0 comments on commit 41de1f4

Please sign in to comment.