ruzstd/encoding/
frame_compressor.rs

1//! Utilities and interfaces for encoding an entire frame. Allows reusing resources
2
3use alloc::vec::Vec;
4use core::convert::TryInto;
5#[cfg(feature = "hash")]
6use twox_hash::XxHash64;
7
8#[cfg(feature = "hash")]
9use core::hash::Hasher;
10
11use super::{
12    block_header::BlockHeader, frame_header::FrameHeader, levels::*,
13    match_generator::MatchGeneratorDriver, CompressionLevel, Matcher,
14};
15use crate::fse::fse_encoder::{default_ll_table, default_ml_table, default_of_table, FSETable};
16
17use crate::io::{Read, Write};
18
19/// An interface for compressing arbitrary data with the ZStandard compression algorithm.
20///
21/// `FrameCompressor` will generally be used by:
22/// 1. Initializing a compressor by providing a buffer of data using `FrameCompressor::new()`
23/// 2. Starting compression and writing that compression into a vec using `FrameCompressor::begin`
24///
25/// # Examples
26/// ```
27/// use ruzstd::encoding::{FrameCompressor, CompressionLevel};
28/// let mock_data: &[_] = &[0x1, 0x2, 0x3, 0x4];
29/// let mut output = std::vec::Vec::new();
30/// // Initialize a compressor.
31/// let mut compressor = FrameCompressor::new(CompressionLevel::Uncompressed);
32/// compressor.set_source(mock_data);
33/// compressor.set_drain(&mut output);
34///
35/// // `compress` writes the compressed output into the provided buffer.
36/// compressor.compress();
37/// ```
38pub struct FrameCompressor<R: Read, W: Write, M: Matcher> {
39    uncompressed_data: Option<R>,
40    compressed_data: Option<W>,
41    compression_level: CompressionLevel,
42    state: CompressState<M>,
43    #[cfg(feature = "hash")]
44    hasher: XxHash64,
45}
46
47pub(crate) struct FseTables {
48    pub(crate) ll_default: FSETable,
49    pub(crate) ll_previous: Option<FSETable>,
50    pub(crate) ml_default: FSETable,
51    pub(crate) ml_previous: Option<FSETable>,
52    pub(crate) of_default: FSETable,
53    pub(crate) of_previous: Option<FSETable>,
54}
55
56impl FseTables {
57    pub fn new() -> Self {
58        Self {
59            ll_default: default_ll_table(),
60            ll_previous: None,
61            ml_default: default_ml_table(),
62            ml_previous: None,
63            of_default: default_of_table(),
64            of_previous: None,
65        }
66    }
67}
68
69pub(crate) struct CompressState<M: Matcher> {
70    pub(crate) matcher: M,
71    pub(crate) last_huff_table: Option<crate::huff0::huff0_encoder::HuffmanTable>,
72    pub(crate) fse_tables: FseTables,
73}
74
75impl<R: Read, W: Write> FrameCompressor<R, W, MatchGeneratorDriver> {
76    /// Create a new `FrameCompressor`
77    pub fn new(compression_level: CompressionLevel) -> Self {
78        Self {
79            uncompressed_data: None,
80            compressed_data: None,
81            compression_level,
82            state: CompressState {
83                matcher: MatchGeneratorDriver::new(1024 * 128, 1),
84                last_huff_table: None,
85                fse_tables: FseTables::new(),
86            },
87            #[cfg(feature = "hash")]
88            hasher: XxHash64::with_seed(0),
89        }
90    }
91}
92
93impl<R: Read, W: Write, M: Matcher> FrameCompressor<R, W, M> {
94    /// Create a new `FrameCompressor` with a custom matching algorithm implementation
95    pub fn new_with_matcher(matcher: M, compression_level: CompressionLevel) -> Self {
96        Self {
97            uncompressed_data: None,
98            compressed_data: None,
99            state: CompressState {
100                matcher,
101                last_huff_table: None,
102                fse_tables: FseTables::new(),
103            },
104            compression_level,
105            #[cfg(feature = "hash")]
106            hasher: XxHash64::with_seed(0),
107        }
108    }
109
110    /// Before calling [FrameCompressor::compress] you need to set the source.
111    ///
112    /// This is the data that is compressed and written into the drain.
113    pub fn set_source(&mut self, uncompressed_data: R) -> Option<R> {
114        self.uncompressed_data.replace(uncompressed_data)
115    }
116
117    /// Before calling [FrameCompressor::compress] you need to set the drain.
118    ///
119    /// As the compressor compresses data, the drain serves as a place for the output to be writte.
120    pub fn set_drain(&mut self, compressed_data: W) -> Option<W> {
121        self.compressed_data.replace(compressed_data)
122    }
123
124    /// Compress the uncompressed data from the provided source as one Zstd frame and write it to the provided drain
125    ///
126    /// This will repeatedly call [Read::read] on the source to fill up blocks until the source returns 0 on the read call.
127    /// Also [Write::write_all] will be called on the drain after each block has been encoded.
128    ///
129    /// To avoid endlessly encoding from a potentially endless source (like a network socket) you can use the
130    /// [Read::take] function
131    pub fn compress(&mut self) {
132        // Clearing buffers to allow re-using of the compressor
133        self.state.matcher.reset(self.compression_level);
134        self.state.last_huff_table = None;
135        let source = self.uncompressed_data.as_mut().unwrap();
136        let drain = self.compressed_data.as_mut().unwrap();
137        // As the frame is compressed, it's stored here
138        let output: &mut Vec<u8> = &mut Vec::with_capacity(1024 * 130);
139        // First write the frame header
140        let header = FrameHeader {
141            frame_content_size: None,
142            single_segment: false,
143            content_checksum: cfg!(feature = "hash"),
144            dictionary_id: None,
145            window_size: Some(self.state.matcher.window_size()),
146        };
147        header.serialize(output);
148        // Now compress block by block
149        loop {
150            // Read a single block's worth of uncompressed data from the input
151            let mut uncompressed_data = self.state.matcher.get_next_space();
152            let mut read_bytes = 0;
153            let last_block;
154            'read_loop: loop {
155                let new_bytes = source.read(&mut uncompressed_data[read_bytes..]).unwrap();
156                if new_bytes == 0 {
157                    last_block = true;
158                    break 'read_loop;
159                }
160                read_bytes += new_bytes;
161                if read_bytes == uncompressed_data.len() {
162                    last_block = false;
163                    break 'read_loop;
164                }
165            }
166            uncompressed_data.resize(read_bytes, 0);
167            // As we read, hash that data too
168            #[cfg(feature = "hash")]
169            self.hasher.write(&uncompressed_data);
170            // Special handling is needed for compression of a totally empty file (why you'd want to do that, I don't know)
171            if uncompressed_data.is_empty() {
172                let header = BlockHeader {
173                    last_block: true,
174                    block_type: crate::blocks::block::BlockType::Raw,
175                    block_size: 0,
176                };
177                // Write the header, then the block
178                header.serialize(output);
179                drain.write_all(output).unwrap();
180                output.clear();
181                break;
182            }
183
184            match self.compression_level {
185                CompressionLevel::Uncompressed => {
186                    let header = BlockHeader {
187                        last_block,
188                        block_type: crate::blocks::block::BlockType::Raw,
189                        block_size: read_bytes.try_into().unwrap(),
190                    };
191                    // Write the header, then the block
192                    header.serialize(output);
193                    output.extend_from_slice(&uncompressed_data);
194                }
195                CompressionLevel::Fastest => {
196                    compress_fastest(&mut self.state, last_block, uncompressed_data, output)
197                }
198                _ => {
199                    unimplemented!();
200                }
201            }
202            drain.write_all(output).unwrap();
203            output.clear();
204            if last_block {
205                break;
206            }
207        }
208
209        // If the `hash` feature is enabled, then `content_checksum` is set to true in the header
210        // and a 32 bit hash is written at the end of the data.
211        #[cfg(feature = "hash")]
212        {
213            // Because we only have the data as a reader, we need to read all of it to calculate the checksum
214            // Possible TODO: create a wrapper around self.uncompressed data that hashes the data as it's read?
215            let content_checksum = self.hasher.finish();
216            drain
217                .write_all(&(content_checksum as u32).to_le_bytes())
218                .unwrap();
219        }
220    }
221
222    /// Get a mutable reference to the source
223    pub fn source_mut(&mut self) -> Option<&mut R> {
224        self.uncompressed_data.as_mut()
225    }
226
227    /// Get a mutable reference to the drain
228    pub fn drain_mut(&mut self) -> Option<&mut W> {
229        self.compressed_data.as_mut()
230    }
231
232    /// Get a reference to the source
233    pub fn source(&self) -> Option<&R> {
234        self.uncompressed_data.as_ref()
235    }
236
237    /// Get a reference to the drain
238    pub fn drain(&self) -> Option<&W> {
239        self.compressed_data.as_ref()
240    }
241
242    /// Retrieve the source
243    pub fn take_source(&mut self) -> Option<R> {
244        self.uncompressed_data.take()
245    }
246
247    /// Retrieve the drain
248    pub fn take_drain(&mut self) -> Option<W> {
249        self.compressed_data.take()
250    }
251
252    /// Before calling [FrameCompressor::compress] you can replace the matcher
253    pub fn replace_matcher(&mut self, mut match_generator: M) -> M {
254        core::mem::swap(&mut match_generator, &mut self.state.matcher);
255        match_generator
256    }
257
258    /// Before calling [FrameCompressor::compress] you can replace the compression level
259    pub fn set_compression_level(
260        &mut self,
261        compression_level: CompressionLevel,
262    ) -> CompressionLevel {
263        let old = self.compression_level;
264        self.compression_level = compression_level;
265        old
266    }
267
268    /// Get the current compression level
269    pub fn compression_level(&self) -> CompressionLevel {
270        self.compression_level
271    }
272}
273
274#[cfg(test)]
275mod tests {
276    use alloc::vec;
277
278    use super::FrameCompressor;
279    use crate::common::MAGIC_NUM;
280    use crate::decoding::FrameDecoder;
281    use alloc::vec::Vec;
282
283    #[test]
284    fn frame_starts_with_magic_num() {
285        let mock_data = [1_u8, 2, 3].as_slice();
286        let mut output: Vec<u8> = Vec::new();
287        let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
288        compressor.set_source(mock_data);
289        compressor.set_drain(&mut output);
290
291        compressor.compress();
292        assert!(output.starts_with(&MAGIC_NUM.to_le_bytes()));
293    }
294
295    #[test]
296    fn very_simple_raw_compress() {
297        let mock_data = [1_u8, 2, 3].as_slice();
298        let mut output: Vec<u8> = Vec::new();
299        let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
300        compressor.set_source(mock_data);
301        compressor.set_drain(&mut output);
302
303        compressor.compress();
304    }
305
306    #[test]
307    fn very_simple_compress() {
308        let mut mock_data = vec![0; 1 << 17];
309        mock_data.extend(vec![1; (1 << 17) - 1]);
310        mock_data.extend(vec![2; (1 << 18) - 1]);
311        mock_data.extend(vec![2; 1 << 17]);
312        mock_data.extend(vec![3; (1 << 17) - 1]);
313        let mut output: Vec<u8> = Vec::new();
314        let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
315        compressor.set_source(mock_data.as_slice());
316        compressor.set_drain(&mut output);
317
318        compressor.compress();
319
320        let mut decoder = FrameDecoder::new();
321        let mut decoded = Vec::with_capacity(mock_data.len());
322        decoder.decode_all_to_vec(&output, &mut decoded).unwrap();
323        assert_eq!(mock_data, decoded);
324
325        let mut decoded = Vec::new();
326        zstd::stream::copy_decode(output.as_slice(), &mut decoded).unwrap();
327        assert_eq!(mock_data, decoded);
328    }
329
330    #[test]
331    fn rle_compress() {
332        let mock_data = vec![0; 1 << 19];
333        let mut output: Vec<u8> = Vec::new();
334        let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
335        compressor.set_source(mock_data.as_slice());
336        compressor.set_drain(&mut output);
337
338        compressor.compress();
339
340        let mut decoder = FrameDecoder::new();
341        let mut decoded = Vec::with_capacity(mock_data.len());
342        decoder.decode_all_to_vec(&output, &mut decoded).unwrap();
343        assert_eq!(mock_data, decoded);
344    }
345
346    #[test]
347    fn aaa_compress() {
348        let mock_data = vec![0, 1, 3, 4, 5];
349        let mut output: Vec<u8> = Vec::new();
350        let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
351        compressor.set_source(mock_data.as_slice());
352        compressor.set_drain(&mut output);
353
354        compressor.compress();
355
356        let mut decoder = FrameDecoder::new();
357        let mut decoded = Vec::with_capacity(mock_data.len());
358        decoder.decode_all_to_vec(&output, &mut decoded).unwrap();
359        assert_eq!(mock_data, decoded);
360
361        let mut decoded = Vec::new();
362        zstd::stream::copy_decode(output.as_slice(), &mut decoded).unwrap();
363        assert_eq!(mock_data, decoded);
364    }
365
366    #[cfg(feature = "std")]
367    #[test]
368    fn fuzz_targets() {
369        use std::io::Read;
370        fn decode_ruzstd(data: &mut dyn std::io::Read) -> Vec<u8> {
371            let mut decoder = crate::decoding::StreamingDecoder::new(data).unwrap();
372            let mut result: Vec<u8> = Vec::new();
373            decoder.read_to_end(&mut result).expect("Decoding failed");
374            result
375        }
376
377        fn decode_ruzstd_writer(mut data: impl Read) -> Vec<u8> {
378            let mut decoder = crate::decoding::FrameDecoder::new();
379            decoder.reset(&mut data).unwrap();
380            let mut result = vec![];
381            while !decoder.is_finished() || decoder.can_collect() > 0 {
382                decoder
383                    .decode_blocks(
384                        &mut data,
385                        crate::decoding::BlockDecodingStrategy::UptoBytes(1024 * 1024),
386                    )
387                    .unwrap();
388                decoder.collect_to_writer(&mut result).unwrap();
389            }
390            result
391        }
392
393        fn encode_zstd(data: &[u8]) -> Result<Vec<u8>, std::io::Error> {
394            zstd::stream::encode_all(std::io::Cursor::new(data), 3)
395        }
396
397        fn encode_ruzstd_uncompressed(data: &mut dyn std::io::Read) -> Vec<u8> {
398            let mut input = Vec::new();
399            data.read_to_end(&mut input).unwrap();
400
401            crate::encoding::compress_to_vec(
402                input.as_slice(),
403                crate::encoding::CompressionLevel::Uncompressed,
404            )
405        }
406
407        fn encode_ruzstd_compressed(data: &mut dyn std::io::Read) -> Vec<u8> {
408            let mut input = Vec::new();
409            data.read_to_end(&mut input).unwrap();
410
411            crate::encoding::compress_to_vec(
412                input.as_slice(),
413                crate::encoding::CompressionLevel::Fastest,
414            )
415        }
416
417        fn decode_zstd(data: &[u8]) -> Result<Vec<u8>, std::io::Error> {
418            let mut output = Vec::new();
419            zstd::stream::copy_decode(data, &mut output)?;
420            Ok(output)
421        }
422        if std::fs::exists("fuzz/artifacts/interop").unwrap_or(false) {
423            for file in std::fs::read_dir("fuzz/artifacts/interop").unwrap() {
424                if file.as_ref().unwrap().file_type().unwrap().is_file() {
425                    let data = std::fs::read(file.unwrap().path()).unwrap();
426                    let data = data.as_slice();
427                    // Decoding
428                    let compressed = encode_zstd(data).unwrap();
429                    let decoded = decode_ruzstd(&mut compressed.as_slice());
430                    let decoded2 = decode_ruzstd_writer(&mut compressed.as_slice());
431                    assert!(
432                        decoded == data,
433                        "Decoded data did not match the original input during decompression"
434                    );
435                    assert_eq!(
436                        decoded2, data,
437                        "Decoded data did not match the original input during decompression"
438                    );
439
440                    // Encoding
441                    // Uncompressed encoding
442                    let mut input = data;
443                    let compressed = encode_ruzstd_uncompressed(&mut input);
444                    let decoded = decode_zstd(&compressed).unwrap();
445                    assert_eq!(
446                        decoded, data,
447                        "Decoded data did not match the original input during compression"
448                    );
449                    // Compressed encoding
450                    let mut input = data;
451                    let compressed = encode_ruzstd_compressed(&mut input);
452                    let decoded = decode_zstd(&compressed).unwrap();
453                    assert_eq!(
454                        decoded, data,
455                        "Decoded data did not match the original input during compression"
456                    );
457                }
458            }
459        }
460    }
461}