nalgebra/linalg/
symmetric_tridiagonal.rs

1#[cfg(feature = "serde-serialize-no-std")]
2use serde::{Deserialize, Serialize};
3
4use crate::allocator::Allocator;
5use crate::base::{DefaultAllocator, OMatrix, OVector};
6use crate::dimension::{Const, DimDiff, DimSub, U1};
7use simba::scalar::ComplexField;
8
9use crate::Matrix;
10use crate::linalg::householder;
11use std::mem::MaybeUninit;
12
13/// Tridiagonalization of a symmetric matrix.
14#[cfg_attr(feature = "serde-serialize-no-std", derive(Serialize, Deserialize))]
15#[cfg_attr(
16    feature = "serde-serialize-no-std",
17    serde(bound(serialize = "DefaultAllocator: Allocator<D, D> +
18                           Allocator<DimDiff<D, U1>>,
19         OMatrix<T, D, D>: Serialize,
20         OVector<T, DimDiff<D, U1>>: Serialize"))
21)]
22#[cfg_attr(
23    feature = "serde-serialize-no-std",
24    serde(bound(deserialize = "DefaultAllocator: Allocator<D, D> +
25                           Allocator<DimDiff<D, U1>>,
26         OMatrix<T, D, D>: Deserialize<'de>,
27         OVector<T, DimDiff<D, U1>>: Deserialize<'de>"))
28)]
29#[cfg_attr(feature = "defmt", derive(defmt::Format))]
30#[derive(Clone, Debug)]
31pub struct SymmetricTridiagonal<T: ComplexField, D: DimSub<U1>>
32where
33    DefaultAllocator: Allocator<D, D> + Allocator<DimDiff<D, U1>>,
34{
35    tri: OMatrix<T, D, D>,
36    off_diagonal: OVector<T, DimDiff<D, U1>>,
37}
38
39impl<T: ComplexField, D: DimSub<U1>> Copy for SymmetricTridiagonal<T, D>
40where
41    DefaultAllocator: Allocator<D, D> + Allocator<DimDiff<D, U1>>,
42    OMatrix<T, D, D>: Copy,
43    OVector<T, DimDiff<D, U1>>: Copy,
44{
45}
46
47impl<T: ComplexField, D: DimSub<U1>> SymmetricTridiagonal<T, D>
48where
49    DefaultAllocator: Allocator<D, D> + Allocator<DimDiff<D, U1>>,
50{
51    /// Computes the tridiagonalization of the symmetric matrix `m`.
52    ///
53    /// Only the lower-triangular part (including the diagonal) of `m` is read.
54    pub fn new(mut m: OMatrix<T, D, D>) -> Self {
55        let dim = m.shape_generic().0;
56
57        assert!(
58            m.is_square(),
59            "Unable to compute the symmetric tridiagonal decomposition of a non-square matrix."
60        );
61        assert!(
62            dim.value() != 0,
63            "Unable to compute the symmetric tridiagonal decomposition of an empty matrix."
64        );
65
66        let mut off_diagonal = Matrix::uninit(dim.sub(Const::<1>), Const::<1>);
67        let mut p = Matrix::zeros_generic(dim.sub(Const::<1>), Const::<1>);
68
69        for i in 0..dim.value() - 1 {
70            let mut m = m.rows_range_mut(i + 1..);
71            let (mut axis, mut m) = m.columns_range_pair_mut(i, i + 1..);
72
73            let (norm, not_zero) = householder::reflection_axis_mut(&mut axis);
74            off_diagonal[i] = MaybeUninit::new(norm);
75
76            if not_zero {
77                let mut p = p.rows_range_mut(i..);
78
79                p.hegemv(crate::convert(2.0), &m, &axis, T::zero());
80
81                let dot = axis.dotc(&p);
82                m.hegerc(-T::one(), &p, &axis, T::one());
83                m.hegerc(-T::one(), &axis, &p, T::one());
84                m.hegerc(dot * crate::convert(2.0), &axis, &axis, T::one());
85            }
86        }
87
88        // Safety: off_diagonal has been fully initialized.
89        let off_diagonal = unsafe { off_diagonal.assume_init() };
90        Self {
91            tri: m,
92            off_diagonal,
93        }
94    }
95
96    #[doc(hidden)]
97    // For debugging.
98    pub const fn internal_tri(&self) -> &OMatrix<T, D, D> {
99        &self.tri
100    }
101
102    /// Retrieve the orthogonal transformation, diagonal, and off diagonal elements of this
103    /// decomposition.
104    pub fn unpack(
105        self,
106    ) -> (
107        OMatrix<T, D, D>,
108        OVector<T::RealField, D>,
109        OVector<T::RealField, DimDiff<D, U1>>,
110    )
111    where
112        DefaultAllocator: Allocator<D> + Allocator<DimDiff<D, U1>>,
113    {
114        let diag = self.diagonal();
115        let q = self.q();
116
117        (q, diag, self.off_diagonal.map(T::modulus))
118    }
119
120    /// Retrieve the diagonal, and off diagonal elements of this decomposition.
121    pub fn unpack_tridiagonal(
122        self,
123    ) -> (
124        OVector<T::RealField, D>,
125        OVector<T::RealField, DimDiff<D, U1>>,
126    )
127    where
128        DefaultAllocator: Allocator<D> + Allocator<DimDiff<D, U1>>,
129    {
130        (self.diagonal(), self.off_diagonal.map(T::modulus))
131    }
132
133    /// The diagonal components of this decomposition.
134    #[must_use]
135    pub fn diagonal(&self) -> OVector<T::RealField, D>
136    where
137        DefaultAllocator: Allocator<D>,
138    {
139        self.tri.map_diagonal(|e| e.real())
140    }
141
142    /// The off-diagonal components of this decomposition.
143    #[must_use]
144    pub fn off_diagonal(&self) -> OVector<T::RealField, DimDiff<D, U1>>
145    where
146        DefaultAllocator: Allocator<DimDiff<D, U1>>,
147    {
148        self.off_diagonal.map(T::modulus)
149    }
150
151    /// Computes the orthogonal matrix `Q` of this decomposition.
152    #[must_use]
153    pub fn q(&self) -> OMatrix<T, D, D> {
154        householder::assemble_q(&self.tri, self.off_diagonal.as_slice())
155    }
156
157    /// Recomputes the original symmetric matrix.
158    pub fn recompose(mut self) -> OMatrix<T, D, D> {
159        let q = self.q();
160        self.tri.fill_lower_triangle(T::zero(), 2);
161        self.tri.fill_upper_triangle(T::zero(), 2);
162
163        for i in 0..self.off_diagonal.len() {
164            let val = T::from_real(self.off_diagonal[i].clone().modulus());
165            self.tri[(i + 1, i)] = val.clone();
166            self.tri[(i, i + 1)] = val;
167        }
168
169        &q * self.tri * q.adjoint()
170    }
171}