1#[cfg(feature = "serde-serialize-no-std")]
2use serde::{Deserialize, Serialize};
3
4use crate::allocator::{Allocator, Reallocator};
5use crate::base::{DefaultAllocator, Matrix, OMatrix, Scalar};
6use crate::constraint::{SameNumberOfRows, ShapeConstraint};
7use crate::dimension::{Dim, DimMin, DimMinimum};
8use crate::storage::{Storage, StorageMut};
9use simba::scalar::{ComplexField, Field};
10use std::mem;
11
12use crate::linalg::PermutationSequence;
13
14#[cfg_attr(feature = "serde-serialize-no-std", derive(Serialize, Deserialize))]
16#[cfg_attr(
17 feature = "serde-serialize-no-std",
18 serde(bound(serialize = "DefaultAllocator: Allocator<R, C> +
19 Allocator<DimMinimum<R, C>>,
20 OMatrix<T, R, C>: Serialize,
21 PermutationSequence<DimMinimum<R, C>>: Serialize"))
22)]
23#[cfg_attr(
24 feature = "serde-serialize-no-std",
25 serde(bound(deserialize = "DefaultAllocator: Allocator<R, C> +
26 Allocator<DimMinimum<R, C>>,
27 OMatrix<T, R, C>: Deserialize<'de>,
28 PermutationSequence<DimMinimum<R, C>>: Deserialize<'de>"))
29)]
30#[derive(Clone, Debug)]
31pub struct LU<T: ComplexField, R: DimMin<C>, C: Dim>
32where
33 DefaultAllocator: Allocator<R, C> + Allocator<DimMinimum<R, C>>,
34{
35 lu: OMatrix<T, R, C>,
36 p: PermutationSequence<DimMinimum<R, C>>,
37}
38
39impl<T: ComplexField, R: DimMin<C>, C: Dim> Copy for LU<T, R, C>
40where
41 DefaultAllocator: Allocator<R, C> + Allocator<DimMinimum<R, C>>,
42 OMatrix<T, R, C>: Copy,
43 PermutationSequence<DimMinimum<R, C>>: Copy,
44{
45}
46
47pub fn try_invert_to<T: ComplexField, D: Dim, S>(
51 mut matrix: OMatrix<T, D, D>,
52 out: &mut Matrix<T, D, D, S>,
53) -> bool
54where
55 S: StorageMut<T, D, D>,
56 DefaultAllocator: Allocator<D, D>,
57{
58 assert!(
59 matrix.is_square(),
60 "LU inversion: unable to invert a rectangular matrix."
61 );
62 let dim = matrix.nrows();
63
64 out.fill_with_identity();
65
66 for i in 0..dim {
67 let piv = matrix.view_range(i.., i).icamax() + i;
68 let diag = matrix[(piv, i)].clone();
69
70 if diag.is_zero() {
71 return false;
72 }
73
74 if piv != i {
75 out.swap_rows(i, piv);
76 matrix.columns_range_mut(..i).swap_rows(i, piv);
77 gauss_step_swap(&mut matrix, diag, i, piv);
78 } else {
79 gauss_step(&mut matrix, diag, i);
80 }
81 }
82
83 let _ = matrix.solve_lower_triangular_with_diag_mut(out, T::one());
84 matrix.solve_upper_triangular_mut(out)
85}
86
87impl<T: ComplexField, R: DimMin<C>, C: Dim> LU<T, R, C>
88where
89 DefaultAllocator: Allocator<R, C> + Allocator<DimMinimum<R, C>>,
90{
91 pub fn new(mut matrix: OMatrix<T, R, C>) -> Self {
93 let (nrows, ncols) = matrix.shape_generic();
94 let min_nrows_ncols = nrows.min(ncols);
95
96 let mut p = PermutationSequence::identity_generic(min_nrows_ncols);
97
98 if min_nrows_ncols.value() == 0 {
99 return LU { lu: matrix, p };
100 }
101
102 for i in 0..min_nrows_ncols.value() {
103 let piv = matrix.view_range(i.., i).icamax() + i;
104 let diag = matrix[(piv, i)].clone();
105
106 if diag.is_zero() {
107 continue;
109 }
110
111 if piv != i {
112 p.append_permutation(i, piv);
113 matrix.columns_range_mut(..i).swap_rows(i, piv);
114 gauss_step_swap(&mut matrix, diag, i, piv);
115 } else {
116 gauss_step(&mut matrix, diag, i);
117 }
118 }
119
120 LU { lu: matrix, p }
121 }
122
123 #[doc(hidden)]
124 pub fn lu_internal(&self) -> &OMatrix<T, R, C> {
125 &self.lu
126 }
127
128 #[inline]
130 #[must_use]
131 pub fn l(&self) -> OMatrix<T, R, DimMinimum<R, C>>
132 where
133 DefaultAllocator: Allocator<R, DimMinimum<R, C>>,
134 {
135 let (nrows, ncols) = self.lu.shape_generic();
136 let mut m = self.lu.columns_generic(0, nrows.min(ncols)).into_owned();
137 m.fill_upper_triangle(T::zero(), 1);
138 m.fill_diagonal(T::one());
139 m
140 }
141
142 fn l_unpack_with_p(
144 self,
145 ) -> (
146 OMatrix<T, R, DimMinimum<R, C>>,
147 PermutationSequence<DimMinimum<R, C>>,
148 )
149 where
150 DefaultAllocator: Reallocator<T, R, C, R, DimMinimum<R, C>>,
151 {
152 let (nrows, ncols) = self.lu.shape_generic();
153 let mut m = self.lu.resize_generic(nrows, nrows.min(ncols), T::zero());
154 m.fill_upper_triangle(T::zero(), 1);
155 m.fill_diagonal(T::one());
156 (m, self.p)
157 }
158
159 #[inline]
161 pub fn l_unpack(self) -> OMatrix<T, R, DimMinimum<R, C>>
162 where
163 DefaultAllocator: Reallocator<T, R, C, R, DimMinimum<R, C>>,
164 {
165 let (nrows, ncols) = self.lu.shape_generic();
166 let mut m = self.lu.resize_generic(nrows, nrows.min(ncols), T::zero());
167 m.fill_upper_triangle(T::zero(), 1);
168 m.fill_diagonal(T::one());
169 m
170 }
171
172 #[inline]
174 #[must_use]
175 pub fn u(&self) -> OMatrix<T, DimMinimum<R, C>, C>
176 where
177 DefaultAllocator: Allocator<DimMinimum<R, C>, C>,
178 {
179 let (nrows, ncols) = self.lu.shape_generic();
180 self.lu.rows_generic(0, nrows.min(ncols)).upper_triangle()
181 }
182
183 #[inline]
185 #[must_use]
186 pub fn p(&self) -> &PermutationSequence<DimMinimum<R, C>> {
187 &self.p
188 }
189
190 #[inline]
192 pub fn unpack(
193 self,
194 ) -> (
195 PermutationSequence<DimMinimum<R, C>>,
196 OMatrix<T, R, DimMinimum<R, C>>,
197 OMatrix<T, DimMinimum<R, C>, C>,
198 )
199 where
200 DefaultAllocator: Allocator<R, DimMinimum<R, C>>
201 + Allocator<DimMinimum<R, C>, C>
202 + Reallocator<T, R, C, R, DimMinimum<R, C>>,
203 {
204 let u = self.u();
206 let (l, p) = self.l_unpack_with_p();
207
208 (p, l, u)
209 }
210}
211
212impl<T: ComplexField, D: DimMin<D, Output = D>> LU<T, D, D>
213where
214 DefaultAllocator: Allocator<D, D> + Allocator<D>,
215{
216 #[must_use = "Did you mean to use solve_mut()?"]
220 pub fn solve<R2: Dim, C2: Dim, S2>(
221 &self,
222 b: &Matrix<T, R2, C2, S2>,
223 ) -> Option<OMatrix<T, R2, C2>>
224 where
225 S2: Storage<T, R2, C2>,
226 ShapeConstraint: SameNumberOfRows<R2, D>,
227 DefaultAllocator: Allocator<R2, C2>,
228 {
229 let mut res = b.clone_owned();
230 if self.solve_mut(&mut res) {
231 Some(res)
232 } else {
233 None
234 }
235 }
236
237 pub fn solve_mut<R2: Dim, C2: Dim, S2>(&self, b: &mut Matrix<T, R2, C2, S2>) -> bool
242 where
243 S2: StorageMut<T, R2, C2>,
244 ShapeConstraint: SameNumberOfRows<R2, D>,
245 {
246 assert_eq!(
247 self.lu.nrows(),
248 b.nrows(),
249 "LU solve matrix dimension mismatch."
250 );
251 assert!(
252 self.lu.is_square(),
253 "LU solve: unable to solve a non-square system."
254 );
255
256 self.p.permute_rows(b);
257 let _ = self.lu.solve_lower_triangular_with_diag_mut(b, T::one());
258 self.lu.solve_upper_triangular_mut(b)
259 }
260
261 #[must_use]
265 pub fn try_inverse(&self) -> Option<OMatrix<T, D, D>> {
266 assert!(
267 self.lu.is_square(),
268 "LU inverse: unable to compute the inverse of a non-square matrix."
269 );
270
271 let (nrows, ncols) = self.lu.shape_generic();
272 let mut res = OMatrix::identity_generic(nrows, ncols);
273 if self.try_inverse_to(&mut res) {
274 Some(res)
275 } else {
276 None
277 }
278 }
279
280 pub fn try_inverse_to<S2: StorageMut<T, D, D>>(&self, out: &mut Matrix<T, D, D, S2>) -> bool {
285 assert!(
286 self.lu.is_square(),
287 "LU inverse: unable to compute the inverse of a non-square matrix."
288 );
289 assert!(
290 self.lu.shape() == out.shape(),
291 "LU inverse: mismatched output shape."
292 );
293
294 out.fill_with_identity();
295 self.solve_mut(out)
296 }
297
298 #[must_use]
300 pub fn determinant(&self) -> T {
301 let dim = self.lu.nrows();
302 assert!(
303 self.lu.is_square(),
304 "LU determinant: unable to compute the determinant of a non-square matrix."
305 );
306
307 let mut res = T::one();
308 for i in 0..dim {
309 res *= unsafe { self.lu.get_unchecked((i, i)).clone() };
310 }
311
312 res * self.p.determinant()
313 }
314
315 #[must_use]
317 pub fn is_invertible(&self) -> bool {
318 assert!(
319 self.lu.is_square(),
320 "LU: unable to test the invertibility of a non-square matrix."
321 );
322
323 for i in 0..self.lu.nrows() {
324 if self.lu[(i, i)].is_zero() {
325 return false;
326 }
327 }
328
329 true
330 }
331}
332
333#[doc(hidden)]
334pub fn gauss_step<T, R: Dim, C: Dim, S>(matrix: &mut Matrix<T, R, C, S>, diag: T, i: usize)
337where
338 T: Scalar + Field,
339 S: StorageMut<T, R, C>,
340{
341 let mut submat = matrix.view_range_mut(i.., i..);
342
343 let inv_diag = T::one() / diag;
344
345 let (mut coeffs, mut submat) = submat.columns_range_pair_mut(0, 1..);
346
347 let mut coeffs = coeffs.rows_range_mut(1..);
348 coeffs *= inv_diag;
349
350 let (pivot_row, mut down) = submat.rows_range_pair_mut(0, 1..);
351
352 for k in 0..pivot_row.ncols() {
353 down.column_mut(k)
354 .axpy(-pivot_row[k].clone(), &coeffs, T::one());
355 }
356}
357
358#[doc(hidden)]
359pub fn gauss_step_swap<T, R: Dim, C: Dim, S>(
362 matrix: &mut Matrix<T, R, C, S>,
363 diag: T,
364 i: usize,
365 piv: usize,
366) where
367 T: Scalar + Field,
368 S: StorageMut<T, R, C>,
369{
370 let piv = piv - i;
371 let mut submat = matrix.view_range_mut(i.., i..);
372
373 let inv_diag = T::one() / diag;
374
375 let (mut coeffs, mut submat) = submat.columns_range_pair_mut(0, 1..);
376
377 coeffs.swap((0, 0), (piv, 0));
378 let mut coeffs = coeffs.rows_range_mut(1..);
379 coeffs *= inv_diag;
380
381 let (mut pivot_row, mut down) = submat.rows_range_pair_mut(0, 1..);
382
383 for k in 0..pivot_row.ncols() {
384 mem::swap(&mut pivot_row[k], &mut down[(piv - 1, k)]);
385 down.column_mut(k)
386 .axpy(-pivot_row[k].clone(), &coeffs, T::one());
387 }
388}