ruzstd/decoding/
bit_reader_reverse.rs

1use crate::io::Read;
2use core::convert::TryInto;
3
4/// Zstandard encodes some types of data in a way that the data must be read
5/// back to front to decode it properly. `BitReaderReversed` provides a
6/// convenient interface to do that.
7pub struct BitReaderReversed<'s> {
8    idx: isize, //index counts bits already read
9    source: &'s [u8],
10    /// The reader doesn't read directly from the source,
11    /// it reads bits from here, and the container is
12    /// "refilled" as it's emptied.
13    bit_container: u64,
14    bits_in_container: u8,
15}
16
17impl<'s> BitReaderReversed<'s> {
18    /// How many bits are left to read by the reader.
19    pub fn bits_remaining(&self) -> isize {
20        self.idx + self.bits_in_container as isize
21    }
22
23    pub fn new(source: &'s [u8]) -> BitReaderReversed<'s> {
24        BitReaderReversed {
25            idx: source.len() as isize * 8,
26            source,
27            bit_container: 0,
28            bits_in_container: 0,
29        }
30    }
31
32    /// We refill the container in full bytes, shifting the still unread portion to the left, and filling the lower bits with new data
33    #[inline(always)]
34    fn refill_container(&mut self) {
35        let byte_idx = self.byte_idx() as usize;
36
37        let retain_bytes = (self.bits_in_container + 7) / 8;
38        let want_to_read_bits = 64 - (retain_bytes * 8);
39
40        // if there are >= 8 byte left to read we go a fast path:
41        // The slice is looking something like this |U..UCCCCCCCCR..R| Where U are some unread bytes, C are the bytes in the container, and R are already read bytes
42        // What we do is, we shift the container by a few bytes to the left by just reading a u64 from the correct position, rereading the portion we did not yet return from the conainer.
43        // Technically this would still work for positions lower than 8 but this guarantees that enough bytes are in the source and generally makes for less edge cases
44        if byte_idx >= 8 {
45            self.refill_fast(byte_idx, retain_bytes, want_to_read_bits)
46        } else {
47            // In the slow path we just read however many bytes we can
48            self.refill_slow(byte_idx, want_to_read_bits)
49        }
50    }
51
52    #[inline(always)]
53    fn refill_fast(&mut self, byte_idx: usize, retain_bytes: u8, want_to_read_bits: u8) {
54        let load_from_byte_idx = byte_idx - 7 + retain_bytes as usize;
55        let tmp_bytes: [u8; 8] = (&self.source[load_from_byte_idx..][..8])
56            .try_into()
57            .unwrap();
58        let refill = u64::from_le_bytes(tmp_bytes);
59        self.bit_container = refill;
60        self.bits_in_container += want_to_read_bits;
61        self.idx -= want_to_read_bits as isize;
62    }
63
64    #[cold]
65    fn refill_slow(&mut self, byte_idx: usize, want_to_read_bits: u8) {
66        let can_read_bits = isize::min(want_to_read_bits as isize, self.idx);
67        let can_read_bytes = can_read_bits / 8;
68        let mut tmp_bytes = [0u8; 8];
69        let offset @ 1..=8 = can_read_bytes as usize else {
70            unreachable!()
71        };
72        let bits_read = offset * 8;
73
74        let _ = (&self.source[byte_idx - (offset - 1)..]).read_exact(&mut tmp_bytes[0..offset]);
75        self.bits_in_container += bits_read as u8;
76        self.idx -= bits_read as isize;
77        if offset < 8 {
78            self.bit_container <<= bits_read;
79            self.bit_container |= u64::from_le_bytes(tmp_bytes);
80        } else {
81            self.bit_container = u64::from_le_bytes(tmp_bytes);
82        }
83    }
84
85    /// Next byte that should be read into the container
86    /// Negative values mean that the source buffer as been read into the container completetly.
87    fn byte_idx(&self) -> isize {
88        (self.idx - 1) / 8
89    }
90
91    /// Read `n` number of bits from the source. Will read at most 56 bits.
92    /// If there are no more bits to be read from the source zero bits will be returned instead.
93    #[inline(always)]
94    pub fn get_bits(&mut self, n: u8) -> u64 {
95        if n == 0 {
96            return 0;
97        }
98        if self.bits_in_container >= n {
99            return self.get_bits_unchecked(n);
100        }
101
102        self.get_bits_cold(n)
103    }
104
105    #[cold]
106    fn get_bits_cold(&mut self, n: u8) -> u64 {
107        let n = u8::min(n, 56);
108        let signed_n = n as isize;
109
110        if self.bits_remaining() <= 0 {
111            self.idx -= signed_n;
112            return 0;
113        }
114
115        if self.bits_remaining() < signed_n {
116            let emulated_read_shift = signed_n - self.bits_remaining();
117            let v = self.get_bits(self.bits_remaining() as u8);
118            debug_assert!(self.idx == 0);
119            let value = v.wrapping_shl(emulated_read_shift as u32);
120            self.idx -= emulated_read_shift;
121            return value;
122        }
123
124        while (self.bits_in_container < n) && self.idx > 0 {
125            self.refill_container();
126        }
127
128        debug_assert!(self.bits_in_container >= n);
129
130        //if we reach this point there are enough bits in the container
131
132        self.get_bits_unchecked(n)
133    }
134
135    /// Same as calling get_bits three times but slightly more performant
136    #[inline(always)]
137    pub fn get_bits_triple(&mut self, n1: u8, n2: u8, n3: u8) -> (u64, u64, u64) {
138        let sum = n1 as usize + n2 as usize + n3 as usize;
139        if sum == 0 {
140            return (0, 0, 0);
141        }
142        if sum > 56 {
143            // try and get the values separately
144            return (self.get_bits(n1), self.get_bits(n2), self.get_bits(n3));
145        }
146        let sum = sum as u8;
147
148        if self.bits_in_container >= sum {
149            let v1 = if n1 == 0 {
150                0
151            } else {
152                self.get_bits_unchecked(n1)
153            };
154            let v2 = if n2 == 0 {
155                0
156            } else {
157                self.get_bits_unchecked(n2)
158            };
159            let v3 = if n3 == 0 {
160                0
161            } else {
162                self.get_bits_unchecked(n3)
163            };
164
165            return (v1, v2, v3);
166        }
167
168        self.get_bits_triple_cold(n1, n2, n3, sum)
169    }
170
171    #[cold]
172    fn get_bits_triple_cold(&mut self, n1: u8, n2: u8, n3: u8, sum: u8) -> (u64, u64, u64) {
173        let sum_signed = sum as isize;
174
175        if self.bits_remaining() <= 0 {
176            self.idx -= sum_signed;
177            return (0, 0, 0);
178        }
179
180        if self.bits_remaining() < sum_signed {
181            return (self.get_bits(n1), self.get_bits(n2), self.get_bits(n3));
182        }
183
184        while (self.bits_in_container < sum) && self.idx > 0 {
185            self.refill_container();
186        }
187
188        debug_assert!(self.bits_in_container >= sum);
189
190        //if we reach this point there are enough bits in the container
191
192        let v1 = if n1 == 0 {
193            0
194        } else {
195            self.get_bits_unchecked(n1)
196        };
197        let v2 = if n2 == 0 {
198            0
199        } else {
200            self.get_bits_unchecked(n2)
201        };
202        let v3 = if n3 == 0 {
203            0
204        } else {
205            self.get_bits_unchecked(n3)
206        };
207
208        (v1, v2, v3)
209    }
210
211    #[inline(always)]
212    fn get_bits_unchecked(&mut self, n: u8) -> u64 {
213        let shift_by = self.bits_in_container - n;
214        let mask = (1u64 << n) - 1u64;
215
216        let value = self.bit_container >> shift_by;
217        self.bits_in_container -= n;
218        let value_masked = value & mask;
219        debug_assert!(value_masked < (1 << n));
220
221        value_masked
222    }
223}