1use num::Zero;
2#[cfg(feature = "serde-serialize-no-std")]
3use serde::{Deserialize, Serialize};
4
5use crate::allocator::{Allocator, Reallocator};
6use crate::base::{DefaultAllocator, Matrix, OMatrix, OVector, Unit};
7use crate::constraint::{SameNumberOfRows, ShapeConstraint};
8use crate::dimension::{Const, Dim, DimMin, DimMinimum};
9use crate::storage::{Storage, StorageMut};
10use simba::scalar::ComplexField;
11
12use crate::geometry::Reflection;
13use crate::linalg::householder;
14use std::mem::MaybeUninit;
15
16#[cfg_attr(feature = "serde-serialize-no-std", derive(Serialize, Deserialize))]
18#[cfg_attr(
19 feature = "serde-serialize-no-std",
20 serde(bound(serialize = "DefaultAllocator: Allocator<R, C> +
21 Allocator<DimMinimum<R, C>>,
22 OMatrix<T, R, C>: Serialize,
23 OVector<T, DimMinimum<R, C>>: Serialize"))
24)]
25#[cfg_attr(
26 feature = "serde-serialize-no-std",
27 serde(bound(deserialize = "DefaultAllocator: Allocator<R, C> +
28 Allocator<DimMinimum<R, C>>,
29 OMatrix<T, R, C>: Deserialize<'de>,
30 OVector<T, DimMinimum<R, C>>: Deserialize<'de>"))
31)]
32#[derive(Clone, Debug)]
33pub struct QR<T: ComplexField, R: DimMin<C>, C: Dim>
34where
35 DefaultAllocator: Allocator<R, C> + Allocator<DimMinimum<R, C>>,
36{
37 qr: OMatrix<T, R, C>,
38 diag: OVector<T, DimMinimum<R, C>>,
39}
40
41impl<T: ComplexField, R: DimMin<C>, C: Dim> Copy for QR<T, R, C>
42where
43 DefaultAllocator: Allocator<R, C> + Allocator<DimMinimum<R, C>>,
44 OMatrix<T, R, C>: Copy,
45 OVector<T, DimMinimum<R, C>>: Copy,
46{
47}
48
49impl<T: ComplexField, R: DimMin<C>, C: Dim> QR<T, R, C>
50where
51 DefaultAllocator: Allocator<R, C> + Allocator<R> + Allocator<DimMinimum<R, C>>,
52{
53 pub fn new(mut matrix: OMatrix<T, R, C>) -> Self {
55 let (nrows, ncols) = matrix.shape_generic();
56 let min_nrows_ncols = nrows.min(ncols);
57
58 if min_nrows_ncols.value() == 0 {
59 return QR {
60 qr: matrix,
61 diag: Matrix::zeros_generic(min_nrows_ncols, Const::<1>),
62 };
63 }
64
65 let mut diag = Matrix::uninit(min_nrows_ncols, Const::<1>);
66
67 for i in 0..min_nrows_ncols.value() {
68 diag[i] =
69 MaybeUninit::new(householder::clear_column_unchecked(&mut matrix, i, 0, None));
70 }
71
72 let diag = unsafe { diag.assume_init() };
74 QR { qr: matrix, diag }
75 }
76
77 #[inline]
79 #[must_use]
80 pub fn r(&self) -> OMatrix<T, DimMinimum<R, C>, C>
81 where
82 DefaultAllocator: Allocator<DimMinimum<R, C>, C>,
83 {
84 let (nrows, ncols) = self.qr.shape_generic();
85 let mut res = self.qr.rows_generic(0, nrows.min(ncols)).upper_triangle();
86 res.set_partial_diagonal(self.diag.iter().map(|e| T::from_real(e.clone().modulus())));
87 res
88 }
89
90 #[inline]
94 pub fn unpack_r(self) -> OMatrix<T, DimMinimum<R, C>, C>
95 where
96 DefaultAllocator: Reallocator<T, R, C, DimMinimum<R, C>, C>,
97 {
98 let (nrows, ncols) = self.qr.shape_generic();
99 let mut res = self.qr.resize_generic(nrows.min(ncols), ncols, T::zero());
100 res.fill_lower_triangle(T::zero(), 1);
101 res.set_partial_diagonal(self.diag.iter().map(|e| T::from_real(e.clone().modulus())));
102 res
103 }
104
105 #[must_use]
107 pub fn q(&self) -> OMatrix<T, R, DimMinimum<R, C>>
108 where
109 DefaultAllocator: Allocator<R, DimMinimum<R, C>>,
110 {
111 let (nrows, ncols) = self.qr.shape_generic();
112
113 let mut res = Matrix::identity_generic(nrows, nrows.min(ncols));
116 let dim = self.diag.len();
117
118 for i in (0..dim).rev() {
119 let axis = self.qr.view_range(i.., i);
120 let refl = Reflection::new(Unit::new_unchecked(axis), T::zero());
122
123 let mut res_rows = res.view_range_mut(i.., i..);
124 refl.reflect_with_sign(&mut res_rows, self.diag[i].clone().signum());
125 }
126
127 res
128 }
129
130 pub fn unpack(
132 self,
133 ) -> (
134 OMatrix<T, R, DimMinimum<R, C>>,
135 OMatrix<T, DimMinimum<R, C>, C>,
136 )
137 where
138 DimMinimum<R, C>: DimMin<C, Output = DimMinimum<R, C>>,
139 DefaultAllocator:
140 Allocator<R, DimMinimum<R, C>> + Reallocator<T, R, C, DimMinimum<R, C>, C>,
141 {
142 (self.q(), self.unpack_r())
143 }
144
145 #[doc(hidden)]
146 pub fn qr_internal(&self) -> &OMatrix<T, R, C> {
147 &self.qr
148 }
149
150 #[must_use]
151 pub(crate) fn diag_internal(&self) -> &OVector<T, DimMinimum<R, C>> {
152 &self.diag
153 }
154
155 pub fn q_tr_mul<R2: Dim, C2: Dim, S2>(&self, rhs: &mut Matrix<T, R2, C2, S2>)
157 where
159 S2: StorageMut<T, R2, C2>,
160 {
161 let dim = self.diag.len();
162
163 for i in 0..dim {
164 let axis = self.qr.view_range(i.., i);
165 let refl = Reflection::new(Unit::new_unchecked(axis), T::zero());
166
167 let mut rhs_rows = rhs.rows_range_mut(i..);
168 refl.reflect_with_sign(&mut rhs_rows, self.diag[i].clone().signum().conjugate());
169 }
170 }
171}
172
173impl<T: ComplexField, D: DimMin<D, Output = D>> QR<T, D, D>
174where
175 DefaultAllocator: Allocator<D, D> + Allocator<D>,
176{
177 #[must_use = "Did you mean to use solve_mut()?"]
181 pub fn solve<R2: Dim, C2: Dim, S2>(
182 &self,
183 b: &Matrix<T, R2, C2, S2>,
184 ) -> Option<OMatrix<T, R2, C2>>
185 where
186 S2: Storage<T, R2, C2>,
187 ShapeConstraint: SameNumberOfRows<R2, D>,
188 DefaultAllocator: Allocator<R2, C2>,
189 {
190 let mut res = b.clone_owned();
191
192 if self.solve_mut(&mut res) {
193 Some(res)
194 } else {
195 None
196 }
197 }
198
199 pub fn solve_mut<R2: Dim, C2: Dim, S2>(&self, b: &mut Matrix<T, R2, C2, S2>) -> bool
204 where
205 S2: StorageMut<T, R2, C2>,
206 ShapeConstraint: SameNumberOfRows<R2, D>,
207 {
208 assert_eq!(
209 self.qr.nrows(),
210 b.nrows(),
211 "QR solve matrix dimension mismatch."
212 );
213 assert!(
214 self.qr.is_square(),
215 "QR solve: unable to solve a non-square system."
216 );
217
218 self.q_tr_mul(b);
219 self.solve_upper_triangular_mut(b)
220 }
221
222 fn solve_upper_triangular_mut<R2: Dim, C2: Dim, S2>(
224 &self,
225 b: &mut Matrix<T, R2, C2, S2>,
226 ) -> bool
227 where
228 S2: StorageMut<T, R2, C2>,
229 ShapeConstraint: SameNumberOfRows<R2, D>,
230 {
231 let dim = self.qr.nrows();
232
233 for k in 0..b.ncols() {
234 let mut b = b.column_mut(k);
235 for i in (0..dim).rev() {
236 let coeff;
237
238 unsafe {
239 let diag = self.diag.vget_unchecked(i).clone().modulus();
240
241 if diag.is_zero() {
242 return false;
243 }
244
245 coeff = b.vget_unchecked(i).clone().unscale(diag);
246 *b.vget_unchecked_mut(i) = coeff.clone();
247 }
248
249 b.rows_range_mut(..i)
250 .axpy(-coeff, &self.qr.view_range(..i, i), T::one());
251 }
252 }
253
254 true
255 }
256
257 #[must_use]
261 pub fn try_inverse(&self) -> Option<OMatrix<T, D, D>> {
262 assert!(
263 self.qr.is_square(),
264 "QR inverse: unable to compute the inverse of a non-square matrix."
265 );
266
267 let (nrows, ncols) = self.qr.shape_generic();
269 let mut res = OMatrix::identity_generic(nrows, ncols);
270
271 if self.solve_mut(&mut res) {
272 Some(res)
273 } else {
274 None
275 }
276 }
277
278 #[must_use]
280 pub fn is_invertible(&self) -> bool {
281 assert!(
282 self.qr.is_square(),
283 "QR: unable to test the invertibility of a non-square matrix."
284 );
285
286 for i in 0..self.diag.len() {
287 if self.diag[i].is_zero() {
288 return false;
289 }
290 }
291
292 true
293 }
294
295 }