nalgebra/linalg/
bidiagonal.rs

1#[cfg(feature = "serde-serialize-no-std")]
2use serde::{Deserialize, Serialize};
3
4use crate::allocator::Allocator;
5use crate::base::{DefaultAllocator, Matrix, OMatrix, OVector, Unit};
6use crate::dimension::{Const, Dim, DimDiff, DimMin, DimMinimum, DimSub, U1};
7use simba::scalar::ComplexField;
8
9use crate::geometry::Reflection;
10use crate::linalg::householder;
11use crate::num::Zero;
12use std::mem::MaybeUninit;
13
14/// The bidiagonalization of a general matrix.
15#[cfg_attr(feature = "serde-serialize-no-std", derive(Serialize, Deserialize))]
16#[cfg_attr(
17    feature = "serde-serialize-no-std",
18    serde(bound(serialize = "DimMinimum<R, C>: DimSub<U1>,
19         DefaultAllocator: Allocator<R, C>             +
20                           Allocator<DimMinimum<R, C>> +
21                           Allocator<DimDiff<DimMinimum<R, C>, U1>>,
22         OMatrix<T, R, C>: Serialize,
23         OVector<T, DimMinimum<R, C>>: Serialize,
24         OVector<T, DimDiff<DimMinimum<R, C>, U1>>: Serialize"))
25)]
26#[cfg_attr(
27    feature = "serde-serialize-no-std",
28    serde(bound(deserialize = "DimMinimum<R, C>: DimSub<U1>,
29         DefaultAllocator: Allocator<R, C>             +
30                           Allocator<DimMinimum<R, C>> +
31                           Allocator<DimDiff<DimMinimum<R, C>, U1>>,
32         OMatrix<T, R, C>: Deserialize<'de>,
33         OVector<T, DimMinimum<R, C>>: Deserialize<'de>,
34         OVector<T, DimDiff<DimMinimum<R, C>, U1>>: Deserialize<'de>"))
35)]
36#[derive(Clone, Debug)]
37pub struct Bidiagonal<T: ComplexField, R: DimMin<C>, C: Dim>
38where
39    DimMinimum<R, C>: DimSub<U1>,
40    DefaultAllocator:
41        Allocator<R, C> + Allocator<DimMinimum<R, C>> + Allocator<DimDiff<DimMinimum<R, C>, U1>>,
42{
43    // TODO: perhaps we should pack the axes into different vectors so that axes for `v_t` are
44    // contiguous. This prevents some useless copies.
45    uv: OMatrix<T, R, C>,
46    /// The diagonal elements of the decomposed matrix.
47    diagonal: OVector<T, DimMinimum<R, C>>,
48    /// The off-diagonal elements of the decomposed matrix.
49    off_diagonal: OVector<T, DimDiff<DimMinimum<R, C>, U1>>,
50    upper_diagonal: bool,
51}
52
53impl<T: ComplexField, R: DimMin<C>, C: Dim> Copy for Bidiagonal<T, R, C>
54where
55    DimMinimum<R, C>: DimSub<U1>,
56    DefaultAllocator:
57        Allocator<R, C> + Allocator<DimMinimum<R, C>> + Allocator<DimDiff<DimMinimum<R, C>, U1>>,
58    OMatrix<T, R, C>: Copy,
59    OVector<T, DimMinimum<R, C>>: Copy,
60    OVector<T, DimDiff<DimMinimum<R, C>, U1>>: Copy,
61{
62}
63
64impl<T: ComplexField, R: DimMin<C>, C: Dim> Bidiagonal<T, R, C>
65where
66    DimMinimum<R, C>: DimSub<U1>,
67    DefaultAllocator: Allocator<R, C>
68        + Allocator<C>
69        + Allocator<R>
70        + Allocator<DimMinimum<R, C>>
71        + Allocator<DimDiff<DimMinimum<R, C>, U1>>,
72{
73    /// Computes the Bidiagonal decomposition using householder reflections.
74    pub fn new(mut matrix: OMatrix<T, R, C>) -> Self {
75        let (nrows, ncols) = matrix.shape_generic();
76        let min_nrows_ncols = nrows.min(ncols);
77        let dim = min_nrows_ncols.value();
78        assert!(
79            dim != 0,
80            "Cannot compute the bidiagonalization of an empty matrix."
81        );
82
83        let mut diagonal = Matrix::uninit(min_nrows_ncols, Const::<1>);
84        let mut off_diagonal = Matrix::uninit(min_nrows_ncols.sub(Const::<1>), Const::<1>);
85        let mut axis_packed = Matrix::zeros_generic(ncols, Const::<1>);
86        let mut work = Matrix::zeros_generic(nrows, Const::<1>);
87
88        let upper_diagonal = nrows.value() >= ncols.value();
89        if upper_diagonal {
90            for ite in 0..dim - 1 {
91                diagonal[ite] = MaybeUninit::new(householder::clear_column_unchecked(
92                    &mut matrix,
93                    ite,
94                    0,
95                    None,
96                ));
97                off_diagonal[ite] = MaybeUninit::new(householder::clear_row_unchecked(
98                    &mut matrix,
99                    &mut axis_packed,
100                    &mut work,
101                    ite,
102                    1,
103                ));
104            }
105
106            diagonal[dim - 1] = MaybeUninit::new(householder::clear_column_unchecked(
107                &mut matrix,
108                dim - 1,
109                0,
110                None,
111            ));
112        } else {
113            for ite in 0..dim - 1 {
114                diagonal[ite] = MaybeUninit::new(householder::clear_row_unchecked(
115                    &mut matrix,
116                    &mut axis_packed,
117                    &mut work,
118                    ite,
119                    0,
120                ));
121                off_diagonal[ite] = MaybeUninit::new(householder::clear_column_unchecked(
122                    &mut matrix,
123                    ite,
124                    1,
125                    None,
126                ));
127            }
128
129            diagonal[dim - 1] = MaybeUninit::new(householder::clear_row_unchecked(
130                &mut matrix,
131                &mut axis_packed,
132                &mut work,
133                dim - 1,
134                0,
135            ));
136        }
137
138        // Safety: diagonal and off_diagonal have been fully initialized.
139        let (diagonal, off_diagonal) =
140            unsafe { (diagonal.assume_init(), off_diagonal.assume_init()) };
141
142        Bidiagonal {
143            uv: matrix,
144            diagonal,
145            off_diagonal,
146            upper_diagonal,
147        }
148    }
149
150    /// Indicates whether this decomposition contains an upper-diagonal matrix.
151    #[inline]
152    #[must_use]
153    pub fn is_upper_diagonal(&self) -> bool {
154        self.upper_diagonal
155    }
156
157    #[inline]
158    fn axis_shift(&self) -> (usize, usize) {
159        if self.upper_diagonal {
160            (0, 1)
161        } else {
162            (1, 0)
163        }
164    }
165
166    /// Unpacks this decomposition into its three matrix factors `(U, D, V^t)`.
167    ///
168    /// The decomposed matrix `M` is equal to `U * D * V^t`.
169    #[inline]
170    pub fn unpack(
171        self,
172    ) -> (
173        OMatrix<T, R, DimMinimum<R, C>>,
174        OMatrix<T, DimMinimum<R, C>, DimMinimum<R, C>>,
175        OMatrix<T, DimMinimum<R, C>, C>,
176    )
177    where
178        DefaultAllocator: Allocator<DimMinimum<R, C>, DimMinimum<R, C>>
179            + Allocator<R, DimMinimum<R, C>>
180            + Allocator<DimMinimum<R, C>, C>,
181    {
182        // TODO: optimize by calling a reallocator.
183        (self.u(), self.d(), self.v_t())
184    }
185
186    /// Retrieves the upper trapezoidal submatrix `R` of this decomposition.
187    #[inline]
188    #[must_use]
189    pub fn d(&self) -> OMatrix<T, DimMinimum<R, C>, DimMinimum<R, C>>
190    where
191        DefaultAllocator: Allocator<DimMinimum<R, C>, DimMinimum<R, C>>,
192    {
193        let (nrows, ncols) = self.uv.shape_generic();
194
195        let d = nrows.min(ncols);
196        let mut res = OMatrix::identity_generic(d, d);
197        res.set_partial_diagonal(
198            self.diagonal
199                .iter()
200                .map(|e| T::from_real(e.clone().modulus())),
201        );
202
203        let start = self.axis_shift();
204        res.view_mut(start, (d.value() - 1, d.value() - 1))
205            .set_partial_diagonal(
206                self.off_diagonal
207                    .iter()
208                    .map(|e| T::from_real(e.clone().modulus())),
209            );
210        res
211    }
212
213    /// Computes the orthogonal matrix `U` of this `U * D * V` decomposition.
214    // TODO: code duplication with householder::assemble_q.
215    // Except that we are returning a rectangular matrix here.
216    #[must_use]
217    pub fn u(&self) -> OMatrix<T, R, DimMinimum<R, C>>
218    where
219        DefaultAllocator: Allocator<R, DimMinimum<R, C>>,
220    {
221        let (nrows, ncols) = self.uv.shape_generic();
222
223        let mut res = Matrix::identity_generic(nrows, nrows.min(ncols));
224        let dim = self.diagonal.len();
225        let shift = self.axis_shift().0;
226
227        for i in (0..dim - shift).rev() {
228            let axis = self.uv.view_range(i + shift.., i);
229
230            // Sometimes, the axis might have a zero magnitude.
231            if axis.norm_squared().is_zero() {
232                continue;
233            }
234            let refl = Reflection::new(Unit::new_unchecked(axis), T::zero());
235
236            let mut res_rows = res.view_range_mut(i + shift.., i..);
237
238            let sign = if self.upper_diagonal {
239                self.diagonal[i].clone().signum()
240            } else {
241                self.off_diagonal[i].clone().signum()
242            };
243
244            refl.reflect_with_sign(&mut res_rows, sign);
245        }
246
247        res
248    }
249
250    /// Computes the orthogonal matrix `V_t` of this `U * D * V_t` decomposition.
251    #[must_use]
252    pub fn v_t(&self) -> OMatrix<T, DimMinimum<R, C>, C>
253    where
254        DefaultAllocator: Allocator<DimMinimum<R, C>, C>,
255    {
256        let (nrows, ncols) = self.uv.shape_generic();
257        let min_nrows_ncols = nrows.min(ncols);
258
259        let mut res = Matrix::identity_generic(min_nrows_ncols, ncols);
260        let mut work = Matrix::zeros_generic(min_nrows_ncols, Const::<1>);
261        let mut axis_packed = Matrix::zeros_generic(ncols, Const::<1>);
262
263        let shift = self.axis_shift().1;
264
265        for i in (0..min_nrows_ncols.value() - shift).rev() {
266            let axis = self.uv.view_range(i, i + shift..);
267            let mut axis_packed = axis_packed.rows_range_mut(i + shift..);
268            axis_packed.tr_copy_from(&axis);
269
270            // Sometimes, the axis might have a zero magnitude.
271            if axis_packed.norm_squared().is_zero() {
272                continue;
273            }
274            let refl = Reflection::new(Unit::new_unchecked(axis_packed), T::zero());
275
276            let mut res_rows = res.view_range_mut(i.., i + shift..);
277
278            let sign = if self.upper_diagonal {
279                self.off_diagonal[i].clone().signum()
280            } else {
281                self.diagonal[i].clone().signum()
282            };
283
284            refl.reflect_rows_with_sign(&mut res_rows, &mut work.rows_range_mut(i..), sign);
285        }
286
287        res
288    }
289
290    /// The diagonal part of this decomposed matrix.
291    #[must_use]
292    pub fn diagonal(&self) -> OVector<T::RealField, DimMinimum<R, C>>
293    where
294        DefaultAllocator: Allocator<DimMinimum<R, C>>,
295    {
296        self.diagonal.map(|e| e.modulus())
297    }
298
299    /// The off-diagonal part of this decomposed matrix.
300    #[must_use]
301    pub fn off_diagonal(&self) -> OVector<T::RealField, DimDiff<DimMinimum<R, C>, U1>>
302    where
303        DefaultAllocator: Allocator<DimDiff<DimMinimum<R, C>, U1>>,
304    {
305        self.off_diagonal.map(|e| e.modulus())
306    }
307
308    #[doc(hidden)]
309    pub fn uv_internal(&self) -> &OMatrix<T, R, C> {
310        &self.uv
311    }
312}
313
314// impl<T: ComplexField, D: DimMin<D, Output = D> + DimSub<Dyn>> Bidiagonal<T, D, D>
315//     where DefaultAllocator: Allocator<D, D> +
316//                             Allocator<D> {
317//     /// Solves the linear system `self * x = b`, where `x` is the unknown to be determined.
318//     pub fn solve<R2: Dim, C2: Dim, S2>(&self, b: &Matrix<T, R2, C2, S2>) -> OMatrix<T, R2, C2>
319//         where S2: StorageMut<T, R2, C2>,
320//               ShapeConstraint: SameNumberOfRows<R2, D> {
321//         let mut res = b.clone_owned();
322//         self.solve_mut(&mut res);
323//         res
324//     }
325//
326//     /// Solves the linear system `self * x = b`, where `x` is the unknown to be determined.
327//     pub fn solve_mut<R2: Dim, C2: Dim, S2>(&self, b: &mut Matrix<T, R2, C2, S2>)
328//         where S2: StorageMut<T, R2, C2>,
329//               ShapeConstraint: SameNumberOfRows<R2, D> {
330//
331//         assert_eq!(self.uv.nrows(), b.nrows(), "Bidiagonal solve matrix dimension mismatch.");
332//         assert!(self.uv.is_square(), "Bidiagonal solve: unable to solve a non-square system.");
333//
334//         self.q_tr_mul(b);
335//         self.solve_upper_triangular_mut(b);
336//     }
337//
338//     // TODO: duplicate code from the `solve` module.
339//     fn solve_upper_triangular_mut<R2: Dim, C2: Dim, S2>(&self, b: &mut Matrix<T, R2, C2, S2>)
340//         where S2: StorageMut<T, R2, C2>,
341//               ShapeConstraint: SameNumberOfRows<R2, D> {
342//
343//         let dim  = self.uv.nrows();
344//
345//         for k in 0 .. b.ncols() {
346//             let mut b = b.column_mut(k);
347//             for i in (0 .. dim).rev() {
348//                 let coeff;
349//
350//                 unsafe {
351//                     let diag = *self.diag.vget_unchecked(i);
352//                     coeff = *b.vget_unchecked(i) / diag;
353//                     *b.vget_unchecked_mut(i) = coeff;
354//                 }
355//
356//                 b.rows_range_mut(.. i).axpy(-coeff, &self.uv.view_range(.. i, i), T::one());
357//             }
358//         }
359//     }
360//
361//     /// Computes the inverse of the decomposed matrix.
362//     pub fn inverse(&self) -> OMatrix<T, D, D> {
363//         assert!(self.uv.is_square(), "Bidiagonal inverse: unable to compute the inverse of a non-square matrix.");
364//
365//         // TODO: is there a less naive method ?
366//         let (nrows, ncols) = self.uv.shape_generic();
367//         let mut res = OMatrix::identity_generic(nrows, ncols);
368//         self.solve_mut(&mut res);
369//         res
370//     }
371//
372//     // /// Computes the determinant of the decomposed matrix.
373//     // pub fn determinant(&self) -> T {
374//     //     let dim = self.uv.nrows();
375//     //     assert!(self.uv.is_square(), "Bidiagonal determinant: unable to compute the determinant of a non-square matrix.");
376//
377//     //     let mut res = T::one();
378//     //     for i in 0 .. dim {
379//     //         res *= unsafe { *self.diag.vget_unchecked(i) };
380//     //     }
381//
382//     //     res self.q_determinant()
383//     // }
384// }