-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
25 changed files
with
2,772 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
__pycache__/ | ||
src/HallOfFame/ | ||
plots/ | ||
backup/ | ||
|
||
# Ignore subdirectories other than raw/ | ||
data/BAMultiShapesDataset/* | ||
!data/BAMultiShapesDataset/raw/ | ||
!data/BAMultiShapesDataset/raw/* | ||
|
||
data/MUTAG/* | ||
!data/MUTAG/raw/ | ||
!data/MUTAG/raw/* | ||
|
||
data/Mutagenicity/* | ||
!data/Mutagenicity/raw/ | ||
!data/Mutagenicity/raw/* | ||
|
||
data/NCI1/* | ||
!data/NCI1/raw/ | ||
!data/NCI1/raw/* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,208 @@ | ||
name: GraphTrail | ||
channels: | ||
- pyg | ||
- pytorch | ||
- conda-forge | ||
- defaults | ||
dependencies: | ||
- _libgcc_mutex=0.1=conda_forge | ||
- _openmp_mutex=4.5=2_gnu | ||
- _sysroot_linux-64_curr_repodata_hack=3=haa98f57_10 | ||
- aiohttp=3.9.5=py310h5eee18b_0 | ||
- aiosignal=1.2.0=pyhd3eb1b0_0 | ||
- asttokens=2.0.5=pyhd3eb1b0_0 | ||
- async-timeout=4.0.3=py310h06a4308_0 | ||
- attrs=23.1.0=py310h06a4308_0 | ||
- binutils_impl_linux-64=2.38=h2a08ee3_1 | ||
- binutils_linux-64=2.38.0=hc2dff05_0 | ||
- blas=1.0=mkl | ||
- boost=1.85.0=hb7f781d_1 | ||
- brotli=1.0.9=h5eee18b_8 | ||
- brotli-bin=1.0.9=h5eee18b_8 | ||
- brotli-python=1.0.9=py310h6a678d5_8 | ||
- bzip2=1.0.8=h5eee18b_6 | ||
- ca-certificates=2024.3.11=h06a4308_0 | ||
- certifi=2024.2.2=pyhd8ed1ab_0 | ||
- charset-normalizer=2.0.4=pyhd3eb1b0_0 | ||
- cloudpickle=3.0.0=pyhd8ed1ab_0 | ||
- comm=0.2.1=py310h06a4308_0 | ||
- contourpy=1.2.0=py310hdb19cb5_0 | ||
- cpuonly=2.0=0 | ||
- cudatoolkit=11.8.0=h4ba93d1_13 | ||
- cycler=0.11.0=pyhd3eb1b0_0 | ||
- cyrus-sasl=2.1.28=h52b45da_1 | ||
- dbus=1.13.18=hb2f20db_0 | ||
- debugpy=1.6.7=py310h6a678d5_0 | ||
- decorator=5.1.1=pyhd3eb1b0_0 | ||
- dill=0.3.8=pyhd8ed1ab_0 | ||
- exceptiongroup=1.2.0=py310h06a4308_0 | ||
- executing=0.8.3=pyhd3eb1b0_0 | ||
- expat=2.6.2=h6a678d5_0 | ||
- ffmpeg=4.3=hf484d3e_0 | ||
- fontconfig=2.14.1=h4c34cd2_2 | ||
- fonttools=4.51.0=py310h5eee18b_0 | ||
- freetype=2.12.1=h4a9f257_0 | ||
- frozenlist=1.4.0=py310h5eee18b_0 | ||
- fsspec=2024.3.1=py310h06a4308_0 | ||
- gcc_impl_linux-64=11.2.0=h1234567_1 | ||
- gcc_linux-64=11.2.0=h5c386dc_0 | ||
- glib=2.78.4=h6a678d5_0 | ||
- glib-tools=2.78.4=h6a678d5_0 | ||
- gmp=6.2.1=h295c915_3 | ||
- gnutls=3.6.15=he1e5248_0 | ||
- gst-plugins-base=1.14.1=h6a678d5_1 | ||
- gstreamer=1.14.1=h5eee18b_1 | ||
- gxx_impl_linux-64=11.2.0=h1234567_1 | ||
- gxx_linux-64=11.2.0=hc2dff05_0 | ||
- icu=73.2=h59595ed_0 | ||
- idna=3.7=py310h06a4308_0 | ||
- intel-openmp=2023.1.0=hdb19cb5_46306 | ||
- ipykernel=6.28.0=py310h06a4308_0 | ||
- ipython=8.20.0=py310h06a4308_0 | ||
- ipywidgets=8.1.2=py310h06a4308_0 | ||
- jedi=0.18.1=py310h06a4308_1 | ||
- jinja2=3.1.3=py310h06a4308_0 | ||
- joblib=1.4.0=py310h06a4308_0 | ||
- jpeg=9e=h5eee18b_1 | ||
- jupyter_client=8.6.0=py310h06a4308_0 | ||
- jupyter_core=5.5.0=py310h06a4308_0 | ||
- jupyterlab_widgets=3.0.10=py310h06a4308_0 | ||
- kernel-headers_linux-64=3.10.0=h57e8cba_10 | ||
- kiwisolver=1.4.4=py310h6a678d5_0 | ||
- krb5=1.20.1=h143b758_1 | ||
- lame=3.100=h7b6447c_0 | ||
- lcms2=2.12=h3be6417_0 | ||
- ld_impl_linux-64=2.38=h1181459_1 | ||
- lerc=3.0=h295c915_0 | ||
- libboost=1.85.0=hba137d9_1 | ||
- libboost-devel=1.85.0=h00ab1b0_1 | ||
- libboost-headers=1.85.0=ha770c72_1 | ||
- libboost-python=1.85.0=py310he6ccd79_1 | ||
- libboost-python-devel=1.85.0=py310hb7f781d_1 | ||
- libbrotlicommon=1.0.9=h5eee18b_8 | ||
- libbrotlidec=1.0.9=h5eee18b_8 | ||
- libbrotlienc=1.0.9=h5eee18b_8 | ||
- libclang=14.0.6=default_hc6dbbc7_1 | ||
- libclang13=14.0.6=default_he11475f_1 | ||
- libcups=2.4.2=h2d74bed_1 | ||
- libdeflate=1.17=h5eee18b_1 | ||
- libedit=3.1.20230828=h5eee18b_0 | ||
- libffi=3.4.4=h6a678d5_1 | ||
- libgcc-devel_linux-64=11.2.0=h1234567_1 | ||
- libgcc-ng=13.2.0=h77fa898_7 | ||
- libgfortran-ng=11.2.0=h00389a5_1 | ||
- libgfortran5=11.2.0=h1234567_1 | ||
- libglib=2.78.4=hdc74915_0 | ||
- libgomp=13.2.0=h77fa898_7 | ||
- libiconv=1.16=h5eee18b_3 | ||
- libidn2=2.3.4=h5eee18b_0 | ||
- libllvm14=14.0.6=hdb19cb5_3 | ||
- libpng=1.6.39=h5eee18b_0 | ||
- libpq=12.17=hdbd6064_0 | ||
- libsodium=1.0.18=h7b6447c_0 | ||
- libstdcxx-devel_linux-64=11.2.0=h1234567_1 | ||
- libstdcxx-ng=13.2.0=hc0a3c3a_7 | ||
- libtasn1=4.19.0=h5eee18b_0 | ||
- libtiff=4.5.1=h6a678d5_0 | ||
- libunistring=0.9.10=h27cfd23_0 | ||
- libuuid=1.41.5=h5eee18b_0 | ||
- libwebp-base=1.3.2=h5eee18b_0 | ||
- libxcb=1.15=h7f8727e_0 | ||
- libxkbcommon=1.0.1=h5eee18b_1 | ||
- libxml2=2.10.4=hfdd30dd_2 | ||
- libzlib=1.2.13=h4ab18f5_6 | ||
- llvmlite=0.42.0=py310h6a678d5_0 | ||
- lz4-c=1.9.4=h6a678d5_1 | ||
- markupsafe=2.1.3=py310h5eee18b_0 | ||
- matplotlib=3.8.4=py310h06a4308_0 | ||
- matplotlib-base=3.8.4=py310h1128e8f_0 | ||
- matplotlib-inline=0.1.6=py310h06a4308_0 | ||
- mkl=2023.1.0=h213fc3f_46344 | ||
- mkl-service=2.4.0=py310h5eee18b_1 | ||
- mkl_fft=1.3.8=py310h5eee18b_0 | ||
- mkl_random=1.2.4=py310hdb19cb5_0 | ||
- multidict=6.0.4=py310h5eee18b_0 | ||
- multiprocess=0.70.16=py310h2372a71_0 | ||
- mysql=5.7.24=h721c034_2 | ||
- ncurses=6.4=h6a678d5_0 | ||
- nest-asyncio=1.6.0=py310h06a4308_0 | ||
- nettle=3.7.3=hbbd107a_1 | ||
- networkx=3.1=py310h06a4308_0 | ||
- numba=0.59.1=py310h7dc5dd1_0 | ||
- numpy=1.26.4=py310h5f9d8c6_0 | ||
- numpy-base=1.26.4=py310hb5e798b_0 | ||
- openh264=2.1.1=h4ff587b_0 | ||
- openjpeg=2.4.0=h3ad879b_0 | ||
- openssl=3.3.0=h4ab18f5_3 | ||
- packaging=24.0=pyhd8ed1ab_0 | ||
- pandas=2.2.2=py310hf9f9076_1 | ||
- parso=0.8.3=pyhd3eb1b0_0 | ||
- pcre2=10.42=hebb0a14_1 | ||
- pexpect=4.8.0=pyhd3eb1b0_3 | ||
- pillow=10.3.0=py310h5eee18b_0 | ||
- pip=24.0=py310h06a4308_0 | ||
- platformdirs=3.10.0=py310h06a4308_0 | ||
- ply=3.11=py310h06a4308_0 | ||
- prompt-toolkit=3.0.43=py310h06a4308_0 | ||
- prompt_toolkit=3.0.43=hd3eb1b0_0 | ||
- psutil=5.9.0=py310h5eee18b_0 | ||
- ptyprocess=0.7.0=pyhd3eb1b0_2 | ||
- pure_eval=0.2.2=pyhd3eb1b0_0 | ||
- pybind11-abi=4=hd3eb1b0_1 | ||
- pyg=2.5.2=py310_torch_1.13.0_cpu | ||
- pygments=2.15.1=py310h06a4308_1 | ||
- pyparsing=3.0.9=py310h06a4308_0 | ||
- pyqt=5.15.10=py310h6a678d5_0 | ||
- pyqt5-sip=12.13.0=py310h5eee18b_0 | ||
- pysocks=1.7.1=py310h06a4308_0 | ||
- python=3.10.14=h955ad1f_1 | ||
- python-dateutil=2.9.0=pyhd8ed1ab_0 | ||
- python-tzdata=2024.1=pyhd8ed1ab_0 | ||
- python_abi=3.10=2_cp310 | ||
- pytorch=1.13.1=py3.10_cpu_0 | ||
- pytorch-mutex=1.0=cpu | ||
- pytz=2024.1=pyhd8ed1ab_0 | ||
- pyzmq=25.1.2=py310h6a678d5_0 | ||
- qt-main=5.15.2=h53bd1ea_10 | ||
- readline=8.2=h5eee18b_0 | ||
- requests=2.32.2=py310h06a4308_0 | ||
- scikit-learn=1.4.2=py310h1128e8f_1 | ||
- scipy=1.13.0=py310h5f9d8c6_0 | ||
- seaborn=0.12.2=py310h06a4308_0 | ||
- setuptools=69.5.1=py310h06a4308_0 | ||
- shap=0.45.1=cuda118py310hd5e5b8b_0 | ||
- sip=6.7.12=py310h6a678d5_0 | ||
- six=1.16.0=pyh6c4a22f_0 | ||
- slicer=0.0.8=pyhd8ed1ab_0 | ||
- sqlite=3.45.3=h5eee18b_0 | ||
- stack_data=0.2.0=pyhd3eb1b0_0 | ||
- sysroot_linux-64=2.17=h57e8cba_10 | ||
- tbb=2021.8.0=hdb19cb5_0 | ||
- threadpoolctl=2.2.0=pyh0d69192_0 | ||
- tk=8.6.14=h39e8969_0 | ||
- tomli=2.0.1=py310h06a4308_0 | ||
- torchaudio=0.13.1=py310_cpu | ||
- torchvision=0.14.1=py310_cpu | ||
- tornado=6.3.3=py310h5eee18b_0 | ||
- tqdm=4.66.4=py310h2f386ee_0 | ||
- traitlets=5.7.1=py310h06a4308_0 | ||
- typing_extensions=4.11.0=py310h06a4308_0 | ||
- tzdata=2024a=h04d1e81_0 | ||
- unicodedata2=15.1.0=py310h5eee18b_0 | ||
- urllib3=2.2.1=py310h06a4308_0 | ||
- wcwidth=0.2.5=pyhd3eb1b0_0 | ||
- wheel=0.43.0=py310h06a4308_0 | ||
- widgetsnbextension=4.0.10=py310h06a4308_0 | ||
- xz=5.4.6=h5eee18b_1 | ||
- yarl=1.9.3=py310h5eee18b_0 | ||
- zeromq=4.3.5=h6a678d5_0 | ||
- zlib=1.2.13=h4ab18f5_6 | ||
- zstd=1.5.6=ha6fb4c9_0 | ||
- pip: | ||
- click==8.1.7 | ||
- juliacall==0.9.20 | ||
- juliapkg==0.1.13 | ||
- mpmath==1.3.0 | ||
- pysr==0.18.4 | ||
- semver==3.0.2 | ||
- sympy==1.12 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,94 @@ | ||
# GraphTrail | ||
GrpahTrail: Translating GNN Predictions into Human-Interpretable Logical Rules | ||
|
||
NOTE: All commands should be run from `src/` | ||
|
||
# Environment | ||
```bash | ||
conda env create -f GraphTrail.yml | ||
``` | ||
|
||
In case you have some issues with the above command, use the following instead: | ||
```bash | ||
cd src/ | ||
|
||
conda create -n GraphTrail | ||
conda activate GraphTrail | ||
conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 cpuonly -c pytorch | ||
conda install pyg -c pyg | ||
conda install -c conda-forge shap | ||
conda install networkx matplotlib seaborn ipykernel ipywidgets | ||
pip install pysr | ||
conda install conda-forge::boost | ||
conda install gxx_linux-64 | ||
|
||
# if files and folder are present | ||
rm -r pygcanl/build/ | ||
rm -r pygcanl.egg-info/ | ||
rm pygcanl/*.so | ||
pip install -e pygcanl | ||
|
||
conda install -c conda-forge multiprocess | ||
``` | ||
|
||
# Run the code | ||
```bash | ||
cd src/ | ||
|
||
# Generate training, validation, and test indices for all datasets. | ||
python gen_indices.py | ||
|
||
# Train GNN if not already trained. | ||
python train_gnn.py -h | ||
python train_gnn.py ... | ||
|
||
# Identify the unique computation trees and create the concept vectors. | ||
python gen_ctree.py -h | ||
python gen_ctree.py ... | ||
|
||
# Compute the Shapley values of the computation trees identified in gen_ctree.py | ||
pythopn gen_shap.py -h | ||
python gen_shap.py ... | ||
|
||
# Generate formulae over the ctrees identified by gen_shap.py | ||
# You will see some Julia installation on your first run. | ||
python gen_formulae.py -h | ||
python gen_formulae.py ... | ||
``` | ||
|
||
## Example | ||
```python | ||
cd src/ | ||
|
||
python gen_indices.py | ||
|
||
python train_gnn.py --name MUTAG --arch GIN | ||
|
||
python gen_ctree.py --name MUTAG --arch GIN | ||
|
||
python gen_shap.py --name MUTAG --arch GIN | ||
|
||
python gen_formula.py --name MUTAG --arch GIN | ||
``` | ||
|
||
# Data | ||
The code will generate some intermediate files and save them under the following directory structure: | ||
```bash | ||
data | ||
├── BAMultiShapesDataset | ||
│ ├── GAT | ||
│ │ ├── add | ||
│ │ │ ├── 0.05 | ||
│ │ │ │ ├── 357 | ||
│ │ │ │ │ ├── test_indices.pkl | ||
│ │ │ │ │ ├── train_indices.pkl | ||
│ │ │ │ │ └── val_indices.pkl | ||
│ │ │ │ ├── 45 | ||
│ │ │ │ │ ├── test_indices.pkl | ||
│ │ │ │ │ ├── train_indices.pkl | ||
│ │ │ │ │ └── val_indices.pkl | ||
│ │ │ │ └── 796 | ||
│ │ │ │ ├── test_indices.pkl | ||
│ │ │ │ ├── train_indices.pkl | ||
│ │ │ │ └── val_indices.pkl | ||
``` |
Binary file not shown.
Oops, something went wrong.