nalgebra/base/
iter.rs

1//! Matrix iterators.
2
3use core::fmt::Debug;
4use core::ops::Range;
5use std::iter::FusedIterator;
6use std::marker::PhantomData;
7use std::mem;
8
9use crate::base::dimension::{Dim, U1};
10use crate::base::storage::{RawStorage, RawStorageMut};
11use crate::base::{Matrix, MatrixView, MatrixViewMut, Scalar, ViewStorage, ViewStorageMut};
12
13#[derive(Clone, Debug)]
14struct RawIter<Ptr, T, R: Dim, C: Dim, RStride: Dim, CStride: Dim> {
15    ptr: Ptr,
16    inner_ptr: Ptr,
17    inner_end: Ptr,
18    size: usize,
19    strides: (RStride, CStride),
20    _phantoms: PhantomData<(fn() -> T, R, C)>,
21}
22
23macro_rules! iterator {
24    (struct $Name:ident for $Storage:ident.$ptr: ident -> $Ptr:ty, $Ref:ty, $SRef: ty, $($derives:ident),* $(,)?) => {
25        // TODO: we need to specialize for the case where the matrix storage is owned (in which
26        // case the iterator is trivial because it does not have any stride).
27        impl<T, R: Dim, C: Dim, RStride: Dim, CStride: Dim>
28            RawIter<$Ptr, T, R, C, RStride, CStride>
29        {
30            /// Creates a new iterator for the given matrix storage.
31            fn new<'a, S: $Storage<T, R, C, RStride = RStride, CStride = CStride>>(
32                storage: $SRef,
33            ) -> Self {
34                let shape = storage.shape();
35                let strides = storage.strides();
36                let inner_offset = shape.0.value() * strides.0.value();
37                let size = shape.0.value() * shape.1.value();
38                let ptr = storage.$ptr();
39
40                // If we have a size of 0, 'ptr' must be
41                // dangling. However, 'inner_offset' might
42                // not be zero if only one dimension is zero, so
43                // we don't want to call 'offset'.
44                // This pointer will never actually get used
45                // if our size is '0', so it's fine to use
46                // 'ptr' for both the start and end.
47                let inner_end = if size == 0 {
48                    ptr
49                } else {
50                    // Safety:
51                    // If 'size' is non-zero, we know that 'ptr'
52                    // is not dangling, and 'inner_offset' must lie
53                    // within the allocation
54                    unsafe { ptr.add(inner_offset) }
55                };
56
57                RawIter {
58                    ptr,
59                    inner_ptr: ptr,
60                    inner_end,
61                    size: shape.0.value() * shape.1.value(),
62                    strides,
63                    _phantoms: PhantomData,
64                }
65            }
66        }
67
68        impl<T, R: Dim, C: Dim, RStride: Dim, CStride: Dim> Iterator
69            for RawIter<$Ptr, T, R, C, RStride, CStride>
70        {
71            type Item = $Ptr;
72
73            #[inline]
74            fn next(&mut self) -> Option<Self::Item> {
75                unsafe {
76                    if self.size == 0 {
77                        None
78                    } else {
79                        self.size -= 1;
80
81                        // Jump to the next outer dimension if needed.
82                        if self.ptr == self.inner_end {
83                            let stride = self.strides.1.value() as isize;
84                            // This might go past the end of the allocation,
85                            // depending on the value of 'size'. We use
86                            // `wrapping_offset` to avoid UB
87                            self.inner_end = self.ptr.wrapping_offset(stride);
88                            // This will always be in bounds, since
89                            // we're going to dereference it
90                            self.ptr = self.inner_ptr.offset(stride);
91                            self.inner_ptr = self.ptr;
92                        }
93
94                        // Go to the next element.
95                        let old = self.ptr;
96
97                        // Don't offset `self.ptr` for the last element,
98                        // as this will be out of bounds. Iteration is done
99                        // at this point (the next call to `next` will return `None`)
100                        // so this is not observable.
101                        if self.size != 0 {
102                            let stride = self.strides.0.value();
103                            self.ptr = self.ptr.add(stride);
104                        }
105
106                        Some(old)
107                    }
108                }
109            }
110
111            #[inline]
112            fn size_hint(&self) -> (usize, Option<usize>) {
113                (self.size, Some(self.size))
114            }
115
116            #[inline]
117            fn count(self) -> usize {
118                self.size_hint().0
119            }
120        }
121
122        impl<T, R: Dim, C: Dim, RStride: Dim, CStride: Dim> DoubleEndedIterator
123            for RawIter<$Ptr, T, R, C, RStride, CStride>
124        {
125            #[inline]
126            fn next_back(&mut self) -> Option<Self::Item> {
127                unsafe {
128                    if self.size == 0 {
129                        None
130                    } else {
131                        // Pre-decrement `size` such that it now counts to the
132                        // element we want to return.
133                        self.size -= 1;
134
135                        // Fetch strides
136                        let inner_stride = self.strides.0.value();
137                        let outer_stride = self.strides.1.value();
138
139                        // Compute number of rows
140                        // Division should be exact
141                        let inner_raw_size = self.inner_end.offset_from(self.inner_ptr) as usize;
142                        let inner_size = inner_raw_size / inner_stride;
143
144                        // Compute rows and cols remaining
145                        let outer_remaining = self.size / inner_size;
146                        let inner_remaining = self.size % inner_size;
147
148                        // Compute pointer to last element
149                        let last = self
150                            .ptr
151                            .add((outer_remaining * outer_stride + inner_remaining * inner_stride));
152
153                        Some(last)
154                    }
155                }
156            }
157        }
158
159        impl<T, R: Dim, C: Dim, RStride: Dim, CStride: Dim> ExactSizeIterator
160            for RawIter<$Ptr, T, R, C, RStride, CStride>
161        {
162            #[inline]
163            fn len(&self) -> usize {
164                self.size
165            }
166        }
167
168        impl<T, R: Dim, C: Dim, RStride: Dim, CStride: Dim> FusedIterator
169            for RawIter<$Ptr, T, R, C, RStride, CStride>
170        {
171        }
172
173        /// An iterator through a dense matrix with arbitrary strides matrix.
174        #[derive($($derives),*)]
175        pub struct $Name<'a, T, R: Dim, C: Dim, S: 'a + $Storage<T, R, C>> {
176            inner: RawIter<$Ptr, T, R, C, S::RStride, S::CStride>,
177            _marker: PhantomData<$Ref>,
178        }
179
180        impl<'a, T, R: Dim, C: Dim, S: 'a + $Storage<T, R, C>> $Name<'a, T, R, C, S> {
181            /// Creates a new iterator for the given matrix storage.
182            pub fn new(storage: $SRef) -> Self {
183                Self {
184                    inner: RawIter::<$Ptr, T, R, C, S::RStride, S::CStride>::new(storage),
185                    _marker: PhantomData,
186                }
187            }
188        }
189
190        impl<'a, T, R: Dim, C: Dim, S: 'a + $Storage<T, R, C>> Iterator for $Name<'a, T, R, C, S> {
191            type Item = $Ref;
192
193            #[inline(always)]
194            fn next(&mut self) -> Option<Self::Item> {
195                // We want either `& *last` or `&mut *last` here, depending
196                // on the mutability of `$Ref`.
197                #[allow(clippy::transmute_ptr_to_ref)]
198                self.inner.next().map(|ptr| unsafe { mem::transmute(ptr) })
199            }
200
201            #[inline(always)]
202            fn size_hint(&self) -> (usize, Option<usize>) {
203                self.inner.size_hint()
204            }
205
206            #[inline(always)]
207            fn count(self) -> usize {
208                self.inner.count()
209            }
210        }
211
212        impl<'a, T, R: Dim, C: Dim, S: 'a + $Storage<T, R, C>> DoubleEndedIterator
213            for $Name<'a, T, R, C, S>
214        {
215            #[inline(always)]
216            fn next_back(&mut self) -> Option<Self::Item> {
217                // We want either `& *last` or `&mut *last` here, depending
218                // on the mutability of `$Ref`.
219                #[allow(clippy::transmute_ptr_to_ref)]
220                self.inner
221                    .next_back()
222                    .map(|ptr| unsafe { mem::transmute(ptr) })
223            }
224        }
225
226        impl<'a, T, R: Dim, C: Dim, S: 'a + $Storage<T, R, C>> ExactSizeIterator
227            for $Name<'a, T, R, C, S>
228        {
229            #[inline(always)]
230            fn len(&self) -> usize {
231                self.inner.len()
232            }
233        }
234
235        impl<'a, T, R: Dim, C: Dim, S: 'a + $Storage<T, R, C>> FusedIterator
236            for $Name<'a, T, R, C, S>
237        {
238        }
239    };
240}
241
242iterator!(struct MatrixIter for RawStorage.ptr -> *const T, &'a T, &'a S, Clone, Debug);
243iterator!(struct MatrixIterMut for RawStorageMut.ptr_mut -> *mut T, &'a mut T, &'a mut S, Debug);
244
245impl<'a, T, R: Dim, C: Dim, RStride: Dim, CStride: Dim>
246    MatrixIter<'a, T, R, C, ViewStorage<'a, T, R, C, RStride, CStride>>
247{
248    /// Creates a new iterator for the given matrix storage view.
249    pub fn new_owned(storage: ViewStorage<'a, T, R, C, RStride, CStride>) -> Self {
250        Self {
251            inner: RawIter::<*const T, T, R, C, RStride, CStride>::new(&storage),
252            _marker: PhantomData,
253        }
254    }
255}
256
257impl<'a, T, R: Dim, C: Dim, RStride: Dim, CStride: Dim>
258    MatrixIterMut<'a, T, R, C, ViewStorageMut<'a, T, R, C, RStride, CStride>>
259{
260    /// Creates a new iterator for the given matrix storage view.
261    pub fn new_owned_mut(mut storage: ViewStorageMut<'a, T, R, C, RStride, CStride>) -> Self {
262        Self {
263            inner: RawIter::<*mut T, T, R, C, RStride, CStride>::new(&mut storage),
264            _marker: PhantomData,
265        }
266    }
267}
268
269/*
270 *
271 * Row iterators.
272 *
273 */
274#[derive(Clone, Debug)]
275/// An iterator through the rows of a matrix.
276pub struct RowIter<'a, T, R: Dim, C: Dim, S: RawStorage<T, R, C>> {
277    mat: &'a Matrix<T, R, C, S>,
278    curr: usize,
279}
280
281impl<'a, T, R: Dim, C: Dim, S: 'a + RawStorage<T, R, C>> RowIter<'a, T, R, C, S> {
282    pub(crate) const fn new(mat: &'a Matrix<T, R, C, S>) -> Self {
283        RowIter { mat, curr: 0 }
284    }
285}
286
287impl<'a, T, R: Dim, C: Dim, S: 'a + RawStorage<T, R, C>> Iterator for RowIter<'a, T, R, C, S> {
288    type Item = MatrixView<'a, T, U1, C, S::RStride, S::CStride>;
289
290    #[inline]
291    fn next(&mut self) -> Option<Self::Item> {
292        if self.curr < self.mat.nrows() {
293            let res = self.mat.row(self.curr);
294            self.curr += 1;
295            Some(res)
296        } else {
297            None
298        }
299    }
300
301    #[inline]
302    fn size_hint(&self) -> (usize, Option<usize>) {
303        (
304            self.mat.nrows() - self.curr,
305            Some(self.mat.nrows() - self.curr),
306        )
307    }
308
309    #[inline]
310    fn count(self) -> usize {
311        self.mat.nrows() - self.curr
312    }
313}
314
315impl<'a, T: Scalar, R: Dim, C: Dim, S: 'a + RawStorage<T, R, C>> ExactSizeIterator
316    for RowIter<'a, T, R, C, S>
317{
318    #[inline]
319    fn len(&self) -> usize {
320        self.mat.nrows() - self.curr
321    }
322}
323
324/// An iterator through the mutable rows of a matrix.
325#[derive(Debug)]
326pub struct RowIterMut<'a, T, R: Dim, C: Dim, S: RawStorageMut<T, R, C>> {
327    mat: *mut Matrix<T, R, C, S>,
328    curr: usize,
329    phantom: PhantomData<&'a mut Matrix<T, R, C, S>>,
330}
331
332impl<'a, T, R: Dim, C: Dim, S: 'a + RawStorageMut<T, R, C>> RowIterMut<'a, T, R, C, S> {
333    pub(crate) const fn new(mat: &'a mut Matrix<T, R, C, S>) -> Self {
334        RowIterMut {
335            mat,
336            curr: 0,
337            phantom: PhantomData,
338        }
339    }
340
341    fn nrows(&self) -> usize {
342        unsafe { (*self.mat).nrows() }
343    }
344}
345
346impl<'a, T, R: Dim, C: Dim, S: 'a + RawStorageMut<T, R, C>> Iterator
347    for RowIterMut<'a, T, R, C, S>
348{
349    type Item = MatrixViewMut<'a, T, U1, C, S::RStride, S::CStride>;
350
351    #[inline]
352    fn next(&mut self) -> Option<Self::Item> {
353        if self.curr < self.nrows() {
354            let res = unsafe { (*self.mat).row_mut(self.curr) };
355            self.curr += 1;
356            Some(res)
357        } else {
358            None
359        }
360    }
361
362    #[inline]
363    fn size_hint(&self) -> (usize, Option<usize>) {
364        (self.nrows() - self.curr, Some(self.nrows() - self.curr))
365    }
366
367    #[inline]
368    fn count(self) -> usize {
369        self.nrows() - self.curr
370    }
371}
372
373impl<'a, T: Scalar, R: Dim, C: Dim, S: 'a + RawStorageMut<T, R, C>> ExactSizeIterator
374    for RowIterMut<'a, T, R, C, S>
375{
376    #[inline]
377    fn len(&self) -> usize {
378        self.nrows() - self.curr
379    }
380}
381
382/*
383 * Column iterators.
384 *
385 */
386#[derive(Clone, Debug)]
387/// An iterator through the columns of a matrix.
388pub struct ColumnIter<'a, T, R: Dim, C: Dim, S: RawStorage<T, R, C>> {
389    mat: &'a Matrix<T, R, C, S>,
390    range: Range<usize>,
391}
392
393impl<'a, T, R: Dim, C: Dim, S: 'a + RawStorage<T, R, C>> ColumnIter<'a, T, R, C, S> {
394    /// a new column iterator covering all columns of the matrix
395    pub(crate) fn new(mat: &'a Matrix<T, R, C, S>) -> Self {
396        ColumnIter {
397            mat,
398            range: 0..mat.ncols(),
399        }
400    }
401
402    #[cfg(feature = "rayon")]
403    pub(crate) fn split_at(self, index: usize) -> (Self, Self) {
404        // SAFETY: this makes sure the generated ranges are valid.
405        let split_pos = (self.range.start + index).min(self.range.end);
406
407        let left_iter = ColumnIter {
408            mat: self.mat,
409            range: self.range.start..split_pos,
410        };
411
412        let right_iter = ColumnIter {
413            mat: self.mat,
414            range: split_pos..self.range.end,
415        };
416
417        (left_iter, right_iter)
418    }
419}
420
421impl<'a, T, R: Dim, C: Dim, S: 'a + RawStorage<T, R, C>> Iterator for ColumnIter<'a, T, R, C, S> {
422    type Item = MatrixView<'a, T, R, U1, S::RStride, S::CStride>;
423
424    #[inline]
425    fn next(&mut self) -> Option<Self::Item> {
426        debug_assert!(self.range.start <= self.range.end);
427        if self.range.start < self.range.end {
428            let res = self.mat.column(self.range.start);
429            self.range.start += 1;
430            Some(res)
431        } else {
432            None
433        }
434    }
435
436    #[inline]
437    fn size_hint(&self) -> (usize, Option<usize>) {
438        let hint = self.range.len();
439        (hint, Some(hint))
440    }
441
442    #[inline]
443    fn count(self) -> usize {
444        self.range.len()
445    }
446}
447
448impl<'a, T, R: Dim, C: Dim, S: 'a + RawStorage<T, R, C>> DoubleEndedIterator
449    for ColumnIter<'a, T, R, C, S>
450{
451    fn next_back(&mut self) -> Option<Self::Item> {
452        debug_assert!(self.range.start <= self.range.end);
453        if !self.range.is_empty() {
454            self.range.end -= 1;
455            debug_assert!(self.range.end < self.mat.ncols());
456            debug_assert!(self.range.end >= self.range.start);
457            Some(self.mat.column(self.range.end))
458        } else {
459            None
460        }
461    }
462}
463
464impl<'a, T: Scalar, R: Dim, C: Dim, S: 'a + RawStorage<T, R, C>> ExactSizeIterator
465    for ColumnIter<'a, T, R, C, S>
466{
467    #[inline]
468    fn len(&self) -> usize {
469        self.range.end - self.range.start
470    }
471}
472
473/// An iterator through the mutable columns of a matrix.
474#[derive(Debug)]
475pub struct ColumnIterMut<'a, T, R: Dim, C: Dim, S: RawStorageMut<T, R, C>> {
476    mat: *mut Matrix<T, R, C, S>,
477    range: Range<usize>,
478    phantom: PhantomData<&'a mut Matrix<T, R, C, S>>,
479}
480
481impl<'a, T, R: Dim, C: Dim, S: 'a + RawStorageMut<T, R, C>> ColumnIterMut<'a, T, R, C, S> {
482    pub(crate) fn new(mat: &'a mut Matrix<T, R, C, S>) -> Self {
483        let range = 0..mat.ncols();
484        ColumnIterMut {
485            mat,
486            range,
487            phantom: Default::default(),
488        }
489    }
490
491    #[cfg(feature = "rayon")]
492    pub(crate) fn split_at(self, index: usize) -> (Self, Self) {
493        // SAFETY: this makes sure the generated ranges are valid.
494        let split_pos = (self.range.start + index).min(self.range.end);
495
496        let left_iter = ColumnIterMut {
497            mat: self.mat,
498            range: self.range.start..split_pos,
499            phantom: Default::default(),
500        };
501
502        let right_iter = ColumnIterMut {
503            mat: self.mat,
504            range: split_pos..self.range.end,
505            phantom: Default::default(),
506        };
507
508        (left_iter, right_iter)
509    }
510
511    fn ncols(&self) -> usize {
512        unsafe { (*self.mat).ncols() }
513    }
514}
515
516impl<'a, T, R: Dim, C: Dim, S: 'a + RawStorageMut<T, R, C>> Iterator
517    for ColumnIterMut<'a, T, R, C, S>
518{
519    type Item = MatrixViewMut<'a, T, R, U1, S::RStride, S::CStride>;
520
521    #[inline]
522    fn next(&'_ mut self) -> Option<Self::Item> {
523        debug_assert!(self.range.start <= self.range.end);
524        if self.range.start < self.range.end {
525            let res = unsafe { (*self.mat).column_mut(self.range.start) };
526            self.range.start += 1;
527            Some(res)
528        } else {
529            None
530        }
531    }
532
533    #[inline]
534    fn size_hint(&self) -> (usize, Option<usize>) {
535        let hint = self.range.len();
536        (hint, Some(hint))
537    }
538
539    #[inline]
540    fn count(self) -> usize {
541        self.range.len()
542    }
543}
544
545impl<'a, T: Scalar, R: Dim, C: Dim, S: 'a + RawStorageMut<T, R, C>> ExactSizeIterator
546    for ColumnIterMut<'a, T, R, C, S>
547{
548    #[inline]
549    fn len(&self) -> usize {
550        self.range.len()
551    }
552}
553
554impl<'a, T: Scalar, R: Dim, C: Dim, S: 'a + RawStorageMut<T, R, C>> DoubleEndedIterator
555    for ColumnIterMut<'a, T, R, C, S>
556{
557    fn next_back(&mut self) -> Option<Self::Item> {
558        debug_assert!(self.range.start <= self.range.end);
559        if !self.range.is_empty() {
560            self.range.end -= 1;
561            debug_assert!(self.range.end < self.ncols());
562            debug_assert!(self.range.end >= self.range.start);
563            Some(unsafe { (*self.mat).column_mut(self.range.end) })
564        } else {
565            None
566        }
567    }
568}