1#[cfg(feature = "serde-serialize-no-std")]
2use serde::{Deserialize, Serialize};
3
4use crate::allocator::Allocator;
5use crate::base::{DefaultAllocator, OMatrix, OVector};
6use crate::dimension::{Const, DimDiff, DimSub, U1};
7use simba::scalar::ComplexField;
8
9use crate::linalg::householder;
10use crate::Matrix;
11use std::mem::MaybeUninit;
12
13#[cfg_attr(feature = "serde-serialize-no-std", derive(Serialize, Deserialize))]
15#[cfg_attr(
16 feature = "serde-serialize-no-std",
17 serde(bound(serialize = "DefaultAllocator: Allocator<D, D> +
18 Allocator<DimDiff<D, U1>>,
19 OMatrix<T, D, D>: Serialize,
20 OVector<T, DimDiff<D, U1>>: Serialize"))
21)]
22#[cfg_attr(
23 feature = "serde-serialize-no-std",
24 serde(bound(deserialize = "DefaultAllocator: Allocator<D, D> +
25 Allocator<DimDiff<D, U1>>,
26 OMatrix<T, D, D>: Deserialize<'de>,
27 OVector<T, DimDiff<D, U1>>: Deserialize<'de>"))
28)]
29#[derive(Clone, Debug)]
30pub struct SymmetricTridiagonal<T: ComplexField, D: DimSub<U1>>
31where
32 DefaultAllocator: Allocator<D, D> + Allocator<DimDiff<D, U1>>,
33{
34 tri: OMatrix<T, D, D>,
35 off_diagonal: OVector<T, DimDiff<D, U1>>,
36}
37
38impl<T: ComplexField, D: DimSub<U1>> Copy for SymmetricTridiagonal<T, D>
39where
40 DefaultAllocator: Allocator<D, D> + Allocator<DimDiff<D, U1>>,
41 OMatrix<T, D, D>: Copy,
42 OVector<T, DimDiff<D, U1>>: Copy,
43{
44}
45
46impl<T: ComplexField, D: DimSub<U1>> SymmetricTridiagonal<T, D>
47where
48 DefaultAllocator: Allocator<D, D> + Allocator<DimDiff<D, U1>>,
49{
50 pub fn new(mut m: OMatrix<T, D, D>) -> Self {
54 let dim = m.shape_generic().0;
55
56 assert!(
57 m.is_square(),
58 "Unable to compute the symmetric tridiagonal decomposition of a non-square matrix."
59 );
60 assert!(
61 dim.value() != 0,
62 "Unable to compute the symmetric tridiagonal decomposition of an empty matrix."
63 );
64
65 let mut off_diagonal = Matrix::uninit(dim.sub(Const::<1>), Const::<1>);
66 let mut p = Matrix::zeros_generic(dim.sub(Const::<1>), Const::<1>);
67
68 for i in 0..dim.value() - 1 {
69 let mut m = m.rows_range_mut(i + 1..);
70 let (mut axis, mut m) = m.columns_range_pair_mut(i, i + 1..);
71
72 let (norm, not_zero) = householder::reflection_axis_mut(&mut axis);
73 off_diagonal[i] = MaybeUninit::new(norm);
74
75 if not_zero {
76 let mut p = p.rows_range_mut(i..);
77
78 p.hegemv(crate::convert(2.0), &m, &axis, T::zero());
79
80 let dot = axis.dotc(&p);
81 m.hegerc(-T::one(), &p, &axis, T::one());
82 m.hegerc(-T::one(), &axis, &p, T::one());
83 m.hegerc(dot * crate::convert(2.0), &axis, &axis, T::one());
84 }
85 }
86
87 let off_diagonal = unsafe { off_diagonal.assume_init() };
89 Self {
90 tri: m,
91 off_diagonal,
92 }
93 }
94
95 #[doc(hidden)]
96 pub fn internal_tri(&self) -> &OMatrix<T, D, D> {
98 &self.tri
99 }
100
101 pub fn unpack(
104 self,
105 ) -> (
106 OMatrix<T, D, D>,
107 OVector<T::RealField, D>,
108 OVector<T::RealField, DimDiff<D, U1>>,
109 )
110 where
111 DefaultAllocator: Allocator<D> + Allocator<DimDiff<D, U1>>,
112 {
113 let diag = self.diagonal();
114 let q = self.q();
115
116 (q, diag, self.off_diagonal.map(T::modulus))
117 }
118
119 pub fn unpack_tridiagonal(
121 self,
122 ) -> (
123 OVector<T::RealField, D>,
124 OVector<T::RealField, DimDiff<D, U1>>,
125 )
126 where
127 DefaultAllocator: Allocator<D> + Allocator<DimDiff<D, U1>>,
128 {
129 (self.diagonal(), self.off_diagonal.map(T::modulus))
130 }
131
132 #[must_use]
134 pub fn diagonal(&self) -> OVector<T::RealField, D>
135 where
136 DefaultAllocator: Allocator<D>,
137 {
138 self.tri.map_diagonal(|e| e.real())
139 }
140
141 #[must_use]
143 pub fn off_diagonal(&self) -> OVector<T::RealField, DimDiff<D, U1>>
144 where
145 DefaultAllocator: Allocator<DimDiff<D, U1>>,
146 {
147 self.off_diagonal.map(T::modulus)
148 }
149
150 #[must_use]
152 pub fn q(&self) -> OMatrix<T, D, D> {
153 householder::assemble_q(&self.tri, self.off_diagonal.as_slice())
154 }
155
156 pub fn recompose(mut self) -> OMatrix<T, D, D> {
158 let q = self.q();
159 self.tri.fill_lower_triangle(T::zero(), 2);
160 self.tri.fill_upper_triangle(T::zero(), 2);
161
162 for i in 0..self.off_diagonal.len() {
163 let val = T::from_real(self.off_diagonal[i].clone().modulus());
164 self.tri[(i + 1, i)] = val.clone();
165 self.tri[(i, i + 1)] = val;
166 }
167
168 &q * self.tri * q.adjoint()
169 }
170}