diff --git a/lib/line-index/src/lib.rs b/lib/line-index/src/lib.rs index 58f266d67f62..1614504f80a3 100644 --- a/lib/line-index/src/lib.rs +++ b/lib/line-index/src/lib.rs @@ -227,6 +227,22 @@ fn analyze_source_file_dispatch( } } +#[cfg(target_arch = "aarch64")] +fn analyze_source_file_dispatch( + src: &str, + lines: &mut Vec, + multi_byte_chars: &mut IntMap>, +) { + if std::arch::is_aarch64_feature_detected!("neon") { + // SAFETY: NEON support was checked + unsafe { + analyze_source_file_neon(src, lines, multi_byte_chars); + } + } else { + analyze_source_file_generic(src, src.len(), TextSize::from(0), lines, multi_byte_chars); + } +} + /// Checks 16 byte chunks of text at a time. If the chunk contains /// something other than printable ASCII characters and newlines, the /// function falls back to the generic implementation. Otherwise it uses @@ -322,7 +338,102 @@ unsafe fn analyze_source_file_sse2( } } -#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))] +#[target_feature(enable = "neon")] +#[cfg(any(target_arch = "aarch64"))] +#[inline] +// See https://community.arm.com/arm-community-blogs/b/infrastructure-solutions-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon +// +// The mask is a 64-bit integer, where each 4-bit corresponds to a u8 in the +// input vector. The least significant 4 bits correspond to the first byte in +// the vector. +unsafe fn move_mask(v: std::arch::aarch64::uint8x16_t) -> u64 { + use std::arch::aarch64::*; + + let nibble_mask = vshrn_n_u16(vreinterpretq_u16_u8(v), 4); + vget_lane_u64(vreinterpret_u64_u8(nibble_mask), 0) +} + +#[target_feature(enable = "neon")] +#[cfg(any(target_arch = "aarch64"))] +unsafe fn analyze_source_file_neon( + src: &str, + lines: &mut Vec, + multi_byte_chars: &mut IntMap>, +) { + use std::arch::aarch64::*; + + const CHUNK_SIZE: usize = 16; + + let src_bytes = src.as_bytes(); + + let chunk_count = src.len() / CHUNK_SIZE; + + let newline = vdupq_n_s8(b'\n' as i8); + + // This variable keeps track of where we should start decoding a + // chunk. If a multi-byte character spans across chunk boundaries, + // we need to skip that part in the next chunk because we already + // handled it. + let mut intra_chunk_offset = 0; + + for chunk_index in 0..chunk_count { + let ptr = src_bytes.as_ptr() as *const i8; + let chunk = vld1q_s8(ptr.add(chunk_index * CHUNK_SIZE)); + + // For character in the chunk, see if its byte value is < 0, which + // indicates that it's part of a UTF-8 char. + let multibyte_test = vcltzq_s8(chunk); + // Create a bit mask from the comparison results. + let multibyte_mask = move_mask(multibyte_test); + + // If the bit mask is all zero, we only have ASCII chars here: + if multibyte_mask == 0 { + assert!(intra_chunk_offset == 0); + + // Check for newlines in the chunk + let newlines_test = vceqq_s8(chunk, newline); + let mut newlines_mask = move_mask(newlines_test); + + // If the bit mask is not all zero, there are newlines in this chunk. + if newlines_mask != 0 { + let output_offset = TextSize::from((chunk_index * CHUNK_SIZE + 1) as u32); + + while newlines_mask != 0 { + let trailing_zeros = newlines_mask.trailing_zeros(); + let index = trailing_zeros / 4; + + lines.push(TextSize::from(index) + output_offset); + + // Clear the current 4-bit, so we can find the next one. + newlines_mask &= (!0xF) << trailing_zeros; + } + } + continue; + } + + let scan_start = chunk_index * CHUNK_SIZE + intra_chunk_offset; + intra_chunk_offset = analyze_source_file_generic( + &src[scan_start..], + CHUNK_SIZE - intra_chunk_offset, + TextSize::from(scan_start as u32), + lines, + multi_byte_chars, + ); + } + + let tail_start = chunk_count * CHUNK_SIZE + intra_chunk_offset; + if tail_start < src.len() { + analyze_source_file_generic( + &src[tail_start..], + src.len() - tail_start, + TextSize::from(tail_start as u32), + lines, + multi_byte_chars, + ); + } +} + +#[cfg(not(any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64")))] // The target (or compiler version) does not support SSE2 ... fn analyze_source_file_dispatch( src: &str,