-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: Yang Gao <[email protected]>
- Loading branch information
1 parent
04e2281
commit 5ae0f33
Showing
22 changed files
with
2,437 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# Ignore Python virtual environment | ||
venv/ | ||
|
||
# Ignore compiled Python files | ||
*.pyc | ||
|
||
# Ignore cache and temporary files | ||
__pycache__/ | ||
*.pyo | ||
*.swp | ||
*.swo | ||
|
||
# Ignore IDE-specific files | ||
.vscode/ | ||
.idea/ | ||
|
||
# Ignore environment-specific files | ||
.env | ||
.env.local | ||
.env.*.local | ||
|
||
# Ignore log files | ||
*.log | ||
|
||
# Ignore package lock files | ||
pip-lock.txt | ||
poetry.lock | ||
|
||
# Ignore generated documentation | ||
docs/_build/ | ||
|
||
# Ignore test coverage reports | ||
htmlcov/ | ||
|
||
# Ignore compiled binaries | ||
*.exe | ||
*.dll | ||
*.so | ||
*.dylib | ||
|
||
# Ignore database files | ||
*.db | ||
|
||
# Ignore generated files | ||
*.pyc | ||
*.pyo | ||
*.pyd | ||
__pycache__/ | ||
|
||
# Ignore cache files | ||
*.cache | ||
|
||
# Ignore system-specific files | ||
.DS_Store | ||
Thumbs.db | ||
*.ndjson | ||
experiments/* | ||
data/* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
<div align="center"> | ||
<h1> Social-Transmotion:<br> Promptable Human Trajectory Prediction </h1> | ||
<h3>Saeed Saadatnejad*, Yang Gao*, Kaouther Messaoud, Alexandre Alahi | ||
</h3> | ||
|
||
|
||
<image src="docs/social-transmotion.png" width="600"> | ||
</div> | ||
|
||
<div align="center"> <h3> Abstract </h3> </div> | ||
<div align="justify"> | ||
|
||
Accurate human trajectory prediction is crucial for applications such as autonomous vehicles, robotics, and surveillance systems. Yet, existing models often fail to fully leverage the non-verbal social cues human subconsciously communicate when navigating the space. | ||
To address this, we introduce Social-Transmotion, a generic model that exploits the power of transformers to handle diverse and numerous visual cues, capturing the multi-modal nature of human behavior. We translate the idea of a prompt from Natural Language Processing (NLP) to the task of human trajectory prediction, where a prompt can be a sequence of x-y coordinates on the ground, bounding boxes or body poses. This, in turn, augments trajectory data, leading to enhanced human trajectory prediction. | ||
Our model exhibits flexibility and adaptability by capturing spatiotemporal interactions between pedestrians based on the available visual cues, whether they are poses, bounding boxes, or a combination thereof. | ||
By the masking technique, we ensure our model's effectiveness even when certain visual cues are unavailable, although performance is further boosted with the presence of comprehensive visual data. | ||
</br> | ||
|
||
|
||
# Getting Started | ||
|
||
Install the requirements using `pip`: | ||
``` | ||
pip install -r requirements.txt | ||
``` | ||
|
||
We have conveniently added the preprocessed data to the release section of the repository. | ||
Place the data subdirectory of JTA under `data/jta_all_visual_cues` and the data subdirectory of JRDB under `data/jrdb_2dbox` of the repository. | ||
|
||
# Training and Testing | ||
|
||
## JTA dataset | ||
You can train the Social-Transmotion model on this dataset using the following command: | ||
``` | ||
python train_jta.py --cfg configs/jta_all_visual_cues.yaml --exp_name jta | ||
``` | ||
|
||
|
||
To evaluate the trained model, use the following command: | ||
``` | ||
python evaluate_jta.py --ckpt ./experiments/jta/checkpoints/checkpoint.pth.tar --metric ade_fde --modality traj+all | ||
``` | ||
Please note that the evaluation modality can be any of `[traj, traj+2dbox, traj+3dpose, traj+2dpose, traj+3dpose+3dbox, traj+all]`. | ||
For the ease of use, we have also provided the trained model in the release section of this repo. In order to use that, you should pass the address of the saved checkpoint via `--ckpt`. | ||
|
||
## JRDB dataset | ||
You can train the Social-Transmotion model on this dataset using the following command: | ||
``` | ||
python train_jrdb.py --cfg configs/jrdb_2dbox.yaml --exp_name jrdb | ||
``` | ||
|
||
To evaluate the trained model, use the following command: | ||
``` | ||
python evaluate_jrdb.py --ckpt ./experiments/jrdb/checkpoints/checkpoint.pth.tar --metric ade_fde --modality traj+2dbox | ||
``` | ||
Please note that the evaluation modality can be one any of `[traj, traj+2dbox]`. | ||
For the ease of use, we have also provided the trained model in the release section of this repo. In order to use that, you should pass the address of the saved checkpoint via `--ckpt`. | ||
|
||
# Work in Progress | ||
|
||
This repository is work-in-progress and will continue to get updated and improved over the coming months. |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
SEED: 0 | ||
TRAIN: | ||
batch_size: 16 | ||
epochs: 100 | ||
num_workers: 0 | ||
input_track_size: 9 | ||
output_track_size: 12 | ||
lr: 0.0001 | ||
lr_decay: 1 | ||
lr_drop: true | ||
aux_weight: 0.2 | ||
val_frequency: 5 | ||
optimizer: "adam" | ||
max_grad_norm: 1.0 | ||
DATA: | ||
train_datasets: | ||
- jrdb_2dbox | ||
MODEL: | ||
seq_len: 30 | ||
token_num: 2 | ||
num_layers_local: 6 | ||
num_layers_global: 3 | ||
num_heads: 4 | ||
dim_hidden: 128 | ||
dim_feedforward: 1024 | ||
type: "transmotion" | ||
eval_single: false | ||
checkpoint: "" | ||
output_scale: 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
SEED: 0 | ||
TRAIN: | ||
batch_size: 4 | ||
epochs: 50 | ||
num_workers: 0 | ||
input_track_size: 9 | ||
output_track_size: 12 | ||
lr: 0.0001 | ||
lr_decay: 1 | ||
lr_drop: true | ||
aux_weight: 0.2 | ||
val_frequency: 5 | ||
optimizer: "adam" | ||
max_grad_norm: 1.0 | ||
DATA: | ||
train_datasets: | ||
- jta_all_visual_cues | ||
MODEL: | ||
seq_len: 435 # 1*21 + (token_num-1)*9 ,seq length for local-former, 219 for 2d/3d pose, 30 for 2d/3d bb, 21 for baseline, 228 for 3dbox+3dpose | ||
token_num: 47 # number of tokens for local-former, 23 or 2d/3d pose, 2 for 2d/3d bb, 1 for baseline | ||
num_layers_local: 6 | ||
num_layers_global: 3 | ||
num_heads: 4 | ||
dim_hidden: 128 | ||
dim_feedforward: 1024 | ||
type: "transmotion" | ||
eval_single: false | ||
checkpoint: "" ##checkpoint.pth.tar | ||
output_scale: 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,194 @@ | ||
import torch | ||
from torch.nn.utils.rnn import pad_sequence | ||
|
||
from utils.data import load_data_jta_all_visual_cues, load_data_jrdb_2dbox | ||
from torchvision import transforms | ||
|
||
def collate_batch(batch): | ||
joints_list = [] | ||
masks_list = [] | ||
num_people_list = [] | ||
for joints, masks in batch: | ||
|
||
joints_list.append(joints) | ||
masks_list.append(masks) | ||
num_people_list.append(torch.zeros(joints.shape[0])) | ||
|
||
joints = pad_sequence(joints_list, batch_first=True) | ||
masks = pad_sequence(masks_list, batch_first=True) | ||
padding_mask = pad_sequence(num_people_list, batch_first=True, padding_value=1).bool() | ||
|
||
return joints, masks, padding_mask | ||
|
||
|
||
def batch_process_coords(coords, masks, padding_mask, config, modality_selection='traj+2dbox', training=False, multiperson=True): | ||
joints = coords.to(config["DEVICE"]) | ||
masks = masks.to(config["DEVICE"]) | ||
in_F = config["TRAIN"]["input_track_size"] | ||
|
||
in_joints_pelvis = joints[:,:, (in_F-1):in_F, 0:1, :].clone() | ||
in_joints_pelvis_last = joints[:,:, (in_F-2):(in_F-1), 0:1, :].clone() | ||
|
||
joints[:,:,:,0] = joints[:,:,:,0] - joints[:,0:1, (in_F-1):in_F, 0] | ||
joints[:,:,:,1:] = (joints[:,:,:,1:] - joints[:,:,(in_F-1):in_F,1:])*0.25 #rescale for BB | ||
|
||
B, N, F, J, K = joints.shape | ||
if not training: | ||
if modality_selection=='traj': | ||
joints[:,:,:,1:]=0 | ||
elif modality_selection=='traj+2dbox': | ||
pass | ||
else: | ||
print('modality error') | ||
exit() | ||
else: | ||
# augment JRDB traj | ||
joints[:,:,:,0,:3] = getRandomRotatePoseTransform(config)(joints[:,:,:,0,:3]) | ||
joints = joints.transpose(1, 2).reshape(B, F, N*J, K) | ||
in_joints_pelvis = in_joints_pelvis.reshape(B, 1, N, K) | ||
in_joints_pelvis_last = in_joints_pelvis_last.reshape(B, 1, N, K) | ||
masks = masks.transpose(1, 2).reshape(B, F, N*J) | ||
|
||
in_F, out_F = config["TRAIN"]["input_track_size"], config["TRAIN"]["output_track_size"] | ||
in_joints = joints[:,:in_F].float() | ||
out_joints = joints[:,in_F:in_F+out_F].float() | ||
in_masks = masks[:,:in_F].float() | ||
out_masks = masks[:,in_F:in_F+out_F].float() | ||
|
||
|
||
return in_joints, in_masks, out_joints, out_masks, padding_mask.float() | ||
|
||
def getRandomRotatePoseTransform(config): | ||
""" | ||
Performs a random rotation about the origin (0, 0, 0) | ||
""" | ||
|
||
def do_rotate(pose_seq): | ||
B, F, J, K = pose_seq.shape | ||
|
||
angles = torch.deg2rad(torch.rand(B)*360) | ||
|
||
rotation_matrix = torch.zeros(B, 3, 3).to(pose_seq.device) | ||
|
||
## rotate around z axis (vertical axis) | ||
rotation_matrix[:,0,0] = torch.cos(angles) | ||
rotation_matrix[:,0,1] = -torch.sin(angles) | ||
rotation_matrix[:,1,0] = torch.sin(angles) | ||
rotation_matrix[:,1,1] = torch.cos(angles) | ||
rotation_matrix[:,2,2] = 1 | ||
|
||
rot_pose = torch.bmm(pose_seq.reshape(B, -1, 3).float(), rotation_matrix) | ||
rot_pose = rot_pose.reshape(pose_seq.shape) | ||
return rot_pose | ||
|
||
return transforms.Lambda(lambda x: do_rotate(x)) | ||
|
||
|
||
|
||
class MultiPersonTrajPoseDataset(torch.utils.data.Dataset): | ||
|
||
|
||
|
||
def __init__(self, name, split="train", track_size=21, track_cutoff=9, segmented=True, | ||
add_flips=False, frequency=1): | ||
|
||
self.name = name | ||
self.split = split | ||
self.track_size = track_size | ||
self.track_cutoff = track_cutoff | ||
self.frequency = frequency | ||
|
||
self.initialize() | ||
|
||
def load_data(self): | ||
raise NotImplementedError("Dataset load_data() method is not implemented.") | ||
|
||
def initialize(self): | ||
self.load_data() | ||
|
||
tracks = [] | ||
for scene in self.datalist: | ||
for seg, j in enumerate(range(0, len(scene[0][0]) - self.track_size * self.frequency + 1, self.track_size)): | ||
people = [] | ||
for person in scene: | ||
start_idx = j | ||
end_idx = start_idx + self.track_size * self.frequency | ||
J_3D_real, J_3D_mask = person[0][start_idx:end_idx:self.frequency], person[1][ | ||
start_idx:end_idx:self.frequency] | ||
people.append((J_3D_real, J_3D_mask)) | ||
tracks.append(people) | ||
self.datalist = tracks | ||
|
||
|
||
def __len__(self): | ||
return len(self.datalist) | ||
|
||
def __getitem__(self, idx): | ||
scene = self.datalist[idx] | ||
|
||
J_3D_real = torch.stack([s[0] for s in scene]) | ||
J_3D_mask = torch.stack([s[1] for s in scene]) | ||
|
||
return J_3D_real, J_3D_mask | ||
|
||
|
||
class JtaAllVisualCuesDataset(MultiPersonTrajPoseDataset): | ||
def __init__(self, **args): | ||
super(JtaAllVisualCuesDataset, self).__init__("jta_all_visual_cues", frequency=1, **args) | ||
|
||
def load_data(self): | ||
|
||
self.data = load_data_jta_all_visual_cues(split=self.split) | ||
self.datalist = [] | ||
for scene in self.data: | ||
joints, mask = scene | ||
people=[] | ||
for n in range(len(joints)): | ||
people.append((torch.from_numpy(joints[n]),torch.from_numpy(mask[n]))) | ||
|
||
self.datalist.append(people) | ||
|
||
class Jrdb2dboxDataset(MultiPersonTrajPoseDataset): | ||
def __init__(self, **args): | ||
super(Jrdb2dboxDataset, self).__init__("jrdb_2dbox", frequency=1, **args) | ||
|
||
def load_data(self): | ||
|
||
self.data = load_data_jrdb_2dbox(split=self.split) | ||
self.datalist = [] | ||
for scene in self.data: | ||
joints, mask = scene | ||
people=[] | ||
for n in range(len(joints)): | ||
people.append((torch.from_numpy(joints[n]),torch.from_numpy(mask[n]))) | ||
|
||
self.datalist.append(people) | ||
|
||
|
||
def create_dataset(dataset_name, logger, **args): | ||
logger.info("Loading dataset " + dataset_name) | ||
|
||
if dataset_name == 'jta_all_visual_cues': | ||
dataset = JtaAllVisualCuesDataset(**args) | ||
elif dataset_name == 'jrdb_2dbox': | ||
dataset = Jrdb2dboxDataset(**args) | ||
else: | ||
raise ValueError(f"Dataset with name '{dataset_name}' not found.") | ||
|
||
return dataset | ||
|
||
|
||
def get_datasets(datasets_list, config, logger): | ||
|
||
in_F, out_F = config['TRAIN']['input_track_size'], config['TRAIN']['output_track_size'] | ||
datasets = [] | ||
for dataset_name in datasets_list: | ||
datasets.append(create_dataset(dataset_name, logger, split="train", track_size=(in_F+out_F), track_cutoff=in_F)) | ||
return datasets | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
Oops, something went wrong.