1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34 package org.waarp.compress.zstdsafe;
35
36 import java.util.Arrays;
37
38 import static org.waarp.compress.zstdsafe.BitInputStream.*;
39 import static org.waarp.compress.zstdsafe.Constants.*;
40 import static org.waarp.compress.zstdsafe.UnsafeUtil.*;
41 import static org.waarp.compress.zstdsafe.Util.*;
42
43 class Huffman {
44 public static final int MAX_SYMBOL = 255;
45 public static final int MAX_SYMBOL_COUNT = MAX_SYMBOL + 1;
46
47 public static final int MAX_TABLE_LOG = 12;
48 public static final int MIN_TABLE_LOG = 5;
49 public static final int MAX_FSE_TABLE_LOG = 6;
50 public static final String NOT_ENOUGH_INPUT_BYTES = "Not enough input bytes";
51 public static final String INPUT_IS_CORRUPTED = "Input is corrupted";
52
53
54 private final byte[] weights = new byte[MAX_SYMBOL + 1];
55 private final int[] ranks = new int[MAX_TABLE_LOG + 1];
56
57
58 private int tableLog = -1;
59 private final byte[] symbols = new byte[1 << MAX_TABLE_LOG];
60 private final byte[] numbersOfBits = new byte[1 << MAX_TABLE_LOG];
61
62 private final FseTableReader reader = new FseTableReader();
63 private final FiniteStateEntropy.Table fseTable =
64 new FiniteStateEntropy.Table(MAX_FSE_TABLE_LOG);
65
66 public boolean isLoaded() {
67 return tableLog != -1;
68 }
69
70 public int readTable(final byte[] inputBase, final int inputAddress,
71 final int size) {
72 Arrays.fill(ranks, 0);
73 int input = inputAddress;
74
75
76 verify(size > 0, input, NOT_ENOUGH_INPUT_BYTES);
77 int inputSize = inputBase[input++] & 0xFF;
78
79 final int outputSize;
80 if (inputSize >= 128) {
81 outputSize = inputSize - 127;
82 inputSize = ((outputSize + 1) / 2);
83
84 verify(inputSize + 1 <= size, input, NOT_ENOUGH_INPUT_BYTES);
85 verify(outputSize <= MAX_SYMBOL + 1, input, INPUT_IS_CORRUPTED);
86
87 for (int i = 0; i < outputSize; i += 2) {
88 final int value = inputBase[input + i / 2] & 0xFF;
89 weights[i] = (byte) (value >>> 4);
90 weights[i + 1] = (byte) (value & 0xF);
91 }
92 } else {
93 verify(inputSize + 1 <= size, input, NOT_ENOUGH_INPUT_BYTES);
94
95 final int inputLimit = input + inputSize;
96 input += reader.readFseTable(fseTable, inputBase, input, inputLimit,
97 FiniteStateEntropy.MAX_SYMBOL,
98 MAX_FSE_TABLE_LOG);
99 outputSize =
100 FiniteStateEntropy.decompress(fseTable, inputBase, input, inputLimit,
101 weights);
102 }
103
104 int totalWeight = 0;
105 for (int i = 0; i < outputSize; i++) {
106 ranks[weights[i]]++;
107 totalWeight +=
108 (1 << weights[i]) >> 1;
109 }
110 verify(totalWeight != 0, input, INPUT_IS_CORRUPTED);
111
112 tableLog = Util.highestBit(totalWeight) + 1;
113 verify(tableLog <= MAX_TABLE_LOG, input, INPUT_IS_CORRUPTED);
114
115 final int total = 1 << tableLog;
116 final int rest = total - totalWeight;
117 verify(isPowerOf2(rest), input, INPUT_IS_CORRUPTED);
118
119 final int lastWeight = Util.highestBit(rest) + 1;
120
121 weights[outputSize] = (byte) lastWeight;
122 ranks[lastWeight]++;
123
124 final int numberOfSymbols = outputSize + 1;
125
126
127 int nextRankStart = 0;
128 for (int i = 1; i < tableLog + 1; ++i) {
129 final int current = nextRankStart;
130 nextRankStart += ranks[i] << (i - 1);
131 ranks[i] = current;
132 }
133
134 for (int n = 0; n < numberOfSymbols; n++) {
135 final int weight = weights[n];
136 final int length = (1 << weight) >> 1;
137
138 final byte symbol = (byte) n;
139 final byte numberOfBits = (byte) (tableLog + 1 - weight);
140 for (int i = ranks[weight]; i < ranks[weight] + length; i++) {
141 symbols[i] = symbol;
142 numbersOfBits[i] = numberOfBits;
143 }
144 ranks[weight] += length;
145 }
146
147 verify(ranks[1] >= 2 && (ranks[1] & 1) == 0, input, INPUT_IS_CORRUPTED);
148
149 return inputSize + 1;
150 }
151
152 public void decodeSingleStream(final byte[] inputBase, final int inputAddress,
153 final int inputLimit, final byte[] outputBase,
154 final int outputAddress,
155 final int outputLimit) {
156 final BitInputStream.Initializer initializer =
157 new BitInputStream.Initializer(inputBase, inputAddress, inputLimit);
158 initializer.initialize();
159
160 long bits = initializer.getBits();
161 int bitsConsumed = initializer.getBitsConsumed();
162 int currentAddress = initializer.getCurrentAddress();
163
164 final int tableLog1 = this.tableLog;
165 final byte[] numbersOfBits1 = this.numbersOfBits;
166 final byte[] symbols1 = this.symbols;
167
168
169 int output = outputAddress;
170 final int fastOutputLimit = outputLimit - 4;
171 while (output < fastOutputLimit) {
172 final BitInputStream.Loader loader =
173 new BitInputStream.Loader(inputBase, inputAddress, currentAddress,
174 bits, bitsConsumed);
175 final boolean done = loader.load();
176 bits = loader.getBits();
177 bitsConsumed = loader.getBitsConsumed();
178 currentAddress = loader.getCurrentAddress();
179 if (done) {
180 break;
181 }
182
183 bitsConsumed =
184 decodeSymbol(outputBase, output, bits, bitsConsumed, tableLog1,
185 numbersOfBits1, symbols1);
186 bitsConsumed =
187 decodeSymbol(outputBase, output + 1, bits, bitsConsumed, tableLog1,
188 numbersOfBits1, symbols1);
189 bitsConsumed =
190 decodeSymbol(outputBase, output + 2, bits, bitsConsumed, tableLog1,
191 numbersOfBits1, symbols1);
192 bitsConsumed =
193 decodeSymbol(outputBase, output + 3, bits, bitsConsumed, tableLog1,
194 numbersOfBits1, symbols1);
195 output += SIZE_OF_INT;
196 }
197
198 decodeTail(inputBase, inputAddress, currentAddress, bitsConsumed, bits,
199 outputBase, output, outputLimit);
200 }
201
202 public void decode4Streams(final byte[] inputBase, final int inputAddress,
203 final int inputLimit, final byte[] outputBase,
204 final int outputAddress, final int outputLimit) {
205 verify(inputLimit - inputAddress >= 10, inputAddress,
206 INPUT_IS_CORRUPTED);
207
208 final int start1 =
209 inputAddress + 3 * SIZE_OF_SHORT;
210 final int start2 = start1 + (getShort(inputBase, inputAddress) & 0xFFFF);
211 final int start3 =
212 start2 + (getShort(inputBase, inputAddress + 2) & 0xFFFF);
213 final int start4 =
214 start3 + (getShort(inputBase, inputAddress + 4) & 0xFFFF);
215
216 BitInputStream.Initializer initializer =
217 new BitInputStream.Initializer(inputBase, start1, start2);
218 initializer.initialize();
219 int stream1bitsConsumed = initializer.getBitsConsumed();
220 int stream1currentAddress = initializer.getCurrentAddress();
221 long stream1bits = initializer.getBits();
222
223 initializer = new BitInputStream.Initializer(inputBase, start2, start3);
224 initializer.initialize();
225 int stream2bitsConsumed = initializer.getBitsConsumed();
226 int stream2currentAddress = initializer.getCurrentAddress();
227 long stream2bits = initializer.getBits();
228
229 initializer = new BitInputStream.Initializer(inputBase, start3, start4);
230 initializer.initialize();
231 int stream3bitsConsumed = initializer.getBitsConsumed();
232 int stream3currentAddress = initializer.getCurrentAddress();
233 long stream3bits = initializer.getBits();
234
235 initializer = new BitInputStream.Initializer(inputBase, start4, inputLimit);
236 initializer.initialize();
237 int stream4bitsConsumed = initializer.getBitsConsumed();
238 int stream4currentAddress = initializer.getCurrentAddress();
239 long stream4bits = initializer.getBits();
240
241 final int segmentSize = (outputLimit - outputAddress + 3) / 4;
242
243 final int outputStart2 = outputAddress + segmentSize;
244 final int outputStart3 = outputStart2 + segmentSize;
245 final int outputStart4 = outputStart3 + segmentSize;
246
247 int output1 = outputAddress;
248 int output2 = outputStart2;
249 int output3 = outputStart3;
250 int output4 = outputStart4;
251
252 final int fastOutputLimit = outputLimit - 7;
253 final int tableLog1 = this.tableLog;
254 final byte[] numbersOfBits1 = this.numbersOfBits;
255 final byte[] symbols1 = this.symbols;
256
257 while (output4 < fastOutputLimit) {
258 stream1bitsConsumed =
259 decodeSymbol(outputBase, output1, stream1bits, stream1bitsConsumed,
260 tableLog1, numbersOfBits1, symbols1);
261 stream2bitsConsumed =
262 decodeSymbol(outputBase, output2, stream2bits, stream2bitsConsumed,
263 tableLog1, numbersOfBits1, symbols1);
264 stream3bitsConsumed =
265 decodeSymbol(outputBase, output3, stream3bits, stream3bitsConsumed,
266 tableLog1, numbersOfBits1, symbols1);
267 stream4bitsConsumed =
268 decodeSymbol(outputBase, output4, stream4bits, stream4bitsConsumed,
269 tableLog1, numbersOfBits1, symbols1);
270
271 stream1bitsConsumed = decodeSymbol(outputBase, output1 + 1, stream1bits,
272 stream1bitsConsumed, tableLog1,
273 numbersOfBits1, symbols1);
274 stream2bitsConsumed = decodeSymbol(outputBase, output2 + 1, stream2bits,
275 stream2bitsConsumed, tableLog1,
276 numbersOfBits1, symbols1);
277 stream3bitsConsumed = decodeSymbol(outputBase, output3 + 1, stream3bits,
278 stream3bitsConsumed, tableLog1,
279 numbersOfBits1, symbols1);
280 stream4bitsConsumed = decodeSymbol(outputBase, output4 + 1, stream4bits,
281 stream4bitsConsumed, tableLog1,
282 numbersOfBits1, symbols1);
283
284 stream1bitsConsumed = decodeSymbol(outputBase, output1 + 2, stream1bits,
285 stream1bitsConsumed, tableLog1,
286 numbersOfBits1, symbols1);
287 stream2bitsConsumed = decodeSymbol(outputBase, output2 + 2, stream2bits,
288 stream2bitsConsumed, tableLog1,
289 numbersOfBits1, symbols1);
290 stream3bitsConsumed = decodeSymbol(outputBase, output3 + 2, stream3bits,
291 stream3bitsConsumed, tableLog1,
292 numbersOfBits1, symbols1);
293 stream4bitsConsumed = decodeSymbol(outputBase, output4 + 2, stream4bits,
294 stream4bitsConsumed, tableLog1,
295 numbersOfBits1, symbols1);
296
297 stream1bitsConsumed = decodeSymbol(outputBase, output1 + 3, stream1bits,
298 stream1bitsConsumed, tableLog1,
299 numbersOfBits1, symbols1);
300 stream2bitsConsumed = decodeSymbol(outputBase, output2 + 3, stream2bits,
301 stream2bitsConsumed, tableLog1,
302 numbersOfBits1, symbols1);
303 stream3bitsConsumed = decodeSymbol(outputBase, output3 + 3, stream3bits,
304 stream3bitsConsumed, tableLog1,
305 numbersOfBits1, symbols1);
306 stream4bitsConsumed = decodeSymbol(outputBase, output4 + 3, stream4bits,
307 stream4bitsConsumed, tableLog1,
308 numbersOfBits1, symbols1);
309
310 output1 += SIZE_OF_INT;
311 output2 += SIZE_OF_INT;
312 output3 += SIZE_OF_INT;
313 output4 += SIZE_OF_INT;
314
315 BitInputStream.Loader loader =
316 new BitInputStream.Loader(inputBase, start1, stream1currentAddress,
317 stream1bits, stream1bitsConsumed);
318 boolean done = loader.load();
319 stream1bitsConsumed = loader.getBitsConsumed();
320 stream1bits = loader.getBits();
321 stream1currentAddress = loader.getCurrentAddress();
322
323 if (done) {
324 break;
325 }
326
327 loader =
328 new BitInputStream.Loader(inputBase, start2, stream2currentAddress,
329 stream2bits, stream2bitsConsumed);
330 done = loader.load();
331 stream2bitsConsumed = loader.getBitsConsumed();
332 stream2bits = loader.getBits();
333 stream2currentAddress = loader.getCurrentAddress();
334
335 if (done) {
336 break;
337 }
338
339 loader =
340 new BitInputStream.Loader(inputBase, start3, stream3currentAddress,
341 stream3bits, stream3bitsConsumed);
342 done = loader.load();
343 stream3bitsConsumed = loader.getBitsConsumed();
344 stream3bits = loader.getBits();
345 stream3currentAddress = loader.getCurrentAddress();
346 if (done) {
347 break;
348 }
349
350 loader =
351 new BitInputStream.Loader(inputBase, start4, stream4currentAddress,
352 stream4bits, stream4bitsConsumed);
353 done = loader.load();
354 stream4bitsConsumed = loader.getBitsConsumed();
355 stream4bits = loader.getBits();
356 stream4currentAddress = loader.getCurrentAddress();
357 if (done) {
358 break;
359 }
360 }
361
362 verify(output1 <= outputStart2 && output2 <= outputStart3 &&
363 output3 <= outputStart4, inputAddress, INPUT_IS_CORRUPTED);
364
365
366 decodeTail(inputBase, start1, stream1currentAddress, stream1bitsConsumed,
367 stream1bits, outputBase, output1, outputStart2);
368 decodeTail(inputBase, start2, stream2currentAddress, stream2bitsConsumed,
369 stream2bits, outputBase, output2, outputStart3);
370 decodeTail(inputBase, start3, stream3currentAddress, stream3bitsConsumed,
371 stream3bits, outputBase, output3, outputStart4);
372 decodeTail(inputBase, start4, stream4currentAddress, stream4bitsConsumed,
373 stream4bits, outputBase, output4, outputLimit);
374 }
375
376 private void decodeTail(final byte[] inputBase, final int startAddress,
377 int currentAddress, int bitsConsumed, long bits,
378 final byte[] outputBase, int outputAddress,
379 final int outputLimit) {
380 final int tableLog1 = this.tableLog;
381 final byte[] numbersOfBits1 = this.numbersOfBits;
382 final byte[] symbols1 = this.symbols;
383
384
385 while (outputAddress < outputLimit) {
386 final BitInputStream.Loader loader =
387 new BitInputStream.Loader(inputBase, startAddress, currentAddress,
388 bits, bitsConsumed);
389 final boolean done = loader.load();
390 bitsConsumed = loader.getBitsConsumed();
391 bits = loader.getBits();
392 currentAddress = loader.getCurrentAddress();
393 if (done) {
394 break;
395 }
396
397 bitsConsumed =
398 decodeSymbol(outputBase, outputAddress++, bits, bitsConsumed,
399 tableLog1, numbersOfBits1, symbols1);
400 }
401
402
403 while (outputAddress < outputLimit) {
404 bitsConsumed =
405 decodeSymbol(outputBase, outputAddress++, bits, bitsConsumed,
406 tableLog1, numbersOfBits1, symbols1);
407 }
408
409 verify(isEndOfStream(startAddress, currentAddress, bitsConsumed),
410 startAddress, "Bit stream is not fully consumed");
411 }
412
413 private static int decodeSymbol(final byte[] outputBase,
414 final int outputAddress,
415 final long bitContainer,
416 final int bitsConsumed, final int tableLog,
417 final byte[] numbersOfBits,
418 final byte[] symbols) {
419 final int value = (int) peekBitsFast(bitsConsumed, bitContainer, tableLog);
420 outputBase[outputAddress] = symbols[value];
421 return bitsConsumed + numbersOfBits[value];
422 }
423 }