1use crate::decoding::bit_reader::BitReader;
2use crate::decoding::bit_reader_reverse::{BitReaderReversed, GetBitsError};
3use alloc::vec::Vec;
4
5pub struct FSETable {
10 max_symbol: u8,
12 pub decode: Vec<Entry>, pub accuracy_log: u8,
19 pub symbol_probabilities: Vec<i32>, symbol_counter: Vec<u32>,
33}
34
35#[derive(Debug)]
36#[non_exhaustive]
37pub enum FSETableError {
38 AccLogIsZero,
39 AccLogTooBig {
40 got: u8,
41 max: u8,
42 },
43 GetBitsError(GetBitsError),
44 ProbabilityCounterMismatch {
45 got: u32,
46 expected_sum: u32,
47 symbol_probabilities: Vec<i32>,
48 },
49 TooManySymbols {
50 got: usize,
51 },
52}
53
54#[cfg(feature = "std")]
55impl std::error::Error for FSETableError {
56 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
57 match self {
58 FSETableError::GetBitsError(source) => Some(source),
59 _ => None,
60 }
61 }
62}
63
64impl core::fmt::Display for FSETableError {
65 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
66 match self {
67 FSETableError::AccLogIsZero => write!(f, "Acclog must be at least 1"),
68 FSETableError::AccLogTooBig { got, max } => {
69 write!(
70 f,
71 "Found FSE acc_log: {0} bigger than allowed maximum in this case: {1}",
72 got, max
73 )
74 }
75 FSETableError::GetBitsError(e) => write!(f, "{:?}", e),
76 FSETableError::ProbabilityCounterMismatch {
77 got,
78 expected_sum,
79 symbol_probabilities,
80 } => {
81 write!(f,
82 "The counter ({}) exceeded the expected sum: {}. This means an error or corrupted data \n {:?}",
83 got,
84 expected_sum,
85 symbol_probabilities,
86 )
87 }
88 FSETableError::TooManySymbols { got } => {
89 write!(
90 f,
91 "There are too many symbols in this distribution: {}. Max: 256",
92 got,
93 )
94 }
95 }
96 }
97}
98
99impl From<GetBitsError> for FSETableError {
100 fn from(val: GetBitsError) -> Self {
101 Self::GetBitsError(val)
102 }
103}
104
105pub struct FSEDecoder<'table> {
106 pub state: Entry,
108 table: &'table FSETable,
110}
111
112#[derive(Debug)]
113#[non_exhaustive]
114pub enum FSEDecoderError {
115 GetBitsError(GetBitsError),
116 TableIsUninitialized,
117}
118
119#[cfg(feature = "std")]
120impl std::error::Error for FSEDecoderError {
121 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
122 match self {
123 FSEDecoderError::GetBitsError(source) => Some(source),
124 _ => None,
125 }
126 }
127}
128
129impl core::fmt::Display for FSEDecoderError {
130 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
131 match self {
132 FSEDecoderError::GetBitsError(e) => write!(f, "{:?}", e),
133 FSEDecoderError::TableIsUninitialized => {
134 write!(f, "Tried to use an uninitialized table!")
135 }
136 }
137 }
138}
139
140impl From<GetBitsError> for FSEDecoderError {
141 fn from(val: GetBitsError) -> Self {
142 Self::GetBitsError(val)
143 }
144}
145
146#[derive(Copy, Clone)]
148pub struct Entry {
149 pub base_line: u32,
152 pub num_bits: u8,
154 pub symbol: u8,
156}
157
158const ACC_LOG_OFFSET: u8 = 5;
161
162fn highest_bit_set(x: u32) -> u32 {
163 assert!(x > 0);
164 u32::BITS - x.leading_zeros()
165}
166
167impl<'t> FSEDecoder<'t> {
168 pub fn new(table: &'t FSETable) -> FSEDecoder<'t> {
170 FSEDecoder {
171 state: table.decode.first().copied().unwrap_or(Entry {
172 base_line: 0,
173 num_bits: 0,
174 symbol: 0,
175 }),
176 table,
177 }
178 }
179
180 pub fn decode_symbol(&self) -> u8 {
182 self.state.symbol
183 }
184
185 pub fn init_state(&mut self, bits: &mut BitReaderReversed<'_>) -> Result<(), FSEDecoderError> {
188 if self.table.accuracy_log == 0 {
189 return Err(FSEDecoderError::TableIsUninitialized);
190 }
191 self.state = self.table.decode[bits.get_bits(self.table.accuracy_log) as usize];
192
193 Ok(())
194 }
195
196 pub fn update_state(&mut self, bits: &mut BitReaderReversed<'_>) {
198 let num_bits = self.state.num_bits;
199 let add = bits.get_bits(num_bits);
200 let base_line = self.state.base_line;
201 let new_state = base_line + add as u32;
202 self.state = self.table.decode[new_state as usize];
203
204 }
206}
207
208impl FSETable {
209 pub fn new(max_symbol: u8) -> FSETable {
211 FSETable {
212 max_symbol,
213 symbol_probabilities: Vec::with_capacity(256), symbol_counter: Vec::with_capacity(256), decode: Vec::new(), accuracy_log: 0,
217 }
218 }
219
220 pub fn reinit_from(&mut self, other: &Self) {
222 self.reset();
223 self.symbol_counter.extend_from_slice(&other.symbol_counter);
224 self.symbol_probabilities
225 .extend_from_slice(&other.symbol_probabilities);
226 self.decode.extend_from_slice(&other.decode);
227 self.accuracy_log = other.accuracy_log;
228 }
229
230 pub fn reset(&mut self) {
232 self.symbol_counter.clear();
233 self.symbol_probabilities.clear();
234 self.decode.clear();
235 self.accuracy_log = 0;
236 }
237
238 pub fn build_decoder(&mut self, source: &[u8], max_log: u8) -> Result<usize, FSETableError> {
240 self.accuracy_log = 0;
241
242 let bytes_read = self.read_probabilities(source, max_log)?;
243 self.build_decoding_table()?;
244
245 Ok(bytes_read)
246 }
247
248 pub fn build_from_probabilities(
250 &mut self,
251 acc_log: u8,
252 probs: &[i32],
253 ) -> Result<(), FSETableError> {
254 if acc_log == 0 {
255 return Err(FSETableError::AccLogIsZero);
256 }
257 self.symbol_probabilities = probs.to_vec();
258 self.accuracy_log = acc_log;
259 self.build_decoding_table()
260 }
261
262 fn build_decoding_table(&mut self) -> Result<(), FSETableError> {
265 if self.symbol_probabilities.len() > self.max_symbol as usize + 1 {
266 return Err(FSETableError::TooManySymbols {
267 got: self.symbol_probabilities.len(),
268 });
269 }
270
271 self.decode.clear();
272
273 let table_size = 1 << self.accuracy_log;
274 if self.decode.len() < table_size {
275 self.decode.reserve(table_size - self.decode.len());
276 }
277 self.decode.resize(
279 table_size,
280 Entry {
281 base_line: 0,
282 num_bits: 0,
283 symbol: 0,
284 },
285 );
286
287 let mut negative_idx = table_size; for symbol in 0..self.symbol_probabilities.len() {
291 if self.symbol_probabilities[symbol] == -1 {
292 negative_idx -= 1;
293 let entry = &mut self.decode[negative_idx];
294 entry.symbol = symbol as u8;
295 entry.base_line = 0;
296 entry.num_bits = self.accuracy_log;
297 }
298 }
299
300 let mut position = 0;
302 for idx in 0..self.symbol_probabilities.len() {
303 let symbol = idx as u8;
304 if self.symbol_probabilities[idx] <= 0 {
305 continue;
306 }
307
308 let prob = self.symbol_probabilities[idx];
310 for _ in 0..prob {
311 let entry = &mut self.decode[position];
312 entry.symbol = symbol;
313
314 position = next_position(position, table_size);
315 while position >= negative_idx {
316 position = next_position(position, table_size);
317 }
319 }
320 }
321
322 self.symbol_counter.clear();
324 self.symbol_counter
325 .resize(self.symbol_probabilities.len(), 0);
326 for idx in 0..negative_idx {
327 let entry = &mut self.decode[idx];
328 let symbol = entry.symbol;
329 let prob = self.symbol_probabilities[symbol as usize];
330
331 let symbol_count = self.symbol_counter[symbol as usize];
332 let (bl, nb) = calc_baseline_and_numbits(table_size as u32, prob as u32, symbol_count);
333
334 assert!(nb <= self.accuracy_log);
337 self.symbol_counter[symbol as usize] += 1;
338
339 entry.base_line = bl;
340 entry.num_bits = nb;
341 }
342 Ok(())
343 }
344
345 fn read_probabilities(&mut self, source: &[u8], max_log: u8) -> Result<usize, FSETableError> {
348 self.symbol_probabilities.clear(); let mut br = BitReader::new(source);
351 self.accuracy_log = ACC_LOG_OFFSET + (br.get_bits(4)? as u8);
352 if self.accuracy_log > max_log {
353 return Err(FSETableError::AccLogTooBig {
354 got: self.accuracy_log,
355 max: max_log,
356 });
357 }
358 if self.accuracy_log == 0 {
359 return Err(FSETableError::AccLogIsZero);
360 }
361
362 let probability_sum = 1 << self.accuracy_log;
363 let mut probability_counter = 0;
364
365 while probability_counter < probability_sum {
366 let max_remaining_value = probability_sum - probability_counter + 1;
367 let bits_to_read = highest_bit_set(max_remaining_value);
368
369 let unchecked_value = br.get_bits(bits_to_read as usize)? as u32;
370
371 let low_threshold = ((1 << bits_to_read) - 1) - (max_remaining_value);
372 let mask = (1 << (bits_to_read - 1)) - 1;
373 let small_value = unchecked_value & mask;
374
375 let value = if small_value < low_threshold {
376 br.return_bits(1);
377 small_value
378 } else if unchecked_value > mask {
379 unchecked_value - low_threshold
380 } else {
381 unchecked_value
382 };
383 let prob = (value as i32) - 1;
386
387 self.symbol_probabilities.push(prob);
388 if prob != 0 {
389 if prob > 0 {
390 probability_counter += prob as u32;
391 } else {
392 assert!(prob == -1);
394 probability_counter += 1;
395 }
396 } else {
397 loop {
399 let skip_amount = br.get_bits(2)? as usize;
400
401 self.symbol_probabilities
402 .resize(self.symbol_probabilities.len() + skip_amount, 0);
403 if skip_amount != 3 {
404 break;
405 }
406 }
407 }
408 }
409
410 if probability_counter != probability_sum {
411 return Err(FSETableError::ProbabilityCounterMismatch {
412 got: probability_counter,
413 expected_sum: probability_sum,
414 symbol_probabilities: self.symbol_probabilities.clone(),
415 });
416 }
417 if self.symbol_probabilities.len() > self.max_symbol as usize + 1 {
418 return Err(FSETableError::TooManySymbols {
419 got: self.symbol_probabilities.len(),
420 });
421 }
422
423 let bytes_read = if br.bits_read() % 8 == 0 {
424 br.bits_read() / 8
425 } else {
426 (br.bits_read() / 8) + 1
427 };
428
429 Ok(bytes_read)
430 }
431}
432
433fn next_position(mut p: usize, table_size: usize) -> usize {
437 p += (table_size >> 1) + (table_size >> 3) + 3;
438 p &= table_size - 1;
439 p
440}
441
442fn calc_baseline_and_numbits(
443 num_states_total: u32,
444 num_states_symbol: u32,
445 state_number: u32,
446) -> (u32, u8) {
447 let num_state_slices = if 1 << (highest_bit_set(num_states_symbol) - 1) == num_states_symbol {
448 num_states_symbol
449 } else {
450 1 << (highest_bit_set(num_states_symbol))
451 }; let num_double_width_state_slices = num_state_slices - num_states_symbol; let num_single_width_state_slices = num_states_symbol - num_double_width_state_slices; let slice_width = num_states_total / num_state_slices; let num_bits = highest_bit_set(slice_width) - 1; if state_number < num_double_width_state_slices {
459 let baseline = num_single_width_state_slices * slice_width + state_number * slice_width * 2;
460 (baseline, num_bits as u8 + 1)
461 } else {
462 let index_shifted = state_number - num_double_width_state_slices;
463 ((index_shifted * slice_width), num_bits as u8)
464 }
465}