generated from fastai/nbdev_template
-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding video llm fine-tuning example (#2336)
* adding video example * exposing more parameters * fixing formatting --------- Co-authored-by: Kashif Rasul <[email protected]>
- Loading branch information
Showing
1 changed file
with
257 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,257 @@ | ||
# Copyright 2024. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
""" | ||
Example usage: | ||
accelerate launch \ | ||
--config_file=deepspeed_zero2.yaml \ | ||
sft_video_llm.py \ | ||
--dataset_name=mfarre/simplevideoshorts \ | ||
--video_cache_dir="/optional/path/to/cache/" \ | ||
--model_name_or_path=Qwen/Qwen2-VL-7B-Instruct \ | ||
--per_device_train_batch_size=1 \ | ||
--output_dir=video-llm-output \ | ||
--bf16=True \ | ||
--tf32=True \ | ||
--gradient_accumulation_steps=4 \ | ||
--num_train_epochs=4 \ | ||
--optim="adamw_torch_fused" \ | ||
--logging_steps=1 \ | ||
--log_level="debug" \ | ||
--log_level_replica="debug" \ | ||
--save_strategy="steps" \ | ||
--save_steps=300 \ | ||
--learning_rate=8e-5 \ | ||
--max_grad_norm=0.3 \ | ||
--warmup_ratio=0.1 \ | ||
--lr_scheduler_type="cosine" \ | ||
--report_to="wandb" \ | ||
--push_to_hub=False \ | ||
--torch_dtype=bfloat16 \ | ||
--gradient_checkpointing=True | ||
""" | ||
|
||
import json | ||
import os | ||
import random | ||
from dataclasses import dataclass | ||
from typing import Any, Dict, List | ||
|
||
import requests | ||
import torch | ||
import wandb | ||
from datasets import load_dataset | ||
from peft import LoraConfig | ||
from qwen_vl_utils import process_vision_info | ||
from transformers import ( | ||
AutoModelForVision2Seq, | ||
AutoProcessor, | ||
BitsAndBytesConfig, | ||
Qwen2VLProcessor, | ||
) | ||
|
||
from trl import ( | ||
SFTConfig, | ||
SFTTrainer, | ||
get_kbit_device_map, | ||
) | ||
from trl.commands.cli_utils import SFTScriptArguments, TrlParser | ||
from trl.trainer import ModelConfig | ||
|
||
|
||
def download_video(url: str, cache_dir: str) -> str: | ||
"""Download video if not already present locally.""" | ||
os.makedirs(cache_dir, exist_ok=True) # Create cache dir if it doesn't exist | ||
filename = url.split("/")[-1] | ||
local_path = os.path.join(cache_dir, filename) | ||
|
||
if os.path.exists(local_path): | ||
return local_path | ||
|
||
try: | ||
with requests.get(url, stream=True) as r: | ||
r.raise_for_status() | ||
with open(local_path, "wb") as f: | ||
for chunk in r.iter_content(chunk_size=8192): | ||
if chunk: | ||
f.write(chunk) | ||
return local_path | ||
except requests.RequestException as e: | ||
raise Exception(f"Failed to download video: {e}") from e | ||
|
||
|
||
def prepare_dataset(example: Dict[str, Any], cache_dir: str) -> Dict[str, List[Dict[str, Any]]]: | ||
"""Prepare dataset example for training.""" | ||
video_url = example["video_url"] | ||
timecoded_cc = example["timecoded_cc"] | ||
qa_pairs = json.loads(example["qa"]) | ||
|
||
system_message = "You are an expert in movie narrative analysis." | ||
base_prompt = f"""Analyze the video and consider the following timecoded subtitles: | ||
{timecoded_cc} | ||
Based on this information, please answer the following questions:""" | ||
|
||
selected_qa = random.sample(qa_pairs, 1)[0] | ||
|
||
messages = [ | ||
{"role": "system", "content": [{"type": "text", "text": system_message}]}, | ||
{ | ||
"role": "user", | ||
"content": [ | ||
{"type": "video", "video": download_video(video_url, cache_dir), "max_pixels": 360 * 420, "fps": 1.0}, | ||
{"type": "text", "text": f"{base_prompt}\n\nQuestion: {selected_qa['question']}"}, | ||
], | ||
}, | ||
{"role": "assistant", "content": [{"type": "text", "text": selected_qa["answer"]}]}, | ||
] | ||
|
||
return {"messages": messages} | ||
|
||
|
||
def collate_fn(examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: | ||
"""Collate batch of examples for training.""" | ||
texts = [] | ||
video_inputs = [] | ||
|
||
for i, example in enumerate(examples): | ||
try: | ||
video_path = next( | ||
content["video"] | ||
for message in example["messages"] | ||
for content in message["content"] | ||
if content.get("type") == "video" | ||
) | ||
print(f"Processing video: {os.path.basename(video_path)}") | ||
|
||
texts.append(processor.apply_chat_template(example["messages"], tokenize=False)) | ||
video_input = process_vision_info(example["messages"])[1][0] | ||
video_inputs.append(video_input) | ||
except Exception as e: | ||
raise ValueError(f"Failed to process example {i}: {e}") from e | ||
|
||
inputs = processor(text=texts, videos=video_inputs, return_tensors="pt", padding=True) | ||
|
||
labels = inputs["input_ids"].clone() | ||
labels[labels == processor.tokenizer.pad_token_id] = -100 | ||
|
||
# Handle visual tokens based on processor type | ||
visual_tokens = ( | ||
[151652, 151653, 151656] | ||
if isinstance(processor, Qwen2VLProcessor) | ||
else [processor.tokenizer.convert_tokens_to_ids(processor.image_token)] | ||
) | ||
|
||
for visual_token_id in visual_tokens: | ||
labels[labels == visual_token_id] = -100 | ||
|
||
inputs["labels"] = labels | ||
return inputs | ||
|
||
|
||
@dataclass | ||
class CustomScriptArguments(SFTScriptArguments): | ||
video_cache_dir: str = "/tmp/videos/" | ||
|
||
|
||
if __name__ == "__main__": | ||
# Parse arguments | ||
parser = TrlParser((CustomScriptArguments, SFTConfig, ModelConfig)) | ||
script_args, training_args, model_config = parser.parse_args_and_config() | ||
|
||
# Configure training args | ||
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False) | ||
training_args.remove_unused_columns = False | ||
training_args.dataset_kwargs = {"skip_prepare_dataset": True} | ||
|
||
# Load dataset | ||
dataset = load_dataset(script_args.dataset_name, split="train") | ||
|
||
# Setup model | ||
torch_dtype = ( | ||
model_config.torch_dtype | ||
if model_config.torch_dtype in ["auto", None] | ||
else getattr(torch, model_config.torch_dtype) | ||
) | ||
|
||
# Quantization configuration for 4-bit training | ||
bnb_config = BitsAndBytesConfig( | ||
load_in_4bit=True, | ||
bnb_4bit_use_double_quant=True, | ||
bnb_4bit_quant_type="nf4", | ||
bnb_4bit_compute_dtype=torch.bfloat16, | ||
) | ||
|
||
# Model initialization | ||
model_kwargs = dict( | ||
revision=model_config.model_revision, | ||
trust_remote_code=model_config.trust_remote_code, | ||
torch_dtype=torch_dtype, | ||
device_map=get_kbit_device_map(), | ||
quantization_config=bnb_config, | ||
) | ||
|
||
model = AutoModelForVision2Seq.from_pretrained(model_config.model_name_or_path, **model_kwargs) | ||
|
||
peft_config = LoraConfig( | ||
task_type="CAUSAL_LM", | ||
r=16, | ||
lora_alpha=16, | ||
lora_dropout=0.1, | ||
bias="none", | ||
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], | ||
) | ||
|
||
# Configure model modules for gradients | ||
if training_args.gradient_checkpointing: | ||
model.gradient_checkpointing_enable() | ||
model.config.use_reentrant = False | ||
model.enable_input_require_grads() | ||
|
||
processor = AutoProcessor.from_pretrained( | ||
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code | ||
) | ||
|
||
# Prepare dataset | ||
prepared_dataset = [prepare_dataset(example, script_args.video_cache_dir) for example in dataset] | ||
|
||
# Initialize wandb if specified | ||
if training_args.report_to == "wandb": | ||
wandb.init(project="video-llm-training") | ||
|
||
# Initialize trainer | ||
trainer = SFTTrainer( | ||
model=model, | ||
args=training_args, | ||
train_dataset=prepared_dataset, | ||
data_collator=collate_fn, | ||
peft_config=peft_config, | ||
tokenizer=processor.tokenizer, | ||
) | ||
|
||
# Train model | ||
trainer.train() | ||
|
||
# Save final model | ||
trainer.save_model(training_args.output_dir) | ||
if training_args.push_to_hub: | ||
trainer.push_to_hub(dataset_name=script_args.dataset_name) | ||
if trainer.accelerator.is_main_process: | ||
processor.push_to_hub(training_args.hub_model_id) | ||
|
||
# Cleanup | ||
del model | ||
del trainer | ||
torch.cuda.empty_cache() | ||
wandb.finish() |