simba/simd/
simd_value.rs

1use crate::simd::SimdBool;
2
3/// Base trait for every SIMD types.
4pub trait SimdValue: Sized {
5    /// The number of lanes of this SIMD value.
6    const LANES: usize;
7    /// The type of the elements of each lane of this SIMD value.
8    type Element: SimdValue<Element = Self::Element, SimdBool = bool>;
9    /// Type of the result of comparing two SIMD values like `self`.
10    type SimdBool: SimdBool;
11
12    /// Initializes an SIMD value with each lanes set to `val`.
13    fn splat(val: Self::Element) -> Self;
14    /// Extracts the i-th lane of `self`.
15    ///
16    /// Panics if `i >= Self::LANES`.
17    fn extract(&self, i: usize) -> Self::Element;
18    /// Extracts the i-th lane of `self` without bound-checking.
19    ///
20    /// # Safety
21    /// Undefined behavior if `i >= Self::LANES`.
22    unsafe fn extract_unchecked(&self, i: usize) -> Self::Element;
23    /// Replaces the i-th lane of `self` by `val`.
24    ///
25    /// Panics if `i >= Self::LANES`.
26    fn replace(&mut self, i: usize, val: Self::Element);
27    /// Replaces the i-th lane of `self` by `val` without bound-checking.
28    ///
29    /// # Safety
30    /// Undefined behavior if `i >= Self::LANES`.
31    unsafe fn replace_unchecked(&mut self, i: usize, val: Self::Element);
32
33    /// Merges `self` and `other` depending on the lanes of `cond`.
34    ///
35    /// For each lane of `cond` with bits set to 1, the result's will contain the value of the lane of `self`.
36    /// For each lane of `cond` with bits set to 0, the result's will contain the value of the lane of `other`.
37    fn select(self, cond: Self::SimdBool, other: Self) -> Self;
38
39    /// Applies a function to each lane of `self`.
40    ///
41    /// Note that, while convenient, this method can be extremely slow as this
42    /// requires to extract each lane of `self` and then combine them again into
43    /// a new SIMD value.
44    #[inline(always)]
45    fn map_lanes(self, f: impl Fn(Self::Element) -> Self::Element) -> Self
46    where
47        Self: Clone,
48    {
49        let mut result = self.clone();
50
51        for i in 0..Self::LANES {
52            unsafe { result.replace_unchecked(i, f(self.extract_unchecked(i))) }
53        }
54
55        result
56    }
57
58    /// Applies a function to each lane of `self` paired with the corresponding lane of `b`.
59    ///
60    /// Note that, while convenient, this method can be extremely slow as this
61    /// requires to extract each lane of `self` and then combine them again into
62    /// a new SIMD value.
63    #[inline(always)]
64    fn zip_map_lanes(
65        self,
66        b: Self,
67        f: impl Fn(Self::Element, Self::Element) -> Self::Element,
68    ) -> Self
69    where
70        Self: Clone,
71    {
72        let mut result = self.clone();
73
74        for i in 0..Self::LANES {
75            unsafe {
76                let a = self.extract_unchecked(i);
77                let b = b.extract_unchecked(i);
78                result.replace_unchecked(i, f(a, b))
79            }
80        }
81
82        result
83    }
84}
85
86/// Marker trait implemented by SIMD and non-SIMD primitive numeric values.
87///
88/// This trait is useful for some disambiguations when writing blanked impls.
89/// This is implemented by all unsigned integer, integer, float, and complex types, as
90/// with only one lane, i.e., `f32`, `f64`, `u32`, `i64`, etc. as well as SIMD types like
91/// `f32x4, i32x8`, etc..
92pub trait PrimitiveSimdValue: Copy + SimdValue {}
93
94impl<N: SimdValue> SimdValue for num_complex::Complex<N> {
95    const LANES: usize = N::LANES;
96    type Element = num_complex::Complex<N::Element>;
97    type SimdBool = N::SimdBool;
98
99    #[inline(always)]
100    fn splat(val: Self::Element) -> Self {
101        num_complex::Complex {
102            re: N::splat(val.re),
103            im: N::splat(val.im),
104        }
105    }
106
107    #[inline(always)]
108    fn extract(&self, i: usize) -> Self::Element {
109        num_complex::Complex {
110            re: self.re.extract(i),
111            im: self.im.extract(i),
112        }
113    }
114
115    #[inline(always)]
116    unsafe fn extract_unchecked(&self, i: usize) -> Self::Element {
117        num_complex::Complex {
118            re: self.re.extract_unchecked(i),
119            im: self.im.extract_unchecked(i),
120        }
121    }
122
123    #[inline(always)]
124    fn replace(&mut self, i: usize, val: Self::Element) {
125        self.re.replace(i, val.re);
126        self.im.replace(i, val.im);
127    }
128
129    #[inline(always)]
130    unsafe fn replace_unchecked(&mut self, i: usize, val: Self::Element) {
131        self.re.replace_unchecked(i, val.re);
132        self.im.replace_unchecked(i, val.im);
133    }
134
135    #[inline(always)]
136    fn select(self, cond: Self::SimdBool, other: Self) -> Self {
137        num_complex::Complex {
138            re: self.re.select(cond, other.re),
139            im: self.im.select(cond, other.im),
140        }
141    }
142}
143
144impl<N: PrimitiveSimdValue> PrimitiveSimdValue for num_complex::Complex<N> {}
145
146macro_rules! impl_primitive_simd_value_for_scalar (
147    ($($t: ty),*) => {$(
148        impl PrimitiveSimdValue for $t {}
149        impl SimdValue for $t {
150            const LANES: usize = 1;
151            type Element = $t;
152            type SimdBool = bool;
153
154            #[inline(always)]
155            fn splat(val: Self::Element) -> Self {
156                val
157            }
158
159            #[inline(always)]
160            fn extract(&self, _: usize) -> Self::Element {
161                *self
162            }
163
164            #[inline(always)]
165            unsafe fn extract_unchecked(&self, _: usize) -> Self::Element {
166                *self
167            }
168
169            #[inline(always)]
170            fn replace(&mut self, _: usize, val: Self::Element) {
171                *self = val
172            }
173
174            #[inline(always)]
175            unsafe fn replace_unchecked(&mut self, _: usize, val: Self::Element) {
176                *self = val
177            }
178
179            #[inline(always)]
180            fn select(self, cond: Self::SimdBool, other: Self) -> Self {
181                if cond {
182                    self
183                } else {
184                    other
185                }
186            }
187        }
188    )*}
189);
190
191impl_primitive_simd_value_for_scalar!(
192    bool, u8, u16, u32, u64, u128, usize, i8, i16, i32, i64, i128, isize, f32, f64
193);
194#[cfg(feature = "decimal")]
195impl_primitive_simd_value_for_scalar!(decimal::d128);