ruzstd/decoding/frame_decoder.rs
1//! Framedecoder is the main low-level struct users interact with to decode zstd frames
2//!
3//! Zstandard compressed data is made of one or more frames. Each frame is independent and can be
4//! decompressed independently of other frames. This module contains structures
5//! and utilities that can be used to decode a frame.
6
7use super::frame;
8use crate::decoding;
9use crate::decoding::dictionary::Dictionary;
10use crate::decoding::errors::FrameDecoderError;
11use crate::decoding::scratch::DecoderScratch;
12use crate::io::{Error, Read, Write};
13use alloc::collections::BTreeMap;
14use alloc::vec::Vec;
15use core::convert::TryInto;
16
17/// While the maximum window size allowed by the spec is significantly larger,
18/// our implementation limits it to 100mb to protect against malformed frames.
19const MAXIMUM_ALLOWED_WINDOW_SIZE: u64 = 1024 * 1024 * 100;
20
21/// Low level Zstandard decoder that can be used to decompress frames with fine control over when and how many bytes are decoded.
22///
23/// This decoder is able to decode frames only partially and gives control
24/// over how many bytes/blocks will be decoded at a time (so you don't have to decode a 10GB file into memory all at once).
25/// It reads bytes as needed from a provided source and can be read from to collect partial results.
26///
27/// If you want to just read the whole frame with an `io::Read` without having to deal with manually calling [FrameDecoder::decode_blocks]
28/// you can use the provided [crate::decoding::StreamingDecoder] wich wraps this FrameDecoder.
29///
30/// Workflow is as follows:
31/// ```
32/// use ruzstd::decoding::BlockDecodingStrategy;
33///
34/// # #[cfg(feature = "std")]
35/// use std::io::{Read, Write};
36///
37/// // no_std environments can use the crate's own Read traits
38/// # #[cfg(not(feature = "std"))]
39/// use ruzstd::io::{Read, Write};
40///
41/// fn decode_this(mut file: impl Read) {
42/// //Create a new decoder
43/// let mut frame_dec = ruzstd::decoding::FrameDecoder::new();
44/// let mut result = Vec::new();
45///
46/// // Use reset or init to make the decoder ready to decode the frame from the io::Read
47/// frame_dec.reset(&mut file).unwrap();
48///
49/// // Loop until the frame has been decoded completely
50/// while !frame_dec.is_finished() {
51/// // decode (roughly) batch_size many bytes
52/// frame_dec.decode_blocks(&mut file, BlockDecodingStrategy::UptoBytes(1024)).unwrap();
53///
54/// // read from the decoder to collect bytes from the internal buffer
55/// let bytes_read = frame_dec.read(result.as_mut_slice()).unwrap();
56///
57/// // then do something with it
58/// do_something(&result[0..bytes_read]);
59/// }
60///
61/// // handle the last chunk of data
62/// while frame_dec.can_collect() > 0 {
63/// let x = frame_dec.read(result.as_mut_slice()).unwrap();
64///
65/// do_something(&result[0..x]);
66/// }
67/// }
68///
69/// fn do_something(data: &[u8]) {
70/// # #[cfg(feature = "std")]
71/// std::io::stdout().write_all(data).unwrap();
72/// }
73/// ```
74pub struct FrameDecoder {
75 state: Option<FrameDecoderState>,
76 dicts: BTreeMap<u32, Dictionary>,
77}
78
79struct FrameDecoderState {
80 pub frame_header: frame::FrameHeader,
81 decoder_scratch: DecoderScratch,
82 frame_finished: bool,
83 block_counter: usize,
84 bytes_read_counter: u64,
85 check_sum: Option<u32>,
86 using_dict: Option<u32>,
87}
88
89pub enum BlockDecodingStrategy {
90 All,
91 UptoBlocks(usize),
92 UptoBytes(usize),
93}
94
95impl FrameDecoderState {
96 pub fn new(source: impl Read) -> Result<FrameDecoderState, FrameDecoderError> {
97 let (frame, header_size) = frame::read_frame_header(source)?;
98 let window_size = frame.window_size()?;
99 Ok(FrameDecoderState {
100 frame_header: frame,
101 frame_finished: false,
102 block_counter: 0,
103 decoder_scratch: DecoderScratch::new(window_size as usize),
104 bytes_read_counter: u64::from(header_size),
105 check_sum: None,
106 using_dict: None,
107 })
108 }
109
110 pub fn reset(&mut self, source: impl Read) -> Result<(), FrameDecoderError> {
111 let (frame_header, header_size) = frame::read_frame_header(source)?;
112 let window_size = frame_header.window_size()?;
113
114 if window_size > MAXIMUM_ALLOWED_WINDOW_SIZE {
115 return Err(FrameDecoderError::WindowSizeTooBig {
116 requested: window_size,
117 });
118 }
119
120 self.frame_header = frame_header;
121 self.frame_finished = false;
122 self.block_counter = 0;
123 self.decoder_scratch.reset(window_size as usize);
124 self.bytes_read_counter = u64::from(header_size);
125 self.check_sum = None;
126 self.using_dict = None;
127 Ok(())
128 }
129}
130
131impl Default for FrameDecoder {
132 fn default() -> Self {
133 Self::new()
134 }
135}
136
137impl FrameDecoder {
138 /// This will create a new decoder without allocating anything yet.
139 /// init()/reset() will allocate all needed buffers if it is the first time this decoder is used
140 /// else they just reset these buffers with not further allocations
141 pub fn new() -> FrameDecoder {
142 FrameDecoder {
143 state: None,
144 dicts: BTreeMap::new(),
145 }
146 }
147
148 /// init() will allocate all needed buffers if it is the first time this decoder is used
149 /// else they just reset these buffers with not further allocations
150 ///
151 /// Note that all bytes currently in the decodebuffer from any previous frame will be lost. Collect them with collect()/collect_to_writer()
152 ///
153 /// equivalent to reset()
154 pub fn init(&mut self, source: impl Read) -> Result<(), FrameDecoderError> {
155 self.reset(source)
156 }
157
158 /// reset() will allocate all needed buffers if it is the first time this decoder is used
159 /// else they just reset these buffers with not further allocations
160 ///
161 /// Note that all bytes currently in the decodebuffer from any previous frame will be lost. Collect them with collect()/collect_to_writer()
162 ///
163 /// equivalent to init()
164 pub fn reset(&mut self, source: impl Read) -> Result<(), FrameDecoderError> {
165 use FrameDecoderError as err;
166 let state = match &mut self.state {
167 Some(s) => {
168 s.reset(source)?;
169 s
170 }
171 None => {
172 self.state = Some(FrameDecoderState::new(source)?);
173 self.state.as_mut().unwrap()
174 }
175 };
176 if let Some(dict_id) = state.frame_header.dictionary_id() {
177 let dict = self
178 .dicts
179 .get(&dict_id)
180 .ok_or(err::DictNotProvided { dict_id })?;
181 state.decoder_scratch.init_from_dict(dict);
182 state.using_dict = Some(dict_id);
183 }
184 Ok(())
185 }
186
187 /// Add a dict to the FrameDecoder that can be used when needed. The FrameDecoder uses the appropriate one dynamically
188 pub fn add_dict(&mut self, dict: Dictionary) -> Result<(), FrameDecoderError> {
189 self.dicts.insert(dict.id, dict);
190 Ok(())
191 }
192
193 pub fn force_dict(&mut self, dict_id: u32) -> Result<(), FrameDecoderError> {
194 use FrameDecoderError as err;
195 let Some(state) = self.state.as_mut() else {
196 return Err(err::NotYetInitialized);
197 };
198
199 let dict = self
200 .dicts
201 .get(&dict_id)
202 .ok_or(err::DictNotProvided { dict_id })?;
203 state.decoder_scratch.init_from_dict(dict);
204 state.using_dict = Some(dict_id);
205
206 Ok(())
207 }
208
209 /// Returns how many bytes the frame contains after decompression
210 pub fn content_size(&self) -> u64 {
211 match &self.state {
212 None => 0,
213 Some(s) => s.frame_header.frame_content_size(),
214 }
215 }
216
217 /// Returns the checksum that was read from the data. Only available after all bytes have been read. It is the last 4 bytes of a zstd-frame
218 pub fn get_checksum_from_data(&self) -> Option<u32> {
219 let state = match &self.state {
220 None => return None,
221 Some(s) => s,
222 };
223
224 state.check_sum
225 }
226
227 /// Returns the checksum that was calculated while decoding.
228 /// Only a sensible value after all decoded bytes have been collected/read from the FrameDecoder
229 #[cfg(feature = "hash")]
230 pub fn get_calculated_checksum(&self) -> Option<u32> {
231 use core::hash::Hasher;
232
233 let state = match &self.state {
234 None => return None,
235 Some(s) => s,
236 };
237 let cksum_64bit = state.decoder_scratch.buffer.hash.finish();
238 //truncate to lower 32bit because reasons...
239 Some(cksum_64bit as u32)
240 }
241
242 /// Counter for how many bytes have been consumed while decoding the frame
243 pub fn bytes_read_from_source(&self) -> u64 {
244 let state = match &self.state {
245 None => return 0,
246 Some(s) => s,
247 };
248 state.bytes_read_counter
249 }
250
251 /// Whether the current frames last block has been decoded yet
252 /// If this returns true you can call the drain* functions to get all content
253 /// (the read() function will drain automatically if this returns true)
254 pub fn is_finished(&self) -> bool {
255 let state = match &self.state {
256 None => return true,
257 Some(s) => s,
258 };
259 if state.frame_header.descriptor.content_checksum_flag() {
260 state.frame_finished && state.check_sum.is_some()
261 } else {
262 state.frame_finished
263 }
264 }
265
266 /// Counter for how many blocks have already been decoded
267 pub fn blocks_decoded(&self) -> usize {
268 let state = match &self.state {
269 None => return 0,
270 Some(s) => s,
271 };
272 state.block_counter
273 }
274
275 /// Decodes blocks from a reader. It requires that the framedecoder has been initialized first.
276 /// The Strategy influences how many blocks will be decoded before the function returns
277 /// This is important if you want to manage memory consumption carefully. If you don't care
278 /// about that you can just choose the strategy "All" and have all blocks of the frame decoded into the buffer
279 pub fn decode_blocks(
280 &mut self,
281 mut source: impl Read,
282 strat: BlockDecodingStrategy,
283 ) -> Result<bool, FrameDecoderError> {
284 use FrameDecoderError as err;
285 let state = self.state.as_mut().ok_or(err::NotYetInitialized)?;
286
287 let mut block_dec = decoding::block_decoder::new();
288
289 let buffer_size_before = state.decoder_scratch.buffer.len();
290 let block_counter_before = state.block_counter;
291 loop {
292 vprintln!("################");
293 vprintln!("Next Block: {}", state.block_counter);
294 vprintln!("################");
295 let (block_header, block_header_size) = block_dec
296 .read_block_header(&mut source)
297 .map_err(err::FailedToReadBlockHeader)?;
298 state.bytes_read_counter += u64::from(block_header_size);
299
300 vprintln!();
301 vprintln!(
302 "Found {} block with size: {}, which will be of size: {}",
303 block_header.block_type,
304 block_header.content_size,
305 block_header.decompressed_size
306 );
307
308 let bytes_read_in_block_body = block_dec
309 .decode_block_content(&block_header, &mut state.decoder_scratch, &mut source)
310 .map_err(err::FailedToReadBlockBody)?;
311 state.bytes_read_counter += bytes_read_in_block_body;
312
313 state.block_counter += 1;
314
315 vprintln!("Output: {}", state.decoder_scratch.buffer.len());
316
317 if block_header.last_block {
318 state.frame_finished = true;
319 if state.frame_header.descriptor.content_checksum_flag() {
320 let mut chksum = [0u8; 4];
321 source
322 .read_exact(&mut chksum)
323 .map_err(err::FailedToReadChecksum)?;
324 state.bytes_read_counter += 4;
325 let chksum = u32::from_le_bytes(chksum);
326 state.check_sum = Some(chksum);
327 }
328 break;
329 }
330
331 match strat {
332 BlockDecodingStrategy::All => { /* keep going */ }
333 BlockDecodingStrategy::UptoBlocks(n) => {
334 if state.block_counter - block_counter_before >= n {
335 break;
336 }
337 }
338 BlockDecodingStrategy::UptoBytes(n) => {
339 if state.decoder_scratch.buffer.len() - buffer_size_before >= n {
340 break;
341 }
342 }
343 }
344 }
345
346 Ok(state.frame_finished)
347 }
348
349 /// Collect bytes and retain window_size bytes while decoding is still going on.
350 /// After decoding of the frame (is_finished() == true) has finished it will collect all remaining bytes
351 pub fn collect(&mut self) -> Option<Vec<u8>> {
352 let finished = self.is_finished();
353 let state = self.state.as_mut()?;
354 if finished {
355 Some(state.decoder_scratch.buffer.drain())
356 } else {
357 state.decoder_scratch.buffer.drain_to_window_size()
358 }
359 }
360
361 /// Collect bytes and retain window_size bytes while decoding is still going on.
362 /// After decoding of the frame (is_finished() == true) has finished it will collect all remaining bytes
363 pub fn collect_to_writer(&mut self, w: impl Write) -> Result<usize, Error> {
364 let finished = self.is_finished();
365 let state = match &mut self.state {
366 None => return Ok(0),
367 Some(s) => s,
368 };
369 if finished {
370 state.decoder_scratch.buffer.drain_to_writer(w)
371 } else {
372 state.decoder_scratch.buffer.drain_to_window_size_writer(w)
373 }
374 }
375
376 /// How many bytes can currently be collected from the decodebuffer, while decoding is going on this will be lower than the actual decodbuffer size
377 /// because window_size bytes need to be retained for decoding.
378 /// After decoding of the frame (is_finished() == true) has finished it will report all remaining bytes
379 pub fn can_collect(&self) -> usize {
380 let finished = self.is_finished();
381 let state = match &self.state {
382 None => return 0,
383 Some(s) => s,
384 };
385 if finished {
386 state.decoder_scratch.buffer.can_drain()
387 } else {
388 state
389 .decoder_scratch
390 .buffer
391 .can_drain_to_window_size()
392 .unwrap_or(0)
393 }
394 }
395
396 /// Decodes as many blocks as possible from the source slice and reads from the decodebuffer into the target slice
397 /// The source slice may contain only parts of a frame but must contain at least one full block to make progress
398 ///
399 /// By all means use decode_blocks if you have a io.Reader available. This is just for compatibility with other decompressors
400 /// which try to serve an old-style c api
401 ///
402 /// Returns (read, written), if read == 0 then the source did not contain a full block and further calls with the same
403 /// input will not make any progress!
404 ///
405 /// Note that no kind of block can be bigger than 128kb.
406 /// So to be safe use at least 128*1024 (max block content size) + 3 (block_header size) + 18 (max frame_header size) bytes as your source buffer
407 ///
408 /// You may call this function with an empty source after all bytes have been decoded. This is equivalent to just call decoder.read(&mut target)
409 pub fn decode_from_to(
410 &mut self,
411 source: &[u8],
412 target: &mut [u8],
413 ) -> Result<(usize, usize), FrameDecoderError> {
414 use FrameDecoderError as err;
415 let bytes_read_at_start = match &self.state {
416 Some(s) => s.bytes_read_counter,
417 None => 0,
418 };
419
420 if !self.is_finished() || self.state.is_none() {
421 let mut mt_source = source;
422
423 if self.state.is_none() {
424 self.init(&mut mt_source)?;
425 }
426
427 //pseudo block to scope "state" so we can borrow self again after the block
428 {
429 let state = match &mut self.state {
430 Some(s) => s,
431 None => panic!("Bug in library"),
432 };
433 let mut block_dec = decoding::block_decoder::new();
434
435 if state.frame_header.descriptor.content_checksum_flag()
436 && state.frame_finished
437 && state.check_sum.is_none()
438 {
439 //this block is needed if the checksum were the only 4 bytes that were not included in the last decode_from_to call for a frame
440 if mt_source.len() >= 4 {
441 let chksum = mt_source[..4].try_into().expect("optimized away");
442 state.bytes_read_counter += 4;
443 let chksum = u32::from_le_bytes(chksum);
444 state.check_sum = Some(chksum);
445 }
446 return Ok((4, 0));
447 }
448
449 loop {
450 //check if there are enough bytes for the next header
451 if mt_source.len() < 3 {
452 break;
453 }
454 let (block_header, block_header_size) = block_dec
455 .read_block_header(&mut mt_source)
456 .map_err(err::FailedToReadBlockHeader)?;
457
458 // check the needed size for the block before updating counters.
459 // If not enough bytes are in the source, the header will have to be read again, so act like we never read it in the first place
460 if mt_source.len() < block_header.content_size as usize {
461 break;
462 }
463 state.bytes_read_counter += u64::from(block_header_size);
464
465 let bytes_read_in_block_body = block_dec
466 .decode_block_content(
467 &block_header,
468 &mut state.decoder_scratch,
469 &mut mt_source,
470 )
471 .map_err(err::FailedToReadBlockBody)?;
472 state.bytes_read_counter += bytes_read_in_block_body;
473 state.block_counter += 1;
474
475 if block_header.last_block {
476 state.frame_finished = true;
477 if state.frame_header.descriptor.content_checksum_flag() {
478 //if there are enough bytes handle this here. Else the block at the start of this function will handle it at the next call
479 if mt_source.len() >= 4 {
480 let chksum = mt_source[..4].try_into().expect("optimized away");
481 state.bytes_read_counter += 4;
482 let chksum = u32::from_le_bytes(chksum);
483 state.check_sum = Some(chksum);
484 }
485 }
486 break;
487 }
488 }
489 }
490 }
491
492 let result_len = self.read(target).map_err(err::FailedToDrainDecodebuffer)?;
493 let bytes_read_at_end = match &mut self.state {
494 Some(s) => s.bytes_read_counter,
495 None => panic!("Bug in library"),
496 };
497 let read_len = bytes_read_at_end - bytes_read_at_start;
498 Ok((read_len as usize, result_len))
499 }
500
501 /// Decode multiple frames into the output slice.
502 ///
503 /// `input` must contain an exact number of frames.
504 ///
505 /// `output` must be large enough to hold the decompressed data. If you don't know
506 /// how large the output will be, use [`FrameDecoder::decode_blocks`] instead.
507 ///
508 /// This calls [`FrameDecoder::init`], and all bytes currently in the decoder will be lost.
509 ///
510 /// Returns the number of bytes written to `output`.
511 pub fn decode_all(
512 &mut self,
513 mut input: &[u8],
514 mut output: &mut [u8],
515 ) -> Result<usize, FrameDecoderError> {
516 let mut total_bytes_written = 0;
517 while !input.is_empty() {
518 match self.init(&mut input) {
519 Ok(_) => {}
520 Err(FrameDecoderError::ReadFrameHeaderError(
521 crate::decoding::errors::ReadFrameHeaderError::SkipFrame { length, .. },
522 )) => {
523 input = input
524 .get(length as usize..)
525 .ok_or(FrameDecoderError::FailedToSkipFrame)?;
526 continue;
527 }
528 Err(e) => return Err(e),
529 };
530 loop {
531 self.decode_blocks(&mut input, BlockDecodingStrategy::UptoBytes(1024 * 1024))?;
532 let bytes_written = self
533 .read(output)
534 .map_err(FrameDecoderError::FailedToDrainDecodebuffer)?;
535 output = &mut output[bytes_written..];
536 total_bytes_written += bytes_written;
537 if self.can_collect() != 0 {
538 return Err(FrameDecoderError::TargetTooSmall);
539 }
540 if self.is_finished() {
541 break;
542 }
543 }
544 }
545
546 Ok(total_bytes_written)
547 }
548
549 /// Decode multiple frames into the extra capacity of the output vector.
550 ///
551 /// `input` must contain an exact number of frames.
552 ///
553 /// `output` must have enough extra capacity to hold the decompressed data.
554 /// This function will not reallocate or grow the vector. If you don't know
555 /// how large the output will be, use [`FrameDecoder::decode_blocks`] instead.
556 ///
557 /// This calls [`FrameDecoder::init`], and all bytes currently in the decoder will be lost.
558 ///
559 /// The length of the output vector is updated to include the decompressed data.
560 /// The length is not changed if an error occurs.
561 pub fn decode_all_to_vec(
562 &mut self,
563 input: &[u8],
564 output: &mut Vec<u8>,
565 ) -> Result<(), FrameDecoderError> {
566 let len = output.len();
567 let cap = output.capacity();
568 output.resize(cap, 0);
569 match self.decode_all(input, &mut output[len..]) {
570 Ok(bytes_written) => {
571 let new_len = core::cmp::min(len + bytes_written, cap); // Sanitizes `bytes_written`.
572 output.resize(new_len, 0);
573 Ok(())
574 }
575 Err(e) => {
576 output.resize(len, 0);
577 Err(e)
578 }
579 }
580 }
581}
582
583/// Read bytes from the decode_buffer that are no longer needed. While the frame is not yet finished
584/// this will retain window_size bytes, else it will drain it completely
585impl Read for FrameDecoder {
586 fn read(&mut self, target: &mut [u8]) -> Result<usize, Error> {
587 let state = match &mut self.state {
588 None => return Ok(0),
589 Some(s) => s,
590 };
591 if state.frame_finished {
592 state.decoder_scratch.buffer.read_all(target)
593 } else {
594 state.decoder_scratch.buffer.read(target)
595 }
596 }
597}