Skip to content

Commit

Permalink
Adds strict filter mechanism to Mixer (#231)
Browse files Browse the repository at this point in the history
* WIP: adds unshuffled/unmerged mixing operations

* Example config with shuffle false

* Adds comment for sanity

* Try new gh upload-artifact with overwrite

* More ci fixes

* Should panic if no suffix
  • Loading branch information
undfined authored Feb 13, 2025
1 parent d42fa1d commit 625ac1c
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 20 deletions.
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "dolma"
version = "1.1.0"
version = "1.1.1"
edition = "2021"
license = "Apache-2.0"

Expand Down
24 changes: 24 additions & 0 deletions configs/test/test_filtered_mixer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
streams:
- name: filtered_object_test
documents:
- s3://ai2-oe-data/pretraining-data/sources/refine/v0/documents/0001/0000_dclm_shard_00000*.jsonl.zstd
attributes:
- random_number_v1
- fineweb-edu-classifier
- fineweb-edu-classifier-original
compression:
input: zst
output: zst
output:
path: s3://ai2-oe-data/tylerm/test/mixer/filtered
filter:
include:
- '((.attributes."HuggingFaceFW_fineweb-edu-classifier_score" != null and .attributes."HuggingFaceFW_fineweb-edu-classifier_score_original" != null) and ((${oc.env:ALPHA} * (.attributes.random_number_v1__random_number_v1__random[0][-1] * 2 - 1)) + (.attributes."HuggingFaceFW_fineweb-edu-classifier_score"[0][-1]) - .attributes."HuggingFaceFW_fineweb-edu-classifier_score_original"[0][-1]) > 0.30)'
syntax: jq

work_dir:
input: "/tmp/dolma/input"
output: "/tmp/dolma/output"

shuffle: false
processes: 10
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "dolma"
version = "1.1.0"
version = "1.1.1"
description = "Data filters"
license = { text = "Apache-2.0" }
readme = "README.md"
Expand Down
2 changes: 2 additions & 0 deletions python/dolma/cli/mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class MixerConfig:
streams: List[StreamConfig] = field(default=[], help="List configurations of streams to be mixed")
work_dir: WorkDirConfig = field(default=WorkDirConfig(), help="Configuration for temporary work directories.")
processes: int = field(default=1, help="Number of processes to use for mixing. By default 1 process is used.")
shuffle: bool = field(default=True, help="Whether to shard and shuffle the documents during mixing.")
dryrun: bool = field(
default=False,
help="If true, only print the configuration and exit without running the mixer.",
Expand All @@ -92,6 +93,7 @@ def run(cls, parsed_config: MixerConfig):
"work_dir": {"input": str(work_dirs.input), "output": str(work_dirs.output)},
"processes": int(parsed_config.processes),
"streams": [],
"shuffle": bool(parsed_config.shuffle),
}

for stream_config in parsed_config.streams:
Expand Down
15 changes: 13 additions & 2 deletions src/mixer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,14 @@ use crate::shard::Shard;
use mixer_config::*;

pub fn run(config: MixerConfig) -> Result<u32, u32> {
let shards = Shard::split_streams(&config.streams).unwrap();

let shards = if config.shuffle {
Shard::split_streams(&config.streams).unwrap()
} else {
Shard::split_streams_unshuffled(&config.streams).unwrap()
};
let threadpool = ThreadPool::new(config.processes);
let failed_shard_count_ref = Arc::new(AtomicU32::new(0));

for shard in shards {
let output_path = Path::new(&config.work_dir.output.clone()).join(&shard.output);
if output_path.exists() {
Expand Down Expand Up @@ -50,11 +54,18 @@ pub mod mixer_config {

use crate::shard::shard_config::{StreamConfig, WorkDirConfig};

fn shuffle_default() -> bool {
true
}

#[derive(Serialize, Deserialize, Clone)]
pub struct MixerConfig {
pub streams: Vec<StreamConfig>,
pub processes: usize,
pub work_dir: WorkDirConfig,
// Includes default for backwards compatibility
#[serde(default = "shuffle_default")]
pub shuffle: bool,
}

impl MixerConfig {
Expand Down
84 changes: 70 additions & 14 deletions src/shard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,21 +63,11 @@ impl Shard {
})
.collect::<Vec<(DocumentPaths, usize)>>();
let mut shard_size = inputs_with_sizes[0].1;
// Start with the first input and add it to the vector
let mut shard_inputs: Vec<DocumentPaths> = vec![inputs_with_sizes[0].0.clone()];
let output_ext = match stream_config
.compression
.clone()
.unwrap_or(CompressionConfig::infer())
.output
{
// empty string means no compression
Some(ext) if ext.is_empty() => "".to_string(),
// if there is an extension, add a dot
Some(ext) => format!(".{}", ext),
// default to .gz
None => ".gz".to_string(),
};
let output_ext = Shard::get_output_extension(stream_config.clone());

// We slice from the second position since we already added the first input above
for (input, size) in inputs_with_sizes[1..].iter() {
if *size == 0 {
log::warn!(
Expand All @@ -89,7 +79,7 @@ impl Shard {
shard_size += size;
if shard_size > stream_config.output.max_size_in_bytes {
let output = format!(
"{}/{}-{:04}.json{}",
"{}/{}-{:04}.jsonl{}",
stream_config.output.path,
stream_config.name,
stream_shard_count,
Expand Down Expand Up @@ -139,6 +129,56 @@ impl Shard {
Ok(shards)
}

pub fn split_streams_unshuffled(streams: &Vec<StreamConfig>) -> Result<Vec<Shard>, IoError> {
// Partitions the input files of a stream into a vector of shards each consisting of a single object and maintaining
// the original file structure and naming below */documents/. Useful for "filter" only operations where the resulting
// dataset is a strict subset of the original and is intended to be unshuffled and unsharded.
let mut shards: Vec<Shard> = Vec::new();
for stream_config in streams {
let stream_inputs = find_objects_matching_patterns(&stream_config.documents)?;
let input_count = stream_inputs.len();
let inputs = stream_inputs.into_iter().map(|input| {
let mut attr_paths = Vec::new();
for prefix in stream_config.attributes.iter() {
let attr_prefix = format!("/attributes/{}/", prefix);
let attr_path = input.replace("/documents/", &attr_prefix);
attr_paths.push(attr_path);
}
DocumentPaths {
doc_path: input,
attribute_paths: attr_paths,
}
});

for input in inputs {
let doc_path_clone = input.doc_path.clone();
let output_suffix = doc_path_clone.split("/documents/").last().unwrap();
let output = format!(
"{}/documents/{}",
stream_config.output.path.clone(),
output_suffix
);
log::info!("Creating shard for {}", output);
let shard: Shard = Shard {
inputs: vec![input.clone()],
output,
filter: stream_config.filter.clone(),
span_replacements: stream_config.span_replacement.clone(),
discard_fields: stream_config.output.discard_fields.clone(),
min_text_length: stream_config.output.min_text_length.clone(),
compression: stream_config.compression.clone(),
};
shards.push(shard);
}
log::info!(
"Created {} shards of file count 1 for {}",
input_count,
stream_config.name,
);
}
Ok(shards)
}

// Process a shard:
// Read all input files sequentially,
// Merge attributes
Expand Down Expand Up @@ -465,6 +505,22 @@ impl Shard {
cache.finalize_output(&self.output)?;
Ok(())
}

fn get_output_extension(stream_config: StreamConfig) -> String {
match stream_config
.compression
.clone()
.unwrap_or(CompressionConfig::infer())
.output
{
// empty string means no compression
Some(ext) if ext.is_empty() => "".to_string(),
// if there is an extension, add a dot
Some(ext) => format!(".{}", ext),
// default to .gz
None => ".gz".to_string(),
}
}
}

pub mod shard_config {
Expand Down

0 comments on commit 625ac1c

Please sign in to comment.