nalgebra/linalg/
qr.rs

1use num::Zero;
2#[cfg(feature = "serde-serialize-no-std")]
3use serde::{Deserialize, Serialize};
4
5use crate::allocator::{Allocator, Reallocator};
6use crate::base::{DefaultAllocator, Matrix, OMatrix, OVector, Unit};
7use crate::constraint::{SameNumberOfRows, ShapeConstraint};
8use crate::dimension::{Const, Dim, DimMin, DimMinimum};
9use crate::storage::{Storage, StorageMut};
10use simba::scalar::ComplexField;
11
12use crate::geometry::Reflection;
13use crate::linalg::householder;
14use std::mem::MaybeUninit;
15
16/// The QR decomposition of a general matrix.
17#[cfg_attr(feature = "serde-serialize-no-std", derive(Serialize, Deserialize))]
18#[cfg_attr(
19    feature = "serde-serialize-no-std",
20    serde(bound(serialize = "DefaultAllocator: Allocator<R, C> +
21                           Allocator<DimMinimum<R, C>>,
22         OMatrix<T, R, C>: Serialize,
23         OVector<T, DimMinimum<R, C>>: Serialize"))
24)]
25#[cfg_attr(
26    feature = "serde-serialize-no-std",
27    serde(bound(deserialize = "DefaultAllocator: Allocator<R, C> +
28                           Allocator<DimMinimum<R, C>>,
29         OMatrix<T, R, C>: Deserialize<'de>,
30         OVector<T, DimMinimum<R, C>>: Deserialize<'de>"))
31)]
32#[cfg_attr(feature = "defmt", derive(defmt::Format))]
33#[derive(Clone, Debug)]
34pub struct QR<T: ComplexField, R: DimMin<C>, C: Dim>
35where
36    DefaultAllocator: Allocator<R, C> + Allocator<DimMinimum<R, C>>,
37{
38    qr: OMatrix<T, R, C>,
39    diag: OVector<T, DimMinimum<R, C>>,
40}
41
42impl<T: ComplexField, R: DimMin<C>, C: Dim> Copy for QR<T, R, C>
43where
44    DefaultAllocator: Allocator<R, C> + Allocator<DimMinimum<R, C>>,
45    OMatrix<T, R, C>: Copy,
46    OVector<T, DimMinimum<R, C>>: Copy,
47{
48}
49
50impl<T: ComplexField, R: DimMin<C>, C: Dim> QR<T, R, C>
51where
52    DefaultAllocator: Allocator<R, C> + Allocator<R> + Allocator<DimMinimum<R, C>>,
53{
54    /// Computes the QR decomposition using householder reflections.
55    pub fn new(mut matrix: OMatrix<T, R, C>) -> Self {
56        let (nrows, ncols) = matrix.shape_generic();
57        let min_nrows_ncols = nrows.min(ncols);
58
59        if min_nrows_ncols.value() == 0 {
60            return QR {
61                qr: matrix,
62                diag: Matrix::zeros_generic(min_nrows_ncols, Const::<1>),
63            };
64        }
65
66        let mut diag = Matrix::uninit(min_nrows_ncols, Const::<1>);
67
68        for i in 0..min_nrows_ncols.value() {
69            diag[i] =
70                MaybeUninit::new(householder::clear_column_unchecked(&mut matrix, i, 0, None));
71        }
72
73        // Safety: diag is now fully initialized.
74        let diag = unsafe { diag.assume_init() };
75        QR { qr: matrix, diag }
76    }
77
78    /// Retrieves the upper trapezoidal submatrix `R` of this decomposition.
79    #[inline]
80    #[must_use]
81    pub fn r(&self) -> OMatrix<T, DimMinimum<R, C>, C>
82    where
83        DefaultAllocator: Allocator<DimMinimum<R, C>, C>,
84    {
85        let (nrows, ncols) = self.qr.shape_generic();
86        let mut res = self.qr.rows_generic(0, nrows.min(ncols)).upper_triangle();
87        res.set_partial_diagonal(self.diag.iter().map(|e| T::from_real(e.clone().modulus())));
88        res
89    }
90
91    /// Retrieves the upper trapezoidal submatrix `R` of this decomposition.
92    ///
93    /// This is usually faster than `r` but consumes `self`.
94    #[inline]
95    pub fn unpack_r(self) -> OMatrix<T, DimMinimum<R, C>, C>
96    where
97        DefaultAllocator: Reallocator<T, R, C, DimMinimum<R, C>, C>,
98    {
99        let (nrows, ncols) = self.qr.shape_generic();
100        let mut res = self.qr.resize_generic(nrows.min(ncols), ncols, T::zero());
101        res.fill_lower_triangle(T::zero(), 1);
102        res.set_partial_diagonal(self.diag.iter().map(|e| T::from_real(e.clone().modulus())));
103        res
104    }
105
106    /// Computes the orthogonal matrix `Q` of this decomposition.
107    #[must_use]
108    pub fn q(&self) -> OMatrix<T, R, DimMinimum<R, C>>
109    where
110        DefaultAllocator: Allocator<R, DimMinimum<R, C>>,
111    {
112        let (nrows, ncols) = self.qr.shape_generic();
113
114        // NOTE: we could build the identity matrix and call q_mul on it.
115        // Instead we don't so that we take in account the matrix sparseness.
116        let mut res = Matrix::identity_generic(nrows, nrows.min(ncols));
117        let dim = self.diag.len();
118
119        for i in (0..dim).rev() {
120            let axis = self.qr.view_range(i.., i);
121            // TODO: sometimes, the axis might have a zero magnitude.
122            let refl = Reflection::new(Unit::new_unchecked(axis), T::zero());
123
124            let mut res_rows = res.view_range_mut(i.., i..);
125            refl.reflect_with_sign(&mut res_rows, self.diag[i].clone().signum());
126        }
127
128        res
129    }
130
131    /// Unpacks this decomposition into its two matrix factors.
132    pub fn unpack(
133        self,
134    ) -> (
135        OMatrix<T, R, DimMinimum<R, C>>,
136        OMatrix<T, DimMinimum<R, C>, C>,
137    )
138    where
139        DimMinimum<R, C>: DimMin<C, Output = DimMinimum<R, C>>,
140        DefaultAllocator:
141            Allocator<R, DimMinimum<R, C>> + Reallocator<T, R, C, DimMinimum<R, C>, C>,
142    {
143        (self.q(), self.unpack_r())
144    }
145
146    #[doc(hidden)]
147    pub const fn qr_internal(&self) -> &OMatrix<T, R, C> {
148        &self.qr
149    }
150
151    #[must_use]
152    pub(crate) const fn diag_internal(&self) -> &OVector<T, DimMinimum<R, C>> {
153        &self.diag
154    }
155
156    /// Multiplies the provided matrix by the transpose of the `Q` matrix of this decomposition.
157    pub fn q_tr_mul<R2: Dim, C2: Dim, S2>(&self, rhs: &mut Matrix<T, R2, C2, S2>)
158    // TODO: do we need a static constraint on the number of rows of rhs?
159    where
160        S2: StorageMut<T, R2, C2>,
161    {
162        let dim = self.diag.len();
163
164        for i in 0..dim {
165            let axis = self.qr.view_range(i.., i);
166            let refl = Reflection::new(Unit::new_unchecked(axis), T::zero());
167
168            let mut rhs_rows = rhs.rows_range_mut(i..);
169            refl.reflect_with_sign(&mut rhs_rows, self.diag[i].clone().signum().conjugate());
170        }
171    }
172}
173
174impl<T: ComplexField, D: DimMin<D, Output = D>> QR<T, D, D>
175where
176    DefaultAllocator: Allocator<D, D> + Allocator<D>,
177{
178    /// Solves the linear system `self * x = b`, where `x` is the unknown to be determined.
179    ///
180    /// Returns `None` if `self` is not invertible.
181    #[must_use = "Did you mean to use solve_mut()?"]
182    pub fn solve<R2: Dim, C2: Dim, S2>(
183        &self,
184        b: &Matrix<T, R2, C2, S2>,
185    ) -> Option<OMatrix<T, R2, C2>>
186    where
187        S2: Storage<T, R2, C2>,
188        ShapeConstraint: SameNumberOfRows<R2, D>,
189        DefaultAllocator: Allocator<R2, C2>,
190    {
191        let mut res = b.clone_owned();
192
193        if self.solve_mut(&mut res) {
194            Some(res)
195        } else {
196            None
197        }
198    }
199
200    /// Solves the linear system `self * x = b`, where `x` is the unknown to be determined.
201    ///
202    /// If the decomposed matrix is not invertible, this returns `false` and its input `b` is
203    /// overwritten with garbage.
204    pub fn solve_mut<R2: Dim, C2: Dim, S2>(&self, b: &mut Matrix<T, R2, C2, S2>) -> bool
205    where
206        S2: StorageMut<T, R2, C2>,
207        ShapeConstraint: SameNumberOfRows<R2, D>,
208    {
209        assert_eq!(
210            self.qr.nrows(),
211            b.nrows(),
212            "QR solve matrix dimension mismatch."
213        );
214        assert!(
215            self.qr.is_square(),
216            "QR solve: unable to solve a non-square system."
217        );
218
219        self.q_tr_mul(b);
220        self.solve_upper_triangular_mut(b)
221    }
222
223    // TODO: duplicate code from the `solve` module.
224    fn solve_upper_triangular_mut<R2: Dim, C2: Dim, S2>(
225        &self,
226        b: &mut Matrix<T, R2, C2, S2>,
227    ) -> bool
228    where
229        S2: StorageMut<T, R2, C2>,
230        ShapeConstraint: SameNumberOfRows<R2, D>,
231    {
232        let dim = self.qr.nrows();
233
234        for k in 0..b.ncols() {
235            let mut b = b.column_mut(k);
236            for i in (0..dim).rev() {
237                let coeff;
238
239                unsafe {
240                    let diag = self.diag.vget_unchecked(i).clone().modulus();
241
242                    if diag.is_zero() {
243                        return false;
244                    }
245
246                    coeff = b.vget_unchecked(i).clone().unscale(diag);
247                    *b.vget_unchecked_mut(i) = coeff.clone();
248                }
249
250                b.rows_range_mut(..i)
251                    .axpy(-coeff, &self.qr.view_range(..i, i), T::one());
252            }
253        }
254
255        true
256    }
257
258    /// Computes the inverse of the decomposed matrix.
259    ///
260    /// Returns `None` if the decomposed matrix is not invertible.
261    #[must_use]
262    pub fn try_inverse(&self) -> Option<OMatrix<T, D, D>> {
263        assert!(
264            self.qr.is_square(),
265            "QR inverse: unable to compute the inverse of a non-square matrix."
266        );
267
268        // TODO: is there a less naive method ?
269        let (nrows, ncols) = self.qr.shape_generic();
270        let mut res = OMatrix::identity_generic(nrows, ncols);
271
272        if self.solve_mut(&mut res) {
273            Some(res)
274        } else {
275            None
276        }
277    }
278
279    /// Indicates if the decomposed matrix is invertible.
280    #[must_use]
281    pub fn is_invertible(&self) -> bool {
282        assert!(
283            self.qr.is_square(),
284            "QR: unable to test the invertibility of a non-square matrix."
285        );
286
287        for i in 0..self.diag.len() {
288            if self.diag[i].is_zero() {
289                return false;
290            }
291        }
292
293        true
294    }
295
296    // /// Computes the determinant of the decomposed matrix.
297    // pub fn determinant(&self) -> T {
298    //     let dim = self.qr.nrows();
299    //     assert!(self.qr.is_square(), "QR determinant: unable to compute the determinant of a non-square matrix.");
300
301    //     let mut res = T::one();
302    //     for i in 0 .. dim {
303    //         res *= unsafe { *self.diag.vget_unchecked(i) };
304    //     }
305
306    //     res self.q_determinant()
307    // }
308}