Skip to content

Commit

Permalink
refactored some deduplication code; added reference/index mode to sen…
Browse files Browse the repository at this point in the history
…tence dedup; fixed tests
  • Loading branch information
guipenedo committed Nov 1, 2023
1 parent e21a372 commit 5b1c4b4
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 95 deletions.
29 changes: 10 additions & 19 deletions src/datatrove/pipeline/dedup/minhash.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from datatrove.data import DocumentsPipeline
from datatrove.io import BaseInputDataFolder, BaseOutputDataFolder, InputDataFile
from datatrove.pipeline.base import PipelineStep
from datatrove.pipeline.dedup.utils import sha1_hash32, simplify_content
from datatrove.pipeline.dedup.utils import read_tuples_from_file, sha1_hash32, simplify_content
from datatrove.pipeline.writers.disk_base import DiskWriter
from datatrove.utils.typeshelper import StatHints

Expand Down Expand Up @@ -54,16 +54,11 @@ def __lt__(self, other):

def read_sigs(file: InputDataFile, reader_id: int, hashes_per_bucket: int, index_file: bool = False) -> Generator:
n = hashes_per_bucket + 1 if not index_file else hashes_per_bucket
with file.open(binary=True) as f:
while True:
data = f.read(n * struct.calcsize("I"))
if not data:
return
data = struct.unpack("<%sI" % n, data)
if index_file:
yield HashSig(sig=data, doc_id=-1, file_id=-1, reader_id=reader_id)
else:
yield HashSig(sig=data[:-1], doc_id=data[-1], file_id=reader_id, reader_id=reader_id)
for data in read_tuples_from_file(file, f"{n}I"):
if index_file:
yield HashSig(sig=data, doc_id=-1, file_id=-1, reader_id=reader_id)
else:
yield HashSig(sig=data[:-1], doc_id=data[-1], file_id=reader_id, reader_id=reader_id)


class MinhashDedupSignature(PipelineStep):
Expand Down Expand Up @@ -241,11 +236,9 @@ def parent(x):
return union_set[x]

for dup_file in dup_files:
with dup_file.open(binary=True) as df:
while data := df.read(4 * struct.calcsize("I")):
f1, d1, f2, d2 = struct.unpack("<4I", data)
a, b = (f1, d1), (f2, d2)
union_set[parent(b)] = parent(a)
for f1, d1, f2, d2 in read_tuples_from_file(dup_file, "4I"):
a, b = (f1, d1), (f2, d2)
union_set[parent(b)] = parent(a)

ci = 0
cluster_ids = {}
Expand Down Expand Up @@ -307,9 +300,7 @@ def get_next():

def load_clusters():
if clusters_data:
with clusters_data[0].open_binary() as cf:
while data := cf.read(struct.calcsize("I") * 2):
yield struct.unpack("<2I", data)
yield from read_tuples_from_file(clusters_data[0], "2I")

if self.load_cluster_ids:
cluster_loader = load_clusters()
Expand Down
144 changes: 95 additions & 49 deletions src/datatrove/pipeline/dedup/sentence_dedup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from dataclasses import dataclass
from typing import Generator

from nltk import load
from loguru import logger
from nltk import load, ngrams
from nltk.tokenize import sent_tokenize, word_tokenize

from datatrove.data import Document, DocumentsPipeline
Expand All @@ -23,7 +24,7 @@
from datatrove.utils.typeshelper import StatHints

from ..writers.disk_base import DiskWriter
from .utils import ExtensionHelperSD, merge_docs, simplify_content, str_hash
from .utils import ExtensionHelperSD, merge_docs, read_tuples_from_file, simplify_content, str_hash


@dataclass
Expand All @@ -33,14 +34,19 @@ class HashSig:
sent_id: int
file_id: int = None

def to_tuple(self):
# this also determines the sorting order
# hash_value needs to come first as that's what we match on
# file_id should come after doc_id so that hashes from the index (sent_id=doc_id=-1) come up first
# return self.hash_value, self.sent_id, self.doc_id, self.file_id
return self.hash_value, self.doc_id, self.file_id, self.sent_id

# priority queue accepts anything that is sortable
def __lt__(self, other) -> bool:
return (self.hash_value, self.file_id, self.doc_id, self.sent_id) < (
other.hash_value,
other.file_id,
other.doc_id,
other.sent_id,
)
return self.to_tuple() < other.to_tuple()

def is_from_index(self):
return self.doc_id == self.sent_id == -1


class SentenceDedupSignature(PipelineStep):
Expand All @@ -67,9 +73,9 @@ def save_hashes(self, rank: int):

f = self.output_folder.open(f"{rank:05d}{ExtensionHelperSD.stage_1_signature}", mode="wb")
for hs in self.signatures:
f.file_handler.write(struct.pack("<Q", hs.hash_value))
f.file_handler.write(struct.pack("<I", hs.doc_id))
f.file_handler.write(struct.pack("<H", hs.sent_id))
f.write(struct.pack("<Q", hs.hash_value))
f.write(struct.pack("<I", hs.doc_id))
f.write(struct.pack("<H", hs.sent_id))

def get_hashes(self, doc: Document, doc_idx: int) -> list[None] | list[HashSig]:
# todo use language id metadata in sent_tokenize
Expand All @@ -78,10 +84,7 @@ def get_hashes(self, doc: Document, doc_idx: int) -> list[None] | list[HashSig]:
return []

sentences_tokens = [simplify_content(sent) for sent in sentences]
n_sent_grams: list = [
" ".join(sentences_tokens[i : i + self.n_sentences])
for i in range(len(sentences_tokens) - self.n_sentences + 1)
]
n_sent_grams: list = [" ".join(x) for x in ngrams(sentences_tokens, self.n_sentences)]
hashes = [
HashSig(
hash_value=str_hash(n_sent_gram),
Expand Down Expand Up @@ -114,30 +117,33 @@ def __call__(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1):
self.output_folder.close()


def read_sigs(file: InputDataFile, file_id: int) -> Generator[HashSig, None, None]:
with file.open(binary=True) as f:
while True:
x = {}
for t, b, k in [
("Q", struct.calcsize("Q"), "hash_value"),
("I", struct.calcsize("I"), "doc_id"),
("H", struct.calcsize("H"), "sent_id"),
]:
by = f.read(b)
if not by:
return
x[k] = struct.unpack(f"<{t}", by)[0]
yield HashSig(file_id=file_id, **x)
def read_sigs(file: InputDataFile, file_id: int, index_file: bool = False) -> Generator[HashSig, None, None]:
if index_file:
# only read hashes
for (hash,) in read_tuples_from_file(file, "Q"):
yield HashSig(hash_value=hash, doc_id=-1, file_id=file_id, sent_id=-1)
else:
for hash, doc_id, sent_id in read_tuples_from_file(file, "Q", "I", "H"):
yield HashSig(file_id=file_id, hash_value=hash, doc_id=doc_id, sent_id=sent_id)


class SentenceFindDedups(PipelineStep):
type = "🫂 - DEDUPS"
name = "💥 sentence-deduplication stage 2"

def __init__(self, data_folder: BaseInputDataFolder, output_folder: BaseOutputDataFolder, **kwargs):
def __init__(
self,
data_folder: BaseInputDataFolder,
output_folder: BaseOutputDataFolder,
index_folder: BaseInputDataFolder = None,
only_dedup_in_index: bool = True,
**kwargs,
):
super().__init__(**kwargs)
self.data_folder = data_folder
self.output_folder = output_folder
self.index_folder = index_folder
self.only_dedup_in_index = only_dedup_in_index

def __call__(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1):
"""
Expand All @@ -153,21 +159,36 @@ def __call__(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1):
assert world_size == 1, "SentenceFindDedups can only run on a single worker."
files_with_duplicates = set()
with self.stats.time_manager:
sig_files = self.data_folder.list_files(ExtensionHelperSD.stage_1_signature)
sig_files = self.data_folder.list_files(extension=ExtensionHelperSD.stage_1_signature)
sig_readers = [read_sigs(file, file_i) for file_i, file in enumerate(sig_files)]
index_files = self.index_folder.list_files() if self.index_folder else None
if index_files:
logger.info(f"Found index file(s): {', '.join([file.relative_path for file in index_files])}")
sig_readers.extend(
[
read_sigs(file, len(sig_readers) + file_i, index_file=True)
for file_i, file in enumerate(index_files)
]
)

pq = [next(sig_reader) for sig_reader in sig_readers]
heapq.heapify(pq)

last = None
last: HashSig | None = None
while pq:
v: HashSig = heapq.heappop(pq)
if last == v.hash_value:
if (
last and last.hash_value == v.hash_value and not v.is_from_index()
): # we never want to match samples from the index itself
f = self.output_folder.open(f"{v.file_id:05d}{ExtensionHelperSD.stage_2_duplicates}", mode="wb")
f.file_handler.write(struct.pack("<I", v.doc_id))
f.file_handler.write(struct.pack("<H", v.sent_id))
files_with_duplicates.add(v.file_id)
last = v.hash_value
# the previous one we are matching against is part of the index
# OR there are no index files
# OR we are also matching within the main dataset
if last.is_from_index() or not index_files or not self.only_dedup_in_index:
f.write(struct.pack("<I", v.doc_id))
f.write(struct.pack("<H", v.sent_id))
files_with_duplicates.add(v.file_id)
last = v
new_v = next(sig_readers[v.file_id], None)

if new_v:
Expand All @@ -181,18 +202,7 @@ def __call__(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1):


def read_duplicates(file: InputDataFile) -> Generator[tuple, None, None]:
with file.open(binary=True) as f:
while True:
x = []
for (
t,
b,
) in [("I", struct.calcsize("I")), ("H", struct.calcsize("H"))]:
by = f.read(b)
if not by:
return
x.append(struct.unpack(f"<{t}", by)[0])
yield tuple(x) # (doc_id, sent_id) pairs
yield from read_tuples_from_file(file, "I", "H") # (doc_id, sent_id) pairs


class SentenceDedupFilter(PipelineStep):
Expand Down Expand Up @@ -279,3 +289,39 @@ def __call__(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1)
elif writer:
doc.content = original_formatted
writer.write(doc, rank=rank)


class SentenceDedupBuildIndex(PipelineStep):
type = "🫂 - DEDUP"
name = "💥 sentence-deduplication build index"

def __init__(
self, data_folder: BaseInputDataFolder, output_folder: BaseOutputDataFolder, index_name: str, **kwargs
):
super().__init__(**kwargs)
self.data_folder = data_folder
self.output_folder = output_folder
self.index_name = index_name

def __call__(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1):
assert world_size == 1, "SentenceDedupBuildIndex can only run on a single worker."
with self.stats.time_manager:
sig_files = self.data_folder.list_files(extension=ExtensionHelperSD.stage_1_signature)
sig_readers = [read_sigs(file, file_i) for file_i, file in enumerate(sig_files)]

pq = [next(sig_reader) for sig_reader in sig_readers]
heapq.heapify(pq)

out_f = self.output_folder.open(f"{self.index_name}.{ExtensionHelperSD.index}", mode="wb")

last = None
while pq:
v: HashSig = heapq.heappop(pq)
if last != v.hash_value:
out_f.write(struct.pack("<Q", v.hash_value))
last = v.hash_value
new_v = next(sig_readers[v.file_id], None)

if new_v:
heapq.heappush(pq, new_v)
self.output_folder.close()
25 changes: 13 additions & 12 deletions src/datatrove/pipeline/dedup/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,14 @@
from collections import defaultdict

import numpy as np
from nltk.tokenize import sent_tokenize


ESCAPE_PLACEHOLDER = " @_ESC_APE_@ "
from datatrove.io import InputDataFile


class ExtensionHelperSD:
stage_1_signature = ".c4_sig"
stage_2_duplicates = ".c4_dup"
index = ".c4_index"


class ExtensionHelperES:
Expand All @@ -29,6 +28,17 @@ class ExtensionHelperES:
)


def read_tuples_from_file(file: InputDataFile, *formats):
with file.open(binary=True) as f:
while True:
line = []
for F in formats:
if not (data := f.read(struct.calcsize(F))):
return
line.extend(struct.unpack(f"<{F}", data))
yield tuple(line)


def simplify_content(text: str):
# lower case
text = text.lower()
Expand All @@ -51,7 +61,6 @@ def str_hash(s: str) -> int:


def merge_docs(sen_list, n_sentences: int = 3) -> dict:
# TODO IMPROVE!
def to_sentences(idx: int):
return {idx + i for i in range(n_sentences)}

Expand All @@ -72,11 +81,3 @@ def sha1_hash32(data):
int: an integer hash value that can be encoded using 32 bits.
"""
return struct.unpack("<I", hashlib.sha1(data).digest()[:4])[0]


def tokenize_with_escapes(text):
# TODO improve
escaped_text = text.replace("\n", ESCAPE_PLACEHOLDER)
sentences = sent_tokenize(escaped_text)
sentences_with_escapes = [s.replace(ESCAPE_PLACEHOLDER, "\n") for s in sentences]
return sentences_with_escapes
Loading

0 comments on commit 5b1c4b4

Please sign in to comment.