nalgebra/base/
array_storage.rs

1use std::fmt::{self, Debug, Formatter};
2// use std::hash::{Hash, Hasher};
3use std::ops::Mul;
4
5#[cfg(feature = "serde-serialize-no-std")]
6use serde::de::{Error, SeqAccess, Visitor};
7#[cfg(feature = "serde-serialize-no-std")]
8use serde::ser::SerializeTuple;
9#[cfg(feature = "serde-serialize-no-std")]
10use serde::{Deserialize, Deserializer, Serialize, Serializer};
11#[cfg(feature = "serde-serialize-no-std")]
12use std::marker::PhantomData;
13
14#[cfg(feature = "rkyv-serialize")]
15use rkyv::bytecheck;
16
17use crate::base::allocator::Allocator;
18use crate::base::default_allocator::DefaultAllocator;
19use crate::base::dimension::{Const, ToTypenum};
20use crate::base::storage::{IsContiguous, Owned, RawStorage, RawStorageMut, ReshapableStorage};
21use crate::base::Scalar;
22use crate::Storage;
23use std::mem;
24
25/*
26 *
27 * Static RawStorage.
28 *
29 */
30/// An array-based statically sized matrix data storage.
31#[repr(transparent)]
32#[derive(Copy, Clone, PartialEq, Eq, Hash)]
33#[cfg_attr(
34    feature = "rkyv-serialize-no-std",
35    derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize),
36    archive(
37        as = "ArrayStorage<T::Archived, R, C>",
38        bound(archive = "
39        T: rkyv::Archive,
40        [[T; R]; C]: rkyv::Archive<Archived = [[T::Archived; R]; C]>
41    ")
42    )
43)]
44#[cfg_attr(feature = "rkyv-serialize", derive(bytecheck::CheckBytes))]
45pub struct ArrayStorage<T, const R: usize, const C: usize>(pub [[T; R]; C]);
46
47impl<T, const R: usize, const C: usize> ArrayStorage<T, R, C> {
48    /// Converts this array storage to a slice.
49    #[inline]
50    pub fn as_slice(&self) -> &[T] {
51        // SAFETY: this is OK because ArrayStorage is contiguous.
52        unsafe { self.as_slice_unchecked() }
53    }
54
55    /// Converts this array storage to a mutable slice.
56    #[inline]
57    pub fn as_mut_slice(&mut self) -> &mut [T] {
58        // SAFETY: this is OK because ArrayStorage is contiguous.
59        unsafe { self.as_mut_slice_unchecked() }
60    }
61}
62
63// TODO: remove this once the stdlib implements Default for arrays.
64impl<T: Default, const R: usize, const C: usize> Default for ArrayStorage<T, R, C>
65where
66    [[T; R]; C]: Default,
67{
68    #[inline]
69    fn default() -> Self {
70        Self(Default::default())
71    }
72}
73
74impl<T: Debug, const R: usize, const C: usize> Debug for ArrayStorage<T, R, C> {
75    #[inline]
76    fn fmt(&self, fmt: &mut Formatter<'_>) -> fmt::Result {
77        self.0.fmt(fmt)
78    }
79}
80
81unsafe impl<T, const R: usize, const C: usize> RawStorage<T, Const<R>, Const<C>>
82    for ArrayStorage<T, R, C>
83{
84    type RStride = Const<1>;
85    type CStride = Const<R>;
86
87    #[inline]
88    fn ptr(&self) -> *const T {
89        self.0.as_ptr() as *const T
90    }
91
92    #[inline]
93    fn shape(&self) -> (Const<R>, Const<C>) {
94        (Const, Const)
95    }
96
97    #[inline]
98    fn strides(&self) -> (Self::RStride, Self::CStride) {
99        (Const, Const)
100    }
101
102    #[inline]
103    fn is_contiguous(&self) -> bool {
104        true
105    }
106
107    #[inline]
108    unsafe fn as_slice_unchecked(&self) -> &[T] {
109        std::slice::from_raw_parts(self.ptr(), R * C)
110    }
111}
112
113unsafe impl<T: Scalar, const R: usize, const C: usize> Storage<T, Const<R>, Const<C>>
114    for ArrayStorage<T, R, C>
115where
116    DefaultAllocator: Allocator<Const<R>, Const<C>, Buffer<T> = Self>,
117{
118    #[inline]
119    fn into_owned(self) -> Owned<T, Const<R>, Const<C>>
120    where
121        DefaultAllocator: Allocator<Const<R>, Const<C>>,
122    {
123        self
124    }
125
126    #[inline]
127    fn clone_owned(&self) -> Owned<T, Const<R>, Const<C>>
128    where
129        DefaultAllocator: Allocator<Const<R>, Const<C>>,
130    {
131        self.clone()
132    }
133
134    #[inline]
135    fn forget_elements(self) {
136        // No additional cleanup required.
137        std::mem::forget(self);
138    }
139}
140
141unsafe impl<T, const R: usize, const C: usize> RawStorageMut<T, Const<R>, Const<C>>
142    for ArrayStorage<T, R, C>
143{
144    #[inline]
145    fn ptr_mut(&mut self) -> *mut T {
146        self.0.as_mut_ptr() as *mut T
147    }
148
149    #[inline]
150    unsafe fn as_mut_slice_unchecked(&mut self) -> &mut [T] {
151        std::slice::from_raw_parts_mut(self.ptr_mut(), R * C)
152    }
153}
154
155unsafe impl<T, const R: usize, const C: usize> IsContiguous for ArrayStorage<T, R, C> {}
156
157impl<T, const R1: usize, const C1: usize, const R2: usize, const C2: usize>
158    ReshapableStorage<T, Const<R1>, Const<C1>, Const<R2>, Const<C2>> for ArrayStorage<T, R1, C1>
159where
160    T: Scalar,
161    Const<R1>: ToTypenum,
162    Const<C1>: ToTypenum,
163    Const<R2>: ToTypenum,
164    Const<C2>: ToTypenum,
165    <Const<R1> as ToTypenum>::Typenum: Mul<<Const<C1> as ToTypenum>::Typenum>,
166    <Const<R2> as ToTypenum>::Typenum: Mul<
167        <Const<C2> as ToTypenum>::Typenum,
168        Output = typenum::Prod<
169            <Const<R1> as ToTypenum>::Typenum,
170            <Const<C1> as ToTypenum>::Typenum,
171        >,
172    >,
173{
174    type Output = ArrayStorage<T, R2, C2>;
175
176    fn reshape_generic(self, _: Const<R2>, _: Const<C2>) -> Self::Output {
177        unsafe {
178            let data: [[T; R2]; C2] = mem::transmute_copy(&self.0);
179            mem::forget(self.0);
180            ArrayStorage(data)
181        }
182    }
183}
184
185/*
186 *
187 * Serialization.
188 *
189 */
190// XXX: open an issue for serde so that it allows the serialization/deserialization of all arrays?
191#[cfg(feature = "serde-serialize-no-std")]
192impl<T, const R: usize, const C: usize> Serialize for ArrayStorage<T, R, C>
193where
194    T: Scalar + Serialize,
195{
196    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
197    where
198        S: Serializer,
199    {
200        let mut serializer = serializer.serialize_tuple(R * C)?;
201
202        for e in self.as_slice().iter() {
203            serializer.serialize_element(e)?;
204        }
205
206        serializer.end()
207    }
208}
209
210#[cfg(feature = "serde-serialize-no-std")]
211impl<'a, T, const R: usize, const C: usize> Deserialize<'a> for ArrayStorage<T, R, C>
212where
213    T: Scalar + Deserialize<'a>,
214{
215    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
216    where
217        D: Deserializer<'a>,
218    {
219        deserializer.deserialize_tuple(R * C, ArrayStorageVisitor::new())
220    }
221}
222
223#[cfg(feature = "serde-serialize-no-std")]
224/// A visitor that produces a matrix array.
225struct ArrayStorageVisitor<T, const R: usize, const C: usize> {
226    marker: PhantomData<T>,
227}
228
229#[cfg(feature = "serde-serialize-no-std")]
230impl<T, const R: usize, const C: usize> ArrayStorageVisitor<T, R, C>
231where
232    T: Scalar,
233{
234    /// Construct a new sequence visitor.
235    pub fn new() -> Self {
236        ArrayStorageVisitor {
237            marker: PhantomData,
238        }
239    }
240}
241
242#[cfg(feature = "serde-serialize-no-std")]
243impl<'a, T, const R: usize, const C: usize> Visitor<'a> for ArrayStorageVisitor<T, R, C>
244where
245    T: Scalar + Deserialize<'a>,
246{
247    type Value = ArrayStorage<T, R, C>;
248
249    fn expecting(&self, formatter: &mut Formatter<'_>) -> fmt::Result {
250        formatter.write_str("a matrix array")
251    }
252
253    #[inline]
254    fn visit_seq<V>(self, mut visitor: V) -> Result<ArrayStorage<T, R, C>, V::Error>
255    where
256        V: SeqAccess<'a>,
257    {
258        let mut out: ArrayStorage<core::mem::MaybeUninit<T>, R, C> =
259            <DefaultAllocator as Allocator<_, _>>::allocate_uninit(Const::<R>, Const::<C>);
260        let mut curr = 0;
261
262        while let Some(value) = visitor.next_element()? {
263            *out.as_mut_slice()
264                .get_mut(curr)
265                .ok_or_else(|| V::Error::invalid_length(curr, &self))? =
266                core::mem::MaybeUninit::new(value);
267            curr += 1;
268        }
269
270        if curr == R * C {
271            // Safety: all the elements have been initialized.
272            unsafe { Ok(<DefaultAllocator as Allocator<Const<R>, Const<C>>>::assume_init(out)) }
273        } else {
274            for i in 0..curr {
275                // Safety:
276                // - We couldn’t initialize the whole storage. Drop the ones we initialized.
277                unsafe { std::ptr::drop_in_place(out.as_mut_slice()[i].as_mut_ptr()) };
278            }
279
280            Err(V::Error::invalid_length(curr, &self))
281        }
282    }
283}
284
285#[cfg(feature = "bytemuck")]
286unsafe impl<T: Scalar + Copy + bytemuck::Zeroable, const R: usize, const C: usize>
287    bytemuck::Zeroable for ArrayStorage<T, R, C>
288{
289}
290
291#[cfg(feature = "bytemuck")]
292unsafe impl<T: Scalar + Copy + bytemuck::Pod, const R: usize, const C: usize> bytemuck::Pod
293    for ArrayStorage<T, R, C>
294{
295}