1#![allow(clippy::suspicious_operation_groupings)]
2#[cfg(feature = "serde-serialize-no-std")]
3use serde::{Deserialize, Serialize};
4
5use approx::AbsDiffEq;
6use num_complex::Complex as NumComplex;
7use num_traits::identities::Zero;
8use simba::scalar::{ComplexField, RealField};
9use std::cmp;
10
11use crate::allocator::Allocator;
12use crate::base::dimension::{Const, Dim, DimDiff, DimSub, Dyn, U1, U2};
13use crate::base::storage::Storage;
14use crate::base::{DefaultAllocator, OMatrix, OVector, SquareMatrix, Unit, Vector2, Vector3};
15
16use crate::geometry::Reflection;
17use crate::linalg::Hessenberg;
18use crate::linalg::givens::GivensRotation;
19use crate::linalg::householder;
20use crate::{Matrix, UninitVector};
21use std::mem::MaybeUninit;
22
23#[cfg_attr(feature = "serde-serialize-no-std", derive(Serialize, Deserialize))]
27#[cfg_attr(
28 feature = "serde-serialize-no-std",
29 serde(bound(serialize = "DefaultAllocator: Allocator<D, D>,
30 OMatrix<T, D, D>: Serialize"))
31)]
32#[cfg_attr(
33 feature = "serde-serialize-no-std",
34 serde(bound(deserialize = "DefaultAllocator: Allocator<D, D>,
35 OMatrix<T, D, D>: Deserialize<'de>"))
36)]
37#[cfg_attr(feature = "defmt", derive(defmt::Format))]
38#[derive(Clone, Debug)]
39pub struct Schur<T: ComplexField, D: Dim>
40where
41 DefaultAllocator: Allocator<D, D>,
42{
43 q: OMatrix<T, D, D>,
44 t: OMatrix<T, D, D>,
45}
46
47impl<T: ComplexField, D: Dim> Copy for Schur<T, D>
48where
49 DefaultAllocator: Allocator<D, D>,
50 OMatrix<T, D, D>: Copy,
51{
52}
53
54impl<T: ComplexField, D: Dim> Schur<T, D>
55where
56 D: DimSub<U1>, DefaultAllocator:
58 Allocator<D, DimDiff<D, U1>> + Allocator<DimDiff<D, U1>> + Allocator<D, D> + Allocator<D>,
59{
60 pub fn new(m: OMatrix<T, D, D>) -> Self {
62 Self::try_new(m, T::RealField::default_epsilon(), 0).unwrap()
63 }
64
65 pub fn try_new(m: OMatrix<T, D, D>, eps: T::RealField, max_niter: usize) -> Option<Self> {
77 let mut work = Matrix::zeros_generic(m.shape_generic().0, Const::<1>);
78
79 Self::do_decompose(m, &mut work, eps, max_niter, true)
80 .map(|(q, t)| Schur { q: q.unwrap(), t })
81 }
82
83 fn do_decompose(
84 mut m: OMatrix<T, D, D>,
85 work: &mut OVector<T, D>,
86 eps: T::RealField,
87 max_niter: usize,
88 compute_q: bool,
89 ) -> Option<(Option<OMatrix<T, D, D>>, OMatrix<T, D, D>)> {
90 assert!(
91 m.is_square(),
92 "Unable to compute the eigenvectors and eigenvalues of a non-square matrix."
93 );
94
95 let dim = m.shape_generic().0;
96
97 if dim.value() == 0 {
99 let vecs = Some(OMatrix::from_element_generic(dim, dim, T::zero()));
100 let vals = OMatrix::from_element_generic(dim, dim, T::zero());
101 return Some((vecs, vals));
102 } else if dim.value() == 1 {
103 if compute_q {
104 let q = OMatrix::from_element_generic(dim, dim, T::one());
105 return Some((Some(q), m));
106 } else {
107 return Some((None, m));
108 }
109 } else if dim.value() == 2 {
110 return decompose_2x2(m, compute_q);
111 }
112
113 let amax_m = m.camax();
114 if !amax_m.is_zero() {
118 m.unscale_mut(amax_m.clone());
119 }
120
121 let hess = Hessenberg::new_with_workspace(m, work);
122 let mut q;
123 let mut t;
124
125 if compute_q {
126 let (vecs, vals) = hess.unpack();
129 q = Some(vecs);
130 t = vals;
131 } else {
132 q = None;
133 t = hess.unpack_h()
134 }
135
136 let mut niter = 0;
138 let (mut start, mut end) = Self::delimit_subproblem(&mut t, eps.clone(), dim.value() - 1);
139
140 while end != start {
141 let subdim = end - start + 1;
142
143 if subdim > 2 {
144 let m = end - 1;
145 let n = end;
146
147 let h11 = t[(start, start)].clone();
148 let h12 = t[(start, start + 1)].clone();
149 let h21 = t[(start + 1, start)].clone();
150 let h22 = t[(start + 1, start + 1)].clone();
151 let h32 = t[(start + 2, start + 1)].clone();
152
153 let hnn = t[(n, n)].clone();
154 let hmm = t[(m, m)].clone();
155 let hnm = t[(n, m)].clone();
156 let hmn = t[(m, n)].clone();
157
158 let tra = hnn.clone() + hmm.clone();
159 let det = hnn * hmm - hnm * hmn;
160
161 let mut axis = Vector3::new(
162 h11.clone() * h11.clone() + h12 * h21.clone() - tra.clone() * h11.clone() + det,
163 h21.clone() * (h11 + h22 - tra),
164 h21 * h32,
165 );
166
167 for k in start..n - 1 {
168 let (norm, not_zero) = householder::reflection_axis_mut(&mut axis);
169
170 if not_zero {
171 if k > start {
172 t[(k, k - 1)] = norm;
173 t[(k + 1, k - 1)] = T::zero();
174 t[(k + 2, k - 1)] = T::zero();
175 }
176
177 let refl = Reflection::new(Unit::new_unchecked(axis.clone()), T::zero());
178
179 {
180 let krows = cmp::min(k + 4, end + 1);
181 let mut work = work.rows_mut(0, krows);
182 refl.reflect(
183 &mut t.generic_view_mut((k, k), (Const::<3>, Dyn(dim.value() - k))),
184 );
185 refl.reflect_rows(
186 &mut t.generic_view_mut((0, k), (Dyn(krows), Const::<3>)),
187 &mut work,
188 );
189 }
190
191 if let Some(ref mut q) = q {
192 refl.reflect_rows(
193 &mut q.generic_view_mut((0, k), (dim, Const::<3>)),
194 work,
195 );
196 }
197 }
198
199 axis.x = t[(k + 1, k)].clone();
200 axis.y = t[(k + 2, k)].clone();
201
202 if k < n - 2 {
203 axis.z = t[(k + 3, k)].clone();
204 }
205 }
206
207 let mut axis = Vector2::new(axis.x.clone(), axis.y.clone());
208 let (norm, not_zero) = householder::reflection_axis_mut(&mut axis);
209
210 if not_zero {
211 let refl = Reflection::new(Unit::new_unchecked(axis), T::zero());
212
213 t[(m, m - 1)] = norm;
214 t[(n, m - 1)] = T::zero();
215
216 {
217 let mut work = work.rows_mut(0, end + 1);
218 refl.reflect(
219 &mut t.generic_view_mut((m, m), (Const::<2>, Dyn(dim.value() - m))),
220 );
221 refl.reflect_rows(
222 &mut t.generic_view_mut((0, m), (Dyn(end + 1), Const::<2>)),
223 &mut work,
224 );
225 }
226
227 if let Some(ref mut q) = q {
228 refl.reflect_rows(&mut q.generic_view_mut((0, m), (dim, Const::<2>)), work);
229 }
230 }
231 } else {
232 if let Some(rot) = compute_2x2_basis(&t.fixed_view::<2, 2>(start, start)) {
234 let inv_rot = rot.inverse();
235 inv_rot.rotate(
236 &mut t.generic_view_mut(
237 (start, start),
238 (Const::<2>, Dyn(dim.value() - start)),
239 ),
240 );
241 rot.rotate_rows(
242 &mut t.generic_view_mut((0, start), (Dyn(end + 1), Const::<2>)),
243 );
244 t[(end, start)] = T::zero();
245
246 if let Some(ref mut q) = q {
247 rot.rotate_rows(&mut q.generic_view_mut((0, start), (dim, Const::<2>)));
248 }
249 }
250
251 if end > 2 {
253 end -= 2;
254 } else {
255 break;
256 }
257 }
258
259 let sub = Self::delimit_subproblem(&mut t, eps.clone(), end);
260
261 start = sub.0;
262 end = sub.1;
263
264 niter += 1;
265 if niter == max_niter {
266 return None;
267 }
268 }
269
270 t.scale_mut(amax_m);
271
272 Some((q, t))
273 }
274
275 fn do_eigenvalues(t: &OMatrix<T, D, D>, out: &mut OVector<T, D>) -> bool {
277 let dim = t.nrows();
278 let mut m = 0;
279
280 while m < dim - 1 {
281 let n = m + 1;
282
283 if t[(n, m)].is_zero() {
284 out[m] = t[(m, m)].clone();
285 m += 1;
286 } else {
287 return false;
289 }
290 }
291
292 if m == dim - 1 {
293 out[m] = t[(m, m)].clone();
294 }
295
296 true
297 }
298
299 fn do_complex_eigenvalues(t: &OMatrix<T, D, D>, out: &mut UninitVector<NumComplex<T>, D>)
301 where
302 T: RealField,
303 DefaultAllocator: Allocator<D>,
304 {
305 let dim = t.nrows();
306 let mut m = 0;
307
308 while m < dim - 1 {
309 let n = m + 1;
310
311 if t[(n, m)].is_zero() {
312 out[m] = MaybeUninit::new(NumComplex::new(t[(m, m)].clone(), T::zero()));
313 m += 1;
314 } else {
315 let hmm = t[(m, m)].clone();
317 let hnm = t[(n, m)].clone();
318 let hmn = t[(m, n)].clone();
319 let hnn = t[(n, n)].clone();
320
321 let val = (hmm.clone() - hnn.clone()) * crate::convert(0.5);
323 let discr = hnm * hmn + val.clone() * val;
324
325 let sqrt_discr = NumComplex::new(T::zero(), (-discr).sqrt());
328
329 let half_tra = (hnn + hmm) * crate::convert(0.5);
330 out[m] = MaybeUninit::new(
331 NumComplex::new(half_tra.clone(), T::zero()) + sqrt_discr.clone(),
332 );
333 out[m + 1] =
334 MaybeUninit::new(NumComplex::new(half_tra, T::zero()) - sqrt_discr.clone());
335
336 m += 2;
337 }
338 }
339
340 if m == dim - 1 {
341 out[m] = MaybeUninit::new(NumComplex::new(t[(m, m)].clone(), T::zero()));
342 }
343 }
344
345 fn delimit_subproblem(t: &mut OMatrix<T, D, D>, eps: T::RealField, end: usize) -> (usize, usize)
346 where
347 D: DimSub<U1>,
348 DefaultAllocator: Allocator<DimDiff<D, U1>>,
349 {
350 let mut n = end;
351
352 while n > 0 {
353 let m = n - 1;
354
355 if t[(n, m)].clone().norm1()
356 <= eps.clone() * (t[(n, n)].clone().norm1() + t[(m, m)].clone().norm1())
357 {
358 t[(n, m)] = T::zero();
359 } else {
360 break;
361 }
362
363 n -= 1;
364 }
365
366 if n == 0 {
367 return (0, 0);
368 }
369
370 let mut new_start = n - 1;
371 while new_start > 0 {
372 let m = new_start - 1;
373
374 let off_diag = t[(new_start, m)].clone();
375 if off_diag.is_zero()
376 || off_diag.norm1()
377 <= eps.clone()
378 * (t[(new_start, new_start)].clone().norm1() + t[(m, m)].clone().norm1())
379 {
380 t[(new_start, m)] = T::zero();
381 break;
382 }
383
384 new_start -= 1;
385 }
386
387 (new_start, n)
388 }
389
390 pub fn unpack(self) -> (OMatrix<T, D, D>, OMatrix<T, D, D>) {
393 (self.q, self.t)
394 }
395
396 #[must_use]
400 pub fn eigenvalues(&self) -> Option<OVector<T, D>> {
401 let mut out = Matrix::zeros_generic(self.t.shape_generic().0, Const::<1>);
402 if Self::do_eigenvalues(&self.t, &mut out) {
403 Some(out)
404 } else {
405 None
406 }
407 }
408
409 #[must_use]
411 pub fn complex_eigenvalues(&self) -> OVector<NumComplex<T>, D>
412 where
413 T: RealField,
414 DefaultAllocator: Allocator<D>,
415 {
416 let mut out = Matrix::uninit(self.t.shape_generic().0, Const::<1>);
417 Self::do_complex_eigenvalues(&self.t, &mut out);
418 unsafe { out.assume_init() }
420 }
421}
422
423fn decompose_2x2<T: ComplexField, D: Dim>(
424 mut m: OMatrix<T, D, D>,
425 compute_q: bool,
426) -> Option<(Option<OMatrix<T, D, D>>, OMatrix<T, D, D>)>
427where
428 DefaultAllocator: Allocator<D, D>,
429{
430 let dim = m.shape_generic().0;
431 let mut q = None;
432 match compute_2x2_basis(&m.fixed_view::<2, 2>(0, 0)) {
433 Some(rot) => {
434 let mut m = m.fixed_view_mut::<2, 2>(0, 0);
435 let inv_rot = rot.inverse();
436 inv_rot.rotate(&mut m);
437 rot.rotate_rows(&mut m);
438 m[(1, 0)] = T::zero();
439
440 if compute_q {
441 let c = T::from_real(rot.c());
444 q = Some(OMatrix::from_column_slice_generic(
445 dim,
446 dim,
447 &[c.clone(), rot.s(), -rot.s().conjugate(), c],
448 ));
449 }
450 }
451 None => {
452 if compute_q {
453 q = Some(OMatrix::identity_generic(dim, dim));
454 }
455 }
456 };
457
458 Some((q, m))
459}
460
461fn compute_2x2_eigvals<T: ComplexField, S: Storage<T, U2, U2>>(
462 m: &SquareMatrix<T, U2, S>,
463) -> Option<(T, T)> {
464 let h00 = m[(0, 0)].clone();
466 let h10 = m[(1, 0)].clone();
467 let h01 = m[(0, 1)].clone();
468 let h11 = m[(1, 1)].clone();
469
470 let val = (h00.clone() - h11.clone()) * crate::convert(0.5);
474 let discr = h10 * h01 + val.clone() * val;
475
476 discr.try_sqrt().map(|sqrt_discr| {
477 let half_tra = (h00 + h11) * crate::convert(0.5);
478 (half_tra.clone() + sqrt_discr.clone(), half_tra - sqrt_discr)
479 })
480}
481
482fn compute_2x2_basis<T: ComplexField, S: Storage<T, U2, U2>>(
488 m: &SquareMatrix<T, U2, S>,
489) -> Option<GivensRotation<T>> {
490 let h10 = m[(1, 0)].clone();
491
492 if h10.is_zero() {
493 return None;
494 }
495
496 let (eigval1, eigval2) = compute_2x2_eigvals(m)?;
497 let x1 = eigval1 - m[(1, 1)].clone();
498 let x2 = eigval2 - m[(1, 1)].clone();
499
500 if x1.clone().norm1() > x2.clone().norm1() {
504 Some(GivensRotation::new(x1, h10).0)
505 } else {
506 Some(GivensRotation::new(x2, h10).0)
507 }
508}
509
510impl<T: ComplexField, D: Dim, S: Storage<T, D, D>> SquareMatrix<T, D, S>
511where
512 D: DimSub<U1>, DefaultAllocator:
514 Allocator<D, DimDiff<D, U1>> + Allocator<DimDiff<D, U1>> + Allocator<D, D> + Allocator<D>,
515{
516 #[must_use]
518 pub fn eigenvalues(&self) -> Option<OVector<T, D>> {
519 assert!(
520 self.is_square(),
521 "Unable to compute eigenvalues of a non-square matrix."
522 );
523
524 let mut work = Matrix::zeros_generic(self.shape_generic().0, Const::<1>);
525
526 if self.nrows() == 2 {
528 let me = self.fixed_view::<2, 2>(0, 0);
531 return match compute_2x2_eigvals(&me) {
532 Some((a, b)) => {
533 work[0] = a;
534 work[1] = b;
535 Some(work)
536 }
537 None => None,
538 };
539 }
540
541 let schur = Schur::do_decompose(
543 self.clone_owned(),
544 &mut work,
545 T::RealField::default_epsilon(),
546 0,
547 false,
548 )
549 .unwrap();
550
551 if Schur::do_eigenvalues(&schur.1, &mut work) {
552 Some(work)
553 } else {
554 None
555 }
556 }
557
558 #[must_use]
560 pub fn complex_eigenvalues(&self) -> OVector<NumComplex<T>, D>
561 where
563 T: RealField,
564 DefaultAllocator: Allocator<D>,
565 {
566 let dim = self.shape_generic().0;
567 let mut work = Matrix::zeros_generic(dim, Const::<1>);
568
569 let schur = Schur::do_decompose(
570 self.clone_owned(),
571 &mut work,
572 T::default_epsilon(),
573 0,
574 false,
575 )
576 .unwrap();
577 let mut eig = Matrix::uninit(dim, Const::<1>);
578 Schur::do_complex_eigenvalues(&schur.1, &mut eig);
579 unsafe { eig.assume_init() }
581 }
582}