ruzstd/decoding/
decodebuffer.rs

1use crate::io::{Error, Read, Write};
2use alloc::vec::Vec;
3#[cfg(feature = "hash")]
4use core::hash::Hasher;
5
6use super::ringbuffer::RingBuffer;
7use crate::decoding::errors::DecodeBufferError;
8
9pub struct DecodeBuffer {
10    buffer: RingBuffer,
11    pub dict_content: Vec<u8>,
12
13    pub window_size: usize,
14    total_output_counter: u64,
15    #[cfg(feature = "hash")]
16    pub hash: twox_hash::XxHash64,
17}
18
19impl Read for DecodeBuffer {
20    fn read(&mut self, target: &mut [u8]) -> Result<usize, Error> {
21        let max_amount = self.can_drain_to_window_size().unwrap_or(0);
22        let amount = max_amount.min(target.len());
23
24        let mut written = 0;
25        self.drain_to(amount, |buf| {
26            target[written..][..buf.len()].copy_from_slice(buf);
27            written += buf.len();
28            (buf.len(), Ok(()))
29        })?;
30        Ok(amount)
31    }
32}
33
34impl DecodeBuffer {
35    pub fn new(window_size: usize) -> DecodeBuffer {
36        DecodeBuffer {
37            buffer: RingBuffer::new(),
38            dict_content: Vec::new(),
39            window_size,
40            total_output_counter: 0,
41            #[cfg(feature = "hash")]
42            hash: twox_hash::XxHash64::with_seed(0),
43        }
44    }
45
46    pub fn reset(&mut self, window_size: usize) {
47        self.window_size = window_size;
48        self.buffer.clear();
49        self.buffer.reserve(self.window_size);
50        self.dict_content.clear();
51        self.total_output_counter = 0;
52        #[cfg(feature = "hash")]
53        {
54            self.hash = twox_hash::XxHash64::with_seed(0);
55        }
56    }
57
58    pub fn len(&self) -> usize {
59        self.buffer.len()
60    }
61
62    pub fn push(&mut self, data: &[u8]) {
63        self.buffer.extend(data);
64        self.total_output_counter += data.len() as u64;
65    }
66
67    pub fn repeat(&mut self, offset: usize, match_length: usize) -> Result<(), DecodeBufferError> {
68        if offset > self.buffer.len() {
69            self.repeat_from_dict(offset, match_length)
70        } else {
71            let buf_len = self.buffer.len();
72            let start_idx = buf_len - offset;
73            let end_idx = start_idx + match_length;
74
75            self.buffer.reserve(match_length);
76            if end_idx > buf_len {
77                // We need to copy in chunks.
78                self.repeat_in_chunks(offset, match_length, start_idx);
79            } else {
80                // can just copy parts of the existing buffer
81                // SAFETY: Requirements checked:
82                // 1. start_idx + match_length must be <= self.buffer.len()
83                //      We know that:
84                //      1. start_idx = self.buffer.len() - offset
85                //      2. end_idx = start_idx + match_length
86                //      3. end_idx <= self.buffer.len()
87                //      Thus follows: start_idx + match_length <= self.buffer.len()
88                //
89                // 2. explicitly reserved enough memory for the whole match_length
90                unsafe {
91                    self.buffer
92                        .extend_from_within_unchecked(start_idx, match_length)
93                };
94            }
95
96            self.total_output_counter += match_length as u64;
97            Ok(())
98        }
99    }
100
101    fn repeat_in_chunks(&mut self, offset: usize, match_length: usize, start_idx: usize) {
102        // We have at max offset bytes in one chunk, the last one can be smaller
103        let mut start_idx = start_idx;
104        let mut copied_counter_left = match_length;
105        // TODO this can  be optimized further I think.
106        // Each time we copy a chunk we have a repetiton of length 'offset', so we can copy offset * iteration many bytes from start_idx
107        while copied_counter_left > 0 {
108            let chunksize = usize::min(offset, copied_counter_left);
109
110            // SAFETY: Requirements checked:
111            // 1. start_idx + chunksize must be <= self.buffer.len()
112            //      We know that:
113            //      1. start_idx starts at buffer.len() - offset
114            //      2. chunksize <= offset (== offset for each iteration but the last, and match_length modulo offset in the last iteration)
115            //      3. the buffer grows by offset many bytes each iteration but the last
116            //      4. start_idx is increased by the same amount as the buffer grows each iteration
117            //
118            //      Thus follows: start_idx + chunksize == self.buffer.len() in each iteration but the last, where match_length modulo offset == chunksize < offset
119            //          Meaning: start_idx + chunksize <= self.buffer.len()
120            //
121            // 2. explicitly reserved enough memory for the whole match_length
122            unsafe {
123                self.buffer
124                    .extend_from_within_unchecked(start_idx, chunksize)
125            };
126            copied_counter_left -= chunksize;
127            start_idx += chunksize;
128        }
129    }
130
131    #[cold]
132    fn repeat_from_dict(
133        &mut self,
134        offset: usize,
135        match_length: usize,
136    ) -> Result<(), DecodeBufferError> {
137        if self.total_output_counter <= self.window_size as u64 {
138            // at least part of that repeat is from the dictionary content
139            let bytes_from_dict = offset - self.buffer.len();
140
141            if bytes_from_dict > self.dict_content.len() {
142                return Err(DecodeBufferError::NotEnoughBytesInDictionary {
143                    got: self.dict_content.len(),
144                    need: bytes_from_dict,
145                });
146            }
147
148            if bytes_from_dict < match_length {
149                let dict_slice = &self.dict_content[self.dict_content.len() - bytes_from_dict..];
150                self.buffer.extend(dict_slice);
151
152                self.total_output_counter += bytes_from_dict as u64;
153                return self.repeat(self.buffer.len(), match_length - bytes_from_dict);
154            } else {
155                let low = self.dict_content.len() - bytes_from_dict;
156                let high = low + match_length;
157                let dict_slice = &self.dict_content[low..high];
158                self.buffer.extend(dict_slice);
159            }
160            Ok(())
161        } else {
162            Err(DecodeBufferError::OffsetTooBig {
163                offset,
164                buf_len: self.buffer.len(),
165            })
166        }
167    }
168
169    /// Check if and how many bytes can currently be drawn from the buffer
170    pub fn can_drain_to_window_size(&self) -> Option<usize> {
171        if self.buffer.len() > self.window_size {
172            Some(self.buffer.len() - self.window_size)
173        } else {
174            None
175        }
176    }
177
178    //How many bytes can be drained if the window_size does not have to be maintained
179    pub fn can_drain(&self) -> usize {
180        self.buffer.len()
181    }
182
183    /// Drain as much as possible while retaining enough so that decoding si still possible with the required window_size
184    /// At best call only if can_drain_to_window_size reports a 'high' number of bytes to reduce allocations
185    pub fn drain_to_window_size(&mut self) -> Option<Vec<u8>> {
186        //TODO investigate if it is possible to return the std::vec::Drain iterator directly without collecting here
187        match self.can_drain_to_window_size() {
188            None => None,
189            Some(can_drain) => {
190                let mut vec = Vec::with_capacity(can_drain);
191                self.drain_to(can_drain, |buf| {
192                    vec.extend_from_slice(buf);
193                    (buf.len(), Ok(()))
194                })
195                .ok()?;
196                Some(vec)
197            }
198        }
199    }
200
201    pub fn drain_to_window_size_writer(&mut self, mut sink: impl Write) -> Result<usize, Error> {
202        match self.can_drain_to_window_size() {
203            None => Ok(0),
204            Some(can_drain) => self.drain_to(can_drain, |buf| write_all_bytes(&mut sink, buf)),
205        }
206    }
207
208    /// drain the buffer completely
209    pub fn drain(&mut self) -> Vec<u8> {
210        let (slice1, slice2) = self.buffer.as_slices();
211        #[cfg(feature = "hash")]
212        {
213            self.hash.write(slice1);
214            self.hash.write(slice2);
215        }
216
217        let mut vec = Vec::with_capacity(slice1.len() + slice2.len());
218        vec.extend_from_slice(slice1);
219        vec.extend_from_slice(slice2);
220        self.buffer.clear();
221        vec
222    }
223
224    pub fn drain_to_writer(&mut self, mut sink: impl Write) -> Result<usize, Error> {
225        let write_limit = self.buffer.len();
226        self.drain_to(write_limit, |buf| write_all_bytes(&mut sink, buf))
227    }
228
229    pub fn read_all(&mut self, target: &mut [u8]) -> Result<usize, Error> {
230        let amount = self.buffer.len().min(target.len());
231
232        let mut written = 0;
233        self.drain_to(amount, |buf| {
234            target[written..][..buf.len()].copy_from_slice(buf);
235            written += buf.len();
236            (buf.len(), Ok(()))
237        })?;
238        Ok(amount)
239    }
240
241    /// Semantics of write_bytes:
242    /// Should dump as many of the provided bytes as possible to whatever sink until no bytes are left or an error is encountered
243    /// Return how many bytes have actually been dumped to the sink.
244    fn drain_to(
245        &mut self,
246        amount: usize,
247        mut write_bytes: impl FnMut(&[u8]) -> (usize, Result<(), Error>),
248    ) -> Result<usize, Error> {
249        if amount == 0 {
250            return Ok(0);
251        }
252
253        struct DrainGuard<'a> {
254            buffer: &'a mut RingBuffer,
255            amount: usize,
256        }
257
258        impl Drop for DrainGuard<'_> {
259            fn drop(&mut self) {
260                if self.amount != 0 {
261                    self.buffer.drop_first_n(self.amount);
262                }
263            }
264        }
265
266        let mut drain_guard = DrainGuard {
267            buffer: &mut self.buffer,
268            amount: 0,
269        };
270
271        let (slice1, slice2) = drain_guard.buffer.as_slices();
272        let n1 = slice1.len().min(amount);
273        let n2 = slice2.len().min(amount - n1);
274
275        if n1 != 0 {
276            let (written1, res1) = write_bytes(&slice1[..n1]);
277            #[cfg(feature = "hash")]
278            self.hash.write(&slice1[..written1]);
279            drain_guard.amount += written1;
280
281            // Apparently this is what clippy thinks is the best way of expressing this
282            res1?;
283
284            // Only if the first call to write_bytes was not a partial write we can continue with slice2
285            // Partial writes SHOULD never happen without res1 being an error, but lets just protect against it anyways.
286            if written1 == n1 && n2 != 0 {
287                let (written2, res2) = write_bytes(&slice2[..n2]);
288                #[cfg(feature = "hash")]
289                self.hash.write(&slice2[..written2]);
290                drain_guard.amount += written2;
291
292                // Apparently this is what clippy thinks is the best way of expressing this
293                res2?;
294            }
295        }
296
297        let amount_written = drain_guard.amount;
298        // Make sure we don't accidentally drop `DrainGuard` earlier.
299        drop(drain_guard);
300
301        Ok(amount_written)
302    }
303}
304
305/// Like Write::write_all but returns partial write length even on error
306fn write_all_bytes(mut sink: impl Write, buf: &[u8]) -> (usize, Result<(), Error>) {
307    let mut written = 0;
308    while written < buf.len() {
309        match sink.write(&buf[written..]) {
310            Ok(0) => return (written, Ok(())),
311            Ok(w) => written += w,
312            Err(e) => return (written, Err(e)),
313        }
314    }
315    (written, Ok(()))
316}
317
318#[cfg(test)]
319mod tests {
320    use super::DecodeBuffer;
321    use crate::io::{Error, ErrorKind, Write};
322
323    extern crate std;
324    use alloc::vec;
325    use alloc::vec::Vec;
326
327    #[test]
328    fn short_writer() {
329        struct ShortWriter {
330            buf: Vec<u8>,
331            write_len: usize,
332        }
333
334        impl Write for ShortWriter {
335            fn write(&mut self, buf: &[u8]) -> std::result::Result<usize, Error> {
336                if buf.len() > self.write_len {
337                    self.buf.extend_from_slice(&buf[..self.write_len]);
338                    Ok(self.write_len)
339                } else {
340                    self.buf.extend_from_slice(buf);
341                    Ok(buf.len())
342                }
343            }
344
345            fn flush(&mut self) -> std::result::Result<(), Error> {
346                Ok(())
347            }
348        }
349
350        let mut short_writer = ShortWriter {
351            buf: vec![],
352            write_len: 10,
353        };
354
355        let mut decode_buf = DecodeBuffer::new(100);
356        decode_buf.push(b"0123456789");
357        decode_buf.repeat(10, 90).unwrap();
358        let repeats = 1000;
359        for _ in 0..repeats {
360            assert_eq!(decode_buf.len(), 100);
361            decode_buf.repeat(10, 50).unwrap();
362            assert_eq!(decode_buf.len(), 150);
363            decode_buf
364                .drain_to_window_size_writer(&mut short_writer)
365                .unwrap();
366            assert_eq!(decode_buf.len(), 100);
367        }
368
369        assert_eq!(short_writer.buf.len(), repeats * 50);
370        decode_buf.drain_to_writer(&mut short_writer).unwrap();
371        assert_eq!(short_writer.buf.len(), repeats * 50 + 100);
372    }
373
374    #[test]
375    fn wouldblock_writer() {
376        struct WouldblockWriter {
377            buf: Vec<u8>,
378            last_blocked: usize,
379            block_every: usize,
380        }
381
382        impl Write for WouldblockWriter {
383            fn write(&mut self, buf: &[u8]) -> std::result::Result<usize, Error> {
384                if self.last_blocked < self.block_every {
385                    self.buf.extend_from_slice(buf);
386                    self.last_blocked += 1;
387                    Ok(buf.len())
388                } else {
389                    self.last_blocked = 0;
390                    Err(Error::from(ErrorKind::WouldBlock))
391                }
392            }
393
394            fn flush(&mut self) -> std::result::Result<(), Error> {
395                Ok(())
396            }
397        }
398
399        let mut short_writer = WouldblockWriter {
400            buf: vec![],
401            last_blocked: 0,
402            block_every: 5,
403        };
404
405        let mut decode_buf = DecodeBuffer::new(100);
406        decode_buf.push(b"0123456789");
407        decode_buf.repeat(10, 90).unwrap();
408        let repeats = 1000;
409        for _ in 0..repeats {
410            assert_eq!(decode_buf.len(), 100);
411            decode_buf.repeat(10, 50).unwrap();
412            assert_eq!(decode_buf.len(), 150);
413            loop {
414                match decode_buf.drain_to_window_size_writer(&mut short_writer) {
415                    Ok(written) => {
416                        if written == 0 {
417                            break;
418                        }
419                    }
420                    Err(e) => {
421                        if e.kind() == ErrorKind::WouldBlock {
422                            continue;
423                        } else {
424                            panic!("Unexpected error {:?}", e);
425                        }
426                    }
427                }
428            }
429            assert_eq!(decode_buf.len(), 100);
430        }
431
432        assert_eq!(short_writer.buf.len(), repeats * 50);
433        loop {
434            match decode_buf.drain_to_writer(&mut short_writer) {
435                Ok(written) => {
436                    if written == 0 {
437                        break;
438                    }
439                }
440                Err(e) => {
441                    if e.kind() == ErrorKind::WouldBlock {
442                        continue;
443                    } else {
444                        panic!("Unexpected error {:?}", e);
445                    }
446                }
447            }
448        }
449        assert_eq!(short_writer.buf.len(), repeats * 50 + 100);
450    }
451}