Skip to content

Commit

Permalink
chore: fix cosine residual calculation (#2015)
Browse files Browse the repository at this point in the history
  • Loading branch information
eddyxu authored Mar 3, 2024
1 parent b436b22 commit f1218e3
Show file tree
Hide file tree
Showing 13 changed files with 562 additions and 442 deletions.
2 changes: 1 addition & 1 deletion python/python/lance/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class KMeans:
>>> import numpy as np
>>> import lance
>>> data = np.random.randn(1000, 128).astype(np.float32)
>>> kmeans = lance.util.KMeans(8, metric_type="cosine")
>>> kmeans = lance.util.KMeans(8, metric_type="l2")
>>> kmeans.fit(data)
>>> centroids = np.stack(kmeans.centroids.to_numpy(zero_copy_only=False))
>>> clusters = kmeans.predict(data)
Expand Down
24 changes: 0 additions & 24 deletions python/python/tests/test_kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,6 @@
import numpy as np
import pyarrow as pa
import pytest
from numpy.linalg import norm


def test_train_cosine():
kmeans = lance.util.KMeans(32, metric_type="cosine")
data = np.random.randn(1000, 128).astype(np.float32)

assert kmeans.centroids is None
kmeans.fit(data)
assert kmeans.centroids is not None
centroids = kmeans.centroids.to_numpy_ndarray()
assert centroids.shape == (32, 128)

# test predict
pred = kmeans.predict(data)

# compute predict using numpy brute-force
expected = []
for row in data:
# Cosine distance
dist = 1 - np.dot(centroids, row) / (norm(centroids, axis=1) * norm(row))
cluster_id = np.argmin(dist)
expected.append(cluster_id)
assert np.allclose(pred, expected)


def test_invalid_inputs():
Expand Down
47 changes: 22 additions & 25 deletions rust/lance-index/src/vector/ivf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,20 +253,17 @@ impl<T: ArrowFloatType + Dot + L2 + ArrowPrimitiveType> IvfImpl<T> {
) -> Self {
let mut transforms: Vec<Arc<dyn Transformer>> = vec![];

// Re-enable it after search path fixed.
// if metric_type == MetricType::Cosine {
// transforms.push(Arc::new(super::transform::NormalizeTransformer::new(
// vector_column,
// )));
// };
let mt = if metric_type == MetricType::Cosine {
transforms.push(Arc::new(super::transform::NormalizeTransformer::new(
vector_column,
)));
MetricType::L2
} else {
metric_type
};

// TODO: add range filter
let ivf_transform = Arc::new(IvfTransformer::new(
centroids.clone(),
metric_type,
vector_column,
));
transforms.push(ivf_transform.clone() as Arc<dyn Transformer>);
let ivf_transform = Arc::new(IvfTransformer::new(centroids.clone(), mt, vector_column));
transforms.push(ivf_transform.clone());

if let Some(range) = range {
transforms.push(Arc::new(transform::PartitionFilter::new(
Expand Down Expand Up @@ -349,16 +346,15 @@ impl<T: ArrowFloatType + Dot + L2 + ArrowPrimitiveType> Ivf for IvfImpl<T> {
self.compute_partitions(original).await?
};
let dim = original.value_length() as usize;
let mut residual_arr: Vec<<T as ArrowFloatType>::Native> =
Vec::with_capacity(original.values().len());
flatten_arr
let residual_arr = flatten_arr
.as_slice()
.chunks_exact(dim)
.zip(part_ids.values())
.for_each(|(vector, &part_id)| {
.flat_map(|(vector, &part_id)| {
let centroid = self.centroids.row(part_id as usize).unwrap();
residual_arr.extend(vector.iter().zip(centroid.iter()).map(|(&v, &c)| v - c));
});
vector.iter().zip(centroid.iter()).map(|(&v, &c)| v - c)
})
.collect::<Vec<_>>();
let arr = T::ArrayType::from(residual_arr);
Ok(FixedSizeListArray::try_new_from_values(arr, dim as i32)?)
}
Expand All @@ -375,12 +371,13 @@ impl<T: ArrowFloatType + Dot + L2 + ArrowPrimitiveType> Ivf for IvfImpl<T> {
),
location: Default::default(),
})?;
// todo: hold kmeans in this struct.
let kmeans = KMeans::<T>::with_centroids(
self.centroids.data().clone(),
self.dimension(),
self.metric_type,
);
let mt = if self.metric_type == MetricType::Cosine {
MetricType::L2
} else {
self.metric_type
};
let kmeans =
KMeans::<T>::with_centroids(self.centroids.data().clone(), self.dimension(), mt);
Ok(kmeans.find_partitions(query.as_slice(), nprobes)?)
}
}
Expand Down
17 changes: 17 additions & 0 deletions rust/lance-index/src/vector/ivf/transformer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright 2024 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//! IVF Transformer
//!
//! It transforms a column of vectors into a column of IVF
4 changes: 2 additions & 2 deletions rust/lance-index/src/vector/kmeans.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ use std::sync::Arc;

use lance_core::{Error, Result};
use lance_linalg::{
distance::{Cosine, Dot, MetricType, L2},
distance::{Dot, MetricType, L2},
kmeans::{KMeans, KMeansParams},
};

/// Train KMeans model and returns the centroids of each cluster.
#[allow(clippy::too_many_arguments)]
pub async fn train_kmeans<T: ArrowFloatType + Dot + L2 + Cosine>(
pub async fn train_kmeans<T: ArrowFloatType + Dot + L2>(
array: &T::ArrayType,
centroids: Option<Arc<T::ArrayType>>,
dimension: usize,
Expand Down
84 changes: 16 additions & 68 deletions rust/lance-index/src/vector/pq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,8 @@ use async_trait::async_trait;
use lance_arrow::floats::FloatArray;
use lance_arrow::*;
use lance_core::{Error, Result};
use lance_linalg::distance::{
cosine_distance_batch, dot_distance_batch, l2_distance_batch, Cosine, Dot, L2,
};
use lance_linalg::kernels::{argmin, argmin_value_float, normalize};
use lance_linalg::distance::{dot_distance_batch, l2_distance_batch, Dot, L2};
use lance_linalg::kernels::{argmin, argmin_value_float};
use lance_linalg::{distance::MetricType, MatrixView};
use snafu::{location, Location};
pub mod builder;
Expand Down Expand Up @@ -82,7 +80,7 @@ pub trait ProductQuantizer: Send + Sync + std::fmt::Debug {
//
// TODO: move this to be pub(crate) once we have a better way to test it.
#[derive(Debug)]
pub struct ProductQuantizerImpl<T: ArrowFloatType + Cosine + Dot + L2> {
pub struct ProductQuantizerImpl<T: ArrowFloatType + Dot + L2> {
/// Number of bits for the centroids.
///
/// Only support 8, as one of `u8` byte now.
Expand Down Expand Up @@ -117,7 +115,7 @@ pub struct ProductQuantizerImpl<T: ArrowFloatType + Cosine + Dot + L2> {
pub codebook: Arc<T::ArrayType>,
}

impl<T: ArrowFloatType + Cosine + Dot + L2> ProductQuantizerImpl<T> {
impl<T: ArrowFloatType + Dot + L2> ProductQuantizerImpl<T> {
/// Create a [`ProductQuantizer`] with pre-trained codebook.
pub fn new(
m: usize,
Expand All @@ -126,6 +124,11 @@ impl<T: ArrowFloatType + Cosine + Dot + L2> ProductQuantizerImpl<T> {
codebook: Arc<T::ArrayType>,
metric_type: MetricType,
) -> Self {
assert_ne!(
metric_type,
MetricType::Cosine,
"Product quantization does not support cosine, use normalized L2 instead"
);
assert_eq!(nbits, 8, "nbits can only be 8");
Self {
num_bits: nbits,
Expand Down Expand Up @@ -201,12 +204,12 @@ impl<T: ArrowFloatType + Cosine + Dot + L2> ProductQuantizerImpl<T> {
lance_linalg::distance::DistanceType::L2 => {
l2_distance_batch(sub_vec, centroids, sub_vector_width)
}
lance_linalg::distance::DistanceType::Cosine => {
cosine_distance_batch(sub_vec, centroids, sub_vector_width)
}
lance_linalg::distance::DistanceType::Dot => {
dot_distance_batch(sub_vec, centroids, sub_vector_width)
}
lance_linalg::distance::DistanceType::Cosine => {
panic!("There should not be cosine for PQ");
}
};
argmin_value_float(distances).map(|(_, v)| v).unwrap_or(0.0)
})
Expand Down Expand Up @@ -321,7 +324,7 @@ impl<T: ArrowFloatType + Cosine + Dot + L2> ProductQuantizerImpl<T> {
}

#[async_trait]
impl<T: ArrowFloatType + Cosine + Dot + L2 + 'static> ProductQuantizer for ProductQuantizerImpl<T> {
impl<T: ArrowFloatType + Dot + L2 + 'static> ProductQuantizer for ProductQuantizerImpl<T> {
fn as_any(&self) -> &dyn Any {
self
}
Expand All @@ -331,36 +334,13 @@ impl<T: ArrowFloatType + Cosine + Dot + L2 + 'static> ProductQuantizer for Produ
.as_fixed_size_list_opt()
.ok_or(Error::Index {
message: format!(
"Expect to be a float vector array, got: {:?}",
"Expect to be a FixedSizeList<float> vector array, got: {:?} array",
data.data_type()
),
location: location!(),
})?
.clone();

let fsl = if self.metric_type == MetricType::Cosine {
// Normalize cosine vectors to unit length.
let values = fsl
.values()
.as_any()
.downcast_ref::<T::ArrayType>()
.ok_or(Error::Index {
message: format!(
"Expect to be a float vector array, got: {:?}",
fsl.value_type()
),
location: location!(),
})?
.as_slice()
.chunks(self.dimension)
.flat_map(normalize)
.collect::<Vec<_>>();
let data = T::ArrayType::from(values);
FixedSizeListArray::try_new_from_values(data, self.dimension as i32)?
} else {
fsl
};

let num_sub_vectors = self.num_sub_vectors;
let dim = self.dimension;
let num_rows = fsl.len();
Expand Down Expand Up @@ -435,20 +415,10 @@ impl<T: ArrowFloatType + Cosine + Dot + L2 + 'static> ProductQuantizer for Produ
match self.metric_type {
MetricType::L2 => self.l2_distances(query, code),
MetricType::Cosine => {
let query: &T::ArrayType = query.as_any().downcast_ref().ok_or(Error::Index {
message: format!(
"Build cosine distance table, type mismatch: {}",
query.data_type()
),
location: Default::default(),
})?;

// Normalized query vector.
let query = T::ArrayType::from(normalize(query.as_slice()).collect::<Vec<_>>());
// L2 over normalized vectors: ||x - y|| = x^2 + y^2 - 2 * xy = 1 + 1 - 2 * xy = 2 * (1 - xy)
// Cosine distance: 1 - |xy| / (||x|| * ||y||) = 1 - xy / (x^2 * y^2) = 1 - xy / (1 * 1) = 1 - xy
// Therefore, Cosine = L2 / 2
let l2_dists = self.l2_distances(&query, code)?;
let l2_dists = self.l2_distances(query, code)?;
Ok(l2_dists.values().iter().map(|v| *v / 2.0).collect())
}
MetricType::Dot => self.dot_distances(query, code),
Expand Down Expand Up @@ -506,7 +476,7 @@ mod tests {
use approx::assert_relative_eq;
use arrow_array::{
types::{Float16Type, Float32Type},
Float16Array, Float32Array,
Float16Array,
};
use half::f16;
use lance_testing::datagen::generate_random_array;
Expand Down Expand Up @@ -535,28 +505,6 @@ mod tests {
assert_eq!(tensor.shape, vec![256, 16]);
}

#[tokio::test]
async fn test_empty_dist_iter() {
let pq = ProductQuantizerImpl::<Float32Type> {
num_bits: 8,
num_sub_vectors: 4,
dimension: 16,
codebook: Arc::new(Float32Array::from_iter_values(
(0..256 * 16).map(|v| v as f32),
)),
metric_type: MetricType::Cosine,
};

let data = Float32Array::from_iter_values(repeat(0.0).take(16));
let data = FixedSizeListArray::try_new_from_values(data, 16).unwrap();
let rst = pq.transform(&data).await;
assert!(rst.is_err());
assert!(rst
.unwrap_err()
.to_string()
.contains("it is likely that distance is NaN"));
}

#[tokio::test]
async fn test_l2_distance() {
const DIM: usize = 512;
Expand Down
Loading

0 comments on commit f1218e3

Please sign in to comment.