Skip to content

Commit

Permalink
Fix bug in multi-token Stop Sequences (EleutherAI#1268)
Browse files Browse the repository at this point in the history
* fix incorrect lookback protections

* bump generate_until task versions
  • Loading branch information
haileyschoelkopf authored Jan 11, 2024
1 parent 818c056 commit ff73941
Show file tree
Hide file tree
Showing 42 changed files with 49 additions and 42 deletions.
2 changes: 1 addition & 1 deletion lm_eval/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,7 +696,7 @@ def _model_generate(self, context, max_length, stop, **generation_kwargs):
generation_kwargs["do_sample"] = False
# build stopping criteria
stopping_criteria = stop_sequences_criteria(
self.tokenizer, stop, 1, context.shape[0]
self.tokenizer, stop, context.shape[1], context.shape[0]
)
return self.model.generate(
input_ids=context,
Expand Down
2 changes: 1 addition & 1 deletion lm_eval/tasks/babi/babi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ metric_list:
aggregation: mean
higher_is_better: true
metadata:
version: 0.0
version: 1.0
2 changes: 1 addition & 1 deletion lm_eval/tasks/bbh/cot_fewshot/_cot_fewshot_template_yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@ filter_list:
- function: "take_first"
num_fewshot: 0
metadata:
version: 1.0
version: 2.0
2 changes: 1 addition & 1 deletion lm_eval/tasks/bbh/cot_zeroshot/_cot_zeroshot_template_yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@ filter_list:
- function: "take_first"
num_fewshot: 0
metadata:
version: 0
version: 1.0
2 changes: 1 addition & 1 deletion lm_eval/tasks/bbh/fewshot/_fewshot_template_yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ generation_kwargs:
temperature: 0.0
num_fewshot: 0
metadata:
version: 0
version: 1.0
2 changes: 1 addition & 1 deletion lm_eval/tasks/bbh/zeroshot/_zeroshot_template_yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ generation_kwargs:
temperature: 0.0
num_fewshot: 0
metadata:
version: 0
version: 1.0
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,5 @@ filter_list:
- function: "regex"
regex_pattern: "The answer is (\\-?[0-9\\.\\,]+)"
- function: "take_first"
metadata:
version: 1.0
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,5 @@ generation_kwargs:
- "</s>"
do_sample: false
temperature: 0.0
metadata:
version: 1.0
2 changes: 1 addition & 1 deletion lm_eval/tasks/bigbench/generate_until_template_yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ metric_list:
higher_is_better: true
ignore_punctuation: true
metadata:
version: 0.0
version: 1.0
2 changes: 1 addition & 1 deletion lm_eval/tasks/code_x_glue/code-text/go.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ metric_list:
aggregation: mean
higher_is_better: True
metadata:
version: 0.0
version: 1.0
2 changes: 1 addition & 1 deletion lm_eval/tasks/code_x_glue/code-text/java.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ metric_list:
aggregation: mean
higher_is_better: True
metadata:
version: 0.0
version: 1.0
2 changes: 1 addition & 1 deletion lm_eval/tasks/code_x_glue/code-text/javascript.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ metric_list:
aggregation: mean
higher_is_better: True
metadata:
version: 0.0
version: 1.0
2 changes: 1 addition & 1 deletion lm_eval/tasks/code_x_glue/code-text/php.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ metric_list:
aggregation: mean
higher_is_better: True
metadata:
version: 0.0
version: 1.0
2 changes: 1 addition & 1 deletion lm_eval/tasks/code_x_glue/code-text/python.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ metric_list:
aggregation: mean
higher_is_better: True
metadata:
version: 0.0
version: 1.0
2 changes: 1 addition & 1 deletion lm_eval/tasks/code_x_glue/code-text/ruby.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ metric_list:
aggregation: mean
higher_is_better: True
metadata:
version: 2.0
version: 3.0
2 changes: 1 addition & 1 deletion lm_eval/tasks/coqa/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ metric_list:
aggregation: mean
higher_is_better: true
metadata:
version: 2.0
version: 3.0
2 changes: 1 addition & 1 deletion lm_eval/tasks/drop/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@ metric_list:
aggregation: mean
higher_is_better: true
metadata:
version: 2.0
version: 3.0
2 changes: 1 addition & 1 deletion lm_eval/tasks/fld/fld_default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ filter_list:
- function: remove_whitespace
- function: take_first
metadata:
version: 1.0
version: 2.0
2 changes: 1 addition & 1 deletion lm_eval/tasks/gsm8k/gsm8k-cot-self-consistency.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,4 @@ filter_list:
- function: "majority_vote"
- function: "take_first"
metadata:
version: 0.0
version: 1.0
2 changes: 1 addition & 1 deletion lm_eval/tasks/gsm8k/gsm8k-cot.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,4 @@ filter_list:
regex_pattern: "The answer is (\\-?[0-9\\.\\,]+)."
- function: "take_first"
metadata:
version: 0.0
version: 1.0
2 changes: 1 addition & 1 deletion lm_eval/tasks/gsm8k/gsm8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@ filter_list:
regex_pattern: "#### (\\-?[0-9\\.\\,]+)"
- function: "take_first"
metadata:
version: 1.0
version: 2.0
2 changes: 1 addition & 1 deletion lm_eval/tasks/ifeval/ifeval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@ metric_list:
aggregation: !function utils.agg_inst_level_acc
higher_is_better: true
metadata:
version: 1.0
version: 2.0
2 changes: 1 addition & 1 deletion lm_eval/tasks/mgsm/direct/direct_yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@ metric_list:
ignore_case: true
ignore_punctuation: true
metadata:
version: 0.0
version: 1.0
2 changes: 1 addition & 1 deletion lm_eval/tasks/mgsm/en_cot/cot_yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@ filter_list:
regex_pattern: "The answer is (\\-?[0-9\\.\\,]+)"
- function: "take_first"
metadata:
version: 0.0
version: 1.0
2 changes: 1 addition & 1 deletion lm_eval/tasks/mgsm/native_cot/cot_yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@ filter_list:
regex_pattern: "The answer is (\\-?[0-9\\.\\,]+)"
- function: "take_first"
metadata:
version: 1.0
version: 2.0
2 changes: 1 addition & 1 deletion lm_eval/tasks/minerva_math/minerva_math_algebra.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@ metric_list:
higher_is_better: true
num_fewshot: 0
metadata:
version: 0.0
version: 1.0
2 changes: 1 addition & 1 deletion lm_eval/tasks/nq_open/nq_open.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@ metric_list:
regexes_to_ignore:
- "\ban|a|the\b"
metadata:
version: 0.0
version: 1.0
2 changes: 1 addition & 1 deletion lm_eval/tasks/polemo2/polemo2_in.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,4 @@ metric_list:
aggregation: mean
higher_is_better: true
metadata:
version: 0.0
version: 1.0
2 changes: 1 addition & 1 deletion lm_eval/tasks/qasper/freeform.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ metric_list:
aggregation: mean
higher_is_better: true
metadata:
version: 1.0
version: 2.0
2 changes: 1 addition & 1 deletion lm_eval/tasks/scrolls/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def _num_cpu_cores():


class _SCROLLSTask(Task):
VERSION = 1
VERSION = 2
DATASET_PATH = "tau/scrolls"
DATASET_NAME = None
PRUNE_TOKENIZERS = None
Expand Down
2 changes: 1 addition & 1 deletion lm_eval/tasks/squadv2/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _squad_agg(key, items):

@register_task("squadv2")
class SQuAD2(Task):
VERSION = 2
VERSION = 3
DATASET_PATH = "squad_v2"
DATASET_NAME = None

Expand Down
2 changes: 1 addition & 1 deletion lm_eval/tasks/translation/wmt_common_yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ generation_kwargs:
temperature: 0.0
repeats: 1
metadata:
version: 0.0
version: 1.0
2 changes: 1 addition & 1 deletion lm_eval/tasks/triviaqa/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@ metric_list:
ignore_case: true
ignore_punctuation: true
metadata:
version: 2.0
version: 3.0
2 changes: 1 addition & 1 deletion lm_eval/tasks/truthfulqa/truthfulqa_gen.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,4 @@ metric_list:
aggregation: mean
higher_is_better: true
metadata:
version: 2.0
version: 3.0
2 changes: 1 addition & 1 deletion lm_eval/tasks/unscramble/anagrams1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ metric_list:
ignore_case: false
ignore_punctuation: false
metadata:
version: 1.0
version: 2.0
2 changes: 1 addition & 1 deletion lm_eval/tasks/unscramble/anagrams2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ metric_list:
ignore_case: false
ignore_punctuation: false
metadata:
version: 1.0
version: 2.0
2 changes: 1 addition & 1 deletion lm_eval/tasks/unscramble/cycle_letters.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ metric_list:
ignore_case: false
ignore_punctuation: false
metadata:
version: 1.0
version: 2.0
2 changes: 1 addition & 1 deletion lm_eval/tasks/unscramble/random_insertion.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ metric_list:
ignore_case: false
ignore_punctuation: false
metadata:
version: 1.0
version: 2.0
2 changes: 1 addition & 1 deletion lm_eval/tasks/unscramble/reversed_words.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ metric_list:
ignore_case: false
ignore_punctuation: false
metadata:
version: 1.0
version: 2.0
2 changes: 1 addition & 1 deletion lm_eval/tasks/webqs/webqs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ metric_list:
aggregation: mean
higher_is_better: true
metadata:
version: 1.0
version: 2.0
2 changes: 1 addition & 1 deletion lm_eval/tasks/wmt2016/ro_en-t5_prompt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ metric_list:
aggregation: !function metrics.agg_bleu
higher_is_better: true
metadata:
version: 0.0
version: 1.0
9 changes: 6 additions & 3 deletions lm_eval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,23 +636,26 @@ def __init__(
self.done_tracker = [False] * batch_size
self.sequence = sequence
self.sequence_ids = tokenizer.encode(sequence, add_special_tokens=False)
# print(sequence, self.sequence_ids)
# we look back for 2 more tokens than it takes to encode our stop sequence
# because tokenizers suck, and a model might generate `['\n', '\n']` but our `sequence` is `['\n\n']`
# and we don't want to mistakenly not stop a generation because our
# (string) stop sequence was output in a different tokenization

# NOTE: there is a minor danger that this will end up looking back 2 tokens into the past, into the inputs to the model,
# and stopping generation immediately as a result. With only 2 extra tokens of lookback, this risk is minimized
# Additionally, in lookback_ids_batch we should prevent ever looking back into the inputs as described.
self.sequence_id_len = len(self.sequence_ids) + 2
self.tokenizer = tokenizer

def __call__(self, input_ids, scores, **kwargs) -> bool:
# For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence
lookback_ids_batch = input_ids[:, self.initial_decoder_input_length :][
:, -self.sequence_id_len :
]
lookback_ids_batch = input_ids[:, self.initial_decoder_input_length :]

lookback_ids_batch = lookback_ids_batch[:, -self.sequence_id_len :]

lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch)

for i, done in enumerate(self.done_tracker):
if not done:
self.done_tracker[i] = self.sequence in lookback_tokens_batch[i]
Expand Down

0 comments on commit ff73941

Please sign in to comment.