ruzstd/huff0/
huff0_decoder.rs

1//! Utilities for decoding Huff0 encoded huffman data.
2
3use crate::decoding::bit_reader_reverse::{BitReaderReversed, GetBitsError};
4use crate::fse::{FSEDecoder, FSEDecoderError, FSETable, FSETableError};
5use alloc::vec::Vec;
6#[cfg(feature = "std")]
7use std::error::Error as StdError;
8
9pub struct HuffmanTable {
10    decode: Vec<Entry>,
11    /// The weight of a symbol is the number of occurences in a table.
12    /// This value is used in constructing a binary tree referred to as
13    /// a huffman tree.
14    weights: Vec<u8>,
15    /// The maximum size in bits a prefix code in the encoded data can be.
16    /// This value is used so that the decoder knows how many bits
17    /// to read from the bitstream before checking the table. This
18    /// value must be 11 or lower.
19    pub max_num_bits: u8,
20    bits: Vec<u8>,
21    bit_ranks: Vec<u32>,
22    rank_indexes: Vec<usize>,
23    /// In some cases, the list of weights is compressed using FSE compression.
24    fse_table: FSETable,
25}
26
27#[derive(Debug)]
28#[non_exhaustive]
29pub enum HuffmanTableError {
30    GetBitsError(GetBitsError),
31    FSEDecoderError(FSEDecoderError),
32    FSETableError(FSETableError),
33    SourceIsEmpty,
34    NotEnoughBytesForWeights {
35        got_bytes: usize,
36        expected_bytes: u8,
37    },
38    ExtraPadding {
39        skipped_bits: i32,
40    },
41    TooManyWeights {
42        got: usize,
43    },
44    MissingWeights,
45    LeftoverIsNotAPowerOf2 {
46        got: u32,
47    },
48    NotEnoughBytesToDecompressWeights {
49        have: usize,
50        need: usize,
51    },
52    FSETableUsedTooManyBytes {
53        used: usize,
54        available_bytes: u8,
55    },
56    NotEnoughBytesInSource {
57        got: usize,
58        need: usize,
59    },
60    WeightBiggerThanMaxNumBits {
61        got: u8,
62    },
63    MaxBitsTooHigh {
64        got: u8,
65    },
66}
67
68#[cfg(feature = "std")]
69impl StdError for HuffmanTableError {
70    fn source(&self) -> Option<&(dyn StdError + 'static)> {
71        match self {
72            HuffmanTableError::GetBitsError(source) => Some(source),
73            HuffmanTableError::FSEDecoderError(source) => Some(source),
74            HuffmanTableError::FSETableError(source) => Some(source),
75            _ => None,
76        }
77    }
78}
79
80impl core::fmt::Display for HuffmanTableError {
81    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> ::core::fmt::Result {
82        match self {
83            HuffmanTableError::GetBitsError(e) => write!(f, "{:?}", e),
84            HuffmanTableError::FSEDecoderError(e) => write!(f, "{:?}", e),
85            HuffmanTableError::FSETableError(e) => write!(f, "{:?}", e),
86            HuffmanTableError::SourceIsEmpty => write!(f, "Source needs to have at least one byte"),
87            HuffmanTableError::NotEnoughBytesForWeights {
88                got_bytes,
89                expected_bytes,
90            } => {
91                write!(f, "Header says there should be {} bytes for the weights but there are only {} bytes in the stream",
92                    expected_bytes,
93                    got_bytes)
94            }
95            HuffmanTableError::ExtraPadding { skipped_bits } => {
96                write!(f,
97                    "Padding at the end of the sequence_section was more than a byte long: {} bits. Probably caused by data corruption",
98                    skipped_bits,
99                )
100            }
101            HuffmanTableError::TooManyWeights { got } => {
102                write!(
103                    f,
104                    "More than 255 weights decoded (got {} weights). Stream is probably corrupted",
105                    got,
106                )
107            }
108            HuffmanTableError::MissingWeights => {
109                write!(f, "Can\'t build huffman table without any weights")
110            }
111            HuffmanTableError::LeftoverIsNotAPowerOf2 { got } => {
112                write!(f, "Leftover must be power of two but is: {}", got)
113            }
114            HuffmanTableError::NotEnoughBytesToDecompressWeights { have, need } => {
115                write!(
116                    f,
117                    "Not enough bytes in stream to decompress weights. Is: {}, Should be: {}",
118                    have, need,
119                )
120            }
121            HuffmanTableError::FSETableUsedTooManyBytes {
122                used,
123                available_bytes,
124            } => {
125                write!(f,
126                    "FSE table used more bytes: {} than were meant to be used for the whole stream of huffman weights ({})",
127                    used,
128                    available_bytes,
129                )
130            }
131            HuffmanTableError::NotEnoughBytesInSource { got, need } => {
132                write!(
133                    f,
134                    "Source needs to have at least {} bytes, got: {}",
135                    need, got,
136                )
137            }
138            HuffmanTableError::WeightBiggerThanMaxNumBits { got } => {
139                write!(
140                    f,
141                    "Cant have weight: {} bigger than max_num_bits: {}",
142                    got, MAX_MAX_NUM_BITS,
143                )
144            }
145            HuffmanTableError::MaxBitsTooHigh { got } => {
146                write!(
147                    f,
148                    "max_bits derived from weights is: {} should be lower than: {}",
149                    got, MAX_MAX_NUM_BITS,
150                )
151            }
152        }
153    }
154}
155
156impl From<GetBitsError> for HuffmanTableError {
157    fn from(val: GetBitsError) -> Self {
158        Self::GetBitsError(val)
159    }
160}
161
162impl From<FSEDecoderError> for HuffmanTableError {
163    fn from(val: FSEDecoderError) -> Self {
164        Self::FSEDecoderError(val)
165    }
166}
167
168impl From<FSETableError> for HuffmanTableError {
169    fn from(val: FSETableError) -> Self {
170        Self::FSETableError(val)
171    }
172}
173
174/// An interface around a huffman table used to decode data.
175pub struct HuffmanDecoder<'table> {
176    table: &'table HuffmanTable,
177    /// State is used to index into the table.
178    pub state: u64,
179}
180
181#[derive(Debug)]
182#[non_exhaustive]
183pub enum HuffmanDecoderError {
184    GetBitsError(GetBitsError),
185}
186
187impl core::fmt::Display for HuffmanDecoderError {
188    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
189        match self {
190            HuffmanDecoderError::GetBitsError(e) => write!(f, "{:?}", e),
191        }
192    }
193}
194
195#[cfg(feature = "std")]
196impl StdError for HuffmanDecoderError {
197    fn source(&self) -> Option<&(dyn StdError + 'static)> {
198        match self {
199            HuffmanDecoderError::GetBitsError(source) => Some(source),
200        }
201    }
202}
203
204impl From<GetBitsError> for HuffmanDecoderError {
205    fn from(val: GetBitsError) -> Self {
206        Self::GetBitsError(val)
207    }
208}
209
210/// A single entry in the table contains the decoded symbol/literal and the
211/// size of the prefix code.
212#[derive(Copy, Clone)]
213pub struct Entry {
214    /// The byte that the prefix code replaces during encoding.
215    symbol: u8,
216    /// The number of bits the prefix code occupies.
217    num_bits: u8,
218}
219
220/// The Zstandard specification limits the maximum length of a code to 11 bits.
221const MAX_MAX_NUM_BITS: u8 = 11;
222
223/// Assert that the provided value is greater than zero, and returns the
224/// 32 - the number of leading zeros
225fn highest_bit_set(x: u32) -> u32 {
226    assert!(x > 0);
227    u32::BITS - x.leading_zeros()
228}
229
230impl<'t> HuffmanDecoder<'t> {
231    /// Create a new decoder with the provided table
232    pub fn new(table: &'t HuffmanTable) -> HuffmanDecoder<'t> {
233        HuffmanDecoder { table, state: 0 }
234    }
235
236    /// Re-initialize the decoder, using the new table if one is provided.
237    /// This might used for treeless blocks, because they re-use the table from old
238    /// data.
239    pub fn reset(mut self, new_table: Option<&'t HuffmanTable>) {
240        self.state = 0;
241        if let Some(next_table) = new_table {
242            self.table = next_table;
243        }
244    }
245
246    /// Decode the symbol the internal state (cursor) is pointed at and return the
247    /// decoded literal.
248    pub fn decode_symbol(&mut self) -> u8 {
249        self.table.decode[self.state as usize].symbol
250    }
251
252    /// Initialize internal state and prepare to decode data. Then, `decode_symbol` can be called
253    /// to read the byte the internal cursor is pointing at, and `next_state` can be called to advance
254    /// the cursor until the max number of bits has been read.
255    pub fn init_state(&mut self, br: &mut BitReaderReversed<'_>) -> u8 {
256        let num_bits = self.table.max_num_bits;
257        let new_bits = br.get_bits(num_bits);
258        self.state = new_bits;
259        num_bits
260    }
261
262    /// Advance the internal cursor to the next symbol. After this, you can call `decode_symbol`
263    /// to read from the new position.
264    pub fn next_state(&mut self, br: &mut BitReaderReversed<'_>) -> u8 {
265        // self.state stores a small section, or a window of the bit stream. The table can be indexed via this state,
266        // telling you how many bits identify the current symbol.
267        let num_bits = self.table.decode[self.state as usize].num_bits;
268        // New bits are read from the stream
269        let new_bits = br.get_bits(num_bits);
270        // Shift and mask out the bits that identify the current symbol
271        self.state <<= num_bits;
272        self.state &= self.table.decode.len() as u64 - 1;
273        // The new bits are appended at the end of the current state.
274        self.state |= new_bits;
275        num_bits
276    }
277}
278
279impl Default for HuffmanTable {
280    fn default() -> Self {
281        Self::new()
282    }
283}
284
285impl HuffmanTable {
286    /// Create a new, empty table.
287    pub fn new() -> HuffmanTable {
288        HuffmanTable {
289            decode: Vec::new(),
290
291            weights: Vec::with_capacity(256),
292            max_num_bits: 0,
293            bits: Vec::with_capacity(256),
294            bit_ranks: Vec::with_capacity(11),
295            rank_indexes: Vec::with_capacity(11),
296            fse_table: FSETable::new(100),
297        }
298    }
299
300    /// Completely empty the table then repopulate as a replica
301    /// of `other`.
302    pub fn reinit_from(&mut self, other: &Self) {
303        self.reset();
304        self.decode.extend_from_slice(&other.decode);
305        self.weights.extend_from_slice(&other.weights);
306        self.max_num_bits = other.max_num_bits;
307        self.bits.extend_from_slice(&other.bits);
308        self.rank_indexes.extend_from_slice(&other.rank_indexes);
309        self.fse_table.reinit_from(&other.fse_table);
310    }
311
312    /// Completely empty the table of all data.
313    pub fn reset(&mut self) {
314        self.decode.clear();
315        self.weights.clear();
316        self.max_num_bits = 0;
317        self.bits.clear();
318        self.bit_ranks.clear();
319        self.rank_indexes.clear();
320        self.fse_table.reset();
321    }
322
323    /// Read from `source` and parse it into a huffman table.
324    ///
325    /// Returns the number of bytes read.
326    pub fn build_decoder(&mut self, source: &[u8]) -> Result<u32, HuffmanTableError> {
327        self.decode.clear();
328
329        let bytes_used = self.read_weights(source)?;
330        self.build_table_from_weights()?;
331        Ok(bytes_used)
332    }
333
334    /// Read weights from the provided source.
335    ///
336    /// The huffman table is represented in the encoded data as a list of weights
337    /// at the most basic level. After the header, weights are read, then the table
338    /// can be built using that list of weights.
339    ///
340    /// Returns the number of bytes read.
341    fn read_weights(&mut self, source: &[u8]) -> Result<u32, HuffmanTableError> {
342        use HuffmanTableError as err;
343
344        if source.is_empty() {
345            return Err(err::SourceIsEmpty);
346        }
347        let header = source[0];
348        let mut bits_read = 8;
349
350        match header {
351            // If the header byte is less than 128, the series of weights
352            // is compressed using two interleaved FSE streams that share
353            // a distribution table.
354            0..=127 => {
355                let fse_stream = &source[1..];
356                if header as usize > fse_stream.len() {
357                    return Err(err::NotEnoughBytesForWeights {
358                        got_bytes: fse_stream.len(),
359                        expected_bytes: header,
360                    });
361                }
362                //fse decompress weights
363                let bytes_used_by_fse_header = self
364                    .fse_table
365                    .build_decoder(fse_stream, /*TODO find actual max*/ 100)?;
366
367                if bytes_used_by_fse_header > header as usize {
368                    return Err(err::FSETableUsedTooManyBytes {
369                        used: bytes_used_by_fse_header,
370                        available_bytes: header,
371                    });
372                }
373
374                vprintln!(
375                    "Building fse table for huffman weights used: {}",
376                    bytes_used_by_fse_header
377                );
378                // Huffman headers are compressed using two interleaved
379                // FSE bitstreams, where the first state (decoder) handles
380                // even symbols, and the second handles odd symbols.
381                let mut dec1 = FSEDecoder::new(&self.fse_table);
382                let mut dec2 = FSEDecoder::new(&self.fse_table);
383
384                let compressed_start = bytes_used_by_fse_header;
385                let compressed_length = header as usize - bytes_used_by_fse_header;
386
387                let compressed_weights = &fse_stream[compressed_start..];
388                if compressed_weights.len() < compressed_length {
389                    return Err(err::NotEnoughBytesToDecompressWeights {
390                        have: compressed_weights.len(),
391                        need: compressed_length,
392                    });
393                }
394                let compressed_weights = &compressed_weights[..compressed_length];
395                let mut br = BitReaderReversed::new(compressed_weights);
396
397                bits_read += (bytes_used_by_fse_header + compressed_length) * 8;
398
399                //skip the 0 padding at the end of the last byte of the bit stream and throw away the first 1 found
400                let mut skipped_bits = 0;
401                loop {
402                    let val = br.get_bits(1);
403                    skipped_bits += 1;
404                    if val == 1 || skipped_bits > 8 {
405                        break;
406                    }
407                }
408                if skipped_bits > 8 {
409                    //if more than 7 bits are 0, this is not the correct end of the bitstream. Either a bug or corrupted data
410                    return Err(err::ExtraPadding { skipped_bits });
411                }
412
413                dec1.init_state(&mut br)?;
414                dec2.init_state(&mut br)?;
415
416                self.weights.clear();
417
418                // The two decoders take turns decoding a single symbol and updating their state.
419                loop {
420                    let w = dec1.decode_symbol();
421                    self.weights.push(w);
422                    dec1.update_state(&mut br);
423
424                    if br.bits_remaining() <= -1 {
425                        //collect final states
426                        self.weights.push(dec2.decode_symbol());
427                        break;
428                    }
429
430                    let w = dec2.decode_symbol();
431                    self.weights.push(w);
432                    dec2.update_state(&mut br);
433
434                    if br.bits_remaining() <= -1 {
435                        //collect final states
436                        self.weights.push(dec1.decode_symbol());
437                        break;
438                    }
439                    //maximum number of weights is 255 because we use u8 symbols and the last weight is inferred from the sum of all others
440                    if self.weights.len() > 255 {
441                        return Err(err::TooManyWeights {
442                            got: self.weights.len(),
443                        });
444                    }
445                }
446            }
447            // If the header byte is greater than or equal to 128,
448            // weights are directly represented, where each weight is
449            // encoded directly as a 4 bit field. The weights will
450            // always be encoded with full bytes, meaning if there's
451            // an odd number of weights, the last weight will still
452            // occupy a full byte.
453            _ => {
454                // weights are directly encoded
455                let weights_raw = &source[1..];
456                let num_weights = header - 127;
457                self.weights.resize(num_weights as usize, 0);
458
459                let bytes_needed = if num_weights % 2 == 0 {
460                    num_weights as usize / 2
461                } else {
462                    (num_weights as usize / 2) + 1
463                };
464
465                if weights_raw.len() < bytes_needed {
466                    return Err(err::NotEnoughBytesInSource {
467                        got: weights_raw.len(),
468                        need: bytes_needed,
469                    });
470                }
471
472                for idx in 0..num_weights {
473                    if idx % 2 == 0 {
474                        self.weights[idx as usize] = weights_raw[idx as usize / 2] >> 4;
475                    } else {
476                        self.weights[idx as usize] = weights_raw[idx as usize / 2] & 0xF;
477                    }
478                    bits_read += 4;
479                }
480            }
481        }
482
483        let bytes_read = if bits_read % 8 == 0 {
484            bits_read / 8
485        } else {
486            (bits_read / 8) + 1
487        };
488        Ok(bytes_read as u32)
489    }
490
491    /// Once the weights have been read from the data, you can decode the weights
492    /// into a table, and use that table to decode the actual compressed data.
493    ///
494    /// This function populates the rest of the table from the series of weights.
495    fn build_table_from_weights(&mut self) -> Result<(), HuffmanTableError> {
496        use HuffmanTableError as err;
497
498        self.bits.clear();
499        self.bits.resize(self.weights.len() + 1, 0);
500
501        let mut weight_sum: u32 = 0;
502        for w in &self.weights {
503            if *w > MAX_MAX_NUM_BITS {
504                return Err(err::WeightBiggerThanMaxNumBits { got: *w });
505            }
506            weight_sum += if *w > 0 { 1_u32 << (*w - 1) } else { 0 };
507        }
508
509        if weight_sum == 0 {
510            return Err(err::MissingWeights);
511        }
512
513        let max_bits = highest_bit_set(weight_sum) as u8;
514        let left_over = (1 << max_bits) - weight_sum;
515
516        //left_over must be power of two
517        if !left_over.is_power_of_two() {
518            return Err(err::LeftoverIsNotAPowerOf2 { got: left_over });
519        }
520
521        let last_weight = highest_bit_set(left_over) as u8;
522
523        for symbol in 0..self.weights.len() {
524            let bits = if self.weights[symbol] > 0 {
525                max_bits + 1 - self.weights[symbol]
526            } else {
527                0
528            };
529            self.bits[symbol] = bits;
530        }
531
532        self.bits[self.weights.len()] = max_bits + 1 - last_weight;
533        self.max_num_bits = max_bits;
534
535        if max_bits > MAX_MAX_NUM_BITS {
536            return Err(err::MaxBitsTooHigh { got: max_bits });
537        }
538
539        self.bit_ranks.clear();
540        self.bit_ranks.resize((max_bits + 1) as usize, 0);
541        for num_bits in &self.bits {
542            self.bit_ranks[(*num_bits) as usize] += 1;
543        }
544
545        //fill with dummy symbols
546        self.decode.resize(
547            1 << self.max_num_bits,
548            Entry {
549                symbol: 0,
550                num_bits: 0,
551            },
552        );
553
554        //starting codes for each rank
555        self.rank_indexes.clear();
556        self.rank_indexes.resize((max_bits + 1) as usize, 0);
557
558        self.rank_indexes[max_bits as usize] = 0;
559        for bits in (1..self.rank_indexes.len() as u8).rev() {
560            self.rank_indexes[bits as usize - 1] = self.rank_indexes[bits as usize]
561                + self.bit_ranks[bits as usize] as usize * (1 << (max_bits - bits));
562        }
563
564        assert!(
565            self.rank_indexes[0] == self.decode.len(),
566            "rank_idx[0]: {} should be: {}",
567            self.rank_indexes[0],
568            self.decode.len()
569        );
570
571        for symbol in 0..self.bits.len() {
572            let bits_for_symbol = self.bits[symbol];
573            if bits_for_symbol != 0 {
574                // allocate code for the symbol and set in the table
575                // a code ignores all max_bits - bits[symbol] bits, so it gets
576                // a range that spans all of those in the decoding table
577                let base_idx = self.rank_indexes[bits_for_symbol as usize];
578                let len = 1 << (max_bits - bits_for_symbol);
579                self.rank_indexes[bits_for_symbol as usize] += len;
580                for idx in 0..len {
581                    self.decode[base_idx + idx].symbol = symbol as u8;
582                    self.decode[base_idx + idx].num_bits = bits_for_symbol;
583                }
584            }
585        }
586
587        Ok(())
588    }
589}