From 1016ab5935c4f40f5381ffefec3c6a3de0451eac Mon Sep 17 00:00:00 2001 From: Sam Triest Date: Thu, 2 Jun 2022 11:26:09 -0400 Subject: [PATCH] images in dataset, ros costmapping node, resnetCNN-based costmaps --- maxent_irl_costmaps/algos/mppi_irl.py | 27 +++- .../dataset/maxent_irl_dataset.py | 20 ++- maxent_irl_costmaps/networks/resnet.py | 2 +- maxent_irl_costmaps/preprocess.py | 67 +++++++- scripts/ros/gridmap_to_costmap_node.py | 60 ++++++- scripts/ros/gridmap_to_height_costmap.py | 74 +++++++++ scripts/synthetic_data/make_fake_trail.py | 147 ++++++++++++++++++ scripts/training/train_ackermann_steer.py | 60 +++++++ scripts/training/train_skid_steer.py | 30 ++-- 9 files changed, 455 insertions(+), 32 deletions(-) create mode 100644 scripts/ros/gridmap_to_height_costmap.py create mode 100644 scripts/synthetic_data/make_fake_trail.py create mode 100644 scripts/training/train_ackermann_steer.py diff --git a/maxent_irl_costmaps/algos/mppi_irl.py b/maxent_irl_costmaps/algos/mppi_irl.py index 59db95a..a1b71b6 100644 --- a/maxent_irl_costmaps/algos/mppi_irl.py +++ b/maxent_irl_costmaps/algos/mppi_irl.py @@ -40,12 +40,13 @@ def __init__(self, expert_dataset, mppi, batch_size=64): """ self.expert_dataset = expert_dataset self.mppi = mppi - self.mppi_itrs = 5 + self.mppi_itrs = 10 # mlp_hiddens = [128, 128] # self.network = MLP(insize = len(expert_dataset.feature_keys), outsize=1, hiddens=mlp_hiddens) - hiddens = [] + hiddens = [128,] +# hiddens = [] self.network = ResnetCostmapCNN(in_channels=len(expert_dataset.feature_keys), out_channels=1, hidden_channels=hiddens) print(sum([x.numel() for x in self.network.parameters()])) @@ -97,6 +98,14 @@ def gradient_step(self, batch): HACK = {"state":initial_state, "steer_angle":torch.zeros(1)} x = self.mppi.model.get_observations(HACK) + map_params = { + 'resolution': map_metadata['resolution'].item(), + 'height': map_metadata['height'].item(), + 'width': map_metadata['width'].item(), + 'origin': map_metadata['origin'] + } + self.mppi.reset() + self.mppi.cost_fn.update_map_params(map_params) self.mppi.cost_fn.update_costmap(costmap) self.mppi.cost_fn.update_goal(expert_traj[-1, :2]) @@ -148,6 +157,11 @@ def visualize(self): map_features = data['map_features'][0] map_metadata = {k:v[0] for k,v in data['metadata'].items()} + metadata = data['metadata'] + xmin = metadata['origin'][0, 0] + ymin = metadata['origin'][0, 1] + xmax = xmin + metadata['width'][0] + ymax = ymin + metadata['height'][0] expert_traj = data['traj'][0] #compute costmap @@ -162,6 +176,15 @@ def visualize(self): initial_state = expert_traj[0] HACK = {"state":initial_state, "steer_angle":torch.zeros(1)} x = self.mppi.model.get_observations(HACK) + + map_params = { + 'resolution': map_metadata['resolution'].item(), + 'height': map_metadata['height'].item(), + 'width': map_metadata['width'].item(), + 'origin': map_metadata['origin'] + } + self.mppi.reset() + self.mppi.cost_fn.update_map_params(map_params) self.mppi.cost_fn.update_costmap(costmap) self.mppi.cost_fn.update_goal(expert_traj[-1, :2]) diff --git a/maxent_irl_costmaps/dataset/maxent_irl_dataset.py b/maxent_irl_costmaps/dataset/maxent_irl_dataset.py index 2230e41..296ed55 100644 --- a/maxent_irl_costmaps/dataset/maxent_irl_dataset.py +++ b/maxent_irl_costmaps/dataset/maxent_irl_dataset.py @@ -18,7 +18,7 @@ class MaxEntIRLDataset(Dataset): Ok, the ony diff now is that there are multiple bag files and we save the trajdata to a temporary pt file. """ - def __init__(self, bag_fp, preprocess_fp, map_features_topic='/local_gridmap', odom_topic='/integrated_to_init', horizon=70, dt=0.1, fill_value=0.): + def __init__(self, bag_fp, preprocess_fp, map_features_topic='/local_gridmap', odom_topic='/integrated_to_init', image_topic='/multisense/left/image_rect_color', horizon=70, dt=0.1, fill_value=0.): """ Args: bag_fp: The bag to get data from @@ -33,6 +33,7 @@ def __init__(self, bag_fp, preprocess_fp, map_features_topic='/local_gridmap', o self.preprocess_fp = preprocess_fp self.map_features_topic = map_features_topic self.odom_topic = odom_topic + self.image_topic = image_topic self.horizon = horizon self.dt = dt self.fill_value = fill_value #I don't know if this is the best way to do this, but setting the fill value to 0 implies that missing features contribute nothing to the cost. @@ -40,9 +41,6 @@ def __init__(self, bag_fp, preprocess_fp, map_features_topic='/local_gridmap', o self.initialize_dataset() -# self.dataset, self.feature_keys = load_data(bag_fp, self.map_features_topic, self.odom_topic, self.horizon, self.dt, self.fill_value) -# self.normalize_map_features() - def initialize_dataset(self): """ Profile the trajectories in the bag to: @@ -64,7 +62,7 @@ def initialize_dataset(self): if preprocess: for tfp in os.listdir(self.bag_fp): raw_fp = os.path.join(self.bag_fp, tfp) - data, feature_keys = load_data(raw_fp, self.map_features_topic, self.odom_topic, self.horizon, self.dt, self.fill_value) + data, feature_keys = load_data(raw_fp, self.map_features_topic, self.odom_topic, self.image_topic, self.horizon, self.dt, self.fill_value) if data is None: continue for i in range(len(data['traj'])): @@ -80,7 +78,6 @@ def initialize_dataset(self): torch.save(subdata, pp_fp) self.N += 1 - #Actually read all the data to get statistics. #need number of trajs, and the mean/std of all the map features. @@ -111,14 +108,14 @@ def initialize_dataset(self): self.feature_var = var_new K += k - self.feature_std = self.feature_var.sqrt() + self.feature_std = self.feature_var.sqrt() + 1e-6 self.feature_std[~torch.isfinite(self.feature_std)] = 1e-6 def visualize(self): """ Get a rough sense of features """ - n_panes = len(self.feature_keys) + n_panes = len(self.feature_keys) + 1 nx = int(np.sqrt(n_panes)) ny = int(n_panes / nx) + 1 @@ -134,12 +131,19 @@ def visualize(self): ymin = metadata['origin'][1] xmax = xmin + metadata['width'] ymax = ymin + metadata['height'] + for ax, feat, feat_key in zip(axs, feats, self.feature_keys): ax.imshow(feat, origin='lower', cmap='gray', extent=(xmin, xmax, ymin, ymax)) ax.plot(traj[:, 0], traj[:, 1], c='y') ax.set_title(feat_key) + if 'image' in data.keys(): + image = data['image'] + ax = axs[len(self.feature_keys)] + ax.imshow(image.permute(1, 2, 0)) + ax.set_title('Image') + return fig, axs def __len__(self): diff --git a/maxent_irl_costmaps/networks/resnet.py b/maxent_irl_costmaps/networks/resnet.py index 515ba98..7545df8 100644 --- a/maxent_irl_costmaps/networks/resnet.py +++ b/maxent_irl_costmaps/networks/resnet.py @@ -31,7 +31,7 @@ def __init__(self, in_channels, out_channels, hidden_channels, hidden_activation def forward(self, x): cnn_out = self.cnn.forward(x) - return cnn_out + return cnn_out.sigmoid() class ResnetCostmapBlock(nn.Module): """ diff --git a/maxent_irl_costmaps/preprocess.py b/maxent_irl_costmaps/preprocess.py index 24309e1..6c644f6 100644 --- a/maxent_irl_costmaps/preprocess.py +++ b/maxent_irl_costmaps/preprocess.py @@ -2,13 +2,16 @@ import torch import numpy as np import scipy.interpolate, scipy.spatial - -def load_data(bag_fp, map_features_topic, odom_topic, horizon, dt, fill_value): +import cv2 + +def load_data(bag_fp, map_features_topic, odom_topic, image_topic, horizon, dt, fill_value): """ Extract map features and trajectory data from the bag. """ + print(bag_fp) map_features_list = [] traj = [] + vels = [] timestamps = [] dataset = [] @@ -26,12 +29,24 @@ def load_data(bag_fp, map_features_topic, odom_topic, horizon, dt, fill_value): pose.orientation.w, ]) + twist = msg.twist.twist + v = np.array([ + twist.linear.x, + twist.linear.y, + twist.linear.z, + twist.angular.x, + twist.angular.y, + twist.angular.z, + ]) + traj.append(p) + vels.append(v) timestamps.append(msg.header.stamp.to_sec()) elif topic == map_features_topic: map_features_list.append(msg) traj = np.stack(traj, axis=0) + vels = np.stack(vels, axis=0) timestamps = np.array(timestamps) #edge case check @@ -51,6 +66,16 @@ def load_data(bag_fp, map_features_topic, odom_topic, horizon, dt, fill_value): rots = scipy.spatial.transform.Rotation.from_quat(traj[:, 3:]) interp_q = scipy.spatial.transform.Slerp(timestamps[idxs], rots[idxs]) + interp_vx = scipy.interpolate.interp1d(timestamps[idxs], vels[idxs, 0]) + interp_vy = scipy.interpolate.interp1d(timestamps[idxs], vels[idxs, 1]) + interp_vz = scipy.interpolate.interp1d(timestamps[idxs], vels[idxs, 2]) + + interp_wx = scipy.interpolate.interp1d(timestamps[idxs], vels[idxs, 3]) + interp_wy = scipy.interpolate.interp1d(timestamps[idxs], vels[idxs, 4]) + interp_wz = scipy.interpolate.interp1d(timestamps[idxs], vels[idxs, 5]) + + map_target_times = [] + #get a registered trajectory for each map. for i, map_features in enumerate(map_features_list): print('{}/{}'.format(i+1, len(map_features_list)), end='\r') @@ -82,14 +107,26 @@ def load_data(bag_fp, map_features_topic, odom_topic, horizon, dt, fill_value): start_time = map_features.info.header.stamp.to_sec() targets = start_time + np.arange(horizon) * dt + map_target_times.append(start_time) xs = interp_x(targets) ys = interp_y(targets) zs = interp_z(targets) qs = interp_q(targets).as_quat() + vxs = interp_vx(targets) + vys = interp_vy(targets) + vzs = interp_vz(targets) + wxs = interp_wx(targets) + wys = interp_wy(targets) + wzs = interp_wz(targets) + #handle transforms to deserialize map/costmap - traj = np.concatenate([np.stack([xs, ys, zs], axis=-1), qs], axis=-1) + traj = np.concatenate([ + np.stack([xs, ys, zs], axis=-1), + qs, + np.stack([vxs, vys, vzs, wxs, wys, wzs], axis=-1) + ], axis=-1) map_metadata = map_features.info xmin = map_metadata.pose.position.x - 0.5 * (map_metadata.length_x) @@ -110,6 +147,7 @@ def load_data(bag_fp, map_features_topic, odom_topic, horizon, dt, fill_value): 'map_features': torch.tensor(map_feature_data).float(), 'metadata': metadata_out } + dataset.append(data) #convert from gridmap to occgrid metadata @@ -117,7 +155,28 @@ def load_data(bag_fp, map_features_topic, odom_topic, horizon, dt, fill_value): dataset = { 'map_features':[x['map_features'] for x in dataset], 'traj':[x['traj'] for x in dataset], - 'metadata':[x['metadata'] for x in dataset] + 'metadata':[x['metadata'] for x in dataset], } + #If image topic exists, add to bag + if image_topic is not None: + image_timestamps = [] + for topic, msg, t in bag.read_messages(topics=[image_topic]): + image_timestamps.append(t.to_sec()) + image_timestamps = np.array(image_timestamps) + #get closest image to targets + dists = np.abs(np.expand_dims(image_timestamps, axis=0) - np.expand_dims(map_target_times, axis=1)) + image_targets = np.argmin(dists, axis=1) + + images = [] + for i, (topic, msg, t) in enumerate(bag.read_messages(topics=[image_topic])): + n_hits = np.sum(image_targets == i) + for j in range(n_hits): + img = np.frombuffer(msg.data, dtype=np.uint8) + img = cv2.imdecode(img, cv2.IMREAD_UNCHANGED) + img = cv2.resize(img, dsize=(224, 224), interpolation=cv2.INTER_AREA) + images.append(torch.tensor(img).permute(2, 0, 1)[[2, 1, 0]] / 255.) + + dataset['image'] = images + return dataset, feature_keys diff --git a/scripts/ros/gridmap_to_costmap_node.py b/scripts/ros/gridmap_to_costmap_node.py index 5d1506e..5db2ec7 100644 --- a/scripts/ros/gridmap_to_costmap_node.py +++ b/scripts/ros/gridmap_to_costmap_node.py @@ -4,7 +4,7 @@ import numpy as np import torch -from nav_msgs.msg import OccupancyGrid +from nav_msgs.msg import OccupancyGrid, Odometry from grid_map_msgs.msg import GridMap from rosbag_to_dataset.dtypes.gridmap import GridMapConvert @@ -13,16 +13,58 @@ class CostmapperNode: """ Node that listens to gridmaps from perception and uses IRL nets to make them into costmaps """ - def __init__(self, grid_map_topic, cost_map_topic): + def __init__(self, grid_map_topic, cost_map_topic, odom_topic, dataset, network): + """ + Args: + grid_map_topic: the topic to get map features from + cost_map_topic: The topic to publish costmaps to + odom_topic: The topic to get height from + dataset: The dataset that the network was trained on. (Need to get feature mean/var) + network: the network to produce costmaps. + """ + self.feature_keys = dataset.feature_keys + self.feature_mean = dataset.feature_mean + self.feature_std = dataset.feature_std + self.map_metadata = dataset.metadata + self.network = network + self.current_height = 0. + + #we will set the output resolution dynamically + self.grid_map_cvt = GridMapConvert(channels=self.feature_keys, size=[1, 1]) + self.grid_map_sub = rospy.Subscriber(grid_map_topic, GridMap, self.handle_grid_map, queue_size=1) + self.odom_sub = rospy.Subscriber(odom_topic, Odometry, self.handle_odom, queue_size=1) self.cost_map_pub = rospy.Publisher(cost_map_topic, OccupancyGrid, queue_size=1) - self.grid_map_cvt = GridMapConvert(channels=['diff'], output_resolution=[120, 120]) + + def handle_odom(self, msg): + self.current_height = msg.pose.pose.position.z def handle_grid_map(self, msg): rospy.loginfo('handling gridmap...') + nx = int(msg.info.length_x / msg.info.resolution) + ny = int(msg.info.length_y / msg.info.resolution) + self.grid_map_cvt.size = [nx, ny] gridmap = self.grid_map_cvt.ros_to_numpy(msg) - gridmap[~np.isfinite(gridmap)] = 0. - costmap = (gridmap[0] > 1.3).astype(np.uint8) * 100 + + rospy.loginfo_throttle(1.0, "output shape: {}".format(gridmap['data'].shape)) + + map_feats = torch.from_numpy(gridmap['data']).float() + for k in self.feature_keys: + if 'height' in k or 'terrain' in k: + idx = self.feature_keys.index(k) + map_feats[idx] -= self.current_height + + map_feats[~torch.isfinite(map_feats)] = 0. + map_feats[map_feats.abs() > 100.] = 0. + + map_feats_norm = (map_feats - self.feature_mean.view(-1, 1, 1)) / self.feature_std.view(-1, 1, 1) + with torch.no_grad(): + costmap = self.network.forward(map_feats_norm.view(1, *map_feats_norm.shape))[0] + + #experiment w/ normalizing + rospy.loginfo_throttle(1.0, "min = {}, max = {}".format(costmap.min(), costmap.max())) + costmap = (costmap - costmap.min()) / (costmap.max() - costmap.min()) + costmap = (costmap * 100.).long().numpy() costmap_msg = OccupancyGrid() costmap_msg.header.stamp = msg.info.header.stamp @@ -33,7 +75,7 @@ def handle_grid_map(self, msg): costmap_msg.info.origin.position.x = msg.info.pose.position.x - msg.info.length_x/2. costmap_msg.info.origin.position.y = msg.info.pose.position.y - msg.info.length_y/2. - costmap_msg.data = costmap[::-1, ::-1].flatten() + costmap_msg.data = costmap.flatten() self.cost_map_pub.publish(costmap_msg) @@ -42,7 +84,11 @@ def handle_grid_map(self, msg): grid_map_topic = '/local_gridmap' cost_map_topic = '/local_cost_map_final_occupancy_grid' + odom_topic = '/integrated_to_init' + mppi_irl = torch.load('../training/ackermann_costmaps/baseline2.pt') + +# mppi_irl.visualize() - costmapper = CostmapperNode(grid_map_topic, cost_map_topic) + costmapper = CostmapperNode(grid_map_topic, cost_map_topic, odom_topic, mppi_irl.expert_dataset, mppi_irl.network) rospy.spin() diff --git a/scripts/ros/gridmap_to_height_costmap.py b/scripts/ros/gridmap_to_height_costmap.py new file mode 100644 index 0000000..b274131 --- /dev/null +++ b/scripts/ros/gridmap_to_height_costmap.py @@ -0,0 +1,74 @@ +#! /usr/bin/python3 + +import rospy +import numpy as np +import torch + +from nav_msgs.msg import OccupancyGrid, Odometry +from grid_map_msgs.msg import GridMap + +from rosbag_to_dataset.dtypes.gridmap import GridMapConvert + +class CostmapperNode: + """ + Node that listens to gridmaps from perception and uses IRL nets to make them into costmaps + """ + def __init__(self, grid_map_topic, cost_map_topic, odom_topic): + """ + Args: + grid_map_topic: the topic to get map features from + cost_map_topic: The topic to publish costmaps to + odom_topic: The topic to get height from + dataset: The dataset that the network was trained on. (Need to get feature mean/var) + network: the network to produce costmaps. + """ + self.current_height = 0. + + #we will set the output resolution dynamically + self.grid_map_cvt = GridMapConvert(channels=['diff'], size=[1, 1]) + + self.grid_map_sub = rospy.Subscriber(grid_map_topic, GridMap, self.handle_grid_map, queue_size=1) + self.odom_sub = rospy.Subscriber(odom_topic, Odometry, self.handle_odom, queue_size=1) + self.cost_map_pub = rospy.Publisher(cost_map_topic, OccupancyGrid, queue_size=1) + + def handle_odom(self, msg): + self.current_height = msg.pose.pose.position.z + + def handle_grid_map(self, msg): + rospy.loginfo('handling gridmap...') + nx = int(msg.info.length_x / msg.info.resolution) + ny = int(msg.info.length_y / msg.info.resolution) + self.grid_map_cvt.size = [nx, ny] + gridmap = self.grid_map_cvt.ros_to_numpy(msg) + + rospy.loginfo_throttle(1.0, "output shape: {}".format(gridmap['data'].shape)) + + map_feats = torch.from_numpy(gridmap['data']).float() + + #geometrically determine cost + costmap = map_feats[0] > 1.0 + costmap = (costmap * 100.).long().numpy() + + costmap_msg = OccupancyGrid() + costmap_msg.header.stamp = msg.info.header.stamp + costmap_msg.header.frame_id = msg.info.header.frame_id + costmap_msg.info.resolution = msg.info.resolution + costmap_msg.info.width = int(msg.info.length_x / msg.info.resolution) + costmap_msg.info.height = int(msg.info.length_y / msg.info.resolution) + costmap_msg.info.origin.position.x = msg.info.pose.position.x - msg.info.length_x/2. + costmap_msg.info.origin.position.y = msg.info.pose.position.y - msg.info.length_y/2. + + costmap_msg.data = costmap.flatten() + + self.cost_map_pub.publish(costmap_msg) + +if __name__ == '__main__': + rospy.init_node('costmapper_node') + + grid_map_topic = '/local_gridmap' + cost_map_topic = '/local_cost_map_final_occupancy_grid' + odom_topic = '/warty/odom' + + costmapper = CostmapperNode(grid_map_topic, cost_map_topic, odom_topic) + + rospy.spin() diff --git a/scripts/synthetic_data/make_fake_trail.py b/scripts/synthetic_data/make_fake_trail.py new file mode 100644 index 0000000..22af9f3 --- /dev/null +++ b/scripts/synthetic_data/make_fake_trail.py @@ -0,0 +1,147 @@ +import torch +import matplotlib.pyplot as plt +import argparse +import os + +from torch_mpc.models.skid_steer import SkidSteer + +""" +Simple code for generating fake 'trail' features and trajectories. +""" + +def expand_trajs(trajs): + """ + go from x,y,th to + x,y,z,qx,qy,qz,qw + """ + x, y, th = trajs.moveaxis(-1, 0) + z = torch.zeros_like(x) + qw = torch.cos(th/2.) + qx = torch.zeros_like(x) + qy = torch.zeros_like(y) + qz = torch.sin(th/2.) + return torch.stack([x, y, z, qx, qy, qz, qw], axis=-1) + +def get_feature_maps_from_trajs(trajs, metadata, obstacle_radius = 1.0): + """ + Get obstacle and position maps from trajs assuming that everything more than + obstacle radius away from the trajs is obstacle + + Args: + trajs: The trajectories to infer obstacles from + metadata: The metadata of the maps to produce + obstacle_radius: Cells within this distance of any traj will be considered free + """ + feature_keys = [ + 'height_high' + ] + + res = metadata['resolution'] + width = metadata['width'] + height = metadata['height'] + ox = metadata['origin'][0] + oy = metadata['origin'][1] + + xs = ox + torch.arange(start=0., end=width, step=res) + ys = oy + torch.arange(start=0., end=height, step=res) + + positions = torch.meshgrid(xs, ys, indexing='ij') + positions = torch.stack(positions, axis=-1) #[x by y] + + distances = torch.linalg.norm(positions, axis=-1) + pos_x = positions[..., 0] + pos_y = positions[..., 1] + + obstacle_acc = torch.ones_like(distances).bool() + + #iterate through trajectories to save memory + for traj in trajs: + traj_poses = traj[:, :2].unsqueeze(1) #[T x G x 2] + map_poses = positions.view(-1, 2).unsqueeze(0) #[T x G x 2] + dists = (traj_poses - map_poses).pow(2).sum(dim=-1).sqrt().min(dim=0)[0] + dists_map = dists.view(obstacle_acc.shape) + mask = (dists_map > obstacle_radius) + obstacle_acc = obstacle_acc & mask + + obstacles = obstacle_acc.float() + + features = torch.stack([ + obstacles.T + ], axis=0) + + return features, feature_keys + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--both', action='store_true') + parser.add_argument('--save_dir', required=True, help='location to save the temp data') + args = parser.parse_args() + + model = SkidSteer() + + x0 = torch.zeros(model.observation_space().shape[0]) + U1 = torch.zeros(100, 2) + U1[:, 0] = 1.5 + U1[30:50, 1] = 0.2 + + U2 = U1.clone() + U2[:, 1] *= -1. + + + metadata = { + 'width': torch.tensor(30.), + 'height': torch.tensor(30.), + 'resolution': torch.tensor(0.25), + 'origin': torch.tensor([-15., -15.]) + } + + left_trajs = [] + right_trajs = [] + + for i in range(10): + z = torch.randn_like(U1) * 0.05 + traj = model.rollout(x0, U1 + z) + traj2 = model.rollout(x0, U2 + z) + left_trajs.append(traj) + right_trajs.append(traj2) + + + left_trajs = expand_trajs(torch.stack(left_trajs, dim=0)) + right_trajs = expand_trajs(torch.stack(right_trajs, dim=0)) + all_trajs = torch.cat([left_trajs, right_trajs], dim=0) + + feature_maps, feature_keys = get_feature_maps_from_trajs(all_trajs, metadata) + left_feature_maps, feature_keys = get_feature_maps_from_trajs(left_trajs, metadata) + right_feature_maps, feature_keys = get_feature_maps_from_trajs(right_trajs, metadata) + + feature_keys = ['height_high', 'height_high_left', 'height_high_right'] + feature_maps = torch.cat([feature_maps, left_feature_maps, right_feature_maps], axis=0) + + dataset = [] + for traj in left_trajs: + datapt = { + 'metadata': metadata, + 'feature_keys': feature_keys, + 'map_features': feature_maps, + 'traj': traj + } + dataset.append(datapt) + + if args.both: + for traj in right_trajs: + datapt = { + 'metadata': metadata, + 'feature_keys': feature_keys, + 'map_features': feature_maps, + 'traj': traj + } + dataset.append(datapt) + + if os.path.exists(args.save_dir): + x = input('{} already exists. Preprocess again? [y/N]'.format(args.save_dir)) + else: + os.mkdir(args.save_dir) + + for i, x in enumerate(dataset): + fp = os.path.join(args.save_dir, 'traj_{}.pt'.format(i)) + torch.save(x, fp) diff --git a/scripts/training/train_ackermann_steer.py b/scripts/training/train_ackermann_steer.py new file mode 100644 index 0000000..ba4a85b --- /dev/null +++ b/scripts/training/train_ackermann_steer.py @@ -0,0 +1,60 @@ +import torch +import os +import argparse +import matplotlib.pyplot as plt + +from torch_mpc.models.steer_setpoint_kbm import SteerSetpointKBM +from torch_mpc.algos.mppi import MPPI +from torch_mpc.cost_functions.waypoint_costmap import WaypointCostMapCostFunction + +from maxent_irl_costmaps.dataset.maxent_irl_dataset import MaxEntIRLDataset +from maxent_irl_costmaps.algos.mppi_irl import MPPIIRL + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--experiment_fp', type=str, required=True, help='dir to save experiment results') + parser.add_argument('--rosbag_dir', type=str, required=True, help='dir for rosbags to train from') + parser.add_argument('--preprocess_dir', type=str, required=True, help='dir to save preprocessed data to') + parser.add_argument('--map_topic', type=str, required=False, default='/local_gridmap', help='topic to extract map features from') + parser.add_argument('--odom_topic', type=str, required=False, default='/integrated_to_init', help='topic to extract odom from') + parser.add_argument('--image_topic', type=str, required=False, default='/multisense/left/image_rect_color', help='topic to extract images from') + parser.add_argument('--epochs', type=int, required=False, default=10, help='number of epochs to run') + parser.add_argument('--batch_size', type=int, required=False, default=64, help='batch size') + parser.add_argument('--horizon', type=int, required=False, default=70, help='number of mppi steps to optimize over') + parser.add_argument('--model_fp', type=str, required=False, default=None, help='fp to old experiment if testing/fine tuning') + parser.add_argument('--test', action='store_true') + args = parser.parse_args() + + dataset = MaxEntIRLDataset(bag_fp=args.rosbag_dir, preprocess_fp=args.preprocess_dir, map_features_topic=args.map_topic, odom_topic=args.odom_topic, image_topic=args.image_topic, horizon=int(args.horizon) * 1.0) + +# for i in range(10): +# dataset.visualize() +# plt.show() + + vmin = 1.0 + vmax = 8.0 + wmax = 0.5 + model = SteerSetpointKBM(L=3.0, v_target_lim=[vmin, vmax], steer_lim=[-wmax, wmax], steer_rate_lim=0.5) + parameters = { + 'log_K_delta':torch.tensor(10.0), + 'log_K_v':torch.tensor(1.0) + } + model.update_parameters(parameters) + + cfn = WaypointCostMapCostFunction(unknown_cost=0., goal_cost=5.0, map_params=dataset.metadata) + mppi = MPPI(model=model, cost_fn=cfn, num_samples=2048, num_timesteps=args.horizon, control_params={'sys_noise':torch.tensor([0.2 * vmax, 0.2 * wmax]), 'temperature':0.2}) + + if args.model_fp: + mppi_irl = torch.load(args.model_fp) + else: + mppi_irl = MPPIIRL(dataset, mppi, args.batch_size) + mppi_irl.mppi_itrs = 20 + +# mppi_irl.visualize() + + if not args.test: + for ei in range(args.epochs): + mppi_irl.update() + torch.save(mppi_irl, 'ackermann_costmaps/baseline3.pt') + + mppi_irl.visualize() diff --git a/scripts/training/train_skid_steer.py b/scripts/training/train_skid_steer.py index 1079a6d..6b30b3e 100644 --- a/scripts/training/train_skid_steer.py +++ b/scripts/training/train_skid_steer.py @@ -17,29 +17,39 @@ parser.add_argument('--preprocess_dir', type=str, required=True, help='dir to save preprocessed data to') parser.add_argument('--map_topic', type=str, required=False, default='/local_gridmap', help='topic to extract map features from') parser.add_argument('--odom_topic', type=str, required=False, default='/warty/odom', help='topic to extract odom from') + parser.add_argument('--image_topic', type=str, required=False, default='/warty/stereo_left/image_rect_color/compressed', help='topic to extract image from') parser.add_argument('--epochs', type=int, required=False, default=10, help='number of epochs to run') parser.add_argument('--batch_size', type=int, required=False, default=64, help='batch size') parser.add_argument('--horizon', type=int, required=False, default=70, help='number of mppi steps to optimize over') + parser.add_argument('--model_fp', type=str, required=False, default=None, help='fp to old experiment if testing/fine tuning') + parser.add_argument('--test', action='store_true') args = parser.parse_args() - dataset = MaxEntIRLDataset(bag_fp=args.rosbag_dir, preprocess_fp=args.preprocess_dir, map_features_topic=args.map_topic, odom_topic=args.odom_topic, horizon=int(args.horizon) * 1.5) + dataset = MaxEntIRLDataset(bag_fp=args.rosbag_dir, preprocess_fp=args.preprocess_dir, map_features_topic=args.map_topic, odom_topic=args.odom_topic, image_topic=args.image_topic, horizon=int(args.horizon) * 1.0) -# for i in range(10): -# dataset.visualize() -# plt.show() + for i in range(10): + dataset.visualize() + plt.show() - model = SkidSteer(v_lim=[0.5, 3.0], w_lim=[-1.5, 1.5]) + vmax = 2.0 + wmax = 1.5 + model = SkidSteer(v_lim=[-vmax, vmax], w_lim=[-wmax, wmax]) - cfn = WaypointCostMapCostFunction(unknown_cost=0., goal_cost=0., map_params=dataset.metadata) - mppi = MPPI(model=model, cost_fn=cfn, num_samples=2048, num_timesteps=args.horizon, control_params={'sys_noise':torch.tensor([1.0, 0.5]), 'temperature':0.05}) + cfn = WaypointCostMapCostFunction(unknown_cost=0., goal_cost=10.0, map_params=dataset.metadata) + mppi = MPPI(model=model, cost_fn=cfn, num_samples=2048, num_timesteps=args.horizon, control_params={'sys_noise':torch.tensor([0.2 * vmax, 0.2 * wmax]), 'temperature':0.2}) - mppi_irl = MPPIIRL(dataset, mppi, args.batch_size) + if args.model_fp: + mppi_irl = torch.load(args.model_fp) + else: + mppi_irl = MPPIIRL(dataset, mppi, args.batch_size) + mppi_irl.mppi_itrs = 20 + + mppi_irl.visualize() for ei in range(args.epochs): mppi_irl.update() - - torch.save(mppi_irl.network, 'baseline.pt') + torch.save(mppi_irl, 'baseline.pt') # if ((ei+1) % 5) == 0: # mppi_irl.visualize() mppi_irl.visualize()