Skip to content

Commit

Permalink
checkpoints release
Browse files Browse the repository at this point in the history
  • Loading branch information
Haotian-Zhang committed Dec 15, 2023
1 parent 5fc7ff8 commit 262a943
Show file tree
Hide file tree
Showing 6 changed files with 299 additions and 8 deletions.
24 changes: 22 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,12 @@ Key Contributions:


## Release
- [10/30] 🔥 We released the code of **FERRET** model.
- [12/14] 🔥 We released the [checkpoints(7B, 13B)](#checkpoints).
- [10/30] 🔥 We released the code of **FERRET** model and [Ferret-Bench](ferret/eval/ferret_gpt4_data).



**Usage and License Notices**: The data, and code is intended and licensed for research use only. They are also restricted to uses that follow the license agreement of LLaMA, Vicuna and GPT-4. The dataset is CC BY NC 4.0 (allowing only non-commercial use) and models trained using the dataset should not be used outside of research purposes.
**Usage and License Notices**: The data, and code is intended and licensed for research use only. They are also restricted to uses that follow the license agreement of LLaMA, Vicuna and GPT-4. The dataset is CC BY NC 4.0 (allowing only non-commercial use) and models trained using the dataset should not be used outside of research purposes.

## Contents
- [Install](#install)
Expand Down Expand Up @@ -95,6 +96,25 @@ The scripts are provided ([7B](experiments/ferret_7b_train.sh), [13B](experiment

Please see this [doc](EVAL.md) for the details.

## Checkpoints
We extracted the `delta` between our pre-trained model and Vicuna. Please first download weights of Vicuna following the [previous instruction](#prepare-vicuna-checkpoint-and-llavas-projector). Then download our prepared offsets of weights: [7B](https://docs-assets.developer.apple.com/ml-research/models/ferret/ferret-7b/ferret-7b-delta.zip), [13B](https://docs-assets.developer.apple.com/ml-research/models/ferret/ferret-13b/ferret-13b-delta.zip) using `wget` or `curl`, and unzip the downloaded offsets. Lastly, apply the offset to the Vicuna's weight by running the following script:
```Shell
# 7B
python3 -m ferret.model.apply_delta \
--base ./model/vicuna-7b-v1-3 \
--target ./model/ferret-7b-v1-3 \
--delta path/to/ferret-7b-delta
# 13B
python3 -m ferret.model.apply_delta \
--base ./model/vicuna-13b-v1-3 \
--target ./model/ferret-13b-v1-3 \
--delta path/to/ferret-13b-delta
```

**Notices**: Apple's rights in the attached weight differentials are hereby licensed under the CC-BY-NC license. Apple makes no representations with regards to LLaMa or any other third party software, which are subject to their own terms.

Please refer to the next section about how to set up a local demo with pre-trained weight.

## Demo

To run our demo, you need to train FERRET and use the checkpoints locally. Gradio web UI is used. Please run the following commands one by one.
Expand Down
26 changes: 24 additions & 2 deletions ferret/model/apply_delta.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
"""
Usage:
python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta
# 7B
python3 -m ferret.model.apply_delta \
--base ./model/vicuna-7b-v1-3 \
--target ./model/ferret-7b-v1-3 \
--delta ./checkpoints/ferret_ft_clipL336_vicunaV1-3-7b_3Ep_dataV16_RSamplerV2/ferret-7b-delta
# 13B
python3 -m ferret.model.apply_delta \
--base ./model/vicuna-13b-v1-3 \
--target ./model/ferret-13b-v1-3 \
--delta ./checkpoints/ferret_ft_clipL336_vicunaV1-3-13b_3Ep_dataV16_RSamplerV2/ferret-13b-delta
"""
import argparse

Expand All @@ -10,6 +20,18 @@
from ferret import FERRETLlamaForCausalLM


exclude_name_lists = ['model.mm_projector.weight', 'model.mm_projector.bias',
'model.region_geo_sampler.agg_projector_list.0.net.0.bias', 'model.region_geo_sampler.agg_projector_list.0.net.0.weight',
'model.region_geo_sampler.agg_projector_list.0.norm.bias', 'model.region_geo_sampler.agg_projector_list.0.norm.weight',
'model.region_geo_sampler.agg_projector_list.1.net.0.bias', 'model.region_geo_sampler.agg_projector_list.1.net.0.weight',
'model.region_geo_sampler.agg_projector_list.1.norm.bias', 'model.region_geo_sampler.agg_projector_list.1.norm.weight',
'model.region_geo_sampler.diff_projector_list.0.bias', 'model.region_geo_sampler.diff_projector_list.0.weight',
'model.region_geo_sampler.diff_projector_list.1.bias', 'model.region_geo_sampler.diff_projector_list.1.weight',
'model.region_geo_sampler.dim_projector.bias', 'model.region_geo_sampler.dim_projector.weight',
'model.region_geo_sampler.flatten_projector.bias', 'model.region_geo_sampler.flatten_projector.weight'
]


def apply_delta(base_model_path, target_model_path, delta_path):
print("Loading base model")
base = AutoModelForCausalLM.from_pretrained(
Expand All @@ -22,7 +44,7 @@ def apply_delta(base_model_path, target_model_path, delta_path):
print("Applying delta")
for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
if name not in base.state_dict():
assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
assert name in exclude_name_lists, f'{name} not in base model'
continue
if param.data.shape == base.state_dict()[name].shape:
param.data += base.state_dict()[name]
Expand Down
74 changes: 74 additions & 0 deletions ferret/model/make_delta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""
Usage:
# 7B
python3 -m ferret.model.make_delta \
--base ./model/vicuna-7b-v1-3 \
--target ./checkpoints/ferret_ft_clipL336_vicunaV1-3-7b_3Ep_dataV16_RSamplerV2/checkpoint-final \
--delta ./checkpoints/ferret_ft_clipL336_vicunaV1-3-7b_3Ep_dataV16_RSamplerV2/ferret-7b-delta
# 13B
python3 -m ferret.model.make_delta \
--base ./model/vicuna-13b-v1-3 \
--target ./checkpoints/ferret_ft_clipL336_vicunaV1-3-13b_3Ep_dataV16_RSamplerV2/checkpoint-final \
--delta ./checkpoints/ferret_ft_clipL336_vicunaV1-3-13b_3Ep_dataV16_RSamplerV2/ferret-13b-delta
"""
import argparse

import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from ferret.model.utils import auto_upgrade

# all the parameters inside the geosampler and mm projector
exclude_name_lists = ['model.mm_projector.weight', 'model.mm_projector.bias',
'model.region_geo_sampler.agg_projector_list.0.net.0.bias', 'model.region_geo_sampler.agg_projector_list.0.net.0.weight',
'model.region_geo_sampler.agg_projector_list.0.norm.bias', 'model.region_geo_sampler.agg_projector_list.0.norm.weight',
'model.region_geo_sampler.agg_projector_list.1.net.0.bias', 'model.region_geo_sampler.agg_projector_list.1.net.0.weight',
'model.region_geo_sampler.agg_projector_list.1.norm.bias', 'model.region_geo_sampler.agg_projector_list.1.norm.weight',
'model.region_geo_sampler.diff_projector_list.0.bias', 'model.region_geo_sampler.diff_projector_list.0.weight',
'model.region_geo_sampler.diff_projector_list.1.bias', 'model.region_geo_sampler.diff_projector_list.1.weight',
'model.region_geo_sampler.dim_projector.bias', 'model.region_geo_sampler.dim_projector.weight',
'model.region_geo_sampler.flatten_projector.bias', 'model.region_geo_sampler.flatten_projector.weight'
]


def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id):
print("Loading base model")
base = AutoModelForCausalLM.from_pretrained(
base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)

print("Loading target model")
auto_upgrade(target_model_path)
target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)

print("Calculating delta")
for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
if name not in base.state_dict():
assert name in exclude_name_lists, f'{name} not in base model'
continue
if param.data.shape == base.state_dict()[name].shape:
param.data -= base.state_dict()[name]
else:
assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
bparam = base.state_dict()[name]
param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam

print("Saving delta")
if hub_repo_id:
kwargs = {"push_to_hub": True, "repo_id": hub_repo_id}
else:
kwargs = {}
target.save_pretrained(delta_path, **kwargs)
target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
target_tokenizer.save_pretrained(delta_path, **kwargs)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--base-model-path", type=str, required=True)
parser.add_argument("--target-model-path", type=str, required=True)
parser.add_argument("--delta-path", type=str, required=True)
parser.add_argument("--hub-repo-id", type=str, default=None)
args = parser.parse_args()

make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id)
55 changes: 51 additions & 4 deletions ferret/serve/gradio_web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def clear_history(request: gr.Request):
logger.info(f"clear_history. ip: {request.client.host}")
state = default_conversation.copy()
return (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5 + \
(None, {'region_placeholder_tokens':[],'region_coordinates':[],'region_masks':[],'masks':[]}, [], None)
(None, {'region_placeholder_tokens':[],'region_coordinates':[],'region_masks':[],'region_masks_in_prompts':[],'masks':[]}, [], None)


def resize_bbox(box, image_w=None, image_h=None, default_wh=VOCAB_IMAGE_W):
Expand Down Expand Up @@ -321,7 +321,26 @@ def post_process_code(code):
return code


def find_indices_in_order(str_list, STR):
indices = []
i = 0
while i < len(STR):
for element in str_list:
if STR[i:i+len(element)] == element:
indices.append(str_list.index(element))
i += len(element) - 1
break
i += 1
return indices


def format_region_prompt(prompt, refer_input_state):
# Find regions in prompts and assign corresponding region masks
refer_input_state['region_masks_in_prompts'] = []
indices_region_placeholder_in_prompt = find_indices_in_order(refer_input_state['region_placeholder_tokens'], prompt)
refer_input_state['region_masks_in_prompts'] = [refer_input_state['region_masks'][iii] for iii in indices_region_placeholder_in_prompt]

# Find regions in prompts and replace with real coordinates and region feature token.
for region_ph_index, region_ph_i in enumerate(refer_input_state['region_placeholder_tokens']):
prompt = prompt.replace(region_ph_i, '{} {}'.format(refer_input_state['region_coordinates'][region_ph_index], DEFAULT_REGION_FEA_TOKEN))
return prompt
Expand All @@ -341,6 +360,32 @@ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, refer_in
if len(state.messages) == state.offset + 2:
# First round of conversation
template_name = 'ferret_v1'
# Below is LLaVA's original templates.
# if "llava" in model_name.lower():
# if 'llama-2' in model_name.lower():
# template_name = "llava_llama_2"
# elif "v1" in model_name.lower():
# if 'mmtag' in model_name.lower():
# template_name = "v1_mmtag"
# elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
# template_name = "v1_mmtag"
# else:
# template_name = "llava_v1"
# elif "mpt" in model_name.lower():
# template_name = "mpt"
# else:
# if 'mmtag' in model_name.lower():
# template_name = "v0_mmtag"
# elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
# template_name = "v0_mmtag"
# else:
# template_name = "llava_v0"
# elif "mpt" in model_name:
# template_name = "mpt_text"
# elif "llama-2" in model_name:
# template_name = "llama_2"
# else:
# template_name = "vicuna_v1"
new_state = conv_templates[template_name].copy()
new_state.append_message(new_state.roles[0], state.messages[-2][1])
new_state.append_message(new_state.roles[1], None)
Expand Down Expand Up @@ -386,8 +431,8 @@ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, refer_in
}
logger.info(f"==== request ====\n{pload}")
if args.add_region_feature:
pload['region_masks'] = refer_input_state['region_masks']
logger.info(f"==== add region_masks to request ====\n")
pload['region_masks'] = refer_input_state['region_masks_in_prompts']
logger.info(f"==== add region_masks_in_prompts to request ====\n")

pload['images'] = state.get_images()
print(f'Input Prompt: {prompt}')
Expand Down Expand Up @@ -439,8 +484,8 @@ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, refer_in

title_markdown = ("""
# 🦦 Ferret: Refer and Ground Anything Anywhere at Any Granularity
[[Code](https://github.com/apple/ml-ferret)] [[Paper](https://arxiv.org/abs/2310.07704)]
""")
# [[Project Page]](https://llava-vl.github.io) [[Paper]](https://arxiv.org/abs/2304.08485)

tos_markdown = ("""
### Terms of use
Expand Down Expand Up @@ -554,6 +599,7 @@ def draw(input_mode, input, refer_input_state, refer_text_show, imagebox_refer):
refer_input_state['region_placeholder_tokens'].append(cur_region_token)
refer_input_state['region_coordinates'].append(cur_region_coordinates)
refer_input_state['region_masks'].append(cur_region_masks)
assert len(refer_input_state['region_masks']) == len(refer_input_state['region_coordinates']) == len(refer_input_state['region_placeholder_tokens'])
refer_text_show.append((cur_region_token, ''))

# Show Parsed Referring.
Expand Down Expand Up @@ -597,6 +643,7 @@ def build_demo(embed_mode):
refer_input_state = gr.State({'region_placeholder_tokens':[],
'region_coordinates':[],
'region_masks':[],
'region_masks_in_prompts':[],
'masks':[],
})
refer_text_show = gr.HighlightedText(value=[], label="Referring Input Cache")
Expand Down
77 changes: 77 additions & 0 deletions scripts/extract_geosampler_and_mm_projector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""
Usage:
# 7B
To extract region_geo_sampler:
python misc/extract_geosampler_and_mm_projector.py \
--keys_to_match=region_geo_sampler \
--model-path=./checkpoints/ferret_ft_clipL336_vicunaV1-3-7b_3Ep_dataV16_RSamplerV2/checkpoint-final \
--output=./checkpoints/ferret_ft_clipL336_vicunaV1-3-7b_3Ep_dataV16_RSamplerV2/extracted_region_geo_sampler.bin
To extract mm_projector:
python misc/extract_geosampler_and_mm_projector.py \
--keys_to_match=mm_projector \
--model-path=./checkpoints/ferret_ft_clipL336_vicunaV1-3-7b_3Ep_dataV16_RSamplerV2/checkpoint-final \
--output=./checkpoints/ferret_ft_clipL336_vicunaV1-3-7b_3Ep_dataV16_RSamplerV2/extracted_mm_projector.bin
# 13B
To extract region_geo_sampler:
python misc/extract_geosampler_and_mm_projector.py \
--keys_to_match=region_geo_sampler \
--model-path=./checkpoints/ferret_ft_clipL336_vicunaV1-3-13b_3Ep_dataV16_RSamplerV2/checkpoint-final \
--output=./checkpoints/ferret_ft_clipL336_vicunaV1-3-13b_3Ep_dataV16_RSamplerV2/extracted_region_geo_sampler.bin
To extract mm_projector:
python misc/extract_geosampler_and_mm_projector.py \
--keys_to_match=mm_projector \
--model-path=./checkpoints/ferret_ft_clipL336_vicunaV1-3-13b_3Ep_dataV16_RSamplerV2/checkpoint-final \
--output=./checkpoints/ferret_ft_clipL336_vicunaV1-3-13b_3Ep_dataV16_RSamplerV2/extracted_mm_projector.bin
"""


import os
import argparse
import torch
import json
from collections import defaultdict


def parse_args():
parser = argparse.ArgumentParser(description='Extract MMProjector or GeoSampler weights')
parser.add_argument('--model-path', type=str, help='model folder')
parser.add_argument('--output', type=str, help='output file')
parser.add_argument('--keys_to_match', type=str, default="region_geo_sampler", choices=["mm_projector", "region_geo_sampler"], help='keys to be matched')
args = parser.parse_args()
return args


if __name__ == '__main__':
args = parse_args()

keys_to_match = [args.keys_to_match]
ckpt_to_key = defaultdict(list)
print('----indexing keys_to_match...----')
try:
model_indices = json.load(open(os.path.join(args.model_path, 'pytorch_model.bin.index.json')))
for k, v in model_indices['weight_map'].items():
if any(key_match in k for key_match in keys_to_match):
ckpt_to_key[v].append(k)
except FileNotFoundError:
# Smaller models or model checkpoints saved by DeepSpeed.
v = 'pytorch_model.bin'
for k in torch.load(os.path.join(args.model_path, v), map_location='cpu').keys():
if any(key_match in k for key_match in keys_to_match):
ckpt_to_key[v].append(k)

loaded_weights = {}

print('----loading weights...----')
for ckpt_name, weight_keys in ckpt_to_key.items():
ckpt = torch.load(os.path.join(args.model_path, ckpt_name), map_location='cpu')
for k in weight_keys:
loaded_weights[k] = ckpt[k]

print('----saving weights...----')
print(f'the keys of saved weights: {loaded_weights.keys()}')
print(f'----saved to {args.output}----')
torch.save(loaded_weights, args.output)
51 changes: 51 additions & 0 deletions scripts/verify_equal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""
Usage:
python3 misc/verify_equal.py \
--orig-model-path ./checkpoints/ferret_ft_clipL336_vicunaV1-3-7b_3Ep_dataV16_RSamplerV2/checkpoint-final \
--new-model-path ./model/ferret-7b-v1-3
"""
import argparse

import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from ferret import FERRETLlamaForCausalLM

def verify_equal(old_model_path, new_model_path):
print("Loading old model")
old = FERRETLlamaForCausalLM.from_pretrained(old_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)

print("Loading saved model")
new = FERRETLlamaForCausalLM.from_pretrained(new_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)

# Get state dictionaries of both models
state_dict1 = old.state_dict()
state_dict2 = new.state_dict()

# Compare each parameter
for name, param in tqdm(state_dict1.items(), desc="Traverse all params"):
# Check if the parameter name exists in the second model
if name not in state_dict2:
print(f"Parameter {name} found in the first model but not in the second.")
return False

# Check if the parameter weights are the same, bf16 vs. f32
if not torch.allclose(param, state_dict2[name], atol=1e-4):
print(param.shape)
print(state_dict2[name].shape)
print(f"Parameter weights for {name} are different.")
return False

print("All parameter names and weights are the same.")
return True


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--orig-model-path", type=str, required=True)
parser.add_argument("--new-model-path", type=str, required=True)

args = parser.parse_args()

print(verify_equal(args.orig_model_path, args.new_model_path))

0 comments on commit 262a943

Please sign in to comment.