nalgebra/linalg/
col_piv_qr.rs

1use num::Zero;
2#[cfg(feature = "serde-serialize-no-std")]
3use serde::{Deserialize, Serialize};
4
5use crate::ComplexField;
6use crate::allocator::{Allocator, Reallocator};
7use crate::base::{Const, DefaultAllocator, Matrix, OMatrix, OVector, Unit};
8use crate::constraint::{SameNumberOfRows, ShapeConstraint};
9use crate::dimension::{Dim, DimMin, DimMinimum};
10use crate::storage::StorageMut;
11
12use crate::geometry::Reflection;
13use crate::linalg::{PermutationSequence, householder};
14use std::mem::MaybeUninit;
15
16/// The QR decomposition (with column pivoting) of a general matrix.
17#[cfg_attr(feature = "serde-serialize-no-std", derive(Serialize, Deserialize))]
18#[cfg_attr(
19    feature = "serde-serialize-no-std",
20    serde(bound(serialize = "DefaultAllocator: Allocator<R, C> +
21                           Allocator<DimMinimum<R, C>>,
22         OMatrix<T, R, C>: Serialize,
23         PermutationSequence<DimMinimum<R, C>>: Serialize,
24         OVector<T, DimMinimum<R, C>>: Serialize"))
25)]
26#[cfg_attr(
27    feature = "serde-serialize-no-std",
28    serde(bound(deserialize = "DefaultAllocator: Allocator<R, C> +
29                           Allocator<DimMinimum<R, C>>,
30         OMatrix<T, R, C>: Deserialize<'de>,
31         PermutationSequence<DimMinimum<R, C>>: Deserialize<'de>,
32         OVector<T, DimMinimum<R, C>>: Deserialize<'de>"))
33)]
34#[cfg_attr(feature = "defmt", derive(defmt::Format))]
35#[derive(Clone, Debug)]
36pub struct ColPivQR<T: ComplexField, R: DimMin<C>, C: Dim>
37where
38    DefaultAllocator: Allocator<R, C> + Allocator<DimMinimum<R, C>>,
39{
40    col_piv_qr: OMatrix<T, R, C>,
41    p: PermutationSequence<DimMinimum<R, C>>,
42    diag: OVector<T, DimMinimum<R, C>>,
43}
44
45impl<T: ComplexField, R: DimMin<C>, C: Dim> Copy for ColPivQR<T, R, C>
46where
47    DefaultAllocator: Allocator<R, C> + Allocator<DimMinimum<R, C>>,
48    OMatrix<T, R, C>: Copy,
49    PermutationSequence<DimMinimum<R, C>>: Copy,
50    OVector<T, DimMinimum<R, C>>: Copy,
51{
52}
53
54impl<T: ComplexField, R: DimMin<C>, C: Dim> ColPivQR<T, R, C>
55where
56    DefaultAllocator: Allocator<R, C> + Allocator<R> + Allocator<DimMinimum<R, C>>,
57{
58    /// Computes the `ColPivQR` decomposition using householder reflections.
59    pub fn new(mut matrix: OMatrix<T, R, C>) -> Self {
60        let (nrows, ncols) = matrix.shape_generic();
61        let min_nrows_ncols = nrows.min(ncols);
62        let mut p = PermutationSequence::identity_generic(min_nrows_ncols);
63
64        if min_nrows_ncols.value() == 0 {
65            return ColPivQR {
66                col_piv_qr: matrix,
67                p,
68                diag: Matrix::zeros_generic(min_nrows_ncols, Const::<1>),
69            };
70        }
71
72        let mut diag = Matrix::uninit(min_nrows_ncols, Const::<1>);
73
74        for i in 0..min_nrows_ncols.value() {
75            let piv = matrix.view_range(i.., i..).icamax_full();
76            let col_piv = piv.1 + i;
77            matrix.swap_columns(i, col_piv);
78            p.append_permutation(i, col_piv);
79
80            diag[i] =
81                MaybeUninit::new(householder::clear_column_unchecked(&mut matrix, i, 0, None));
82        }
83
84        // Safety: diag is now fully initialized.
85        let diag = unsafe { diag.assume_init() };
86
87        ColPivQR {
88            col_piv_qr: matrix,
89            p,
90            diag,
91        }
92    }
93
94    /// Retrieves the upper trapezoidal submatrix `R` of this decomposition.
95    #[inline]
96    #[must_use]
97    pub fn r(&self) -> OMatrix<T, DimMinimum<R, C>, C>
98    where
99        DefaultAllocator: Allocator<DimMinimum<R, C>, C>,
100    {
101        let (nrows, ncols) = self.col_piv_qr.shape_generic();
102        let mut res = self
103            .col_piv_qr
104            .rows_generic(0, nrows.min(ncols))
105            .upper_triangle();
106        res.set_partial_diagonal(self.diag.iter().map(|e| T::from_real(e.clone().modulus())));
107        res
108    }
109
110    /// Retrieves the upper trapezoidal submatrix `R` of this decomposition.
111    ///
112    /// This is usually faster than `r` but consumes `self`.
113    #[inline]
114    pub fn unpack_r(self) -> OMatrix<T, DimMinimum<R, C>, C>
115    where
116        DefaultAllocator: Reallocator<T, R, C, DimMinimum<R, C>, C>,
117    {
118        let (nrows, ncols) = self.col_piv_qr.shape_generic();
119        let mut res = self
120            .col_piv_qr
121            .resize_generic(nrows.min(ncols), ncols, T::zero());
122        res.fill_lower_triangle(T::zero(), 1);
123        res.set_partial_diagonal(self.diag.iter().map(|e| T::from_real(e.clone().modulus())));
124        res
125    }
126
127    /// Computes the orthogonal matrix `Q` of this decomposition.
128    #[must_use]
129    pub fn q(&self) -> OMatrix<T, R, DimMinimum<R, C>>
130    where
131        DefaultAllocator: Allocator<R, DimMinimum<R, C>>,
132    {
133        let (nrows, ncols) = self.col_piv_qr.shape_generic();
134
135        // NOTE: we could build the identity matrix and call q_mul on it.
136        // Instead we don't so that we take in account the matrix sparseness.
137        let mut res = Matrix::identity_generic(nrows, nrows.min(ncols));
138        let dim = self.diag.len();
139
140        for i in (0..dim).rev() {
141            let axis = self.col_piv_qr.view_range(i.., i);
142            // TODO: sometimes, the axis might have a zero magnitude.
143            let refl = Reflection::new(Unit::new_unchecked(axis), T::zero());
144
145            let mut res_rows = res.view_range_mut(i.., i..);
146            refl.reflect_with_sign(&mut res_rows, self.diag[i].clone().signum());
147        }
148
149        res
150    }
151    /// Retrieves the column permutation of this decomposition.
152    #[inline]
153    #[must_use]
154    pub const fn p(&self) -> &PermutationSequence<DimMinimum<R, C>> {
155        &self.p
156    }
157
158    /// Unpacks this decomposition into its two matrix factors.
159    pub fn unpack(
160        self,
161    ) -> (
162        OMatrix<T, R, DimMinimum<R, C>>,
163        OMatrix<T, DimMinimum<R, C>, C>,
164        PermutationSequence<DimMinimum<R, C>>,
165    )
166    where
167        DimMinimum<R, C>: DimMin<C, Output = DimMinimum<R, C>>,
168        DefaultAllocator: Allocator<R, DimMinimum<R, C>>
169            + Reallocator<T, R, C, DimMinimum<R, C>, C>
170            + Allocator<DimMinimum<R, C>>,
171    {
172        (self.q(), self.r(), self.p)
173    }
174
175    #[doc(hidden)]
176    pub const fn col_piv_qr_internal(&self) -> &OMatrix<T, R, C> {
177        &self.col_piv_qr
178    }
179
180    /// Multiplies the provided matrix by the transpose of the `Q` matrix of this decomposition.
181    pub fn q_tr_mul<R2: Dim, C2: Dim, S2>(&self, rhs: &mut Matrix<T, R2, C2, S2>)
182    where
183        S2: StorageMut<T, R2, C2>,
184    {
185        let dim = self.diag.len();
186
187        for i in 0..dim {
188            let axis = self.col_piv_qr.view_range(i.., i);
189            let refl = Reflection::new(Unit::new_unchecked(axis), T::zero());
190
191            let mut rhs_rows = rhs.rows_range_mut(i..);
192            refl.reflect_with_sign(&mut rhs_rows, self.diag[i].clone().signum().conjugate());
193        }
194    }
195}
196
197impl<T: ComplexField, D: DimMin<D, Output = D>> ColPivQR<T, D, D>
198where
199    DefaultAllocator: Allocator<D, D> + Allocator<D> + Allocator<DimMinimum<D, D>>,
200{
201    /// Solves the linear system `self * x = b`, where `x` is the unknown to be determined.
202    ///
203    /// Returns `None` if `self` is not invertible.
204    #[must_use = "Did you mean to use solve_mut()?"]
205    pub fn solve<R2: Dim, C2: Dim, S2>(
206        &self,
207        b: &Matrix<T, R2, C2, S2>,
208    ) -> Option<OMatrix<T, R2, C2>>
209    where
210        S2: StorageMut<T, R2, C2>,
211        ShapeConstraint: SameNumberOfRows<R2, D>,
212        DefaultAllocator: Allocator<R2, C2>,
213    {
214        let mut res = b.clone_owned();
215
216        if self.solve_mut(&mut res) {
217            Some(res)
218        } else {
219            None
220        }
221    }
222
223    /// Solves the linear system `self * x = b`, where `x` is the unknown to be determined.
224    ///
225    /// If the decomposed matrix is not invertible, this returns `false` and its input `b` is
226    /// overwritten with garbage.
227    pub fn solve_mut<R2: Dim, C2: Dim, S2>(&self, b: &mut Matrix<T, R2, C2, S2>) -> bool
228    where
229        S2: StorageMut<T, R2, C2>,
230        ShapeConstraint: SameNumberOfRows<R2, D>,
231    {
232        assert_eq!(
233            self.col_piv_qr.nrows(),
234            b.nrows(),
235            "ColPivQR solve matrix dimension mismatch."
236        );
237        assert!(
238            self.col_piv_qr.is_square(),
239            "ColPivQR solve: unable to solve a non-square system."
240        );
241
242        self.q_tr_mul(b);
243        let solved = self.solve_upper_triangular_mut(b);
244        self.p.inv_permute_rows(b);
245
246        solved
247    }
248
249    // TODO: duplicate code from the `solve` module.
250    fn solve_upper_triangular_mut<R2: Dim, C2: Dim, S2>(
251        &self,
252        b: &mut Matrix<T, R2, C2, S2>,
253    ) -> bool
254    where
255        S2: StorageMut<T, R2, C2>,
256        ShapeConstraint: SameNumberOfRows<R2, D>,
257    {
258        let dim = self.col_piv_qr.nrows();
259
260        for k in 0..b.ncols() {
261            let mut b = b.column_mut(k);
262            for i in (0..dim).rev() {
263                let coeff;
264
265                unsafe {
266                    let diag = self.diag.vget_unchecked(i).clone().modulus();
267
268                    if diag.is_zero() {
269                        return false;
270                    }
271
272                    coeff = b.vget_unchecked(i).clone().unscale(diag);
273                    *b.vget_unchecked_mut(i) = coeff.clone();
274                }
275
276                b.rows_range_mut(..i)
277                    .axpy(-coeff, &self.col_piv_qr.view_range(..i, i), T::one());
278            }
279        }
280
281        true
282    }
283
284    /// Computes the inverse of the decomposed matrix.
285    ///
286    /// Returns `None` if the decomposed matrix is not invertible.
287    #[must_use]
288    pub fn try_inverse(&self) -> Option<OMatrix<T, D, D>> {
289        assert!(
290            self.col_piv_qr.is_square(),
291            "ColPivQR inverse: unable to compute the inverse of a non-square matrix."
292        );
293
294        // TODO: is there a less naive method ?
295        let (nrows, ncols) = self.col_piv_qr.shape_generic();
296        let mut res = OMatrix::identity_generic(nrows, ncols);
297
298        if self.solve_mut(&mut res) {
299            Some(res)
300        } else {
301            None
302        }
303    }
304
305    /// Indicates if the decomposed matrix is invertible.
306    #[must_use]
307    pub fn is_invertible(&self) -> bool {
308        assert!(
309            self.col_piv_qr.is_square(),
310            "ColPivQR: unable to test the invertibility of a non-square matrix."
311        );
312
313        for i in 0..self.diag.len() {
314            if self.diag[i].is_zero() {
315                return false;
316            }
317        }
318
319        true
320    }
321
322    /// Computes the determinant of the decomposed matrix.
323    #[must_use]
324    pub fn determinant(&self) -> T {
325        let dim = self.col_piv_qr.nrows();
326        assert!(
327            self.col_piv_qr.is_square(),
328            "ColPivQR determinant: unable to compute the determinant of a non-square matrix."
329        );
330
331        let mut res = T::one();
332        for i in 0..dim {
333            res *= unsafe { self.diag.vget_unchecked(i).clone() };
334        }
335
336        res * self.p.determinant()
337    }
338}