nalgebra/base/blas.rs
1use crate::{RawStorage, SimdComplexField};
2use num::{One, Zero};
3use simba::scalar::{ClosedAddAssign, ClosedMulAssign};
4
5use crate::base::allocator::Allocator;
6use crate::base::blas_uninit::{axcpy_uninit, gemm_uninit, gemv_uninit};
7use crate::base::constraint::{
8 AreMultipliable, DimEq, SameNumberOfColumns, SameNumberOfRows, ShapeConstraint,
9};
10use crate::base::dimension::{Const, Dim, Dyn, U1, U2, U3, U4};
11use crate::base::storage::{Storage, StorageMut};
12use crate::base::uninit::Init;
13use crate::base::{
14 DVectorView, DefaultAllocator, Matrix, Scalar, SquareMatrix, Vector, VectorView,
15};
16
17/// # Dot/scalar product
18impl<T, R: Dim, C: Dim, S: RawStorage<T, R, C>> Matrix<T, R, C, S>
19where
20 T: Scalar + Zero + ClosedAddAssign + ClosedMulAssign,
21{
22 #[inline(always)]
23 fn dotx<R2: Dim, C2: Dim, SB>(
24 &self,
25 rhs: &Matrix<T, R2, C2, SB>,
26 conjugate: impl Fn(T) -> T,
27 ) -> T
28 where
29 SB: RawStorage<T, R2, C2>,
30 ShapeConstraint: DimEq<R, R2> + DimEq<C, C2>,
31 {
32 assert!(
33 self.nrows() == rhs.nrows(),
34 "Dot product dimensions mismatch for shapes {:?} and {:?}: left rows != right rows.",
35 self.shape(),
36 rhs.shape(),
37 );
38
39 assert!(
40 self.ncols() == rhs.ncols(),
41 "Dot product dimensions mismatch for shapes {:?} and {:?}: left cols != right cols.",
42 self.shape(),
43 rhs.shape(),
44 );
45
46 // So we do some special cases for common fixed-size vectors of dimension lower than 8
47 // because the `for` loop below won't be very efficient on those.
48 if (R::is::<U2>() || R2::is::<U2>()) && (C::is::<U1>() || C2::is::<U1>()) {
49 unsafe {
50 let a = conjugate(self.get_unchecked((0, 0)).clone())
51 * rhs.get_unchecked((0, 0)).clone();
52 let b = conjugate(self.get_unchecked((1, 0)).clone())
53 * rhs.get_unchecked((1, 0)).clone();
54
55 return a + b;
56 }
57 }
58 if (R::is::<U3>() || R2::is::<U3>()) && (C::is::<U1>() || C2::is::<U1>()) {
59 unsafe {
60 let a = conjugate(self.get_unchecked((0, 0)).clone())
61 * rhs.get_unchecked((0, 0)).clone();
62 let b = conjugate(self.get_unchecked((1, 0)).clone())
63 * rhs.get_unchecked((1, 0)).clone();
64 let c = conjugate(self.get_unchecked((2, 0)).clone())
65 * rhs.get_unchecked((2, 0)).clone();
66
67 return a + b + c;
68 }
69 }
70 if (R::is::<U4>() || R2::is::<U4>()) && (C::is::<U1>() || C2::is::<U1>()) {
71 unsafe {
72 let mut a = conjugate(self.get_unchecked((0, 0)).clone())
73 * rhs.get_unchecked((0, 0)).clone();
74 let mut b = conjugate(self.get_unchecked((1, 0)).clone())
75 * rhs.get_unchecked((1, 0)).clone();
76 let c = conjugate(self.get_unchecked((2, 0)).clone())
77 * rhs.get_unchecked((2, 0)).clone();
78 let d = conjugate(self.get_unchecked((3, 0)).clone())
79 * rhs.get_unchecked((3, 0)).clone();
80
81 a += c;
82 b += d;
83
84 return a + b;
85 }
86 }
87
88 // All this is inspired from the "unrolled version" discussed in:
89 // https://blog.theincredibleholk.org/blog/2012/12/10/optimizing-dot-product/
90 //
91 // And this comment from bluss:
92 // https://users.rust-lang.org/t/how-to-zip-two-slices-efficiently/2048/12
93 let mut res = T::zero();
94
95 // We have to define them outside of the loop (and not inside at first assignment)
96 // otherwise vectorization won't kick in for some reason.
97 let mut acc0;
98 let mut acc1;
99 let mut acc2;
100 let mut acc3;
101 let mut acc4;
102 let mut acc5;
103 let mut acc6;
104 let mut acc7;
105
106 for j in 0..self.ncols() {
107 let mut i = 0;
108
109 acc0 = T::zero();
110 acc1 = T::zero();
111 acc2 = T::zero();
112 acc3 = T::zero();
113 acc4 = T::zero();
114 acc5 = T::zero();
115 acc6 = T::zero();
116 acc7 = T::zero();
117
118 while self.nrows() - i >= 8 {
119 acc0 += unsafe {
120 conjugate(self.get_unchecked((i, j)).clone())
121 * rhs.get_unchecked((i, j)).clone()
122 };
123 acc1 += unsafe {
124 conjugate(self.get_unchecked((i + 1, j)).clone())
125 * rhs.get_unchecked((i + 1, j)).clone()
126 };
127 acc2 += unsafe {
128 conjugate(self.get_unchecked((i + 2, j)).clone())
129 * rhs.get_unchecked((i + 2, j)).clone()
130 };
131 acc3 += unsafe {
132 conjugate(self.get_unchecked((i + 3, j)).clone())
133 * rhs.get_unchecked((i + 3, j)).clone()
134 };
135 acc4 += unsafe {
136 conjugate(self.get_unchecked((i + 4, j)).clone())
137 * rhs.get_unchecked((i + 4, j)).clone()
138 };
139 acc5 += unsafe {
140 conjugate(self.get_unchecked((i + 5, j)).clone())
141 * rhs.get_unchecked((i + 5, j)).clone()
142 };
143 acc6 += unsafe {
144 conjugate(self.get_unchecked((i + 6, j)).clone())
145 * rhs.get_unchecked((i + 6, j)).clone()
146 };
147 acc7 += unsafe {
148 conjugate(self.get_unchecked((i + 7, j)).clone())
149 * rhs.get_unchecked((i + 7, j)).clone()
150 };
151 i += 8;
152 }
153
154 res += acc0 + acc4;
155 res += acc1 + acc5;
156 res += acc2 + acc6;
157 res += acc3 + acc7;
158
159 for k in i..self.nrows() {
160 res += unsafe {
161 conjugate(self.get_unchecked((k, j)).clone())
162 * rhs.get_unchecked((k, j)).clone()
163 }
164 }
165 }
166
167 res
168 }
169
170 /// The dot product between two vectors or matrices (seen as vectors).
171 ///
172 /// This is equal to `self.transpose() * rhs`. For the sesquilinear complex dot product, use
173 /// `self.dotc(rhs)`.
174 ///
175 /// Note that this is **not** the matrix multiplication as in, e.g., numpy. For matrix
176 /// multiplication, use one of: `.gemm`, `.mul_to`, `.mul`, the `*` operator.
177 ///
178 /// # Example
179 /// ```
180 /// # use nalgebra::{Vector3, Matrix2x3};
181 /// let vec1 = Vector3::new(1.0, 2.0, 3.0);
182 /// let vec2 = Vector3::new(0.1, 0.2, 0.3);
183 /// assert_eq!(vec1.dot(&vec2), 1.4);
184 ///
185 /// let mat1 = Matrix2x3::new(1.0, 2.0, 3.0,
186 /// 4.0, 5.0, 6.0);
187 /// let mat2 = Matrix2x3::new(0.1, 0.2, 0.3,
188 /// 0.4, 0.5, 0.6);
189 /// assert_eq!(mat1.dot(&mat2), 9.1);
190 /// ```
191 ///
192 #[inline]
193 #[must_use]
194 pub fn dot<R2: Dim, C2: Dim, SB>(&self, rhs: &Matrix<T, R2, C2, SB>) -> T
195 where
196 SB: RawStorage<T, R2, C2>,
197 ShapeConstraint: DimEq<R, R2> + DimEq<C, C2>,
198 {
199 self.dotx(rhs, |e| e)
200 }
201
202 /// The conjugate-linear dot product between two vectors or matrices (seen as vectors).
203 ///
204 /// This is equal to `self.adjoint() * rhs`.
205 /// For real vectors, this is identical to `self.dot(&rhs)`.
206 /// Note that this is **not** the matrix multiplication as in, e.g., numpy. For matrix
207 /// multiplication, use one of: `.gemm`, `.mul_to`, `.mul`, the `*` operator.
208 ///
209 /// # Example
210 /// ```
211 /// # use nalgebra::{Vector2, Complex};
212 /// let vec1 = Vector2::new(Complex::new(1.0, 2.0), Complex::new(3.0, 4.0));
213 /// let vec2 = Vector2::new(Complex::new(0.4, 0.3), Complex::new(0.2, 0.1));
214 /// assert_eq!(vec1.dotc(&vec2), Complex::new(2.0, -1.0));
215 ///
216 /// // Note that for complex vectors, we generally have:
217 /// // vec1.dotc(&vec2) != vec2.dot(&vec2)
218 /// assert_ne!(vec1.dotc(&vec2), vec1.dot(&vec2));
219 /// ```
220 #[inline]
221 #[must_use]
222 pub fn dotc<R2: Dim, C2: Dim, SB>(&self, rhs: &Matrix<T, R2, C2, SB>) -> T
223 where
224 T: SimdComplexField,
225 SB: RawStorage<T, R2, C2>,
226 ShapeConstraint: DimEq<R, R2> + DimEq<C, C2>,
227 {
228 self.dotx(rhs, T::simd_conjugate)
229 }
230
231 /// The dot product between the transpose of `self` and `rhs`.
232 ///
233 /// # Example
234 /// ```
235 /// # use nalgebra::{Vector3, RowVector3, Matrix2x3, Matrix3x2};
236 /// let vec1 = Vector3::new(1.0, 2.0, 3.0);
237 /// let vec2 = RowVector3::new(0.1, 0.2, 0.3);
238 /// assert_eq!(vec1.tr_dot(&vec2), 1.4);
239 ///
240 /// let mat1 = Matrix2x3::new(1.0, 2.0, 3.0,
241 /// 4.0, 5.0, 6.0);
242 /// let mat2 = Matrix3x2::new(0.1, 0.4,
243 /// 0.2, 0.5,
244 /// 0.3, 0.6);
245 /// assert_eq!(mat1.tr_dot(&mat2), 9.1);
246 /// ```
247 #[inline]
248 #[must_use]
249 pub fn tr_dot<R2: Dim, C2: Dim, SB>(&self, rhs: &Matrix<T, R2, C2, SB>) -> T
250 where
251 SB: RawStorage<T, R2, C2>,
252 ShapeConstraint: DimEq<C, R2> + DimEq<R, C2>,
253 {
254 let (nrows, ncols) = self.shape();
255 assert_eq!(
256 (ncols, nrows),
257 rhs.shape(),
258 "Transposed dot product dimension mismatch."
259 );
260
261 let mut res = T::zero();
262
263 for j in 0..self.nrows() {
264 for i in 0..self.ncols() {
265 res += unsafe {
266 self.get_unchecked((j, i)).clone() * rhs.get_unchecked((i, j)).clone()
267 }
268 }
269 }
270
271 res
272 }
273}
274
275/// # BLAS functions
276impl<T, D: Dim, S> Vector<T, D, S>
277where
278 T: Scalar + Zero + ClosedAddAssign + ClosedMulAssign,
279 S: StorageMut<T, D>,
280{
281 /// Computes `self = a * x * c + b * self`.
282 ///
283 /// If `b` is zero, `self` is never read from.
284 ///
285 /// # Example
286 /// ```
287 /// # use nalgebra::Vector3;
288 /// let mut vec1 = Vector3::new(1.0, 2.0, 3.0);
289 /// let vec2 = Vector3::new(0.1, 0.2, 0.3);
290 /// vec1.axcpy(5.0, &vec2, 2.0, 5.0);
291 /// assert_eq!(vec1, Vector3::new(6.0, 12.0, 18.0));
292 /// ```
293 #[inline]
294 #[allow(clippy::many_single_char_names)]
295 pub fn axcpy<D2: Dim, SB>(&mut self, a: T, x: &Vector<T, D2, SB>, c: T, b: T)
296 where
297 SB: Storage<T, D2>,
298 ShapeConstraint: DimEq<D, D2>,
299 {
300 unsafe { axcpy_uninit(Init, self, a, x, c, b) };
301 }
302
303 /// Computes `self = a * x + b * self`.
304 ///
305 /// If `b` is zero, `self` is never read from.
306 ///
307 /// # Example
308 /// ```
309 /// # use nalgebra::Vector3;
310 /// let mut vec1 = Vector3::new(1.0, 2.0, 3.0);
311 /// let vec2 = Vector3::new(0.1, 0.2, 0.3);
312 /// vec1.axpy(10.0, &vec2, 5.0);
313 /// assert_eq!(vec1, Vector3::new(6.0, 12.0, 18.0));
314 /// ```
315 #[inline]
316 pub fn axpy<D2: Dim, SB>(&mut self, a: T, x: &Vector<T, D2, SB>, b: T)
317 where
318 T: One,
319 SB: Storage<T, D2>,
320 ShapeConstraint: DimEq<D, D2>,
321 {
322 assert_eq!(self.nrows(), x.nrows(), "Axpy: mismatched vector shapes.");
323 self.axcpy(a, x, T::one(), b)
324 }
325
326 /// Computes `self = alpha * a * x + beta * self`, where `a` is a matrix, `x` a vector, and
327 /// `alpha, beta` two scalars.
328 ///
329 /// If `beta` is zero, `self` is never read.
330 ///
331 /// # Example
332 /// ```
333 /// # use nalgebra::{Matrix2, Vector2};
334 /// let mut vec1 = Vector2::new(1.0, 2.0);
335 /// let vec2 = Vector2::new(0.1, 0.2);
336 /// let mat = Matrix2::new(1.0, 2.0,
337 /// 3.0, 4.0);
338 /// vec1.gemv(10.0, &mat, &vec2, 5.0);
339 /// assert_eq!(vec1, Vector2::new(10.0, 21.0));
340 /// ```
341 #[inline]
342 pub fn gemv<R2: Dim, C2: Dim, D3: Dim, SB, SC>(
343 &mut self,
344 alpha: T,
345 a: &Matrix<T, R2, C2, SB>,
346 x: &Vector<T, D3, SC>,
347 beta: T,
348 ) where
349 T: One,
350 SB: Storage<T, R2, C2>,
351 SC: Storage<T, D3>,
352 ShapeConstraint: DimEq<D, R2> + AreMultipliable<R2, C2, D3, U1>,
353 {
354 // Safety: this is safe because we are passing Status == Init.
355 unsafe { gemv_uninit(Init, self, alpha, a, x, beta) }
356 }
357
358 #[inline(always)]
359 fn xxgemv<D2: Dim, D3: Dim, SB, SC>(
360 &mut self,
361 alpha: T,
362 a: &SquareMatrix<T, D2, SB>,
363 x: &Vector<T, D3, SC>,
364 beta: T,
365 dot: impl Fn(
366 &DVectorView<'_, T, SB::RStride, SB::CStride>,
367 &DVectorView<'_, T, SC::RStride, SC::CStride>,
368 ) -> T,
369 ) where
370 T: One,
371 SB: Storage<T, D2, D2>,
372 SC: Storage<T, D3>,
373 ShapeConstraint: DimEq<D, D2> + AreMultipliable<D2, D2, D3, U1>,
374 {
375 let dim1 = self.nrows();
376 let dim2 = a.nrows();
377 let dim3 = x.nrows();
378
379 assert!(
380 a.is_square(),
381 "Symmetric cgemv: the input matrix must be square."
382 );
383 assert!(
384 dim2 == dim3 && dim1 == dim2,
385 "Symmetric cgemv: dimensions mismatch."
386 );
387
388 if dim2 == 0 {
389 return;
390 }
391
392 // TODO: avoid bound checks.
393 let col2 = a.column(0);
394 let val = unsafe { x.vget_unchecked(0).clone() };
395 self.axpy(alpha.clone() * val, &col2, beta);
396 self[0] += alpha.clone() * dot(&a.view_range(1.., 0), &x.rows_range(1..));
397
398 for j in 1..dim2 {
399 let col2 = a.column(j);
400 let dot = dot(&col2.rows_range(j..), &x.rows_range(j..));
401
402 let val;
403 unsafe {
404 val = x.vget_unchecked(j).clone();
405 *self.vget_unchecked_mut(j) += alpha.clone() * dot;
406 }
407 self.rows_range_mut(j + 1..).axpy(
408 alpha.clone() * val,
409 &col2.rows_range(j + 1..),
410 T::one(),
411 );
412 }
413 }
414
415 /// Computes `self = alpha * a * x + beta * self`, where `a` is a **symmetric** matrix, `x` a
416 /// vector, and `alpha, beta` two scalars.
417 ///
418 /// For hermitian matrices, use `.hegemv` instead.
419 /// If `beta` is zero, `self` is never read. If `self` is read, only its lower-triangular part
420 /// (including the diagonal) is actually read.
421 ///
422 /// # Examples
423 /// ```
424 /// # use nalgebra::{Matrix2, Vector2};
425 /// let mat = Matrix2::new(1.0, 2.0,
426 /// 2.0, 4.0);
427 /// let mut vec1 = Vector2::new(1.0, 2.0);
428 /// let vec2 = Vector2::new(0.1, 0.2);
429 /// vec1.sygemv(10.0, &mat, &vec2, 5.0);
430 /// assert_eq!(vec1, Vector2::new(10.0, 20.0));
431 ///
432 ///
433 /// // The matrix upper-triangular elements can be garbage because it is never
434 /// // read by this method. Therefore, it is not necessary for the caller to
435 /// // fill the matrix struct upper-triangle.
436 /// let mat = Matrix2::new(1.0, 9999999.9999999,
437 /// 2.0, 4.0);
438 /// let mut vec1 = Vector2::new(1.0, 2.0);
439 /// vec1.sygemv(10.0, &mat, &vec2, 5.0);
440 /// assert_eq!(vec1, Vector2::new(10.0, 20.0));
441 /// ```
442 #[inline]
443 pub fn sygemv<D2: Dim, D3: Dim, SB, SC>(
444 &mut self,
445 alpha: T,
446 a: &SquareMatrix<T, D2, SB>,
447 x: &Vector<T, D3, SC>,
448 beta: T,
449 ) where
450 T: One,
451 SB: Storage<T, D2, D2>,
452 SC: Storage<T, D3>,
453 ShapeConstraint: DimEq<D, D2> + AreMultipliable<D2, D2, D3, U1>,
454 {
455 self.xxgemv(alpha, a, x, beta, |a, b| a.dot(b))
456 }
457
458 /// Computes `self = alpha * a * x + beta * self`, where `a` is an **hermitian** matrix, `x` a
459 /// vector, and `alpha, beta` two scalars.
460 ///
461 /// If `beta` is zero, `self` is never read. If `self` is read, only its lower-triangular part
462 /// (including the diagonal) is actually read.
463 ///
464 /// # Examples
465 /// ```
466 /// # use nalgebra::{Matrix2, Vector2, Complex};
467 /// let mat = Matrix2::new(Complex::new(1.0, 0.0), Complex::new(2.0, -0.1),
468 /// Complex::new(2.0, 1.0), Complex::new(4.0, 0.0));
469 /// let mut vec1 = Vector2::new(Complex::new(1.0, 2.0), Complex::new(3.0, 4.0));
470 /// let vec2 = Vector2::new(Complex::new(0.1, 0.2), Complex::new(0.3, 0.4));
471 /// vec1.sygemv(Complex::new(10.0, 20.0), &mat, &vec2, Complex::new(5.0, 15.0));
472 /// assert_eq!(vec1, Vector2::new(Complex::new(-48.0, 44.0), Complex::new(-75.0, 110.0)));
473 ///
474 ///
475 /// // The matrix upper-triangular elements can be garbage because it is never
476 /// // read by this method. Therefore, it is not necessary for the caller to
477 /// // fill the matrix struct upper-triangle.
478 ///
479 /// let mat = Matrix2::new(Complex::new(1.0, 0.0), Complex::new(99999999.9, 999999999.9),
480 /// Complex::new(2.0, 1.0), Complex::new(4.0, 0.0));
481 /// let mut vec1 = Vector2::new(Complex::new(1.0, 2.0), Complex::new(3.0, 4.0));
482 /// let vec2 = Vector2::new(Complex::new(0.1, 0.2), Complex::new(0.3, 0.4));
483 /// vec1.sygemv(Complex::new(10.0, 20.0), &mat, &vec2, Complex::new(5.0, 15.0));
484 /// assert_eq!(vec1, Vector2::new(Complex::new(-48.0, 44.0), Complex::new(-75.0, 110.0)));
485 /// ```
486 #[inline]
487 pub fn hegemv<D2: Dim, D3: Dim, SB, SC>(
488 &mut self,
489 alpha: T,
490 a: &SquareMatrix<T, D2, SB>,
491 x: &Vector<T, D3, SC>,
492 beta: T,
493 ) where
494 T: SimdComplexField,
495 SB: Storage<T, D2, D2>,
496 SC: Storage<T, D3>,
497 ShapeConstraint: DimEq<D, D2> + AreMultipliable<D2, D2, D3, U1>,
498 {
499 self.xxgemv(alpha, a, x, beta, |a, b| a.dotc(b))
500 }
501
502 #[inline(always)]
503 fn gemv_xx<R2: Dim, C2: Dim, D3: Dim, SB, SC>(
504 &mut self,
505 alpha: T,
506 a: &Matrix<T, R2, C2, SB>,
507 x: &Vector<T, D3, SC>,
508 beta: T,
509 dot: impl Fn(&VectorView<'_, T, R2, SB::RStride, SB::CStride>, &Vector<T, D3, SC>) -> T,
510 ) where
511 T: One,
512 SB: Storage<T, R2, C2>,
513 SC: Storage<T, D3>,
514 ShapeConstraint: DimEq<D, C2> + AreMultipliable<C2, R2, D3, U1>,
515 {
516 let dim1 = self.nrows();
517 let (nrows2, ncols2) = a.shape();
518 let dim3 = x.nrows();
519
520 assert!(
521 nrows2 == dim3 && dim1 == ncols2,
522 "Gemv: dimensions mismatch."
523 );
524
525 if ncols2 == 0 {
526 return;
527 }
528
529 if beta.is_zero() {
530 for j in 0..ncols2 {
531 let val = unsafe { self.vget_unchecked_mut(j) };
532 *val = alpha.clone() * dot(&a.column(j), x)
533 }
534 } else {
535 for j in 0..ncols2 {
536 let val = unsafe { self.vget_unchecked_mut(j) };
537 *val = alpha.clone() * dot(&a.column(j), x) + beta.clone() * val.clone();
538 }
539 }
540 }
541
542 /// Computes `self = alpha * a.transpose() * x + beta * self`, where `a` is a matrix, `x` a vector, and
543 /// `alpha, beta` two scalars.
544 ///
545 /// If `beta` is zero, `self` is never read.
546 ///
547 /// # Example
548 /// ```
549 /// # use nalgebra::{Matrix2, Vector2};
550 /// let mat = Matrix2::new(1.0, 3.0,
551 /// 2.0, 4.0);
552 /// let mut vec1 = Vector2::new(1.0, 2.0);
553 /// let vec2 = Vector2::new(0.1, 0.2);
554 /// let expected = mat.transpose() * vec2 * 10.0 + vec1 * 5.0;
555 ///
556 /// vec1.gemv_tr(10.0, &mat, &vec2, 5.0);
557 /// assert_eq!(vec1, expected);
558 /// ```
559 #[inline]
560 pub fn gemv_tr<R2: Dim, C2: Dim, D3: Dim, SB, SC>(
561 &mut self,
562 alpha: T,
563 a: &Matrix<T, R2, C2, SB>,
564 x: &Vector<T, D3, SC>,
565 beta: T,
566 ) where
567 T: One,
568 SB: Storage<T, R2, C2>,
569 SC: Storage<T, D3>,
570 ShapeConstraint: DimEq<D, C2> + AreMultipliable<C2, R2, D3, U1>,
571 {
572 self.gemv_xx(alpha, a, x, beta, |a, b| a.dot(b))
573 }
574
575 /// Computes `self = alpha * a.adjoint() * x + beta * self`, where `a` is a matrix, `x` a vector, and
576 /// `alpha, beta` two scalars.
577 ///
578 /// For real matrices, this is the same as `.gemv_tr`.
579 /// If `beta` is zero, `self` is never read.
580 ///
581 /// # Example
582 /// ```
583 /// # use nalgebra::{Matrix2, Vector2, Complex};
584 /// let mat = Matrix2::new(Complex::new(1.0, 2.0), Complex::new(3.0, 4.0),
585 /// Complex::new(5.0, 6.0), Complex::new(7.0, 8.0));
586 /// let mut vec1 = Vector2::new(Complex::new(1.0, 2.0), Complex::new(3.0, 4.0));
587 /// let vec2 = Vector2::new(Complex::new(0.1, 0.2), Complex::new(0.3, 0.4));
588 /// let expected = mat.adjoint() * vec2 * Complex::new(10.0, 20.0) + vec1 * Complex::new(5.0, 15.0);
589 ///
590 /// vec1.gemv_ad(Complex::new(10.0, 20.0), &mat, &vec2, Complex::new(5.0, 15.0));
591 /// assert_eq!(vec1, expected);
592 /// ```
593 #[inline]
594 pub fn gemv_ad<R2: Dim, C2: Dim, D3: Dim, SB, SC>(
595 &mut self,
596 alpha: T,
597 a: &Matrix<T, R2, C2, SB>,
598 x: &Vector<T, D3, SC>,
599 beta: T,
600 ) where
601 T: SimdComplexField,
602 SB: Storage<T, R2, C2>,
603 SC: Storage<T, D3>,
604 ShapeConstraint: DimEq<D, C2> + AreMultipliable<C2, R2, D3, U1>,
605 {
606 self.gemv_xx(alpha, a, x, beta, |a, b| a.dotc(b))
607 }
608}
609
610impl<T, R1: Dim, C1: Dim, S: StorageMut<T, R1, C1>> Matrix<T, R1, C1, S>
611where
612 T: Scalar + Zero + ClosedAddAssign + ClosedMulAssign,
613{
614 #[inline(always)]
615 fn gerx<D2: Dim, D3: Dim, SB, SC>(
616 &mut self,
617 alpha: T,
618 x: &Vector<T, D2, SB>,
619 y: &Vector<T, D3, SC>,
620 beta: T,
621 conjugate: impl Fn(T) -> T,
622 ) where
623 T: One,
624 SB: Storage<T, D2>,
625 SC: Storage<T, D3>,
626 ShapeConstraint: DimEq<R1, D2> + DimEq<C1, D3>,
627 {
628 let (nrows1, ncols1) = self.shape();
629 let dim2 = x.nrows();
630 let dim3 = y.nrows();
631
632 assert!(
633 nrows1 == dim2 && ncols1 == dim3,
634 "ger: dimensions mismatch."
635 );
636
637 for j in 0..ncols1 {
638 // TODO: avoid bound checks.
639 let val = unsafe { conjugate(y.vget_unchecked(j).clone()) };
640 self.column_mut(j)
641 .axpy(alpha.clone() * val, x, beta.clone());
642 }
643 }
644
645 /// Computes `self = alpha * x * y.transpose() + beta * self`.
646 ///
647 /// If `beta` is zero, `self` is never read.
648 ///
649 /// # Example
650 /// ```
651 /// # use nalgebra::{Matrix2x3, Vector2, Vector3};
652 /// let mut mat = Matrix2x3::repeat(4.0);
653 /// let vec1 = Vector2::new(1.0, 2.0);
654 /// let vec2 = Vector3::new(0.1, 0.2, 0.3);
655 /// let expected = vec1 * vec2.transpose() * 10.0 + mat * 5.0;
656 ///
657 /// mat.ger(10.0, &vec1, &vec2, 5.0);
658 /// assert_eq!(mat, expected);
659 /// ```
660 #[inline]
661 pub fn ger<D2: Dim, D3: Dim, SB, SC>(
662 &mut self,
663 alpha: T,
664 x: &Vector<T, D2, SB>,
665 y: &Vector<T, D3, SC>,
666 beta: T,
667 ) where
668 T: One,
669 SB: Storage<T, D2>,
670 SC: Storage<T, D3>,
671 ShapeConstraint: DimEq<R1, D2> + DimEq<C1, D3>,
672 {
673 self.gerx(alpha, x, y, beta, |e| e)
674 }
675
676 /// Computes `self = alpha * x * y.adjoint() + beta * self`.
677 ///
678 /// If `beta` is zero, `self` is never read.
679 ///
680 /// # Example
681 /// ```
682 /// # #[macro_use] extern crate approx;
683 /// # use nalgebra::{Matrix2x3, Vector2, Vector3, Complex};
684 /// let mut mat = Matrix2x3::repeat(Complex::new(4.0, 5.0));
685 /// let vec1 = Vector2::new(Complex::new(1.0, 2.0), Complex::new(3.0, 4.0));
686 /// let vec2 = Vector3::new(Complex::new(0.6, 0.5), Complex::new(0.4, 0.5), Complex::new(0.2, 0.1));
687 /// let expected = vec1 * vec2.adjoint() * Complex::new(10.0, 20.0) + mat * Complex::new(5.0, 15.0);
688 ///
689 /// mat.gerc(Complex::new(10.0, 20.0), &vec1, &vec2, Complex::new(5.0, 15.0));
690 /// assert_eq!(mat, expected);
691 /// ```
692 #[inline]
693 pub fn gerc<D2: Dim, D3: Dim, SB, SC>(
694 &mut self,
695 alpha: T,
696 x: &Vector<T, D2, SB>,
697 y: &Vector<T, D3, SC>,
698 beta: T,
699 ) where
700 T: SimdComplexField,
701 SB: Storage<T, D2>,
702 SC: Storage<T, D3>,
703 ShapeConstraint: DimEq<R1, D2> + DimEq<C1, D3>,
704 {
705 self.gerx(alpha, x, y, beta, SimdComplexField::simd_conjugate)
706 }
707
708 /// Computes `self = alpha * a * b + beta * self`, where `a, b, self` are matrices.
709 /// `alpha` and `beta` are scalar.
710 ///
711 /// If `beta` is zero, `self` is never read.
712 ///
713 /// # Example
714 /// ```
715 /// # #[macro_use] extern crate approx;
716 /// # use nalgebra::{Matrix2x3, Matrix3x4, Matrix2x4};
717 /// let mut mat1 = Matrix2x4::identity();
718 /// let mat2 = Matrix2x3::new(1.0, 2.0, 3.0,
719 /// 4.0, 5.0, 6.0);
720 /// let mat3 = Matrix3x4::new(0.1, 0.2, 0.3, 0.4,
721 /// 0.5, 0.6, 0.7, 0.8,
722 /// 0.9, 1.0, 1.1, 1.2);
723 /// let expected = mat2 * mat3 * 10.0 + mat1 * 5.0;
724 ///
725 /// mat1.gemm(10.0, &mat2, &mat3, 5.0);
726 /// assert_relative_eq!(mat1, expected);
727 /// ```
728 #[inline]
729 pub fn gemm<R2: Dim, C2: Dim, R3: Dim, C3: Dim, SB, SC>(
730 &mut self,
731 alpha: T,
732 a: &Matrix<T, R2, C2, SB>,
733 b: &Matrix<T, R3, C3, SC>,
734 beta: T,
735 ) where
736 T: One,
737 SB: Storage<T, R2, C2>,
738 SC: Storage<T, R3, C3>,
739 ShapeConstraint: SameNumberOfRows<R1, R2>
740 + SameNumberOfColumns<C1, C3>
741 + AreMultipliable<R2, C2, R3, C3>,
742 {
743 // SAFETY: this is valid because our matrices are initialized and
744 // we are using status = Init.
745 unsafe { gemm_uninit(Init, self, alpha, a, b, beta) }
746 }
747
748 /// Computes `self = alpha * a.transpose() * b + beta * self`, where `a, b, self` are matrices.
749 /// `alpha` and `beta` are scalar.
750 ///
751 /// If `beta` is zero, `self` is never read.
752 ///
753 /// # Example
754 /// ```
755 /// # #[macro_use] extern crate approx;
756 /// # use nalgebra::{Matrix3x2, Matrix3x4, Matrix2x4};
757 /// let mut mat1 = Matrix2x4::identity();
758 /// let mat2 = Matrix3x2::new(1.0, 4.0,
759 /// 2.0, 5.0,
760 /// 3.0, 6.0);
761 /// let mat3 = Matrix3x4::new(0.1, 0.2, 0.3, 0.4,
762 /// 0.5, 0.6, 0.7, 0.8,
763 /// 0.9, 1.0, 1.1, 1.2);
764 /// let expected = mat2.transpose() * mat3 * 10.0 + mat1 * 5.0;
765 ///
766 /// mat1.gemm_tr(10.0, &mat2, &mat3, 5.0);
767 /// assert_eq!(mat1, expected);
768 /// ```
769 #[inline]
770 pub fn gemm_tr<R2: Dim, C2: Dim, R3: Dim, C3: Dim, SB, SC>(
771 &mut self,
772 alpha: T,
773 a: &Matrix<T, R2, C2, SB>,
774 b: &Matrix<T, R3, C3, SC>,
775 beta: T,
776 ) where
777 T: One,
778 SB: Storage<T, R2, C2>,
779 SC: Storage<T, R3, C3>,
780 ShapeConstraint: SameNumberOfRows<R1, C2>
781 + SameNumberOfColumns<C1, C3>
782 + AreMultipliable<C2, R2, R3, C3>,
783 {
784 let (nrows1, ncols1) = self.shape();
785 let (nrows2, ncols2) = a.shape();
786 let (nrows3, ncols3) = b.shape();
787
788 assert_eq!(
789 nrows2, nrows3,
790 "gemm: dimensions mismatch for multiplication."
791 );
792 assert_eq!(
793 (nrows1, ncols1),
794 (ncols2, ncols3),
795 "gemm: dimensions mismatch for addition."
796 );
797
798 for j1 in 0..ncols1 {
799 // TODO: avoid bound checks.
800 self.column_mut(j1)
801 .gemv_tr(alpha.clone(), a, &b.column(j1), beta.clone());
802 }
803 }
804
805 /// Computes `self = alpha * a.adjoint() * b + beta * self`, where `a, b, self` are matrices.
806 /// `alpha` and `beta` are scalar.
807 ///
808 /// If `beta` is zero, `self` is never read.
809 ///
810 /// # Example
811 /// ```
812 /// # #[macro_use] extern crate approx;
813 /// # use nalgebra::{Matrix3x2, Matrix3x4, Matrix2x4, Complex};
814 /// let mut mat1 = Matrix2x4::identity();
815 /// let mat2 = Matrix3x2::new(Complex::new(1.0, 4.0), Complex::new(7.0, 8.0),
816 /// Complex::new(2.0, 5.0), Complex::new(9.0, 10.0),
817 /// Complex::new(3.0, 6.0), Complex::new(11.0, 12.0));
818 /// let mat3 = Matrix3x4::new(Complex::new(0.1, 1.3), Complex::new(0.2, 1.4), Complex::new(0.3, 1.5), Complex::new(0.4, 1.6),
819 /// Complex::new(0.5, 1.7), Complex::new(0.6, 1.8), Complex::new(0.7, 1.9), Complex::new(0.8, 2.0),
820 /// Complex::new(0.9, 2.1), Complex::new(1.0, 2.2), Complex::new(1.1, 2.3), Complex::new(1.2, 2.4));
821 /// let expected = mat2.adjoint() * mat3 * Complex::new(10.0, 20.0) + mat1 * Complex::new(5.0, 15.0);
822 ///
823 /// mat1.gemm_ad(Complex::new(10.0, 20.0), &mat2, &mat3, Complex::new(5.0, 15.0));
824 /// assert_eq!(mat1, expected);
825 /// ```
826 #[inline]
827 pub fn gemm_ad<R2: Dim, C2: Dim, R3: Dim, C3: Dim, SB, SC>(
828 &mut self,
829 alpha: T,
830 a: &Matrix<T, R2, C2, SB>,
831 b: &Matrix<T, R3, C3, SC>,
832 beta: T,
833 ) where
834 T: SimdComplexField,
835 SB: Storage<T, R2, C2>,
836 SC: Storage<T, R3, C3>,
837 ShapeConstraint: SameNumberOfRows<R1, C2>
838 + SameNumberOfColumns<C1, C3>
839 + AreMultipliable<C2, R2, R3, C3>,
840 {
841 let (nrows1, ncols1) = self.shape();
842 let (nrows2, ncols2) = a.shape();
843 let (nrows3, ncols3) = b.shape();
844
845 assert_eq!(
846 nrows2, nrows3,
847 "gemm: dimensions mismatch for multiplication."
848 );
849 assert_eq!(
850 (nrows1, ncols1),
851 (ncols2, ncols3),
852 "gemm: dimensions mismatch for addition."
853 );
854
855 for j1 in 0..ncols1 {
856 // TODO: avoid bound checks.
857 self.column_mut(j1)
858 .gemv_ad(alpha.clone(), a, &b.column(j1), beta.clone());
859 }
860 }
861}
862
863impl<T, R1: Dim, C1: Dim, S: StorageMut<T, R1, C1>> Matrix<T, R1, C1, S>
864where
865 T: Scalar + Zero + ClosedAddAssign + ClosedMulAssign,
866{
867 #[inline(always)]
868 fn xxgerx<D2: Dim, D3: Dim, SB, SC>(
869 &mut self,
870 alpha: T,
871 x: &Vector<T, D2, SB>,
872 y: &Vector<T, D3, SC>,
873 beta: T,
874 conjugate: impl Fn(T) -> T,
875 ) where
876 T: One,
877 SB: Storage<T, D2>,
878 SC: Storage<T, D3>,
879 ShapeConstraint: DimEq<R1, D2> + DimEq<C1, D3>,
880 {
881 let dim1 = self.nrows();
882 let dim2 = x.nrows();
883 let dim3 = y.nrows();
884
885 assert!(
886 self.is_square(),
887 "Symmetric ger: the input matrix must be square."
888 );
889 assert!(dim1 == dim2 && dim1 == dim3, "ger: dimensions mismatch.");
890
891 for j in 0..dim1 {
892 let val = unsafe { conjugate(y.vget_unchecked(j).clone()) };
893 let subdim = Dyn(dim1 - j);
894 // TODO: avoid bound checks.
895 self.generic_view_mut((j, j), (subdim, Const::<1>)).axpy(
896 alpha.clone() * val,
897 &x.rows_range(j..),
898 beta.clone(),
899 );
900 }
901 }
902
903 /// Computes `self = alpha * x * y.transpose() + beta * self`, where `self` is a **symmetric**
904 /// matrix.
905 ///
906 /// If `beta` is zero, `self` is never read. The result is symmetric. Only the lower-triangular
907 /// (including the diagonal) part of `self` is read/written.
908 ///
909 /// # Example
910 /// ```
911 /// # use nalgebra::{Matrix2, Vector2};
912 /// let mut mat = Matrix2::identity();
913 /// let vec1 = Vector2::new(1.0, 2.0);
914 /// let vec2 = Vector2::new(0.1, 0.2);
915 /// let expected = vec1 * vec2.transpose() * 10.0 + mat * 5.0;
916 /// mat.m12 = 99999.99999; // This component is on the upper-triangular part and will not be read/written.
917 ///
918 /// mat.ger_symm(10.0, &vec1, &vec2, 5.0);
919 /// assert_eq!(mat.lower_triangle(), expected.lower_triangle());
920 /// assert_eq!(mat.m12, 99999.99999); // This was untouched.
921 /// ```
922 #[inline]
923 #[deprecated(note = "This is renamed `syger` to match the original BLAS terminology.")]
924 pub fn ger_symm<D2: Dim, D3: Dim, SB, SC>(
925 &mut self,
926 alpha: T,
927 x: &Vector<T, D2, SB>,
928 y: &Vector<T, D3, SC>,
929 beta: T,
930 ) where
931 T: One,
932 SB: Storage<T, D2>,
933 SC: Storage<T, D3>,
934 ShapeConstraint: DimEq<R1, D2> + DimEq<C1, D3>,
935 {
936 self.syger(alpha, x, y, beta)
937 }
938
939 /// Computes `self = alpha * x * y.transpose() + beta * self`, where `self` is a **symmetric**
940 /// matrix.
941 ///
942 /// For hermitian complex matrices, use `.hegerc` instead.
943 /// If `beta` is zero, `self` is never read. The result is symmetric. Only the lower-triangular
944 /// (including the diagonal) part of `self` is read/written.
945 ///
946 /// # Example
947 /// ```
948 /// # use nalgebra::{Matrix2, Vector2};
949 /// let mut mat = Matrix2::identity();
950 /// let vec1 = Vector2::new(1.0, 2.0);
951 /// let vec2 = Vector2::new(0.1, 0.2);
952 /// let expected = vec1 * vec2.transpose() * 10.0 + mat * 5.0;
953 /// mat.m12 = 99999.99999; // This component is on the upper-triangular part and will not be read/written.
954 ///
955 /// mat.syger(10.0, &vec1, &vec2, 5.0);
956 /// assert_eq!(mat.lower_triangle(), expected.lower_triangle());
957 /// assert_eq!(mat.m12, 99999.99999); // This was untouched.
958 /// ```
959 #[inline]
960 pub fn syger<D2: Dim, D3: Dim, SB, SC>(
961 &mut self,
962 alpha: T,
963 x: &Vector<T, D2, SB>,
964 y: &Vector<T, D3, SC>,
965 beta: T,
966 ) where
967 T: One,
968 SB: Storage<T, D2>,
969 SC: Storage<T, D3>,
970 ShapeConstraint: DimEq<R1, D2> + DimEq<C1, D3>,
971 {
972 self.xxgerx(alpha, x, y, beta, |e| e)
973 }
974
975 /// Computes `self = alpha * x * y.adjoint() + beta * self`, where `self` is an **hermitian**
976 /// matrix.
977 ///
978 /// If `beta` is zero, `self` is never read. The result is symmetric. Only the lower-triangular
979 /// (including the diagonal) part of `self` is read/written.
980 ///
981 /// # Example
982 /// ```
983 /// # use nalgebra::{Matrix2, Vector2, Complex};
984 /// let mut mat = Matrix2::identity();
985 /// let vec1 = Vector2::new(Complex::new(1.0, 3.0), Complex::new(2.0, 4.0));
986 /// let vec2 = Vector2::new(Complex::new(0.2, 0.4), Complex::new(0.1, 0.3));
987 /// let expected = vec1 * vec2.adjoint() * Complex::new(10.0, 20.0) + mat * Complex::new(5.0, 15.0);
988 /// mat.m12 = Complex::new(99999.99999, 88888.88888); // This component is on the upper-triangular part and will not be read/written.
989 ///
990 /// mat.hegerc(Complex::new(10.0, 20.0), &vec1, &vec2, Complex::new(5.0, 15.0));
991 /// assert_eq!(mat.lower_triangle(), expected.lower_triangle());
992 /// assert_eq!(mat.m12, Complex::new(99999.99999, 88888.88888)); // This was untouched.
993 /// ```
994 #[inline]
995 pub fn hegerc<D2: Dim, D3: Dim, SB, SC>(
996 &mut self,
997 alpha: T,
998 x: &Vector<T, D2, SB>,
999 y: &Vector<T, D3, SC>,
1000 beta: T,
1001 ) where
1002 T: SimdComplexField,
1003 SB: Storage<T, D2>,
1004 SC: Storage<T, D3>,
1005 ShapeConstraint: DimEq<R1, D2> + DimEq<C1, D3>,
1006 {
1007 self.xxgerx(alpha, x, y, beta, SimdComplexField::simd_conjugate)
1008 }
1009}
1010
1011impl<T, D1: Dim, S: StorageMut<T, D1, D1>> SquareMatrix<T, D1, S>
1012where
1013 T: Scalar + Zero + One + ClosedAddAssign + ClosedMulAssign,
1014{
1015 /// Computes the quadratic form `self = alpha * lhs * mid * lhs.transpose() + beta * self`.
1016 ///
1017 /// This uses the provided workspace `work` to avoid allocations for intermediate results.
1018 ///
1019 /// # Example
1020 /// ```
1021 /// # #[macro_use] extern crate approx;
1022 /// # use nalgebra::{DMatrix, DVector};
1023 /// // Note that all those would also work with statically-sized matrices.
1024 /// // We use DMatrix/DVector since that's the only case where pre-allocating the
1025 /// // workspace is actually useful (assuming the same workspace is re-used for
1026 /// // several computations) because it avoids repeated dynamic allocations.
1027 /// let mut mat = DMatrix::identity(2, 2);
1028 /// let lhs = DMatrix::from_row_slice(2, 3, &[1.0, 2.0, 3.0,
1029 /// 4.0, 5.0, 6.0]);
1030 /// let mid = DMatrix::from_row_slice(3, 3, &[0.1, 0.2, 0.3,
1031 /// 0.5, 0.6, 0.7,
1032 /// 0.9, 1.0, 1.1]);
1033 /// // The random shows that values on the workspace do not
1034 /// // matter as they will be overwritten.
1035 /// let mut workspace = DVector::new_random(2);
1036 /// let expected = &lhs * &mid * lhs.transpose() * 10.0 + &mat * 5.0;
1037 ///
1038 /// mat.quadform_tr_with_workspace(&mut workspace, 10.0, &lhs, &mid, 5.0);
1039 /// assert_relative_eq!(mat, expected);
1040 /// ```
1041 pub fn quadform_tr_with_workspace<D2, S2, R3, C3, S3, D4, S4>(
1042 &mut self,
1043 work: &mut Vector<T, D2, S2>,
1044 alpha: T,
1045 lhs: &Matrix<T, R3, C3, S3>,
1046 mid: &SquareMatrix<T, D4, S4>,
1047 beta: T,
1048 ) where
1049 D2: Dim,
1050 R3: Dim,
1051 C3: Dim,
1052 D4: Dim,
1053 S2: StorageMut<T, D2>,
1054 S3: Storage<T, R3, C3>,
1055 S4: Storage<T, D4, D4>,
1056 ShapeConstraint: DimEq<D1, D2> + DimEq<D1, R3> + DimEq<D2, R3> + DimEq<C3, D4>,
1057 {
1058 work.gemv(T::one(), lhs, &mid.column(0), T::zero());
1059 self.ger(alpha.clone(), work, &lhs.column(0), beta);
1060
1061 for j in 1..mid.ncols() {
1062 work.gemv(T::one(), lhs, &mid.column(j), T::zero());
1063 self.ger(alpha.clone(), work, &lhs.column(j), T::one());
1064 }
1065 }
1066
1067 /// Computes the quadratic form `self = alpha * lhs * mid * lhs.transpose() + beta * self`.
1068 ///
1069 /// This allocates a workspace vector of dimension D1 for intermediate results.
1070 /// If `D1` is a type-level integer, then the allocation is performed on the stack.
1071 /// Use `.quadform_tr_with_workspace(...)` instead to avoid allocations.
1072 ///
1073 /// # Example
1074 /// ```
1075 /// # #[macro_use] extern crate approx;
1076 /// # use nalgebra::{Matrix2, Matrix3, Matrix2x3, Vector2};
1077 /// let mut mat = Matrix2::identity();
1078 /// let lhs = Matrix2x3::new(1.0, 2.0, 3.0,
1079 /// 4.0, 5.0, 6.0);
1080 /// let mid = Matrix3::new(0.1, 0.2, 0.3,
1081 /// 0.5, 0.6, 0.7,
1082 /// 0.9, 1.0, 1.1);
1083 /// let expected = lhs * mid * lhs.transpose() * 10.0 + mat * 5.0;
1084 ///
1085 /// mat.quadform_tr(10.0, &lhs, &mid, 5.0);
1086 /// assert_relative_eq!(mat, expected);
1087 /// ```
1088 pub fn quadform_tr<R3, C3, S3, D4, S4>(
1089 &mut self,
1090 alpha: T,
1091 lhs: &Matrix<T, R3, C3, S3>,
1092 mid: &SquareMatrix<T, D4, S4>,
1093 beta: T,
1094 ) where
1095 R3: Dim,
1096 C3: Dim,
1097 D4: Dim,
1098 S3: Storage<T, R3, C3>,
1099 S4: Storage<T, D4, D4>,
1100 ShapeConstraint: DimEq<D1, D1> + DimEq<D1, R3> + DimEq<C3, D4>,
1101 DefaultAllocator: Allocator<D1>,
1102 {
1103 // TODO: would it be useful to avoid the zero-initialization of the workspace data?
1104 let mut work = Matrix::zeros_generic(self.shape_generic().0, Const::<1>);
1105 self.quadform_tr_with_workspace(&mut work, alpha, lhs, mid, beta)
1106 }
1107
1108 /// Computes the quadratic form `self = alpha * rhs.transpose() * mid * rhs + beta * self`.
1109 ///
1110 /// This uses the provided workspace `work` to avoid allocations for intermediate results.
1111 ///
1112 /// # Example
1113 /// ```
1114 /// # #[macro_use] extern crate approx;
1115 /// # use nalgebra::{DMatrix, DVector};
1116 /// // Note that all those would also work with statically-sized matrices.
1117 /// // We use DMatrix/DVector since that's the only case where pre-allocating the
1118 /// // workspace is actually useful (assuming the same workspace is re-used for
1119 /// // several computations) because it avoids repeated dynamic allocations.
1120 /// let mut mat = DMatrix::identity(2, 2);
1121 /// let rhs = DMatrix::from_row_slice(3, 2, &[1.0, 2.0,
1122 /// 3.0, 4.0,
1123 /// 5.0, 6.0]);
1124 /// let mid = DMatrix::from_row_slice(3, 3, &[0.1, 0.2, 0.3,
1125 /// 0.5, 0.6, 0.7,
1126 /// 0.9, 1.0, 1.1]);
1127 /// // The random shows that values on the workspace do not
1128 /// // matter as they will be overwritten.
1129 /// let mut workspace = DVector::new_random(3);
1130 /// let expected = rhs.transpose() * &mid * &rhs * 10.0 + &mat * 5.0;
1131 ///
1132 /// mat.quadform_with_workspace(&mut workspace, 10.0, &mid, &rhs, 5.0);
1133 /// assert_relative_eq!(mat, expected);
1134 /// ```
1135 pub fn quadform_with_workspace<D2, S2, D3, S3, R4, C4, S4>(
1136 &mut self,
1137 work: &mut Vector<T, D2, S2>,
1138 alpha: T,
1139 mid: &SquareMatrix<T, D3, S3>,
1140 rhs: &Matrix<T, R4, C4, S4>,
1141 beta: T,
1142 ) where
1143 D2: Dim,
1144 D3: Dim,
1145 R4: Dim,
1146 C4: Dim,
1147 S2: StorageMut<T, D2>,
1148 S3: Storage<T, D3, D3>,
1149 S4: Storage<T, R4, C4>,
1150 ShapeConstraint:
1151 DimEq<D3, R4> + DimEq<D1, C4> + DimEq<D2, D3> + AreMultipliable<C4, R4, D2, U1>,
1152 {
1153 work.gemv(T::one(), mid, &rhs.column(0), T::zero());
1154 self.column_mut(0)
1155 .gemv_tr(alpha.clone(), rhs, work, beta.clone());
1156
1157 for j in 1..rhs.ncols() {
1158 work.gemv(T::one(), mid, &rhs.column(j), T::zero());
1159 self.column_mut(j)
1160 .gemv_tr(alpha.clone(), rhs, work, beta.clone());
1161 }
1162 }
1163
1164 /// Computes the quadratic form `self = alpha * rhs.transpose() * mid * rhs + beta * self`.
1165 ///
1166 /// This allocates a workspace vector of dimension D2 for intermediate results.
1167 /// If `D2` is a type-level integer, then the allocation is performed on the stack.
1168 /// Use `.quadform_with_workspace(...)` instead to avoid allocations.
1169 ///
1170 /// # Example
1171 /// ```
1172 /// # #[macro_use] extern crate approx;
1173 /// # use nalgebra::{Matrix2, Matrix3x2, Matrix3};
1174 /// let mut mat = Matrix2::identity();
1175 /// let rhs = Matrix3x2::new(1.0, 2.0,
1176 /// 3.0, 4.0,
1177 /// 5.0, 6.0);
1178 /// let mid = Matrix3::new(0.1, 0.2, 0.3,
1179 /// 0.5, 0.6, 0.7,
1180 /// 0.9, 1.0, 1.1);
1181 /// let expected = rhs.transpose() * mid * rhs * 10.0 + mat * 5.0;
1182 ///
1183 /// mat.quadform(10.0, &mid, &rhs, 5.0);
1184 /// assert_relative_eq!(mat, expected);
1185 /// ```
1186 pub fn quadform<D2, S2, R3, C3, S3>(
1187 &mut self,
1188 alpha: T,
1189 mid: &SquareMatrix<T, D2, S2>,
1190 rhs: &Matrix<T, R3, C3, S3>,
1191 beta: T,
1192 ) where
1193 D2: Dim,
1194 R3: Dim,
1195 C3: Dim,
1196 S2: Storage<T, D2, D2>,
1197 S3: Storage<T, R3, C3>,
1198 ShapeConstraint: DimEq<D2, R3> + DimEq<D1, C3> + AreMultipliable<C3, R3, D2, U1>,
1199 DefaultAllocator: Allocator<D2>,
1200 {
1201 // TODO: would it be useful to avoid the zero-initialization of the workspace data?
1202 let mut work = Vector::zeros_generic(mid.shape_generic().0, Const::<1>);
1203 self.quadform_with_workspace(&mut work, alpha, mid, rhs, beta)
1204 }
1205}