nalgebra/linalg/
convolution.rs

1use std::cmp;
2
3use crate::base::allocator::Allocator;
4use crate::base::default_allocator::DefaultAllocator;
5use crate::base::dimension::{Const, Dim, DimAdd, DimDiff, DimSub, DimSum};
6use crate::storage::Storage;
7use crate::{OVector, RealField, U1, Vector, zero};
8
9impl<T: RealField, D1: Dim, S1: Storage<T, D1>> Vector<T, D1, S1> {
10    /// Returns the convolution of the target vector and a kernel.
11    ///
12    /// # Arguments
13    ///
14    /// * `kernel` - A Vector with size > 0
15    ///
16    /// # Errors
17    /// Inputs must satisfy `vector.len() >= kernel.len() > 0`.
18    ///
19    pub fn convolve_full<D2, S2>(
20        &self,
21        kernel: Vector<T, D2, S2>,
22    ) -> OVector<T, DimDiff<DimSum<D1, D2>, U1>>
23    where
24        D1: DimAdd<D2>,
25        D2: DimAdd<D1, Output = DimSum<D1, D2>>,
26        DimSum<D1, D2>: DimSub<U1>,
27        S2: Storage<T, D2>,
28        DefaultAllocator: Allocator<DimDiff<DimSum<D1, D2>, U1>>,
29    {
30        let vec = self.len();
31        let ker = kernel.len();
32
33        if ker == 0 || ker > vec {
34            panic!(
35                "convolve_full expects `self.len() >= kernel.len() > 0`, received {vec} and {ker} respectively."
36            );
37        }
38
39        let result_len = self
40            .data
41            .shape()
42            .0
43            .add(kernel.shape_generic().0)
44            .sub(Const::<1>);
45        let mut conv = OVector::zeros_generic(result_len, Const::<1>);
46
47        for i in 0..(vec + ker - 1) {
48            let u_i = if i > vec { i - ker } else { 0 };
49            let u_f = cmp::min(i, vec - 1);
50
51            if u_i == u_f {
52                conv[i] += self[u_i].clone() * kernel[i - u_i].clone();
53            } else {
54                for u in u_i..(u_f + 1) {
55                    if i - u < ker {
56                        conv[i] += self[u].clone() * kernel[i - u].clone();
57                    }
58                }
59            }
60        }
61        conv
62    }
63    /// Returns the convolution of the target vector and a kernel.
64    ///
65    /// The output convolution consists only of those elements that do not rely on the zero-padding.
66    /// # Arguments
67    ///
68    /// * `kernel` - A Vector with size > 0
69    ///
70    ///
71    /// # Errors
72    /// Inputs must satisfy `self.len() >= kernel.len() > 0`.
73    ///
74    pub fn convolve_valid<D2, S2>(
75        &self,
76        kernel: Vector<T, D2, S2>,
77    ) -> OVector<T, DimDiff<DimSum<D1, U1>, D2>>
78    where
79        D1: DimAdd<U1>,
80        D2: Dim,
81        DimSum<D1, U1>: DimSub<D2>,
82        S2: Storage<T, D2>,
83        DefaultAllocator: Allocator<DimDiff<DimSum<D1, U1>, D2>>,
84    {
85        let vec = self.len();
86        let ker = kernel.len();
87
88        if ker == 0 || ker > vec {
89            panic!(
90                "convolve_valid expects `self.len() >= kernel.len() > 0`, received {vec} and {ker} respectively."
91            );
92        }
93
94        let result_len = self
95            .data
96            .shape()
97            .0
98            .add(Const::<1>)
99            .sub(kernel.shape_generic().0);
100        let mut conv = OVector::zeros_generic(result_len, Const::<1>);
101
102        for i in 0..(vec - ker + 1) {
103            for j in 0..ker {
104                conv[i] += self[i + j].clone() * kernel[ker - j - 1].clone();
105            }
106        }
107        conv
108    }
109
110    /// Returns the convolution of the target vector and a kernel.
111    ///
112    /// The output convolution is the same size as vector, centered with respect to the ‘full’ output.
113    /// # Arguments
114    ///
115    /// * `kernel` - A Vector with size > 0
116    ///
117    /// # Errors
118    /// Inputs must satisfy `self.len() >= kernel.len() > 0`.
119    #[must_use]
120    pub fn convolve_same<D2, S2>(&self, kernel: Vector<T, D2, S2>) -> OVector<T, D1>
121    where
122        D2: Dim,
123        S2: Storage<T, D2>,
124        DefaultAllocator: Allocator<D1>,
125    {
126        let vec = self.len();
127        let ker = kernel.len();
128
129        if ker == 0 || ker > vec {
130            panic!(
131                "convolve_same expects `self.len() >= kernel.len() > 0`, received {vec} and {ker} respectively."
132            );
133        }
134
135        let mut conv = OVector::zeros_generic(self.shape_generic().0, Const::<1>);
136
137        for i in 0..vec {
138            for j in 0..ker {
139                let val = if i + j < 1 || i + j >= vec + 1 {
140                    zero::<T>()
141                } else {
142                    self[i + j - 1].clone()
143                };
144                conv[i] += val * kernel[ker - j - 1].clone();
145            }
146        }
147
148        conv
149    }
150}