1#[cfg(feature = "serde-serialize-no-std")]
2use serde::{Deserialize, Serialize};
3
4use crate::allocator::Allocator;
5use crate::base::{DefaultAllocator, Matrix, OMatrix};
6use crate::constraint::{SameNumberOfRows, ShapeConstraint};
7use crate::dimension::{Dim, DimMin, DimMinimum};
8use crate::storage::{Storage, StorageMut};
9use simba::scalar::ComplexField;
10
11use crate::linalg::PermutationSequence;
12use crate::linalg::lu;
13
14#[cfg_attr(feature = "serde-serialize-no-std", derive(Serialize, Deserialize))]
16#[cfg_attr(
17 feature = "serde-serialize-no-std",
18 serde(bound(serialize = "DefaultAllocator: Allocator<R, C> +
19 Allocator<DimMinimum<R, C>>,
20 OMatrix<T, R, C>: Serialize,
21 PermutationSequence<DimMinimum<R, C>>: Serialize"))
22)]
23#[cfg_attr(
24 feature = "serde-serialize-no-std",
25 serde(bound(deserialize = "DefaultAllocator: Allocator<R, C> +
26 Allocator<DimMinimum<R, C>>,
27 OMatrix<T, R, C>: Deserialize<'de>,
28 PermutationSequence<DimMinimum<R, C>>: Deserialize<'de>"))
29)]
30#[cfg_attr(feature = "defmt", derive(defmt::Format))]
31#[derive(Clone, Debug)]
32pub struct FullPivLU<T: ComplexField, R: DimMin<C>, C: Dim>
33where
34 DefaultAllocator: Allocator<R, C> + Allocator<DimMinimum<R, C>>,
35{
36 lu: OMatrix<T, R, C>,
37 p: PermutationSequence<DimMinimum<R, C>>,
38 q: PermutationSequence<DimMinimum<R, C>>,
39}
40
41impl<T: ComplexField, R: DimMin<C>, C: Dim> Copy for FullPivLU<T, R, C>
42where
43 DefaultAllocator: Allocator<R, C> + Allocator<DimMinimum<R, C>>,
44 OMatrix<T, R, C>: Copy,
45 PermutationSequence<DimMinimum<R, C>>: Copy,
46{
47}
48
49impl<T: ComplexField, R: DimMin<C>, C: Dim> FullPivLU<T, R, C>
50where
51 DefaultAllocator: Allocator<R, C> + Allocator<DimMinimum<R, C>>,
52{
53 pub fn new(mut matrix: OMatrix<T, R, C>) -> Self {
57 let (nrows, ncols) = matrix.shape_generic();
58 let min_nrows_ncols = nrows.min(ncols);
59
60 let mut p = PermutationSequence::identity_generic(min_nrows_ncols);
61 let mut q = PermutationSequence::identity_generic(min_nrows_ncols);
62
63 if min_nrows_ncols.value() == 0 {
64 return Self { lu: matrix, p, q };
65 }
66
67 for i in 0..min_nrows_ncols.value() {
68 let piv = matrix.view_range(i.., i..).icamax_full();
69 let row_piv = piv.0 + i;
70 let col_piv = piv.1 + i;
71 let diag = matrix[(row_piv, col_piv)].clone();
72
73 if diag.is_zero() {
74 break;
76 }
77
78 matrix.swap_columns(i, col_piv);
79 q.append_permutation(i, col_piv);
80
81 if row_piv != i {
82 p.append_permutation(i, row_piv);
83 matrix.columns_range_mut(..i).swap_rows(i, row_piv);
84 lu::gauss_step_swap(&mut matrix, diag, i, row_piv);
85 } else {
86 lu::gauss_step(&mut matrix, diag, i);
87 }
88 }
89
90 Self { lu: matrix, p, q }
91 }
92
93 #[doc(hidden)]
94 pub const fn lu_internal(&self) -> &OMatrix<T, R, C> {
95 &self.lu
96 }
97
98 #[inline]
100 #[must_use]
101 pub fn l(&self) -> OMatrix<T, R, DimMinimum<R, C>>
102 where
103 DefaultAllocator: Allocator<R, DimMinimum<R, C>>,
104 {
105 let (nrows, ncols) = self.lu.shape_generic();
106 let mut m = self.lu.columns_generic(0, nrows.min(ncols)).into_owned();
107 m.fill_upper_triangle(T::zero(), 1);
108 m.fill_diagonal(T::one());
109 m
110 }
111
112 #[inline]
114 #[must_use]
115 pub fn u(&self) -> OMatrix<T, DimMinimum<R, C>, C>
116 where
117 DefaultAllocator: Allocator<DimMinimum<R, C>, C>,
118 {
119 let (nrows, ncols) = self.lu.shape_generic();
120 self.lu.rows_generic(0, nrows.min(ncols)).upper_triangle()
121 }
122
123 #[inline]
125 #[must_use]
126 pub const fn p(&self) -> &PermutationSequence<DimMinimum<R, C>> {
127 &self.p
128 }
129
130 #[inline]
132 #[must_use]
133 pub const fn q(&self) -> &PermutationSequence<DimMinimum<R, C>> {
134 &self.q
135 }
136
137 #[inline]
139 pub fn unpack(
140 self,
141 ) -> (
142 PermutationSequence<DimMinimum<R, C>>,
143 OMatrix<T, R, DimMinimum<R, C>>,
144 OMatrix<T, DimMinimum<R, C>, C>,
145 PermutationSequence<DimMinimum<R, C>>,
146 )
147 where
148 DefaultAllocator: Allocator<R, DimMinimum<R, C>> + Allocator<DimMinimum<R, C>, C>,
149 {
150 let l = self.l();
152 let u = self.u();
153 let p = self.p;
154 let q = self.q;
155
156 (p, l, u, q)
157 }
158}
159
160impl<T: ComplexField, D: DimMin<D, Output = D>> FullPivLU<T, D, D>
161where
162 DefaultAllocator: Allocator<D, D> + Allocator<D>,
163{
164 #[must_use = "Did you mean to use solve_mut()?"]
168 pub fn solve<R2: Dim, C2: Dim, S2>(
169 &self,
170 b: &Matrix<T, R2, C2, S2>,
171 ) -> Option<OMatrix<T, R2, C2>>
172 where
173 S2: Storage<T, R2, C2>,
174 ShapeConstraint: SameNumberOfRows<R2, D>,
175 DefaultAllocator: Allocator<R2, C2>,
176 {
177 let mut res = b.clone_owned();
178 if self.solve_mut(&mut res) {
179 Some(res)
180 } else {
181 None
182 }
183 }
184
185 pub fn solve_mut<R2: Dim, C2: Dim, S2>(&self, b: &mut Matrix<T, R2, C2, S2>) -> bool
190 where
191 S2: StorageMut<T, R2, C2>,
192 ShapeConstraint: SameNumberOfRows<R2, D>,
193 {
194 assert_eq!(
195 self.lu.nrows(),
196 b.nrows(),
197 "FullPivLU solve matrix dimension mismatch."
198 );
199 assert!(
200 self.lu.is_square(),
201 "FullPivLU solve: unable to solve a non-square system."
202 );
203
204 if self.is_invertible() {
205 self.p.permute_rows(b);
206 let _ = self.lu.solve_lower_triangular_with_diag_mut(b, T::one());
207 let _ = self.lu.solve_upper_triangular_mut(b);
208 self.q.inv_permute_rows(b);
209
210 true
211 } else {
212 false
213 }
214 }
215
216 #[must_use]
220 pub fn try_inverse(&self) -> Option<OMatrix<T, D, D>> {
221 assert!(
222 self.lu.is_square(),
223 "FullPivLU inverse: unable to compute the inverse of a non-square matrix."
224 );
225
226 let (nrows, ncols) = self.lu.shape_generic();
227
228 let mut res = OMatrix::identity_generic(nrows, ncols);
229 if self.solve_mut(&mut res) {
230 Some(res)
231 } else {
232 None
233 }
234 }
235
236 #[must_use]
238 pub fn is_invertible(&self) -> bool {
239 assert!(
240 self.lu.is_square(),
241 "FullPivLU: unable to test the invertibility of a non-square matrix."
242 );
243
244 let dim = self.lu.nrows();
245 !self.lu[(dim - 1, dim - 1)].is_zero()
246 }
247
248 #[must_use]
250 pub fn determinant(&self) -> T {
251 assert!(
252 self.lu.is_square(),
253 "FullPivLU determinant: unable to compute the determinant of a non-square matrix."
254 );
255
256 let dim = self.lu.nrows();
257 let mut res = self.lu[(dim - 1, dim - 1)].clone();
258 if !res.is_zero() {
259 for i in 0..dim - 1 {
260 res *= unsafe { self.lu.get_unchecked((i, i)).clone() };
261 }
262
263 res * self.p.determinant() * self.q.determinant()
264 } else {
265 T::zero()
266 }
267 }
268}