-
Notifications
You must be signed in to change notification settings - Fork 507
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
5fc7ff8
commit 262a943
Showing
6 changed files
with
299 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
""" | ||
Usage: | ||
# 7B | ||
python3 -m ferret.model.make_delta \ | ||
--base ./model/vicuna-7b-v1-3 \ | ||
--target ./checkpoints/ferret_ft_clipL336_vicunaV1-3-7b_3Ep_dataV16_RSamplerV2/checkpoint-final \ | ||
--delta ./checkpoints/ferret_ft_clipL336_vicunaV1-3-7b_3Ep_dataV16_RSamplerV2/ferret-7b-delta | ||
# 13B | ||
python3 -m ferret.model.make_delta \ | ||
--base ./model/vicuna-13b-v1-3 \ | ||
--target ./checkpoints/ferret_ft_clipL336_vicunaV1-3-13b_3Ep_dataV16_RSamplerV2/checkpoint-final \ | ||
--delta ./checkpoints/ferret_ft_clipL336_vicunaV1-3-13b_3Ep_dataV16_RSamplerV2/ferret-13b-delta | ||
""" | ||
import argparse | ||
|
||
import torch | ||
from tqdm import tqdm | ||
from transformers import AutoTokenizer, AutoModelForCausalLM | ||
from ferret.model.utils import auto_upgrade | ||
|
||
# all the parameters inside the geosampler and mm projector | ||
exclude_name_lists = ['model.mm_projector.weight', 'model.mm_projector.bias', | ||
'model.region_geo_sampler.agg_projector_list.0.net.0.bias', 'model.region_geo_sampler.agg_projector_list.0.net.0.weight', | ||
'model.region_geo_sampler.agg_projector_list.0.norm.bias', 'model.region_geo_sampler.agg_projector_list.0.norm.weight', | ||
'model.region_geo_sampler.agg_projector_list.1.net.0.bias', 'model.region_geo_sampler.agg_projector_list.1.net.0.weight', | ||
'model.region_geo_sampler.agg_projector_list.1.norm.bias', 'model.region_geo_sampler.agg_projector_list.1.norm.weight', | ||
'model.region_geo_sampler.diff_projector_list.0.bias', 'model.region_geo_sampler.diff_projector_list.0.weight', | ||
'model.region_geo_sampler.diff_projector_list.1.bias', 'model.region_geo_sampler.diff_projector_list.1.weight', | ||
'model.region_geo_sampler.dim_projector.bias', 'model.region_geo_sampler.dim_projector.weight', | ||
'model.region_geo_sampler.flatten_projector.bias', 'model.region_geo_sampler.flatten_projector.weight' | ||
] | ||
|
||
|
||
def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id): | ||
print("Loading base model") | ||
base = AutoModelForCausalLM.from_pretrained( | ||
base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) | ||
|
||
print("Loading target model") | ||
auto_upgrade(target_model_path) | ||
target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) | ||
|
||
print("Calculating delta") | ||
for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): | ||
if name not in base.state_dict(): | ||
assert name in exclude_name_lists, f'{name} not in base model' | ||
continue | ||
if param.data.shape == base.state_dict()[name].shape: | ||
param.data -= base.state_dict()[name] | ||
else: | ||
assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' | ||
bparam = base.state_dict()[name] | ||
param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam | ||
|
||
print("Saving delta") | ||
if hub_repo_id: | ||
kwargs = {"push_to_hub": True, "repo_id": hub_repo_id} | ||
else: | ||
kwargs = {} | ||
target.save_pretrained(delta_path, **kwargs) | ||
target_tokenizer = AutoTokenizer.from_pretrained(target_model_path) | ||
target_tokenizer.save_pretrained(delta_path, **kwargs) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--base-model-path", type=str, required=True) | ||
parser.add_argument("--target-model-path", type=str, required=True) | ||
parser.add_argument("--delta-path", type=str, required=True) | ||
parser.add_argument("--hub-repo-id", type=str, default=None) | ||
args = parser.parse_args() | ||
|
||
make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
""" | ||
Usage: | ||
# 7B | ||
To extract region_geo_sampler: | ||
python misc/extract_geosampler_and_mm_projector.py \ | ||
--keys_to_match=region_geo_sampler \ | ||
--model-path=./checkpoints/ferret_ft_clipL336_vicunaV1-3-7b_3Ep_dataV16_RSamplerV2/checkpoint-final \ | ||
--output=./checkpoints/ferret_ft_clipL336_vicunaV1-3-7b_3Ep_dataV16_RSamplerV2/extracted_region_geo_sampler.bin | ||
To extract mm_projector: | ||
python misc/extract_geosampler_and_mm_projector.py \ | ||
--keys_to_match=mm_projector \ | ||
--model-path=./checkpoints/ferret_ft_clipL336_vicunaV1-3-7b_3Ep_dataV16_RSamplerV2/checkpoint-final \ | ||
--output=./checkpoints/ferret_ft_clipL336_vicunaV1-3-7b_3Ep_dataV16_RSamplerV2/extracted_mm_projector.bin | ||
# 13B | ||
To extract region_geo_sampler: | ||
python misc/extract_geosampler_and_mm_projector.py \ | ||
--keys_to_match=region_geo_sampler \ | ||
--model-path=./checkpoints/ferret_ft_clipL336_vicunaV1-3-13b_3Ep_dataV16_RSamplerV2/checkpoint-final \ | ||
--output=./checkpoints/ferret_ft_clipL336_vicunaV1-3-13b_3Ep_dataV16_RSamplerV2/extracted_region_geo_sampler.bin | ||
To extract mm_projector: | ||
python misc/extract_geosampler_and_mm_projector.py \ | ||
--keys_to_match=mm_projector \ | ||
--model-path=./checkpoints/ferret_ft_clipL336_vicunaV1-3-13b_3Ep_dataV16_RSamplerV2/checkpoint-final \ | ||
--output=./checkpoints/ferret_ft_clipL336_vicunaV1-3-13b_3Ep_dataV16_RSamplerV2/extracted_mm_projector.bin | ||
""" | ||
|
||
|
||
import os | ||
import argparse | ||
import torch | ||
import json | ||
from collections import defaultdict | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser(description='Extract MMProjector or GeoSampler weights') | ||
parser.add_argument('--model-path', type=str, help='model folder') | ||
parser.add_argument('--output', type=str, help='output file') | ||
parser.add_argument('--keys_to_match', type=str, default="region_geo_sampler", choices=["mm_projector", "region_geo_sampler"], help='keys to be matched') | ||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
if __name__ == '__main__': | ||
args = parse_args() | ||
|
||
keys_to_match = [args.keys_to_match] | ||
ckpt_to_key = defaultdict(list) | ||
print('----indexing keys_to_match...----') | ||
try: | ||
model_indices = json.load(open(os.path.join(args.model_path, 'pytorch_model.bin.index.json'))) | ||
for k, v in model_indices['weight_map'].items(): | ||
if any(key_match in k for key_match in keys_to_match): | ||
ckpt_to_key[v].append(k) | ||
except FileNotFoundError: | ||
# Smaller models or model checkpoints saved by DeepSpeed. | ||
v = 'pytorch_model.bin' | ||
for k in torch.load(os.path.join(args.model_path, v), map_location='cpu').keys(): | ||
if any(key_match in k for key_match in keys_to_match): | ||
ckpt_to_key[v].append(k) | ||
|
||
loaded_weights = {} | ||
|
||
print('----loading weights...----') | ||
for ckpt_name, weight_keys in ckpt_to_key.items(): | ||
ckpt = torch.load(os.path.join(args.model_path, ckpt_name), map_location='cpu') | ||
for k in weight_keys: | ||
loaded_weights[k] = ckpt[k] | ||
|
||
print('----saving weights...----') | ||
print(f'the keys of saved weights: {loaded_weights.keys()}') | ||
print(f'----saved to {args.output}----') | ||
torch.save(loaded_weights, args.output) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
""" | ||
Usage: | ||
python3 misc/verify_equal.py \ | ||
--orig-model-path ./checkpoints/ferret_ft_clipL336_vicunaV1-3-7b_3Ep_dataV16_RSamplerV2/checkpoint-final \ | ||
--new-model-path ./model/ferret-7b-v1-3 | ||
""" | ||
import argparse | ||
|
||
import torch | ||
from tqdm import tqdm | ||
from transformers import AutoTokenizer, AutoModelForCausalLM | ||
from ferret import FERRETLlamaForCausalLM | ||
|
||
def verify_equal(old_model_path, new_model_path): | ||
print("Loading old model") | ||
old = FERRETLlamaForCausalLM.from_pretrained(old_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) | ||
|
||
print("Loading saved model") | ||
new = FERRETLlamaForCausalLM.from_pretrained(new_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) | ||
|
||
# Get state dictionaries of both models | ||
state_dict1 = old.state_dict() | ||
state_dict2 = new.state_dict() | ||
|
||
# Compare each parameter | ||
for name, param in tqdm(state_dict1.items(), desc="Traverse all params"): | ||
# Check if the parameter name exists in the second model | ||
if name not in state_dict2: | ||
print(f"Parameter {name} found in the first model but not in the second.") | ||
return False | ||
|
||
# Check if the parameter weights are the same, bf16 vs. f32 | ||
if not torch.allclose(param, state_dict2[name], atol=1e-4): | ||
print(param.shape) | ||
print(state_dict2[name].shape) | ||
print(f"Parameter weights for {name} are different.") | ||
return False | ||
|
||
print("All parameter names and weights are the same.") | ||
return True | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--orig-model-path", type=str, required=True) | ||
parser.add_argument("--new-model-path", type=str, required=True) | ||
|
||
args = parser.parse_args() | ||
|
||
print(verify_equal(args.orig_model_path, args.new_model_path)) |