1use super::super::blocks::literals_section::{LiteralsSection, LiteralsSectionType};
5use super::bit_reader_reverse::{BitReaderReversed, GetBitsError};
6use super::scratch::HuffmanScratch;
7use crate::huff0::{HuffmanDecoder, HuffmanDecoderError, HuffmanTableError};
8use alloc::vec::Vec;
9
10#[derive(Debug)]
11#[non_exhaustive]
12pub enum DecompressLiteralsError {
13 MissingCompressedSize,
14 MissingNumStreams,
15 GetBitsError(GetBitsError),
16 HuffmanTableError(HuffmanTableError),
17 HuffmanDecoderError(HuffmanDecoderError),
18 UninitializedHuffmanTable,
19 MissingBytesForJumpHeader { got: usize },
20 MissingBytesForLiterals { got: usize, needed: usize },
21 ExtraPadding { skipped_bits: i32 },
22 BitstreamReadMismatch { read_til: isize, expected: isize },
23 DecodedLiteralCountMismatch { decoded: usize, expected: usize },
24}
25
26#[cfg(feature = "std")]
27impl std::error::Error for DecompressLiteralsError {
28 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
29 match self {
30 DecompressLiteralsError::GetBitsError(source) => Some(source),
31 DecompressLiteralsError::HuffmanTableError(source) => Some(source),
32 DecompressLiteralsError::HuffmanDecoderError(source) => Some(source),
33 _ => None,
34 }
35 }
36}
37impl core::fmt::Display for DecompressLiteralsError {
38 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
39 match self {
40 DecompressLiteralsError::MissingCompressedSize => {
41 write!(f,
42 "compressed size was none even though it must be set to something for compressed literals",
43 )
44 }
45 DecompressLiteralsError::MissingNumStreams => {
46 write!(f,
47 "num_streams was none even though it must be set to something (1 or 4) for compressed literals",
48 )
49 }
50 DecompressLiteralsError::GetBitsError(e) => write!(f, "{:?}", e),
51 DecompressLiteralsError::HuffmanTableError(e) => write!(f, "{:?}", e),
52 DecompressLiteralsError::HuffmanDecoderError(e) => write!(f, "{:?}", e),
53 DecompressLiteralsError::UninitializedHuffmanTable => {
54 write!(
55 f,
56 "Tried to reuse huffman table but it was never initialized",
57 )
58 }
59 DecompressLiteralsError::MissingBytesForJumpHeader { got } => {
60 write!(f, "Need 6 bytes to decode jump header, got {} bytes", got,)
61 }
62 DecompressLiteralsError::MissingBytesForLiterals { got, needed } => {
63 write!(
64 f,
65 "Need at least {} bytes to decode literals. Have: {} bytes",
66 needed, got,
67 )
68 }
69 DecompressLiteralsError::ExtraPadding { skipped_bits } => {
70 write!(f,
71 "Padding at the end of the sequence_section was more than a byte long: {} bits. Probably caused by data corruption",
72 skipped_bits,
73 )
74 }
75 DecompressLiteralsError::BitstreamReadMismatch { read_til, expected } => {
76 write!(
77 f,
78 "Bitstream was read till: {}, should have been: {}",
79 read_til, expected,
80 )
81 }
82 DecompressLiteralsError::DecodedLiteralCountMismatch { decoded, expected } => {
83 write!(
84 f,
85 "Did not decode enough literals: {}, Should have been: {}",
86 decoded, expected,
87 )
88 }
89 }
90 }
91}
92
93impl From<HuffmanDecoderError> for DecompressLiteralsError {
94 fn from(val: HuffmanDecoderError) -> Self {
95 Self::HuffmanDecoderError(val)
96 }
97}
98
99impl From<GetBitsError> for DecompressLiteralsError {
100 fn from(val: GetBitsError) -> Self {
101 Self::GetBitsError(val)
102 }
103}
104
105impl From<HuffmanTableError> for DecompressLiteralsError {
106 fn from(val: HuffmanTableError) -> Self {
107 Self::HuffmanTableError(val)
108 }
109}
110
111pub fn decode_literals(
113 section: &LiteralsSection,
114 scratch: &mut HuffmanScratch,
115 source: &[u8],
116 target: &mut Vec<u8>,
117) -> Result<u32, DecompressLiteralsError> {
118 match section.ls_type {
119 LiteralsSectionType::Raw => {
120 target.extend(&source[0..section.regenerated_size as usize]);
121 Ok(section.regenerated_size)
122 }
123 LiteralsSectionType::RLE => {
124 target.resize(target.len() + section.regenerated_size as usize, source[0]);
125 Ok(1)
126 }
127 LiteralsSectionType::Compressed | LiteralsSectionType::Treeless => {
128 let bytes_read = decompress_literals(section, scratch, source, target)?;
129
130 Ok(bytes_read)
132 }
133 }
134}
135
136fn decompress_literals(
141 section: &LiteralsSection,
142 scratch: &mut HuffmanScratch,
143 source: &[u8],
144 target: &mut Vec<u8>,
145) -> Result<u32, DecompressLiteralsError> {
146 use DecompressLiteralsError as err;
147
148 let compressed_size = section.compressed_size.ok_or(err::MissingCompressedSize)? as usize;
149 let num_streams = section.num_streams.ok_or(err::MissingNumStreams)?;
150
151 target.reserve(section.regenerated_size as usize);
152 let source = &source[0..compressed_size];
153 let mut bytes_read = 0;
154
155 match section.ls_type {
156 LiteralsSectionType::Compressed => {
157 bytes_read += scratch.table.build_decoder(source)?;
159 vprintln!("Built huffman table using {} bytes", bytes_read);
160 }
161 LiteralsSectionType::Treeless => {
162 if scratch.table.max_num_bits == 0 {
163 return Err(err::UninitializedHuffmanTable);
164 }
165 }
166 _ => { }
167 }
168
169 let source = &source[bytes_read as usize..];
170
171 if num_streams == 4 {
172 if source.len() < 6 {
174 return Err(err::MissingBytesForJumpHeader { got: source.len() });
175 }
176 let jump1 = source[0] as usize + ((source[1] as usize) << 8);
177 let jump2 = jump1 + source[2] as usize + ((source[3] as usize) << 8);
178 let jump3 = jump2 + source[4] as usize + ((source[5] as usize) << 8);
179 bytes_read += 6;
180 let source = &source[6..];
181
182 if source.len() < jump3 {
183 return Err(err::MissingBytesForLiterals {
184 got: source.len(),
185 needed: jump3,
186 });
187 }
188
189 let stream1 = &source[..jump1];
191 let stream2 = &source[jump1..jump2];
192 let stream3 = &source[jump2..jump3];
193 let stream4 = &source[jump3..];
194
195 for stream in &[stream1, stream2, stream3, stream4] {
196 let mut decoder = HuffmanDecoder::new(&scratch.table);
197 let mut br = BitReaderReversed::new(stream);
198 let mut skipped_bits = 0;
200 loop {
201 let val = br.get_bits(1);
202 skipped_bits += 1;
203 if val == 1 || skipped_bits > 8 {
204 break;
205 }
206 }
207 if skipped_bits > 8 {
208 return Err(DecompressLiteralsError::ExtraPadding { skipped_bits });
210 }
211 decoder.init_state(&mut br);
212
213 while br.bits_remaining() > -(scratch.table.max_num_bits as isize) {
214 target.push(decoder.decode_symbol());
215 decoder.next_state(&mut br);
216 }
217 if br.bits_remaining() != -(scratch.table.max_num_bits as isize) {
218 return Err(DecompressLiteralsError::BitstreamReadMismatch {
219 read_til: br.bits_remaining(),
220 expected: -(scratch.table.max_num_bits as isize),
221 });
222 }
223 }
224
225 bytes_read += source.len() as u32;
226 } else {
227 assert!(num_streams == 1);
229 let mut decoder = HuffmanDecoder::new(&scratch.table);
230 let mut br = BitReaderReversed::new(source);
231 let mut skipped_bits = 0;
232 loop {
233 let val = br.get_bits(1);
234 skipped_bits += 1;
235 if val == 1 || skipped_bits > 8 {
236 break;
237 }
238 }
239 if skipped_bits > 8 {
240 return Err(DecompressLiteralsError::ExtraPadding { skipped_bits });
242 }
243 decoder.init_state(&mut br);
244 while br.bits_remaining() > -(scratch.table.max_num_bits as isize) {
245 target.push(decoder.decode_symbol());
246 decoder.next_state(&mut br);
247 }
248 bytes_read += source.len() as u32;
249 }
250
251 if target.len() != section.regenerated_size as usize {
252 return Err(DecompressLiteralsError::DecodedLiteralCountMismatch {
253 decoded: target.len(),
254 expected: section.regenerated_size as usize,
255 });
256 }
257
258 Ok(bytes_read)
259}