1use num::Zero;
2#[cfg(feature = "serde-serialize-no-std")]
3use serde::{Deserialize, Serialize};
4
5use crate::allocator::{Allocator, Reallocator};
6use crate::base::{Const, DefaultAllocator, Matrix, OMatrix, OVector, Unit};
7use crate::constraint::{SameNumberOfRows, ShapeConstraint};
8use crate::dimension::{Dim, DimMin, DimMinimum};
9use crate::storage::StorageMut;
10use crate::ComplexField;
11
12use crate::geometry::Reflection;
13use crate::linalg::{householder, PermutationSequence};
14use std::mem::MaybeUninit;
15
16#[cfg_attr(feature = "serde-serialize-no-std", derive(Serialize, Deserialize))]
18#[cfg_attr(
19 feature = "serde-serialize-no-std",
20 serde(bound(serialize = "DefaultAllocator: Allocator<R, C> +
21 Allocator<DimMinimum<R, C>>,
22 OMatrix<T, R, C>: Serialize,
23 PermutationSequence<DimMinimum<R, C>>: Serialize,
24 OVector<T, DimMinimum<R, C>>: Serialize"))
25)]
26#[cfg_attr(
27 feature = "serde-serialize-no-std",
28 serde(bound(deserialize = "DefaultAllocator: Allocator<R, C> +
29 Allocator<DimMinimum<R, C>>,
30 OMatrix<T, R, C>: Deserialize<'de>,
31 PermutationSequence<DimMinimum<R, C>>: Deserialize<'de>,
32 OVector<T, DimMinimum<R, C>>: Deserialize<'de>"))
33)]
34#[derive(Clone, Debug)]
35pub struct ColPivQR<T: ComplexField, R: DimMin<C>, C: Dim>
36where
37 DefaultAllocator: Allocator<R, C> + Allocator<DimMinimum<R, C>>,
38{
39 col_piv_qr: OMatrix<T, R, C>,
40 p: PermutationSequence<DimMinimum<R, C>>,
41 diag: OVector<T, DimMinimum<R, C>>,
42}
43
44impl<T: ComplexField, R: DimMin<C>, C: Dim> Copy for ColPivQR<T, R, C>
45where
46 DefaultAllocator: Allocator<R, C> + Allocator<DimMinimum<R, C>>,
47 OMatrix<T, R, C>: Copy,
48 PermutationSequence<DimMinimum<R, C>>: Copy,
49 OVector<T, DimMinimum<R, C>>: Copy,
50{
51}
52
53impl<T: ComplexField, R: DimMin<C>, C: Dim> ColPivQR<T, R, C>
54where
55 DefaultAllocator: Allocator<R, C> + Allocator<R> + Allocator<DimMinimum<R, C>>,
56{
57 pub fn new(mut matrix: OMatrix<T, R, C>) -> Self {
59 let (nrows, ncols) = matrix.shape_generic();
60 let min_nrows_ncols = nrows.min(ncols);
61 let mut p = PermutationSequence::identity_generic(min_nrows_ncols);
62
63 if min_nrows_ncols.value() == 0 {
64 return ColPivQR {
65 col_piv_qr: matrix,
66 p,
67 diag: Matrix::zeros_generic(min_nrows_ncols, Const::<1>),
68 };
69 }
70
71 let mut diag = Matrix::uninit(min_nrows_ncols, Const::<1>);
72
73 for i in 0..min_nrows_ncols.value() {
74 let piv = matrix.view_range(i.., i..).icamax_full();
75 let col_piv = piv.1 + i;
76 matrix.swap_columns(i, col_piv);
77 p.append_permutation(i, col_piv);
78
79 diag[i] =
80 MaybeUninit::new(householder::clear_column_unchecked(&mut matrix, i, 0, None));
81 }
82
83 let diag = unsafe { diag.assume_init() };
85
86 ColPivQR {
87 col_piv_qr: matrix,
88 p,
89 diag,
90 }
91 }
92
93 #[inline]
95 #[must_use]
96 pub fn r(&self) -> OMatrix<T, DimMinimum<R, C>, C>
97 where
98 DefaultAllocator: Allocator<DimMinimum<R, C>, C>,
99 {
100 let (nrows, ncols) = self.col_piv_qr.shape_generic();
101 let mut res = self
102 .col_piv_qr
103 .rows_generic(0, nrows.min(ncols))
104 .upper_triangle();
105 res.set_partial_diagonal(self.diag.iter().map(|e| T::from_real(e.clone().modulus())));
106 res
107 }
108
109 #[inline]
113 pub fn unpack_r(self) -> OMatrix<T, DimMinimum<R, C>, C>
114 where
115 DefaultAllocator: Reallocator<T, R, C, DimMinimum<R, C>, C>,
116 {
117 let (nrows, ncols) = self.col_piv_qr.shape_generic();
118 let mut res = self
119 .col_piv_qr
120 .resize_generic(nrows.min(ncols), ncols, T::zero());
121 res.fill_lower_triangle(T::zero(), 1);
122 res.set_partial_diagonal(self.diag.iter().map(|e| T::from_real(e.clone().modulus())));
123 res
124 }
125
126 #[must_use]
128 pub fn q(&self) -> OMatrix<T, R, DimMinimum<R, C>>
129 where
130 DefaultAllocator: Allocator<R, DimMinimum<R, C>>,
131 {
132 let (nrows, ncols) = self.col_piv_qr.shape_generic();
133
134 let mut res = Matrix::identity_generic(nrows, nrows.min(ncols));
137 let dim = self.diag.len();
138
139 for i in (0..dim).rev() {
140 let axis = self.col_piv_qr.view_range(i.., i);
141 let refl = Reflection::new(Unit::new_unchecked(axis), T::zero());
143
144 let mut res_rows = res.view_range_mut(i.., i..);
145 refl.reflect_with_sign(&mut res_rows, self.diag[i].clone().signum());
146 }
147
148 res
149 }
150 #[inline]
152 #[must_use]
153 pub fn p(&self) -> &PermutationSequence<DimMinimum<R, C>> {
154 &self.p
155 }
156
157 pub fn unpack(
159 self,
160 ) -> (
161 OMatrix<T, R, DimMinimum<R, C>>,
162 OMatrix<T, DimMinimum<R, C>, C>,
163 PermutationSequence<DimMinimum<R, C>>,
164 )
165 where
166 DimMinimum<R, C>: DimMin<C, Output = DimMinimum<R, C>>,
167 DefaultAllocator: Allocator<R, DimMinimum<R, C>>
168 + Reallocator<T, R, C, DimMinimum<R, C>, C>
169 + Allocator<DimMinimum<R, C>>,
170 {
171 (self.q(), self.r(), self.p)
172 }
173
174 #[doc(hidden)]
175 pub fn col_piv_qr_internal(&self) -> &OMatrix<T, R, C> {
176 &self.col_piv_qr
177 }
178
179 pub fn q_tr_mul<R2: Dim, C2: Dim, S2>(&self, rhs: &mut Matrix<T, R2, C2, S2>)
181 where
182 S2: StorageMut<T, R2, C2>,
183 {
184 let dim = self.diag.len();
185
186 for i in 0..dim {
187 let axis = self.col_piv_qr.view_range(i.., i);
188 let refl = Reflection::new(Unit::new_unchecked(axis), T::zero());
189
190 let mut rhs_rows = rhs.rows_range_mut(i..);
191 refl.reflect_with_sign(&mut rhs_rows, self.diag[i].clone().signum().conjugate());
192 }
193 }
194}
195
196impl<T: ComplexField, D: DimMin<D, Output = D>> ColPivQR<T, D, D>
197where
198 DefaultAllocator: Allocator<D, D> + Allocator<D> + Allocator<DimMinimum<D, D>>,
199{
200 #[must_use = "Did you mean to use solve_mut()?"]
204 pub fn solve<R2: Dim, C2: Dim, S2>(
205 &self,
206 b: &Matrix<T, R2, C2, S2>,
207 ) -> Option<OMatrix<T, R2, C2>>
208 where
209 S2: StorageMut<T, R2, C2>,
210 ShapeConstraint: SameNumberOfRows<R2, D>,
211 DefaultAllocator: Allocator<R2, C2>,
212 {
213 let mut res = b.clone_owned();
214
215 if self.solve_mut(&mut res) {
216 Some(res)
217 } else {
218 None
219 }
220 }
221
222 pub fn solve_mut<R2: Dim, C2: Dim, S2>(&self, b: &mut Matrix<T, R2, C2, S2>) -> bool
227 where
228 S2: StorageMut<T, R2, C2>,
229 ShapeConstraint: SameNumberOfRows<R2, D>,
230 {
231 assert_eq!(
232 self.col_piv_qr.nrows(),
233 b.nrows(),
234 "ColPivQR solve matrix dimension mismatch."
235 );
236 assert!(
237 self.col_piv_qr.is_square(),
238 "ColPivQR solve: unable to solve a non-square system."
239 );
240
241 self.q_tr_mul(b);
242 let solved = self.solve_upper_triangular_mut(b);
243 self.p.inv_permute_rows(b);
244
245 solved
246 }
247
248 fn solve_upper_triangular_mut<R2: Dim, C2: Dim, S2>(
250 &self,
251 b: &mut Matrix<T, R2, C2, S2>,
252 ) -> bool
253 where
254 S2: StorageMut<T, R2, C2>,
255 ShapeConstraint: SameNumberOfRows<R2, D>,
256 {
257 let dim = self.col_piv_qr.nrows();
258
259 for k in 0..b.ncols() {
260 let mut b = b.column_mut(k);
261 for i in (0..dim).rev() {
262 let coeff;
263
264 unsafe {
265 let diag = self.diag.vget_unchecked(i).clone().modulus();
266
267 if diag.is_zero() {
268 return false;
269 }
270
271 coeff = b.vget_unchecked(i).clone().unscale(diag);
272 *b.vget_unchecked_mut(i) = coeff.clone();
273 }
274
275 b.rows_range_mut(..i)
276 .axpy(-coeff, &self.col_piv_qr.view_range(..i, i), T::one());
277 }
278 }
279
280 true
281 }
282
283 #[must_use]
287 pub fn try_inverse(&self) -> Option<OMatrix<T, D, D>> {
288 assert!(
289 self.col_piv_qr.is_square(),
290 "ColPivQR inverse: unable to compute the inverse of a non-square matrix."
291 );
292
293 let (nrows, ncols) = self.col_piv_qr.shape_generic();
295 let mut res = OMatrix::identity_generic(nrows, ncols);
296
297 if self.solve_mut(&mut res) {
298 Some(res)
299 } else {
300 None
301 }
302 }
303
304 #[must_use]
306 pub fn is_invertible(&self) -> bool {
307 assert!(
308 self.col_piv_qr.is_square(),
309 "ColPivQR: unable to test the invertibility of a non-square matrix."
310 );
311
312 for i in 0..self.diag.len() {
313 if self.diag[i].is_zero() {
314 return false;
315 }
316 }
317
318 true
319 }
320
321 #[must_use]
323 pub fn determinant(&self) -> T {
324 let dim = self.col_piv_qr.nrows();
325 assert!(
326 self.col_piv_qr.is_square(),
327 "ColPivQR determinant: unable to compute the determinant of a non-square matrix."
328 );
329
330 let mut res = T::one();
331 for i in 0..dim {
332 res *= unsafe { self.diag.vget_unchecked(i).clone() };
333 }
334
335 res * self.p.determinant()
336 }
337}