TensorFLow or PyTorch? Both!
- GraphGallery
- π What's important
- π Installation
- π€ Implementations
- β‘ Quick Start
- β How to add your datasets
- β How to define your models
- π More Examples
- β Road Map
- π Acknowledgement
GraphGallery is a gallery for benchmarking Graph Neural Networks (GNNs) with TensorFlow 2.x and PyTorch backend. GraphGallery 0.6.x is a total re-write from previous versions, and some things have changed.
NEWS:
- PyG backend and DGL backend now are available in GraphGallery
- GraphGallery now supports Multiple Graph for different tasks
Differences between GraphGallery and Pytorch Geometric (PyG), Deep Graph Library (DGL), etc...
- PyG and DGL are just like TensorFlow while GraphGallery is more like Keras
- GraphGallery is a plug-and-play and user-friendly toolbox
- GraphGallery has high scalaribility for researchers to use
- Build from source (latest version)
git clone https://github.com/EdisonLeeeee/GraphGallery.git
cd GraphGallery
python setup.py install
- Or using pip (stable version)
pip install -U graphgallery
Please refer to examples
more details please refer to GraphData.
fixed datasets
from graphgallery.datasets import Planetoid
# set `verbose=False` to avoid additional outputs
data = Planetoid('cora', verbose=False)
graph = data.graph
# here `splits` is a dict like instance
splits = data.split_nodes()
# splits.train_nodes: training indices: 1D Numpy array
# splits.val_nodes: validation indices: 1D Numpy array
# splits.nodes: testing indices: 1D Numpy array
>>> graph
Graph(adj_matrix(2708, 2708),
node_attr(2708, 1433),
node_label(2708,),
metadata=None, multiple=False)
currently the available datasets are:
>>> data.available_datasets()
('citeseer', 'cora', 'pubmed')
more scalable datasets (stored with .npz
)
from graphgallery.datasets import NPZDataset;
# set `verbose=False` to avoid additional outputs
data = NPZDataset('cora', verbose=False, standardize=False)
graph = data.graph
# here `splits` is a dict like instance
splits = data.split_nodes(random_state=42)
>>> graph
Graph(adj_matrix(2708, 2708),
node_attr(2708, 1433),
node_label(2708,),
metadata=None, multiple=False)
currently the available datasets are:
>>> data.available_datasets()
('citeseer','citeseer_full','cora','cora_ml','cora_full',
'amazon_cs','amazon_photo','coauthor_cs','coauthor_phy',
'polblogs', 'pubmed', 'flickr','blogcatalog','dblp')
- Strided (dense) Tensor
>>> backend()
TensorFlow 2.1.2 Backend
>>> from graphgallery import functional as gf
>>> arr = [1, 2, 3]
>>> gf.astensor(arr)
<tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 2, 3], dtype=int32)>
- Sparse Tensor
>>> import scipy.sparse as sp
>>> sp_matrix = sp.eye(3)
>>> gf.astensor(sp_matrix)
<tensorflow.python.framework.sparse_tensor.SparseTensor at 0x7f1bbc205dd8>
- also works for PyTorch, just like
>>> from graphgallery import set_backend
>>> set_backend('torch') # torch, pytorch or th
PyTorch 1.6.0+cu101 Backend
>>> gf.astensor(arr)
tensor([1, 2, 3])
>>> gf.astensor(sp_matrix)
tensor(indices=tensor([[0, 1, 2],
[0, 1, 2]]),
values=tensor([1., 1., 1.]),
size=(3, 3), nnz=3, layout=torch.sparse_coo)
- To Numpy or Scipy sparse matrix
>>> tensor = gf.astensor(arr)
>>> gf.tensoras(tensor)
array([1, 2, 3])
>>> sp_tensor = gf.astensor(sp_matrix)
>>> gf.tensoras(sp_tensor)
<3x3 sparse matrix of type '<class 'numpy.float32'>'
with 3 stored elements in Compressed Sparse Row format>
- Or even convert one Tensor to another
>>> tensor = gf.astensor(arr, backend="tensorflow") # or "tf" in short
>>> tensor
<tf.Tensor: shape=(3,), dtype=int64, numpy=array([1, 2, 3])>
>>> gf.tensor2tensor(tensor)
tensor([1, 2, 3])
>>> sp_tensor = gf.astensor(sp_matrix, backend="tensorflow") # set backend="tensorflow" to convert to tensorflow tensor
>>> sp_tensor
<tensorflow.python.framework.sparse_tensor.SparseTensor at 0x7efb6836a898>
>>> gf.tensor2tensor(sp_tensor)
tensor(indices=tensor([[0, 1, 2],
[0, 1, 2]]),
values=tensor([1., 1., 1.]),
size=(3, 3), nnz=3, layout=torch.sparse_coo)
from graphgallery.gallery import GCN
model = GCN(graph, attr_transform="normalize_attr", device="CPU", seed=123)
# build your GCN model with default hyper-parameters
model.build()
# train your model. here splits.train_nodes and splits.val_nodes are numpy arrays
# verbose takes 0, 1, 2, 3, 4
history = model.train(splits.train_nodes, splits.val_nodes, verbose=1, epochs=100)
# test your model
# verbose takes 0, 1, 2
results = model.test(splits.nodes, verbose=1)
print(f'Test loss {results.loss:.5}, Test accuracy {results.accuracy:.2%}')
On Cora
dataset:
Training...
100/100 [==============================] - 1s 14ms/step - loss: 1.0161 - acc: 0.9500 - val_loss: 1.4101 - val_acc: 0.7740 - time: 1.4180
Testing...
1/1 [==============================] - 0s 62ms/step - loss: 1.4123 - acc: 0.8120 - time: 0.0620
Test loss 1.4123, Test accuracy 81.20%
- Build your model you can use the following statement to build your model
# one hidden layer with hidden units 32 and activation function RELU
>>> model.build(hiddens=32, activations='relu')
# two hidden layer with hidden units 32, 64 and all activation functions are RELU
>>> model.build(hiddens=[32, 64], activations='relu')
# two hidden layer with hidden units 32, 64 and activation functions RELU and ELU
>>> model.build(hiddens=[32, 64], activations=['relu', 'elu'])
- Train your model
# train with validation
>>> history = model.train(splits.train_nodes, splits.val_nodes, verbose=1, epochs=100)
# train without validation
>>> history = model.train(splits.train_nodes, verbose=1, epochs=100)
here history
is a tensorflow History
instance.
- Test you model
>>> results = model.test(splits.nodes, verbose=1)
Testing...
1/1 [==============================] - 0s 62ms/step - loss: 1.4123 - acc: 0.8120 - time: 0.0620
>>> print(f'Test loss {results.loss:.5}, Test accuracy {results.accuracy:.2%}')
Test loss 1.4123, Test accuracy 81.20%
NOTE: you must install SciencePlots package for a better preview.
import matplotlib.pyplot as plt
with plt.style.context(['science', 'no-latex']):
fig, axes = plt.subplots(1, 2, figsize=(15, 5))
axes[0].plot(history.history['accuracy'], label='Training accuracy', linewidth=3)
axes[0].plot(history.history['val_accuracy'], label='Validation accuracy', linewidth=3)
axes[0].legend(fontsize=20)
axes[0].set_title('Accuracy', fontsize=20)
axes[0].set_xlabel('Epochs', fontsize=20)
axes[0].set_ylabel('Accuracy', fontsize=20)
axes[1].plot(history.history['loss'], label='Training loss', linewidth=3)
axes[1].plot(history.history['val_loss'], label='Validation loss', linewidth=3)
axes[1].legend(fontsize=20)
axes[1].set_title('Loss', fontsize=20)
axes[1].set_xlabel('Epochs', fontsize=20)
axes[1].set_ylabel('Loss', fontsize=20)
plt.autoscale(tight=True)
plt.show()
>>> import graphgallery
>>> graphgallery.backend()
TensorFlow 2.1.0 Backend
>>> graphgallery.set_backend("pytorch")
PyTorch 1.6.0+cu101 Backend
# The following codes are the same with TensorFlow Backend
>>> from graphgallery.gallery import GCN
>>> model = GCN(graph, attr_transform="normalize_attr", device="GPU", seed=123);
>>> model.build()
>>> history = model.train(splits.train_nodes, splits.val_nodes, verbose=1, epochs=100)
Training...
100/100 [==============================] - 0s 5ms/step - loss: 0.6813 - acc: 0.9214 - val_loss: 1.0506 - val_acc: 0.7820 - time: 0.4734
>>> results = model.test(splits.nodes, verbose=1)
Testing...
1/1 [==============================] - 0s 1ms/step - loss: 1.0131 - acc: 0.8220 - time: 0.0013
>>> print(f'Test loss {results.loss:.5}, Test accuracy {results.accuracy:.2%}')
Test loss 1.0131, Test accuracy 82.20%
This is motivated by gnn-benchmark
from graphgallery.data import Graph
# Load the adjacency matrix A, attribute matrix X and labels vector y
# A - scipy.sparse.csr_matrix of shape [n_nodes, n_nodes]
# X - scipy.sparse.csr_matrix or np.ndarray of shape [n_nodes, n_atts]
# y - np.ndarray of shape [n_nodes]
mydataset = Graph(adj_matrix=A, attr_matrix=X, labels=y)
# save dataset
mydataset.to_npz('path/to/mydataset.npz')
# load dataset
mydataset = Graph.from_npz('path/to/mydataset.npz')
You can follow the codes in the folder graphgallery.gallery
and write you models based on:
- TensorFlow
- PyTorch
- PyTorch Geometric (PyG)
- Deep Graph Library (DGL)
NOTE: PyG backend and DGL backend now are supported in GraphGallery!
>>> import graphgallery
>>> graphgallery.set_backend("pyg")
PyTorch Geometric 1.6.1 (PyTorch 1.6.0+cu101) Backend
# The following codes are the same with TensorFlow or PyTorch Backend
>>> from graphgallery.gallery import GCN
>>> model = GCN(graph, attr_transform="normalize_attr", device="GPU", seed=123);
>>> model.build()
>>> history = model.train(splits.train_nodes, splits.val_nodes, verbose=1, epochs=100)
Training...
100/100 [==============================] - 0s 3ms/step - loss: 0.5325 - acc: 0.9643 - val_loss: 1.0034 - val_acc: 0.7980 - time: 0.3101
>>> results = model.test(splits.nodes, verbose=1)
Testing...
1/1 [==============================] - 0s 834us/step - loss: 0.9733 - acc: 0.8130 - time: 8.2737e-04
>>> print(f'Test loss {results.loss:.5}, Test accuracy {results.accuracy:.2%}')
Test loss 0.97332, Test accuracy 81.30%
Please refer to the examples directory.
- Add PyTorch models support
- Add other frameworks (PyG and DGL) support
- Add more GNN models (TF and Torch backend)
- Support for more tasks, e.g.,
graph Classification
andlink prediction
- Support for more types of graphs, e.g., Heterogeneous graph
- Add Docstrings and Documentation (Building)
This project is motivated by Pytorch Geometric, Tensorflow Geometric, Stellargraph and DGL, etc., and the original implementations of the authors, thanks for their excellent works!