ruzstd/huff0/
huff0_decoder.rs

1//! Utilities for decoding Huff0 encoded huffman data.
2
3use crate::bit_io::BitReaderReversed;
4use crate::decoding::errors::HuffmanTableError;
5use crate::fse::{FSEDecoder, FSETable};
6use alloc::vec::Vec;
7
8/// The Zstandard specification limits the maximum length of a code to 11 bits.
9pub(crate) const MAX_MAX_NUM_BITS: u8 = 11;
10
11pub struct HuffmanDecoder<'table> {
12    table: &'table HuffmanTable,
13    /// State is used to index into the table.
14    pub state: u64,
15}
16
17impl<'t> HuffmanDecoder<'t> {
18    /// Create a new decoder with the provided table
19    pub fn new(table: &'t HuffmanTable) -> HuffmanDecoder<'t> {
20        HuffmanDecoder { table, state: 0 }
21    }
22
23    /// Decode the symbol the internal state (cursor) is pointed at and return the
24    /// decoded literal.
25    pub fn decode_symbol(&mut self) -> u8 {
26        self.table.decode[self.state as usize].symbol
27    }
28
29    /// Initialize internal state and prepare to decode data. Then, `decode_symbol` can be called
30    /// to read the byte the internal cursor is pointing at, and `next_state` can be called to advance
31    /// the cursor until the max number of bits has been read.
32    pub fn init_state(&mut self, br: &mut BitReaderReversed<'_>) -> u8 {
33        let num_bits = self.table.max_num_bits;
34        let new_bits = br.get_bits(num_bits);
35        self.state = new_bits;
36        num_bits
37    }
38
39    /// Advance the internal cursor to the next symbol. After this, you can call `decode_symbol`
40    /// to read from the new position.
41    pub fn next_state(&mut self, br: &mut BitReaderReversed<'_>) -> u8 {
42        // self.state stores a small section, or a window of the bit stream. The table can be indexed via this state,
43        // telling you how many bits identify the current symbol.
44        let num_bits = self.table.decode[self.state as usize].num_bits;
45        // New bits are read from the stream
46        let new_bits = br.get_bits(num_bits);
47        // Shift and mask out the bits that identify the current symbol
48        self.state <<= num_bits;
49        self.state &= self.table.decode.len() as u64 - 1;
50        // The new bits are appended at the end of the current state.
51        self.state |= new_bits;
52        num_bits
53    }
54}
55
56/// A Huffman decoding table contains a list of Huffman prefix codes and their associated values
57pub struct HuffmanTable {
58    decode: Vec<Entry>,
59    /// The weight of a symbol is the number of occurences in a table.
60    /// This value is used in constructing a binary tree referred to as
61    /// a Huffman tree. Once this tree is constructed, it can be used to build the
62    /// lookup table
63    weights: Vec<u8>,
64    /// The maximum size in bits a prefix code in the encoded data can be.
65    /// This value is used so that the decoder knows how many bits
66    /// to read from the bitstream before checking the table. This
67    /// value must be 11 or lower.
68    pub max_num_bits: u8,
69    bits: Vec<u8>,
70    bit_ranks: Vec<u32>,
71    rank_indexes: Vec<usize>,
72    /// In some cases, the list of weights is compressed using FSE compression.
73    fse_table: FSETable,
74}
75
76impl HuffmanTable {
77    /// Create a new, empty table.
78    pub fn new() -> HuffmanTable {
79        HuffmanTable {
80            decode: Vec::new(),
81
82            weights: Vec::with_capacity(256),
83            max_num_bits: 0,
84            bits: Vec::with_capacity(256),
85            bit_ranks: Vec::with_capacity(11),
86            rank_indexes: Vec::with_capacity(11),
87            fse_table: FSETable::new(255),
88        }
89    }
90
91    /// Completely empty the table then repopulate as a replica
92    /// of `other`.
93    pub fn reinit_from(&mut self, other: &Self) {
94        self.reset();
95        self.decode.extend_from_slice(&other.decode);
96        self.weights.extend_from_slice(&other.weights);
97        self.max_num_bits = other.max_num_bits;
98        self.bits.extend_from_slice(&other.bits);
99        self.rank_indexes.extend_from_slice(&other.rank_indexes);
100        self.fse_table.reinit_from(&other.fse_table);
101    }
102
103    /// Completely empty the table of all data.
104    pub fn reset(&mut self) {
105        self.decode.clear();
106        self.weights.clear();
107        self.max_num_bits = 0;
108        self.bits.clear();
109        self.bit_ranks.clear();
110        self.rank_indexes.clear();
111        self.fse_table.reset();
112    }
113
114    /// Read from `source` and decode the input, populating the huffman decoding table.
115    ///
116    /// Returns the number of bytes read.
117    pub fn build_decoder(&mut self, source: &[u8]) -> Result<u32, HuffmanTableError> {
118        self.decode.clear();
119
120        let bytes_used = self.read_weights(source)?;
121        self.build_table_from_weights()?;
122        Ok(bytes_used)
123    }
124
125    /// Read weights from the provided source.
126    ///
127    /// The huffman table is represented in the input data as a list of weights.
128    /// After the header, weights are read, then a Huffman decoding table
129    /// can be constructed using that list of weights.
130    ///
131    /// Returns the number of bytes read.
132    fn read_weights(&mut self, source: &[u8]) -> Result<u32, HuffmanTableError> {
133        use HuffmanTableError as err;
134
135        if source.is_empty() {
136            return Err(err::SourceIsEmpty);
137        }
138        let header = source[0];
139        let mut bits_read = 8;
140
141        match header {
142            // If the header byte is less than 128, the series of weights
143            // is compressed using two interleaved FSE streams that share
144            // a distribution table.
145            0..=127 => {
146                let fse_stream = &source[1..];
147                if header as usize > fse_stream.len() {
148                    return Err(err::NotEnoughBytesForWeights {
149                        got_bytes: fse_stream.len(),
150                        expected_bytes: header,
151                    });
152                }
153                //fse decompress weights
154                let bytes_used_by_fse_header = self.fse_table.build_decoder(fse_stream, 6)?;
155
156                if bytes_used_by_fse_header > header as usize {
157                    return Err(err::FSETableUsedTooManyBytes {
158                        used: bytes_used_by_fse_header,
159                        available_bytes: header,
160                    });
161                }
162
163                vprintln!(
164                    "Building fse table for huffman weights used: {}",
165                    bytes_used_by_fse_header
166                );
167                // Huffman headers are compressed using two interleaved
168                // FSE bitstreams, where the first state (decoder) handles
169                // even symbols, and the second handles odd symbols.
170                let mut dec1 = FSEDecoder::new(&self.fse_table);
171                let mut dec2 = FSEDecoder::new(&self.fse_table);
172
173                let compressed_start = bytes_used_by_fse_header;
174                let compressed_length = header as usize - bytes_used_by_fse_header;
175
176                let compressed_weights = &fse_stream[compressed_start..];
177                if compressed_weights.len() < compressed_length {
178                    return Err(err::NotEnoughBytesToDecompressWeights {
179                        have: compressed_weights.len(),
180                        need: compressed_length,
181                    });
182                }
183                let compressed_weights = &compressed_weights[..compressed_length];
184                let mut br = BitReaderReversed::new(compressed_weights);
185
186                bits_read += (bytes_used_by_fse_header + compressed_length) * 8;
187
188                //skip the 0 padding at the end of the last byte of the bit stream and throw away the first 1 found
189                let mut skipped_bits = 0;
190                loop {
191                    let val = br.get_bits(1);
192                    skipped_bits += 1;
193                    if val == 1 || skipped_bits > 8 {
194                        break;
195                    }
196                }
197                if skipped_bits > 8 {
198                    //if more than 7 bits are 0, this is not the correct end of the bitstream. Either a bug or corrupted data
199                    return Err(err::ExtraPadding { skipped_bits });
200                }
201
202                dec1.init_state(&mut br)?;
203                dec2.init_state(&mut br)?;
204
205                self.weights.clear();
206
207                // The two decoders take turns decoding a single symbol and updating their state.
208                loop {
209                    let w = dec1.decode_symbol();
210                    self.weights.push(w);
211                    dec1.update_state(&mut br);
212
213                    if br.bits_remaining() <= -1 {
214                        //collect final states
215                        self.weights.push(dec2.decode_symbol());
216                        break;
217                    }
218
219                    let w = dec2.decode_symbol();
220                    self.weights.push(w);
221                    dec2.update_state(&mut br);
222
223                    if br.bits_remaining() <= -1 {
224                        //collect final states
225                        self.weights.push(dec1.decode_symbol());
226                        break;
227                    }
228                    //maximum number of weights is 255 because we use u8 symbols and the last weight is inferred from the sum of all others
229                    if self.weights.len() > 255 {
230                        return Err(err::TooManyWeights {
231                            got: self.weights.len(),
232                        });
233                    }
234                }
235            }
236            // If the header byte is greater than or equal to 128,
237            // weights are directly represented, where each weight is
238            // encoded directly as a 4 bit field. The weights will
239            // always be encoded with full bytes, meaning if there's
240            // an odd number of weights, the last weight will still
241            // occupy a full byte.
242            _ => {
243                // weights are directly encoded
244                let weights_raw = &source[1..];
245                let num_weights = header - 127;
246                self.weights.resize(num_weights as usize, 0);
247
248                let bytes_needed = if num_weights % 2 == 0 {
249                    num_weights as usize / 2
250                } else {
251                    (num_weights as usize / 2) + 1
252                };
253
254                if weights_raw.len() < bytes_needed {
255                    return Err(err::NotEnoughBytesInSource {
256                        got: weights_raw.len(),
257                        need: bytes_needed,
258                    });
259                }
260
261                for idx in 0..num_weights {
262                    if idx % 2 == 0 {
263                        self.weights[idx as usize] = weights_raw[idx as usize / 2] >> 4;
264                    } else {
265                        self.weights[idx as usize] = weights_raw[idx as usize / 2] & 0xF;
266                    }
267                    bits_read += 4;
268                }
269            }
270        }
271
272        let bytes_read = if bits_read % 8 == 0 {
273            bits_read / 8
274        } else {
275            (bits_read / 8) + 1
276        };
277        Ok(bytes_read as u32)
278    }
279
280    /// Once the weights have been read from the data, you can decode the weights
281    /// into a table, and use that table to decode the actual compressed data.
282    ///
283    /// This function populates the rest of the table from the series of weights.
284    fn build_table_from_weights(&mut self) -> Result<(), HuffmanTableError> {
285        use HuffmanTableError as err;
286
287        self.bits.clear();
288        self.bits.resize(self.weights.len() + 1, 0);
289
290        let mut weight_sum: u32 = 0;
291        for w in &self.weights {
292            if *w > MAX_MAX_NUM_BITS {
293                return Err(err::WeightBiggerThanMaxNumBits { got: *w });
294            }
295            weight_sum += if *w > 0 { 1_u32 << (*w - 1) } else { 0 };
296        }
297
298        if weight_sum == 0 {
299            return Err(err::MissingWeights);
300        }
301
302        let max_bits = highest_bit_set(weight_sum) as u8;
303        let left_over = (1 << max_bits) - weight_sum;
304
305        //left_over must be power of two
306        if !left_over.is_power_of_two() {
307            return Err(err::LeftoverIsNotAPowerOf2 { got: left_over });
308        }
309
310        let last_weight = highest_bit_set(left_over) as u8;
311
312        for symbol in 0..self.weights.len() {
313            let bits = if self.weights[symbol] > 0 {
314                max_bits + 1 - self.weights[symbol]
315            } else {
316                0
317            };
318            self.bits[symbol] = bits;
319        }
320
321        self.bits[self.weights.len()] = max_bits + 1 - last_weight;
322        self.max_num_bits = max_bits;
323
324        if max_bits > MAX_MAX_NUM_BITS {
325            return Err(err::MaxBitsTooHigh { got: max_bits });
326        }
327
328        self.bit_ranks.clear();
329        self.bit_ranks.resize((max_bits + 1) as usize, 0);
330        for num_bits in &self.bits {
331            self.bit_ranks[(*num_bits) as usize] += 1;
332        }
333
334        //fill with dummy symbols
335        self.decode.resize(
336            1 << self.max_num_bits,
337            Entry {
338                symbol: 0,
339                num_bits: 0,
340            },
341        );
342
343        //starting codes for each rank
344        self.rank_indexes.clear();
345        self.rank_indexes.resize((max_bits + 1) as usize, 0);
346
347        self.rank_indexes[max_bits as usize] = 0;
348        for bits in (1..self.rank_indexes.len() as u8).rev() {
349            self.rank_indexes[bits as usize - 1] = self.rank_indexes[bits as usize]
350                + self.bit_ranks[bits as usize] as usize * (1 << (max_bits - bits));
351        }
352
353        assert!(
354            self.rank_indexes[0] == self.decode.len(),
355            "rank_idx[0]: {} should be: {}",
356            self.rank_indexes[0],
357            self.decode.len()
358        );
359
360        for symbol in 0..self.bits.len() {
361            let bits_for_symbol = self.bits[symbol];
362            if bits_for_symbol != 0 {
363                // allocate code for the symbol and set in the table
364                // a code ignores all max_bits - bits[symbol] bits, so it gets
365                // a range that spans all of those in the decoding table
366                let base_idx = self.rank_indexes[bits_for_symbol as usize];
367                let len = 1 << (max_bits - bits_for_symbol);
368                self.rank_indexes[bits_for_symbol as usize] += len;
369                for idx in 0..len {
370                    self.decode[base_idx + idx].symbol = symbol as u8;
371                    self.decode[base_idx + idx].num_bits = bits_for_symbol;
372                }
373            }
374        }
375
376        Ok(())
377    }
378}
379
380impl Default for HuffmanTable {
381    fn default() -> Self {
382        Self::new()
383    }
384}
385
386/// A single entry in the table contains the decoded symbol/literal and the
387/// size of the prefix code.
388#[derive(Copy, Clone, Debug)]
389pub struct Entry {
390    /// The byte that the prefix code replaces during encoding.
391    symbol: u8,
392    /// The number of bits the prefix code occupies.
393    num_bits: u8,
394}
395
396/// Assert that the provided value is greater than zero, and returns the
397/// 32 - the number of leading zeros
398fn highest_bit_set(x: u32) -> u32 {
399    assert!(x > 0);
400    u32::BITS - x.leading_zeros()
401}