nalgebra/linalg/
bidiagonal.rs

1#[cfg(feature = "serde-serialize-no-std")]
2use serde::{Deserialize, Serialize};
3
4use crate::allocator::Allocator;
5use crate::base::{DefaultAllocator, Matrix, OMatrix, OVector, Unit};
6use crate::dimension::{Const, Dim, DimDiff, DimMin, DimMinimum, DimSub, U1};
7use simba::scalar::ComplexField;
8
9use crate::geometry::Reflection;
10use crate::linalg::householder;
11use crate::num::Zero;
12use std::mem::MaybeUninit;
13
14/// The bidiagonalization of a general matrix.
15#[cfg_attr(feature = "serde-serialize-no-std", derive(Serialize, Deserialize))]
16#[cfg_attr(
17    feature = "serde-serialize-no-std",
18    serde(bound(serialize = "DimMinimum<R, C>: DimSub<U1>,
19         DefaultAllocator: Allocator<R, C>             +
20                           Allocator<DimMinimum<R, C>> +
21                           Allocator<DimDiff<DimMinimum<R, C>, U1>>,
22         OMatrix<T, R, C>: Serialize,
23         OVector<T, DimMinimum<R, C>>: Serialize,
24         OVector<T, DimDiff<DimMinimum<R, C>, U1>>: Serialize"))
25)]
26#[cfg_attr(
27    feature = "serde-serialize-no-std",
28    serde(bound(deserialize = "DimMinimum<R, C>: DimSub<U1>,
29         DefaultAllocator: Allocator<R, C>             +
30                           Allocator<DimMinimum<R, C>> +
31                           Allocator<DimDiff<DimMinimum<R, C>, U1>>,
32         OMatrix<T, R, C>: Deserialize<'de>,
33         OVector<T, DimMinimum<R, C>>: Deserialize<'de>,
34         OVector<T, DimDiff<DimMinimum<R, C>, U1>>: Deserialize<'de>"))
35)]
36#[cfg_attr(feature = "defmt", derive(defmt::Format))]
37#[derive(Clone, Debug)]
38pub struct Bidiagonal<T: ComplexField, R: DimMin<C>, C: Dim>
39where
40    DimMinimum<R, C>: DimSub<U1>,
41    DefaultAllocator:
42        Allocator<R, C> + Allocator<DimMinimum<R, C>> + Allocator<DimDiff<DimMinimum<R, C>, U1>>,
43{
44    // TODO: perhaps we should pack the axes into different vectors so that axes for `v_t` are
45    // contiguous. This prevents some useless copies.
46    uv: OMatrix<T, R, C>,
47    /// The diagonal elements of the decomposed matrix.
48    diagonal: OVector<T, DimMinimum<R, C>>,
49    /// The off-diagonal elements of the decomposed matrix.
50    off_diagonal: OVector<T, DimDiff<DimMinimum<R, C>, U1>>,
51    upper_diagonal: bool,
52}
53
54impl<T: ComplexField, R: DimMin<C>, C: Dim> Copy for Bidiagonal<T, R, C>
55where
56    DimMinimum<R, C>: DimSub<U1>,
57    DefaultAllocator:
58        Allocator<R, C> + Allocator<DimMinimum<R, C>> + Allocator<DimDiff<DimMinimum<R, C>, U1>>,
59    OMatrix<T, R, C>: Copy,
60    OVector<T, DimMinimum<R, C>>: Copy,
61    OVector<T, DimDiff<DimMinimum<R, C>, U1>>: Copy,
62{
63}
64
65impl<T: ComplexField, R: DimMin<C>, C: Dim> Bidiagonal<T, R, C>
66where
67    DimMinimum<R, C>: DimSub<U1>,
68    DefaultAllocator: Allocator<R, C>
69        + Allocator<C>
70        + Allocator<R>
71        + Allocator<DimMinimum<R, C>>
72        + Allocator<DimDiff<DimMinimum<R, C>, U1>>,
73{
74    /// Computes the Bidiagonal decomposition using householder reflections.
75    pub fn new(mut matrix: OMatrix<T, R, C>) -> Self {
76        let (nrows, ncols) = matrix.shape_generic();
77        let min_nrows_ncols = nrows.min(ncols);
78        let dim = min_nrows_ncols.value();
79        assert!(
80            dim != 0,
81            "Cannot compute the bidiagonalization of an empty matrix."
82        );
83
84        let mut diagonal = Matrix::uninit(min_nrows_ncols, Const::<1>);
85        let mut off_diagonal = Matrix::uninit(min_nrows_ncols.sub(Const::<1>), Const::<1>);
86        let mut axis_packed = Matrix::zeros_generic(ncols, Const::<1>);
87        let mut work = Matrix::zeros_generic(nrows, Const::<1>);
88
89        let upper_diagonal = nrows.value() >= ncols.value();
90        if upper_diagonal {
91            for ite in 0..dim - 1 {
92                diagonal[ite] = MaybeUninit::new(householder::clear_column_unchecked(
93                    &mut matrix,
94                    ite,
95                    0,
96                    None,
97                ));
98                off_diagonal[ite] = MaybeUninit::new(householder::clear_row_unchecked(
99                    &mut matrix,
100                    &mut axis_packed,
101                    &mut work,
102                    ite,
103                    1,
104                ));
105            }
106
107            diagonal[dim - 1] = MaybeUninit::new(householder::clear_column_unchecked(
108                &mut matrix,
109                dim - 1,
110                0,
111                None,
112            ));
113        } else {
114            for ite in 0..dim - 1 {
115                diagonal[ite] = MaybeUninit::new(householder::clear_row_unchecked(
116                    &mut matrix,
117                    &mut axis_packed,
118                    &mut work,
119                    ite,
120                    0,
121                ));
122                off_diagonal[ite] = MaybeUninit::new(householder::clear_column_unchecked(
123                    &mut matrix,
124                    ite,
125                    1,
126                    None,
127                ));
128            }
129
130            diagonal[dim - 1] = MaybeUninit::new(householder::clear_row_unchecked(
131                &mut matrix,
132                &mut axis_packed,
133                &mut work,
134                dim - 1,
135                0,
136            ));
137        }
138
139        // Safety: diagonal and off_diagonal have been fully initialized.
140        let (diagonal, off_diagonal) =
141            unsafe { (diagonal.assume_init(), off_diagonal.assume_init()) };
142
143        Bidiagonal {
144            uv: matrix,
145            diagonal,
146            off_diagonal,
147            upper_diagonal,
148        }
149    }
150
151    /// Indicates whether this decomposition contains an upper-diagonal matrix.
152    #[inline]
153    #[must_use]
154    pub const fn is_upper_diagonal(&self) -> bool {
155        self.upper_diagonal
156    }
157
158    #[inline]
159    const fn axis_shift(&self) -> (usize, usize) {
160        if self.upper_diagonal { (0, 1) } else { (1, 0) }
161    }
162
163    /// Unpacks this decomposition into its three matrix factors `(U, D, V^t)`.
164    ///
165    /// The decomposed matrix `M` is equal to `U * D * V^t`.
166    #[inline]
167    pub fn unpack(
168        self,
169    ) -> (
170        OMatrix<T, R, DimMinimum<R, C>>,
171        OMatrix<T, DimMinimum<R, C>, DimMinimum<R, C>>,
172        OMatrix<T, DimMinimum<R, C>, C>,
173    )
174    where
175        DefaultAllocator: Allocator<DimMinimum<R, C>, DimMinimum<R, C>>
176            + Allocator<R, DimMinimum<R, C>>
177            + Allocator<DimMinimum<R, C>, C>,
178    {
179        // TODO: optimize by calling a reallocator.
180        (self.u(), self.d(), self.v_t())
181    }
182
183    /// Retrieves the upper trapezoidal submatrix `R` of this decomposition.
184    #[inline]
185    #[must_use]
186    pub fn d(&self) -> OMatrix<T, DimMinimum<R, C>, DimMinimum<R, C>>
187    where
188        DefaultAllocator: Allocator<DimMinimum<R, C>, DimMinimum<R, C>>,
189    {
190        let (nrows, ncols) = self.uv.shape_generic();
191
192        let d = nrows.min(ncols);
193        let mut res = OMatrix::identity_generic(d, d);
194        res.set_partial_diagonal(
195            self.diagonal
196                .iter()
197                .map(|e| T::from_real(e.clone().modulus())),
198        );
199
200        let start = self.axis_shift();
201        res.view_mut(start, (d.value() - 1, d.value() - 1))
202            .set_partial_diagonal(
203                self.off_diagonal
204                    .iter()
205                    .map(|e| T::from_real(e.clone().modulus())),
206            );
207        res
208    }
209
210    /// Computes the orthogonal matrix `U` of this `U * D * V` decomposition.
211    // TODO: code duplication with householder::assemble_q.
212    // Except that we are returning a rectangular matrix here.
213    #[must_use]
214    pub fn u(&self) -> OMatrix<T, R, DimMinimum<R, C>>
215    where
216        DefaultAllocator: Allocator<R, DimMinimum<R, C>>,
217    {
218        let (nrows, ncols) = self.uv.shape_generic();
219
220        let mut res = Matrix::identity_generic(nrows, nrows.min(ncols));
221        let dim = self.diagonal.len();
222        let shift = self.axis_shift().0;
223
224        for i in (0..dim - shift).rev() {
225            let axis = self.uv.view_range(i + shift.., i);
226
227            // Sometimes, the axis might have a zero magnitude.
228            if axis.norm_squared().is_zero() {
229                continue;
230            }
231            let refl = Reflection::new(Unit::new_unchecked(axis), T::zero());
232
233            let mut res_rows = res.view_range_mut(i + shift.., i..);
234
235            let sign = if self.upper_diagonal {
236                self.diagonal[i].clone().signum()
237            } else {
238                self.off_diagonal[i].clone().signum()
239            };
240
241            refl.reflect_with_sign(&mut res_rows, sign);
242        }
243
244        res
245    }
246
247    /// Computes the orthogonal matrix `V_t` of this `U * D * V_t` decomposition.
248    #[must_use]
249    pub fn v_t(&self) -> OMatrix<T, DimMinimum<R, C>, C>
250    where
251        DefaultAllocator: Allocator<DimMinimum<R, C>, C>,
252    {
253        let (nrows, ncols) = self.uv.shape_generic();
254        let min_nrows_ncols = nrows.min(ncols);
255
256        let mut res = Matrix::identity_generic(min_nrows_ncols, ncols);
257        let mut work = Matrix::zeros_generic(min_nrows_ncols, Const::<1>);
258        let mut axis_packed = Matrix::zeros_generic(ncols, Const::<1>);
259
260        let shift = self.axis_shift().1;
261
262        for i in (0..min_nrows_ncols.value() - shift).rev() {
263            let axis = self.uv.view_range(i, i + shift..);
264            let mut axis_packed = axis_packed.rows_range_mut(i + shift..);
265            axis_packed.tr_copy_from(&axis);
266
267            // Sometimes, the axis might have a zero magnitude.
268            if axis_packed.norm_squared().is_zero() {
269                continue;
270            }
271            let refl = Reflection::new(Unit::new_unchecked(axis_packed), T::zero());
272
273            let mut res_rows = res.view_range_mut(i.., i + shift..);
274
275            let sign = if self.upper_diagonal {
276                self.off_diagonal[i].clone().signum()
277            } else {
278                self.diagonal[i].clone().signum()
279            };
280
281            refl.reflect_rows_with_sign(&mut res_rows, &mut work.rows_range_mut(i..), sign);
282        }
283
284        res
285    }
286
287    /// The diagonal part of this decomposed matrix.
288    #[must_use]
289    pub fn diagonal(&self) -> OVector<T::RealField, DimMinimum<R, C>>
290    where
291        DefaultAllocator: Allocator<DimMinimum<R, C>>,
292    {
293        self.diagonal.map(|e| e.modulus())
294    }
295
296    /// The off-diagonal part of this decomposed matrix.
297    #[must_use]
298    pub fn off_diagonal(&self) -> OVector<T::RealField, DimDiff<DimMinimum<R, C>, U1>>
299    where
300        DefaultAllocator: Allocator<DimDiff<DimMinimum<R, C>, U1>>,
301    {
302        self.off_diagonal.map(|e| e.modulus())
303    }
304
305    #[doc(hidden)]
306    pub const fn uv_internal(&self) -> &OMatrix<T, R, C> {
307        &self.uv
308    }
309}
310
311// impl<T: ComplexField, D: DimMin<D, Output = D> + DimSub<Dyn>> Bidiagonal<T, D, D>
312//     where DefaultAllocator: Allocator<D, D> +
313//                             Allocator<D> {
314//     /// Solves the linear system `self * x = b`, where `x` is the unknown to be determined.
315//     pub fn solve<R2: Dim, C2: Dim, S2>(&self, b: &Matrix<T, R2, C2, S2>) -> OMatrix<T, R2, C2>
316//         where S2: StorageMut<T, R2, C2>,
317//               ShapeConstraint: SameNumberOfRows<R2, D> {
318//         let mut res = b.clone_owned();
319//         self.solve_mut(&mut res);
320//         res
321//     }
322//
323//     /// Solves the linear system `self * x = b`, where `x` is the unknown to be determined.
324//     pub fn solve_mut<R2: Dim, C2: Dim, S2>(&self, b: &mut Matrix<T, R2, C2, S2>)
325//         where S2: StorageMut<T, R2, C2>,
326//               ShapeConstraint: SameNumberOfRows<R2, D> {
327//
328//         assert_eq!(self.uv.nrows(), b.nrows(), "Bidiagonal solve matrix dimension mismatch.");
329//         assert!(self.uv.is_square(), "Bidiagonal solve: unable to solve a non-square system.");
330//
331//         self.q_tr_mul(b);
332//         self.solve_upper_triangular_mut(b);
333//     }
334//
335//     // TODO: duplicate code from the `solve` module.
336//     fn solve_upper_triangular_mut<R2: Dim, C2: Dim, S2>(&self, b: &mut Matrix<T, R2, C2, S2>)
337//         where S2: StorageMut<T, R2, C2>,
338//               ShapeConstraint: SameNumberOfRows<R2, D> {
339//
340//         let dim  = self.uv.nrows();
341//
342//         for k in 0 .. b.ncols() {
343//             let mut b = b.column_mut(k);
344//             for i in (0 .. dim).rev() {
345//                 let coeff;
346//
347//                 unsafe {
348//                     let diag = *self.diag.vget_unchecked(i);
349//                     coeff = *b.vget_unchecked(i) / diag;
350//                     *b.vget_unchecked_mut(i) = coeff;
351//                 }
352//
353//                 b.rows_range_mut(.. i).axpy(-coeff, &self.uv.view_range(.. i, i), T::one());
354//             }
355//         }
356//     }
357//
358//     /// Computes the inverse of the decomposed matrix.
359//     pub fn inverse(&self) -> OMatrix<T, D, D> {
360//         assert!(self.uv.is_square(), "Bidiagonal inverse: unable to compute the inverse of a non-square matrix.");
361//
362//         // TODO: is there a less naive method ?
363//         let (nrows, ncols) = self.uv.shape_generic();
364//         let mut res = OMatrix::identity_generic(nrows, ncols);
365//         self.solve_mut(&mut res);
366//         res
367//     }
368//
369//     // /// Computes the determinant of the decomposed matrix.
370//     // pub fn determinant(&self) -> T {
371//     //     let dim = self.uv.nrows();
372//     //     assert!(self.uv.is_square(), "Bidiagonal determinant: unable to compute the determinant of a non-square matrix.");
373//
374//     //     let mut res = T::one();
375//     //     for i in 0 .. dim {
376//     //         res *= unsafe { *self.diag.vget_unchecked(i) };
377//     //     }
378//
379//     //     res self.q_determinant()
380//     // }
381// }