ruzstd/fse/
fse_decoder.rs

1use crate::bit_io::{BitReader, BitReaderReversed};
2use crate::decoding::errors::{FSEDecoderError, FSETableError};
3use alloc::vec::Vec;
4
5pub struct FSEDecoder<'table> {
6    /// An FSE state value represents an index in the FSE table.
7    pub state: Entry,
8    /// A reference to the table used for decoding.
9    table: &'table FSETable,
10}
11
12impl<'t> FSEDecoder<'t> {
13    /// Initialize a new Finite State Entropy decoder.
14    pub fn new(table: &'t FSETable) -> FSEDecoder<'t> {
15        FSEDecoder {
16            state: table.decode.first().copied().unwrap_or(Entry {
17                base_line: 0,
18                num_bits: 0,
19                symbol: 0,
20            }),
21            table,
22        }
23    }
24
25    /// Returns the byte associated with the symbol the internal cursor is pointing at.
26    pub fn decode_symbol(&self) -> u8 {
27        self.state.symbol
28    }
29
30    /// Initialize internal state and prepare for decoding. After this, `decode_symbol` can be called
31    /// to read the first symbol and `update_state` can be called to prepare to read the next symbol.
32    pub fn init_state(&mut self, bits: &mut BitReaderReversed<'_>) -> Result<(), FSEDecoderError> {
33        if self.table.accuracy_log == 0 {
34            return Err(FSEDecoderError::TableIsUninitialized);
35        }
36        let new_state = bits.get_bits(self.table.accuracy_log);
37        self.state = self.table.decode[new_state as usize];
38
39        Ok(())
40    }
41
42    /// Advance the internal state to decode the next symbol in the bitstream.
43    pub fn update_state(&mut self, bits: &mut BitReaderReversed<'_>) {
44        let num_bits = self.state.num_bits;
45        let add = bits.get_bits(num_bits);
46        let base_line = self.state.base_line;
47        let new_state = base_line + add as u32;
48        self.state = self.table.decode[new_state as usize];
49
50        //println!("Update: {}, {} -> {}", base_line, add,  self.state);
51    }
52}
53
54/// FSE decoding involves a decoding table that describes the probabilities of
55/// all literals from 0 to the highest present one
56///
57/// <https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#fse-table-description>
58#[derive(Debug, Clone)]
59pub struct FSETable {
60    /// The maximum symbol in the table (inclusive). Limits the probabilities length to max_symbol + 1.
61    max_symbol: u8,
62    /// The actual table containing the decoded symbol and the compression data
63    /// connected to that symbol.
64    pub decode: Vec<Entry>, //used to decode symbols, and calculate the next state
65    /// The size of the table is stored in logarithm base 2 format,
66    /// with the **size of the table** being equal to `(1 << accuracy_log)`.
67    /// This value is used so that the decoder knows how many bits to read from the bitstream.
68    pub accuracy_log: u8,
69    /// In this context, probability refers to the likelihood that a symbol occurs in the given data.
70    /// Given this info, the encoder can assign shorter codes to symbols that appear more often,
71    /// and longer codes that appear less often, then the decoder can use the probability
72    /// to determine what code was assigned to what symbol.
73    ///
74    /// The probability of a single symbol is a value representing the proportion of times the symbol
75    /// would fall within the data.
76    ///
77    /// If a symbol probability is set to `-1`, it means that the probability of a symbol
78    /// occurring in the data is less than one.
79    pub symbol_probabilities: Vec<i32>, //used while building the decode Vector
80    /// The number of times each symbol occurs (The first entry being 0x0, the second being 0x1) and so on
81    /// up until the highest possible symbol (255).
82    symbol_counter: Vec<u32>,
83}
84
85impl FSETable {
86    /// Initialize a new empty Finite State Entropy decoding table.
87    pub fn new(max_symbol: u8) -> FSETable {
88        FSETable {
89            max_symbol,
90            symbol_probabilities: Vec::with_capacity(256), //will never be more than 256 symbols because u8
91            symbol_counter: Vec::with_capacity(256), //will never be more than 256 symbols because u8
92            decode: Vec::new(),                      //depending on acc_log.
93            accuracy_log: 0,
94        }
95    }
96
97    /// Reset `self` and update `self`'s state to mirror the provided table.
98    pub fn reinit_from(&mut self, other: &Self) {
99        self.reset();
100        self.symbol_counter.extend_from_slice(&other.symbol_counter);
101        self.symbol_probabilities
102            .extend_from_slice(&other.symbol_probabilities);
103        self.decode.extend_from_slice(&other.decode);
104        self.accuracy_log = other.accuracy_log;
105    }
106
107    /// Empty the table and clear all internal state.
108    pub fn reset(&mut self) {
109        self.symbol_counter.clear();
110        self.symbol_probabilities.clear();
111        self.decode.clear();
112        self.accuracy_log = 0;
113    }
114
115    /// returns how many BYTEs (not bits) were read while building the decoder
116    pub fn build_decoder(&mut self, source: &[u8], max_log: u8) -> Result<usize, FSETableError> {
117        self.accuracy_log = 0;
118
119        let bytes_read = self.read_probabilities(source, max_log)?;
120        self.build_decoding_table()?;
121
122        Ok(bytes_read)
123    }
124
125    /// Given the provided accuracy log, build a decoding table from that log.
126    pub fn build_from_probabilities(
127        &mut self,
128        acc_log: u8,
129        probs: &[i32],
130    ) -> Result<(), FSETableError> {
131        if acc_log == 0 {
132            return Err(FSETableError::AccLogIsZero);
133        }
134        self.symbol_probabilities = probs.to_vec();
135        self.accuracy_log = acc_log;
136        self.build_decoding_table()
137    }
138
139    /// Build the actual decoding table after probabilities have been read into the table.
140    /// After this function is called, the decoding process can begin.
141    fn build_decoding_table(&mut self) -> Result<(), FSETableError> {
142        if self.symbol_probabilities.len() > self.max_symbol as usize + 1 {
143            return Err(FSETableError::TooManySymbols {
144                got: self.symbol_probabilities.len(),
145            });
146        }
147
148        self.decode.clear();
149
150        let table_size = 1 << self.accuracy_log;
151        if self.decode.len() < table_size {
152            self.decode.reserve(table_size - self.decode.len());
153        }
154        //fill with dummy entries
155        self.decode.resize(
156            table_size,
157            Entry {
158                base_line: 0,
159                num_bits: 0,
160                symbol: 0,
161            },
162        );
163
164        let mut negative_idx = table_size; //will point to the highest index with is already occupied by a negative-probability-symbol
165
166        //first scan for all -1 probabilities and place them at the top of the table
167        for symbol in 0..self.symbol_probabilities.len() {
168            if self.symbol_probabilities[symbol] == -1 {
169                negative_idx -= 1;
170                let entry = &mut self.decode[negative_idx];
171                entry.symbol = symbol as u8;
172                entry.base_line = 0;
173                entry.num_bits = self.accuracy_log;
174            }
175        }
176
177        //then place in a semi-random order all of the other symbols
178        let mut position = 0;
179        for idx in 0..self.symbol_probabilities.len() {
180            let symbol = idx as u8;
181            if self.symbol_probabilities[idx] <= 0 {
182                continue;
183            }
184
185            //for each probability point the symbol gets on slot
186            let prob = self.symbol_probabilities[idx];
187            for _ in 0..prob {
188                let entry = &mut self.decode[position];
189                entry.symbol = symbol;
190
191                position = next_position(position, table_size);
192                while position >= negative_idx {
193                    position = next_position(position, table_size);
194                    //everything above negative_idx is already taken
195                }
196            }
197        }
198
199        // baselines and num_bits can only be calculated when all symbols have been spread
200        self.symbol_counter.clear();
201        self.symbol_counter
202            .resize(self.symbol_probabilities.len(), 0);
203        for idx in 0..negative_idx {
204            let entry = &mut self.decode[idx];
205            let symbol = entry.symbol;
206            let prob = self.symbol_probabilities[symbol as usize];
207
208            let symbol_count = self.symbol_counter[symbol as usize];
209            let (bl, nb) = calc_baseline_and_numbits(table_size as u32, prob as u32, symbol_count);
210
211            //println!("symbol: {:2}, table: {}, prob: {:3}, count: {:3}, bl: {:3}, nb: {:2}", symbol, table_size, prob, symbol_count, bl, nb);
212
213            assert!(nb <= self.accuracy_log);
214            self.symbol_counter[symbol as usize] += 1;
215
216            entry.base_line = bl;
217            entry.num_bits = nb;
218        }
219        Ok(())
220    }
221
222    /// Read the accuracy log and the probability table from the source and return the number of bytes
223    /// read. If the size of the table is larger than the provided `max_log`, return an error.
224    fn read_probabilities(&mut self, source: &[u8], max_log: u8) -> Result<usize, FSETableError> {
225        self.symbol_probabilities.clear(); //just clear, we will fill a probability for each entry anyways. No need to force new allocs here
226
227        let mut br = BitReader::new(source);
228        self.accuracy_log = ACC_LOG_OFFSET + (br.get_bits(4)? as u8);
229        if self.accuracy_log > max_log {
230            return Err(FSETableError::AccLogTooBig {
231                got: self.accuracy_log,
232                max: max_log,
233            });
234        }
235        if self.accuracy_log == 0 {
236            return Err(FSETableError::AccLogIsZero);
237        }
238
239        let probability_sum = 1 << self.accuracy_log;
240        let mut probability_counter = 0;
241
242        while probability_counter < probability_sum {
243            let max_remaining_value = probability_sum - probability_counter + 1;
244            let bits_to_read = highest_bit_set(max_remaining_value);
245
246            let unchecked_value = br.get_bits(bits_to_read as usize)? as u32;
247
248            let low_threshold = ((1 << bits_to_read) - 1) - (max_remaining_value);
249            let mask = (1 << (bits_to_read - 1)) - 1;
250            let small_value = unchecked_value & mask;
251
252            let value = if small_value < low_threshold {
253                br.return_bits(1);
254                small_value
255            } else if unchecked_value > mask {
256                unchecked_value - low_threshold
257            } else {
258                unchecked_value
259            };
260            //println!("{}, {}, {}", self.symbol_probablilities.len(), unchecked_value, value);
261
262            let prob = (value as i32) - 1;
263
264            self.symbol_probabilities.push(prob);
265            if prob != 0 {
266                if prob > 0 {
267                    probability_counter += prob as u32;
268                } else {
269                    // probability -1 counts as 1
270                    assert!(prob == -1);
271                    probability_counter += 1;
272                }
273            } else {
274                //fast skip further zero probabilities
275                loop {
276                    let skip_amount = br.get_bits(2)? as usize;
277
278                    self.symbol_probabilities
279                        .resize(self.symbol_probabilities.len() + skip_amount, 0);
280                    if skip_amount != 3 {
281                        break;
282                    }
283                }
284            }
285        }
286
287        if probability_counter != probability_sum {
288            return Err(FSETableError::ProbabilityCounterMismatch {
289                got: probability_counter,
290                expected_sum: probability_sum,
291                symbol_probabilities: self.symbol_probabilities.clone(),
292            });
293        }
294        if self.symbol_probabilities.len() > self.max_symbol as usize + 1 {
295            return Err(FSETableError::TooManySymbols {
296                got: self.symbol_probabilities.len(),
297            });
298        }
299
300        let bytes_read = if br.bits_read() % 8 == 0 {
301            br.bits_read() / 8
302        } else {
303            (br.bits_read() / 8) + 1
304        };
305
306        Ok(bytes_read)
307    }
308}
309
310/// A single entry in an FSE table.
311#[derive(Copy, Clone, Debug)]
312pub struct Entry {
313    /// This value is used as an offset value, and it is added
314    /// to a value read from the stream to determine the next state value.
315    pub base_line: u32,
316    /// How many bits should be read from the stream when decoding this entry.
317    pub num_bits: u8,
318    /// The byte that should be put in the decode output when encountering this state.
319    pub symbol: u8,
320}
321
322/// This value is added to the first 4 bits of the stream to determine the
323/// `Accuracy_Log`
324const ACC_LOG_OFFSET: u8 = 5;
325
326fn highest_bit_set(x: u32) -> u32 {
327    assert!(x > 0);
328    u32::BITS - x.leading_zeros()
329}
330
331//utility functions for building the decoding table from probabilities
332/// Calculate the position of the next entry of the table given the current
333/// position and size of the table.
334fn next_position(mut p: usize, table_size: usize) -> usize {
335    p += (table_size >> 1) + (table_size >> 3) + 3;
336    p &= table_size - 1;
337    p
338}
339
340fn calc_baseline_and_numbits(
341    num_states_total: u32,
342    num_states_symbol: u32,
343    state_number: u32,
344) -> (u32, u8) {
345    if num_states_symbol == 0 {
346        return (0, 0);
347    }
348    let num_state_slices = if 1 << (highest_bit_set(num_states_symbol) - 1) == num_states_symbol {
349        num_states_symbol
350    } else {
351        1 << (highest_bit_set(num_states_symbol))
352    }; //always power of two
353
354    let num_double_width_state_slices = num_state_slices - num_states_symbol; //leftovers to the power of two need to be distributed
355    let num_single_width_state_slices = num_states_symbol - num_double_width_state_slices; //these will not receive a double width slice of states
356    let slice_width = num_states_total / num_state_slices; //size of a single width slice of states
357    let num_bits = highest_bit_set(slice_width) - 1; //number of bits needed to read for one slice
358
359    if state_number < num_double_width_state_slices {
360        let baseline = num_single_width_state_slices * slice_width + state_number * slice_width * 2;
361        (baseline, num_bits as u8 + 1)
362    } else {
363        let index_shifted = state_number - num_double_width_state_slices;
364        ((index_shifted * slice_width), num_bits as u8)
365    }
366}