1use num::Zero;
2#[cfg(feature = "serde-serialize-no-std")]
3use serde::{Deserialize, Serialize};
4
5use crate::ComplexField;
6use crate::allocator::{Allocator, Reallocator};
7use crate::base::{Const, DefaultAllocator, Matrix, OMatrix, OVector, Unit};
8use crate::constraint::{SameNumberOfRows, ShapeConstraint};
9use crate::dimension::{Dim, DimMin, DimMinimum};
10use crate::storage::StorageMut;
11
12use crate::geometry::Reflection;
13use crate::linalg::{PermutationSequence, householder};
14use std::mem::MaybeUninit;
15
16#[cfg_attr(feature = "serde-serialize-no-std", derive(Serialize, Deserialize))]
18#[cfg_attr(
19 feature = "serde-serialize-no-std",
20 serde(bound(serialize = "DefaultAllocator: Allocator<R, C> +
21 Allocator<DimMinimum<R, C>>,
22 OMatrix<T, R, C>: Serialize,
23 PermutationSequence<DimMinimum<R, C>>: Serialize,
24 OVector<T, DimMinimum<R, C>>: Serialize"))
25)]
26#[cfg_attr(
27 feature = "serde-serialize-no-std",
28 serde(bound(deserialize = "DefaultAllocator: Allocator<R, C> +
29 Allocator<DimMinimum<R, C>>,
30 OMatrix<T, R, C>: Deserialize<'de>,
31 PermutationSequence<DimMinimum<R, C>>: Deserialize<'de>,
32 OVector<T, DimMinimum<R, C>>: Deserialize<'de>"))
33)]
34#[cfg_attr(feature = "defmt", derive(defmt::Format))]
35#[derive(Clone, Debug)]
36pub struct ColPivQR<T: ComplexField, R: DimMin<C>, C: Dim>
37where
38 DefaultAllocator: Allocator<R, C> + Allocator<DimMinimum<R, C>>,
39{
40 col_piv_qr: OMatrix<T, R, C>,
41 p: PermutationSequence<DimMinimum<R, C>>,
42 diag: OVector<T, DimMinimum<R, C>>,
43}
44
45impl<T: ComplexField, R: DimMin<C>, C: Dim> Copy for ColPivQR<T, R, C>
46where
47 DefaultAllocator: Allocator<R, C> + Allocator<DimMinimum<R, C>>,
48 OMatrix<T, R, C>: Copy,
49 PermutationSequence<DimMinimum<R, C>>: Copy,
50 OVector<T, DimMinimum<R, C>>: Copy,
51{
52}
53
54impl<T: ComplexField, R: DimMin<C>, C: Dim> ColPivQR<T, R, C>
55where
56 DefaultAllocator: Allocator<R, C> + Allocator<R> + Allocator<DimMinimum<R, C>>,
57{
58 pub fn new(mut matrix: OMatrix<T, R, C>) -> Self {
60 let (nrows, ncols) = matrix.shape_generic();
61 let min_nrows_ncols = nrows.min(ncols);
62 let mut p = PermutationSequence::identity_generic(min_nrows_ncols);
63
64 if min_nrows_ncols.value() == 0 {
65 return ColPivQR {
66 col_piv_qr: matrix,
67 p,
68 diag: Matrix::zeros_generic(min_nrows_ncols, Const::<1>),
69 };
70 }
71
72 let mut diag = Matrix::uninit(min_nrows_ncols, Const::<1>);
73
74 for i in 0..min_nrows_ncols.value() {
75 let piv = matrix.view_range(i.., i..).icamax_full();
76 let col_piv = piv.1 + i;
77 matrix.swap_columns(i, col_piv);
78 p.append_permutation(i, col_piv);
79
80 diag[i] =
81 MaybeUninit::new(householder::clear_column_unchecked(&mut matrix, i, 0, None));
82 }
83
84 let diag = unsafe { diag.assume_init() };
86
87 ColPivQR {
88 col_piv_qr: matrix,
89 p,
90 diag,
91 }
92 }
93
94 #[inline]
96 #[must_use]
97 pub fn r(&self) -> OMatrix<T, DimMinimum<R, C>, C>
98 where
99 DefaultAllocator: Allocator<DimMinimum<R, C>, C>,
100 {
101 let (nrows, ncols) = self.col_piv_qr.shape_generic();
102 let mut res = self
103 .col_piv_qr
104 .rows_generic(0, nrows.min(ncols))
105 .upper_triangle();
106 res.set_partial_diagonal(self.diag.iter().map(|e| T::from_real(e.clone().modulus())));
107 res
108 }
109
110 #[inline]
114 pub fn unpack_r(self) -> OMatrix<T, DimMinimum<R, C>, C>
115 where
116 DefaultAllocator: Reallocator<T, R, C, DimMinimum<R, C>, C>,
117 {
118 let (nrows, ncols) = self.col_piv_qr.shape_generic();
119 let mut res = self
120 .col_piv_qr
121 .resize_generic(nrows.min(ncols), ncols, T::zero());
122 res.fill_lower_triangle(T::zero(), 1);
123 res.set_partial_diagonal(self.diag.iter().map(|e| T::from_real(e.clone().modulus())));
124 res
125 }
126
127 #[must_use]
129 pub fn q(&self) -> OMatrix<T, R, DimMinimum<R, C>>
130 where
131 DefaultAllocator: Allocator<R, DimMinimum<R, C>>,
132 {
133 let (nrows, ncols) = self.col_piv_qr.shape_generic();
134
135 let mut res = Matrix::identity_generic(nrows, nrows.min(ncols));
138 let dim = self.diag.len();
139
140 for i in (0..dim).rev() {
141 let axis = self.col_piv_qr.view_range(i.., i);
142 let refl = Reflection::new(Unit::new_unchecked(axis), T::zero());
144
145 let mut res_rows = res.view_range_mut(i.., i..);
146 refl.reflect_with_sign(&mut res_rows, self.diag[i].clone().signum());
147 }
148
149 res
150 }
151 #[inline]
153 #[must_use]
154 pub const fn p(&self) -> &PermutationSequence<DimMinimum<R, C>> {
155 &self.p
156 }
157
158 pub fn unpack(
160 self,
161 ) -> (
162 OMatrix<T, R, DimMinimum<R, C>>,
163 OMatrix<T, DimMinimum<R, C>, C>,
164 PermutationSequence<DimMinimum<R, C>>,
165 )
166 where
167 DimMinimum<R, C>: DimMin<C, Output = DimMinimum<R, C>>,
168 DefaultAllocator: Allocator<R, DimMinimum<R, C>>
169 + Reallocator<T, R, C, DimMinimum<R, C>, C>
170 + Allocator<DimMinimum<R, C>>,
171 {
172 (self.q(), self.r(), self.p)
173 }
174
175 #[doc(hidden)]
176 pub const fn col_piv_qr_internal(&self) -> &OMatrix<T, R, C> {
177 &self.col_piv_qr
178 }
179
180 pub fn q_tr_mul<R2: Dim, C2: Dim, S2>(&self, rhs: &mut Matrix<T, R2, C2, S2>)
182 where
183 S2: StorageMut<T, R2, C2>,
184 {
185 let dim = self.diag.len();
186
187 for i in 0..dim {
188 let axis = self.col_piv_qr.view_range(i.., i);
189 let refl = Reflection::new(Unit::new_unchecked(axis), T::zero());
190
191 let mut rhs_rows = rhs.rows_range_mut(i..);
192 refl.reflect_with_sign(&mut rhs_rows, self.diag[i].clone().signum().conjugate());
193 }
194 }
195}
196
197impl<T: ComplexField, D: DimMin<D, Output = D>> ColPivQR<T, D, D>
198where
199 DefaultAllocator: Allocator<D, D> + Allocator<D> + Allocator<DimMinimum<D, D>>,
200{
201 #[must_use = "Did you mean to use solve_mut()?"]
205 pub fn solve<R2: Dim, C2: Dim, S2>(
206 &self,
207 b: &Matrix<T, R2, C2, S2>,
208 ) -> Option<OMatrix<T, R2, C2>>
209 where
210 S2: StorageMut<T, R2, C2>,
211 ShapeConstraint: SameNumberOfRows<R2, D>,
212 DefaultAllocator: Allocator<R2, C2>,
213 {
214 let mut res = b.clone_owned();
215
216 if self.solve_mut(&mut res) {
217 Some(res)
218 } else {
219 None
220 }
221 }
222
223 pub fn solve_mut<R2: Dim, C2: Dim, S2>(&self, b: &mut Matrix<T, R2, C2, S2>) -> bool
228 where
229 S2: StorageMut<T, R2, C2>,
230 ShapeConstraint: SameNumberOfRows<R2, D>,
231 {
232 assert_eq!(
233 self.col_piv_qr.nrows(),
234 b.nrows(),
235 "ColPivQR solve matrix dimension mismatch."
236 );
237 assert!(
238 self.col_piv_qr.is_square(),
239 "ColPivQR solve: unable to solve a non-square system."
240 );
241
242 self.q_tr_mul(b);
243 let solved = self.solve_upper_triangular_mut(b);
244 self.p.inv_permute_rows(b);
245
246 solved
247 }
248
249 fn solve_upper_triangular_mut<R2: Dim, C2: Dim, S2>(
251 &self,
252 b: &mut Matrix<T, R2, C2, S2>,
253 ) -> bool
254 where
255 S2: StorageMut<T, R2, C2>,
256 ShapeConstraint: SameNumberOfRows<R2, D>,
257 {
258 let dim = self.col_piv_qr.nrows();
259
260 for k in 0..b.ncols() {
261 let mut b = b.column_mut(k);
262 for i in (0..dim).rev() {
263 let coeff;
264
265 unsafe {
266 let diag = self.diag.vget_unchecked(i).clone().modulus();
267
268 if diag.is_zero() {
269 return false;
270 }
271
272 coeff = b.vget_unchecked(i).clone().unscale(diag);
273 *b.vget_unchecked_mut(i) = coeff.clone();
274 }
275
276 b.rows_range_mut(..i)
277 .axpy(-coeff, &self.col_piv_qr.view_range(..i, i), T::one());
278 }
279 }
280
281 true
282 }
283
284 #[must_use]
288 pub fn try_inverse(&self) -> Option<OMatrix<T, D, D>> {
289 assert!(
290 self.col_piv_qr.is_square(),
291 "ColPivQR inverse: unable to compute the inverse of a non-square matrix."
292 );
293
294 let (nrows, ncols) = self.col_piv_qr.shape_generic();
296 let mut res = OMatrix::identity_generic(nrows, ncols);
297
298 if self.solve_mut(&mut res) {
299 Some(res)
300 } else {
301 None
302 }
303 }
304
305 #[must_use]
307 pub fn is_invertible(&self) -> bool {
308 assert!(
309 self.col_piv_qr.is_square(),
310 "ColPivQR: unable to test the invertibility of a non-square matrix."
311 );
312
313 for i in 0..self.diag.len() {
314 if self.diag[i].is_zero() {
315 return false;
316 }
317 }
318
319 true
320 }
321
322 #[must_use]
324 pub fn determinant(&self) -> T {
325 let dim = self.col_piv_qr.nrows();
326 assert!(
327 self.col_piv_qr.is_square(),
328 "ColPivQR determinant: unable to compute the determinant of a non-square matrix."
329 );
330
331 let mut res = T::one();
332 for i in 0..dim {
333 res *= unsafe { self.diag.vget_unchecked(i).clone() };
334 }
335
336 res * self.p.determinant()
337 }
338}