Skip to content

Commit

Permalink
4M-21 release
Browse files Browse the repository at this point in the history
Co-authored-by: Oğuzhan Fatih Kar <[email protected]>
Co-authored-by: David Mizrahi <[email protected]>
Co-authored-by: Ali Garjani <[email protected]>
  • Loading branch information
4 people committed Jun 13, 2024
1 parent 43558d1 commit 4600165
Show file tree
Hide file tree
Showing 61 changed files with 15,348 additions and 8,822 deletions.
332 changes: 314 additions & 18 deletions ACKNOWLEDGEMENTS.md

Large diffs are not rendered by default.

206 changes: 136 additions & 70 deletions README.md

Large diffs are not rendered by default.

10 changes: 7 additions & 3 deletions README_DATA.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,16 +71,20 @@ Starting from text-image pairs, we use pseudo labeling to create an aligned mult
| Modality | Model | Homepage |
|-----------------------|--------------------------------------|------------------------------------------------------------------------------------------------------|
| Depth | Omnidata DPT-B-Hybrid (v2) | [link](https://docs.omnidata.vision/pretrained.html#Pretrained-Models) |
| Surface Normals | Omnidata DPT-B-Hybrid (v2) | [link](https://docs.omnidata.vision/pretrained.html#Pretrained-Models) |
| Surface normals | Omnidata DPT-B-Hybrid (v2) | [link](https://docs.omnidata.vision/pretrained.html#Pretrained-Models) |
| Semantic segmentation | Mask2Former Swin-B | [link](https://github.com/facebookresearch/Mask2Former/blob/main/MODEL_ZOO.md#panoptic-segmentation) |
| Bounding boxes | ViTDet ViT-H with Cascade Mask-RCNN | [link](https://github.com/facebookresearch/detectron2/tree/main/projects/ViTDet#cascade-mask-r-cnn) |
| CLIP features | CLIP ViT-B/16 | [link](https://github.com/OpenAI/CLIP#clip) |

| DINOv2 features | DINOv2 ViT-B/14 | [link](https://github.com/facebookresearch/dinov2?tab=readme-ov-file#pretrained-models) |
| ImageBind features | ImageBind ViT-H/14 | [link](https://github.com/facebookresearch/ImageBind?tab=readme-ov-file#imagebind-model) |
| SAM instances | SAM ViT-H | [link](https://github.com/facebookresearch/segment-anything?tab=readme-ov-file#model-checkpoints) |
| 3D human poses & shape| HMR2.0 | [link](https://github.com/shubham-goel/4D-Humans) |
| Color palette | PyPalette | [link](https://github.com/adamgrieger/pypalette) |

## Pre-tokenization

During training, all modalities are maps to sets or sequences of discrete tokens using modality-specific tokenizers. Please refer to [README_TOKENIZATION.md](README_TOKENIZATION.md) for more information. To avoid dataloading and tokenization from becoming a training bottleneck, we instead pre-compute the tokens of all image-like modalities once before training (i.e. pre-tokenization), and then directly load the tokens.

To pre-tokenize any modality, run the provided `save_vq_tokens.py` file with the appropriate arguments.

:information_source: For non-square images or if `--n_crops` is > 1, pre-tokenization requires cropping the original image. Therefore, to ensure that the tokens from all modalities are aligned, we automatically create a `crop_settings` directory with the crop information for all samples the first time that a dataset is tokenized. This information is then used when tokenizing the same dataset with a different modality.
:information_source: For non-square images or if `--n_crops` is > 1, pre-tokenization requires cropping the original image. Therefore, to ensure that the tokens from all modalities are aligned, we automatically create a `crop_settings` directory with the crop information for all samples the first time that a dataset is tokenized. This information is then used when tokenizing the same dataset with a different modality.
9 changes: 5 additions & 4 deletions README_GENERATION.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ We provide information and example scripts on how to perform multimodal chained

#### Scripts and notebooks:
- `run_generation.py`: Script that automatically performs chained X→Y→etc... generation on a given dataset.
- `notebooks/generation.ipynb`: Jupyter notebook that gives some examples for performing generation with 4M models.
- `notebooks/`: Jupyter notebooks that gives some examples for performing generation with 4M-7 and 4M-21 models.


#### Configs:
Expand Down Expand Up @@ -75,7 +75,8 @@ Performing generation with 4M can be complex due to the large range of possibili
- Classifier-free guidance can have a large impact on the generation fidelity, but is most important for input/output pairs that did not have clean aligned training data like images and captions. We found that increasing the guidance scale slightly when doing RGB→X inference (e.g. surface normal or segmentation prediction) can, however, also improve how well the generated modality matches the given input.
- Multi-modal guidance can be an effective tool to balance the influence of different input modalities during the generation process.
- The generation samplers and decoding functions support batching for faster inference.

- The default generation script can only generate a limited number of SAM instances due to limits on the number of tokens the model can handle. To get a denser estimation of SAM instances use the `generate_sam_dense` method (as shown in `notebooks/generation_4M-21.ipynb`). The method performs multiple independent SAM instance predictions and aggregates them into one dense estimation.
- Avoid using the output of `generate_sam_dense` as the condition for generation. The output can contain large number of tokens and using it as the conditioning input can create memory issues.

## Generation script usage

Expand All @@ -88,6 +89,6 @@ OMP_NUM_THREADS=1 torchrun --nproc_per_node=8 run_generation.py -c cfgs/default/
This generates and saves three variants for each prompt in the dataset. Before running this, make sure you either downloaded the 4M and tokenizer checkpoints and pointed the config entries to the right paths, or load the models via Hugging Face Hub.


## Generation notebook
## Generation notebooks

Please see the provided Jupyter notebook in `notebooks/generation.ipynb` for more examples on how to use 4M models for inference / generation. We recommend running it on an A100 GPU, with `xformers` installed.
Please see the provided Jupyter notebooks in `notebooks/` for more examples on how to use 4M-7 and 4M-21 models for inference/generation. We recommend running it on an A100 GPU, with `xformers` installed.
7 changes: 5 additions & 2 deletions README_TRAINING.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,13 @@ OMP_NUM_THREADS=1 torchrun --nproc_per_node=8 run_training_4m_fsdp.py \
```


The training configurations for the 4M models in our paper are:
The training configurations for the 4M models in our papers are:

| Model | # Modalities | # Parameters | # GPUs | Config |
|-|-|-|-|-|
| 4M-B | 7 | 198M | 32x A100 | [link](cfgs/default/4m/models/main/4m-b_mod7_500b.yaml) |
| 4M-B | 7 | 198M | 64x A100 | [link](cfgs/default/4m/models/main/4m-b_mod7_500b.yaml) |
| 4M-L | 7 | 705M | 64x A100 | [link](cfgs/default/4m/models/main/4m-l_mod7_500b.yaml) |
| 4M-XL | 7 | 2.8B | 128x A100 | [link](cfgs/default/4m/models/main/4m-xl_mod7_500b.yaml) |
| 4M-B | 21 | 198M | 64x A100 | [link](cfgs/default/4m/models/main/4m-b_mod21_500b.yaml) |
| 4M-L | 21 | 705M | 64x A100 | [link](cfgs/default/4m/models/main/4m-l_mod21_500b.yaml) |
| 4M-XL | 21 | 2.8B | 128x A100 | [link](cfgs/default/4m/models/main/4m-xl_mod21_500b.yaml) |
Binary file added assets/4M_main_fig_darkmode.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/4M_main_fig_lightmode.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed assets/4m_main_fig.jpg
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Mixture of alphas:
# - all2all with input and target alphas 0.01, 0.1, 1.0, 10.0
# - rgb2all with target alpha 0.5
# - caption and T5 embedding bias (each weighted half)

sampling_weights: [1.0, 1.0, 1.0, 1.0, 1.0, 0.5, 0.5]

alphas_mixture:
rgb@224:
input_alphas: [0.01, 0.1, 1.0, 10.0, 1000.0, 0.05, 0.05]
target_alphas: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
caption:
input_alphas: [0.01, 0.1, 1.0, 10.0, 0.0, 5.0, 0.0]
target_alphas: [0.01, 0.1, 1.0, 10.0, 0.5, 0.0, 0.5]
keep: ['random', 'random', 'random', 'random', 'random', 'all', 'random']
t5_caption:
input_alphas: [0.01, 0.1, 1.0, 10.0, 0.0, 0.0, 5.0]
target_alphas: [0.01, 0.1, 1.0, 10.0, 0.5, 0.5, 0.0]
keep: ['random', 'random', 'random', 'random', 'random', 'random', 'all']
det:
input_alphas: [0.01, 0.1, 1.0, 10.0, 0.0, 0.05, 0.05]
target_alphas: [0.01, 0.1, 1.0, 10.0, 0.5, 0.5, 0.5]
keep: ['random', 'random', 'random', 'random', 'random', 'random', 'random']
tok_rgb@224:
input_alphas: [0.01, 0.1, 1.0, 10.0, 0.0, 0.05, 0.05]
target_alphas: [0.01, 0.1, 1.0, 10.0, 0.5, 0.5, 0.5]
tok_normal@224:
input_alphas: [0.01, 0.1, 1.0, 10.0, 0.0, 0.05, 0.05]
target_alphas: [0.01, 0.1, 1.0, 10.0, 0.5, 0.5, 0.5]
tok_depth@224:
input_alphas: [0.01, 0.1, 1.0, 10.0, 0.0, 0.05, 0.05]
target_alphas: [0.01, 0.1, 1.0, 10.0, 0.5, 0.5, 0.5]
tok_semseg@224:
input_alphas: [0.01, 0.1, 1.0, 10.0, 0.0, 0.05, 0.05]
target_alphas: [0.01, 0.1, 1.0, 10.0, 0.5, 0.5, 0.5]
tok_clip@224:
input_alphas: [0.01, 0.1, 1.0, 10.0, 0.0, 0.05, 0.05]
target_alphas: [0.01, 0.1, 1.0, 10.0, 0.5, 0.5, 0.5]
human_poses:
input_alphas: [0.01, 0.1, 1.0, 10.0, 0.0, 0.05, 0.05]
target_alphas: [0.01, 0.1, 1.0, 10.0, 0.5, 0.5, 0.5]
keep: ['random', 'random', 'random', 'random', 'random', 'random', 'random']
tok_dinov2@224:
input_alphas: [0.01, 0.1, 1.0, 10.0, 0.0, 0.05, 0.05]
target_alphas: [0.01, 0.1, 1.0, 10.0, 0.5, 0.5, 0.5]
tok_dinov2_global:
input_alphas: [0.01, 0.1, 1.0, 10.0, 0.0, 0.05, 0.05]
target_alphas: [0.01, 0.1, 1.0, 10.0, 0.5, 0.5, 0.5]
tok_imagebind@224:
input_alphas: [0.01, 0.1, 1.0, 10.0, 0.0, 0.05, 0.05]
target_alphas: [0.01, 0.1, 1.0, 10.0, 0.5, 0.5, 0.5]
tok_imagebind_global:
input_alphas: [0.01, 0.1, 1.0, 10.0, 0.0, 0.05, 0.05]
target_alphas: [0.01, 0.1, 1.0, 10.0, 0.5, 0.5, 0.5]
tok_canny_edge@224:
input_alphas: [0.01, 0.1, 1.0, 10.0, 0.0, 0.05, 0.05]
target_alphas: [0.01, 0.1, 1.0, 10.0, 0.5, 0.5, 0.5]
tok_sam_edge@224:
input_alphas: [0.01, 0.1, 1.0, 10.0, 0.0, 0.05, 0.05]
target_alphas: [0.01, 0.1, 1.0, 10.0, 0.5, 0.5, 0.5]
color_palette:
input_alphas: [0.01, 0.1, 1.0, 10.0, 0.0, 0.05, 0.05]
target_alphas: [0.01, 0.1, 1.0, 10.0, 0.5, 0.5, 0.5]
keep: ['binary', 'binary', 'binary', 'binary', 'binary', 'binary', 'binary']
metadata:
input_alphas: [0.01, 0.1, 1.0, 10.0, 0.0, 0.05, 0.05]
target_alphas: [0.01, 0.1, 1.0, 10.0, 0.5, 0.5, 0.5]
keep: ['random', 'random', 'random', 'random', 'random', 'random', 'random']
sam_instance:
input_alphas: [0.01, 0.1, 1.0, 10.0, 0.0, 0.05, 0.05]
target_alphas: [0.01, 0.1, 1.0, 10.0, 0.5, 0.5, 0.5]
keep: ['random', 'random', 'random', 'random', 'random', 'random', 'random']
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Mixture of alphas:
# - all2all with input and target alphas 0.01, 0.1, 1.0, 10.0
# - rgb2all with target alpha 0.5
# - caption bias

sampling_weights: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]

alphas_mixture:
rgb@224:
input_alphas: [0.01, 0.1, 1.0, 10.0, 1000.0, 0.05]
target_alphas: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
caption:
input_alphas: [0.01, 0.1, 1.0, 10.0, 0.0, 5.0]
target_alphas: [0.01, 0.1, 1.0, 10.0, 0.5, 0.5]
keep: ['random', 'random', 'random', 'random', 'random', 'all']
det:
input_alphas: [0.01, 0.1, 1.0, 10.0, 0.0, 0.05]
target_alphas: [0.01, 0.1, 1.0, 10.0, 0.5, 0.5]
keep: ['random', 'random', 'random', 'random', 'random', 'random']
tok_rgb@224:
input_alphas: [0.01, 0.1, 1.0, 10.0, 0.0, 0.05]
target_alphas: [0.01, 0.1, 1.0, 10.0, 0.5, 0.5]
tok_normal@224:
input_alphas: [0.01, 0.1, 1.0, 10.0, 0.0, 0.05]
target_alphas: [0.01, 0.1, 1.0, 10.0, 0.5, 0.5]
tok_depth@224:
input_alphas: [0.01, 0.1, 1.0, 10.0, 0.0, 0.05]
target_alphas: [0.01, 0.1, 1.0, 10.0, 0.5, 0.5]
tok_semseg@224:
input_alphas: [0.01, 0.1, 1.0, 10.0, 0.0, 0.05]
target_alphas: [0.01, 0.1, 1.0, 10.0, 0.5, 0.5]
tok_clip@224:
input_alphas: [0.01, 0.1, 1.0, 10.0, 0.0, 0.05]
target_alphas: [0.01, 0.1, 1.0, 10.0, 0.5, 0.5]
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
train:
datasets:
cc12m:
type: multimodal

# Input and output domain names, separated by hyphen
in_domains: caption-t5_caption-det-metadata-rgb@224-tok_rgb@224-tok_normal@224-tok_depth@224-tok_semseg@224-tok_clip@224-human_poses-tok_dinov2@224-tok_dinov2_global-tok_imagebind@224-tok_imagebind_global-tok_sam_edge@224-tok_canny_edge@224-color_palette-sam_instance
out_domains: caption-det-metadata-tok_rgb@224-tok_normal@224-tok_depth@224-tok_semseg@224-tok_clip@224-human_poses-tok_dinov2@224-tok_dinov2_global-tok_imagebind@224-tok_imagebind_global-tok_sam_edge@224-tok_canny_edge@224-color_palette-sam_instance

# Dirichlet alphas concentration parameter for input and output.
# Can be either one value, or one value per input modality separated by hyphen.
input_alphas: null
target_alphas: null
# Path to specific alphas configuration to enable mixture of Dirichlets.
# If provided, overrides input_alphas and target_alphas
alphas_config: "cfgs/default/4m/alphas_mixture/main/mix_mod21_all2allmix_rgb2all_capT5bias.yaml"

# Optionally, min_input_tokens, min_target_tokens, num_input_tokens, num_target_tokens can be specified here
# If so, they will override the values provided in the main config
min_input_tokens: null
min_target_tokens: null
num_input_tokens: 256
num_target_tokens: 256

# Data can either be local or on cloud storage (e.g. S3), see data docs for more info
# Use braceexpand notation to indicate shard range (e.g. shard-{0000..9999}.tar)
# Use brackets to indicate multiple modalities (e.g. [modality1,modality2,modality3])
data_path: 'path/to/training/data/[modality1,modality2,modality3]/shard-{00000..9999}.tar'
use_wds: True # Use webdataset
wds_n_repeats: 4 # Number of repeats for webdataset loader to improve efficiency
wds_shuffle_buffer_tar: 1_000 # Webdatasets shuffle buffer after loading tar files
wds_shuffle_buffer_repeat: 1_000 # Webdatasets shuffle buffer after repeating samples

main_augment_domain: rgb@224 # Select from which modality to get the original full image size (mostly important for resizing bounding boxes)
aligned_captions: True # Align captions to crop_settings
tok_train_aug: True # Apply data augmentation to tokens (if multiple crop settings are available)

# modality_name_map: # Use modality_name_map to define a mapping from a folder name to a modality name
# tok_rgb_folder_name: tok_rgb@224
# tok_depth_folder_nme: tok_depth@224
# ...

coyo700m:
type: multimodal

# Input and output domain names, separated by hyphen
in_domains: caption-det-rgb@224-tok_rgb@224-tok_normal@224-tok_depth@224-tok_semseg@224-tok_clip@224
out_domains: caption-det-tok_rgb@224-tok_normal@224-tok_depth@224-tok_semseg@224-tok_clip@224

# Dirichlet alphas concentration parameter for input and output.
# Can be either one value, or one value per input modality separated by hyphen.
input_alphas: null
target_alphas: null
# Path to specific alphas configuration to enable mixture of Dirichlets.
# If provided, overrides input_alphas and target_alphas
alphas_config: "cfgs/bolt/pretrain/4m/alphas_mixture/all2allmix-oldmod_rgb2all_capbias_v0.yaml" # TODO

# Optionally, min_input_tokens, min_target_tokens, num_input_tokens, num_target_tokens can be specified here
# If so, they will override the values provided in the main config
min_input_tokens: null
min_target_tokens: null
num_input_tokens: 256
num_target_tokens: 256

# Data can either be local or on cloud storage (e.g. S3), see data docs for more info
# Use braceexpand notation to indicate shard range (e.g. shard-{0000..9999}.tar)
# Use brackets to indicate multiple modalities (e.g. [modality1,modality2,modality3])
data_path: 'path/to/training/data/[modality1,modality2,modality3]/shard-{00000..9999}.tar'
use_wds: True # Use webdataset
wds_n_repeats: 1 # Number of repeats for webdataset loader to improve efficiency
wds_shuffle_buffer_tar: 1_000 # Webdatasets shuffle buffer after loading tar files
wds_shuffle_buffer_repeat: 1_000 # Webdatasets shuffle buffer after repeating samples

main_augment_domain: rgb@224 # Select from which modality to get the original full image size (mostly important for resizing bounding boxes)
aligned_captions: True # Align captions to crop_settings
tok_train_aug: True # Apply data augmentation to tokens (if multiple crop settings are available)

# modality_name_map: # Use modality_name_map to define a mapping from a folder name to a modality name
# tok_rgb_folder_name: tok_rgb@224
# tok_depth_folder_nme: tok_depth@224
# ...

c4:
type: huggingface

in_domains: caption
out_domains: caption

input_alphas: "1.0"
target_alphas: "1.0"
alphas_config: null

data_path: '/path/to/c4/en'
shuffle_buffer_load: 1_000

weights: [0.6, 0.2, 0.2] # Sampling weights for the training datasets

val:
datasets:
cc12m:
data_path: 'path/to/val/data'
coyo700m:
data_path: 'path/to/val/data'
c4:
data_path: 'path/to/val/data'
49 changes: 49 additions & 0 deletions cfgs/default/4m/models/main/4m-b_mod21_500b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Config for DDP

# Arch: SwiGLU No Bias
# Modalities: Mix of rgb2all, all2all, caption/T5-biased2all, and C4 text-only
# Datasets: Mix of COYO700M, CC12M, and C4
# To be run on 64 GPUs for batch size = 4096
run_name: auto

# Input & output
num_input_tokens: 256
num_target_tokens: 256
loss_type: mod

# Architecture
model: fm_base_12e_12d_swiglu_nobias
patch_size: 16
input_size: 224
dtype: bfloat16
tokenizer_path: "fourm/utils/tokenizer/trained/text_tokenizer_4m_wordpiece_30k.json"

# Initialization
finetune: '/path/to/4M_checkpoint.pth' # Change me. Initialize 4M-21 training from 4M-7 checkpoint

# Train
epochs: -1
total_tokens: 500 # in billions
opt: adamw
blr: 0.0001 # this is base_lr = 1e-4, lr = base_lr * batch_size / 256
min_blr: 0.
warmup_epochs: -1
warmup_tokens: 10 # in billions
batch_size: 64 # 64 x 64 = 4096

# Data
data_config: "cfgs/default/4m/data/cc12m+coyo+c4/main/mix_mod21_all2allmix_rgb2all_capT5bias_C4.yaml"
s3_data_endpoint: "/path/to/endpoint" # Change me
eval_freq: 1
fixed_eval: True
epoch_size: 10_000_000 # Number of samples per "epoch"

# Saving
save_ckpt_freq: 1
output_dir: 'output/auto'

# Wandb
log_wandb: False # Set to True to log to Weights & Biases
wandb_project: '4m-train'
wandb_entity: null # Change if needed
wandb_run_name: auto
Loading

0 comments on commit 4600165

Please sign in to comment.