Skip to content

Commit

Permalink
fix: f32 reduce_min for x86 (#2385)
Browse files Browse the repository at this point in the history
Shift did not reach only 1 value left.

For f32x8:

```
// min(4 elements)
m1  = [v0, v1, v2, v3, v4, v5, v6, v7]
m2  = [v4, v5, v6, v7, xx, xx, xx, xx]
min = [v0, v5, v2, v3, xx, xx, xx, xx]

// min(2 elements)
m1  = [v0, v5, v2, v3, xx, xx, xx, xx]
m2  = [v2, v3, xx, xx, xx, xx, xx, xx]
min = [v0, v3, xx, xx, xx, xx, xx, xx]

// min(1 element) (This step is missing now)
m1  = [v0, v3, xx, xx, xx, xx, xx, xx]
m2  = [v3, xx, xx, xx, xx, xx, xx, xx]
min = [v0, xx, xx, xx, xx, xx, xx, xx]
```
  • Loading branch information
heiher authored May 28, 2024
1 parent da2b295 commit 35c066a
Showing 1 changed file with 16 additions and 6 deletions.
22 changes: 16 additions & 6 deletions rust/lance-linalg/src/simd/f32.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,9 @@ impl SIMD<f32, 8> for f32x8 {
min = _mm256_min_ps(min, shift);
shift = _mm256_permute_ps(min, 14);
min = _mm256_min_ps(min, shift);
let mut results: [f32; 8] = [0f32; 8];
_mm256_storeu_ps(results.as_mut_ptr(), min);
results[0]
shift = _mm256_permute_ps(min, 1);
min = _mm256_min_ps(min, shift);
_mm256_cvtss_f32(min)
}
}
#[cfg(target_arch = "aarch64")]
Expand Down Expand Up @@ -520,9 +520,9 @@ impl SIMD<f32, 16> for f32x16 {
m1 = _mm256_min_ps(m1, m2);
m2 = _mm256_permute_ps(m1, 14);
m1 = _mm256_min_ps(m1, m2);
let mut results: [f32; 8] = [0f32; 8];
_mm256_storeu_ps(results.as_mut_ptr(), m1);
results[0]
m2 = _mm256_permute_ps(m1, 1);
m1 = _mm256_min_ps(m1, m2);
_mm256_cvtss_f32(m1)
}

#[cfg(target_arch = "aarch64")]
Expand Down Expand Up @@ -782,8 +782,10 @@ mod tests {
fn test_f32x8_cmp_ops() {
let a = [1.0_f32, 2.0, 5.0, 6.0, 7.0, 3.0, 2.0, 1.0];
let b = [2.0_f32, 1.0, 4.0, 5.0, 9.0, 5.0, 6.0, 2.0];
let c = [2.0_f32, 1.0, 4.0, 5.0, 7.0, 3.0, 2.0, 1.0];
let simd_a: f32x8 = (&a).into();
let simd_b: f32x8 = (&b).into();
let simd_c: f32x8 = (&c).into();

let min_simd = simd_a.min(&simd_b);
assert_eq!(
Expand All @@ -792,6 +794,8 @@ mod tests {
);
let min_val = min_simd.reduce_min();
assert_eq!(min_val, 1.0);
let min_val = simd_c.reduce_min();
assert_eq!(min_val, 1.0);

assert_eq!(Some(2), simd_a.find(5.0));
assert_eq!(Some(1), simd_a.find(2.0));
Expand Down Expand Up @@ -836,8 +840,12 @@ mod tests {
let b = [
2.0_f32, 1.0, 4.0, 5.0, 9.0, 5.0, 6.0, 2.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 2.0, 1.0,
];
let c = [
1.0_f32, 1.0, 4.0, 5.0, 7.0, 3.0, 2.0, 1.0, -0.5, 5.0, 6.0, 7.0, 8.0, 9.0, 1.0, -1.0,
];
let simd_a: f32x16 = (&a).into();
let simd_b: f32x16 = (&b).into();
let simd_c: f32x16 = (&c).into();

let min_simd = simd_a.min(&simd_b);
assert_eq!(
Expand All @@ -846,6 +854,8 @@ mod tests {
);
let min_val = min_simd.reduce_min();
assert_eq!(min_val, -0.5);
let min_val = simd_c.reduce_min();
assert_eq!(min_val, -1.0);

assert_eq!(Some(2), simd_a.find(5.0));
assert_eq!(Some(1), simd_a.find(2.0));
Expand Down

0 comments on commit 35c066a

Please sign in to comment.