Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] Fix Padding Related Bugs: Crossfit #66

Merged
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crossfit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __call__(self, *args, **kwargs):
load_dataset = LazyLoader("crossfit.dataset.load.load_dataset")
embed = LazyLoader("crossfit.report.beir.embed.embed")
beir_report = LazyLoader("crossfit.report.beir.report.beir_report")
utils = LazyLoader("crossfit.utils")

__all__.extend(
[
Expand Down
30 changes: 26 additions & 4 deletions crossfit/backend/torch/hf/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,23 @@


class HFModel(Model):
def __init__(self, path_or_name: str, max_mem_gb: int = 16, training=False):
def __init__(
self,
path_or_name: str,
max_mem_gb: int = 16,
training: bool = False,
start_batch_size: int = 1,
end_batch_size: int = 2048,
batch_size_increment: int = 256,
start_seq_len: int = 1,
seq_len_increment: int = 64,
):
super().__init__(path_or_name, max_mem_gb)
self.start_batch_size = start_batch_size
self.end_batch_size = end_batch_size
self.batch_size_increment = batch_size_increment
self.start_seq_len = start_seq_len
self.seq_len_increment = seq_len_increment

if not training:
with torch.no_grad():
Expand Down Expand Up @@ -81,11 +96,13 @@ def fit_memory_estimate_curve(self, model=None):
y = []

max_seq = self.max_seq_length()
for batch_size in tqdm(range(2048, 0, -256)):
for batch_size in tqdm(
range(self.end_batch_size, self.start_batch_size, -self.batch_size_increment)
):
if batch_size <= 0:
continue

for seq_len in range(max_seq, 0, -64):
for seq_len in range(max_seq, self.start_seq_len, -self.seq_len_increment):
if seq_len <= 0:
continue

Expand Down Expand Up @@ -130,7 +147,12 @@ def estimate_memory(self, max_num_tokens: int, batch_size: int) -> int:
return predicted_memory[0] / 1024 # Convert from MB to GB

def max_seq_length(self) -> int:
return self.load_cfg().max_position_embeddings
max_seq_length = self.load_tokenizer().model_max_length
# Guard against the HF bug
# which sets max_seq_length to max(int) for some models
if max_seq_length > 1e5:
max_seq_length = AutoConfig.from_pretrained(self.path_or_name).max_position_embeddings
return max_seq_length


class SentenceTransformerModel(HFModel):
Expand Down
46 changes: 38 additions & 8 deletions crossfit/backend/torch/loader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 NVIDIA CORPORATION
# Copyright 2024 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -22,6 +22,7 @@
from crossfit.data.array.conversion import convert_array
from crossfit.data.array.dispatch import crossarray
from crossfit.data.dataframe.dispatch import CrossFrame
from crossfit.op.tokenize import clip_tokens
from crossfit.utils.model_adapter import adapt_model_input

DEFAULT_BATCH_SIZE = 512
Expand All @@ -36,7 +37,14 @@ def __init__(self, data: Dict[str, torch.Tensor], batch_size: int, progress_bar=
def __init__(self, data: CrossFrame, batch_size: int, progress_bar=None):
...

def __init__(self, data, batch_size: int, progress_bar=None, max_seq_len=None):
def __init__(
self,
data,
batch_size: int,
progress_bar=None,
max_seq_len=None,
padding_side: str = "right",
):
self.data = CrossFrame(data).cast(torch.Tensor)
self.tensor_dict = self.data.to_dict()
self._batch_size = batch_size
Expand All @@ -45,6 +53,7 @@ def __init__(self, data, batch_size: int, progress_bar=None, max_seq_len=None):
self._to_map = []
self.progress_bar = progress_bar
self.max_seq_len = max_seq_len
self.padding_side = padding_side

def map(self, fn):
self._to_map.append(fn)
Expand All @@ -66,7 +75,11 @@ def __next__(self):

batch = {key: val[self.current_idx : end] for key, val in self.tensor_dict.items()}
if self.max_seq_len is not None:
batch = {key: val[:, : self.max_seq_len] for key, val in batch.items()}
# TODO: Check this
if self.padding_side == "right":
batch = {key: val[:, : self.max_seq_len] for key, val in batch.items()}
else:
batch = {key: val[:, -self.max_seq_len :] for key, val in batch.items()}

self.current_idx += self.batch_size

Expand Down Expand Up @@ -97,9 +110,19 @@ def __init__(
self.to_ignore = to_ignore or []
self.to_ignore.append("seq_length")
self.model = model
tokenizer = self.model.load_tokenizer()
pad_token_id = tokenizer.pad_token_id
padding_side = tokenizer.padding_side

if padding_side not in ["right", "left"]:
raise ValueError("padding_side must be either 'right' or 'left'")

self.pad_token_id = pad_token_id
self.padding_side = padding_side
self.max_seq_len = self.model.max_seq_length()

frame = CrossFrame(data).cast(torch.Tensor)
seq_length = (frame[sort_key] != 0).sum(axis=1)
seq_length = (frame[sort_key] != self.pad_token_id).sum(axis=1)
self.sorted_indices = seq_length.argsort(descending=True)
frame = frame.apply(lambda x: x[self.sorted_indices])
frame = frame.assign(seq_length=seq_length[self.sorted_indices])
Expand Down Expand Up @@ -128,8 +151,6 @@ def __next__(self):
else:
start = self.splits[self.current_idx - 1]

_tokens = self.tensor_dict["seq_length"]

end = min(self.splits[self.current_idx], self.num_rows)
while end > start:
try:
Expand All @@ -138,8 +159,17 @@ def __next__(self):
for key, val in self.tensor_dict.items()
if key not in self.to_ignore
}
clip_len = min(max(_tokens[start], _tokens[end - 1]), self.model.max_seq_length())
batch = {key: val[:, :clip_len] for key, val in batch.items()}
# TODO: Fix max_length
if self.max_seq_len is None:
self.max_seq_len = self.model.max_seq_length()

batch = clip_tokens(
token_o=batch,
max_length=self.max_seq_len,
padding_side=self.padding_side,
pad_token_id=self.pad_token_id,
return_type="pt",
)

for fn in self._to_map:
batch = adapt_model_input(fn, batch)
Expand Down
8 changes: 7 additions & 1 deletion crossfit/backend/torch/op/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from crossfit.backend.torch.loader import DEFAULT_BATCH_SIZE, InMemoryLoader, SortedSeqLoader
from crossfit.backend.torch.model import Model
from crossfit.op.base import Op
from crossfit.utils.torch_utils import concat_and_pad_tensors


class Predictor(Op):
Expand Down Expand Up @@ -66,6 +67,7 @@ def call(self, data, partition_info=None):
loader = InMemoryLoader(
data[["input_ids", "attention_mask"]],
batch_size=self.batch_size,
padding_side=self.model.load_tokenizer().padding_side,
progress_bar=self.create_progress_bar(len(data), partition_info),
max_seq_len=self.model.max_seq_length(),
)
Expand All @@ -83,7 +85,11 @@ def call(self, data, partition_info=None):
all_outputs_ls.append(output)

out = cudf.DataFrame(index=index)
outputs = cp.asarray(torch.cat(all_outputs_ls, dim=0))
outputs = cp.asarray(
concat_and_pad_tensors(
all_outputs_ls, pad_token_id=loader.pad_token_id, padding_side=loader.padding_side
)
)
_index = loader.sort_column(index.values) if self.sorted_data_loader else index
if len(outputs.shape) <= 2:
out[self.pred_output_col] = create_list_series_from_1d_or_2d_ar(outputs, _index)
Expand Down
78 changes: 66 additions & 12 deletions crossfit/op/tokenize.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 NVIDIA CORPORATION
# Copyright 2023-2024 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -13,7 +13,7 @@
# limitations under the License.
import os
from enum import Enum
from typing import Optional, Union
from typing import Dict, Optional, Union

import cudf
import cupy as cp
Expand All @@ -31,6 +31,7 @@
class TokenizerType(Enum):
SUBWORD = 1
SENTENCE_PIECE = 2
DEFAULT = 3


class Tokenizer(Op):
Expand All @@ -55,8 +56,10 @@ def __init__(
GPUTokenizer.from_pretrained(self.model)

def tokenize_strings(self, sentences, max_length=None):
if self.tokenizer_type == TokenizerType.SENTENCE_PIECE:
if self.tokenizer_type in [TokenizerType.SENTENCE_PIECE, TokenizerType.DEFAULT]:
tokenizer = self.model.load_tokenizer()
self.padding_side = tokenizer.padding_side
self.pad_token_id = tokenizer.pad_token_id

if isinstance(sentences, cudf.Series):
sentences = sentences.to_arrow().to_pylist()
Expand All @@ -81,6 +84,8 @@ def tokenize_strings(self, sentences, max_length=None):
tokenizer = GPUTokenizer.from_pretrained(self.model)
worker.tokenizer = tokenizer

self.padding_side = tokenizer.padding_side
self.pad_token_id = tokenizer.pad_token_id
return worker.tokenizer(
sentences,
max_length=max_length or self.max_length,
Expand Down Expand Up @@ -110,7 +115,13 @@ def call_column(self, data):
text = text.str.slice(0, self.max_chars)

tokenized_data = self.tokenize_strings(text).copy()
tokenized_data = clip_tokens(tokenized_data, max_length=self.max_length, return_type="cp")
tokenized_data = clip_tokens(
tokenized_data,
max_length=self.max_length,
padding_side=self.padding_side,
pad_token_id=self.pad_token_id,
return_type="cp",
)

input_ids = create_list_series_from_1d_or_2d_ar(
tokenized_data["input_ids"].astype("int32"), data.index
Expand Down Expand Up @@ -173,13 +184,25 @@ def _convert_to_tokenizer_type(
tokenizer_type = TokenizerType.SENTENCE_PIECE
elif tokenizer_type in ["subword", "bert", TokenizerType.SUBWORD]:
tokenizer_type = TokenizerType.SUBWORD
elif tokenizer_type in ["default", TokenizerType.DEFAULT]:
tokenizer_type = TokenizerType.DEFAULT
return tokenizer_type


class GPUTokenizer(SubwordTokenizer):
def __init__(self, hash_file: str, do_lower_case: bool = True, config=None):
super().__init__(str(hash_file), do_lower_case=do_lower_case)
self.config = config or {"_name_or_path": hash_file}
self.padding_side = self.config.get("padding_side", "right")
self.pad_token_id = self.config.get("pad_token_id", 0)
if self.padding_side != "right":
raise ValueError(
f"Only right padding is supported for GPUTokenizer, got {self.padding_side}"
)
if self.pad_token_id != 0:
raise ValueError(
f"Only pad_token_id=0 is supported for GPUTokenizer, got {self.pad_token_id}"
)

@classmethod
def get_tokenizer_config(cls, name):
Expand Down Expand Up @@ -224,17 +247,48 @@ def from_pretrained(cls, name, cache_dir=None):
return cls(hashed_vocab_path, config=config)


def clip_tokens(token_o, max_length, return_type="pt"):
def clip_tokens(
token_o: Dict[str, Union[cp.ndarray, torch.Tensor]],
max_length: int,
padding_side: str,
pad_token_id: int,
return_type: str = "pt",
) -> Dict[str, Union[cp.ndarray, torch.Tensor]]:
# Verify non-empty max_length, padding_side, and pad_token_id
if not max_length:
raise ValueError("max_length cannot be empty or zero.")
if not padding_side:
raise ValueError("padding_side cannot be empty.")
if pad_token_id is None:
raise ValueError("pad_token_id cannot be None.")

# Check if input_ids is a cupy array, if not convert to cupy array
if not isinstance(token_o["input_ids"], cp.ndarray):
token_o = {k: cp.asarray(v) for k, v in token_o.items()}

clip_len = max_length - int((token_o["input_ids"][:, ::-1] != 0).argmax(1).min())
token_o["input_ids"] = _cast_to_appropriate_type(
token_o["input_ids"][:, :clip_len], return_type
)
token_o["attention_mask"] = _cast_to_appropriate_type(
token_o["attention_mask"][:, :clip_len], return_type
)
# Clip the input_ids and attention_mask based on the padding side
# max_length = min(max_length, token_o["input_ids"].shape[1])
total_indices = token_o["input_ids"].shape[1]
if padding_side == "right":
clip_len = total_indices - int(
(token_o["input_ids"][:, ::-1] != pad_token_id).argmax(1).min()
)
clip_len = min(clip_len, max_length)
token_o["input_ids"] = _cast_to_appropriate_type(
token_o["input_ids"][:, :clip_len], return_type
)
token_o["attention_mask"] = _cast_to_appropriate_type(
token_o["attention_mask"][:, :clip_len], return_type
)
else:
clip_len = total_indices - int((token_o["input_ids"] != pad_token_id).argmax(1).min())
clip_len = min(clip_len, max_length)
token_o["input_ids"] = _cast_to_appropriate_type(
token_o["input_ids"][:, -clip_len:], return_type
)
token_o["attention_mask"] = _cast_to_appropriate_type(
token_o["attention_mask"][:, -clip_len:], return_type
)

if "metadata" in token_o:
del token_o["metadata"]
Expand Down
Loading
Loading