From 9923b740d16633033efd38e15a8ac272b50ef1d7 Mon Sep 17 00:00:00 2001 From: Luca Soldaini Date: Thu, 22 Jun 2023 11:03:01 -0700 Subject: [PATCH] adding python content --- .flake8 | 25 ++ .github/workflows/CI.yml | 2 +- .gitignore | 1 - Cargo.toml | 2 + python/dolma/core/__init__.py | 11 + python/dolma/core/data_types.py | 205 +++++++++++++ python/dolma/core/ft_dataset.py | 243 +++++++++++++++ python/dolma/core/ft_tagger.py | 155 ++++++++++ python/dolma/core/parallel.py | 344 +++++++++++++++++++++ python/dolma/core/registry.py | 39 +++ python/dolma/core/runtime.py | 308 +++++++++++++++++++ python/dolma/core/taggers.py | 35 +++ python/dolma/core/trainer.py | 7 + python/dolma/core/utils.py | 43 +++ python/dolma/core/vizualizer.py | 252 ++++++++++++++++ python/dolma/data/naughty_words_en.txt | 403 +++++++++++++++++++++++++ python/dolma/taggers/__init__.py | 8 + python/dolma/taggers/c4.py | 87 ++++++ python/dolma/taggers/code.py | 199 ++++++++++++ python/dolma/taggers/gopher.py | 172 +++++++++++ python/dolma/taggers/jigsaw.py | 45 +++ python/dolma/taggers/language.py | 156 ++++++++++ python/dolma/taggers/length.py | 119 ++++++++ python/dolma/taggers/pii.py | 291 ++++++++++++++++++ python/dolma/taggers/sampling.py | 20 ++ 25 files changed, 3170 insertions(+), 2 deletions(-) create mode 100644 .flake8 create mode 100644 python/dolma/core/__init__.py create mode 100644 python/dolma/core/data_types.py create mode 100644 python/dolma/core/ft_dataset.py create mode 100644 python/dolma/core/ft_tagger.py create mode 100644 python/dolma/core/parallel.py create mode 100644 python/dolma/core/registry.py create mode 100644 python/dolma/core/runtime.py create mode 100644 python/dolma/core/taggers.py create mode 100644 python/dolma/core/trainer.py create mode 100644 python/dolma/core/utils.py create mode 100644 python/dolma/core/vizualizer.py create mode 100644 python/dolma/data/naughty_words_en.txt create mode 100644 python/dolma/taggers/__init__.py create mode 100644 python/dolma/taggers/c4.py create mode 100644 python/dolma/taggers/code.py create mode 100644 python/dolma/taggers/gopher.py create mode 100644 python/dolma/taggers/jigsaw.py create mode 100644 python/dolma/taggers/language.py create mode 100644 python/dolma/taggers/length.py create mode 100644 python/dolma/taggers/pii.py create mode 100644 python/dolma/taggers/sampling.py diff --git a/.flake8 b/.flake8 new file mode 100644 index 00000000..b22f3a3a --- /dev/null +++ b/.flake8 @@ -0,0 +1,25 @@ +[flake8] +max-line-length = 115 + +ignore = + # these rules don't play well with black + # whitespace before : + E203 + # line break before binary operator + W503 + # line too long, who cares? + E501 + +exclude = + .venv + .git + __pycache__ + docs/build + dist + .mypy_cache + pretrain_data + +per-file-ignores = + # __init__.py files are allowed to have unused imports and lines-too-long + */__init__.py:F401,F403 + */**/**/__init__.py:F401,E501 diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 8268eaf6..d24f6c52 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -23,7 +23,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - target: [x86_64, x86, aarch64, armv7, s390x, ppc64le] + target: [x86_64, x86, aarch64, armv7, s390x] steps: - uses: actions/checkout@v3 - uses: actions/setup-python@v4 diff --git a/.gitignore b/.gitignore index 4cb4e066..0e8d0042 100644 --- a/.gitignore +++ b/.gitignore @@ -53,7 +53,6 @@ site/ /runs/ /wandb/ /scratch/ -core # compiled rust extension target/ diff --git a/Cargo.toml b/Cargo.toml index fbd1bb19..5d9604fe 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,6 +2,8 @@ name = "dolma" version = "0.6.0" edition = "2021" +license = "Apache-2.0" + # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [lib] diff --git a/python/dolma/core/__init__.py b/python/dolma/core/__init__.py new file mode 100644 index 00000000..9219cd50 --- /dev/null +++ b/python/dolma/core/__init__.py @@ -0,0 +1,11 @@ +from .data_types import DocResult, Document, Span +from .registry import TaggerRegistry +from .taggers import BaseTagger + +__all__ = [ + "BaseTagger", + "DocResult", + "Document", + "Span", + "TaggerRegistry", +] diff --git a/python/dolma/core/data_types.py b/python/dolma/core/data_types.py new file mode 100644 index 00000000..00a5f6a1 --- /dev/null +++ b/python/dolma/core/data_types.py @@ -0,0 +1,205 @@ +""" + +Data types assumed by Filters. + +@kylel, @soldni + +""" + +from typing import Any, Dict, List, Optional, Tuple, Union + +from msgspec import Struct + + +class Ai2LlmFilterError(Exception): + """Base class for all errors""" + + +class Ai2LlmFatalError(Ai2LlmFilterError): + """Fatal error. Abort the entire process""" + + +class Ai2LlmShardError(Ai2LlmFilterError): + """Fail the shard and continue""" + + +class Ai2LlmRetryableFailure(Ai2LlmFilterError): + """Retry if a shard throws this error""" + + +class InputSpec(Struct): + id: str + text: str + source: str + version: Optional[str] = None + + +class OutputSpec(Struct): + id: str + attributes: Dict[str, List[List[Union[int, float]]]] + source: Optional[str] = None + + +class Document: + __slots__ = "source", "version", "id", "text" + + def __init__(self, source: str, id: str, text: str, version: Optional[str] = None) -> None: + self.source = source + self.version = version + self.id = id + self.text = text + + @classmethod + def from_spec(cls, spec: InputSpec) -> "Document": + return Document(source=spec.source, version=spec.version, id=spec.id, text=spec.text) + + def to_spec(self) -> InputSpec: + return InputSpec(source=self.source, version=self.version, id=self.id, text=self.text) + + @classmethod + def from_json(cls, d: Dict) -> "Document": + return Document(source=d["source"], version=d["version"], id=d["id"], text=d["text"]) + + def to_json(self) -> Dict: + return {"source": self.source, "version": self.version, "id": self.id, "text": self.text} + + def __str__(self) -> str: + return ( + str(self.__class__.__name__) + + f"(source={repr(self.source)},version={repr(self.version)},id={repr(self.id)},text={repr(self.text)})" + ) + + +class Span: + __slots__ = "start", "end", "type", "score", "experiment", "tagger" + + def __init__( + self, + start: int, + end: int, + type: str, + score: float = 1.0, + experiment: Optional[str] = None, + tagger: Optional[str] = None, + ): + self.start = start + self.end = end + self.type = type + self.score = float(score) + self.experiment = experiment + self.tagger = tagger + + def mention(self, text: str, window: int = 0) -> str: + return text[max(0, self.start - window) : min(len(text), self.end + window)] + + def select(self, doc: Document) -> str: + return doc.text[self.start : self.end] + + @classmethod + def from_spec(cls, attribute_name: str, attribute_value: List[Union[int, float]]) -> "Span": + if "__" in attribute_name: + # bff tagger has different name + exp_name, tgr_name, attr_type = attribute_name.split("__", 2) + else: + exp_name = tgr_name = attr_type = attribute_name + + start, end, score = attribute_value + return Span( + start=int(start), + end=int(end), + type=attr_type, + score=float(score), + experiment=exp_name, + tagger=tgr_name, + ) + + def to_spec(self) -> Tuple[str, List[Union[int, float]]]: + assert self.experiment is not None, "Experiment name must be set to convert to spec" + assert self.tagger is not None, "Tagger name must be set to convert to spec" + return ( + f"{self.experiment}__{self.tagger}__{self.type}", + [self.start, self.end, self.score], + ) + + def __len__(self) -> int: + return self.end - self.start + + @classmethod + def from_json(cls, di: Dict) -> "Span": + return Span(start=di["start"], end=di["end"], type=di["type"], score=di["score"]) + + def to_json(self, text: Optional[str] = None, window: int = 0) -> dict: + span_repr = {"start": self.start, "end": self.end, "type": self.type, "score": self.score} + if text is not None: + span_repr["mention"] = self.mention(text=text, window=window) + return span_repr + + def __str__(self) -> str: + cls_name = self.__class__.__name__ + return f"{cls_name}(start={self.start},end={self.end},type={repr(self.type)},score={self.score:.5f})" + + +class DocResult: + __slots__ = "doc", "spans" + + def __init__(self, doc: Document, spans: List[Span]) -> None: + self.doc = doc + self.spans = spans + + @classmethod + def from_spec(cls, doc: InputSpec, *attrs_groups: OutputSpec) -> "DocResult": + spans: List[Span] = [] + for attrs in attrs_groups: + assert doc.id == attrs.id, f"doc.id={doc.id} != attrs.id={attrs.id}" + spans.extend( + [ + Span.from_spec(attribute_name=attr_name, attribute_value=attr_value) + for attr_name, attr_values in attrs.attributes.items() + for attr_value in attr_values + ] + ) + return DocResult(doc=Document.from_spec(doc), spans=spans) + + def to_spec(self) -> Tuple[InputSpec, OutputSpec]: + doc_spec = self.doc.to_spec() + attributes: Dict[str, List[List[Union[int, float]]]] = {} + + for span in self.spans: + attr_name, attr_value = span.to_spec() + attributes.setdefault(attr_name, []).append(attr_value) + + return doc_spec, OutputSpec(source=self.doc.source, id=self.doc.id, attributes=attributes) + + @classmethod + def from_json(cls, d: Dict[str, Any]) -> "DocResult": + return DocResult( + doc=Document.from_json(d["doc"]), + spans=[Span.from_json(span) for span in d["spans"]], + ) + + def to_json(self, with_doc: bool = False, window: int = 0) -> Dict[str, Any]: + d: Dict[str, Any] = {"spans": [span.to_json(text=self.doc.text, window=window) for span in self.spans]} + if with_doc: + d["doc"] = self.doc.to_json() + return d + + def __str__(self) -> str: + return f"{self.__class__.__name__}(doc={self.doc},spans=[{','.join(str(s) for s in self.spans)}])" + + +class TextSlice: + """A slice of text from a document.""" + + __slots__ = "doc", "start", "end" + + def __init__(self, doc: str, start: int, end: int): + self.doc = doc + self.start = start + self.end = end + + @property + def text(self) -> str: + return self.doc[self.start : self.end] + + def __str__(self) -> str: + return f"{self.__class__.__name__}(text={repr(self.text)},start={self.start},end={self.end})" diff --git a/python/dolma/core/ft_dataset.py b/python/dolma/core/ft_dataset.py new file mode 100644 index 00000000..43385948 --- /dev/null +++ b/python/dolma/core/ft_dataset.py @@ -0,0 +1,243 @@ +""" + +Builds a dataset for training a FastText model from 2 or more pretraining datasets. + +@rauthur + +""" + +import argparse +import gzip +import json +import os +import random +from dataclasses import dataclass +from functools import partial +from multiprocessing import Event, Manager, Pool, Process, Queue +from typing import Generator, Optional + +from smashed.utils.io_utils import ( + open_file_for_read, + open_file_for_write, + recursively_list_files, +) + +from .data_types import TextSlice +from .ft_tagger import BaseFastTextTagger +from .utils import split_paragraphs, split_sentences + +_WRITER_EXIT_MSG = "__WRITE__EXIT__" + + +@dataclass +class Config: + target_path: str + sample_paths: str + out_path: str + mode: str + newlines: str + n_proc: int + n_segments: Optional[int] + pos_label: str + neg_label: str + + +def gzip_open(file, mode, **open_kwargs): + return gzip.open(filename=file, mode=mode, **open_kwargs) + + +def _split(text: str, config: Config) -> Generator[TextSlice, None, None]: + if config.mode == BaseFastTextTagger.SENTENCE_LEVEL_TAGGER: + for sentence in split_sentences(text): + yield sentence + + elif config.mode == BaseFastTextTagger.PARAGRAPH_LEVEL_TAGGER: + for paragraph in split_paragraphs(text): + yield paragraph + + elif config.mode == BaseFastTextTagger.DOCUMENT_LEVEL_TAGGER: + yield TextSlice(doc=text, start=0, end=len(text)) + + else: + raise RuntimeError(f"Unknown data split mode: {config.mode}") + + +@dataclass +class ReadResult: + examples: list[str] + + +def process_file(config: Config, q: Queue, flag: Event, label: str, fn): + # Check a global exit flag and stop processing file + if flag.is_set(): + return + + print(f"Processing {fn}") + + with open_file_for_read(fn, "rb", open_fn=gzip_open) as f: + for line in f: + # Abort part way through processing this file is flag set + if flag.is_set(): + return + + # Expected JSONL format following OLMo data spec + data = json.loads(line.decode("utf-8")) + line_text = data["text"] + + if len(line_text) == 0: + continue + + for slice in _split(line_text, config): + final_text = slice.text + + if "\n" in final_text: + if config.newlines == "replace": + final_text = final_text.replace("\n", " ") + elif config.newlines == "skip": + continue + + q.put(f"__label__{label} {final_text}") + + +def write_results(config: Config, q: Queue, flag: Event): + written = 0 + + with open_file_for_write(config.out_path) as o: + while True: + msg = q.get() + + if msg == _WRITER_EXIT_MSG: + break + + if not flag.is_set(): + o.write(q.get()) + o.write("\n") + written += 1 + + if config.n_segments is not None and written >= config.n_segments: + flag.set() + written = 0 + + +def process_paths(paths: list[str], config: Config, q: Queue, flag: Event, label: str): + fns = [fn for p in paths for fn in recursively_list_files(p)] + random.shuffle(fns) + + work_fn = partial(process_file, config, q, flag, label) + + with Pool(processes=max(1, config.n_proc - 1)) as pool: + pool.map(work_fn, fns) + pool.close() + pool.join() + + +def main(config: Config): + random.seed(117) + + with Manager() as manager: + q = manager.Queue() + flag = manager.Event() + + writer = Process(target=write_results, args=(config, q, flag)) + writer.start() + + # Generate expected number of positive examples + process_paths([config.target_path], config, q, flag, config.pos_label) + + # Reset early exit flag as all positive examples processed + flag.clear() + + # Generate expected number of negative examples + process_paths(config.sample_paths, config, q, flag, config.neg_label) + + q.put(_WRITER_EXIT_MSG) + writer.join() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "-t", + "--target", + required=True, + type=str, + help="Local or remote path including OLMo formatted TARGET dataset", + ) + parser.add_argument( + "-s", + "--sample", + required=True, + type=str, + nargs="+", + help="Sample these paths to create negative examples", + ) + parser.add_argument( + "-p", + "--processes", + type=int, + required=False, + default=os.cpu_count(), + help="Number of processes to launch", + ) + parser.add_argument( + "-m", + "--mode", + type=str, + required=True, + choices=[ + BaseFastTextTagger.SENTENCE_LEVEL_TAGGER, + BaseFastTextTagger.PARAGRAPH_LEVEL_TAGGER, + BaseFastTextTagger.DOCUMENT_LEVEL_TAGGER, + ], + help="Output examples at this level", + ) + parser.add_argument( + "-o", + "--output-filename", + type=str, + required=True, + help="Path to write the processed result (can be on S3)", + ) + parser.add_argument( + "--n-segments", + type=int, + required=False, + help="Stop after generating this many segments (e.g., sentences)", + ) + parser.add_argument( + "--newlines", + type=str, + required=False, + choices=["skip", "keep", "replace"], + default="skip", + help="Skip, keep or replace with ' ' examples with newlines after splitting", + ) + parser.add_argument( + "--pos-label", + type=str, + required=False, + default="pos", + help="Use this class label for positive instances", + ) + parser.add_argument( + "--neg-label", + type=str, + required=False, + default="neg", + help="Use this class label for negative instances", + ) + + args = parser.parse_args() + config = Config( + target_path=args.target, + sample_paths=args.sample, + out_path=args.output_filename, + mode=args.mode, + newlines=args.newlines, + n_proc=args.processes, + n_segments=args.n_segments, + pos_label=args.pos_label, + neg_label=args.neg_label, + ) + + main(config) diff --git a/python/dolma/core/ft_tagger.py b/python/dolma/core/ft_tagger.py new file mode 100644 index 00000000..d127438a --- /dev/null +++ b/python/dolma/core/ft_tagger.py @@ -0,0 +1,155 @@ +""" + +Base implementation for a fasttext tagger; all fasttext taggers should inherit from this class. + +@kylel, @soldni + +""" +import os +from tempfile import NamedTemporaryFile +from typing import Iterable, Literal, NamedTuple, Optional + +from cached_path import cached_path +from fasttext import train_supervised +from fasttext.FastText import _FastText +from smashed.utils.io_utils import open_file_for_write + +from .data_types import DocResult, Document, Span, TextSlice +from .taggers import BaseTagger +from .utils import split_paragraphs, split_sentences + + +class Prediction(NamedTuple): + label: str + score: float + + +class BaseFastTextTagger(BaseTagger): + SENTENCE_LEVEL_TAGGER = "sentence" + PARAGRAPH_LEVEL_TAGGER = "paragraph" + DOCUMENT_LEVEL_TAGGER = "document" + + def __init__(self, model_path: str, model_mode: str) -> None: + # we use this private attribute to avoid a warning from the fasttext library. See this comment: + # https://github.com/facebookresearch/fastText/issues/1056#issuecomment-1278058705 + self.classifier = _FastText(str(cached_path(model_path))) + self.mode = model_mode + + @classmethod + def train( + cls, + train_file: str, + save_path: str, + learning_rate: float = 0.1, + word_vectors_dim: int = 100, + context_window_size: int = 5, + max_epochs: int = 100, # non-default + min_word_count: int = 1, + min_label_count: int = 1, + min_char_ngram: int = 0, + max_char_ngram: int = 0, + num_negative_samples: int = 5, + max_word_ngram: int = 2, # non-default + loss_function: Literal["ns", "hs", "softmax", "ova"] = "softmax", + num_buckets: int = 2_000_000, + num_threads: int = 0, + learning_rate_update_rate: int = 100, + sampling_threshold: float = 0.0001, + label_prefix: str = "__label__", + verbose: int = 2, + pretrained_vectors: Optional[str] = None, + ) -> _FastText: + # download potentially remote files + local_train_file = cached_path(train_file) + local_pretrained_vectors = cached_path(pretrained_vectors) if pretrained_vectors else None + + # base checks on file format + with open(local_train_file, "r") as f: + # check a few lines to see if the file is in the right format + i = 0 + for ln in f: + if label_prefix not in ln: + raise ValueError(f"{train_file} not the fasttext format, no labels found!") + if (i := i + 1) > 5: + break + if i == 0: + raise ValueError(f"{train_file} is empty!") + + # train the fasttext model + classifier = train_supervised( + input=local_train_file, + lr=learning_rate, + dim=word_vectors_dim, + ws=context_window_size, + epoch=max_epochs, + minCount=min_word_count, + minCountLabel=min_label_count, + minn=min_char_ngram, + maxn=max_char_ngram, + neg=num_negative_samples, + wordNgrams=max_word_ngram, + loss=loss_function, + bucket=num_buckets, + thread=num_threads, + lrUpdateRate=learning_rate_update_rate, + t=sampling_threshold, + label=label_prefix, + verbose=verbose, + pretrainedVectors=local_pretrained_vectors, + ) + + local_save_path = None + try: + # create a local temp file where we save the model + with NamedTemporaryFile("w", delete=False) as f: + local_save_path = f.name + + # save the model + classifier.save_model(local_save_path) + + # upload to remote if save_path is s3 path + with open_file_for_write(save_path, "wb") as fo, open(local_save_path, "rb") as fi: + fo.write(fi.read()) + finally: + # regardless to what happened, remove the local temp file if it + # exists + if local_save_path is not None and os.path.exists(local_save_path): + os.remove(local_save_path) + + return classifier + + @classmethod + def test( + cls, + test_file: str, + model_path: Optional[str] = None, + classifier: Optional[_FastText] = None, + ): + # load the model if one is not provided + if classifier is None: + assert model_path is not None, "Please provide either a model path or a model" + classifier = _FastText(str(cached_path(model_path))) + + local_test_file = cached_path(test_file) + model_performance = classifier.test(local_test_file) + print(model_performance) + + def predict(self, doc: Document) -> DocResult: + if self.mode == self.SENTENCE_LEVEL_TAGGER: + units = split_sentences(doc.text) + elif self.mode == self.PARAGRAPH_LEVEL_TAGGER: + units = split_paragraphs(doc.text) + elif self.mode == self.DOCUMENT_LEVEL_TAGGER: + units = [TextSlice(doc=doc.text, start=0, end=len(doc.text))] + else: + raise ValueError(f"Unknown mode {self.mode}") + + spans = [] + for unit in units: + for prediction in self.predict_slice(unit): + spans.append(Span(start=unit.start, end=unit.end, type=prediction.label, score=prediction.score)) + + return DocResult(doc=doc, spans=spans) + + def predict_slice(self, text_slice: TextSlice) -> Iterable[Prediction]: + raise NotImplementedError("Please implement the predict slice method") diff --git a/python/dolma/core/parallel.py b/python/dolma/core/parallel.py new file mode 100644 index 00000000..ac32c187 --- /dev/null +++ b/python/dolma/core/parallel.py @@ -0,0 +1,344 @@ +import inspect +import multiprocessing +import pickle +import random +import re +import time +from contextlib import ExitStack +from datetime import datetime +from functools import partial +from queue import Queue +from threading import Thread +from typing import Any, Dict, List, Optional, Tuple, Union + +import tqdm +from smashed.utils.io_utils import ( + MultiPath, + open_file_for_write, + recursively_list_files, +) + +from .data_types import Ai2LlmFilterError, Ai2LlmRetryableFailure + +METADATA_SUFFIX = ".done.txt" + + +class BaseParallelProcessor: + """A base parallel processor that supports applying the same process_single method to a list of files. + + This class is meant to be subclassed. The subclass must implement: + - `process_single` method, which takes a source path file to transform, and a destination path where + to save the transformed file. + - `increment_progressbar` method, which defines which units to keep track of in the progress bar. + + See documentation of both methods for more details on how to implement them correctly. + """ + + def __init__( + self, + source_prefix: str, + destination_prefix: str, + metadata_prefix: str, + num_processes: int = 1, + debug: bool = False, + seed: int = 0, + pbar_timeout: float = 0.01, + ignore_existing: bool = False, + include_paths: Optional[List[str]] = None, + exclude_paths: Optional[List[str]] = None, + files_regex_pattern: Optional[str] = None, + ): + """Initialize the parallel processor. + + Args: + source_prefix (str): The location where source files are stored. This can be a local directory or a + prefix to an S3 location. + destination_prefix (str): The location where to save the transformed files. This can be a local + directory or a prefix to an S3 location. Local directories will be created if they do not exist. + The directory structure from the source prefix will be replicated in the destination prefix; + file names will also be the same. + metadata_prefix (str): The prefix of the metadata files to save. This can be a local path or an + S3 path. Metadata output will be created for each file after it is processed. Filenames are + checked to verify if a file has been processed and can be skipped unless `ignore_existing` is + set to true. + num_processes (int, optional): The number of processes to use. Defaults to 1. + debug (bool, optional): Whether to run in debug mode; if true, no multiprocessing will be used. + Defaults to False. + seed (int, optional): The random seed to use when shuffling input files. Defaults to 0. + pbar_timeout (float, optional): How often to update progress bars in seconds. + Defaults to 0.01 seconds. + ignore_existing (bool, optional): Whether to ignore files that have been already processed and + re-run the processor on all files from scratch. Defaults to False. + include_paths (Optional[List[str]], optional): A list of paths to include. If provided, only files + that match one of the paths will be processed. Defaults to None. + exclude_paths (Optional[List[str]], optional): A list of paths to exclude. If provided, files that + match one of the paths will be skipped. Defaults to None. + """ + + self.source_prefix = MultiPath.parse(source_prefix) + self.destination_prefix = MultiPath.parse(destination_prefix) + self.metadata_prefix = MultiPath.parse(metadata_prefix) + self.num_processes = num_processes + self.debug = debug + self.seed = seed + self.pbar_timeout = pbar_timeout + self.ignore_existing = ignore_existing + + self.include_paths = set(include_paths) if include_paths is not None else None + self.exclude_paths = set(exclude_paths) if exclude_paths is not None else None + self.files_regex_pattern = re.compile(files_regex_pattern) if files_regex_pattern else None + + # checking that the increment_progressbar method is subclassed + # correctly + sig = inspect.signature(self.increment_progressbar) + if "queue" not in sig.parameters or sig.parameters["queue"].kind != inspect.Parameter.POSITIONAL_ONLY: + raise AttributeError( + "increment_progressbar must have a positional-only argument named 'queue'; " + "Check that you have subclassed BaseParallelProcessor correctly!" + ) + if "kwargs" in sig.parameters and sig.parameters["kwargs"].kind == inspect.Parameter.VAR_KEYWORD: + raise AttributeError( + "increment_progressbar must not have a **kwargs argument; " + "Check that you have subclassed BaseParallelProcessor correctly!" + ) + if any(p.name != "queue" and p.default != 0 for p in sig.parameters.values()): + raise AttributeError( + "increment_progressbar must have a default value of 0 for all arguments except 'queue'; " + "Check that you have subclassed BaseParallelProcessor correctly!" + ) + + @classmethod + def process_single( + cls, + source_path: str, + destination_path: str, + queue: "Queue[Union[None, Tuple[int, ...]]]", + **kwargs: Any, + ): + """Process a single file. + + This method must be implemented by the subclass. It takes a source path file to transform, and a + destination path where to save the transformed file. It also takes a queue to increment the progress + bars. The queue should be passed to the `increment_progressbar` method. + + Args: + source_path (str): The path to the source file to transform. Can be an S3 path or a local path. + destination_path (str): The path to the destination file to save. Can be an S3 path or a local path. + queue (Queue[Union[None, Tuple[int, ...]]]): The queue to increment the progress bars. + """ + raise NotImplementedError() + + @classmethod + def _process_single_and_save_status( + cls, + source_path: str, + destination_path: str, + metadata_path: str, + queue: "Queue[Union[None, Tuple[int, ...]]]", + serialized_kwargs: bytes, + ): + """A wrapper around process single that saves a metadata file if processing is successful.""" + + kwargs = pickle.loads(serialized_kwargs) + tries_remaining = kwargs.get("retry_on_read_error", 0) + 1 + while True: + try: + cls.process_single( + source_path=source_path, destination_path=destination_path, queue=queue, **kwargs + ) + break + except Ai2LlmRetryableFailure as e: + tries_remaining -= 1 + if tries_remaining == 0: + raise Ai2LlmFilterError from e + with open_file_for_write(metadata_path) as f: + f.write(datetime.now().isoformat()) + + @classmethod + def increment_progressbar( + cls, queue: "Queue[Union[None, Tuple[int, ...]]]", /, **kwargs: int + ) -> Dict[str, int]: + """Increment the progress bar by putting a tuple in the queue. + + When subclassing, we recommend defining which units to keep track of in the progress bar by + defining keyword arguments. Then you can call the base class via `super()` and pass the keyword. + Example: + + ```python + class MyProcessor(BaseParallelProcessor): + def increment_progressbar(self, queue, /, files = 0, documents = 0): # we use two progress bars + return super().increment_progressbar(queue, files=files, documents=documents) + ``` + """ + queue.put(tuple(kwargs.get(k, 0) for k in kwargs)) + return kwargs + + @classmethod + def _run_threaded_progressbar( + cls, + queue: "Queue[Union[None, Tuple[int, ...]]]", + timeout: float, + ): + """Run a progress bar in a separate thread. + + Args: + queue (Queue[Union[None, Tuple[int, ...]]]): The queue to increment the progress bars. + timeout (float): How often to update the progress bars in seconds. + """ + + sample_queue_output = cls.increment_progressbar(queue) + + with ExitStack() as stack: + pbars = [ + stack.enter_context( + tqdm.tqdm(desc=str(k), unit=str(k)[:1], position=i, unit_scale=True) # pyright: ignore + ) + for i, k in enumerate(sample_queue_output) + ] + + while True: + item = queue.get() + if item is None: + break + + for pbar, value in zip(pbars, item): + pbar.update(value) + + time.sleep(timeout) + + def _debug_run_all( + self, + all_source_paths: List[MultiPath], + all_destination_paths: List[MultiPath], + all_metadata_paths: List[MultiPath], + **process_single_kwargs: Any, + ): + """Run files one by one on the main process + + Args: + all_source_paths (List[MultiPath]): The list of source paths to process. + all_destination_paths (List[MultiPath]): The list of destination paths to save. + all_metadata_paths (List[MultiPath]): The locations where to save metadata. + """ + + it = zip(all_source_paths, all_destination_paths, all_metadata_paths) + pbar_queue: "Queue[Union[None, Tuple[int, ...]]]" = Queue() + thread = Thread(target=self._run_threaded_progressbar, args=(pbar_queue, self.pbar_timeout), daemon=True) + thread.start() + + for source_prefix, destination_prefix, metadata_prefix in it: + self._process_single_and_save_status( + source_path=source_prefix.as_str, + destination_path=destination_prefix.as_str, + metadata_path=metadata_prefix.as_str, + queue=pbar_queue, + serialized_kwargs=pickle.dumps(process_single_kwargs), + ) + + pbar_queue.put(None) + thread.join() + + def _multiprocessing_run_all( + self, + all_source_paths: List[MultiPath], + all_destination_paths: List[MultiPath], + all_metadata_paths: List[MultiPath], + **process_single_kwargs: Any, + ): + """Run files in parallel using multiprocessing. + + Args: + all_source_paths (List[MultiPath]): The list of source paths to process. + all_destination_paths (List[MultiPath]): The list of destination paths to save. + all_metadata_paths (List[MultiPath]): The locations where to save metadata. + """ + try: + multiprocessing.set_start_method("spawn") + except RuntimeError: + assert multiprocessing.get_start_method() == "spawn", "Multiprocessing start method must be spawn" + + with multiprocessing.Pool(processes=self.num_processes) as pool: + pbar_queue: "Queue[Union[None, Tuple[int, ...]]]" = (manager := multiprocessing.Manager()).Queue() + thread = Thread( + target=self._run_threaded_progressbar, args=(pbar_queue, self.pbar_timeout), daemon=True + ) + thread.start() + + process_single_fn = partial(self.process_single, queue=pbar_queue) + results = [] + + for s, d, m in zip(all_source_paths, all_destination_paths, all_metadata_paths): + process_single_fn = partial( + self._process_single_and_save_status, + queue=pbar_queue, + source_path=s.as_str, + destination_path=d.as_str, + metadata_path=m.as_str, + serialized_kwargs=pickle.dumps(process_single_kwargs), + ) + result = pool.apply_async(process_single_fn) + results.append(result) + + for result in results: + result.get() + + pool.close() + pool.join() + + pbar_queue.put(None) + thread.join() + manager.shutdown() + + def _valid_path(self, path: str) -> bool: + if self.include_paths is not None and path not in self.include_paths: + return False + if self.exclude_paths is not None and path in self.exclude_paths: + return False + if self.files_regex_pattern is not None and not self.files_regex_pattern.search(path): + return False + return True + + def _get_all_paths(self) -> Tuple[List[MultiPath], List[MultiPath], List[MultiPath]]: + """Get all paths to process using prefixes provided""" + all_source_paths, all_destination_paths, all_metadata_paths = [], [], [] + + existing_metadata_names = set( + (MultiPath.parse(path) - self.metadata_prefix).as_str.rstrip(METADATA_SUFFIX) + for path in recursively_list_files(self.metadata_prefix) + ) + + paths = list(recursively_list_files(self.source_prefix)) + random.shuffle(paths) + + for path in paths: + source_path = MultiPath.parse(path) + if not self.ignore_existing and (source_path - self.source_prefix).as_str in existing_metadata_names: + continue + + if not self._valid_path(source_path.as_str): + continue + + all_source_paths.append(source_path) + all_destination_paths.append(self.destination_prefix / (source_path - self.source_prefix)) + + metadata_path = MultiPath.parse(source_path.as_str + METADATA_SUFFIX) + all_metadata_paths.append(self.metadata_prefix / (metadata_path - self.source_prefix)) + + return all_source_paths, all_destination_paths, all_metadata_paths + + def __call__(self, **process_single_kwargs: Any): + """Run the processor.""" + random.seed(self.seed) + + all_source_paths, all_destination_paths, all_metadata_paths = self._get_all_paths() + + print(f"Found {len(all_source_paths):,} files to process") + + fn = self._debug_run_all if self.debug else self._multiprocessing_run_all + + fn( + all_source_paths=all_source_paths, + all_destination_paths=all_destination_paths, + all_metadata_paths=all_metadata_paths, + **process_single_kwargs, + ) diff --git a/python/dolma/core/registry.py b/python/dolma/core/registry.py new file mode 100644 index 00000000..9e638810 --- /dev/null +++ b/python/dolma/core/registry.py @@ -0,0 +1,39 @@ +from typing import Callable, Dict, Generator, Tuple, Type, TypeVar + +from .taggers import BaseTagger + +T = TypeVar("T", bound=BaseTagger) + + +class TaggerRegistry: + __taggers: Dict[str, Type[BaseTagger]] = {} + + @classmethod + def taggers(cls) -> Generator[Tuple[str, Type[BaseTagger]], None, None]: + yield from cls.__taggers.items() + + @classmethod + def add(cls, name: str) -> Callable[[Type[T]], Type[T]]: + def _add( + tagger_cls: Type[T], + tagger_name: str = name, + taggers_dict: Dict[str, Type[BaseTagger]] = cls.__taggers, + ) -> Type[T]: + if tagger_name in taggers_dict and taggers_dict[tagger_name] != tagger_cls: + if tagger_cls.__module__ == "__main__": + return tagger_cls + + raise ValueError(f"Tagger {tagger_name} already exists") + + taggers_dict[tagger_name] = tagger_cls + return tagger_cls + + return _add + + @classmethod + def get(cls, name: str) -> Type[BaseTagger]: + if name not in cls.__taggers: + raise ValueError( + f"Unknown tagger {name}; available taggers: " + ", ".join([tn for tn, _ in cls.taggers()]) + ) + return cls.__taggers[name] diff --git a/python/dolma/core/runtime.py b/python/dolma/core/runtime.py new file mode 100644 index 00000000..edb42451 --- /dev/null +++ b/python/dolma/core/runtime.py @@ -0,0 +1,308 @@ +import argparse +import gzip +import logging +import multiprocessing +import os +import tempfile +from contextlib import ExitStack +from queue import Queue +from typing import Dict + +import msgspec +from cached_path import cached_path +from smashed.utils.io_utils import ( + compress_stream, + decompress_stream, + open_file_for_write, + stream_file_for_read, +) + +from .data_types import ( + Ai2LlmFatalError, + Ai2LlmRetryableFailure, + Ai2LlmShardError, + InputSpec, + OutputSpec, +) +from .parallel import BaseParallelProcessor +from .registry import TaggerRegistry +from .utils import make_variable_name + + +class TaggerProcessor(BaseParallelProcessor): + BASE_S3_PREFIX = "s3://ai2-llm/pretraining-data/sources" + + @classmethod + def get_logger(cls) -> logging.Logger: + return logging.getLogger(cls.__name__) + + @classmethod + def increment_progressbar( # type: ignore + cls, + queue, # queue must be the first argument, and it should be a positional-only argument + /, + files: int = 0, + documents: int = 0, + ) -> Dict[str, int]: + """We override this method to specify which units we want to keep track of in a progress bar. + Specifically, we keep track of files and documents in this example. Their default value must be zero.""" + + # we call the super method to increment the progress bar + return super().increment_progressbar(queue, files=files, documents=documents) + + @classmethod + def process_single( + cls, + source_path: str, + destination_path: str, + queue: "Queue", + **kwargs, + ): + """Lets count run the taggers! We will use the destination path to save each tagger output.""" + + # get names of taggers + taggers_names = kwargs.get("taggers_names", None) + if taggers_names is None: + raise RuntimeError("Taggers not in kwargs, this is a bug! Please report it.") + elif not isinstance(taggers_names, list) or not all(isinstance(t, str) for t in taggers_names): + raise RuntimeError("Taggers are in the wrong format, this is a bug! Please report it.") + taggers = {make_variable_name(t): TaggerRegistry.get(t)() for t in taggers_names} + + # get name of experiment + experiment_name = kwargs.get("experiment_name", None) + if experiment_name is None: + raise RuntimeError("Experiment name not in kwargs, this is a bug! Please report it.") + experiment_name = make_variable_name(experiment_name) + + # skip on failure + skip_on_failure = kwargs.get("skip_on_failure", False) + + # local read cache + local_read_cache = kwargs.get("local_read_cache", None) + + # interval at which to update the progress bar; will double if it gets + # too full + update_interval = 1 + + # running document count; gets reset every time we update the progress + # bar + docs_cnt = 0 + + # creating dedicated encoders/decoders speeds up the process + encoder = msgspec.json.Encoder() + decoder = msgspec.json.Decoder(InputSpec) + + # this will be used to cache the file locally if needed + caching_path = source_path + + with ExitStack() as stack: + try: + # open each file for reading and writing. We use open_file_for_read to handle s3 paths and + # download the file locally if needed, while gzip.open is used to + # read and write gzipped files. + + if local_read_cache is not None: + caching_path = cached_path(source_path, cache_dir=local_read_cache) + in_stream = stack.enter_context(gzip.open(caching_path, mode="rt")) + else: + input_file = stack.enter_context(stream_file_for_read(source_path, mode="rb")) + in_stream = stack.enter_context(decompress_stream(input_file, mode="rt")) + + out_file = stack.enter_context(open_file_for_write(destination_path, "wb")) + out_stream = stack.enter_context(compress_stream(out_file, "wt")) + + for raw in in_stream: + # row = json.loads(raw) + row = decoder.decode(raw) + + # running the taggers and merging them flat + attributes = {} + for tagger_name, tagger in taggers.items(): + for key_name, key_value in tagger.tag(row).items(): + key_name = f"{experiment_name}__{tagger_name}__{make_variable_name(key_name)}" + attributes[key_name] = key_value + + # make output file + output = OutputSpec(source=row.source, id=row.id, attributes=attributes) + + # write the output to the output file + out_stream.write(encoder.encode(output).decode("utf-8") + "\n") # pyright: ignore + + # increment the number of documents processed so far + docs_cnt += 1 + + if docs_cnt % update_interval == 0: + # update the progress bar every 1000 documents to prevent + # buffering + cls.increment_progressbar(queue, documents=docs_cnt) + docs_cnt = 0 + + if queue.qsize() >= multiprocessing.cpu_count(): + # double the update interval if the queue is full + update_interval *= 2 + + except Exception as e: + # handle any exception that might have occurred + msg = f"Failed to process {source_path} due to {e.__class__.__name__}: {' '.join(e.args)}" + if e.__class__.__name__ == "IncompleteReadError": + # Intermittent error that occurs when reading from S3 + raise Ai2LlmRetryableFailure(msg) from e + else: + if skip_on_failure: + raise Ai2LlmShardError(msg) from e + else: + raise Ai2LlmFatalError(msg) from e + finally: + if caching_path != source_path and os.path.exists(caching_path): + os.remove(caching_path) + + # increment the files progress bar + cls.increment_progressbar(queue, files=1, documents=docs_cnt) + + @classmethod + def main(cls): + ap = argparse.ArgumentParser() + ap.add_argument( + "-d", + "--dataset", + default=None, + help=f"Dataset to process; this should be relative path from {TaggerProcessor.BASE_S3_PREFIX}.", + ) + ap.add_argument( + "-n", + "--experiment-name", + default=None, + help=( + "Name of for this sequence of taggers to be grouped under; " + "it could be 'experiment_n' or a more descriptive name." + ), + ) + ap.add_argument( + "-t", + "--taggers", + default=[], + nargs="+", + help="One or more taggers to run; use -l to list available taggers.", + ) + ap.add_argument("-l", "--list-taggers", action="store_true", help="List available taggers.") + ap.add_argument("-p", "--parallel", type=int, default=1, help="Number of parallel processes to use.") + ap.add_argument( + "-u", "--debug", action="store_true", help="Run in debug mode; parallelism will be disabled." + ) + ap.add_argument( + "--skip-on-failure", + action="store_true", + help="Skip documents that fail to process instead of raising an error.", + ) + ap.add_argument( + "--retry-on-read-error", + default=0, + type=int, + help="Number of retries to attempt after an error from reading inputs.", + ) + ap.add_argument( + "--reuse-existing", + default=None, + type=str, + help="If provided, keeps track of which files have already been processed and skips them. ", + ) + ap.add_argument( + "--local-read-cache", + default=None, + type=str, + help="If provided, will cache the files locally before processing them.", + ) + ap.add_argument( + "--manually-included-paths", + default=None, + nargs="+", + help="If provided, only these paths will be processed. If points to an existing file, read s3:// URLs from it.", + ) + ap.add_argument( + "--files-regex-pattern", + default=None, + type=str, + help="If provided, only files matching this regex pattern will be processed.", + ) + ap.add_argument( + "--manually-excluded-paths", + default=None, + nargs="+", + help="If provided, these paths will be skipped. If points to an existing file, read s3:// URLs from it.", + ) + ap.add_argument( + "--safe-mode", action="store_true", help="Run in safe mode; will download locally before processing." + ) + opts = ap.parse_args() + + if opts.list_taggers: + print("\nAvailable Taggers:") + for tagger_name, tagger_cls in sorted(TaggerRegistry.taggers()): + print(f" {tagger_name} ({tagger_cls.__name__})") + print() + return + + assert opts.dataset is not None, "Dataset must be specified." + assert opts.experiment_name is not None, "Experiment name must be specified." + assert len(opts.taggers) > 0, "At least one tagger must be specified." + + source_prefix = f"{cls.BASE_S3_PREFIX}/{opts.dataset}/documents" + destination_prefix = f"{cls.BASE_S3_PREFIX}/{opts.dataset}/attributes/{opts.experiment_name}" + + # use a local read cache if we are in safe mode or if a local read cache is provided + local_read_cache = opts.local_read_cache or (tempfile.gettempdir() if opts.safe_mode else None) + + with tempfile.TemporaryDirectory() as tempdir: + metadata_workdir = opts.reuse_existing or tempdir + ignore_existing = opts.reuse_existing is None + manually_included_paths = opts.manually_included_paths + if ( + manually_included_paths + and len(manually_included_paths) == 1 + and os.path.exists(manually_included_paths[0]) + ): + manually_included_paths = [l.strip() for l in open(manually_included_paths[0])] + manually_excluded_paths = opts.manually_excluded_paths + if ( + manually_excluded_paths + and len(manually_excluded_paths) == 1 + and os.path.exists(manually_excluded_paths[0]) + ): + manually_excluded_paths = [l.strip() for l in open(manually_excluded_paths[0])] + + msg = ( + "----- TaggerProcessor -----\n" + f"source: {source_prefix}\n" + f"destination: {destination_prefix}\n" + f"scratch: {tempdir}\n" + f"taggers: {', '.join(opts.taggers)}\n" + f"parallel: {opts.parallel}\n" + f"debug: {opts.debug}\n" + f"skip on fail: {opts.skip_on_failure}\n" + f"reuse prev: {not ignore_existing}\n" + f"workdir: {metadata_workdir}\n" + f"safe mode: {opts.safe_mode}\n" + f"local cache: {local_read_cache}\n" + f"file regex: {opts.files_regex_pattern}\n" + "---------------------------\n" + ) + print(msg) + + parallel_compute = cls( + source_prefix=source_prefix, + destination_prefix=destination_prefix, + metadata_prefix=metadata_workdir, + num_processes=opts.parallel, + ignore_existing=ignore_existing, + debug=opts.debug, + include_paths=opts.manually_included_paths, + exclude_paths=opts.manually_excluded_paths, + files_regex_pattern=opts.files_regex_pattern, + ) + parallel_compute( + taggers_names=opts.taggers, + experiment_name=opts.experiment_name, + skip_on_failure=opts.skip_on_failure, + local_read_cache=local_read_cache, + retry_on_read_error=opts.retry_on_read_error, + ) diff --git a/python/dolma/core/taggers.py b/python/dolma/core/taggers.py new file mode 100644 index 00000000..1790191a --- /dev/null +++ b/python/dolma/core/taggers.py @@ -0,0 +1,35 @@ +""" + +Filters. + +@kylel, @soldni + +""" +from abc import abstractmethod +from typing import Dict, List, Union + +from .data_types import DocResult, Document, InputSpec + + +class BaseTagger: + @classmethod + def train(cls, *args, **kwargs): + raise RuntimeError("This tagger does not support training") + + @classmethod + def test(cls, *args, **kwargs): + raise RuntimeError("This tagger does not support testing") + + @abstractmethod + def predict(self, doc: Document) -> DocResult: + raise NotImplementedError + + def tag(self, row: InputSpec) -> Dict[str, List[List[Union[int, float]]]]: + """Internal function that is used by the tagger to get data""" + doc = Document(source=row.source, version=row.version, id=row.id, text=row.text) + doc_result = self.predict(doc) + + tagger_output: Dict[str, list] = {} + for span in doc_result.spans: + tagger_output.setdefault(span.type, []).append([span.start, span.end, round(float(span.score), 5)]) + return tagger_output diff --git a/python/dolma/core/trainer.py b/python/dolma/core/trainer.py new file mode 100644 index 00000000..b1658be7 --- /dev/null +++ b/python/dolma/core/trainer.py @@ -0,0 +1,7 @@ +""" + +Code to train a Filter. + +@kylel + +""" diff --git a/python/dolma/core/utils.py b/python/dolma/core/utils.py new file mode 100644 index 00000000..35278fc6 --- /dev/null +++ b/python/dolma/core/utils.py @@ -0,0 +1,43 @@ +import re +import string +from typing import List + +import blingfire + +from .data_types import TextSlice + + +def make_variable_name(name: str, remove_multiple_underscores: bool = False) -> str: + # use underscores for any non-valid characters in variable name + name = re.sub(r"[^a-zA-Z0-9_]", "_", name) + + if remove_multiple_underscores: + # replace multiple underscores with a single underscore + name = re.sub(r"__+", "_", name) + + if name[0] in string.digits: + raise ValueError(f"Invalid variable name {name}") + + return name + + +def split_paragraphs(text: str) -> List[TextSlice]: + """ + Split a string into paragraphs. A paragraph is defined as a sequence of zero or more characters, followed + by a newline character, or a sequence of one or more characters, followed by the end of the string. + """ + return [ + TextSlice(doc=text, start=match.start(), end=match.end()) + for match in re.finditer(r"([^\n]*\n|[^\n]+$)", text) + ] + + +def split_sentences(text: str) -> List[TextSlice]: + """ + Split a string into sentences. + """ + if text: + _, offsets = blingfire.text_to_sentences_and_offsets(text) + else: + offsets = [] + return [TextSlice(doc=text, start=start, end=end) for (start, end) in offsets] diff --git a/python/dolma/core/vizualizer.py b/python/dolma/core/vizualizer.py new file mode 100644 index 00000000..9afc00b6 --- /dev/null +++ b/python/dolma/core/vizualizer.py @@ -0,0 +1,252 @@ +import argparse +import json +import os +from contextlib import ExitStack +from typing import Dict, List, Optional + +import msgspec +import yaml +from smashed.utils.io_utils import ( + decompress_stream, + recursively_list_files, + stream_file_for_read, +) +from termcolor import colored + +from .data_types import DocResult, InputSpec, OutputSpec + + +class Visualizer: + BASE_S3_PREFIX = "s3://ai2-llm/pretraining-data/sources" + + def __init__( + self, + dataset: str, + experiment: Optional[str] = None, + tagger: Optional[str] = None, + type: Optional[str] = None, + ): + self.dataset = dataset + self.experiment = experiment + self.tagger = tagger + self.type = type + + def list_tags(self, path: str): + prefix, doc_path = path.split("/documents/") + + attrs_decoder = msgspec.json.Decoder(OutputSpec) + doc_decoder = msgspec.json.Decoder(InputSpec) + + with ExitStack() as stack: + doc_file = stack.enter_context(stream_file_for_read(path, "rb")) + doc_stream = stack.enter_context(decompress_stream(doc_file, "rt")) + exp_path = f"{prefix}/attributes/{self.experiment}/{doc_path}" + exp_file = stack.enter_context(stream_file_for_read(exp_path, "rb")) + exp_stream = stack.enter_context(decompress_stream(exp_file, "rt")) + + tags: Dict[str, List[str]] = {} + for doc_line, exp_line in zip(doc_stream, exp_stream): + # parse out data from the line + input_doc = doc_decoder.decode(doc_line) + input_exp = attrs_decoder.decode(exp_line) + doc_result = DocResult.from_spec(input_doc, input_exp) + + for span in doc_result.spans: + tags.setdefault(str(span.tagger), []).append(span.type) + + break + + print(colored(f"from {self.short_path(path)}:", color="yellow")) + for tagger, types in sorted(tags.items()): + print(colored(f"{tagger}:", color="magenta")) + for type in sorted(set(types)): + print(colored(f" {type}", color="cyan")) + print() + + def short_path(self, path: str, slack: int = 20) -> str: + return f"...{path[-s:]}" if (s := round(os.get_terminal_size().columns - slack)) < len(path) else path + + def visualize_single(self, path: str): + prefix, doc_path = path.split("/documents/") + + attrs_decoder = msgspec.json.Decoder(OutputSpec) + doc_decoder = msgspec.json.Decoder(InputSpec) + + with ExitStack() as stack: + doc_file = stack.enter_context(stream_file_for_read(path, "rb")) + doc_stream = stack.enter_context(decompress_stream(doc_file, "rt")) + exp_path = f"{prefix}/attributes/{self.experiment}/{doc_path}" + exp_file = stack.enter_context(stream_file_for_read(exp_path, "rb")) + exp_stream = stack.enter_context(decompress_stream(exp_file, "rt")) + + i = 0 + file_header = colored(f"file: {self.short_path(path)}\n", color="magenta") + + for doc_line, exp_line in zip(doc_stream, exp_stream): + # parse out data from the line + input_doc = doc_decoder.decode(doc_line) + input_exp = attrs_decoder.decode(exp_line) + doc_result = DocResult.from_spec(input_doc, input_exp) + + example_header = colored(f"example: {i:,}\n", color="yellow") + dt = doc_result.doc.text + + spans = sorted( + (s for s in doc_result.spans if s.tagger == self.tagger and s.type == self.type), + key=lambda s: s.start, + ) + if not spans: + continue + + prev_start = 0 + text_fragments = [] + for span in spans: + text_fragments.append(colored(dt[prev_start : span.start].replace("\n", "\\n"), color="black")) + text_fragments.append(colored(dt[span.start : span.end].replace("\n", "\\n"), color="green")) + text_fragments.append(colored(f"{{{span.type}: {span.score}}}", color="red")) + prev_start = span.end + text_fragments.append(colored(dt[prev_start:].replace("\n", "\\n"), color="black")) + + tagger_header = colored(f"tagger: {self.tagger}\n", color="cyan") + print("".join(text_fragments) + "\n" + file_header + example_header + tagger_header + "\n") + while True: + out = input("next? [l/f] ").lower().strip() + if out == "l": + i += 1 + break + elif out == "f": + return + else: + print(f"invalid input: {out}; choose between next (l)ine or next (f)ile.") + + def __call__(self): + try: + source_prefix = f"{self.BASE_S3_PREFIX}/{self.dataset}/documents" + for path in recursively_list_files(source_prefix): + # just list tags if no tagger or type is specified + if self.tagger is None or self.type is None: + self.list_tags(path) + return + + # visualize the specified tagger and type + self.visualize_single(path) + except KeyboardInterrupt: + print("\nExiting... bye!") + + @classmethod + def main(cls): + ap = argparse.ArgumentParser() + ap.add_argument( + "-d", + "--dataset", + required=True, + help="Dataset to visualize", + ) + ap.add_argument( + "-e", + "--experiment-name", + required=True, + help="Experiment name to visualize", + ) + ap.add_argument( + "-t", + "--tagger", + type=str, + default=None, + help="Tagger to visualize", + ) + ap.add_argument( + "-y", + "--type", + type=str, + default=None, + help="Type to visualize", + ) + opts = ap.parse_args() + cls(dataset=opts.dataset, experiment=opts.experiment_name, tagger=opts.tagger, type=opts.type)() + + +class RawPreviewer: + BASE_S3_PREFIX = "s3://ai2-llm/pretraining-data/sources" + + def __init__(self, dataset: str, type: str, file: str, pretty: bool = False, experiment: Optional[str] = None): + self.dataset = dataset + self.experiment = experiment + self.type = type + self.file = file + self.pretty = pretty + + assert type == "documents" or experiment is not None, "Must specify experiment for attributes" + + def preview_file(self): + if self.type == "documents": + path = f"{self.BASE_S3_PREFIX}/{self.dataset}/documents/{self.file}" + else: + path = f"{self.BASE_S3_PREFIX}/{self.dataset}/attributes/{self.experiment}/{self.file}" + + with ExitStack() as stack: + file = stack.enter_context(stream_file_for_read(path, "rb")) + stream = stack.enter_context(decompress_stream(file, "rt")) + + list_colors = ["red", "green", "blue", "magenta", "cyan"] + + for line in stream: + row = json.loads(line) + if self.pretty: + out = yaml.dump(row, width=float("inf")) + for ln in out.split("\n"): + if not ln.startswith(" "): + key, *rest = ln.split(":") + rest = (":" + ":".join(rest) if rest else "").strip() + print(colored(key, color=list_colors[0]) + rest) + list_colors = list_colors[1:] + list_colors[:1] + else: + print(ln) + else: + print(row) + input(colored("\n[press enter for next line]", color="yellow")) + + def list_files(self): + prefix = f"{self.BASE_S3_PREFIX}/{self.dataset}/documents" + for path in recursively_list_files(prefix): + print(path[len(prefix) + 1 :]) + + def __call__(self): + try: + if self.file is not None: + self.preview_file() + else: + self.list_files() + except KeyboardInterrupt: + print("\nExiting... bye!") + + @classmethod + def main(cls): + ap = argparse.ArgumentParser() + ap.add_argument("-d", "--dataset", required=True, help="Dataset to preview, e.g. `wikipedia/v0`") + ap.add_argument( + "-t", + "--type", + choices=["documents", "attributes"], + required=True, + help="Type of data to preview; it can be either `documents` or `attributes`.", + ) + ap.add_argument( + "-e", + "--experiment", + default=None, + help="Experiment to preview; this is only used for previewing `attributes` assigned by tagger.", + ) + ap.add_argument( + "-f", + "--file", + default=None, + type=str, + help="File to preview; if not sure which file to preview, skip this argument to list all files.", + ) + ap.add_argument( + "-p", "--pretty", action="store_true", help="Whether to use pretty print for previewing JSON lines." + ) + opts = ap.parse_args() + + cls(dataset=opts.dataset, type=opts.type, file=opts.file, experiment=opts.experiment, pretty=opts.pretty)() diff --git a/python/dolma/data/naughty_words_en.txt b/python/dolma/data/naughty_words_en.txt new file mode 100644 index 00000000..a438b9ca --- /dev/null +++ b/python/dolma/data/naughty_words_en.txt @@ -0,0 +1,403 @@ +2g1c +2 girls 1 cup +acrotomophilia +alabama hot pocket +alaskan pipeline +anal +anilingus +anus +apeshit +arsehole +ass +asshole +assmunch +auto erotic +autoerotic +babeland +baby batter +baby juice +ball gag +ball gravy +ball kicking +ball licking +ball sack +ball sucking +bangbros +bangbus +bareback +barely legal +barenaked +bastard +bastardo +bastinado +bbw +bdsm +beaner +beaners +beaver cleaver +beaver lips +beastiality +bestiality +big black +big breasts +big knockers +big tits +bimbos +birdlock +bitch +bitches +black cock +blonde action +blonde on blonde action +blowjob +blow job +blow your load +blue waffle +blumpkin +bollocks +bondage +boner +boob +boobs +booty call +brown showers +brunette action +bukkake +bulldyke +bullet vibe +bullshit +bung hole +bunghole +busty +butt +buttcheeks +butthole +camel toe +camgirl +camslut +camwhore +carpet muncher +carpetmuncher +chocolate rosebuds +cialis +circlejerk +cleveland steamer +clit +clitoris +clover clamps +clusterfuck +cock +cocks +coprolagnia +coprophilia +cornhole +coon +coons +creampie +cum +cumming +cumshot +cumshots +cunnilingus +cunt +darkie +date rape +daterape +deep throat +deepthroat +dendrophilia +dick +dildo +dingleberry +dingleberries +dirty pillows +dirty sanchez +doggie style +doggiestyle +doggy style +doggystyle +dog style +dolcett +domination +dominatrix +dommes +donkey punch +double dong +double penetration +dp action +dry hump +dvda +eat my ass +ecchi +ejaculation +erotic +erotism +escort +eunuch +fag +faggot +fecal +felch +fellatio +feltch +female squirting +femdom +figging +fingerbang +fingering +fisting +foot fetish +footjob +frotting +fuck +fuck buttons +fuckin +fucking +fucktards +fudge packer +fudgepacker +futanari +gangbang +gang bang +gay sex +genitals +giant cock +girl on +girl on top +girls gone wild +goatcx +goatse +god damn +gokkun +golden shower +goodpoop +goo girl +goregasm +grope +group sex +g-spot +guro +hand job +handjob +hard core +hardcore +hentai +homoerotic +honkey +hooker +horny +hot carl +hot chick +how to kill +how to murder +huge fat +humping +incest +intercourse +jack off +jail bait +jailbait +jelly donut +jerk off +jigaboo +jiggaboo +jiggerboo +jizz +juggs +kike +kinbaku +kinkster +kinky +knobbing +leather restraint +leather straight jacket +lemon party +livesex +lolita +lovemaking +make me come +male squirting +masturbate +masturbating +masturbation +menage a trois +milf +missionary position +mong +motherfucker +mound of venus +mr hands +muff diver +muffdiving +nambla +nawashi +negro +neonazi +nigga +nigger +nig nog +nimphomania +nipple +nipples +nsfw +nsfw images +nude +nudity +nutten +nympho +nymphomania +octopussy +omorashi +one cup two girls +one guy one jar +orgasm +orgy +paedophile +paki +panties +panty +pedobear +pedophile +pegging +penis +phone sex +piece of shit +pikey +pissing +piss pig +pisspig +playboy +pleasure chest +pole smoker +ponyplay +poof +poon +poontang +punany +poop chute +poopchute +porn +porno +pornography +prince albert piercing +pthc +pubes +pussy +queaf +queef +quim +raghead +raging boner +rape +raping +rapist +rectum +reverse cowgirl +rimjob +rimming +rosy palm +rosy palm and her 5 sisters +rusty trombone +sadism +santorum +scat +schlong +scissoring +semen +sex +sexcam +sexo +sexy +sexual +sexually +sexuality +shaved beaver +shaved pussy +shemale +shibari +shit +shitblimp +shitty +shota +shrimping +skeet +slanteye +slut +s&m +smut +snatch +snowballing +sodomize +sodomy +spastic +spic +splooge +splooge moose +spooge +spread legs +spunk +strap on +strapon +strappado +strip club +style doggy +suck +sucks +suicide girls +sultry women +swastika +swinger +tainted love +taste my +tea bagging +threesome +throating +thumbzilla +tied up +tight white +tit +tits +titties +titty +tongue in a +topless +tosser +towelhead +tranny +tribadism +tub girl +tubgirl +tushy +twat +twink +twinkie +two girls one cup +undressing +upskirt +urethra play +urophilia +vagina +venus mound +viagra +vibrator +violet wand +vorarephilia +voyeur +voyeurweb +voyuer +vulva +wank +wetback +wet dream +white power +whore +worldsex +wrapping men +wrinkled starfish +xx +xxx +yaoi +yellow showers +yiffy +zoophilia +๐Ÿ–• diff --git a/python/dolma/taggers/__init__.py b/python/dolma/taggers/__init__.py new file mode 100644 index 00000000..506fc1d3 --- /dev/null +++ b/python/dolma/taggers/__init__.py @@ -0,0 +1,8 @@ +from .c4 import * # noqa: F403 +from .code import * # noqa: F403 +from .gopher import * # noqa: F403 +from .jigsaw import * # noqa: F403 +from .language import * # noqa: F403 +from .length import * # noqa: F403 +from .pii import * # noqa: F403 +from .sampling import * # noqa: F403 diff --git a/python/dolma/taggers/c4.py b/python/dolma/taggers/c4.py new file mode 100644 index 00000000..b8f484ab --- /dev/null +++ b/python/dolma/taggers/c4.py @@ -0,0 +1,87 @@ +import logging +import os +from dataclasses import dataclass +from typing import List, Set + +from ..core.data_types import DocResult, Document, Span +from ..core.registry import TaggerRegistry +from ..core.taggers import BaseTagger + +MIN_WORDS_PER_LINE = 3 +naughty_lines = ( + open(os.path.join(os.path.dirname(os.path.dirname(__file__)), "data", "naughty_words_en.txt")) + .read() + .splitlines() +) +NAUGHTY_WORDS: Set[str] = set(w for w in naughty_lines if " " not in w) +NAUGHTY_PHRASES: Set[str] = set(w for w in naughty_lines if " " in w) + + +@dataclass +class C4Attributes: + lines_with_no_ending_punctuation: List[Span] + lines_with_too_few_words: List[Span] + has_naughty_word: bool = False + has_javascript: bool = False + has_lorem_ipsum: bool = False + has_curly_brace: bool = False + line_count: int = 0 + character_count: int = 0 + + def as_spans(self) -> List[Span]: + spans = [] + spans.extend(self.lines_with_no_ending_punctuation) + spans.extend(self.lines_with_too_few_words) + if self.has_naughty_word: + spans.append(Span(0, self.character_count, type="has_naughty_word")) + if self.has_javascript: + spans.append(Span(0, self.character_count, type="has_javascript")) + if self.has_lorem_ipsum: + spans.append(Span(0, self.character_count, type="has_lorem_ipsum")) + if self.has_curly_brace: + spans.append(Span(0, self.character_count, type="has_curly_brace")) + spans.append(Span(0, self.character_count, type="line_count", score=self.line_count)) + return spans + + +def get_attributes(text: str) -> C4Attributes: + attrs = C4Attributes([], []) + attrs.character_count = len(text) + try: + lines = text.split("\n") + attrs.line_count = len(lines) + offset = 0 + for line_no in range(0, len(lines)): + original_line = lines[line_no] + end_offset = offset + len(original_line) + if line_no < len(lines) - 1: + end_offset += 1 + line = original_line.lower().strip() + if not line.endswith((".", "?", "!", '"')): + attrs.lines_with_no_ending_punctuation.append( + Span(offset, end_offset, type="lines_with_no_ending_punctuation") + ) + words = line.split() + if len(words) < MIN_WORDS_PER_LINE: + attrs.lines_with_too_few_words.append(Span(offset, end_offset, type="lines_with_too_few_words")) + if any(word in NAUGHTY_WORDS for word in words) or any(phrase in line for phrase in NAUGHTY_PHRASES): + attrs.has_naughty_word = True + if any(word == "javascript" for word in words): + attrs.has_javascript = True + if "lorem ipsum" in line: + attrs.has_lorem_ipsum = True + if "{" in line: + attrs.has_curly_brace = True + offset = end_offset + except Exception: + logging.exception(f"Error parsing text: {text[:200]}") + + return attrs + + +@TaggerRegistry.add("c4_v1") +class C4Tagger(BaseTagger): + def predict(self, doc: Document) -> DocResult: + attrs = get_attributes(doc.text) + result = DocResult(doc=doc, spans=attrs.as_spans()) + return result diff --git a/python/dolma/taggers/code.py b/python/dolma/taggers/code.py new file mode 100644 index 00000000..714ca33a --- /dev/null +++ b/python/dolma/taggers/code.py @@ -0,0 +1,199 @@ +""" + +Code-related taggers. + +@akshitab + +""" +import logging +import re +from typing import Generator, List + +import numpy as np +import regex +from detect_secrets import SecretsCollection +from detect_secrets.core.scan import ( + PotentialSecret, + _process_line_based_plugins, + get_plugins, +) +from detect_secrets.settings import default_settings + +from ..core.data_types import DocResult, Document, Span +from ..core.registry import TaggerRegistry +from ..core.taggers import BaseTagger + +logger = logging.getLogger(__name__) + + +def scan_code(code: str) -> Generator[PotentialSecret, None, None]: + if not get_plugins(): + logger.error("No plugins to scan with!") + return + + has_secret = False + for lines in [code.splitlines()]: + for secret in _process_line_based_plugins( + lines=list(enumerate(lines, start=1)), + filename="code_str.yml", + ): + has_secret = True + yield secret + + if has_secret: + break + + +class SecretsCollectionForStringInput(SecretsCollection): + def scan_str(self, code_str: str): + for secret in scan_code(code_str): + self["code_str.yml"].add(secret) + + +def get_secrets(code: str): + secrets = SecretsCollectionForStringInput() + with default_settings(): + secrets.scan_str(code) + + return secrets + + +@TaggerRegistry.add("code_secrets_v1") +class CodeSecretsTagger(BaseTagger): + @classmethod + def _extract_code_secrets(cls, text: str) -> List[Span]: + secrets_spans: List[Span] = [] + + text_lines = text.splitlines() + secrets = get_secrets(text) + for _, secret in secrets: + line_number = secret.line_number - 1 + span = secret.secret_value + span_line = text_lines[line_number] + line_start = text.find(span_line) + start = line_start + span_line.find(span) + end = start + len(span) + assert text[start:end] == span + secret_type = secret.type.replace(" ", "_") + secrets_spans.append(Span(start=start, end=end, type=f"SECRET_{secret_type}")) # , span]) + + return secrets_spans + + def predict(self, doc: Document) -> DocResult: + """Main runner.""" + spans = self._extract_code_secrets(doc.text) + + # document-level score + score = self._score(text=doc.text, secrets_spans=spans) + spans.append(Span(start=0, end=len(doc.text), type="doc", score=score)) + return DocResult(doc=doc, spans=spans) + + def _score(self, text: str, secrets_spans: List[Span]) -> float: + try: + score = len(secrets_spans) * 1.0 / len(text.split()) + except ZeroDivisionError: + score = -1.0 + return score + + +@TaggerRegistry.add("code_copyright_comments_v1") +class CodeCopyrightTagger(BaseTagger): + """ + Based on RedPajama code filtering. + """ + + def __init__(self): + self.cpat = re.compile("copyright", re.IGNORECASE) + self.pat = re.compile("/\\*[^*]*\\*+(?:[^/*][^*]*\\*+)*/") + + def _extract_copyright_spans(self, text: str) -> List[Span]: + copyright_spans: List[Span] = [] + + reg = self.pat.search(text) + + if reg: + # found one, now see if it contains "copyright", if so strip it + span = reg.span() + sub = text[span[0] : span[1]] + if self.cpat.search(sub): + copyright_spans.append(Span(start=span[0], end=span[1], type="copyright_notice", score=1.0)) + return copyright_spans + + lines = text.split("\n") + skip = 0 + # Greedy replace any file that begins with comment block, most + # are copyright headers + end = 0 + for k in range(len(lines)): + if lines[k].startswith("//") or lines[k].startswith("#") or lines[k].startswith("--") or not lines[k]: + skip = skip + 1 + if not lines[k]: + end += 1 + else: + end += len(lines[k]) + else: + break + + if skip: + copyright_spans.append(Span(start=0, end=end, type="comment_block", score=1.0)) + return copyright_spans + + def predict(self, doc: Document) -> DocResult: + """Main runner.""" + spans = self._extract_copyright_spans(doc.text) + + # document-level score + score = self._score(text=doc.text, copyright_spans=spans) + spans.append(Span(start=0, end=len(doc.text), type="doc", score=score)) + return DocResult(doc=doc, spans=spans) + + def _score(self, text: str, copyright_spans: List[Span]) -> float: + try: + if len(copyright_spans) == 0: + score = 0.0 + else: + span = copyright_spans[0] + # percentage of content affected + score = (span.end - span.start + 1) * 1.0 / len(text) + except ZeroDivisionError: + score = -1.0 + return score + + +@TaggerRegistry.add("code_redpajama_taggers_v1") +class CodeRedPajamaTaggers(BaseTagger): + """ + Based on RedPajama code filtering. + """ + + WHITESPACE_REGEX = regex.compile(r"\w+|[^\w\s]+") + + def _get_num_tokens(self, text: str): + return len(self.WHITESPACE_REGEX.split(text)) + + def predict(self, doc: Document) -> DocResult: + """Main runner.""" + + spans: List[Span] = [] + + doc_length = len(doc.text) + + line_lengths = list(map(len, doc.text.splitlines())) + + max_line_length = max(line_lengths, default=0.0) + avg_line_length = np.mean(line_lengths) if len(line_lengths) > 0 else 0.0 + + alnum_count = sum(map(lambda char: 1 if char.isalnum() else 0, doc.text)) + alnum_prop = (alnum_count / doc_length) if doc_length > 0 else 0.0 + + num_tokens = self._get_num_tokens(doc.text) + num_alpha = len([c for c in doc.text if c.isalpha()]) + alpha_token_prop = (num_alpha / num_tokens) if num_tokens > 0 else 0.0 + + # document-level scores + spans.append(Span(start=0, end=doc_length, type="max_line_length_doc", score=max_line_length)) + spans.append(Span(start=0, end=doc_length, type="avg_line_length_doc", score=avg_line_length)) + spans.append(Span(start=0, end=doc_length, type="alnum_prop_doc", score=alnum_prop)) + spans.append(Span(start=0, end=doc_length, type="alpha_token_prop_doc", score=alpha_token_prop)) + + return DocResult(doc=doc, spans=spans) diff --git a/python/dolma/taggers/gopher.py b/python/dolma/taggers/gopher.py new file mode 100644 index 00000000..223cd878 --- /dev/null +++ b/python/dolma/taggers/gopher.py @@ -0,0 +1,172 @@ +import logging +from collections import Counter +from dataclasses import dataclass +from statistics import median +from typing import Counter as CounterType +from typing import List, Tuple + +from ..core.data_types import DocResult, Document, Span +from ..core.registry import TaggerRegistry +from ..core.taggers import BaseTagger + +REQUIRED_ENGLISH_WORDS = {"the", "be", "to", "of", "and", "that", "have", "with"} +SYMBOLS = {"#", "\u2026"} +BULLET_POINTS = {"*", "-"} + + +@dataclass +class GopherAttributes: + fraction_of_characters_in_most_common_ngram: List[Tuple[int, float]] + fraction_of_characters_in_duplicate_ngrams: List[Tuple[int, float]] + character_count: int = 0 + word_count: int = 0 + median_word_length: float = False + symbol_to_word_ratio: float = 0.0 + fraction_of_words_with_alpha_character: float = 0.0 + required_word_count: int = 0 + fraction_of_lines_starting_with_bullet_point: float = 0.0 + fraction_of_lines_ending_with_ellipsis: float = 0.0 + fraction_of_duplicate_lines: float = 0.0 + fraction_of_characters_in_duplicate_lines: float = 0.0 + + def as_spans(self) -> List[Span]: + spans = [] + spans.extend( + [ + Span(0, self.character_count, f"fraction_of_characters_in_most_common_{n}grams", v) + for n, v in self.fraction_of_characters_in_most_common_ngram + ] + ) + spans.extend( + [ + Span(0, self.character_count, f"fraction_of_characters_in_duplicate_{n}grams", v) + for n, v in self.fraction_of_characters_in_duplicate_ngrams + ] + ) + spans.append(Span(0, self.character_count, type="character_count", score=self.character_count)) + spans.append(Span(0, self.character_count, type="word_count", score=self.word_count)) + spans.append(Span(0, self.character_count, type="median_word_length", score=self.median_word_length)) + spans.append(Span(0, self.character_count, type="symbol_to_word_ratio", score=self.symbol_to_word_ratio)) + spans.append( + Span( + 0, + self.character_count, + type="fraction_of_words_with_alpha_character", + score=self.fraction_of_words_with_alpha_character, + ) + ) + spans.append(Span(0, self.character_count, type="required_word_count", score=self.required_word_count)) + spans.append( + Span( + 0, + self.character_count, + type="fraction_of_lines_starting_with_bullet_point", + score=self.fraction_of_lines_starting_with_bullet_point, + ) + ) + spans.append( + Span( + 0, + self.character_count, + type="fraction_of_lines_ending_with_ellipsis", + score=self.fraction_of_lines_ending_with_ellipsis, + ) + ) + spans.append( + Span( + 0, self.character_count, type="fraction_of_duplicate_lines", score=self.fraction_of_duplicate_lines + ) + ) + spans.append( + Span( + 0, + self.character_count, + type="fraction_of_characters_in_duplicate_lines", + score=self.fraction_of_characters_in_duplicate_lines, + ) + ) + return spans + + +def get_attributes(text: str) -> GopherAttributes: + attrs = GopherAttributes([], []) + attrs.character_count = len(text) + if attrs.character_count == 0: + return attrs + + try: + words = text.split() + word_count = len(words) + character_count = sum(len(word) for word in words) + + attrs.word_count = word_count + attrs.median_word_length = median([len(word) for word in words]) + attrs.symbol_to_word_ratio = sum(1 for word in words if any(s in word for s in SYMBOLS)) / word_count + attrs.fraction_of_words_with_alpha_character = ( + sum(1 for word in words if any(c.isalpha() for c in word)) / word_count + ) + attrs.required_word_count = sum(1 for word in words if word in REQUIRED_ENGLISH_WORDS) + + all_counts = all_ngram_counts(words) + + count_most_common_ngrams = {2, 3, 4} + for n, ngram_counts in all_counts: + if not ngram_counts: + continue + if n in count_most_common_ngrams: + most_common_ngram, count = ngram_counts.most_common(1)[0] + value = count * sum(len(w) for w in most_common_ngram) / character_count + attrs.fraction_of_characters_in_most_common_ngram.append((n, value)) + else: + ng_char_count = sum(count * sum(len(w) for w in ng) for ng, count in ngram_counts.items()) + value = ( + sum(count * sum(len(w) for w in ng) for ng, count in ngram_counts.items() if count > 1) + / ng_char_count + ) + attrs.fraction_of_characters_in_duplicate_ngrams.append((n, value)) + + lines = text.split("\n") + line_count = len(lines) + for line in lines: + if any(line.startswith(s) for s in BULLET_POINTS): + attrs.fraction_of_lines_starting_with_bullet_point += 1 + if line.endswith("\u2026"): + attrs.fraction_of_lines_ending_with_ellipsis += 1 + attrs.fraction_of_lines_starting_with_bullet_point /= line_count + attrs.fraction_of_lines_ending_with_ellipsis /= line_count + + line_counts = Counter(lines) + attrs.fraction_of_duplicate_lines = ( + sum(count for line, count in line_counts.items() if count > 1) / line_count + ) + attrs.fraction_of_characters_in_duplicate_lines = ( + sum(len(line) * count for line, count in line_counts.items() if count > 1) / character_count + ) + except Exception as e: + logging.exception(f"Error processing text {e}: {text[:200]}") + + return attrs + + +def all_ngram_counts(words) -> List[Tuple[int, CounterType[Tuple[str, ...]]]]: + return [(n, Counter(list(zip(*[words[i:] for i in range(n)])))) for n in range(2, 11)] + + +def all_ngram_counts_alt(words: List[str]) -> List[Tuple[int, CounterType[Tuple[str, ...]]]]: + """Seems like it should be faster, but isn't""" + ngram: List[Tuple[str, ...]] = list(zip(words, words[1:])) + all_counts = [(2, Counter(ngram))] + + for n in range(3, 11): + ngram = list(a + (b,) for a, b in zip(ngram, words[n - 1 :])) + all_counts.append((n, Counter(ngram))) + + return all_counts + + +@TaggerRegistry.add("gopher_v1") +class GopherTagger(BaseTagger): + def predict(self, doc: Document) -> DocResult: + attrs = get_attributes(doc.text) + result = DocResult(doc=doc, spans=attrs.as_spans()) + return result diff --git a/python/dolma/taggers/jigsaw.py b/python/dolma/taggers/jigsaw.py new file mode 100644 index 00000000..037b3c22 --- /dev/null +++ b/python/dolma/taggers/jigsaw.py @@ -0,0 +1,45 @@ +""" + +Filters. + +@kylel, @soldni + +""" + +from typing import Iterable + +from ..core.data_types import TextSlice +from ..core.ft_tagger import BaseFastTextTagger, Prediction +from ..core.registry import TaggerRegistry + + +@TaggerRegistry.add("jigsaw_hatespeech_document_v2") +class FastTextJigsawHatespeechDocumentTagger(BaseFastTextTagger): + MODEL_PATH = "https://ai2-s2-research-public.s3.us-west-2.amazonaws.com/aakankshan/olmo-data-filters/jigsaw_fasttext_bigrams_hatespeech_final.bin" # noqa: E501 + + def __init__(self): + super().__init__(model_path=self.MODEL_PATH, model_mode=self.DOCUMENT_LEVEL_TAGGER) + + def predict_slice(self, text_slice: TextSlice) -> Iterable[Prediction]: + labels, probs = self.classifier.predict(text_slice.text.replace("\n", " ").strip(), k=-1) + label_index = 1 if "non" in labels[0] else 0 # pyright: ignore + return ( + Prediction(label=labels[label_index], score=probs[label_index]), + Prediction(label=labels[1 - label_index], score=probs[1 - label_index]), + ) + + +@TaggerRegistry.add("jigsaw_hatespeech_sentence_v2") +class FastTextJigsawHatespeechSentenceTagger(FastTextJigsawHatespeechDocumentTagger): + def __init__(self): + BaseFastTextTagger.__init__(self, model_path=self.MODEL_PATH, model_mode=self.SENTENCE_LEVEL_TAGGER) + + +@TaggerRegistry.add("jigsaw_nsfw_document_v1") +class FastTextJigsawNsfwDocumentTagger(FastTextJigsawHatespeechDocumentTagger): + MODEL_PATH = "https://ai2-s2-research-public.s3.us-west-2.amazonaws.com/aakankshan/olmo-data-filters/jigsaw_fasttext_bigrams_nsfw_final.bin" # noqa: E501 + + +@TaggerRegistry.add("jigsaw_nsfw_sencence_v2") +class FastTextJigsawNsfwSentenceTagger(FastTextJigsawHatespeechSentenceTagger): + MODEL_PATH = "https://ai2-s2-research-public.s3.us-west-2.amazonaws.com/aakankshan/olmo-data-filters/jigsaw_fasttext_bigrams_nsfw_final.bin" # noqa: E501 diff --git a/python/dolma/taggers/language.py b/python/dolma/taggers/language.py new file mode 100644 index 00000000..f8e3b7a5 --- /dev/null +++ b/python/dolma/taggers/language.py @@ -0,0 +1,156 @@ +""" + +Filters. + +@kylel, @soldni + +""" +from typing import Iterable, List, Tuple + +import cld3 +import pycld2 as cld2 +import regex +from unidecode import unidecode + +from ..core.data_types import DocResult, Document, Span, TextSlice +from ..core.ft_tagger import BaseFastTextTagger, Prediction +from ..core.registry import TaggerRegistry +from ..core.taggers import BaseTagger +from ..core.utils import split_paragraphs + + +@TaggerRegistry.add("cld3_en_doc_v2") +class Cld3LanguageTagger(BaseTagger): + def _predict_text(self, text: str) -> Tuple[str, float]: + pred = cld3.get_language(text) # pyright: ignore + score = pred.probability if pred.language == "en" else 0.0 + return "en", score + + def predict(self, doc: Document) -> DocResult: + lang, score = self._predict_text(doc.text) + positive_span = Span(start=0, end=len(doc.text), type=lang, score=score) + negative_span = Span(start=0, end=len(doc.text), type=f"not_{lang}", score=1.0 - score) + return DocResult(doc=doc, spans=[positive_span, negative_span]) + + +@TaggerRegistry.add("cld3_en_paragraph_v2") +class Cld3LanguageTaggerParagraph(Cld3LanguageTagger): + def predict(self, doc: Document) -> DocResult: + paragraphs = split_paragraphs(doc.text) + spans: List[Span] = [] + for paragraph in paragraphs: + lang, score = self._predict_text(paragraph.text) # pyright: ignore + positive_span = Span(start=paragraph.start, end=paragraph.end, type=lang, score=score) + negative_span = Span(start=paragraph.start, end=paragraph.end, type=f"not_{lang}", score=1.0 - score) + spans.extend((positive_span, negative_span)) + return DocResult(doc=doc, spans=spans) + + +@TaggerRegistry.add("cld2_en_doc_v2") +class Cld2LanguageFilter(BaseTagger): + RE_BAD_CHARS = regex.compile(r"[\p{Cc}\p{Cs}]+") + + def _sanitize_input(self, text: str) -> str: + return self.RE_BAD_CHARS.sub("", text) + + def _to_ascii_input(self, text: str) -> str: + return unidecode(text) + + def _identity_fn(self, text: str) -> str: + return text + + def _predict_text(self, text: str) -> Tuple[str, float]: + details = [] + is_reliable = False + for fn in (self._identity_fn, self._to_ascii_input, self._sanitize_input): + try: + is_reliable, _, details = cld2.detect(fn(text)) + break + except cld2.error: + ... + + score = max([d[2] for d in details if d[0] == "ENGLISH" and is_reliable] or [0]) + return "en", score / 100.0 + + def predict(self, doc: Document) -> DocResult: + lang, score = self._predict_text(doc.text) + positive_span = Span(start=0, end=len(doc.text), type=lang, score=score) + negative_span = Span(start=0, end=len(doc.text), type=f"not_{lang}", score=1.0 - score) + return DocResult(doc=doc, spans=[positive_span, negative_span]) + + +@TaggerRegistry.add("cld2_en_paragraph_v2") +class Cld2LanguageFilterParagraph(Cld2LanguageFilter): + def predict(self, doc: Document) -> DocResult: + paragraphs = split_paragraphs(doc.text) + spans: List[Span] = [] + for paragraph in paragraphs: + lang, score = self._predict_text(paragraph.text) # pyright: ignore + positive_span = Span(start=paragraph.start, end=paragraph.end, type=lang, score=score) + negative_span = Span(start=paragraph.start, end=paragraph.end, type=f"not_{lang}", score=1.0 - score) + spans.extend((positive_span, negative_span)) + return DocResult(doc=doc, spans=spans) + + +@TaggerRegistry.add("ft_lang_id_en_doc_v2") +class FastTextEnglishLanguageDocumentTagger(BaseFastTextTagger): + MODEL_PATH = "https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin" + + def __init__(self): + super().__init__(model_path=self.MODEL_PATH, model_mode=self.DOCUMENT_LEVEL_TAGGER) + + def predict_slice(self, text_slice: TextSlice) -> Iterable[Prediction]: + pred = self.classifier.predict(text_slice.text.lower().replace("\n", " ").strip(), k=-1) + for label, score in zip(*pred): + if label == "__label__en": + return Prediction(label="en", score=score), Prediction(label="not_en", score=1.0 - score) + return Prediction(label="en", score=0.0), Prediction(label="not_en", score=1.0) + + +@TaggerRegistry.add("ft_lang_id_en_paragraph_v2") +class FastTextEnglishLanguageParagraphTagger(FastTextEnglishLanguageDocumentTagger): + def __init__(self): + BaseFastTextTagger.__init__(self, model_path=self.MODEL_PATH, model_mode=self.PARAGRAPH_LEVEL_TAGGER) + + +def add_global_language_score_from_slice_score(result: DocResult) -> DocResult: + # the total document score is # of characters in each "english" span multiplied by the likelihood + # of said span being english + try: + doc_en_score = sum((s.end - s.start) * s.score for s in result.spans if s.type == "en") / len( + result.doc.text + ) + doc_not_en_score = 1 - doc_en_score + except ZeroDivisionError: + doc_en_score = doc_not_en_score = 0.0 + + doc_level = ( + Span(start=0, end=len(result.doc.text), type="doc_en", score=doc_en_score), + Span(start=0, end=len(result.doc.text), type="doc_not_en", score=doc_not_en_score), + ) + result.spans.extend(doc_level) + return result + + +@TaggerRegistry.add("cld2_en_paragraph_with_doc_score_v2") +class Cld2LanguageFilterParagraphWithDocScoreTagger(Cld2LanguageFilterParagraph): + def predict(self, doc: Document) -> DocResult: + doc_result = super().predict(doc) + doc_result = add_global_language_score_from_slice_score(doc_result) + return doc_result + + +@TaggerRegistry.add("cld3_en_paragraph_with_doc_score_v2") +class Cld3LanguageFilterParagraphWithDocScoreTagger(Cld3LanguageTaggerParagraph): + def predict(self, doc: Document) -> DocResult: + doc_result = super().predict(doc) + doc_result = add_global_language_score_from_slice_score(doc_result) + return doc_result + + +@TaggerRegistry.add("ft_lang_id_en_paragraph_with_doc_score_v2") +class FastTextEnglishLanguageParagraphWithDocScoreTagger(FastTextEnglishLanguageParagraphTagger): + def predict(self, doc: Document) -> DocResult: + doc_result = super().predict(doc) + doc_result = add_global_language_score_from_slice_score(doc_result) + return doc_result diff --git a/python/dolma/taggers/length.py b/python/dolma/taggers/length.py new file mode 100644 index 00000000..903d8bfe --- /dev/null +++ b/python/dolma/taggers/length.py @@ -0,0 +1,119 @@ +""" + +Filters. + +@kylel, @soldni + +""" + +import regex +import uniseg.wordbreak +from tokenizers import Regex, pre_tokenizers + +from ..core.data_types import DocResult, Document, Span +from ..core.registry import TaggerRegistry +from ..core.taggers import BaseTagger +from ..core.utils import split_paragraphs + + +@TaggerRegistry.add("char_length_v1") +class CharLengthV1(BaseTagger): + def predict(self, doc: Document) -> DocResult: + score = len(doc.text) + return DocResult(doc=doc, spans=[Span(start=0, end=len(doc.text), type="length", score=score)]) + + +@TaggerRegistry.add("char_length_with_paragraphs_v1") +class CharLengthWithParagraphsV1(BaseTagger): + def predict(self, doc: Document) -> DocResult: + spans = [ + Span(start=p.start, end=p.end, type="paragraph", score=len(p.text)) for p in split_paragraphs(doc.text) + ] + spans.append(Span(start=0, end=len(doc.text), type="document", score=len(doc.text))) + return DocResult(doc=doc, spans=spans) + + +@TaggerRegistry.add("whitespace_tokenizer_v1") +class WhitespaceLengthV1(BaseTagger): + WHITESPACE_REGEX = regex.compile(r"\w+|[^\w\s]+") + + def predict(self, doc: Document) -> DocResult: + score = len(self.WHITESPACE_REGEX.split(doc.text)) + return DocResult(doc=doc, spans=[Span(start=0, end=len(doc.text), type="length", score=score)]) + + +@TaggerRegistry.add("whitespace_tokenizer_with_paragraphs_v1") +class WhitespaceLengthParagraphsV1(WhitespaceLengthV1): + def predict(self, doc: Document) -> DocResult: + spans = [ + Span(start=p.start, end=p.end, type="paragraph", score=len(self.WHITESPACE_REGEX.split(p.text))) + for p in split_paragraphs(doc.text) + ] + spans.append(Span(start=0, end=len(doc.text), type="document", score=sum(s.score for s in spans))) + return DocResult(doc=doc, spans=spans) + + +@TaggerRegistry.add("uniseg_length_paragraphs_v1") +class UnisegParagraphsV1(BaseTagger): + def predict(self, doc: Document) -> DocResult: + spans = [] + for para in split_paragraphs(doc.text): + # we ignore whitespace-only tokens when counting words + para_length = sum(1 for w in uniseg.wordbreak.words(para.text.strip()) if w.strip()) + spans.append(Span(start=para.start, end=para.end, type="paragraph", score=para_length)) + + # we have to record negative length because mixer can only work on greater than operations, + # and we might want to drop paragraphs that are shorter than a certain length n, so we need + # to filter on >= n + spans.append(Span(start=para.start, end=para.end, type="negative_paragraph", score=-para_length)) + return DocResult(doc=doc, spans=spans) + + +@TaggerRegistry.add("uniseg_length_paragraphs_with_doc_length_v1") +class UnisegParagraphsWithDocLengthV1(UnisegParagraphsV1): + def predict(self, doc: Document) -> DocResult: + doc_results = super().predict(doc) + pos_len = sum(s.score for s in doc_results.spans if s.type == "paragraph") + neg_len = sum(s.score for s in doc_results.spans if s.type == "negative_paragraph") + doc_results.spans.append(Span(start=0, end=len(doc.text), type="document", score=pos_len)) + doc_results.spans.append(Span(start=0, end=len(doc.text), type="negative_document", score=neg_len)) + return doc_results + + +@TaggerRegistry.add("olmo_pretokenizer_v1") +class OlmoPreTokenizerV1(BaseTagger): + def __init__(self) -> None: + self.pre_tokenizer = pre_tokenizers.Sequence( # type: ignore + [ + # Split on all punctuation. + pre_tokenizers.Split( + pattern=Regex(" ?[[:punct:]]"), + behavior="isolated", + invert=False, + ), + # Split up digits. + pre_tokenizers.Split( + pattern=Regex(" ?\\d"), + behavior="isolated", + invert=False, + ), + pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=True), + ] + ) + + def predict(self, doc: Document) -> DocResult: + score = len(self.pre_tokenizer.pre_tokenize_str(doc.text)) + return DocResult(doc=doc, spans=[Span(start=0, end=len(doc.text), type="length", score=score)]) + + +@TaggerRegistry.add("olmo_pretokenizer_with_paragraphs_v1") +class OlmoPreTokenizerParagraphsV1(OlmoPreTokenizerV1): + def predict(self, doc: Document) -> DocResult: + spans = [ + Span( + start=p.start, end=p.end, type="paragraph", score=len(self.pre_tokenizer.pre_tokenize_str(p.text)) + ) + for p in split_paragraphs(doc.text) + ] + spans.append(Span(start=0, end=len(doc.text), type="document", score=sum(s.score for s in spans))) + return DocResult(doc=doc, spans=spans) diff --git a/python/dolma/taggers/pii.py b/python/dolma/taggers/pii.py new file mode 100644 index 00000000..a256e093 --- /dev/null +++ b/python/dolma/taggers/pii.py @@ -0,0 +1,291 @@ +""" + +Filters. + +@kylel, @soldni + +""" + +try: + import re2 as re +except ImportError: + import re +else: + re.set_fallback_notification(re.FALLBACK_WARNING) + + +from typing import List +from warnings import warn + +from presidio_analyzer import AnalyzerEngine + +from ..core.data_types import DocResult, Document, Span, TextSlice +from ..core.registry import TaggerRegistry +from ..core.taggers import BaseTagger +from ..core.utils import split_paragraphs + + +class BasePiiFilter(BaseTagger): + EMAIL = "EMAIL_ADDRESS" + PHONE = "PHONE_NUMBER" + IP = "IP_ADDRESS" + + PRESIDIO = "presidio" + REGEX = "regex" + + ENGLISH = "en" + WINDOW = 100 + + def __init__( + self, + method: str, + postprocess: bool, + window: int, + ) -> None: + assert method in [ + self.PRESIDIO, + self.REGEX, + ], f"Please provide a valid method for filtering ({self.PRESIDIO} or {self.REGEX})" + + # configs + self.method = method + self.postprocess = postprocess + self.window = window + + # Regular expressions for different types of PII + self.pii_type_to_regex = { + self.EMAIL: re.compile("[.\\s@,?!;:)(]*([^\\s@]+@[^\\s@,?!;:)(]+?)[.\\s@,?!;:)(]?[\\s\n\r]"), + self.PHONE: re.compile("\\s+\\(?(\\d{3})\\)?[-\\. ]*(\\d{3})[-. ]?(\\d{4})"), + self.IP: re.compile( + "(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)" + ), + } + self.url_regex = re.compile( + "(?i)\b((?:https?://|www\\d{0,3}[.]|[a-z0-9.\\-]+[.][a-z]{2,4}/)(?:[^\\s()<>]+|\\(([^\\s()<>]+|" + "(\\([^\\s()<>]+\\)))*\\))+(?:\\(([^\\s()<>]+|(\\([^\\s()<>]+\\)))*\\)|[^\\s`!()\\[\\]" + "{};:'\".,<>?ยซยปโ€œโ€โ€˜โ€™]))" + ) + + # presidio + if self.method == self.PRESIDIO: + self.analyzer = AnalyzerEngine() + + def predict(self, doc: Document) -> DocResult: + """Main runner.""" + # extract + if self.method == self.PRESIDIO: + pii_spans = self._extract_pii_presidio(text=doc.text) + elif self.method == self.REGEX: + pii_spans = self._extract_pii_regex(text=doc.text) + else: + raise NotImplementedError + # post process + if self.postprocess: + new_pii_spans = self._postprocess(text=doc.text, pii_spans=pii_spans, window=self.window) + else: + new_pii_spans = pii_spans + + # document-level score + score = self._score(text=doc.text, pii_spans=new_pii_spans) + new_pii_spans.append(Span(start=0, end=len(doc.text), type="doc", score=score)) + return DocResult(doc=doc, spans=new_pii_spans) + + def _score(self, text: str, pii_spans: List[Span]) -> float: + return len(pii_spans) * 1.0 / len(text.split()) + + def _extract_pii_regex(self, text: str) -> List[Span]: + pii_spans: List[Span] = [] + for pii_type, regex in self.pii_type_to_regex.items(): + for match in regex.finditer(text): + start, end = match.span() + pii_spans.append(Span(start=start, end=end, type=pii_type)) + return pii_spans + + def _extract_pii_presidio(self, text: str) -> List[Span]: + analyzer_results = self.analyzer.analyze( + text=text, + entities=[self.EMAIL, self.PHONE, self.IP], + language=self.ENGLISH, + ) + pii_spans: List[Span] = [] + for res in analyzer_results: + pii_spans.append(Span(start=res.start, end=res.end, type=res.entity_type)) + return pii_spans + + def _postprocess(self, text: str, pii_spans: List[Span], window: int) -> List[Span]: + """Applies some rules to remove over-prediction of PII types.""" + new_pii_spans = [] + for pii_span in pii_spans: + if pii_span.type == self.EMAIL: + if self._is_email(text, pii_span): + new_pii_spans.append(pii_span) + else: + pass + + elif pii_span.type == self.PHONE or pii_span.type == self.IP: + context = pii_span.mention(text=text, window=window) + # for both phone numbers & IP addresses, context shouldnt + # contain these strings + if "isbn" in context or "doi" in context or "#" in context: + pass + elif pii_span.type == self.IP: + new_pii_spans.append(pii_span) + elif pii_span.type == self.PHONE: + # for phone numbers, additionally shouldnt be URL + if self._contains_url(text=text): + pass + else: + new_pii_spans.append(pii_span) + + else: + raise NotImplementedError(f"Unsupported PII type for Postprocess: {pii_span.type}") + return new_pii_spans + + def _contains_url(self, text: str) -> bool: + return len(self.url_regex.findall(text)) > 0 + + def _is_email(self, text: str, pii_span: Span) -> bool: + """ + Rules: + (1) The email address besides the domain, cannot be only "(" + (2) There must be a "." in the domain + """ + mention = pii_span.mention(text=text) + addressee = mention.split("@")[0] + domain = mention.split("@")[1] + if addressee.strip() == "(" or "." not in domain: + return False + return True + + +@TaggerRegistry.add("pii_presidio_v1") +class PiiPresidioV1(BasePiiFilter): + def __init__(self): + super().__init__(method=self.PRESIDIO, postprocess=True, window=self.WINDOW) + + +@TaggerRegistry.add("pii_regex_v1") +class PiiRegexV1(BasePiiFilter): + def __init__(self): + super().__init__(method=self.REGEX, postprocess=True, window=self.WINDOW) + + +@TaggerRegistry.add("pii_regex_v2") +class PiiRegexV2(PiiRegexV1): + def _score(self, text: str, pii_spans: List[Span]) -> float: + try: + score = len(pii_spans) * 1.0 / len(text.split()) + except ZeroDivisionError: + score = -1.0 + return score + + +@TaggerRegistry.add("pii_regex_with_counts_fast_v2") +class FastPiiRegex(BaseTagger): + EMAIL_KEY = "EMAIL_ADDRESS" + PHONE_KEY = "PHONE_NUMBER" + IP_KEY = "IP_ADDRESS" + + EMAIL_REGEX = "[.\\s@,?!;:)(]*([^\\s@]+@[^\\s@,?!;:)(]+?)[.\\s@,?!;:)(]?[\\s\n\r]" + PHONE_REGEX = "\\s+\\(?(\\d{3})\\)?[-\\. ]*(\\d{3})[-. ]?(\\d{4})" + IP_REGEX = "(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)" + URL_REGEX = "(?i)\b((?:https?://|www\\d{0,3}[.]|[a-z0-9.\\-]+[.][a-z]{2,4}/)(?:[^\\s()<>]+|\\(([^\\s()<>]+|(\\([^\\s()<>]+\\)))*\\))+(?:\\(([^\\s()<>]+|(\\([^\\s()<>]+\\)))*\\)|[^\\s`!()\\[\\]{};:'\".,<>?ยซยปโ€œโ€โ€˜โ€™]))" # noqa: E501 + + def __init__( + self, + email_regex: str = EMAIL_REGEX, + phone_regex: str = PHONE_REGEX, + ip_regex: str = IP_REGEX, + url_regex: str = URL_REGEX, + ) -> None: + self.email_regex = re.compile(email_regex) + self.phone_regex = re.compile(phone_regex) + self.ip_regex = re.compile(ip_regex) + self.url_regex = re.compile(url_regex) + + self.pre_ip_regex = re.compile(r"\.[^\s]") + self.pre_phone_regex = re.compile(r"\d") + + def _false_positive_identifiers(self, text: str) -> bool: + return "isbn" in text or "doi" in text or "#" in text + + def _predict_email(self, slice: TextSlice) -> List[Span]: + if "@" not in slice.text: + return [] + + spans = [] + for match in self.email_regex.finditer(slice.text): + addressee, domain = match.group(1).split("@", 1) + if addressee.strip() == "(" or "." not in domain: + continue + + start, end = match.span() + spans.append(Span(start=start + slice.start, end=end + slice.start, type=self.EMAIL_KEY)) + + return spans + + def _predict_phone(self, slice: TextSlice) -> List[Span]: + if not self.pre_phone_regex.search(slice.text): + return [] + + spans = [] + for match in self.phone_regex.finditer(slice.text): + start, end = match.span() + spans.append(Span(start=start + slice.start, end=end + slice.start, type=self.PHONE_KEY)) + + return spans + + def _predict_ip(self, slice: TextSlice) -> List[Span]: + if not self.pre_ip_regex.search(slice.text): + return [] + + spans = [] + for match in self.ip_regex.finditer(slice.text): + if self._contains_url(match.group(0)): + continue + start, end = match.span() + spans.append(Span(start=start + slice.start, end=end + slice.start, type=self.IP_KEY)) + + return spans + + def _contains_url(self, text: str) -> bool: + return self.url_regex.search(text) is not None + + def predict(self, doc: Document) -> DocResult: + paragraphs = split_paragraphs(doc.text) + spans: List[Span] = [] + + if doc.text.count('?') > 10_000: + warn("Skipping regex PII detection for doc with >10k question marks") + paragraphs = [] + + for paragraph in paragraphs: + spans.extend(self._predict_email(paragraph)) + spans.extend(self._predict_phone(paragraph)) + spans.extend(self._predict_ip(paragraph)) + + # doc level score is the count of spans matching any of the PII types + score = sum(1.0 for s in spans if s.type != "doc") + spans.append(Span(start=0, end=len(doc.text), type="doc_count", score=score)) + + try: + # fraction of words that are PII + score = sum(len(s) for s in spans) / len(doc.text) + except ZeroDivisionError: + # empty doc + score = -1.0 + + spans.append(Span(start=0, end=len(doc.text), type="doc_frac", score=score)) + return DocResult(doc=doc, spans=spans) + + +@TaggerRegistry.add("pii_regex_with_counts_v2") +class PiiRegexWithCountV2(BasePiiFilter): + def __init__(self): + super().__init__(method=self.REGEX, postprocess=True, window=self.WINDOW) + + def predict(self, doc: Document) -> DocResult: + doc_result = super().predict(doc=doc) + count = sum(1 for s in doc_result.spans if s.type != "doc") + doc_result.spans.append(Span(start=0, end=len(doc.text), type="doc_count", score=count)) + return doc_result diff --git a/python/dolma/taggers/sampling.py b/python/dolma/taggers/sampling.py new file mode 100644 index 00000000..ea9e13a8 --- /dev/null +++ b/python/dolma/taggers/sampling.py @@ -0,0 +1,20 @@ +import random +from multiprocessing import current_process + +from ..core.data_types import DocResult, Document, Span +from ..core.registry import TaggerRegistry +from ..core.taggers import BaseTagger + + +@TaggerRegistry.add("random_number_v1") +class RandomNumberTagger(BaseTagger): + def __init__(self, seed: int = 1) -> None: + assert seed > 0 + # we multiply the seed by the current process id to ensure that each + # process has a different seed + self.seed = ((current_process().pid or 0) + 1) * seed + random.seed(self.seed) + + def predict(self, doc: Document) -> DocResult: + score = random.random() + return DocResult(doc=doc, spans=[Span(start=0, end=len(doc.text), type="random", score=score)])