nalgebra/linalg/
symmetric_tridiagonal.rs

1#[cfg(feature = "serde-serialize-no-std")]
2use serde::{Deserialize, Serialize};
3
4use crate::allocator::Allocator;
5use crate::base::{DefaultAllocator, OMatrix, OVector};
6use crate::dimension::{Const, DimDiff, DimSub, U1};
7use simba::scalar::ComplexField;
8
9use crate::linalg::householder;
10use crate::Matrix;
11use std::mem::MaybeUninit;
12
13/// Tridiagonalization of a symmetric matrix.
14#[cfg_attr(feature = "serde-serialize-no-std", derive(Serialize, Deserialize))]
15#[cfg_attr(
16    feature = "serde-serialize-no-std",
17    serde(bound(serialize = "DefaultAllocator: Allocator<D, D> +
18                           Allocator<DimDiff<D, U1>>,
19         OMatrix<T, D, D>: Serialize,
20         OVector<T, DimDiff<D, U1>>: Serialize"))
21)]
22#[cfg_attr(
23    feature = "serde-serialize-no-std",
24    serde(bound(deserialize = "DefaultAllocator: Allocator<D, D> +
25                           Allocator<DimDiff<D, U1>>,
26         OMatrix<T, D, D>: Deserialize<'de>,
27         OVector<T, DimDiff<D, U1>>: Deserialize<'de>"))
28)]
29#[derive(Clone, Debug)]
30pub struct SymmetricTridiagonal<T: ComplexField, D: DimSub<U1>>
31where
32    DefaultAllocator: Allocator<D, D> + Allocator<DimDiff<D, U1>>,
33{
34    tri: OMatrix<T, D, D>,
35    off_diagonal: OVector<T, DimDiff<D, U1>>,
36}
37
38impl<T: ComplexField, D: DimSub<U1>> Copy for SymmetricTridiagonal<T, D>
39where
40    DefaultAllocator: Allocator<D, D> + Allocator<DimDiff<D, U1>>,
41    OMatrix<T, D, D>: Copy,
42    OVector<T, DimDiff<D, U1>>: Copy,
43{
44}
45
46impl<T: ComplexField, D: DimSub<U1>> SymmetricTridiagonal<T, D>
47where
48    DefaultAllocator: Allocator<D, D> + Allocator<DimDiff<D, U1>>,
49{
50    /// Computes the tridiagonalization of the symmetric matrix `m`.
51    ///
52    /// Only the lower-triangular part (including the diagonal) of `m` is read.
53    pub fn new(mut m: OMatrix<T, D, D>) -> Self {
54        let dim = m.shape_generic().0;
55
56        assert!(
57            m.is_square(),
58            "Unable to compute the symmetric tridiagonal decomposition of a non-square matrix."
59        );
60        assert!(
61            dim.value() != 0,
62            "Unable to compute the symmetric tridiagonal decomposition of an empty matrix."
63        );
64
65        let mut off_diagonal = Matrix::uninit(dim.sub(Const::<1>), Const::<1>);
66        let mut p = Matrix::zeros_generic(dim.sub(Const::<1>), Const::<1>);
67
68        for i in 0..dim.value() - 1 {
69            let mut m = m.rows_range_mut(i + 1..);
70            let (mut axis, mut m) = m.columns_range_pair_mut(i, i + 1..);
71
72            let (norm, not_zero) = householder::reflection_axis_mut(&mut axis);
73            off_diagonal[i] = MaybeUninit::new(norm);
74
75            if not_zero {
76                let mut p = p.rows_range_mut(i..);
77
78                p.hegemv(crate::convert(2.0), &m, &axis, T::zero());
79
80                let dot = axis.dotc(&p);
81                m.hegerc(-T::one(), &p, &axis, T::one());
82                m.hegerc(-T::one(), &axis, &p, T::one());
83                m.hegerc(dot * crate::convert(2.0), &axis, &axis, T::one());
84            }
85        }
86
87        // Safety: off_diagonal has been fully initialized.
88        let off_diagonal = unsafe { off_diagonal.assume_init() };
89        Self {
90            tri: m,
91            off_diagonal,
92        }
93    }
94
95    #[doc(hidden)]
96    // For debugging.
97    pub fn internal_tri(&self) -> &OMatrix<T, D, D> {
98        &self.tri
99    }
100
101    /// Retrieve the orthogonal transformation, diagonal, and off diagonal elements of this
102    /// decomposition.
103    pub fn unpack(
104        self,
105    ) -> (
106        OMatrix<T, D, D>,
107        OVector<T::RealField, D>,
108        OVector<T::RealField, DimDiff<D, U1>>,
109    )
110    where
111        DefaultAllocator: Allocator<D> + Allocator<DimDiff<D, U1>>,
112    {
113        let diag = self.diagonal();
114        let q = self.q();
115
116        (q, diag, self.off_diagonal.map(T::modulus))
117    }
118
119    /// Retrieve the diagonal, and off diagonal elements of this decomposition.
120    pub fn unpack_tridiagonal(
121        self,
122    ) -> (
123        OVector<T::RealField, D>,
124        OVector<T::RealField, DimDiff<D, U1>>,
125    )
126    where
127        DefaultAllocator: Allocator<D> + Allocator<DimDiff<D, U1>>,
128    {
129        (self.diagonal(), self.off_diagonal.map(T::modulus))
130    }
131
132    /// The diagonal components of this decomposition.
133    #[must_use]
134    pub fn diagonal(&self) -> OVector<T::RealField, D>
135    where
136        DefaultAllocator: Allocator<D>,
137    {
138        self.tri.map_diagonal(|e| e.real())
139    }
140
141    /// The off-diagonal components of this decomposition.
142    #[must_use]
143    pub fn off_diagonal(&self) -> OVector<T::RealField, DimDiff<D, U1>>
144    where
145        DefaultAllocator: Allocator<DimDiff<D, U1>>,
146    {
147        self.off_diagonal.map(T::modulus)
148    }
149
150    /// Computes the orthogonal matrix `Q` of this decomposition.
151    #[must_use]
152    pub fn q(&self) -> OMatrix<T, D, D> {
153        householder::assemble_q(&self.tri, self.off_diagonal.as_slice())
154    }
155
156    /// Recomputes the original symmetric matrix.
157    pub fn recompose(mut self) -> OMatrix<T, D, D> {
158        let q = self.q();
159        self.tri.fill_lower_triangle(T::zero(), 2);
160        self.tri.fill_upper_triangle(T::zero(), 2);
161
162        for i in 0..self.off_diagonal.len() {
163            let val = T::from_real(self.off_diagonal[i].clone().modulus());
164            self.tri[(i + 1, i)] = val.clone();
165            self.tri[(i, i + 1)] = val;
166        }
167
168        &q * self.tri * q.adjoint()
169    }
170}