Skip to content

Commit

Permalink
update docker image and the run file
Browse files Browse the repository at this point in the history
  • Loading branch information
scott-yj-yang committed Dec 6, 2022
1 parent 9740893 commit 0d22389
Show file tree
Hide file tree
Showing 3 changed files with 328 additions and 3 deletions.
329 changes: 327 additions & 2 deletions run.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,328 @@
import sys
sys.path.insert(0,'src')

#@title Run to install MuJoCo and `dm_control`
import distutils.util
import subprocess
if __name__ == "main":
subprocess.call("src/train_agent.py --num_steps=10")
# if subprocess.run('nvidia-smi').returncode:
# raise RuntimeError(
# 'Cannot communicate with GPU. '
# 'Make sure you are using a GPU Colab runtime. '
# 'Go to the Runtime menu and select Choose runtime type.')

print('Installing dm_control...')
# !pip install -q dm_control==1.0.8

# # Configure dm_control to use the EGL rendering backend (requires GPU)
# %env MUJOCO_GL=osmesa
# %env PYOPENGL_PLATFORM=
# %env PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python

print('Checking that the dm_control installation succeeded...')
try:
from dm_control import suite
env = suite.load('cartpole', 'swingup')
pixels = env.physics.render()
except Exception as e:
raise e from RuntimeError(
'Something went wrong during installation. Check the shell output above '
'for more information.\n'
'If using a hosted Colab runtime, make sure you enable GPU acceleration '
'by going to the Runtime menu and selecting "Choose runtime type".')
else:
del pixels, suite

# !echo Installed dm_control $(pip show dm_control | grep -Po "(?<=Version: ).+")
# !rm -r "=1.0.8"

#@title All `dm_control` imports required for this tutorial

# The basic mujoco wrapper.
from dm_control import mujoco

# Access to enums and MuJoCo library functions.
from dm_control.mujoco.wrapper.mjbindings import enums
from dm_control.mujoco.wrapper.mjbindings import mjlib

# PyMJCF
from dm_control import mjcf

# Composer high level imports
from dm_control import composer
from dm_control.composer.observation import observable
from dm_control.composer import variation

# Imports for Composer tutorial example
from dm_control.composer.variation import distributions
from dm_control.composer.variation import noises
from dm_control.locomotion.arenas import floors

# Control Suite
from dm_control import suite

# Run through corridor example
from dm_control.locomotion.walkers import cmu_humanoid
from dm_control.locomotion.arenas import corridors as corridor_arenas
from dm_control.locomotion.tasks import corridors as corridor_tasks

# # Soccer
# from dm_control.locomotion import soccer

# Manipulation
from dm_control import manipulation

#@title Other imports and helper functions

# General
import copy
import os
import itertools
from IPython.display import clear_output
import numpy as np

# Graphics-related
import matplotlib
import matplotlib.animation as animation
import matplotlib.pyplot as plt
from IPython.display import HTML
import PIL.Image
# Internal loading of video libraries.

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
from torch.optim import Adam
# from torch.utils.tensorboard import SummaryWriter

# import this first to resolve the issue.
from acme import wrappers
from model import *
from utils import *
# Soft-Actor-Critic Model
from sac import *
from replay_memory import *
import argparse
import datetime
import itertools
import os
import random
import math
import pickle


# try out the wrappers
from acme import wrappers
from dm_control import suite

# # Use svg backend for figure rendering
# %config InlineBackend.figure_format = 'svg'

# Font sizes
SMALL_SIZE = 8
MEDIUM_SIZE = 10
BIGGER_SIZE = 12
plt.rc('font', size=SMALL_SIZE) # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE) # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE) # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE) # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE) # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE) # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE) # fontsize of the figure title

# Inline video helper function
if os.environ.get('COLAB_NOTEBOOK_TEST', False):
# We skip video generation during tests, as it is quite expensive.
display_video = lambda *args, **kwargs: None
else:
def display_video(frames, framerate=30):
height, width, _ = frames[0].shape
dpi = 70
orig_backend = matplotlib.get_backend()
matplotlib.use('Agg') # Switch to headless 'Agg' to inhibit figure rendering.
fig, ax = plt.subplots(1, 1, figsize=(width / dpi, height / dpi), dpi=dpi)
matplotlib.use(orig_backend) # Switch back to the original backend.
ax.set_axis_off()
ax.set_aspect('equal')
ax.set_position([0, 0, 1, 1])
im = ax.imshow(frames[0])
def update(frame):
im.set_data(frame)
return [im]
interval = 1000/framerate
anim = animation.FuncAnimation(fig=fig, func=update, frames=frames,
interval=interval, blit=True, repeat=False)
return HTML(anim.to_html5_video())

# Seed numpy's global RNG so that cell outputs are deterministic. We also try to
# use RandomState instances that are local to a single cell wherever possible.
np.random.seed(42)


###### Environment wrappers ####
from dm_env import specs


# environment wrappers
class NormilizeActionSpecWrapper(wrappers.EnvironmentWrapper):
"""Turn each dimension of the actions into the range of [-1, 1]."""

def __init__(self, environment):
super().__init__(environment)

action_spec = environment.action_spec()
self._scale = action_spec.maximum - action_spec.minimum
self._offset = action_spec.minimum

minimum = action_spec.minimum * 0 - 1.
maximum = action_spec.minimum * 0 + 1.
self._action_spec = specs.BoundedArray(
action_spec.shape,
action_spec.dtype,
minimum,
maximum,
name=action_spec.name)

def _from_normal_actions(self, actions):
actions = 0.5 * (actions + 1.0) # a_t is now in the range [0, 1]
# scale range to [minimum, maximum]
return actions * self._scale + self._offset

def step(self, action):
action = self._from_normal_actions(action)
return self._environment.step(action)

def action_spec(self):
return self._action_spec


class MujocoActionNormalizer(wrappers.EnvironmentWrapper):
"""Rescale actions to [-1, 1] range for mujoco physics engine.
For control environments whose actions have bounded range in [-1, 1], this
adaptor rescale actions to the desired range. This allows actor network to
output unscaled actions for better gradient dynamics.
"""

def __init__(self, environment, rescale='clip'):
super().__init__(environment)
self._rescale = rescale

def step(self, action):
"""Rescale actions to [-1, 1] range before stepping wrapped environment."""
if self._rescale == 'tanh':
scaled_actions = tree.map_structure(np.tanh, action)
elif self._rescale == 'clip':
scaled_actions = tree.map_structure(lambda a: np.clip(a, -1., 1.), action)
else:
raise ValueError('Unrecognized scaling option: %s' % self._rescale)
return self._environment.step(scaled_actions)

from IPython.display import display, HTML

#@title Loading and simulating a `suite` task{vertical-output: true}

# Load the environment
# random_state = np.random.RandomState(42)
# env = suite.load('hopper', 'stand', task_kwargs={'random': random_state})

# Simulate episode with random actions
def visualize(duration=10, save=False, name="vids.mp4"):
frames = []
ticks = []
rewards = []
observations = []

spec = env.action_spec()
time_step = env.reset()

while env.physics.data.time < duration:
state = get_flat_obs(time_step)
action = agent.select_action(state)
time_step = env.step(action)

camera0 = env.physics.render(camera_id=0, height=400, width=400)
camera1 = env.physics.render(camera_id=1, height=400, width=400)
frames.append(np.hstack((camera0, camera1)))
rewards.append(time_step.reward)
observations.append(copy.deepcopy(time_step.observation))
ticks.append(env.physics.data.time)

html_video = display_video(frames, framerate=1./env.control_timestep())

# Show video and plot reward and observations
num_sensors = len(time_step.observation)

_, ax = plt.subplots(1 + num_sensors, 1, sharex=True, figsize=(4, 8))
ax[0].plot(ticks, rewards)
ax[0].set_ylabel('reward')
ax[-1].set_xlabel('time')

for i, key in enumerate(time_step.observation):
data = np.asarray([observations[j][key] for j in range(len(observations))])
ax[i+1].plot(ticks, data, label=key)
ax[i+1].set_ylabel(key)
if save:
save_video(frames, video_name=name)

return html_video


# load the environment
env = suite.load(domain_name="cheetah", task_name="run")
# add wrappers onto the environment
env = NormilizeActionSpecWrapper(env)
env = MujocoActionNormalizer(environment=env, rescale='clip')
env = wrappers.SinglePrecisionWrapper(env)


class Args:
env_name = 'whatever'
policy = 'Gaussian'
eval = True
gamma = 0.99
tau = 0.005
lr = 0.0003
alpha = 0.2
automatic_entropy_tuning = True
seed = 42
batch_size = 256
num_steps = 1000000
hidden_size = 256
updates_per_step = 1
start_steps = 10000
target_update_interval = 1
replay_size = 1000000
# use the cuda to speedup
cuda = True


args = Args()

# get the dimensionality of the observation_spec after flattening
flat_obs = tree.flatten(env.observation_spec())
# combine all the shapes
# obs_dim = sum([item.shape[0] for item in flat_obs])
obs_dim = 0
for i in flat_obs:
try:
obs_dim += i.shape[0]
except IndexError:
obs_dim += 1

# setup agent, using Soft-Actor-Critic Model
agent = SAC(obs_dim, env.action_spec(), args)
# load checkpoint - UPLOAD YOUR FILE HERE!
model_path = 'src/models/sac_checkpoint_cheetah_123456_10000'
agent.load_checkpoint(model_path, evaluate=True)

# pull out model
model = agent.policy
# setup hook dict
hook_dict = init_hook_dict(model)
# add hooks
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
module.register_forward_hook(recordtodict_hook(name=name, hook_dict=hook_dict))
print(model)
print(f"Successfully Loaded the checkpoint from {model_path}")
Binary file added src/models/sac_checkpoint_cheetah_123456_10000
Binary file not shown.
2 changes: 1 addition & 1 deletion submission.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{
"dockerhub-id": "scottyang17/dm:initial",
"dockerhub-id": "scottyang17/dm:latest",
"build-script": "python run.py"
}

0 comments on commit 0d22389

Please sign in to comment.