From 35c066a39b07d607a5270d66d274037aee06f103 Mon Sep 17 00:00:00 2001 From: hev Date: Wed, 29 May 2024 04:29:47 +0800 Subject: [PATCH] fix: f32 reduce_min for x86 (#2385) Shift did not reach only 1 value left. For f32x8: ``` // min(4 elements) m1 = [v0, v1, v2, v3, v4, v5, v6, v7] m2 = [v4, v5, v6, v7, xx, xx, xx, xx] min = [v0, v5, v2, v3, xx, xx, xx, xx] // min(2 elements) m1 = [v0, v5, v2, v3, xx, xx, xx, xx] m2 = [v2, v3, xx, xx, xx, xx, xx, xx] min = [v0, v3, xx, xx, xx, xx, xx, xx] // min(1 element) (This step is missing now) m1 = [v0, v3, xx, xx, xx, xx, xx, xx] m2 = [v3, xx, xx, xx, xx, xx, xx, xx] min = [v0, xx, xx, xx, xx, xx, xx, xx] ``` --- rust/lance-linalg/src/simd/f32.rs | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/rust/lance-linalg/src/simd/f32.rs b/rust/lance-linalg/src/simd/f32.rs index f365185fa3..c204e116a9 100644 --- a/rust/lance-linalg/src/simd/f32.rs +++ b/rust/lance-linalg/src/simd/f32.rs @@ -177,9 +177,9 @@ impl SIMD for f32x8 { min = _mm256_min_ps(min, shift); shift = _mm256_permute_ps(min, 14); min = _mm256_min_ps(min, shift); - let mut results: [f32; 8] = [0f32; 8]; - _mm256_storeu_ps(results.as_mut_ptr(), min); - results[0] + shift = _mm256_permute_ps(min, 1); + min = _mm256_min_ps(min, shift); + _mm256_cvtss_f32(min) } } #[cfg(target_arch = "aarch64")] @@ -520,9 +520,9 @@ impl SIMD for f32x16 { m1 = _mm256_min_ps(m1, m2); m2 = _mm256_permute_ps(m1, 14); m1 = _mm256_min_ps(m1, m2); - let mut results: [f32; 8] = [0f32; 8]; - _mm256_storeu_ps(results.as_mut_ptr(), m1); - results[0] + m2 = _mm256_permute_ps(m1, 1); + m1 = _mm256_min_ps(m1, m2); + _mm256_cvtss_f32(m1) } #[cfg(target_arch = "aarch64")] @@ -782,8 +782,10 @@ mod tests { fn test_f32x8_cmp_ops() { let a = [1.0_f32, 2.0, 5.0, 6.0, 7.0, 3.0, 2.0, 1.0]; let b = [2.0_f32, 1.0, 4.0, 5.0, 9.0, 5.0, 6.0, 2.0]; + let c = [2.0_f32, 1.0, 4.0, 5.0, 7.0, 3.0, 2.0, 1.0]; let simd_a: f32x8 = (&a).into(); let simd_b: f32x8 = (&b).into(); + let simd_c: f32x8 = (&c).into(); let min_simd = simd_a.min(&simd_b); assert_eq!( @@ -792,6 +794,8 @@ mod tests { ); let min_val = min_simd.reduce_min(); assert_eq!(min_val, 1.0); + let min_val = simd_c.reduce_min(); + assert_eq!(min_val, 1.0); assert_eq!(Some(2), simd_a.find(5.0)); assert_eq!(Some(1), simd_a.find(2.0)); @@ -836,8 +840,12 @@ mod tests { let b = [ 2.0_f32, 1.0, 4.0, 5.0, 9.0, 5.0, 6.0, 2.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 2.0, 1.0, ]; + let c = [ + 1.0_f32, 1.0, 4.0, 5.0, 7.0, 3.0, 2.0, 1.0, -0.5, 5.0, 6.0, 7.0, 8.0, 9.0, 1.0, -1.0, + ]; let simd_a: f32x16 = (&a).into(); let simd_b: f32x16 = (&b).into(); + let simd_c: f32x16 = (&c).into(); let min_simd = simd_a.min(&simd_b); assert_eq!( @@ -846,6 +854,8 @@ mod tests { ); let min_val = min_simd.reduce_min(); assert_eq!(min_val, -0.5); + let min_val = simd_c.reduce_min(); + assert_eq!(min_val, -1.0); assert_eq!(Some(2), simd_a.find(5.0)); assert_eq!(Some(1), simd_a.find(2.0));