ruzstd/decoding/
literals_section_decoder.rs

1//! This module contains the [decompress_literals] function, used to take a
2//! parsed literals header and a source and decompress it.
3
4use super::super::blocks::literals_section::{LiteralsSection, LiteralsSectionType};
5use super::bit_reader_reverse::{BitReaderReversed, GetBitsError};
6use super::scratch::HuffmanScratch;
7use crate::huff0::{HuffmanDecoder, HuffmanDecoderError, HuffmanTableError};
8use alloc::vec::Vec;
9
10#[derive(Debug)]
11#[non_exhaustive]
12pub enum DecompressLiteralsError {
13    MissingCompressedSize,
14    MissingNumStreams,
15    GetBitsError(GetBitsError),
16    HuffmanTableError(HuffmanTableError),
17    HuffmanDecoderError(HuffmanDecoderError),
18    UninitializedHuffmanTable,
19    MissingBytesForJumpHeader { got: usize },
20    MissingBytesForLiterals { got: usize, needed: usize },
21    ExtraPadding { skipped_bits: i32 },
22    BitstreamReadMismatch { read_til: isize, expected: isize },
23    DecodedLiteralCountMismatch { decoded: usize, expected: usize },
24}
25
26#[cfg(feature = "std")]
27impl std::error::Error for DecompressLiteralsError {
28    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
29        match self {
30            DecompressLiteralsError::GetBitsError(source) => Some(source),
31            DecompressLiteralsError::HuffmanTableError(source) => Some(source),
32            DecompressLiteralsError::HuffmanDecoderError(source) => Some(source),
33            _ => None,
34        }
35    }
36}
37impl core::fmt::Display for DecompressLiteralsError {
38    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
39        match self {
40            DecompressLiteralsError::MissingCompressedSize => {
41                write!(f,
42                    "compressed size was none even though it must be set to something for compressed literals",
43                )
44            }
45            DecompressLiteralsError::MissingNumStreams => {
46                write!(f,
47                    "num_streams was none even though it must be set to something (1 or 4) for compressed literals",
48                )
49            }
50            DecompressLiteralsError::GetBitsError(e) => write!(f, "{:?}", e),
51            DecompressLiteralsError::HuffmanTableError(e) => write!(f, "{:?}", e),
52            DecompressLiteralsError::HuffmanDecoderError(e) => write!(f, "{:?}", e),
53            DecompressLiteralsError::UninitializedHuffmanTable => {
54                write!(
55                    f,
56                    "Tried to reuse huffman table but it was never initialized",
57                )
58            }
59            DecompressLiteralsError::MissingBytesForJumpHeader { got } => {
60                write!(f, "Need 6 bytes to decode jump header, got {} bytes", got,)
61            }
62            DecompressLiteralsError::MissingBytesForLiterals { got, needed } => {
63                write!(
64                    f,
65                    "Need at least {} bytes to decode literals. Have: {} bytes",
66                    needed, got,
67                )
68            }
69            DecompressLiteralsError::ExtraPadding { skipped_bits } => {
70                write!(f,
71                    "Padding at the end of the sequence_section was more than a byte long: {} bits. Probably caused by data corruption",
72                    skipped_bits,
73                )
74            }
75            DecompressLiteralsError::BitstreamReadMismatch { read_til, expected } => {
76                write!(
77                    f,
78                    "Bitstream was read till: {}, should have been: {}",
79                    read_til, expected,
80                )
81            }
82            DecompressLiteralsError::DecodedLiteralCountMismatch { decoded, expected } => {
83                write!(
84                    f,
85                    "Did not decode enough literals: {}, Should have been: {}",
86                    decoded, expected,
87                )
88            }
89        }
90    }
91}
92
93impl From<HuffmanDecoderError> for DecompressLiteralsError {
94    fn from(val: HuffmanDecoderError) -> Self {
95        Self::HuffmanDecoderError(val)
96    }
97}
98
99impl From<GetBitsError> for DecompressLiteralsError {
100    fn from(val: GetBitsError) -> Self {
101        Self::GetBitsError(val)
102    }
103}
104
105impl From<HuffmanTableError> for DecompressLiteralsError {
106    fn from(val: HuffmanTableError) -> Self {
107        Self::HuffmanTableError(val)
108    }
109}
110
111/// Decode and decompress the provided literals section into `target`, returning the number of bytes read.
112pub fn decode_literals(
113    section: &LiteralsSection,
114    scratch: &mut HuffmanScratch,
115    source: &[u8],
116    target: &mut Vec<u8>,
117) -> Result<u32, DecompressLiteralsError> {
118    match section.ls_type {
119        LiteralsSectionType::Raw => {
120            target.extend(&source[0..section.regenerated_size as usize]);
121            Ok(section.regenerated_size)
122        }
123        LiteralsSectionType::RLE => {
124            target.resize(target.len() + section.regenerated_size as usize, source[0]);
125            Ok(1)
126        }
127        LiteralsSectionType::Compressed | LiteralsSectionType::Treeless => {
128            let bytes_read = decompress_literals(section, scratch, source, target)?;
129
130            //return sum of used bytes
131            Ok(bytes_read)
132        }
133    }
134}
135
136/// Decompress the provided literals section and source into the provided `target`.
137/// This function is used when the literals section is `Compressed` or `Treeless`
138///
139/// Returns the number of bytes read.
140fn decompress_literals(
141    section: &LiteralsSection,
142    scratch: &mut HuffmanScratch,
143    source: &[u8],
144    target: &mut Vec<u8>,
145) -> Result<u32, DecompressLiteralsError> {
146    use DecompressLiteralsError as err;
147
148    let compressed_size = section.compressed_size.ok_or(err::MissingCompressedSize)? as usize;
149    let num_streams = section.num_streams.ok_or(err::MissingNumStreams)?;
150
151    target.reserve(section.regenerated_size as usize);
152    let source = &source[0..compressed_size];
153    let mut bytes_read = 0;
154
155    match section.ls_type {
156        LiteralsSectionType::Compressed => {
157            //read Huffman tree description
158            bytes_read += scratch.table.build_decoder(source)?;
159            vprintln!("Built huffman table using {} bytes", bytes_read);
160        }
161        LiteralsSectionType::Treeless => {
162            if scratch.table.max_num_bits == 0 {
163                return Err(err::UninitializedHuffmanTable);
164            }
165        }
166        _ => { /* nothing to do, huffman tree has been provided by previous block */ }
167    }
168
169    let source = &source[bytes_read as usize..];
170
171    if num_streams == 4 {
172        //build jumptable
173        if source.len() < 6 {
174            return Err(err::MissingBytesForJumpHeader { got: source.len() });
175        }
176        let jump1 = source[0] as usize + ((source[1] as usize) << 8);
177        let jump2 = jump1 + source[2] as usize + ((source[3] as usize) << 8);
178        let jump3 = jump2 + source[4] as usize + ((source[5] as usize) << 8);
179        bytes_read += 6;
180        let source = &source[6..];
181
182        if source.len() < jump3 {
183            return Err(err::MissingBytesForLiterals {
184                got: source.len(),
185                needed: jump3,
186            });
187        }
188
189        //decode 4 streams
190        let stream1 = &source[..jump1];
191        let stream2 = &source[jump1..jump2];
192        let stream3 = &source[jump2..jump3];
193        let stream4 = &source[jump3..];
194
195        for stream in &[stream1, stream2, stream3, stream4] {
196            let mut decoder = HuffmanDecoder::new(&scratch.table);
197            let mut br = BitReaderReversed::new(stream);
198            //skip the 0 padding at the end of the last byte of the bit stream and throw away the first 1 found
199            let mut skipped_bits = 0;
200            loop {
201                let val = br.get_bits(1);
202                skipped_bits += 1;
203                if val == 1 || skipped_bits > 8 {
204                    break;
205                }
206            }
207            if skipped_bits > 8 {
208                //if more than 7 bits are 0, this is not the correct end of the bitstream. Either a bug or corrupted data
209                return Err(DecompressLiteralsError::ExtraPadding { skipped_bits });
210            }
211            decoder.init_state(&mut br);
212
213            while br.bits_remaining() > -(scratch.table.max_num_bits as isize) {
214                target.push(decoder.decode_symbol());
215                decoder.next_state(&mut br);
216            }
217            if br.bits_remaining() != -(scratch.table.max_num_bits as isize) {
218                return Err(DecompressLiteralsError::BitstreamReadMismatch {
219                    read_til: br.bits_remaining(),
220                    expected: -(scratch.table.max_num_bits as isize),
221                });
222            }
223        }
224
225        bytes_read += source.len() as u32;
226    } else {
227        //just decode the one stream
228        assert!(num_streams == 1);
229        let mut decoder = HuffmanDecoder::new(&scratch.table);
230        let mut br = BitReaderReversed::new(source);
231        let mut skipped_bits = 0;
232        loop {
233            let val = br.get_bits(1);
234            skipped_bits += 1;
235            if val == 1 || skipped_bits > 8 {
236                break;
237            }
238        }
239        if skipped_bits > 8 {
240            //if more than 7 bits are 0, this is not the correct end of the bitstream. Either a bug or corrupted data
241            return Err(DecompressLiteralsError::ExtraPadding { skipped_bits });
242        }
243        decoder.init_state(&mut br);
244        while br.bits_remaining() > -(scratch.table.max_num_bits as isize) {
245            target.push(decoder.decode_symbol());
246            decoder.next_state(&mut br);
247        }
248        bytes_read += source.len() as u32;
249    }
250
251    if target.len() != section.regenerated_size as usize {
252        return Err(DecompressLiteralsError::DecodedLiteralCountMismatch {
253            decoded: target.len(),
254            expected: section.regenerated_size as usize,
255        });
256    }
257
258    Ok(bytes_read)
259}