-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
update docker image and the run file
- Loading branch information
1 parent
9740893
commit 0d22389
Showing
3 changed files
with
328 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,328 @@ | ||
import sys | ||
sys.path.insert(0,'src') | ||
|
||
#@title Run to install MuJoCo and `dm_control` | ||
import distutils.util | ||
import subprocess | ||
if __name__ == "main": | ||
subprocess.call("src/train_agent.py --num_steps=10") | ||
# if subprocess.run('nvidia-smi').returncode: | ||
# raise RuntimeError( | ||
# 'Cannot communicate with GPU. ' | ||
# 'Make sure you are using a GPU Colab runtime. ' | ||
# 'Go to the Runtime menu and select Choose runtime type.') | ||
|
||
print('Installing dm_control...') | ||
# !pip install -q dm_control==1.0.8 | ||
|
||
# # Configure dm_control to use the EGL rendering backend (requires GPU) | ||
# %env MUJOCO_GL=osmesa | ||
# %env PYOPENGL_PLATFORM= | ||
# %env PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python | ||
|
||
print('Checking that the dm_control installation succeeded...') | ||
try: | ||
from dm_control import suite | ||
env = suite.load('cartpole', 'swingup') | ||
pixels = env.physics.render() | ||
except Exception as e: | ||
raise e from RuntimeError( | ||
'Something went wrong during installation. Check the shell output above ' | ||
'for more information.\n' | ||
'If using a hosted Colab runtime, make sure you enable GPU acceleration ' | ||
'by going to the Runtime menu and selecting "Choose runtime type".') | ||
else: | ||
del pixels, suite | ||
|
||
# !echo Installed dm_control $(pip show dm_control | grep -Po "(?<=Version: ).+") | ||
# !rm -r "=1.0.8" | ||
|
||
#@title All `dm_control` imports required for this tutorial | ||
|
||
# The basic mujoco wrapper. | ||
from dm_control import mujoco | ||
|
||
# Access to enums and MuJoCo library functions. | ||
from dm_control.mujoco.wrapper.mjbindings import enums | ||
from dm_control.mujoco.wrapper.mjbindings import mjlib | ||
|
||
# PyMJCF | ||
from dm_control import mjcf | ||
|
||
# Composer high level imports | ||
from dm_control import composer | ||
from dm_control.composer.observation import observable | ||
from dm_control.composer import variation | ||
|
||
# Imports for Composer tutorial example | ||
from dm_control.composer.variation import distributions | ||
from dm_control.composer.variation import noises | ||
from dm_control.locomotion.arenas import floors | ||
|
||
# Control Suite | ||
from dm_control import suite | ||
|
||
# Run through corridor example | ||
from dm_control.locomotion.walkers import cmu_humanoid | ||
from dm_control.locomotion.arenas import corridors as corridor_arenas | ||
from dm_control.locomotion.tasks import corridors as corridor_tasks | ||
|
||
# # Soccer | ||
# from dm_control.locomotion import soccer | ||
|
||
# Manipulation | ||
from dm_control import manipulation | ||
|
||
#@title Other imports and helper functions | ||
|
||
# General | ||
import copy | ||
import os | ||
import itertools | ||
from IPython.display import clear_output | ||
import numpy as np | ||
|
||
# Graphics-related | ||
import matplotlib | ||
import matplotlib.animation as animation | ||
import matplotlib.pyplot as plt | ||
from IPython.display import HTML | ||
import PIL.Image | ||
# Internal loading of video libraries. | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torch.distributions import Normal | ||
from torch.optim import Adam | ||
# from torch.utils.tensorboard import SummaryWriter | ||
|
||
# import this first to resolve the issue. | ||
from acme import wrappers | ||
from model import * | ||
from utils import * | ||
# Soft-Actor-Critic Model | ||
from sac import * | ||
from replay_memory import * | ||
import argparse | ||
import datetime | ||
import itertools | ||
import os | ||
import random | ||
import math | ||
import pickle | ||
|
||
|
||
# try out the wrappers | ||
from acme import wrappers | ||
from dm_control import suite | ||
|
||
# # Use svg backend for figure rendering | ||
# %config InlineBackend.figure_format = 'svg' | ||
|
||
# Font sizes | ||
SMALL_SIZE = 8 | ||
MEDIUM_SIZE = 10 | ||
BIGGER_SIZE = 12 | ||
plt.rc('font', size=SMALL_SIZE) # controls default text sizes | ||
plt.rc('axes', titlesize=SMALL_SIZE) # fontsize of the axes title | ||
plt.rc('axes', labelsize=MEDIUM_SIZE) # fontsize of the x and y labels | ||
plt.rc('xtick', labelsize=SMALL_SIZE) # fontsize of the tick labels | ||
plt.rc('ytick', labelsize=SMALL_SIZE) # fontsize of the tick labels | ||
plt.rc('legend', fontsize=SMALL_SIZE) # legend fontsize | ||
plt.rc('figure', titlesize=BIGGER_SIZE) # fontsize of the figure title | ||
|
||
# Inline video helper function | ||
if os.environ.get('COLAB_NOTEBOOK_TEST', False): | ||
# We skip video generation during tests, as it is quite expensive. | ||
display_video = lambda *args, **kwargs: None | ||
else: | ||
def display_video(frames, framerate=30): | ||
height, width, _ = frames[0].shape | ||
dpi = 70 | ||
orig_backend = matplotlib.get_backend() | ||
matplotlib.use('Agg') # Switch to headless 'Agg' to inhibit figure rendering. | ||
fig, ax = plt.subplots(1, 1, figsize=(width / dpi, height / dpi), dpi=dpi) | ||
matplotlib.use(orig_backend) # Switch back to the original backend. | ||
ax.set_axis_off() | ||
ax.set_aspect('equal') | ||
ax.set_position([0, 0, 1, 1]) | ||
im = ax.imshow(frames[0]) | ||
def update(frame): | ||
im.set_data(frame) | ||
return [im] | ||
interval = 1000/framerate | ||
anim = animation.FuncAnimation(fig=fig, func=update, frames=frames, | ||
interval=interval, blit=True, repeat=False) | ||
return HTML(anim.to_html5_video()) | ||
|
||
# Seed numpy's global RNG so that cell outputs are deterministic. We also try to | ||
# use RandomState instances that are local to a single cell wherever possible. | ||
np.random.seed(42) | ||
|
||
|
||
###### Environment wrappers #### | ||
from dm_env import specs | ||
|
||
|
||
# environment wrappers | ||
class NormilizeActionSpecWrapper(wrappers.EnvironmentWrapper): | ||
"""Turn each dimension of the actions into the range of [-1, 1].""" | ||
|
||
def __init__(self, environment): | ||
super().__init__(environment) | ||
|
||
action_spec = environment.action_spec() | ||
self._scale = action_spec.maximum - action_spec.minimum | ||
self._offset = action_spec.minimum | ||
|
||
minimum = action_spec.minimum * 0 - 1. | ||
maximum = action_spec.minimum * 0 + 1. | ||
self._action_spec = specs.BoundedArray( | ||
action_spec.shape, | ||
action_spec.dtype, | ||
minimum, | ||
maximum, | ||
name=action_spec.name) | ||
|
||
def _from_normal_actions(self, actions): | ||
actions = 0.5 * (actions + 1.0) # a_t is now in the range [0, 1] | ||
# scale range to [minimum, maximum] | ||
return actions * self._scale + self._offset | ||
|
||
def step(self, action): | ||
action = self._from_normal_actions(action) | ||
return self._environment.step(action) | ||
|
||
def action_spec(self): | ||
return self._action_spec | ||
|
||
|
||
class MujocoActionNormalizer(wrappers.EnvironmentWrapper): | ||
"""Rescale actions to [-1, 1] range for mujoco physics engine. | ||
For control environments whose actions have bounded range in [-1, 1], this | ||
adaptor rescale actions to the desired range. This allows actor network to | ||
output unscaled actions for better gradient dynamics. | ||
""" | ||
|
||
def __init__(self, environment, rescale='clip'): | ||
super().__init__(environment) | ||
self._rescale = rescale | ||
|
||
def step(self, action): | ||
"""Rescale actions to [-1, 1] range before stepping wrapped environment.""" | ||
if self._rescale == 'tanh': | ||
scaled_actions = tree.map_structure(np.tanh, action) | ||
elif self._rescale == 'clip': | ||
scaled_actions = tree.map_structure(lambda a: np.clip(a, -1., 1.), action) | ||
else: | ||
raise ValueError('Unrecognized scaling option: %s' % self._rescale) | ||
return self._environment.step(scaled_actions) | ||
|
||
from IPython.display import display, HTML | ||
|
||
#@title Loading and simulating a `suite` task{vertical-output: true} | ||
|
||
# Load the environment | ||
# random_state = np.random.RandomState(42) | ||
# env = suite.load('hopper', 'stand', task_kwargs={'random': random_state}) | ||
|
||
# Simulate episode with random actions | ||
def visualize(duration=10, save=False, name="vids.mp4"): | ||
frames = [] | ||
ticks = [] | ||
rewards = [] | ||
observations = [] | ||
|
||
spec = env.action_spec() | ||
time_step = env.reset() | ||
|
||
while env.physics.data.time < duration: | ||
state = get_flat_obs(time_step) | ||
action = agent.select_action(state) | ||
time_step = env.step(action) | ||
|
||
camera0 = env.physics.render(camera_id=0, height=400, width=400) | ||
camera1 = env.physics.render(camera_id=1, height=400, width=400) | ||
frames.append(np.hstack((camera0, camera1))) | ||
rewards.append(time_step.reward) | ||
observations.append(copy.deepcopy(time_step.observation)) | ||
ticks.append(env.physics.data.time) | ||
|
||
html_video = display_video(frames, framerate=1./env.control_timestep()) | ||
|
||
# Show video and plot reward and observations | ||
num_sensors = len(time_step.observation) | ||
|
||
_, ax = plt.subplots(1 + num_sensors, 1, sharex=True, figsize=(4, 8)) | ||
ax[0].plot(ticks, rewards) | ||
ax[0].set_ylabel('reward') | ||
ax[-1].set_xlabel('time') | ||
|
||
for i, key in enumerate(time_step.observation): | ||
data = np.asarray([observations[j][key] for j in range(len(observations))]) | ||
ax[i+1].plot(ticks, data, label=key) | ||
ax[i+1].set_ylabel(key) | ||
if save: | ||
save_video(frames, video_name=name) | ||
|
||
return html_video | ||
|
||
|
||
# load the environment | ||
env = suite.load(domain_name="cheetah", task_name="run") | ||
# add wrappers onto the environment | ||
env = NormilizeActionSpecWrapper(env) | ||
env = MujocoActionNormalizer(environment=env, rescale='clip') | ||
env = wrappers.SinglePrecisionWrapper(env) | ||
|
||
|
||
class Args: | ||
env_name = 'whatever' | ||
policy = 'Gaussian' | ||
eval = True | ||
gamma = 0.99 | ||
tau = 0.005 | ||
lr = 0.0003 | ||
alpha = 0.2 | ||
automatic_entropy_tuning = True | ||
seed = 42 | ||
batch_size = 256 | ||
num_steps = 1000000 | ||
hidden_size = 256 | ||
updates_per_step = 1 | ||
start_steps = 10000 | ||
target_update_interval = 1 | ||
replay_size = 1000000 | ||
# use the cuda to speedup | ||
cuda = True | ||
|
||
|
||
args = Args() | ||
|
||
# get the dimensionality of the observation_spec after flattening | ||
flat_obs = tree.flatten(env.observation_spec()) | ||
# combine all the shapes | ||
# obs_dim = sum([item.shape[0] for item in flat_obs]) | ||
obs_dim = 0 | ||
for i in flat_obs: | ||
try: | ||
obs_dim += i.shape[0] | ||
except IndexError: | ||
obs_dim += 1 | ||
|
||
# setup agent, using Soft-Actor-Critic Model | ||
agent = SAC(obs_dim, env.action_spec(), args) | ||
# load checkpoint - UPLOAD YOUR FILE HERE! | ||
model_path = 'src/models/sac_checkpoint_cheetah_123456_10000' | ||
agent.load_checkpoint(model_path, evaluate=True) | ||
|
||
# pull out model | ||
model = agent.policy | ||
# setup hook dict | ||
hook_dict = init_hook_dict(model) | ||
# add hooks | ||
for name, module in model.named_modules(): | ||
if isinstance(module, torch.nn.Linear): | ||
module.register_forward_hook(recordtodict_hook(name=name, hook_dict=hook_dict)) | ||
print(model) | ||
print(f"Successfully Loaded the checkpoint from {model_path}") |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
{ | ||
"dockerhub-id": "scottyang17/dm:initial", | ||
"dockerhub-id": "scottyang17/dm:latest", | ||
"build-script": "python run.py" | ||
} |