ruzstd/decoding/
bit_reader.rs1pub struct BitReader<'s> {
3 idx: usize, source: &'s [u8],
5}
6
7#[derive(Debug)]
8#[non_exhaustive]
9pub enum GetBitsError {
10 TooManyBits {
11 num_requested_bits: usize,
12 limit: u8,
13 },
14 NotEnoughRemainingBits {
15 requested: usize,
16 remaining: usize,
17 },
18}
19
20#[cfg(feature = "std")]
21impl std::error::Error for GetBitsError {}
22
23impl core::fmt::Display for GetBitsError {
24 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
25 match self {
26 GetBitsError::TooManyBits {
27 num_requested_bits,
28 limit,
29 } => {
30 write!(
31 f,
32 "Cant serve this request. The reader is limited to {} bits, requested {} bits",
33 limit, num_requested_bits,
34 )
35 }
36 GetBitsError::NotEnoughRemainingBits {
37 requested,
38 remaining,
39 } => {
40 write!(
41 f,
42 "Can\'t read {} bits, only have {} bits left",
43 requested, remaining,
44 )
45 }
46 }
47 }
48}
49
50impl<'s> BitReader<'s> {
51 pub fn new(source: &'s [u8]) -> BitReader<'s> {
52 BitReader { idx: 0, source }
53 }
54
55 pub fn bits_left(&self) -> usize {
56 self.source.len() * 8 - self.idx
57 }
58
59 pub fn bits_read(&self) -> usize {
60 self.idx
61 }
62
63 pub fn return_bits(&mut self, n: usize) {
64 if n > self.idx {
65 panic!("Cant return this many bits");
66 }
67 self.idx -= n;
68 }
69
70 pub fn get_bits(&mut self, n: usize) -> Result<u64, GetBitsError> {
71 if n > 64 {
72 return Err(GetBitsError::TooManyBits {
73 num_requested_bits: n,
74 limit: 64,
75 });
76 }
77 if self.bits_left() < n {
78 return Err(GetBitsError::NotEnoughRemainingBits {
79 requested: n,
80 remaining: self.bits_left(),
81 });
82 }
83
84 let old_idx = self.idx;
85
86 let bits_left_in_current_byte = 8 - (self.idx % 8);
87 let bits_not_needed_in_current_byte = 8 - bits_left_in_current_byte;
88
89 let mut value = u64::from(self.source[self.idx / 8] >> bits_not_needed_in_current_byte);
91
92 if bits_left_in_current_byte >= n {
93 value &= (1 << n) - 1;
97 self.idx += n;
98 } else {
99 self.idx += bits_left_in_current_byte;
100
101 let full_bytes_needed = (n - bits_left_in_current_byte) / 8;
103 let bits_in_last_byte_needed = n - bits_left_in_current_byte - full_bytes_needed * 8;
104
105 assert!(
106 bits_left_in_current_byte + full_bytes_needed * 8 + bits_in_last_byte_needed == n
107 );
108
109 let mut bit_shift = bits_left_in_current_byte; assert!(self.idx % 8 == 0);
112
113 for _ in 0..full_bytes_needed {
115 value |= u64::from(self.source[self.idx / 8]) << bit_shift;
116 self.idx += 8;
117 bit_shift += 8;
118 }
119
120 assert!(n - bit_shift == bits_in_last_byte_needed);
121
122 if bits_in_last_byte_needed > 0 {
123 let val_las_byte =
124 u64::from(self.source[self.idx / 8]) & ((1 << bits_in_last_byte_needed) - 1);
125 value |= val_las_byte << bit_shift;
126 self.idx += bits_in_last_byte_needed;
127 }
128 }
129
130 assert!(self.idx == old_idx + n);
131
132 Ok(value)
133 }
134
135 pub fn reset(&mut self, new_source: &'s [u8]) {
136 self.idx = 0;
137 self.source = new_source;
138 }
139}