ruzstd/huff0/
huff0_decoder.rs1use crate::bit_io::BitReaderReversed;
4use crate::decoding::errors::HuffmanTableError;
5use crate::fse::{FSEDecoder, FSETable};
6use alloc::vec::Vec;
7
8pub(crate) const MAX_MAX_NUM_BITS: u8 = 11;
10
11pub struct HuffmanDecoder<'table> {
12 table: &'table HuffmanTable,
13 pub state: u64,
15}
16
17impl<'t> HuffmanDecoder<'t> {
18 pub fn new(table: &'t HuffmanTable) -> HuffmanDecoder<'t> {
20 HuffmanDecoder { table, state: 0 }
21 }
22
23 pub fn decode_symbol(&mut self) -> u8 {
26 self.table.decode[self.state as usize].symbol
27 }
28
29 pub fn init_state(&mut self, br: &mut BitReaderReversed<'_>) -> u8 {
33 let num_bits = self.table.max_num_bits;
34 let new_bits = br.get_bits(num_bits);
35 self.state = new_bits;
36 num_bits
37 }
38
39 pub fn next_state(&mut self, br: &mut BitReaderReversed<'_>) -> u8 {
42 let num_bits = self.table.decode[self.state as usize].num_bits;
45 let new_bits = br.get_bits(num_bits);
47 self.state <<= num_bits;
49 self.state &= self.table.decode.len() as u64 - 1;
50 self.state |= new_bits;
52 num_bits
53 }
54}
55
56pub struct HuffmanTable {
58 decode: Vec<Entry>,
59 weights: Vec<u8>,
64 pub max_num_bits: u8,
69 bits: Vec<u8>,
70 bit_ranks: Vec<u32>,
71 rank_indexes: Vec<usize>,
72 fse_table: FSETable,
74}
75
76impl HuffmanTable {
77 pub fn new() -> HuffmanTable {
79 HuffmanTable {
80 decode: Vec::new(),
81
82 weights: Vec::with_capacity(256),
83 max_num_bits: 0,
84 bits: Vec::with_capacity(256),
85 bit_ranks: Vec::with_capacity(11),
86 rank_indexes: Vec::with_capacity(11),
87 fse_table: FSETable::new(255),
88 }
89 }
90
91 pub fn reinit_from(&mut self, other: &Self) {
94 self.reset();
95 self.decode.extend_from_slice(&other.decode);
96 self.weights.extend_from_slice(&other.weights);
97 self.max_num_bits = other.max_num_bits;
98 self.bits.extend_from_slice(&other.bits);
99 self.rank_indexes.extend_from_slice(&other.rank_indexes);
100 self.fse_table.reinit_from(&other.fse_table);
101 }
102
103 pub fn reset(&mut self) {
105 self.decode.clear();
106 self.weights.clear();
107 self.max_num_bits = 0;
108 self.bits.clear();
109 self.bit_ranks.clear();
110 self.rank_indexes.clear();
111 self.fse_table.reset();
112 }
113
114 pub fn build_decoder(&mut self, source: &[u8]) -> Result<u32, HuffmanTableError> {
118 self.decode.clear();
119
120 let bytes_used = self.read_weights(source)?;
121 self.build_table_from_weights()?;
122 Ok(bytes_used)
123 }
124
125 fn read_weights(&mut self, source: &[u8]) -> Result<u32, HuffmanTableError> {
133 use HuffmanTableError as err;
134
135 if source.is_empty() {
136 return Err(err::SourceIsEmpty);
137 }
138 let header = source[0];
139 let mut bits_read = 8;
140
141 match header {
142 0..=127 => {
146 let fse_stream = &source[1..];
147 if header as usize > fse_stream.len() {
148 return Err(err::NotEnoughBytesForWeights {
149 got_bytes: fse_stream.len(),
150 expected_bytes: header,
151 });
152 }
153 let bytes_used_by_fse_header = self.fse_table.build_decoder(fse_stream, 6)?;
155
156 if bytes_used_by_fse_header > header as usize {
157 return Err(err::FSETableUsedTooManyBytes {
158 used: bytes_used_by_fse_header,
159 available_bytes: header,
160 });
161 }
162
163 vprintln!(
164 "Building fse table for huffman weights used: {}",
165 bytes_used_by_fse_header
166 );
167 let mut dec1 = FSEDecoder::new(&self.fse_table);
171 let mut dec2 = FSEDecoder::new(&self.fse_table);
172
173 let compressed_start = bytes_used_by_fse_header;
174 let compressed_length = header as usize - bytes_used_by_fse_header;
175
176 let compressed_weights = &fse_stream[compressed_start..];
177 if compressed_weights.len() < compressed_length {
178 return Err(err::NotEnoughBytesToDecompressWeights {
179 have: compressed_weights.len(),
180 need: compressed_length,
181 });
182 }
183 let compressed_weights = &compressed_weights[..compressed_length];
184 let mut br = BitReaderReversed::new(compressed_weights);
185
186 bits_read += (bytes_used_by_fse_header + compressed_length) * 8;
187
188 let mut skipped_bits = 0;
190 loop {
191 let val = br.get_bits(1);
192 skipped_bits += 1;
193 if val == 1 || skipped_bits > 8 {
194 break;
195 }
196 }
197 if skipped_bits > 8 {
198 return Err(err::ExtraPadding { skipped_bits });
200 }
201
202 dec1.init_state(&mut br)?;
203 dec2.init_state(&mut br)?;
204
205 self.weights.clear();
206
207 loop {
209 let w = dec1.decode_symbol();
210 self.weights.push(w);
211 dec1.update_state(&mut br);
212
213 if br.bits_remaining() <= -1 {
214 self.weights.push(dec2.decode_symbol());
216 break;
217 }
218
219 let w = dec2.decode_symbol();
220 self.weights.push(w);
221 dec2.update_state(&mut br);
222
223 if br.bits_remaining() <= -1 {
224 self.weights.push(dec1.decode_symbol());
226 break;
227 }
228 if self.weights.len() > 255 {
230 return Err(err::TooManyWeights {
231 got: self.weights.len(),
232 });
233 }
234 }
235 }
236 _ => {
243 let weights_raw = &source[1..];
245 let num_weights = header - 127;
246 self.weights.resize(num_weights as usize, 0);
247
248 let bytes_needed = if num_weights % 2 == 0 {
249 num_weights as usize / 2
250 } else {
251 (num_weights as usize / 2) + 1
252 };
253
254 if weights_raw.len() < bytes_needed {
255 return Err(err::NotEnoughBytesInSource {
256 got: weights_raw.len(),
257 need: bytes_needed,
258 });
259 }
260
261 for idx in 0..num_weights {
262 if idx % 2 == 0 {
263 self.weights[idx as usize] = weights_raw[idx as usize / 2] >> 4;
264 } else {
265 self.weights[idx as usize] = weights_raw[idx as usize / 2] & 0xF;
266 }
267 bits_read += 4;
268 }
269 }
270 }
271
272 let bytes_read = if bits_read % 8 == 0 {
273 bits_read / 8
274 } else {
275 (bits_read / 8) + 1
276 };
277 Ok(bytes_read as u32)
278 }
279
280 fn build_table_from_weights(&mut self) -> Result<(), HuffmanTableError> {
285 use HuffmanTableError as err;
286
287 self.bits.clear();
288 self.bits.resize(self.weights.len() + 1, 0);
289
290 let mut weight_sum: u32 = 0;
291 for w in &self.weights {
292 if *w > MAX_MAX_NUM_BITS {
293 return Err(err::WeightBiggerThanMaxNumBits { got: *w });
294 }
295 weight_sum += if *w > 0 { 1_u32 << (*w - 1) } else { 0 };
296 }
297
298 if weight_sum == 0 {
299 return Err(err::MissingWeights);
300 }
301
302 let max_bits = highest_bit_set(weight_sum) as u8;
303 let left_over = (1 << max_bits) - weight_sum;
304
305 if !left_over.is_power_of_two() {
307 return Err(err::LeftoverIsNotAPowerOf2 { got: left_over });
308 }
309
310 let last_weight = highest_bit_set(left_over) as u8;
311
312 for symbol in 0..self.weights.len() {
313 let bits = if self.weights[symbol] > 0 {
314 max_bits + 1 - self.weights[symbol]
315 } else {
316 0
317 };
318 self.bits[symbol] = bits;
319 }
320
321 self.bits[self.weights.len()] = max_bits + 1 - last_weight;
322 self.max_num_bits = max_bits;
323
324 if max_bits > MAX_MAX_NUM_BITS {
325 return Err(err::MaxBitsTooHigh { got: max_bits });
326 }
327
328 self.bit_ranks.clear();
329 self.bit_ranks.resize((max_bits + 1) as usize, 0);
330 for num_bits in &self.bits {
331 self.bit_ranks[(*num_bits) as usize] += 1;
332 }
333
334 self.decode.resize(
336 1 << self.max_num_bits,
337 Entry {
338 symbol: 0,
339 num_bits: 0,
340 },
341 );
342
343 self.rank_indexes.clear();
345 self.rank_indexes.resize((max_bits + 1) as usize, 0);
346
347 self.rank_indexes[max_bits as usize] = 0;
348 for bits in (1..self.rank_indexes.len() as u8).rev() {
349 self.rank_indexes[bits as usize - 1] = self.rank_indexes[bits as usize]
350 + self.bit_ranks[bits as usize] as usize * (1 << (max_bits - bits));
351 }
352
353 assert!(
354 self.rank_indexes[0] == self.decode.len(),
355 "rank_idx[0]: {} should be: {}",
356 self.rank_indexes[0],
357 self.decode.len()
358 );
359
360 for symbol in 0..self.bits.len() {
361 let bits_for_symbol = self.bits[symbol];
362 if bits_for_symbol != 0 {
363 let base_idx = self.rank_indexes[bits_for_symbol as usize];
367 let len = 1 << (max_bits - bits_for_symbol);
368 self.rank_indexes[bits_for_symbol as usize] += len;
369 for idx in 0..len {
370 self.decode[base_idx + idx].symbol = symbol as u8;
371 self.decode[base_idx + idx].num_bits = bits_for_symbol;
372 }
373 }
374 }
375
376 Ok(())
377 }
378}
379
380impl Default for HuffmanTable {
381 fn default() -> Self {
382 Self::new()
383 }
384}
385
386#[derive(Copy, Clone, Debug)]
389pub struct Entry {
390 symbol: u8,
392 num_bits: u8,
394}
395
396fn highest_bit_set(x: u32) -> u32 {
399 assert!(x > 0);
400 u32::BITS - x.leading_zeros()
401}