Skip to content

Commit

Permalink
undo changes made to tokenizer (for number tokenization experiment)
Browse files Browse the repository at this point in the history
  • Loading branch information
garrethlee committed Dec 22, 2024
1 parent 751ec13 commit 0be58ae
Showing 1 changed file with 1 addition and 28 deletions.
29 changes: 1 addition & 28 deletions src/datatrove/pipeline/tokens/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,6 @@ def __init__(
save_filename: str = None, # if defined, the final output filename will be this
tokenizer_name_or_path: str = "gpt2", # tokenizer to use, from HF or a local
eos_token: str = "<|endoftext|>", # whether to add the EOS token after each document
r2l_digits: bool = False, # whether to tokenize digits from right to left
save_loss_metadata: bool = False, # save the loss information
shuffle: bool = True, # whether to shuffle documents in the dataset,
batch_size: int = 10000, # batch size for tokenization
Expand All @@ -306,7 +305,6 @@ def __init__(
self.save_filename = save_filename
self.tokenizer_name_or_path = tokenizer_name_or_path
self.eos_token = eos_token
self.r2l_digits = r2l_digits
self.save_loss_metadata = save_loss_metadata
self.shuffle = shuffle
self.batch_size = batch_size
Expand Down Expand Up @@ -345,29 +343,8 @@ def write_unshuffled(self, data: DocumentsPipeline, filename: str):
data (DocumentsPipeline): the documents to process
filename (str): the filename to use for the output file
"""
import re

from tokenizers import Encoding

R2L_COMMA_TOKEN = "<R2L-COMMA>"

self.tokenizer.add_special_tokens([R2L_COMMA_TOKEN])
R2L_COMMA_TOKEN_ID = self.tokenizer.token_to_id(R2L_COMMA_TOKEN)

def add_commas(match):
import sys

sys.set_int_max_str_digits(1000000)

num, decimal = match.groups()
decimal = decimal if decimal else ""
return f"{int(num):,}{decimal}".replace(",", R2L_COMMA_TOKEN)

def preprocess_text(text: str):
number_regex = re.compile(r"([\d]+)(\.\d+)?")
processed_text = number_regex.sub(add_commas, text)
return processed_text

unshuff = TokenizedFile(
self.output_folder if not self.shuffle or not self.local_working_dir else self.local_working_dir,
filename,
Expand All @@ -381,13 +358,9 @@ def preprocess_text(text: str):
# tokenize document's text in batches to go faster – we compute loss values independently if needed
for batch in batched(data, self.batch_size):
with self.track_time(unit="batch"):
encoded_batch: list[Encoding] = self.tokenizer.encode_batch(
[preprocess_text(document.text) if self.r2l_digits else document.text for document in batch]
)
encoded_batch: list[Encoding] = self.tokenizer.encode_batch([document.text for document in batch])
for document, encoded in zip(batch, encoded_batch):
tokens = encoded.ids
if self.r2l_digits:
tokens = [tok for tok in tokens if tok != R2L_COMMA_TOKEN_ID]
loss_values = self.get_loss_values(document, encoded)
if loss_values is not None and len(loss_values) < len(tokens):
# crop final section without loss
Expand Down

0 comments on commit 0be58ae

Please sign in to comment.