Skip to main content

nalgebra/linalg/
lblt.rs

1use crate::{
2    ComplexField, DefaultAllocator, Dim, Matrix, OMatrix, OVector, RealField, Storage, U1,
3    allocator::Allocator, convert,
4};
5use num_traits::{One, Zero};
6
7#[cfg(feature = "serde-serialize-no-std")]
8use serde::{Deserialize, Serialize};
9
10/// A pivot represented as `(index, block_size)`, where `block_size` is either 1 or 2.
11type Pivot = (usize, usize);
12
13/// Bunch–Kaufman LBL^H factorization of a Hermitian matrix with symmetric pivoting.
14#[cfg_attr(feature = "serde-serialize-no-std", derive(Serialize, Deserialize))]
15#[cfg_attr(
16    feature = "serde-serialize-no-std",
17    serde(bound(serialize = "DefaultAllocator: Allocator<N, N>,
18         OMatrix<T, N, N>: Serialize,
19         OVector<Pivot, N>: Serialize,
20         Option<usize>: Serialize"))
21)]
22#[cfg_attr(
23    feature = "serde-serialize-no-std",
24    serde(bound(deserialize = "DefaultAllocator: Allocator<N, N>,
25         OMatrix<T, N, N>: Deserialize<'de>,
26         OVector<Pivot, N>: Deserialize<'de>,
27         Option<usize>: Deserialize<'de>"))
28)]
29#[cfg_attr(feature = "defmt", derive(defmt::Format))]
30#[derive(Clone, Debug)]
31pub struct LBLT<T: ComplexField, N: Dim>
32where
33    DefaultAllocator: Allocator<N> + Allocator<N, N>,
34{
35    matrix: OMatrix<T, N, N>,
36    pivots: OVector<Pivot, N>,
37    zero_pivot: Option<usize>,
38}
39
40impl<T: Copy + ComplexField, N: Dim> LBLT<T, N>
41where
42    T::RealField: Copy,
43    DefaultAllocator: Allocator<N> + Allocator<N, N>,
44{
45    /// Compute the factorization of a complex Hermitian matrix using the Bunch-Kaufman
46    /// block-diagonal pivoting method:
47    ///
48    /// P A P^T = L * B * L^H
49    ///
50    /// where P is the permutation induced by the pivot sequence, L is unit lower
51    /// triangular in the permuted basis, and B is Hermitian block diagonal with
52    /// 1-by-1 and 2-by-2 diagonal blocks.
53    ///
54    /// This implementation follows the partial pivoting (Algorithm A) variant from
55    /// Bunch & Kaufman (1977), which is also the basis for LAPACK’s `?sytrf/?hetrf` routines.
56    pub fn new(mut matrix: OMatrix<T, N, N>) -> Self {
57        assert!(matrix.is_square());
58        let n = matrix.nrows();
59
60        let mut pivots = OVector::from_element_generic(matrix.shape_generic().0, U1, (0, 0));
61        let mut zero_pivot = None;
62
63        // Bunch–Kaufman pivot threshold: (1 + sqrt(17)) / 8
64        let alpha: T::RealField = convert(0.6403882032022076);
65
66        // current pivot position
67        let mut k = 0;
68
69        while k < n {
70            let mut block_size = 1;
71
72            // Ensure the diagonal element is real
73            matrix[(k, k)] = T::from_real(matrix[(k, k)].real());
74            let diag_abs = matrix[(k, k)].real().abs();
75
76            // Row index and magnitude of the largest off-diagonal entry in the active part
77            // of column k.
78            let (imax, colmax) = if k + 1 < n {
79                let mut imax = k + 1;
80                let mut colmax = matrix[(imax, k)].norm1();
81                for i in (k + 2)..n {
82                    let magnitude = matrix[(i, k)].norm1();
83                    if magnitude > colmax {
84                        imax = i;
85                        colmax = magnitude;
86                    }
87                }
88                (imax, colmax)
89            } else {
90                // If k is the last column, there is no off-diagonal candidate.
91                (0, T::RealField::zero())
92            };
93
94            if diag_abs.max(colmax) == T::RealField::zero() {
95                // Column k is zero: store a 1x1 pivot, and skip all other logic.
96                if zero_pivot.is_none() {
97                    zero_pivot = Some(k);
98                }
99
100                pivots[k] = (k, 1);
101                k += 1;
102                continue;
103            }
104
105            let pivot_index: usize;
106
107            if diag_abs < alpha * colmax {
108                let mut rowmax = T::RealField::zero();
109                for j in k..imax {
110                    rowmax = rowmax.max(matrix[(imax, j)].norm1());
111                }
112                for j in (imax + 1)..matrix.nrows() {
113                    rowmax = rowmax.max(matrix[(j, imax)].norm1());
114                }
115
116                if diag_abs >= alpha * colmax * (colmax / rowmax) {
117                    // Even though A[k, k] is not diagonally dominant, it is still large enough
118                    // compared with the candidate row/column growth, so keep a 1x1 pivot at k.
119                    pivot_index = k;
120                } else {
121                    pivot_index = imax;
122
123                    if matrix[(imax, imax)].real().abs() < alpha * rowmax {
124                        // The candidate diagonal at imax is also too small relative to its row
125                        // maximum, so use a 2x2 pivot block involving k and k+1.
126                        block_size = 2;
127                    }
128                }
129            } else {
130                // The diagonal dominates column k strongly enough to use A[k, k] as a 1x1 pivot
131                // without any row/column interchange.
132                pivot_index = k;
133            }
134
135            let pivot_target = k + block_size - 1;
136
137            if pivot_index != pivot_target {
138                // Hermitian two-sided interchange for the chosen pivot.
139                for i in (pivot_index + 1)..matrix.nrows() {
140                    // Swap entries below both indices
141                    matrix.swap((i, pivot_target), (i, pivot_index));
142                }
143
144                for j in (pivot_target + 1)..pivot_index {
145                    // Swap the strip between the two indices.
146                    matrix.swap((j, pivot_target), (pivot_index, j));
147                    matrix[(j, pivot_target)] = matrix[(j, pivot_target)].conjugate();
148                    matrix[(pivot_index, j)] = matrix[(pivot_index, j)].conjugate();
149                }
150
151                // The cross entry between the swapped indices remains in the same slot.
152                matrix[(pivot_index, pivot_target)] =
153                    matrix[(pivot_index, pivot_target)].conjugate();
154
155                // Swap the diagonal entries.
156                matrix.swap((pivot_target, pivot_target), (pivot_index, pivot_index));
157
158                if k + 1 == pivot_target {
159                    // For a 2x2 pivot block, move the off-diagonal block entry.
160                    matrix.swap((k + 1, k), (pivot_index, k));
161                }
162            }
163
164            if block_size == 1 {
165                // 1x1 pivot block D(k)
166                if k + 1 < n {
167                    let inv_diag = T::RealField::one() / matrix[(k, k)].real();
168
169                    for j in (k + 1)..n {
170                        let jk_conj = matrix[(j, k)].conjugate();
171
172                        for i in j..n {
173                            // Rank-1 Hermitian update of the trailing submatrix.
174                            matrix[(i, j)] =
175                                matrix[(i, j)] + matrix[(i, k)].scale(-inv_diag) * jk_conj;
176                        }
177
178                        // Keep the Hermitian diagonal explicitly real.
179                        matrix[(j, j)] = T::from_real(matrix[(j, j)].real());
180                    }
181
182                    // Normalize column k so that it stores the multipliers of L.
183                    for i in (k + 1)..n {
184                        matrix[(i, k)] = matrix[(i, k)].scale(inv_diag);
185                    }
186                }
187
188                pivots[k] = (pivot_index, 1);
189            } else {
190                // 2x2 pivot block D(k:k+1)
191                if k + 2 < n {
192                    // Form the scaled inverse-coefficient data for the 2x2 Hermitian pivot block.
193                    let d = matrix[(k + 1, k)].abs();
194                    let d11 = matrix[(k + 1, k + 1)].real() / d;
195                    let d22 = matrix[(k, k)].real() / d;
196                    let d21 = matrix[(k + 1, k)].unscale(d);
197                    let scale = T::RealField::one() / (d * (d11 * d22 - T::RealField::one()));
198
199                    for j in (k + 2)..n {
200                        // These are the two transformed entries for row j. Together they represent
201                        // the action of inv(D(k:k+1)) on the stored columns k and k+1.
202                        let work1 =
203                            (matrix[(j, k)].scale(d11) - matrix[(j, k + 1)] * d21).scale(scale);
204                        let work2 = (matrix[(j, k + 1)].scale(d22)
205                            - matrix[(j, k)] * d21.conjugate())
206                        .scale(scale);
207
208                        for i in j..n {
209                            // Rank-2 Hermitian update of the trailing submatrix.
210                            matrix[(i, j)] = matrix[(i, j)]
211                                - matrix[(i, k)] * work1.conjugate()
212                                - matrix[(i, k + 1)] * work2.conjugate();
213                        }
214
215                        matrix[(j, k)] = work1;
216                        matrix[(j, k + 1)] = work2;
217
218                        // Keep the Hermitian diagonal explicitly real.
219                        matrix[(j, j)] = T::from_real(matrix[(j, j)].real());
220                    }
221                }
222
223                pivots[k] = (pivot_index, 2);
224                pivots[k + 1] = (pivot_index, 2);
225            }
226
227            k += block_size;
228        }
229
230        Self {
231            matrix,
232            pivots,
233            zero_pivot,
234        }
235    }
236
237    /// Returns the permutation-aware factor P^T L.
238    ///
239    /// This factor can be combined directly with `d()` to reconstruct the original
240    /// matrix. In general `P^T L` is not lower triangular, even though `L` itself is
241    /// unit lower triangular in the permuted basis.
242    pub fn l_permuted(&self) -> OMatrix<T, N, N> {
243        let n = self.matrix.nrows();
244        let (nrows, ncols) = self.matrix.shape_generic();
245        let mut l_permuted = OMatrix::identity_generic(nrows, ncols);
246
247        let mut k = 0;
248        while k < n {
249            let (pivot_index, block_size) = self.pivots[k];
250
251            if block_size == 1 {
252                // Right-multiply by the permutation: swap the affected columns.
253                l_permuted.swap_columns(k, pivot_index);
254
255                // Right-multiply by the unit-lower factor for this 1x1 step.
256                for row in 0..n {
257                    for i in (k + 1)..n {
258                        l_permuted[(row, k)] =
259                            l_permuted[(row, k)] + l_permuted[(row, i)] * self.matrix[(i, k)];
260                    }
261                }
262
263                k += 1;
264            } else {
265                // Right-multiply by the permutation: swap the affected columns.
266                l_permuted.swap_columns(k + 1, pivot_index);
267
268                // Right-multiply by the unit-lower factor for this 2x2 step.
269                for row in 0..n {
270                    for i in (k + 2)..n {
271                        l_permuted[(row, k)] =
272                            l_permuted[(row, k)] + l_permuted[(row, i)] * self.matrix[(i, k)];
273                        l_permuted[(row, k + 1)] = l_permuted[(row, k + 1)]
274                            + l_permuted[(row, i)] * self.matrix[(i, k + 1)];
275                    }
276                }
277
278                k += 2;
279            }
280        }
281
282        l_permuted
283    }
284
285    /// The block diagonal matrix of this decomposition.
286    pub fn d(&self) -> OMatrix<T, N, N> {
287        let n = self.matrix.nrows();
288        let (nrows, ncols) = self.matrix.shape_generic();
289        let mut d = OMatrix::zeros_generic(nrows, ncols);
290
291        let mut k = 0;
292        while k < n {
293            d[(k, k)] = self.matrix[(k, k)];
294            if self.pivots[k].1 == 2 {
295                d[(k + 1, k)] = self.matrix[(k + 1, k)];
296                d[(k, k + 1)] = self.matrix[(k + 1, k)].conjugate();
297                d[(k + 1, k + 1)] = self.matrix[(k + 1, k + 1)];
298                k += 1;
299            }
300            k += 1;
301        }
302
303        d
304    }
305
306    /// Solves the linear system A * x = b using this factorization.
307    pub fn solve<M: Dim, S>(&self, b: &Matrix<T, N, M, S>) -> Option<OMatrix<T, N, M>>
308    where
309        S: Storage<T, N, M>,
310        DefaultAllocator: Allocator<N, M>,
311    {
312        let mut result = b.clone_owned();
313
314        if self.solve_mut(&mut result) {
315            Some(result)
316        } else {
317            None
318        }
319    }
320
321    /// Solves the linear system A * x = b in place, overwriting `b` with the solution.
322    pub fn solve_mut<M: Dim>(&self, b: &mut OMatrix<T, N, M>) -> bool
323    where
324        DefaultAllocator: Allocator<N, M>,
325    {
326        assert_eq!(self.matrix.nrows(), b.nrows());
327
328        if self.zero_pivot.is_some() {
329            return false;
330        }
331
332        let (n, m) = b.shape();
333
334        // Solve L * y = P^T * b using the stored pivot sequence and multipliers.
335        let mut k = 0;
336        while k < n {
337            let (pivot_index, block_size) = self.pivots[k];
338
339            if block_size == 1 {
340                b.swap_rows(k, pivot_index);
341
342                for j in 0..m {
343                    for i in (k + 1)..n {
344                        b[(i, j)] = b[(i, j)] - self.matrix[(i, k)] * b[(k, j)];
345                    }
346                }
347
348                k += 1;
349            } else {
350                b.swap_rows(k + 1, pivot_index);
351
352                for j in 0..m {
353                    for i in (k + 2)..n {
354                        b[(i, j)] = b[(i, j)]
355                            - self.matrix[(i, k)] * b[(k, j)]
356                            - self.matrix[(i, k + 1)] * b[(k + 1, j)];
357                    }
358                }
359
360                k += 2;
361            }
362        }
363
364        // Solve D * z = y, handling 1x1 and 2x2 diagonal blocks.
365        let mut k = 0;
366        while k < n {
367            if self.pivots[k].1 == 1 {
368                for j in 0..m {
369                    b[(k, j)] = b[(k, j)].unscale(self.matrix[(k, k)].real());
370                }
371                k += 1;
372            } else {
373                let d11 = self.matrix[(k, k)].real();
374                let d22 = self.matrix[(k + 1, k + 1)].real();
375                let d21 = self.matrix[(k + 1, k)];
376
377                let det = d11 * d22 - d21.modulus_squared();
378
379                for j in 0..m {
380                    let b_k = b[(k, j)];
381                    let b_k1 = b[(k + 1, j)];
382
383                    b[(k, j)] = (b_k.scale(d22) - b_k1 * d21.conjugate()).unscale(det);
384                    b[(k + 1, j)] = (b_k1.scale(d11) - b_k * d21).unscale(det);
385                }
386                k += 2;
387            }
388        }
389
390        // Solve L^H * x = z, undoing the pivot sequence in reverse order.
391        let mut k = n;
392        while k > 0 {
393            let k1 = k - 1;
394
395            for j in 0..m {
396                for i in k..n {
397                    b[(k1, j)] = b[(k1, j)] - self.matrix[(i, k1)].conjugate() * b[(i, j)];
398                }
399            }
400
401            if self.pivots[k1].1 == 1 {
402                k -= 1;
403            } else {
404                let k2 = k - 2;
405                for j in 0..m {
406                    for i in k..n {
407                        b[(k2, j)] = b[(k2, j)] - self.matrix[(i, k2)].conjugate() * b[(i, j)];
408                    }
409                }
410                k -= 2;
411            }
412
413            b.swap_rows(k1, self.pivots[k1].0);
414        }
415
416        true
417    }
418
419    /// Computes the determinant of the decomposed matrix.
420    pub fn determinant(&self) -> T::RealField {
421        let n = self.matrix.nrows();
422        let mut determinant = T::RealField::one();
423
424        let mut k = 0;
425        while k < n {
426            if self.pivots[k].1 == 1 {
427                determinant *= self.matrix[(k, k)].real();
428                k += 1;
429            } else {
430                determinant *= self.matrix[(k, k)].real() * self.matrix[(k + 1, k + 1)].real()
431                    - self.matrix[(k + 1, k)].modulus_squared();
432                k += 2;
433            }
434        }
435
436        determinant
437    }
438}
439
440#[cfg(test)]
441mod tests {
442    use crate::{DMatrix, DVector};
443
444    use super::*;
445
446    #[test]
447    fn zero_matrix() {
448        for n in 1..=5 {
449            let lblt = DMatrix::<f64>::zeros(n, n).lblt();
450            assert_eq!(lblt.l_permuted(), DMatrix::identity(n, n));
451            assert_eq!(lblt.d(), DMatrix::zeros(n, n));
452            assert_eq!(lblt.zero_pivot, Some(0));
453            assert_eq!(lblt.pivots, DVector::from_fn(n, |i, _| (i, 1)));
454            assert!(lblt.determinant().is_zero());
455            assert!(lblt.solve(&DVector::from_element(n, 1.0)).is_none());
456        }
457    }
458
459    #[test]
460    fn identity_matrix() {
461        for n in 1..=5 {
462            let identity = DMatrix::<f64>::identity(n, n);
463            let lblt = identity.clone().lblt();
464
465            assert_eq!(lblt.l_permuted(), identity);
466            assert_eq!(lblt.d(), identity);
467            assert_eq!(lblt.zero_pivot, None);
468            assert_eq!(lblt.pivots, DVector::from_fn(n, |i, _| (i, 1)));
469            assert!(lblt.determinant().is_one());
470        }
471    }
472
473    #[test]
474    fn exchange_matrix() {
475        for n in 1..=15 {
476            let exchange = DMatrix::from_fn(n, n, |i, j| if i + j + 1 == n { 1.0 } else { 0.0 });
477            let lblt = exchange.clone().lblt();
478
479            let mut expected = Vec::with_capacity(n);
480            let m = (n + 2) / 4;
481            for r in 0..m {
482                let pivot = n - 2 * r - 1;
483                expected.push((pivot, 2));
484                expected.push((pivot, 2));
485            }
486            if !n.is_multiple_of(2) {
487                expected.push((2 * m, 1));
488            }
489
490            for r in m..(n / 2) {
491                let pivot = 2 * r + n % 2 + 1;
492                expected.push((pivot, 2));
493                expected.push((pivot, 2));
494            }
495
496            let l_permuted = lblt.l_permuted();
497            let reconstruction = &l_permuted * lblt.d() * l_permuted.adjoint();
498
499            assert_eq!(exchange, reconstruction);
500            assert_eq!(lblt.pivots.as_slice(), expected);
501        }
502    }
503}