From 9cdd4a49270d935b7329516526259ed9bfcd7447 Mon Sep 17 00:00:00 2001 From: Zhe Chen Date: Sun, 18 Feb 2024 13:49:26 +0800 Subject: [PATCH] Bump version to v1.2.1 (#46) 1. Support dynamic sequence length during training 2. Update README.md 3. Update evaluation code --- BLOG.md | 1 - README.md | 4 +- internvl_chat/README.md | 8 +-- .../eval/scienceqa/evaluate_scienceqa.py | 2 +- internvl_chat/internvl/patch/__init__.py | 6 ++- .../internvl/patch/pad_data_collator.py | 49 +++++++++++++++++++ .../internvl/patch/train_sampler_patch.py | 31 ++++++++++++ .../internvl/train/internvl_chat_finetune.py | 45 ++++++++++++----- internvl_chat/pyproject.toml | 2 +- internvl_chat/tools/json2jsonl.py | 3 +- internvl_chat/tools/resize_pos_embed.py | 25 ++++++++++ internvl_chat/zero_stage3_config.json | 12 +++++ 12 files changed, 167 insertions(+), 21 deletions(-) create mode 100644 internvl_chat/internvl/patch/pad_data_collator.py create mode 100644 internvl_chat/internvl/patch/train_sampler_patch.py create mode 100644 internvl_chat/tools/resize_pos_embed.py diff --git a/BLOG.md b/BLOG.md index d32734f1..22d33052 100644 --- a/BLOG.md +++ b/BLOG.md @@ -58,7 +58,6 @@ We released [InternVL-Chat-V1.1](https://huggingface.co/OpenGVLab/InternVL-Chat- image - ## InternVL > Date: 2023/12/12
diff --git a/README.md b/README.md index b726731b..bca7eef6 100644 --- a/README.md +++ b/README.md @@ -332,7 +332,7 @@ InternVL scales up the ViT to _**6B parameters**_ and aligns it with LLM. Multimodal Dialogue (click to expand) - Compared with SOTA VLLMs - + | name | image size | MMMU
(val) | MMMU
(test) | MathVista
(testmini) | MMB
(test) | MMB−CN
(test) | MMVP | MME | ScienceQA
(image) | POPE | TextVQA | SEEDv1
(image) | VizWiz
(test) | GQA
(test) | | ------------------ | ---------- | ------------- | -------------- | ----------------------- | ------------- | ---------------- | ---- | -------- | -------------------- | ---- | ------- | ----------------- | ---------------- | ------------- | | GPT-4V\* | unknown | 56.8 | 55.7 | 49.9 | 77.0 | 74.4 | 38.7 | 1409/517 | - | - | 78.0 | 71.6 | - | - | @@ -343,7 +343,7 @@ InternVL scales up the ViT to _**6B parameters**_ and aligns it with LLM. | | | | | | | | | | | | | | | | | LLaVA-NEXT-34B | 672x672 | 51.1 | 44.7 | 46.5 | 79.3 | 79.0 | - | 1631/397 | 81.8 | 87.7 | 69.5 | 75.9 | 63.8 | 67.1 | | InternVL-Chat-V1.2 | 448x448 | 51.6 | 46.2 | 47.7 | 82.2 | 81.2 | 56.7 | 1672/509 | 83.3 | 88.0 | 69.7 | 75.6 | 60.0 | 64.0 | - + \* denotes proprietary models. MMBench results are collected from the [leaderboard](https://mmbench.opencompass.org.cn/leaderboard). In most benchmarks, InternVL-Chat-V1.2 achieves better performance than LLaVA-NeXT-34B. - Zero-Shot Image Captioning [\[see details\]](./internvl_g#zero-shot-image-captioning) diff --git a/internvl_chat/README.md b/internvl_chat/README.md index bd6cb1f8..85d94577 100644 --- a/internvl_chat/README.md +++ b/internvl_chat/README.md @@ -141,6 +141,8 @@ The hyperparameters used for finetuning are listed in the following table. ## 📊 Evaluation +\* Training set observed. + **MultiModal Benchmark** | model | MME | MMBdev/test | MMB-CNdev/test | POPE | MMVP | MathVista | @@ -151,14 +153,14 @@ The hyperparameters used for finetuning are listed in the following table. | model | MMMUval/test | CMMMUval/test | TinyLVLM | LLaVAbench | MM-Vet | | --------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------- | ------------------------ | ------------------- | --------------------- | ------ | | [InternVL-Chat-V1.1](https://huggingface.co/OpenGVLab/InternVL-Chat-Chinese-V1-1) | 39.1 / 35.3 | 34.8 / 34.0 | 344.5 | 76.3 | 45.0 | -| [InternVL-Chat-V1.2](https://huggingface.co/OpenGVLab/InternVL-Chat-Chinese-V1-2) | 51.6 / [46.2](https://eval.ai/web/challenges/challenge-page/2179/leaderboard/5377) | TODO | 350.3 | - | 48.9 | +| [InternVL-Chat-V1.2](https://huggingface.co/OpenGVLab/InternVL-Chat-Chinese-V1-2) | 51.6 / [46.2](https://eval.ai/web/challenges/challenge-page/2179/leaderboard/5377) | - | 350.3 | - | 48.9 | **Visual Question Answering** | model | VQAv2test | OKVQAval | TextVQAval | VizWizval/test | AI2Dtest | GQAtest | SQAtest | | --------------------------------------------------------------------------------- | -------------------- | ------------------- | --------------------- | ------------------------- | ------------------- | ------------------ | ------------------ | -| [InternVL-Chat-V1.1](https://huggingface.co/OpenGVLab/InternVL-Chat-Chinese-V1-1) | 80.9 | 64.2 | 65.8 | 58.3 / 57.3 | 70.2 | 62.4 | 91.2 | -| [InternVL-Chat-V1.2](https://huggingface.co/OpenGVLab/InternVL-Chat-Chinese-V1-2) | - | 62.5 | 69.7 | 61.9 / 60.0 | 71.6 | 64.0 | 83.3 | +| [InternVL-Chat-V1.1](https://huggingface.co/OpenGVLab/InternVL-Chat-Chinese-V1-1) | 80.9\* | 64.2\* | 65.8 | 58.3 / 57.3 | 70.2\* | 62.4\* | 91.2\* | +| [InternVL-Chat-V1.2](https://huggingface.co/OpenGVLab/InternVL-Chat-Chinese-V1-2) | - | 62.5\* | 69.7 | 61.9 / 60.0 | 71.6\* | 64.0\* | 83.3 | **Image Captioning** diff --git a/internvl_chat/eval/scienceqa/evaluate_scienceqa.py b/internvl_chat/eval/scienceqa/evaluate_scienceqa.py index 257b7fc4..76b4860c 100644 --- a/internvl_chat/eval/scienceqa/evaluate_scienceqa.py +++ b/internvl_chat/eval/scienceqa/evaluate_scienceqa.py @@ -114,7 +114,7 @@ def post_process(pred, option): if v in pred: return k - return random.choice(option_candidate) + return pred def evaluate_chat_model(): diff --git a/internvl_chat/internvl/patch/__init__.py b/internvl_chat/internvl/patch/__init__.py index 68312ef7..02557a54 100644 --- a/internvl_chat/internvl/patch/__init__.py +++ b/internvl_chat/internvl/patch/__init__.py @@ -2,7 +2,11 @@ from .llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn from .llama_rmsnorm_monkey_patch import \ replace_llama_rmsnorm_with_fused_rmsnorm +from .pad_data_collator import pad_data_collator +from .train_sampler_patch import replace_train_sampler __all__ = ['replace_llama_attn_with_flash_attn', 'replace_llama_rmsnorm_with_fused_rmsnorm', - 'replace_llama2_attn_with_flash_attn'] + 'replace_llama2_attn_with_flash_attn', + 'replace_train_sampler', + 'pad_data_collator'] diff --git a/internvl_chat/internvl/patch/pad_data_collator.py b/internvl_chat/internvl/patch/pad_data_collator.py new file mode 100644 index 00000000..03305298 --- /dev/null +++ b/internvl_chat/internvl/patch/pad_data_collator.py @@ -0,0 +1,49 @@ +import numpy as np +import torch + +IGNORE_INDEX = -100 + + +def pad_data_collator(features, pad_id=0): + + first = features[0] + batch = {} + + batch_lens = [feat['input_ids'].shape for feat in features] + max_item_length = max(batch_lens)[0] + for idx in range(len(features)): + feat = features[idx] + temp_input_ids = torch.LongTensor([pad_id] * max_item_length) + temp_input_ids[:feat['input_ids'].shape[0]] = feat['input_ids'] + feat['input_ids'] = temp_input_ids + temp_labels = torch.LongTensor([IGNORE_INDEX] * max_item_length) + temp_labels[:feat['labels'].shape[0]] = feat['labels'] + feat['labels'] = temp_labels + feat['attention_mask'] = feat['input_ids'].ne(pad_id) + + # Special handling for labels. + # Ensure that tensor is created with the correct type + # (it should be automatically the case, but let's make sure of it.) + if 'label' in first and first['label'] is not None: + label = first['label'].item() if isinstance(first['label'], torch.Tensor) else first['label'] + dtype = torch.long if isinstance(label, int) else torch.float + batch['labels'] = torch.tensor([f['label'] for f in features], dtype=dtype) + elif 'label_ids' in first and first['label_ids'] is not None: + if isinstance(first['label_ids'], torch.Tensor): + batch['labels'] = torch.stack([f['label_ids'] for f in features]) + else: + dtype = torch.long if isinstance(first['label_ids'][0], int) else torch.float + batch['labels'] = torch.tensor([f['label_ids'] for f in features], dtype=dtype) + + # Handling of all other possible keys. + # Again, we will use the first element to figure out which key/values are not None for this model. + for k, v in first.items(): + if k not in ('label', 'label_ids') and v is not None and not isinstance(v, str): + if isinstance(v, torch.Tensor): + batch[k] = torch.stack([f[k] for f in features]) + elif isinstance(v, np.ndarray): + batch[k] = torch.tensor(np.stack([f[k] for f in features])) + else: + batch[k] = torch.tensor([f[k] for f in features]) + + return batch diff --git a/internvl_chat/internvl/patch/train_sampler_patch.py b/internvl_chat/internvl/patch/train_sampler_patch.py new file mode 100644 index 00000000..aadbde59 --- /dev/null +++ b/internvl_chat/internvl/patch/train_sampler_patch.py @@ -0,0 +1,31 @@ +from typing import Optional + +import torch +import transformers +from transformers.trainer import (LengthGroupedSampler, RandomSampler, + has_length) + + +# patch trainer +def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: + if self.train_dataset is None or not has_length(self.train_dataset): + return None + # Build the sampler. + if self.args.group_by_length: + lengths = [] + for dataset in self.train_dataset.datasets: + lengths = lengths + dataset.length + model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None + return LengthGroupedSampler( + self.args.train_batch_size * self.args.gradient_accumulation_steps, + dataset=self.train_dataset, + lengths=lengths, + model_input_name=model_input_name, + ) + else: + return RandomSampler(self.train_dataset) + + +def replace_train_sampler(): + transformers.Trainer._get_train_sampler = _get_train_sampler + print('Replace train sampler!!') diff --git a/internvl_chat/internvl/train/internvl_chat_finetune.py b/internvl_chat/internvl/train/internvl_chat_finetune.py index b72ad4a8..a7a8c756 100644 --- a/internvl_chat/internvl/train/internvl_chat_finetune.py +++ b/internvl_chat/internvl/train/internvl_chat_finetune.py @@ -23,8 +23,10 @@ InternVisionModel, InternVLChatConfig, InternVLChatModel) -from internvl.patch import (replace_llama2_attn_with_flash_attn, - replace_llama_rmsnorm_with_fused_rmsnorm) +from internvl.patch import (pad_data_collator, + replace_llama2_attn_with_flash_attn, + replace_llama_rmsnorm_with_fused_rmsnorm, + replace_train_sampler) from internvl.train.dataset import (TCSLoader, WeightedConcatDataset, build_transform) from PIL import Image, ImageFile, PngImagePlugin @@ -39,6 +41,7 @@ # Upgrade transformers to v4.36.2, we don't need it anymore # replace_llama2_attn_with_flash_attn() replace_llama_rmsnorm_with_fused_rmsnorm() +replace_train_sampler() try: from petrel_client.client import Client @@ -182,6 +185,7 @@ def preprocess( tokenizer: transformers.PreTrainedTokenizer, num_image_token: int, text_only: bool = False, + group_by_length: bool = False, ) -> Dict: conv = get_conv_template(template_name) roles = {'human': conv.roles[0], 'gpt': conv.roles[1]} @@ -213,7 +217,7 @@ def preprocess( input_ids = tokenizer( conversations, return_tensors='pt', - padding='max_length', + padding=False if group_by_length else 'max_length', max_length=tokenizer.model_max_length, truncation=True, ).input_ids @@ -283,6 +287,7 @@ def preprocess_mpt( tokenizer: transformers.PreTrainedTokenizer, num_image_token: int, text_only: bool = False, + group_by_length: bool = False, ) -> Dict: conv = get_conv_template(template_name) roles = {'human': conv.roles[0], 'gpt': conv.roles[1]} @@ -314,7 +319,7 @@ def preprocess_mpt( input_ids = tokenizer( conversations, return_tensors='pt', - padding='max_length', + padding=False if group_by_length else 'max_length', max_length=tokenizer.model_max_length, truncation=True, ).input_ids @@ -368,7 +373,7 @@ class LazySupervisedDataset(Dataset): """Dataset for supervised fine-tuning.""" def __init__(self, template_name, meta, tokenizer, tcs_loader, num_image_token, - image_size=224, is_train=True, pad2square=False): + image_size=224, is_train=True, pad2square=False, group_by_length=False): super(LazySupervisedDataset, self).__init__() self.tokenizer = tokenizer self.template_name = template_name @@ -384,6 +389,21 @@ def __init__(self, template_name, meta, tokenizer, tcs_loader, num_image_token, self.root = meta['root'] self.cached_data_dict = {} self.tcs_loader = tcs_loader + self.group_by_length = group_by_length + if self.group_by_length: + self.conv2length = {} + self.length = [] + for data_item in self.raw_data: + conversations = ''.join(data_item.split('conversations')[1:]) + str_length = len(conversations) + if str_length not in self.conv2length: + token_length = tokenizer( + conversations, return_tensors='pt', padding=False, truncation=False, + ).input_ids.size(1) + self.conv2length[str_length] = token_length + else: + token_length = self.conv2length[str_length] + self.length.append(token_length) def __len__(self): return len(self.raw_data) @@ -405,7 +425,7 @@ def multi_modal_get_item(self, data_item): else: preprocess_function = preprocess ret = preprocess_function(self.template_name, [deepcopy(data_item['conversations'])], - self.tokenizer, self.num_image_token) + self.tokenizer, self.num_image_token, group_by_length=self.group_by_length) ret = dict( input_ids=ret['input_ids'][0], labels=ret['labels'][0], @@ -425,7 +445,8 @@ def pure_text_get_item(self, data_item): else: preprocess_function = preprocess ret = preprocess_function(self.template_name, [deepcopy(data_item['conversations'])], - self.tokenizer, self.num_image_token, text_only=True) + self.tokenizer, self.num_image_token, text_only=True, + group_by_length=self.group_by_length) ret = dict( input_ids=ret['input_ids'][0], labels=ret['labels'][0], @@ -455,7 +476,7 @@ def __getitem__(self, i) -> Dict[str, torch.Tensor]: return ret -def build_datasets(data_args, tokenizer, tcs_loader, model): +def build_datasets(data_args, tokenizer, tcs_loader, model, group_by_length=False): datasets = [] lengths = [] ds_collections = json.loads(open(data_args.meta_path).read()) @@ -469,7 +490,8 @@ def build_datasets(data_args, tokenizer, tcs_loader, model): num_image_token=model.num_image_token, image_size=data_args.force_image_size, is_train=ds_collections[ds_name]['data_augment'], - pad2square=data_args.pad2square + pad2square=data_args.pad2square, + group_by_length=group_by_length ) except Exception: logger.info(f'Error in loading dataset: {ds_name}') @@ -623,7 +645,8 @@ def main(): if model_args.grad_checkpoint: model.language_model._set_gradient_checkpointing() - train_dataset = build_datasets(data_args, tokenizer, tcs_loader, model) + train_dataset = build_datasets(data_args, tokenizer, tcs_loader, model, + group_by_length=training_args.group_by_length) def _freeze_params(module): for param in module.parameters(): @@ -672,7 +695,7 @@ def _freeze_params(module): train_dataset=train_dataset if training_args.do_train else None, eval_dataset=None, tokenizer=tokenizer, - data_collator=default_data_collator, + data_collator=default_data_collator if not training_args.group_by_length else pad_data_collator, ) # Training diff --git a/internvl_chat/pyproject.toml b/internvl_chat/pyproject.toml index f612679e..f0491d13 100644 --- a/internvl_chat/pyproject.toml +++ b/internvl_chat/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "internvl_chat" -version = "1.2.0" +version = "1.2.1" description = "Scaling up Vision Foundation Models and Aligning for Generic Visual-Linguistic Tasks." readme = "README.md" requires-python = ">=3.8" diff --git a/internvl_chat/tools/json2jsonl.py b/internvl_chat/tools/json2jsonl.py index 4aa3078d..c57838c8 100644 --- a/internvl_chat/tools/json2jsonl.py +++ b/internvl_chat/tools/json2jsonl.py @@ -11,10 +11,11 @@ data = json.load(open(args.path)) writer = open(args.path.replace('.json', '.jsonl'), 'w') -for item in data: +for idx, item in enumerate(data): conversations = item['conversations'] if conversations[0]['from'] == 'system': item['conversations'] = item['conversations'][1:] + item['id'] = idx writer.write(json.dumps(item, ensure_ascii=False) + '\n') writer.close() diff --git a/internvl_chat/tools/resize_pos_embed.py b/internvl_chat/tools/resize_pos_embed.py new file mode 100644 index 00000000..8c89cebb --- /dev/null +++ b/internvl_chat/tools/resize_pos_embed.py @@ -0,0 +1,25 @@ +import argparse + +import torch +from internvl.model.internvl_chat import InternVLChatModel +from transformers import AutoTokenizer + +argparse = argparse.ArgumentParser() +argparse.add_argument('model_path', type=str, default='') +argparse.add_argument('output_path', type=str, default='') +argparse.add_argument('force_image_size', type=int, default=448) + +args = argparse.parse_args() + +model = InternVLChatModel.from_pretrained(args.model_path, torch_dtype=torch.bfloat16) +model.vision_model.resize_pos_embeddings(old_size=model.config.vision_config.image_size, + new_size=args.force_image_size, + patch_size=14) +model.config.vision_config.image_size = args.force_image_size +model.config.force_image_size = args.force_image_size + +model.save_pretrained(args.output_path) + +tokenizer = AutoTokenizer.from_pretrained(args.model_path) +tokenizer.save_pretrained(args.output_path) +print('finished') diff --git a/internvl_chat/zero_stage3_config.json b/internvl_chat/zero_stage3_config.json index 2d5f0968..5fc009d6 100644 --- a/internvl_chat/zero_stage3_config.json +++ b/internvl_chat/zero_stage3_config.json @@ -23,6 +23,18 @@ "bf16": { "enabled": "auto" }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": "auto", + "betas": [ + 0.9, + 0.999 + ], + "eps": 1e-8, + "weight_decay": "auto" + } + }, "gradient_accumulation_steps": "auto", "gradient_clipping": "auto", "steps_per_print": 2000,