ruzstd/decoding/
bit_reader_reverse.rs

1use core::convert::TryInto;
2
3pub use super::bit_reader::GetBitsError;
4use crate::io::Read;
5
6/// Zstandard encodes some types of data in a way that the data must be read
7/// back to front to decode it properly. `BitReaderReversed` provides a
8/// convenient interface to do that.
9pub struct BitReaderReversed<'s> {
10    idx: isize, //index counts bits already read
11    source: &'s [u8],
12    /// The reader doesn't read directly from the source,
13    /// it reads bits from here, and the container is
14    /// "refilled" as it's emptied.
15    bit_container: u64,
16    bits_in_container: u8,
17}
18
19impl<'s> BitReaderReversed<'s> {
20    /// How many bits are left to read by the reader.
21    pub fn bits_remaining(&self) -> isize {
22        self.idx + self.bits_in_container as isize
23    }
24
25    pub fn new(source: &'s [u8]) -> BitReaderReversed<'s> {
26        BitReaderReversed {
27            idx: source.len() as isize * 8,
28            source,
29            bit_container: 0,
30            bits_in_container: 0,
31        }
32    }
33
34    /// We refill the container in full bytes, shifting the still unread portion to the left, and filling the lower bits with new data
35    #[inline(always)]
36    fn refill_container(&mut self) {
37        let byte_idx = self.byte_idx() as usize;
38
39        let retain_bytes = (self.bits_in_container + 7) / 8;
40        let want_to_read_bits = 64 - (retain_bytes * 8);
41
42        // if there are >= 8 byte left to read we go a fast path:
43        // The slice is looking something like this |U..UCCCCCCCCR..R| Where U are some unread bytes, C are the bytes in the container, and R are already read bytes
44        // What we do is, we shift the container by a few bytes to the left by just reading a u64 from the correct position, rereading the portion we did not yet return from the conainer.
45        // Technically this would still work for positions lower than 8 but this guarantees that enough bytes are in the source and generally makes for less edge cases
46        if byte_idx >= 8 {
47            self.refill_fast(byte_idx, retain_bytes, want_to_read_bits)
48        } else {
49            // In the slow path we just read however many bytes we can
50            self.refill_slow(byte_idx, want_to_read_bits)
51        }
52    }
53
54    #[inline(always)]
55    fn refill_fast(&mut self, byte_idx: usize, retain_bytes: u8, want_to_read_bits: u8) {
56        let load_from_byte_idx = byte_idx - 7 + retain_bytes as usize;
57        let tmp_bytes: [u8; 8] = (&self.source[load_from_byte_idx..][..8])
58            .try_into()
59            .unwrap();
60        let refill = u64::from_le_bytes(tmp_bytes);
61        self.bit_container = refill;
62        self.bits_in_container += want_to_read_bits;
63        self.idx -= want_to_read_bits as isize;
64    }
65
66    #[cold]
67    fn refill_slow(&mut self, byte_idx: usize, want_to_read_bits: u8) {
68        let can_read_bits = isize::min(want_to_read_bits as isize, self.idx);
69        let can_read_bytes = can_read_bits / 8;
70        let mut tmp_bytes = [0u8; 8];
71        let offset @ 1..=8 = can_read_bytes as usize else {
72            unreachable!()
73        };
74        let bits_read = offset * 8;
75
76        let _ = (&self.source[byte_idx - (offset - 1)..]).read_exact(&mut tmp_bytes[0..offset]);
77        self.bits_in_container += bits_read as u8;
78        self.idx -= bits_read as isize;
79        if offset < 8 {
80            self.bit_container <<= bits_read;
81            self.bit_container |= u64::from_le_bytes(tmp_bytes);
82        } else {
83            self.bit_container = u64::from_le_bytes(tmp_bytes);
84        }
85    }
86
87    /// Next byte that should be read into the container
88    /// Negative values mean that the source buffer as been read into the container completetly.
89    fn byte_idx(&self) -> isize {
90        (self.idx - 1) / 8
91    }
92
93    /// Read `n` number of bits from the source. Will read at most 56 bits.
94    /// If there are no more bits to be read from the source zero bits will be returned instead.
95    #[inline(always)]
96    pub fn get_bits(&mut self, n: u8) -> u64 {
97        if n == 0 {
98            return 0;
99        }
100        if self.bits_in_container >= n {
101            return self.get_bits_unchecked(n);
102        }
103
104        self.get_bits_cold(n)
105    }
106
107    #[cold]
108    fn get_bits_cold(&mut self, n: u8) -> u64 {
109        let n = u8::min(n, 56);
110        let signed_n = n as isize;
111
112        if self.bits_remaining() <= 0 {
113            self.idx -= signed_n;
114            return 0;
115        }
116
117        if self.bits_remaining() < signed_n {
118            let emulated_read_shift = signed_n - self.bits_remaining();
119            let v = self.get_bits(self.bits_remaining() as u8);
120            debug_assert!(self.idx == 0);
121            let value = v.wrapping_shl(emulated_read_shift as u32);
122            self.idx -= emulated_read_shift;
123            return value;
124        }
125
126        while (self.bits_in_container < n) && self.idx > 0 {
127            self.refill_container();
128        }
129
130        debug_assert!(self.bits_in_container >= n);
131
132        //if we reach this point there are enough bits in the container
133
134        self.get_bits_unchecked(n)
135    }
136
137    /// Same as calling get_bits three times but slightly more performant
138    #[inline(always)]
139    pub fn get_bits_triple(&mut self, n1: u8, n2: u8, n3: u8) -> (u64, u64, u64) {
140        let sum = n1 as usize + n2 as usize + n3 as usize;
141        if sum == 0 {
142            return (0, 0, 0);
143        }
144        if sum > 56 {
145            // try and get the values separately
146            return (self.get_bits(n1), self.get_bits(n2), self.get_bits(n3));
147        }
148        let sum = sum as u8;
149
150        if self.bits_in_container >= sum {
151            let v1 = if n1 == 0 {
152                0
153            } else {
154                self.get_bits_unchecked(n1)
155            };
156            let v2 = if n2 == 0 {
157                0
158            } else {
159                self.get_bits_unchecked(n2)
160            };
161            let v3 = if n3 == 0 {
162                0
163            } else {
164                self.get_bits_unchecked(n3)
165            };
166
167            return (v1, v2, v3);
168        }
169
170        self.get_bits_triple_cold(n1, n2, n3, sum)
171    }
172
173    #[cold]
174    fn get_bits_triple_cold(&mut self, n1: u8, n2: u8, n3: u8, sum: u8) -> (u64, u64, u64) {
175        let sum_signed = sum as isize;
176
177        if self.bits_remaining() <= 0 {
178            self.idx -= sum_signed;
179            return (0, 0, 0);
180        }
181
182        if self.bits_remaining() < sum_signed {
183            return (self.get_bits(n1), self.get_bits(n2), self.get_bits(n3));
184        }
185
186        while (self.bits_in_container < sum) && self.idx > 0 {
187            self.refill_container();
188        }
189
190        debug_assert!(self.bits_in_container >= sum);
191
192        //if we reach this point there are enough bits in the container
193
194        let v1 = if n1 == 0 {
195            0
196        } else {
197            self.get_bits_unchecked(n1)
198        };
199        let v2 = if n2 == 0 {
200            0
201        } else {
202            self.get_bits_unchecked(n2)
203        };
204        let v3 = if n3 == 0 {
205            0
206        } else {
207            self.get_bits_unchecked(n3)
208        };
209
210        (v1, v2, v3)
211    }
212
213    #[inline(always)]
214    fn get_bits_unchecked(&mut self, n: u8) -> u64 {
215        let shift_by = self.bits_in_container - n;
216        let mask = (1u64 << n) - 1u64;
217
218        let value = self.bit_container >> shift_by;
219        self.bits_in_container -= n;
220        let value_masked = value & mask;
221        debug_assert!(value_masked < (1 << n));
222
223        value_masked
224    }
225
226    pub fn reset(&mut self, new_source: &'s [u8]) {
227        self.idx = new_source.len() as isize * 8;
228        self.source = new_source;
229        self.bit_container = 0;
230        self.bits_in_container = 0;
231    }
232}