ruzstd/encoding/
frame_compressor.rs

1//! Utilities and interfaces for encoding an entire frame. Allows reusing resources
2
3use alloc::vec::Vec;
4use core::convert::TryInto;
5
6use super::{
7    block_header::BlockHeader, blocks::compress_block, frame_header::FrameHeader,
8    match_generator::MatchGeneratorDriver, CompressionLevel, Matcher,
9};
10
11use crate::io::{Read, Write};
12
13/// Blocks cannot be larger than 128KB in size.
14const MAX_BLOCK_SIZE: usize = 128 * 1024 - 20;
15
16/// An interface for compressing arbitrary data with the ZStandard compression algorithm.
17///
18/// `FrameCompressor` will generally be used by:
19/// 1. Initializing a compressor by providing a buffer of data using `FrameCompressor::new()`
20/// 2. Starting compression and writing that compression into a vec using `FrameCompressor::begin`
21///
22/// # Examples
23/// ```
24/// use ruzstd::encoding::{FrameCompressor, CompressionLevel};
25/// let mock_data: &[_] = &[0x1, 0x2, 0x3, 0x4];
26/// let mut output = std::vec::Vec::new();
27/// // Initialize a compressor.
28/// let mut compressor = FrameCompressor::new(CompressionLevel::Uncompressed);
29/// compressor.set_source(mock_data);
30/// compressor.set_drain(&mut output);
31///
32/// // `compress` writes the compressed output into the provided buffer.
33/// compressor.compress();
34/// ```
35pub struct FrameCompressor<R: Read, W: Write, M: Matcher> {
36    uncompressed_data: Option<R>,
37    compressed_data: Option<W>,
38    compression_level: CompressionLevel,
39    match_generator: M,
40}
41
42impl<R: Read, W: Write> FrameCompressor<R, W, MatchGeneratorDriver> {
43    /// Create a new `FrameCompressor`
44    pub fn new(compression_level: CompressionLevel) -> Self {
45        Self {
46            uncompressed_data: None,
47            compressed_data: None,
48            compression_level,
49            match_generator: MatchGeneratorDriver::new(1024 * 128, 1),
50        }
51    }
52}
53
54impl<R: Read, W: Write, M: Matcher> FrameCompressor<R, W, M> {
55    /// Create a new `FrameCompressor` with a custom matching algorithm implementation
56    pub fn new_with_matcher(matcher: M, compression_level: CompressionLevel) -> Self {
57        Self {
58            uncompressed_data: None,
59            compressed_data: None,
60            match_generator: matcher,
61            compression_level,
62        }
63    }
64
65    /// Before calling [FrameCompressor::compress] you need to set the source
66    pub fn set_source(&mut self, uncompressed_data: R) -> Option<R> {
67        self.uncompressed_data.replace(uncompressed_data)
68    }
69
70    /// Before calling [FrameCompressor::compress] you need to set the drain
71    pub fn set_drain(&mut self, compressed_data: W) -> Option<W> {
72        self.compressed_data.replace(compressed_data)
73    }
74
75    /// Compress the uncompressed data from the provided source as one Zstd frame and write it to the provided drain
76    ///
77    /// This will repeatedly call [Read::read] on the source to fill up blocks until the source returns 0 on the read call.
78    /// Also [Write::write_all] will be called on the drain after each block has been encoded.
79    ///
80    /// To avoid endlessly encoding from a potentially endless source (like a network socket) you can use the
81    /// [Read::take] function
82    pub fn compress(&mut self) {
83        self.match_generator.reset(self.compression_level);
84        let source = self.uncompressed_data.as_mut().unwrap();
85        let drain = self.compressed_data.as_mut().unwrap();
86
87        let mut output = Vec::with_capacity(1024 * 130);
88        let output = &mut output;
89        let header = FrameHeader {
90            frame_content_size: None,
91            single_segment: false,
92            content_checksum: false,
93            dictionary_id: None,
94            window_size: Some(self.match_generator.window_size()),
95        };
96        header.serialize(output);
97
98        loop {
99            let mut uncompressed_data = self.match_generator.get_next_space();
100            let mut read_bytes = 0;
101            let last_block;
102            'read_loop: loop {
103                let new_bytes = source.read(&mut uncompressed_data[read_bytes..]).unwrap();
104                if new_bytes == 0 {
105                    last_block = true;
106                    break 'read_loop;
107                }
108                read_bytes += new_bytes;
109                if read_bytes == uncompressed_data.len() {
110                    last_block = false;
111                    break 'read_loop;
112                }
113            }
114            uncompressed_data.resize(read_bytes, 0);
115
116            // Special handling is needed for compression of a totally empty file (why you'd want to do that, I don't know)
117            if uncompressed_data.is_empty() {
118                let header = BlockHeader {
119                    last_block: true,
120                    block_type: crate::blocks::block::BlockType::Raw,
121                    block_size: 0,
122                };
123                // Write the header, then the block
124                header.serialize(output);
125                drain.write_all(output).unwrap();
126                output.clear();
127                break;
128            }
129
130            match self.compression_level {
131                CompressionLevel::Uncompressed => {
132                    let header = BlockHeader {
133                        last_block,
134                        block_type: crate::blocks::block::BlockType::Raw,
135                        block_size: read_bytes.try_into().unwrap(),
136                    };
137                    // Write the header, then the block
138                    header.serialize(output);
139                    output.extend_from_slice(&uncompressed_data);
140                }
141                CompressionLevel::Fastest => {
142                    if uncompressed_data.iter().all(|x| uncompressed_data[0].eq(x)) {
143                        let rle_byte = uncompressed_data[0];
144                        self.match_generator.commit_space(uncompressed_data);
145                        self.match_generator.skip_matching();
146                        let header = BlockHeader {
147                            last_block,
148                            block_type: crate::blocks::block::BlockType::RLE,
149                            block_size: read_bytes.try_into().unwrap(),
150                        };
151                        // Write the header, then the block
152                        header.serialize(output);
153                        output.push(rle_byte);
154                    } else {
155                        let mut compressed = Vec::new();
156                        self.match_generator.commit_space(uncompressed_data);
157                        compress_block(&mut self.match_generator, &mut compressed);
158                        if compressed.len() >= MAX_BLOCK_SIZE {
159                            let header = BlockHeader {
160                                last_block,
161                                block_type: crate::blocks::block::BlockType::Raw,
162                                block_size: read_bytes.try_into().unwrap(),
163                            };
164                            // Write the header, then the block
165                            header.serialize(output);
166                            output.extend_from_slice(self.match_generator.get_last_space());
167                        } else {
168                            let header = BlockHeader {
169                                last_block,
170                                block_type: crate::blocks::block::BlockType::Compressed,
171                                block_size: (compressed.len()).try_into().unwrap(),
172                            };
173                            // Write the header, then the block
174                            header.serialize(output);
175                            output.extend(compressed);
176                        }
177                    }
178                }
179                _ => {
180                    unimplemented!();
181                }
182            }
183            drain.write_all(output).unwrap();
184            output.clear();
185            if last_block {
186                break;
187            }
188        }
189    }
190
191    /// Get a mutable reference to the source
192    pub fn source_mut(&mut self) -> Option<&mut R> {
193        self.uncompressed_data.as_mut()
194    }
195
196    /// Get a mutable reference to the drain
197    pub fn drain_mut(&mut self) -> Option<&mut W> {
198        self.compressed_data.as_mut()
199    }
200
201    /// Get a reference to the source
202    pub fn source(&self) -> Option<&R> {
203        self.uncompressed_data.as_ref()
204    }
205
206    /// Get a reference to the drain
207    pub fn drain(&self) -> Option<&W> {
208        self.compressed_data.as_ref()
209    }
210
211    /// Retrieve the source
212    pub fn take_source(&mut self) -> Option<R> {
213        self.uncompressed_data.take()
214    }
215
216    /// Retrieve the drain
217    pub fn take_drain(&mut self) -> Option<W> {
218        self.compressed_data.take()
219    }
220
221    /// Before calling [FrameCompressor::compress] you can replace the matcher
222    pub fn replace_matcher(&mut self, mut match_generator: M) -> M {
223        core::mem::swap(&mut match_generator, &mut self.match_generator);
224        match_generator
225    }
226
227    /// Before calling [FrameCompressor::compress] you can replace the compression level
228    pub fn set_compression_level(
229        &mut self,
230        compression_level: CompressionLevel,
231    ) -> CompressionLevel {
232        let old = self.compression_level;
233        self.compression_level = compression_level;
234        old
235    }
236
237    /// Get the current compression level
238    pub fn compression_level(&self) -> CompressionLevel {
239        self.compression_level
240    }
241}
242
243#[cfg(test)]
244mod tests {
245    use alloc::vec;
246
247    use super::FrameCompressor;
248    use crate::decoding::{frame::MAGIC_NUM, FrameDecoder};
249    use alloc::vec::Vec;
250
251    #[test]
252    fn frame_starts_with_magic_num() {
253        let mock_data = [1_u8, 2, 3].as_slice();
254        let mut output: Vec<u8> = Vec::new();
255        let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
256        compressor.set_source(mock_data);
257        compressor.set_drain(&mut output);
258
259        compressor.compress();
260        assert!(output.starts_with(&MAGIC_NUM.to_le_bytes()));
261    }
262
263    #[test]
264    fn very_simple_raw_compress() {
265        let mock_data = [1_u8, 2, 3].as_slice();
266        let mut output: Vec<u8> = Vec::new();
267        let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
268        compressor.set_source(mock_data);
269        compressor.set_drain(&mut output);
270
271        compressor.compress();
272    }
273
274    #[test]
275    fn very_simple_compress() {
276        let mut mock_data = vec![0; 1 << 17];
277        mock_data.extend(vec![1; (1 << 17) - 1]);
278        mock_data.extend(vec![2; (1 << 18) - 1]);
279        mock_data.extend(vec![2; 1 << 17]);
280        mock_data.extend(vec![3; (1 << 17) - 1]);
281        let mut output: Vec<u8> = Vec::new();
282        let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
283        compressor.set_source(mock_data.as_slice());
284        compressor.set_drain(&mut output);
285
286        compressor.compress();
287
288        let mut decoder = FrameDecoder::new();
289        let mut decoded = Vec::with_capacity(mock_data.len());
290        decoder.decode_all_to_vec(&output, &mut decoded).unwrap();
291        assert_eq!(mock_data, decoded);
292
293        let mut decoded = Vec::new();
294        zstd::stream::copy_decode(output.as_slice(), &mut decoded).unwrap();
295        assert_eq!(mock_data, decoded);
296    }
297
298    #[test]
299    fn rle_compress() {
300        let mock_data = vec![0; 1 << 19];
301        let mut output: Vec<u8> = Vec::new();
302        let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
303        compressor.set_source(mock_data.as_slice());
304        compressor.set_drain(&mut output);
305
306        compressor.compress();
307
308        let mut decoder = FrameDecoder::new();
309        let mut decoded = Vec::with_capacity(mock_data.len());
310        decoder.decode_all_to_vec(&output, &mut decoded).unwrap();
311        assert_eq!(mock_data, decoded);
312    }
313
314    #[test]
315    fn aaa_compress() {
316        let mock_data = vec![0, 1, 3, 4, 5];
317        let mut output: Vec<u8> = Vec::new();
318        let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
319        compressor.set_source(mock_data.as_slice());
320        compressor.set_drain(&mut output);
321
322        compressor.compress();
323
324        let mut decoder = FrameDecoder::new();
325        let mut decoded = Vec::with_capacity(mock_data.len());
326        decoder.decode_all_to_vec(&output, &mut decoded).unwrap();
327        assert_eq!(mock_data, decoded);
328
329        let mut decoded = Vec::new();
330        zstd::stream::copy_decode(output.as_slice(), &mut decoded).unwrap();
331        assert_eq!(mock_data, decoded);
332    }
333
334    #[cfg(feature = "std")]
335    #[test]
336    fn fuzz_targets() {
337        use std::io::Read;
338        fn decode_ruzstd(data: &mut dyn std::io::Read) -> Vec<u8> {
339            let mut decoder = crate::decoding::StreamingDecoder::new(data).unwrap();
340            let mut result: Vec<u8> = Vec::new();
341            decoder.read_to_end(&mut result).expect("Decoding failed");
342            result
343        }
344
345        fn decode_ruzstd_writer(mut data: impl Read) -> Vec<u8> {
346            let mut decoder = crate::decoding::FrameDecoder::new();
347            decoder.reset(&mut data).unwrap();
348            let mut result = vec![];
349            while !decoder.is_finished() || decoder.can_collect() > 0 {
350                decoder
351                    .decode_blocks(
352                        &mut data,
353                        crate::decoding::BlockDecodingStrategy::UptoBytes(1024 * 1024),
354                    )
355                    .unwrap();
356                decoder.collect_to_writer(&mut result).unwrap();
357            }
358            result
359        }
360
361        fn encode_zstd(data: &[u8]) -> Result<Vec<u8>, std::io::Error> {
362            zstd::stream::encode_all(std::io::Cursor::new(data), 3)
363        }
364
365        fn encode_ruzstd_uncompressed(data: &mut dyn std::io::Read) -> Vec<u8> {
366            let mut input = Vec::new();
367            data.read_to_end(&mut input).unwrap();
368
369            crate::encoding::compress_to_vec(
370                input.as_slice(),
371                crate::encoding::CompressionLevel::Uncompressed,
372            )
373        }
374
375        fn encode_ruzstd_compressed(data: &mut dyn std::io::Read) -> Vec<u8> {
376            let mut input = Vec::new();
377            data.read_to_end(&mut input).unwrap();
378
379            crate::encoding::compress_to_vec(
380                input.as_slice(),
381                crate::encoding::CompressionLevel::Fastest,
382            )
383        }
384
385        fn decode_zstd(data: &[u8]) -> Result<Vec<u8>, std::io::Error> {
386            let mut output = Vec::new();
387            zstd::stream::copy_decode(data, &mut output)?;
388            Ok(output)
389        }
390        if std::fs::exists("fuzz/artifacts/interop").unwrap_or(false) {
391            for file in std::fs::read_dir("fuzz/artifacts/interop").unwrap() {
392                if file.as_ref().unwrap().file_type().unwrap().is_file() {
393                    let data = std::fs::read(file.unwrap().path()).unwrap();
394                    let data = data.as_slice();
395                    // Decoding
396                    let compressed = encode_zstd(data).unwrap();
397                    let decoded = decode_ruzstd(&mut compressed.as_slice());
398                    let decoded2 = decode_ruzstd_writer(&mut compressed.as_slice());
399                    assert!(
400                        decoded == data,
401                        "Decoded data did not match the original input during decompression"
402                    );
403                    assert_eq!(
404                        decoded2, data,
405                        "Decoded data did not match the original input during decompression"
406                    );
407
408                    // Encoding
409                    // Uncompressed encoding
410                    let mut input = data;
411                    let compressed = encode_ruzstd_uncompressed(&mut input);
412                    let decoded = decode_zstd(&compressed).unwrap();
413                    assert_eq!(
414                        decoded, data,
415                        "Decoded data did not match the original input during compression"
416                    );
417                    // Compressed encoding
418                    let mut input = data;
419                    let compressed = encode_ruzstd_compressed(&mut input);
420                    let decoded = decode_zstd(&compressed).unwrap();
421                    assert_eq!(
422                        decoded, data,
423                        "Decoded data did not match the original input during compression"
424                    );
425                }
426            }
427        }
428    }
429}