ruzstd/decoding/
frame_decoder.rs

1//! Framedecoder is the main low-level struct users interact with to decode zstd frames
2//!
3//! Zstandard compressed data is made of one or more frames. Each frame is independent and can be
4//! decompressed independently of other frames. This module contains structures
5//! and utilities that can be used to decode a frame.
6
7use super::frame;
8use crate::decoding;
9use crate::decoding::dictionary::Dictionary;
10use crate::decoding::errors::FrameDecoderError;
11use crate::decoding::scratch::DecoderScratch;
12use crate::io::{Error, Read, Write};
13use alloc::collections::BTreeMap;
14use alloc::vec::Vec;
15use core::convert::TryInto;
16
17/// Low level Zstandard decoder that can be used to decompress frames with fine control over when and how many bytes are decoded.
18///
19/// This decoder is able to decode frames only partially and gives control
20/// over how many bytes/blocks will be decoded at a time (so you don't have to decode a 10GB file into memory all at once).
21/// It reads bytes as needed from a provided source and can be read from to collect partial results.
22///
23/// If you want to just read the whole frame with an `io::Read` without having to deal with manually calling [FrameDecoder::decode_blocks]
24/// you can use the provided [crate::decoding::StreamingDecoder] wich wraps this FrameDecoder.
25///
26/// Workflow is as follows:
27/// ```
28/// use ruzstd::decoding::BlockDecodingStrategy;
29///
30/// # #[cfg(feature = "std")]
31/// use std::io::{Read, Write};
32///
33/// // no_std environments can use the crate's own Read traits
34/// # #[cfg(not(feature = "std"))]
35/// use ruzstd::io::{Read, Write};
36///
37/// fn decode_this(mut file: impl Read) {
38///     //Create a new decoder
39///     let mut frame_dec = ruzstd::decoding::FrameDecoder::new();
40///     let mut result = Vec::new();
41///
42///     // Use reset or init to make the decoder ready to decode the frame from the io::Read
43///     frame_dec.reset(&mut file).unwrap();
44///
45///     // Loop until the frame has been decoded completely
46///     while !frame_dec.is_finished() {
47///         // decode (roughly) batch_size many bytes
48///         frame_dec.decode_blocks(&mut file, BlockDecodingStrategy::UptoBytes(1024)).unwrap();
49///
50///         // read from the decoder to collect bytes from the internal buffer
51///         let bytes_read = frame_dec.read(result.as_mut_slice()).unwrap();
52///
53///         // then do something with it
54///         do_something(&result[0..bytes_read]);
55///     }
56///
57///     // handle the last chunk of data
58///     while frame_dec.can_collect() > 0 {
59///         let x = frame_dec.read(result.as_mut_slice()).unwrap();
60///
61///         do_something(&result[0..x]);
62///     }
63/// }
64///
65/// fn do_something(data: &[u8]) {
66/// # #[cfg(feature = "std")]
67///     std::io::stdout().write_all(data).unwrap();
68/// }
69/// ```
70pub struct FrameDecoder {
71    state: Option<FrameDecoderState>,
72    dicts: BTreeMap<u32, Dictionary>,
73}
74
75struct FrameDecoderState {
76    pub frame: frame::Frame,
77    decoder_scratch: DecoderScratch,
78    frame_finished: bool,
79    block_counter: usize,
80    bytes_read_counter: u64,
81    check_sum: Option<u32>,
82    using_dict: Option<u32>,
83}
84
85pub enum BlockDecodingStrategy {
86    All,
87    UptoBlocks(usize),
88    UptoBytes(usize),
89}
90
91const MAX_WINDOW_SIZE: u64 = 1024 * 1024 * 100;
92
93impl FrameDecoderState {
94    pub fn new(source: impl Read) -> Result<FrameDecoderState, FrameDecoderError> {
95        let (frame, header_size) = frame::read_frame_header(source)?;
96        let window_size = frame.header.window_size()?;
97        Ok(FrameDecoderState {
98            frame,
99            frame_finished: false,
100            block_counter: 0,
101            decoder_scratch: DecoderScratch::new(window_size as usize),
102            bytes_read_counter: u64::from(header_size),
103            check_sum: None,
104            using_dict: None,
105        })
106    }
107
108    pub fn reset(&mut self, source: impl Read) -> Result<(), FrameDecoderError> {
109        let (frame, header_size) = frame::read_frame_header(source)?;
110        let window_size = frame.header.window_size()?;
111
112        if window_size > MAX_WINDOW_SIZE {
113            return Err(FrameDecoderError::WindowSizeTooBig {
114                requested: window_size,
115            });
116        }
117
118        self.frame = frame;
119        self.frame_finished = false;
120        self.block_counter = 0;
121        self.decoder_scratch.reset(window_size as usize);
122        self.bytes_read_counter = u64::from(header_size);
123        self.check_sum = None;
124        self.using_dict = None;
125        Ok(())
126    }
127}
128
129impl Default for FrameDecoder {
130    fn default() -> Self {
131        Self::new()
132    }
133}
134
135impl FrameDecoder {
136    /// This will create a new decoder without allocating anything yet.
137    /// init()/reset() will allocate all needed buffers if it is the first time this decoder is used
138    /// else they just reset these buffers with not further allocations
139    pub fn new() -> FrameDecoder {
140        FrameDecoder {
141            state: None,
142            dicts: BTreeMap::new(),
143        }
144    }
145
146    /// init() will allocate all needed buffers if it is the first time this decoder is used
147    /// else they just reset these buffers with not further allocations
148    ///
149    /// Note that all bytes currently in the decodebuffer from any previous frame will be lost. Collect them with collect()/collect_to_writer()
150    ///
151    /// equivalent to reset()
152    pub fn init(&mut self, source: impl Read) -> Result<(), FrameDecoderError> {
153        self.reset(source)
154    }
155
156    /// reset() will allocate all needed buffers if it is the first time this decoder is used
157    /// else they just reset these buffers with not further allocations
158    ///
159    /// Note that all bytes currently in the decodebuffer from any previous frame will be lost. Collect them with collect()/collect_to_writer()
160    ///
161    /// equivalent to init()
162    pub fn reset(&mut self, source: impl Read) -> Result<(), FrameDecoderError> {
163        use FrameDecoderError as err;
164        let state = match &mut self.state {
165            Some(s) => {
166                s.reset(source)?;
167                s
168            }
169            None => {
170                self.state = Some(FrameDecoderState::new(source)?);
171                self.state.as_mut().unwrap()
172            }
173        };
174        if let Some(dict_id) = state.frame.header.dictionary_id() {
175            let dict = self
176                .dicts
177                .get(&dict_id)
178                .ok_or(err::DictNotProvided { dict_id })?;
179            state.decoder_scratch.init_from_dict(dict);
180            state.using_dict = Some(dict_id);
181        }
182        Ok(())
183    }
184
185    /// Add a dict to the FrameDecoder that can be used when needed. The FrameDecoder uses the appropriate one dynamically
186    pub fn add_dict(&mut self, dict: Dictionary) -> Result<(), FrameDecoderError> {
187        self.dicts.insert(dict.id, dict);
188        Ok(())
189    }
190
191    pub fn force_dict(&mut self, dict_id: u32) -> Result<(), FrameDecoderError> {
192        use FrameDecoderError as err;
193        let Some(state) = self.state.as_mut() else {
194            return Err(err::NotYetInitialized);
195        };
196
197        let dict = self
198            .dicts
199            .get(&dict_id)
200            .ok_or(err::DictNotProvided { dict_id })?;
201        state.decoder_scratch.init_from_dict(dict);
202        state.using_dict = Some(dict_id);
203
204        Ok(())
205    }
206
207    /// Returns how many bytes the frame contains after decompression
208    pub fn content_size(&self) -> u64 {
209        match &self.state {
210            None => 0,
211            Some(s) => s.frame.header.frame_content_size(),
212        }
213    }
214
215    /// Returns the checksum that was read from the data. Only available after all bytes have been read. It is the last 4 bytes of a zstd-frame
216    pub fn get_checksum_from_data(&self) -> Option<u32> {
217        let state = match &self.state {
218            None => return None,
219            Some(s) => s,
220        };
221
222        state.check_sum
223    }
224
225    /// Returns the checksum that was calculated while decoding.
226    /// Only a sensible value after all decoded bytes have been collected/read from the FrameDecoder
227    #[cfg(feature = "hash")]
228    pub fn get_calculated_checksum(&self) -> Option<u32> {
229        use core::hash::Hasher;
230
231        let state = match &self.state {
232            None => return None,
233            Some(s) => s,
234        };
235        let cksum_64bit = state.decoder_scratch.buffer.hash.finish();
236        //truncate to lower 32bit because reasons...
237        Some(cksum_64bit as u32)
238    }
239
240    /// Counter for how many bytes have been consumed while decoding the frame
241    pub fn bytes_read_from_source(&self) -> u64 {
242        let state = match &self.state {
243            None => return 0,
244            Some(s) => s,
245        };
246        state.bytes_read_counter
247    }
248
249    /// Whether the current frames last block has been decoded yet
250    /// If this returns true you can call the drain* functions to get all content
251    /// (the read() function will drain automatically if this returns true)
252    pub fn is_finished(&self) -> bool {
253        let state = match &self.state {
254            None => return true,
255            Some(s) => s,
256        };
257        if state.frame.header.descriptor.content_checksum_flag() {
258            state.frame_finished && state.check_sum.is_some()
259        } else {
260            state.frame_finished
261        }
262    }
263
264    /// Counter for how many blocks have already been decoded
265    pub fn blocks_decoded(&self) -> usize {
266        let state = match &self.state {
267            None => return 0,
268            Some(s) => s,
269        };
270        state.block_counter
271    }
272
273    /// Decodes blocks from a reader. It requires that the framedecoder has been initialized first.
274    /// The Strategy influences how many blocks will be decoded before the function returns
275    /// This is important if you want to manage memory consumption carefully. If you don't care
276    /// about that you can just choose the strategy "All" and have all blocks of the frame decoded into the buffer
277    pub fn decode_blocks(
278        &mut self,
279        mut source: impl Read,
280        strat: BlockDecodingStrategy,
281    ) -> Result<bool, FrameDecoderError> {
282        use FrameDecoderError as err;
283        let state = self.state.as_mut().ok_or(err::NotYetInitialized)?;
284
285        let mut block_dec = decoding::block_decoder::new();
286
287        let buffer_size_before = state.decoder_scratch.buffer.len();
288        let block_counter_before = state.block_counter;
289        loop {
290            vprintln!("################");
291            vprintln!("Next Block: {}", state.block_counter);
292            vprintln!("################");
293            let (block_header, block_header_size) = block_dec
294                .read_block_header(&mut source)
295                .map_err(err::FailedToReadBlockHeader)?;
296            state.bytes_read_counter += u64::from(block_header_size);
297
298            vprintln!();
299            vprintln!(
300                "Found {} block with size: {}, which will be of size: {}",
301                block_header.block_type,
302                block_header.content_size,
303                block_header.decompressed_size
304            );
305
306            let bytes_read_in_block_body = block_dec
307                .decode_block_content(&block_header, &mut state.decoder_scratch, &mut source)
308                .map_err(err::FailedToReadBlockBody)?;
309            state.bytes_read_counter += bytes_read_in_block_body;
310
311            state.block_counter += 1;
312
313            vprintln!("Output: {}", state.decoder_scratch.buffer.len());
314
315            if block_header.last_block {
316                state.frame_finished = true;
317                if state.frame.header.descriptor.content_checksum_flag() {
318                    let mut chksum = [0u8; 4];
319                    source
320                        .read_exact(&mut chksum)
321                        .map_err(err::FailedToReadChecksum)?;
322                    state.bytes_read_counter += 4;
323                    let chksum = u32::from_le_bytes(chksum);
324                    state.check_sum = Some(chksum);
325                }
326                break;
327            }
328
329            match strat {
330                BlockDecodingStrategy::All => { /* keep going */ }
331                BlockDecodingStrategy::UptoBlocks(n) => {
332                    if state.block_counter - block_counter_before >= n {
333                        break;
334                    }
335                }
336                BlockDecodingStrategy::UptoBytes(n) => {
337                    if state.decoder_scratch.buffer.len() - buffer_size_before >= n {
338                        break;
339                    }
340                }
341            }
342        }
343
344        Ok(state.frame_finished)
345    }
346
347    /// Collect bytes and retain window_size bytes while decoding is still going on.
348    /// After decoding of the frame (is_finished() == true) has finished it will collect all remaining bytes
349    pub fn collect(&mut self) -> Option<Vec<u8>> {
350        let finished = self.is_finished();
351        let state = self.state.as_mut()?;
352        if finished {
353            Some(state.decoder_scratch.buffer.drain())
354        } else {
355            state.decoder_scratch.buffer.drain_to_window_size()
356        }
357    }
358
359    /// Collect bytes and retain window_size bytes while decoding is still going on.
360    /// After decoding of the frame (is_finished() == true) has finished it will collect all remaining bytes
361    pub fn collect_to_writer(&mut self, w: impl Write) -> Result<usize, Error> {
362        let finished = self.is_finished();
363        let state = match &mut self.state {
364            None => return Ok(0),
365            Some(s) => s,
366        };
367        if finished {
368            state.decoder_scratch.buffer.drain_to_writer(w)
369        } else {
370            state.decoder_scratch.buffer.drain_to_window_size_writer(w)
371        }
372    }
373
374    /// How many bytes can currently be collected from the decodebuffer, while decoding is going on this will be lower than the actual decodbuffer size
375    /// because window_size bytes need to be retained for decoding.
376    /// After decoding of the frame (is_finished() == true) has finished it will report all remaining bytes
377    pub fn can_collect(&self) -> usize {
378        let finished = self.is_finished();
379        let state = match &self.state {
380            None => return 0,
381            Some(s) => s,
382        };
383        if finished {
384            state.decoder_scratch.buffer.can_drain()
385        } else {
386            state
387                .decoder_scratch
388                .buffer
389                .can_drain_to_window_size()
390                .unwrap_or(0)
391        }
392    }
393
394    /// Decodes as many blocks as possible from the source slice and reads from the decodebuffer into the target slice
395    /// The source slice may contain only parts of a frame but must contain at least one full block to make progress
396    ///
397    /// By all means use decode_blocks if you have a io.Reader available. This is just for compatibility with other decompressors
398    /// which try to serve an old-style c api
399    ///
400    /// Returns (read, written), if read == 0 then the source did not contain a full block and further calls with the same
401    /// input will not make any progress!
402    ///
403    /// Note that no kind of block can be bigger than 128kb.
404    /// So to be safe use at least 128*1024 (max block content size) + 3 (block_header size) + 18 (max frame_header size) bytes as your source buffer
405    ///
406    /// You may call this function with an empty source after all bytes have been decoded. This is equivalent to just call decoder.read(&mut target)
407    pub fn decode_from_to(
408        &mut self,
409        source: &[u8],
410        target: &mut [u8],
411    ) -> Result<(usize, usize), FrameDecoderError> {
412        use FrameDecoderError as err;
413        let bytes_read_at_start = match &self.state {
414            Some(s) => s.bytes_read_counter,
415            None => 0,
416        };
417
418        if !self.is_finished() || self.state.is_none() {
419            let mut mt_source = source;
420
421            if self.state.is_none() {
422                self.init(&mut mt_source)?;
423            }
424
425            //pseudo block to scope "state" so we can borrow self again after the block
426            {
427                let state = match &mut self.state {
428                    Some(s) => s,
429                    None => panic!("Bug in library"),
430                };
431                let mut block_dec = decoding::block_decoder::new();
432
433                if state.frame.header.descriptor.content_checksum_flag()
434                    && state.frame_finished
435                    && state.check_sum.is_none()
436                {
437                    //this block is needed if the checksum were the only 4 bytes that were not included in the last decode_from_to call for a frame
438                    if mt_source.len() >= 4 {
439                        let chksum = mt_source[..4].try_into().expect("optimized away");
440                        state.bytes_read_counter += 4;
441                        let chksum = u32::from_le_bytes(chksum);
442                        state.check_sum = Some(chksum);
443                    }
444                    return Ok((4, 0));
445                }
446
447                loop {
448                    //check if there are enough bytes for the next header
449                    if mt_source.len() < 3 {
450                        break;
451                    }
452                    let (block_header, block_header_size) = block_dec
453                        .read_block_header(&mut mt_source)
454                        .map_err(err::FailedToReadBlockHeader)?;
455
456                    // check the needed size for the block before updating counters.
457                    // If not enough bytes are in the source, the header will have to be read again, so act like we never read it in the first place
458                    if mt_source.len() < block_header.content_size as usize {
459                        break;
460                    }
461                    state.bytes_read_counter += u64::from(block_header_size);
462
463                    let bytes_read_in_block_body = block_dec
464                        .decode_block_content(
465                            &block_header,
466                            &mut state.decoder_scratch,
467                            &mut mt_source,
468                        )
469                        .map_err(err::FailedToReadBlockBody)?;
470                    state.bytes_read_counter += bytes_read_in_block_body;
471                    state.block_counter += 1;
472
473                    if block_header.last_block {
474                        state.frame_finished = true;
475                        if state.frame.header.descriptor.content_checksum_flag() {
476                            //if there are enough bytes handle this here. Else the block at the start of this function will handle it at the next call
477                            if mt_source.len() >= 4 {
478                                let chksum = mt_source[..4].try_into().expect("optimized away");
479                                state.bytes_read_counter += 4;
480                                let chksum = u32::from_le_bytes(chksum);
481                                state.check_sum = Some(chksum);
482                            }
483                        }
484                        break;
485                    }
486                }
487            }
488        }
489
490        let result_len = self.read(target).map_err(err::FailedToDrainDecodebuffer)?;
491        let bytes_read_at_end = match &mut self.state {
492            Some(s) => s.bytes_read_counter,
493            None => panic!("Bug in library"),
494        };
495        let read_len = bytes_read_at_end - bytes_read_at_start;
496        Ok((read_len as usize, result_len))
497    }
498
499    /// Decode multiple frames into the output slice.
500    ///
501    /// `input` must contain an exact number of frames.
502    ///
503    /// `output` must be large enough to hold the decompressed data. If you don't know
504    /// how large the output will be, use [`FrameDecoder::decode_blocks`] instead.
505    ///
506    /// This calls [`FrameDecoder::init`], and all bytes currently in the decoder will be lost.
507    ///
508    /// Returns the number of bytes written to `output`.
509    pub fn decode_all(
510        &mut self,
511        mut input: &[u8],
512        mut output: &mut [u8],
513    ) -> Result<usize, FrameDecoderError> {
514        let mut total_bytes_written = 0;
515        while !input.is_empty() {
516            match self.init(&mut input) {
517                Ok(_) => {}
518                Err(FrameDecoderError::ReadFrameHeaderError(
519                    crate::decoding::errors::ReadFrameHeaderError::SkipFrame { length, .. },
520                )) => {
521                    input = input
522                        .get(length as usize..)
523                        .ok_or(FrameDecoderError::FailedToSkipFrame)?;
524                    continue;
525                }
526                Err(e) => return Err(e),
527            };
528            loop {
529                self.decode_blocks(&mut input, BlockDecodingStrategy::UptoBytes(1024 * 1024))?;
530                let bytes_written = self
531                    .read(output)
532                    .map_err(FrameDecoderError::FailedToDrainDecodebuffer)?;
533                output = &mut output[bytes_written..];
534                total_bytes_written += bytes_written;
535                if self.can_collect() != 0 {
536                    return Err(FrameDecoderError::TargetTooSmall);
537                }
538                if self.is_finished() {
539                    break;
540                }
541            }
542        }
543
544        Ok(total_bytes_written)
545    }
546
547    /// Decode multiple frames into the extra capacity of the output vector.
548    ///
549    /// `input` must contain an exact number of frames.
550    ///
551    /// `output` must have enough extra capacity to hold the decompressed data.
552    /// This function will not reallocate or grow the vector. If you don't know
553    /// how large the output will be, use [`FrameDecoder::decode_blocks`] instead.
554    ///
555    /// This calls [`FrameDecoder::init`], and all bytes currently in the decoder will be lost.
556    ///
557    /// The length of the output vector is updated to include the decompressed data.
558    /// The length is not changed if an error occurs.
559    pub fn decode_all_to_vec(
560        &mut self,
561        input: &[u8],
562        output: &mut Vec<u8>,
563    ) -> Result<(), FrameDecoderError> {
564        let len = output.len();
565        let cap = output.capacity();
566        output.resize(cap, 0);
567        match self.decode_all(input, &mut output[len..]) {
568            Ok(bytes_written) => {
569                let new_len = core::cmp::min(len + bytes_written, cap); // Sanitizes `bytes_written`.
570                output.resize(new_len, 0);
571                Ok(())
572            }
573            Err(e) => {
574                output.resize(len, 0);
575                Err(e)
576            }
577        }
578    }
579}
580
581/// Read bytes from the decode_buffer that are no longer needed. While the frame is not yet finished
582/// this will retain window_size bytes, else it will drain it completely
583impl Read for FrameDecoder {
584    fn read(&mut self, target: &mut [u8]) -> Result<usize, Error> {
585        let state = match &mut self.state {
586            None => return Ok(0),
587            Some(s) => s,
588        };
589        if state.frame_finished {
590            state.decoder_scratch.buffer.read_all(target)
591        } else {
592            state.decoder_scratch.buffer.read(target)
593        }
594    }
595}