Official implementation of Annealed Multiple Choice Learning: Overcoming limitations of Winner-takes-all with annealing (NeurIPS 2024).
This repository contains the code for reproducing the experiments from our paper, including synthetic data experiments, UCI dataset experiments, and source separation tasks.
├── synthetic_experiments.ipynb # Quick start notebook with synthetic data examples
├── synthetic_and_uci_datasets/ # Code for synthetic and UCI dataset experiments
├── source_separation/ # Source separation experiment implementation
└── images/ # Visualizations and animations
We provide two ways to explore the aMCL framework:
- Quick Start: Use
synthetic_experiment.ipynb
for interactive experimentation and fast prototyping. - Full Pipeline: Reproduce paper results using the following steps:
First create and activate a conda environment with
conda create -n synth_env -y python=3.9.20
conda init
Then, re-launch the shell and run:
conda activate synth_env
pip install -r synthetic_and_uci_datasets/requirements.txt
Then, define your absolute home path as environment variable
export MY_HOME=<YOUR_PATH>/annealed_mcl/synthetic_and_uci_datasets
Then, the training to be performed for reproducing the Figures 1, 2 and 4 of the main paper can be performed through the following commands:
cd ${MY_HOME}/scripts_synthetic ;
./scripts_synthetic_train_three_gaussians_fast.sh ; # Run this to reproduce results with three fixed Gaussians
./scripts_synthetic_train_three_gaussians_changedist_fast.sh ; # Run this to reproduce results with three moving Gaussians
Note: The _fast suffix enables faster training than the original paper's setup, with visually similar plots. Remove _fast to train with the exact setup as described in the paper.
Below are animations comparing WTA and aMCL training dynamics:
Winner-takes-all training dynamics with stochastic gradient descent (see Fig.1)
Annealed Multiple Choice Learning training dynamics with stochastic gradient descent (see Fig.1)
- Setup Environment
For reproducing the results on the UCI datasets, first create and activate a conda environment where you install the needed dependencies:
conda create -n uci_env -y python=3.9.20
conda init
Then, after reloading the shell:
conda activate uci_env
pip install -r synthetic_and_uci_datasets/requirements.txt
Then, define your absolute home path as environment variable
export MY_HOME=<YOUR_PATH>/annealed_mcl/synthetic_and_uci_datasets
- Data Preparation:
- Download UCI datasets from this drive [C].
- Place datasets in
synthetic_and_uci_datasets/data/uci/
- Run Experiments:
The benchmark follows the experimental protocol of previous works [A,B].
Evaluation pipeline:
- Download checkpoints from this drive.
- Place the checkpoints folder (named
checkpoints_uci_amcl
) in the synthetic_and_uci_datasets/ folder. - Run the python scriptin
synthetic_and_uci_datasets/scripts_uci_eval/extract_ckpt_json.py
to extract the checkpoints paths in a json file withpython synthetic_and_uci_datasets/scripts_uci_eval/extract_ckpt_json.py
- Run the following commands to train and evaluate the models:
cd ${MY_HOME}/scripts_uci_eval ;
./1scripts_uci_train_and_eval_loop_protein.sh ;
./1scripts_uci_train_and_eval_loop_year.sh ;
./1scripts_uci_train_and_eval_loop.sh ;
./2scripts_uci_extract_results.sh
Training pipeline:
- To train the models, run the following commands:
cd ${MY_HOME}/scripts_uci ;
./1scripts_uci_train_and_eval_loop_protein.sh ;
./1scripts_uci_train_and_eval_loop_year.sh ;
./1scripts_uci_train_and_eval_loop.sh ;
./2scripts_uci_extract_results.sh
Note that the results of the MCL baseline should match those presented in [D].
[A] Hernandez-Lobato, J. M. and Adams, R. Probabilistic back-propagation for scalable learning of bayesian neural networks. In ICML, pp. 1861–1869. PMLR, 2015.
[B] Lakshminarayanan, B., Pritzel, A., and Blundell, C. Simple and scalable predictive uncertainty estimation using deep ensembles. In NeurIPS, volume 30, 2017.
[C] Han, X., Zheng, H., & Zhou, M. Card: Classification and regression diffusion models. In NeurIPS, volume 35, 2022.
[D] Letzelter, Victor, David Perera, Cédric Rommel, Mathieu Fontaine, Slim Essid, Gael Richard, and Patrick Pérez. "Winner-takes-all learners are geometry-aware conditional density estimators." In IMCL, 2024.
The code to train source separation models on the WSJ0-mix dataset is available in the source_separation
repository. It is structured as follows:
data.py
: pytorch Dataset to load every version of the Wall Street Jounal mix dataset. This dataset can load versions with a fixed number of sources, but also with a variable number of speakers.losses.py
: every losses that are presented in the paper and some additional ones are implemented here:- PIT
- MCL
- aMCL
- EMD
models.py
: implements the separation models. Note that our systems have an additional scoring head which is only useful in the context of separating a variable number of sources.- DPRNN used in the paper
- Sudo Rm-Rf: SOTA separation architecture for low resources source separation used for additional experiments, not presented in the paper
system.py
: Pytorch Lightning system to train and evaluate the separation modelstrain.py
: code to launch to train the systems. It manages the configurations and the exeriments saving.utils.py
: various utilitariesconfig_dprnn.py
: configuration file for DPRNN with each configuration for the experiments conducted in the paper.config_sudo.py
: same with SudoRm-Rf
To train a model:
python train.py --conf_id 001 --backbone dprnn
# Name Version
asteroid 0.6.1.dev0
numpy 1.24.4
pandas 2.0.3
pot 0.9.3
python 3.8.18
pytorch-lightning 1.7.7
scikit-learn 1.3.2
scipy 1.10.1
torch 1.13.1
torchaudio 0.13.1
We welcome contributions! Please feel free to:
- Submit issues for bugs or difficulties
- Create pull requests for improvements
- Suggest better organization or efficiency improvements
If our work helped in your research, fell free to give us a star ⭐ or to cite us with the following bibtex code:
@article{amcl,
title={Annealed Multiple Choice Learning: Overcoming limitations of Winner-takes-all with annealing},
author={Perera, David and Letzelter, Victor and Mariotte, Th{\'e}o and Cort{\'e}s, Adrien and Chen, Mickael and Essid, Slim and Richard, Ga{\"e}l},
journal={Advances in neural information processing systems},
year={2024}
}