1use num::Zero;
2#[cfg(feature = "serde-serialize-no-std")]
3use serde::{Deserialize, Serialize};
4
5use crate::allocator::{Allocator, Reallocator};
6use crate::base::{DefaultAllocator, Matrix, OMatrix, OVector, Unit};
7use crate::constraint::{SameNumberOfRows, ShapeConstraint};
8use crate::dimension::{Const, Dim, DimMin, DimMinimum};
9use crate::storage::{Storage, StorageMut};
10use simba::scalar::ComplexField;
11
12use crate::geometry::Reflection;
13use crate::linalg::householder;
14use std::mem::MaybeUninit;
15
16#[cfg_attr(feature = "serde-serialize-no-std", derive(Serialize, Deserialize))]
18#[cfg_attr(
19 feature = "serde-serialize-no-std",
20 serde(bound(serialize = "DefaultAllocator: Allocator<R, C> +
21 Allocator<DimMinimum<R, C>>,
22 OMatrix<T, R, C>: Serialize,
23 OVector<T, DimMinimum<R, C>>: Serialize"))
24)]
25#[cfg_attr(
26 feature = "serde-serialize-no-std",
27 serde(bound(deserialize = "DefaultAllocator: Allocator<R, C> +
28 Allocator<DimMinimum<R, C>>,
29 OMatrix<T, R, C>: Deserialize<'de>,
30 OVector<T, DimMinimum<R, C>>: Deserialize<'de>"))
31)]
32#[cfg_attr(feature = "defmt", derive(defmt::Format))]
33#[derive(Clone, Debug)]
34pub struct QR<T: ComplexField, R: DimMin<C>, C: Dim>
35where
36 DefaultAllocator: Allocator<R, C> + Allocator<DimMinimum<R, C>>,
37{
38 qr: OMatrix<T, R, C>,
39 diag: OVector<T, DimMinimum<R, C>>,
40}
41
42impl<T: ComplexField, R: DimMin<C>, C: Dim> Copy for QR<T, R, C>
43where
44 DefaultAllocator: Allocator<R, C> + Allocator<DimMinimum<R, C>>,
45 OMatrix<T, R, C>: Copy,
46 OVector<T, DimMinimum<R, C>>: Copy,
47{
48}
49
50impl<T: ComplexField, R: DimMin<C>, C: Dim> QR<T, R, C>
51where
52 DefaultAllocator: Allocator<R, C> + Allocator<R> + Allocator<DimMinimum<R, C>>,
53{
54 pub fn new(mut matrix: OMatrix<T, R, C>) -> Self {
56 let (nrows, ncols) = matrix.shape_generic();
57 let min_nrows_ncols = nrows.min(ncols);
58
59 if min_nrows_ncols.value() == 0 {
60 return QR {
61 qr: matrix,
62 diag: Matrix::zeros_generic(min_nrows_ncols, Const::<1>),
63 };
64 }
65
66 let mut diag = Matrix::uninit(min_nrows_ncols, Const::<1>);
67
68 for i in 0..min_nrows_ncols.value() {
69 diag[i] =
70 MaybeUninit::new(householder::clear_column_unchecked(&mut matrix, i, 0, None));
71 }
72
73 let diag = unsafe { diag.assume_init() };
75 QR { qr: matrix, diag }
76 }
77
78 #[inline]
80 #[must_use]
81 pub fn r(&self) -> OMatrix<T, DimMinimum<R, C>, C>
82 where
83 DefaultAllocator: Allocator<DimMinimum<R, C>, C>,
84 {
85 let (nrows, ncols) = self.qr.shape_generic();
86 let mut res = self.qr.rows_generic(0, nrows.min(ncols)).upper_triangle();
87 res.set_partial_diagonal(self.diag.iter().map(|e| T::from_real(e.clone().modulus())));
88 res
89 }
90
91 #[inline]
95 pub fn unpack_r(self) -> OMatrix<T, DimMinimum<R, C>, C>
96 where
97 DefaultAllocator: Reallocator<T, R, C, DimMinimum<R, C>, C>,
98 {
99 let (nrows, ncols) = self.qr.shape_generic();
100 let mut res = self.qr.resize_generic(nrows.min(ncols), ncols, T::zero());
101 res.fill_lower_triangle(T::zero(), 1);
102 res.set_partial_diagonal(self.diag.iter().map(|e| T::from_real(e.clone().modulus())));
103 res
104 }
105
106 #[must_use]
108 pub fn q(&self) -> OMatrix<T, R, DimMinimum<R, C>>
109 where
110 DefaultAllocator: Allocator<R, DimMinimum<R, C>>,
111 {
112 let (nrows, ncols) = self.qr.shape_generic();
113
114 let mut res = Matrix::identity_generic(nrows, nrows.min(ncols));
117 let dim = self.diag.len();
118
119 for i in (0..dim).rev() {
120 let axis = self.qr.view_range(i.., i);
121 let refl = Reflection::new(Unit::new_unchecked(axis), T::zero());
123
124 let mut res_rows = res.view_range_mut(i.., i..);
125 refl.reflect_with_sign(&mut res_rows, self.diag[i].clone().signum());
126 }
127
128 res
129 }
130
131 pub fn unpack(
133 self,
134 ) -> (
135 OMatrix<T, R, DimMinimum<R, C>>,
136 OMatrix<T, DimMinimum<R, C>, C>,
137 )
138 where
139 DimMinimum<R, C>: DimMin<C, Output = DimMinimum<R, C>>,
140 DefaultAllocator:
141 Allocator<R, DimMinimum<R, C>> + Reallocator<T, R, C, DimMinimum<R, C>, C>,
142 {
143 (self.q(), self.unpack_r())
144 }
145
146 #[doc(hidden)]
147 pub const fn qr_internal(&self) -> &OMatrix<T, R, C> {
148 &self.qr
149 }
150
151 #[must_use]
152 pub(crate) const fn diag_internal(&self) -> &OVector<T, DimMinimum<R, C>> {
153 &self.diag
154 }
155
156 pub fn q_tr_mul<R2: Dim, C2: Dim, S2>(&self, rhs: &mut Matrix<T, R2, C2, S2>)
158 where
160 S2: StorageMut<T, R2, C2>,
161 {
162 let dim = self.diag.len();
163
164 for i in 0..dim {
165 let axis = self.qr.view_range(i.., i);
166 let refl = Reflection::new(Unit::new_unchecked(axis), T::zero());
167
168 let mut rhs_rows = rhs.rows_range_mut(i..);
169 refl.reflect_with_sign(&mut rhs_rows, self.diag[i].clone().signum().conjugate());
170 }
171 }
172}
173
174impl<T: ComplexField, D: DimMin<D, Output = D>> QR<T, D, D>
175where
176 DefaultAllocator: Allocator<D, D> + Allocator<D>,
177{
178 #[must_use = "Did you mean to use solve_mut()?"]
182 pub fn solve<R2: Dim, C2: Dim, S2>(
183 &self,
184 b: &Matrix<T, R2, C2, S2>,
185 ) -> Option<OMatrix<T, R2, C2>>
186 where
187 S2: Storage<T, R2, C2>,
188 ShapeConstraint: SameNumberOfRows<R2, D>,
189 DefaultAllocator: Allocator<R2, C2>,
190 {
191 let mut res = b.clone_owned();
192
193 if self.solve_mut(&mut res) {
194 Some(res)
195 } else {
196 None
197 }
198 }
199
200 pub fn solve_mut<R2: Dim, C2: Dim, S2>(&self, b: &mut Matrix<T, R2, C2, S2>) -> bool
205 where
206 S2: StorageMut<T, R2, C2>,
207 ShapeConstraint: SameNumberOfRows<R2, D>,
208 {
209 assert_eq!(
210 self.qr.nrows(),
211 b.nrows(),
212 "QR solve matrix dimension mismatch."
213 );
214 assert!(
215 self.qr.is_square(),
216 "QR solve: unable to solve a non-square system."
217 );
218
219 self.q_tr_mul(b);
220 self.solve_upper_triangular_mut(b)
221 }
222
223 fn solve_upper_triangular_mut<R2: Dim, C2: Dim, S2>(
225 &self,
226 b: &mut Matrix<T, R2, C2, S2>,
227 ) -> bool
228 where
229 S2: StorageMut<T, R2, C2>,
230 ShapeConstraint: SameNumberOfRows<R2, D>,
231 {
232 let dim = self.qr.nrows();
233
234 for k in 0..b.ncols() {
235 let mut b = b.column_mut(k);
236 for i in (0..dim).rev() {
237 let coeff;
238
239 unsafe {
240 let diag = self.diag.vget_unchecked(i).clone().modulus();
241
242 if diag.is_zero() {
243 return false;
244 }
245
246 coeff = b.vget_unchecked(i).clone().unscale(diag);
247 *b.vget_unchecked_mut(i) = coeff.clone();
248 }
249
250 b.rows_range_mut(..i)
251 .axpy(-coeff, &self.qr.view_range(..i, i), T::one());
252 }
253 }
254
255 true
256 }
257
258 #[must_use]
262 pub fn try_inverse(&self) -> Option<OMatrix<T, D, D>> {
263 assert!(
264 self.qr.is_square(),
265 "QR inverse: unable to compute the inverse of a non-square matrix."
266 );
267
268 let (nrows, ncols) = self.qr.shape_generic();
270 let mut res = OMatrix::identity_generic(nrows, ncols);
271
272 if self.solve_mut(&mut res) {
273 Some(res)
274 } else {
275 None
276 }
277 }
278
279 #[must_use]
281 pub fn is_invertible(&self) -> bool {
282 assert!(
283 self.qr.is_square(),
284 "QR: unable to test the invertibility of a non-square matrix."
285 );
286
287 for i in 0..self.diag.len() {
288 if self.diag[i].is_zero() {
289 return false;
290 }
291 }
292
293 true
294 }
295
296 }