1use crate::io::{Error, Read, Write};
2use alloc::vec::Vec;
3#[cfg(feature = "hash")]
4use core::hash::Hasher;
5
6use super::ringbuffer::RingBuffer;
7use crate::decoding::errors::DecodeBufferError;
8
9pub struct DecodeBuffer {
10 buffer: RingBuffer,
11 pub dict_content: Vec<u8>,
12
13 pub window_size: usize,
14 total_output_counter: u64,
15 #[cfg(feature = "hash")]
16 pub hash: twox_hash::XxHash64,
17}
18
19impl Read for DecodeBuffer {
20 fn read(&mut self, target: &mut [u8]) -> Result<usize, Error> {
21 let max_amount = self.can_drain_to_window_size().unwrap_or(0);
22 let amount = max_amount.min(target.len());
23
24 let mut written = 0;
25 self.drain_to(amount, |buf| {
26 target[written..][..buf.len()].copy_from_slice(buf);
27 written += buf.len();
28 (buf.len(), Ok(()))
29 })?;
30 Ok(amount)
31 }
32}
33
34impl DecodeBuffer {
35 pub fn new(window_size: usize) -> DecodeBuffer {
36 DecodeBuffer {
37 buffer: RingBuffer::new(),
38 dict_content: Vec::new(),
39 window_size,
40 total_output_counter: 0,
41 #[cfg(feature = "hash")]
42 hash: twox_hash::XxHash64::with_seed(0),
43 }
44 }
45
46 pub fn reset(&mut self, window_size: usize) {
47 self.window_size = window_size;
48 self.buffer.clear();
49 self.buffer.reserve(self.window_size);
50 self.dict_content.clear();
51 self.total_output_counter = 0;
52 #[cfg(feature = "hash")]
53 {
54 self.hash = twox_hash::XxHash64::with_seed(0);
55 }
56 }
57
58 pub fn len(&self) -> usize {
59 self.buffer.len()
60 }
61
62 pub fn push(&mut self, data: &[u8]) {
63 self.buffer.extend(data);
64 self.total_output_counter += data.len() as u64;
65 }
66
67 pub fn repeat(&mut self, offset: usize, match_length: usize) -> Result<(), DecodeBufferError> {
68 if offset > self.buffer.len() {
69 self.repeat_from_dict(offset, match_length)
70 } else {
71 let buf_len = self.buffer.len();
72 let start_idx = buf_len - offset;
73 let end_idx = start_idx + match_length;
74
75 self.buffer.reserve(match_length);
76 if end_idx > buf_len {
77 self.repeat_in_chunks(offset, match_length, start_idx);
79 } else {
80 unsafe {
91 self.buffer
92 .extend_from_within_unchecked(start_idx, match_length)
93 };
94 }
95
96 self.total_output_counter += match_length as u64;
97 Ok(())
98 }
99 }
100
101 fn repeat_in_chunks(&mut self, offset: usize, match_length: usize, start_idx: usize) {
102 let mut start_idx = start_idx;
104 let mut copied_counter_left = match_length;
105 while copied_counter_left > 0 {
108 let chunksize = usize::min(offset, copied_counter_left);
109
110 unsafe {
123 self.buffer
124 .extend_from_within_unchecked(start_idx, chunksize)
125 };
126 copied_counter_left -= chunksize;
127 start_idx += chunksize;
128 }
129 }
130
131 #[cold]
132 fn repeat_from_dict(
133 &mut self,
134 offset: usize,
135 match_length: usize,
136 ) -> Result<(), DecodeBufferError> {
137 if self.total_output_counter <= self.window_size as u64 {
138 let bytes_from_dict = offset - self.buffer.len();
140
141 if bytes_from_dict > self.dict_content.len() {
142 return Err(DecodeBufferError::NotEnoughBytesInDictionary {
143 got: self.dict_content.len(),
144 need: bytes_from_dict,
145 });
146 }
147
148 if bytes_from_dict < match_length {
149 let dict_slice = &self.dict_content[self.dict_content.len() - bytes_from_dict..];
150 self.buffer.extend(dict_slice);
151
152 self.total_output_counter += bytes_from_dict as u64;
153 return self.repeat(self.buffer.len(), match_length - bytes_from_dict);
154 } else {
155 let low = self.dict_content.len() - bytes_from_dict;
156 let high = low + match_length;
157 let dict_slice = &self.dict_content[low..high];
158 self.buffer.extend(dict_slice);
159 }
160 Ok(())
161 } else {
162 Err(DecodeBufferError::OffsetTooBig {
163 offset,
164 buf_len: self.buffer.len(),
165 })
166 }
167 }
168
169 pub fn can_drain_to_window_size(&self) -> Option<usize> {
171 if self.buffer.len() > self.window_size {
172 Some(self.buffer.len() - self.window_size)
173 } else {
174 None
175 }
176 }
177
178 pub fn can_drain(&self) -> usize {
180 self.buffer.len()
181 }
182
183 pub fn drain_to_window_size(&mut self) -> Option<Vec<u8>> {
186 match self.can_drain_to_window_size() {
188 None => None,
189 Some(can_drain) => {
190 let mut vec = Vec::with_capacity(can_drain);
191 self.drain_to(can_drain, |buf| {
192 vec.extend_from_slice(buf);
193 (buf.len(), Ok(()))
194 })
195 .ok()?;
196 Some(vec)
197 }
198 }
199 }
200
201 pub fn drain_to_window_size_writer(&mut self, mut sink: impl Write) -> Result<usize, Error> {
202 match self.can_drain_to_window_size() {
203 None => Ok(0),
204 Some(can_drain) => self.drain_to(can_drain, |buf| write_all_bytes(&mut sink, buf)),
205 }
206 }
207
208 pub fn drain(&mut self) -> Vec<u8> {
210 let (slice1, slice2) = self.buffer.as_slices();
211 #[cfg(feature = "hash")]
212 {
213 self.hash.write(slice1);
214 self.hash.write(slice2);
215 }
216
217 let mut vec = Vec::with_capacity(slice1.len() + slice2.len());
218 vec.extend_from_slice(slice1);
219 vec.extend_from_slice(slice2);
220 self.buffer.clear();
221 vec
222 }
223
224 pub fn drain_to_writer(&mut self, mut sink: impl Write) -> Result<usize, Error> {
225 let write_limit = self.buffer.len();
226 self.drain_to(write_limit, |buf| write_all_bytes(&mut sink, buf))
227 }
228
229 pub fn read_all(&mut self, target: &mut [u8]) -> Result<usize, Error> {
230 let amount = self.buffer.len().min(target.len());
231
232 let mut written = 0;
233 self.drain_to(amount, |buf| {
234 target[written..][..buf.len()].copy_from_slice(buf);
235 written += buf.len();
236 (buf.len(), Ok(()))
237 })?;
238 Ok(amount)
239 }
240
241 fn drain_to(
245 &mut self,
246 amount: usize,
247 mut write_bytes: impl FnMut(&[u8]) -> (usize, Result<(), Error>),
248 ) -> Result<usize, Error> {
249 if amount == 0 {
250 return Ok(0);
251 }
252
253 struct DrainGuard<'a> {
254 buffer: &'a mut RingBuffer,
255 amount: usize,
256 }
257
258 impl Drop for DrainGuard<'_> {
259 fn drop(&mut self) {
260 if self.amount != 0 {
261 self.buffer.drop_first_n(self.amount);
262 }
263 }
264 }
265
266 let mut drain_guard = DrainGuard {
267 buffer: &mut self.buffer,
268 amount: 0,
269 };
270
271 let (slice1, slice2) = drain_guard.buffer.as_slices();
272 let n1 = slice1.len().min(amount);
273 let n2 = slice2.len().min(amount - n1);
274
275 if n1 != 0 {
276 let (written1, res1) = write_bytes(&slice1[..n1]);
277 #[cfg(feature = "hash")]
278 self.hash.write(&slice1[..written1]);
279 drain_guard.amount += written1;
280
281 res1?;
283
284 if written1 == n1 && n2 != 0 {
287 let (written2, res2) = write_bytes(&slice2[..n2]);
288 #[cfg(feature = "hash")]
289 self.hash.write(&slice2[..written2]);
290 drain_guard.amount += written2;
291
292 res2?;
294 }
295 }
296
297 let amount_written = drain_guard.amount;
298 drop(drain_guard);
300
301 Ok(amount_written)
302 }
303}
304
305fn write_all_bytes(mut sink: impl Write, buf: &[u8]) -> (usize, Result<(), Error>) {
307 let mut written = 0;
308 while written < buf.len() {
309 match sink.write(&buf[written..]) {
310 Ok(0) => return (written, Ok(())),
311 Ok(w) => written += w,
312 Err(e) => return (written, Err(e)),
313 }
314 }
315 (written, Ok(()))
316}
317
318#[cfg(test)]
319mod tests {
320 use super::DecodeBuffer;
321 use crate::io::{Error, ErrorKind, Write};
322
323 extern crate std;
324 use alloc::vec;
325 use alloc::vec::Vec;
326
327 #[test]
328 fn short_writer() {
329 struct ShortWriter {
330 buf: Vec<u8>,
331 write_len: usize,
332 }
333
334 impl Write for ShortWriter {
335 fn write(&mut self, buf: &[u8]) -> std::result::Result<usize, Error> {
336 if buf.len() > self.write_len {
337 self.buf.extend_from_slice(&buf[..self.write_len]);
338 Ok(self.write_len)
339 } else {
340 self.buf.extend_from_slice(buf);
341 Ok(buf.len())
342 }
343 }
344
345 fn flush(&mut self) -> std::result::Result<(), Error> {
346 Ok(())
347 }
348 }
349
350 let mut short_writer = ShortWriter {
351 buf: vec![],
352 write_len: 10,
353 };
354
355 let mut decode_buf = DecodeBuffer::new(100);
356 decode_buf.push(b"0123456789");
357 decode_buf.repeat(10, 90).unwrap();
358 let repeats = 1000;
359 for _ in 0..repeats {
360 assert_eq!(decode_buf.len(), 100);
361 decode_buf.repeat(10, 50).unwrap();
362 assert_eq!(decode_buf.len(), 150);
363 decode_buf
364 .drain_to_window_size_writer(&mut short_writer)
365 .unwrap();
366 assert_eq!(decode_buf.len(), 100);
367 }
368
369 assert_eq!(short_writer.buf.len(), repeats * 50);
370 decode_buf.drain_to_writer(&mut short_writer).unwrap();
371 assert_eq!(short_writer.buf.len(), repeats * 50 + 100);
372 }
373
374 #[test]
375 fn wouldblock_writer() {
376 struct WouldblockWriter {
377 buf: Vec<u8>,
378 last_blocked: usize,
379 block_every: usize,
380 }
381
382 impl Write for WouldblockWriter {
383 fn write(&mut self, buf: &[u8]) -> std::result::Result<usize, Error> {
384 if self.last_blocked < self.block_every {
385 self.buf.extend_from_slice(buf);
386 self.last_blocked += 1;
387 Ok(buf.len())
388 } else {
389 self.last_blocked = 0;
390 Err(Error::from(ErrorKind::WouldBlock))
391 }
392 }
393
394 fn flush(&mut self) -> std::result::Result<(), Error> {
395 Ok(())
396 }
397 }
398
399 let mut short_writer = WouldblockWriter {
400 buf: vec![],
401 last_blocked: 0,
402 block_every: 5,
403 };
404
405 let mut decode_buf = DecodeBuffer::new(100);
406 decode_buf.push(b"0123456789");
407 decode_buf.repeat(10, 90).unwrap();
408 let repeats = 1000;
409 for _ in 0..repeats {
410 assert_eq!(decode_buf.len(), 100);
411 decode_buf.repeat(10, 50).unwrap();
412 assert_eq!(decode_buf.len(), 150);
413 loop {
414 match decode_buf.drain_to_window_size_writer(&mut short_writer) {
415 Ok(written) => {
416 if written == 0 {
417 break;
418 }
419 }
420 Err(e) => {
421 if e.kind() == ErrorKind::WouldBlock {
422 continue;
423 } else {
424 panic!("Unexpected error {:?}", e);
425 }
426 }
427 }
428 }
429 assert_eq!(decode_buf.len(), 100);
430 }
431
432 assert_eq!(short_writer.buf.len(), repeats * 50);
433 loop {
434 match decode_buf.drain_to_writer(&mut short_writer) {
435 Ok(written) => {
436 if written == 0 {
437 break;
438 }
439 }
440 Err(e) => {
441 if e.kind() == ErrorKind::WouldBlock {
442 continue;
443 } else {
444 panic!("Unexpected error {:?}", e);
445 }
446 }
447 }
448 }
449 assert_eq!(short_writer.buf.len(), repeats * 50 + 100);
450 }
451}