1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34 package org.waarp.compress.zstdunsafe;
35
36 import org.waarp.compress.MalformedInputException;
37
38 import java.util.Arrays;
39
40 import static org.waarp.compress.zstdunsafe.BitInputStream.*;
41 import static org.waarp.compress.zstdunsafe.Constants.*;
42 import static org.waarp.compress.zstdunsafe.UnsafeUtil.*;
43 import static org.waarp.compress.zstdunsafe.Util.*;
44 import static sun.misc.Unsafe.*;
45
46 class ZstdFrameDecompressor {
47 private static final int[] DEC_32_TABLE = { 4, 1, 2, 1, 4, 4, 4, 4 };
48 private static final int[] DEC_64_TABLE = { 0, 0, 0, -1, 0, 1, 2, 3 };
49
50 private static final int V07_MAGIC_NUMBER = 0xFD2FB527;
51
52 private static final int MAX_WINDOW_SIZE = 1 << 23;
53
54 private static final int[] LITERALS_LENGTH_BASE = {
55 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18, 20, 22, 24,
56 28, 32, 40, 48, 64, 0x80, 0x100, 0x200, 0x400, 0x800, 0x1000, 0x2000,
57 0x4000, 0x8000, 0x10000
58 };
59
60 private static final int[] MATCH_LENGTH_BASE = {
61 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22,
62 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 37, 39, 41, 43, 47,
63 51, 59, 67, 83, 99, 0x83, 0x103, 0x203, 0x403, 0x803, 0x1003, 0x2003,
64 0x4003, 0x8003, 0x10003
65 };
66
67 private static final int[] OFFSET_CODES_BASE = {
68 0, 1, 1, 5, 0xD, 0x1D, 0x3D, 0x7D, 0xFD, 0x1FD, 0x3FD, 0x7FD, 0xFFD,
69 0x1FFD, 0x3FFD, 0x7FFD, 0xFFFD, 0x1FFFD, 0x3FFFD, 0x7FFFD, 0xFFFFD,
70 0x1FFFFD, 0x3FFFFD, 0x7FFFFD, 0xFFFFFD, 0x1FFFFFD, 0x3FFFFFD, 0x7FFFFFD,
71 0xFFFFFFD
72 };
73
74 private static final FiniteStateEntropy.Table DEFAULT_LITERALS_LENGTH_TABLE =
75 new FiniteStateEntropy.Table(6, new int[] {
76 0, 16, 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32, 0, 0, 0, 0,
77 32, 0, 0, 32, 0, 32, 0, 32, 0, 0, 32, 0, 32, 0, 32, 0, 0, 16, 32, 0,
78 0, 48, 16, 32, 32, 32, 32, 32, 32, 32, 32, 0, 32, 32, 32, 32, 32, 32,
79 0, 0, 0, 0
80 }, new byte[] {
81 0, 0, 1, 3, 4, 6, 7, 9, 10, 12, 14, 16, 18, 19, 21, 22, 24, 25, 26,
82 27, 29, 31, 0, 1, 2, 4, 5, 7, 8, 10, 11, 13, 16, 17, 19, 20, 22, 23,
83 25, 25, 26, 28, 30, 0, 1, 2, 3, 5, 6, 8, 9, 11, 12, 15, 17, 18, 20,
84 21, 23, 24, 35, 34, 33, 32
85 }, new byte[] {
86 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 6, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 4,
87 4, 5, 5, 5, 5, 5, 5, 5, 6, 5, 5, 5, 5, 5, 5, 4, 4, 5, 6, 6, 4, 4, 5,
88 5, 5, 5, 5, 5, 5, 5, 6, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6
89 });
90
91 private static final FiniteStateEntropy.Table DEFAULT_OFFSET_CODES_TABLE =
92 new FiniteStateEntropy.Table(5, new int[] {
93 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 0, 0, 0, 0, 16, 0, 0,
94 0, 16, 0, 0, 0, 0, 0, 0, 0
95 }, new byte[] {
96 0, 6, 9, 15, 21, 3, 7, 12, 18, 23, 5, 8, 14, 20, 2, 7, 11, 17, 22, 4,
97 8, 13, 19, 1, 6, 10, 16, 28, 27, 26, 25, 24
98 }, new byte[] {
99 5, 4, 5, 5, 5, 5, 4, 5, 5, 5, 5, 4, 5, 5, 5, 4, 5, 5, 5, 5, 4, 5, 5,
100 5, 4, 5, 5, 5, 5, 5, 5, 5
101 });
102
103 private static final FiniteStateEntropy.Table DEFAULT_MATCH_LENGTH_TABLE =
104 new FiniteStateEntropy.Table(6, new int[] {
105 0, 0, 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16,
106 0, 32, 0, 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32, 48,
107 16, 32, 32, 32, 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
108 }, new byte[] {
109 0, 1, 2, 3, 5, 6, 8, 10, 13, 16, 19, 22, 25, 28, 31, 33, 35, 37, 39,
110 41, 43, 45, 1, 2, 3, 4, 6, 7, 9, 12, 15, 18, 21, 24, 27, 30, 32, 34,
111 36, 38, 40, 42, 44, 1, 1, 2, 4, 5, 7, 8, 11, 14, 17, 20, 23, 26, 29,
112 52, 51, 50, 49, 48, 47, 46
113 }, new byte[] {
114 6, 4, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 4,
115 4, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 4, 4, 4,
116 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6
117 });
118 public static final String NOT_ENOUGH_INPUT_BYTES = "Not enough input bytes";
119 public static final String OUTPUT_BUFFER_TOO_SMALL =
120 "Output buffer too small";
121 public static final String EXPECTED_MATCH_LENGTH_TABLE_TO_BE_PRESENT =
122 "Expected match length table to be present";
123 public static final String INPUT_IS_CORRUPTED = "Input is corrupted";
124 public static final String VALUE_EXCEEDS_EXPECTED_MAXIMUM_VALUE =
125 "Value exceeds expected maximum value";
126
127 private final byte[] literals = new byte[MAX_BLOCK_SIZE + SIZE_OF_LONG];
128
129
130
131 private Object literalsBase;
132 private long literalsAddress;
133 private long literalsLimit;
134
135 private final int[] previousOffsets = new int[3];
136
137 private final FiniteStateEntropy.Table literalsLengthTable =
138 new FiniteStateEntropy.Table(LITERAL_LENGTH_TABLE_LOG);
139 private final FiniteStateEntropy.Table offsetCodesTable =
140 new FiniteStateEntropy.Table(OFFSET_TABLE_LOG);
141 private final FiniteStateEntropy.Table matchLengthTable =
142 new FiniteStateEntropy.Table(MATCH_LENGTH_TABLE_LOG);
143
144 private FiniteStateEntropy.Table currentLiteralsLengthTable;
145 private FiniteStateEntropy.Table currentOffsetCodesTable;
146 private FiniteStateEntropy.Table currentMatchLengthTable;
147
148 private final Huffman huffman = new Huffman();
149 private final FseTableReader fse = new FseTableReader();
150
151 public int decompress(final Object inputBase, final long inputAddress,
152 final long inputLimit, final Object outputBase,
153 final long outputAddress, final long outputLimit) {
154 if (outputAddress == outputLimit) {
155 return 0;
156 }
157
158 long input = inputAddress;
159 long output = outputAddress;
160
161 while (input < inputLimit) {
162 reset();
163 final long outputStart = output;
164 input += verifyMagic(inputBase, inputAddress, inputLimit);
165
166 final FrameHeader frameHeader =
167 readFrameHeader(inputBase, input, inputLimit);
168 input += frameHeader.headerSize;
169
170 boolean lastBlock;
171 do {
172 verify(input + SIZE_OF_BLOCK_HEADER <= inputLimit, input,
173 NOT_ENOUGH_INPUT_BYTES);
174
175
176 final int header = UNSAFE.getInt(inputBase, input) & 0xFFFFFF;
177 input += SIZE_OF_BLOCK_HEADER;
178
179 lastBlock = (header & 1) != 0;
180 final int blockType = (header >>> 1) & 0x3;
181 final int blockSize = (header >>> 3) & 0x1FFFFF;
182
183 final int decodedSize;
184 switch (blockType) {
185 case RAW_BLOCK:
186 verify(inputAddress + blockSize <= inputLimit, input,
187 NOT_ENOUGH_INPUT_BYTES);
188 decodedSize =
189 decodeRawBlock(inputBase, input, blockSize, outputBase, output,
190 outputLimit);
191 input += blockSize;
192 break;
193 case RLE_BLOCK:
194 verify(inputAddress + 1 <= inputLimit, input,
195 NOT_ENOUGH_INPUT_BYTES);
196 decodedSize =
197 decodeRleBlock(blockSize, inputBase, input, outputBase, output,
198 outputLimit);
199 input += 1;
200 break;
201 case COMPRESSED_BLOCK:
202 verify(inputAddress + blockSize <= inputLimit, input,
203 NOT_ENOUGH_INPUT_BYTES);
204 decodedSize =
205 decodeCompressedBlock(inputBase, input, blockSize, outputBase,
206 output, outputLimit,
207 frameHeader.windowSize, outputAddress);
208 input += blockSize;
209 break;
210 default:
211 throw fail(input, "Invalid block type");
212 }
213
214 output += decodedSize;
215 } while (!lastBlock);
216
217 if (frameHeader.hasChecksum) {
218 final int decodedFrameSize = (int) (output - outputStart);
219
220 final long hash =
221 XxHash64.hash(0, outputBase, outputStart, decodedFrameSize);
222
223 final int checksum = UNSAFE.getInt(inputBase, input);
224 if (checksum != (int) hash) {
225 throw new MalformedInputException(input, String.format(
226 "Bad checksum. Expected: %s, actual: %s",
227 Integer.toHexString(checksum), Integer.toHexString((int) hash)));
228 }
229
230 input += SIZE_OF_INT;
231 }
232 }
233
234 return (int) (output - outputAddress);
235 }
236
237 private void reset() {
238 previousOffsets[0] = 1;
239 previousOffsets[1] = 4;
240 previousOffsets[2] = 8;
241
242 currentLiteralsLengthTable = null;
243 currentOffsetCodesTable = null;
244 currentMatchLengthTable = null;
245 }
246
247 private static int decodeRawBlock(final Object inputBase,
248 final long inputAddress,
249 final int blockSize,
250 final Object outputBase,
251 final long outputAddress,
252 final long outputLimit) {
253 verify(outputAddress + blockSize <= outputLimit, inputAddress,
254 OUTPUT_BUFFER_TOO_SMALL);
255
256 UNSAFE.copyMemory(inputBase, inputAddress, outputBase, outputAddress,
257 blockSize);
258 return blockSize;
259 }
260
261 private static int decodeRleBlock(final int size, final Object inputBase,
262 final long inputAddress,
263 final Object outputBase,
264 final long outputAddress,
265 final long outputLimit) {
266 verify(outputAddress + size <= outputLimit, inputAddress,
267 OUTPUT_BUFFER_TOO_SMALL);
268
269 long output = outputAddress;
270 final long value = UNSAFE.getByte(inputBase, inputAddress) & 0xFFL;
271
272 int remaining = size;
273 if (remaining >= SIZE_OF_LONG) {
274 final long packed =
275 value | (value << 8) | (value << 16) | (value << 24) | (value << 32) |
276 (value << 40) | (value << 48) | (value << 56);
277
278 do {
279 UNSAFE.putLong(outputBase, output, packed);
280 output += SIZE_OF_LONG;
281 remaining -= SIZE_OF_LONG;
282 } while (remaining >= SIZE_OF_LONG);
283 }
284
285 for (int i = 0; i < remaining; i++) {
286 UNSAFE.putByte(outputBase, output, (byte) value);
287 output++;
288 }
289
290 return size;
291 }
292
293 private int decodeCompressedBlock(final Object inputBase,
294 final long inputAddress,
295 final int blockSize,
296 final Object outputBase,
297 final long outputAddress,
298 final long outputLimit,
299 final int windowSize,
300 final long outputAbsoluteBaseAddress) {
301 final long inputLimit = inputAddress + blockSize;
302 long input = inputAddress;
303
304 verify(blockSize <= MAX_BLOCK_SIZE, input,
305 EXPECTED_MATCH_LENGTH_TABLE_TO_BE_PRESENT);
306 verify(blockSize >= MIN_BLOCK_SIZE, input,
307 "Compressed block size too small");
308
309
310 final int literalsBlockType = UNSAFE.getByte(inputBase, input) & 0x3;
311
312 switch (literalsBlockType) {
313 case RAW_LITERALS_BLOCK: {
314 input += decodeRawLiterals(inputBase, input, inputLimit);
315 break;
316 }
317 case RLE_LITERALS_BLOCK: {
318 input += decodeRleLiterals(inputBase, input, blockSize);
319 break;
320 }
321 case TREELESS_LITERALS_BLOCK:
322 verify(huffman.isLoaded(), input, "Dictionary is corrupted");
323 case COMPRESSED_LITERALS_BLOCK: {
324 input += decodeCompressedLiterals(inputBase, input, blockSize,
325 literalsBlockType);
326 break;
327 }
328 default:
329 throw fail(input, "Invalid literals block encoding type");
330 }
331
332 verify(windowSize <= MAX_WINDOW_SIZE, input,
333 "Window size too large (not yet supported)");
334
335 return decompressSequences(inputBase, input, inputAddress + blockSize,
336 outputBase, outputAddress, outputLimit,
337 literalsBase, literalsAddress, literalsLimit,
338 outputAbsoluteBaseAddress);
339 }
340
341 private int decompressSequences(final Object inputBase,
342 final long inputAddress,
343 final long inputLimit,
344 final Object outputBase,
345 final long outputAddress,
346 final long outputLimit,
347 final Object literalsBase,
348 final long literalsAddress,
349 final long literalsLimit,
350 final long outputAbsoluteBaseAddress) {
351 final long fastOutputLimit = outputLimit - SIZE_OF_LONG;
352 final long fastMatchOutputLimit = fastOutputLimit - SIZE_OF_LONG;
353
354 long input = inputAddress;
355 long output = outputAddress;
356
357 long literalsInput = literalsAddress;
358
359 final int size = (int) (inputLimit - inputAddress);
360 verify(size >= MIN_SEQUENCES_SIZE, input, NOT_ENOUGH_INPUT_BYTES);
361
362
363 int sequenceCount = UNSAFE.getByte(inputBase, input++) & 0xFF;
364 if (sequenceCount != 0) {
365 if (sequenceCount == 255) {
366 verify(input + SIZE_OF_SHORT <= inputLimit, input,
367 NOT_ENOUGH_INPUT_BYTES);
368 sequenceCount = (UNSAFE.getShort(inputBase, input) & 0xFFFF) +
369 LONG_NUMBER_OF_SEQUENCES;
370 input += SIZE_OF_SHORT;
371 } else if (sequenceCount > 127) {
372 verify(input < inputLimit, input, NOT_ENOUGH_INPUT_BYTES);
373 sequenceCount = ((sequenceCount - 128) << 8) +
374 (UNSAFE.getByte(inputBase, input++) & 0xFF);
375 }
376
377 verify(input + SIZE_OF_INT <= inputLimit, input, NOT_ENOUGH_INPUT_BYTES);
378
379 final byte type = UNSAFE.getByte(inputBase, input++);
380
381 final int literalsLengthType = (type & 0xFF) >>> 6;
382 final int offsetCodesType = (type >>> 4) & 0x3;
383 final int matchLengthType = (type >>> 2) & 0x3;
384
385 input = computeLiteralsTable(literalsLengthType, inputBase, input,
386 inputLimit);
387 input =
388 computeOffsetsTable(offsetCodesType, inputBase, input, inputLimit);
389 input = computeMatchLengthTable(matchLengthType, inputBase, input,
390 inputLimit);
391
392
393 final BitInputStream.Initializer initializer =
394 new BitInputStream.Initializer(inputBase, input, inputLimit);
395 initializer.initialize();
396 int bitsConsumed = initializer.getBitsConsumed();
397 long bits = initializer.getBits();
398 long currentAddress = initializer.getCurrentAddress();
399
400 final FiniteStateEntropy.Table literalsLengthTable1 =
401 this.currentLiteralsLengthTable;
402 final FiniteStateEntropy.Table offsetCodesTable1 =
403 this.currentOffsetCodesTable;
404 final FiniteStateEntropy.Table matchLengthTable1 =
405 this.currentMatchLengthTable;
406
407 int literalsLengthState =
408 (int) peekBits(bitsConsumed, bits, literalsLengthTable1.log2Size);
409 bitsConsumed += literalsLengthTable1.log2Size;
410
411 int offsetCodesState =
412 (int) peekBits(bitsConsumed, bits, offsetCodesTable1.log2Size);
413 bitsConsumed += offsetCodesTable1.log2Size;
414
415 int matchLengthState =
416 (int) peekBits(bitsConsumed, bits, matchLengthTable1.log2Size);
417 bitsConsumed += matchLengthTable1.log2Size;
418
419 final int[] previousOffsets1 = this.previousOffsets;
420
421 final byte[] literalsLengthNumbersOfBits =
422 literalsLengthTable1.numberOfBits;
423 final int[] literalsLengthNewStates = literalsLengthTable1.newState;
424 final byte[] literalsLengthSymbols = literalsLengthTable1.symbol;
425
426 final byte[] matchLengthNumbersOfBits = matchLengthTable1.numberOfBits;
427 final int[] matchLengthNewStates = matchLengthTable1.newState;
428 final byte[] matchLengthSymbols = matchLengthTable1.symbol;
429
430 final byte[] offsetCodesNumbersOfBits = offsetCodesTable1.numberOfBits;
431 final int[] offsetCodesNewStates = offsetCodesTable1.newState;
432 final byte[] offsetCodesSymbols = offsetCodesTable1.symbol;
433
434 while (sequenceCount > 0) {
435 sequenceCount--;
436
437 final BitInputStream.Loader loader =
438 new BitInputStream.Loader(inputBase, input, currentAddress, bits,
439 bitsConsumed);
440 loader.load();
441 bitsConsumed = loader.getBitsConsumed();
442 bits = loader.getBits();
443 currentAddress = loader.getCurrentAddress();
444 if (loader.isOverflow()) {
445 verify(sequenceCount == 0, input, "Not all sequences were consumed");
446 break;
447 }
448
449
450 final int literalsLengthCode =
451 literalsLengthSymbols[literalsLengthState];
452 final int matchLengthCode = matchLengthSymbols[matchLengthState];
453 final int offsetCode = offsetCodesSymbols[offsetCodesState];
454
455 final int literalsLengthBits = LITERALS_LENGTH_BITS[literalsLengthCode];
456 final int matchLengthBits = MATCH_LENGTH_BITS[matchLengthCode];
457
458 int offset = OFFSET_CODES_BASE[offsetCode];
459 if (offsetCode > 0) {
460 offset += peekBits(bitsConsumed, bits, offsetCode);
461 bitsConsumed += offsetCode;
462 }
463
464 if (offsetCode <= 1) {
465 if (literalsLengthCode == 0) {
466 offset++;
467 }
468
469 if (offset != 0) {
470 int temp;
471 if (offset == 3) {
472 temp = previousOffsets1[0] - 1;
473 } else {
474 temp = previousOffsets1[offset];
475 }
476
477 if (temp == 0) {
478 temp = 1;
479 }
480
481 if (offset != 1) {
482 previousOffsets1[2] = previousOffsets1[1];
483 }
484 previousOffsets1[1] = previousOffsets1[0];
485 previousOffsets1[0] = temp;
486
487 offset = temp;
488 } else {
489 offset = previousOffsets1[0];
490 }
491 } else {
492 previousOffsets1[2] = previousOffsets1[1];
493 previousOffsets1[1] = previousOffsets1[0];
494 previousOffsets1[0] = offset;
495 }
496
497 int matchLength = MATCH_LENGTH_BASE[matchLengthCode];
498 if (matchLengthCode > 31) {
499 matchLength += peekBits(bitsConsumed, bits, matchLengthBits);
500 bitsConsumed += matchLengthBits;
501 }
502
503 int literalsLength = LITERALS_LENGTH_BASE[literalsLengthCode];
504 if (literalsLengthCode > 15) {
505 literalsLength += peekBits(bitsConsumed, bits, literalsLengthBits);
506 bitsConsumed += literalsLengthBits;
507 }
508
509 final int totalBits = literalsLengthBits + matchLengthBits + offsetCode;
510 if (totalBits > 64 - 7 -
511 (LITERAL_LENGTH_TABLE_LOG + MATCH_LENGTH_TABLE_LOG +
512 OFFSET_TABLE_LOG)) {
513 final BitInputStream.Loader loader1 =
514 new BitInputStream.Loader(inputBase, input, currentAddress, bits,
515 bitsConsumed);
516 loader1.load();
517
518 bitsConsumed = loader1.getBitsConsumed();
519 bits = loader1.getBits();
520 currentAddress = loader1.getCurrentAddress();
521 }
522
523 int numberOfBits;
524
525 numberOfBits = literalsLengthNumbersOfBits[literalsLengthState];
526 literalsLengthState =
527 (int) (literalsLengthNewStates[literalsLengthState] +
528 peekBits(bitsConsumed, bits, numberOfBits));
529 bitsConsumed += numberOfBits;
530
531 numberOfBits = matchLengthNumbersOfBits[matchLengthState];
532 matchLengthState = (int) (matchLengthNewStates[matchLengthState] +
533 peekBits(bitsConsumed, bits,
534 numberOfBits));
535 bitsConsumed += numberOfBits;
536
537 numberOfBits = offsetCodesNumbersOfBits[offsetCodesState];
538 offsetCodesState = (int) (offsetCodesNewStates[offsetCodesState] +
539 peekBits(bitsConsumed, bits,
540 numberOfBits));
541 bitsConsumed += numberOfBits;
542
543 final long literalOutputLimit = output + literalsLength;
544 final long matchOutputLimit = literalOutputLimit + matchLength;
545
546 verify(matchOutputLimit <= outputLimit, input, OUTPUT_BUFFER_TOO_SMALL);
547 final long literalEnd = literalsInput + literalsLength;
548 verify(literalEnd <= literalsLimit, input, INPUT_IS_CORRUPTED);
549
550 final long matchAddress = literalOutputLimit - offset;
551 verify(matchAddress >= outputAbsoluteBaseAddress, input,
552 INPUT_IS_CORRUPTED);
553
554 if (literalOutputLimit > fastOutputLimit) {
555 executeLastSequence(outputBase, output, literalOutputLimit,
556 matchOutputLimit, fastOutputLimit, literalsInput,
557 matchAddress);
558 } else {
559
560
561 output = copyLiterals(outputBase, literalsBase, output, literalsInput,
562 literalOutputLimit);
563 copyMatch(outputBase, fastOutputLimit, output, offset,
564 matchOutputLimit, matchAddress, matchLength,
565 fastMatchOutputLimit);
566 }
567 output = matchOutputLimit;
568 literalsInput = literalEnd;
569 }
570 }
571
572
573 output = copyLastLiteral(outputBase, literalsBase, literalsLimit, output,
574 literalsInput);
575
576 return (int) (output - outputAddress);
577 }
578
579 private long copyLastLiteral(final Object outputBase,
580 final Object literalsBase,
581 final long literalsLimit, long output,
582 final long literalsInput) {
583 final long lastLiteralsSize = literalsLimit - literalsInput;
584 UNSAFE.copyMemory(literalsBase, literalsInput, outputBase, output,
585 lastLiteralsSize);
586 output += lastLiteralsSize;
587 return output;
588 }
589
590 private void copyMatch(final Object outputBase, final long fastOutputLimit,
591 long output, final int offset,
592 final long matchOutputLimit, long matchAddress,
593 int matchLength, final long fastMatchOutputLimit) {
594 matchAddress = copyMatchHead(outputBase, output, offset, matchAddress);
595 output += SIZE_OF_LONG;
596 matchLength -= SIZE_OF_LONG;
597
598 copyMatchTail(outputBase, fastOutputLimit, output, matchOutputLimit,
599 matchAddress, matchLength, fastMatchOutputLimit);
600 }
601
602 private void copyMatchTail(final Object outputBase,
603 final long fastOutputLimit, long output,
604 final long matchOutputLimit, long matchAddress,
605 final int matchLength,
606 final long fastMatchOutputLimit) {
607
608
609
610
611 if (matchOutputLimit < fastMatchOutputLimit) {
612 int copied = 0;
613 do {
614 UNSAFE.putLong(outputBase, output,
615 UNSAFE.getLong(outputBase, matchAddress));
616 output += SIZE_OF_LONG;
617 matchAddress += SIZE_OF_LONG;
618 copied += SIZE_OF_LONG;
619 } while (copied < matchLength);
620 } else {
621 while (output < fastOutputLimit) {
622 UNSAFE.putLong(outputBase, output,
623 UNSAFE.getLong(outputBase, matchAddress));
624 matchAddress += SIZE_OF_LONG;
625 output += SIZE_OF_LONG;
626 }
627
628 while (output < matchOutputLimit) {
629 UNSAFE.putByte(outputBase, output++,
630 UNSAFE.getByte(outputBase, matchAddress++));
631 }
632 }
633 }
634
635 private long copyMatchHead(final Object outputBase, final long output,
636 final int offset, long matchAddress) {
637
638 if (offset < 8) {
639
640 final int increment32 = DEC_32_TABLE[offset];
641 final int decrement64 = DEC_64_TABLE[offset];
642
643 UNSAFE.putByte(outputBase, output,
644 UNSAFE.getByte(outputBase, matchAddress));
645 UNSAFE.putByte(outputBase, output + 1,
646 UNSAFE.getByte(outputBase, matchAddress + 1));
647 UNSAFE.putByte(outputBase, output + 2,
648 UNSAFE.getByte(outputBase, matchAddress + 2));
649 UNSAFE.putByte(outputBase, output + 3,
650 UNSAFE.getByte(outputBase, matchAddress + 3));
651 matchAddress += increment32;
652
653 UNSAFE.putInt(outputBase, output + 4,
654 UNSAFE.getInt(outputBase, matchAddress));
655 matchAddress -= decrement64;
656 } else {
657 UNSAFE.putLong(outputBase, output,
658 UNSAFE.getLong(outputBase, matchAddress));
659 matchAddress += SIZE_OF_LONG;
660 }
661 return matchAddress;
662 }
663
664 private long copyLiterals(final Object outputBase, final Object literalsBase,
665 long output, final long literalsInput,
666 final long literalOutputLimit) {
667 long literalInput = literalsInput;
668 do {
669 UNSAFE.putLong(outputBase, output,
670 UNSAFE.getLong(literalsBase, literalInput));
671 output += SIZE_OF_LONG;
672 literalInput += SIZE_OF_LONG;
673 } while (output < literalOutputLimit);
674 output = literalOutputLimit;
675 return output;
676 }
677
678 private long computeMatchLengthTable(final int matchLengthType,
679 final Object inputBase, long input,
680 final long inputLimit) {
681 switch (matchLengthType) {
682 case SEQUENCE_ENCODING_RLE:
683 verify(input < inputLimit, input, NOT_ENOUGH_INPUT_BYTES);
684
685 final byte value = UNSAFE.getByte(inputBase, input++);
686 verify(value <= MAX_MATCH_LENGTH_SYMBOL, input,
687 VALUE_EXCEEDS_EXPECTED_MAXIMUM_VALUE);
688
689 FseTableReader.initializeRleTable(matchLengthTable, value);
690 currentMatchLengthTable = matchLengthTable;
691 break;
692 case SEQUENCE_ENCODING_BASIC:
693 currentMatchLengthTable = DEFAULT_MATCH_LENGTH_TABLE;
694 break;
695 case SEQUENCE_ENCODING_REPEAT:
696 verify(currentMatchLengthTable != null, input,
697 EXPECTED_MATCH_LENGTH_TABLE_TO_BE_PRESENT);
698 break;
699 case SEQUENCE_ENCODING_COMPRESSED:
700 input +=
701 fse.readFseTable(matchLengthTable, inputBase, input, inputLimit,
702 MAX_MATCH_LENGTH_SYMBOL, MATCH_LENGTH_TABLE_LOG);
703 currentMatchLengthTable = matchLengthTable;
704 break;
705 default:
706 throw fail(input, "Invalid match length encoding type");
707 }
708 return input;
709 }
710
711 private long computeOffsetsTable(final int offsetCodesType,
712 final Object inputBase, long input,
713 final long inputLimit) {
714 switch (offsetCodesType) {
715 case SEQUENCE_ENCODING_RLE:
716 verify(input < inputLimit, input, NOT_ENOUGH_INPUT_BYTES);
717
718 final byte value = UNSAFE.getByte(inputBase, input++);
719 verify(value <= DEFAULT_MAX_OFFSET_CODE_SYMBOL, input,
720 VALUE_EXCEEDS_EXPECTED_MAXIMUM_VALUE);
721
722 FseTableReader.initializeRleTable(offsetCodesTable, value);
723 currentOffsetCodesTable = offsetCodesTable;
724 break;
725 case SEQUENCE_ENCODING_BASIC:
726 currentOffsetCodesTable = DEFAULT_OFFSET_CODES_TABLE;
727 break;
728 case SEQUENCE_ENCODING_REPEAT:
729 verify(currentOffsetCodesTable != null, input,
730 EXPECTED_MATCH_LENGTH_TABLE_TO_BE_PRESENT);
731 break;
732 case SEQUENCE_ENCODING_COMPRESSED:
733 input +=
734 fse.readFseTable(offsetCodesTable, inputBase, input, inputLimit,
735 DEFAULT_MAX_OFFSET_CODE_SYMBOL, OFFSET_TABLE_LOG);
736 currentOffsetCodesTable = offsetCodesTable;
737 break;
738 default:
739 throw fail(input, "Invalid offset code encoding type");
740 }
741 return input;
742 }
743
744 private long computeLiteralsTable(final int literalsLengthType,
745 final Object inputBase, long input,
746 final long inputLimit) {
747 switch (literalsLengthType) {
748 case SEQUENCE_ENCODING_RLE:
749 verify(input < inputLimit, input, NOT_ENOUGH_INPUT_BYTES);
750
751 final byte value = UNSAFE.getByte(inputBase, input++);
752 verify(value <= MAX_LITERALS_LENGTH_SYMBOL, input,
753 VALUE_EXCEEDS_EXPECTED_MAXIMUM_VALUE);
754
755 FseTableReader.initializeRleTable(literalsLengthTable, value);
756 currentLiteralsLengthTable = literalsLengthTable;
757 break;
758 case SEQUENCE_ENCODING_BASIC:
759 currentLiteralsLengthTable = DEFAULT_LITERALS_LENGTH_TABLE;
760 break;
761 case SEQUENCE_ENCODING_REPEAT:
762 verify(currentLiteralsLengthTable != null, input,
763 EXPECTED_MATCH_LENGTH_TABLE_TO_BE_PRESENT);
764 break;
765 case SEQUENCE_ENCODING_COMPRESSED:
766 input +=
767 fse.readFseTable(literalsLengthTable, inputBase, input, inputLimit,
768 MAX_LITERALS_LENGTH_SYMBOL,
769 LITERAL_LENGTH_TABLE_LOG);
770 currentLiteralsLengthTable = literalsLengthTable;
771 break;
772 default:
773 throw fail(input, "Invalid literals length encoding type");
774 }
775 return input;
776 }
777
778 private void executeLastSequence(final Object outputBase, long output,
779 final long literalOutputLimit,
780 final long matchOutputLimit,
781 final long fastOutputLimit,
782 long literalInput, long matchAddress) {
783
784 if (output < fastOutputLimit) {
785
786 do {
787 UNSAFE.putLong(outputBase, output,
788 UNSAFE.getLong(literalsBase, literalInput));
789 output += SIZE_OF_LONG;
790 literalInput += SIZE_OF_LONG;
791 } while (output < fastOutputLimit);
792
793 literalInput -= output - fastOutputLimit;
794 output = fastOutputLimit;
795 }
796
797 while (output < literalOutputLimit) {
798 UNSAFE.putByte(outputBase, output,
799 UNSAFE.getByte(literalsBase, literalInput));
800 output++;
801 literalInput++;
802 }
803
804
805 while (output < matchOutputLimit) {
806 UNSAFE.putByte(outputBase, output,
807 UNSAFE.getByte(outputBase, matchAddress));
808 output++;
809 matchAddress++;
810 }
811 }
812
813 private int decodeCompressedLiterals(final Object inputBase,
814 final long inputAddress,
815 final int blockSize,
816 final int literalsBlockType) {
817 long input = inputAddress;
818 verify(blockSize >= 5, input, NOT_ENOUGH_INPUT_BYTES);
819
820
821 final int compressedSize;
822 final int uncompressedSize;
823 boolean singleStream = false;
824 final int headerSize;
825 final int type = (UNSAFE.getByte(inputBase, input) >> 2) & 0x3;
826 switch (type) {
827 case 0:
828 singleStream = true;
829 case 1: {
830 final int header = UNSAFE.getInt(inputBase, input);
831
832 headerSize = 3;
833 uncompressedSize = (header >>> 4) & mask(10);
834 compressedSize = (header >>> 14) & mask(10);
835 break;
836 }
837 case 2: {
838 final int header = UNSAFE.getInt(inputBase, input);
839
840 headerSize = 4;
841 uncompressedSize = (header >>> 4) & mask(14);
842 compressedSize = (header >>> 18) & mask(14);
843 break;
844 }
845 case 3: {
846
847 final long header = UNSAFE.getByte(inputBase, input) & 0xFF |
848 (UNSAFE.getInt(inputBase, input + 1) &
849 0xFFFFFFFFL) << 8;
850
851 headerSize = 5;
852 uncompressedSize = (int) ((header >>> 4) & mask(18));
853 compressedSize = (int) ((header >>> 22) & mask(18));
854 break;
855 }
856 default:
857 throw fail(input, "Invalid literals header size type");
858 }
859
860 verify(uncompressedSize <= MAX_BLOCK_SIZE, input,
861 "Block exceeds maximum size");
862 verify(headerSize + compressedSize <= blockSize, input, INPUT_IS_CORRUPTED);
863
864 input += headerSize;
865
866 final long inputLimit = input + compressedSize;
867 if (literalsBlockType != TREELESS_LITERALS_BLOCK) {
868 input += huffman.readTable(inputBase, input, compressedSize);
869 }
870
871 literalsBase = literals;
872 literalsAddress = ARRAY_BYTE_BASE_OFFSET;
873 literalsLimit = ARRAY_BYTE_BASE_OFFSET + uncompressedSize;
874
875 if (singleStream) {
876 huffman.decodeSingleStream(inputBase, input, inputLimit, literals,
877 literalsAddress, literalsLimit);
878 } else {
879 huffman.decode4Streams(inputBase, input, inputLimit, literals,
880 literalsAddress, literalsLimit);
881 }
882
883 return headerSize + compressedSize;
884 }
885
886 private int decodeRleLiterals(final Object inputBase, final long inputAddress,
887 final int blockSize) {
888 long input = inputAddress;
889 final int outputSize;
890
891 final int type = (UNSAFE.getByte(inputBase, input) >> 2) & 0x3;
892 switch (type) {
893 case 0:
894 case 2:
895 outputSize = (UNSAFE.getByte(inputBase, input) & 0xFF) >>> 3;
896 input++;
897 break;
898 case 1:
899 outputSize = (UNSAFE.getShort(inputBase, input) & 0xFFFF) >>> 4;
900 input += 2;
901 break;
902 case 3:
903
904 verify(blockSize >= SIZE_OF_INT, input, NOT_ENOUGH_INPUT_BYTES);
905 outputSize = (UNSAFE.getInt(inputBase, input) & 0xFFFFFF) >>> 4;
906 input += 3;
907 break;
908 default:
909 throw fail(input, "Invalid RLE literals header encoding type");
910 }
911
912 verify(outputSize <= MAX_BLOCK_SIZE, input,
913 "Output exceeds maximum block size");
914
915 final byte value = UNSAFE.getByte(inputBase, input++);
916 Arrays.fill(literals, 0, outputSize + SIZE_OF_LONG, value);
917
918 literalsBase = literals;
919 literalsAddress = ARRAY_BYTE_BASE_OFFSET;
920 literalsLimit = ARRAY_BYTE_BASE_OFFSET + outputSize;
921
922 return (int) (input - inputAddress);
923 }
924
925 private int decodeRawLiterals(final Object inputBase, final long inputAddress,
926 final long inputLimit) {
927 long input = inputAddress;
928 final int type = (UNSAFE.getByte(inputBase, input) >> 2) & 0x3;
929
930 final int literalSize;
931 switch (type) {
932 case 0:
933 case 2:
934 literalSize = (UNSAFE.getByte(inputBase, input) & 0xFF) >>> 3;
935 input++;
936 break;
937 case 1:
938 literalSize = (UNSAFE.getShort(inputBase, input) & 0xFFFF) >>> 4;
939 input += 2;
940 break;
941 case 3:
942
943 final int header = ((UNSAFE.getByte(inputBase, input) & 0xFF) |
944 ((UNSAFE.getShort(inputBase, input + 1) & 0xFFFF) <<
945 8));
946
947 literalSize = header >>> 4;
948 input += 3;
949 break;
950 default:
951 throw fail(input, "Invalid raw literals header encoding type");
952 }
953
954 verify(input + literalSize <= inputLimit, input, NOT_ENOUGH_INPUT_BYTES);
955
956
957
958 if (literalSize > (inputLimit - input) - SIZE_OF_LONG) {
959 literalsBase = literals;
960 literalsAddress = ARRAY_BYTE_BASE_OFFSET;
961 literalsLimit = ARRAY_BYTE_BASE_OFFSET + literalSize;
962
963 UNSAFE.copyMemory(inputBase, input, literals, literalsAddress,
964 literalSize);
965 Arrays.fill(literals, literalSize, literalSize + SIZE_OF_LONG, (byte) 0);
966 } else {
967 literalsBase = inputBase;
968 literalsAddress = input;
969 literalsLimit = literalsAddress + literalSize;
970 }
971 input += literalSize;
972
973 return (int) (input - inputAddress);
974 }
975
976 static FrameHeader readFrameHeader(final Object inputBase,
977 final long inputAddress,
978 final long inputLimit) {
979 long input = inputAddress;
980 verify(input < inputLimit, input, NOT_ENOUGH_INPUT_BYTES);
981
982 final int frameHeaderDescriptor = UNSAFE.getByte(inputBase, input++) & 0xFF;
983 final boolean singleSegment = (frameHeaderDescriptor & 0x20) != 0;
984 final int dictionaryDescriptor = frameHeaderDescriptor & 0x3;
985 final int contentSizeDescriptor = frameHeaderDescriptor >>> 6;
986
987 final int headerSize = 1 + (singleSegment? 0 : 1) +
988 (dictionaryDescriptor == 0? 0 :
989 (1 << (dictionaryDescriptor - 1))) +
990 (contentSizeDescriptor == 0? (singleSegment? 1 : 0) :
991 (1 << contentSizeDescriptor));
992
993 verify(headerSize <= inputLimit - inputAddress, input,
994 NOT_ENOUGH_INPUT_BYTES);
995
996
997 int windowSize = -1;
998 if (!singleSegment) {
999 final int windowDescriptor = UNSAFE.getByte(inputBase, input++) & 0xFF;
1000 final int exponent = windowDescriptor >>> 3;
1001 final int mantissa = windowDescriptor & 0x7;
1002
1003 final int base = 1 << (MIN_WINDOW_LOG + exponent);
1004 windowSize = base + (base / 8) * mantissa;
1005 }
1006
1007
1008 long dictionaryId = -1;
1009 switch (dictionaryDescriptor) {
1010 case 1:
1011 dictionaryId = UNSAFE.getByte(inputBase, input) & 0xFF;
1012 input += SIZE_OF_BYTE;
1013 break;
1014 case 2:
1015 dictionaryId = UNSAFE.getShort(inputBase, input) & 0xFFFF;
1016 input += SIZE_OF_SHORT;
1017 break;
1018 case 3:
1019 dictionaryId = UNSAFE.getInt(inputBase, input) & 0xFFFFFFFFL;
1020 input += SIZE_OF_INT;
1021 break;
1022 }
1023 verify(dictionaryId == -1, input, "Custom dictionaries not supported");
1024
1025
1026 long contentSize = -1;
1027 switch (contentSizeDescriptor) {
1028 case 0:
1029 if (singleSegment) {
1030 contentSize = UNSAFE.getByte(inputBase, input) & 0xFF;
1031 input += SIZE_OF_BYTE;
1032 }
1033 break;
1034 case 1:
1035 contentSize = UNSAFE.getShort(inputBase, input) & 0xFFFF;
1036 contentSize += 256;
1037 input += SIZE_OF_SHORT;
1038 break;
1039 case 2:
1040 contentSize = UNSAFE.getInt(inputBase, input) & 0xFFFFFFFFL;
1041 input += SIZE_OF_INT;
1042 break;
1043 case 3:
1044 contentSize = UNSAFE.getLong(inputBase, input);
1045 input += SIZE_OF_LONG;
1046 break;
1047 }
1048
1049 final boolean hasChecksum = (frameHeaderDescriptor & 0x4) != 0;
1050
1051 return new FrameHeader(input - inputAddress, windowSize, contentSize,
1052 dictionaryId, hasChecksum);
1053 }
1054
1055 public static long getDecompressedSize(final Object inputBase,
1056 final long inputAddress,
1057 final long inputLimit) {
1058 long input = inputAddress;
1059 input += verifyMagic(inputBase, input, inputLimit);
1060 return readFrameHeader(inputBase, input, inputLimit).contentSize;
1061 }
1062
1063 static int verifyMagic(final Object inputBase, final long inputAddress,
1064 final long inputLimit) {
1065 verify(inputLimit - inputAddress >= 4, inputAddress,
1066 NOT_ENOUGH_INPUT_BYTES);
1067
1068 final int magic = UNSAFE.getInt(inputBase, inputAddress);
1069 if (magic != MAGIC_NUMBER) {
1070 if (magic == V07_MAGIC_NUMBER) {
1071 throw new MalformedInputException(inputAddress,
1072 "Data encoded in unsupported ZSTD v0.7 format");
1073 }
1074 throw new MalformedInputException(inputAddress, "Invalid magic prefix: " +
1075 Integer.toHexString(
1076 magic));
1077 }
1078
1079 return SIZE_OF_INT;
1080 }
1081 }