ruzstd/decoding/
literals_section_decoder.rs

1//! This module contains the decompress_literals function, used to take a
2//! parsed literals header and a source and decompress it.
3
4use super::super::blocks::literals_section::{LiteralsSection, LiteralsSectionType};
5use super::scratch::HuffmanScratch;
6use crate::bit_io::BitReaderReversed;
7use crate::decoding::errors::DecompressLiteralsError;
8use crate::huff0::HuffmanDecoder;
9use alloc::vec::Vec;
10
11/// Decode and decompress the provided literals section into `target`, returning the number of bytes read.
12pub fn decode_literals(
13    section: &LiteralsSection,
14    scratch: &mut HuffmanScratch,
15    source: &[u8],
16    target: &mut Vec<u8>,
17) -> Result<u32, DecompressLiteralsError> {
18    match section.ls_type {
19        LiteralsSectionType::Raw => {
20            target.extend(&source[0..section.regenerated_size as usize]);
21            Ok(section.regenerated_size)
22        }
23        LiteralsSectionType::RLE => {
24            target.resize(target.len() + section.regenerated_size as usize, source[0]);
25            Ok(1)
26        }
27        LiteralsSectionType::Compressed | LiteralsSectionType::Treeless => {
28            let bytes_read = decompress_literals(section, scratch, source, target)?;
29
30            //return sum of used bytes
31            Ok(bytes_read)
32        }
33    }
34}
35
36/// Decompress the provided literals section and source into the provided `target`.
37/// This function is used when the literals section is `Compressed` or `Treeless`
38///
39/// Returns the number of bytes read.
40fn decompress_literals(
41    section: &LiteralsSection,
42    scratch: &mut HuffmanScratch,
43    source: &[u8],
44    target: &mut Vec<u8>,
45) -> Result<u32, DecompressLiteralsError> {
46    use DecompressLiteralsError as err;
47
48    let compressed_size = section.compressed_size.ok_or(err::MissingCompressedSize)? as usize;
49    let num_streams = section.num_streams.ok_or(err::MissingNumStreams)?;
50
51    target.reserve(section.regenerated_size as usize);
52    let source = &source[0..compressed_size];
53    let mut bytes_read = 0;
54
55    match section.ls_type {
56        LiteralsSectionType::Compressed => {
57            //read Huffman tree description
58            bytes_read += scratch.table.build_decoder(source)?;
59            vprintln!("Built huffman table using {} bytes", bytes_read);
60        }
61        LiteralsSectionType::Treeless => {
62            if scratch.table.max_num_bits == 0 {
63                return Err(err::UninitializedHuffmanTable);
64            }
65        }
66        _ => { /* nothing to do, huffman tree has been provided by previous block */ }
67    }
68
69    let source = &source[bytes_read as usize..];
70
71    if num_streams == 4 {
72        //build jumptable
73        if source.len() < 6 {
74            return Err(err::MissingBytesForJumpHeader { got: source.len() });
75        }
76        let jump1 = source[0] as usize + ((source[1] as usize) << 8);
77        let jump2 = jump1 + source[2] as usize + ((source[3] as usize) << 8);
78        let jump3 = jump2 + source[4] as usize + ((source[5] as usize) << 8);
79        bytes_read += 6;
80        let source = &source[6..];
81
82        if source.len() < jump3 {
83            return Err(err::MissingBytesForLiterals {
84                got: source.len(),
85                needed: jump3,
86            });
87        }
88
89        //decode 4 streams
90        let stream1 = &source[..jump1];
91        let stream2 = &source[jump1..jump2];
92        let stream3 = &source[jump2..jump3];
93        let stream4 = &source[jump3..];
94
95        for stream in &[stream1, stream2, stream3, stream4] {
96            let mut decoder = HuffmanDecoder::new(&scratch.table);
97            let mut br = BitReaderReversed::new(stream);
98            //skip the 0 padding at the end of the last byte of the bit stream and throw away the first 1 found
99            let mut skipped_bits = 0;
100            loop {
101                let val = br.get_bits(1);
102                skipped_bits += 1;
103                if val == 1 || skipped_bits > 8 {
104                    break;
105                }
106            }
107            if skipped_bits > 8 {
108                //if more than 7 bits are 0, this is not the correct end of the bitstream. Either a bug or corrupted data
109                return Err(DecompressLiteralsError::ExtraPadding { skipped_bits });
110            }
111            decoder.init_state(&mut br);
112
113            while br.bits_remaining() > -(scratch.table.max_num_bits as isize) {
114                target.push(decoder.decode_symbol());
115                decoder.next_state(&mut br);
116            }
117            if br.bits_remaining() != -(scratch.table.max_num_bits as isize) {
118                return Err(DecompressLiteralsError::BitstreamReadMismatch {
119                    read_til: br.bits_remaining(),
120                    expected: -(scratch.table.max_num_bits as isize),
121                });
122            }
123        }
124
125        bytes_read += source.len() as u32;
126    } else {
127        //just decode the one stream
128        assert!(num_streams == 1);
129        let mut decoder = HuffmanDecoder::new(&scratch.table);
130        let mut br = BitReaderReversed::new(source);
131        let mut skipped_bits = 0;
132        loop {
133            let val = br.get_bits(1);
134            skipped_bits += 1;
135            if val == 1 || skipped_bits > 8 {
136                break;
137            }
138        }
139        if skipped_bits > 8 {
140            //if more than 7 bits are 0, this is not the correct end of the bitstream. Either a bug or corrupted data
141            return Err(DecompressLiteralsError::ExtraPadding { skipped_bits });
142        }
143        decoder.init_state(&mut br);
144        while br.bits_remaining() > -(scratch.table.max_num_bits as isize) {
145            target.push(decoder.decode_symbol());
146            decoder.next_state(&mut br);
147        }
148        bytes_read += source.len() as u32;
149    }
150
151    if target.len() != section.regenerated_size as usize {
152        return Err(DecompressLiteralsError::DecodedLiteralCountMismatch {
153            decoded: target.len(),
154            expected: section.regenerated_size as usize,
155        });
156    }
157
158    Ok(bytes_read)
159}