diff --git a/megatron/arguments.py b/megatron/arguments.py index 6297cb16bb..cdbe49f803 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -83,7 +83,7 @@ def validate_args(args): if args.no_pipeline_parallel: assert args.pipeline_model_parallel_size == 1, \ "pipeline_model_parallel_size must be 1 if pipeline parallel is disabled" - + if args.ds_sequence_parallel_size > 1: assert version.parse(deepspeed.__version__) >= version.parse("0.10.2"), "sequence parallelism requires DeepSpeed version 0.10.2+" @@ -432,6 +432,10 @@ def validate_args(args): assert not args.mos, 'GQA currently does not support args.mos' assert not args.kd, 'GQA currently does not support args.kd' + # entmax loss + if args.loss_function != "cross_entropy": + assert not args.fp16_lm_cross_entropy, "entmax loss only supports fp32" + # Print arguments. _print_args("arguments", args) retro_args = get_retro_args() @@ -477,7 +481,7 @@ def core_transformer_config_from_args(args): kw_args['bias_gelu_fusion'] = False if args.init_method_xavier_uniform: kw_args['init_method'] = torch.nn.init.xavier_uniform_ - kw_args['scaled_init_method'] = torch.nn.init.xavier_uniform_ + kw_args['output_layer_init_method'] = torch.nn.init.xavier_uniform_ return TransformerConfig(**kw_args) @@ -632,7 +636,7 @@ def _add_network_size_args(parser): group.add_argument('--apply-layernorm-1p', action='store_true', help='Adjust LayerNorm weights such that they are centered ' 'around zero. This improves numerical stability.') - group.add_argument('--disable-mem-efficient-ln', action='store_false', + group.add_argument('--disable-mem-efficient-ln', action='store_false', help='Disable the memory-efficient fused LayerNorm optimization ' 'introduced in https://github.com/NVIDIA/apex/pull/1715') group.add_argument('--apply-residual-connection-post-layernorm', @@ -848,7 +852,7 @@ def _add_training_args(parser): 'training runs.') group.add_argument('--random-ltd', action='store_true', - help='enable random layer token drop') + help='enable random layer token drop') group.add_argument('--log-interval', type=int, default=100, help='Report loss and timing interval.') group.add_argument('--exit-interval', type=int, default=None, @@ -940,6 +944,15 @@ def _add_training_args(parser): dest='gradient_accumulation_fusion') group.add_argument('--use-dataset-only', type=bool, required=False, default=False, help='If set to True, only use the megatron dataset for external trainer ') + group.add_argument('--loss-function', default='cross_entropy', + choices=['cross_entropy', 'entmax15', 'sparsemax', 'entmax_bisect'], + help='Loss function for model training') + group.add_argument('--entmax-alpha', type=float, default=1.5, + help='Entmax alpha for entmax_bisect (unused otherwise)') + group.add_argument('--entmax-topk', type=int, default=512, + help='Top k for computation of exact entmax loss (for entmax15 and sparsemax)') + group.add_argument('--entmax-n-iter', type=int, default=30, + help='Number of bisection interations for entmax_bisect') return parser @@ -1034,7 +1047,7 @@ def _add_checkpointing_args(parser): group.add_argument('--no-load-rng', action='store_true', default=None, help='Do not load rng state when loading checkpoint.') group.add_argument('--no-load-lr-state', action='store_true', - help='Do not load lr state when loading checkpoint.') + help='Do not load lr state when loading checkpoint.') group.add_argument('--finetune', action='store_true', help='Load model for finetuning. Do not load optimizer ' 'or rng state from checkpoint and set iteration to 0. ' @@ -1210,7 +1223,7 @@ def _add_data_args(parser): 'form: dataset1-weight dataset1-path dataset2-weight ' 'dataset2-path ...') group.add_argument('--multiple-valid-sets', action='store_true', - help='multiple separated validation steps') + help='multiple separated validation steps') group.add_argument('--test-data-path', nargs='*', default=None, help='Path to the test dataset. Accepted format:' '1) a single data path, 2) multiple datasets in the' @@ -1490,15 +1503,15 @@ def _add_activation_checkpoint_args(parser): def _add_distillation_args(parser): group = parser.add_argument_group('Knowledge distillation', 'Distillation Configurations') - + group.add_argument('--num-layers-teacher', type=int, default=None, - help='Number of the teacher transformer layers.') + help='Number of the teacher transformer layers.') group.add_argument('--num-experts-teacher', type=int, nargs='+', default=[1,], help='number of teacher experts list, MoE related.') group.add_argument('--hidden-size-teacher', type=int, default=None, help='Tansformer teacher hidden size.') group.add_argument('--num-attention-heads-teacher', type=int, default=None, - help='Number of teacher transformer attention heads.') + help='Number of teacher transformer attention heads.') group.add_argument('--mos', action='store_true', help='Enable Mixture-of-Students via knolwedge distillation.') @@ -1509,7 +1522,7 @@ def _add_distillation_args(parser): group.add_argument('--kd-temp', default=1.0, type=float) group.add_argument('--reset-iteration', action='store_true', help='Reset the iteration count.') - + group.add_argument('--load-teacher', type=str, default=None, help='Directory containing a teacher model checkpoint.') diff --git a/megatron/model/gpt_model.py b/megatron/model/gpt_model.py index f2eabe341e..e905825c36 100644 --- a/megatron/model/gpt_model.py +++ b/megatron/model/gpt_model.py @@ -2,7 +2,10 @@ """GPT-2 model.""" +from functools import partial + import torch +import entmax from megatron import get_args from megatron.core import mpu, tensor_parallel, sequence_parallel @@ -24,21 +27,22 @@ except ImportError: MixedFusedRMSNorm = None -try: +try: from deepspeed.checkpoint import ( VOCABULARY_PARAMETER_PATTERNS, PIPELINE_REPLICATED_PARAMETER_PATTERNS, TP_REPLICATED_PARAMETER_PATTERNS, PARAMETER_WITH_ROW_PARALLELISM_PATTERNS, ) - DS_UNIVERSAL_CHECKPOINT_INFO = True + DS_UNIVERSAL_CHECKPOINT_INFO = True except ImportError: - DS_UNIVERSAL_CHECKPOINT_INFO = False + DS_UNIVERSAL_CHECKPOINT_INFO = False def post_language_model_processing(lm_output, labels, logit_weights, parallel_output, - fp16_lm_cross_entropy): + fp16_lm_cross_entropy, + loss_function, alpha, topk, n_iter, return_support_size=False): # Output. Format [s b h] output = parallel_lm_logits( @@ -46,10 +50,13 @@ def post_language_model_processing(lm_output, labels, logit_weights, logit_weights, parallel_output) + # should it return a None support size in this case? let's say no for now if labels is None: # [s b h] => [b s h] return output.transpose(0,1).contiguous() - else: + + if loss_function == "cross_entropy": + # cross entropy # [b s] => [s b] labels = labels.transpose(0,1).contiguous() cross_entropy = sequence_parallel.vocab_sequence_parallel_cross_entropy if mpu.get_sequence_parallel_world_size() > 1 \ @@ -62,6 +69,42 @@ def post_language_model_processing(lm_output, labels, logit_weights, # [s b] => [b, s] loss = loss.transpose(0,1).contiguous() + support = None + else: + # now: the loss function is "entmax15", "sparsemax", or "entmax_bisect" + loss_funcs = { + "entmax15": partial(entmax.entmax15_loss, k=topk, return_support_size=True), + "sparsemax": partial(entmax.sparsemax_loss, k=topk, return_support_size=True), + "entmax_bisect": partial(entmax.entmax_bisect_loss, alpha=alpha, n_iter=n_iter) + } + f = loss_funcs[loss_function] + b, s = labels.size() + output = output.transpose(0, 1).contiguous() + vocab_size = output.size(-1) + + # currently entmax_bisect_loss always returns a None support size, + # which is not ideal. This is a stopgap until entmax_bisect_loss is + # fixed + loss = f(output.float().view(-1, vocab_size), labels.view(-1)) + if isinstance(loss, tuple): + loss, support = loss + else: + support = None + + # old version which breaks because entmax_bisect unexpectedly returned + # a tuple: + ''' + if loss_function != "entmax_bisect": + loss, support = f(output.float().view(-1, vocab_size), labels.view(-1)) + else: + loss = f(output.float().view(-1, vocab_size), labels.view(-1)) + support = None + ''' + loss = loss.view(b, s) + + if return_support_size: + return loss, support + else: return loss @@ -84,6 +127,10 @@ def __init__(self, self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy self.return_moe_loss = return_moe_loss self.untie_embeddings_and_output_weights = args.untie_embeddings_and_output_weights + self.loss_function = args.loss_function + self.entmax_alpha = args.entmax_alpha + self.entmax_topk = args.entmax_topk + self.entmax_n_iter = args.entmax_n_iter self.language_model, self._language_model_key = get_language_model( config=config, @@ -106,7 +153,8 @@ def forward(self, input_ids, position_ids, attention_mask, retriever_position_ids=None, retriever_attn_mask=None, labels=None, tokentype_ids=None, inference_params=None, - curriculum_seqlen=None): + curriculum_seqlen=None, + return_support_size=False): args = get_args() if curriculum_seqlen is not None: args.curriculum_seqlen = curriculum_seqlen @@ -135,13 +183,25 @@ def forward(self, input_ids, position_ids, attention_mask, inference_params=inference_params) if self.post_process: - lm_output = post_language_model_processing( + post_lm_out = post_language_model_processing( lm_output, labels, self.language_model.output_layer.weight if self.untie_embeddings_and_output_weights else self.shared_embedding_or_output_weight(), self.parallel_output, - self.fp16_lm_cross_entropy) - - return lm_output, moe_losses if self.return_moe_loss else lm_output + self.fp16_lm_cross_entropy, + self.loss_function, + self.entmax_alpha, + self.entmax_topk, + self.entmax_n_iter, + return_support_size=True) + if labels is not None: + lm_output, support_size = post_lm_out + else: + lm_output = post_lm_out + # now...what do do about support_size? + if return_support_size: + return lm_output, moe_losses if self.return_moe_loss else lm_output, support_size + else: + return lm_output, moe_losses if self.return_moe_loss else lm_output def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): @@ -210,7 +270,7 @@ def universal_checkpoint_info(self): ] return info - + def CrossEntropy(output, labels): labels, loss_mask = labels[0], labels[1] diff --git a/pretrain_gpt.py b/pretrain_gpt.py index 1ff671f120..5129129923 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -211,7 +211,7 @@ def get_batch_pipe(data): return (tokens, position_ids, attention_mask), (labels, loss_mask) -def loss_func(loss_mask, moe_loss, mos_loss, output_tensor): +def loss_func(loss_mask, moe_loss, mos_loss, support_size, output_tensor): args = get_args() losses = output_tensor.float() loss_mask = loss_mask.view(-1).float() @@ -229,11 +229,68 @@ def loss_func(loss_mask, moe_loss, mos_loss, output_tensor): print_rank_0('>>> total loss: {}, lm loss {}, kd loss {}'.format(loss, averaged_loss[0], mos_loss)) else: if max(args.num_experts) <= 1: - return loss, {'lm loss': averaged_loss[0]} + if support_size is not None: + support_size = support_size.float() + n_tokens = support_size.size(0) + + # how often is support size 1? + fully_peaked = support_size.eq(1).sum() + + # what is the max support size in the batch? + max_support = support_size.max() + min_support = support_size.min() + + # sum support stats across groups + current_group = mpu.get_data_parallel_group() + torch.distributed.all_reduce(support_size, group=current_group) + torch.distributed.all_reduce(max_support, group=current_group) + torch.distributed.all_reduce(min_support, group=current_group) + torch.distributed.all_reduce(fully_peaked, group=current_group) + + # find number of groups + world_size = torch.distributed.get_world_size( + group=mpu.get_data_parallel_group() + ) + + # compute mean support size + support_size = support_size / world_size + + # compute mean max support + max_support = max_support / world_size + min_support = min_support / world_size + + # compute how often support is fully peaked + fully_peaked = fully_peaked / (world_size * n_tokens) + + # the stats for support size might be slightly wrong, but still + # illustrative + support_mean = support_size.mean() + support_std = support_size.std() + quantiles = torch.linspace( + 0, 1, 5, dtype=support_size.dtype, device=support_size.device + ) + support_quantiles = torch.quantile(support_size, quantiles) + + loss_dict = { + 'lm loss': averaged_loss[0], + 'support size': support_mean, + "support std": support_std, + "support max": max_support, + "support min": min_support, + "support 25%": support_quantiles[1], + "support 50%": support_quantiles[2], + "support 75%": support_quantiles[3], + "fully peaked": fully_peaked + } + + return loss, loss_dict + else: + return loss, {'lm loss': averaged_loss[0]} else: loss = loss + moe_loss return loss, {'lm loss': averaged_loss[0], 'moe loss': moe_loss} + def calculate_mos_loss(args, stu_output, teacher_model, tokens, position_ids, attention_mask): mos_loss = 0 alpha = args.kd_alpha_ce @@ -277,6 +334,7 @@ def forward_step(data_iterator, model): args.data_efficiency_curriculum_learning_seqlen_type == 'seqlen_reshape': args.data_efficiency_curriculum_learning_numel = torch.numel(tokens) + support_size = None if args.mos or args.kd: # The forward func can return either the loss or the logits, depending on whether passing in the labels or not. stu_output, other_losses = model(tokens, position_ids, attention_mask) @@ -285,8 +343,9 @@ def forward_step(data_iterator, model): labels = labels[:, :args.curriculum_seqlen].contiguous() output_tensor = tensor_parallel.vocab_parallel_cross_entropy(stu_output.contiguous().float(), labels) else: - output_tensor, other_losses = model(tokens, position_ids, attention_mask, - labels=labels) + output_tensor, other_losses, support_size = model( + tokens, position_ids, attention_mask, labels=labels, return_support_size=True + ) if args.curriculum_learning_legacy and args.curriculum_seqlen < args.seq_length: loss_mask = loss_mask[:, :args.curriculum_seqlen].contiguous() @@ -304,7 +363,7 @@ def forward_step(data_iterator, model): args.teacher_model[0], tokens, position_ids, attention_mask) # Output_tensor stores the standard loss, loos_func calculates the total loss. - return output_tensor, partial(loss_func, loss_mask, moe_loss, mos_loss) + return output_tensor, partial(loss_func, loss_mask, moe_loss, mos_loss, support_size) def train_valid_test_datasets_provider(train_val_test_num_samples): diff --git a/tasks/main.py b/tasks/main.py index 9bc38f5fd2..913e519b0f 100644 --- a/tasks/main.py +++ b/tasks/main.py @@ -70,7 +70,11 @@ def get_tasks_args(parser): group.add_argument('--val-av-rank-other-neg', type=int, default=30, help='Av.rank validation: how many other negatives to' ' take from each question pool') - + group.add_argument('--eval-metric', default=None, + help='Eval metric to use other than a task-specific' + 'default') + group.add_argument('--acc-k', default=5, type=int, + help='k for force-decoded accuracy at k') return parser @@ -79,7 +83,7 @@ def get_tasks_args(parser): initialize_megatron(extra_args_provider=get_tasks_args) - args = get_args() + args = get_args() # will the task args be included here? if args.num_layers_per_virtual_pipeline_stage is not None: print("Interleaved pipeline schedule is not yet supported for downstream tasks.") diff --git a/tasks/zeroshot_gpt/evaluate.py b/tasks/zeroshot_gpt/evaluate.py index a9e27fc49c..6d33c4c20c 100644 --- a/tasks/zeroshot_gpt/evaluate.py +++ b/tasks/zeroshot_gpt/evaluate.py @@ -3,9 +3,13 @@ """GPT zero-shot evaluation.""" import math +from functools import partial +from collections import defaultdict import torch +import entmax + from megatron import get_args from megatron import print_rank_0, is_last_rank from megatron import get_tokenizer @@ -44,7 +48,7 @@ def model_provider(pre_process=True, post_process=True): config = core_transformer_config_from_args(get_args()) - if eval_metric == 'loss': + if eval_metric in {"loss", "force_decoded_accuracy", "force_decoded_accuracy_at_k", "sparsemax_score"}: parallel_output = True elif eval_metric == 'accuracy': parallel_output = False @@ -53,7 +57,7 @@ def model_provider(pre_process=True, post_process=True): 'is not supported.'.format(eval_metric)) print_rank_0('building GPT model ...') - + args = get_args() config = core_transformer_config_from_args(args) if args.deepspeed: @@ -62,7 +66,7 @@ def model_provider(pre_process=True, post_process=True): config_dict_or_path=args.deepspeed_config, enabled=args.zero_stage == 3, mpu=mpu): - + model = GPTModel( config=config, num_tokentypes=0, @@ -73,7 +77,7 @@ def model_provider(pre_process=True, post_process=True): else: model = GPTModel(config=config, num_tokentypes=0, parallel_output=parallel_output, pre_process=pre_process, post_process=post_process) - + return model @@ -100,9 +104,189 @@ def process_batch(batch): return tokens, labels, attention_mask, position_ids, loss_mask +""" +if loss_function == "cross_entropy": + # cross entropy + # [b s] => [s b] + labels = labels.transpose(0,1).contiguous() + cross_entropy = sequence_parallel.vocab_sequence_parallel_cross_entropy if mpu.get_sequence_parallel_world_size() > 1 \ + else tensor_parallel.vocab_parallel_cross_entropy + if fp16_lm_cross_entropy: + assert output.dtype == torch.half + loss = cross_entropy(output, labels) + else: + loss = cross_entropy(output.float(), labels) + + # [s b] => [b, s] + loss = loss.transpose(0,1).contiguous() + support = None +else: + # now: the loss function is "entmax15", "sparsemax", or "entmax_bisect" + loss_funcs = { + "entmax15": partial(entmax.entmax15_loss, k=topk, return_support_size=True), + "sparsemax": partial(entmax.sparsemax_loss, k=topk, return_support_size=True), + "entmax_bisect": partial(entmax.entmax_bisect_loss, alpha=alpha, n_iter=n_iter) + } + f = loss_funcs[loss_function] + b, s = labels.size() + output = output.transpose(0, 1).contiguous() + vocab_size = output.size(-1) + if loss_function != "entmax_bisect": + loss, support = f(output.float().view(-1, vocab_size), labels.view(-1)) + else: + loss = f(output.float().view(-1, vocab_size), labels.view(-1)) + support = None + loss = loss.view(b, s) + +""" + + +def _compute_loss(output, labels, loss_mask, loss_function="cross_entropy", topk=512, alpha=1.5, n_iter=30): + """ + Dimensions are confusing but I think I've figured it out. Based on + process_batch (defined above) and forward_step, labels is [b s]. I assume + loss_mask is the same shape. And I assume that output is [b s V] (it would + be ridiculously confusing otherwise). + + Based on the documentation of tensor_parallel.vocab_parallel_cross_entropy, + we can expect output (or rather, output[0], which should be the decoder + output based on TransformerLanguageModel.forward) to be [s b V] and labels + to be [s b]. + """ + print("output type before _compute_loss", type(output)) + if isinstance(output, torch.Tensor): + print("size before indexing", output.size()) + output = output[0] # based on how loss was previously computed + print("size after indexing", output.size()) + + if loss_function == "cross_entropy": + # I believe (based on the commented-out block above) that this + # function takes [s b] as its input. + # But I'm not certain + losses = tensor_parallel.vocab_parallel_cross_entropy( + output.contiguous().float(), labels.contiguous()) + else: + # now: the loss function is "entmax15", "sparsemax", or "entmax_bisect" + loss_funcs = { + "entmax15": partial(entmax.entmax15_loss, k=topk, return_support_size=True), + "sparsemax": partial(entmax.sparsemax_loss, k=topk, return_support_size=True), + "entmax_bisect": partial(entmax.entmax_bisect_loss, alpha=alpha, n_iter=n_iter) + } + f = loss_funcs[loss_function] + print("labels size: ", labels.size()) + print("output size", output.size()) + vocab_size = output[0].size(-1) + if loss_function != "entmax_bisect": + losses, _ = f(output.float().view(-1, vocab_size), labels.view(-1)) + else: + losses = f(output.float().view(-1, vocab_size), labels.view(-1)) + # losses = losses.view(b, s) + + loss = torch.sum(losses.view(-1) * loss_mask.contiguous().view(-1).float()) + + return loss + + +def _force_decoded_accuracy(output, labels, loss_mask): + """ + This is different from LAMBADA accuracy, which is only about getting the + final word right based on a long context. + + Dimensions are confusing but I think I've figured it out. Based on + process_batch (defined above) and forward_step, labels is [b s]. I assume + loss_mask is the same shape. And I assume that output is [b s V] (it would + be ridiculously confusing otherwise). + """ + + # the raw output is a tuple (same as for eval_metric=="loss") + output = output[0] + + predictions = output.argmax(dim=-1).view(-1) + correct = predictions.eq(labels.view(-1)).float() + + correct_sum = torch.sum(correct * loss_mask.contiguous().view(-1).float()) + return correct_sum + + +def _force_decoded_accuracy_at_k(output, labels, loss_mask, k): + """ + Accuracy at k -- do any of the top k outputs match? + + This is different from LAMBADA accuracy, which is only about getting the + final word right based on a long context. + + Dimensions are confusing but I think I've figured it out. Based on + process_batch (defined above) and forward_step, labels is [b s]. I assume + loss_mask is the same shape. And I assume that output is [b s V] (it would + be ridiculously confusing otherwise). + """ + + # the raw output is a tuple (same as for eval_metric=="loss") + output = output[0] + + _, predictions = torch.topk(output, k, dim=-1) + predictions = predictions.view(-1, k) + + correct = predictions.eq(labels.view(-1, 1)).any(dim=-1).float() + + correct_sum = torch.sum(correct * loss_mask.contiguous().view(-1).float()) + return correct_sum + + +def _gini_entropy(probs): + return (probs * (1 - probs)).sum(dim=-1) / 2 + + +def _sparsemax_score(output, labels, loss_mask, loss_function="cross_entropy", topk=512, alpha=1.5, n_iter=30): + # loss_function is really the generator function here + + if isinstance(output, torch.Tensor): + print("size before indexing", output.size()) + output = output[0] # based on how loss was previously computed + vocab_size = output.size(-1) + + output = output.view(-1, vocab_size) + labels = labels.view(-1) + loss_mask = loss_mask.contiguous().view(-1).float() + + # you can get the accuracy almost for free, so you might as well + predictions = output.argmax(dim=-1) + correct = predictions.eq(labels).float() + correct_sum = torch.sum(correct * loss_mask) + + gen_funcs = { + "cross_entropy": torch.softmax, + "entmax15": partial(entmax.entmax15, k=topk, return_support_size=True), + "sparsemax": partial(entmax.sparsemax, k=topk, return_support_size=True), + "entmax_bisect": partial(entmax.entmax_bisect, alpha=alpha, n_iter=n_iter) + } + + f = gen_funcs[loss_function] + + if loss_function not in {"cross_entropy", "entmax_bisect"}: + probs, support_size = f(output.float(), dim=-1) + else: + probs = f(output.float(), dim=-1) + support_size = None + + # now...p_theta(x) + # sp = p_theta(x) + H_2(p_theta) + + gold_probs = probs.gather(1, labels.unsqueeze(1)).view(-1) + entropy = _gini_entropy(probs) + sp = ((gold_probs + entropy) * loss_mask).sum() + + return sp, correct_sum + def forward_step(batch, model, eval_metric): """Forward step.""" + # TODO: return dict + eval_metrics = {"loss", "accuracy", "force_decoded_accuracy", + "force_decoded_accuracy_at_k", "sparsemax_score"} + if eval_metric not in eval_metrics: + raise NotImplementedError('forward method for evaluation metric {} ' + 'is not implemented.'.format(eval_metric)) # Get the batch. tokens, labels, attention_mask, position_ids, loss_mask = process_batch( @@ -125,23 +309,54 @@ def forward_step(batch, model, eval_metric): if parallel_state.is_pipeline_last_stage(): # For loss, return the unreduced loss. + + scores = dict() + if eval_metric == 'loss': + ''' losses = tensor_parallel.vocab_parallel_cross_entropy( output[0].contiguous().float(), labels.contiguous()) loss = torch.sum( losses.view(-1) * loss_mask.contiguous().view(-1).float()) return loss + ''' + + loss = _compute_loss( + output, labels, loss_mask, + loss_function=args.loss_function, topk=args.entmax_topk, n_iter=args.entmax_n_iter, alpha=args.entmax_alpha + ) + scores["loss"] = loss + + if eval_metric == "force_decoded_accuracy": + correct_sum = _force_decoded_accuracy(output, labels, loss_mask) + scores["force_decoded_accuracy"] = correct_sum + + if eval_metric == "force_decoded_accuracy_at_k": + correct_sum = _force_decoded_accuracy_at_k(output, labels, loss_mask, args.acc_k) + scores["force_decoded_accuracy_at_k"] = correct_sum + + if eval_metric == "sparsemax_score": + sp_sum, correct_sum = _sparsemax_score( + output, labels, loss_mask, + loss_function=args.loss_function, topk=args.entmax_topk, n_iter=args.entmax_n_iter, alpha=args.entmax_alpha + ) + scores["sparsemax_score"] = sp_sum + scores["force_decoded_accuracy"] = correct_sum + # currently computes the accuracy but doesn't return it. annoying. + # return {"sparsemax_score": sp_sum, "force_decoded_accuracy": correct_sum} # For accuracy, return the number of correctly predicted samples. if eval_metric == 'accuracy': + if isinstance(output, tuple): + output = output[0] # not sure why this was necessary outputs = torch.argmax(output, -1) correct = (outputs == labels).float() correct[(1 - loss_mask).bool()] = 1 correct = correct.prod(-1) - return correct.sum() + scores["accuracy"] = correct.sum() + + return scores - raise NotImplementedError('forward method for evaluation metric {} ' - 'is not implemented.'.format(eval_metric)) return None @@ -152,21 +367,24 @@ def evaluate(data_loader, model, eval_metric): # Turn on evaluation mode which disables dropout. model.eval() - total_output = 0.0 + total_output = defaultdict(float) + with torch.no_grad(): # For all the batches in the dataset. for iteration, batch in enumerate(data_loader): if iteration % args.log_interval == 0: print_rank_0('> working on iteration: {}'.format(iteration)) # Forward evaluation. - output = forward_step(batch, model, eval_metric) + output_dict = forward_step(batch, model, eval_metric) # problem if this doesn't return a tensor # Reduce across processes. if parallel_state.is_pipeline_last_stage(): - torch.distributed.all_reduce(output, - group=parallel_state.get_data_parallel_group()) - - total_output += output + for metric_name, output in output_dict.items(): + torch.distributed.all_reduce( + output, + group=parallel_state.get_data_parallel_group() + ) + total_output[metric_name] += output return total_output @@ -175,42 +393,61 @@ def evaluate_and_print_results(task, data_loader, model, eval_metric): """Evaluate and print results on screen.""" # Evaluate and get results. - output = evaluate(data_loader, model, eval_metric) + output_dict = evaluate(data_loader, model, eval_metric) # this is a dict string = ' validation results on {} | '.format(task) if is_last_rank(): - if eval_metric == 'loss': - num_tokenized_tokens = data_loader.dataset.num_tokenized_tokens - num_original_tokens = data_loader.dataset.num_original_tokens - val_loss = output / (num_tokenized_tokens - 1) - ppl = math.exp(min(20, val_loss)) - token_ratio = (num_tokenized_tokens - 1) / (num_original_tokens - 1) - adjusted_ppl = math.exp(min(20, val_loss * token_ratio)) - string += 'avg loss: {:.4E} | '.format(val_loss) - string += 'ppl: {:.4E} | '.format(ppl) - string += 'adjusted ppl: {:.4E} | '.format(adjusted_ppl) - string += 'token ratio: {} |'.format(token_ratio) - - results = { - "loss": val_loss.item(), - "ppl": ppl, - "ajusted_ppl": adjusted_ppl, - "token_ratio": token_ratio - } - - with open('./eval_results', 'w') as json_file: - json.dump(results, json_file) - - elif eval_metric == 'accuracy': - num_examples = len(data_loader.dataset) - acc = output / num_examples - string += 'number correct: {:.4E} | '.format(output) - string += 'total examples: {:.4E} | '.format(num_examples) - string += 'avg accuracy: {:.4E}'.format(acc) - - else: - raise NotImplementedError('evaluation method for {} metric is not ' - 'implemented yet.'.format(eval_metric)) + results = dict() + + num_tokenized_tokens = data_loader.dataset.num_tokenized_tokens + num_original_tokens = data_loader.dataset.num_original_tokens + num_examples = len(data_loader.dataset) + + results["n_tokens"] = num_tokenized_tokens + + for eval_metric, output in output_dict.items(): + if eval_metric == 'loss': + + val_loss = output / (num_tokenized_tokens - 1) + ppl = math.exp(min(20, val_loss)) + token_ratio = (num_tokenized_tokens - 1) / (num_original_tokens - 1) + adjusted_ppl = math.exp(min(20, val_loss * token_ratio)) + string += 'avg loss: {:.4E} | '.format(val_loss) + string += 'ppl: {:.4E} | '.format(ppl) + string += 'adjusted ppl: {:.4E} | '.format(adjusted_ppl) + string += 'token ratio: {} |'.format(token_ratio) + + results["loss"] = val_loss.item() + results["ppl"] = ppl + results["adjusted_ppl"] = adjusted_ppl + results["token_ratio"] = token_ratio + + elif eval_metric == 'accuracy': + # remember this is Lambada accuracy + acc = output / num_examples + string += 'number correct: {:.4E} | '.format(output) + string += 'total examples: {:.4E} | '.format(num_examples) + string += 'avg accuracy: {:.4E}'.format(acc) + results["accuracy"] = acc.item() + + elif eval_metric == "force_decoded_accuracy" or eval_metric == "force_decoded_accuracy_at_k": + acc = output / (num_tokenized_tokens - 1) + string += 'number correct: {:.4E} | '.format(output) + string += 'total tokens: {:.4E} | '.format(num_tokenized_tokens) + string += 'avg accuracy: {:.4E}'.format(acc) + + results["accuracy"] = acc.item() + results["n_correct"] = output.item() + elif eval_metric == "sparsemax_score": + avg_sparsemax_score = output / (num_tokenized_tokens - 1) + string += 'sparsemax score: {:.4E} | '.format(avg_sparsemax_score) + results["sparsemax_score"] = avg_sparsemax_score.item() + else: + raise NotImplementedError('evaluation method for {} metric is not ' + 'implemented yet.'.format(eval_metric)) + + with open('./eval_results', 'w') as json_file: + json.dump(results, json_file) length = len(string) + 1 print('-' * length) @@ -225,7 +462,9 @@ def main(): print("Interleaved pipeline schedule is not yet supported for text generation.") exit() - if args.task == 'LAMBADA': + if args.eval_metric is not None: + eval_metric = args.eval_metric + elif args.task == 'LAMBADA': eval_metric = 'accuracy' elif args.task == 'WIKITEXT103': eval_metric = 'loss' @@ -248,9 +487,12 @@ def main(): mpu=mpu if args.no_pipeline_parallel else None ) model = [model] - + if args.load is not None: - _ = load_checkpoint(model, None, None, load_iteration=args.load_iteration) + if args.task == "LAMBADA": + _ = load_checkpoint(model, None, None, load_iteration=args.load_iteration, load_only_weights=True) + else: + _ = load_checkpoint(model, None, None, load_iteration=args.load_iteration) assert len(model) == 1, "Above condition should have caught this" model = model[0] @@ -262,7 +504,6 @@ def main(): # Run evaluation. evaluate_and_print_results(args.task, dataloader, model, eval_metric) - - print_rank_0('done :-)') + print_rank_0('done :-)')