Skip to content

Commit

Permalink
chore: remove async transform and use rayon instead (#2351)
Browse files Browse the repository at this point in the history
  • Loading branch information
eddyxu authored May 18, 2024
1 parent ca56bb8 commit 6845edf
Show file tree
Hide file tree
Showing 8 changed files with 24 additions and 36 deletions.
5 changes: 2 additions & 3 deletions rust/lance-index/src/vector/ivf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -494,15 +494,14 @@ where
}
}

#[async_trait]
impl<T: ArrowFloatType> Transformer for IvfImpl<T>
where
T::Native: Dot + L2,
{
async fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch> {
fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch> {
let mut batch = batch.clone();
for transform in self.transforms.as_slice() {
batch = transform.transform(&batch).await?;
batch = transform.transform(&batch)?;
}
Ok(batch)
}
Expand Down
4 changes: 2 additions & 2 deletions rust/lance-index/src/vector/ivf/shuffler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,9 @@ pub async fn shuffle_dataset(
}

// Filter out NaNs/Infs
batch = nan_filter.transform(&batch).await?;
batch = nan_filter.transform(&batch)?;

ivf.transform(&batch).await
ivf.transform(&batch)
})
})
.buffer_unordered(num_cpus::get())
Expand Down
6 changes: 2 additions & 4 deletions rust/lance-index/src/vector/ivf/transform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,11 @@ where
}
}

#[async_trait::async_trait]
impl<T: ArrowFloatType + ArrowPrimitiveType> Transformer for IvfTransformer<T>
where
<T as ArrowFloatType>::Native: Dot + L2 + Normalize,
{
async fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch> {
fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch> {
if batch.column_by_name(&self.output_column).is_some() {
// If the partition ID column is already present, we don't need to compute it again.
return Ok(batch.clone());
Expand Down Expand Up @@ -146,9 +145,8 @@ impl PartitionFilter {
}
}

#[async_trait::async_trait]
impl Transformer for PartitionFilter {
async fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch> {
fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch> {
// TODO: use datafusion execute?
let arr = batch
.column_by_name(&self.column)
Expand Down
2 changes: 1 addition & 1 deletion rust/lance-index/src/vector/pq/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ impl ProductQuantizationStorage {
let num_sub_vectors = quantizer.num_sub_vectors;
let metric_type = quantizer.metric_type;
let transform = PQTransformer::new(quantizer, vector_col, PQ_CODE_COLUMN);
let batch = transform.transform(batch).await?;
let batch = transform.transform(batch)?;

Self::new(
codebook,
Expand Down
6 changes: 2 additions & 4 deletions rust/lance-index/src/vector/pq/transform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use std::sync::Arc;

use arrow_array::{cast::AsArray, Array, RecordBatch};
use arrow_schema::Field;
use async_trait::async_trait;
use lance_arrow::RecordBatchExt;
use lance_core::{Error, Result};
use snafu::{location, Location};
Expand Down Expand Up @@ -47,9 +46,8 @@ impl Debug for PQTransformer {
}
}

#[async_trait]
impl Transformer for PQTransformer {
async fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch> {
fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch> {
let input_arr = batch
.column_by_name(&self.input_column)
.ok_or(Error::Index {
Expand Down Expand Up @@ -109,7 +107,7 @@ mod tests {
.unwrap();

let transformer = PQTransformer::new(pq, "vec", "pq_code");
let batch = transformer.transform(&batch).await.unwrap();
let batch = transformer.transform(&batch).unwrap();
assert!(batch.column_by_name("vec").is_none());
assert!(batch.column_by_name("pq_code").is_some());
assert!(batch.column_by_name("other").is_some());
Expand Down
4 changes: 1 addition & 3 deletions rust/lance-index/src/vector/residual.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
use arrow_array::types::UInt32Type;
use arrow_array::{cast::AsArray, Array, FixedSizeListArray, RecordBatch};
use arrow_schema::Field;
use async_trait::async_trait;
use lance_arrow::{ArrowFloatType, FixedSizeListArrayExt, FloatArray, RecordBatchExt};
use lance_core::{Error, Result};
use lance_linalg::MatrixView;
Expand Down Expand Up @@ -46,12 +45,11 @@ impl<T: ArrowFloatType> ResidualTransform<T> {
}
}

#[async_trait]
impl<T: ArrowFloatType> Transformer for ResidualTransform<T> {
/// Replace the original vector in the [`RecordBatch`] to residual vectors.
///
/// The new [`RecordBatch`] will have a new column named [`RESIDUAL_COLUMN`].
async fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch> {
fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch> {
let part_ids = batch.column_by_name(&self.part_col).ok_or(Error::Index {
message: format!(
"Compute residual vector: partition id column not found: {}",
Expand Down
29 changes: 12 additions & 17 deletions rust/lance-index/src/vector/transform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ use std::sync::Arc;
use arrow_array::types::{Float16Type, Float32Type, Float64Type};
use arrow_array::{cast::AsArray, Array, ArrowPrimitiveType, RecordBatch, UInt32Array};
use arrow_schema::{DataType, Field};
use async_trait::async_trait;
use lance_arrow::RecordBatchExt;
use num_traits::Float;
use snafu::{location, Location};
Expand All @@ -21,11 +20,10 @@ use lance_linalg::kernels::normalize_fsl;
/// Transform of a Vector Matrix.
///
///
#[async_trait]
pub trait Transformer: Debug + Sync + Send {
pub trait Transformer: Debug + Send + Sync {
/// Transform a [`RecordBatch`] of vectors
///
async fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch>;
fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch>;
}

/// Normalize Transformer
Expand Down Expand Up @@ -55,9 +53,8 @@ impl NormalizeTransformer {
}
}

#[async_trait]
impl Transformer for NormalizeTransformer {
async fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch> {
fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch> {
let arr = batch
.column_by_name(&self.input_column)
.ok_or(Error::Index {
Expand All @@ -78,9 +75,9 @@ impl Transformer for NormalizeTransformer {
let norm = normalize_fsl(data)?;
if let Some(output_column) = &self.output_column {
let field = Field::new(output_column, norm.data_type().clone(), true);
return Ok(batch.try_with_column(field, Arc::new(norm))?);
Ok(batch.try_with_column(field, Arc::new(norm))?)
} else {
return Ok(batch.replace_column_by_name(&self.input_column, Arc::new(norm))?);
Ok(batch.replace_column_by_name(&self.input_column, Arc::new(norm))?)
}
}
}
Expand Down Expand Up @@ -109,9 +106,8 @@ where
.any(|&v| !v.is_finite())
}

#[async_trait]
impl Transformer for KeepFiniteVectors {
async fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch> {
fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch> {
let arr = batch.column_by_name(&self.column).ok_or(Error::Index {
message: format!(
"KeepFiniteVectors: column {} not found in RecordBatch",
Expand Down Expand Up @@ -169,9 +165,8 @@ impl DropColumn {
}
}

#[async_trait]
impl Transformer for DropColumn {
async fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch> {
fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch> {
Ok(batch.drop_column(&self.column)?)
}
}
Expand All @@ -197,7 +192,7 @@ mod tests {
)]);
let batch = RecordBatch::try_new(schema.into(), vec![Arc::new(fsl)]).unwrap();
let transformer = NormalizeTransformer::new("v");
let output = transformer.transform(&batch).await.unwrap();
let output = transformer.transform(&batch).unwrap();
let actual = output.column_by_name("v").unwrap();
let act_fsl = actual.as_fixed_size_list();
assert_eq!(act_fsl.len(), 2);
Expand All @@ -223,7 +218,7 @@ mod tests {
)]);
let batch = RecordBatch::try_new(schema.into(), vec![Arc::new(fsl)]).unwrap();
let transformer = NormalizeTransformer::new("v");
let output = transformer.transform(&batch).await.unwrap();
let output = transformer.transform(&batch).unwrap();
let actual = output.column_by_name("v").unwrap();
let act_fsl = actual.as_fixed_size_list();
assert_eq!(act_fsl.len(), 2);
Expand Down Expand Up @@ -257,7 +252,7 @@ mod tests {
)]);
let batch = RecordBatch::try_new(schema.into(), vec![Arc::new(fsl.clone())]).unwrap();
let transformer = NormalizeTransformer::new_with_output("v", "o");
let output = transformer.transform(&batch).await.unwrap();
let output = transformer.transform(&batch).unwrap();
let input = output.column_by_name("v").unwrap();
assert_eq!(input.as_ref(), &fsl);
let actual = output.column_by_name("o").unwrap();
Expand Down Expand Up @@ -289,10 +284,10 @@ mod tests {
let batch =
RecordBatch::try_new(schema.into(), vec![Arc::new(i32_array), Arc::new(fsl)]).unwrap();
let transformer = DropColumn::new("v");
let output = transformer.transform(&batch).await.unwrap();
let output = transformer.transform(&batch).unwrap();
assert!(output.column_by_name("v").is_none());

let dup_drop_result = transformer.transform(&output).await;
let dup_drop_result = transformer.transform(&output);
assert!(dup_drop_result.is_ok());
}
}
4 changes: 2 additions & 2 deletions rust/lance/src/index/vector/ivf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ impl IVFIndex {
/// Internal API with no stability guarantees.
///
/// Assumes the query vector is normalized if the metric type is cosine.
pub async fn find_partitions(&self, query: &Query) -> Result<UInt32Array> {
pub fn find_partitions(&self, query: &Query) -> Result<UInt32Array> {
let mt = if self.metric_type == MetricType::Cosine {
MetricType::L2
} else {
Expand Down Expand Up @@ -680,7 +680,7 @@ impl VectorIndex for IVFIndex {
query.key = key;
};

let partition_ids = self.find_partitions(&query).await?;
let partition_ids = self.find_partitions(&query)?;
assert!(partition_ids.len() <= query.nprobes);
let part_ids = partition_ids.values().to_vec();
let batches = stream::iter(part_ids)
Expand Down

0 comments on commit 6845edf

Please sign in to comment.