nalgebra/linalg/
inverse.rs

1use simba::scalar::ComplexField;
2
3use crate::base::allocator::Allocator;
4use crate::base::dimension::Dim;
5use crate::base::storage::{Storage, StorageMut};
6use crate::base::{DefaultAllocator, OMatrix, SquareMatrix};
7
8use crate::linalg::lu;
9
10impl<T: ComplexField, D: Dim, S: Storage<T, D, D>> SquareMatrix<T, D, S> {
11    /// Attempts to invert this square matrix.
12    ///
13    /// # Panics
14    ///
15    /// Panics if `self` isn’t a square matrix.
16    #[inline]
17    #[must_use = "Did you mean to use try_inverse_mut()?"]
18    pub fn try_inverse(self) -> Option<OMatrix<T, D, D>>
19    where
20        DefaultAllocator: Allocator<D, D>,
21    {
22        let mut me = self.into_owned();
23        if me.try_inverse_mut() {
24            Some(me)
25        } else {
26            None
27        }
28    }
29}
30
31impl<T: ComplexField, D: Dim, S: StorageMut<T, D, D>> SquareMatrix<T, D, S> {
32    /// Attempts to invert this square matrix in-place. Returns `false` and leaves `self` untouched if
33    /// inversion fails.
34    ///
35    /// # Panics
36    ///
37    /// Panics if `self` isn’t a square matrix.
38    #[inline]
39    pub fn try_inverse_mut(&mut self) -> bool
40    where
41        DefaultAllocator: Allocator<D, D>,
42    {
43        assert!(self.is_square(), "Unable to invert a non-square matrix.");
44
45        let dim = self.shape().0;
46
47        unsafe {
48            match dim {
49                0 => true,
50                1 => {
51                    let determinant = self.get_unchecked((0, 0)).clone();
52                    if determinant.is_zero() {
53                        false
54                    } else {
55                        *self.get_unchecked_mut((0, 0)) = T::one() / determinant;
56                        true
57                    }
58                }
59                2 => {
60                    let m11 = self.get_unchecked((0, 0)).clone();
61                    let m12 = self.get_unchecked((0, 1)).clone();
62                    let m21 = self.get_unchecked((1, 0)).clone();
63                    let m22 = self.get_unchecked((1, 1)).clone();
64
65                    let determinant = m11.clone() * m22.clone() - m21.clone() * m12.clone();
66
67                    if determinant.is_zero() {
68                        false
69                    } else {
70                        *self.get_unchecked_mut((0, 0)) = m22 / determinant.clone();
71                        *self.get_unchecked_mut((0, 1)) = -m12 / determinant.clone();
72
73                        *self.get_unchecked_mut((1, 0)) = -m21 / determinant.clone();
74                        *self.get_unchecked_mut((1, 1)) = m11 / determinant;
75
76                        true
77                    }
78                }
79                3 => {
80                    let m11 = self.get_unchecked((0, 0)).clone();
81                    let m12 = self.get_unchecked((0, 1)).clone();
82                    let m13 = self.get_unchecked((0, 2)).clone();
83
84                    let m21 = self.get_unchecked((1, 0)).clone();
85                    let m22 = self.get_unchecked((1, 1)).clone();
86                    let m23 = self.get_unchecked((1, 2)).clone();
87
88                    let m31 = self.get_unchecked((2, 0)).clone();
89                    let m32 = self.get_unchecked((2, 1)).clone();
90                    let m33 = self.get_unchecked((2, 2)).clone();
91
92                    let minor_m12_m23 = m22.clone() * m33.clone() - m32.clone() * m23.clone();
93                    let minor_m11_m23 = m21.clone() * m33.clone() - m31.clone() * m23.clone();
94                    let minor_m11_m22 = m21.clone() * m32.clone() - m31.clone() * m22.clone();
95
96                    let determinant = m11.clone() * minor_m12_m23.clone()
97                        - m12.clone() * minor_m11_m23.clone()
98                        + m13.clone() * minor_m11_m22.clone();
99
100                    if determinant.is_zero() {
101                        false
102                    } else {
103                        *self.get_unchecked_mut((0, 0)) = minor_m12_m23 / determinant.clone();
104                        *self.get_unchecked_mut((0, 1)) = (m13.clone() * m32.clone()
105                            - m33.clone() * m12.clone())
106                            / determinant.clone();
107                        *self.get_unchecked_mut((0, 2)) = (m12.clone() * m23.clone()
108                            - m22.clone() * m13.clone())
109                            / determinant.clone();
110
111                        *self.get_unchecked_mut((1, 0)) = -minor_m11_m23 / determinant.clone();
112                        *self.get_unchecked_mut((1, 1)) =
113                            (m11.clone() * m33 - m31.clone() * m13.clone()) / determinant.clone();
114                        *self.get_unchecked_mut((1, 2)) =
115                            (m13 * m21.clone() - m23 * m11.clone()) / determinant.clone();
116
117                        *self.get_unchecked_mut((2, 0)) = minor_m11_m22 / determinant.clone();
118                        *self.get_unchecked_mut((2, 1)) =
119                            (m12.clone() * m31 - m32 * m11.clone()) / determinant.clone();
120                        *self.get_unchecked_mut((2, 2)) = (m11 * m22 - m21 * m12) / determinant;
121
122                        true
123                    }
124                }
125                4 => {
126                    let oself = self.clone_owned();
127                    do_inverse4(&oself, self)
128                }
129                _ => {
130                    let oself = self.clone_owned();
131                    lu::try_invert_to(oself, self)
132                }
133            }
134        }
135    }
136}
137
138// NOTE: this is an extremely efficient, loop-unrolled matrix inverse from MESA (MIT licensed).
139fn do_inverse4<T: ComplexField, D: Dim, S: StorageMut<T, D, D>>(
140    m: &OMatrix<T, D, D>,
141    out: &mut SquareMatrix<T, D, S>,
142) -> bool
143where
144    DefaultAllocator: Allocator<D, D>,
145{
146    let m = m.as_slice();
147
148    let cofactor00 = m[5].clone() * m[10].clone() * m[15].clone()
149        - m[5].clone() * m[11].clone() * m[14].clone()
150        - m[9].clone() * m[6].clone() * m[15].clone()
151        + m[9].clone() * m[7].clone() * m[14].clone()
152        + m[13].clone() * m[6].clone() * m[11].clone()
153        - m[13].clone() * m[7].clone() * m[10].clone();
154
155    let cofactor01 = -m[4].clone() * m[10].clone() * m[15].clone()
156        + m[4].clone() * m[11].clone() * m[14].clone()
157        + m[8].clone() * m[6].clone() * m[15].clone()
158        - m[8].clone() * m[7].clone() * m[14].clone()
159        - m[12].clone() * m[6].clone() * m[11].clone()
160        + m[12].clone() * m[7].clone() * m[10].clone();
161
162    let cofactor02 = m[4].clone() * m[9].clone() * m[15].clone()
163        - m[4].clone() * m[11].clone() * m[13].clone()
164        - m[8].clone() * m[5].clone() * m[15].clone()
165        + m[8].clone() * m[7].clone() * m[13].clone()
166        + m[12].clone() * m[5].clone() * m[11].clone()
167        - m[12].clone() * m[7].clone() * m[9].clone();
168
169    let cofactor03 = -m[4].clone() * m[9].clone() * m[14].clone()
170        + m[4].clone() * m[10].clone() * m[13].clone()
171        + m[8].clone() * m[5].clone() * m[14].clone()
172        - m[8].clone() * m[6].clone() * m[13].clone()
173        - m[12].clone() * m[5].clone() * m[10].clone()
174        + m[12].clone() * m[6].clone() * m[9].clone();
175
176    let det = m[0].clone() * cofactor00.clone()
177        + m[1].clone() * cofactor01.clone()
178        + m[2].clone() * cofactor02.clone()
179        + m[3].clone() * cofactor03.clone();
180
181    if det.is_zero() {
182        return false;
183    }
184    out[(0, 0)] = cofactor00;
185
186    out[(1, 0)] = -m[1].clone() * m[10].clone() * m[15].clone()
187        + m[1].clone() * m[11].clone() * m[14].clone()
188        + m[9].clone() * m[2].clone() * m[15].clone()
189        - m[9].clone() * m[3].clone() * m[14].clone()
190        - m[13].clone() * m[2].clone() * m[11].clone()
191        + m[13].clone() * m[3].clone() * m[10].clone();
192
193    out[(2, 0)] = m[1].clone() * m[6].clone() * m[15].clone()
194        - m[1].clone() * m[7].clone() * m[14].clone()
195        - m[5].clone() * m[2].clone() * m[15].clone()
196        + m[5].clone() * m[3].clone() * m[14].clone()
197        + m[13].clone() * m[2].clone() * m[7].clone()
198        - m[13].clone() * m[3].clone() * m[6].clone();
199
200    out[(3, 0)] = -m[1].clone() * m[6].clone() * m[11].clone()
201        + m[1].clone() * m[7].clone() * m[10].clone()
202        + m[5].clone() * m[2].clone() * m[11].clone()
203        - m[5].clone() * m[3].clone() * m[10].clone()
204        - m[9].clone() * m[2].clone() * m[7].clone()
205        + m[9].clone() * m[3].clone() * m[6].clone();
206
207    out[(0, 1)] = cofactor01;
208
209    out[(1, 1)] = m[0].clone() * m[10].clone() * m[15].clone()
210        - m[0].clone() * m[11].clone() * m[14].clone()
211        - m[8].clone() * m[2].clone() * m[15].clone()
212        + m[8].clone() * m[3].clone() * m[14].clone()
213        + m[12].clone() * m[2].clone() * m[11].clone()
214        - m[12].clone() * m[3].clone() * m[10].clone();
215
216    out[(2, 1)] = -m[0].clone() * m[6].clone() * m[15].clone()
217        + m[0].clone() * m[7].clone() * m[14].clone()
218        + m[4].clone() * m[2].clone() * m[15].clone()
219        - m[4].clone() * m[3].clone() * m[14].clone()
220        - m[12].clone() * m[2].clone() * m[7].clone()
221        + m[12].clone() * m[3].clone() * m[6].clone();
222
223    out[(3, 1)] = m[0].clone() * m[6].clone() * m[11].clone()
224        - m[0].clone() * m[7].clone() * m[10].clone()
225        - m[4].clone() * m[2].clone() * m[11].clone()
226        + m[4].clone() * m[3].clone() * m[10].clone()
227        + m[8].clone() * m[2].clone() * m[7].clone()
228        - m[8].clone() * m[3].clone() * m[6].clone();
229
230    out[(0, 2)] = cofactor02;
231
232    out[(1, 2)] = -m[0].clone() * m[9].clone() * m[15].clone()
233        + m[0].clone() * m[11].clone() * m[13].clone()
234        + m[8].clone() * m[1].clone() * m[15].clone()
235        - m[8].clone() * m[3].clone() * m[13].clone()
236        - m[12].clone() * m[1].clone() * m[11].clone()
237        + m[12].clone() * m[3].clone() * m[9].clone();
238
239    out[(2, 2)] = m[0].clone() * m[5].clone() * m[15].clone()
240        - m[0].clone() * m[7].clone() * m[13].clone()
241        - m[4].clone() * m[1].clone() * m[15].clone()
242        + m[4].clone() * m[3].clone() * m[13].clone()
243        + m[12].clone() * m[1].clone() * m[7].clone()
244        - m[12].clone() * m[3].clone() * m[5].clone();
245
246    out[(0, 3)] = cofactor03;
247
248    out[(3, 2)] = -m[0].clone() * m[5].clone() * m[11].clone()
249        + m[0].clone() * m[7].clone() * m[9].clone()
250        + m[4].clone() * m[1].clone() * m[11].clone()
251        - m[4].clone() * m[3].clone() * m[9].clone()
252        - m[8].clone() * m[1].clone() * m[7].clone()
253        + m[8].clone() * m[3].clone() * m[5].clone();
254
255    out[(1, 3)] = m[0].clone() * m[9].clone() * m[14].clone()
256        - m[0].clone() * m[10].clone() * m[13].clone()
257        - m[8].clone() * m[1].clone() * m[14].clone()
258        + m[8].clone() * m[2].clone() * m[13].clone()
259        + m[12].clone() * m[1].clone() * m[10].clone()
260        - m[12].clone() * m[2].clone() * m[9].clone();
261
262    out[(2, 3)] = -m[0].clone() * m[5].clone() * m[14].clone()
263        + m[0].clone() * m[6].clone() * m[13].clone()
264        + m[4].clone() * m[1].clone() * m[14].clone()
265        - m[4].clone() * m[2].clone() * m[13].clone()
266        - m[12].clone() * m[1].clone() * m[6].clone()
267        + m[12].clone() * m[2].clone() * m[5].clone();
268
269    out[(3, 3)] = m[0].clone() * m[5].clone() * m[10].clone()
270        - m[0].clone() * m[6].clone() * m[9].clone()
271        - m[4].clone() * m[1].clone() * m[10].clone()
272        + m[4].clone() * m[2].clone() * m[9].clone()
273        + m[8].clone() * m[1].clone() * m[6].clone()
274        - m[8].clone() * m[2].clone() * m[5].clone();
275
276    let inv_det = T::one() / det;
277
278    for j in 0..4 {
279        for i in 0..4 {
280            out[(i, j)] *= inv_det.clone();
281        }
282    }
283    true
284}