1use simba::scalar::ComplexField;
2use simba::simd::SimdComplexField;
3
4use crate::base::allocator::Allocator;
5use crate::base::constraint::{SameNumberOfRows, ShapeConstraint};
6use crate::base::dimension::{Dim, U1};
7use crate::base::storage::{Storage, StorageMut};
8use crate::base::{DVectorView, DefaultAllocator, Matrix, OMatrix, SquareMatrix, Vector};
9
10impl<T: ComplexField, D: Dim, S: Storage<T, D, D>> SquareMatrix<T, D, S> {
11 #[must_use = "Did you mean to use solve_lower_triangular_mut()?"]
14 #[inline]
15 pub fn solve_lower_triangular<R2: Dim, C2: Dim, S2>(
16 &self,
17 b: &Matrix<T, R2, C2, S2>,
18 ) -> Option<OMatrix<T, R2, C2>>
19 where
20 S2: Storage<T, R2, C2>,
21 DefaultAllocator: Allocator<R2, C2>,
22 ShapeConstraint: SameNumberOfRows<R2, D>,
23 {
24 let mut res = b.clone_owned();
25 if self.solve_lower_triangular_mut(&mut res) {
26 Some(res)
27 } else {
28 None
29 }
30 }
31
32 #[must_use = "Did you mean to use solve_upper_triangular_mut()?"]
35 #[inline]
36 pub fn solve_upper_triangular<R2: Dim, C2: Dim, S2>(
37 &self,
38 b: &Matrix<T, R2, C2, S2>,
39 ) -> Option<OMatrix<T, R2, C2>>
40 where
41 S2: Storage<T, R2, C2>,
42 DefaultAllocator: Allocator<R2, C2>,
43 ShapeConstraint: SameNumberOfRows<R2, D>,
44 {
45 let mut res = b.clone_owned();
46 if self.solve_upper_triangular_mut(&mut res) {
47 Some(res)
48 } else {
49 None
50 }
51 }
52
53 pub fn solve_lower_triangular_mut<R2: Dim, C2: Dim, S2>(
56 &self,
57 b: &mut Matrix<T, R2, C2, S2>,
58 ) -> bool
59 where
60 S2: StorageMut<T, R2, C2>,
61 ShapeConstraint: SameNumberOfRows<R2, D>,
62 {
63 let cols = b.ncols();
64
65 for i in 0..cols {
66 if !self.solve_lower_triangular_vector_mut(&mut b.column_mut(i)) {
67 return false;
68 }
69 }
70
71 true
72 }
73
74 fn solve_lower_triangular_vector_mut<R2: Dim, S2>(&self, b: &mut Vector<T, R2, S2>) -> bool
75 where
76 S2: StorageMut<T, R2, U1>,
77 ShapeConstraint: SameNumberOfRows<R2, D>,
78 {
79 let dim = self.nrows();
80
81 for i in 0..dim {
82 let coeff;
83
84 unsafe {
85 let diag = self.get_unchecked((i, i)).clone();
86
87 if diag.is_zero() {
88 return false;
89 }
90
91 coeff = b.vget_unchecked(i).clone() / diag;
92 *b.vget_unchecked_mut(i) = coeff.clone();
93 }
94
95 b.rows_range_mut(i + 1..)
96 .axpy(-coeff, &self.view_range(i + 1.., i), T::one());
97 }
98
99 true
100 }
101
102 pub fn solve_lower_triangular_with_diag_mut<R2: Dim, C2: Dim, S2>(
107 &self,
108 b: &mut Matrix<T, R2, C2, S2>,
109 diag: T,
110 ) -> bool
111 where
112 S2: StorageMut<T, R2, C2>,
113 ShapeConstraint: SameNumberOfRows<R2, D>,
114 {
115 if diag.is_zero() {
116 return false;
117 }
118
119 let dim = self.nrows();
120 let cols = b.ncols();
121
122 for k in 0..cols {
123 let mut bcol = b.column_mut(k);
124
125 for i in 0..dim - 1 {
126 let coeff = unsafe { bcol.vget_unchecked(i).clone() } / diag.clone();
127 bcol.rows_range_mut(i + 1..)
128 .axpy(-coeff, &self.view_range(i + 1.., i), T::one());
129 }
130 }
131
132 true
133 }
134
135 pub fn solve_upper_triangular_mut<R2: Dim, C2: Dim, S2>(
138 &self,
139 b: &mut Matrix<T, R2, C2, S2>,
140 ) -> bool
141 where
142 S2: StorageMut<T, R2, C2>,
143 ShapeConstraint: SameNumberOfRows<R2, D>,
144 {
145 let cols = b.ncols();
146
147 for i in 0..cols {
148 if !self.solve_upper_triangular_vector_mut(&mut b.column_mut(i)) {
149 return false;
150 }
151 }
152
153 true
154 }
155
156 fn solve_upper_triangular_vector_mut<R2: Dim, S2>(&self, b: &mut Vector<T, R2, S2>) -> bool
157 where
158 S2: StorageMut<T, R2, U1>,
159 ShapeConstraint: SameNumberOfRows<R2, D>,
160 {
161 let dim = self.nrows();
162
163 for i in (0..dim).rev() {
164 let coeff;
165
166 unsafe {
167 let diag = self.get_unchecked((i, i)).clone();
168
169 if diag.is_zero() {
170 return false;
171 }
172
173 coeff = b.vget_unchecked(i).clone() / diag;
174 *b.vget_unchecked_mut(i) = coeff.clone();
175 }
176
177 b.rows_range_mut(..i)
178 .axpy(-coeff, &self.view_range(..i, i), T::one());
179 }
180
181 true
182 }
183
184 #[must_use = "Did you mean to use tr_solve_lower_triangular_mut()?"]
192 #[inline]
193 pub fn tr_solve_lower_triangular<R2: Dim, C2: Dim, S2>(
194 &self,
195 b: &Matrix<T, R2, C2, S2>,
196 ) -> Option<OMatrix<T, R2, C2>>
197 where
198 S2: Storage<T, R2, C2>,
199 DefaultAllocator: Allocator<R2, C2>,
200 ShapeConstraint: SameNumberOfRows<R2, D>,
201 {
202 let mut res = b.clone_owned();
203 if self.tr_solve_lower_triangular_mut(&mut res) {
204 Some(res)
205 } else {
206 None
207 }
208 }
209
210 #[must_use = "Did you mean to use tr_solve_upper_triangular_mut()?"]
213 #[inline]
214 pub fn tr_solve_upper_triangular<R2: Dim, C2: Dim, S2>(
215 &self,
216 b: &Matrix<T, R2, C2, S2>,
217 ) -> Option<OMatrix<T, R2, C2>>
218 where
219 S2: Storage<T, R2, C2>,
220 DefaultAllocator: Allocator<R2, C2>,
221 ShapeConstraint: SameNumberOfRows<R2, D>,
222 {
223 let mut res = b.clone_owned();
224 if self.tr_solve_upper_triangular_mut(&mut res) {
225 Some(res)
226 } else {
227 None
228 }
229 }
230
231 pub fn tr_solve_lower_triangular_mut<R2: Dim, C2: Dim, S2>(
234 &self,
235 b: &mut Matrix<T, R2, C2, S2>,
236 ) -> bool
237 where
238 S2: StorageMut<T, R2, C2>,
239 ShapeConstraint: SameNumberOfRows<R2, D>,
240 {
241 let cols = b.ncols();
242
243 for i in 0..cols {
244 if !self.xx_solve_lower_triangular_vector_mut(
245 &mut b.column_mut(i),
246 |e| e,
247 |a, b| a.dot(b),
248 ) {
249 return false;
250 }
251 }
252
253 true
254 }
255
256 pub fn tr_solve_upper_triangular_mut<R2: Dim, C2: Dim, S2>(
259 &self,
260 b: &mut Matrix<T, R2, C2, S2>,
261 ) -> bool
262 where
263 S2: StorageMut<T, R2, C2>,
264 ShapeConstraint: SameNumberOfRows<R2, D>,
265 {
266 let cols = b.ncols();
267
268 for i in 0..cols {
269 if !self.xx_solve_upper_triangular_vector_mut(
270 &mut b.column_mut(i),
271 |e| e,
272 |a, b| a.dot(b),
273 ) {
274 return false;
275 }
276 }
277
278 true
279 }
280
281 #[must_use = "Did you mean to use ad_solve_lower_triangular_mut()?"]
284 #[inline]
285 pub fn ad_solve_lower_triangular<R2: Dim, C2: Dim, S2>(
286 &self,
287 b: &Matrix<T, R2, C2, S2>,
288 ) -> Option<OMatrix<T, R2, C2>>
289 where
290 S2: Storage<T, R2, C2>,
291 DefaultAllocator: Allocator<R2, C2>,
292 ShapeConstraint: SameNumberOfRows<R2, D>,
293 {
294 let mut res = b.clone_owned();
295 if self.ad_solve_lower_triangular_mut(&mut res) {
296 Some(res)
297 } else {
298 None
299 }
300 }
301
302 #[must_use = "Did you mean to use ad_solve_upper_triangular_mut()?"]
305 #[inline]
306 pub fn ad_solve_upper_triangular<R2: Dim, C2: Dim, S2>(
307 &self,
308 b: &Matrix<T, R2, C2, S2>,
309 ) -> Option<OMatrix<T, R2, C2>>
310 where
311 S2: Storage<T, R2, C2>,
312 DefaultAllocator: Allocator<R2, C2>,
313 ShapeConstraint: SameNumberOfRows<R2, D>,
314 {
315 let mut res = b.clone_owned();
316 if self.ad_solve_upper_triangular_mut(&mut res) {
317 Some(res)
318 } else {
319 None
320 }
321 }
322
323 pub fn ad_solve_lower_triangular_mut<R2: Dim, C2: Dim, S2>(
326 &self,
327 b: &mut Matrix<T, R2, C2, S2>,
328 ) -> bool
329 where
330 S2: StorageMut<T, R2, C2>,
331 ShapeConstraint: SameNumberOfRows<R2, D>,
332 {
333 let cols = b.ncols();
334
335 for i in 0..cols {
336 if !self.xx_solve_lower_triangular_vector_mut(
337 &mut b.column_mut(i),
338 |e| e.conjugate(),
339 |a, b| a.dotc(b),
340 ) {
341 return false;
342 }
343 }
344
345 true
346 }
347
348 pub fn ad_solve_upper_triangular_mut<R2: Dim, C2: Dim, S2>(
351 &self,
352 b: &mut Matrix<T, R2, C2, S2>,
353 ) -> bool
354 where
355 S2: StorageMut<T, R2, C2>,
356 ShapeConstraint: SameNumberOfRows<R2, D>,
357 {
358 let cols = b.ncols();
359
360 for i in 0..cols {
361 if !self.xx_solve_upper_triangular_vector_mut(
362 &mut b.column_mut(i),
363 |e| e.conjugate(),
364 |a, b| a.dotc(b),
365 ) {
366 return false;
367 }
368 }
369
370 true
371 }
372
373 #[inline(always)]
374 fn xx_solve_lower_triangular_vector_mut<R2: Dim, S2>(
375 &self,
376 b: &mut Vector<T, R2, S2>,
377 conjugate: impl Fn(T) -> T,
378 dot: impl Fn(
379 &DVectorView<'_, T, S::RStride, S::CStride>,
380 &DVectorView<'_, T, S2::RStride, S2::CStride>,
381 ) -> T,
382 ) -> bool
383 where
384 S2: StorageMut<T, R2, U1>,
385 ShapeConstraint: SameNumberOfRows<R2, D>,
386 {
387 let dim = self.nrows();
388
389 for i in (0..dim).rev() {
390 let dot = dot(&self.view_range(i + 1.., i), &b.view_range(i + 1.., 0));
391
392 unsafe {
393 let b_i = b.vget_unchecked_mut(i);
394
395 let diag = conjugate(self.get_unchecked((i, i)).clone());
396
397 if diag.is_zero() {
398 return false;
399 }
400
401 *b_i = (b_i.clone() - dot) / diag;
402 }
403 }
404
405 true
406 }
407
408 #[inline(always)]
409 fn xx_solve_upper_triangular_vector_mut<R2: Dim, S2>(
410 &self,
411 b: &mut Vector<T, R2, S2>,
412 conjugate: impl Fn(T) -> T,
413 dot: impl Fn(
414 &DVectorView<'_, T, S::RStride, S::CStride>,
415 &DVectorView<'_, T, S2::RStride, S2::CStride>,
416 ) -> T,
417 ) -> bool
418 where
419 S2: StorageMut<T, R2, U1>,
420 ShapeConstraint: SameNumberOfRows<R2, D>,
421 {
422 let dim = self.nrows();
423
424 for i in 0..dim {
425 let dot = dot(&self.view_range(..i, i), &b.view_range(..i, 0));
426
427 unsafe {
428 let b_i = b.vget_unchecked_mut(i);
429 let diag = conjugate(self.get_unchecked((i, i)).clone());
430
431 if diag.is_zero() {
432 return false;
433 }
434
435 *b_i = (b_i.clone() - dot) / diag;
436 }
437 }
438
439 true
440 }
441}
442
443impl<T: SimdComplexField, D: Dim, S: Storage<T, D, D>> SquareMatrix<T, D, S> {
450 #[must_use = "Did you mean to use solve_lower_triangular_unchecked_mut()?"]
453 #[inline]
454 pub fn solve_lower_triangular_unchecked<R2: Dim, C2: Dim, S2>(
455 &self,
456 b: &Matrix<T, R2, C2, S2>,
457 ) -> OMatrix<T, R2, C2>
458 where
459 S2: Storage<T, R2, C2>,
460 DefaultAllocator: Allocator<R2, C2>,
461 ShapeConstraint: SameNumberOfRows<R2, D>,
462 {
463 let mut res = b.clone_owned();
464 self.solve_lower_triangular_unchecked_mut(&mut res);
465 res
466 }
467
468 #[must_use = "Did you mean to use solve_upper_triangular_unchecked_mut()?"]
471 #[inline]
472 pub fn solve_upper_triangular_unchecked<R2: Dim, C2: Dim, S2>(
473 &self,
474 b: &Matrix<T, R2, C2, S2>,
475 ) -> OMatrix<T, R2, C2>
476 where
477 S2: Storage<T, R2, C2>,
478 DefaultAllocator: Allocator<R2, C2>,
479 ShapeConstraint: SameNumberOfRows<R2, D>,
480 {
481 let mut res = b.clone_owned();
482 self.solve_upper_triangular_unchecked_mut(&mut res);
483 res
484 }
485
486 pub fn solve_lower_triangular_unchecked_mut<R2: Dim, C2: Dim, S2>(
489 &self,
490 b: &mut Matrix<T, R2, C2, S2>,
491 ) where
492 S2: StorageMut<T, R2, C2>,
493 ShapeConstraint: SameNumberOfRows<R2, D>,
494 {
495 for i in 0..b.ncols() {
496 self.solve_lower_triangular_vector_unchecked_mut(&mut b.column_mut(i));
497 }
498 }
499
500 fn solve_lower_triangular_vector_unchecked_mut<R2: Dim, S2>(&self, b: &mut Vector<T, R2, S2>)
501 where
502 S2: StorageMut<T, R2, U1>,
503 ShapeConstraint: SameNumberOfRows<R2, D>,
504 {
505 let dim = self.nrows();
506
507 for i in 0..dim {
508 let coeff;
509
510 unsafe {
511 let diag = self.get_unchecked((i, i)).clone();
512 coeff = b.vget_unchecked(i).clone() / diag;
513 *b.vget_unchecked_mut(i) = coeff.clone();
514 }
515
516 b.rows_range_mut(i + 1..)
517 .axpy(-coeff.clone(), &self.view_range(i + 1.., i), T::one());
518 }
519 }
520
521 pub fn solve_lower_triangular_with_diag_unchecked_mut<R2: Dim, C2: Dim, S2>(
526 &self,
527 b: &mut Matrix<T, R2, C2, S2>,
528 diag: T,
529 ) where
530 S2: StorageMut<T, R2, C2>,
531 ShapeConstraint: SameNumberOfRows<R2, D>,
532 {
533 let dim = self.nrows();
534 let cols = b.ncols();
535
536 for k in 0..cols {
537 let mut bcol = b.column_mut(k);
538
539 for i in 0..dim - 1 {
540 let coeff = unsafe { bcol.vget_unchecked(i).clone() } / diag.clone();
541 bcol.rows_range_mut(i + 1..)
542 .axpy(-coeff, &self.view_range(i + 1.., i), T::one());
543 }
544 }
545 }
546
547 pub fn solve_upper_triangular_unchecked_mut<R2: Dim, C2: Dim, S2>(
550 &self,
551 b: &mut Matrix<T, R2, C2, S2>,
552 ) where
553 S2: StorageMut<T, R2, C2>,
554 ShapeConstraint: SameNumberOfRows<R2, D>,
555 {
556 for i in 0..b.ncols() {
557 self.solve_upper_triangular_vector_unchecked_mut(&mut b.column_mut(i))
558 }
559 }
560
561 fn solve_upper_triangular_vector_unchecked_mut<R2: Dim, S2>(&self, b: &mut Vector<T, R2, S2>)
562 where
563 S2: StorageMut<T, R2, U1>,
564 ShapeConstraint: SameNumberOfRows<R2, D>,
565 {
566 let dim = self.nrows();
567
568 for i in (0..dim).rev() {
569 let coeff;
570
571 unsafe {
572 let diag = self.get_unchecked((i, i)).clone();
573 coeff = b.vget_unchecked(i).clone() / diag;
574 *b.vget_unchecked_mut(i) = coeff.clone();
575 }
576
577 b.rows_range_mut(..i)
578 .axpy(-coeff, &self.view_range(..i, i), T::one());
579 }
580 }
581
582 #[must_use = "Did you mean to use tr_solve_lower_triangular_unchecked_mut()?"]
590 #[inline]
591 pub fn tr_solve_lower_triangular_unchecked<R2: Dim, C2: Dim, S2>(
592 &self,
593 b: &Matrix<T, R2, C2, S2>,
594 ) -> OMatrix<T, R2, C2>
595 where
596 S2: Storage<T, R2, C2>,
597 DefaultAllocator: Allocator<R2, C2>,
598 ShapeConstraint: SameNumberOfRows<R2, D>,
599 {
600 let mut res = b.clone_owned();
601 self.tr_solve_lower_triangular_unchecked_mut(&mut res);
602 res
603 }
604
605 #[must_use = "Did you mean to use tr_solve_upper_triangular_unchecked_mut()?"]
608 #[inline]
609 pub fn tr_solve_upper_triangular_unchecked<R2: Dim, C2: Dim, S2>(
610 &self,
611 b: &Matrix<T, R2, C2, S2>,
612 ) -> OMatrix<T, R2, C2>
613 where
614 S2: Storage<T, R2, C2>,
615 DefaultAllocator: Allocator<R2, C2>,
616 ShapeConstraint: SameNumberOfRows<R2, D>,
617 {
618 let mut res = b.clone_owned();
619 self.tr_solve_upper_triangular_unchecked_mut(&mut res);
620 res
621 }
622
623 pub fn tr_solve_lower_triangular_unchecked_mut<R2: Dim, C2: Dim, S2>(
626 &self,
627 b: &mut Matrix<T, R2, C2, S2>,
628 ) where
629 S2: StorageMut<T, R2, C2>,
630 ShapeConstraint: SameNumberOfRows<R2, D>,
631 {
632 for i in 0..b.ncols() {
633 self.xx_solve_lower_triangular_vector_unchecked_mut(
634 &mut b.column_mut(i),
635 |e| e,
636 |a, b| a.dot(b),
637 )
638 }
639 }
640
641 pub fn tr_solve_upper_triangular_unchecked_mut<R2: Dim, C2: Dim, S2>(
644 &self,
645 b: &mut Matrix<T, R2, C2, S2>,
646 ) where
647 S2: StorageMut<T, R2, C2>,
648 ShapeConstraint: SameNumberOfRows<R2, D>,
649 {
650 for i in 0..b.ncols() {
651 self.xx_solve_upper_triangular_vector_unchecked_mut(
652 &mut b.column_mut(i),
653 |e| e,
654 |a, b| a.dot(b),
655 )
656 }
657 }
658
659 #[must_use = "Did you mean to use ad_solve_lower_triangular_unchecked_mut()?"]
662 #[inline]
663 pub fn ad_solve_lower_triangular_unchecked<R2: Dim, C2: Dim, S2>(
664 &self,
665 b: &Matrix<T, R2, C2, S2>,
666 ) -> OMatrix<T, R2, C2>
667 where
668 S2: Storage<T, R2, C2>,
669 DefaultAllocator: Allocator<R2, C2>,
670 ShapeConstraint: SameNumberOfRows<R2, D>,
671 {
672 let mut res = b.clone_owned();
673 self.ad_solve_lower_triangular_unchecked_mut(&mut res);
674 res
675 }
676
677 #[must_use = "Did you mean to use ad_solve_upper_triangular_unchecked_mut()?"]
680 #[inline]
681 pub fn ad_solve_upper_triangular_unchecked<R2: Dim, C2: Dim, S2>(
682 &self,
683 b: &Matrix<T, R2, C2, S2>,
684 ) -> OMatrix<T, R2, C2>
685 where
686 S2: Storage<T, R2, C2>,
687 DefaultAllocator: Allocator<R2, C2>,
688 ShapeConstraint: SameNumberOfRows<R2, D>,
689 {
690 let mut res = b.clone_owned();
691 self.ad_solve_upper_triangular_unchecked_mut(&mut res);
692 res
693 }
694
695 pub fn ad_solve_lower_triangular_unchecked_mut<R2: Dim, C2: Dim, S2>(
698 &self,
699 b: &mut Matrix<T, R2, C2, S2>,
700 ) where
701 S2: StorageMut<T, R2, C2>,
702 ShapeConstraint: SameNumberOfRows<R2, D>,
703 {
704 for i in 0..b.ncols() {
705 self.xx_solve_lower_triangular_vector_unchecked_mut(
706 &mut b.column_mut(i),
707 |e| e.simd_conjugate(),
708 |a, b| a.dotc(b),
709 )
710 }
711 }
712
713 pub fn ad_solve_upper_triangular_unchecked_mut<R2: Dim, C2: Dim, S2>(
716 &self,
717 b: &mut Matrix<T, R2, C2, S2>,
718 ) where
719 S2: StorageMut<T, R2, C2>,
720 ShapeConstraint: SameNumberOfRows<R2, D>,
721 {
722 for i in 0..b.ncols() {
723 self.xx_solve_upper_triangular_vector_unchecked_mut(
724 &mut b.column_mut(i),
725 |e| e.simd_conjugate(),
726 |a, b| a.dotc(b),
727 )
728 }
729 }
730
731 #[inline(always)]
732 fn xx_solve_lower_triangular_vector_unchecked_mut<R2: Dim, S2>(
733 &self,
734 b: &mut Vector<T, R2, S2>,
735 conjugate: impl Fn(T) -> T,
736 dot: impl Fn(
737 &DVectorView<'_, T, S::RStride, S::CStride>,
738 &DVectorView<'_, T, S2::RStride, S2::CStride>,
739 ) -> T,
740 ) where
741 S2: StorageMut<T, R2, U1>,
742 ShapeConstraint: SameNumberOfRows<R2, D>,
743 {
744 let dim = self.nrows();
745
746 for i in (0..dim).rev() {
747 let dot = dot(&self.view_range(i + 1.., i), &b.view_range(i + 1.., 0));
748
749 unsafe {
750 let b_i = b.vget_unchecked_mut(i);
751 let diag = conjugate(self.get_unchecked((i, i)).clone());
752 *b_i = (b_i.clone() - dot) / diag;
753 }
754 }
755 }
756
757 #[inline(always)]
758 fn xx_solve_upper_triangular_vector_unchecked_mut<R2: Dim, S2>(
759 &self,
760 b: &mut Vector<T, R2, S2>,
761 conjugate: impl Fn(T) -> T,
762 dot: impl Fn(
763 &DVectorView<'_, T, S::RStride, S::CStride>,
764 &DVectorView<'_, T, S2::RStride, S2::CStride>,
765 ) -> T,
766 ) where
767 S2: StorageMut<T, R2, U1>,
768 ShapeConstraint: SameNumberOfRows<R2, D>,
769 {
770 for i in 0..self.nrows() {
771 let dot = dot(&self.view_range(..i, i), &b.view_range(..i, 0));
772
773 unsafe {
774 let b_i = b.vget_unchecked_mut(i);
775 let diag = conjugate(self.get_unchecked((i, i)).clone());
776 *b_i = (b_i.clone() - dot) / diag;
777 }
778 }
779 }
780}