1#![allow(clippy::suspicious_operation_groupings)]
2#[cfg(feature = "serde-serialize-no-std")]
3use serde::{Deserialize, Serialize};
4
5use approx::AbsDiffEq;
6use num_complex::Complex as NumComplex;
7use simba::scalar::{ComplexField, RealField};
8use std::cmp;
9
10use crate::allocator::Allocator;
11use crate::base::dimension::{Const, Dim, DimDiff, DimSub, Dyn, U1, U2};
12use crate::base::storage::Storage;
13use crate::base::{DefaultAllocator, OMatrix, OVector, SquareMatrix, Unit, Vector2, Vector3};
14
15use crate::geometry::Reflection;
16use crate::linalg::givens::GivensRotation;
17use crate::linalg::householder;
18use crate::linalg::Hessenberg;
19use crate::{Matrix, UninitVector};
20use std::mem::MaybeUninit;
21
22#[cfg_attr(feature = "serde-serialize-no-std", derive(Serialize, Deserialize))]
26#[cfg_attr(
27 feature = "serde-serialize-no-std",
28 serde(bound(serialize = "DefaultAllocator: Allocator<D, D>,
29 OMatrix<T, D, D>: Serialize"))
30)]
31#[cfg_attr(
32 feature = "serde-serialize-no-std",
33 serde(bound(deserialize = "DefaultAllocator: Allocator<D, D>,
34 OMatrix<T, D, D>: Deserialize<'de>"))
35)]
36#[derive(Clone, Debug)]
37pub struct Schur<T: ComplexField, D: Dim>
38where
39 DefaultAllocator: Allocator<D, D>,
40{
41 q: OMatrix<T, D, D>,
42 t: OMatrix<T, D, D>,
43}
44
45impl<T: ComplexField, D: Dim> Copy for Schur<T, D>
46where
47 DefaultAllocator: Allocator<D, D>,
48 OMatrix<T, D, D>: Copy,
49{
50}
51
52impl<T: ComplexField, D: Dim> Schur<T, D>
53where
54 D: DimSub<U1>, DefaultAllocator:
56 Allocator<D, DimDiff<D, U1>> + Allocator<DimDiff<D, U1>> + Allocator<D, D> + Allocator<D>,
57{
58 pub fn new(m: OMatrix<T, D, D>) -> Self {
60 Self::try_new(m, T::RealField::default_epsilon(), 0).unwrap()
61 }
62
63 pub fn try_new(m: OMatrix<T, D, D>, eps: T::RealField, max_niter: usize) -> Option<Self> {
75 let mut work = Matrix::zeros_generic(m.shape_generic().0, Const::<1>);
76
77 Self::do_decompose(m, &mut work, eps, max_niter, true)
78 .map(|(q, t)| Schur { q: q.unwrap(), t })
79 }
80
81 fn do_decompose(
82 mut m: OMatrix<T, D, D>,
83 work: &mut OVector<T, D>,
84 eps: T::RealField,
85 max_niter: usize,
86 compute_q: bool,
87 ) -> Option<(Option<OMatrix<T, D, D>>, OMatrix<T, D, D>)> {
88 assert!(
89 m.is_square(),
90 "Unable to compute the eigenvectors and eigenvalues of a non-square matrix."
91 );
92
93 let dim = m.shape_generic().0;
94
95 if dim.value() == 0 {
97 let vecs = Some(OMatrix::from_element_generic(dim, dim, T::zero()));
98 let vals = OMatrix::from_element_generic(dim, dim, T::zero());
99 return Some((vecs, vals));
100 } else if dim.value() == 1 {
101 if compute_q {
102 let q = OMatrix::from_element_generic(dim, dim, T::one());
103 return Some((Some(q), m));
104 } else {
105 return Some((None, m));
106 }
107 } else if dim.value() == 2 {
108 return decompose_2x2(m, compute_q);
109 }
110
111 let amax_m = m.camax();
112 m.unscale_mut(amax_m.clone());
113
114 let hess = Hessenberg::new_with_workspace(m, work);
115 let mut q;
116 let mut t;
117
118 if compute_q {
119 let (vecs, vals) = hess.unpack();
122 q = Some(vecs);
123 t = vals;
124 } else {
125 q = None;
126 t = hess.unpack_h()
127 }
128
129 let mut niter = 0;
131 let (mut start, mut end) = Self::delimit_subproblem(&mut t, eps.clone(), dim.value() - 1);
132
133 while end != start {
134 let subdim = end - start + 1;
135
136 if subdim > 2 {
137 let m = end - 1;
138 let n = end;
139
140 let h11 = t[(start, start)].clone();
141 let h12 = t[(start, start + 1)].clone();
142 let h21 = t[(start + 1, start)].clone();
143 let h22 = t[(start + 1, start + 1)].clone();
144 let h32 = t[(start + 2, start + 1)].clone();
145
146 let hnn = t[(n, n)].clone();
147 let hmm = t[(m, m)].clone();
148 let hnm = t[(n, m)].clone();
149 let hmn = t[(m, n)].clone();
150
151 let tra = hnn.clone() + hmm.clone();
152 let det = hnn * hmm - hnm * hmn;
153
154 let mut axis = Vector3::new(
155 h11.clone() * h11.clone() + h12 * h21.clone() - tra.clone() * h11.clone() + det,
156 h21.clone() * (h11 + h22 - tra),
157 h21 * h32,
158 );
159
160 for k in start..n - 1 {
161 let (norm, not_zero) = householder::reflection_axis_mut(&mut axis);
162
163 if not_zero {
164 if k > start {
165 t[(k, k - 1)] = norm;
166 t[(k + 1, k - 1)] = T::zero();
167 t[(k + 2, k - 1)] = T::zero();
168 }
169
170 let refl = Reflection::new(Unit::new_unchecked(axis.clone()), T::zero());
171
172 {
173 let krows = cmp::min(k + 4, end + 1);
174 let mut work = work.rows_mut(0, krows);
175 refl.reflect(
176 &mut t.generic_view_mut((k, k), (Const::<3>, Dyn(dim.value() - k))),
177 );
178 refl.reflect_rows(
179 &mut t.generic_view_mut((0, k), (Dyn(krows), Const::<3>)),
180 &mut work,
181 );
182 }
183
184 if let Some(ref mut q) = q {
185 refl.reflect_rows(
186 &mut q.generic_view_mut((0, k), (dim, Const::<3>)),
187 work,
188 );
189 }
190 }
191
192 axis.x = t[(k + 1, k)].clone();
193 axis.y = t[(k + 2, k)].clone();
194
195 if k < n - 2 {
196 axis.z = t[(k + 3, k)].clone();
197 }
198 }
199
200 let mut axis = Vector2::new(axis.x.clone(), axis.y.clone());
201 let (norm, not_zero) = householder::reflection_axis_mut(&mut axis);
202
203 if not_zero {
204 let refl = Reflection::new(Unit::new_unchecked(axis), T::zero());
205
206 t[(m, m - 1)] = norm;
207 t[(n, m - 1)] = T::zero();
208
209 {
210 let mut work = work.rows_mut(0, end + 1);
211 refl.reflect(
212 &mut t.generic_view_mut((m, m), (Const::<2>, Dyn(dim.value() - m))),
213 );
214 refl.reflect_rows(
215 &mut t.generic_view_mut((0, m), (Dyn(end + 1), Const::<2>)),
216 &mut work,
217 );
218 }
219
220 if let Some(ref mut q) = q {
221 refl.reflect_rows(&mut q.generic_view_mut((0, m), (dim, Const::<2>)), work);
222 }
223 }
224 } else {
225 if let Some(rot) = compute_2x2_basis(&t.fixed_view::<2, 2>(start, start)) {
227 let inv_rot = rot.inverse();
228 inv_rot.rotate(
229 &mut t.generic_view_mut(
230 (start, start),
231 (Const::<2>, Dyn(dim.value() - start)),
232 ),
233 );
234 rot.rotate_rows(
235 &mut t.generic_view_mut((0, start), (Dyn(end + 1), Const::<2>)),
236 );
237 t[(end, start)] = T::zero();
238
239 if let Some(ref mut q) = q {
240 rot.rotate_rows(&mut q.generic_view_mut((0, start), (dim, Const::<2>)));
241 }
242 }
243
244 if end > 2 {
246 end -= 2;
247 } else {
248 break;
249 }
250 }
251
252 let sub = Self::delimit_subproblem(&mut t, eps.clone(), end);
253
254 start = sub.0;
255 end = sub.1;
256
257 niter += 1;
258 if niter == max_niter {
259 return None;
260 }
261 }
262
263 t.scale_mut(amax_m);
264
265 Some((q, t))
266 }
267
268 fn do_eigenvalues(t: &OMatrix<T, D, D>, out: &mut OVector<T, D>) -> bool {
270 let dim = t.nrows();
271 let mut m = 0;
272
273 while m < dim - 1 {
274 let n = m + 1;
275
276 if t[(n, m)].is_zero() {
277 out[m] = t[(m, m)].clone();
278 m += 1;
279 } else {
280 return false;
282 }
283 }
284
285 if m == dim - 1 {
286 out[m] = t[(m, m)].clone();
287 }
288
289 true
290 }
291
292 fn do_complex_eigenvalues(t: &OMatrix<T, D, D>, out: &mut UninitVector<NumComplex<T>, D>)
294 where
295 T: RealField,
296 DefaultAllocator: Allocator<D>,
297 {
298 let dim = t.nrows();
299 let mut m = 0;
300
301 while m < dim - 1 {
302 let n = m + 1;
303
304 if t[(n, m)].is_zero() {
305 out[m] = MaybeUninit::new(NumComplex::new(t[(m, m)].clone(), T::zero()));
306 m += 1;
307 } else {
308 let hmm = t[(m, m)].clone();
310 let hnm = t[(n, m)].clone();
311 let hmn = t[(m, n)].clone();
312 let hnn = t[(n, n)].clone();
313
314 let val = (hmm.clone() - hnn.clone()) * crate::convert(0.5);
316 let discr = hnm * hmn + val.clone() * val;
317
318 let sqrt_discr = NumComplex::new(T::zero(), (-discr).sqrt());
321
322 let half_tra = (hnn + hmm) * crate::convert(0.5);
323 out[m] = MaybeUninit::new(
324 NumComplex::new(half_tra.clone(), T::zero()) + sqrt_discr.clone(),
325 );
326 out[m + 1] =
327 MaybeUninit::new(NumComplex::new(half_tra, T::zero()) - sqrt_discr.clone());
328
329 m += 2;
330 }
331 }
332
333 if m == dim - 1 {
334 out[m] = MaybeUninit::new(NumComplex::new(t[(m, m)].clone(), T::zero()));
335 }
336 }
337
338 fn delimit_subproblem(t: &mut OMatrix<T, D, D>, eps: T::RealField, end: usize) -> (usize, usize)
339 where
340 D: DimSub<U1>,
341 DefaultAllocator: Allocator<DimDiff<D, U1>>,
342 {
343 let mut n = end;
344
345 while n > 0 {
346 let m = n - 1;
347
348 if t[(n, m)].clone().norm1()
349 <= eps.clone() * (t[(n, n)].clone().norm1() + t[(m, m)].clone().norm1())
350 {
351 t[(n, m)] = T::zero();
352 } else {
353 break;
354 }
355
356 n -= 1;
357 }
358
359 if n == 0 {
360 return (0, 0);
361 }
362
363 let mut new_start = n - 1;
364 while new_start > 0 {
365 let m = new_start - 1;
366
367 let off_diag = t[(new_start, m)].clone();
368 if off_diag.is_zero()
369 || off_diag.norm1()
370 <= eps.clone()
371 * (t[(new_start, new_start)].clone().norm1() + t[(m, m)].clone().norm1())
372 {
373 t[(new_start, m)] = T::zero();
374 break;
375 }
376
377 new_start -= 1;
378 }
379
380 (new_start, n)
381 }
382
383 pub fn unpack(self) -> (OMatrix<T, D, D>, OMatrix<T, D, D>) {
386 (self.q, self.t)
387 }
388
389 #[must_use]
393 pub fn eigenvalues(&self) -> Option<OVector<T, D>> {
394 let mut out = Matrix::zeros_generic(self.t.shape_generic().0, Const::<1>);
395 if Self::do_eigenvalues(&self.t, &mut out) {
396 Some(out)
397 } else {
398 None
399 }
400 }
401
402 #[must_use]
404 pub fn complex_eigenvalues(&self) -> OVector<NumComplex<T>, D>
405 where
406 T: RealField,
407 DefaultAllocator: Allocator<D>,
408 {
409 let mut out = Matrix::uninit(self.t.shape_generic().0, Const::<1>);
410 Self::do_complex_eigenvalues(&self.t, &mut out);
411 unsafe { out.assume_init() }
413 }
414}
415
416fn decompose_2x2<T: ComplexField, D: Dim>(
417 mut m: OMatrix<T, D, D>,
418 compute_q: bool,
419) -> Option<(Option<OMatrix<T, D, D>>, OMatrix<T, D, D>)>
420where
421 DefaultAllocator: Allocator<D, D>,
422{
423 let dim = m.shape_generic().0;
424 let mut q = None;
425 match compute_2x2_basis(&m.fixed_view::<2, 2>(0, 0)) {
426 Some(rot) => {
427 let mut m = m.fixed_view_mut::<2, 2>(0, 0);
428 let inv_rot = rot.inverse();
429 inv_rot.rotate(&mut m);
430 rot.rotate_rows(&mut m);
431 m[(1, 0)] = T::zero();
432
433 if compute_q {
434 let c = T::from_real(rot.c());
437 q = Some(OMatrix::from_column_slice_generic(
438 dim,
439 dim,
440 &[c.clone(), rot.s(), -rot.s().conjugate(), c],
441 ));
442 }
443 }
444 None => {
445 if compute_q {
446 q = Some(OMatrix::identity_generic(dim, dim));
447 }
448 }
449 };
450
451 Some((q, m))
452}
453
454fn compute_2x2_eigvals<T: ComplexField, S: Storage<T, U2, U2>>(
455 m: &SquareMatrix<T, U2, S>,
456) -> Option<(T, T)> {
457 let h00 = m[(0, 0)].clone();
459 let h10 = m[(1, 0)].clone();
460 let h01 = m[(0, 1)].clone();
461 let h11 = m[(1, 1)].clone();
462
463 let val = (h00.clone() - h11.clone()) * crate::convert(0.5);
467 let discr = h10 * h01 + val.clone() * val;
468
469 discr.try_sqrt().map(|sqrt_discr| {
470 let half_tra = (h00 + h11) * crate::convert(0.5);
471 (half_tra.clone() + sqrt_discr.clone(), half_tra - sqrt_discr)
472 })
473}
474
475fn compute_2x2_basis<T: ComplexField, S: Storage<T, U2, U2>>(
481 m: &SquareMatrix<T, U2, S>,
482) -> Option<GivensRotation<T>> {
483 let h10 = m[(1, 0)].clone();
484
485 if h10.is_zero() {
486 return None;
487 }
488
489 if let Some((eigval1, eigval2)) = compute_2x2_eigvals(m) {
490 let x1 = eigval1 - m[(1, 1)].clone();
491 let x2 = eigval2 - m[(1, 1)].clone();
492
493 if x1.clone().norm1() > x2.clone().norm1() {
497 Some(GivensRotation::new(x1, h10).0)
498 } else {
499 Some(GivensRotation::new(x2, h10).0)
500 }
501 } else {
502 None
503 }
504}
505
506impl<T: ComplexField, D: Dim, S: Storage<T, D, D>> SquareMatrix<T, D, S>
507where
508 D: DimSub<U1>, DefaultAllocator:
510 Allocator<D, DimDiff<D, U1>> + Allocator<DimDiff<D, U1>> + Allocator<D, D> + Allocator<D>,
511{
512 #[must_use]
514 pub fn eigenvalues(&self) -> Option<OVector<T, D>> {
515 assert!(
516 self.is_square(),
517 "Unable to compute eigenvalues of a non-square matrix."
518 );
519
520 let mut work = Matrix::zeros_generic(self.shape_generic().0, Const::<1>);
521
522 if self.nrows() == 2 {
524 let me = self.fixed_view::<2, 2>(0, 0);
527 return match compute_2x2_eigvals(&me) {
528 Some((a, b)) => {
529 work[0] = a;
530 work[1] = b;
531 Some(work)
532 }
533 None => None,
534 };
535 }
536
537 let schur = Schur::do_decompose(
539 self.clone_owned(),
540 &mut work,
541 T::RealField::default_epsilon(),
542 0,
543 false,
544 )
545 .unwrap();
546
547 if Schur::do_eigenvalues(&schur.1, &mut work) {
548 Some(work)
549 } else {
550 None
551 }
552 }
553
554 #[must_use]
556 pub fn complex_eigenvalues(&self) -> OVector<NumComplex<T>, D>
557 where
559 T: RealField,
560 DefaultAllocator: Allocator<D>,
561 {
562 let dim = self.shape_generic().0;
563 let mut work = Matrix::zeros_generic(dim, Const::<1>);
564
565 let schur = Schur::do_decompose(
566 self.clone_owned(),
567 &mut work,
568 T::default_epsilon(),
569 0,
570 false,
571 )
572 .unwrap();
573 let mut eig = Matrix::uninit(dim, Const::<1>);
574 Schur::do_complex_eigenvalues(&schur.1, &mut eig);
575 unsafe { eig.assume_init() }
577 }
578}