1#[cfg(feature = "serde-serialize-no-std")]
2use serde::{Deserialize, Serialize};
3
4use crate::allocator::Allocator;
5use crate::base::{DefaultAllocator, Matrix, OMatrix};
6use crate::constraint::{SameNumberOfRows, ShapeConstraint};
7use crate::dimension::{Dim, DimMin, DimMinimum};
8use crate::storage::{Storage, StorageMut};
9use simba::scalar::ComplexField;
10
11use crate::linalg::lu;
12use crate::linalg::PermutationSequence;
13
14#[cfg_attr(feature = "serde-serialize-no-std", derive(Serialize, Deserialize))]
16#[cfg_attr(
17 feature = "serde-serialize-no-std",
18 serde(bound(serialize = "DefaultAllocator: Allocator<R, C> +
19 Allocator<DimMinimum<R, C>>,
20 OMatrix<T, R, C>: Serialize,
21 PermutationSequence<DimMinimum<R, C>>: Serialize"))
22)]
23#[cfg_attr(
24 feature = "serde-serialize-no-std",
25 serde(bound(deserialize = "DefaultAllocator: Allocator<R, C> +
26 Allocator<DimMinimum<R, C>>,
27 OMatrix<T, R, C>: Deserialize<'de>,
28 PermutationSequence<DimMinimum<R, C>>: Deserialize<'de>"))
29)]
30#[derive(Clone, Debug)]
31pub struct FullPivLU<T: ComplexField, R: DimMin<C>, C: Dim>
32where
33 DefaultAllocator: Allocator<R, C> + Allocator<DimMinimum<R, C>>,
34{
35 lu: OMatrix<T, R, C>,
36 p: PermutationSequence<DimMinimum<R, C>>,
37 q: PermutationSequence<DimMinimum<R, C>>,
38}
39
40impl<T: ComplexField, R: DimMin<C>, C: Dim> Copy for FullPivLU<T, R, C>
41where
42 DefaultAllocator: Allocator<R, C> + Allocator<DimMinimum<R, C>>,
43 OMatrix<T, R, C>: Copy,
44 PermutationSequence<DimMinimum<R, C>>: Copy,
45{
46}
47
48impl<T: ComplexField, R: DimMin<C>, C: Dim> FullPivLU<T, R, C>
49where
50 DefaultAllocator: Allocator<R, C> + Allocator<DimMinimum<R, C>>,
51{
52 pub fn new(mut matrix: OMatrix<T, R, C>) -> Self {
56 let (nrows, ncols) = matrix.shape_generic();
57 let min_nrows_ncols = nrows.min(ncols);
58
59 let mut p = PermutationSequence::identity_generic(min_nrows_ncols);
60 let mut q = PermutationSequence::identity_generic(min_nrows_ncols);
61
62 if min_nrows_ncols.value() == 0 {
63 return Self { lu: matrix, p, q };
64 }
65
66 for i in 0..min_nrows_ncols.value() {
67 let piv = matrix.view_range(i.., i..).icamax_full();
68 let row_piv = piv.0 + i;
69 let col_piv = piv.1 + i;
70 let diag = matrix[(row_piv, col_piv)].clone();
71
72 if diag.is_zero() {
73 break;
75 }
76
77 matrix.swap_columns(i, col_piv);
78 q.append_permutation(i, col_piv);
79
80 if row_piv != i {
81 p.append_permutation(i, row_piv);
82 matrix.columns_range_mut(..i).swap_rows(i, row_piv);
83 lu::gauss_step_swap(&mut matrix, diag, i, row_piv);
84 } else {
85 lu::gauss_step(&mut matrix, diag, i);
86 }
87 }
88
89 Self { lu: matrix, p, q }
90 }
91
92 #[doc(hidden)]
93 pub fn lu_internal(&self) -> &OMatrix<T, R, C> {
94 &self.lu
95 }
96
97 #[inline]
99 #[must_use]
100 pub fn l(&self) -> OMatrix<T, R, DimMinimum<R, C>>
101 where
102 DefaultAllocator: Allocator<R, DimMinimum<R, C>>,
103 {
104 let (nrows, ncols) = self.lu.shape_generic();
105 let mut m = self.lu.columns_generic(0, nrows.min(ncols)).into_owned();
106 m.fill_upper_triangle(T::zero(), 1);
107 m.fill_diagonal(T::one());
108 m
109 }
110
111 #[inline]
113 #[must_use]
114 pub fn u(&self) -> OMatrix<T, DimMinimum<R, C>, C>
115 where
116 DefaultAllocator: Allocator<DimMinimum<R, C>, C>,
117 {
118 let (nrows, ncols) = self.lu.shape_generic();
119 self.lu.rows_generic(0, nrows.min(ncols)).upper_triangle()
120 }
121
122 #[inline]
124 #[must_use]
125 pub fn p(&self) -> &PermutationSequence<DimMinimum<R, C>> {
126 &self.p
127 }
128
129 #[inline]
131 #[must_use]
132 pub fn q(&self) -> &PermutationSequence<DimMinimum<R, C>> {
133 &self.q
134 }
135
136 #[inline]
138 pub fn unpack(
139 self,
140 ) -> (
141 PermutationSequence<DimMinimum<R, C>>,
142 OMatrix<T, R, DimMinimum<R, C>>,
143 OMatrix<T, DimMinimum<R, C>, C>,
144 PermutationSequence<DimMinimum<R, C>>,
145 )
146 where
147 DefaultAllocator: Allocator<R, DimMinimum<R, C>> + Allocator<DimMinimum<R, C>, C>,
148 {
149 let l = self.l();
151 let u = self.u();
152 let p = self.p;
153 let q = self.q;
154
155 (p, l, u, q)
156 }
157}
158
159impl<T: ComplexField, D: DimMin<D, Output = D>> FullPivLU<T, D, D>
160where
161 DefaultAllocator: Allocator<D, D> + Allocator<D>,
162{
163 #[must_use = "Did you mean to use solve_mut()?"]
167 pub fn solve<R2: Dim, C2: Dim, S2>(
168 &self,
169 b: &Matrix<T, R2, C2, S2>,
170 ) -> Option<OMatrix<T, R2, C2>>
171 where
172 S2: Storage<T, R2, C2>,
173 ShapeConstraint: SameNumberOfRows<R2, D>,
174 DefaultAllocator: Allocator<R2, C2>,
175 {
176 let mut res = b.clone_owned();
177 if self.solve_mut(&mut res) {
178 Some(res)
179 } else {
180 None
181 }
182 }
183
184 pub fn solve_mut<R2: Dim, C2: Dim, S2>(&self, b: &mut Matrix<T, R2, C2, S2>) -> bool
189 where
190 S2: StorageMut<T, R2, C2>,
191 ShapeConstraint: SameNumberOfRows<R2, D>,
192 {
193 assert_eq!(
194 self.lu.nrows(),
195 b.nrows(),
196 "FullPivLU solve matrix dimension mismatch."
197 );
198 assert!(
199 self.lu.is_square(),
200 "FullPivLU solve: unable to solve a non-square system."
201 );
202
203 if self.is_invertible() {
204 self.p.permute_rows(b);
205 let _ = self.lu.solve_lower_triangular_with_diag_mut(b, T::one());
206 let _ = self.lu.solve_upper_triangular_mut(b);
207 self.q.inv_permute_rows(b);
208
209 true
210 } else {
211 false
212 }
213 }
214
215 #[must_use]
219 pub fn try_inverse(&self) -> Option<OMatrix<T, D, D>> {
220 assert!(
221 self.lu.is_square(),
222 "FullPivLU inverse: unable to compute the inverse of a non-square matrix."
223 );
224
225 let (nrows, ncols) = self.lu.shape_generic();
226
227 let mut res = OMatrix::identity_generic(nrows, ncols);
228 if self.solve_mut(&mut res) {
229 Some(res)
230 } else {
231 None
232 }
233 }
234
235 #[must_use]
237 pub fn is_invertible(&self) -> bool {
238 assert!(
239 self.lu.is_square(),
240 "FullPivLU: unable to test the invertibility of a non-square matrix."
241 );
242
243 let dim = self.lu.nrows();
244 !self.lu[(dim - 1, dim - 1)].is_zero()
245 }
246
247 #[must_use]
249 pub fn determinant(&self) -> T {
250 assert!(
251 self.lu.is_square(),
252 "FullPivLU determinant: unable to compute the determinant of a non-square matrix."
253 );
254
255 let dim = self.lu.nrows();
256 let mut res = self.lu[(dim - 1, dim - 1)].clone();
257 if !res.is_zero() {
258 for i in 0..dim - 1 {
259 res *= unsafe { self.lu.get_unchecked((i, i)).clone() };
260 }
261
262 res * self.p.determinant() * self.q.determinant()
263 } else {
264 T::zero()
265 }
266 }
267}