Skip to content

Commit

Permalink
Add type hints (#59)
Browse files Browse the repository at this point in the history
* Start typing env files

* Update RL script + fix lint warnings

* Finish typing core

* Reformat

author: @araffin
  • Loading branch information
araffin authored Mar 24, 2022
1 parent 21fc99e commit 6d9496e
Show file tree
Hide file tree
Showing 17 changed files with 275 additions and 263 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.6, 3.7] # 3.8 not supported yet by pytype
python-version: [3.6, 3.7, 3.8]

steps:
- uses: actions/checkout@v2
Expand Down
7 changes: 7 additions & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@
History
=======

1.2.0 (WIP)
------------------
* Added type hints to most core methods
* Added ``send_lidar_config()`` method to configure LIDAR
* Added car roll, pitch yaw angle
* Renamed lidar config to use snake case instead of CamelCase (for instance ``degPerSweepInc`` was renamed to ``deg_per_sweep_inc``)

1.1.1 (2021-02-28)
------------------
* Fix type checking error
Expand Down
4 changes: 3 additions & 1 deletion examples/genetic_alg/simple_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,10 @@ def make_new(self, parent1, parent2):


class GeneticAlg:
def __init__(self, population, conf={}):
def __init__(self, population, conf=None):
self.population = population
if conf is None:
conf = {}
self.conf = conf

def finished(self):
Expand Down
17 changes: 8 additions & 9 deletions examples/gym_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@
submitting random input for 3 episodes.
"""
import argparse
import uuid

import gym

import gym_donkeycar
import gym_donkeycar # noqa: F401

NUM_EPISODES = 3
MAX_TIME_STEPS = 1000
Expand All @@ -35,12 +34,12 @@ def select_action(env):

def simulate(env):

for episode in range(NUM_EPISODES):
for _ in range(NUM_EPISODES):

# Reset the environment
obv = env.reset()

for t in range(MAX_TIME_STEPS):
for _ in range(MAX_TIME_STEPS):

# Select an action
action = select_action(env)
Expand Down Expand Up @@ -100,11 +99,11 @@ def exit_scene(env):
"start_delay": 1,
"max_cte": 5,
"lidar_config": {
"degPerSweepInc": 2.0,
"degAngDown": 0.0,
"degAngDelta": -1.0,
"numSweepsLevels": 1,
"maxRange": 50.0,
"deg_per_sweep_inc": 2.0,
"deg_ang_down": 0.0,
"deg_ang_delta": -1.0,
"num_sweeps_levels": 1,
"max_range": 50.0,
"noise": 0.4,
"offset_x": 0.0,
"offset_y": 0.5,
Expand Down
6 changes: 3 additions & 3 deletions examples/reinforcement_learning/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

## ppo_train.py

An example using stable-baselines to train a ppo2 agent using the gym-donkeycar environment
An example using stable-baselines to train a PPO agent using the gym-donkeycar environment

* follow [stable-baselines](https://github.com/hill-a/stable-baselines) install
* follow [stable-baselines3](https://github.com/DLR-RM/stable-baselines3) install
* ```python gym-donkeycar/examples/reinforcement_learning/ppo_train.py --sim <path to simulator>```

## ddqn.py

An example training a [deep double Q-learning](https://arxiv.org/abs/1509.06461) agent using the gym-donkeycar environment

* ```python gym-donkeycar/examples/reinforcement_learning/ddqn.py --sim <path to simulator>```
* ```python gym-donkeycar/examples/reinforcement_learning/ddqn.py --sim <path to simulator>```
49 changes: 13 additions & 36 deletions examples/reinforcement_learning/ppo_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,31 +9,7 @@
import uuid

import gym
from stable_baselines import PPO2
from stable_baselines.common import set_global_seeds
from stable_baselines.common.policies import CnnPolicy
from stable_baselines.common.vec_env import DummyVecEnv


def make_env(env_id, rank, seed=0):
"""
Utility function for multiprocessed env.
:param env_id: (str) the environment ID
:param num_env: (int) the number of environment you wish to have in subprocesses
:param seed: (int) the inital seed for RNG
:param rank: (int) index of the subprocess
"""

def _init():
env = gym.make(env_id)
env.seed(seed + rank)
env.reset()
return env

set_global_seeds(seed)
return _init

from stable_baselines3 import PPO

if __name__ == "__main__":

Expand Down Expand Up @@ -93,15 +69,16 @@ def _init():

# Make an environment test our trained policy
env = gym.make(args.env_name, conf=conf)
env = DummyVecEnv([lambda: env])

model = PPO2.load("ppo_donkey")
model = PPO.load("ppo_donkey")

obs = env.reset()
for i in range(1000):
action, _states = model.predict(obs)
obs, rewards, dones, info = env.step(action)
for _ in range(1000):
action, _states = model.predict(obs, deterministic=True)
obs, reward, done, info = env.step(action)
env.render()
if done:
obs = env.reset()

print("done testing")

Expand All @@ -110,11 +87,8 @@ def _init():
# make gym env
env = gym.make(args.env_name, conf=conf)

# Create the vectorized environment
env = DummyVecEnv([lambda: env])

# create cnn policy
model = PPO2(CnnPolicy, env, verbose=1)
model = PPO("CnnPolicy", env, verbose=1)

# set up model in learning mode with goal number of timesteps to complete
model.learn(total_timesteps=10000)
Expand All @@ -123,16 +97,19 @@ def _init():

for i in range(1000):

action, _states = model.predict(obs)
action, _states = model.predict(obs, deterministic=True)

obs, rewards, dones, info = env.step(action)
obs, reward, done, info = env.step(action)

try:
env.render()
except Exception as e:
print(e)
print("failure in render, continuing...")

if done:
obs = env.reset()

if i % 100 == 0:
print("saving...")
model.save("ppo_donkey")
Expand Down
6 changes: 3 additions & 3 deletions examples/supervised_learning/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def generator(samples, batch_size=32, perc_to_augment=0.5):

try:
image = Image.open(fullpath)
except: # noqa: E722
except: # noqa: E722, B001
image = None

if image is None:
Expand Down Expand Up @@ -141,7 +141,7 @@ def get_files(filemask):
path, mask = os.path.split(filemask)

matches = []
for root, dirnames, filenames in os.walk(path):
for root, _, filenames in os.walk(path):
for filename in fnmatch.filter(filenames, mask):
matches.append(os.path.join(root, filename))
return matches
Expand Down Expand Up @@ -244,7 +244,7 @@ def go(model_name, epochs=50, inputs="./log/*.jpg", limit=None):
plt.legend(["train", "test"], loc="upper left")
plt.savefig(model_name + "loss.png")
plt.show()
except: # noqa: E722
except: # noqa: E722, B001
print("problems with loss graph")


Expand Down
18 changes: 10 additions & 8 deletions gym_donkeycar/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
import socket
import time
from threading import Thread
from typing import Any, Dict

from .util import replace_float_notation

logger = logging.getLogger(__name__)


class SDClient:
def __init__(self, host, port, poll_socket_sleep_time=0.001):
def __init__(self, host: str, port: int, poll_socket_sleep_time: float = 0.001):
self.msg = None
self.host = host
self.port = port
Expand All @@ -32,9 +33,10 @@ def __init__(self, host, port, poll_socket_sleep_time=0.001):
# the aborted flag will be set when we have detected a problem with the socket
# that we can't recover from.
self.aborted = False
self.s = None
self.connect()

def connect(self):
def connect(self) -> None:
self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

# connecting to the server
Expand All @@ -54,17 +56,17 @@ def connect(self):
self.th = Thread(target=self.proc_msg, args=(self.s,), daemon=True)
self.th.start()

def send(self, m):
def send(self, m: str) -> None:
self.msg = m

def send_now(self, msg):
def send_now(self, msg: str) -> None:
logger.debug("send_now:" + msg)
self.s.sendall(msg.encode("utf-8"))

def on_msg_recv(self, j):
def on_msg_recv(self, j: Dict[str, Any]) -> None:
logger.debug("got:" + j["msg_type"])

def stop(self):
def stop(self) -> None:
# signal proc_msg loop to stop, then wait for thread to finish
# close socket
self.do_process_msgs = False
Expand All @@ -73,15 +75,15 @@ def stop(self):
if self.s is not None:
self.s.close()

def proc_msg(self, sock): # noqa: C901
def proc_msg(self, sock: socket.socket) -> None: # noqa: C901
"""
This is the thread message loop to process messages.
We will send any message that is queued via the self.msg variable
when our socket is in a writable state.
And we will read any messages when it's in a readable state and then
call self.on_msg_recv with the json object message.
"""
sock.setblocking(0)
sock.setblocking(False)
inputs = [sock]
outputs = [sock]
localbuffer = ""
Expand Down
16 changes: 8 additions & 8 deletions gym_donkeycar/core/fps.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,19 @@ class FPSTimer(object):
Every N on_frame events, give the average iterations per interval.
"""

def __init__(self, N=100):
self.t = time.time()
def __init__(self, N: int = 100):
self.last_time = time.time()
self.iter = 0
self.N = N

def reset(self):
self.t = time.time()
def reset(self) -> None:
self.last_time = time.time()
self.iter = 0

def on_frame(self):
def on_frame(self) -> None:
self.iter += 1
if self.iter == self.N:
e = time.time()
print("fps", float(self.N) / (e - self.t))
self.t = time.time()
current_time = time.time()
print(f"fps {float(self.N) / (current_time - self.last_time):.2f}")
self.last_time = time.time()
self.iter = 0
11 changes: 7 additions & 4 deletions gym_donkeycar/core/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,20 @@
Base class for a handler expected by SimClient
"""
from typing import Any, Dict

from gym_donkeycar.core.client import SDClient


class IMesgHandler(object):
def on_connect(self, client):
def on_connect(self, client: SDClient) -> None:
pass

def on_recv_message(self, message):
def on_recv_message(self, message: Dict[str, Any]) -> None:
pass

def on_close(self):
def on_close(self) -> None:
pass

def on_disconnect(self):
def on_disconnect(self) -> None:
pass
19 changes: 11 additions & 8 deletions gym_donkeycar/core/sim_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
notes: wraps a tcp socket client with a handler to talk to the unity donkey simulator
"""
import json
from typing import Any, Dict, Tuple

from gym_donkeycar.core.message import IMesgHandler

from .client import SDClient

Expand All @@ -14,7 +17,7 @@ class SimClient(SDClient):
Handles messages from a single TCP client.
"""

def __init__(self, address, msg_handler):
def __init__(self, address: Tuple[str, int], msg_handler: IMesgHandler):
# we expect an IMesgHandler derived handler
# assert issubclass(msg_handler, IMesgHandler)

Expand All @@ -27,29 +30,29 @@ def __init__(self, address, msg_handler):
# we connect right away
msg_handler.on_connect(self)

def send_now(self, msg):
def send_now(self, msg: Dict[str, Any]) -> None:
# takes a dict input msg, converts to json string
# and sends immediately. right now, no queue.
json_msg = json.dumps(msg)
super().send_now(json_msg)

def queue_message(self, msg):
def queue_message(self, msg: Dict[str, Any]) -> None:
# takes a dict input msg, converts to json string
# and adds to a lossy queue that sends only the last msg
json_msg = json.dumps(msg)
self.send(json_msg)

def on_msg_recv(self, jsonObj):
def on_msg_recv(self, json_obj: Dict[str, Any]) -> None:
# pass message on to handler
self.msg_handler.on_recv_message(jsonObj)
self.msg_handler.on_recv_message(json_obj)

def is_connected(self):
def is_connected(self) -> bool:
return not self.aborted

def __del__(self):
def __del__(self) -> None:
self.close()

def close(self):
def close(self) -> None:
# Called to close client connection
self.stop()

Expand Down
Loading

0 comments on commit 6d9496e

Please sign in to comment.