Skip to content

Commit

Permalink
add complete codebase and readme
Browse files Browse the repository at this point in the history
  • Loading branch information
lrjconan committed Jan 3, 2019
1 parent 7e5cec7 commit fdfc8fe
Show file tree
Hide file tree
Showing 23 changed files with 710 additions and 387 deletions.
61 changes: 47 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,35 +1,68 @@
# LanczosNetwork
This is the PyTorch implementation of [Lanczos Network](https://arxiv.org/pdf/1803.06396.pdf) as described in the following paper:
# Lanczos Network
This is the PyTorch implementation of [Lanczos Network](https://openreview.net/pdf?id=BkedznAqKQ) as described in the following paper:

```
@article{liao2018reviving,
title={Reviving and Improving Recurrent Back-Propagation},
author={Liao, Renjie and Xiong, Yuwen and Fetaya, Ethan and Zhang, Lisa and Yoon, KiJung and Pitkow, Xaq and Urtasun, Raquel and Zemel, Richard},
journal={arXiv preprint arXiv:1803.06396},
year={2018}
@inproceedings{liao2018lanczos,
title={LanczosNet: Multi-Scale Deep Graph Convolutional Networks},
author={Liao, Renjie and Zhao, Zhizhen and Urtasun, Raquel and Zemel, Richard},
booktitle={ICLR},
year={2019}
}
```

We also provide our own implementation of 9 recent graph neural networks on the [QM8](https://arxiv.org/pdf/1504.01966.pdf) benchmark:

* [graph convolution networks for fingerprint](https://papers.nips.cc/paper/5954-convolutional-networks-on-graphs-for-learning-molecular-fingerprints.pdf) (GCNFP)
* [gated graph neural networks](https://arxiv.org/pdf/1511.05493.pdf) (GGNN)
* [diffusion convolutional neural networks](https://arxiv.org/pdf/1511.02136.pdf) (DCNN)
* [Chebyshev networks](https://papers.nips.cc/paper/6081-convolutional-neural-networks-on-graphs-with-fast-localized-spectral-filtering.pdf) (ChebyNet)
* [graph convolutional networks](https://arxiv.org/pdf/1609.02907.pdf) (GCN)
* [message passing neural networks](https://arxiv.org/pdf/1704.01212.pdf) (MPNN)
* [graph sample and aggregate](https://www-cs-faculty.stanford.edu/people/jure/pubs/graphsage-nips17.pdf) (GraphSAGE)
* [graph partition neural networks](https://arxiv.org/pdf/1803.06272.pdf) (GPNN)
* [graph attention networks](https://arxiv.org/pdf/1710.10903.pdf) (GAT)


## Setup
To set up experiments, we need to build our customized operators by running the following scripts:
To set up experiments, we need to download the [preprocessed QM8 data](http://www.cs.toronto.edu/~rjliao/data/qm8.zip) and build our customized operators by running the following scripts:

```
./setup.sh
```
```

**Note**:
We also provide the script ```dataset/get_qm8_data.py``` to preprocess the [raw QM8](http://quantum-machine.org/datasets/) data which requires the installation of [DeepChem](https://github.com/deepchem/deepchem).
It will produce a different train/dev/test split than what we used in the paper due to the randomness of DeepChem.
Therefore, we suggest using our preprocessed data for a fair comparison.


## Dependencies
Python 3, PyTorch(0.4.0)
Python 3, PyTorch(1.0), scipy, sklearn


## Run Demos
* To run experiments ```X``` where ```X``` is one of {```hopfield```, ```cora```, ```pubmed```, ```hypergrad```}:

### Train
* To run the training of experiment ```X``` where ```X``` is one of {```qm8_lanczos_net```, ```qm8_ada_lanczos_net```, ```qm8_cheby_net```, ...}:

```python run_exp.py -c config/X.yaml```


**Notes**:
**Note**:

* Please check the folder ```config``` for a full list of configuration yaml files.
* Most hyperparameters in the configuration yaml file are self-explanatory.
* To switch between BPTT, TBPTT and RBP variants, you need to specify ```grad_method``` in the config file.
* Conjugate gradient based RBP requires support of forward mode auto-differentiation which we only provided for the experiments of Hopfield networks and graph neural networks (GNNs). You can check the comments in ```model/rbp.py``` for more details.

### Test

* After training, you can specify the ```test_model``` field of the configuration yaml file with the path of your best model snapshot, e.g.,

```test_model: exp/qm8_lanczos_net/LanczosNetFixBasisChem_chemistry_2018-Oct-02-11-55-54_25460/model_snapshot_best.pth```

* To run the test of experiments ```X```:

```python run_exp.py -c config/X.yaml -t```


## Cite
Please cite our paper if you use this code in your research work.
Expand Down
19 changes: 10 additions & 9 deletions config/qm8_ada_lanczos_net.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ runner: QM8Runner
use_gpu: true
gpus: [0]
seed: 1234
# seed: 5678
# seed: 9012
dataset:
loader_name: QM8Data
name: chemistry
Expand All @@ -14,15 +16,13 @@ dataset:
num_bond_type: 6
model:
name: AdaLanczosNet
short_diffusion_dist: [3, 5, 7]
long_diffusion_dist: [10, 20, 30]
short_diffusion_dist: [1, 2, 3]
long_diffusion_dist: [5, 7, 10, 20, 30]
num_eig_vec: 20
use_reorthogonalization: false
use_power_iteration_cap: true
# spectral_filter_kind: None
use_power_iteration_cap: false
spectral_filter_kind: MLP
# input_dim: 64
input_dim: 70
input_dim: 64
hidden_dim: [128, 128, 128, 128, 128, 128, 128]
output_dim: 16
num_layer: 7
Expand All @@ -44,8 +44,9 @@ train:
shuffle: true
is_resume: false
resume_model: None
test:
test:
batch_size: 64
num_workers: 0
test_model: exp/qm8_ada_lanczos_net/LanczosNetChem_chemistry_2018-Sep-17-23-29-48_32182/model_snapshot_best.pth

test_model: exp/qm8_ada_lanczos_net/AdaLanczosNet_chemistry_2018-Dec-28-21-46-34_11666/model_snapshot_best.pth
# test_model: exp/qm8_ada_lanczos_net/AdaLanczosNet_chemistry_2018-Dec-30-11-36-25_7857/model_snapshot_best.pth
# test_model: exp/qm8_ada_lanczos_net/AdaLanczosNet_chemistry_2018-Dec-30-11-36-48_8366/model_snapshot_best.pth
52 changes: 52 additions & 0 deletions config/qm8_gpnn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
---
exp_name: qm8_gpnn
exp_dir: exp/qm8_gpnn
runner: QM8Runner
use_gpu: true
gpus: [0]
# seed: 1234
# seed: 5678
seed: 9012
dataset:
loader_name: QM8Data
name: chemistry
data_path: data/QM8/preprocess
meta_data_path: data/QM8/QM8_meta.p
num_atom: 70
num_bond_type: 6
model:
name: GPNN
num_partition: 3
num_prop: 10
num_prop_cluster: 1
num_prop_cut: 1
input_dim: 64
hidden_dim: 128
update_func: GRU
output_dim: 16
msg_func: MLP
aggregate_type: avg
num_layer: 1
loss: MSE
train:
optimizer: Adam
lr_decay: 0.1
lr_decay_steps: [10000]
num_workers: 8
max_epoch: 200
batch_size: 64
display_iter: 100
snapshot_epoch: 10000
valid_epoch: 1
lr: 1.0e-4
wd: 0.0e-4
momentum: 0.9
shuffle: true
is_resume: false
resume_model: None
test:
batch_size: 64
num_workers: 0
# test_model: exp/qm8_gpnn/GPNN_chemistry_2019-Jan-02-20-37-00_18321/model_snapshot_best.pth
# test_model: exp/qm8_gpnn/GPNN_chemistry_2019-Jan-02-23-18-03_18654/model_snapshot_best.pth
test_model: exp/qm8_gpnn/GPNN_chemistry_2019-Jan-02-23-18-20_19110/model_snapshot_best.pth
17 changes: 9 additions & 8 deletions config/qm8_lanczos_net.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ exp_dir: exp/qm8_lanczos_net
runner: QM8Runner
use_gpu: true
gpus: [0]
# seed: 1234
seed: 1234
# seed: 5678
seed: 9012
# seed: 9012
dataset:
loader_name: QM8Data
name: chemistry
Expand All @@ -16,11 +16,12 @@ dataset:
num_bond_type: 6
model:
name: LanczosNet
short_diffusion_dist: []
long_diffusion_dist: [1, 2, 3, 5, 7, 10, 20, 30]
# short_diffusion_dist: []
# long_diffusion_dist: [1, 2, 3, 5, 7, 10, 20, 30]
short_diffusion_dist: [3, 5, 7]
long_diffusion_dist: [10, 20, 30]
num_eig_vec: 20
spectral_filter_kind: MLP
# spectral_filter_kind: None
spectral_filter_kind: MLP
input_dim: 64
hidden_dim: [128, 128, 128, 128, 128, 128, 128]
output_dim: 16
Expand All @@ -46,6 +47,6 @@ train:
test:
batch_size: 64
num_workers: 0
# test_model: exp/qm8_lanczos_net/LanczosNetFixBasisChem_chemistry_2018-Oct-02-11-55-54_25460/model_snapshot_best.pth
test_model: exp/qm8_lanczos_net/LanczosNetFixBasisChem_chemistry_2018-Oct-02-11-55-54_25460/model_snapshot_best.pth
# test_model: exp/qm8_lanczos_net/LanczosNetFixBasisChem_chemistry_2018-Oct-02-14-04-53_18123/model_snapshot_best.pth
test_model: exp/qm8_lanczos_net/LanczosNetFixBasisChem_chemistry_2018-Oct-02-15-53-14_20776/model_snapshot_best.pth
# test_model: exp/qm8_lanczos_net/LanczosNetFixBasisChem_chemistry_2018-Oct-02-15-53-14_20776/model_snapshot_best.pth
12 changes: 12 additions & 0 deletions data/QM8/readme.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
QM8 Dataset
Abstract
Due to its favorable computational efficiency, time-dependent (TD) density functional theory(DFT) enables the prediction of electronic spectra in a high-throughput manner across chemical space. Its predictions, however, can be quite inaccurate. We resolve this issue with machine learning models trained on deviations of reference second-order approximate coupled-cluster (CC2) singles and doubles spectra from TDDFT counterparts, or even from DFT gap. We applied this approach to low-lying singlet-singlet vertical electronic spectra of over 20 000 synthetically feasible small organic molecules with up to eight CONF atoms. The prediction errors decay monotonously as a function of training set size. For a training set of 10 000 molecules, CC2 excitation energies can be reproduced to within ±0.1 eV for the remaining molecules. Analysis of our spectral database via chromophore counting suggests that even higher accuracies can be achieved. Based on the evidence collected, we discuss open challenges associated with data-driven modeling of high-lying spectra and transition intensities.

Download
ftp://ftp.aip.org/epaps/journ_chem_phys/E-JCPSA6-143-043532/gdb8_22k_elec_spec.txt

How to cite
When using this dataset, please make sure to cite the following two papers:

L. Ruddigkeit, R. van Deursen, L. C. Blum, J.-L. Reymond, Enumeration of 166 billion organic small molecules in the chemical universe database GDB-17, J. Chem. Inf. Model. 52, 2864–2875, 2012.
R. Ramakrishnan, M. Hartmann, E. Tapavicza, O. A. von Lilienfeld, Electronic Spectra from TDDFT and Machine Learning in Chemical Space, J. Chem. Phys. 143 084111, 2015.
76 changes: 68 additions & 8 deletions dataset/qm8.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
import pickle
import numpy as np
from utils.spectral_graph_partition import *

__all__ = ['QM8Data']

Expand All @@ -22,7 +23,7 @@ def __init__(self, config, split='train'):
if self.use_eigs:
self.num_eigs = config.model.num_eig_vec

if self.model_name == 'GraphSAGEChem':
if self.model_name == 'GraphSAGE':
self.num_sample_neighbors = config.model.num_sample_neighbors

self.train_data_files = glob.glob(
Expand Down Expand Up @@ -54,6 +55,10 @@ def __len__(self):
return self.num_test

def collate_fn(self, batch):
"""
Collate function for mini-batch
N.B.: we pad all samples to the maximum of the mini-batch
"""
assert isinstance(batch, list)

data = {}
Expand Down Expand Up @@ -84,7 +89,55 @@ def collate_fn(self, batch):
data['label'] = torch.cat(
[torch.from_numpy(bb['label']) for bb in batch], dim=0).float()

if self.model_name == 'GraphSAGEChem':
if self.model_name == 'GPNN':
#########################################################################
# GPNN
# N.B.: one can perform graph partition offline to speed up
#########################################################################
# graph Laplacian of multi-graph: shape (B, N, N, E)
L_multi = np.stack(
[
np.pad(
bb['L_multi'], ((0, pad_node_size[ii]),
(0, pad_node_size[ii]), (0, 0)),
'constant',
constant_values=0.0) for ii, bb in enumerate(batch)
],
axis=0)

# graph Laplacian of simple graph: shape (B, N, N, 1)
L_simple = np.stack(
[
np.expand_dims(
np.pad(
bb['L_simple_4'], (0, pad_node_size[ii]),
'constant',
constant_values=0.0),
axis=3) for ii, bb in enumerate(batch)
],
axis=0)

L = np.concatenate([L_simple, L_multi], axis=3)
data['L'] = torch.from_numpy(L).float()

# graph partition
L_cluster, L_cut = [], []

for ii in range(batch_size):
node_label = spectral_clustering(L_simple[ii, :, :, 0], self.config.model.num_partition)

# Laplacian of clusters and cut
L_cluster_tmp, L_cut_tmp = get_L_cluster_cut(L_simple[ii, :, :, 0], node_label)

L_cluster += [L_cluster_tmp]
L_cut += [L_cut_tmp]

data['L_cluster'] = torch.from_numpy(np.stack(L_cluster, axis=0)).float()
data['L_cut'] = torch.from_numpy(np.stack(L_cut, axis=0)).float()
elif self.model_name == 'GraphSAGE':
#########################################################################
# GraphSAGE
#########################################################################
# N.B.: adjacency mat of GraphSAGE is asymmetric
nonempty_mask = np.zeros((batch_size, batch_node_size, 1))
nn_idx = np.zeros((batch_size, batch_node_size, self.num_sample_neighbors,
Expand All @@ -111,8 +164,11 @@ def collate_fn(self, batch):

data['nn_idx'] = torch.from_numpy(nn_idx).long()
data['nonempty_mask'] = torch.from_numpy(nonempty_mask).float()
elif self.model_name == 'GATChem':
# graph Laplacian of multi-graph: shape (B, N, N, C)
elif self.model_name == 'GAT':
#########################################################################
# GAT
#########################################################################
# graph Laplacian of multi-graph: shape (B, N, N, E)
L_multi = np.stack(
[
np.pad(
Expand All @@ -123,6 +179,7 @@ def collate_fn(self, batch):
],
axis=0)

# graph Laplacian of simple graph: shape (B, N, N, 1)
L_simple = np.stack(
[
np.expand_dims(
Expand Down Expand Up @@ -161,7 +218,10 @@ def adj_to_bias(adj, sizes, nhood=1):

data['L'] = torch.from_numpy(np.stack(L_new, axis=0)).float()
else:
# graph Laplacian of multi-graph: shape (B, N, N, C)
#########################################################################
# All other models
#########################################################################
# graph Laplacian of multi-graph: shape (B, N, N, E)
L_multi = torch.stack([
torch.from_numpy(
np.pad(
Expand All @@ -173,12 +233,12 @@ def adj_to_bias(adj, sizes, nhood=1):

# graph Laplacian of simple graph: shape (B, N, N, 1)
L_simple_key = 'L_simple_4'
if self.model_name == 'DCNNChem':
if self.model_name == 'DCNN':
L_simple_key = 'L_simple_7'
elif self.model_name in ['ChebyNetChem']:
elif self.model_name in ['ChebyNet']:
L_simple_key = 'L_simple_6'

if self.model_name == 'ChebyNetChem':
if self.model_name == 'ChebyNet':
L_simple = torch.stack([
torch.from_numpy(
np.expand_dims(
Expand Down
1 change: 1 addition & 0 deletions model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from model.set2set import *
from model.ggnn import *
from model.gpnn import *
from model.gcn import *
from model.gat import *
from model.dcnn import *
Expand Down
Loading

0 comments on commit fdfc8fe

Please sign in to comment.