Skip to content

Commit

Permalink
feat: new shuffler (#2404)
Browse files Browse the repository at this point in the history
Signed-off-by: BubbleCal <[email protected]>
  • Loading branch information
BubbleCal authored May 30, 2024
1 parent 69aa511 commit 32cd740
Show file tree
Hide file tree
Showing 2 changed files with 227 additions and 5 deletions.
11 changes: 11 additions & 0 deletions rust/lance-file/src/v2/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,17 @@ impl FileWriter {
Ok(())
}

/// Schedule batches of data to be written to the file
pub async fn write_batches(
&mut self,
batches: impl Iterator<Item = &RecordBatch>,
) -> Result<()> {
for batch in batches {
self.write_batch(batch).await?;
}
Ok(())
}

/// Schedule a batch of data to be written to the file
///
/// Note: the future returned by this method may complete before the data has been fully
Expand Down
221 changes: 216 additions & 5 deletions rust/lance-index/src/vector/v3/shuffler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,27 @@
//! Shuffler is a component that takes a stream of record batches and shuffles them into
//! the corresponding IVF partitions.
use lance_core::Result;
use lance_io::stream::RecordBatchStream;
use std::sync::Arc;

use arrow::{array::AsArray, compute::sort_to_indices};
use arrow_array::{RecordBatch, UInt32Array};
use future::join_all;
use futures::prelude::*;
use lance_arrow::RecordBatchExt;
use lance_core::{utils::tokio::spawn_cpu, Error, Result};
use lance_file::v2::{reader::FileReader, writer::FileWriter};
use lance_io::{
object_store::ObjectStore,
scheduler::ScanScheduler,
stream::{RecordBatchStream, RecordBatchStreamAdapter},
};
use object_store::path::Path;

use crate::vector::PART_ID_COLUMN;

#[async_trait::async_trait]
/// A reader that can read the shuffled partitions.
pub trait IvfShuffleReader {
pub trait ShuffleReader {
/// Read a partition by partition_id
/// will return Ok(None) if partition_size is 0
/// check reader.partiton_size(partition_id) before calling this function
Expand All @@ -25,11 +40,207 @@ pub trait IvfShuffleReader {
#[async_trait::async_trait]
/// A shuffler that can shuffle the incoming stream of record batches into IVF partitions.
/// Returns a IvfShuffleReader that can be used to read the shuffled partitions.
pub trait IvfShuffler {
pub trait Shuffler {
/// Shuffle the incoming stream of record batches into IVF partitions.
/// Returns a IvfShuffleReader that can be used to read the shuffled partitions.
async fn shuffle(
mut self,
data: Box<dyn RecordBatchStream + Unpin + 'static>,
) -> Result<Box<dyn IvfShuffleReader>>;
) -> Result<Box<dyn ShuffleReader>>;
}

pub struct IvfShuffler {
object_store: ObjectStore,
output_dir: Path,
num_partitions: usize,

// options
buffer_size: usize,
}

impl IvfShuffler {
pub fn new(object_store: ObjectStore, output_dir: Path, num_partitions: usize) -> Self {
Self {
object_store,
output_dir,
num_partitions,
buffer_size: 4096,
}
}

pub fn with_buffer_size(mut self, buffer_size: usize) -> Self {
self.buffer_size = buffer_size;
self
}
}

#[async_trait::async_trait]
impl Shuffler for IvfShuffler {
async fn shuffle(
mut self,
data: Box<dyn RecordBatchStream + Unpin + 'static>,
) -> Result<Box<dyn ShuffleReader>> {
let mut writers: Vec<FileWriter> = vec![];
let mut partition_sizes = vec![0; self.num_partitions];
let mut first_pass = true;

let mut counter = 0;

let num_partitions = self.num_partitions;
let mut parallel_sort_stream = data
.map(|batch| {
spawn_cpu(move || {
let batch = batch?;

let part_ids: &UInt32Array = batch
.column_by_name(PART_ID_COLUMN)
.expect("Partition ID column not found")
.as_primitive();

let indices = sort_to_indices(&part_ids, None, None)?;
let batch = batch.take(&indices)?;

let part_ids: &UInt32Array = batch
.column_by_name(PART_ID_COLUMN)
.expect("Partition ID column not found")
.as_primitive();

let mut partition_buffers =
(0..num_partitions).map(|_| Vec::new()).collect::<Vec<_>>();

let mut start = 0;
while start < batch.num_rows() {
let part_id: u32 = part_ids.value(start);
let mut end = start + 1;
while end < batch.num_rows() && part_ids.value(end) == part_id {
end += 1;
}

let part_batches = &mut partition_buffers[part_id as usize];
part_batches.push(batch.slice(start, end - start));
start = end;
}

Ok::<Vec<Vec<RecordBatch>>, Error>(partition_buffers)
})
})
.buffered(num_cpus::get());

// part_id: | 0 | 1 | 3 |
// partition_buffers: |[batch,batch,..]|[batch,batch,..]|[batch,batch,..]|
let mut partition_buffers = (0..self.num_partitions)
.map(|_| Vec::new())
.collect::<Vec<_>>();

while let Some(shuffled) = parallel_sort_stream.next().await {
log::info!("shuffle batch: {}", counter);
let shuffled = shuffled?;

for (part_id, batches) in shuffled.into_iter().enumerate() {
let part_batches = &mut partition_buffers[part_id];
part_batches.extend(batches);
}

counter += 1;

if first_pass {
let schema = partition_buffers
.iter()
.flatten()
.find(|_| true)
.map(|batch| batch.schema())
.expect("there should be at least one batch");
writers = stream::iter(0..self.num_partitions)
.map(|partition_id| {
let path = self.output_dir.clone();
let object_store = self.object_store.clone();
let schema = schema.clone();
async move {
let part_path = path.child(format!("ivf_{}.lance", partition_id));
let writer = object_store.create(&part_path).await?;
FileWriter::try_new(
writer,
part_path.to_string(),
lance_core::datatypes::Schema::try_from(schema.as_ref())?,
Default::default(),
)
}
})
.buffered(10)
.try_collect::<Vec<_>>()
.await?;

first_pass = false;
}

// do flush
if counter % self.buffer_size == 0 {
let mut futs = vec![];
for (part_id, writer) in writers.iter_mut().enumerate() {
let batches = &partition_buffers[part_id];
partition_sizes[part_id] += batches.iter().map(|b| b.num_rows()).sum::<usize>();
futs.push(writer.write_batches(batches.iter()));
}
join_all(futs)
.await
.into_iter()
.collect::<Result<Vec<_>>>()?;

partition_buffers.iter_mut().for_each(|b| b.clear());
}
}

// final flush
for (part_id, batches) in partition_buffers.into_iter().enumerate() {
let writer = &mut writers[part_id];
partition_sizes[part_id] += batches.iter().map(|b| b.num_rows()).sum::<usize>();
for batch in batches.iter() {
writer.write_batch(batch).await?;
}
}

// finish all writers
for (writer, &size) in writers.iter_mut().zip(partition_sizes.iter()) {
if size == 0 {
continue;
}
writer.finish().await?;
}

Ok(Box::new(IvfShufflerReader {
object_store: self.object_store.clone(),
output_dir: self.output_dir.clone(),
partition_sizes,
}))
}
}

pub struct IvfShufflerReader {
object_store: ObjectStore,
output_dir: Path,
partition_sizes: Vec<usize>,
}

#[async_trait::async_trait]
impl ShuffleReader for IvfShufflerReader {
async fn read_partition(
&self,
partition_id: usize,
) -> Result<Option<Box<dyn RecordBatchStream + Unpin + 'static>>> {
let scheduler = ScanScheduler::new(Arc::new(self.object_store.clone()), 32);
let partition_path = self.output_dir.child(format!("ivf_{}.lance", partition_id));

let reader =
FileReader::try_open(scheduler.open_file(&partition_path).await?, None).await?;
let schema = reader.schema().as_ref().into();

Ok(Some(Box::new(RecordBatchStreamAdapter::new(
Arc::new(schema),
reader.read_stream(lance_io::ReadBatchParams::RangeFull, 4096, 16)?,
))))
}

fn partiton_size(&self, partition_id: usize) -> Result<usize> {
Ok(self.partition_sizes[partition_id])
}
}

0 comments on commit 32cd740

Please sign in to comment.