1use crate::decoding::bit_reader_reverse::{BitReaderReversed, GetBitsError};
4use crate::fse::{FSEDecoder, FSEDecoderError, FSETable, FSETableError};
5use alloc::vec::Vec;
6#[cfg(feature = "std")]
7use std::error::Error as StdError;
8
9pub struct HuffmanTable {
10 decode: Vec<Entry>,
11 weights: Vec<u8>,
15 pub max_num_bits: u8,
20 bits: Vec<u8>,
21 bit_ranks: Vec<u32>,
22 rank_indexes: Vec<usize>,
23 fse_table: FSETable,
25}
26
27#[derive(Debug)]
28#[non_exhaustive]
29pub enum HuffmanTableError {
30 GetBitsError(GetBitsError),
31 FSEDecoderError(FSEDecoderError),
32 FSETableError(FSETableError),
33 SourceIsEmpty,
34 NotEnoughBytesForWeights {
35 got_bytes: usize,
36 expected_bytes: u8,
37 },
38 ExtraPadding {
39 skipped_bits: i32,
40 },
41 TooManyWeights {
42 got: usize,
43 },
44 MissingWeights,
45 LeftoverIsNotAPowerOf2 {
46 got: u32,
47 },
48 NotEnoughBytesToDecompressWeights {
49 have: usize,
50 need: usize,
51 },
52 FSETableUsedTooManyBytes {
53 used: usize,
54 available_bytes: u8,
55 },
56 NotEnoughBytesInSource {
57 got: usize,
58 need: usize,
59 },
60 WeightBiggerThanMaxNumBits {
61 got: u8,
62 },
63 MaxBitsTooHigh {
64 got: u8,
65 },
66}
67
68#[cfg(feature = "std")]
69impl StdError for HuffmanTableError {
70 fn source(&self) -> Option<&(dyn StdError + 'static)> {
71 match self {
72 HuffmanTableError::GetBitsError(source) => Some(source),
73 HuffmanTableError::FSEDecoderError(source) => Some(source),
74 HuffmanTableError::FSETableError(source) => Some(source),
75 _ => None,
76 }
77 }
78}
79
80impl core::fmt::Display for HuffmanTableError {
81 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> ::core::fmt::Result {
82 match self {
83 HuffmanTableError::GetBitsError(e) => write!(f, "{:?}", e),
84 HuffmanTableError::FSEDecoderError(e) => write!(f, "{:?}", e),
85 HuffmanTableError::FSETableError(e) => write!(f, "{:?}", e),
86 HuffmanTableError::SourceIsEmpty => write!(f, "Source needs to have at least one byte"),
87 HuffmanTableError::NotEnoughBytesForWeights {
88 got_bytes,
89 expected_bytes,
90 } => {
91 write!(f, "Header says there should be {} bytes for the weights but there are only {} bytes in the stream",
92 expected_bytes,
93 got_bytes)
94 }
95 HuffmanTableError::ExtraPadding { skipped_bits } => {
96 write!(f,
97 "Padding at the end of the sequence_section was more than a byte long: {} bits. Probably caused by data corruption",
98 skipped_bits,
99 )
100 }
101 HuffmanTableError::TooManyWeights { got } => {
102 write!(
103 f,
104 "More than 255 weights decoded (got {} weights). Stream is probably corrupted",
105 got,
106 )
107 }
108 HuffmanTableError::MissingWeights => {
109 write!(f, "Can\'t build huffman table without any weights")
110 }
111 HuffmanTableError::LeftoverIsNotAPowerOf2 { got } => {
112 write!(f, "Leftover must be power of two but is: {}", got)
113 }
114 HuffmanTableError::NotEnoughBytesToDecompressWeights { have, need } => {
115 write!(
116 f,
117 "Not enough bytes in stream to decompress weights. Is: {}, Should be: {}",
118 have, need,
119 )
120 }
121 HuffmanTableError::FSETableUsedTooManyBytes {
122 used,
123 available_bytes,
124 } => {
125 write!(f,
126 "FSE table used more bytes: {} than were meant to be used for the whole stream of huffman weights ({})",
127 used,
128 available_bytes,
129 )
130 }
131 HuffmanTableError::NotEnoughBytesInSource { got, need } => {
132 write!(
133 f,
134 "Source needs to have at least {} bytes, got: {}",
135 need, got,
136 )
137 }
138 HuffmanTableError::WeightBiggerThanMaxNumBits { got } => {
139 write!(
140 f,
141 "Cant have weight: {} bigger than max_num_bits: {}",
142 got, MAX_MAX_NUM_BITS,
143 )
144 }
145 HuffmanTableError::MaxBitsTooHigh { got } => {
146 write!(
147 f,
148 "max_bits derived from weights is: {} should be lower than: {}",
149 got, MAX_MAX_NUM_BITS,
150 )
151 }
152 }
153 }
154}
155
156impl From<GetBitsError> for HuffmanTableError {
157 fn from(val: GetBitsError) -> Self {
158 Self::GetBitsError(val)
159 }
160}
161
162impl From<FSEDecoderError> for HuffmanTableError {
163 fn from(val: FSEDecoderError) -> Self {
164 Self::FSEDecoderError(val)
165 }
166}
167
168impl From<FSETableError> for HuffmanTableError {
169 fn from(val: FSETableError) -> Self {
170 Self::FSETableError(val)
171 }
172}
173
174pub struct HuffmanDecoder<'table> {
176 table: &'table HuffmanTable,
177 pub state: u64,
179}
180
181#[derive(Debug)]
182#[non_exhaustive]
183pub enum HuffmanDecoderError {
184 GetBitsError(GetBitsError),
185}
186
187impl core::fmt::Display for HuffmanDecoderError {
188 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
189 match self {
190 HuffmanDecoderError::GetBitsError(e) => write!(f, "{:?}", e),
191 }
192 }
193}
194
195#[cfg(feature = "std")]
196impl StdError for HuffmanDecoderError {
197 fn source(&self) -> Option<&(dyn StdError + 'static)> {
198 match self {
199 HuffmanDecoderError::GetBitsError(source) => Some(source),
200 }
201 }
202}
203
204impl From<GetBitsError> for HuffmanDecoderError {
205 fn from(val: GetBitsError) -> Self {
206 Self::GetBitsError(val)
207 }
208}
209
210#[derive(Copy, Clone)]
213pub struct Entry {
214 symbol: u8,
216 num_bits: u8,
218}
219
220const MAX_MAX_NUM_BITS: u8 = 11;
222
223fn highest_bit_set(x: u32) -> u32 {
226 assert!(x > 0);
227 u32::BITS - x.leading_zeros()
228}
229
230impl<'t> HuffmanDecoder<'t> {
231 pub fn new(table: &'t HuffmanTable) -> HuffmanDecoder<'t> {
233 HuffmanDecoder { table, state: 0 }
234 }
235
236 pub fn reset(mut self, new_table: Option<&'t HuffmanTable>) {
240 self.state = 0;
241 if let Some(next_table) = new_table {
242 self.table = next_table;
243 }
244 }
245
246 pub fn decode_symbol(&mut self) -> u8 {
249 self.table.decode[self.state as usize].symbol
250 }
251
252 pub fn init_state(&mut self, br: &mut BitReaderReversed<'_>) -> u8 {
256 let num_bits = self.table.max_num_bits;
257 let new_bits = br.get_bits(num_bits);
258 self.state = new_bits;
259 num_bits
260 }
261
262 pub fn next_state(&mut self, br: &mut BitReaderReversed<'_>) -> u8 {
265 let num_bits = self.table.decode[self.state as usize].num_bits;
268 let new_bits = br.get_bits(num_bits);
270 self.state <<= num_bits;
272 self.state &= self.table.decode.len() as u64 - 1;
273 self.state |= new_bits;
275 num_bits
276 }
277}
278
279impl Default for HuffmanTable {
280 fn default() -> Self {
281 Self::new()
282 }
283}
284
285impl HuffmanTable {
286 pub fn new() -> HuffmanTable {
288 HuffmanTable {
289 decode: Vec::new(),
290
291 weights: Vec::with_capacity(256),
292 max_num_bits: 0,
293 bits: Vec::with_capacity(256),
294 bit_ranks: Vec::with_capacity(11),
295 rank_indexes: Vec::with_capacity(11),
296 fse_table: FSETable::new(100),
297 }
298 }
299
300 pub fn reinit_from(&mut self, other: &Self) {
303 self.reset();
304 self.decode.extend_from_slice(&other.decode);
305 self.weights.extend_from_slice(&other.weights);
306 self.max_num_bits = other.max_num_bits;
307 self.bits.extend_from_slice(&other.bits);
308 self.rank_indexes.extend_from_slice(&other.rank_indexes);
309 self.fse_table.reinit_from(&other.fse_table);
310 }
311
312 pub fn reset(&mut self) {
314 self.decode.clear();
315 self.weights.clear();
316 self.max_num_bits = 0;
317 self.bits.clear();
318 self.bit_ranks.clear();
319 self.rank_indexes.clear();
320 self.fse_table.reset();
321 }
322
323 pub fn build_decoder(&mut self, source: &[u8]) -> Result<u32, HuffmanTableError> {
327 self.decode.clear();
328
329 let bytes_used = self.read_weights(source)?;
330 self.build_table_from_weights()?;
331 Ok(bytes_used)
332 }
333
334 fn read_weights(&mut self, source: &[u8]) -> Result<u32, HuffmanTableError> {
342 use HuffmanTableError as err;
343
344 if source.is_empty() {
345 return Err(err::SourceIsEmpty);
346 }
347 let header = source[0];
348 let mut bits_read = 8;
349
350 match header {
351 0..=127 => {
355 let fse_stream = &source[1..];
356 if header as usize > fse_stream.len() {
357 return Err(err::NotEnoughBytesForWeights {
358 got_bytes: fse_stream.len(),
359 expected_bytes: header,
360 });
361 }
362 let bytes_used_by_fse_header = self
364 .fse_table
365 .build_decoder(fse_stream, 100)?;
366
367 if bytes_used_by_fse_header > header as usize {
368 return Err(err::FSETableUsedTooManyBytes {
369 used: bytes_used_by_fse_header,
370 available_bytes: header,
371 });
372 }
373
374 vprintln!(
375 "Building fse table for huffman weights used: {}",
376 bytes_used_by_fse_header
377 );
378 let mut dec1 = FSEDecoder::new(&self.fse_table);
382 let mut dec2 = FSEDecoder::new(&self.fse_table);
383
384 let compressed_start = bytes_used_by_fse_header;
385 let compressed_length = header as usize - bytes_used_by_fse_header;
386
387 let compressed_weights = &fse_stream[compressed_start..];
388 if compressed_weights.len() < compressed_length {
389 return Err(err::NotEnoughBytesToDecompressWeights {
390 have: compressed_weights.len(),
391 need: compressed_length,
392 });
393 }
394 let compressed_weights = &compressed_weights[..compressed_length];
395 let mut br = BitReaderReversed::new(compressed_weights);
396
397 bits_read += (bytes_used_by_fse_header + compressed_length) * 8;
398
399 let mut skipped_bits = 0;
401 loop {
402 let val = br.get_bits(1);
403 skipped_bits += 1;
404 if val == 1 || skipped_bits > 8 {
405 break;
406 }
407 }
408 if skipped_bits > 8 {
409 return Err(err::ExtraPadding { skipped_bits });
411 }
412
413 dec1.init_state(&mut br)?;
414 dec2.init_state(&mut br)?;
415
416 self.weights.clear();
417
418 loop {
420 let w = dec1.decode_symbol();
421 self.weights.push(w);
422 dec1.update_state(&mut br);
423
424 if br.bits_remaining() <= -1 {
425 self.weights.push(dec2.decode_symbol());
427 break;
428 }
429
430 let w = dec2.decode_symbol();
431 self.weights.push(w);
432 dec2.update_state(&mut br);
433
434 if br.bits_remaining() <= -1 {
435 self.weights.push(dec1.decode_symbol());
437 break;
438 }
439 if self.weights.len() > 255 {
441 return Err(err::TooManyWeights {
442 got: self.weights.len(),
443 });
444 }
445 }
446 }
447 _ => {
454 let weights_raw = &source[1..];
456 let num_weights = header - 127;
457 self.weights.resize(num_weights as usize, 0);
458
459 let bytes_needed = if num_weights % 2 == 0 {
460 num_weights as usize / 2
461 } else {
462 (num_weights as usize / 2) + 1
463 };
464
465 if weights_raw.len() < bytes_needed {
466 return Err(err::NotEnoughBytesInSource {
467 got: weights_raw.len(),
468 need: bytes_needed,
469 });
470 }
471
472 for idx in 0..num_weights {
473 if idx % 2 == 0 {
474 self.weights[idx as usize] = weights_raw[idx as usize / 2] >> 4;
475 } else {
476 self.weights[idx as usize] = weights_raw[idx as usize / 2] & 0xF;
477 }
478 bits_read += 4;
479 }
480 }
481 }
482
483 let bytes_read = if bits_read % 8 == 0 {
484 bits_read / 8
485 } else {
486 (bits_read / 8) + 1
487 };
488 Ok(bytes_read as u32)
489 }
490
491 fn build_table_from_weights(&mut self) -> Result<(), HuffmanTableError> {
496 use HuffmanTableError as err;
497
498 self.bits.clear();
499 self.bits.resize(self.weights.len() + 1, 0);
500
501 let mut weight_sum: u32 = 0;
502 for w in &self.weights {
503 if *w > MAX_MAX_NUM_BITS {
504 return Err(err::WeightBiggerThanMaxNumBits { got: *w });
505 }
506 weight_sum += if *w > 0 { 1_u32 << (*w - 1) } else { 0 };
507 }
508
509 if weight_sum == 0 {
510 return Err(err::MissingWeights);
511 }
512
513 let max_bits = highest_bit_set(weight_sum) as u8;
514 let left_over = (1 << max_bits) - weight_sum;
515
516 if !left_over.is_power_of_two() {
518 return Err(err::LeftoverIsNotAPowerOf2 { got: left_over });
519 }
520
521 let last_weight = highest_bit_set(left_over) as u8;
522
523 for symbol in 0..self.weights.len() {
524 let bits = if self.weights[symbol] > 0 {
525 max_bits + 1 - self.weights[symbol]
526 } else {
527 0
528 };
529 self.bits[symbol] = bits;
530 }
531
532 self.bits[self.weights.len()] = max_bits + 1 - last_weight;
533 self.max_num_bits = max_bits;
534
535 if max_bits > MAX_MAX_NUM_BITS {
536 return Err(err::MaxBitsTooHigh { got: max_bits });
537 }
538
539 self.bit_ranks.clear();
540 self.bit_ranks.resize((max_bits + 1) as usize, 0);
541 for num_bits in &self.bits {
542 self.bit_ranks[(*num_bits) as usize] += 1;
543 }
544
545 self.decode.resize(
547 1 << self.max_num_bits,
548 Entry {
549 symbol: 0,
550 num_bits: 0,
551 },
552 );
553
554 self.rank_indexes.clear();
556 self.rank_indexes.resize((max_bits + 1) as usize, 0);
557
558 self.rank_indexes[max_bits as usize] = 0;
559 for bits in (1..self.rank_indexes.len() as u8).rev() {
560 self.rank_indexes[bits as usize - 1] = self.rank_indexes[bits as usize]
561 + self.bit_ranks[bits as usize] as usize * (1 << (max_bits - bits));
562 }
563
564 assert!(
565 self.rank_indexes[0] == self.decode.len(),
566 "rank_idx[0]: {} should be: {}",
567 self.rank_indexes[0],
568 self.decode.len()
569 );
570
571 for symbol in 0..self.bits.len() {
572 let bits_for_symbol = self.bits[symbol];
573 if bits_for_symbol != 0 {
574 let base_idx = self.rank_indexes[bits_for_symbol as usize];
578 let len = 1 << (max_bits - bits_for_symbol);
579 self.rank_indexes[bits_for_symbol as usize] += len;
580 for idx in 0..len {
581 self.decode[base_idx + idx].symbol = symbol as u8;
582 self.decode[base_idx + idx].num_bits = bits_for_symbol;
583 }
584 }
585 }
586
587 Ok(())
588 }
589}