nalgebra/linalg/
qr.rs

1use num::Zero;
2#[cfg(feature = "serde-serialize-no-std")]
3use serde::{Deserialize, Serialize};
4
5use crate::allocator::{Allocator, Reallocator};
6use crate::base::{DefaultAllocator, Matrix, OMatrix, OVector, Unit};
7use crate::constraint::{SameNumberOfRows, ShapeConstraint};
8use crate::dimension::{Const, Dim, DimMin, DimMinimum};
9use crate::storage::{Storage, StorageMut};
10use simba::scalar::ComplexField;
11
12use crate::geometry::Reflection;
13use crate::linalg::householder;
14use std::mem::MaybeUninit;
15
16/// The QR decomposition of a general matrix.
17#[cfg_attr(feature = "serde-serialize-no-std", derive(Serialize, Deserialize))]
18#[cfg_attr(
19    feature = "serde-serialize-no-std",
20    serde(bound(serialize = "DefaultAllocator: Allocator<R, C> +
21                           Allocator<DimMinimum<R, C>>,
22         OMatrix<T, R, C>: Serialize,
23         OVector<T, DimMinimum<R, C>>: Serialize"))
24)]
25#[cfg_attr(
26    feature = "serde-serialize-no-std",
27    serde(bound(deserialize = "DefaultAllocator: Allocator<R, C> +
28                           Allocator<DimMinimum<R, C>>,
29         OMatrix<T, R, C>: Deserialize<'de>,
30         OVector<T, DimMinimum<R, C>>: Deserialize<'de>"))
31)]
32#[derive(Clone, Debug)]
33pub struct QR<T: ComplexField, R: DimMin<C>, C: Dim>
34where
35    DefaultAllocator: Allocator<R, C> + Allocator<DimMinimum<R, C>>,
36{
37    qr: OMatrix<T, R, C>,
38    diag: OVector<T, DimMinimum<R, C>>,
39}
40
41impl<T: ComplexField, R: DimMin<C>, C: Dim> Copy for QR<T, R, C>
42where
43    DefaultAllocator: Allocator<R, C> + Allocator<DimMinimum<R, C>>,
44    OMatrix<T, R, C>: Copy,
45    OVector<T, DimMinimum<R, C>>: Copy,
46{
47}
48
49impl<T: ComplexField, R: DimMin<C>, C: Dim> QR<T, R, C>
50where
51    DefaultAllocator: Allocator<R, C> + Allocator<R> + Allocator<DimMinimum<R, C>>,
52{
53    /// Computes the QR decomposition using householder reflections.
54    pub fn new(mut matrix: OMatrix<T, R, C>) -> Self {
55        let (nrows, ncols) = matrix.shape_generic();
56        let min_nrows_ncols = nrows.min(ncols);
57
58        if min_nrows_ncols.value() == 0 {
59            return QR {
60                qr: matrix,
61                diag: Matrix::zeros_generic(min_nrows_ncols, Const::<1>),
62            };
63        }
64
65        let mut diag = Matrix::uninit(min_nrows_ncols, Const::<1>);
66
67        for i in 0..min_nrows_ncols.value() {
68            diag[i] =
69                MaybeUninit::new(householder::clear_column_unchecked(&mut matrix, i, 0, None));
70        }
71
72        // Safety: diag is now fully initialized.
73        let diag = unsafe { diag.assume_init() };
74        QR { qr: matrix, diag }
75    }
76
77    /// Retrieves the upper trapezoidal submatrix `R` of this decomposition.
78    #[inline]
79    #[must_use]
80    pub fn r(&self) -> OMatrix<T, DimMinimum<R, C>, C>
81    where
82        DefaultAllocator: Allocator<DimMinimum<R, C>, C>,
83    {
84        let (nrows, ncols) = self.qr.shape_generic();
85        let mut res = self.qr.rows_generic(0, nrows.min(ncols)).upper_triangle();
86        res.set_partial_diagonal(self.diag.iter().map(|e| T::from_real(e.clone().modulus())));
87        res
88    }
89
90    /// Retrieves the upper trapezoidal submatrix `R` of this decomposition.
91    ///
92    /// This is usually faster than `r` but consumes `self`.
93    #[inline]
94    pub fn unpack_r(self) -> OMatrix<T, DimMinimum<R, C>, C>
95    where
96        DefaultAllocator: Reallocator<T, R, C, DimMinimum<R, C>, C>,
97    {
98        let (nrows, ncols) = self.qr.shape_generic();
99        let mut res = self.qr.resize_generic(nrows.min(ncols), ncols, T::zero());
100        res.fill_lower_triangle(T::zero(), 1);
101        res.set_partial_diagonal(self.diag.iter().map(|e| T::from_real(e.clone().modulus())));
102        res
103    }
104
105    /// Computes the orthogonal matrix `Q` of this decomposition.
106    #[must_use]
107    pub fn q(&self) -> OMatrix<T, R, DimMinimum<R, C>>
108    where
109        DefaultAllocator: Allocator<R, DimMinimum<R, C>>,
110    {
111        let (nrows, ncols) = self.qr.shape_generic();
112
113        // NOTE: we could build the identity matrix and call q_mul on it.
114        // Instead we don't so that we take in account the matrix sparseness.
115        let mut res = Matrix::identity_generic(nrows, nrows.min(ncols));
116        let dim = self.diag.len();
117
118        for i in (0..dim).rev() {
119            let axis = self.qr.view_range(i.., i);
120            // TODO: sometimes, the axis might have a zero magnitude.
121            let refl = Reflection::new(Unit::new_unchecked(axis), T::zero());
122
123            let mut res_rows = res.view_range_mut(i.., i..);
124            refl.reflect_with_sign(&mut res_rows, self.diag[i].clone().signum());
125        }
126
127        res
128    }
129
130    /// Unpacks this decomposition into its two matrix factors.
131    pub fn unpack(
132        self,
133    ) -> (
134        OMatrix<T, R, DimMinimum<R, C>>,
135        OMatrix<T, DimMinimum<R, C>, C>,
136    )
137    where
138        DimMinimum<R, C>: DimMin<C, Output = DimMinimum<R, C>>,
139        DefaultAllocator:
140            Allocator<R, DimMinimum<R, C>> + Reallocator<T, R, C, DimMinimum<R, C>, C>,
141    {
142        (self.q(), self.unpack_r())
143    }
144
145    #[doc(hidden)]
146    pub fn qr_internal(&self) -> &OMatrix<T, R, C> {
147        &self.qr
148    }
149
150    #[must_use]
151    pub(crate) fn diag_internal(&self) -> &OVector<T, DimMinimum<R, C>> {
152        &self.diag
153    }
154
155    /// Multiplies the provided matrix by the transpose of the `Q` matrix of this decomposition.
156    pub fn q_tr_mul<R2: Dim, C2: Dim, S2>(&self, rhs: &mut Matrix<T, R2, C2, S2>)
157    // TODO: do we need a static constraint on the number of rows of rhs?
158    where
159        S2: StorageMut<T, R2, C2>,
160    {
161        let dim = self.diag.len();
162
163        for i in 0..dim {
164            let axis = self.qr.view_range(i.., i);
165            let refl = Reflection::new(Unit::new_unchecked(axis), T::zero());
166
167            let mut rhs_rows = rhs.rows_range_mut(i..);
168            refl.reflect_with_sign(&mut rhs_rows, self.diag[i].clone().signum().conjugate());
169        }
170    }
171}
172
173impl<T: ComplexField, D: DimMin<D, Output = D>> QR<T, D, D>
174where
175    DefaultAllocator: Allocator<D, D> + Allocator<D>,
176{
177    /// Solves the linear system `self * x = b`, where `x` is the unknown to be determined.
178    ///
179    /// Returns `None` if `self` is not invertible.
180    #[must_use = "Did you mean to use solve_mut()?"]
181    pub fn solve<R2: Dim, C2: Dim, S2>(
182        &self,
183        b: &Matrix<T, R2, C2, S2>,
184    ) -> Option<OMatrix<T, R2, C2>>
185    where
186        S2: Storage<T, R2, C2>,
187        ShapeConstraint: SameNumberOfRows<R2, D>,
188        DefaultAllocator: Allocator<R2, C2>,
189    {
190        let mut res = b.clone_owned();
191
192        if self.solve_mut(&mut res) {
193            Some(res)
194        } else {
195            None
196        }
197    }
198
199    /// Solves the linear system `self * x = b`, where `x` is the unknown to be determined.
200    ///
201    /// If the decomposed matrix is not invertible, this returns `false` and its input `b` is
202    /// overwritten with garbage.
203    pub fn solve_mut<R2: Dim, C2: Dim, S2>(&self, b: &mut Matrix<T, R2, C2, S2>) -> bool
204    where
205        S2: StorageMut<T, R2, C2>,
206        ShapeConstraint: SameNumberOfRows<R2, D>,
207    {
208        assert_eq!(
209            self.qr.nrows(),
210            b.nrows(),
211            "QR solve matrix dimension mismatch."
212        );
213        assert!(
214            self.qr.is_square(),
215            "QR solve: unable to solve a non-square system."
216        );
217
218        self.q_tr_mul(b);
219        self.solve_upper_triangular_mut(b)
220    }
221
222    // TODO: duplicate code from the `solve` module.
223    fn solve_upper_triangular_mut<R2: Dim, C2: Dim, S2>(
224        &self,
225        b: &mut Matrix<T, R2, C2, S2>,
226    ) -> bool
227    where
228        S2: StorageMut<T, R2, C2>,
229        ShapeConstraint: SameNumberOfRows<R2, D>,
230    {
231        let dim = self.qr.nrows();
232
233        for k in 0..b.ncols() {
234            let mut b = b.column_mut(k);
235            for i in (0..dim).rev() {
236                let coeff;
237
238                unsafe {
239                    let diag = self.diag.vget_unchecked(i).clone().modulus();
240
241                    if diag.is_zero() {
242                        return false;
243                    }
244
245                    coeff = b.vget_unchecked(i).clone().unscale(diag);
246                    *b.vget_unchecked_mut(i) = coeff.clone();
247                }
248
249                b.rows_range_mut(..i)
250                    .axpy(-coeff, &self.qr.view_range(..i, i), T::one());
251            }
252        }
253
254        true
255    }
256
257    /// Computes the inverse of the decomposed matrix.
258    ///
259    /// Returns `None` if the decomposed matrix is not invertible.
260    #[must_use]
261    pub fn try_inverse(&self) -> Option<OMatrix<T, D, D>> {
262        assert!(
263            self.qr.is_square(),
264            "QR inverse: unable to compute the inverse of a non-square matrix."
265        );
266
267        // TODO: is there a less naive method ?
268        let (nrows, ncols) = self.qr.shape_generic();
269        let mut res = OMatrix::identity_generic(nrows, ncols);
270
271        if self.solve_mut(&mut res) {
272            Some(res)
273        } else {
274            None
275        }
276    }
277
278    /// Indicates if the decomposed matrix is invertible.
279    #[must_use]
280    pub fn is_invertible(&self) -> bool {
281        assert!(
282            self.qr.is_square(),
283            "QR: unable to test the invertibility of a non-square matrix."
284        );
285
286        for i in 0..self.diag.len() {
287            if self.diag[i].is_zero() {
288                return false;
289            }
290        }
291
292        true
293    }
294
295    // /// Computes the determinant of the decomposed matrix.
296    // pub fn determinant(&self) -> T {
297    //     let dim = self.qr.nrows();
298    //     assert!(self.qr.is_square(), "QR determinant: unable to compute the determinant of a non-square matrix.");
299
300    //     let mut res = T::one();
301    //     for i in 0 .. dim {
302    //         res *= unsafe { *self.diag.vget_unchecked(i) };
303    //     }
304
305    //     res self.q_determinant()
306    // }
307}