nalgebra/base/
array_storage.rs

1// Needed otherwise the rkyv macros generate code incompatible with rust-2024
2#![cfg_attr(feature = "rkyv-serialize", allow(unsafe_op_in_unsafe_fn))]
3
4use std::fmt::{self, Debug, Formatter};
5// use std::hash::{Hash, Hasher};
6use std::ops::Mul;
7
8#[cfg(feature = "serde-serialize-no-std")]
9use serde::de::{Error, SeqAccess, Visitor};
10#[cfg(feature = "serde-serialize-no-std")]
11use serde::ser::SerializeTuple;
12#[cfg(feature = "serde-serialize-no-std")]
13use serde::{Deserialize, Deserializer, Serialize, Serializer};
14#[cfg(feature = "serde-serialize-no-std")]
15use std::marker::PhantomData;
16
17#[cfg(feature = "rkyv-serialize")]
18use rkyv::bytecheck;
19
20use crate::Storage;
21use crate::base::Scalar;
22use crate::base::allocator::Allocator;
23use crate::base::default_allocator::DefaultAllocator;
24use crate::base::dimension::{Const, ToTypenum};
25use crate::base::storage::{IsContiguous, Owned, RawStorage, RawStorageMut, ReshapableStorage};
26use std::mem;
27
28/*
29 *
30 * Static RawStorage.
31 *
32 */
33/// An array-based statically sized matrix data storage.
34#[repr(transparent)]
35#[derive(Copy, Clone, PartialEq, Eq, Hash)]
36#[cfg_attr(
37    feature = "rkyv-serialize-no-std",
38    derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize),
39    archive(
40        as = "ArrayStorage<T::Archived, R, C>",
41        bound(archive = "
42        T: rkyv::Archive,
43        [[T; R]; C]: rkyv::Archive<Archived = [[T::Archived; R]; C]>
44    ")
45    )
46)]
47#[cfg_attr(feature = "rkyv-serialize", derive(bytecheck::CheckBytes))]
48#[cfg_attr(feature = "defmt", derive(defmt::Format))]
49pub struct ArrayStorage<T, const R: usize, const C: usize>(pub [[T; R]; C]);
50
51impl<T, const R: usize, const C: usize> ArrayStorage<T, R, C> {
52    /// Converts this array storage to a slice.
53    #[inline]
54    pub fn as_slice(&self) -> &[T] {
55        // SAFETY: this is OK because ArrayStorage is contiguous.
56        unsafe { self.as_slice_unchecked() }
57    }
58
59    /// Converts this array storage to a mutable slice.
60    #[inline]
61    pub fn as_mut_slice(&mut self) -> &mut [T] {
62        // SAFETY: this is OK because ArrayStorage is contiguous.
63        unsafe { self.as_mut_slice_unchecked() }
64    }
65}
66
67// TODO: remove this once the stdlib implements Default for arrays.
68impl<T: Default, const R: usize, const C: usize> Default for ArrayStorage<T, R, C>
69where
70    [[T; R]; C]: Default,
71{
72    #[inline]
73    fn default() -> Self {
74        Self(Default::default())
75    }
76}
77
78impl<T: Debug, const R: usize, const C: usize> Debug for ArrayStorage<T, R, C> {
79    #[inline]
80    fn fmt(&self, fmt: &mut Formatter<'_>) -> fmt::Result {
81        self.0.fmt(fmt)
82    }
83}
84
85unsafe impl<T, const R: usize, const C: usize> RawStorage<T, Const<R>, Const<C>>
86    for ArrayStorage<T, R, C>
87{
88    type RStride = Const<1>;
89    type CStride = Const<R>;
90
91    #[inline]
92    fn ptr(&self) -> *const T {
93        self.0.as_ptr() as *const T
94    }
95
96    #[inline]
97    fn shape(&self) -> (Const<R>, Const<C>) {
98        (Const, Const)
99    }
100
101    #[inline]
102    fn strides(&self) -> (Self::RStride, Self::CStride) {
103        (Const, Const)
104    }
105
106    #[inline]
107    fn is_contiguous(&self) -> bool {
108        true
109    }
110
111    #[inline]
112    unsafe fn as_slice_unchecked(&self) -> &[T] {
113        unsafe { std::slice::from_raw_parts(self.ptr(), R * C) }
114    }
115}
116
117unsafe impl<T: Scalar, const R: usize, const C: usize> Storage<T, Const<R>, Const<C>>
118    for ArrayStorage<T, R, C>
119where
120    DefaultAllocator: Allocator<Const<R>, Const<C>, Buffer<T> = Self>,
121{
122    #[inline]
123    fn into_owned(self) -> Owned<T, Const<R>, Const<C>>
124    where
125        DefaultAllocator: Allocator<Const<R>, Const<C>>,
126    {
127        self
128    }
129
130    #[inline]
131    fn clone_owned(&self) -> Owned<T, Const<R>, Const<C>>
132    where
133        DefaultAllocator: Allocator<Const<R>, Const<C>>,
134    {
135        self.clone()
136    }
137
138    #[inline]
139    fn forget_elements(self) {
140        // No additional cleanup required.
141        std::mem::forget(self);
142    }
143}
144
145unsafe impl<T, const R: usize, const C: usize> RawStorageMut<T, Const<R>, Const<C>>
146    for ArrayStorage<T, R, C>
147{
148    #[inline]
149    fn ptr_mut(&mut self) -> *mut T {
150        self.0.as_mut_ptr() as *mut T
151    }
152
153    #[inline]
154    unsafe fn as_mut_slice_unchecked(&mut self) -> &mut [T] {
155        unsafe { std::slice::from_raw_parts_mut(self.ptr_mut(), R * C) }
156    }
157}
158
159unsafe impl<T, const R: usize, const C: usize> IsContiguous for ArrayStorage<T, R, C> {}
160
161impl<T, const R1: usize, const C1: usize, const R2: usize, const C2: usize>
162    ReshapableStorage<T, Const<R1>, Const<C1>, Const<R2>, Const<C2>> for ArrayStorage<T, R1, C1>
163where
164    T: Scalar,
165    Const<R1>: ToTypenum,
166    Const<C1>: ToTypenum,
167    Const<R2>: ToTypenum,
168    Const<C2>: ToTypenum,
169    <Const<R1> as ToTypenum>::Typenum: Mul<<Const<C1> as ToTypenum>::Typenum>,
170    <Const<R2> as ToTypenum>::Typenum: Mul<
171            <Const<C2> as ToTypenum>::Typenum,
172            Output = typenum::Prod<
173                <Const<R1> as ToTypenum>::Typenum,
174                <Const<C1> as ToTypenum>::Typenum,
175            >,
176        >,
177{
178    type Output = ArrayStorage<T, R2, C2>;
179
180    fn reshape_generic(self, _: Const<R2>, _: Const<C2>) -> Self::Output {
181        unsafe {
182            let data: [[T; R2]; C2] = mem::transmute_copy(&self.0);
183            mem::forget(self.0);
184            ArrayStorage(data)
185        }
186    }
187}
188
189/*
190 *
191 * Serialization.
192 *
193 */
194// XXX: open an issue for serde so that it allows the serialization/deserialization of all arrays?
195#[cfg(feature = "serde-serialize-no-std")]
196impl<T, const R: usize, const C: usize> Serialize for ArrayStorage<T, R, C>
197where
198    T: Scalar + Serialize,
199{
200    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
201    where
202        S: Serializer,
203    {
204        let mut serializer = serializer.serialize_tuple(R * C)?;
205
206        for e in self.as_slice().iter() {
207            serializer.serialize_element(e)?;
208        }
209
210        serializer.end()
211    }
212}
213
214#[cfg(feature = "serde-serialize-no-std")]
215impl<'a, T, const R: usize, const C: usize> Deserialize<'a> for ArrayStorage<T, R, C>
216where
217    T: Scalar + Deserialize<'a>,
218{
219    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
220    where
221        D: Deserializer<'a>,
222    {
223        deserializer.deserialize_tuple(R * C, ArrayStorageVisitor::new())
224    }
225}
226
227#[cfg(feature = "serde-serialize-no-std")]
228/// A visitor that produces a matrix array.
229struct ArrayStorageVisitor<T, const R: usize, const C: usize> {
230    marker: PhantomData<T>,
231}
232
233#[cfg(feature = "serde-serialize-no-std")]
234impl<T, const R: usize, const C: usize> ArrayStorageVisitor<T, R, C>
235where
236    T: Scalar,
237{
238    /// Construct a new sequence visitor.
239    pub fn new() -> Self {
240        ArrayStorageVisitor {
241            marker: PhantomData,
242        }
243    }
244}
245
246#[cfg(feature = "serde-serialize-no-std")]
247impl<'a, T, const R: usize, const C: usize> Visitor<'a> for ArrayStorageVisitor<T, R, C>
248where
249    T: Scalar + Deserialize<'a>,
250{
251    type Value = ArrayStorage<T, R, C>;
252
253    fn expecting(&self, formatter: &mut Formatter<'_>) -> fmt::Result {
254        formatter.write_str("a matrix array")
255    }
256
257    #[inline]
258    fn visit_seq<V>(self, mut visitor: V) -> Result<ArrayStorage<T, R, C>, V::Error>
259    where
260        V: SeqAccess<'a>,
261    {
262        let mut out: ArrayStorage<core::mem::MaybeUninit<T>, R, C> =
263            <DefaultAllocator as Allocator<_, _>>::allocate_uninit(Const::<R>, Const::<C>);
264        let mut curr = 0;
265
266        while let Some(value) = visitor.next_element()? {
267            *out.as_mut_slice()
268                .get_mut(curr)
269                .ok_or_else(|| V::Error::invalid_length(curr, &self))? =
270                core::mem::MaybeUninit::new(value);
271            curr += 1;
272        }
273
274        if curr == R * C {
275            // Safety: all the elements have been initialized.
276            unsafe { Ok(<DefaultAllocator as Allocator<Const<R>, Const<C>>>::assume_init(out)) }
277        } else {
278            for i in 0..curr {
279                // Safety:
280                // - We couldn’t initialize the whole storage. Drop the ones we initialized.
281                unsafe { std::ptr::drop_in_place(out.as_mut_slice()[i].as_mut_ptr()) };
282            }
283
284            Err(V::Error::invalid_length(curr, &self))
285        }
286    }
287}
288
289#[cfg(feature = "bytemuck")]
290unsafe impl<T: Scalar + Copy + bytemuck::Zeroable, const R: usize, const C: usize>
291    bytemuck::Zeroable for ArrayStorage<T, R, C>
292{
293}
294
295#[cfg(feature = "bytemuck")]
296unsafe impl<T: Scalar + Copy + bytemuck::Pod, const R: usize, const C: usize> bytemuck::Pod
297    for ArrayStorage<T, R, C>
298{
299}