nalgebra/base/
iter.rs

1//! Matrix iterators.
2
3// only enables the `doc_cfg` feature when
4// the `docsrs` configuration attribute is defined
5#![cfg_attr(docsrs, feature(doc_cfg))]
6
7use core::fmt::Debug;
8use core::ops::Range;
9use std::iter::FusedIterator;
10use std::marker::PhantomData;
11use std::mem;
12
13use crate::base::dimension::{Dim, U1};
14use crate::base::storage::{RawStorage, RawStorageMut};
15use crate::base::{Matrix, MatrixView, MatrixViewMut, Scalar, ViewStorage, ViewStorageMut};
16
17#[derive(Clone, Debug)]
18struct RawIter<Ptr, T, R: Dim, C: Dim, RStride: Dim, CStride: Dim> {
19    ptr: Ptr,
20    inner_ptr: Ptr,
21    inner_end: Ptr,
22    size: usize,
23    strides: (RStride, CStride),
24    _phantoms: PhantomData<(fn() -> T, R, C)>,
25}
26
27macro_rules! iterator {
28    (struct $Name:ident for $Storage:ident.$ptr: ident -> $Ptr:ty, $Ref:ty, $SRef: ty, $($derives:ident),* $(,)?) => {
29        // TODO: we need to specialize for the case where the matrix storage is owned (in which
30        // case the iterator is trivial because it does not have any stride).
31        impl<T, R: Dim, C: Dim, RStride: Dim, CStride: Dim>
32            RawIter<$Ptr, T, R, C, RStride, CStride>
33        {
34            /// Creates a new iterator for the given matrix storage.
35            fn new<'a, S: $Storage<T, R, C, RStride = RStride, CStride = CStride>>(
36                storage: $SRef,
37            ) -> Self {
38                let shape = storage.shape();
39                let strides = storage.strides();
40                let inner_offset = shape.0.value() * strides.0.value();
41                let size = shape.0.value() * shape.1.value();
42                let ptr = storage.$ptr();
43
44                // If we have a size of 0, 'ptr' must be
45                // dangling. However, 'inner_offset' might
46                // not be zero if only one dimension is zero, so
47                // we don't want to call 'offset'.
48                // This pointer will never actually get used
49                // if our size is '0', so it's fine to use
50                // 'ptr' for both the start and end.
51                let inner_end = if size == 0 {
52                    ptr
53                } else {
54                    // Safety:
55                    // If 'size' is non-zero, we know that 'ptr'
56                    // is not dangling, and 'inner_offset' must lie
57                    // within the allocation
58                    unsafe { ptr.add(inner_offset) }
59                };
60
61                RawIter {
62                    ptr,
63                    inner_ptr: ptr,
64                    inner_end,
65                    size: shape.0.value() * shape.1.value(),
66                    strides,
67                    _phantoms: PhantomData,
68                }
69            }
70        }
71
72        impl<T, R: Dim, C: Dim, RStride: Dim, CStride: Dim> Iterator
73            for RawIter<$Ptr, T, R, C, RStride, CStride>
74        {
75            type Item = $Ptr;
76
77            #[inline]
78            fn next(&mut self) -> Option<Self::Item> {
79                unsafe {
80                    if self.size == 0 {
81                        None
82                    } else {
83                        self.size -= 1;
84
85                        // Jump to the next outer dimension if needed.
86                        if self.ptr == self.inner_end {
87                            let stride = self.strides.1.value() as isize;
88                            // This might go past the end of the allocation,
89                            // depending on the value of 'size'. We use
90                            // `wrapping_offset` to avoid UB
91                            self.inner_end = self.ptr.wrapping_offset(stride);
92                            // This will always be in bounds, since
93                            // we're going to dereference it
94                            self.ptr = self.inner_ptr.offset(stride);
95                            self.inner_ptr = self.ptr;
96                        }
97
98                        // Go to the next element.
99                        let old = self.ptr;
100
101                        // Don't offset `self.ptr` for the last element,
102                        // as this will be out of bounds. Iteration is done
103                        // at this point (the next call to `next` will return `None`)
104                        // so this is not observable.
105                        if self.size != 0 {
106                            let stride = self.strides.0.value();
107                            self.ptr = self.ptr.add(stride);
108                        }
109
110                        Some(old)
111                    }
112                }
113            }
114
115            #[inline]
116            fn size_hint(&self) -> (usize, Option<usize>) {
117                (self.size, Some(self.size))
118            }
119
120            #[inline]
121            fn count(self) -> usize {
122                self.size_hint().0
123            }
124        }
125
126        impl<T, R: Dim, C: Dim, RStride: Dim, CStride: Dim> DoubleEndedIterator
127            for RawIter<$Ptr, T, R, C, RStride, CStride>
128        {
129            #[inline]
130            fn next_back(&mut self) -> Option<Self::Item> {
131                unsafe {
132                    if self.size == 0 {
133                        None
134                    } else {
135                        // Pre-decrement `size` such that it now counts to the
136                        // element we want to return.
137                        self.size -= 1;
138
139                        // Fetch strides
140                        let inner_stride = self.strides.0.value();
141                        let outer_stride = self.strides.1.value();
142
143                        // Compute number of rows
144                        // Division should be exact
145                        let inner_raw_size = self.inner_end.offset_from(self.inner_ptr) as usize;
146                        let inner_size = inner_raw_size / inner_stride;
147
148                        // Compute rows and cols remaining
149                        let outer_remaining = self.size / inner_size;
150                        let inner_remaining = self.size % inner_size;
151
152                        // Compute pointer to last element
153                        let last = self
154                            .ptr
155                            .add((outer_remaining * outer_stride + inner_remaining * inner_stride));
156
157                        Some(last)
158                    }
159                }
160            }
161        }
162
163        impl<T, R: Dim, C: Dim, RStride: Dim, CStride: Dim> ExactSizeIterator
164            for RawIter<$Ptr, T, R, C, RStride, CStride>
165        {
166            #[inline]
167            fn len(&self) -> usize {
168                self.size
169            }
170        }
171
172        impl<T, R: Dim, C: Dim, RStride: Dim, CStride: Dim> FusedIterator
173            for RawIter<$Ptr, T, R, C, RStride, CStride>
174        {
175        }
176
177        /// An iterator through a dense matrix with arbitrary strides matrix.
178        #[derive($($derives),*)]
179        pub struct $Name<'a, T, R: Dim, C: Dim, S: 'a + $Storage<T, R, C>> {
180            inner: RawIter<$Ptr, T, R, C, S::RStride, S::CStride>,
181            _marker: PhantomData<$Ref>,
182        }
183
184        impl<'a, T, R: Dim, C: Dim, S: 'a + $Storage<T, R, C>> $Name<'a, T, R, C, S> {
185            /// Creates a new iterator for the given matrix storage.
186            pub fn new(storage: $SRef) -> Self {
187                Self {
188                    inner: RawIter::<$Ptr, T, R, C, S::RStride, S::CStride>::new(storage),
189                    _marker: PhantomData,
190                }
191            }
192        }
193
194        impl<'a, T, R: Dim, C: Dim, S: 'a + $Storage<T, R, C>> Iterator for $Name<'a, T, R, C, S> {
195            type Item = $Ref;
196
197            #[inline(always)]
198            fn next(&mut self) -> Option<Self::Item> {
199                // We want either `& *last` or `&mut *last` here, depending
200                // on the mutability of `$Ref`.
201                #[allow(clippy::transmute_ptr_to_ref)]
202                self.inner.next().map(|ptr| unsafe { mem::transmute(ptr) })
203            }
204
205            #[inline(always)]
206            fn size_hint(&self) -> (usize, Option<usize>) {
207                self.inner.size_hint()
208            }
209
210            #[inline(always)]
211            fn count(self) -> usize {
212                self.inner.count()
213            }
214        }
215
216        impl<'a, T, R: Dim, C: Dim, S: 'a + $Storage<T, R, C>> DoubleEndedIterator
217            for $Name<'a, T, R, C, S>
218        {
219            #[inline(always)]
220            fn next_back(&mut self) -> Option<Self::Item> {
221                // We want either `& *last` or `&mut *last` here, depending
222                // on the mutability of `$Ref`.
223                #[allow(clippy::transmute_ptr_to_ref)]
224                self.inner
225                    .next_back()
226                    .map(|ptr| unsafe { mem::transmute(ptr) })
227            }
228        }
229
230        impl<'a, T, R: Dim, C: Dim, S: 'a + $Storage<T, R, C>> ExactSizeIterator
231            for $Name<'a, T, R, C, S>
232        {
233            #[inline(always)]
234            fn len(&self) -> usize {
235                self.inner.len()
236            }
237        }
238
239        impl<'a, T, R: Dim, C: Dim, S: 'a + $Storage<T, R, C>> FusedIterator
240            for $Name<'a, T, R, C, S>
241        {
242        }
243    };
244}
245
246iterator!(struct MatrixIter for RawStorage.ptr -> *const T, &'a T, &'a S, Clone, Debug);
247iterator!(struct MatrixIterMut for RawStorageMut.ptr_mut -> *mut T, &'a mut T, &'a mut S, Debug);
248
249impl<'a, T, R: Dim, C: Dim, RStride: Dim, CStride: Dim>
250    MatrixIter<'a, T, R, C, ViewStorage<'a, T, R, C, RStride, CStride>>
251{
252    /// Creates a new iterator for the given matrix storage view.
253    pub fn new_owned(storage: ViewStorage<'a, T, R, C, RStride, CStride>) -> Self {
254        Self {
255            inner: RawIter::<*const T, T, R, C, RStride, CStride>::new(&storage),
256            _marker: PhantomData,
257        }
258    }
259}
260
261impl<'a, T, R: Dim, C: Dim, RStride: Dim, CStride: Dim>
262    MatrixIterMut<'a, T, R, C, ViewStorageMut<'a, T, R, C, RStride, CStride>>
263{
264    /// Creates a new iterator for the given matrix storage view.
265    pub fn new_owned_mut(mut storage: ViewStorageMut<'a, T, R, C, RStride, CStride>) -> Self {
266        Self {
267            inner: RawIter::<*mut T, T, R, C, RStride, CStride>::new(&mut storage),
268            _marker: PhantomData,
269        }
270    }
271}
272
273/*
274 *
275 * Row iterators.
276 *
277 */
278#[derive(Clone, Debug)]
279/// An iterator through the rows of a matrix.
280pub struct RowIter<'a, T, R: Dim, C: Dim, S: RawStorage<T, R, C>> {
281    mat: &'a Matrix<T, R, C, S>,
282    curr: usize,
283}
284
285impl<'a, T, R: Dim, C: Dim, S: 'a + RawStorage<T, R, C>> RowIter<'a, T, R, C, S> {
286    pub(crate) fn new(mat: &'a Matrix<T, R, C, S>) -> Self {
287        RowIter { mat, curr: 0 }
288    }
289}
290
291impl<'a, T, R: Dim, C: Dim, S: 'a + RawStorage<T, R, C>> Iterator for RowIter<'a, T, R, C, S> {
292    type Item = MatrixView<'a, T, U1, C, S::RStride, S::CStride>;
293
294    #[inline]
295    fn next(&mut self) -> Option<Self::Item> {
296        if self.curr < self.mat.nrows() {
297            let res = self.mat.row(self.curr);
298            self.curr += 1;
299            Some(res)
300        } else {
301            None
302        }
303    }
304
305    #[inline]
306    fn size_hint(&self) -> (usize, Option<usize>) {
307        (
308            self.mat.nrows() - self.curr,
309            Some(self.mat.nrows() - self.curr),
310        )
311    }
312
313    #[inline]
314    fn count(self) -> usize {
315        self.mat.nrows() - self.curr
316    }
317}
318
319impl<'a, T: Scalar, R: Dim, C: Dim, S: 'a + RawStorage<T, R, C>> ExactSizeIterator
320    for RowIter<'a, T, R, C, S>
321{
322    #[inline]
323    fn len(&self) -> usize {
324        self.mat.nrows() - self.curr
325    }
326}
327
328/// An iterator through the mutable rows of a matrix.
329#[derive(Debug)]
330pub struct RowIterMut<'a, T, R: Dim, C: Dim, S: RawStorageMut<T, R, C>> {
331    mat: *mut Matrix<T, R, C, S>,
332    curr: usize,
333    phantom: PhantomData<&'a mut Matrix<T, R, C, S>>,
334}
335
336impl<'a, T, R: Dim, C: Dim, S: 'a + RawStorageMut<T, R, C>> RowIterMut<'a, T, R, C, S> {
337    pub(crate) fn new(mat: &'a mut Matrix<T, R, C, S>) -> Self {
338        RowIterMut {
339            mat,
340            curr: 0,
341            phantom: PhantomData,
342        }
343    }
344
345    fn nrows(&self) -> usize {
346        unsafe { (*self.mat).nrows() }
347    }
348}
349
350impl<'a, T, R: Dim, C: Dim, S: 'a + RawStorageMut<T, R, C>> Iterator
351    for RowIterMut<'a, T, R, C, S>
352{
353    type Item = MatrixViewMut<'a, T, U1, C, S::RStride, S::CStride>;
354
355    #[inline]
356    fn next(&mut self) -> Option<Self::Item> {
357        if self.curr < self.nrows() {
358            let res = unsafe { (*self.mat).row_mut(self.curr) };
359            self.curr += 1;
360            Some(res)
361        } else {
362            None
363        }
364    }
365
366    #[inline]
367    fn size_hint(&self) -> (usize, Option<usize>) {
368        (self.nrows() - self.curr, Some(self.nrows() - self.curr))
369    }
370
371    #[inline]
372    fn count(self) -> usize {
373        self.nrows() - self.curr
374    }
375}
376
377impl<'a, T: Scalar, R: Dim, C: Dim, S: 'a + RawStorageMut<T, R, C>> ExactSizeIterator
378    for RowIterMut<'a, T, R, C, S>
379{
380    #[inline]
381    fn len(&self) -> usize {
382        self.nrows() - self.curr
383    }
384}
385
386/*
387 * Column iterators.
388 *
389 */
390#[derive(Clone, Debug)]
391/// An iterator through the columns of a matrix.
392pub struct ColumnIter<'a, T, R: Dim, C: Dim, S: RawStorage<T, R, C>> {
393    mat: &'a Matrix<T, R, C, S>,
394    range: Range<usize>,
395}
396
397impl<'a, T, R: Dim, C: Dim, S: 'a + RawStorage<T, R, C>> ColumnIter<'a, T, R, C, S> {
398    /// a new column iterator covering all columns of the matrix
399    pub(crate) fn new(mat: &'a Matrix<T, R, C, S>) -> Self {
400        ColumnIter {
401            mat,
402            range: 0..mat.ncols(),
403        }
404    }
405
406    #[cfg(feature = "rayon")]
407    pub(crate) fn split_at(self, index: usize) -> (Self, Self) {
408        // SAFETY: this makes sure the generated ranges are valid.
409        let split_pos = (self.range.start + index).min(self.range.end);
410
411        let left_iter = ColumnIter {
412            mat: self.mat,
413            range: self.range.start..split_pos,
414        };
415
416        let right_iter = ColumnIter {
417            mat: self.mat,
418            range: split_pos..self.range.end,
419        };
420
421        (left_iter, right_iter)
422    }
423}
424
425impl<'a, T, R: Dim, C: Dim, S: 'a + RawStorage<T, R, C>> Iterator for ColumnIter<'a, T, R, C, S> {
426    type Item = MatrixView<'a, T, R, U1, S::RStride, S::CStride>;
427
428    #[inline]
429    fn next(&mut self) -> Option<Self::Item> {
430        debug_assert!(self.range.start <= self.range.end);
431        if self.range.start < self.range.end {
432            let res = self.mat.column(self.range.start);
433            self.range.start += 1;
434            Some(res)
435        } else {
436            None
437        }
438    }
439
440    #[inline]
441    fn size_hint(&self) -> (usize, Option<usize>) {
442        let hint = self.range.len();
443        (hint, Some(hint))
444    }
445
446    #[inline]
447    fn count(self) -> usize {
448        self.range.len()
449    }
450}
451
452impl<'a, T, R: Dim, C: Dim, S: 'a + RawStorage<T, R, C>> DoubleEndedIterator
453    for ColumnIter<'a, T, R, C, S>
454{
455    fn next_back(&mut self) -> Option<Self::Item> {
456        debug_assert!(self.range.start <= self.range.end);
457        if !self.range.is_empty() {
458            self.range.end -= 1;
459            debug_assert!(self.range.end < self.mat.ncols());
460            debug_assert!(self.range.end >= self.range.start);
461            Some(self.mat.column(self.range.end))
462        } else {
463            None
464        }
465    }
466}
467
468impl<'a, T: Scalar, R: Dim, C: Dim, S: 'a + RawStorage<T, R, C>> ExactSizeIterator
469    for ColumnIter<'a, T, R, C, S>
470{
471    #[inline]
472    fn len(&self) -> usize {
473        self.range.end - self.range.start
474    }
475}
476
477/// An iterator through the mutable columns of a matrix.
478#[derive(Debug)]
479pub struct ColumnIterMut<'a, T, R: Dim, C: Dim, S: RawStorageMut<T, R, C>> {
480    mat: *mut Matrix<T, R, C, S>,
481    range: Range<usize>,
482    phantom: PhantomData<&'a mut Matrix<T, R, C, S>>,
483}
484
485impl<'a, T, R: Dim, C: Dim, S: 'a + RawStorageMut<T, R, C>> ColumnIterMut<'a, T, R, C, S> {
486    pub(crate) fn new(mat: &'a mut Matrix<T, R, C, S>) -> Self {
487        let range = 0..mat.ncols();
488        ColumnIterMut {
489            mat,
490            range,
491            phantom: Default::default(),
492        }
493    }
494
495    #[cfg(feature = "rayon")]
496    pub(crate) fn split_at(self, index: usize) -> (Self, Self) {
497        // SAFETY: this makes sure the generated ranges are valid.
498        let split_pos = (self.range.start + index).min(self.range.end);
499
500        let left_iter = ColumnIterMut {
501            mat: self.mat,
502            range: self.range.start..split_pos,
503            phantom: Default::default(),
504        };
505
506        let right_iter = ColumnIterMut {
507            mat: self.mat,
508            range: split_pos..self.range.end,
509            phantom: Default::default(),
510        };
511
512        (left_iter, right_iter)
513    }
514
515    fn ncols(&self) -> usize {
516        unsafe { (*self.mat).ncols() }
517    }
518}
519
520impl<'a, T, R: Dim, C: Dim, S: 'a + RawStorageMut<T, R, C>> Iterator
521    for ColumnIterMut<'a, T, R, C, S>
522{
523    type Item = MatrixViewMut<'a, T, R, U1, S::RStride, S::CStride>;
524
525    #[inline]
526    fn next(&'_ mut self) -> Option<Self::Item> {
527        debug_assert!(self.range.start <= self.range.end);
528        if self.range.start < self.range.end {
529            let res = unsafe { (*self.mat).column_mut(self.range.start) };
530            self.range.start += 1;
531            Some(res)
532        } else {
533            None
534        }
535    }
536
537    #[inline]
538    fn size_hint(&self) -> (usize, Option<usize>) {
539        let hint = self.range.len();
540        (hint, Some(hint))
541    }
542
543    #[inline]
544    fn count(self) -> usize {
545        self.range.len()
546    }
547}
548
549impl<'a, T: Scalar, R: Dim, C: Dim, S: 'a + RawStorageMut<T, R, C>> ExactSizeIterator
550    for ColumnIterMut<'a, T, R, C, S>
551{
552    #[inline]
553    fn len(&self) -> usize {
554        self.range.len()
555    }
556}
557
558impl<'a, T: Scalar, R: Dim, C: Dim, S: 'a + RawStorageMut<T, R, C>> DoubleEndedIterator
559    for ColumnIterMut<'a, T, R, C, S>
560{
561    fn next_back(&mut self) -> Option<Self::Item> {
562        debug_assert!(self.range.start <= self.range.end);
563        if !self.range.is_empty() {
564            self.range.end -= 1;
565            debug_assert!(self.range.end < self.ncols());
566            debug_assert!(self.range.end >= self.range.start);
567            Some(unsafe { (*self.mat).column_mut(self.range.end) })
568        } else {
569            None
570        }
571    }
572}