1use crate::io::{Error, Read, Write};
2use alloc::vec::Vec;
3#[cfg(feature = "hash")]
4use core::hash::Hasher;
5
6use super::ringbuffer::RingBuffer;
7
8pub struct DecodeBuffer {
9 buffer: RingBuffer,
10 pub dict_content: Vec<u8>,
11
12 pub window_size: usize,
13 total_output_counter: u64,
14 #[cfg(feature = "hash")]
15 pub hash: twox_hash::XxHash64,
16}
17
18#[derive(Debug)]
19#[non_exhaustive]
20pub enum DecodeBufferError {
21 NotEnoughBytesInDictionary { got: usize, need: usize },
22 OffsetTooBig { offset: usize, buf_len: usize },
23}
24
25#[cfg(feature = "std")]
26impl std::error::Error for DecodeBufferError {}
27
28impl core::fmt::Display for DecodeBufferError {
29 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
30 match self {
31 DecodeBufferError::NotEnoughBytesInDictionary { got, need } => {
32 write!(
33 f,
34 "Need {} bytes from the dictionary but it is only {} bytes long",
35 need, got,
36 )
37 }
38 DecodeBufferError::OffsetTooBig { offset, buf_len } => {
39 write!(f, "offset: {} bigger than buffer: {}", offset, buf_len,)
40 }
41 }
42 }
43}
44
45impl Read for DecodeBuffer {
46 fn read(&mut self, target: &mut [u8]) -> Result<usize, Error> {
47 let max_amount = self.can_drain_to_window_size().unwrap_or(0);
48 let amount = max_amount.min(target.len());
49
50 let mut written = 0;
51 self.drain_to(amount, |buf| {
52 target[written..][..buf.len()].copy_from_slice(buf);
53 written += buf.len();
54 (buf.len(), Ok(()))
55 })?;
56 Ok(amount)
57 }
58}
59
60impl DecodeBuffer {
61 pub fn new(window_size: usize) -> DecodeBuffer {
62 DecodeBuffer {
63 buffer: RingBuffer::new(),
64 dict_content: Vec::new(),
65 window_size,
66 total_output_counter: 0,
67 #[cfg(feature = "hash")]
68 hash: twox_hash::XxHash64::with_seed(0),
69 }
70 }
71
72 pub fn reset(&mut self, window_size: usize) {
73 self.window_size = window_size;
74 self.buffer.clear();
75 self.buffer.reserve(self.window_size);
76 self.dict_content.clear();
77 self.total_output_counter = 0;
78 #[cfg(feature = "hash")]
79 {
80 self.hash = twox_hash::XxHash64::with_seed(0);
81 }
82 }
83
84 pub fn len(&self) -> usize {
85 self.buffer.len()
86 }
87
88 pub fn is_empty(&self) -> bool {
89 self.buffer.is_empty()
90 }
91
92 pub fn push(&mut self, data: &[u8]) {
93 self.buffer.extend(data);
94 self.total_output_counter += data.len() as u64;
95 }
96
97 pub fn repeat(&mut self, offset: usize, match_length: usize) -> Result<(), DecodeBufferError> {
98 if offset > self.buffer.len() {
99 self.repeat_from_dict(offset, match_length)
100 } else {
101 let buf_len = self.buffer.len();
102 let start_idx = buf_len - offset;
103 let end_idx = start_idx + match_length;
104
105 self.buffer.reserve(match_length);
106 if end_idx > buf_len {
107 self.repeat_in_chunks(offset, match_length, start_idx);
109 } else {
110 unsafe {
121 self.buffer
122 .extend_from_within_unchecked(start_idx, match_length)
123 };
124 }
125
126 self.total_output_counter += match_length as u64;
127 Ok(())
128 }
129 }
130
131 fn repeat_in_chunks(&mut self, offset: usize, match_length: usize, start_idx: usize) {
132 let mut start_idx = start_idx;
134 let mut copied_counter_left = match_length;
135 while copied_counter_left > 0 {
138 let chunksize = usize::min(offset, copied_counter_left);
139
140 unsafe {
153 self.buffer
154 .extend_from_within_unchecked(start_idx, chunksize)
155 };
156 copied_counter_left -= chunksize;
157 start_idx += chunksize;
158 }
159 }
160
161 #[cold]
162 fn repeat_from_dict(
163 &mut self,
164 offset: usize,
165 match_length: usize,
166 ) -> Result<(), DecodeBufferError> {
167 if self.total_output_counter <= self.window_size as u64 {
168 let bytes_from_dict = offset - self.buffer.len();
170
171 if bytes_from_dict > self.dict_content.len() {
172 return Err(DecodeBufferError::NotEnoughBytesInDictionary {
173 got: self.dict_content.len(),
174 need: bytes_from_dict,
175 });
176 }
177
178 if bytes_from_dict < match_length {
179 let dict_slice = &self.dict_content[self.dict_content.len() - bytes_from_dict..];
180 self.buffer.extend(dict_slice);
181
182 self.total_output_counter += bytes_from_dict as u64;
183 return self.repeat(self.buffer.len(), match_length - bytes_from_dict);
184 } else {
185 let low = self.dict_content.len() - bytes_from_dict;
186 let high = low + match_length;
187 let dict_slice = &self.dict_content[low..high];
188 self.buffer.extend(dict_slice);
189 }
190 Ok(())
191 } else {
192 Err(DecodeBufferError::OffsetTooBig {
193 offset,
194 buf_len: self.buffer.len(),
195 })
196 }
197 }
198
199 pub fn can_drain_to_window_size(&self) -> Option<usize> {
201 if self.buffer.len() > self.window_size {
202 Some(self.buffer.len() - self.window_size)
203 } else {
204 None
205 }
206 }
207
208 pub fn can_drain(&self) -> usize {
210 self.buffer.len()
211 }
212
213 pub fn drain_to_window_size(&mut self) -> Option<Vec<u8>> {
216 match self.can_drain_to_window_size() {
218 None => None,
219 Some(can_drain) => {
220 let mut vec = Vec::with_capacity(can_drain);
221 self.drain_to(can_drain, |buf| {
222 vec.extend_from_slice(buf);
223 (buf.len(), Ok(()))
224 })
225 .ok()?;
226 Some(vec)
227 }
228 }
229 }
230
231 pub fn drain_to_window_size_writer(&mut self, mut sink: impl Write) -> Result<usize, Error> {
232 match self.can_drain_to_window_size() {
233 None => Ok(0),
234 Some(can_drain) => self.drain_to(can_drain, |buf| write_all_bytes(&mut sink, buf)),
235 }
236 }
237
238 pub fn drain(&mut self) -> Vec<u8> {
240 let (slice1, slice2) = self.buffer.as_slices();
241 #[cfg(feature = "hash")]
242 {
243 self.hash.write(slice1);
244 self.hash.write(slice2);
245 }
246
247 let mut vec = Vec::with_capacity(slice1.len() + slice2.len());
248 vec.extend_from_slice(slice1);
249 vec.extend_from_slice(slice2);
250 self.buffer.clear();
251 vec
252 }
253
254 pub fn drain_to_writer(&mut self, mut sink: impl Write) -> Result<usize, Error> {
255 let write_limit = self.buffer.len();
256 self.drain_to(write_limit, |buf| write_all_bytes(&mut sink, buf))
257 }
258
259 pub fn read_all(&mut self, target: &mut [u8]) -> Result<usize, Error> {
260 let amount = self.buffer.len().min(target.len());
261
262 let mut written = 0;
263 self.drain_to(amount, |buf| {
264 target[written..][..buf.len()].copy_from_slice(buf);
265 written += buf.len();
266 (buf.len(), Ok(()))
267 })?;
268 Ok(amount)
269 }
270
271 fn drain_to(
275 &mut self,
276 amount: usize,
277 mut write_bytes: impl FnMut(&[u8]) -> (usize, Result<(), Error>),
278 ) -> Result<usize, Error> {
279 if amount == 0 {
280 return Ok(0);
281 }
282
283 struct DrainGuard<'a> {
284 buffer: &'a mut RingBuffer,
285 amount: usize,
286 }
287
288 impl<'a> Drop for DrainGuard<'a> {
289 fn drop(&mut self) {
290 if self.amount != 0 {
291 self.buffer.drop_first_n(self.amount);
292 }
293 }
294 }
295
296 let mut drain_guard = DrainGuard {
297 buffer: &mut self.buffer,
298 amount: 0,
299 };
300
301 let (slice1, slice2) = drain_guard.buffer.as_slices();
302 let n1 = slice1.len().min(amount);
303 let n2 = slice2.len().min(amount - n1);
304
305 if n1 != 0 {
306 let (written1, res1) = write_bytes(&slice1[..n1]);
307 #[cfg(feature = "hash")]
308 self.hash.write(&slice1[..written1]);
309 drain_guard.amount += written1;
310
311 res1?;
313
314 if written1 == n1 && n2 != 0 {
317 let (written2, res2) = write_bytes(&slice2[..n2]);
318 #[cfg(feature = "hash")]
319 self.hash.write(&slice2[..written2]);
320 drain_guard.amount += written2;
321
322 res2?;
324 }
325 }
326
327 let amount_written = drain_guard.amount;
328 drop(drain_guard);
330
331 Ok(amount_written)
332 }
333}
334
335fn write_all_bytes(mut sink: impl Write, buf: &[u8]) -> (usize, Result<(), Error>) {
337 let mut written = 0;
338 while written < buf.len() {
339 match sink.write(&buf[written..]) {
340 Ok(0) => return (written, Ok(())),
341 Ok(w) => written += w,
342 Err(e) => return (written, Err(e)),
343 }
344 }
345 (written, Ok(()))
346}
347
348#[cfg(test)]
349mod tests {
350 use super::DecodeBuffer;
351 use crate::io::{Error, ErrorKind, Write};
352
353 extern crate std;
354 use alloc::vec;
355 use alloc::vec::Vec;
356
357 #[test]
358 fn short_writer() {
359 struct ShortWriter {
360 buf: Vec<u8>,
361 write_len: usize,
362 }
363
364 impl Write for ShortWriter {
365 fn write(&mut self, buf: &[u8]) -> std::result::Result<usize, Error> {
366 if buf.len() > self.write_len {
367 self.buf.extend_from_slice(&buf[..self.write_len]);
368 Ok(self.write_len)
369 } else {
370 self.buf.extend_from_slice(buf);
371 Ok(buf.len())
372 }
373 }
374
375 fn flush(&mut self) -> std::result::Result<(), Error> {
376 Ok(())
377 }
378 }
379
380 let mut short_writer = ShortWriter {
381 buf: vec![],
382 write_len: 10,
383 };
384
385 let mut decode_buf = DecodeBuffer::new(100);
386 decode_buf.push(b"0123456789");
387 decode_buf.repeat(10, 90).unwrap();
388 let repeats = 1000;
389 for _ in 0..repeats {
390 assert_eq!(decode_buf.len(), 100);
391 decode_buf.repeat(10, 50).unwrap();
392 assert_eq!(decode_buf.len(), 150);
393 decode_buf
394 .drain_to_window_size_writer(&mut short_writer)
395 .unwrap();
396 assert_eq!(decode_buf.len(), 100);
397 }
398
399 assert_eq!(short_writer.buf.len(), repeats * 50);
400 decode_buf.drain_to_writer(&mut short_writer).unwrap();
401 assert_eq!(short_writer.buf.len(), repeats * 50 + 100);
402 }
403
404 #[test]
405 fn wouldblock_writer() {
406 struct WouldblockWriter {
407 buf: Vec<u8>,
408 last_blocked: usize,
409 block_every: usize,
410 }
411
412 impl Write for WouldblockWriter {
413 fn write(&mut self, buf: &[u8]) -> std::result::Result<usize, Error> {
414 if self.last_blocked < self.block_every {
415 self.buf.extend_from_slice(buf);
416 self.last_blocked += 1;
417 Ok(buf.len())
418 } else {
419 self.last_blocked = 0;
420 Err(Error::from(ErrorKind::WouldBlock))
421 }
422 }
423
424 fn flush(&mut self) -> std::result::Result<(), Error> {
425 Ok(())
426 }
427 }
428
429 let mut short_writer = WouldblockWriter {
430 buf: vec![],
431 last_blocked: 0,
432 block_every: 5,
433 };
434
435 let mut decode_buf = DecodeBuffer::new(100);
436 decode_buf.push(b"0123456789");
437 decode_buf.repeat(10, 90).unwrap();
438 let repeats = 1000;
439 for _ in 0..repeats {
440 assert_eq!(decode_buf.len(), 100);
441 decode_buf.repeat(10, 50).unwrap();
442 assert_eq!(decode_buf.len(), 150);
443 loop {
444 match decode_buf.drain_to_window_size_writer(&mut short_writer) {
445 Ok(written) => {
446 if written == 0 {
447 break;
448 }
449 }
450 Err(e) => {
451 if e.kind() == ErrorKind::WouldBlock {
452 continue;
453 } else {
454 panic!("Unexpected error {:?}", e);
455 }
456 }
457 }
458 }
459 assert_eq!(decode_buf.len(), 100);
460 }
461
462 assert_eq!(short_writer.buf.len(), repeats * 50);
463 loop {
464 match decode_buf.drain_to_writer(&mut short_writer) {
465 Ok(written) => {
466 if written == 0 {
467 break;
468 }
469 }
470 Err(e) => {
471 if e.kind() == ErrorKind::WouldBlock {
472 continue;
473 } else {
474 panic!("Unexpected error {:?}", e);
475 }
476 }
477 }
478 }
479 assert_eq!(short_writer.buf.len(), repeats * 50 + 100);
480 }
481}