nalgebra/linalg/
lu.rs

1#[cfg(feature = "serde-serialize-no-std")]
2use serde::{Deserialize, Serialize};
3
4use crate::allocator::{Allocator, Reallocator};
5use crate::base::{DefaultAllocator, Matrix, OMatrix, Scalar};
6use crate::constraint::{SameNumberOfRows, ShapeConstraint};
7use crate::dimension::{Dim, DimMin, DimMinimum};
8use crate::storage::{Storage, StorageMut};
9use simba::scalar::{ComplexField, Field};
10use std::mem;
11
12use crate::linalg::PermutationSequence;
13
14/// LU decomposition with partial (row) pivoting.
15#[cfg_attr(feature = "serde-serialize-no-std", derive(Serialize, Deserialize))]
16#[cfg_attr(
17    feature = "serde-serialize-no-std",
18    serde(bound(serialize = "DefaultAllocator: Allocator<R, C> +
19                           Allocator<DimMinimum<R, C>>,
20         OMatrix<T, R, C>: Serialize,
21         PermutationSequence<DimMinimum<R, C>>: Serialize"))
22)]
23#[cfg_attr(
24    feature = "serde-serialize-no-std",
25    serde(bound(deserialize = "DefaultAllocator: Allocator<R, C> +
26                           Allocator<DimMinimum<R, C>>,
27         OMatrix<T, R, C>: Deserialize<'de>,
28         PermutationSequence<DimMinimum<R, C>>: Deserialize<'de>"))
29)]
30#[cfg_attr(feature = "defmt", derive(defmt::Format))]
31#[derive(Clone, Debug)]
32pub struct LU<T: ComplexField, R: DimMin<C>, C: Dim>
33where
34    DefaultAllocator: Allocator<R, C> + Allocator<DimMinimum<R, C>>,
35{
36    lu: OMatrix<T, R, C>,
37    p: PermutationSequence<DimMinimum<R, C>>,
38}
39
40impl<T: ComplexField, R: DimMin<C>, C: Dim> Copy for LU<T, R, C>
41where
42    DefaultAllocator: Allocator<R, C> + Allocator<DimMinimum<R, C>>,
43    OMatrix<T, R, C>: Copy,
44    PermutationSequence<DimMinimum<R, C>>: Copy,
45{
46}
47
48/// Performs a LU decomposition to overwrite `out` with the inverse of `matrix`.
49///
50/// If `matrix` is not invertible, `false` is returned and `out` may contain invalid data.
51pub fn try_invert_to<T: ComplexField, D: Dim, S>(
52    mut matrix: OMatrix<T, D, D>,
53    out: &mut Matrix<T, D, D, S>,
54) -> bool
55where
56    S: StorageMut<T, D, D>,
57    DefaultAllocator: Allocator<D, D>,
58{
59    assert!(
60        matrix.is_square(),
61        "LU inversion: unable to invert a rectangular matrix."
62    );
63    let dim = matrix.nrows();
64
65    out.fill_with_identity();
66
67    for i in 0..dim {
68        let piv = matrix.view_range(i.., i).icamax() + i;
69        let diag = matrix[(piv, i)].clone();
70
71        if diag.is_zero() {
72            return false;
73        }
74
75        if piv != i {
76            out.swap_rows(i, piv);
77            matrix.columns_range_mut(..i).swap_rows(i, piv);
78            gauss_step_swap(&mut matrix, diag, i, piv);
79        } else {
80            gauss_step(&mut matrix, diag, i);
81        }
82    }
83
84    let _ = matrix.solve_lower_triangular_with_diag_mut(out, T::one());
85    matrix.solve_upper_triangular_mut(out)
86}
87
88impl<T: ComplexField, R: DimMin<C>, C: Dim> LU<T, R, C>
89where
90    DefaultAllocator: Allocator<R, C> + Allocator<DimMinimum<R, C>>,
91{
92    /// Computes the LU decomposition with partial (row) pivoting of `matrix`.
93    pub fn new(mut matrix: OMatrix<T, R, C>) -> Self {
94        let (nrows, ncols) = matrix.shape_generic();
95        let min_nrows_ncols = nrows.min(ncols);
96
97        let mut p = PermutationSequence::identity_generic(min_nrows_ncols);
98
99        if min_nrows_ncols.value() == 0 {
100            return LU { lu: matrix, p };
101        }
102
103        for i in 0..min_nrows_ncols.value() {
104            let piv = matrix.view_range(i.., i).icamax() + i;
105            let diag = matrix[(piv, i)].clone();
106
107            if diag.is_zero() {
108                // No non-zero entries on this column.
109                continue;
110            }
111
112            if piv != i {
113                p.append_permutation(i, piv);
114                matrix.columns_range_mut(..i).swap_rows(i, piv);
115                gauss_step_swap(&mut matrix, diag, i, piv);
116            } else {
117                gauss_step(&mut matrix, diag, i);
118            }
119        }
120
121        LU { lu: matrix, p }
122    }
123
124    #[doc(hidden)]
125    pub const fn lu_internal(&self) -> &OMatrix<T, R, C> {
126        &self.lu
127    }
128
129    /// The lower triangular matrix of this decomposition.
130    #[inline]
131    #[must_use]
132    pub fn l(&self) -> OMatrix<T, R, DimMinimum<R, C>>
133    where
134        DefaultAllocator: Allocator<R, DimMinimum<R, C>>,
135    {
136        let (nrows, ncols) = self.lu.shape_generic();
137        let mut m = self.lu.columns_generic(0, nrows.min(ncols)).into_owned();
138        m.fill_upper_triangle(T::zero(), 1);
139        m.fill_diagonal(T::one());
140        m
141    }
142
143    /// The lower triangular matrix of this decomposition.
144    fn l_unpack_with_p(
145        self,
146    ) -> (
147        OMatrix<T, R, DimMinimum<R, C>>,
148        PermutationSequence<DimMinimum<R, C>>,
149    )
150    where
151        DefaultAllocator: Reallocator<T, R, C, R, DimMinimum<R, C>>,
152    {
153        let (nrows, ncols) = self.lu.shape_generic();
154        let mut m = self.lu.resize_generic(nrows, nrows.min(ncols), T::zero());
155        m.fill_upper_triangle(T::zero(), 1);
156        m.fill_diagonal(T::one());
157        (m, self.p)
158    }
159
160    /// The lower triangular matrix of this decomposition.
161    #[inline]
162    pub fn l_unpack(self) -> OMatrix<T, R, DimMinimum<R, C>>
163    where
164        DefaultAllocator: Reallocator<T, R, C, R, DimMinimum<R, C>>,
165    {
166        let (nrows, ncols) = self.lu.shape_generic();
167        let mut m = self.lu.resize_generic(nrows, nrows.min(ncols), T::zero());
168        m.fill_upper_triangle(T::zero(), 1);
169        m.fill_diagonal(T::one());
170        m
171    }
172
173    /// The upper triangular matrix of this decomposition.
174    #[inline]
175    #[must_use]
176    pub fn u(&self) -> OMatrix<T, DimMinimum<R, C>, C>
177    where
178        DefaultAllocator: Allocator<DimMinimum<R, C>, C>,
179    {
180        let (nrows, ncols) = self.lu.shape_generic();
181        self.lu.rows_generic(0, nrows.min(ncols)).upper_triangle()
182    }
183
184    /// The row permutations of this decomposition.
185    #[inline]
186    #[must_use]
187    pub const fn p(&self) -> &PermutationSequence<DimMinimum<R, C>> {
188        &self.p
189    }
190
191    /// The row permutations and two triangular matrices of this decomposition: `(P, L, U)`.
192    #[inline]
193    pub fn unpack(
194        self,
195    ) -> (
196        PermutationSequence<DimMinimum<R, C>>,
197        OMatrix<T, R, DimMinimum<R, C>>,
198        OMatrix<T, DimMinimum<R, C>, C>,
199    )
200    where
201        DefaultAllocator: Allocator<R, DimMinimum<R, C>>
202            + Allocator<DimMinimum<R, C>, C>
203            + Reallocator<T, R, C, R, DimMinimum<R, C>>,
204    {
205        // Use reallocation for either l or u.
206        let u = self.u();
207        let (l, p) = self.l_unpack_with_p();
208
209        (p, l, u)
210    }
211}
212
213impl<T: ComplexField, D: DimMin<D, Output = D>> LU<T, D, D>
214where
215    DefaultAllocator: Allocator<D, D> + Allocator<D>,
216{
217    /// Solves the linear system `self * x = b`, where `x` is the unknown to be determined.
218    ///
219    /// Returns `None` if `self` is not invertible.
220    #[must_use = "Did you mean to use solve_mut()?"]
221    pub fn solve<R2: Dim, C2: Dim, S2>(
222        &self,
223        b: &Matrix<T, R2, C2, S2>,
224    ) -> Option<OMatrix<T, R2, C2>>
225    where
226        S2: Storage<T, R2, C2>,
227        ShapeConstraint: SameNumberOfRows<R2, D>,
228        DefaultAllocator: Allocator<R2, C2>,
229    {
230        let mut res = b.clone_owned();
231        if self.solve_mut(&mut res) {
232            Some(res)
233        } else {
234            None
235        }
236    }
237
238    /// Solves the linear system `self * x = b`, where `x` is the unknown to be determined.
239    ///
240    /// If the decomposed matrix is not invertible, this returns `false` and its input `b` may
241    /// be overwritten with garbage.
242    pub fn solve_mut<R2: Dim, C2: Dim, S2>(&self, b: &mut Matrix<T, R2, C2, S2>) -> bool
243    where
244        S2: StorageMut<T, R2, C2>,
245        ShapeConstraint: SameNumberOfRows<R2, D>,
246    {
247        assert_eq!(
248            self.lu.nrows(),
249            b.nrows(),
250            "LU solve matrix dimension mismatch."
251        );
252        assert!(
253            self.lu.is_square(),
254            "LU solve: unable to solve a non-square system."
255        );
256
257        self.p.permute_rows(b);
258        let _ = self.lu.solve_lower_triangular_with_diag_mut(b, T::one());
259        self.lu.solve_upper_triangular_mut(b)
260    }
261
262    /// Computes the inverse of the decomposed matrix.
263    ///
264    /// Returns `None` if the matrix is not invertible.
265    #[must_use]
266    pub fn try_inverse(&self) -> Option<OMatrix<T, D, D>> {
267        assert!(
268            self.lu.is_square(),
269            "LU inverse: unable to compute the inverse of a non-square matrix."
270        );
271
272        let (nrows, ncols) = self.lu.shape_generic();
273        let mut res = OMatrix::identity_generic(nrows, ncols);
274        if self.try_inverse_to(&mut res) {
275            Some(res)
276        } else {
277            None
278        }
279    }
280
281    /// Computes the inverse of the decomposed matrix and outputs the result to `out`.
282    ///
283    /// If the decomposed matrix is not invertible, this returns `false` and `out` may be
284    /// overwritten with garbage.
285    pub fn try_inverse_to<S2: StorageMut<T, D, D>>(&self, out: &mut Matrix<T, D, D, S2>) -> bool {
286        assert!(
287            self.lu.is_square(),
288            "LU inverse: unable to compute the inverse of a non-square matrix."
289        );
290        assert!(
291            self.lu.shape() == out.shape(),
292            "LU inverse: mismatched output shape."
293        );
294
295        out.fill_with_identity();
296        self.solve_mut(out)
297    }
298
299    /// Computes the determinant of the decomposed matrix.
300    #[must_use]
301    pub fn determinant(&self) -> T {
302        let dim = self.lu.nrows();
303        assert!(
304            self.lu.is_square(),
305            "LU determinant: unable to compute the determinant of a non-square matrix."
306        );
307
308        let mut res = T::one();
309        for i in 0..dim {
310            res *= unsafe { self.lu.get_unchecked((i, i)).clone() };
311        }
312
313        res * self.p.determinant()
314    }
315
316    /// Indicates if the decomposed matrix is invertible.
317    #[must_use]
318    pub fn is_invertible(&self) -> bool {
319        assert!(
320            self.lu.is_square(),
321            "LU: unable to test the invertibility of a non-square matrix."
322        );
323
324        for i in 0..self.lu.nrows() {
325            if self.lu[(i, i)].is_zero() {
326                return false;
327            }
328        }
329
330        true
331    }
332}
333
334#[doc(hidden)]
335/// Executes one step of gaussian elimination on the i-th row and column of `matrix`. The diagonal
336/// element `matrix[(i, i)]` is provided as argument.
337pub fn gauss_step<T, R: Dim, C: Dim, S>(matrix: &mut Matrix<T, R, C, S>, diag: T, i: usize)
338where
339    T: Scalar + Field,
340    S: StorageMut<T, R, C>,
341{
342    let mut submat = matrix.view_range_mut(i.., i..);
343
344    let inv_diag = T::one() / diag;
345
346    let (mut coeffs, mut submat) = submat.columns_range_pair_mut(0, 1..);
347
348    let mut coeffs = coeffs.rows_range_mut(1..);
349    coeffs *= inv_diag;
350
351    let (pivot_row, mut down) = submat.rows_range_pair_mut(0, 1..);
352
353    for k in 0..pivot_row.ncols() {
354        down.column_mut(k)
355            .axpy(-pivot_row[k].clone(), &coeffs, T::one());
356    }
357}
358
359#[doc(hidden)]
360/// Swaps the rows `i` with the row `piv` and executes one step of gaussian elimination on the i-th
361/// row and column of `matrix`. The diagonal element `matrix[(i, i)]` is provided as argument.
362pub fn gauss_step_swap<T, R: Dim, C: Dim, S>(
363    matrix: &mut Matrix<T, R, C, S>,
364    diag: T,
365    i: usize,
366    piv: usize,
367) where
368    T: Scalar + Field,
369    S: StorageMut<T, R, C>,
370{
371    let piv = piv - i;
372    let mut submat = matrix.view_range_mut(i.., i..);
373
374    let inv_diag = T::one() / diag;
375
376    let (mut coeffs, mut submat) = submat.columns_range_pair_mut(0, 1..);
377
378    coeffs.swap((0, 0), (piv, 0));
379    let mut coeffs = coeffs.rows_range_mut(1..);
380    coeffs *= inv_diag;
381
382    let (mut pivot_row, mut down) = submat.rows_range_pair_mut(0, 1..);
383
384    for k in 0..pivot_row.ncols() {
385        mem::swap(&mut pivot_row[k], &mut down[(piv - 1, k)]);
386        down.column_mut(k)
387            .axpy(-pivot_row[k].clone(), &coeffs, T::one());
388    }
389}