ruzstd/fse/
fse_decoder.rs1use crate::bit_io::{BitReader, BitReaderReversed};
2use crate::decoding::errors::{FSEDecoderError, FSETableError};
3use alloc::vec::Vec;
4
5pub struct FSEDecoder<'table> {
6 pub state: Entry,
8 table: &'table FSETable,
10}
11
12impl<'t> FSEDecoder<'t> {
13 pub fn new(table: &'t FSETable) -> FSEDecoder<'t> {
15 FSEDecoder {
16 state: table.decode.first().copied().unwrap_or(Entry {
17 base_line: 0,
18 num_bits: 0,
19 symbol: 0,
20 }),
21 table,
22 }
23 }
24
25 pub fn decode_symbol(&self) -> u8 {
27 self.state.symbol
28 }
29
30 pub fn init_state(&mut self, bits: &mut BitReaderReversed<'_>) -> Result<(), FSEDecoderError> {
33 if self.table.accuracy_log == 0 {
34 return Err(FSEDecoderError::TableIsUninitialized);
35 }
36 let new_state = bits.get_bits(self.table.accuracy_log);
37 self.state = self.table.decode[new_state as usize];
38
39 Ok(())
40 }
41
42 pub fn update_state(&mut self, bits: &mut BitReaderReversed<'_>) {
44 let num_bits = self.state.num_bits;
45 let add = bits.get_bits(num_bits);
46 let base_line = self.state.base_line;
47 let new_state = base_line + add as u32;
48 self.state = self.table.decode[new_state as usize];
49
50 }
52}
53
54#[derive(Debug, Clone)]
59pub struct FSETable {
60 max_symbol: u8,
62 pub decode: Vec<Entry>, pub accuracy_log: u8,
69 pub symbol_probabilities: Vec<i32>, symbol_counter: Vec<u32>,
83}
84
85impl FSETable {
86 pub fn new(max_symbol: u8) -> FSETable {
88 FSETable {
89 max_symbol,
90 symbol_probabilities: Vec::with_capacity(256), symbol_counter: Vec::with_capacity(256), decode: Vec::new(), accuracy_log: 0,
94 }
95 }
96
97 pub fn reinit_from(&mut self, other: &Self) {
99 self.reset();
100 self.symbol_counter.extend_from_slice(&other.symbol_counter);
101 self.symbol_probabilities
102 .extend_from_slice(&other.symbol_probabilities);
103 self.decode.extend_from_slice(&other.decode);
104 self.accuracy_log = other.accuracy_log;
105 }
106
107 pub fn reset(&mut self) {
109 self.symbol_counter.clear();
110 self.symbol_probabilities.clear();
111 self.decode.clear();
112 self.accuracy_log = 0;
113 }
114
115 pub fn build_decoder(&mut self, source: &[u8], max_log: u8) -> Result<usize, FSETableError> {
117 self.accuracy_log = 0;
118
119 let bytes_read = self.read_probabilities(source, max_log)?;
120 self.build_decoding_table()?;
121
122 Ok(bytes_read)
123 }
124
125 pub fn build_from_probabilities(
127 &mut self,
128 acc_log: u8,
129 probs: &[i32],
130 ) -> Result<(), FSETableError> {
131 if acc_log == 0 {
132 return Err(FSETableError::AccLogIsZero);
133 }
134 self.symbol_probabilities = probs.to_vec();
135 self.accuracy_log = acc_log;
136 self.build_decoding_table()
137 }
138
139 fn build_decoding_table(&mut self) -> Result<(), FSETableError> {
142 if self.symbol_probabilities.len() > self.max_symbol as usize + 1 {
143 return Err(FSETableError::TooManySymbols {
144 got: self.symbol_probabilities.len(),
145 });
146 }
147
148 self.decode.clear();
149
150 let table_size = 1 << self.accuracy_log;
151 if self.decode.len() < table_size {
152 self.decode.reserve(table_size - self.decode.len());
153 }
154 self.decode.resize(
156 table_size,
157 Entry {
158 base_line: 0,
159 num_bits: 0,
160 symbol: 0,
161 },
162 );
163
164 let mut negative_idx = table_size; for symbol in 0..self.symbol_probabilities.len() {
168 if self.symbol_probabilities[symbol] == -1 {
169 negative_idx -= 1;
170 let entry = &mut self.decode[negative_idx];
171 entry.symbol = symbol as u8;
172 entry.base_line = 0;
173 entry.num_bits = self.accuracy_log;
174 }
175 }
176
177 let mut position = 0;
179 for idx in 0..self.symbol_probabilities.len() {
180 let symbol = idx as u8;
181 if self.symbol_probabilities[idx] <= 0 {
182 continue;
183 }
184
185 let prob = self.symbol_probabilities[idx];
187 for _ in 0..prob {
188 let entry = &mut self.decode[position];
189 entry.symbol = symbol;
190
191 position = next_position(position, table_size);
192 while position >= negative_idx {
193 position = next_position(position, table_size);
194 }
196 }
197 }
198
199 self.symbol_counter.clear();
201 self.symbol_counter
202 .resize(self.symbol_probabilities.len(), 0);
203 for idx in 0..negative_idx {
204 let entry = &mut self.decode[idx];
205 let symbol = entry.symbol;
206 let prob = self.symbol_probabilities[symbol as usize];
207
208 let symbol_count = self.symbol_counter[symbol as usize];
209 let (bl, nb) = calc_baseline_and_numbits(table_size as u32, prob as u32, symbol_count);
210
211 assert!(nb <= self.accuracy_log);
214 self.symbol_counter[symbol as usize] += 1;
215
216 entry.base_line = bl;
217 entry.num_bits = nb;
218 }
219 Ok(())
220 }
221
222 fn read_probabilities(&mut self, source: &[u8], max_log: u8) -> Result<usize, FSETableError> {
225 self.symbol_probabilities.clear(); let mut br = BitReader::new(source);
228 self.accuracy_log = ACC_LOG_OFFSET + (br.get_bits(4)? as u8);
229 if self.accuracy_log > max_log {
230 return Err(FSETableError::AccLogTooBig {
231 got: self.accuracy_log,
232 max: max_log,
233 });
234 }
235 if self.accuracy_log == 0 {
236 return Err(FSETableError::AccLogIsZero);
237 }
238
239 let probability_sum = 1 << self.accuracy_log;
240 let mut probability_counter = 0;
241
242 while probability_counter < probability_sum {
243 let max_remaining_value = probability_sum - probability_counter + 1;
244 let bits_to_read = highest_bit_set(max_remaining_value);
245
246 let unchecked_value = br.get_bits(bits_to_read as usize)? as u32;
247
248 let low_threshold = ((1 << bits_to_read) - 1) - (max_remaining_value);
249 let mask = (1 << (bits_to_read - 1)) - 1;
250 let small_value = unchecked_value & mask;
251
252 let value = if small_value < low_threshold {
253 br.return_bits(1);
254 small_value
255 } else if unchecked_value > mask {
256 unchecked_value - low_threshold
257 } else {
258 unchecked_value
259 };
260 let prob = (value as i32) - 1;
263
264 self.symbol_probabilities.push(prob);
265 if prob != 0 {
266 if prob > 0 {
267 probability_counter += prob as u32;
268 } else {
269 assert!(prob == -1);
271 probability_counter += 1;
272 }
273 } else {
274 loop {
276 let skip_amount = br.get_bits(2)? as usize;
277
278 self.symbol_probabilities
279 .resize(self.symbol_probabilities.len() + skip_amount, 0);
280 if skip_amount != 3 {
281 break;
282 }
283 }
284 }
285 }
286
287 if probability_counter != probability_sum {
288 return Err(FSETableError::ProbabilityCounterMismatch {
289 got: probability_counter,
290 expected_sum: probability_sum,
291 symbol_probabilities: self.symbol_probabilities.clone(),
292 });
293 }
294 if self.symbol_probabilities.len() > self.max_symbol as usize + 1 {
295 return Err(FSETableError::TooManySymbols {
296 got: self.symbol_probabilities.len(),
297 });
298 }
299
300 let bytes_read = if br.bits_read() % 8 == 0 {
301 br.bits_read() / 8
302 } else {
303 (br.bits_read() / 8) + 1
304 };
305
306 Ok(bytes_read)
307 }
308}
309
310#[derive(Copy, Clone, Debug)]
312pub struct Entry {
313 pub base_line: u32,
316 pub num_bits: u8,
318 pub symbol: u8,
320}
321
322const ACC_LOG_OFFSET: u8 = 5;
325
326fn highest_bit_set(x: u32) -> u32 {
327 assert!(x > 0);
328 u32::BITS - x.leading_zeros()
329}
330
331fn next_position(mut p: usize, table_size: usize) -> usize {
335 p += (table_size >> 1) + (table_size >> 3) + 3;
336 p &= table_size - 1;
337 p
338}
339
340fn calc_baseline_and_numbits(
341 num_states_total: u32,
342 num_states_symbol: u32,
343 state_number: u32,
344) -> (u32, u8) {
345 if num_states_symbol == 0 {
346 return (0, 0);
347 }
348 let num_state_slices = if 1 << (highest_bit_set(num_states_symbol) - 1) == num_states_symbol {
349 num_states_symbol
350 } else {
351 1 << (highest_bit_set(num_states_symbol))
352 }; let num_double_width_state_slices = num_state_slices - num_states_symbol; let num_single_width_state_slices = num_states_symbol - num_double_width_state_slices; let slice_width = num_states_total / num_state_slices; let num_bits = highest_bit_set(slice_width) - 1; if state_number < num_double_width_state_slices {
360 let baseline = num_single_width_state_slices * slice_width + state_number * slice_width * 2;
361 (baseline, num_bits as u8 + 1)
362 } else {
363 let index_shifted = state_number - num_double_width_state_slices;
364 ((index_shifted * slice_width), num_bits as u8)
365 }
366}