Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add sentence piece tokenizer #43

Merged
merged 2 commits into from
Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 63 additions & 19 deletions crossfit/op/tokenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from typing import Optional
from enum import Enum
from typing import Optional, Union

import cudf
import cupy as cp
import torch
from cudf.core.subword_tokenizer import SubwordTokenizer, _cast_to_appropriate_type
from cudf.utils.hash_vocab_utils import hash_vocab
from transformers import AutoConfig, AutoTokenizer
Expand All @@ -26,39 +28,68 @@
from crossfit.op.base import Op


class TokenizerType(Enum):
SUBWORD = 1
SENTENCE_PIECE = 2


class Tokenizer(Op):
def __init__(
self,
model: Model,
tokenizer_type: Union[TokenizerType, str] = TokenizerType.SUBWORD,
cols=None,
keep_cols=None,
pre=None,
max_length: Optional[int] = None,
):
super().__init__(pre=pre, cols=cols, keep_cols=keep_cols)
self.model = model
self.tokenizer_type = self._convert_to_tokenizer_type(tokenizer_type)
self.max_length = max_length or model.max_seq_length()

self.setup()
if self.tokenizer_type == TokenizerType.SUBWORD:
# Make sure we download the tokenizer just once
GPUTokenizer.from_pretrained(self.model)

def tokenize_strings(self, sentences, max_length=None):
worker = self.get_worker()

if hasattr(worker, "tokenizer"):
tokenizer = worker.tokenizer
if self.tokenizer_type == TokenizerType.SENTENCE_PIECE:
tokenizer = self.model.load_tokenizer()

if isinstance(sentences, cudf.Series):
sentences = sentences.to_arrow().to_pylist()

with torch.no_grad():
tokenized_data = tokenizer.batch_encode_plus(
sentences,
max_length=max_length or self.max_length,
return_tensors="pt",
add_special_tokens=True,
padding="max_length",
truncation=True,
return_token_type_ids=False,
)
return tokenized_data
elif self.tokenizer_type == TokenizerType.SUBWORD:
worker = self.get_worker()

if hasattr(worker, "tokenizer"):
tokenizer = worker.tokenizer
else:
tokenizer = GPUTokenizer.from_pretrained(self.model)
worker.tokenizer = tokenizer

return worker.tokenizer(
sentences,
max_length=max_length or self.max_length,
max_num_rows=len(sentences),
padding="max_length",
return_tensors="cp",
truncation=True,
add_special_tokens=True,
)
else:
tokenizer = GPUTokenizer.from_pretrained(self.model)
worker.tokenizer = tokenizer

return worker.tokenizer(
sentences,
max_length=max_length or self.max_length,
max_num_rows=len(sentences),
padding="max_length",
return_tensors="cp",
truncation=True,
add_special_tokens=True,
)
raise ValueError(f"Unexpected tokenizer type: {self.tokenizer_type}")

def teardown(self):
worker = self.get_worker()
Expand Down Expand Up @@ -128,6 +159,16 @@ def _construct_name(self, col_name, suffix):

return f"{col_name}_{suffix}"

def _convert_to_tokenizer_type(
self,
tokenizer_type: Union[TokenizerType, str],
) -> TokenizerType:
if tokenizer_type in ["sentencepiece", "spm", TokenizerType.SENTENCE_PIECE]:
tokenizer_type = TokenizerType.SENTENCE_PIECE
elif tokenizer_type in ["subword", "bert", TokenizerType.SUBWORD]:
tokenizer_type = TokenizerType.SUBWORD
return tokenizer_type


class GPUTokenizer(SubwordTokenizer):
def __init__(self, hash_file: str, do_lower_case: bool = True, config=None):
Expand Down Expand Up @@ -178,6 +219,9 @@ def from_pretrained(cls, name, cache_dir=None):


def clip_tokens(token_o, max_length, return_type="pt"):
if not isinstance(token_o["input_ids"], cp.ndarray):
token_o = {k: cp.asarray(v) for k, v in token_o.items()}

clip_len = max_length - int((token_o["input_ids"][:, ::-1] != 0).argmax(1).min())
token_o["input_ids"] = _cast_to_appropriate_type(
token_o["input_ids"][:, :clip_len], return_type
Expand Down
24 changes: 24 additions & 0 deletions tests/op/test_tokenize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import pytest

cudf = pytest.importorskip("cudf")
dask_cudf = pytest.importorskip("dask_cudf")
transformers = pytest.importorskip("transformers")

import crossfit as cf # noqa: E402
from crossfit import op # noqa: E402


def test_tokenizer_sentence_piece(model_name="microsoft/deberta-v3-base"):
model = cf.HFModel(model_name)
tokenizer = op.Tokenizer(model, cols=["text"], tokenizer_type="spm")
ddf = dask_cudf.from_cudf(
cudf.DataFrame({"text": ["hello world", "this is a sentence"]}),
npartitions=2,
)
results = tokenizer(ddf)
results = results.compute()

hf_tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
assert isinstance(results, cudf.DataFrame)
assert results["input_ids"][0] == hf_tokenizer(["hello world"])["input_ids"][0]
assert results["input_ids"][1] == hf_tokenizer(["this is a sentence"])["input_ids"][0]
Loading