Skip to content

AdaMTL: Adaptive Input-dependent Inference for Efficient Multi-Task Learning

License

Notifications You must be signed in to change notification settings

scale-lab/AdaMTL

Folders and files

NameName
Last commit message
Last commit date

Latest commit

5875898 · Mar 12, 2024

History

9 Commits
Jun 8, 2023
Jun 8, 2023
Jun 8, 2023
Jun 8, 2023
Jun 9, 2023
Jun 8, 2023
Jun 8, 2023
Mar 12, 2024
Jun 8, 2023
Jun 8, 2023
Jun 8, 2023
Jun 9, 2023
Jun 8, 2023
Jun 8, 2023
Nov 13, 2023
Jun 8, 2023

Repository files navigation

AdaMTL: Adaptive Input-dependent Inference for Efficient Multi-Task Learning

Introduction

This is the official implementation of the paper: AdaMTL: Adaptive Input-dependent Inference for Efficient Multi-Task Learning.

This repository provides a Python-based implementation of the adaptive multi-task learning (MTL) approach proposed in the paper. Our method is designed to improve efficiency in multi-task learning by adapting inference based on input, reducing computational requirements and improving performance across multiple tasks. The repository is based upon Swin-Transformer and uses some modules from Multi-Task-Learning-PyTorch.

How to Run

To run the AdaMTL code, follow these steps:

  1. Clone the repository

    git clone https://github.com/scale-lab/AdaMTL.git
    cd AdaMTL
  2. Install the prerequisites

    • Install PyTorch>=1.12.0 and torchvision>=0.13.0 with CUDA>=11.6
    • Install dependencies: pip install -r requirements.txt
  3. Run the code

    Stage 1: Training the backbone: python main.py --cfg configs/swin/<swin variant>.yaml --pascal <path to pascal database> --tasks semseg,normals,sal,human_parts --batch-size <batch size> --ckpt-freq=20 --epoch=1000 --resume-backbone <path to swin weights>

    Stage 2: Controller pretraining: python main.py --cfg configs/ada_swin/<swin variant>_<tag/taw>_pretrain.yaml --pascal <path to pascal database> --tasks semseg,normals,sal,human_parts --batch-size <batch size> --ckpt-freq=20 --epoch=100 --resume <path to the weights generated from Stage 1>

    Stage 3: MTL model training: python main.py --cfg configs/ada_swin/<swin variant>_<tag/taw>.yaml --pascal <path to pascal database> --tasks semseg,normals,sal,human_parts --batch-size <batch size> --ckpt-freq=20 --epoch=300 --resume <path to the weights generated from Stage 2>

    Swin variants and their weights can be found at the official Swin Transformer repository.

    The outputs will be saved in output/ folder unless overridden by the argument --output.

Authorship

Since the release commit is squashed, the GitHub contributors tab doesn't reflect the authors' contributions. The following authors contributed equally to this codebase:

Citation

If you find AdaMTL helpful in your research, please cite our paper:

@inproceedings{neseem2023adamtl,
  title={AdaMTL: Adaptive Input-dependent Inference for Efficient Multi-Task Learning},
  author={Neseem, Marina and Agiza, Ahmed and Reda, Sherief},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
  pages={4729--4738},
  year={2023}
}

License

MIT License. See LICENSE file

About

AdaMTL: Adaptive Input-dependent Inference for Efficient Multi-Task Learning

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published