Skip to content

Commit

Permalink
Adding video llm fine-tuning example (#2336)
Browse files Browse the repository at this point in the history
* adding video example

* exposing more parameters

* fixing formatting

---------

Co-authored-by: Kashif Rasul <[email protected]>
  • Loading branch information
mfarre and kashif authored Nov 12, 2024
1 parent dde20b2 commit 2d24d35
Showing 1 changed file with 257 additions and 0 deletions.
257 changes: 257 additions & 0 deletions examples/scripts/sft_video_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
# Copyright 2024. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Example usage:
accelerate launch \
--config_file=deepspeed_zero2.yaml \
sft_video_llm.py \
--dataset_name=mfarre/simplevideoshorts \
--video_cache_dir="/optional/path/to/cache/" \
--model_name_or_path=Qwen/Qwen2-VL-7B-Instruct \
--per_device_train_batch_size=1 \
--output_dir=video-llm-output \
--bf16=True \
--tf32=True \
--gradient_accumulation_steps=4 \
--num_train_epochs=4 \
--optim="adamw_torch_fused" \
--logging_steps=1 \
--log_level="debug" \
--log_level_replica="debug" \
--save_strategy="steps" \
--save_steps=300 \
--learning_rate=8e-5 \
--max_grad_norm=0.3 \
--warmup_ratio=0.1 \
--lr_scheduler_type="cosine" \
--report_to="wandb" \
--push_to_hub=False \
--torch_dtype=bfloat16 \
--gradient_checkpointing=True
"""

import json
import os
import random
from dataclasses import dataclass
from typing import Any, Dict, List

import requests
import torch
import wandb
from datasets import load_dataset
from peft import LoraConfig
from qwen_vl_utils import process_vision_info
from transformers import (
AutoModelForVision2Seq,
AutoProcessor,
BitsAndBytesConfig,
Qwen2VLProcessor,
)

from trl import (
SFTConfig,
SFTTrainer,
get_kbit_device_map,
)
from trl.commands.cli_utils import SFTScriptArguments, TrlParser
from trl.trainer import ModelConfig


def download_video(url: str, cache_dir: str) -> str:
"""Download video if not already present locally."""
os.makedirs(cache_dir, exist_ok=True) # Create cache dir if it doesn't exist
filename = url.split("/")[-1]
local_path = os.path.join(cache_dir, filename)

if os.path.exists(local_path):
return local_path

try:
with requests.get(url, stream=True) as r:
r.raise_for_status()
with open(local_path, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
return local_path
except requests.RequestException as e:
raise Exception(f"Failed to download video: {e}") from e


def prepare_dataset(example: Dict[str, Any], cache_dir: str) -> Dict[str, List[Dict[str, Any]]]:
"""Prepare dataset example for training."""
video_url = example["video_url"]
timecoded_cc = example["timecoded_cc"]
qa_pairs = json.loads(example["qa"])

system_message = "You are an expert in movie narrative analysis."
base_prompt = f"""Analyze the video and consider the following timecoded subtitles:
{timecoded_cc}
Based on this information, please answer the following questions:"""

selected_qa = random.sample(qa_pairs, 1)[0]

messages = [
{"role": "system", "content": [{"type": "text", "text": system_message}]},
{
"role": "user",
"content": [
{"type": "video", "video": download_video(video_url, cache_dir), "max_pixels": 360 * 420, "fps": 1.0},
{"type": "text", "text": f"{base_prompt}\n\nQuestion: {selected_qa['question']}"},
],
},
{"role": "assistant", "content": [{"type": "text", "text": selected_qa["answer"]}]},
]

return {"messages": messages}


def collate_fn(examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
"""Collate batch of examples for training."""
texts = []
video_inputs = []

for i, example in enumerate(examples):
try:
video_path = next(
content["video"]
for message in example["messages"]
for content in message["content"]
if content.get("type") == "video"
)
print(f"Processing video: {os.path.basename(video_path)}")

texts.append(processor.apply_chat_template(example["messages"], tokenize=False))
video_input = process_vision_info(example["messages"])[1][0]
video_inputs.append(video_input)
except Exception as e:
raise ValueError(f"Failed to process example {i}: {e}") from e

inputs = processor(text=texts, videos=video_inputs, return_tensors="pt", padding=True)

labels = inputs["input_ids"].clone()
labels[labels == processor.tokenizer.pad_token_id] = -100

# Handle visual tokens based on processor type
visual_tokens = (
[151652, 151653, 151656]
if isinstance(processor, Qwen2VLProcessor)
else [processor.tokenizer.convert_tokens_to_ids(processor.image_token)]
)

for visual_token_id in visual_tokens:
labels[labels == visual_token_id] = -100

inputs["labels"] = labels
return inputs


@dataclass
class CustomScriptArguments(SFTScriptArguments):
video_cache_dir: str = "/tmp/videos/"


if __name__ == "__main__":
# Parse arguments
parser = TrlParser((CustomScriptArguments, SFTConfig, ModelConfig))
script_args, training_args, model_config = parser.parse_args_and_config()

# Configure training args
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
training_args.remove_unused_columns = False
training_args.dataset_kwargs = {"skip_prepare_dataset": True}

# Load dataset
dataset = load_dataset(script_args.dataset_name, split="train")

# Setup model
torch_dtype = (
model_config.torch_dtype
if model_config.torch_dtype in ["auto", None]
else getattr(torch, model_config.torch_dtype)
)

# Quantization configuration for 4-bit training
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)

# Model initialization
model_kwargs = dict(
revision=model_config.model_revision,
trust_remote_code=model_config.trust_remote_code,
torch_dtype=torch_dtype,
device_map=get_kbit_device_map(),
quantization_config=bnb_config,
)

model = AutoModelForVision2Seq.from_pretrained(model_config.model_name_or_path, **model_kwargs)

peft_config = LoraConfig(
task_type="CAUSAL_LM",
r=16,
lora_alpha=16,
lora_dropout=0.1,
bias="none",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
)

# Configure model modules for gradients
if training_args.gradient_checkpointing:
model.gradient_checkpointing_enable()
model.config.use_reentrant = False
model.enable_input_require_grads()

processor = AutoProcessor.from_pretrained(
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code
)

# Prepare dataset
prepared_dataset = [prepare_dataset(example, script_args.video_cache_dir) for example in dataset]

# Initialize wandb if specified
if training_args.report_to == "wandb":
wandb.init(project="video-llm-training")

# Initialize trainer
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=prepared_dataset,
data_collator=collate_fn,
peft_config=peft_config,
tokenizer=processor.tokenizer,
)

# Train model
trainer.train()

# Save final model
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub(dataset_name=script_args.dataset_name)
if trainer.accelerator.is_main_process:
processor.push_to_hub(training_args.hub_model_id)

# Cleanup
del model
del trainer
torch.cuda.empty_cache()
wandb.finish()

0 comments on commit 2d24d35

Please sign in to comment.