ruzstd/decoding/
sequence_section_decoder.rs

1use super::super::blocks::sequence_section::ModeType;
2use super::super::blocks::sequence_section::Sequence;
3use super::super::blocks::sequence_section::SequencesHeader;
4use super::bit_reader_reverse::{BitReaderReversed, GetBitsError};
5use super::scratch::FSEScratch;
6use crate::blocks::sequence_section::{
7    MAX_LITERAL_LENGTH_CODE, MAX_MATCH_LENGTH_CODE, MAX_OFFSET_CODE,
8};
9use crate::fse::{FSEDecoder, FSEDecoderError, FSETableError};
10use alloc::vec::Vec;
11
12#[derive(Debug)]
13#[non_exhaustive]
14pub enum DecodeSequenceError {
15    GetBitsError(GetBitsError),
16    FSEDecoderError(FSEDecoderError),
17    FSETableError(FSETableError),
18    ExtraPadding { skipped_bits: i32 },
19    UnsupportedOffset { offset_code: u8 },
20    ZeroOffset,
21    NotEnoughBytesForNumSequences,
22    ExtraBits { bits_remaining: isize },
23    MissingCompressionMode,
24    MissingByteForRleLlTable,
25    MissingByteForRleOfTable,
26    MissingByteForRleMlTable,
27}
28
29#[cfg(feature = "std")]
30impl std::error::Error for DecodeSequenceError {
31    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
32        match self {
33            DecodeSequenceError::GetBitsError(source) => Some(source),
34            DecodeSequenceError::FSEDecoderError(source) => Some(source),
35            DecodeSequenceError::FSETableError(source) => Some(source),
36            _ => None,
37        }
38    }
39}
40
41impl core::fmt::Display for DecodeSequenceError {
42    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
43        match self {
44            DecodeSequenceError::GetBitsError(e) => write!(f, "{:?}", e),
45            DecodeSequenceError::FSEDecoderError(e) => write!(f, "{:?}", e),
46            DecodeSequenceError::FSETableError(e) => write!(f, "{:?}", e),
47            DecodeSequenceError::ExtraPadding { skipped_bits } => {
48                write!(f,
49                    "Padding at the end of the sequence_section was more than a byte long: {} bits. Probably caused by data corruption",
50                    skipped_bits,
51                )
52            }
53            DecodeSequenceError::UnsupportedOffset { offset_code } => {
54                write!(
55                    f,
56                    "Do not support offsets bigger than 1<<32; got: {}",
57                    offset_code,
58                )
59            }
60            DecodeSequenceError::ZeroOffset => write!(
61                f,
62                "Read an offset == 0. That is an illegal value for offsets"
63            ),
64            DecodeSequenceError::NotEnoughBytesForNumSequences => write!(
65                f,
66                "Bytestream did not contain enough bytes to decode num_sequences"
67            ),
68            DecodeSequenceError::ExtraBits { bits_remaining } => write!(f, "{}", bits_remaining),
69            DecodeSequenceError::MissingCompressionMode => write!(
70                f,
71                "compression modes are none but they must be set to something"
72            ),
73            DecodeSequenceError::MissingByteForRleLlTable => {
74                write!(f, "Need a byte to read for RLE ll table")
75            }
76            DecodeSequenceError::MissingByteForRleOfTable => {
77                write!(f, "Need a byte to read for RLE of table")
78            }
79            DecodeSequenceError::MissingByteForRleMlTable => {
80                write!(f, "Need a byte to read for RLE ml table")
81            }
82        }
83    }
84}
85
86impl From<GetBitsError> for DecodeSequenceError {
87    fn from(val: GetBitsError) -> Self {
88        Self::GetBitsError(val)
89    }
90}
91
92impl From<FSETableError> for DecodeSequenceError {
93    fn from(val: FSETableError) -> Self {
94        Self::FSETableError(val)
95    }
96}
97
98impl From<FSEDecoderError> for DecodeSequenceError {
99    fn from(val: FSEDecoderError) -> Self {
100        Self::FSEDecoderError(val)
101    }
102}
103
104/// Decode the provided source as a series of sequences into the supplied `target`.
105pub fn decode_sequences(
106    section: &SequencesHeader,
107    source: &[u8],
108    scratch: &mut FSEScratch,
109    target: &mut Vec<Sequence>,
110) -> Result<(), DecodeSequenceError> {
111    let bytes_read = maybe_update_fse_tables(section, source, scratch)?;
112
113    vprintln!("Updating tables used {} bytes", bytes_read);
114
115    let bit_stream = &source[bytes_read..];
116
117    let mut br = BitReaderReversed::new(bit_stream);
118
119    //skip the 0 padding at the end of the last byte of the bit stream and throw away the first 1 found
120    let mut skipped_bits = 0;
121    loop {
122        let val = br.get_bits(1);
123        skipped_bits += 1;
124        if val == 1 || skipped_bits > 8 {
125            break;
126        }
127    }
128    if skipped_bits > 8 {
129        //if more than 7 bits are 0, this is not the correct end of the bitstream. Either a bug or corrupted data
130        return Err(DecodeSequenceError::ExtraPadding { skipped_bits });
131    }
132
133    if scratch.ll_rle.is_some() || scratch.ml_rle.is_some() || scratch.of_rle.is_some() {
134        decode_sequences_with_rle(section, &mut br, scratch, target)
135    } else {
136        decode_sequences_without_rle(section, &mut br, scratch, target)
137    }
138}
139
140fn decode_sequences_with_rle(
141    section: &SequencesHeader,
142    br: &mut BitReaderReversed<'_>,
143    scratch: &FSEScratch,
144    target: &mut Vec<Sequence>,
145) -> Result<(), DecodeSequenceError> {
146    let mut ll_dec = FSEDecoder::new(&scratch.literal_lengths);
147    let mut ml_dec = FSEDecoder::new(&scratch.match_lengths);
148    let mut of_dec = FSEDecoder::new(&scratch.offsets);
149
150    if scratch.ll_rle.is_none() {
151        ll_dec.init_state(br)?;
152    }
153    if scratch.of_rle.is_none() {
154        of_dec.init_state(br)?;
155    }
156    if scratch.ml_rle.is_none() {
157        ml_dec.init_state(br)?;
158    }
159
160    target.clear();
161    target.reserve(section.num_sequences as usize);
162
163    for _seq_idx in 0..section.num_sequences {
164        //get the codes from either the RLE byte or from the decoder
165        let ll_code = if scratch.ll_rle.is_some() {
166            scratch.ll_rle.unwrap()
167        } else {
168            ll_dec.decode_symbol()
169        };
170        let ml_code = if scratch.ml_rle.is_some() {
171            scratch.ml_rle.unwrap()
172        } else {
173            ml_dec.decode_symbol()
174        };
175        let of_code = if scratch.of_rle.is_some() {
176            scratch.of_rle.unwrap()
177        } else {
178            of_dec.decode_symbol()
179        };
180
181        let (ll_value, ll_num_bits) = lookup_ll_code(ll_code);
182        let (ml_value, ml_num_bits) = lookup_ml_code(ml_code);
183
184        //println!("Sequence: {}", i);
185        //println!("of stat: {}", of_dec.state);
186        //println!("of Code: {}", of_code);
187        //println!("ll stat: {}", ll_dec.state);
188        //println!("ll bits: {}", ll_num_bits);
189        //println!("ll Code: {}", ll_value);
190        //println!("ml stat: {}", ml_dec.state);
191        //println!("ml bits: {}", ml_num_bits);
192        //println!("ml Code: {}", ml_value);
193        //println!("");
194
195        if of_code > MAX_OFFSET_CODE {
196            return Err(DecodeSequenceError::UnsupportedOffset {
197                offset_code: of_code,
198            });
199        }
200
201        let (obits, ml_add, ll_add) = br.get_bits_triple(of_code, ml_num_bits, ll_num_bits);
202        let offset = obits as u32 + (1u32 << of_code);
203
204        if offset == 0 {
205            return Err(DecodeSequenceError::ZeroOffset);
206        }
207
208        target.push(Sequence {
209            ll: ll_value + ll_add as u32,
210            ml: ml_value + ml_add as u32,
211            of: offset,
212        });
213
214        if target.len() < section.num_sequences as usize {
215            //println!(
216            //    "Bits left: {} ({} bytes)",
217            //    br.bits_remaining(),
218            //    br.bits_remaining() / 8,
219            //);
220            if scratch.ll_rle.is_none() {
221                ll_dec.update_state(br);
222            }
223            if scratch.ml_rle.is_none() {
224                ml_dec.update_state(br);
225            }
226            if scratch.of_rle.is_none() {
227                of_dec.update_state(br);
228            }
229        }
230
231        if br.bits_remaining() < 0 {
232            return Err(DecodeSequenceError::NotEnoughBytesForNumSequences);
233        }
234    }
235
236    if br.bits_remaining() > 0 {
237        Err(DecodeSequenceError::ExtraBits {
238            bits_remaining: br.bits_remaining(),
239        })
240    } else {
241        Ok(())
242    }
243}
244
245fn decode_sequences_without_rle(
246    section: &SequencesHeader,
247    br: &mut BitReaderReversed<'_>,
248    scratch: &FSEScratch,
249    target: &mut Vec<Sequence>,
250) -> Result<(), DecodeSequenceError> {
251    let mut ll_dec = FSEDecoder::new(&scratch.literal_lengths);
252    let mut ml_dec = FSEDecoder::new(&scratch.match_lengths);
253    let mut of_dec = FSEDecoder::new(&scratch.offsets);
254
255    ll_dec.init_state(br)?;
256    of_dec.init_state(br)?;
257    ml_dec.init_state(br)?;
258
259    target.clear();
260    target.reserve(section.num_sequences as usize);
261
262    for _seq_idx in 0..section.num_sequences {
263        let ll_code = ll_dec.decode_symbol();
264        let ml_code = ml_dec.decode_symbol();
265        let of_code = of_dec.decode_symbol();
266
267        let (ll_value, ll_num_bits) = lookup_ll_code(ll_code);
268        let (ml_value, ml_num_bits) = lookup_ml_code(ml_code);
269
270        if of_code > MAX_OFFSET_CODE {
271            return Err(DecodeSequenceError::UnsupportedOffset {
272                offset_code: of_code,
273            });
274        }
275
276        let (obits, ml_add, ll_add) = br.get_bits_triple(of_code, ml_num_bits, ll_num_bits);
277        let offset = obits as u32 + (1u32 << of_code);
278
279        if offset == 0 {
280            return Err(DecodeSequenceError::ZeroOffset);
281        }
282
283        target.push(Sequence {
284            ll: ll_value + ll_add as u32,
285            ml: ml_value + ml_add as u32,
286            of: offset,
287        });
288
289        if target.len() < section.num_sequences as usize {
290            //println!(
291            //    "Bits left: {} ({} bytes)",
292            //    br.bits_remaining(),
293            //    br.bits_remaining() / 8,
294            //);
295            ll_dec.update_state(br);
296            ml_dec.update_state(br);
297            of_dec.update_state(br);
298        }
299
300        if br.bits_remaining() < 0 {
301            return Err(DecodeSequenceError::NotEnoughBytesForNumSequences);
302        }
303    }
304
305    if br.bits_remaining() > 0 {
306        Err(DecodeSequenceError::ExtraBits {
307            bits_remaining: br.bits_remaining(),
308        })
309    } else {
310        Ok(())
311    }
312}
313
314/// Look up the provided state value from a literal length table predefined
315/// by the Zstandard reference document. Returns a tuple of (value, number of bits).
316///
317/// <https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#appendix-a---decoding-tables-for-predefined-codes>
318fn lookup_ll_code(code: u8) -> (u32, u8) {
319    match code {
320        0..=15 => (u32::from(code), 0),
321        16 => (16, 1),
322        17 => (18, 1),
323        18 => (20, 1),
324        19 => (22, 1),
325        20 => (24, 2),
326        21 => (28, 2),
327        22 => (32, 3),
328        23 => (40, 3),
329        24 => (48, 4),
330        25 => (64, 6),
331        26 => (128, 7),
332        27 => (256, 8),
333        28 => (512, 9),
334        29 => (1024, 10),
335        30 => (2048, 11),
336        31 => (4096, 12),
337        32 => (8192, 13),
338        33 => (16384, 14),
339        34 => (32768, 15),
340        35 => (65536, 16),
341        _ => unreachable!("Illegal literal length code was: {}", code),
342    }
343}
344
345/// Look up the provided state value from a match length table predefined
346/// by the Zstandard reference document. Returns a tuple of (value, number of bits).
347///
348/// <https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#appendix-a---decoding-tables-for-predefined-codes>
349fn lookup_ml_code(code: u8) -> (u32, u8) {
350    match code {
351        0..=31 => (u32::from(code) + 3, 0),
352        32 => (35, 1),
353        33 => (37, 1),
354        34 => (39, 1),
355        35 => (41, 1),
356        36 => (43, 2),
357        37 => (47, 2),
358        38 => (51, 3),
359        39 => (59, 3),
360        40 => (67, 4),
361        41 => (83, 4),
362        42 => (99, 5),
363        43 => (131, 7),
364        44 => (259, 8),
365        45 => (515, 9),
366        46 => (1027, 10),
367        47 => (2051, 11),
368        48 => (4099, 12),
369        49 => (8195, 13),
370        50 => (16387, 14),
371        51 => (32771, 15),
372        52 => (65539, 16),
373        _ => unreachable!("Illegal match length code was: {}", code),
374    }
375}
376
377// This info is buried in the symbol compression mode table
378/// "The maximum allowed accuracy log for literals length and match length tables is 9"
379pub const LL_MAX_LOG: u8 = 9;
380/// "The maximum allowed accuracy log for literals length and match length tables is 9"
381pub const ML_MAX_LOG: u8 = 9;
382/// "The maximum accuracy log for the offset table is 8."
383pub const OF_MAX_LOG: u8 = 8;
384
385fn maybe_update_fse_tables(
386    section: &SequencesHeader,
387    source: &[u8],
388    scratch: &mut FSEScratch,
389) -> Result<usize, DecodeSequenceError> {
390    let modes = section
391        .modes
392        .ok_or(DecodeSequenceError::MissingCompressionMode)?;
393
394    let mut bytes_read = 0;
395
396    match modes.ll_mode() {
397        ModeType::FSECompressed => {
398            let bytes = scratch.literal_lengths.build_decoder(source, LL_MAX_LOG)?;
399            bytes_read += bytes;
400
401            vprintln!("Updating ll table");
402            vprintln!("Used bytes: {}", bytes);
403            scratch.ll_rle = None;
404        }
405        ModeType::RLE => {
406            vprintln!("Use RLE ll table");
407            if source.is_empty() {
408                return Err(DecodeSequenceError::MissingByteForRleLlTable);
409            }
410            bytes_read += 1;
411            if source[0] > MAX_LITERAL_LENGTH_CODE {
412                return Err(DecodeSequenceError::MissingByteForRleMlTable);
413            }
414            scratch.ll_rle = Some(source[0]);
415        }
416        ModeType::Predefined => {
417            vprintln!("Use predefined ll table");
418            scratch.literal_lengths.build_from_probabilities(
419                LL_DEFAULT_ACC_LOG,
420                &Vec::from(&LITERALS_LENGTH_DEFAULT_DISTRIBUTION[..]),
421            )?;
422            scratch.ll_rle = None;
423        }
424        ModeType::Repeat => {
425            vprintln!("Repeat ll table");
426            /* Nothing to do */
427        }
428    };
429
430    let of_source = &source[bytes_read..];
431
432    match modes.of_mode() {
433        ModeType::FSECompressed => {
434            let bytes = scratch.offsets.build_decoder(of_source, OF_MAX_LOG)?;
435            vprintln!("Updating of table");
436            vprintln!("Used bytes: {}", bytes);
437            bytes_read += bytes;
438            scratch.of_rle = None;
439        }
440        ModeType::RLE => {
441            vprintln!("Use RLE of table");
442            if of_source.is_empty() {
443                return Err(DecodeSequenceError::MissingByteForRleOfTable);
444            }
445            bytes_read += 1;
446            if of_source[0] > MAX_OFFSET_CODE {
447                return Err(DecodeSequenceError::MissingByteForRleMlTable);
448            }
449            scratch.of_rle = Some(of_source[0]);
450        }
451        ModeType::Predefined => {
452            vprintln!("Use predefined of table");
453            scratch.offsets.build_from_probabilities(
454                OF_DEFAULT_ACC_LOG,
455                &Vec::from(&OFFSET_DEFAULT_DISTRIBUTION[..]),
456            )?;
457            scratch.of_rle = None;
458        }
459        ModeType::Repeat => {
460            vprintln!("Repeat of table");
461            /* Nothing to do */
462        }
463    };
464
465    let ml_source = &source[bytes_read..];
466
467    match modes.ml_mode() {
468        ModeType::FSECompressed => {
469            let bytes = scratch.match_lengths.build_decoder(ml_source, ML_MAX_LOG)?;
470            bytes_read += bytes;
471            vprintln!("Updating ml table");
472            vprintln!("Used bytes: {}", bytes);
473            scratch.ml_rle = None;
474        }
475        ModeType::RLE => {
476            vprintln!("Use RLE ml table");
477            if ml_source.is_empty() {
478                return Err(DecodeSequenceError::MissingByteForRleMlTable);
479            }
480            bytes_read += 1;
481            if ml_source[0] > MAX_MATCH_LENGTH_CODE {
482                return Err(DecodeSequenceError::MissingByteForRleMlTable);
483            }
484            scratch.ml_rle = Some(ml_source[0]);
485        }
486        ModeType::Predefined => {
487            vprintln!("Use predefined ml table");
488            scratch.match_lengths.build_from_probabilities(
489                ML_DEFAULT_ACC_LOG,
490                &Vec::from(&MATCH_LENGTH_DEFAULT_DISTRIBUTION[..]),
491            )?;
492            scratch.ml_rle = None;
493        }
494        ModeType::Repeat => {
495            vprintln!("Repeat ml table");
496            /* Nothing to do */
497        }
498    };
499
500    Ok(bytes_read)
501}
502
503// The default Literal Length decoding table uses an accuracy logarithm of 6 bits.
504const LL_DEFAULT_ACC_LOG: u8 = 6;
505/// If [ModeType::Predefined] is selected for a symbol type, its FSE decoding
506/// table is generated using a predefined distribution table.
507///
508/// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#literals-length
509const LITERALS_LENGTH_DEFAULT_DISTRIBUTION: [i32; 36] = [
510    4, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 1, 1, 1, 1, 1,
511    -1, -1, -1, -1,
512];
513
514// The default Match Length decoding table uses an accuracy logarithm of 6 bits.
515const ML_DEFAULT_ACC_LOG: u8 = 6;
516/// If [ModeType::Predefined] is selected for a symbol type, its FSE decoding
517/// table is generated using a predefined distribution table.
518///
519/// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#match-length
520const MATCH_LENGTH_DEFAULT_DISTRIBUTION: [i32; 53] = [
521    1, 4, 3, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
522    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1,
523];
524
525// The default Match Length decoding table uses an accuracy logarithm of 5 bits.
526const OF_DEFAULT_ACC_LOG: u8 = 5;
527/// If [ModeType::Predefined] is selected for a symbol type, its FSE decoding
528/// table is generated using a predefined distribution table.
529///
530/// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#match-length
531const OFFSET_DEFAULT_DISTRIBUTION: [i32; 29] = [
532    1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1,
533];
534
535#[test]
536fn test_ll_default() {
537    let mut table = crate::fse::FSETable::new(MAX_LITERAL_LENGTH_CODE);
538    table
539        .build_from_probabilities(
540            LL_DEFAULT_ACC_LOG,
541            &Vec::from(&LITERALS_LENGTH_DEFAULT_DISTRIBUTION[..]),
542        )
543        .unwrap();
544
545    #[cfg(feature = "std")]
546    for idx in 0..table.decode.len() {
547        std::println!(
548            "{:3}: {:3} {:3} {:3}",
549            idx,
550            table.decode[idx].symbol,
551            table.decode[idx].num_bits,
552            table.decode[idx].base_line
553        );
554    }
555
556    assert!(table.decode.len() == 64);
557
558    //just test a few values. TODO test all values
559    assert!(table.decode[0].symbol == 0);
560    assert!(table.decode[0].num_bits == 4);
561    assert!(table.decode[0].base_line == 0);
562
563    assert!(table.decode[19].symbol == 27);
564    assert!(table.decode[19].num_bits == 6);
565    assert!(table.decode[19].base_line == 0);
566
567    assert!(table.decode[39].symbol == 25);
568    assert!(table.decode[39].num_bits == 4);
569    assert!(table.decode[39].base_line == 16);
570
571    assert!(table.decode[60].symbol == 35);
572    assert!(table.decode[60].num_bits == 6);
573    assert!(table.decode[60].base_line == 0);
574
575    assert!(table.decode[59].symbol == 24);
576    assert!(table.decode[59].num_bits == 5);
577    assert!(table.decode[59].base_line == 32);
578}