ruzstd/decoding/
bit_reader_reverse.rs1use crate::io::Read;
2use core::convert::TryInto;
3
4pub struct BitReaderReversed<'s> {
8 idx: isize, source: &'s [u8],
10 bit_container: u64,
14 bits_in_container: u8,
15}
16
17impl<'s> BitReaderReversed<'s> {
18 pub fn bits_remaining(&self) -> isize {
20 self.idx + self.bits_in_container as isize
21 }
22
23 pub fn new(source: &'s [u8]) -> BitReaderReversed<'s> {
24 BitReaderReversed {
25 idx: source.len() as isize * 8,
26 source,
27 bit_container: 0,
28 bits_in_container: 0,
29 }
30 }
31
32 #[inline(always)]
34 fn refill_container(&mut self) {
35 let byte_idx = self.byte_idx() as usize;
36
37 let retain_bytes = (self.bits_in_container + 7) / 8;
38 let want_to_read_bits = 64 - (retain_bytes * 8);
39
40 if byte_idx >= 8 {
45 self.refill_fast(byte_idx, retain_bytes, want_to_read_bits)
46 } else {
47 self.refill_slow(byte_idx, want_to_read_bits)
49 }
50 }
51
52 #[inline(always)]
53 fn refill_fast(&mut self, byte_idx: usize, retain_bytes: u8, want_to_read_bits: u8) {
54 let load_from_byte_idx = byte_idx - 7 + retain_bytes as usize;
55 let tmp_bytes: [u8; 8] = (&self.source[load_from_byte_idx..][..8])
56 .try_into()
57 .unwrap();
58 let refill = u64::from_le_bytes(tmp_bytes);
59 self.bit_container = refill;
60 self.bits_in_container += want_to_read_bits;
61 self.idx -= want_to_read_bits as isize;
62 }
63
64 #[cold]
65 fn refill_slow(&mut self, byte_idx: usize, want_to_read_bits: u8) {
66 let can_read_bits = isize::min(want_to_read_bits as isize, self.idx);
67 let can_read_bytes = can_read_bits / 8;
68 let mut tmp_bytes = [0u8; 8];
69 let offset @ 1..=8 = can_read_bytes as usize else {
70 unreachable!()
71 };
72 let bits_read = offset * 8;
73
74 let _ = (&self.source[byte_idx - (offset - 1)..]).read_exact(&mut tmp_bytes[0..offset]);
75 self.bits_in_container += bits_read as u8;
76 self.idx -= bits_read as isize;
77 if offset < 8 {
78 self.bit_container <<= bits_read;
79 self.bit_container |= u64::from_le_bytes(tmp_bytes);
80 } else {
81 self.bit_container = u64::from_le_bytes(tmp_bytes);
82 }
83 }
84
85 fn byte_idx(&self) -> isize {
88 (self.idx - 1) / 8
89 }
90
91 #[inline(always)]
94 pub fn get_bits(&mut self, n: u8) -> u64 {
95 if n == 0 {
96 return 0;
97 }
98 if self.bits_in_container >= n {
99 return self.get_bits_unchecked(n);
100 }
101
102 self.get_bits_cold(n)
103 }
104
105 #[cold]
106 fn get_bits_cold(&mut self, n: u8) -> u64 {
107 let n = u8::min(n, 56);
108 let signed_n = n as isize;
109
110 if self.bits_remaining() <= 0 {
111 self.idx -= signed_n;
112 return 0;
113 }
114
115 if self.bits_remaining() < signed_n {
116 let emulated_read_shift = signed_n - self.bits_remaining();
117 let v = self.get_bits(self.bits_remaining() as u8);
118 debug_assert!(self.idx == 0);
119 let value = v.wrapping_shl(emulated_read_shift as u32);
120 self.idx -= emulated_read_shift;
121 return value;
122 }
123
124 while (self.bits_in_container < n) && self.idx > 0 {
125 self.refill_container();
126 }
127
128 debug_assert!(self.bits_in_container >= n);
129
130 self.get_bits_unchecked(n)
133 }
134
135 #[inline(always)]
137 pub fn get_bits_triple(&mut self, n1: u8, n2: u8, n3: u8) -> (u64, u64, u64) {
138 let sum = n1 as usize + n2 as usize + n3 as usize;
139 if sum == 0 {
140 return (0, 0, 0);
141 }
142 if sum > 56 {
143 return (self.get_bits(n1), self.get_bits(n2), self.get_bits(n3));
145 }
146 let sum = sum as u8;
147
148 if self.bits_in_container >= sum {
149 let v1 = if n1 == 0 {
150 0
151 } else {
152 self.get_bits_unchecked(n1)
153 };
154 let v2 = if n2 == 0 {
155 0
156 } else {
157 self.get_bits_unchecked(n2)
158 };
159 let v3 = if n3 == 0 {
160 0
161 } else {
162 self.get_bits_unchecked(n3)
163 };
164
165 return (v1, v2, v3);
166 }
167
168 self.get_bits_triple_cold(n1, n2, n3, sum)
169 }
170
171 #[cold]
172 fn get_bits_triple_cold(&mut self, n1: u8, n2: u8, n3: u8, sum: u8) -> (u64, u64, u64) {
173 let sum_signed = sum as isize;
174
175 if self.bits_remaining() <= 0 {
176 self.idx -= sum_signed;
177 return (0, 0, 0);
178 }
179
180 if self.bits_remaining() < sum_signed {
181 return (self.get_bits(n1), self.get_bits(n2), self.get_bits(n3));
182 }
183
184 while (self.bits_in_container < sum) && self.idx > 0 {
185 self.refill_container();
186 }
187
188 debug_assert!(self.bits_in_container >= sum);
189
190 let v1 = if n1 == 0 {
193 0
194 } else {
195 self.get_bits_unchecked(n1)
196 };
197 let v2 = if n2 == 0 {
198 0
199 } else {
200 self.get_bits_unchecked(n2)
201 };
202 let v3 = if n3 == 0 {
203 0
204 } else {
205 self.get_bits_unchecked(n3)
206 };
207
208 (v1, v2, v3)
209 }
210
211 #[inline(always)]
212 fn get_bits_unchecked(&mut self, n: u8) -> u64 {
213 let shift_by = self.bits_in_container - n;
214 let mask = (1u64 << n) - 1u64;
215
216 let value = self.bit_container >> shift_by;
217 self.bits_in_container -= n;
218 let value_masked = value & mask;
219 debug_assert!(value_masked < (1 << n));
220
221 value_masked
222 }
223}