ruzstd/decoding/
decodebuffer.rs

1use crate::io::{Error, Read, Write};
2use alloc::vec::Vec;
3#[cfg(feature = "hash")]
4use core::hash::Hasher;
5
6use super::ringbuffer::RingBuffer;
7
8pub struct DecodeBuffer {
9    buffer: RingBuffer,
10    pub dict_content: Vec<u8>,
11
12    pub window_size: usize,
13    total_output_counter: u64,
14    #[cfg(feature = "hash")]
15    pub hash: twox_hash::XxHash64,
16}
17
18#[derive(Debug)]
19#[non_exhaustive]
20pub enum DecodeBufferError {
21    NotEnoughBytesInDictionary { got: usize, need: usize },
22    OffsetTooBig { offset: usize, buf_len: usize },
23}
24
25#[cfg(feature = "std")]
26impl std::error::Error for DecodeBufferError {}
27
28impl core::fmt::Display for DecodeBufferError {
29    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
30        match self {
31            DecodeBufferError::NotEnoughBytesInDictionary { got, need } => {
32                write!(
33                    f,
34                    "Need {} bytes from the dictionary but it is only {} bytes long",
35                    need, got,
36                )
37            }
38            DecodeBufferError::OffsetTooBig { offset, buf_len } => {
39                write!(f, "offset: {} bigger than buffer: {}", offset, buf_len,)
40            }
41        }
42    }
43}
44
45impl Read for DecodeBuffer {
46    fn read(&mut self, target: &mut [u8]) -> Result<usize, Error> {
47        let max_amount = self.can_drain_to_window_size().unwrap_or(0);
48        let amount = max_amount.min(target.len());
49
50        let mut written = 0;
51        self.drain_to(amount, |buf| {
52            target[written..][..buf.len()].copy_from_slice(buf);
53            written += buf.len();
54            (buf.len(), Ok(()))
55        })?;
56        Ok(amount)
57    }
58}
59
60impl DecodeBuffer {
61    pub fn new(window_size: usize) -> DecodeBuffer {
62        DecodeBuffer {
63            buffer: RingBuffer::new(),
64            dict_content: Vec::new(),
65            window_size,
66            total_output_counter: 0,
67            #[cfg(feature = "hash")]
68            hash: twox_hash::XxHash64::with_seed(0),
69        }
70    }
71
72    pub fn reset(&mut self, window_size: usize) {
73        self.window_size = window_size;
74        self.buffer.clear();
75        self.buffer.reserve(self.window_size);
76        self.dict_content.clear();
77        self.total_output_counter = 0;
78        #[cfg(feature = "hash")]
79        {
80            self.hash = twox_hash::XxHash64::with_seed(0);
81        }
82    }
83
84    pub fn len(&self) -> usize {
85        self.buffer.len()
86    }
87
88    pub fn is_empty(&self) -> bool {
89        self.buffer.is_empty()
90    }
91
92    pub fn push(&mut self, data: &[u8]) {
93        self.buffer.extend(data);
94        self.total_output_counter += data.len() as u64;
95    }
96
97    pub fn repeat(&mut self, offset: usize, match_length: usize) -> Result<(), DecodeBufferError> {
98        if offset > self.buffer.len() {
99            self.repeat_from_dict(offset, match_length)
100        } else {
101            let buf_len = self.buffer.len();
102            let start_idx = buf_len - offset;
103            let end_idx = start_idx + match_length;
104
105            self.buffer.reserve(match_length);
106            if end_idx > buf_len {
107                // We need to copy in chunks.
108                self.repeat_in_chunks(offset, match_length, start_idx);
109            } else {
110                // can just copy parts of the existing buffer
111                // SAFETY: Requirements checked:
112                // 1. start_idx + match_length must be <= self.buffer.len()
113                //      We know that:
114                //      1. start_idx = self.buffer.len() - offset
115                //      2. end_idx = start_idx + match_length
116                //      3. end_idx <= self.buffer.len()
117                //      Thus follows: start_idx + match_length <= self.buffer.len()
118                //
119                // 2. explicitly reserved enough memory for the whole match_length
120                unsafe {
121                    self.buffer
122                        .extend_from_within_unchecked(start_idx, match_length)
123                };
124            }
125
126            self.total_output_counter += match_length as u64;
127            Ok(())
128        }
129    }
130
131    fn repeat_in_chunks(&mut self, offset: usize, match_length: usize, start_idx: usize) {
132        // We have at max offset bytes in one chunk, the last one can be smaller
133        let mut start_idx = start_idx;
134        let mut copied_counter_left = match_length;
135        // TODO this can  be optimized further I think.
136        // Each time we copy a chunk we have a repetiton of length 'offset', so we can copy offset * iteration many bytes from start_idx
137        while copied_counter_left > 0 {
138            let chunksize = usize::min(offset, copied_counter_left);
139
140            // SAFETY: Requirements checked:
141            // 1. start_idx + chunksize must be <= self.buffer.len()
142            //      We know that:
143            //      1. start_idx starts at buffer.len() - offset
144            //      2. chunksize <= offset (== offset for each iteration but the last, and match_length modulo offset in the last iteration)
145            //      3. the buffer grows by offset many bytes each iteration but the last
146            //      4. start_idx is increased by the same amount as the buffer grows each iteration
147            //
148            //      Thus follows: start_idx + chunksize == self.buffer.len() in each iteration but the last, where match_length modulo offset == chunksize < offset
149            //          Meaning: start_idx + chunksize <= self.buffer.len()
150            //
151            // 2. explicitly reserved enough memory for the whole match_length
152            unsafe {
153                self.buffer
154                    .extend_from_within_unchecked(start_idx, chunksize)
155            };
156            copied_counter_left -= chunksize;
157            start_idx += chunksize;
158        }
159    }
160
161    #[cold]
162    fn repeat_from_dict(
163        &mut self,
164        offset: usize,
165        match_length: usize,
166    ) -> Result<(), DecodeBufferError> {
167        if self.total_output_counter <= self.window_size as u64 {
168            // at least part of that repeat is from the dictionary content
169            let bytes_from_dict = offset - self.buffer.len();
170
171            if bytes_from_dict > self.dict_content.len() {
172                return Err(DecodeBufferError::NotEnoughBytesInDictionary {
173                    got: self.dict_content.len(),
174                    need: bytes_from_dict,
175                });
176            }
177
178            if bytes_from_dict < match_length {
179                let dict_slice = &self.dict_content[self.dict_content.len() - bytes_from_dict..];
180                self.buffer.extend(dict_slice);
181
182                self.total_output_counter += bytes_from_dict as u64;
183                return self.repeat(self.buffer.len(), match_length - bytes_from_dict);
184            } else {
185                let low = self.dict_content.len() - bytes_from_dict;
186                let high = low + match_length;
187                let dict_slice = &self.dict_content[low..high];
188                self.buffer.extend(dict_slice);
189            }
190            Ok(())
191        } else {
192            Err(DecodeBufferError::OffsetTooBig {
193                offset,
194                buf_len: self.buffer.len(),
195            })
196        }
197    }
198
199    /// Check if and how many bytes can currently be drawn from the buffer
200    pub fn can_drain_to_window_size(&self) -> Option<usize> {
201        if self.buffer.len() > self.window_size {
202            Some(self.buffer.len() - self.window_size)
203        } else {
204            None
205        }
206    }
207
208    //How many bytes can be drained if the window_size does not have to be maintained
209    pub fn can_drain(&self) -> usize {
210        self.buffer.len()
211    }
212
213    /// Drain as much as possible while retaining enough so that decoding si still possible with the required window_size
214    /// At best call only if can_drain_to_window_size reports a 'high' number of bytes to reduce allocations
215    pub fn drain_to_window_size(&mut self) -> Option<Vec<u8>> {
216        //TODO investigate if it is possible to return the std::vec::Drain iterator directly without collecting here
217        match self.can_drain_to_window_size() {
218            None => None,
219            Some(can_drain) => {
220                let mut vec = Vec::with_capacity(can_drain);
221                self.drain_to(can_drain, |buf| {
222                    vec.extend_from_slice(buf);
223                    (buf.len(), Ok(()))
224                })
225                .ok()?;
226                Some(vec)
227            }
228        }
229    }
230
231    pub fn drain_to_window_size_writer(&mut self, mut sink: impl Write) -> Result<usize, Error> {
232        match self.can_drain_to_window_size() {
233            None => Ok(0),
234            Some(can_drain) => self.drain_to(can_drain, |buf| write_all_bytes(&mut sink, buf)),
235        }
236    }
237
238    /// drain the buffer completely
239    pub fn drain(&mut self) -> Vec<u8> {
240        let (slice1, slice2) = self.buffer.as_slices();
241        #[cfg(feature = "hash")]
242        {
243            self.hash.write(slice1);
244            self.hash.write(slice2);
245        }
246
247        let mut vec = Vec::with_capacity(slice1.len() + slice2.len());
248        vec.extend_from_slice(slice1);
249        vec.extend_from_slice(slice2);
250        self.buffer.clear();
251        vec
252    }
253
254    pub fn drain_to_writer(&mut self, mut sink: impl Write) -> Result<usize, Error> {
255        let write_limit = self.buffer.len();
256        self.drain_to(write_limit, |buf| write_all_bytes(&mut sink, buf))
257    }
258
259    pub fn read_all(&mut self, target: &mut [u8]) -> Result<usize, Error> {
260        let amount = self.buffer.len().min(target.len());
261
262        let mut written = 0;
263        self.drain_to(amount, |buf| {
264            target[written..][..buf.len()].copy_from_slice(buf);
265            written += buf.len();
266            (buf.len(), Ok(()))
267        })?;
268        Ok(amount)
269    }
270
271    /// Semantics of write_bytes:
272    /// Should dump as many of the provided bytes as possible to whatever sink until no bytes are left or an error is encountered
273    /// Return how many bytes have actually been dumped to the sink.
274    fn drain_to(
275        &mut self,
276        amount: usize,
277        mut write_bytes: impl FnMut(&[u8]) -> (usize, Result<(), Error>),
278    ) -> Result<usize, Error> {
279        if amount == 0 {
280            return Ok(0);
281        }
282
283        struct DrainGuard<'a> {
284            buffer: &'a mut RingBuffer,
285            amount: usize,
286        }
287
288        impl<'a> Drop for DrainGuard<'a> {
289            fn drop(&mut self) {
290                if self.amount != 0 {
291                    self.buffer.drop_first_n(self.amount);
292                }
293            }
294        }
295
296        let mut drain_guard = DrainGuard {
297            buffer: &mut self.buffer,
298            amount: 0,
299        };
300
301        let (slice1, slice2) = drain_guard.buffer.as_slices();
302        let n1 = slice1.len().min(amount);
303        let n2 = slice2.len().min(amount - n1);
304
305        if n1 != 0 {
306            let (written1, res1) = write_bytes(&slice1[..n1]);
307            #[cfg(feature = "hash")]
308            self.hash.write(&slice1[..written1]);
309            drain_guard.amount += written1;
310
311            // Apparently this is what clippy thinks is the best way of expressing this
312            res1?;
313
314            // Only if the first call to write_bytes was not a partial write we can continue with slice2
315            // Partial writes SHOULD never happen without res1 being an error, but lets just protect against it anyways.
316            if written1 == n1 && n2 != 0 {
317                let (written2, res2) = write_bytes(&slice2[..n2]);
318                #[cfg(feature = "hash")]
319                self.hash.write(&slice2[..written2]);
320                drain_guard.amount += written2;
321
322                // Apparently this is what clippy thinks is the best way of expressing this
323                res2?;
324            }
325        }
326
327        let amount_written = drain_guard.amount;
328        // Make sure we don't accidentally drop `DrainGuard` earlier.
329        drop(drain_guard);
330
331        Ok(amount_written)
332    }
333}
334
335/// Like Write::write_all but returns partial write length even on error
336fn write_all_bytes(mut sink: impl Write, buf: &[u8]) -> (usize, Result<(), Error>) {
337    let mut written = 0;
338    while written < buf.len() {
339        match sink.write(&buf[written..]) {
340            Ok(0) => return (written, Ok(())),
341            Ok(w) => written += w,
342            Err(e) => return (written, Err(e)),
343        }
344    }
345    (written, Ok(()))
346}
347
348#[cfg(test)]
349mod tests {
350    use super::DecodeBuffer;
351    use crate::io::{Error, ErrorKind, Write};
352
353    extern crate std;
354    use alloc::vec;
355    use alloc::vec::Vec;
356
357    #[test]
358    fn short_writer() {
359        struct ShortWriter {
360            buf: Vec<u8>,
361            write_len: usize,
362        }
363
364        impl Write for ShortWriter {
365            fn write(&mut self, buf: &[u8]) -> std::result::Result<usize, Error> {
366                if buf.len() > self.write_len {
367                    self.buf.extend_from_slice(&buf[..self.write_len]);
368                    Ok(self.write_len)
369                } else {
370                    self.buf.extend_from_slice(buf);
371                    Ok(buf.len())
372                }
373            }
374
375            fn flush(&mut self) -> std::result::Result<(), Error> {
376                Ok(())
377            }
378        }
379
380        let mut short_writer = ShortWriter {
381            buf: vec![],
382            write_len: 10,
383        };
384
385        let mut decode_buf = DecodeBuffer::new(100);
386        decode_buf.push(b"0123456789");
387        decode_buf.repeat(10, 90).unwrap();
388        let repeats = 1000;
389        for _ in 0..repeats {
390            assert_eq!(decode_buf.len(), 100);
391            decode_buf.repeat(10, 50).unwrap();
392            assert_eq!(decode_buf.len(), 150);
393            decode_buf
394                .drain_to_window_size_writer(&mut short_writer)
395                .unwrap();
396            assert_eq!(decode_buf.len(), 100);
397        }
398
399        assert_eq!(short_writer.buf.len(), repeats * 50);
400        decode_buf.drain_to_writer(&mut short_writer).unwrap();
401        assert_eq!(short_writer.buf.len(), repeats * 50 + 100);
402    }
403
404    #[test]
405    fn wouldblock_writer() {
406        struct WouldblockWriter {
407            buf: Vec<u8>,
408            last_blocked: usize,
409            block_every: usize,
410        }
411
412        impl Write for WouldblockWriter {
413            fn write(&mut self, buf: &[u8]) -> std::result::Result<usize, Error> {
414                if self.last_blocked < self.block_every {
415                    self.buf.extend_from_slice(buf);
416                    self.last_blocked += 1;
417                    Ok(buf.len())
418                } else {
419                    self.last_blocked = 0;
420                    Err(Error::from(ErrorKind::WouldBlock))
421                }
422            }
423
424            fn flush(&mut self) -> std::result::Result<(), Error> {
425                Ok(())
426            }
427        }
428
429        let mut short_writer = WouldblockWriter {
430            buf: vec![],
431            last_blocked: 0,
432            block_every: 5,
433        };
434
435        let mut decode_buf = DecodeBuffer::new(100);
436        decode_buf.push(b"0123456789");
437        decode_buf.repeat(10, 90).unwrap();
438        let repeats = 1000;
439        for _ in 0..repeats {
440            assert_eq!(decode_buf.len(), 100);
441            decode_buf.repeat(10, 50).unwrap();
442            assert_eq!(decode_buf.len(), 150);
443            loop {
444                match decode_buf.drain_to_window_size_writer(&mut short_writer) {
445                    Ok(written) => {
446                        if written == 0 {
447                            break;
448                        }
449                    }
450                    Err(e) => {
451                        if e.kind() == ErrorKind::WouldBlock {
452                            continue;
453                        } else {
454                            panic!("Unexpected error {:?}", e);
455                        }
456                    }
457                }
458            }
459            assert_eq!(decode_buf.len(), 100);
460        }
461
462        assert_eq!(short_writer.buf.len(), repeats * 50);
463        loop {
464            match decode_buf.drain_to_writer(&mut short_writer) {
465                Ok(written) => {
466                    if written == 0 {
467                        break;
468                    }
469                }
470                Err(e) => {
471                    if e.kind() == ErrorKind::WouldBlock {
472                        continue;
473                    } else {
474                        panic!("Unexpected error {:?}", e);
475                    }
476                }
477            }
478        }
479        assert_eq!(short_writer.buf.len(), repeats * 50 + 100);
480    }
481}