-
Notifications
You must be signed in to change notification settings - Fork 66
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
351 additions
and
188 deletions.
There are no files selected for viewing
Submodule FasterTransformer
deleted from
70eb11
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
MODEL_TYPE="blocklm-130B" | ||
MODEL_ARGS="--num-layers 70 \ | ||
--hidden-size 12288 \ | ||
--inner-hidden-size 32768 \ | ||
--vocab-size 150528 \ | ||
--num-attention-heads 96 \ | ||
--max-sequence-length 1025 \ | ||
--tokenizer-type icetk-glm-130B \ | ||
--layernorm-order post \ | ||
--skip-init \ | ||
--task-mask \ | ||
--load ${CHECKPOINT_PATH}/iter_0020000" | ||
# |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,204 @@ | ||
# -*- encoding: utf-8 -*- | ||
''' | ||
@File : inference_large_scale.py | ||
@Time : 2021/10/22 19:41:58 | ||
@Author : Ming Ding | ||
@Contact : [email protected] | ||
''' | ||
|
||
# here put the import lib | ||
from functools import partial | ||
import os | ||
import bminf | ||
import torch | ||
import argparse | ||
import stat | ||
import re | ||
|
||
|
||
from SwissArmyTransformer import mpu, get_args, get_tokenizer | ||
from SwissArmyTransformer.arguments import initialize_distributed, set_random_seed | ||
from SwissArmyTransformer.training import load_checkpoint | ||
|
||
from SwissArmyTransformer.model import GLM130B | ||
from SwissArmyTransformer.model.mixins import CachedAutoregressiveMixin | ||
from SwissArmyTransformer.generation.autoregressive_sampling import filling_sequence, evaluate_perplexity | ||
from SwissArmyTransformer.generation.sampling_strategies import BeamSearchStrategy, BaseStrategy | ||
from SwissArmyTransformer.generation.utils import timed_name, generate_continually | ||
|
||
from SwissArmyTransformer.model.official.glm130B_model import RotaryEmbeddingMixin | ||
|
||
def get_masks_and_position_ids_gmask(seq, mask_position, context_length): | ||
tokens = seq.unsqueeze(0) | ||
|
||
attention_mask = torch.ones((1, len(seq), len(seq)), device=tokens.device) | ||
attention_mask.tril_() | ||
attention_mask[..., :context_length - 1] = 1 | ||
attention_mask.unsqueeze_(1) | ||
|
||
position_ids = torch.arange(len(seq), dtype=torch.long, | ||
device=tokens.device) | ||
position_ids = position_ids.unsqueeze(0) | ||
|
||
return tokens, attention_mask, position_ids | ||
|
||
def get_masks_and_position_ids_mask(seq, mask_position, context_length): | ||
tokens = seq.unsqueeze(0) | ||
|
||
attention_mask = torch.ones((1, len(seq), len(seq)), device=tokens.device) | ||
attention_mask.tril_() | ||
attention_mask[..., :context_length - 1] = 1 | ||
attention_mask.unsqueeze_(1) | ||
|
||
position_ids = torch.arange(len(seq), dtype=torch.long, | ||
device=tokens.device) | ||
position_ids[context_length - 1:] = mask_position | ||
|
||
position_ids = position_ids.unsqueeze(0) | ||
|
||
return tokens, attention_mask, position_ids | ||
|
||
def main(args): | ||
args.do_train = False | ||
initialize_distributed(args) | ||
tokenizer = get_tokenizer(args) | ||
# build model | ||
model = GLM130B(args) | ||
|
||
if args.fp16: | ||
model = model.half() | ||
|
||
load_checkpoint(model, args) | ||
|
||
with torch.cuda.device(args.device): | ||
model = bminf.wrapper(model, quantization=False, memory_limit=20 << 30) # 20GB | ||
|
||
model.eval() | ||
|
||
end_tokens = [tokenizer.get_command('eop'), tokenizer.get_command('eos')] | ||
# define function for each query | ||
|
||
if args.sampling_strategy == 'BaseStrategy': | ||
strategy = BaseStrategy(temperature=args.temperature, top_k=args.top_k,end_tokens=end_tokens) | ||
elif args.sampling_strategy == 'BeamSearchStrategy': | ||
strategy = BeamSearchStrategy(args.batch_size, length_penalty=args.length_penalty, consider_end=True, end_tokens=end_tokens, no_repeat_ngram_size=args.no_repeat_ngram_size, min_tgt_length=args.min_tgt_length) | ||
else: | ||
raise ValueError(f'unknown strategy {args.sampling_strategy}') | ||
|
||
def process(raw_text): | ||
if args.with_id: | ||
query_id, raw_text = raw_text.split('\t') | ||
|
||
# add MASK | ||
generation_mask = '[gMASK]' if args.task_mask else '[MASK]' | ||
if args.task_mask: | ||
assert '[MASK]' not in raw_text, 'should not mix [MASK] and [gMASK]' | ||
else: | ||
assert '[gMASK]' not in raw_text, 'should not mix [MASK] and [gMASK]' | ||
|
||
mask_pattern = r'\[g?MASK\]' | ||
text_list = re.split(mask_pattern, raw_text) | ||
pattern_list = re.compile(mask_pattern).findall(raw_text) | ||
seq = [] | ||
for i in range(len(pattern_list)): | ||
pattern = pattern_list[i] | ||
sub_text = text_list[i] | ||
seq.extend(tokenizer.tokenize(sub_text)) | ||
seq.append(tokenizer.get_command(pattern)) | ||
|
||
seq.extend(tokenizer.tokenize(text_list[-1])) | ||
|
||
if 'MASK]' not in raw_text: | ||
seq += [tokenizer.get_command(generation_mask)] | ||
raw_text += ' ' + generation_mask | ||
if not raw_text.endswith('MASK]'): | ||
seq = seq + [tokenizer.get_command('eos')] | ||
if mpu.get_model_parallel_rank() == 0: | ||
print('raw text: {}\n'.format(raw_text)) | ||
if len(seq) > args.max_sequence_length: | ||
raise ValueError('text too long.') | ||
|
||
# generation | ||
mbz = args.max_inference_batch_size | ||
assert args.batch_size < mbz or args.batch_size % mbz == 0 | ||
output_list = [seq] | ||
# continually detect the first mark position | ||
while True: | ||
seq = output_list[0] # TODO find the best one | ||
# detect | ||
mask_tokens = tokenizer.get_command(generation_mask) | ||
mask_position = len(seq) | ||
try: | ||
mask_position = min(mask_position, seq.index(mask_tokens)) | ||
except ValueError: | ||
pass | ||
if mask_position == len(seq): | ||
break | ||
|
||
if args.task_mask: | ||
get_func = partial(get_masks_and_position_ids_gmask, mask_position=mask_position, context_length=len(seq) + 1) | ||
else: | ||
get_func = partial(get_masks_and_position_ids_mask, mask_position=mask_position, context_length=len(seq) + 1) | ||
|
||
output_list = [] | ||
|
||
for tim in range(max(args.batch_size // mbz, 1)): | ||
input_seq = torch.cuda.LongTensor( | ||
seq + [tokenizer.get_command('sop')] + [-1] * (args.out_seq_length - len(seq) - 1), | ||
device=args.device) | ||
output = filling_sequence(model, input_seq, | ||
batch_size=min(args.batch_size, mbz), | ||
strategy=strategy, | ||
log_attention_weights=None, | ||
get_masks_and_position_ids=get_func, | ||
)[0] # we don't use mems, fill back | ||
if isinstance(output, torch.Tensor): # different strategies | ||
output = list(output) | ||
|
||
output_list.extend(output) | ||
|
||
# clip -1s and fill back generated things into seq | ||
for i in range(len(output_list)): | ||
output = output_list[i].tolist() | ||
try: | ||
unfinished = output.index(-1) | ||
except ValueError: | ||
unfinished = len(output) | ||
if output[unfinished - 1] in end_tokens: | ||
unfinished -= 1 | ||
bog = output.index(tokenizer.get_command('sop')) | ||
output_list[i] = output[:mask_position] + output[bog + 1:unfinished] + output[mask_position + 1:bog] | ||
|
||
# decoding | ||
txts = [] | ||
for seq in output_list: | ||
decode_tokens = tokenizer.detokenize(seq) | ||
txts.append(decode_tokens) | ||
if args.device == 0: | ||
print(torch.cuda.memory_summary()) | ||
# save | ||
if args.with_id: | ||
full_path = os.path.join(args.output_path, query_id + '.txt') | ||
else: | ||
prefix = raw_text.replace('/', '')[:20] | ||
full_path = timed_name(prefix, '.txt', args.output_path) | ||
if mpu.get_model_parallel_rank() == 0: | ||
print("answer", txts) # print the first. | ||
with open(full_path, 'w', encoding='utf-8') as fout: | ||
for txt in txts: | ||
fout.write(txt + '\n') | ||
os.chmod(full_path, stat.S_IRWXO + stat.S_IRWXG + stat.S_IRWXU) | ||
os.makedirs(args.output_path, exist_ok=True) | ||
generate_continually(process, args.input_source) | ||
|
||
|
||
if __name__ == "__main__": | ||
py_parser = argparse.ArgumentParser(add_help=False) | ||
py_parser.add_argument('--sampling-strategy', type=str, default='BaseStrategy', help='type name of sampling strategy') | ||
GLM130B.add_model_specific_args(py_parser) | ||
known, args_list = py_parser.parse_known_args() | ||
args = get_args(args_list) | ||
args = argparse.Namespace(**vars(args), **vars(known)) | ||
|
||
with torch.no_grad(): | ||
main(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# coding=utf-8 | ||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import enum | ||
|
||
class LayerType(enum.Enum): | ||
encoder = 1 | ||
decoder = 2 | ||
|
||
class AttnType(enum.Enum): | ||
self_attn = 1 | ||
cross_attn = 2 | ||
|
||
class AttnMaskType(enum.Enum): | ||
padding = 1 | ||
causal = 2 | ||
prefix = 3 | ||
|
||
class PositionEmbeddingType(enum.Enum): | ||
rotary = 1 | ||
absolute = 2 | ||
alibi = 3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import os | ||
import sys | ||
import torch | ||
|
||
def megatron_to_sat(checkpoint_name, target): | ||
num_layer = 70 | ||
# checkpoint_name = '/thudm/workspace/hanyu/SwissArmyTransformer-old/data/global_step10400/iter_0010400/10400/mp_rank_00_model_states.pt' | ||
sd = torch.load(checkpoint_name, map_location='cpu') | ||
new_sd = {} | ||
new_sd['transformer.word_embeddings.weight'] = sd['model']['word_embeddings_for_head']['weight'] | ||
|
||
encoder = sd['model']['language_model']['encoder'] | ||
|
||
for i in range(num_layer): | ||
new_sd['transformer.layers.' + str(i) +'.input_layernorm.weight'] = encoder['layers.' + str(i) + '.input_layernorm.weight'] | ||
new_sd['transformer.layers.' + str(i) +'.input_layernorm.bias'] = encoder['layers.' + str(i) + '.input_layernorm.bias'] | ||
|
||
new_sd['transformer.layers.' + str(i) + '.attention.query_key_value.weight'] = encoder['layers.' + str(i) + '.self_attention.query_key_value.weight'] | ||
new_sd['transformer.layers.' + str(i) + '.attention.query_key_value.bias'] = encoder['layers.' + str(i) + '.self_attention.query_key_value.bias'] | ||
|
||
new_sd['transformer.layers.' + str(i) + '.attention.dense.weight'] = encoder['layers.' + str(i) + '.self_attention.dense.weight'] | ||
new_sd['transformer.layers.' + str(i) + '.attention.dense.bias'] = encoder['layers.' + str(i) + '.self_attention.dense.bias'] | ||
|
||
new_sd['transformer.layers.' + str(i) + '.post_attention_layernorm.weight'] = encoder['layers.' + str(i) + '.post_attention_layernorm.weight'] | ||
new_sd['transformer.layers.' + str(i) + '.post_attention_layernorm.bias'] = encoder['layers.' + str(i) + '.post_attention_layernorm.bias'] | ||
|
||
new_sd['transformer.layers.' + str(i) + '.mlp.dense_h_to_4h.weight'] = encoder['layers.' + str(i) + '.mlp.dense_h_to_4h.weight'] | ||
new_sd['transformer.layers.' + str(i) + '.mlp.dense_h_to_4h.bias'] = encoder['layers.' + str(i) + '.mlp.dense_h_to_4h.bias'] | ||
|
||
new_sd['transformer.layers.' + str(i) + '.mlp.dense_4h_to_h.weight'] = encoder['layers.' + str(i) + '.mlp.dense_4h_to_h.weight'] | ||
new_sd['transformer.layers.' + str(i) + '.mlp.dense_4h_to_h.bias'] = encoder['layers.' + str(i) + '.mlp.dense_4h_to_h.bias'] | ||
|
||
new_sd['transformer.final_layernorm.weight'] = encoder['final_layernorm.weight'] | ||
new_sd['transformer.final_layernorm.bias'] = encoder['final_layernorm.bias'] | ||
new_sd = { 'module': new_sd } | ||
# target = open('/thudm/workspace/hanyu/SwissArmyTransformer/data/global_step10400/iter_0010400/10400/mp_rank_00_model_states.pt', 'w') | ||
torch.save(new_sd, target) | ||
|
||
|
||
def main(): | ||
dir_path = str(sys.argv[1]) | ||
target_dir = str(sys.argv[2]) | ||
|
||
iter_path = os.path.join(dir_path, 'latest_checkpointed_iteration.txt') | ||
|
||
iteration = open(iter_path).read() | ||
|
||
print(iteration) | ||
|
||
new_iter_dir = os.path.join(target_dir, 'iter_00' + iteration) | ||
iter_dir = os.path.join(dir_path, 'iter_00' + iteration) | ||
|
||
os.mkdir(new_iter_dir) | ||
new_iter_path = os.path.join(new_iter_dir, 'latest') | ||
|
||
print(iter_path, new_iter_path) | ||
os.system('cp {} {}'.format(iter_path, new_iter_path)) | ||
|
||
new_model_dir = os.path.join(new_iter_dir, iteration) | ||
os.mkdir(new_model_dir) | ||
|
||
for i in range(8): | ||
model_dir = os.path.join(iter_dir, 'mp_rank_0' + str(i)) | ||
model_path = os.path.join(model_dir, 'model_optim_rng.pt') | ||
new_model_path = os.path.join(new_model_dir, 'mp_rank_0' + str(i) + '_model_states.pt') | ||
print(model_path, new_model_path) | ||
megatron_to_sat(model_path, new_model_path) | ||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
#!/bin/bash | ||
CHECKPOINT_PATH=/path/to/GLM130B/checkpoint | ||
|
||
source $1 | ||
MPSIZE=8 | ||
MAXSEQLEN=512 | ||
MASTER_PORT=$(shuf -n 1 -i 10000-65535) | ||
|
||
#SAMPLING ARGS | ||
TEMP=0.9 | ||
#If TOPK/TOPP are 0 it defaults to greedy sampling, top-k will also override top-p | ||
TOPK=40 | ||
TOPP=0 | ||
|
||
python -m torch.distributed.launch --nproc_per_node=$MPSIZE --master_port $MASTER_PORT inference_glm130B.py \ | ||
--mode inference \ | ||
--model-parallel-size $MPSIZE \ | ||
$MODEL_ARGS \ | ||
--no-repeat-ngram-size 3 \ | ||
--length-penalty 0.7 \ | ||
--fp16 \ | ||
--out-seq-length $MAXSEQLEN \ | ||
--temperature $TEMP \ | ||
--top_k $TOPK \ | ||
--output-path samples_glm \ | ||
--batch-size 1 \ | ||
--out-seq-length 100 \ | ||
--mode inference \ | ||
--input-source ./input.txt \ | ||
--sampling-strategy BeamSearchStrategy |
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.