Skip to content

Commit

Permalink
fix sequence packing context truncation (#856)
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh authored Jan 15, 2025
1 parent ab7dc65 commit 5581062
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 5 deletions.
11 changes: 11 additions & 0 deletions src/levanter/data/packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,17 @@ class PromptCompletion:
prompt_length: int
segment_id: int | None = None

def __post_init__(self):
if len(self.ids) == 0:
raise ValueError("PromptCompletion must have at least one token")

# check that there is at least one token in the response
if len(self.ids) <= self.prompt_length:
raise ValueError(
f"PromptCompletion must have strictly more tokens than the prompt length. Got {len(self.ids)} tokens"
f" and prompt length {self.prompt_length}"
)


def pack_prompt_completions(
Pos: hax.Axis,
Expand Down
10 changes: 5 additions & 5 deletions src/levanter/eval_harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,20 +766,20 @@ def _iterate_tokenized_requests(
for off in range(len(batch_indices)):
i = batch_indices[off]
context_enc = context_encodings["input_ids"][off]
whole_ids = combined_encodings["input_ids"][off]
all_enc = combined_encodings["input_ids"][off]

context_enc_len = len(context_enc)

if len(whole_ids) > max_len:
if len(all_enc) > max_len:
logger.warning(f"Request {i} is too long. Truncating.")
# Truncate from the left
whole_ids = whole_ids[-max_len:]
context_enc_len = max_len - len(whole_ids) + context_enc_len
context_enc_len = len(context_enc) - (len(all_enc) - max_len)
all_enc = all_enc[-max_len:]
if context_enc_len < 0:
context_enc_len = 0
logger.warning("Prompt length is negative after truncation. Setting to 0.")

yield PromptCompletion(ids=whole_ids, prompt_length=context_enc_len, segment_id=i)
yield PromptCompletion(ids=all_enc, prompt_length=context_enc_len, segment_id=i)


def _pack_requests(
Expand Down

0 comments on commit 5581062

Please sign in to comment.