ruzstd/decoding/
bit_reader.rs

1/// Interact with a provided source at a bit level.
2pub struct BitReader<'s> {
3    idx: usize, //index counts bits already read
4    source: &'s [u8],
5}
6
7#[derive(Debug)]
8#[non_exhaustive]
9pub enum GetBitsError {
10    TooManyBits {
11        num_requested_bits: usize,
12        limit: u8,
13    },
14    NotEnoughRemainingBits {
15        requested: usize,
16        remaining: usize,
17    },
18}
19
20#[cfg(feature = "std")]
21impl std::error::Error for GetBitsError {}
22
23impl core::fmt::Display for GetBitsError {
24    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
25        match self {
26            GetBitsError::TooManyBits {
27                num_requested_bits,
28                limit,
29            } => {
30                write!(
31                    f,
32                    "Cant serve this request. The reader is limited to {} bits, requested {} bits",
33                    limit, num_requested_bits,
34                )
35            }
36            GetBitsError::NotEnoughRemainingBits {
37                requested,
38                remaining,
39            } => {
40                write!(
41                    f,
42                    "Can\'t read {} bits, only have {} bits left",
43                    requested, remaining,
44                )
45            }
46        }
47    }
48}
49
50impl<'s> BitReader<'s> {
51    pub fn new(source: &'s [u8]) -> BitReader<'s> {
52        BitReader { idx: 0, source }
53    }
54
55    pub fn bits_left(&self) -> usize {
56        self.source.len() * 8 - self.idx
57    }
58
59    pub fn bits_read(&self) -> usize {
60        self.idx
61    }
62
63    pub fn return_bits(&mut self, n: usize) {
64        if n > self.idx {
65            panic!("Cant return this many bits");
66        }
67        self.idx -= n;
68    }
69
70    pub fn get_bits(&mut self, n: usize) -> Result<u64, GetBitsError> {
71        if n > 64 {
72            return Err(GetBitsError::TooManyBits {
73                num_requested_bits: n,
74                limit: 64,
75            });
76        }
77        if self.bits_left() < n {
78            return Err(GetBitsError::NotEnoughRemainingBits {
79                requested: n,
80                remaining: self.bits_left(),
81            });
82        }
83
84        let old_idx = self.idx;
85
86        let bits_left_in_current_byte = 8 - (self.idx % 8);
87        let bits_not_needed_in_current_byte = 8 - bits_left_in_current_byte;
88
89        //collect bits from the currently pointed to byte
90        let mut value = u64::from(self.source[self.idx / 8] >> bits_not_needed_in_current_byte);
91
92        if bits_left_in_current_byte >= n {
93            //no need for fancy stuff
94
95            //just mask all but the needed n bit
96            value &= (1 << n) - 1;
97            self.idx += n;
98        } else {
99            self.idx += bits_left_in_current_byte;
100
101            //n spans over multiple bytes
102            let full_bytes_needed = (n - bits_left_in_current_byte) / 8;
103            let bits_in_last_byte_needed = n - bits_left_in_current_byte - full_bytes_needed * 8;
104
105            assert!(
106                bits_left_in_current_byte + full_bytes_needed * 8 + bits_in_last_byte_needed == n
107            );
108
109            let mut bit_shift = bits_left_in_current_byte; //this many bits are already set in value
110
111            assert!(self.idx % 8 == 0);
112
113            //collect full bytes
114            for _ in 0..full_bytes_needed {
115                value |= u64::from(self.source[self.idx / 8]) << bit_shift;
116                self.idx += 8;
117                bit_shift += 8;
118            }
119
120            assert!(n - bit_shift == bits_in_last_byte_needed);
121
122            if bits_in_last_byte_needed > 0 {
123                let val_las_byte =
124                    u64::from(self.source[self.idx / 8]) & ((1 << bits_in_last_byte_needed) - 1);
125                value |= val_las_byte << bit_shift;
126                self.idx += bits_in_last_byte_needed;
127            }
128        }
129
130        assert!(self.idx == old_idx + n);
131
132        Ok(value)
133    }
134
135    pub fn reset(&mut self, new_source: &'s [u8]) {
136        self.idx = 0;
137        self.source = new_source;
138    }
139}