Skip to content

Commit

Permalink
src/ gitignore, yml, readme
Browse files Browse the repository at this point in the history
  • Loading branch information
Armagaan committed May 28, 2024
1 parent 2e7bdb0 commit c7eafc3
Show file tree
Hide file tree
Showing 25 changed files with 2,772 additions and 0 deletions.
21 changes: 21 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
__pycache__/
src/HallOfFame/
plots/
backup/

# Ignore subdirectories other than raw/
data/BAMultiShapesDataset/*
!data/BAMultiShapesDataset/raw/
!data/BAMultiShapesDataset/raw/*

data/MUTAG/*
!data/MUTAG/raw/
!data/MUTAG/raw/*

data/Mutagenicity/*
!data/Mutagenicity/raw/
!data/Mutagenicity/raw/*

data/NCI1/*
!data/NCI1/raw/
!data/NCI1/raw/*
208 changes: 208 additions & 0 deletions GraphTrail.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
name: GraphTrail
channels:
- pyg
- pytorch
- conda-forge
- defaults
dependencies:
- _libgcc_mutex=0.1=conda_forge
- _openmp_mutex=4.5=2_gnu
- _sysroot_linux-64_curr_repodata_hack=3=haa98f57_10
- aiohttp=3.9.5=py310h5eee18b_0
- aiosignal=1.2.0=pyhd3eb1b0_0
- asttokens=2.0.5=pyhd3eb1b0_0
- async-timeout=4.0.3=py310h06a4308_0
- attrs=23.1.0=py310h06a4308_0
- binutils_impl_linux-64=2.38=h2a08ee3_1
- binutils_linux-64=2.38.0=hc2dff05_0
- blas=1.0=mkl
- boost=1.85.0=hb7f781d_1
- brotli=1.0.9=h5eee18b_8
- brotli-bin=1.0.9=h5eee18b_8
- brotli-python=1.0.9=py310h6a678d5_8
- bzip2=1.0.8=h5eee18b_6
- ca-certificates=2024.3.11=h06a4308_0
- certifi=2024.2.2=pyhd8ed1ab_0
- charset-normalizer=2.0.4=pyhd3eb1b0_0
- cloudpickle=3.0.0=pyhd8ed1ab_0
- comm=0.2.1=py310h06a4308_0
- contourpy=1.2.0=py310hdb19cb5_0
- cpuonly=2.0=0
- cudatoolkit=11.8.0=h4ba93d1_13
- cycler=0.11.0=pyhd3eb1b0_0
- cyrus-sasl=2.1.28=h52b45da_1
- dbus=1.13.18=hb2f20db_0
- debugpy=1.6.7=py310h6a678d5_0
- decorator=5.1.1=pyhd3eb1b0_0
- dill=0.3.8=pyhd8ed1ab_0
- exceptiongroup=1.2.0=py310h06a4308_0
- executing=0.8.3=pyhd3eb1b0_0
- expat=2.6.2=h6a678d5_0
- ffmpeg=4.3=hf484d3e_0
- fontconfig=2.14.1=h4c34cd2_2
- fonttools=4.51.0=py310h5eee18b_0
- freetype=2.12.1=h4a9f257_0
- frozenlist=1.4.0=py310h5eee18b_0
- fsspec=2024.3.1=py310h06a4308_0
- gcc_impl_linux-64=11.2.0=h1234567_1
- gcc_linux-64=11.2.0=h5c386dc_0
- glib=2.78.4=h6a678d5_0
- glib-tools=2.78.4=h6a678d5_0
- gmp=6.2.1=h295c915_3
- gnutls=3.6.15=he1e5248_0
- gst-plugins-base=1.14.1=h6a678d5_1
- gstreamer=1.14.1=h5eee18b_1
- gxx_impl_linux-64=11.2.0=h1234567_1
- gxx_linux-64=11.2.0=hc2dff05_0
- icu=73.2=h59595ed_0
- idna=3.7=py310h06a4308_0
- intel-openmp=2023.1.0=hdb19cb5_46306
- ipykernel=6.28.0=py310h06a4308_0
- ipython=8.20.0=py310h06a4308_0
- ipywidgets=8.1.2=py310h06a4308_0
- jedi=0.18.1=py310h06a4308_1
- jinja2=3.1.3=py310h06a4308_0
- joblib=1.4.0=py310h06a4308_0
- jpeg=9e=h5eee18b_1
- jupyter_client=8.6.0=py310h06a4308_0
- jupyter_core=5.5.0=py310h06a4308_0
- jupyterlab_widgets=3.0.10=py310h06a4308_0
- kernel-headers_linux-64=3.10.0=h57e8cba_10
- kiwisolver=1.4.4=py310h6a678d5_0
- krb5=1.20.1=h143b758_1
- lame=3.100=h7b6447c_0
- lcms2=2.12=h3be6417_0
- ld_impl_linux-64=2.38=h1181459_1
- lerc=3.0=h295c915_0
- libboost=1.85.0=hba137d9_1
- libboost-devel=1.85.0=h00ab1b0_1
- libboost-headers=1.85.0=ha770c72_1
- libboost-python=1.85.0=py310he6ccd79_1
- libboost-python-devel=1.85.0=py310hb7f781d_1
- libbrotlicommon=1.0.9=h5eee18b_8
- libbrotlidec=1.0.9=h5eee18b_8
- libbrotlienc=1.0.9=h5eee18b_8
- libclang=14.0.6=default_hc6dbbc7_1
- libclang13=14.0.6=default_he11475f_1
- libcups=2.4.2=h2d74bed_1
- libdeflate=1.17=h5eee18b_1
- libedit=3.1.20230828=h5eee18b_0
- libffi=3.4.4=h6a678d5_1
- libgcc-devel_linux-64=11.2.0=h1234567_1
- libgcc-ng=13.2.0=h77fa898_7
- libgfortran-ng=11.2.0=h00389a5_1
- libgfortran5=11.2.0=h1234567_1
- libglib=2.78.4=hdc74915_0
- libgomp=13.2.0=h77fa898_7
- libiconv=1.16=h5eee18b_3
- libidn2=2.3.4=h5eee18b_0
- libllvm14=14.0.6=hdb19cb5_3
- libpng=1.6.39=h5eee18b_0
- libpq=12.17=hdbd6064_0
- libsodium=1.0.18=h7b6447c_0
- libstdcxx-devel_linux-64=11.2.0=h1234567_1
- libstdcxx-ng=13.2.0=hc0a3c3a_7
- libtasn1=4.19.0=h5eee18b_0
- libtiff=4.5.1=h6a678d5_0
- libunistring=0.9.10=h27cfd23_0
- libuuid=1.41.5=h5eee18b_0
- libwebp-base=1.3.2=h5eee18b_0
- libxcb=1.15=h7f8727e_0
- libxkbcommon=1.0.1=h5eee18b_1
- libxml2=2.10.4=hfdd30dd_2
- libzlib=1.2.13=h4ab18f5_6
- llvmlite=0.42.0=py310h6a678d5_0
- lz4-c=1.9.4=h6a678d5_1
- markupsafe=2.1.3=py310h5eee18b_0
- matplotlib=3.8.4=py310h06a4308_0
- matplotlib-base=3.8.4=py310h1128e8f_0
- matplotlib-inline=0.1.6=py310h06a4308_0
- mkl=2023.1.0=h213fc3f_46344
- mkl-service=2.4.0=py310h5eee18b_1
- mkl_fft=1.3.8=py310h5eee18b_0
- mkl_random=1.2.4=py310hdb19cb5_0
- multidict=6.0.4=py310h5eee18b_0
- multiprocess=0.70.16=py310h2372a71_0
- mysql=5.7.24=h721c034_2
- ncurses=6.4=h6a678d5_0
- nest-asyncio=1.6.0=py310h06a4308_0
- nettle=3.7.3=hbbd107a_1
- networkx=3.1=py310h06a4308_0
- numba=0.59.1=py310h7dc5dd1_0
- numpy=1.26.4=py310h5f9d8c6_0
- numpy-base=1.26.4=py310hb5e798b_0
- openh264=2.1.1=h4ff587b_0
- openjpeg=2.4.0=h3ad879b_0
- openssl=3.3.0=h4ab18f5_3
- packaging=24.0=pyhd8ed1ab_0
- pandas=2.2.2=py310hf9f9076_1
- parso=0.8.3=pyhd3eb1b0_0
- pcre2=10.42=hebb0a14_1
- pexpect=4.8.0=pyhd3eb1b0_3
- pillow=10.3.0=py310h5eee18b_0
- pip=24.0=py310h06a4308_0
- platformdirs=3.10.0=py310h06a4308_0
- ply=3.11=py310h06a4308_0
- prompt-toolkit=3.0.43=py310h06a4308_0
- prompt_toolkit=3.0.43=hd3eb1b0_0
- psutil=5.9.0=py310h5eee18b_0
- ptyprocess=0.7.0=pyhd3eb1b0_2
- pure_eval=0.2.2=pyhd3eb1b0_0
- pybind11-abi=4=hd3eb1b0_1
- pyg=2.5.2=py310_torch_1.13.0_cpu
- pygments=2.15.1=py310h06a4308_1
- pyparsing=3.0.9=py310h06a4308_0
- pyqt=5.15.10=py310h6a678d5_0
- pyqt5-sip=12.13.0=py310h5eee18b_0
- pysocks=1.7.1=py310h06a4308_0
- python=3.10.14=h955ad1f_1
- python-dateutil=2.9.0=pyhd8ed1ab_0
- python-tzdata=2024.1=pyhd8ed1ab_0
- python_abi=3.10=2_cp310
- pytorch=1.13.1=py3.10_cpu_0
- pytorch-mutex=1.0=cpu
- pytz=2024.1=pyhd8ed1ab_0
- pyzmq=25.1.2=py310h6a678d5_0
- qt-main=5.15.2=h53bd1ea_10
- readline=8.2=h5eee18b_0
- requests=2.32.2=py310h06a4308_0
- scikit-learn=1.4.2=py310h1128e8f_1
- scipy=1.13.0=py310h5f9d8c6_0
- seaborn=0.12.2=py310h06a4308_0
- setuptools=69.5.1=py310h06a4308_0
- shap=0.45.1=cuda118py310hd5e5b8b_0
- sip=6.7.12=py310h6a678d5_0
- six=1.16.0=pyh6c4a22f_0
- slicer=0.0.8=pyhd8ed1ab_0
- sqlite=3.45.3=h5eee18b_0
- stack_data=0.2.0=pyhd3eb1b0_0
- sysroot_linux-64=2.17=h57e8cba_10
- tbb=2021.8.0=hdb19cb5_0
- threadpoolctl=2.2.0=pyh0d69192_0
- tk=8.6.14=h39e8969_0
- tomli=2.0.1=py310h06a4308_0
- torchaudio=0.13.1=py310_cpu
- torchvision=0.14.1=py310_cpu
- tornado=6.3.3=py310h5eee18b_0
- tqdm=4.66.4=py310h2f386ee_0
- traitlets=5.7.1=py310h06a4308_0
- typing_extensions=4.11.0=py310h06a4308_0
- tzdata=2024a=h04d1e81_0
- unicodedata2=15.1.0=py310h5eee18b_0
- urllib3=2.2.1=py310h06a4308_0
- wcwidth=0.2.5=pyhd3eb1b0_0
- wheel=0.43.0=py310h06a4308_0
- widgetsnbextension=4.0.10=py310h06a4308_0
- xz=5.4.6=h5eee18b_1
- yarl=1.9.3=py310h5eee18b_0
- zeromq=4.3.5=h6a678d5_0
- zlib=1.2.13=h4ab18f5_6
- zstd=1.5.6=ha6fb4c9_0
- pip:
- click==8.1.7
- juliacall==0.9.20
- juliapkg==0.1.13
- mpmath==1.3.0
- pysr==0.18.4
- semver==3.0.2
- sympy==1.12
92 changes: 92 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,94 @@
# GraphTrail
GrpahTrail: Translating GNN Predictions into Human-Interpretable Logical Rules

NOTE: All commands should be run from `src/`

# Environment
```bash
conda env create -f GraphTrail.yml
```

In case you have some issues with the above command, use the following instead:
```bash
cd src/

conda create -n GraphTrail
conda activate GraphTrail
conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 cpuonly -c pytorch
conda install pyg -c pyg
conda install -c conda-forge shap
conda install networkx matplotlib seaborn ipykernel ipywidgets
pip install pysr
conda install conda-forge::boost
conda install gxx_linux-64

# if files and folder are present
rm -r pygcanl/build/
rm -r pygcanl.egg-info/
rm pygcanl/*.so
pip install -e pygcanl

conda install -c conda-forge multiprocess
```

# Run the code
```bash
cd src/

# Generate training, validation, and test indices for all datasets.
python gen_indices.py

# Train GNN if not already trained.
python train_gnn.py -h
python train_gnn.py ...

# Identify the unique computation trees and create the concept vectors.
python gen_ctree.py -h
python gen_ctree.py ...

# Compute the Shapley values of the computation trees identified in gen_ctree.py
pythopn gen_shap.py -h
python gen_shap.py ...

# Generate formulae over the ctrees identified by gen_shap.py
# You will see some Julia installation on your first run.
python gen_formulae.py -h
python gen_formulae.py ...
```

## Example
```python
cd src/

python gen_indices.py

python train_gnn.py --name MUTAG --arch GIN

python gen_ctree.py --name MUTAG --arch GIN

python gen_shap.py --name MUTAG --arch GIN

python gen_formula.py --name MUTAG --arch GIN
```

# Data
The code will generate some intermediate files and save them under the following directory structure:
```bash
data
├── BAMultiShapesDataset
│   ├── GAT
│   │   ├── add
│   │   │   ├── 0.05
│   │   │   │   ├── 357
│   │   │   │   │   ├── test_indices.pkl
│   │   │   │   │   ├── train_indices.pkl
│   │   │   │   │   └── val_indices.pkl
│   │   │   │   ├── 45
│   │   │   │   │   ├── test_indices.pkl
│   │   │   │   │   ├── train_indices.pkl
│   │   │   │   │   └── val_indices.pkl
│   │   │   │   └── 796
│   │   │   │   ├── test_indices.pkl
│   │   │   │   ├── train_indices.pkl
│   │   │   │   └── val_indices.pkl
```
Binary file added src/data.zip
Binary file not shown.
Loading

0 comments on commit c7eafc3

Please sign in to comment.