1use crate::{
2 ComplexField, DefaultAllocator, Dim, Matrix, OMatrix, OVector, RealField, Storage, U1,
3 allocator::Allocator, convert,
4};
5use num_traits::{One, Zero};
6
7#[cfg(feature = "serde-serialize-no-std")]
8use serde::{Deserialize, Serialize};
9
10type Pivot = (usize, usize);
12
13#[cfg_attr(feature = "serde-serialize-no-std", derive(Serialize, Deserialize))]
15#[cfg_attr(
16 feature = "serde-serialize-no-std",
17 serde(bound(serialize = "DefaultAllocator: Allocator<N, N>,
18 OMatrix<T, N, N>: Serialize,
19 OVector<Pivot, N>: Serialize,
20 Option<usize>: Serialize"))
21)]
22#[cfg_attr(
23 feature = "serde-serialize-no-std",
24 serde(bound(deserialize = "DefaultAllocator: Allocator<N, N>,
25 OMatrix<T, N, N>: Deserialize<'de>,
26 OVector<Pivot, N>: Deserialize<'de>,
27 Option<usize>: Deserialize<'de>"))
28)]
29#[cfg_attr(feature = "defmt", derive(defmt::Format))]
30#[derive(Clone, Debug)]
31pub struct LBLT<T: ComplexField, N: Dim>
32where
33 DefaultAllocator: Allocator<N> + Allocator<N, N>,
34{
35 matrix: OMatrix<T, N, N>,
36 pivots: OVector<Pivot, N>,
37 zero_pivot: Option<usize>,
38}
39
40impl<T: Copy + ComplexField, N: Dim> LBLT<T, N>
41where
42 T::RealField: Copy,
43 DefaultAllocator: Allocator<N> + Allocator<N, N>,
44{
45 pub fn new(mut matrix: OMatrix<T, N, N>) -> Self {
57 assert!(matrix.is_square());
58 let n = matrix.nrows();
59
60 let mut pivots = OVector::from_element_generic(matrix.shape_generic().0, U1, (0, 0));
61 let mut zero_pivot = None;
62
63 let alpha: T::RealField = convert(0.6403882032022076);
65
66 let mut k = 0;
68
69 while k < n {
70 let mut block_size = 1;
71
72 matrix[(k, k)] = T::from_real(matrix[(k, k)].real());
74 let diag_abs = matrix[(k, k)].real().abs();
75
76 let (imax, colmax) = if k + 1 < n {
79 let mut imax = k + 1;
80 let mut colmax = matrix[(imax, k)].norm1();
81 for i in (k + 2)..n {
82 let magnitude = matrix[(i, k)].norm1();
83 if magnitude > colmax {
84 imax = i;
85 colmax = magnitude;
86 }
87 }
88 (imax, colmax)
89 } else {
90 (0, T::RealField::zero())
92 };
93
94 if diag_abs.max(colmax) == T::RealField::zero() {
95 if zero_pivot.is_none() {
97 zero_pivot = Some(k);
98 }
99
100 pivots[k] = (k, 1);
101 k += 1;
102 continue;
103 }
104
105 let pivot_index: usize;
106
107 if diag_abs < alpha * colmax {
108 let mut rowmax = T::RealField::zero();
109 for j in k..imax {
110 rowmax = rowmax.max(matrix[(imax, j)].norm1());
111 }
112 for j in (imax + 1)..matrix.nrows() {
113 rowmax = rowmax.max(matrix[(j, imax)].norm1());
114 }
115
116 if diag_abs >= alpha * colmax * (colmax / rowmax) {
117 pivot_index = k;
120 } else {
121 pivot_index = imax;
122
123 if matrix[(imax, imax)].real().abs() < alpha * rowmax {
124 block_size = 2;
127 }
128 }
129 } else {
130 pivot_index = k;
133 }
134
135 let pivot_target = k + block_size - 1;
136
137 if pivot_index != pivot_target {
138 for i in (pivot_index + 1)..matrix.nrows() {
140 matrix.swap((i, pivot_target), (i, pivot_index));
142 }
143
144 for j in (pivot_target + 1)..pivot_index {
145 matrix.swap((j, pivot_target), (pivot_index, j));
147 matrix[(j, pivot_target)] = matrix[(j, pivot_target)].conjugate();
148 matrix[(pivot_index, j)] = matrix[(pivot_index, j)].conjugate();
149 }
150
151 matrix[(pivot_index, pivot_target)] =
153 matrix[(pivot_index, pivot_target)].conjugate();
154
155 matrix.swap((pivot_target, pivot_target), (pivot_index, pivot_index));
157
158 if k + 1 == pivot_target {
159 matrix.swap((k + 1, k), (pivot_index, k));
161 }
162 }
163
164 if block_size == 1 {
165 if k + 1 < n {
167 let inv_diag = T::RealField::one() / matrix[(k, k)].real();
168
169 for j in (k + 1)..n {
170 let jk_conj = matrix[(j, k)].conjugate();
171
172 for i in j..n {
173 matrix[(i, j)] =
175 matrix[(i, j)] + matrix[(i, k)].scale(-inv_diag) * jk_conj;
176 }
177
178 matrix[(j, j)] = T::from_real(matrix[(j, j)].real());
180 }
181
182 for i in (k + 1)..n {
184 matrix[(i, k)] = matrix[(i, k)].scale(inv_diag);
185 }
186 }
187
188 pivots[k] = (pivot_index, 1);
189 } else {
190 if k + 2 < n {
192 let d = matrix[(k + 1, k)].abs();
194 let d11 = matrix[(k + 1, k + 1)].real() / d;
195 let d22 = matrix[(k, k)].real() / d;
196 let d21 = matrix[(k + 1, k)].unscale(d);
197 let scale = T::RealField::one() / (d * (d11 * d22 - T::RealField::one()));
198
199 for j in (k + 2)..n {
200 let work1 =
203 (matrix[(j, k)].scale(d11) - matrix[(j, k + 1)] * d21).scale(scale);
204 let work2 = (matrix[(j, k + 1)].scale(d22)
205 - matrix[(j, k)] * d21.conjugate())
206 .scale(scale);
207
208 for i in j..n {
209 matrix[(i, j)] = matrix[(i, j)]
211 - matrix[(i, k)] * work1.conjugate()
212 - matrix[(i, k + 1)] * work2.conjugate();
213 }
214
215 matrix[(j, k)] = work1;
216 matrix[(j, k + 1)] = work2;
217
218 matrix[(j, j)] = T::from_real(matrix[(j, j)].real());
220 }
221 }
222
223 pivots[k] = (pivot_index, 2);
224 pivots[k + 1] = (pivot_index, 2);
225 }
226
227 k += block_size;
228 }
229
230 Self {
231 matrix,
232 pivots,
233 zero_pivot,
234 }
235 }
236
237 pub fn l_permuted(&self) -> OMatrix<T, N, N> {
243 let n = self.matrix.nrows();
244 let (nrows, ncols) = self.matrix.shape_generic();
245 let mut l_permuted = OMatrix::identity_generic(nrows, ncols);
246
247 let mut k = 0;
248 while k < n {
249 let (pivot_index, block_size) = self.pivots[k];
250
251 if block_size == 1 {
252 l_permuted.swap_columns(k, pivot_index);
254
255 for row in 0..n {
257 for i in (k + 1)..n {
258 l_permuted[(row, k)] =
259 l_permuted[(row, k)] + l_permuted[(row, i)] * self.matrix[(i, k)];
260 }
261 }
262
263 k += 1;
264 } else {
265 l_permuted.swap_columns(k + 1, pivot_index);
267
268 for row in 0..n {
270 for i in (k + 2)..n {
271 l_permuted[(row, k)] =
272 l_permuted[(row, k)] + l_permuted[(row, i)] * self.matrix[(i, k)];
273 l_permuted[(row, k + 1)] = l_permuted[(row, k + 1)]
274 + l_permuted[(row, i)] * self.matrix[(i, k + 1)];
275 }
276 }
277
278 k += 2;
279 }
280 }
281
282 l_permuted
283 }
284
285 pub fn d(&self) -> OMatrix<T, N, N> {
287 let n = self.matrix.nrows();
288 let (nrows, ncols) = self.matrix.shape_generic();
289 let mut d = OMatrix::zeros_generic(nrows, ncols);
290
291 let mut k = 0;
292 while k < n {
293 d[(k, k)] = self.matrix[(k, k)];
294 if self.pivots[k].1 == 2 {
295 d[(k + 1, k)] = self.matrix[(k + 1, k)];
296 d[(k, k + 1)] = self.matrix[(k + 1, k)].conjugate();
297 d[(k + 1, k + 1)] = self.matrix[(k + 1, k + 1)];
298 k += 1;
299 }
300 k += 1;
301 }
302
303 d
304 }
305
306 pub fn solve<M: Dim, S>(&self, b: &Matrix<T, N, M, S>) -> Option<OMatrix<T, N, M>>
308 where
309 S: Storage<T, N, M>,
310 DefaultAllocator: Allocator<N, M>,
311 {
312 let mut result = b.clone_owned();
313
314 if self.solve_mut(&mut result) {
315 Some(result)
316 } else {
317 None
318 }
319 }
320
321 pub fn solve_mut<M: Dim>(&self, b: &mut OMatrix<T, N, M>) -> bool
323 where
324 DefaultAllocator: Allocator<N, M>,
325 {
326 assert_eq!(self.matrix.nrows(), b.nrows());
327
328 if self.zero_pivot.is_some() {
329 return false;
330 }
331
332 let (n, m) = b.shape();
333
334 let mut k = 0;
336 while k < n {
337 let (pivot_index, block_size) = self.pivots[k];
338
339 if block_size == 1 {
340 b.swap_rows(k, pivot_index);
341
342 for j in 0..m {
343 for i in (k + 1)..n {
344 b[(i, j)] = b[(i, j)] - self.matrix[(i, k)] * b[(k, j)];
345 }
346 }
347
348 k += 1;
349 } else {
350 b.swap_rows(k + 1, pivot_index);
351
352 for j in 0..m {
353 for i in (k + 2)..n {
354 b[(i, j)] = b[(i, j)]
355 - self.matrix[(i, k)] * b[(k, j)]
356 - self.matrix[(i, k + 1)] * b[(k + 1, j)];
357 }
358 }
359
360 k += 2;
361 }
362 }
363
364 let mut k = 0;
366 while k < n {
367 if self.pivots[k].1 == 1 {
368 for j in 0..m {
369 b[(k, j)] = b[(k, j)].unscale(self.matrix[(k, k)].real());
370 }
371 k += 1;
372 } else {
373 let d11 = self.matrix[(k, k)].real();
374 let d22 = self.matrix[(k + 1, k + 1)].real();
375 let d21 = self.matrix[(k + 1, k)];
376
377 let det = d11 * d22 - d21.modulus_squared();
378
379 for j in 0..m {
380 let b_k = b[(k, j)];
381 let b_k1 = b[(k + 1, j)];
382
383 b[(k, j)] = (b_k.scale(d22) - b_k1 * d21.conjugate()).unscale(det);
384 b[(k + 1, j)] = (b_k1.scale(d11) - b_k * d21).unscale(det);
385 }
386 k += 2;
387 }
388 }
389
390 let mut k = n;
392 while k > 0 {
393 let k1 = k - 1;
394
395 for j in 0..m {
396 for i in k..n {
397 b[(k1, j)] = b[(k1, j)] - self.matrix[(i, k1)].conjugate() * b[(i, j)];
398 }
399 }
400
401 if self.pivots[k1].1 == 1 {
402 k -= 1;
403 } else {
404 let k2 = k - 2;
405 for j in 0..m {
406 for i in k..n {
407 b[(k2, j)] = b[(k2, j)] - self.matrix[(i, k2)].conjugate() * b[(i, j)];
408 }
409 }
410 k -= 2;
411 }
412
413 b.swap_rows(k1, self.pivots[k1].0);
414 }
415
416 true
417 }
418
419 pub fn determinant(&self) -> T::RealField {
421 let n = self.matrix.nrows();
422 let mut determinant = T::RealField::one();
423
424 let mut k = 0;
425 while k < n {
426 if self.pivots[k].1 == 1 {
427 determinant *= self.matrix[(k, k)].real();
428 k += 1;
429 } else {
430 determinant *= self.matrix[(k, k)].real() * self.matrix[(k + 1, k + 1)].real()
431 - self.matrix[(k + 1, k)].modulus_squared();
432 k += 2;
433 }
434 }
435
436 determinant
437 }
438}
439
440#[cfg(test)]
441mod tests {
442 use crate::{DMatrix, DVector};
443
444 use super::*;
445
446 #[test]
447 fn zero_matrix() {
448 for n in 1..=5 {
449 let lblt = DMatrix::<f64>::zeros(n, n).lblt();
450 assert_eq!(lblt.l_permuted(), DMatrix::identity(n, n));
451 assert_eq!(lblt.d(), DMatrix::zeros(n, n));
452 assert_eq!(lblt.zero_pivot, Some(0));
453 assert_eq!(lblt.pivots, DVector::from_fn(n, |i, _| (i, 1)));
454 assert!(lblt.determinant().is_zero());
455 assert!(lblt.solve(&DVector::from_element(n, 1.0)).is_none());
456 }
457 }
458
459 #[test]
460 fn identity_matrix() {
461 for n in 1..=5 {
462 let identity = DMatrix::<f64>::identity(n, n);
463 let lblt = identity.clone().lblt();
464
465 assert_eq!(lblt.l_permuted(), identity);
466 assert_eq!(lblt.d(), identity);
467 assert_eq!(lblt.zero_pivot, None);
468 assert_eq!(lblt.pivots, DVector::from_fn(n, |i, _| (i, 1)));
469 assert!(lblt.determinant().is_one());
470 }
471 }
472
473 #[test]
474 fn exchange_matrix() {
475 for n in 1..=15 {
476 let exchange = DMatrix::from_fn(n, n, |i, j| if i + j + 1 == n { 1.0 } else { 0.0 });
477 let lblt = exchange.clone().lblt();
478
479 let mut expected = Vec::with_capacity(n);
480 let m = (n + 2) / 4;
481 for r in 0..m {
482 let pivot = n - 2 * r - 1;
483 expected.push((pivot, 2));
484 expected.push((pivot, 2));
485 }
486 if !n.is_multiple_of(2) {
487 expected.push((2 * m, 1));
488 }
489
490 for r in m..(n / 2) {
491 let pivot = 2 * r + n % 2 + 1;
492 expected.push((pivot, 2));
493 expected.push((pivot, 2));
494 }
495
496 let l_permuted = lblt.l_permuted();
497 let reconstruction = &l_permuted * lblt.d() * l_permuted.adjoint();
498
499 assert_eq!(exchange, reconstruction);
500 assert_eq!(lblt.pivots.as_slice(), expected);
501 }
502 }
503}