ruzstd/fse/
fse_decoder.rs

1use crate::decoding::bit_reader::BitReader;
2use crate::decoding::bit_reader_reverse::{BitReaderReversed, GetBitsError};
3use alloc::vec::Vec;
4
5/// FSE decoding involves a decoding table that describes the probabilities of
6/// all literals from 0 to the highest present one
7///
8/// <https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#fse-table-description>
9pub struct FSETable {
10    /// The maximum symbol in the table (inclusive). Limits the probabilities length to max_symbol + 1.
11    max_symbol: u8,
12    /// The actual table containing the decoded symbol and the compression data
13    /// connected to that symbol.
14    pub decode: Vec<Entry>, //used to decode symbols, and calculate the next state
15    /// The size of the table is stored in logarithm base 2 format,
16    /// with the **size of the table** being equal to `(1 << accuracy_log)`.
17    /// This value is used so that the decoder knows how many bits to read from the bitstream.
18    pub accuracy_log: u8,
19    /// In this context, probability refers to the likelihood that a symbol occurs in the given data.
20    /// Given this info, the encoder can assign shorter codes to symbols that appear more often,
21    /// and longer codes that appear less often, then the decoder can use the probability
22    /// to determine what code was assigned to what symbol.
23    ///
24    /// The probability of a single symbol is a value representing the proportion of times the symbol
25    /// would fall within the data.
26    ///
27    /// If a symbol probability is set to `-1`, it means that the probability of a symbol
28    /// occurring in the data is less than one.
29    pub symbol_probabilities: Vec<i32>, //used while building the decode Vector
30    /// The number of times each symbol occurs (The first entry being 0x0, the second being 0x1) and so on
31    /// up until the highest possible symbol (255).
32    symbol_counter: Vec<u32>,
33}
34
35#[derive(Debug)]
36#[non_exhaustive]
37pub enum FSETableError {
38    AccLogIsZero,
39    AccLogTooBig {
40        got: u8,
41        max: u8,
42    },
43    GetBitsError(GetBitsError),
44    ProbabilityCounterMismatch {
45        got: u32,
46        expected_sum: u32,
47        symbol_probabilities: Vec<i32>,
48    },
49    TooManySymbols {
50        got: usize,
51    },
52}
53
54#[cfg(feature = "std")]
55impl std::error::Error for FSETableError {
56    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
57        match self {
58            FSETableError::GetBitsError(source) => Some(source),
59            _ => None,
60        }
61    }
62}
63
64impl core::fmt::Display for FSETableError {
65    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
66        match self {
67            FSETableError::AccLogIsZero => write!(f, "Acclog must be at least 1"),
68            FSETableError::AccLogTooBig { got, max } => {
69                write!(
70                    f,
71                    "Found FSE acc_log: {0} bigger than allowed maximum in this case: {1}",
72                    got, max
73                )
74            }
75            FSETableError::GetBitsError(e) => write!(f, "{:?}", e),
76            FSETableError::ProbabilityCounterMismatch {
77                got,
78                expected_sum,
79                symbol_probabilities,
80            } => {
81                write!(f,
82                    "The counter ({}) exceeded the expected sum: {}. This means an error or corrupted data \n {:?}",
83                    got,
84                    expected_sum,
85                    symbol_probabilities,
86                )
87            }
88            FSETableError::TooManySymbols { got } => {
89                write!(
90                    f,
91                    "There are too many symbols in this distribution: {}. Max: 256",
92                    got,
93                )
94            }
95        }
96    }
97}
98
99impl From<GetBitsError> for FSETableError {
100    fn from(val: GetBitsError) -> Self {
101        Self::GetBitsError(val)
102    }
103}
104
105pub struct FSEDecoder<'table> {
106    /// An FSE state value represents an index in the FSE table.
107    pub state: Entry,
108    /// A reference to the table used for decoding.
109    table: &'table FSETable,
110}
111
112#[derive(Debug)]
113#[non_exhaustive]
114pub enum FSEDecoderError {
115    GetBitsError(GetBitsError),
116    TableIsUninitialized,
117}
118
119#[cfg(feature = "std")]
120impl std::error::Error for FSEDecoderError {
121    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
122        match self {
123            FSEDecoderError::GetBitsError(source) => Some(source),
124            _ => None,
125        }
126    }
127}
128
129impl core::fmt::Display for FSEDecoderError {
130    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
131        match self {
132            FSEDecoderError::GetBitsError(e) => write!(f, "{:?}", e),
133            FSEDecoderError::TableIsUninitialized => {
134                write!(f, "Tried to use an uninitialized table!")
135            }
136        }
137    }
138}
139
140impl From<GetBitsError> for FSEDecoderError {
141    fn from(val: GetBitsError) -> Self {
142        Self::GetBitsError(val)
143    }
144}
145
146/// A single entry in an FSE table.
147#[derive(Copy, Clone)]
148pub struct Entry {
149    /// This value is used as an offset value, and it is added
150    /// to a value read from the stream to determine the next state value.
151    pub base_line: u32,
152    /// How many bits should be read from the stream when decoding this entry.
153    pub num_bits: u8,
154    /// The byte that should be put in the decode output when encountering this state.
155    pub symbol: u8,
156}
157
158/// This value is added to the first 4 bits of the stream to determine the
159/// `Accuracy_Log`
160const ACC_LOG_OFFSET: u8 = 5;
161
162fn highest_bit_set(x: u32) -> u32 {
163    assert!(x > 0);
164    u32::BITS - x.leading_zeros()
165}
166
167impl<'t> FSEDecoder<'t> {
168    /// Initialize a new Finite State Entropy decoder.
169    pub fn new(table: &'t FSETable) -> FSEDecoder<'t> {
170        FSEDecoder {
171            state: table.decode.first().copied().unwrap_or(Entry {
172                base_line: 0,
173                num_bits: 0,
174                symbol: 0,
175            }),
176            table,
177        }
178    }
179
180    /// Returns the byte associated with the symbol the internal cursor is pointing at.
181    pub fn decode_symbol(&self) -> u8 {
182        self.state.symbol
183    }
184
185    /// Initialize internal state and prepare for decoding. After this, `decode_symbol` can be called
186    /// to read the first symbol and `update_state` can be called to prepare to read the next symbol.
187    pub fn init_state(&mut self, bits: &mut BitReaderReversed<'_>) -> Result<(), FSEDecoderError> {
188        if self.table.accuracy_log == 0 {
189            return Err(FSEDecoderError::TableIsUninitialized);
190        }
191        self.state = self.table.decode[bits.get_bits(self.table.accuracy_log) as usize];
192
193        Ok(())
194    }
195
196    /// Advance the internal state to decode the next symbol in the bitstream.
197    pub fn update_state(&mut self, bits: &mut BitReaderReversed<'_>) {
198        let num_bits = self.state.num_bits;
199        let add = bits.get_bits(num_bits);
200        let base_line = self.state.base_line;
201        let new_state = base_line + add as u32;
202        self.state = self.table.decode[new_state as usize];
203
204        //println!("Update: {}, {} -> {}", base_line, add,  self.state);
205    }
206}
207
208impl FSETable {
209    /// Initialize a new empty Finite State Entropy decoding table.
210    pub fn new(max_symbol: u8) -> FSETable {
211        FSETable {
212            max_symbol,
213            symbol_probabilities: Vec::with_capacity(256), //will never be more than 256 symbols because u8
214            symbol_counter: Vec::with_capacity(256), //will never be more than 256 symbols because u8
215            decode: Vec::new(),                      //depending on acc_log.
216            accuracy_log: 0,
217        }
218    }
219
220    /// Reset `self` and update `self`'s state to mirror the provided table.
221    pub fn reinit_from(&mut self, other: &Self) {
222        self.reset();
223        self.symbol_counter.extend_from_slice(&other.symbol_counter);
224        self.symbol_probabilities
225            .extend_from_slice(&other.symbol_probabilities);
226        self.decode.extend_from_slice(&other.decode);
227        self.accuracy_log = other.accuracy_log;
228    }
229
230    /// Empty the table and clear all internal state.
231    pub fn reset(&mut self) {
232        self.symbol_counter.clear();
233        self.symbol_probabilities.clear();
234        self.decode.clear();
235        self.accuracy_log = 0;
236    }
237
238    /// returns how many BYTEs (not bits) were read while building the decoder
239    pub fn build_decoder(&mut self, source: &[u8], max_log: u8) -> Result<usize, FSETableError> {
240        self.accuracy_log = 0;
241
242        let bytes_read = self.read_probabilities(source, max_log)?;
243        self.build_decoding_table()?;
244
245        Ok(bytes_read)
246    }
247
248    /// Given the provided accuracy log, build a decoding table from that log.
249    pub fn build_from_probabilities(
250        &mut self,
251        acc_log: u8,
252        probs: &[i32],
253    ) -> Result<(), FSETableError> {
254        if acc_log == 0 {
255            return Err(FSETableError::AccLogIsZero);
256        }
257        self.symbol_probabilities = probs.to_vec();
258        self.accuracy_log = acc_log;
259        self.build_decoding_table()
260    }
261
262    /// Build the actual decoding table after probabilities have been read into the table.
263    /// After this function is called, the decoding process can begin.
264    fn build_decoding_table(&mut self) -> Result<(), FSETableError> {
265        if self.symbol_probabilities.len() > self.max_symbol as usize + 1 {
266            return Err(FSETableError::TooManySymbols {
267                got: self.symbol_probabilities.len(),
268            });
269        }
270
271        self.decode.clear();
272
273        let table_size = 1 << self.accuracy_log;
274        if self.decode.len() < table_size {
275            self.decode.reserve(table_size - self.decode.len());
276        }
277        //fill with dummy entries
278        self.decode.resize(
279            table_size,
280            Entry {
281                base_line: 0,
282                num_bits: 0,
283                symbol: 0,
284            },
285        );
286
287        let mut negative_idx = table_size; //will point to the highest index with is already occupied by a negative-probability-symbol
288
289        //first scan for all -1 probabilities and place them at the top of the table
290        for symbol in 0..self.symbol_probabilities.len() {
291            if self.symbol_probabilities[symbol] == -1 {
292                negative_idx -= 1;
293                let entry = &mut self.decode[negative_idx];
294                entry.symbol = symbol as u8;
295                entry.base_line = 0;
296                entry.num_bits = self.accuracy_log;
297            }
298        }
299
300        //then place in a semi-random order all of the other symbols
301        let mut position = 0;
302        for idx in 0..self.symbol_probabilities.len() {
303            let symbol = idx as u8;
304            if self.symbol_probabilities[idx] <= 0 {
305                continue;
306            }
307
308            //for each probability point the symbol gets on slot
309            let prob = self.symbol_probabilities[idx];
310            for _ in 0..prob {
311                let entry = &mut self.decode[position];
312                entry.symbol = symbol;
313
314                position = next_position(position, table_size);
315                while position >= negative_idx {
316                    position = next_position(position, table_size);
317                    //everything above negative_idx is already taken
318                }
319            }
320        }
321
322        // baselines and num_bits can only be calculated when all symbols have been spread
323        self.symbol_counter.clear();
324        self.symbol_counter
325            .resize(self.symbol_probabilities.len(), 0);
326        for idx in 0..negative_idx {
327            let entry = &mut self.decode[idx];
328            let symbol = entry.symbol;
329            let prob = self.symbol_probabilities[symbol as usize];
330
331            let symbol_count = self.symbol_counter[symbol as usize];
332            let (bl, nb) = calc_baseline_and_numbits(table_size as u32, prob as u32, symbol_count);
333
334            //println!("symbol: {:2}, table: {}, prob: {:3}, count: {:3}, bl: {:3}, nb: {:2}", symbol, table_size, prob, symbol_count, bl, nb);
335
336            assert!(nb <= self.accuracy_log);
337            self.symbol_counter[symbol as usize] += 1;
338
339            entry.base_line = bl;
340            entry.num_bits = nb;
341        }
342        Ok(())
343    }
344
345    /// Read the accuracy log and the probability table from the source and return the number of bytes
346    /// read. If the size of the table is larger than the provided `max_log`, return an error.
347    fn read_probabilities(&mut self, source: &[u8], max_log: u8) -> Result<usize, FSETableError> {
348        self.symbol_probabilities.clear(); //just clear, we will fill a probability for each entry anyways. No need to force new allocs here
349
350        let mut br = BitReader::new(source);
351        self.accuracy_log = ACC_LOG_OFFSET + (br.get_bits(4)? as u8);
352        if self.accuracy_log > max_log {
353            return Err(FSETableError::AccLogTooBig {
354                got: self.accuracy_log,
355                max: max_log,
356            });
357        }
358        if self.accuracy_log == 0 {
359            return Err(FSETableError::AccLogIsZero);
360        }
361
362        let probability_sum = 1 << self.accuracy_log;
363        let mut probability_counter = 0;
364
365        while probability_counter < probability_sum {
366            let max_remaining_value = probability_sum - probability_counter + 1;
367            let bits_to_read = highest_bit_set(max_remaining_value);
368
369            let unchecked_value = br.get_bits(bits_to_read as usize)? as u32;
370
371            let low_threshold = ((1 << bits_to_read) - 1) - (max_remaining_value);
372            let mask = (1 << (bits_to_read - 1)) - 1;
373            let small_value = unchecked_value & mask;
374
375            let value = if small_value < low_threshold {
376                br.return_bits(1);
377                small_value
378            } else if unchecked_value > mask {
379                unchecked_value - low_threshold
380            } else {
381                unchecked_value
382            };
383            //println!("{}, {}, {}", self.symbol_probablilities.len(), unchecked_value, value);
384
385            let prob = (value as i32) - 1;
386
387            self.symbol_probabilities.push(prob);
388            if prob != 0 {
389                if prob > 0 {
390                    probability_counter += prob as u32;
391                } else {
392                    // probability -1 counts as 1
393                    assert!(prob == -1);
394                    probability_counter += 1;
395                }
396            } else {
397                //fast skip further zero probabilities
398                loop {
399                    let skip_amount = br.get_bits(2)? as usize;
400
401                    self.symbol_probabilities
402                        .resize(self.symbol_probabilities.len() + skip_amount, 0);
403                    if skip_amount != 3 {
404                        break;
405                    }
406                }
407            }
408        }
409
410        if probability_counter != probability_sum {
411            return Err(FSETableError::ProbabilityCounterMismatch {
412                got: probability_counter,
413                expected_sum: probability_sum,
414                symbol_probabilities: self.symbol_probabilities.clone(),
415            });
416        }
417        if self.symbol_probabilities.len() > self.max_symbol as usize + 1 {
418            return Err(FSETableError::TooManySymbols {
419                got: self.symbol_probabilities.len(),
420            });
421        }
422
423        let bytes_read = if br.bits_read() % 8 == 0 {
424            br.bits_read() / 8
425        } else {
426            (br.bits_read() / 8) + 1
427        };
428
429        Ok(bytes_read)
430    }
431}
432
433//utility functions for building the decoding table from probabilities
434/// Calculate the position of the next entry of the table given the current
435/// position and size of the table.
436fn next_position(mut p: usize, table_size: usize) -> usize {
437    p += (table_size >> 1) + (table_size >> 3) + 3;
438    p &= table_size - 1;
439    p
440}
441
442fn calc_baseline_and_numbits(
443    num_states_total: u32,
444    num_states_symbol: u32,
445    state_number: u32,
446) -> (u32, u8) {
447    let num_state_slices = if 1 << (highest_bit_set(num_states_symbol) - 1) == num_states_symbol {
448        num_states_symbol
449    } else {
450        1 << (highest_bit_set(num_states_symbol))
451    }; //always power of two
452
453    let num_double_width_state_slices = num_state_slices - num_states_symbol; //leftovers to the power of two need to be distributed
454    let num_single_width_state_slices = num_states_symbol - num_double_width_state_slices; //these will not receive a double width slice of states
455    let slice_width = num_states_total / num_state_slices; //size of a single width slice of states
456    let num_bits = highest_bit_set(slice_width) - 1; //number of bits needed to read for one slice
457
458    if state_number < num_double_width_state_slices {
459        let baseline = num_single_width_state_slices * slice_width + state_number * slice_width * 2;
460        (baseline, num_bits as u8 + 1)
461    } else {
462        let index_shifted = state_number - num_double_width_state_slices;
463        ((index_shifted * slice_width), num_bits as u8)
464    }
465}