Skip to content

Commit

Permalink
fixed tests
Browse files Browse the repository at this point in the history
  • Loading branch information
guipenedo committed Jan 11, 2024
1 parent 4dc7af6 commit 3576bcd
Show file tree
Hide file tree
Showing 8 changed files with 39 additions and 41 deletions.
8 changes: 6 additions & 2 deletions src/datatrove/datafolder.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(
"""
super().__init__(path=path, fs=fs if fs else url_to_fs(path, **storage_options)[0])
self.recursive = recursive
self.glob = glob
self.pattern = glob
if not self.isdir("/"):
self.mkdirs("/", exist_ok=True)

Expand All @@ -62,7 +62,11 @@ def list_files(self, suffix: str = "", extension: str | list[str] = None) -> lis
extension = [extension]
return [
f
for f in self.find(suffix, maxdepth=0 if not self.recursive else None)
for f in (
self.find(suffix, maxdepth=0 if not self.recursive else None)
if not self.pattern
else self.glob(self.fs.sep.join([self.pattern, suffix]), maxdepth=0 if not self.recursive else None)
)
if not extension or any(f.endswith(ext) for ext in extension)
]

Expand Down
20 changes: 9 additions & 11 deletions src/datatrove/pipeline/readers/ipc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pyarrow as pa

from datatrove.io import BaseInputDataFile, BaseInputDataFolder
from datatrove.datafolder import ParsableDataFolder
from datatrove.pipeline.readers.base import BaseReader


Expand All @@ -11,7 +11,7 @@ class IpcReader(BaseReader):

def __init__(
self,
data_folder: BaseInputDataFolder,
data_folder: ParsableDataFolder,
limit: int = -1,
stream: bool = False,
progress: bool = False,
Expand All @@ -24,28 +24,26 @@ def __init__(
self.stream = stream
# TODO: add option to disable reading metadata (https://github.com/apache/arrow/issues/13827 needs to be addressed first)

@staticmethod
def _iter_file_batches(datafile: BaseInputDataFile):
with datafile.open(binary=True) as f:
def _iter_file_batches(self, filepath: str):
with self.data_folder.open(filepath, "rb") as f:
with pa.ipc.open_file(f) as ipc_reader:
for i in range(ipc_reader.num_record_batches):
yield ipc_reader.get_batch(i)

@staticmethod
def _iter_stream_batches(datafile: BaseInputDataFile):
with datafile.open(binary=True) as f:
def _iter_stream_batches(self, filepath: str):
with self.data_folder.open(filepath, "rb") as f:
with pa.ipc.open_stream(f) as ipc_stream_reader:
for batch in ipc_stream_reader:
yield batch

def read_file(self, datafile: BaseInputDataFile):
batch_iter = self._iter_file_batches(datafile) if not self.stream else self._iter_stream_batches(datafile)
def read_file(self, filepath: str):
batch_iter = self._iter_file_batches(filepath) if not self.stream else self._iter_stream_batches(filepath)
li = 0
for batch in batch_iter:
documents = []
with self.track_time("batch"):
for line in batch.to_pylist():
document = self.get_document_from_dict(line, datafile, li)
document = self.get_document_from_dict(line, filepath, li)
if not document:
continue
documents.append(document)
Expand Down
4 changes: 2 additions & 2 deletions tests/pipeline/test_bloom_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import unittest

from datatrove.data import Document
from datatrove.pipeline.dedup.bloom_filter import SingleBloomFilter, get_false_positive_prob, get_optimal_k
from datatrove.pipeline.dedup.bloom_filter import SingleBloomFilter


TEXT_0 = (
Expand Down Expand Up @@ -91,7 +91,7 @@ def setUp(self):
self.addCleanup(shutil.rmtree, self.tmp_dir)

def test_sd(self):
bloom_filter = SingleBloomFilter(output_folder=self.test_dir, m_bytes=2**10 - 1, k=7, expected_elements=866)
bloom_filter = SingleBloomFilter(output_folder=self.tmp_dir, m_bytes=2**10 - 1, k=7, expected_elements=866)

for doc_idx, doc in enumerate(DOCS):
is_unique = bloom_filter.step(doc)
Expand Down
20 changes: 10 additions & 10 deletions tests/pipeline/test_exact_substrings.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,16 +183,16 @@ def test_signature_1_worker(self):
f.write(bytearange_file)
data = copy.deepcopy(DATA)

dataset_to_sequence = DatasetToSequence(output_folder=self.test_dir)
dataset_to_sequence = DatasetToSequence(output_folder=self.tmp_dir)
merge_sequence = MergeSequences(
input_folder=self.test_dir,
output_folder=self.test_dir,
input_folder=self.tmp_dir,
output_folder=self.tmp_dir,
tasks_stage_1=1,
)

dedup_reader = DedupReader(
data_folder=self.test_dir,
sequence_folder=self.test_dir,
data_folder=self.tmp_dir,
sequence_folder=self.tmp_dir,
min_doc_words=0,
)

Expand Down Expand Up @@ -237,16 +237,16 @@ def test_signature_2_worker(self):
with open(self.tmp_dir + "/test" + ExtensionHelperES.stage_3_bytes_ranges, "w") as f:
f.write(bytearange_file_2)

dataset_to_sequence = DatasetToSequence(output_folder=self.test_dir)
dataset_to_sequence = DatasetToSequence(output_folder=self.tmp_dir)
merge_sequence = MergeSequences(
input_folder=self.test_dir,
output_folder=self.test_dir,
input_folder=self.tmp_dir,
output_folder=self.tmp_dir,
tasks_stage_1=2,
)

dedup_reader = DedupReader(
data_folder=self.test_dir,
sequence_folder=self.test_dir,
data_folder=self.tmp_dir,
sequence_folder=self.tmp_dir,
min_doc_words=0,
)

Expand Down
7 changes: 2 additions & 5 deletions tests/pipeline/test_ipc_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import pyarrow as pa
import pyarrow.feather as feather

from datatrove.io import LocalInputDataFolder
from datatrove.pipeline.readers.ipc import IpcReader


Expand Down Expand Up @@ -49,13 +48,11 @@ def check_same_data(self, documents, stream=False, check_metadata=True):
self.assertEqual(document.metadata[key], row[key])

def test_ipc_reader(self):
reader = IpcReader(LocalInputDataFolder(self.tmp_dir, extension=os.path.splitext(self.ipc_file)[1]))
reader = IpcReader((self.tmp_dir, {"glob": "*.feather"}))
documents = list(reader.run())
self.check_same_data(documents)

def test_ipc_stream_reader(self):
reader = IpcReader(
LocalInputDataFolder(self.tmp_dir, extension=os.path.splitext(self.ipc_stream_file)[1]), stream=True
)
reader = IpcReader((self.tmp_dir, {"glob": "*.arrow"}), stream=True)
documents = list(reader.run())
self.check_same_data(documents)
4 changes: 2 additions & 2 deletions tests/pipeline/test_minhash.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ def setUp(self):
def test_signatures(self):
for use_64bit_hashes in (True, False):
config = MinhashConfig(use_64bit_hashes=use_64bit_hashes)
minhash = MinhashDedupSignature(output_folder=os.path.join(self.test_dir, "signatures1"), config=config)
minhash = MinhashDedupSignature(output_folder=os.path.join(self.tmp_dir, "signatures1"), config=config)
shingles = minhash.get_shingles(lorem_ipsum)
sig = minhash.get_signature(shingles)
minhash2 = MinhashDedupSignature(output_folder=os.path.join(self.test_dir, "signatures2"), config=config)
minhash2 = MinhashDedupSignature(output_folder=os.path.join(self.tmp_dir, "signatures2"), config=config)
# check consistency
assert sig == minhash2.get_signature(shingles)

Expand Down
5 changes: 2 additions & 3 deletions tests/pipeline/test_parquet_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import pyarrow as pa
import pyarrow.parquet as pq

from datatrove.io import LocalInputDataFolder
from datatrove.pipeline.readers.parquet import ParquetReader


Expand Down Expand Up @@ -37,11 +36,11 @@ def check_same_data(self, documents, check_metadata=True):
self.assertEqual(document.metadata[key], row[key])

def test_read(self):
reader = ParquetReader(LocalInputDataFolder(self.tmp_dir))
reader = ParquetReader(self.tmp_dir)
documents = list(reader.run())
self.check_same_data(documents)

def test_read_no_metadata(self):
reader = ParquetReader(LocalInputDataFolder(self.tmp_dir), read_metadata=False)
reader = ParquetReader(self.tmp_dir, read_metadata=False)
documents = list(reader.run())
self.check_same_data(documents, check_metadata=False)
12 changes: 6 additions & 6 deletions tests/pipeline/test_sentence_deduplication.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,20 +133,20 @@ def setUp(self):
self.addCleanup(shutil.rmtree, self.tmp_dir)

def test_sd(self):
signature_creation = SentenceDedupSignature(output_folder=self.test_dir)
find_duplicates = SentenceFindDedups(data_folder=self.test_dir, output_folder=self.test_dir)
dedup_filter = SentenceDedupFilter(data_folder=self.test_dir, min_doc_words=0)
signature_creation = SentenceDedupSignature(output_folder=self.tmp_dir)
find_duplicates = SentenceFindDedups(data_folder=self.tmp_dir, output_folder=self.tmp_dir)
dedup_filter = SentenceDedupFilter(data_folder=self.tmp_dir, min_doc_words=0)

signature_creation(data=DOCS)
find_duplicates(data=[])
for i, doc in enumerate(dedup_filter(data=copy.deepcopy(DOCS))):
self.assertEqual(doc.content, TARGETS[i])

def test_sd_worker(self):
signature_creation = SentenceDedupSignature(output_folder=self.test_dir)
signature_creation = SentenceDedupSignature(output_folder=self.tmp_dir)

find_duplicates = SentenceFindDedups(data_folder=self.test_dir, output_folder=self.test_dir)
dedup_filter = SentenceDedupFilter(data_folder=self.test_dir, min_doc_words=0)
find_duplicates = SentenceFindDedups(data_folder=self.tmp_dir, output_folder=self.tmp_dir)
dedup_filter = SentenceDedupFilter(data_folder=self.tmp_dir, min_doc_words=0)

signature_creation(data=DOCS, rank=0, world_size=2)
signature_creation(data=DOCS_2, rank=1, world_size=2)
Expand Down

0 comments on commit 3576bcd

Please sign in to comment.