diff --git a/3rdparty/FasterTransformer b/3rdparty/FasterTransformer deleted file mode 160000 index 70eb116..0000000 --- a/3rdparty/FasterTransformer +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 70eb1166af63e089f0bdb8a17adf34aa52c464d9 diff --git a/example/glm-130B/config/model_glm_130B.sh b/example/glm-130B/config/model_glm_130B.sh new file mode 100644 index 0000000..c53c87a --- /dev/null +++ b/example/glm-130B/config/model_glm_130B.sh @@ -0,0 +1,13 @@ +MODEL_TYPE="blocklm-130B" +MODEL_ARGS="--num-layers 70 \ + --hidden-size 12288 \ + --inner-hidden-size 32768 \ + --vocab-size 150528 \ + --num-attention-heads 96 \ + --max-sequence-length 1025 \ + --tokenizer-type icetk-glm-130B \ + --layernorm-order post \ + --skip-init \ + --task-mask \ + --load ${CHECKPOINT_PATH}/iter_0020000" + # diff --git a/example/glm-130B/inference_glm130B.py b/example/glm-130B/inference_glm130B.py new file mode 100644 index 0000000..14b2e7b --- /dev/null +++ b/example/glm-130B/inference_glm130B.py @@ -0,0 +1,204 @@ +# -*- encoding: utf-8 -*- +''' +@File : inference_large_scale.py +@Time : 2021/10/22 19:41:58 +@Author : Ming Ding +@Contact : dm18@mails.tsinghua.edu.cn +''' + +# here put the import lib +from functools import partial +import os +import bminf +import torch +import argparse +import stat +import re + + +from SwissArmyTransformer import mpu, get_args, get_tokenizer +from SwissArmyTransformer.arguments import initialize_distributed, set_random_seed +from SwissArmyTransformer.training import load_checkpoint + +from SwissArmyTransformer.model import GLM130B +from SwissArmyTransformer.model.mixins import CachedAutoregressiveMixin +from SwissArmyTransformer.generation.autoregressive_sampling import filling_sequence, evaluate_perplexity +from SwissArmyTransformer.generation.sampling_strategies import BeamSearchStrategy, BaseStrategy +from SwissArmyTransformer.generation.utils import timed_name, generate_continually + +from SwissArmyTransformer.model.official.glm130B_model import RotaryEmbeddingMixin + +def get_masks_and_position_ids_gmask(seq, mask_position, context_length): + tokens = seq.unsqueeze(0) + + attention_mask = torch.ones((1, len(seq), len(seq)), device=tokens.device) + attention_mask.tril_() + attention_mask[..., :context_length - 1] = 1 + attention_mask.unsqueeze_(1) + + position_ids = torch.arange(len(seq), dtype=torch.long, + device=tokens.device) + position_ids = position_ids.unsqueeze(0) + + return tokens, attention_mask, position_ids + +def get_masks_and_position_ids_mask(seq, mask_position, context_length): + tokens = seq.unsqueeze(0) + + attention_mask = torch.ones((1, len(seq), len(seq)), device=tokens.device) + attention_mask.tril_() + attention_mask[..., :context_length - 1] = 1 + attention_mask.unsqueeze_(1) + + position_ids = torch.arange(len(seq), dtype=torch.long, + device=tokens.device) + position_ids[context_length - 1:] = mask_position + + position_ids = position_ids.unsqueeze(0) + + return tokens, attention_mask, position_ids + +def main(args): + args.do_train = False + initialize_distributed(args) + tokenizer = get_tokenizer(args) + # build model + model = GLM130B(args) + + if args.fp16: + model = model.half() + + load_checkpoint(model, args) + + with torch.cuda.device(args.device): + model = bminf.wrapper(model, quantization=False, memory_limit=20 << 30) # 20GB + + model.eval() + + end_tokens = [tokenizer.get_command('eop'), tokenizer.get_command('eos')] + # define function for each query + + if args.sampling_strategy == 'BaseStrategy': + strategy = BaseStrategy(temperature=args.temperature, top_k=args.top_k,end_tokens=end_tokens) + elif args.sampling_strategy == 'BeamSearchStrategy': + strategy = BeamSearchStrategy(args.batch_size, length_penalty=args.length_penalty, consider_end=True, end_tokens=end_tokens, no_repeat_ngram_size=args.no_repeat_ngram_size, min_tgt_length=args.min_tgt_length) + else: + raise ValueError(f'unknown strategy {args.sampling_strategy}') + + def process(raw_text): + if args.with_id: + query_id, raw_text = raw_text.split('\t') + + # add MASK + generation_mask = '[gMASK]' if args.task_mask else '[MASK]' + if args.task_mask: + assert '[MASK]' not in raw_text, 'should not mix [MASK] and [gMASK]' + else: + assert '[gMASK]' not in raw_text, 'should not mix [MASK] and [gMASK]' + + mask_pattern = r'\[g?MASK\]' + text_list = re.split(mask_pattern, raw_text) + pattern_list = re.compile(mask_pattern).findall(raw_text) + seq = [] + for i in range(len(pattern_list)): + pattern = pattern_list[i] + sub_text = text_list[i] + seq.extend(tokenizer.tokenize(sub_text)) + seq.append(tokenizer.get_command(pattern)) + + seq.extend(tokenizer.tokenize(text_list[-1])) + + if 'MASK]' not in raw_text: + seq += [tokenizer.get_command(generation_mask)] + raw_text += ' ' + generation_mask + if not raw_text.endswith('MASK]'): + seq = seq + [tokenizer.get_command('eos')] + if mpu.get_model_parallel_rank() == 0: + print('raw text: {}\n'.format(raw_text)) + if len(seq) > args.max_sequence_length: + raise ValueError('text too long.') + + # generation + mbz = args.max_inference_batch_size + assert args.batch_size < mbz or args.batch_size % mbz == 0 + output_list = [seq] + # continually detect the first mark position + while True: + seq = output_list[0] # TODO find the best one + # detect + mask_tokens = tokenizer.get_command(generation_mask) + mask_position = len(seq) + try: + mask_position = min(mask_position, seq.index(mask_tokens)) + except ValueError: + pass + if mask_position == len(seq): + break + + if args.task_mask: + get_func = partial(get_masks_and_position_ids_gmask, mask_position=mask_position, context_length=len(seq) + 1) + else: + get_func = partial(get_masks_and_position_ids_mask, mask_position=mask_position, context_length=len(seq) + 1) + + output_list = [] + + for tim in range(max(args.batch_size // mbz, 1)): + input_seq = torch.cuda.LongTensor( + seq + [tokenizer.get_command('sop')] + [-1] * (args.out_seq_length - len(seq) - 1), + device=args.device) + output = filling_sequence(model, input_seq, + batch_size=min(args.batch_size, mbz), + strategy=strategy, + log_attention_weights=None, + get_masks_and_position_ids=get_func, + )[0] # we don't use mems, fill back + if isinstance(output, torch.Tensor): # different strategies + output = list(output) + + output_list.extend(output) + + # clip -1s and fill back generated things into seq + for i in range(len(output_list)): + output = output_list[i].tolist() + try: + unfinished = output.index(-1) + except ValueError: + unfinished = len(output) + if output[unfinished - 1] in end_tokens: + unfinished -= 1 + bog = output.index(tokenizer.get_command('sop')) + output_list[i] = output[:mask_position] + output[bog + 1:unfinished] + output[mask_position + 1:bog] + + # decoding + txts = [] + for seq in output_list: + decode_tokens = tokenizer.detokenize(seq) + txts.append(decode_tokens) + if args.device == 0: + print(torch.cuda.memory_summary()) + # save + if args.with_id: + full_path = os.path.join(args.output_path, query_id + '.txt') + else: + prefix = raw_text.replace('/', '')[:20] + full_path = timed_name(prefix, '.txt', args.output_path) + if mpu.get_model_parallel_rank() == 0: + print("answer", txts) # print the first. + with open(full_path, 'w', encoding='utf-8') as fout: + for txt in txts: + fout.write(txt + '\n') + os.chmod(full_path, stat.S_IRWXO + stat.S_IRWXG + stat.S_IRWXU) + os.makedirs(args.output_path, exist_ok=True) + generate_continually(process, args.input_source) + + +if __name__ == "__main__": + py_parser = argparse.ArgumentParser(add_help=False) + py_parser.add_argument('--sampling-strategy', type=str, default='BaseStrategy', help='type name of sampling strategy') + GLM130B.add_model_specific_args(py_parser) + known, args_list = py_parser.parse_known_args() + args = get_args(args_list) + args = argparse.Namespace(**vars(args), **vars(known)) + + with torch.no_grad(): + main(args) diff --git a/example/glm-130B/model_convert/megatron/enums.py b/example/glm-130B/model_convert/megatron/enums.py new file mode 100644 index 0000000..d905046 --- /dev/null +++ b/example/glm-130B/model_convert/megatron/enums.py @@ -0,0 +1,34 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import enum + +class LayerType(enum.Enum): + encoder = 1 + decoder = 2 + +class AttnType(enum.Enum): + self_attn = 1 + cross_attn = 2 + +class AttnMaskType(enum.Enum): + padding = 1 + causal = 2 + prefix = 3 + +class PositionEmbeddingType(enum.Enum): + rotary = 1 + absolute = 2 + alibi = 3 diff --git a/example/glm-130B/model_convert/megatron_to_sat.py b/example/glm-130B/model_convert/megatron_to_sat.py new file mode 100644 index 0000000..d4d39a3 --- /dev/null +++ b/example/glm-130B/model_convert/megatron_to_sat.py @@ -0,0 +1,70 @@ +import os +import sys +import torch + +def megatron_to_sat(checkpoint_name, target): + num_layer = 70 + # checkpoint_name = '/thudm/workspace/hanyu/SwissArmyTransformer-old/data/global_step10400/iter_0010400/10400/mp_rank_00_model_states.pt' + sd = torch.load(checkpoint_name, map_location='cpu') + new_sd = {} + new_sd['transformer.word_embeddings.weight'] = sd['model']['word_embeddings_for_head']['weight'] + + encoder = sd['model']['language_model']['encoder'] + + for i in range(num_layer): + new_sd['transformer.layers.' + str(i) +'.input_layernorm.weight'] = encoder['layers.' + str(i) + '.input_layernorm.weight'] + new_sd['transformer.layers.' + str(i) +'.input_layernorm.bias'] = encoder['layers.' + str(i) + '.input_layernorm.bias'] + + new_sd['transformer.layers.' + str(i) + '.attention.query_key_value.weight'] = encoder['layers.' + str(i) + '.self_attention.query_key_value.weight'] + new_sd['transformer.layers.' + str(i) + '.attention.query_key_value.bias'] = encoder['layers.' + str(i) + '.self_attention.query_key_value.bias'] + + new_sd['transformer.layers.' + str(i) + '.attention.dense.weight'] = encoder['layers.' + str(i) + '.self_attention.dense.weight'] + new_sd['transformer.layers.' + str(i) + '.attention.dense.bias'] = encoder['layers.' + str(i) + '.self_attention.dense.bias'] + + new_sd['transformer.layers.' + str(i) + '.post_attention_layernorm.weight'] = encoder['layers.' + str(i) + '.post_attention_layernorm.weight'] + new_sd['transformer.layers.' + str(i) + '.post_attention_layernorm.bias'] = encoder['layers.' + str(i) + '.post_attention_layernorm.bias'] + + new_sd['transformer.layers.' + str(i) + '.mlp.dense_h_to_4h.weight'] = encoder['layers.' + str(i) + '.mlp.dense_h_to_4h.weight'] + new_sd['transformer.layers.' + str(i) + '.mlp.dense_h_to_4h.bias'] = encoder['layers.' + str(i) + '.mlp.dense_h_to_4h.bias'] + + new_sd['transformer.layers.' + str(i) + '.mlp.dense_4h_to_h.weight'] = encoder['layers.' + str(i) + '.mlp.dense_4h_to_h.weight'] + new_sd['transformer.layers.' + str(i) + '.mlp.dense_4h_to_h.bias'] = encoder['layers.' + str(i) + '.mlp.dense_4h_to_h.bias'] + + new_sd['transformer.final_layernorm.weight'] = encoder['final_layernorm.weight'] + new_sd['transformer.final_layernorm.bias'] = encoder['final_layernorm.bias'] + new_sd = { 'module': new_sd } + # target = open('/thudm/workspace/hanyu/SwissArmyTransformer/data/global_step10400/iter_0010400/10400/mp_rank_00_model_states.pt', 'w') + torch.save(new_sd, target) + + +def main(): + dir_path = str(sys.argv[1]) + target_dir = str(sys.argv[2]) + + iter_path = os.path.join(dir_path, 'latest_checkpointed_iteration.txt') + + iteration = open(iter_path).read() + + print(iteration) + + new_iter_dir = os.path.join(target_dir, 'iter_00' + iteration) + iter_dir = os.path.join(dir_path, 'iter_00' + iteration) + + os.mkdir(new_iter_dir) + new_iter_path = os.path.join(new_iter_dir, 'latest') + + print(iter_path, new_iter_path) + os.system('cp {} {}'.format(iter_path, new_iter_path)) + + new_model_dir = os.path.join(new_iter_dir, iteration) + os.mkdir(new_model_dir) + + for i in range(8): + model_dir = os.path.join(iter_dir, 'mp_rank_0' + str(i)) + model_path = os.path.join(model_dir, 'model_optim_rng.pt') + new_model_path = os.path.join(new_model_dir, 'mp_rank_0' + str(i) + '_model_states.pt') + print(model_path, new_model_path) + megatron_to_sat(model_path, new_model_path) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/example/glm-130B/scripts/generate_glm.sh b/example/glm-130B/scripts/generate_glm.sh new file mode 100644 index 0000000..a5aad7f --- /dev/null +++ b/example/glm-130B/scripts/generate_glm.sh @@ -0,0 +1,30 @@ +#!/bin/bash +CHECKPOINT_PATH=/path/to/GLM130B/checkpoint + +source $1 +MPSIZE=8 +MAXSEQLEN=512 +MASTER_PORT=$(shuf -n 1 -i 10000-65535) + +#SAMPLING ARGS +TEMP=0.9 +#If TOPK/TOPP are 0 it defaults to greedy sampling, top-k will also override top-p +TOPK=40 +TOPP=0 + +python -m torch.distributed.launch --nproc_per_node=$MPSIZE --master_port $MASTER_PORT inference_glm130B.py \ + --mode inference \ + --model-parallel-size $MPSIZE \ + $MODEL_ARGS \ + --no-repeat-ngram-size 3 \ + --length-penalty 0.7 \ + --fp16 \ + --out-seq-length $MAXSEQLEN \ + --temperature $TEMP \ + --top_k $TOPK \ + --output-path samples_glm \ + --batch-size 1 \ + --out-seq-length 100 \ + --mode inference \ + --input-source ./input.txt \ + --sampling-strategy BeamSearchStrategy diff --git a/example/glm130b/example.py b/example/glm130b/example.py deleted file mode 100644 index 8fe3754..0000000 --- a/example/glm130b/example.py +++ /dev/null @@ -1,20 +0,0 @@ -from model import GLM130B -import torch -import bminf - -def main(): - m = bminf.wrapper(GLM130B().cuda(), quantization=False) - print(list(m.state_dict().keys())) - - x = torch.LongTensor([[1, 2, 3, 4]]).cuda() - position = torch.LongTensor([[0, 1, 2, 3]]).cuda() - mask = torch.BoolTensor([[ - [True, True, True, True], - [True, True, True, True], - [True, True, True, True], - [True, True, True, True], - ]]).cuda() - print(m(x, position, mask)) - -if __name__ == "__main__": - main() diff --git a/example/glm130b/model.py b/example/glm130b/model.py deleted file mode 100644 index 917f166..0000000 --- a/example/glm130b/model.py +++ /dev/null @@ -1,167 +0,0 @@ -import torch -import math - -class GLMSelfAttention(torch.nn.Module): - def __init__( - self, - dim_model : int, - num_heads : int, - dim_head : int, - ) -> None: - super().__init__() - - self.dim_head = dim_head - self.num_heads = num_heads - - self.weight_q = torch.nn.Linear(dim_model, num_heads * dim_head, bias=True).half() - self.weight_k = torch.nn.Linear(dim_model, num_heads * dim_head, bias=True).half() - self.weight_v = torch.nn.Linear(dim_model, num_heads * dim_head, bias=True).half() - - self.attn_out = torch.nn.Linear(dim_model, num_heads * dim_head, bias=True).half() - - self.softmax = torch.nn.Softmax(dim=-1) - - def forward( - self, - hidden_state : torch.Tensor, # (batch, seq_len, dim_model) - mask : torch.BoolTensor, # (batch, seq_len, seq_len) - position : torch.LongTensor, - ): - batch_size = hidden_state.size(0) - len_k = len_q = hidden_state.size(1) - - h_q = self.weight_q(hidden_state) - h_k = self.weight_k(hidden_state) - - h_q = h_q.view(batch_size, len_q, self.num_heads, self.dim_head).permute(0, 2, 1, 3) # (batch, num_heads, len_q, dim_head) - h_k = h_k.view(batch_size, len_k, self.num_heads, self.dim_head).permute(0, 2, 1, 3) # (batch, num_heads, len_k, dim_head) - - h_q = self.rotary_embedding(h_q, position) # (batch, num_heads, len_q, dim_head) - h_k = self.rotary_embedding(h_k, position) # (batch, num_heads, len_k, dim_head) - - - score = torch.matmul(h_q, h_k.transpose(-1, -2)) / math.sqrt(self.dim_head) # (batch, num_heads, len_q, len_k) - del h_q - del h_k - - score = torch.masked_fill( - score, - mask.view(batch_size, 1, len_q, len_k)==False, - torch.scalar_tensor(float("-inf"), device=score.device, dtype=score.dtype) - ) # (batch, num_heads, len_q, len_k) - score = self.softmax(score) - - # avoid nan in softmax - score = torch.masked_fill( - score, - mask.view(batch_size, 1, len_q, len_k)==False, - torch.scalar_tensor(0, device=score.device, dtype=score.dtype) - ) - - h_v = self.weight_v(hidden_state) - h_v = h_v.view(batch_size, len_k, self.num_heads, self.dim_head).permute(0, 2, 1, 3) # (batch, num_heads, len_k, dim_head) - - score = torch.matmul(score, h_v) # (batch, num_heads, len_q, dim_head) - del h_v - - score = score.permute(0, 2, 1, 3).reshape(batch_size, len_q, self.num_heads * self.dim_head) - - return self.attn_out(score) - - - def rotary_embedding(self, - hidden : torch.Tensor, # (batch, num_heads, seq_len, dim_head) - position : torch.LongTensor # (batch, seq_len) - ): - dim = hidden.size(-1) - - inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2, device=hidden.device).float() / dim)) # (dim_head/2) - inv_freq = inv_freq.half() - freqs = torch.einsum('bi,j->bij', position.half(), inv_freq) # (batch, seq_len, dim_head/2) - emb = torch.cat((freqs, freqs), dim=-1) # (batch, seq_len, dim_head) - v_cos = emb.cos() # (batch, seq_len, dim_head) - v_sin = emb.sin() # (batch, seq_len, dim_head) - - def rotate_half(x): - x1, x2 = x[..., :x.size(-1) // 2], x[..., x.size(-1) // 2:] - return torch.cat((-x2, x1), dim=x1.ndim - 1) - - return (hidden * v_cos[:, None, :, :]) + (rotate_half(hidden) * v_sin[:, None, :, :]) - -class FeedForward(torch.nn.Module): - def __init__( - self, - dim_model : int, - dim_ff : int, - ) -> None: - super().__init__() - - self.w_in = torch.nn.Linear(dim_model, dim_ff, bias=True).half() - self.w_out = torch.nn.Linear(dim_ff, dim_model, bias=True).half() - self.w_gate = torch.nn.Linear(dim_model, dim_ff, bias=True).half() - self.activation = torch.nn.GELU() - - - def forward(self, x : torch.Tensor): - x_in = self.activation(self.w_in(x)) - x_in = x_in * self.w_gate(x) - return self.w_out(x_in) - - -class GLMBlock(torch.nn.Module): - def __init__( - self, - dim_model : int, - num_heads : int, - dim_head : int, - dim_ff : int, - alpha : float, - eps : float, - ) -> None: - super().__init__() - - self.input_layernorm = torch.nn.LayerNorm(dim_model, eps=eps).half() - self.attention = GLMSelfAttention(dim_model, num_heads, dim_head) - - self.post_attention_layernorm = torch.nn.LayerNorm(dim_model, eps=eps).half() - self.ff = FeedForward(dim_model, dim_ff) - self.alpha = alpha - - def forward( - self, - hidden_state : torch.Tensor, # (batch, seq_len, dim_model) - mask : torch.BoolTensor, # (batch, seq_len, seq_len) - position : torch.LongTensor - ): - attn_input = self.input_layernorm(hidden_state) - attn_output = self.attention(attn_input, mask, position) - hidden_state = attn_input * self.alpha + attn_output - - mlp_input = self.post_attention_layernorm(hidden_state) - mlp_output = self.ff(mlp_input) - - return mlp_input * self.alpha + mlp_output - -class GLM130B(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - - self.token_embedding = torch.nn.Embedding(150528, 12288).half() - - self.layers = torch.nn.ModuleList([ - GLMBlock( - 12288, - 96, - 128, - 32768, - (2 * 70) ** 0.5, - 1e-5 - ) - for _ in range(2) - ]) - - def forward(self, ids : torch.LongTensor, position : torch.LongTensor, mask : torch.BoolTensor): - hidden_state = self.token_embedding(ids) - for layer in self.layers: - hidden_state = layer(hidden_state, mask, position) - return hidden_state