ruzstd/decoding/
bit_reader.rs1use super::errors::GetBitsError;
2
3pub struct BitReader<'s> {
5 idx: usize, source: &'s [u8],
7}
8
9#[cfg(feature = "std")]
10impl std::error::Error for GetBitsError {}
11
12impl core::fmt::Display for GetBitsError {
13 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
14 match self {
15 GetBitsError::TooManyBits {
16 num_requested_bits,
17 limit,
18 } => {
19 write!(
20 f,
21 "Cant serve this request. The reader is limited to {} bits, requested {} bits",
22 limit, num_requested_bits,
23 )
24 }
25 GetBitsError::NotEnoughRemainingBits {
26 requested,
27 remaining,
28 } => {
29 write!(
30 f,
31 "Can\'t read {} bits, only have {} bits left",
32 requested, remaining,
33 )
34 }
35 }
36 }
37}
38
39impl<'s> BitReader<'s> {
40 pub fn new(source: &'s [u8]) -> BitReader<'s> {
41 BitReader { idx: 0, source }
42 }
43
44 pub fn bits_left(&self) -> usize {
45 self.source.len() * 8 - self.idx
46 }
47
48 pub fn bits_read(&self) -> usize {
49 self.idx
50 }
51
52 pub fn return_bits(&mut self, n: usize) {
53 if n > self.idx {
54 panic!("Cant return this many bits");
55 }
56 self.idx -= n;
57 }
58
59 pub fn get_bits(&mut self, n: usize) -> Result<u64, GetBitsError> {
60 if n > 64 {
61 return Err(GetBitsError::TooManyBits {
62 num_requested_bits: n,
63 limit: 64,
64 });
65 }
66 if self.bits_left() < n {
67 return Err(GetBitsError::NotEnoughRemainingBits {
68 requested: n,
69 remaining: self.bits_left(),
70 });
71 }
72
73 let old_idx = self.idx;
74
75 let bits_left_in_current_byte = 8 - (self.idx % 8);
76 let bits_not_needed_in_current_byte = 8 - bits_left_in_current_byte;
77
78 let mut value = u64::from(self.source[self.idx / 8] >> bits_not_needed_in_current_byte);
80
81 if bits_left_in_current_byte >= n {
82 value &= (1 << n) - 1;
86 self.idx += n;
87 } else {
88 self.idx += bits_left_in_current_byte;
89
90 let full_bytes_needed = (n - bits_left_in_current_byte) / 8;
92 let bits_in_last_byte_needed = n - bits_left_in_current_byte - full_bytes_needed * 8;
93
94 assert!(
95 bits_left_in_current_byte + full_bytes_needed * 8 + bits_in_last_byte_needed == n
96 );
97
98 let mut bit_shift = bits_left_in_current_byte; assert!(self.idx % 8 == 0);
101
102 for _ in 0..full_bytes_needed {
104 value |= u64::from(self.source[self.idx / 8]) << bit_shift;
105 self.idx += 8;
106 bit_shift += 8;
107 }
108
109 assert!(n - bit_shift == bits_in_last_byte_needed);
110
111 if bits_in_last_byte_needed > 0 {
112 let val_las_byte =
113 u64::from(self.source[self.idx / 8]) & ((1 << bits_in_last_byte_needed) - 1);
114 value |= val_las_byte << bit_shift;
115 self.idx += bits_in_last_byte_needed;
116 }
117 }
118
119 assert!(self.idx == old_idx + n);
120
121 Ok(value)
122 }
123}