ruzstd/decoding/
sequence_section_decoder.rs

1use super::super::blocks::sequence_section::ModeType;
2use super::super::blocks::sequence_section::Sequence;
3use super::super::blocks::sequence_section::SequencesHeader;
4use super::scratch::FSEScratch;
5use crate::bit_io::BitReaderReversed;
6use crate::blocks::sequence_section::{
7    MAX_LITERAL_LENGTH_CODE, MAX_MATCH_LENGTH_CODE, MAX_OFFSET_CODE,
8};
9use crate::decoding::errors::DecodeSequenceError;
10use crate::fse::FSEDecoder;
11use alloc::vec::Vec;
12
13/// Decode the provided source as a series of sequences into the supplied `target`.
14pub fn decode_sequences(
15    section: &SequencesHeader,
16    source: &[u8],
17    scratch: &mut FSEScratch,
18    target: &mut Vec<Sequence>,
19) -> Result<(), DecodeSequenceError> {
20    let bytes_read = maybe_update_fse_tables(section, source, scratch)?;
21
22    vprintln!("Updating tables used {} bytes", bytes_read);
23
24    let bit_stream = &source[bytes_read..];
25
26    let mut br = BitReaderReversed::new(bit_stream);
27
28    //skip the 0 padding at the end of the last byte of the bit stream and throw away the first 1 found
29    let mut skipped_bits = 0;
30    loop {
31        let val = br.get_bits(1);
32        skipped_bits += 1;
33        if val == 1 || skipped_bits > 8 {
34            break;
35        }
36    }
37    if skipped_bits > 8 {
38        //if more than 7 bits are 0, this is not the correct end of the bitstream. Either a bug or corrupted data
39        return Err(DecodeSequenceError::ExtraPadding { skipped_bits });
40    }
41
42    if scratch.ll_rle.is_some() || scratch.ml_rle.is_some() || scratch.of_rle.is_some() {
43        decode_sequences_with_rle(section, &mut br, scratch, target)
44    } else {
45        decode_sequences_without_rle(section, &mut br, scratch, target)
46    }
47}
48
49fn decode_sequences_with_rle(
50    section: &SequencesHeader,
51    br: &mut BitReaderReversed<'_>,
52    scratch: &FSEScratch,
53    target: &mut Vec<Sequence>,
54) -> Result<(), DecodeSequenceError> {
55    let mut ll_dec = FSEDecoder::new(&scratch.literal_lengths);
56    let mut ml_dec = FSEDecoder::new(&scratch.match_lengths);
57    let mut of_dec = FSEDecoder::new(&scratch.offsets);
58
59    if scratch.ll_rle.is_none() {
60        ll_dec.init_state(br)?;
61    }
62    if scratch.of_rle.is_none() {
63        of_dec.init_state(br)?;
64    }
65    if scratch.ml_rle.is_none() {
66        ml_dec.init_state(br)?;
67    }
68
69    target.clear();
70    target.reserve(section.num_sequences as usize);
71
72    for _seq_idx in 0..section.num_sequences {
73        //get the codes from either the RLE byte or from the decoder
74        let ll_code = if scratch.ll_rle.is_some() {
75            scratch.ll_rle.unwrap()
76        } else {
77            ll_dec.decode_symbol()
78        };
79        let ml_code = if scratch.ml_rle.is_some() {
80            scratch.ml_rle.unwrap()
81        } else {
82            ml_dec.decode_symbol()
83        };
84        let of_code = if scratch.of_rle.is_some() {
85            scratch.of_rle.unwrap()
86        } else {
87            of_dec.decode_symbol()
88        };
89
90        let (ll_value, ll_num_bits) = lookup_ll_code(ll_code);
91        let (ml_value, ml_num_bits) = lookup_ml_code(ml_code);
92
93        //println!("Sequence: {}", i);
94        //println!("of stat: {}", of_dec.state);
95        //println!("of Code: {}", of_code);
96        //println!("ll stat: {}", ll_dec.state);
97        //println!("ll bits: {}", ll_num_bits);
98        //println!("ll Code: {}", ll_value);
99        //println!("ml stat: {}", ml_dec.state);
100        //println!("ml bits: {}", ml_num_bits);
101        //println!("ml Code: {}", ml_value);
102        //println!("");
103
104        if of_code > MAX_OFFSET_CODE {
105            return Err(DecodeSequenceError::UnsupportedOffset {
106                offset_code: of_code,
107            });
108        }
109
110        let (obits, ml_add, ll_add) = br.get_bits_triple(of_code, ml_num_bits, ll_num_bits);
111        let offset = obits as u32 + (1u32 << of_code);
112
113        if offset == 0 {
114            return Err(DecodeSequenceError::ZeroOffset);
115        }
116
117        target.push(Sequence {
118            ll: ll_value + ll_add as u32,
119            ml: ml_value + ml_add as u32,
120            of: offset,
121        });
122
123        if target.len() < section.num_sequences as usize {
124            //println!(
125            //    "Bits left: {} ({} bytes)",
126            //    br.bits_remaining(),
127            //    br.bits_remaining() / 8,
128            //);
129            if scratch.ll_rle.is_none() {
130                ll_dec.update_state(br);
131            }
132            if scratch.ml_rle.is_none() {
133                ml_dec.update_state(br);
134            }
135            if scratch.of_rle.is_none() {
136                of_dec.update_state(br);
137            }
138        }
139
140        if br.bits_remaining() < 0 {
141            return Err(DecodeSequenceError::NotEnoughBytesForNumSequences);
142        }
143    }
144
145    if br.bits_remaining() > 0 {
146        Err(DecodeSequenceError::ExtraBits {
147            bits_remaining: br.bits_remaining(),
148        })
149    } else {
150        Ok(())
151    }
152}
153
154fn decode_sequences_without_rle(
155    section: &SequencesHeader,
156    br: &mut BitReaderReversed<'_>,
157    scratch: &FSEScratch,
158    target: &mut Vec<Sequence>,
159) -> Result<(), DecodeSequenceError> {
160    let mut ll_dec = FSEDecoder::new(&scratch.literal_lengths);
161    let mut ml_dec = FSEDecoder::new(&scratch.match_lengths);
162    let mut of_dec = FSEDecoder::new(&scratch.offsets);
163
164    ll_dec.init_state(br)?;
165    of_dec.init_state(br)?;
166    ml_dec.init_state(br)?;
167
168    target.clear();
169    target.reserve(section.num_sequences as usize);
170
171    for _seq_idx in 0..section.num_sequences {
172        let ll_code = ll_dec.decode_symbol();
173        let ml_code = ml_dec.decode_symbol();
174        let of_code = of_dec.decode_symbol();
175
176        let (ll_value, ll_num_bits) = lookup_ll_code(ll_code);
177        let (ml_value, ml_num_bits) = lookup_ml_code(ml_code);
178
179        if of_code > MAX_OFFSET_CODE {
180            return Err(DecodeSequenceError::UnsupportedOffset {
181                offset_code: of_code,
182            });
183        }
184
185        let (obits, ml_add, ll_add) = br.get_bits_triple(of_code, ml_num_bits, ll_num_bits);
186        let offset = obits as u32 + (1u32 << of_code);
187
188        if offset == 0 {
189            return Err(DecodeSequenceError::ZeroOffset);
190        }
191
192        target.push(Sequence {
193            ll: ll_value + ll_add as u32,
194            ml: ml_value + ml_add as u32,
195            of: offset,
196        });
197
198        if target.len() < section.num_sequences as usize {
199            //println!(
200            //    "Bits left: {} ({} bytes)",
201            //    br.bits_remaining(),
202            //    br.bits_remaining() / 8,
203            //);
204            ll_dec.update_state(br);
205            ml_dec.update_state(br);
206            of_dec.update_state(br);
207        }
208
209        if br.bits_remaining() < 0 {
210            return Err(DecodeSequenceError::NotEnoughBytesForNumSequences);
211        }
212    }
213
214    if br.bits_remaining() > 0 {
215        Err(DecodeSequenceError::ExtraBits {
216            bits_remaining: br.bits_remaining(),
217        })
218    } else {
219        Ok(())
220    }
221}
222
223/// Look up the provided state value from a literal length table predefined
224/// by the Zstandard reference document. Returns a tuple of (value, number of bits).
225///
226/// <https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#appendix-a---decoding-tables-for-predefined-codes>
227fn lookup_ll_code(code: u8) -> (u32, u8) {
228    match code {
229        0..=15 => (u32::from(code), 0),
230        16 => (16, 1),
231        17 => (18, 1),
232        18 => (20, 1),
233        19 => (22, 1),
234        20 => (24, 2),
235        21 => (28, 2),
236        22 => (32, 3),
237        23 => (40, 3),
238        24 => (48, 4),
239        25 => (64, 6),
240        26 => (128, 7),
241        27 => (256, 8),
242        28 => (512, 9),
243        29 => (1024, 10),
244        30 => (2048, 11),
245        31 => (4096, 12),
246        32 => (8192, 13),
247        33 => (16384, 14),
248        34 => (32768, 15),
249        35 => (65536, 16),
250        _ => unreachable!("Illegal literal length code was: {}", code),
251    }
252}
253
254/// Look up the provided state value from a match length table predefined
255/// by the Zstandard reference document. Returns a tuple of (value, number of bits).
256///
257/// <https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#appendix-a---decoding-tables-for-predefined-codes>
258fn lookup_ml_code(code: u8) -> (u32, u8) {
259    match code {
260        0..=31 => (u32::from(code) + 3, 0),
261        32 => (35, 1),
262        33 => (37, 1),
263        34 => (39, 1),
264        35 => (41, 1),
265        36 => (43, 2),
266        37 => (47, 2),
267        38 => (51, 3),
268        39 => (59, 3),
269        40 => (67, 4),
270        41 => (83, 4),
271        42 => (99, 5),
272        43 => (131, 7),
273        44 => (259, 8),
274        45 => (515, 9),
275        46 => (1027, 10),
276        47 => (2051, 11),
277        48 => (4099, 12),
278        49 => (8195, 13),
279        50 => (16387, 14),
280        51 => (32771, 15),
281        52 => (65539, 16),
282        _ => unreachable!("Illegal match length code was: {}", code),
283    }
284}
285
286// This info is buried in the symbol compression mode table
287/// "The maximum allowed accuracy log for literals length and match length tables is 9"
288pub const LL_MAX_LOG: u8 = 9;
289/// "The maximum allowed accuracy log for literals length and match length tables is 9"
290pub const ML_MAX_LOG: u8 = 9;
291/// "The maximum accuracy log for the offset table is 8."
292pub const OF_MAX_LOG: u8 = 8;
293
294fn maybe_update_fse_tables(
295    section: &SequencesHeader,
296    source: &[u8],
297    scratch: &mut FSEScratch,
298) -> Result<usize, DecodeSequenceError> {
299    let modes = section
300        .modes
301        .ok_or(DecodeSequenceError::MissingCompressionMode)?;
302
303    let mut bytes_read = 0;
304
305    match modes.ll_mode() {
306        ModeType::FSECompressed => {
307            let bytes = scratch.literal_lengths.build_decoder(source, LL_MAX_LOG)?;
308            bytes_read += bytes;
309
310            vprintln!("Updating ll table");
311            vprintln!("Used bytes: {}", bytes);
312            scratch.ll_rle = None;
313        }
314        ModeType::RLE => {
315            vprintln!("Use RLE ll table");
316            if source.is_empty() {
317                return Err(DecodeSequenceError::MissingByteForRleLlTable);
318            }
319            bytes_read += 1;
320            if source[0] > MAX_LITERAL_LENGTH_CODE {
321                return Err(DecodeSequenceError::MissingByteForRleMlTable);
322            }
323            scratch.ll_rle = Some(source[0]);
324        }
325        ModeType::Predefined => {
326            vprintln!("Use predefined ll table");
327            scratch.literal_lengths.build_from_probabilities(
328                LL_DEFAULT_ACC_LOG,
329                &Vec::from(&LITERALS_LENGTH_DEFAULT_DISTRIBUTION[..]),
330            )?;
331            scratch.ll_rle = None;
332        }
333        ModeType::Repeat => {
334            vprintln!("Repeat ll table");
335            /* Nothing to do */
336        }
337    };
338
339    let of_source = &source[bytes_read..];
340
341    match modes.of_mode() {
342        ModeType::FSECompressed => {
343            let bytes = scratch.offsets.build_decoder(of_source, OF_MAX_LOG)?;
344            vprintln!("Updating of table");
345            vprintln!("Used bytes: {}", bytes);
346            bytes_read += bytes;
347            scratch.of_rle = None;
348        }
349        ModeType::RLE => {
350            vprintln!("Use RLE of table");
351            if of_source.is_empty() {
352                return Err(DecodeSequenceError::MissingByteForRleOfTable);
353            }
354            bytes_read += 1;
355            if of_source[0] > MAX_OFFSET_CODE {
356                return Err(DecodeSequenceError::MissingByteForRleMlTable);
357            }
358            scratch.of_rle = Some(of_source[0]);
359        }
360        ModeType::Predefined => {
361            vprintln!("Use predefined of table");
362            scratch.offsets.build_from_probabilities(
363                OF_DEFAULT_ACC_LOG,
364                &Vec::from(&OFFSET_DEFAULT_DISTRIBUTION[..]),
365            )?;
366            scratch.of_rle = None;
367        }
368        ModeType::Repeat => {
369            vprintln!("Repeat of table");
370            /* Nothing to do */
371        }
372    };
373
374    let ml_source = &source[bytes_read..];
375
376    match modes.ml_mode() {
377        ModeType::FSECompressed => {
378            let bytes = scratch.match_lengths.build_decoder(ml_source, ML_MAX_LOG)?;
379            bytes_read += bytes;
380            vprintln!("Updating ml table");
381            vprintln!("Used bytes: {}", bytes);
382            scratch.ml_rle = None;
383        }
384        ModeType::RLE => {
385            vprintln!("Use RLE ml table");
386            if ml_source.is_empty() {
387                return Err(DecodeSequenceError::MissingByteForRleMlTable);
388            }
389            bytes_read += 1;
390            if ml_source[0] > MAX_MATCH_LENGTH_CODE {
391                return Err(DecodeSequenceError::MissingByteForRleMlTable);
392            }
393            scratch.ml_rle = Some(ml_source[0]);
394        }
395        ModeType::Predefined => {
396            vprintln!("Use predefined ml table");
397            scratch.match_lengths.build_from_probabilities(
398                ML_DEFAULT_ACC_LOG,
399                &Vec::from(&MATCH_LENGTH_DEFAULT_DISTRIBUTION[..]),
400            )?;
401            scratch.ml_rle = None;
402        }
403        ModeType::Repeat => {
404            vprintln!("Repeat ml table");
405            /* Nothing to do */
406        }
407    };
408
409    Ok(bytes_read)
410}
411
412// The default Literal Length decoding table uses an accuracy logarithm of 6 bits.
413const LL_DEFAULT_ACC_LOG: u8 = 6;
414/// If [ModeType::Predefined] is selected for a symbol type, its FSE decoding
415/// table is generated using a predefined distribution table.
416///
417/// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#literals-length
418const LITERALS_LENGTH_DEFAULT_DISTRIBUTION: [i32; 36] = [
419    4, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 1, 1, 1, 1, 1,
420    -1, -1, -1, -1,
421];
422
423// The default Match Length decoding table uses an accuracy logarithm of 6 bits.
424const ML_DEFAULT_ACC_LOG: u8 = 6;
425/// If [ModeType::Predefined] is selected for a symbol type, its FSE decoding
426/// table is generated using a predefined distribution table.
427///
428/// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#match-length
429const MATCH_LENGTH_DEFAULT_DISTRIBUTION: [i32; 53] = [
430    1, 4, 3, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
431    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1,
432];
433
434// The default Match Length decoding table uses an accuracy logarithm of 5 bits.
435const OF_DEFAULT_ACC_LOG: u8 = 5;
436/// If [ModeType::Predefined] is selected for a symbol type, its FSE decoding
437/// table is generated using a predefined distribution table.
438///
439/// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#match-length
440const OFFSET_DEFAULT_DISTRIBUTION: [i32; 29] = [
441    1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1,
442];
443
444#[test]
445fn test_ll_default() {
446    let mut table = crate::fse::FSETable::new(MAX_LITERAL_LENGTH_CODE);
447    table
448        .build_from_probabilities(
449            LL_DEFAULT_ACC_LOG,
450            &Vec::from(&LITERALS_LENGTH_DEFAULT_DISTRIBUTION[..]),
451        )
452        .unwrap();
453
454    #[cfg(feature = "std")]
455    for idx in 0..table.decode.len() {
456        std::println!(
457            "{:3}: {:3} {:3} {:3}",
458            idx,
459            table.decode[idx].symbol,
460            table.decode[idx].num_bits,
461            table.decode[idx].base_line
462        );
463    }
464
465    assert!(table.decode.len() == 64);
466
467    //just test a few values. TODO test all values
468    assert!(table.decode[0].symbol == 0);
469    assert!(table.decode[0].num_bits == 4);
470    assert!(table.decode[0].base_line == 0);
471
472    assert!(table.decode[19].symbol == 27);
473    assert!(table.decode[19].num_bits == 6);
474    assert!(table.decode[19].base_line == 0);
475
476    assert!(table.decode[39].symbol == 25);
477    assert!(table.decode[39].num_bits == 4);
478    assert!(table.decode[39].base_line == 16);
479
480    assert!(table.decode[60].symbol == 35);
481    assert!(table.decode[60].num_bits == 6);
482    assert!(table.decode[60].base_line == 0);
483
484    assert!(table.decode[59].symbol == 24);
485    assert!(table.decode[59].num_bits == 5);
486    assert!(table.decode[59].base_line == 32);
487}