1#[cfg(feature = "serde-serialize-no-std")]
2use serde::{Deserialize, Serialize};
3
4use num::{One, Zero};
5use simba::scalar::ComplexField;
6use simba::simd::SimdComplexField;
7
8use crate::allocator::Allocator;
9use crate::base::{Const, DefaultAllocator, Matrix, OMatrix, Vector};
10use crate::constraint::{SameNumberOfRows, ShapeConstraint};
11use crate::dimension::{Dim, DimAdd, DimDiff, DimSub, DimSum, U1};
12use crate::storage::{Storage, StorageMut};
13
14#[cfg_attr(feature = "serde-serialize-no-std", derive(Serialize, Deserialize))]
16#[cfg_attr(
17 feature = "serde-serialize-no-std",
18 serde(bound(serialize = "DefaultAllocator: Allocator<D>,
19 OMatrix<T, D, D>: Serialize"))
20)]
21#[cfg_attr(
22 feature = "serde-serialize-no-std",
23 serde(bound(deserialize = "DefaultAllocator: Allocator<D>,
24 OMatrix<T, D, D>: Deserialize<'de>"))
25)]
26#[derive(Clone, Debug)]
27pub struct Cholesky<T: SimdComplexField, D: Dim>
28where
29 DefaultAllocator: Allocator<D, D>,
30{
31 chol: OMatrix<T, D, D>,
32}
33
34impl<T: SimdComplexField, D: Dim> Copy for Cholesky<T, D>
35where
36 DefaultAllocator: Allocator<D, D>,
37 OMatrix<T, D, D>: Copy,
38{
39}
40
41impl<T: SimdComplexField, D: Dim> Cholesky<T, D>
42where
43 DefaultAllocator: Allocator<D, D>,
44{
45 pub fn new_unchecked(mut matrix: OMatrix<T, D, D>) -> Self {
49 assert!(matrix.is_square(), "The input matrix must be square.");
50
51 let n = matrix.nrows();
52
53 for j in 0..n {
54 for k in 0..j {
55 let factor = unsafe { -matrix.get_unchecked((j, k)).clone() };
56
57 let (mut col_j, col_k) = matrix.columns_range_pair_mut(j, k);
58 let mut col_j = col_j.rows_range_mut(j..);
59 let col_k = col_k.rows_range(j..);
60 col_j.axpy(factor.simd_conjugate(), &col_k, T::one());
61 }
62
63 let diag = unsafe { matrix.get_unchecked((j, j)).clone() };
64 let denom = diag.simd_sqrt();
65
66 unsafe {
67 *matrix.get_unchecked_mut((j, j)) = denom.clone();
68 }
69
70 let mut col = matrix.view_range_mut(j + 1.., j);
71 col /= denom;
72 }
73
74 Cholesky { chol: matrix }
75 }
76
77 pub fn pack_dirty(matrix: OMatrix<T, D, D>) -> Self {
82 Cholesky { chol: matrix }
83 }
84
85 pub fn unpack(mut self) -> OMatrix<T, D, D> {
88 self.chol.fill_upper_triangle(T::zero(), 1);
89 self.chol
90 }
91
92 pub fn unpack_dirty(self) -> OMatrix<T, D, D> {
98 self.chol
99 }
100
101 #[must_use]
104 pub fn l(&self) -> OMatrix<T, D, D> {
105 self.chol.lower_triangle()
106 }
107
108 #[must_use]
114 pub fn l_dirty(&self) -> &OMatrix<T, D, D> {
115 &self.chol
116 }
117
118 pub fn solve_mut<R2: Dim, C2: Dim, S2>(&self, b: &mut Matrix<T, R2, C2, S2>)
122 where
123 S2: StorageMut<T, R2, C2>,
124 ShapeConstraint: SameNumberOfRows<R2, D>,
125 {
126 self.chol.solve_lower_triangular_unchecked_mut(b);
127 self.chol.ad_solve_lower_triangular_unchecked_mut(b);
128 }
129
130 #[must_use = "Did you mean to use solve_mut()?"]
133 pub fn solve<R2: Dim, C2: Dim, S2>(&self, b: &Matrix<T, R2, C2, S2>) -> OMatrix<T, R2, C2>
134 where
135 S2: Storage<T, R2, C2>,
136 DefaultAllocator: Allocator<R2, C2>,
137 ShapeConstraint: SameNumberOfRows<R2, D>,
138 {
139 let mut res = b.clone_owned();
140 self.solve_mut(&mut res);
141 res
142 }
143
144 #[must_use]
146 pub fn inverse(&self) -> OMatrix<T, D, D> {
147 let shape = self.chol.shape_generic();
148 let mut res = OMatrix::identity_generic(shape.0, shape.1);
149
150 self.solve_mut(&mut res);
151 res
152 }
153
154 #[must_use]
156 pub fn determinant(&self) -> T::SimdRealField {
157 let dim = self.chol.nrows();
158 let mut prod_diag = T::one();
159 for i in 0..dim {
160 prod_diag *= unsafe { self.chol.get_unchecked((i, i)).clone() };
161 }
162 prod_diag.simd_modulus_squared()
163 }
164
165 #[must_use]
171 pub fn ln_determinant(&self) -> T::SimdRealField {
172 let dim = self.chol.nrows();
173 let mut sum_diag = T::SimdRealField::zero();
174 for i in 0..dim {
175 sum_diag += unsafe {
176 self.chol
177 .get_unchecked((i, i))
178 .clone()
179 .simd_modulus_squared()
180 .simd_ln()
181 };
182 }
183 sum_diag
184 }
185}
186
187impl<T: ComplexField, D: Dim> Cholesky<T, D>
188where
189 DefaultAllocator: Allocator<D, D>,
190{
191 pub fn new(matrix: OMatrix<T, D, D>) -> Option<Self> {
196 Self::new_internal(matrix, None)
197 }
198
199 pub fn new_with_substitute(matrix: OMatrix<T, D, D>, substitute: T) -> Option<Self> {
216 Self::new_internal(matrix, Some(substitute))
217 }
218
219 fn new_internal(mut matrix: OMatrix<T, D, D>, substitute: Option<T>) -> Option<Self> {
221 assert!(matrix.is_square(), "The input matrix must be square.");
222
223 let n = matrix.nrows();
224
225 for j in 0..n {
226 for k in 0..j {
227 let factor = unsafe { -matrix.get_unchecked((j, k)).clone() };
228
229 let (mut col_j, col_k) = matrix.columns_range_pair_mut(j, k);
230 let mut col_j = col_j.rows_range_mut(j..);
231 let col_k = col_k.rows_range(j..);
232
233 col_j.axpy(factor.conjugate(), &col_k, T::one());
234 }
235
236 let sqrt_denom = |v: T| {
237 if v.is_zero() {
238 return None;
239 }
240 v.try_sqrt()
241 };
242
243 let diag = unsafe { matrix.get_unchecked((j, j)).clone() };
244
245 if let Some(denom) =
246 sqrt_denom(diag).or_else(|| substitute.clone().and_then(sqrt_denom))
247 {
248 unsafe {
249 *matrix.get_unchecked_mut((j, j)) = denom.clone();
250 }
251
252 let mut col = matrix.view_range_mut(j + 1.., j);
253 col /= denom;
254 continue;
255 }
256
257 return None;
260 }
261
262 Some(Cholesky { chol: matrix })
263 }
264
265 #[inline]
268 pub fn rank_one_update<R2: Dim, S2>(&mut self, x: &Vector<T, R2, S2>, sigma: T::RealField)
269 where
270 S2: Storage<T, R2, U1>,
271 DefaultAllocator: Allocator<R2, U1>,
272 ShapeConstraint: SameNumberOfRows<R2, D>,
273 {
274 Self::xx_rank_one_update(&mut self.chol, &mut x.clone_owned(), sigma)
275 }
276
277 pub fn insert_column<R2, S2>(
280 &self,
281 j: usize,
282 col: Vector<T, R2, S2>,
283 ) -> Cholesky<T, DimSum<D, U1>>
284 where
285 D: DimAdd<U1>,
286 R2: Dim,
287 S2: Storage<T, R2, U1>,
288 DefaultAllocator: Allocator<DimSum<D, U1>, DimSum<D, U1>> + Allocator<R2>,
289 ShapeConstraint: SameNumberOfRows<R2, DimSum<D, U1>>,
290 {
291 let mut col = col.into_owned();
292 let n = col.nrows();
294 assert_eq!(
295 n,
296 self.chol.nrows() + 1,
297 "The new column must have the size of the factored matrix plus one."
298 );
299 assert!(j < n, "j needs to be within the bound of the new matrix.");
300
301 let mut chol = Matrix::zeros_generic(
304 self.chol.shape_generic().0.add(Const::<1>),
305 self.chol.shape_generic().1.add(Const::<1>),
306 );
307 chol.view_range_mut(..j, ..j)
308 .copy_from(&self.chol.view_range(..j, ..j));
309 chol.view_range_mut(..j, j + 1..)
310 .copy_from(&self.chol.view_range(..j, j..));
311 chol.view_range_mut(j + 1.., ..j)
312 .copy_from(&self.chol.view_range(j.., ..j));
313 chol.view_range_mut(j + 1.., j + 1..)
314 .copy_from(&self.chol.view_range(j.., j..));
315
316 let top_left_corner = self.chol.view_range(..j, ..j);
318
319 let col_j = col[j].clone();
320 let (mut new_rowj_adjoint, mut new_colj) = col.rows_range_pair_mut(..j, j + 1..);
321 assert!(
322 top_left_corner.solve_lower_triangular_mut(&mut new_rowj_adjoint),
323 "Cholesky::insert_column : Unable to solve lower triangular system!"
324 );
325
326 new_rowj_adjoint.adjoint_to(&mut chol.view_range_mut(j, ..j));
327
328 let center_element = T::sqrt(col_j - T::from_real(new_rowj_adjoint.norm_squared()));
330 chol[(j, j)] = center_element.clone();
331
332 let bottom_left_corner = self.chol.view_range(j.., ..j);
334 new_colj.gemm(
336 -T::one() / center_element.clone(),
337 &bottom_left_corner,
338 &new_rowj_adjoint,
339 T::one() / center_element,
340 );
341 chol.view_range_mut(j + 1.., j).copy_from(&new_colj);
342
343 let mut bottom_right_corner = chol.view_range_mut(j + 1.., j + 1..);
345 Self::xx_rank_one_update(
346 &mut bottom_right_corner,
347 &mut new_colj,
348 -T::RealField::one(),
349 );
350
351 Cholesky { chol }
352 }
353
354 #[must_use]
357 pub fn remove_column(&self, j: usize) -> Cholesky<T, DimDiff<D, U1>>
358 where
359 D: DimSub<U1>,
360 DefaultAllocator: Allocator<DimDiff<D, U1>, DimDiff<D, U1>> + Allocator<D>,
361 {
362 let n = self.chol.nrows();
363 assert!(n > 0, "The matrix needs at least one column.");
364 assert!(j < n, "j needs to be within the bound of the matrix.");
365
366 let mut chol = Matrix::zeros_generic(
369 self.chol.shape_generic().0.sub(Const::<1>),
370 self.chol.shape_generic().1.sub(Const::<1>),
371 );
372 chol.view_range_mut(..j, ..j)
373 .copy_from(&self.chol.view_range(..j, ..j));
374 chol.view_range_mut(..j, j..)
375 .copy_from(&self.chol.view_range(..j, j + 1..));
376 chol.view_range_mut(j.., ..j)
377 .copy_from(&self.chol.view_range(j + 1.., ..j));
378 chol.view_range_mut(j.., j..)
379 .copy_from(&self.chol.view_range(j + 1.., j + 1..));
380
381 let mut bottom_right_corner = chol.view_range_mut(j.., j..);
383 let mut workspace = self.chol.column(j).clone_owned();
384 let mut old_colj = workspace.rows_range_mut(j + 1..);
385 Self::xx_rank_one_update(&mut bottom_right_corner, &mut old_colj, T::RealField::one());
386
387 Cholesky { chol }
388 }
389
390 fn xx_rank_one_update<Dm, Sm, Rx, Sx>(
396 chol: &mut Matrix<T, Dm, Dm, Sm>,
397 x: &mut Vector<T, Rx, Sx>,
398 sigma: T::RealField,
399 ) where
400 Dm: Dim,
402 Rx: Dim,
403 Sm: StorageMut<T, Dm, Dm>,
404 Sx: StorageMut<T, Rx, U1>,
405 {
406 let n = x.nrows();
408 assert_eq!(
409 n,
410 chol.nrows(),
411 "The input vector must be of the same size as the factorized matrix."
412 );
413
414 let mut beta = crate::one::<T::RealField>();
415
416 for j in 0..n {
417 let diag = T::real(unsafe { chol.get_unchecked((j, j)).clone() });
419 let diag2 = diag.clone() * diag.clone();
420 let xj = unsafe { x.get_unchecked(j).clone() };
421 let sigma_xj2 = sigma.clone() * T::modulus_squared(xj.clone());
422 let gamma = diag2.clone() * beta.clone() + sigma_xj2.clone();
423 let new_diag = (diag2.clone() + sigma_xj2.clone() / beta.clone()).sqrt();
424 unsafe { *chol.get_unchecked_mut((j, j)) = T::from_real(new_diag.clone()) };
425 beta += sigma_xj2 / diag2;
426 let mut xjplus = x.rows_range_mut(j + 1..);
428 let mut col_j = chol.view_range_mut(j + 1.., j);
429 xjplus.axpy(-xj.clone() / T::from_real(diag.clone()), &col_j, T::one());
431 if gamma != crate::zero::<T::RealField>() {
432 col_j.axpy(
434 T::from_real(new_diag.clone() * sigma.clone() / gamma) * T::conjugate(xj),
435 &xjplus,
436 T::from_real(new_diag / diag),
437 );
438 }
439 }
440 }
441}