nalgebra/linalg/
schur.rs

1#![allow(clippy::suspicious_operation_groupings)]
2#[cfg(feature = "serde-serialize-no-std")]
3use serde::{Deserialize, Serialize};
4
5use approx::AbsDiffEq;
6use num_complex::Complex as NumComplex;
7use simba::scalar::{ComplexField, RealField};
8use std::cmp;
9
10use crate::allocator::Allocator;
11use crate::base::dimension::{Const, Dim, DimDiff, DimSub, Dyn, U1, U2};
12use crate::base::storage::Storage;
13use crate::base::{DefaultAllocator, OMatrix, OVector, SquareMatrix, Unit, Vector2, Vector3};
14
15use crate::geometry::Reflection;
16use crate::linalg::givens::GivensRotation;
17use crate::linalg::householder;
18use crate::linalg::Hessenberg;
19use crate::{Matrix, UninitVector};
20use std::mem::MaybeUninit;
21
22/// Schur decomposition of a square matrix.
23///
24/// If this is a real matrix, this will be a `RealField` Schur decomposition.
25#[cfg_attr(feature = "serde-serialize-no-std", derive(Serialize, Deserialize))]
26#[cfg_attr(
27    feature = "serde-serialize-no-std",
28    serde(bound(serialize = "DefaultAllocator: Allocator<D, D>,
29         OMatrix<T, D, D>: Serialize"))
30)]
31#[cfg_attr(
32    feature = "serde-serialize-no-std",
33    serde(bound(deserialize = "DefaultAllocator: Allocator<D, D>,
34         OMatrix<T, D, D>: Deserialize<'de>"))
35)]
36#[derive(Clone, Debug)]
37pub struct Schur<T: ComplexField, D: Dim>
38where
39    DefaultAllocator: Allocator<D, D>,
40{
41    q: OMatrix<T, D, D>,
42    t: OMatrix<T, D, D>,
43}
44
45impl<T: ComplexField, D: Dim> Copy for Schur<T, D>
46where
47    DefaultAllocator: Allocator<D, D>,
48    OMatrix<T, D, D>: Copy,
49{
50}
51
52impl<T: ComplexField, D: Dim> Schur<T, D>
53where
54    D: DimSub<U1>, // For Hessenberg.
55    DefaultAllocator:
56        Allocator<D, DimDiff<D, U1>> + Allocator<DimDiff<D, U1>> + Allocator<D, D> + Allocator<D>,
57{
58    /// Computes the Schur decomposition of a square matrix.
59    pub fn new(m: OMatrix<T, D, D>) -> Self {
60        Self::try_new(m, T::RealField::default_epsilon(), 0).unwrap()
61    }
62
63    /// Attempts to compute the Schur decomposition of a square matrix.
64    ///
65    /// If only eigenvalues are needed, it is more efficient to call the matrix method
66    /// `.eigenvalues()` instead.
67    ///
68    /// # Arguments
69    ///
70    /// * `eps`       − tolerance used to determine when a value converged to 0.
71    /// * `max_niter` − maximum total number of iterations performed by the algorithm. If this
72    ///   number of iteration is exceeded, `None` is returned. If `niter == 0`, then the algorithm
73    ///   continues indefinitely until convergence.
74    pub fn try_new(m: OMatrix<T, D, D>, eps: T::RealField, max_niter: usize) -> Option<Self> {
75        let mut work = Matrix::zeros_generic(m.shape_generic().0, Const::<1>);
76
77        Self::do_decompose(m, &mut work, eps, max_niter, true)
78            .map(|(q, t)| Schur { q: q.unwrap(), t })
79    }
80
81    fn do_decompose(
82        mut m: OMatrix<T, D, D>,
83        work: &mut OVector<T, D>,
84        eps: T::RealField,
85        max_niter: usize,
86        compute_q: bool,
87    ) -> Option<(Option<OMatrix<T, D, D>>, OMatrix<T, D, D>)> {
88        assert!(
89            m.is_square(),
90            "Unable to compute the eigenvectors and eigenvalues of a non-square matrix."
91        );
92
93        let dim = m.shape_generic().0;
94
95        // Specialization would make this easier.
96        if dim.value() == 0 {
97            let vecs = Some(OMatrix::from_element_generic(dim, dim, T::zero()));
98            let vals = OMatrix::from_element_generic(dim, dim, T::zero());
99            return Some((vecs, vals));
100        } else if dim.value() == 1 {
101            if compute_q {
102                let q = OMatrix::from_element_generic(dim, dim, T::one());
103                return Some((Some(q), m));
104            } else {
105                return Some((None, m));
106            }
107        } else if dim.value() == 2 {
108            return decompose_2x2(m, compute_q);
109        }
110
111        let amax_m = m.camax();
112        m.unscale_mut(amax_m.clone());
113
114        let hess = Hessenberg::new_with_workspace(m, work);
115        let mut q;
116        let mut t;
117
118        if compute_q {
119            // TODO: could we work without unpacking? Using only the internal representation of
120            // hessenberg decomposition.
121            let (vecs, vals) = hess.unpack();
122            q = Some(vecs);
123            t = vals;
124        } else {
125            q = None;
126            t = hess.unpack_h()
127        }
128
129        // Implicit double-shift QR method.
130        let mut niter = 0;
131        let (mut start, mut end) = Self::delimit_subproblem(&mut t, eps.clone(), dim.value() - 1);
132
133        while end != start {
134            let subdim = end - start + 1;
135
136            if subdim > 2 {
137                let m = end - 1;
138                let n = end;
139
140                let h11 = t[(start, start)].clone();
141                let h12 = t[(start, start + 1)].clone();
142                let h21 = t[(start + 1, start)].clone();
143                let h22 = t[(start + 1, start + 1)].clone();
144                let h32 = t[(start + 2, start + 1)].clone();
145
146                let hnn = t[(n, n)].clone();
147                let hmm = t[(m, m)].clone();
148                let hnm = t[(n, m)].clone();
149                let hmn = t[(m, n)].clone();
150
151                let tra = hnn.clone() + hmm.clone();
152                let det = hnn * hmm - hnm * hmn;
153
154                let mut axis = Vector3::new(
155                    h11.clone() * h11.clone() + h12 * h21.clone() - tra.clone() * h11.clone() + det,
156                    h21.clone() * (h11 + h22 - tra),
157                    h21 * h32,
158                );
159
160                for k in start..n - 1 {
161                    let (norm, not_zero) = householder::reflection_axis_mut(&mut axis);
162
163                    if not_zero {
164                        if k > start {
165                            t[(k, k - 1)] = norm;
166                            t[(k + 1, k - 1)] = T::zero();
167                            t[(k + 2, k - 1)] = T::zero();
168                        }
169
170                        let refl = Reflection::new(Unit::new_unchecked(axis.clone()), T::zero());
171
172                        {
173                            let krows = cmp::min(k + 4, end + 1);
174                            let mut work = work.rows_mut(0, krows);
175                            refl.reflect(
176                                &mut t.generic_view_mut((k, k), (Const::<3>, Dyn(dim.value() - k))),
177                            );
178                            refl.reflect_rows(
179                                &mut t.generic_view_mut((0, k), (Dyn(krows), Const::<3>)),
180                                &mut work,
181                            );
182                        }
183
184                        if let Some(ref mut q) = q {
185                            refl.reflect_rows(
186                                &mut q.generic_view_mut((0, k), (dim, Const::<3>)),
187                                work,
188                            );
189                        }
190                    }
191
192                    axis.x = t[(k + 1, k)].clone();
193                    axis.y = t[(k + 2, k)].clone();
194
195                    if k < n - 2 {
196                        axis.z = t[(k + 3, k)].clone();
197                    }
198                }
199
200                let mut axis = Vector2::new(axis.x.clone(), axis.y.clone());
201                let (norm, not_zero) = householder::reflection_axis_mut(&mut axis);
202
203                if not_zero {
204                    let refl = Reflection::new(Unit::new_unchecked(axis), T::zero());
205
206                    t[(m, m - 1)] = norm;
207                    t[(n, m - 1)] = T::zero();
208
209                    {
210                        let mut work = work.rows_mut(0, end + 1);
211                        refl.reflect(
212                            &mut t.generic_view_mut((m, m), (Const::<2>, Dyn(dim.value() - m))),
213                        );
214                        refl.reflect_rows(
215                            &mut t.generic_view_mut((0, m), (Dyn(end + 1), Const::<2>)),
216                            &mut work,
217                        );
218                    }
219
220                    if let Some(ref mut q) = q {
221                        refl.reflect_rows(&mut q.generic_view_mut((0, m), (dim, Const::<2>)), work);
222                    }
223                }
224            } else {
225                // Decouple the 2x2 block if it has real eigenvalues.
226                if let Some(rot) = compute_2x2_basis(&t.fixed_view::<2, 2>(start, start)) {
227                    let inv_rot = rot.inverse();
228                    inv_rot.rotate(
229                        &mut t.generic_view_mut(
230                            (start, start),
231                            (Const::<2>, Dyn(dim.value() - start)),
232                        ),
233                    );
234                    rot.rotate_rows(
235                        &mut t.generic_view_mut((0, start), (Dyn(end + 1), Const::<2>)),
236                    );
237                    t[(end, start)] = T::zero();
238
239                    if let Some(ref mut q) = q {
240                        rot.rotate_rows(&mut q.generic_view_mut((0, start), (dim, Const::<2>)));
241                    }
242                }
243
244                // Check if we reached the beginning of the matrix.
245                if end > 2 {
246                    end -= 2;
247                } else {
248                    break;
249                }
250            }
251
252            let sub = Self::delimit_subproblem(&mut t, eps.clone(), end);
253
254            start = sub.0;
255            end = sub.1;
256
257            niter += 1;
258            if niter == max_niter {
259                return None;
260            }
261        }
262
263        t.scale_mut(amax_m);
264
265        Some((q, t))
266    }
267
268    /// Computes the eigenvalues of the decomposed matrix.
269    fn do_eigenvalues(t: &OMatrix<T, D, D>, out: &mut OVector<T, D>) -> bool {
270        let dim = t.nrows();
271        let mut m = 0;
272
273        while m < dim - 1 {
274            let n = m + 1;
275
276            if t[(n, m)].is_zero() {
277                out[m] = t[(m, m)].clone();
278                m += 1;
279            } else {
280                // Complex eigenvalue.
281                return false;
282            }
283        }
284
285        if m == dim - 1 {
286            out[m] = t[(m, m)].clone();
287        }
288
289        true
290    }
291
292    /// Computes the complex eigenvalues of the decomposed matrix.
293    fn do_complex_eigenvalues(t: &OMatrix<T, D, D>, out: &mut UninitVector<NumComplex<T>, D>)
294    where
295        T: RealField,
296        DefaultAllocator: Allocator<D>,
297    {
298        let dim = t.nrows();
299        let mut m = 0;
300
301        while m < dim - 1 {
302            let n = m + 1;
303
304            if t[(n, m)].is_zero() {
305                out[m] = MaybeUninit::new(NumComplex::new(t[(m, m)].clone(), T::zero()));
306                m += 1;
307            } else {
308                // Solve the 2x2 eigenvalue subproblem.
309                let hmm = t[(m, m)].clone();
310                let hnm = t[(n, m)].clone();
311                let hmn = t[(m, n)].clone();
312                let hnn = t[(n, n)].clone();
313
314                // NOTE: use the same algorithm as in compute_2x2_eigvals.
315                let val = (hmm.clone() - hnn.clone()) * crate::convert(0.5);
316                let discr = hnm * hmn + val.clone() * val;
317
318                // All 2x2 blocks have negative discriminant because we already decoupled those
319                // with positive eigenvalues.
320                let sqrt_discr = NumComplex::new(T::zero(), (-discr).sqrt());
321
322                let half_tra = (hnn + hmm) * crate::convert(0.5);
323                out[m] = MaybeUninit::new(
324                    NumComplex::new(half_tra.clone(), T::zero()) + sqrt_discr.clone(),
325                );
326                out[m + 1] =
327                    MaybeUninit::new(NumComplex::new(half_tra, T::zero()) - sqrt_discr.clone());
328
329                m += 2;
330            }
331        }
332
333        if m == dim - 1 {
334            out[m] = MaybeUninit::new(NumComplex::new(t[(m, m)].clone(), T::zero()));
335        }
336    }
337
338    fn delimit_subproblem(t: &mut OMatrix<T, D, D>, eps: T::RealField, end: usize) -> (usize, usize)
339    where
340        D: DimSub<U1>,
341        DefaultAllocator: Allocator<DimDiff<D, U1>>,
342    {
343        let mut n = end;
344
345        while n > 0 {
346            let m = n - 1;
347
348            if t[(n, m)].clone().norm1()
349                <= eps.clone() * (t[(n, n)].clone().norm1() + t[(m, m)].clone().norm1())
350            {
351                t[(n, m)] = T::zero();
352            } else {
353                break;
354            }
355
356            n -= 1;
357        }
358
359        if n == 0 {
360            return (0, 0);
361        }
362
363        let mut new_start = n - 1;
364        while new_start > 0 {
365            let m = new_start - 1;
366
367            let off_diag = t[(new_start, m)].clone();
368            if off_diag.is_zero()
369                || off_diag.norm1()
370                    <= eps.clone()
371                        * (t[(new_start, new_start)].clone().norm1() + t[(m, m)].clone().norm1())
372            {
373                t[(new_start, m)] = T::zero();
374                break;
375            }
376
377            new_start -= 1;
378        }
379
380        (new_start, n)
381    }
382
383    /// Retrieves the unitary matrix `Q` and the upper-quasitriangular matrix `T` such that the
384    /// decomposed matrix equals `Q * T * Q.transpose()`.
385    pub fn unpack(self) -> (OMatrix<T, D, D>, OMatrix<T, D, D>) {
386        (self.q, self.t)
387    }
388
389    /// Computes the real eigenvalues of the decomposed matrix.
390    ///
391    /// Return `None` if some eigenvalues are complex.
392    #[must_use]
393    pub fn eigenvalues(&self) -> Option<OVector<T, D>> {
394        let mut out = Matrix::zeros_generic(self.t.shape_generic().0, Const::<1>);
395        if Self::do_eigenvalues(&self.t, &mut out) {
396            Some(out)
397        } else {
398            None
399        }
400    }
401
402    /// Computes the complex eigenvalues of the decomposed matrix.
403    #[must_use]
404    pub fn complex_eigenvalues(&self) -> OVector<NumComplex<T>, D>
405    where
406        T: RealField,
407        DefaultAllocator: Allocator<D>,
408    {
409        let mut out = Matrix::uninit(self.t.shape_generic().0, Const::<1>);
410        Self::do_complex_eigenvalues(&self.t, &mut out);
411        // Safety: out has been fully initialized by do_complex_eigenvalues.
412        unsafe { out.assume_init() }
413    }
414}
415
416fn decompose_2x2<T: ComplexField, D: Dim>(
417    mut m: OMatrix<T, D, D>,
418    compute_q: bool,
419) -> Option<(Option<OMatrix<T, D, D>>, OMatrix<T, D, D>)>
420where
421    DefaultAllocator: Allocator<D, D>,
422{
423    let dim = m.shape_generic().0;
424    let mut q = None;
425    match compute_2x2_basis(&m.fixed_view::<2, 2>(0, 0)) {
426        Some(rot) => {
427            let mut m = m.fixed_view_mut::<2, 2>(0, 0);
428            let inv_rot = rot.inverse();
429            inv_rot.rotate(&mut m);
430            rot.rotate_rows(&mut m);
431            m[(1, 0)] = T::zero();
432
433            if compute_q {
434                // XXX: we have to build the matrix manually because
435                // rot.to_rotation_matrix().unwrap() causes an ICE.
436                let c = T::from_real(rot.c());
437                q = Some(OMatrix::from_column_slice_generic(
438                    dim,
439                    dim,
440                    &[c.clone(), rot.s(), -rot.s().conjugate(), c],
441                ));
442            }
443        }
444        None => {
445            if compute_q {
446                q = Some(OMatrix::identity_generic(dim, dim));
447            }
448        }
449    };
450
451    Some((q, m))
452}
453
454fn compute_2x2_eigvals<T: ComplexField, S: Storage<T, U2, U2>>(
455    m: &SquareMatrix<T, U2, S>,
456) -> Option<(T, T)> {
457    // Solve the 2x2 eigenvalue subproblem.
458    let h00 = m[(0, 0)].clone();
459    let h10 = m[(1, 0)].clone();
460    let h01 = m[(0, 1)].clone();
461    let h11 = m[(1, 1)].clone();
462
463    // NOTE: this discriminant computation is more stable than the
464    // one based on the trace and determinant: 0.25 * tra * tra - det
465    // because it ensures positiveness for symmetric matrices.
466    let val = (h00.clone() - h11.clone()) * crate::convert(0.5);
467    let discr = h10 * h01 + val.clone() * val;
468
469    discr.try_sqrt().map(|sqrt_discr| {
470        let half_tra = (h00 + h11) * crate::convert(0.5);
471        (half_tra.clone() + sqrt_discr.clone(), half_tra - sqrt_discr)
472    })
473}
474
475// Computes the 2x2 transformation that upper-triangulates a 2x2 matrix with real eigenvalues.
476/// Computes the singular vectors for a 2x2 matrix.
477///
478/// Returns `None` if the matrix has complex eigenvalues, or is upper-triangular. In both case,
479/// the basis is the identity.
480fn compute_2x2_basis<T: ComplexField, S: Storage<T, U2, U2>>(
481    m: &SquareMatrix<T, U2, S>,
482) -> Option<GivensRotation<T>> {
483    let h10 = m[(1, 0)].clone();
484
485    if h10.is_zero() {
486        return None;
487    }
488
489    if let Some((eigval1, eigval2)) = compute_2x2_eigvals(m) {
490        let x1 = eigval1 - m[(1, 1)].clone();
491        let x2 = eigval2 - m[(1, 1)].clone();
492
493        // NOTE: Choose the one that yields a larger x component.
494        // This is necessary for numerical stability of the normalization of the complex
495        // number.
496        if x1.clone().norm1() > x2.clone().norm1() {
497            Some(GivensRotation::new(x1, h10).0)
498        } else {
499            Some(GivensRotation::new(x2, h10).0)
500        }
501    } else {
502        None
503    }
504}
505
506impl<T: ComplexField, D: Dim, S: Storage<T, D, D>> SquareMatrix<T, D, S>
507where
508    D: DimSub<U1>, // For Hessenberg.
509    DefaultAllocator:
510        Allocator<D, DimDiff<D, U1>> + Allocator<DimDiff<D, U1>> + Allocator<D, D> + Allocator<D>,
511{
512    /// Computes the eigenvalues of this matrix.
513    #[must_use]
514    pub fn eigenvalues(&self) -> Option<OVector<T, D>> {
515        assert!(
516            self.is_square(),
517            "Unable to compute eigenvalues of a non-square matrix."
518        );
519
520        let mut work = Matrix::zeros_generic(self.shape_generic().0, Const::<1>);
521
522        // Special case for 2x2 matrices.
523        if self.nrows() == 2 {
524            // TODO: can we avoid this slicing
525            // (which is needed here just to transform D to U2)?
526            let me = self.fixed_view::<2, 2>(0, 0);
527            return match compute_2x2_eigvals(&me) {
528                Some((a, b)) => {
529                    work[0] = a;
530                    work[1] = b;
531                    Some(work)
532                }
533                None => None,
534            };
535        }
536
537        // TODO: add balancing?
538        let schur = Schur::do_decompose(
539            self.clone_owned(),
540            &mut work,
541            T::RealField::default_epsilon(),
542            0,
543            false,
544        )
545        .unwrap();
546
547        if Schur::do_eigenvalues(&schur.1, &mut work) {
548            Some(work)
549        } else {
550            None
551        }
552    }
553
554    /// Computes the eigenvalues of this matrix.
555    #[must_use]
556    pub fn complex_eigenvalues(&self) -> OVector<NumComplex<T>, D>
557    // TODO: add balancing?
558    where
559        T: RealField,
560        DefaultAllocator: Allocator<D>,
561    {
562        let dim = self.shape_generic().0;
563        let mut work = Matrix::zeros_generic(dim, Const::<1>);
564
565        let schur = Schur::do_decompose(
566            self.clone_owned(),
567            &mut work,
568            T::default_epsilon(),
569            0,
570            false,
571        )
572        .unwrap();
573        let mut eig = Matrix::uninit(dim, Const::<1>);
574        Schur::do_complex_eigenvalues(&schur.1, &mut eig);
575        // Safety: eig has been fully initialized by do_complex_eigenvalues.
576        unsafe { eig.assume_init() }
577    }
578}