ruzstd/decoding/
bit_reader_reverse.rs1use core::convert::TryInto;
2
3pub use super::bit_reader::GetBitsError;
4use crate::io::Read;
5
6pub struct BitReaderReversed<'s> {
10 idx: isize, source: &'s [u8],
12 bit_container: u64,
16 bits_in_container: u8,
17}
18
19impl<'s> BitReaderReversed<'s> {
20 pub fn bits_remaining(&self) -> isize {
22 self.idx + self.bits_in_container as isize
23 }
24
25 pub fn new(source: &'s [u8]) -> BitReaderReversed<'s> {
26 BitReaderReversed {
27 idx: source.len() as isize * 8,
28 source,
29 bit_container: 0,
30 bits_in_container: 0,
31 }
32 }
33
34 #[inline(always)]
36 fn refill_container(&mut self) {
37 let byte_idx = self.byte_idx() as usize;
38
39 let retain_bytes = (self.bits_in_container + 7) / 8;
40 let want_to_read_bits = 64 - (retain_bytes * 8);
41
42 if byte_idx >= 8 {
47 self.refill_fast(byte_idx, retain_bytes, want_to_read_bits)
48 } else {
49 self.refill_slow(byte_idx, want_to_read_bits)
51 }
52 }
53
54 #[inline(always)]
55 fn refill_fast(&mut self, byte_idx: usize, retain_bytes: u8, want_to_read_bits: u8) {
56 let load_from_byte_idx = byte_idx - 7 + retain_bytes as usize;
57 let tmp_bytes: [u8; 8] = (&self.source[load_from_byte_idx..][..8])
58 .try_into()
59 .unwrap();
60 let refill = u64::from_le_bytes(tmp_bytes);
61 self.bit_container = refill;
62 self.bits_in_container += want_to_read_bits;
63 self.idx -= want_to_read_bits as isize;
64 }
65
66 #[cold]
67 fn refill_slow(&mut self, byte_idx: usize, want_to_read_bits: u8) {
68 let can_read_bits = isize::min(want_to_read_bits as isize, self.idx);
69 let can_read_bytes = can_read_bits / 8;
70 let mut tmp_bytes = [0u8; 8];
71 let offset @ 1..=8 = can_read_bytes as usize else {
72 unreachable!()
73 };
74 let bits_read = offset * 8;
75
76 let _ = (&self.source[byte_idx - (offset - 1)..]).read_exact(&mut tmp_bytes[0..offset]);
77 self.bits_in_container += bits_read as u8;
78 self.idx -= bits_read as isize;
79 if offset < 8 {
80 self.bit_container <<= bits_read;
81 self.bit_container |= u64::from_le_bytes(tmp_bytes);
82 } else {
83 self.bit_container = u64::from_le_bytes(tmp_bytes);
84 }
85 }
86
87 fn byte_idx(&self) -> isize {
90 (self.idx - 1) / 8
91 }
92
93 #[inline(always)]
96 pub fn get_bits(&mut self, n: u8) -> u64 {
97 if n == 0 {
98 return 0;
99 }
100 if self.bits_in_container >= n {
101 return self.get_bits_unchecked(n);
102 }
103
104 self.get_bits_cold(n)
105 }
106
107 #[cold]
108 fn get_bits_cold(&mut self, n: u8) -> u64 {
109 let n = u8::min(n, 56);
110 let signed_n = n as isize;
111
112 if self.bits_remaining() <= 0 {
113 self.idx -= signed_n;
114 return 0;
115 }
116
117 if self.bits_remaining() < signed_n {
118 let emulated_read_shift = signed_n - self.bits_remaining();
119 let v = self.get_bits(self.bits_remaining() as u8);
120 debug_assert!(self.idx == 0);
121 let value = v.wrapping_shl(emulated_read_shift as u32);
122 self.idx -= emulated_read_shift;
123 return value;
124 }
125
126 while (self.bits_in_container < n) && self.idx > 0 {
127 self.refill_container();
128 }
129
130 debug_assert!(self.bits_in_container >= n);
131
132 self.get_bits_unchecked(n)
135 }
136
137 #[inline(always)]
139 pub fn get_bits_triple(&mut self, n1: u8, n2: u8, n3: u8) -> (u64, u64, u64) {
140 let sum = n1 as usize + n2 as usize + n3 as usize;
141 if sum == 0 {
142 return (0, 0, 0);
143 }
144 if sum > 56 {
145 return (self.get_bits(n1), self.get_bits(n2), self.get_bits(n3));
147 }
148 let sum = sum as u8;
149
150 if self.bits_in_container >= sum {
151 let v1 = if n1 == 0 {
152 0
153 } else {
154 self.get_bits_unchecked(n1)
155 };
156 let v2 = if n2 == 0 {
157 0
158 } else {
159 self.get_bits_unchecked(n2)
160 };
161 let v3 = if n3 == 0 {
162 0
163 } else {
164 self.get_bits_unchecked(n3)
165 };
166
167 return (v1, v2, v3);
168 }
169
170 self.get_bits_triple_cold(n1, n2, n3, sum)
171 }
172
173 #[cold]
174 fn get_bits_triple_cold(&mut self, n1: u8, n2: u8, n3: u8, sum: u8) -> (u64, u64, u64) {
175 let sum_signed = sum as isize;
176
177 if self.bits_remaining() <= 0 {
178 self.idx -= sum_signed;
179 return (0, 0, 0);
180 }
181
182 if self.bits_remaining() < sum_signed {
183 return (self.get_bits(n1), self.get_bits(n2), self.get_bits(n3));
184 }
185
186 while (self.bits_in_container < sum) && self.idx > 0 {
187 self.refill_container();
188 }
189
190 debug_assert!(self.bits_in_container >= sum);
191
192 let v1 = if n1 == 0 {
195 0
196 } else {
197 self.get_bits_unchecked(n1)
198 };
199 let v2 = if n2 == 0 {
200 0
201 } else {
202 self.get_bits_unchecked(n2)
203 };
204 let v3 = if n3 == 0 {
205 0
206 } else {
207 self.get_bits_unchecked(n3)
208 };
209
210 (v1, v2, v3)
211 }
212
213 #[inline(always)]
214 fn get_bits_unchecked(&mut self, n: u8) -> u64 {
215 let shift_by = self.bits_in_container - n;
216 let mask = (1u64 << n) - 1u64;
217
218 let value = self.bit_container >> shift_by;
219 self.bits_in_container -= n;
220 let value_masked = value & mask;
221 debug_assert!(value_masked < (1 << n));
222
223 value_masked
224 }
225
226 pub fn reset(&mut self, new_source: &'s [u8]) {
227 self.idx = new_source.len() as isize * 8;
228 self.source = new_source;
229 self.bit_container = 0;
230 self.bits_in_container = 0;
231 }
232}