forked from lrjconan/LanczosNetwork
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
51 changed files
with
4,365 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,38 @@ | ||
# LanczosNetwork | ||
# LanczosNetwork | ||
This is the PyTorch implementation of [Lanczos Network](https://arxiv.org/pdf/1803.06396.pdf) as described in the following paper: | ||
|
||
``` | ||
@article{liao2018reviving, | ||
title={Reviving and Improving Recurrent Back-Propagation}, | ||
author={Liao, Renjie and Xiong, Yuwen and Fetaya, Ethan and Zhang, Lisa and Yoon, KiJung and Pitkow, Xaq and Urtasun, Raquel and Zemel, Richard}, | ||
journal={arXiv preprint arXiv:1803.06396}, | ||
year={2018} | ||
} | ||
``` | ||
|
||
## Setup | ||
To set up experiments, we need to build our customized operators by running the following scripts: | ||
``` | ||
./setup.sh | ||
``` | ||
|
||
## Dependencies | ||
Python 3, PyTorch(0.4.0) | ||
|
||
|
||
## Run Demos | ||
* To run experiments ```X``` where ```X``` is one of {```hopfield```, ```cora```, ```pubmed```, ```hypergrad```}: | ||
|
||
```python run_exp.py -c config/X.yaml``` | ||
|
||
|
||
**Notes**: | ||
* Most hyperparameters in the configuration yaml file are self-explanatory. | ||
* To switch between BPTT, TBPTT and RBP variants, you need to specify ```grad_method``` in the config file. | ||
* Conjugate gradient based RBP requires support of forward mode auto-differentiation which we only provided for the experiments of Hopfield networks and graph neural networks (GNNs). You can check the comments in ```model/rbp.py``` for more details. | ||
|
||
## Cite | ||
Please cite our paper if you use this code in your research work. | ||
|
||
## Questions/Bugs | ||
Please submit a Github issue or contact [email protected] if you have any questions or find any bugs. |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
--- | ||
exp_name: qm8_ada_lanczos_net | ||
exp_dir: exp/qm8_ada_lanczos_net | ||
runner: QM8Runner | ||
use_gpu: true | ||
gpus: [0] | ||
seed: 1234 | ||
dataset: | ||
loader_name: QM8Data | ||
name: chemistry | ||
data_path: data/QM8/preprocess | ||
meta_data_path: data/QM8/QM8_meta.p | ||
num_atom: 70 | ||
num_bond_type: 6 | ||
model: | ||
name: AdaLanczosNet | ||
short_diffusion_dist: [3, 5, 7] | ||
long_diffusion_dist: [10, 20, 30] | ||
num_eig_vec: 20 | ||
use_reorthogonalization: false | ||
use_power_iteration_cap: true | ||
# spectral_filter_kind: None | ||
spectral_filter_kind: MLP | ||
# input_dim: 64 | ||
input_dim: 70 | ||
hidden_dim: [128, 128, 128, 128, 128, 128, 128] | ||
output_dim: 16 | ||
num_layer: 7 | ||
loss: MSE | ||
output_func: MLP | ||
train: | ||
optimizer: Adam | ||
lr_decay: 0.1 | ||
lr_decay_steps: [10000] | ||
num_workers: 4 | ||
max_epoch: 200 | ||
batch_size: 64 | ||
display_iter: 100 | ||
snapshot_epoch: 10000 | ||
valid_epoch: 1 | ||
lr: 1.0e-4 | ||
wd: 0.0e-4 | ||
momentum: 0.9 | ||
shuffle: true | ||
is_resume: false | ||
resume_model: None | ||
test: | ||
batch_size: 64 | ||
num_workers: 0 | ||
test_model: exp/qm8_ada_lanczos_net/LanczosNetChem_chemistry_2018-Sep-17-23-29-48_32182/model_snapshot_best.pth | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
--- | ||
exp_name: qm8_cheby_net | ||
exp_dir: exp/qm8_cheby_net | ||
runner: QM8Runner | ||
use_gpu: true | ||
gpus: [0] | ||
# seed: 1234 | ||
# seed: 5678 | ||
seed: 9012 | ||
dataset: | ||
loader_name: QM8Data | ||
name: chemistry | ||
data_path: data/QM8/preprocess | ||
meta_data_path: data/QM8/QM8_meta.p | ||
num_atom: 70 | ||
num_bond_type: 6 | ||
model: | ||
name: ChebyNet | ||
input_dim: 64 | ||
hidden_dim: [128, 128, 128, 128, 128, 128, 128] | ||
output_dim: 16 | ||
polynomial_order: 5 | ||
num_layer: 7 | ||
loss: MSE | ||
output_func: MLP | ||
train: | ||
optimizer: Adam | ||
lr_decay: 0.1 | ||
lr_decay_steps: [10000] | ||
num_workers: 4 | ||
max_epoch: 200 | ||
batch_size: 64 | ||
display_iter: 100 | ||
snapshot_epoch: 10000 | ||
valid_epoch: 1 | ||
lr: 1.0e-4 | ||
wd: 0.0e-4 | ||
momentum: 0.9 | ||
shuffle: true | ||
is_resume: false | ||
resume_model: None | ||
test: | ||
batch_size: 64 | ||
num_workers: 0 | ||
# test_model: exp/qm8_cheby_net/ChebyNetChem_chemistry_2018-Sep-26-15-39-38_2519/model_snapshot_best.pth | ||
# test_model: exp/qm8_cheby_net/ChebyNetChem_chemistry_2018-Sep-28-11-52-19_13447/model_snapshot_best.pth | ||
test_model: exp/qm8_cheby_net/ChebyNetChem_chemistry_2018-Sep-28-12-35-37_5959/model_snapshot_best.pth |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
--- | ||
exp_name: qm8_dcnn | ||
exp_dir: exp/qm8_dcnn | ||
runner: QM8Runner | ||
use_gpu: false | ||
gpus: [0] | ||
seed: 1234 | ||
dataset: | ||
loader_name: QM8Data | ||
name: chemistry | ||
data_path: data/QM8/preprocess | ||
meta_data_path: data/QM8/QM8_meta.p | ||
num_atom: 70 | ||
num_bond_type: 6 | ||
model: | ||
name: DCNN | ||
input_dim: 64 | ||
diffusion_dist: [3, 5, 7, 10, 20, 30] | ||
hidden_dim: [128, 128, 128, 128, 128, 128, 128] | ||
output_dim: 16 | ||
num_layer: 7 | ||
loss: MSE | ||
output_func: MLP | ||
train: | ||
optimizer: Adam | ||
lr_decay: 0.1 | ||
lr_decay_steps: [10000] | ||
num_workers: 4 | ||
max_epoch: 200 | ||
batch_size: 64 | ||
display_iter: 100 | ||
snapshot_epoch: 10000 | ||
valid_epoch: 1 | ||
lr: 1.0e-4 | ||
wd: 0.0e-4 | ||
momentum: 0.9 | ||
shuffle: true | ||
is_resume: false | ||
resume_model: None | ||
test: | ||
batch_size: 64 | ||
num_workers: 0 | ||
test_model: exp/qm8_dcnn/DCNNChem_chemistry_2018-Sep-20-22-33-08_21116/model_snapshot_best.pth | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
--- | ||
exp_name: qm8_gat | ||
exp_dir: exp/qm8_gat | ||
runner: QM8Runner | ||
use_gpu: true | ||
gpus: [0] | ||
# seed: 1234 | ||
# seed: 5678 | ||
seed: 9012 | ||
dataset: | ||
loader_name: QM8Data | ||
name: chemistry | ||
data_path: data/QM8/preprocess | ||
meta_data_path: data/QM8/QM8_meta.p | ||
num_atom: 70 | ||
num_bond_type: 6 | ||
model: | ||
name: GAT | ||
input_dim: 64 | ||
hidden_dim: [16, 16, 16, 16, 16, 16, 16] | ||
num_layer: 7 | ||
num_heads: [8, 8, 8, 8, 8, 8, 8] | ||
output_dim: 16 | ||
dropout: 0.0 | ||
loss: MSE | ||
train: | ||
optimizer: Adam | ||
lr_decay: 0.1 | ||
lr_decay_steps: [10000] | ||
num_workers: 4 | ||
max_epoch: 200 | ||
batch_size: 64 | ||
display_iter: 100 | ||
snapshot_epoch: 10000 | ||
valid_epoch: 1 | ||
lr: 1.0e-4 | ||
wd: 0.0e-4 | ||
momentum: 0.9 | ||
shuffle: true | ||
is_resume: false | ||
resume_model: None | ||
test: | ||
batch_size: 64 | ||
num_workers: 0 | ||
# test_model: exp/qm8_gat/GATChem_chemistry_2018-Sep-26-16-28-53_5748/model_snapshot_best.pth | ||
# test_model: exp/qm8_gat/GATChem_chemistry_2018-Sep-30-17-47-57_21569/model_snapshot_best.pth | ||
test_model: exp/qm8_gat/GATChem_chemistry_2018-Oct-01-15-03-58_5262/model_snapshot_best.pth |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
--- | ||
exp_name: qm8_gcn | ||
exp_dir: exp/qm8_gcn | ||
runner: QM8Runner | ||
use_gpu: true | ||
gpus: [0] | ||
# seed: 1234 | ||
# seed: 5678 | ||
seed: 9012 | ||
dataset: | ||
loader_name: QM8Data | ||
name: chemistry | ||
data_path: data/QM8/preprocess | ||
meta_data_path: data/QM8/QM8_meta.p | ||
num_atom: 70 | ||
num_bond_type: 6 | ||
model: | ||
name: GCN | ||
input_dim: 64 | ||
hidden_dim: [128, 128, 128, 128, 128, 128, 128] | ||
output_dim: 16 | ||
num_layer: 7 | ||
loss: MSE | ||
output_func: MLP | ||
train: | ||
optimizer: Adam | ||
lr_decay: 0.1 | ||
lr_decay_steps: [10000] | ||
num_workers: 4 | ||
max_epoch: 200 | ||
batch_size: 64 | ||
display_iter: 100 | ||
snapshot_epoch: 10000 | ||
valid_epoch: 1 | ||
lr: 1.0e-4 | ||
wd: 0.0e-4 | ||
momentum: 0.9 | ||
shuffle: true | ||
is_resume: false | ||
resume_model: None | ||
test: | ||
batch_size: 64 | ||
num_workers: 0 | ||
# test_model: exp/qm8_gcn/GCNChem_chemistry_2018-Sep-26-15-38-02_455/model_snapshot_best.pth | ||
# test_model: exp/qm8_gcn/GCNChem_chemistry_2018-Sep-28-15-35-42_1367/model_snapshot_best.pth | ||
test_model: exp/qm8_gcn/GCNChem_chemistry_2018-Sep-28-16-26-21_3359/model_snapshot_best.pth |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
--- | ||
exp_name: qm8_gcnfp | ||
exp_dir: exp/qm8_gcnfp | ||
runner: QM8Runner | ||
use_gpu: false | ||
gpus: [0] | ||
seed: 1234 | ||
dataset: | ||
loader_name: QM8Data | ||
name: chemistry | ||
data_path: data/QM8/preprocess | ||
meta_data_path: data/QM8/QM8_meta.p | ||
num_atom: 70 | ||
num_bond_type: 6 | ||
model: | ||
name: GCNFP | ||
input_dim: 64 | ||
hidden_dim: [128, 128, 128] | ||
output_dim: 16 | ||
num_layer: 3 | ||
loss: MSE | ||
train: | ||
optimizer: Adam | ||
lr_decay: 0.1 | ||
lr_decay_steps: [10000] | ||
num_workers: 4 | ||
max_epoch: 200 | ||
batch_size: 64 | ||
display_iter: 100 | ||
snapshot_epoch: 10000 | ||
valid_epoch: 1 | ||
lr: 1.0e-4 | ||
wd: 0.0e-4 | ||
momentum: 0.9 | ||
shuffle: true | ||
is_resume: false | ||
resume_model: None | ||
test: | ||
batch_size: 64 | ||
num_workers: 0 | ||
test_model: exp/qm8_gcnfp/GCNFPChem_chemistry_2018-Sep-22-10-57-35_7511/model_snapshot_best.pth |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
--- | ||
exp_name: qm8_ggnn | ||
exp_dir: exp/qm8_ggnn | ||
runner: QM8Runner | ||
use_gpu: true | ||
gpus: [0] | ||
# seed: 1234 | ||
# seed: 5678 | ||
seed: 9012 | ||
dataset: | ||
loader_name: QM8Data | ||
name: chemistry | ||
data_path: data/QM8/preprocess | ||
meta_data_path: data/QM8/QM8_meta.p | ||
num_atom: 70 | ||
num_bond_type: 6 | ||
model: | ||
name: GGNN | ||
num_prop: 15 | ||
input_dim: 64 | ||
hidden_dim: 128 | ||
update_func: GRU | ||
output_dim: 16 | ||
msg_func: MLP | ||
aggregate_type: avg | ||
num_layer: 1 | ||
loss: MSE | ||
train: | ||
optimizer: Adam | ||
lr_decay: 0.1 | ||
lr_decay_steps: [10000] | ||
num_workers: 4 | ||
max_epoch: 200 | ||
batch_size: 64 | ||
display_iter: 100 | ||
snapshot_epoch: 10000 | ||
valid_epoch: 1 | ||
lr: 1.0e-4 | ||
wd: 0.0e-4 | ||
momentum: 0.9 | ||
shuffle: true | ||
is_resume: false | ||
resume_model: None | ||
test: | ||
batch_size: 64 | ||
num_workers: 0 | ||
# test_model: exp/qm8_ggnn/GGNNChem_chemistry_2018-Sep-21-20-14-02_16607/model_snapshot_best.pth | ||
# test_model: exp/qm8_ggnn/GGNNChem_chemistry_2018-Sep-27-17-58-19_26297/model_snapshot_best.pth | ||
test_model: exp/qm8_ggnn/GGNNChem_chemistry_2018-Sep-27-21-28-43_26955/model_snapshot_best.pth |
Oops, something went wrong.