From 919995d5c69cf8544b3c66344771865240a3791d Mon Sep 17 00:00:00 2001 From: n1ck-guo Date: Thu, 19 Dec 2024 01:03:52 -0500 Subject: [PATCH 1/2] support for more vlms Signed-off-by: n1ck-guo --- auto_round/mllm/autoround_mllm.py | 10 ++++-- auto_round/mllm/mllm_dataset.py | 3 +- auto_round/mllm/processor.py | 45 ++++++++++++++++++++++++++ auto_round/mllm/templates/default.json | 2 +- auto_round/script/mllm.py | 13 +++++++- auto_round/utils.py | 13 +++++--- 6 files changed, 76 insertions(+), 10 deletions(-) diff --git a/auto_round/mllm/autoround_mllm.py b/auto_round/mllm/autoround_mllm.py index ccd1bc7a..0979be06 100644 --- a/auto_round/mllm/autoround_mllm.py +++ b/auto_round/mllm/autoround_mllm.py @@ -37,8 +37,8 @@ def _only_text_test(model, tokenizer, device): text = ["only text", "test"] tokenizer.padding_side = 'left' if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - if device.split(':')[0] != model.device.type: + tokenizer.pad_token = tokenizer.eos_token if tokenizer.eos_token else tokenizer.unk_token + if device.split(':')[0] != model.device.type: # TODO: OOM model = model.to(device) inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(model.device) model(**inputs) @@ -158,6 +158,9 @@ def __init__( self.template, model=model, tokenizer=tokenizer, processor=processor, image_processor=image_processor) dataset = self.template.default_dataset if dataset is None else dataset + if model.config.model_type == "deepseek_vl_v2": + model.forward = model.language.forward + from ..calib_dataset import CALIB_DATASETS from .mllm_dataset import MLLM_DATASET if isinstance(dataset, str): @@ -256,6 +259,7 @@ def calib(self, nsamples, bs): template=self.template, model=self.model, tokenizer=self.tokenizer, + processor=self.processor, image_processor=self.image_processor, dataset=dataset, extra_data_dir=self.extra_data_dir, @@ -324,7 +328,7 @@ def calib(self, nsamples, bs): data_new = {} for key in data.keys(): data_new[key] = to_device(data[key], self.model.device) - if key == 'images': + if key in ['images', 'pixel_values']: data_new[key] = to_dtype(data_new[key], self.model.dtype) input_ids = data_new["input_ids"] diff --git a/auto_round/mllm/mllm_dataset.py b/auto_round/mllm/mllm_dataset.py index 400f5773..69b27803 100644 --- a/auto_round/mllm/mllm_dataset.py +++ b/auto_round/mllm/mllm_dataset.py @@ -190,6 +190,7 @@ def get_mllm_dataloader( template, model, tokenizer, + processor, image_processor=None, dataset="liuhaotian/llava_conv_58k", extra_data_dir=None, @@ -222,7 +223,7 @@ def get_mllm_dataloader( """ if isinstance(template, str): from .template import get_template - template = get_template(template, model=model, tokenizer=tokenizer, image_processor=image_processor) + template = get_template(template, model=model, tokenizer=tokenizer, processor=processor, image_processor=image_processor) if os.path.isfile(dataset) or dataset in MLLM_DATASET.keys(): dataset = MLLM_DATASET['liuhaotian/llava']( diff --git a/auto_round/mllm/processor.py b/auto_round/mllm/processor.py index 302d0120..78dfe2a9 100644 --- a/auto_round/mllm/processor.py +++ b/auto_round/mllm/processor.py @@ -111,6 +111,51 @@ def squeeze_result(ret): return ret +@regist_processor("hf") +class HFProcessor(BasicProcessor): + IMAGE_TOKEN = '' + def __init__(self): + pass + + def post_init(self, model, tokenizer, processor=None, image_processor=None, **kwargs): + self.model = model + self.tokenizer = tokenizer + self.processor = processor + if image_processor is not None: + self.image_processor = image_processor + else: + self.image_processor = self.default_image_processor + + def get_input( + self, + text, + images, + return_tensors="pt", + squeeze=True, + max_length=None, + truncation=False, + truncation_strategy="text", + **kwargs): + + messages = [] + for content in text: + messages.append({ + "role": content['role'], + "content": [ + {"text": content["content"].replace(self.IMAGE_TOKEN, ""), "type": "text"} + ] + }) + if self.IMAGE_TOKEN in content['content']: + messages[-1]["content"].append({"text": None, "type": "image"}) + text = self.processor.apply_chat_template(messages, add_generation_prompt=True) + if images is not None: + images = self.image_processor(images) + ret = self.processor(text=text, images=images, return_tensors="pt") + if squeeze: + ret = self.squeeze_result(ret) + return ret + + @regist_processor("qwen2_vl") class Qwen2VLProcessor(BasicProcessor): @staticmethod diff --git a/auto_round/mllm/templates/default.json b/auto_round/mllm/templates/default.json index 321a0c48..a64017c7 100644 --- a/auto_round/mllm/templates/default.json +++ b/auto_round/mllm/templates/default.json @@ -10,5 +10,5 @@ "replace_tokens": null, "extra_encode" : false, "default_dataset": "NeelNanda/pile-10k", - "processor": "basic" + "processor": "hf" } \ No newline at end of file diff --git a/auto_round/script/mllm.py b/auto_round/script/mllm.py index 89e5b23a..c55dd997 100644 --- a/auto_round/script/mllm.py +++ b/auto_round/script/mllm.py @@ -270,7 +270,8 @@ def tune(args): os.environ["CUDA_VISIBLE_DEVICES"] = args.device args.device = ",".join(map(str, range(len(devices)))) devices = args.device.replace(" ", "").split(',') - use_auto_mapping = True + if len(devices) > 1: ##for 70B model on single card, use auto will cause some layer offload to cpu + use_auto_mapping = True elif args.device == "auto": use_auto_mapping == True @@ -288,6 +289,13 @@ def tune(args): model_name, model_base=None, model_name=model_name, torch_dtype=torch_dtype) model_type = "llava" + elif "deepseek" in model_name.lower(): + from deepseek_vl2.models import DeepseekVLV2Processor, DeepseekVLV2ForCausalLM + processor = DeepseekVLV2Processor.from_pretrained(model_name) + tokenizer = processor.tokenizer + model: DeepseekVLV2ForCausalLM = AutoModelForCausalLM.from_pretrained( + model_name, trust_remote_code=not args.disable_trust_remote_code, torch_dtype=torch_dtype, + device_map="auto" if use_auto_mapping else None) else: config = AutoConfig.from_pretrained(model_name, trust_remote_code=not args.disable_trust_remote_code) tokenizer = AutoTokenizer.from_pretrained(model_name) @@ -299,6 +307,9 @@ def tune(args): elif "mllama" in model_type: from transformers import MllamaForConditionalGeneration cls = MllamaForConditionalGeneration + elif "idefics3" in model_type: + from transformers import AutoModelForVision2Seq + cls = AutoModelForVision2Seq else: cls = AutoModelForCausalLM diff --git a/auto_round/utils.py b/auto_round/utils.py index 92f0c0a1..d5f66b16 100644 --- a/auto_round/utils.py +++ b/auto_round/utils.py @@ -402,11 +402,16 @@ def get_multimodal_block_names(model, quant_vision=False): """ block_names = [] target_modules = [] - vison_blocks_tuple = ("vision", "visual",) + vison_blocks_tuple = ("vision", "visual", "projector") + module_list_type = ("ModuleList", "Sequential") + last_module_list = None for n, m in model.named_modules(): - if hasattr(type(m), "__name__") and "ModuleList" in type(m).__name__: - if quant_vision or all(key not in n.lower() for key in (vison_blocks_tuple)): - target_modules.append((n, m)) + # if hasattr(type(m), "__name__") and "ModuleList" in type(m).__name__: + if hasattr(type(m), "__name__") and any(key in type(m).__name__ for key in module_list_type): + if quant_vision or all(key not in n.lower() for key in vison_blocks_tuple): + if last_module_list is None or last_module_list not in n: + last_module_list = n + target_modules.append((n, m)) validate_modules(target_modules, quant_vision, vison_blocks_tuple) for i, target_m in enumerate(target_modules): block_names.append([]) From 550450955fdd759c71ab820d00e39f3b53ae9dd4 Mon Sep 17 00:00:00 2001 From: n1ck-guo Date: Thu, 2 Jan 2025 20:42:56 -0500 Subject: [PATCH 2/2] update Signed-off-by: n1ck-guo --- auto_round/mllm/mllm_dataset.py | 4 +- auto_round/mllm/processor.py | 13 +++++- auto_round/mllm/template.py | 2 + auto_round/mllm/templates/mllama.json | 10 ---- auto_round/mllm/templates/qwen2_vl.json | 13 ------ auto_round/script/mllm.py | 62 +++++++++++++------------ 6 files changed, 49 insertions(+), 55 deletions(-) delete mode 100644 auto_round/mllm/templates/mllama.json delete mode 100644 auto_round/mllm/templates/qwen2_vl.json diff --git a/auto_round/mllm/mllm_dataset.py b/auto_round/mllm/mllm_dataset.py index 69b27803..18b7d0fd 100644 --- a/auto_round/mllm/mllm_dataset.py +++ b/auto_round/mllm/mllm_dataset.py @@ -223,7 +223,9 @@ def get_mllm_dataloader( """ if isinstance(template, str): from .template import get_template - template = get_template(template, model=model, tokenizer=tokenizer, processor=processor, image_processor=image_processor) + template = get_template( + template, model=model, tokenizer=tokenizer, + processor=processor, image_processor=image_processor) if os.path.isfile(dataset) or dataset in MLLM_DATASET.keys(): dataset = MLLM_DATASET['liuhaotian/llava']( diff --git a/auto_round/mllm/processor.py b/auto_round/mllm/processor.py index 78dfe2a9..a36acc1a 100644 --- a/auto_round/mllm/processor.py +++ b/auto_round/mllm/processor.py @@ -157,7 +157,7 @@ def get_input( @regist_processor("qwen2_vl") -class Qwen2VLProcessor(BasicProcessor): +class Qwen2VLProcessor(HFProcessor): @staticmethod def squeeze_result(ret): for key in ret: @@ -290,3 +290,14 @@ class DataArgs: def data_collator(self, batch): return self.collator_func(batch) + + +@regist_processor("deepseek_vl_v2") +class DeepseekVL2Processor(BasicProcessor): + def get_input( + self, + text, + images, + return_tensors="pt", + squeeze=True, max_length=None, truncation=False, truncation_strategy="text", **kwargs): + breakpoint() \ No newline at end of file diff --git a/auto_round/mllm/template.py b/auto_round/mllm/template.py index 08b4d9eb..5a97a9b9 100644 --- a/auto_round/mllm/template.py +++ b/auto_round/mllm/template.py @@ -115,6 +115,8 @@ def _register_template( ) return TEMPLATES[model_type] +_register_template("qwen2_vl", default_dataset="NeelNanda/pile-10k",processor=PROCESSORS["qwen2_vl"]) +_register_template("mllama", default_dataset="liuhaotian/llava", processor=PROCESSORS["hf"]) def load_template(path: str): """Load template information from a json file.""" diff --git a/auto_round/mllm/templates/mllama.json b/auto_round/mllm/templates/mllama.json deleted file mode 100644 index 2434dc4b..00000000 --- a/auto_round/mllm/templates/mllama.json +++ /dev/null @@ -1,10 +0,0 @@ -{ - "model_type": "mllama", - "format_user": "<|start_header_id|>user<|end_header_id|>\n{{content}}<|eot_id|>", - "format_assistant": "<|start_header_id|>assistant<|end_header_id|>\n{{content}}<|eot_id|>", - "format_system": "<|begin_of_text|><|start_header_id|>system|end_header_id|>\n{{content}}<", - "default_system": "You are a helpful assistant.", - "replace_tokens": ["", "<|image|>"], - "extra_encode" : true, - "default_dataset": "liuhaotian/llava_conv_58k" -} \ No newline at end of file diff --git a/auto_round/mllm/templates/qwen2_vl.json b/auto_round/mllm/templates/qwen2_vl.json deleted file mode 100644 index 6b72a063..00000000 --- a/auto_round/mllm/templates/qwen2_vl.json +++ /dev/null @@ -1,13 +0,0 @@ -{ - "model_type": "qwen2_vl", - "format_user": "<|im_start|>user\n{{content}}<|im_end|>\n", - "format_assistant": "<|im_start|>assistant\n{{content}}<|im_end|>\n", - "format_system": "<|im_start|>system\n{{content}}<|im_end|>\n", - "format_observation": "<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n", - "format_separator": "\n", - "default_system": "You are a helpful assistant.", - "replace_tokens": ["", "<|vision_start|><|image_pad|><|vision_end|>"], - "processor": "qwen2_vl", - "extra_encode" : true, - "default_dataset": "NeelNanda/pile-10k" -} \ No newline at end of file diff --git a/auto_round/script/mllm.py b/auto_round/script/mllm.py index 95473dec..63450f33 100644 --- a/auto_round/script/mllm.py +++ b/auto_round/script/mllm.py @@ -283,42 +283,44 @@ def tune(args): # load_model processor, image_processor = None, None - config = AutoConfig.from_pretrained(model_name, trust_remote_code=not args.disable_trust_remote_code) - if "llava" in model_name and config.architectures[0] != "LlavaForConditionalGeneration": - from llava.model.builder import load_pretrained_model # pylint: disable=E0401 - tokenizer, model, image_processor, _ = load_pretrained_model( - model_name, model_base=None, model_name=model_name, - torch_dtype=torch_dtype) - model_type = "llava" - elif "deepseek" in model_name.lower(): - from deepseek_vl2.models import DeepseekVLV2Processor, DeepseekVLV2ForCausalLM - processor = DeepseekVLV2Processor.from_pretrained(model_name) + if "deepseek" in model_name.lower(): + from deepseek_vl2.models import DeepseekVLV2Processor, DeepseekVLV2ForCausalLM # pylint: disable=E0401 + processor = DeepseekVLV2Processor.from_pretrained(model_name) tokenizer = processor.tokenizer model: DeepseekVLV2ForCausalLM = AutoModelForCausalLM.from_pretrained( model_name, trust_remote_code=not args.disable_trust_remote_code, torch_dtype=torch_dtype, device_map="auto" if use_auto_mapping else None) + model_type = "deepseek_vl_v2" else: - tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=not args.disable_trust_remote_code) - processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=not args.disable_trust_remote_code) - model_type = config.model_type - if "llava" in model_type: - from transformers import LlavaForConditionalGeneration - cls = LlavaForConditionalGeneration - elif "qwen2_vl" in model_type: - from transformers import Qwen2VLForConditionalGeneration - cls = Qwen2VLForConditionalGeneration - elif "mllama" in model_type: - from transformers import MllamaForConditionalGeneration - cls = MllamaForConditionalGeneration - elif "idefics3" in model_type: - from transformers import AutoModelForVision2Seq - cls = AutoModelForVision2Seq + config = AutoConfig.from_pretrained(model_name, trust_remote_code=not args.disable_trust_remote_code) + if "llava" in model_name and config.architectures[0] != "LlavaForConditionalGeneration": + from llava.model.builder import load_pretrained_model # pylint: disable=E0401 + tokenizer, model, image_processor, _ = load_pretrained_model( + model_name, model_base=None, model_name=model_name, + torch_dtype=torch_dtype) + model_type = "llava" else: - cls = AutoModelForCausalLM - - model = cls.from_pretrained( - model_name, trust_remote_code=not args.disable_trust_remote_code, torch_dtype=torch_dtype, - device_map="auto" if use_auto_mapping else None) + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=not args.disable_trust_remote_code) + processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=not args.disable_trust_remote_code) + model_type = config.model_type + if "llava" in model_type: + from transformers import LlavaForConditionalGeneration + cls = LlavaForConditionalGeneration + elif "qwen2_vl" in model_type: + from transformers import Qwen2VLForConditionalGeneration + cls = Qwen2VLForConditionalGeneration + elif "mllama" in model_type: + from transformers import MllamaForConditionalGeneration + cls = MllamaForConditionalGeneration + elif "idefics3" in model_type: + from transformers import AutoModelForVision2Seq + cls = AutoModelForVision2Seq + else: + cls = AutoModelForCausalLM + + model = cls.from_pretrained( + model_name, trust_remote_code=not args.disable_trust_remote_code, torch_dtype=torch_dtype, + device_map="auto" if use_auto_mapping else None) if "cogvlm2" in model_name: model.config.model_type = "cogvlm2"