Skip to content

Commit

Permalink
Add batch and event shape to repr. (#1946)
Browse files Browse the repository at this point in the history
  • Loading branch information
tillahoffmann authored Jan 12, 2025
1 parent 6ae76ea commit d1ca868
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 0 deletions.
7 changes: 7 additions & 0 deletions numpyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,13 @@ def icdf(self, q: ArrayLike) -> ArrayLike:
def is_discrete(self):
return self.support.is_discrete

def __repr__(self) -> str:
cls = self.__class__
return (
f"<{cls.__module__}.{cls.__name__} object at {id(self):#x} with batch "
f"shape {self.batch_shape} and event shape {self.event_shape}>"
)


@runtime_checkable
class DistributionLike(Protocol):
Expand Down
6 changes: 6 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3681,3 +3681,9 @@ def icdf(self, q):

assert my_dist.mean.shape == (2, 3, 4)
assert my_dist.variance.shape == (2, 3, 4)


def test_distribution_repr():
result = repr(dist.Wishart(7, jnp.eye(5)).expand([3, 4]).to_event(1))
assert "batch shape (3,)" in result
assert "event shape (4, 5, 5)"

0 comments on commit d1ca868

Please sign in to comment.