forked from graphdeeplearning/benchmarking-gnns
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata.py
42 lines (36 loc) · 1.35 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
"""
File to load dataset based on user control from main file
"""
from data.superpixels import SuperPixDataset
from data.molecules import MoleculeDataset
from data.TUs import TUsDataset
from data.SBMs import SBMsDataset
from data.TSP import TSPDataset
from data.CitationGraphs import CitationGraphsDataset
def LoadData(DATASET_NAME):
"""
This function is called in the main.py file
returns:
; dataset object
"""
# handling for MNIST or CIFAR Superpixels
if DATASET_NAME == 'MNIST' or DATASET_NAME == 'CIFAR10':
return SuperPixDataset(DATASET_NAME)
# handling for (ZINC) molecule dataset
if DATASET_NAME == 'ZINC':
return MoleculeDataset(DATASET_NAME)
# handling for the TU Datasets
TU_DATASETS = ['COLLAB', 'ENZYMES', 'DD', 'PROTEINS_full']
if DATASET_NAME in TU_DATASETS:
return TUsDataset(DATASET_NAME)
# handling for SBM datasets
SBM_DATASETS = ['SBM_CLUSTER', 'SBM_PATTERN']
if DATASET_NAME in SBM_DATASETS:
return SBMsDataset(DATASET_NAME)
# handling for TSP dataset
if DATASET_NAME == 'TSP':
return TSPDataset(DATASET_NAME)
# handling for the CITATIONGRAPHS Datasets
CITATIONGRAPHS_DATASETS = ['CORA', 'CITESEER', 'PUBMED']
if DATASET_NAME in CITATIONGRAPHS_DATASETS:
return CitationGraphsDataset(DATASET_NAME)