nalgebra/linalg/
full_piv_lu.rs

1#[cfg(feature = "serde-serialize-no-std")]
2use serde::{Deserialize, Serialize};
3
4use crate::allocator::Allocator;
5use crate::base::{DefaultAllocator, Matrix, OMatrix};
6use crate::constraint::{SameNumberOfRows, ShapeConstraint};
7use crate::dimension::{Dim, DimMin, DimMinimum};
8use crate::storage::{Storage, StorageMut};
9use simba::scalar::ComplexField;
10
11use crate::linalg::PermutationSequence;
12use crate::linalg::lu;
13
14/// LU decomposition with full row and column pivoting.
15#[cfg_attr(feature = "serde-serialize-no-std", derive(Serialize, Deserialize))]
16#[cfg_attr(
17    feature = "serde-serialize-no-std",
18    serde(bound(serialize = "DefaultAllocator: Allocator<R, C> +
19                           Allocator<DimMinimum<R, C>>,
20         OMatrix<T, R, C>: Serialize,
21         PermutationSequence<DimMinimum<R, C>>: Serialize"))
22)]
23#[cfg_attr(
24    feature = "serde-serialize-no-std",
25    serde(bound(deserialize = "DefaultAllocator: Allocator<R, C> +
26                           Allocator<DimMinimum<R, C>>,
27         OMatrix<T, R, C>: Deserialize<'de>,
28         PermutationSequence<DimMinimum<R, C>>: Deserialize<'de>"))
29)]
30#[cfg_attr(feature = "defmt", derive(defmt::Format))]
31#[derive(Clone, Debug)]
32pub struct FullPivLU<T: ComplexField, R: DimMin<C>, C: Dim>
33where
34    DefaultAllocator: Allocator<R, C> + Allocator<DimMinimum<R, C>>,
35{
36    lu: OMatrix<T, R, C>,
37    p: PermutationSequence<DimMinimum<R, C>>,
38    q: PermutationSequence<DimMinimum<R, C>>,
39}
40
41impl<T: ComplexField, R: DimMin<C>, C: Dim> Copy for FullPivLU<T, R, C>
42where
43    DefaultAllocator: Allocator<R, C> + Allocator<DimMinimum<R, C>>,
44    OMatrix<T, R, C>: Copy,
45    PermutationSequence<DimMinimum<R, C>>: Copy,
46{
47}
48
49impl<T: ComplexField, R: DimMin<C>, C: Dim> FullPivLU<T, R, C>
50where
51    DefaultAllocator: Allocator<R, C> + Allocator<DimMinimum<R, C>>,
52{
53    /// Computes the LU decomposition with full pivoting of `matrix`.
54    ///
55    /// This effectively computes `P, L, U, Q` such that `P * matrix * Q = LU`.
56    pub fn new(mut matrix: OMatrix<T, R, C>) -> Self {
57        let (nrows, ncols) = matrix.shape_generic();
58        let min_nrows_ncols = nrows.min(ncols);
59
60        let mut p = PermutationSequence::identity_generic(min_nrows_ncols);
61        let mut q = PermutationSequence::identity_generic(min_nrows_ncols);
62
63        if min_nrows_ncols.value() == 0 {
64            return Self { lu: matrix, p, q };
65        }
66
67        for i in 0..min_nrows_ncols.value() {
68            let piv = matrix.view_range(i.., i..).icamax_full();
69            let row_piv = piv.0 + i;
70            let col_piv = piv.1 + i;
71            let diag = matrix[(row_piv, col_piv)].clone();
72
73            if diag.is_zero() {
74                // The remaining of the matrix is zero.
75                break;
76            }
77
78            matrix.swap_columns(i, col_piv);
79            q.append_permutation(i, col_piv);
80
81            if row_piv != i {
82                p.append_permutation(i, row_piv);
83                matrix.columns_range_mut(..i).swap_rows(i, row_piv);
84                lu::gauss_step_swap(&mut matrix, diag, i, row_piv);
85            } else {
86                lu::gauss_step(&mut matrix, diag, i);
87            }
88        }
89
90        Self { lu: matrix, p, q }
91    }
92
93    #[doc(hidden)]
94    pub const fn lu_internal(&self) -> &OMatrix<T, R, C> {
95        &self.lu
96    }
97
98    /// The lower triangular matrix of this decomposition.
99    #[inline]
100    #[must_use]
101    pub fn l(&self) -> OMatrix<T, R, DimMinimum<R, C>>
102    where
103        DefaultAllocator: Allocator<R, DimMinimum<R, C>>,
104    {
105        let (nrows, ncols) = self.lu.shape_generic();
106        let mut m = self.lu.columns_generic(0, nrows.min(ncols)).into_owned();
107        m.fill_upper_triangle(T::zero(), 1);
108        m.fill_diagonal(T::one());
109        m
110    }
111
112    /// The upper triangular matrix of this decomposition.
113    #[inline]
114    #[must_use]
115    pub fn u(&self) -> OMatrix<T, DimMinimum<R, C>, C>
116    where
117        DefaultAllocator: Allocator<DimMinimum<R, C>, C>,
118    {
119        let (nrows, ncols) = self.lu.shape_generic();
120        self.lu.rows_generic(0, nrows.min(ncols)).upper_triangle()
121    }
122
123    /// The row permutations of this decomposition.
124    #[inline]
125    #[must_use]
126    pub const fn p(&self) -> &PermutationSequence<DimMinimum<R, C>> {
127        &self.p
128    }
129
130    /// The column permutations of this decomposition.
131    #[inline]
132    #[must_use]
133    pub const fn q(&self) -> &PermutationSequence<DimMinimum<R, C>> {
134        &self.q
135    }
136
137    /// The two matrices of this decomposition and the row and column permutations: `(P, L, U, Q)`.
138    #[inline]
139    pub fn unpack(
140        self,
141    ) -> (
142        PermutationSequence<DimMinimum<R, C>>,
143        OMatrix<T, R, DimMinimum<R, C>>,
144        OMatrix<T, DimMinimum<R, C>, C>,
145        PermutationSequence<DimMinimum<R, C>>,
146    )
147    where
148        DefaultAllocator: Allocator<R, DimMinimum<R, C>> + Allocator<DimMinimum<R, C>, C>,
149    {
150        // Use reallocation for either l or u.
151        let l = self.l();
152        let u = self.u();
153        let p = self.p;
154        let q = self.q;
155
156        (p, l, u, q)
157    }
158}
159
160impl<T: ComplexField, D: DimMin<D, Output = D>> FullPivLU<T, D, D>
161where
162    DefaultAllocator: Allocator<D, D> + Allocator<D>,
163{
164    /// Solves the linear system `self * x = b`, where `x` is the unknown to be determined.
165    ///
166    /// Returns `None` if the decomposed matrix is not invertible.
167    #[must_use = "Did you mean to use solve_mut()?"]
168    pub fn solve<R2: Dim, C2: Dim, S2>(
169        &self,
170        b: &Matrix<T, R2, C2, S2>,
171    ) -> Option<OMatrix<T, R2, C2>>
172    where
173        S2: Storage<T, R2, C2>,
174        ShapeConstraint: SameNumberOfRows<R2, D>,
175        DefaultAllocator: Allocator<R2, C2>,
176    {
177        let mut res = b.clone_owned();
178        if self.solve_mut(&mut res) {
179            Some(res)
180        } else {
181            None
182        }
183    }
184
185    /// Solves the linear system `self * x = b`, where `x` is the unknown to be determined.
186    ///
187    /// If the decomposed matrix is not invertible, this returns `false` and its input `b` may
188    /// be overwritten with garbage.
189    pub fn solve_mut<R2: Dim, C2: Dim, S2>(&self, b: &mut Matrix<T, R2, C2, S2>) -> bool
190    where
191        S2: StorageMut<T, R2, C2>,
192        ShapeConstraint: SameNumberOfRows<R2, D>,
193    {
194        assert_eq!(
195            self.lu.nrows(),
196            b.nrows(),
197            "FullPivLU solve matrix dimension mismatch."
198        );
199        assert!(
200            self.lu.is_square(),
201            "FullPivLU solve: unable to solve a non-square system."
202        );
203
204        if self.is_invertible() {
205            self.p.permute_rows(b);
206            let _ = self.lu.solve_lower_triangular_with_diag_mut(b, T::one());
207            let _ = self.lu.solve_upper_triangular_mut(b);
208            self.q.inv_permute_rows(b);
209
210            true
211        } else {
212            false
213        }
214    }
215
216    /// Computes the inverse of the decomposed matrix.
217    ///
218    /// Returns `None` if the decomposed matrix is not invertible.
219    #[must_use]
220    pub fn try_inverse(&self) -> Option<OMatrix<T, D, D>> {
221        assert!(
222            self.lu.is_square(),
223            "FullPivLU inverse: unable to compute the inverse of a non-square matrix."
224        );
225
226        let (nrows, ncols) = self.lu.shape_generic();
227
228        let mut res = OMatrix::identity_generic(nrows, ncols);
229        if self.solve_mut(&mut res) {
230            Some(res)
231        } else {
232            None
233        }
234    }
235
236    /// Indicates if the decomposed matrix is invertible.
237    #[must_use]
238    pub fn is_invertible(&self) -> bool {
239        assert!(
240            self.lu.is_square(),
241            "FullPivLU: unable to test the invertibility of a non-square matrix."
242        );
243
244        let dim = self.lu.nrows();
245        !self.lu[(dim - 1, dim - 1)].is_zero()
246    }
247
248    /// Computes the determinant of the decomposed matrix.
249    #[must_use]
250    pub fn determinant(&self) -> T {
251        assert!(
252            self.lu.is_square(),
253            "FullPivLU determinant: unable to compute the determinant of a non-square matrix."
254        );
255
256        let dim = self.lu.nrows();
257        let mut res = self.lu[(dim - 1, dim - 1)].clone();
258        if !res.is_zero() {
259            for i in 0..dim - 1 {
260                res *= unsafe { self.lu.get_unchecked((i, i)).clone() };
261            }
262
263            res * self.p.determinant() * self.q.determinant()
264        } else {
265            T::zero()
266        }
267    }
268}