ruzstd/decoding/
bit_reader.rs

1use super::errors::GetBitsError;
2
3/// Interact with a provided source at a bit level.
4pub struct BitReader<'s> {
5    idx: usize, //index counts bits already read
6    source: &'s [u8],
7}
8
9#[cfg(feature = "std")]
10impl std::error::Error for GetBitsError {}
11
12impl core::fmt::Display for GetBitsError {
13    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
14        match self {
15            GetBitsError::TooManyBits {
16                num_requested_bits,
17                limit,
18            } => {
19                write!(
20                    f,
21                    "Cant serve this request. The reader is limited to {} bits, requested {} bits",
22                    limit, num_requested_bits,
23                )
24            }
25            GetBitsError::NotEnoughRemainingBits {
26                requested,
27                remaining,
28            } => {
29                write!(
30                    f,
31                    "Can\'t read {} bits, only have {} bits left",
32                    requested, remaining,
33                )
34            }
35        }
36    }
37}
38
39impl<'s> BitReader<'s> {
40    pub fn new(source: &'s [u8]) -> BitReader<'s> {
41        BitReader { idx: 0, source }
42    }
43
44    pub fn bits_left(&self) -> usize {
45        self.source.len() * 8 - self.idx
46    }
47
48    pub fn bits_read(&self) -> usize {
49        self.idx
50    }
51
52    pub fn return_bits(&mut self, n: usize) {
53        if n > self.idx {
54            panic!("Cant return this many bits");
55        }
56        self.idx -= n;
57    }
58
59    pub fn get_bits(&mut self, n: usize) -> Result<u64, GetBitsError> {
60        if n > 64 {
61            return Err(GetBitsError::TooManyBits {
62                num_requested_bits: n,
63                limit: 64,
64            });
65        }
66        if self.bits_left() < n {
67            return Err(GetBitsError::NotEnoughRemainingBits {
68                requested: n,
69                remaining: self.bits_left(),
70            });
71        }
72
73        let old_idx = self.idx;
74
75        let bits_left_in_current_byte = 8 - (self.idx % 8);
76        let bits_not_needed_in_current_byte = 8 - bits_left_in_current_byte;
77
78        //collect bits from the currently pointed to byte
79        let mut value = u64::from(self.source[self.idx / 8] >> bits_not_needed_in_current_byte);
80
81        if bits_left_in_current_byte >= n {
82            //no need for fancy stuff
83
84            //just mask all but the needed n bit
85            value &= (1 << n) - 1;
86            self.idx += n;
87        } else {
88            self.idx += bits_left_in_current_byte;
89
90            //n spans over multiple bytes
91            let full_bytes_needed = (n - bits_left_in_current_byte) / 8;
92            let bits_in_last_byte_needed = n - bits_left_in_current_byte - full_bytes_needed * 8;
93
94            assert!(
95                bits_left_in_current_byte + full_bytes_needed * 8 + bits_in_last_byte_needed == n
96            );
97
98            let mut bit_shift = bits_left_in_current_byte; //this many bits are already set in value
99
100            assert!(self.idx % 8 == 0);
101
102            //collect full bytes
103            for _ in 0..full_bytes_needed {
104                value |= u64::from(self.source[self.idx / 8]) << bit_shift;
105                self.idx += 8;
106                bit_shift += 8;
107            }
108
109            assert!(n - bit_shift == bits_in_last_byte_needed);
110
111            if bits_in_last_byte_needed > 0 {
112                let val_las_byte =
113                    u64::from(self.source[self.idx / 8]) & ((1 << bits_in_last_byte_needed) - 1);
114                value |= val_las_byte << bit_shift;
115                self.idx += bits_in_last_byte_needed;
116            }
117        }
118
119        assert!(self.idx == old_idx + n);
120
121        Ok(value)
122    }
123}