memchr/
vector.rs

1/// A trait for describing vector operations used by vectorized searchers.
2///
3/// The trait is highly constrained to low level vector operations needed.
4/// In general, it was invented mostly to be generic over x86's __m128i and
5/// __m256i types. At time of writing, it also supports wasm and aarch64
6/// 128-bit vector types as well.
7///
8/// # Safety
9///
10/// All methods are not safe since they are intended to be implemented using
11/// vendor intrinsics, which are also not safe. Callers must ensure that the
12/// appropriate target features are enabled in the calling function, and that
13/// the current CPU supports them. All implementations should avoid marking the
14/// routines with #[target_feature] and instead mark them as #[inline(always)]
15/// to ensure they get appropriately inlined. (inline(always) cannot be used
16/// with target_feature.)
17pub(crate) trait Vector: Copy + core::fmt::Debug {
18    /// The number of bits in the vector.
19    const BITS: usize;
20    /// The number of bytes in the vector. That is, this is the size of the
21    /// vector in memory.
22    const BYTES: usize;
23    /// The bits that must be zero in order for a `*const u8` pointer to be
24    /// correctly aligned to read vector values.
25    const ALIGN: usize;
26
27    /// The type of the value returned by `Vector::movemask`.
28    ///
29    /// This supports abstracting over the specific representation used in
30    /// order to accommodate different representations in different ISAs.
31    type Mask: MoveMask;
32
33    /// Create a vector with 8-bit lanes with the given byte repeated into each
34    /// lane.
35    unsafe fn splat(byte: u8) -> Self;
36
37    /// Read a vector-size number of bytes from the given pointer. The pointer
38    /// must be aligned to the size of the vector.
39    ///
40    /// # Safety
41    ///
42    /// Callers must guarantee that at least `BYTES` bytes are readable from
43    /// `data` and that `data` is aligned to a `BYTES` boundary.
44    unsafe fn load_aligned(data: *const u8) -> Self;
45
46    /// Read a vector-size number of bytes from the given pointer. The pointer
47    /// does not need to be aligned.
48    ///
49    /// # Safety
50    ///
51    /// Callers must guarantee that at least `BYTES` bytes are readable from
52    /// `data`.
53    unsafe fn load_unaligned(data: *const u8) -> Self;
54
55    /// _mm_movemask_epi8 or _mm256_movemask_epi8
56    unsafe fn movemask(self) -> Self::Mask;
57    /// _mm_cmpeq_epi8 or _mm256_cmpeq_epi8
58    unsafe fn cmpeq(self, vector2: Self) -> Self;
59    /// _mm_and_si128 or _mm256_and_si256
60    unsafe fn and(self, vector2: Self) -> Self;
61    /// _mm_or or _mm256_or_si256
62    unsafe fn or(self, vector2: Self) -> Self;
63    /// Returns true if and only if `Self::movemask` would return a mask that
64    /// contains at least one non-zero bit.
65    unsafe fn movemask_will_have_non_zero(self) -> bool {
66        self.movemask().has_non_zero()
67    }
68}
69
70/// A trait that abstracts over a vector-to-scalar operation called
71/// "move mask."
72///
73/// On x86-64, this is `_mm_movemask_epi8` for SSE2 and `_mm256_movemask_epi8`
74/// for AVX2. It takes a vector of `u8` lanes and returns a scalar where the
75/// `i`th bit is set if and only if the most significant bit in the `i`th lane
76/// of the vector is set. The simd128 ISA for wasm32 also supports this
77/// exact same operation natively.
78///
79/// ... But aarch64 doesn't. So we have to fake it with more instructions and
80/// a slightly different representation. We could do extra work to unify the
81/// representations, but then would require additional costs in the hot path
82/// for `memchr` and `packedpair`. So instead, we abstraction over the specific
83/// representation with this trait an ddefine the operations we actually need.
84pub(crate) trait MoveMask: Copy + core::fmt::Debug {
85    /// Return a mask that is all zeros except for the least significant `n`
86    /// lanes in a corresponding vector.
87    fn all_zeros_except_least_significant(n: usize) -> Self;
88
89    /// Returns true if and only if this mask has a a non-zero bit anywhere.
90    fn has_non_zero(self) -> bool;
91
92    /// Returns the number of bits set to 1 in this mask.
93    fn count_ones(self) -> usize;
94
95    /// Does a bitwise `and` operation between `self` and `other`.
96    fn and(self, other: Self) -> Self;
97
98    /// Does a bitwise `or` operation between `self` and `other`.
99    fn or(self, other: Self) -> Self;
100
101    /// Returns a mask that is equivalent to `self` but with the least
102    /// significant 1-bit set to 0.
103    fn clear_least_significant_bit(self) -> Self;
104
105    /// Returns the offset of the first non-zero lane this mask represents.
106    fn first_offset(self) -> usize;
107
108    /// Returns the offset of the last non-zero lane this mask represents.
109    fn last_offset(self) -> usize;
110}
111
112/// This is a "sensible" movemask implementation where each bit represents
113/// whether the most significant bit is set in each corresponding lane of a
114/// vector. This is used on x86-64 and wasm, but such a mask is more expensive
115/// to get on aarch64 so we use something a little different.
116///
117/// We call this "sensible" because this is what we get using native sse/avx
118/// movemask instructions. But neon has no such native equivalent.
119#[derive(Clone, Copy, Debug)]
120pub(crate) struct SensibleMoveMask(u32);
121
122impl SensibleMoveMask {
123    /// Get the mask in a form suitable for computing offsets.
124    ///
125    /// Basically, this normalizes to little endian. On big endian, this swaps
126    /// the bytes.
127    #[inline(always)]
128    fn get_for_offset(self) -> u32 {
129        #[cfg(target_endian = "big")]
130        {
131            self.0.swap_bytes()
132        }
133        #[cfg(target_endian = "little")]
134        {
135            self.0
136        }
137    }
138}
139
140impl MoveMask for SensibleMoveMask {
141    #[inline(always)]
142    fn all_zeros_except_least_significant(n: usize) -> SensibleMoveMask {
143        debug_assert!(n < 32);
144        SensibleMoveMask(!((1 << n) - 1))
145    }
146
147    #[inline(always)]
148    fn has_non_zero(self) -> bool {
149        self.0 != 0
150    }
151
152    #[inline(always)]
153    fn count_ones(self) -> usize {
154        self.0.count_ones() as usize
155    }
156
157    #[inline(always)]
158    fn and(self, other: SensibleMoveMask) -> SensibleMoveMask {
159        SensibleMoveMask(self.0 & other.0)
160    }
161
162    #[inline(always)]
163    fn or(self, other: SensibleMoveMask) -> SensibleMoveMask {
164        SensibleMoveMask(self.0 | other.0)
165    }
166
167    #[inline(always)]
168    fn clear_least_significant_bit(self) -> SensibleMoveMask {
169        SensibleMoveMask(self.0 & (self.0 - 1))
170    }
171
172    #[inline(always)]
173    fn first_offset(self) -> usize {
174        // We are dealing with little endian here (and if we aren't, we swap
175        // the bytes so we are in practice), where the most significant byte
176        // is at a higher address. That means the least significant bit that
177        // is set corresponds to the position of our first matching byte.
178        // That position corresponds to the number of zeros after the least
179        // significant bit.
180        self.get_for_offset().trailing_zeros() as usize
181    }
182
183    #[inline(always)]
184    fn last_offset(self) -> usize {
185        // We are dealing with little endian here (and if we aren't, we swap
186        // the bytes so we are in practice), where the most significant byte is
187        // at a higher address. That means the most significant bit that is set
188        // corresponds to the position of our last matching byte. The position
189        // from the end of the mask is therefore the number of leading zeros
190        // in a 32 bit integer, and the position from the start of the mask is
191        // therefore 32 - (leading zeros) - 1.
192        32 - self.get_for_offset().leading_zeros() as usize - 1
193    }
194}
195
196#[cfg(target_arch = "x86_64")]
197mod x86sse2 {
198    use core::arch::x86_64::*;
199
200    use super::{SensibleMoveMask, Vector};
201
202    impl Vector for __m128i {
203        const BITS: usize = 128;
204        const BYTES: usize = 16;
205        const ALIGN: usize = Self::BYTES - 1;
206
207        type Mask = SensibleMoveMask;
208
209        #[inline(always)]
210        unsafe fn splat(byte: u8) -> __m128i {
211            _mm_set1_epi8(byte as i8)
212        }
213
214        #[inline(always)]
215        unsafe fn load_aligned(data: *const u8) -> __m128i {
216            _mm_load_si128(data as *const __m128i)
217        }
218
219        #[inline(always)]
220        unsafe fn load_unaligned(data: *const u8) -> __m128i {
221            _mm_loadu_si128(data as *const __m128i)
222        }
223
224        #[inline(always)]
225        unsafe fn movemask(self) -> SensibleMoveMask {
226            SensibleMoveMask(_mm_movemask_epi8(self) as u32)
227        }
228
229        #[inline(always)]
230        unsafe fn cmpeq(self, vector2: Self) -> __m128i {
231            _mm_cmpeq_epi8(self, vector2)
232        }
233
234        #[inline(always)]
235        unsafe fn and(self, vector2: Self) -> __m128i {
236            _mm_and_si128(self, vector2)
237        }
238
239        #[inline(always)]
240        unsafe fn or(self, vector2: Self) -> __m128i {
241            _mm_or_si128(self, vector2)
242        }
243    }
244}
245
246#[cfg(target_arch = "x86_64")]
247mod x86avx2 {
248    use core::arch::x86_64::*;
249
250    use super::{SensibleMoveMask, Vector};
251
252    impl Vector for __m256i {
253        const BITS: usize = 256;
254        const BYTES: usize = 32;
255        const ALIGN: usize = Self::BYTES - 1;
256
257        type Mask = SensibleMoveMask;
258
259        #[inline(always)]
260        unsafe fn splat(byte: u8) -> __m256i {
261            _mm256_set1_epi8(byte as i8)
262        }
263
264        #[inline(always)]
265        unsafe fn load_aligned(data: *const u8) -> __m256i {
266            _mm256_load_si256(data as *const __m256i)
267        }
268
269        #[inline(always)]
270        unsafe fn load_unaligned(data: *const u8) -> __m256i {
271            _mm256_loadu_si256(data as *const __m256i)
272        }
273
274        #[inline(always)]
275        unsafe fn movemask(self) -> SensibleMoveMask {
276            SensibleMoveMask(_mm256_movemask_epi8(self) as u32)
277        }
278
279        #[inline(always)]
280        unsafe fn cmpeq(self, vector2: Self) -> __m256i {
281            _mm256_cmpeq_epi8(self, vector2)
282        }
283
284        #[inline(always)]
285        unsafe fn and(self, vector2: Self) -> __m256i {
286            _mm256_and_si256(self, vector2)
287        }
288
289        #[inline(always)]
290        unsafe fn or(self, vector2: Self) -> __m256i {
291            _mm256_or_si256(self, vector2)
292        }
293    }
294}
295
296#[cfg(target_arch = "aarch64")]
297mod aarch64neon {
298    use core::arch::aarch64::*;
299
300    use super::{MoveMask, Vector};
301
302    impl Vector for uint8x16_t {
303        const BITS: usize = 128;
304        const BYTES: usize = 16;
305        const ALIGN: usize = Self::BYTES - 1;
306
307        type Mask = NeonMoveMask;
308
309        #[inline(always)]
310        unsafe fn splat(byte: u8) -> uint8x16_t {
311            vdupq_n_u8(byte)
312        }
313
314        #[inline(always)]
315        unsafe fn load_aligned(data: *const u8) -> uint8x16_t {
316            // I've tried `data.cast::<uint8x16_t>().read()` instead, but
317            // couldn't observe any benchmark differences.
318            Self::load_unaligned(data)
319        }
320
321        #[inline(always)]
322        unsafe fn load_unaligned(data: *const u8) -> uint8x16_t {
323            vld1q_u8(data)
324        }
325
326        #[inline(always)]
327        unsafe fn movemask(self) -> NeonMoveMask {
328            let asu16s = vreinterpretq_u16_u8(self);
329            let mask = vshrn_n_u16(asu16s, 4);
330            let asu64 = vreinterpret_u64_u8(mask);
331            let scalar64 = vget_lane_u64(asu64, 0);
332            NeonMoveMask(scalar64 & 0x8888888888888888)
333        }
334
335        #[inline(always)]
336        unsafe fn cmpeq(self, vector2: Self) -> uint8x16_t {
337            vceqq_u8(self, vector2)
338        }
339
340        #[inline(always)]
341        unsafe fn and(self, vector2: Self) -> uint8x16_t {
342            vandq_u8(self, vector2)
343        }
344
345        #[inline(always)]
346        unsafe fn or(self, vector2: Self) -> uint8x16_t {
347            vorrq_u8(self, vector2)
348        }
349
350        /// This is the only interesting implementation of this routine.
351        /// Basically, instead of doing the "shift right narrow" dance, we use
352        /// adajacent folding max to determine whether there are any non-zero
353        /// bytes in our mask. If there are, *then* we'll do the "shift right
354        /// narrow" dance. In benchmarks, this does lead to slightly better
355        /// throughput, but the win doesn't appear huge.
356        #[inline(always)]
357        unsafe fn movemask_will_have_non_zero(self) -> bool {
358            let low = vreinterpretq_u64_u8(vpmaxq_u8(self, self));
359            vgetq_lane_u64(low, 0) != 0
360        }
361    }
362
363    /// Neon doesn't have a `movemask` that works like the one in x86-64, so we
364    /// wind up using a different method[1]. The different method also produces
365    /// a mask, but 4 bits are set in the neon case instead of a single bit set
366    /// in the x86-64 case. We do an extra step to zero out 3 of the 4 bits,
367    /// but we still wind up with at least 3 zeroes between each set bit. This
368    /// generally means that we need to do some division by 4 before extracting
369    /// offsets.
370    ///
371    /// In fact, the existence of this type is the entire reason that we have
372    /// the `MoveMask` trait in the first place. This basically lets us keep
373    /// the different representations of masks without being forced to unify
374    /// them into a single representation, which could result in extra and
375    /// unnecessary work.
376    ///
377    /// [1]: https://community.arm.com/arm-community-blogs/b/infrastructure-solutions-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon
378    #[derive(Clone, Copy, Debug)]
379    pub(crate) struct NeonMoveMask(u64);
380
381    impl NeonMoveMask {
382        /// Get the mask in a form suitable for computing offsets.
383        ///
384        /// Basically, this normalizes to little endian. On big endian, this
385        /// swaps the bytes.
386        #[inline(always)]
387        fn get_for_offset(self) -> u64 {
388            #[cfg(target_endian = "big")]
389            {
390                self.0.swap_bytes()
391            }
392            #[cfg(target_endian = "little")]
393            {
394                self.0
395            }
396        }
397    }
398
399    impl MoveMask for NeonMoveMask {
400        #[inline(always)]
401        fn all_zeros_except_least_significant(n: usize) -> NeonMoveMask {
402            debug_assert!(n < 16);
403            NeonMoveMask(!(((1 << n) << 2) - 1))
404        }
405
406        #[inline(always)]
407        fn has_non_zero(self) -> bool {
408            self.0 != 0
409        }
410
411        #[inline(always)]
412        fn count_ones(self) -> usize {
413            self.0.count_ones() as usize
414        }
415
416        #[inline(always)]
417        fn and(self, other: NeonMoveMask) -> NeonMoveMask {
418            NeonMoveMask(self.0 & other.0)
419        }
420
421        #[inline(always)]
422        fn or(self, other: NeonMoveMask) -> NeonMoveMask {
423            NeonMoveMask(self.0 | other.0)
424        }
425
426        #[inline(always)]
427        fn clear_least_significant_bit(self) -> NeonMoveMask {
428            NeonMoveMask(self.0 & (self.0 - 1))
429        }
430
431        #[inline(always)]
432        fn first_offset(self) -> usize {
433            // We are dealing with little endian here (and if we aren't,
434            // we swap the bytes so we are in practice), where the most
435            // significant byte is at a higher address. That means the least
436            // significant bit that is set corresponds to the position of our
437            // first matching byte. That position corresponds to the number of
438            // zeros after the least significant bit.
439            //
440            // Note that unlike `SensibleMoveMask`, this mask has its bits
441            // spread out over 64 bits instead of 16 bits (for a 128 bit
442            // vector). Namely, where as x86-64 will turn
443            //
444            //   0x00 0xFF 0x00 0x00 0xFF
445            //
446            // into 10010, our neon approach will turn it into
447            //
448            //   10000000000010000000
449            //
450            // And this happens because neon doesn't have a native `movemask`
451            // instruction, so we kind of fake it[1]. Thus, we divide the
452            // number of trailing zeros by 4 to get the "real" offset.
453            //
454            // [1]: https://community.arm.com/arm-community-blogs/b/infrastructure-solutions-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon
455            (self.get_for_offset().trailing_zeros() >> 2) as usize
456        }
457
458        #[inline(always)]
459        fn last_offset(self) -> usize {
460            // See comment in `first_offset` above. This is basically the same,
461            // but coming from the other direction.
462            16 - (self.get_for_offset().leading_zeros() >> 2) as usize - 1
463        }
464    }
465}
466
467#[cfg(target_arch = "wasm32")]
468mod wasm_simd128 {
469    use core::arch::wasm32::*;
470
471    use super::{SensibleMoveMask, Vector};
472
473    impl Vector for v128 {
474        const BITS: usize = 128;
475        const BYTES: usize = 16;
476        const ALIGN: usize = Self::BYTES - 1;
477
478        type Mask = SensibleMoveMask;
479
480        #[inline(always)]
481        unsafe fn splat(byte: u8) -> v128 {
482            u8x16_splat(byte)
483        }
484
485        #[inline(always)]
486        unsafe fn load_aligned(data: *const u8) -> v128 {
487            *data.cast()
488        }
489
490        #[inline(always)]
491        unsafe fn load_unaligned(data: *const u8) -> v128 {
492            v128_load(data.cast())
493        }
494
495        #[inline(always)]
496        unsafe fn movemask(self) -> SensibleMoveMask {
497            SensibleMoveMask(u8x16_bitmask(self).into())
498        }
499
500        #[inline(always)]
501        unsafe fn cmpeq(self, vector2: Self) -> v128 {
502            u8x16_eq(self, vector2)
503        }
504
505        #[inline(always)]
506        unsafe fn and(self, vector2: Self) -> v128 {
507            v128_and(self, vector2)
508        }
509
510        #[inline(always)]
511        unsafe fn or(self, vector2: Self) -> v128 {
512            v128_or(self, vector2)
513        }
514    }
515}