ruzstd/decoding/
frame_decoder.rs

1//! Framedecoder is the main low-level struct users interact with to decode zstd frames
2//!
3//! Zstandard compressed data is made of one or more frames. Each frame is independent and can be
4//! decompressed independently of other frames. This module contains structures
5//! and utilities that can be used to decode a frame.
6
7use super::frame;
8use crate::decoding;
9use crate::decoding::dictionary::Dictionary;
10use crate::decoding::errors::FrameDecoderError;
11use crate::decoding::scratch::DecoderScratch;
12use crate::io::{Error, Read, Write};
13use alloc::collections::BTreeMap;
14use alloc::vec::Vec;
15use core::convert::TryInto;
16
17/// While the maximum window size allowed by the spec is significantly larger,
18/// our implementation limits it to 100mb to protect against malformed frames.
19const MAXIMUM_ALLOWED_WINDOW_SIZE: u64 = 1024 * 1024 * 100;
20
21/// Low level Zstandard decoder that can be used to decompress frames with fine control over when and how many bytes are decoded.
22///
23/// This decoder is able to decode frames only partially and gives control
24/// over how many bytes/blocks will be decoded at a time (so you don't have to decode a 10GB file into memory all at once).
25/// It reads bytes as needed from a provided source and can be read from to collect partial results.
26///
27/// If you want to just read the whole frame with an `io::Read` without having to deal with manually calling [FrameDecoder::decode_blocks]
28/// you can use the provided [crate::decoding::StreamingDecoder] wich wraps this FrameDecoder.
29///
30/// Workflow is as follows:
31/// ```
32/// use ruzstd::decoding::BlockDecodingStrategy;
33///
34/// # #[cfg(feature = "std")]
35/// use std::io::{Read, Write};
36///
37/// // no_std environments can use the crate's own Read traits
38/// # #[cfg(not(feature = "std"))]
39/// use ruzstd::io::{Read, Write};
40///
41/// fn decode_this(mut file: impl Read) {
42///     //Create a new decoder
43///     let mut frame_dec = ruzstd::decoding::FrameDecoder::new();
44///     let mut result = Vec::new();
45///
46///     // Use reset or init to make the decoder ready to decode the frame from the io::Read
47///     frame_dec.reset(&mut file).unwrap();
48///
49///     // Loop until the frame has been decoded completely
50///     while !frame_dec.is_finished() {
51///         // decode (roughly) batch_size many bytes
52///         frame_dec.decode_blocks(&mut file, BlockDecodingStrategy::UptoBytes(1024)).unwrap();
53///
54///         // read from the decoder to collect bytes from the internal buffer
55///         let bytes_read = frame_dec.read(result.as_mut_slice()).unwrap();
56///
57///         // then do something with it
58///         do_something(&result[0..bytes_read]);
59///     }
60///
61///     // handle the last chunk of data
62///     while frame_dec.can_collect() > 0 {
63///         let x = frame_dec.read(result.as_mut_slice()).unwrap();
64///
65///         do_something(&result[0..x]);
66///     }
67/// }
68///
69/// fn do_something(data: &[u8]) {
70/// # #[cfg(feature = "std")]
71///     std::io::stdout().write_all(data).unwrap();
72/// }
73/// ```
74pub struct FrameDecoder {
75    state: Option<FrameDecoderState>,
76    dicts: BTreeMap<u32, Dictionary>,
77}
78
79struct FrameDecoderState {
80    pub frame_header: frame::FrameHeader,
81    decoder_scratch: DecoderScratch,
82    frame_finished: bool,
83    block_counter: usize,
84    bytes_read_counter: u64,
85    check_sum: Option<u32>,
86    using_dict: Option<u32>,
87}
88
89pub enum BlockDecodingStrategy {
90    All,
91    UptoBlocks(usize),
92    UptoBytes(usize),
93}
94
95impl FrameDecoderState {
96    pub fn new(source: impl Read) -> Result<FrameDecoderState, FrameDecoderError> {
97        let (frame, header_size) = frame::read_frame_header(source)?;
98        let window_size = frame.window_size()?;
99        Ok(FrameDecoderState {
100            frame_header: frame,
101            frame_finished: false,
102            block_counter: 0,
103            decoder_scratch: DecoderScratch::new(window_size as usize),
104            bytes_read_counter: u64::from(header_size),
105            check_sum: None,
106            using_dict: None,
107        })
108    }
109
110    pub fn reset(&mut self, source: impl Read) -> Result<(), FrameDecoderError> {
111        let (frame_header, header_size) = frame::read_frame_header(source)?;
112        let window_size = frame_header.window_size()?;
113
114        if window_size > MAXIMUM_ALLOWED_WINDOW_SIZE {
115            return Err(FrameDecoderError::WindowSizeTooBig {
116                requested: window_size,
117            });
118        }
119
120        self.frame_header = frame_header;
121        self.frame_finished = false;
122        self.block_counter = 0;
123        self.decoder_scratch.reset(window_size as usize);
124        self.bytes_read_counter = u64::from(header_size);
125        self.check_sum = None;
126        self.using_dict = None;
127        Ok(())
128    }
129}
130
131impl Default for FrameDecoder {
132    fn default() -> Self {
133        Self::new()
134    }
135}
136
137impl FrameDecoder {
138    /// This will create a new decoder without allocating anything yet.
139    /// init()/reset() will allocate all needed buffers if it is the first time this decoder is used
140    /// else they just reset these buffers with not further allocations
141    pub fn new() -> FrameDecoder {
142        FrameDecoder {
143            state: None,
144            dicts: BTreeMap::new(),
145        }
146    }
147
148    /// init() will allocate all needed buffers if it is the first time this decoder is used
149    /// else they just reset these buffers with not further allocations
150    ///
151    /// Note that all bytes currently in the decodebuffer from any previous frame will be lost. Collect them with collect()/collect_to_writer()
152    ///
153    /// equivalent to reset()
154    pub fn init(&mut self, source: impl Read) -> Result<(), FrameDecoderError> {
155        self.reset(source)
156    }
157
158    /// reset() will allocate all needed buffers if it is the first time this decoder is used
159    /// else they just reset these buffers with not further allocations
160    ///
161    /// Note that all bytes currently in the decodebuffer from any previous frame will be lost. Collect them with collect()/collect_to_writer()
162    ///
163    /// equivalent to init()
164    pub fn reset(&mut self, source: impl Read) -> Result<(), FrameDecoderError> {
165        use FrameDecoderError as err;
166        let state = match &mut self.state {
167            Some(s) => {
168                s.reset(source)?;
169                s
170            }
171            None => {
172                self.state = Some(FrameDecoderState::new(source)?);
173                self.state.as_mut().unwrap()
174            }
175        };
176        if let Some(dict_id) = state.frame_header.dictionary_id() {
177            let dict = self
178                .dicts
179                .get(&dict_id)
180                .ok_or(err::DictNotProvided { dict_id })?;
181            state.decoder_scratch.init_from_dict(dict);
182            state.using_dict = Some(dict_id);
183        }
184        Ok(())
185    }
186
187    /// Add a dict to the FrameDecoder that can be used when needed. The FrameDecoder uses the appropriate one dynamically
188    pub fn add_dict(&mut self, dict: Dictionary) -> Result<(), FrameDecoderError> {
189        self.dicts.insert(dict.id, dict);
190        Ok(())
191    }
192
193    pub fn force_dict(&mut self, dict_id: u32) -> Result<(), FrameDecoderError> {
194        use FrameDecoderError as err;
195        let Some(state) = self.state.as_mut() else {
196            return Err(err::NotYetInitialized);
197        };
198
199        let dict = self
200            .dicts
201            .get(&dict_id)
202            .ok_or(err::DictNotProvided { dict_id })?;
203        state.decoder_scratch.init_from_dict(dict);
204        state.using_dict = Some(dict_id);
205
206        Ok(())
207    }
208
209    /// Returns how many bytes the frame contains after decompression
210    pub fn content_size(&self) -> u64 {
211        match &self.state {
212            None => 0,
213            Some(s) => s.frame_header.frame_content_size(),
214        }
215    }
216
217    /// Returns the checksum that was read from the data. Only available after all bytes have been read. It is the last 4 bytes of a zstd-frame
218    pub fn get_checksum_from_data(&self) -> Option<u32> {
219        let state = match &self.state {
220            None => return None,
221            Some(s) => s,
222        };
223
224        state.check_sum
225    }
226
227    /// Returns the checksum that was calculated while decoding.
228    /// Only a sensible value after all decoded bytes have been collected/read from the FrameDecoder
229    #[cfg(feature = "hash")]
230    pub fn get_calculated_checksum(&self) -> Option<u32> {
231        use core::hash::Hasher;
232
233        let state = match &self.state {
234            None => return None,
235            Some(s) => s,
236        };
237        let cksum_64bit = state.decoder_scratch.buffer.hash.finish();
238        //truncate to lower 32bit because reasons...
239        Some(cksum_64bit as u32)
240    }
241
242    /// Counter for how many bytes have been consumed while decoding the frame
243    pub fn bytes_read_from_source(&self) -> u64 {
244        let state = match &self.state {
245            None => return 0,
246            Some(s) => s,
247        };
248        state.bytes_read_counter
249    }
250
251    /// Whether the current frames last block has been decoded yet
252    /// If this returns true you can call the drain* functions to get all content
253    /// (the read() function will drain automatically if this returns true)
254    pub fn is_finished(&self) -> bool {
255        let state = match &self.state {
256            None => return true,
257            Some(s) => s,
258        };
259        if state.frame_header.descriptor.content_checksum_flag() {
260            state.frame_finished && state.check_sum.is_some()
261        } else {
262            state.frame_finished
263        }
264    }
265
266    /// Counter for how many blocks have already been decoded
267    pub fn blocks_decoded(&self) -> usize {
268        let state = match &self.state {
269            None => return 0,
270            Some(s) => s,
271        };
272        state.block_counter
273    }
274
275    /// Decodes blocks from a reader. It requires that the framedecoder has been initialized first.
276    /// The Strategy influences how many blocks will be decoded before the function returns
277    /// This is important if you want to manage memory consumption carefully. If you don't care
278    /// about that you can just choose the strategy "All" and have all blocks of the frame decoded into the buffer
279    pub fn decode_blocks(
280        &mut self,
281        mut source: impl Read,
282        strat: BlockDecodingStrategy,
283    ) -> Result<bool, FrameDecoderError> {
284        use FrameDecoderError as err;
285        let state = self.state.as_mut().ok_or(err::NotYetInitialized)?;
286
287        let mut block_dec = decoding::block_decoder::new();
288
289        let buffer_size_before = state.decoder_scratch.buffer.len();
290        let block_counter_before = state.block_counter;
291        loop {
292            vprintln!("################");
293            vprintln!("Next Block: {}", state.block_counter);
294            vprintln!("################");
295            let (block_header, block_header_size) = block_dec
296                .read_block_header(&mut source)
297                .map_err(err::FailedToReadBlockHeader)?;
298            state.bytes_read_counter += u64::from(block_header_size);
299
300            vprintln!();
301            vprintln!(
302                "Found {} block with size: {}, which will be of size: {}",
303                block_header.block_type,
304                block_header.content_size,
305                block_header.decompressed_size
306            );
307
308            let bytes_read_in_block_body = block_dec
309                .decode_block_content(&block_header, &mut state.decoder_scratch, &mut source)
310                .map_err(err::FailedToReadBlockBody)?;
311            state.bytes_read_counter += bytes_read_in_block_body;
312
313            state.block_counter += 1;
314
315            vprintln!("Output: {}", state.decoder_scratch.buffer.len());
316
317            if block_header.last_block {
318                state.frame_finished = true;
319                if state.frame_header.descriptor.content_checksum_flag() {
320                    let mut chksum = [0u8; 4];
321                    source
322                        .read_exact(&mut chksum)
323                        .map_err(err::FailedToReadChecksum)?;
324                    state.bytes_read_counter += 4;
325                    let chksum = u32::from_le_bytes(chksum);
326                    state.check_sum = Some(chksum);
327                }
328                break;
329            }
330
331            match strat {
332                BlockDecodingStrategy::All => { /* keep going */ }
333                BlockDecodingStrategy::UptoBlocks(n) => {
334                    if state.block_counter - block_counter_before >= n {
335                        break;
336                    }
337                }
338                BlockDecodingStrategy::UptoBytes(n) => {
339                    if state.decoder_scratch.buffer.len() - buffer_size_before >= n {
340                        break;
341                    }
342                }
343            }
344        }
345
346        Ok(state.frame_finished)
347    }
348
349    /// Collect bytes and retain window_size bytes while decoding is still going on.
350    /// After decoding of the frame (is_finished() == true) has finished it will collect all remaining bytes
351    pub fn collect(&mut self) -> Option<Vec<u8>> {
352        let finished = self.is_finished();
353        let state = self.state.as_mut()?;
354        if finished {
355            Some(state.decoder_scratch.buffer.drain())
356        } else {
357            state.decoder_scratch.buffer.drain_to_window_size()
358        }
359    }
360
361    /// Collect bytes and retain window_size bytes while decoding is still going on.
362    /// After decoding of the frame (is_finished() == true) has finished it will collect all remaining bytes
363    pub fn collect_to_writer(&mut self, w: impl Write) -> Result<usize, Error> {
364        let finished = self.is_finished();
365        let state = match &mut self.state {
366            None => return Ok(0),
367            Some(s) => s,
368        };
369        if finished {
370            state.decoder_scratch.buffer.drain_to_writer(w)
371        } else {
372            state.decoder_scratch.buffer.drain_to_window_size_writer(w)
373        }
374    }
375
376    /// How many bytes can currently be collected from the decodebuffer, while decoding is going on this will be lower than the actual decodbuffer size
377    /// because window_size bytes need to be retained for decoding.
378    /// After decoding of the frame (is_finished() == true) has finished it will report all remaining bytes
379    pub fn can_collect(&self) -> usize {
380        let finished = self.is_finished();
381        let state = match &self.state {
382            None => return 0,
383            Some(s) => s,
384        };
385        if finished {
386            state.decoder_scratch.buffer.can_drain()
387        } else {
388            state
389                .decoder_scratch
390                .buffer
391                .can_drain_to_window_size()
392                .unwrap_or(0)
393        }
394    }
395
396    /// Decodes as many blocks as possible from the source slice and reads from the decodebuffer into the target slice
397    /// The source slice may contain only parts of a frame but must contain at least one full block to make progress
398    ///
399    /// By all means use decode_blocks if you have a io.Reader available. This is just for compatibility with other decompressors
400    /// which try to serve an old-style c api
401    ///
402    /// Returns (read, written), if read == 0 then the source did not contain a full block and further calls with the same
403    /// input will not make any progress!
404    ///
405    /// Note that no kind of block can be bigger than 128kb.
406    /// So to be safe use at least 128*1024 (max block content size) + 3 (block_header size) + 18 (max frame_header size) bytes as your source buffer
407    ///
408    /// You may call this function with an empty source after all bytes have been decoded. This is equivalent to just call decoder.read(&mut target)
409    pub fn decode_from_to(
410        &mut self,
411        source: &[u8],
412        target: &mut [u8],
413    ) -> Result<(usize, usize), FrameDecoderError> {
414        use FrameDecoderError as err;
415        let bytes_read_at_start = match &self.state {
416            Some(s) => s.bytes_read_counter,
417            None => 0,
418        };
419
420        if !self.is_finished() || self.state.is_none() {
421            let mut mt_source = source;
422
423            if self.state.is_none() {
424                self.init(&mut mt_source)?;
425            }
426
427            //pseudo block to scope "state" so we can borrow self again after the block
428            {
429                let state = match &mut self.state {
430                    Some(s) => s,
431                    None => panic!("Bug in library"),
432                };
433                let mut block_dec = decoding::block_decoder::new();
434
435                if state.frame_header.descriptor.content_checksum_flag()
436                    && state.frame_finished
437                    && state.check_sum.is_none()
438                {
439                    //this block is needed if the checksum were the only 4 bytes that were not included in the last decode_from_to call for a frame
440                    if mt_source.len() >= 4 {
441                        let chksum = mt_source[..4].try_into().expect("optimized away");
442                        state.bytes_read_counter += 4;
443                        let chksum = u32::from_le_bytes(chksum);
444                        state.check_sum = Some(chksum);
445                    }
446                    return Ok((4, 0));
447                }
448
449                loop {
450                    //check if there are enough bytes for the next header
451                    if mt_source.len() < 3 {
452                        break;
453                    }
454                    let (block_header, block_header_size) = block_dec
455                        .read_block_header(&mut mt_source)
456                        .map_err(err::FailedToReadBlockHeader)?;
457
458                    // check the needed size for the block before updating counters.
459                    // If not enough bytes are in the source, the header will have to be read again, so act like we never read it in the first place
460                    if mt_source.len() < block_header.content_size as usize {
461                        break;
462                    }
463                    state.bytes_read_counter += u64::from(block_header_size);
464
465                    let bytes_read_in_block_body = block_dec
466                        .decode_block_content(
467                            &block_header,
468                            &mut state.decoder_scratch,
469                            &mut mt_source,
470                        )
471                        .map_err(err::FailedToReadBlockBody)?;
472                    state.bytes_read_counter += bytes_read_in_block_body;
473                    state.block_counter += 1;
474
475                    if block_header.last_block {
476                        state.frame_finished = true;
477                        if state.frame_header.descriptor.content_checksum_flag() {
478                            //if there are enough bytes handle this here. Else the block at the start of this function will handle it at the next call
479                            if mt_source.len() >= 4 {
480                                let chksum = mt_source[..4].try_into().expect("optimized away");
481                                state.bytes_read_counter += 4;
482                                let chksum = u32::from_le_bytes(chksum);
483                                state.check_sum = Some(chksum);
484                            }
485                        }
486                        break;
487                    }
488                }
489            }
490        }
491
492        let result_len = self.read(target).map_err(err::FailedToDrainDecodebuffer)?;
493        let bytes_read_at_end = match &mut self.state {
494            Some(s) => s.bytes_read_counter,
495            None => panic!("Bug in library"),
496        };
497        let read_len = bytes_read_at_end - bytes_read_at_start;
498        Ok((read_len as usize, result_len))
499    }
500
501    /// Decode multiple frames into the output slice.
502    ///
503    /// `input` must contain an exact number of frames.
504    ///
505    /// `output` must be large enough to hold the decompressed data. If you don't know
506    /// how large the output will be, use [`FrameDecoder::decode_blocks`] instead.
507    ///
508    /// This calls [`FrameDecoder::init`], and all bytes currently in the decoder will be lost.
509    ///
510    /// Returns the number of bytes written to `output`.
511    pub fn decode_all(
512        &mut self,
513        mut input: &[u8],
514        mut output: &mut [u8],
515    ) -> Result<usize, FrameDecoderError> {
516        let mut total_bytes_written = 0;
517        while !input.is_empty() {
518            match self.init(&mut input) {
519                Ok(_) => {}
520                Err(FrameDecoderError::ReadFrameHeaderError(
521                    crate::decoding::errors::ReadFrameHeaderError::SkipFrame { length, .. },
522                )) => {
523                    input = input
524                        .get(length as usize..)
525                        .ok_or(FrameDecoderError::FailedToSkipFrame)?;
526                    continue;
527                }
528                Err(e) => return Err(e),
529            };
530            loop {
531                self.decode_blocks(&mut input, BlockDecodingStrategy::UptoBytes(1024 * 1024))?;
532                let bytes_written = self
533                    .read(output)
534                    .map_err(FrameDecoderError::FailedToDrainDecodebuffer)?;
535                output = &mut output[bytes_written..];
536                total_bytes_written += bytes_written;
537                if self.can_collect() != 0 {
538                    return Err(FrameDecoderError::TargetTooSmall);
539                }
540                if self.is_finished() {
541                    break;
542                }
543            }
544        }
545
546        Ok(total_bytes_written)
547    }
548
549    /// Decode multiple frames into the extra capacity of the output vector.
550    ///
551    /// `input` must contain an exact number of frames.
552    ///
553    /// `output` must have enough extra capacity to hold the decompressed data.
554    /// This function will not reallocate or grow the vector. If you don't know
555    /// how large the output will be, use [`FrameDecoder::decode_blocks`] instead.
556    ///
557    /// This calls [`FrameDecoder::init`], and all bytes currently in the decoder will be lost.
558    ///
559    /// The length of the output vector is updated to include the decompressed data.
560    /// The length is not changed if an error occurs.
561    pub fn decode_all_to_vec(
562        &mut self,
563        input: &[u8],
564        output: &mut Vec<u8>,
565    ) -> Result<(), FrameDecoderError> {
566        let len = output.len();
567        let cap = output.capacity();
568        output.resize(cap, 0);
569        match self.decode_all(input, &mut output[len..]) {
570            Ok(bytes_written) => {
571                let new_len = core::cmp::min(len + bytes_written, cap); // Sanitizes `bytes_written`.
572                output.resize(new_len, 0);
573                Ok(())
574            }
575            Err(e) => {
576                output.resize(len, 0);
577                Err(e)
578            }
579        }
580    }
581}
582
583/// Read bytes from the decode_buffer that are no longer needed. While the frame is not yet finished
584/// this will retain window_size bytes, else it will drain it completely
585impl Read for FrameDecoder {
586    fn read(&mut self, target: &mut [u8]) -> Result<usize, Error> {
587        let state = match &mut self.state {
588            None => return Ok(0),
589            Some(s) => s,
590        };
591        if state.frame_finished {
592            state.decoder_scratch.buffer.read_all(target)
593        } else {
594            state.decoder_scratch.buffer.read(target)
595        }
596    }
597}