1use alloc::vec::Vec;
4use core::convert::TryInto;
5#[cfg(feature = "hash")]
6use twox_hash::XxHash64;
7
8#[cfg(feature = "hash")]
9use core::hash::Hasher;
10
11use super::{
12 block_header::BlockHeader, frame_header::FrameHeader, levels::*,
13 match_generator::MatchGeneratorDriver, CompressionLevel, Matcher,
14};
15use crate::fse::fse_encoder::{default_ll_table, default_ml_table, default_of_table, FSETable};
16
17use crate::io::{Read, Write};
18
19pub struct FrameCompressor<R: Read, W: Write, M: Matcher> {
39 uncompressed_data: Option<R>,
40 compressed_data: Option<W>,
41 compression_level: CompressionLevel,
42 state: CompressState<M>,
43 #[cfg(feature = "hash")]
44 hasher: XxHash64,
45}
46
47pub(crate) struct FseTables {
48 pub(crate) ll_default: FSETable,
49 pub(crate) ll_previous: Option<FSETable>,
50 pub(crate) ml_default: FSETable,
51 pub(crate) ml_previous: Option<FSETable>,
52 pub(crate) of_default: FSETable,
53 pub(crate) of_previous: Option<FSETable>,
54}
55
56impl FseTables {
57 pub fn new() -> Self {
58 Self {
59 ll_default: default_ll_table(),
60 ll_previous: None,
61 ml_default: default_ml_table(),
62 ml_previous: None,
63 of_default: default_of_table(),
64 of_previous: None,
65 }
66 }
67}
68
69pub(crate) struct CompressState<M: Matcher> {
70 pub(crate) matcher: M,
71 pub(crate) last_huff_table: Option<crate::huff0::huff0_encoder::HuffmanTable>,
72 pub(crate) fse_tables: FseTables,
73}
74
75impl<R: Read, W: Write> FrameCompressor<R, W, MatchGeneratorDriver> {
76 pub fn new(compression_level: CompressionLevel) -> Self {
78 Self {
79 uncompressed_data: None,
80 compressed_data: None,
81 compression_level,
82 state: CompressState {
83 matcher: MatchGeneratorDriver::new(1024 * 128, 1),
84 last_huff_table: None,
85 fse_tables: FseTables::new(),
86 },
87 #[cfg(feature = "hash")]
88 hasher: XxHash64::with_seed(0),
89 }
90 }
91}
92
93impl<R: Read, W: Write, M: Matcher> FrameCompressor<R, W, M> {
94 pub fn new_with_matcher(matcher: M, compression_level: CompressionLevel) -> Self {
96 Self {
97 uncompressed_data: None,
98 compressed_data: None,
99 state: CompressState {
100 matcher,
101 last_huff_table: None,
102 fse_tables: FseTables::new(),
103 },
104 compression_level,
105 #[cfg(feature = "hash")]
106 hasher: XxHash64::with_seed(0),
107 }
108 }
109
110 pub fn set_source(&mut self, uncompressed_data: R) -> Option<R> {
114 self.uncompressed_data.replace(uncompressed_data)
115 }
116
117 pub fn set_drain(&mut self, compressed_data: W) -> Option<W> {
121 self.compressed_data.replace(compressed_data)
122 }
123
124 pub fn compress(&mut self) {
132 self.state.matcher.reset(self.compression_level);
134 self.state.last_huff_table = None;
135 let source = self.uncompressed_data.as_mut().unwrap();
136 let drain = self.compressed_data.as_mut().unwrap();
137 let output: &mut Vec<u8> = &mut Vec::with_capacity(1024 * 130);
139 let header = FrameHeader {
141 frame_content_size: None,
142 single_segment: false,
143 content_checksum: cfg!(feature = "hash"),
144 dictionary_id: None,
145 window_size: Some(self.state.matcher.window_size()),
146 };
147 header.serialize(output);
148 loop {
150 let mut uncompressed_data = self.state.matcher.get_next_space();
152 let mut read_bytes = 0;
153 let last_block;
154 'read_loop: loop {
155 let new_bytes = source.read(&mut uncompressed_data[read_bytes..]).unwrap();
156 if new_bytes == 0 {
157 last_block = true;
158 break 'read_loop;
159 }
160 read_bytes += new_bytes;
161 if read_bytes == uncompressed_data.len() {
162 last_block = false;
163 break 'read_loop;
164 }
165 }
166 uncompressed_data.resize(read_bytes, 0);
167 #[cfg(feature = "hash")]
169 self.hasher.write(&uncompressed_data);
170 if uncompressed_data.is_empty() {
172 let header = BlockHeader {
173 last_block: true,
174 block_type: crate::blocks::block::BlockType::Raw,
175 block_size: 0,
176 };
177 header.serialize(output);
179 drain.write_all(output).unwrap();
180 output.clear();
181 break;
182 }
183
184 match self.compression_level {
185 CompressionLevel::Uncompressed => {
186 let header = BlockHeader {
187 last_block,
188 block_type: crate::blocks::block::BlockType::Raw,
189 block_size: read_bytes.try_into().unwrap(),
190 };
191 header.serialize(output);
193 output.extend_from_slice(&uncompressed_data);
194 }
195 CompressionLevel::Fastest => {
196 compress_fastest(&mut self.state, last_block, uncompressed_data, output)
197 }
198 _ => {
199 unimplemented!();
200 }
201 }
202 drain.write_all(output).unwrap();
203 output.clear();
204 if last_block {
205 break;
206 }
207 }
208
209 #[cfg(feature = "hash")]
212 {
213 let content_checksum = self.hasher.finish();
216 drain
217 .write_all(&(content_checksum as u32).to_le_bytes())
218 .unwrap();
219 }
220 }
221
222 pub fn source_mut(&mut self) -> Option<&mut R> {
224 self.uncompressed_data.as_mut()
225 }
226
227 pub fn drain_mut(&mut self) -> Option<&mut W> {
229 self.compressed_data.as_mut()
230 }
231
232 pub fn source(&self) -> Option<&R> {
234 self.uncompressed_data.as_ref()
235 }
236
237 pub fn drain(&self) -> Option<&W> {
239 self.compressed_data.as_ref()
240 }
241
242 pub fn take_source(&mut self) -> Option<R> {
244 self.uncompressed_data.take()
245 }
246
247 pub fn take_drain(&mut self) -> Option<W> {
249 self.compressed_data.take()
250 }
251
252 pub fn replace_matcher(&mut self, mut match_generator: M) -> M {
254 core::mem::swap(&mut match_generator, &mut self.state.matcher);
255 match_generator
256 }
257
258 pub fn set_compression_level(
260 &mut self,
261 compression_level: CompressionLevel,
262 ) -> CompressionLevel {
263 let old = self.compression_level;
264 self.compression_level = compression_level;
265 old
266 }
267
268 pub fn compression_level(&self) -> CompressionLevel {
270 self.compression_level
271 }
272}
273
274#[cfg(test)]
275mod tests {
276 use alloc::vec;
277
278 use super::FrameCompressor;
279 use crate::common::MAGIC_NUM;
280 use crate::decoding::FrameDecoder;
281 use alloc::vec::Vec;
282
283 #[test]
284 fn frame_starts_with_magic_num() {
285 let mock_data = [1_u8, 2, 3].as_slice();
286 let mut output: Vec<u8> = Vec::new();
287 let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
288 compressor.set_source(mock_data);
289 compressor.set_drain(&mut output);
290
291 compressor.compress();
292 assert!(output.starts_with(&MAGIC_NUM.to_le_bytes()));
293 }
294
295 #[test]
296 fn very_simple_raw_compress() {
297 let mock_data = [1_u8, 2, 3].as_slice();
298 let mut output: Vec<u8> = Vec::new();
299 let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
300 compressor.set_source(mock_data);
301 compressor.set_drain(&mut output);
302
303 compressor.compress();
304 }
305
306 #[test]
307 fn very_simple_compress() {
308 let mut mock_data = vec![0; 1 << 17];
309 mock_data.extend(vec![1; (1 << 17) - 1]);
310 mock_data.extend(vec![2; (1 << 18) - 1]);
311 mock_data.extend(vec![2; 1 << 17]);
312 mock_data.extend(vec![3; (1 << 17) - 1]);
313 let mut output: Vec<u8> = Vec::new();
314 let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
315 compressor.set_source(mock_data.as_slice());
316 compressor.set_drain(&mut output);
317
318 compressor.compress();
319
320 let mut decoder = FrameDecoder::new();
321 let mut decoded = Vec::with_capacity(mock_data.len());
322 decoder.decode_all_to_vec(&output, &mut decoded).unwrap();
323 assert_eq!(mock_data, decoded);
324
325 let mut decoded = Vec::new();
326 zstd::stream::copy_decode(output.as_slice(), &mut decoded).unwrap();
327 assert_eq!(mock_data, decoded);
328 }
329
330 #[test]
331 fn rle_compress() {
332 let mock_data = vec![0; 1 << 19];
333 let mut output: Vec<u8> = Vec::new();
334 let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
335 compressor.set_source(mock_data.as_slice());
336 compressor.set_drain(&mut output);
337
338 compressor.compress();
339
340 let mut decoder = FrameDecoder::new();
341 let mut decoded = Vec::with_capacity(mock_data.len());
342 decoder.decode_all_to_vec(&output, &mut decoded).unwrap();
343 assert_eq!(mock_data, decoded);
344 }
345
346 #[test]
347 fn aaa_compress() {
348 let mock_data = vec![0, 1, 3, 4, 5];
349 let mut output: Vec<u8> = Vec::new();
350 let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
351 compressor.set_source(mock_data.as_slice());
352 compressor.set_drain(&mut output);
353
354 compressor.compress();
355
356 let mut decoder = FrameDecoder::new();
357 let mut decoded = Vec::with_capacity(mock_data.len());
358 decoder.decode_all_to_vec(&output, &mut decoded).unwrap();
359 assert_eq!(mock_data, decoded);
360
361 let mut decoded = Vec::new();
362 zstd::stream::copy_decode(output.as_slice(), &mut decoded).unwrap();
363 assert_eq!(mock_data, decoded);
364 }
365
366 #[cfg(feature = "std")]
367 #[test]
368 fn fuzz_targets() {
369 use std::io::Read;
370 fn decode_ruzstd(data: &mut dyn std::io::Read) -> Vec<u8> {
371 let mut decoder = crate::decoding::StreamingDecoder::new(data).unwrap();
372 let mut result: Vec<u8> = Vec::new();
373 decoder.read_to_end(&mut result).expect("Decoding failed");
374 result
375 }
376
377 fn decode_ruzstd_writer(mut data: impl Read) -> Vec<u8> {
378 let mut decoder = crate::decoding::FrameDecoder::new();
379 decoder.reset(&mut data).unwrap();
380 let mut result = vec![];
381 while !decoder.is_finished() || decoder.can_collect() > 0 {
382 decoder
383 .decode_blocks(
384 &mut data,
385 crate::decoding::BlockDecodingStrategy::UptoBytes(1024 * 1024),
386 )
387 .unwrap();
388 decoder.collect_to_writer(&mut result).unwrap();
389 }
390 result
391 }
392
393 fn encode_zstd(data: &[u8]) -> Result<Vec<u8>, std::io::Error> {
394 zstd::stream::encode_all(std::io::Cursor::new(data), 3)
395 }
396
397 fn encode_ruzstd_uncompressed(data: &mut dyn std::io::Read) -> Vec<u8> {
398 let mut input = Vec::new();
399 data.read_to_end(&mut input).unwrap();
400
401 crate::encoding::compress_to_vec(
402 input.as_slice(),
403 crate::encoding::CompressionLevel::Uncompressed,
404 )
405 }
406
407 fn encode_ruzstd_compressed(data: &mut dyn std::io::Read) -> Vec<u8> {
408 let mut input = Vec::new();
409 data.read_to_end(&mut input).unwrap();
410
411 crate::encoding::compress_to_vec(
412 input.as_slice(),
413 crate::encoding::CompressionLevel::Fastest,
414 )
415 }
416
417 fn decode_zstd(data: &[u8]) -> Result<Vec<u8>, std::io::Error> {
418 let mut output = Vec::new();
419 zstd::stream::copy_decode(data, &mut output)?;
420 Ok(output)
421 }
422 if std::fs::exists("fuzz/artifacts/interop").unwrap_or(false) {
423 for file in std::fs::read_dir("fuzz/artifacts/interop").unwrap() {
424 if file.as_ref().unwrap().file_type().unwrap().is_file() {
425 let data = std::fs::read(file.unwrap().path()).unwrap();
426 let data = data.as_slice();
427 let compressed = encode_zstd(data).unwrap();
429 let decoded = decode_ruzstd(&mut compressed.as_slice());
430 let decoded2 = decode_ruzstd_writer(&mut compressed.as_slice());
431 assert!(
432 decoded == data,
433 "Decoded data did not match the original input during decompression"
434 );
435 assert_eq!(
436 decoded2, data,
437 "Decoded data did not match the original input during decompression"
438 );
439
440 let mut input = data;
443 let compressed = encode_ruzstd_uncompressed(&mut input);
444 let decoded = decode_zstd(&compressed).unwrap();
445 assert_eq!(
446 decoded, data,
447 "Decoded data did not match the original input during compression"
448 );
449 let mut input = data;
451 let compressed = encode_ruzstd_compressed(&mut input);
452 let decoded = decode_zstd(&compressed).unwrap();
453 assert_eq!(
454 decoded, data,
455 "Decoded data did not match the original input during compression"
456 );
457 }
458 }
459 }
460 }
461}