Skip to content

Commit

Permalink
Address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
zombie-einstein committed Jan 9, 2025
1 parent fe8880a commit b52fefd
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 18 deletions.
5 changes: 3 additions & 2 deletions jumanji/environments/logic/rubiks_cube/env_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,9 @@ def test_rubiks_cube__done(time_limit: int) -> None:
assert episode_length == time_limit


def test_rubiks_cube__animate(rubiks_cube: RubiksCube, mocker: pytest_mock.MockerFixture) -> None:
def test_rubiks_cube_animate(rubiks_cube: RubiksCube, mocker: pytest_mock.MockerFixture) -> None:
"""Test that the `animate` method creates the animation correctly (but does not display it)."""
states = mocker.MagicMock()
state, _ = rubiks_cube.reset(jax.random.PRNGKey(0))
states = [state] * 5
animation = rubiks_cube.animate(states)
assert isinstance(animation, matplotlib.animation.Animation)
10 changes: 5 additions & 5 deletions jumanji/environments/logic/rubiks_cube/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,12 @@ def animate(
ax = ax.flatten()
plt.close(fig)

images = self._draw(ax, states[0])
faces = self._draw(ax, states[0])

def make_frame(state: State) -> Sequence[Artist]:
for i, image in enumerate(images):
image.set_data(state.cube[i])
return images
for i, face in enumerate(faces):
face.set_data(state.cube[i])
return faces

# Create the animation object.
self._animation = matplotlib.animation.FuncAnimation(
Expand Down Expand Up @@ -111,7 +111,7 @@ def _get_fig_ax(self) -> Tuple[plt.Figure, List[plt.Axes]]:
return fig, ax

def _draw(self, ax: List[plt.Axes], state: State) -> List[AxesImage]:
images = list()
images = []

for i, face in enumerate(Face):
ax[i].clear()
Expand Down
14 changes: 4 additions & 10 deletions jumanji/environments/packing/tetris/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,17 +198,11 @@ def animate(
fig, ax = plt.subplots(num=f"{self._name}Animation", figsize=TetrisViewer.FIGURE_SIZE)
plt.close(fig)

def make_frame(grid_index: int) -> Tuple[Artist]:
"""creates a frames for each state
Args:
grid_index: `int`
"""
def make_frame(frame_data: Tuple[chex.Array, chex.Numeric]) -> Tuple[Artist]:
grid, score = frame_data
ax.clear()
ax.invert_yaxis()
fig.suptitle(f"Tetris Score: {int(scores[grid_index])}", size=20)
grid = grids[grid_index]

fig.suptitle(f"Tetris Score: {int(score)}", size=20)
self._add_grid_image(ax, grid)
return (ax,)

Expand All @@ -233,7 +227,7 @@ def make_frame(grid_index: int) -> Tuple[Artist]:
self._animation = matplotlib.animation.FuncAnimation(
fig,
make_frame,
frames=range(len(grids)),
frames=zip(grids, scores, strict=False),
interval=interval,
)

Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ dm-env>=1.5
gymnasium>=1.0
huggingface-hub
jax>=0.2.26,<0.4.36
matplotlib>3.8.0
matplotlib>=3.8.0
numpy>=1.19.5
Pillow>=9.0.0
typing-extensions>=4.0.0

0 comments on commit b52fefd

Please sign in to comment.