Code for associated with the ICML 2024 publication Winner-takes-all learners are geometry-aware conditional density estimators by Victor Letzelter, David Perera, Cédric Rommel, Mathieu Fontaine, Slim Essid, Gaël Richard and Patrick Pérez.
Winner-takes-all training is a simple learning paradigm, which handles ambiguous tasks by predicting a set of plausible hypotheses. Recently, a connection was established between Winner-takes-all training and centroidal Voronoi tessellations, showing that, once trained, hypotheses should quantize optimally the shape of the conditional distribution to predict. However, the best use of these hypotheses for uncertainty quantification is still an open question. In this work, we show how to leverage the appealing geometric properties of the Winner-takes-all learners for conditional density estimation, without modifying its original training scheme. We theoretically establish the advantages of our novel estimator both in terms of quantization and density estimation, and we demonstrate its competitiveness on synthetic and real-world datasets, including audio data.
The file structure is organized as follows:
.
├── README.md
└── toy # Code-base for synthetic data and UCI datasets experiments.
└── ...
└── audio # Code-base for audio experiments.
└── ...
For any machine learning task for which the distribution
For quick overview of score-based Winner-Takes-All based models, please check the README.md of this repository (See the figure below) [A].
[A] Letzelter, V., Fontaine, M., Chen, M., Pérez, P., Essid, S., & Richard, G. (2024). Resilient Multiple Choice Learning: A learned scoring scheme with application to audio scene analysis. Advances in neural information processing systems, 36.
We are concerned here about turning trained Winner-takes-all models into continuous density estimators, which allows a better quantification of prediction uncertainty (e.g., using the NLL). The estimators studied include Kernel-WTA:
and Voronoi-WTA:
where
Let's say you have given a trained scoring-based WTA model and you want to evaluate the quality of its probabilistic predictions. In real-world tasks, only samples from the target distribution
where
- Install the needed requirements.
pip install -r requirements.txt
-
Select a dataset of interest (see in
toy/config/experiment', e.g.,
gauss_not_centered.yaml'). -
Train and test your score-based WTA model.
cd toy/src ; python train.py experiment=<my_dataset.yaml> model.hparams.kde_mode=<kernel_wta_enabled> --model.hparams.scaling_factor=1.0 test=True
Where
-
test
(boolean) indicates whether or not perform the test and the end of the training. -
model.hparams.kde_mode
(boolean): if True, the estimator Kernel WTA is used, otherwise it is Voronoi WTA. -
model.hparams.scaling_factor
(float) indicates the value of$h$ used for evaluation (if model.hparams.h_optimization=False).
Note that by default, the NLL, Risk and MSE are computed, but not the EMD. This can be adjusted with the additionnal flags, e.g., if only the NLL is nedded with --compute_nll=True --compute_emd=False --compute_risk=False --compute_mse=False
.
If you have a given checkpoint located in ckpt_path
(str), the same setup applies in evaluation mode
cd toy/src ; python eval.py data=<my_dataset.yaml> model.hparams.h_optimization=True model.hparams.hmin=0.1 model.hparams.hmax=2.0 model.hparams.h_tol=0.1 mmodel.hparams.batch_size_h_opt=256 model.hparams.limit_val_batches_h_opt=None model.hparams.kde_mode=<kernel_wta_enabled> --model.hparams.scaling_factor=1.0 --compute_nll=True --compute_emd=False --compute_risk=True --compute_mse=True ckpt_path=<your_ckpt_path>
Note that the optimization of the scaling factor h at validation time can be performed through the following parameters:
-
model.hparams.h_optimization
: (Default to False) is a boolean: whether the scaling factor$h$ is optimized at validation time or not (if test=True). -
model.hparams.h_min
: (Default to 0.1) is a float: lower bound on the optimal scaling factor. -
model.hparams.h_max
: 2 is a float: upper bound on the optimal scaling factor. -
model.hparams.h_tol
: 0.1. Tolerance used in the golden section search algorithm. -
model.hparams.batch_size_h_opt
: 256 (integer). Batch size used in the optimization of the scaling factor. -
limit_val_batches_h_opt
: None (float). If not none; indicates the proportion of validation batches to use for the optimization of the scaling factor. -
model.hparams.kde_mode
(boolean): if True, the estimator Kernel WTA is used, otherwise it is Voronoi WTA.
Visualisation of the training dynamics is given in Figure A.
Training dynamic of score-based WTA models
Note that the following commands involves execution of shell scripts, which may require execution permissions. This execution permission can be set for
all the scripts in the folder VoronoiWTA
with
find . -type f -iname "*.sh" -exec chmod +x {} +
First, the evaluation of our proposed estimators was performed on synthetic data, for which the conditional distribution
A visualisation of the synthetic datasets is given in Figure 3.
Synthetic Datasets Predictions (Figure 2 of the main paper)
For reproducing the paper experiments, first define your absolute home path as environment variable
export MY_HOME=<YOUR_PATH>/VoronoiWTA/toy
Where <YOUR_PATH> is your absolute path where the folder VoronoiWTA is stored.
First, the Figure 1 can be reproduced with the command
cd toy/scripts
./-1script_plot_fig1.sh
Where the plot will be saved in toy/logs/train/plot_figure1/<LAUNCH_TIME>_.../rmcl_preds.png
To reproduce the quantitative results, launch the following scripts, which are located in toy/scripts
:
cd toy/scripts ; # Change directory
./1scripts_synthetic_train.sh ; # Launch the trainings
Once the trainings have finished, the results can be extracted with the following command:
./2scripts_synthetic_extract_ckpt.sh ; # Extract the checkpoints paths in a json file
Finally, launch the evaluations and plot the results
./3scripts_synthetic_eval_.sh ; # Launch evaluations
./4scripts_synthetic_extract_results.sh ; # Extract the metrics in csv files (in toy/results_saved_csv)
./5scripts_zador_computation.sh ; # Compute the Zador equivalent
./6scripts_plot_figures_std.sh : # Plot the Figures, and save them in toy/figures.
This should provide in particular the Figures 3 and 4 of the main paper (and also the Figures 8-11 of the appendix).
Quantitative results on the synthetic datasets (Figure 3) of the main paper.
Comparison of Kernel-WTA and Voronoi-WTA with respect to sensitivity to the scaling factor (Figure 4).
Then, the evaluation of our proposed estimators was performed on UCI datasets, for which the conditional distribution
The benchmark follows the experimental protocol of preivous works [A,B]. The UCI regression datasets can be downloaded in this drive [C].
Once the datasets are placed in the `toy/data/uci' folder, the results can be reproduced from the following commands.
cd toy/scripts_uci ; # Change directory
./1scripts_uci_train_and_eval_loop_protein.sh ; # Perform training and evalution on the Protein dataset
./1scripts_uci_train_and_eval_loop_year.sh ; # Perform training and evalution on the Year dataset
./1scripts_uci_train_and_eval_loop.sh ; # Perform training and evalution on the UCI datasets: "boston" "concrete" "energy" "kin8nm" "naval" "power_plant" "wine" "yacht".
./2scripts_uci_extract_results.sh ; # Extract the metric in csv_files (in toy/results_uci).
[A] Hernandez-Lobato, J. M. and Adams, R. Probabilistic back-propagation for scalable learning of bayesian neural networks. In ICML, pp. 1861–1869. PMLR, 2015.
[B] Lakshminarayanan, B., Pritzel, A., and Blundell, C. Simple and scalable predictive uncertainty estimation using deep ensembles. In NeurIPS, volume 30, 2017.
[C] Han, X., Zheng, H., & Zhou, M. Card: Classification and regression diffusion models. In NeurIPS, volume 35, 2022.
For reproducing the experiments on audio data, the reader is invited to check the audio
folder, where the following scripts can be executed. The evaluation was performed on the ANSYN dataset as proposed by [A].
For reproducing the results, first define the environment variable with:
export MY_HOME=<YOUR_PATH>/VoronoiWTA/audio
The following shell scripts can be executed for downloading, pre-processing the data:
cd audio/scripts/official_pipeline ;
./1download_data.sh ; # script to download the dataset
./2preprocessing.sh ; # preprocess the raw audio data by computing spectrograms.
The checkpoints of the pre-trained models are available at this link.
Once downloaded, place the the content of the checkpoints
folder in the root of the repo and run the following command:
mkdir -p audio/log/train/ansim
mv checkpoints/* audio/log/train/ansim/
For re-training the models, please launch the following script:
./3train.sh ; # script to launch the trainings of the different baselines
When lauching the trainings, the logs will be written in the directory audio/logs/train/ansim
.
For extracting the checkpoint paths in a JSON file:
PATH_HISTOGRAM=${MY_HOME}/logs/train/ansim/Histogram # PATH to Histogram results folders.
PATH_WTA=${MY_HOME}/logs/train/ansim/Score_based_WTA # PATH to Score-based WTA results folders.
PATH_MDN=${MY_HOME}/logs/train/ansim/Von_Mises_Fisher # PATH to MDN results folder.
./4extract_ckpt.sh histogram $PATH_HISTOGRAM score_based_wta $PATH_WTA mdn $PATH_MDN
The JSON file will be saved it $PATH_HISTOGRAM/../checkpoints.json
.
For the next, make sure jq
is installed or install it with sudo apt-get install jq
.
Perform the evaluation
./5eval.sh ; # script to launch the evaluation of the different baselines.
Then, the results can be extracted with the following commands.
./6extract_results.sh ; # script to extract results from the different baselines.
./7plot_results.sh # script to plot curves.
[A] Adavanne, S., Politis, A., Nikunen, J., & Virtanen, T. (2018). Sound event localization and detection of overlapping sources using convolutional recurrent neural networks. IEEE Journal of Selected Topics in Signal Processing, 13(1), 34-48.
Note that we used the following Lightning Hydra Template template for our experiments, that we strongly encourage to use as well.
Our code was also inspired from the following repositories:
[A] Resilient Multiple Choice Learning
[B] Pilot
[C] Multimodal Future Prediction
[D] Compactified Voronoi Density Estimator
[E] CARD
[F] Seld Net
This code could be better organized and made more computationally efficient. Feel free to contribute to this repository or report any difficulties or bugs.
If our work helped in your research, please consider citing us with the following bibtex code:
@inproceedings{letzelter2024winner,
title={Winner-takes-all learners are geometry-aware conditional density estimators},
author={Letzelter, Victor and Perera, David and Rommel, C{\'e}dric and Fontaine, Mathieu and Essid, Slim and Richard, Gael and P{\'e}rez, Patrick},
booktitle={International Conference on Machine Learning},
year={2024}
}