1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34 package org.waarp.compress.zstdsafe;
35
36 import org.waarp.compress.MalformedInputException;
37
38 import java.util.Arrays;
39
40 import static org.waarp.compress.zstdsafe.BitInputStream.*;
41 import static org.waarp.compress.zstdsafe.Constants.*;
42 import static org.waarp.compress.zstdsafe.UnsafeUtil.*;
43 import static org.waarp.compress.zstdsafe.Util.*;
44
45 class ZstdFrameDecompressor {
46 private static final int[] DEC_32_TABLE = { 4, 1, 2, 1, 4, 4, 4, 4 };
47 private static final int[] DEC_64_TABLE = { 0, 0, 0, -1, 0, 1, 2, 3 };
48
49 private static final int V07_MAGIC_NUMBER = 0xFD2FB527;
50
51 private static final int MAX_WINDOW_SIZE = 1 << 23;
52
53 private static final int[] LITERALS_LENGTH_BASE = {
54 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18, 20, 22, 24,
55 28, 32, 40, 48, 64, 0x80, 0x100, 0x200, 0x400, 0x800, 0x1000, 0x2000,
56 0x4000, 0x8000, 0x10000
57 };
58
59 private static final int[] MATCH_LENGTH_BASE = {
60 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22,
61 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 37, 39, 41, 43, 47,
62 51, 59, 67, 83, 99, 0x83, 0x103, 0x203, 0x403, 0x803, 0x1003, 0x2003,
63 0x4003, 0x8003, 0x10003
64 };
65
66 private static final int[] OFFSET_CODES_BASE = {
67 0, 1, 1, 5, 0xD, 0x1D, 0x3D, 0x7D, 0xFD, 0x1FD, 0x3FD, 0x7FD, 0xFFD,
68 0x1FFD, 0x3FFD, 0x7FFD, 0xFFFD, 0x1FFFD, 0x3FFFD, 0x7FFFD, 0xFFFFD,
69 0x1FFFFD, 0x3FFFFD, 0x7FFFFD, 0xFFFFFD, 0x1FFFFFD, 0x3FFFFFD, 0x7FFFFFD,
70 0xFFFFFFD
71 };
72
73 private static final FiniteStateEntropy.Table DEFAULT_LITERALS_LENGTH_TABLE =
74 new FiniteStateEntropy.Table(6, new int[] {
75 0, 16, 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32, 0, 0, 0, 0,
76 32, 0, 0, 32, 0, 32, 0, 32, 0, 0, 32, 0, 32, 0, 32, 0, 0, 16, 32, 0,
77 0, 48, 16, 32, 32, 32, 32, 32, 32, 32, 32, 0, 32, 32, 32, 32, 32, 32,
78 0, 0, 0, 0
79 }, new byte[] {
80 0, 0, 1, 3, 4, 6, 7, 9, 10, 12, 14, 16, 18, 19, 21, 22, 24, 25, 26,
81 27, 29, 31, 0, 1, 2, 4, 5, 7, 8, 10, 11, 13, 16, 17, 19, 20, 22, 23,
82 25, 25, 26, 28, 30, 0, 1, 2, 3, 5, 6, 8, 9, 11, 12, 15, 17, 18, 20,
83 21, 23, 24, 35, 34, 33, 32
84 }, new byte[] {
85 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 6, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 4,
86 4, 5, 5, 5, 5, 5, 5, 5, 6, 5, 5, 5, 5, 5, 5, 4, 4, 5, 6, 6, 4, 4, 5,
87 5, 5, 5, 5, 5, 5, 5, 6, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6
88 });
89
90 private static final FiniteStateEntropy.Table DEFAULT_OFFSET_CODES_TABLE =
91 new FiniteStateEntropy.Table(5, new int[] {
92 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 0, 0, 0, 0, 16, 0, 0,
93 0, 16, 0, 0, 0, 0, 0, 0, 0
94 }, new byte[] {
95 0, 6, 9, 15, 21, 3, 7, 12, 18, 23, 5, 8, 14, 20, 2, 7, 11, 17, 22, 4,
96 8, 13, 19, 1, 6, 10, 16, 28, 27, 26, 25, 24
97 }, new byte[] {
98 5, 4, 5, 5, 5, 5, 4, 5, 5, 5, 5, 4, 5, 5, 5, 4, 5, 5, 5, 5, 4, 5, 5,
99 5, 4, 5, 5, 5, 5, 5, 5, 5
100 });
101
102 private static final FiniteStateEntropy.Table DEFAULT_MATCH_LENGTH_TABLE =
103 new FiniteStateEntropy.Table(6, new int[] {
104 0, 0, 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16,
105 0, 32, 0, 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32, 48,
106 16, 32, 32, 32, 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
107 }, new byte[] {
108 0, 1, 2, 3, 5, 6, 8, 10, 13, 16, 19, 22, 25, 28, 31, 33, 35, 37, 39,
109 41, 43, 45, 1, 2, 3, 4, 6, 7, 9, 12, 15, 18, 21, 24, 27, 30, 32, 34,
110 36, 38, 40, 42, 44, 1, 1, 2, 4, 5, 7, 8, 11, 14, 17, 20, 23, 26, 29,
111 52, 51, 50, 49, 48, 47, 46
112 }, new byte[] {
113 6, 4, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 4,
114 4, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 4, 4, 4,
115 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6
116 });
117 public static final String NOT_ENOUGH_INPUT_BYTES = "Not enough input bytes";
118 public static final String OUTPUT_BUFFER_TOO_SMALL =
119 "Output buffer too small";
120 public static final String EXPECTED_MATCH_LENGTH_TABLE_TO_BE_PRESENT =
121 "Expected match length table to be present";
122 public static final String INPUT_IS_CORRUPTED = "Input is corrupted";
123 public static final String VALUE_EXCEEDS_EXPECTED_MAXIMUM_VALUE =
124 "Value exceeds expected maximum value";
125
126 private final byte[] literals = new byte[MAX_BLOCK_SIZE + SIZE_OF_LONG];
127
128
129
130 private byte[] literalsBase;
131 private int literalsAddress;
132 private int literalsLimit;
133
134 private final int[] previousOffsets = new int[3];
135
136 private final FiniteStateEntropy.Table literalsLengthTable =
137 new FiniteStateEntropy.Table(LITERAL_LENGTH_TABLE_LOG);
138 private final FiniteStateEntropy.Table offsetCodesTable =
139 new FiniteStateEntropy.Table(OFFSET_TABLE_LOG);
140 private final FiniteStateEntropy.Table matchLengthTable =
141 new FiniteStateEntropy.Table(MATCH_LENGTH_TABLE_LOG);
142
143 private FiniteStateEntropy.Table currentLiteralsLengthTable;
144 private FiniteStateEntropy.Table currentOffsetCodesTable;
145 private FiniteStateEntropy.Table currentMatchLengthTable;
146
147 private final Huffman huffman = new Huffman();
148 private final FseTableReader fse = new FseTableReader();
149
150 public int decompress(final byte[] inputBase, final int inputAddress,
151 final int inputLimit, final byte[] outputBase,
152 final int outputAddress, final int outputLimit) {
153 if (outputAddress == outputLimit) {
154 return 0;
155 }
156
157 int input = inputAddress;
158 int output = outputAddress;
159
160 while (input < inputLimit) {
161 reset();
162 final int outputStart = output;
163 input += verifyMagic(inputBase, inputAddress, inputLimit);
164
165 final FrameHeader frameHeader =
166 readFrameHeader(inputBase, input, inputLimit);
167 input += frameHeader.headerSize;
168
169 boolean lastBlock;
170 do {
171 verify(input + SIZE_OF_BLOCK_HEADER <= inputLimit, input,
172 NOT_ENOUGH_INPUT_BYTES);
173
174
175 final int header = getInt(inputBase, input) & 0xFFFFFF;
176 input += SIZE_OF_BLOCK_HEADER;
177
178 lastBlock = (header & 1) != 0;
179 final int blockType = (header >>> 1) & 0x3;
180 final int blockSize = (header >>> 3) & 0x1FFFFF;
181
182 final int decodedSize;
183 switch (blockType) {
184 case RAW_BLOCK:
185 verify(inputAddress + blockSize <= inputLimit, input,
186 NOT_ENOUGH_INPUT_BYTES);
187 decodedSize =
188 decodeRawBlock(inputBase, input, blockSize, outputBase, output,
189 outputLimit);
190 input += blockSize;
191 break;
192 case RLE_BLOCK:
193 verify(inputAddress + 1 <= inputLimit, input,
194 NOT_ENOUGH_INPUT_BYTES);
195 decodedSize =
196 decodeRleBlock(blockSize, inputBase, input, outputBase, output,
197 outputLimit);
198 input += 1;
199 break;
200 case COMPRESSED_BLOCK:
201 verify(inputAddress + blockSize <= inputLimit, input,
202 NOT_ENOUGH_INPUT_BYTES);
203 decodedSize =
204 decodeCompressedBlock(inputBase, input, blockSize, outputBase,
205 output, outputLimit,
206 frameHeader.windowSize, outputAddress);
207 input += blockSize;
208 break;
209 default:
210 throw fail(input, "Invalid block type");
211 }
212
213 output += decodedSize;
214 } while (!lastBlock);
215
216 if (frameHeader.hasChecksum) {
217 final int decodedFrameSize = output - outputStart;
218
219 final long hash =
220 XxHash64.hash(0, outputBase, outputStart, decodedFrameSize);
221
222 final int checksum = getInt(inputBase, input);
223 if (checksum != (int) hash) {
224 throw new MalformedInputException(input, String.format(
225 "Bad checksum. Expected: %s, actual: %s",
226 Integer.toHexString(checksum), Integer.toHexString((int) hash)));
227 }
228
229 input += SIZE_OF_INT;
230 }
231 }
232
233 return output - outputAddress;
234 }
235
236 private void reset() {
237 previousOffsets[0] = 1;
238 previousOffsets[1] = 4;
239 previousOffsets[2] = 8;
240
241 currentLiteralsLengthTable = null;
242 currentOffsetCodesTable = null;
243 currentMatchLengthTable = null;
244 }
245
246 private static int decodeRawBlock(final byte[] inputBase,
247 final int inputAddress, final int blockSize,
248 final byte[] outputBase,
249 final int outputAddress,
250 final int outputLimit) {
251 verify(outputAddress + blockSize <= outputLimit, inputAddress,
252 OUTPUT_BUFFER_TOO_SMALL);
253
254 copyMemory(inputBase, inputAddress, outputBase, outputAddress, blockSize);
255 return blockSize;
256 }
257
258 private static int decodeRleBlock(final int size, final byte[] inputBase,
259 final int inputAddress,
260 final byte[] outputBase,
261 final int outputAddress,
262 final int outputLimit) {
263 verify(outputAddress + size <= outputLimit, inputAddress,
264 OUTPUT_BUFFER_TOO_SMALL);
265
266 int output = outputAddress;
267 final long value = inputBase[inputAddress] & 0xFFL;
268
269 int remaining = size;
270 if (remaining >= SIZE_OF_LONG) {
271 final long packed =
272 value | (value << 8) | (value << 16) | (value << 24) | (value << 32) |
273 (value << 40) | (value << 48) | (value << 56);
274
275 do {
276 putLong(outputBase, output, packed);
277 output += SIZE_OF_LONG;
278 remaining -= SIZE_OF_LONG;
279 } while (remaining >= SIZE_OF_LONG);
280 }
281
282 for (int i = 0; i < remaining; i++) {
283 outputBase[output] = (byte) value;
284 output++;
285 }
286
287 return size;
288 }
289
290 private int decodeCompressedBlock(final byte[] inputBase,
291 final int inputAddress, final int blockSize,
292 final byte[] outputBase,
293 final int outputAddress,
294 final int outputLimit, final int windowSize,
295 final int outputAbsoluteBaseAddress) {
296 final int inputLimit = inputAddress + blockSize;
297 int input = inputAddress;
298
299 verify(blockSize <= MAX_BLOCK_SIZE, input,
300 EXPECTED_MATCH_LENGTH_TABLE_TO_BE_PRESENT);
301 verify(blockSize >= MIN_BLOCK_SIZE, input,
302 "Compressed block size too small");
303
304
305 final int literalsBlockType = inputBase[input] & 0x3;
306
307 switch (literalsBlockType) {
308 case RAW_LITERALS_BLOCK: {
309 input += decodeRawLiterals(inputBase, input, inputLimit);
310 break;
311 }
312 case RLE_LITERALS_BLOCK: {
313 input += decodeRleLiterals(inputBase, input, blockSize);
314 break;
315 }
316 case TREELESS_LITERALS_BLOCK:
317 verify(huffman.isLoaded(), input, "Dictionary is corrupted");
318 case COMPRESSED_LITERALS_BLOCK: {
319 input += decodeCompressedLiterals(inputBase, input, blockSize,
320 literalsBlockType);
321 break;
322 }
323 default:
324 throw fail(input, "Invalid literals block encoding type");
325 }
326
327 verify(windowSize <= MAX_WINDOW_SIZE, input,
328 "Window size too large (not yet supported)");
329
330 return decompressSequences(inputBase, input, inputAddress + blockSize,
331 outputBase, outputAddress, outputLimit,
332 literalsBase, literalsAddress, literalsLimit,
333 outputAbsoluteBaseAddress);
334 }
335
336 private int decompressSequences(final byte[] inputBase,
337 final int inputAddress, final int inputLimit,
338 final byte[] outputBase,
339 final int outputAddress,
340 final int outputLimit,
341 final byte[] literalsBase,
342 final int literalsAddress,
343 final int literalsLimit,
344 final int outputAbsoluteBaseAddress) {
345 final int fastOutputLimit = outputLimit - SIZE_OF_LONG;
346 final int fastMatchOutputLimit = fastOutputLimit - SIZE_OF_LONG;
347
348 int input = inputAddress;
349 int output = outputAddress;
350
351 int literalsInput = literalsAddress;
352
353 final int size = inputLimit - inputAddress;
354 verify(size >= MIN_SEQUENCES_SIZE, input, NOT_ENOUGH_INPUT_BYTES);
355
356
357 int sequenceCount = inputBase[input++] & 0xFF;
358 if (sequenceCount != 0) {
359 if (sequenceCount == 255) {
360 verify(input + SIZE_OF_SHORT <= inputLimit, input,
361 NOT_ENOUGH_INPUT_BYTES);
362 sequenceCount =
363 (getShort(inputBase, input) & 0xFFFF) + LONG_NUMBER_OF_SEQUENCES;
364 input += SIZE_OF_SHORT;
365 } else if (sequenceCount > 127) {
366 verify(input < inputLimit, input, NOT_ENOUGH_INPUT_BYTES);
367 sequenceCount =
368 ((sequenceCount - 128) << 8) + (inputBase[input++] & 0xFF);
369 }
370
371 verify(input + SIZE_OF_INT <= inputLimit, input, NOT_ENOUGH_INPUT_BYTES);
372
373 final byte type = inputBase[input++];
374
375 final int literalsLengthType = (type & 0xFF) >>> 6;
376 final int offsetCodesType = (type >>> 4) & 0x3;
377 final int matchLengthType = (type >>> 2) & 0x3;
378
379 input = computeLiteralsTable(literalsLengthType, inputBase, input,
380 inputLimit);
381 input =
382 computeOffsetsTable(offsetCodesType, inputBase, input, inputLimit);
383 input = computeMatchLengthTable(matchLengthType, inputBase, input,
384 inputLimit);
385
386
387 final BitInputStream.Initializer initializer =
388 new BitInputStream.Initializer(inputBase, input, inputLimit);
389 initializer.initialize();
390 int bitsConsumed = initializer.getBitsConsumed();
391 long bits = initializer.getBits();
392 int currentAddress = initializer.getCurrentAddress();
393
394 final FiniteStateEntropy.Table currentLiteralsLengthTable1 =
395 this.currentLiteralsLengthTable;
396 final FiniteStateEntropy.Table currentOffsetCodesTable1 =
397 this.currentOffsetCodesTable;
398 final FiniteStateEntropy.Table currentMatchLengthTable1 =
399 this.currentMatchLengthTable;
400
401 int literalsLengthState = (int) peekBits(bitsConsumed, bits,
402 currentLiteralsLengthTable1.log2Size);
403 bitsConsumed += currentLiteralsLengthTable1.log2Size;
404
405 int offsetCodesState =
406 (int) peekBits(bitsConsumed, bits, currentOffsetCodesTable1.log2Size);
407 bitsConsumed += currentOffsetCodesTable1.log2Size;
408
409 int matchLengthState =
410 (int) peekBits(bitsConsumed, bits, currentMatchLengthTable1.log2Size);
411 bitsConsumed += currentMatchLengthTable1.log2Size;
412
413 final int[] previousOffsets1 = this.previousOffsets;
414
415 final byte[] literalsLengthNumbersOfBits =
416 currentLiteralsLengthTable1.numberOfBits;
417 final int[] literalsLengthNewStates =
418 currentLiteralsLengthTable1.newState;
419 final byte[] literalsLengthSymbols = currentLiteralsLengthTable1.symbol;
420
421 final byte[] matchLengthNumbersOfBits =
422 currentMatchLengthTable1.numberOfBits;
423 final int[] matchLengthNewStates = currentMatchLengthTable1.newState;
424 final byte[] matchLengthSymbols = currentMatchLengthTable1.symbol;
425
426 final byte[] offsetCodesNumbersOfBits =
427 currentOffsetCodesTable1.numberOfBits;
428 final int[] offsetCodesNewStates = currentOffsetCodesTable1.newState;
429 final byte[] offsetCodesSymbols = currentOffsetCodesTable1.symbol;
430
431 while (sequenceCount > 0) {
432 sequenceCount--;
433
434 final BitInputStream.Loader loader =
435 new BitInputStream.Loader(inputBase, input, currentAddress, bits,
436 bitsConsumed);
437 loader.load();
438 bitsConsumed = loader.getBitsConsumed();
439 bits = loader.getBits();
440 currentAddress = loader.getCurrentAddress();
441 if (loader.isOverflow()) {
442 verify(sequenceCount == 0, input, "Not all sequences were consumed");
443 break;
444 }
445
446
447 final int literalsLengthCode =
448 literalsLengthSymbols[literalsLengthState];
449 final int matchLengthCode = matchLengthSymbols[matchLengthState];
450 final int offsetCode = offsetCodesSymbols[offsetCodesState];
451
452 final int literalsLengthBits = LITERALS_LENGTH_BITS[literalsLengthCode];
453 final int matchLengthBits = MATCH_LENGTH_BITS[matchLengthCode];
454
455 int offset = OFFSET_CODES_BASE[offsetCode];
456 if (offsetCode > 0) {
457 offset += peekBits(bitsConsumed, bits, offsetCode);
458 bitsConsumed += offsetCode;
459 }
460
461 if (offsetCode <= 1) {
462 if (literalsLengthCode == 0) {
463 offset++;
464 }
465
466 if (offset != 0) {
467 int temp;
468 if (offset == 3) {
469 temp = previousOffsets1[0] - 1;
470 } else {
471 temp = previousOffsets1[offset];
472 }
473
474 if (temp == 0) {
475 temp = 1;
476 }
477
478 if (offset != 1) {
479 previousOffsets1[2] = previousOffsets1[1];
480 }
481 previousOffsets1[1] = previousOffsets1[0];
482 previousOffsets1[0] = temp;
483
484 offset = temp;
485 } else {
486 offset = previousOffsets1[0];
487 }
488 } else {
489 previousOffsets1[2] = previousOffsets1[1];
490 previousOffsets1[1] = previousOffsets1[0];
491 previousOffsets1[0] = offset;
492 }
493
494 int matchLength = MATCH_LENGTH_BASE[matchLengthCode];
495 if (matchLengthCode > 31) {
496 matchLength += peekBits(bitsConsumed, bits, matchLengthBits);
497 bitsConsumed += matchLengthBits;
498 }
499
500 int literalsLength = LITERALS_LENGTH_BASE[literalsLengthCode];
501 if (literalsLengthCode > 15) {
502 literalsLength += peekBits(bitsConsumed, bits, literalsLengthBits);
503 bitsConsumed += literalsLengthBits;
504 }
505
506 final int totalBits = literalsLengthBits + matchLengthBits + offsetCode;
507 if (totalBits > 64 - 7 -
508 (LITERAL_LENGTH_TABLE_LOG + MATCH_LENGTH_TABLE_LOG +
509 OFFSET_TABLE_LOG)) {
510 final BitInputStream.Loader loader1 =
511 new BitInputStream.Loader(inputBase, input, currentAddress, bits,
512 bitsConsumed);
513 loader1.load();
514
515 bitsConsumed = loader1.getBitsConsumed();
516 bits = loader1.getBits();
517 currentAddress = loader1.getCurrentAddress();
518 }
519
520 int numberOfBits;
521
522 numberOfBits = literalsLengthNumbersOfBits[literalsLengthState];
523 literalsLengthState =
524 (int) (literalsLengthNewStates[literalsLengthState] +
525 peekBits(bitsConsumed, bits, numberOfBits));
526 bitsConsumed += numberOfBits;
527
528 numberOfBits = matchLengthNumbersOfBits[matchLengthState];
529 matchLengthState = (int) (matchLengthNewStates[matchLengthState] +
530 peekBits(bitsConsumed, bits,
531 numberOfBits));
532 bitsConsumed += numberOfBits;
533
534 numberOfBits = offsetCodesNumbersOfBits[offsetCodesState];
535 offsetCodesState = (int) (offsetCodesNewStates[offsetCodesState] +
536 peekBits(bitsConsumed, bits,
537 numberOfBits));
538 bitsConsumed += numberOfBits;
539
540 final int literalOutputLimit = output + literalsLength;
541 final int matchOutputLimit = literalOutputLimit + matchLength;
542
543 verify(matchOutputLimit <= outputLimit, input, OUTPUT_BUFFER_TOO_SMALL);
544 final int literalEnd = literalsInput + literalsLength;
545 verify(literalEnd <= literalsLimit, input, INPUT_IS_CORRUPTED);
546
547 final int matchAddress = literalOutputLimit - offset;
548 verify(matchAddress >= outputAbsoluteBaseAddress, input,
549 INPUT_IS_CORRUPTED);
550
551 if (literalOutputLimit > fastOutputLimit) {
552 executeLastSequence(outputBase, output, literalOutputLimit,
553 matchOutputLimit, fastOutputLimit, literalsInput,
554 matchAddress);
555 } else {
556
557
558 output = copyLiterals(outputBase, literalsBase, output, literalsInput,
559 literalOutputLimit);
560 copyMatch(outputBase, fastOutputLimit, output, offset,
561 matchOutputLimit, matchAddress, matchLength,
562 fastMatchOutputLimit);
563 }
564 output = matchOutputLimit;
565 literalsInput = literalEnd;
566 }
567 }
568
569
570 output = copyLastLiteral(outputBase, literalsBase, literalsLimit, output,
571 literalsInput);
572
573 return output - outputAddress;
574 }
575
576 private int copyLastLiteral(final byte[] outputBase,
577 final byte[] literalsBase,
578 final int literalsLimit, int output,
579 final int literalsInput) {
580 final int lastLiteralsSize = literalsLimit - literalsInput;
581 copyMemory(literalsBase, literalsInput, outputBase, output,
582 lastLiteralsSize);
583 output += lastLiteralsSize;
584 return output;
585 }
586
587 private void copyMatch(final byte[] outputBase, final int fastOutputLimit,
588 int output, final int offset,
589 final int matchOutputLimit, int matchAddress,
590 int matchLength, final int fastMatchOutputLimit) {
591 matchAddress = copyMatchHead(outputBase, output, offset, matchAddress);
592 output += SIZE_OF_LONG;
593 matchLength -= SIZE_OF_LONG;
594
595 copyMatchTail(outputBase, fastOutputLimit, output, matchOutputLimit,
596 matchAddress, matchLength, fastMatchOutputLimit);
597 }
598
599 private void copyMatchTail(final byte[] outputBase, final int fastOutputLimit,
600 int output, final int matchOutputLimit,
601 int matchAddress, final int matchLength,
602 final int fastMatchOutputLimit) {
603
604
605
606
607 if (matchOutputLimit < fastMatchOutputLimit) {
608 int copied = 0;
609 do {
610 putLong(outputBase, output, getLong(outputBase, matchAddress));
611 output += SIZE_OF_LONG;
612 matchAddress += SIZE_OF_LONG;
613 copied += SIZE_OF_LONG;
614 } while (copied < matchLength);
615 } else {
616 while (output < fastOutputLimit) {
617 putLong(outputBase, output, getLong(outputBase, matchAddress));
618 matchAddress += SIZE_OF_LONG;
619 output += SIZE_OF_LONG;
620 }
621
622 while (output < matchOutputLimit) {
623 outputBase[output++] = outputBase[matchAddress++];
624 }
625 }
626 }
627
628 private int copyMatchHead(final byte[] outputBase, final int output,
629 final int offset, int matchAddress) {
630
631 if (offset < 8) {
632
633 final int increment32 = DEC_32_TABLE[offset];
634 final int decrement64 = DEC_64_TABLE[offset];
635
636 outputBase[output] = outputBase[matchAddress];
637 outputBase[output + 1] = outputBase[matchAddress + 1];
638 outputBase[output + 2] = outputBase[matchAddress + 2];
639 outputBase[output + 3] = outputBase[matchAddress + 3];
640 matchAddress += increment32;
641
642 putInt(outputBase, output + 4, getInt(outputBase, matchAddress));
643 matchAddress -= decrement64;
644 } else {
645 putLong(outputBase, output, getLong(outputBase, matchAddress));
646 matchAddress += SIZE_OF_LONG;
647 }
648 return matchAddress;
649 }
650
651 private int copyLiterals(final byte[] outputBase, final byte[] literalsBase,
652 int output, final int literalsInput,
653 final int literalOutputLimit) {
654 int literalInput = literalsInput;
655 do {
656 putLong(outputBase, output, getLong(literalsBase, literalInput));
657 output += SIZE_OF_LONG;
658 literalInput += SIZE_OF_LONG;
659 } while (output < literalOutputLimit);
660 output = literalOutputLimit;
661 return output;
662 }
663
664 private int computeMatchLengthTable(final int matchLengthType,
665 final byte[] inputBase, int input,
666 final int inputLimit) {
667 switch (matchLengthType) {
668 case SEQUENCE_ENCODING_RLE:
669 verify(input < inputLimit, input, NOT_ENOUGH_INPUT_BYTES);
670
671 final byte value = inputBase[input++];
672 verify(value <= MAX_MATCH_LENGTH_SYMBOL, input,
673 VALUE_EXCEEDS_EXPECTED_MAXIMUM_VALUE);
674
675 FseTableReader.initializeRleTable(matchLengthTable, value);
676 currentMatchLengthTable = matchLengthTable;
677 break;
678 case SEQUENCE_ENCODING_BASIC:
679 currentMatchLengthTable = DEFAULT_MATCH_LENGTH_TABLE;
680 break;
681 case SEQUENCE_ENCODING_REPEAT:
682 verify(currentMatchLengthTable != null, input,
683 EXPECTED_MATCH_LENGTH_TABLE_TO_BE_PRESENT);
684 break;
685 case SEQUENCE_ENCODING_COMPRESSED:
686 input +=
687 fse.readFseTable(matchLengthTable, inputBase, input, inputLimit,
688 MAX_MATCH_LENGTH_SYMBOL, MATCH_LENGTH_TABLE_LOG);
689 currentMatchLengthTable = matchLengthTable;
690 break;
691 default:
692 throw fail(input, "Invalid match length encoding type");
693 }
694 return input;
695 }
696
697 private int computeOffsetsTable(final int offsetCodesType,
698 final byte[] inputBase, int input,
699 final int inputLimit) {
700 switch (offsetCodesType) {
701 case SEQUENCE_ENCODING_RLE:
702 verify(input < inputLimit, input, NOT_ENOUGH_INPUT_BYTES);
703
704 final byte value = inputBase[input++];
705 verify(value <= DEFAULT_MAX_OFFSET_CODE_SYMBOL, input,
706 VALUE_EXCEEDS_EXPECTED_MAXIMUM_VALUE);
707
708 FseTableReader.initializeRleTable(offsetCodesTable, value);
709 currentOffsetCodesTable = offsetCodesTable;
710 break;
711 case SEQUENCE_ENCODING_BASIC:
712 currentOffsetCodesTable = DEFAULT_OFFSET_CODES_TABLE;
713 break;
714 case SEQUENCE_ENCODING_REPEAT:
715 verify(currentOffsetCodesTable != null, input,
716 EXPECTED_MATCH_LENGTH_TABLE_TO_BE_PRESENT);
717 break;
718 case SEQUENCE_ENCODING_COMPRESSED:
719 input +=
720 fse.readFseTable(offsetCodesTable, inputBase, input, inputLimit,
721 DEFAULT_MAX_OFFSET_CODE_SYMBOL, OFFSET_TABLE_LOG);
722 currentOffsetCodesTable = offsetCodesTable;
723 break;
724 default:
725 throw fail(input, "Invalid offset code encoding type");
726 }
727 return input;
728 }
729
730 private int computeLiteralsTable(final int literalsLengthType,
731 final byte[] inputBase, int input,
732 final int inputLimit) {
733 switch (literalsLengthType) {
734 case SEQUENCE_ENCODING_RLE:
735 verify(input < inputLimit, input, NOT_ENOUGH_INPUT_BYTES);
736
737 final byte value = inputBase[input++];
738 verify(value <= MAX_LITERALS_LENGTH_SYMBOL, input,
739 VALUE_EXCEEDS_EXPECTED_MAXIMUM_VALUE);
740
741 FseTableReader.initializeRleTable(literalsLengthTable, value);
742 currentLiteralsLengthTable = literalsLengthTable;
743 break;
744 case SEQUENCE_ENCODING_BASIC:
745 currentLiteralsLengthTable = DEFAULT_LITERALS_LENGTH_TABLE;
746 break;
747 case SEQUENCE_ENCODING_REPEAT:
748 verify(currentLiteralsLengthTable != null, input,
749 EXPECTED_MATCH_LENGTH_TABLE_TO_BE_PRESENT);
750 break;
751 case SEQUENCE_ENCODING_COMPRESSED:
752 input +=
753 fse.readFseTable(literalsLengthTable, inputBase, input, inputLimit,
754 MAX_LITERALS_LENGTH_SYMBOL,
755 LITERAL_LENGTH_TABLE_LOG);
756 currentLiteralsLengthTable = literalsLengthTable;
757 break;
758 default:
759 throw fail(input, "Invalid literals length encoding type");
760 }
761 return input;
762 }
763
764 private void executeLastSequence(final byte[] outputBase, int output,
765 final int literalOutputLimit,
766 final int matchOutputLimit,
767 final int fastOutputLimit, int literalInput,
768 int matchAddress) {
769
770 if (output < fastOutputLimit) {
771
772 do {
773 putLong(outputBase, output, getLong(literalsBase, literalInput));
774 output += SIZE_OF_LONG;
775 literalInput += SIZE_OF_LONG;
776 } while (output < fastOutputLimit);
777
778 literalInput -= output - fastOutputLimit;
779 output = fastOutputLimit;
780 }
781
782 while (output < literalOutputLimit) {
783 outputBase[output] = literalsBase[literalInput];
784 output++;
785 literalInput++;
786 }
787
788
789 while (output < matchOutputLimit) {
790 outputBase[output] = outputBase[matchAddress];
791 output++;
792 matchAddress++;
793 }
794 }
795
796 private int decodeCompressedLiterals(final byte[] inputBase,
797 final int inputAddress,
798 final int blockSize,
799 final int literalsBlockType) {
800 int input = inputAddress;
801 verify(blockSize >= 5, input, NOT_ENOUGH_INPUT_BYTES);
802
803
804 final int compressedSize;
805 final int uncompressedSize;
806 boolean singleStream = false;
807 final int headerSize;
808 final int type = (inputBase[input] >> 2) & 0x3;
809 switch (type) {
810 case 0:
811 singleStream = true;
812 case 1: {
813 final int header = getInt(inputBase, input);
814
815 headerSize = 3;
816 uncompressedSize = (header >>> 4) & mask(10);
817 compressedSize = (header >>> 14) & mask(10);
818 break;
819 }
820 case 2: {
821 final int header = getInt(inputBase, input);
822
823 headerSize = 4;
824 uncompressedSize = (header >>> 4) & mask(14);
825 compressedSize = (header >>> 18) & mask(14);
826 break;
827 }
828 case 3: {
829
830 final long header = inputBase[input] & 0xFF |
831 (getInt(inputBase, input + 1) & 0xFFFFFFFFL) << 8;
832
833 headerSize = 5;
834 uncompressedSize = (int) ((header >>> 4) & mask(18));
835 compressedSize = (int) ((header >>> 22) & mask(18));
836 break;
837 }
838 default:
839 throw fail(input, "Invalid literals header size type");
840 }
841
842 verify(uncompressedSize <= MAX_BLOCK_SIZE, input,
843 "Block exceeds maximum size");
844 verify(headerSize + compressedSize <= blockSize, input, INPUT_IS_CORRUPTED);
845
846 input += headerSize;
847
848 final int inputLimit = input + compressedSize;
849 if (literalsBlockType != TREELESS_LITERALS_BLOCK) {
850 input += huffman.readTable(inputBase, input, compressedSize);
851 }
852
853 literalsBase = literals;
854 literalsAddress = 0;
855 literalsLimit = uncompressedSize;
856
857 if (singleStream) {
858 huffman.decodeSingleStream(inputBase, input, inputLimit, literals,
859 literalsAddress, literalsLimit);
860 } else {
861 huffman.decode4Streams(inputBase, input, inputLimit, literals,
862 literalsAddress, literalsLimit);
863 }
864
865 return headerSize + compressedSize;
866 }
867
868 private int decodeRleLiterals(final byte[] inputBase, final int inputAddress,
869 final int blockSize) {
870 int input = inputAddress;
871 final int outputSize;
872
873 final int type = (inputBase[input] >> 2) & 0x3;
874 switch (type) {
875 case 0:
876 case 2:
877 outputSize = (inputBase[input] & 0xFF) >>> 3;
878 input++;
879 break;
880 case 1:
881 outputSize = (getShort(inputBase, input) & 0xFFFF) >>> 4;
882 input += 2;
883 break;
884 case 3:
885
886 verify(blockSize >= SIZE_OF_INT, input, NOT_ENOUGH_INPUT_BYTES);
887 outputSize = (getInt(inputBase, input) & 0xFFFFFF) >>> 4;
888 input += 3;
889 break;
890 default:
891 throw fail(input, "Invalid RLE literals header encoding type");
892 }
893
894 verify(outputSize <= MAX_BLOCK_SIZE, input,
895 "Output exceeds maximum block size");
896
897 final byte value = inputBase[input++];
898 Arrays.fill(literals, 0, outputSize + SIZE_OF_LONG, value);
899
900 literalsBase = literals;
901 literalsAddress = 0;
902 literalsLimit = outputSize;
903
904 return input - inputAddress;
905 }
906
907 private int decodeRawLiterals(final byte[] inputBase, final int inputAddress,
908 final int inputLimit) {
909 int input = inputAddress;
910 final int type = (inputBase[input] >> 2) & 0x3;
911
912 final int literalSize;
913 switch (type) {
914 case 0:
915 case 2:
916 literalSize = (inputBase[input] & 0xFF) >>> 3;
917 input++;
918 break;
919 case 1:
920 literalSize = (getShort(inputBase, input) & 0xFFFF) >>> 4;
921 input += 2;
922 break;
923 case 3:
924
925 final int header = ((inputBase[input] & 0xFF) |
926 ((getShort(inputBase, input + 1) & 0xFFFF) << 8));
927
928 literalSize = header >>> 4;
929 input += 3;
930 break;
931 default:
932 throw fail(input, "Invalid raw literals header encoding type");
933 }
934
935 verify(input + literalSize <= inputLimit, input, NOT_ENOUGH_INPUT_BYTES);
936
937
938
939 if (literalSize > (inputLimit - input) - SIZE_OF_LONG) {
940 literalsBase = literals;
941 literalsAddress = 0;
942 literalsLimit = literalSize;
943
944 copyMemory(inputBase, input, literals, literalsAddress, literalSize);
945 Arrays.fill(literals, literalSize, literalSize + SIZE_OF_LONG, (byte) 0);
946 } else {
947 literalsBase = inputBase;
948 literalsAddress = input;
949 literalsLimit = literalsAddress + literalSize;
950 }
951 input += literalSize;
952
953 return input - inputAddress;
954 }
955
956 static FrameHeader readFrameHeader(final byte[] inputBase,
957 final int inputAddress,
958 final int inputLimit) {
959 int input = inputAddress;
960 verify(input < inputLimit, input, NOT_ENOUGH_INPUT_BYTES);
961
962 final int frameHeaderDescriptor = inputBase[input++] & 0xFF;
963 final boolean singleSegment = (frameHeaderDescriptor & 0x20) != 0;
964 final int dictionaryDescriptor = frameHeaderDescriptor & 0x3;
965 final int contentSizeDescriptor = frameHeaderDescriptor >>> 6;
966
967 final int headerSize = 1 + (singleSegment? 0 : 1) +
968 (dictionaryDescriptor == 0? 0 :
969 (1 << (dictionaryDescriptor - 1))) +
970 (contentSizeDescriptor == 0? (singleSegment? 1 : 0) :
971 (1 << contentSizeDescriptor));
972
973 verify(headerSize <= inputLimit - inputAddress, input,
974 NOT_ENOUGH_INPUT_BYTES);
975
976
977 int windowSize = -1;
978 if (!singleSegment) {
979 final int windowDescriptor = inputBase[input++] & 0xFF;
980 final int exponent = windowDescriptor >>> 3;
981 final int mantissa = windowDescriptor & 0x7;
982
983 final int base = 1 << (MIN_WINDOW_LOG + exponent);
984 windowSize = base + (base / 8) * mantissa;
985 }
986
987
988 int dictionaryId = -1;
989 switch (dictionaryDescriptor) {
990 case 1:
991 dictionaryId = inputBase[input] & 0xFF;
992 input += SIZE_OF_BYTE;
993 break;
994 case 2:
995 dictionaryId = getShort(inputBase, input) & 0xFFFF;
996 input += SIZE_OF_SHORT;
997 break;
998 case 3:
999 dictionaryId = getInt(inputBase, input);
1000 input += SIZE_OF_INT;
1001 break;
1002 }
1003 verify(dictionaryId == -1, input, "Custom dictionaries not supported");
1004
1005
1006 int contentSize = -1;
1007 switch (contentSizeDescriptor) {
1008 case 0:
1009 if (singleSegment) {
1010 contentSize = inputBase[input] & 0xFF;
1011 input += SIZE_OF_BYTE;
1012 }
1013 break;
1014 case 1:
1015 contentSize = getShort(inputBase, input) & 0xFFFF;
1016 contentSize += 256;
1017 input += SIZE_OF_SHORT;
1018 break;
1019 case 2:
1020 contentSize = getInt(inputBase, input);
1021 input += SIZE_OF_INT;
1022 break;
1023 case 3:
1024 contentSize = (int) getLong(inputBase, input);
1025 input += SIZE_OF_LONG;
1026 break;
1027 }
1028
1029 final boolean hasChecksum = (frameHeaderDescriptor & 0x4) != 0;
1030
1031 return new FrameHeader(input - inputAddress, windowSize, contentSize,
1032 dictionaryId, hasChecksum);
1033 }
1034
1035 public static int getDecompressedSize(final byte[] inputBase,
1036 final int inputAddress,
1037 final int inputLimit) {
1038 int input = inputAddress;
1039 input += verifyMagic(inputBase, input, inputLimit);
1040 return readFrameHeader(inputBase, input, inputLimit).contentSize;
1041 }
1042
1043 static int verifyMagic(final byte[] inputBase, final int inputAddress,
1044 final int inputLimit) {
1045 verify(inputLimit - inputAddress >= 4, inputAddress,
1046 NOT_ENOUGH_INPUT_BYTES);
1047
1048 final int magic = getInt(inputBase, inputAddress);
1049 if (magic != MAGIC_NUMBER) {
1050 if (magic == V07_MAGIC_NUMBER) {
1051 throw new MalformedInputException(inputAddress,
1052 "Data encoded in unsupported ZSTD v0.7 format");
1053 }
1054 throw new MalformedInputException(inputAddress, "Invalid magic prefix: " +
1055 Integer.toHexString(
1056 magic));
1057 }
1058
1059 return SIZE_OF_INT;
1060 }
1061 }