View Javadoc
1   /*
2    * This file is part of Waarp Project (named also Waarp or GG).
3    *
4    *  Copyright (c) 2019, Waarp SAS, and individual contributors by the @author
5    *  tags. See the COPYRIGHT.txt in the distribution for a full listing of
6    * individual contributors.
7    *
8    *  All Waarp Project is free software: you can redistribute it and/or
9    * modify it under the terms of the GNU General Public License as published by
10   * the Free Software Foundation, either version 3 of the License, or (at your
11   * option) any later version.
12   *
13   * Waarp is distributed in the hope that it will be useful, but WITHOUT ANY
14   * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
15   * A PARTICULAR PURPOSE. See the GNU General Public License for more details.
16   *
17   *  You should have received a copy of the GNU General Public License along with
18   * Waarp . If not, see <http://www.gnu.org/licenses/>.
19   */
20  
21  /*
22   * Licensed under the Apache License, Version 2.0 (the "License");
23   * you may not use this file except in compliance with the License.
24   * You may obtain a copy of the License at
25   *
26   *     http://www.apache.org/licenses/LICENSE-2.0
27   *
28   * Unless required by applicable law or agreed to in writing, software
29   * distributed under the License is distributed on an "AS IS" BASIS,
30   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31   * See the License for the specific language governing permissions and
32   * limitations under the License.
33   */
34  package org.waarp.compress.zstdunsafe;
35  
36  import org.waarp.compress.MalformedInputException;
37  
38  import java.util.Arrays;
39  
40  import static org.waarp.compress.zstdunsafe.BitInputStream.*;
41  import static org.waarp.compress.zstdunsafe.Constants.*;
42  import static org.waarp.compress.zstdunsafe.UnsafeUtil.*;
43  import static org.waarp.compress.zstdunsafe.Util.*;
44  import static sun.misc.Unsafe.*;
45  
46  class ZstdFrameDecompressor {
47    private static final int[] DEC_32_TABLE = { 4, 1, 2, 1, 4, 4, 4, 4 };
48    private static final int[] DEC_64_TABLE = { 0, 0, 0, -1, 0, 1, 2, 3 };
49  
50    private static final int V07_MAGIC_NUMBER = 0xFD2FB527;
51  
52    private static final int MAX_WINDOW_SIZE = 1 << 23;
53  
54    private static final int[] LITERALS_LENGTH_BASE = {
55        0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18, 20, 22, 24,
56        28, 32, 40, 48, 64, 0x80, 0x100, 0x200, 0x400, 0x800, 0x1000, 0x2000,
57        0x4000, 0x8000, 0x10000
58    };
59  
60    private static final int[] MATCH_LENGTH_BASE = {
61        3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22,
62        23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 37, 39, 41, 43, 47,
63        51, 59, 67, 83, 99, 0x83, 0x103, 0x203, 0x403, 0x803, 0x1003, 0x2003,
64        0x4003, 0x8003, 0x10003
65    };
66  
67    private static final int[] OFFSET_CODES_BASE = {
68        0, 1, 1, 5, 0xD, 0x1D, 0x3D, 0x7D, 0xFD, 0x1FD, 0x3FD, 0x7FD, 0xFFD,
69        0x1FFD, 0x3FFD, 0x7FFD, 0xFFFD, 0x1FFFD, 0x3FFFD, 0x7FFFD, 0xFFFFD,
70        0x1FFFFD, 0x3FFFFD, 0x7FFFFD, 0xFFFFFD, 0x1FFFFFD, 0x3FFFFFD, 0x7FFFFFD,
71        0xFFFFFFD
72    };
73  
74    private static final FiniteStateEntropy.Table DEFAULT_LITERALS_LENGTH_TABLE =
75        new FiniteStateEntropy.Table(6, new int[] {
76            0, 16, 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32, 0, 0, 0, 0,
77            32, 0, 0, 32, 0, 32, 0, 32, 0, 0, 32, 0, 32, 0, 32, 0, 0, 16, 32, 0,
78            0, 48, 16, 32, 32, 32, 32, 32, 32, 32, 32, 0, 32, 32, 32, 32, 32, 32,
79            0, 0, 0, 0
80        }, new byte[] {
81            0, 0, 1, 3, 4, 6, 7, 9, 10, 12, 14, 16, 18, 19, 21, 22, 24, 25, 26,
82            27, 29, 31, 0, 1, 2, 4, 5, 7, 8, 10, 11, 13, 16, 17, 19, 20, 22, 23,
83            25, 25, 26, 28, 30, 0, 1, 2, 3, 5, 6, 8, 9, 11, 12, 15, 17, 18, 20,
84            21, 23, 24, 35, 34, 33, 32
85        }, new byte[] {
86            4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 6, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 4,
87            4, 5, 5, 5, 5, 5, 5, 5, 6, 5, 5, 5, 5, 5, 5, 4, 4, 5, 6, 6, 4, 4, 5,
88            5, 5, 5, 5, 5, 5, 5, 6, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6
89        });
90  
91    private static final FiniteStateEntropy.Table DEFAULT_OFFSET_CODES_TABLE =
92        new FiniteStateEntropy.Table(5, new int[] {
93            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 0, 0, 0, 0, 16, 0, 0,
94            0, 16, 0, 0, 0, 0, 0, 0, 0
95        }, new byte[] {
96            0, 6, 9, 15, 21, 3, 7, 12, 18, 23, 5, 8, 14, 20, 2, 7, 11, 17, 22, 4,
97            8, 13, 19, 1, 6, 10, 16, 28, 27, 26, 25, 24
98        }, new byte[] {
99            5, 4, 5, 5, 5, 5, 4, 5, 5, 5, 5, 4, 5, 5, 5, 4, 5, 5, 5, 5, 4, 5, 5,
100           5, 4, 5, 5, 5, 5, 5, 5, 5
101       });
102 
103   private static final FiniteStateEntropy.Table DEFAULT_MATCH_LENGTH_TABLE =
104       new FiniteStateEntropy.Table(6, new int[] {
105           0, 0, 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16,
106           0, 32, 0, 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32, 48,
107           16, 32, 32, 32, 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
108       }, new byte[] {
109           0, 1, 2, 3, 5, 6, 8, 10, 13, 16, 19, 22, 25, 28, 31, 33, 35, 37, 39,
110           41, 43, 45, 1, 2, 3, 4, 6, 7, 9, 12, 15, 18, 21, 24, 27, 30, 32, 34,
111           36, 38, 40, 42, 44, 1, 1, 2, 4, 5, 7, 8, 11, 14, 17, 20, 23, 26, 29,
112           52, 51, 50, 49, 48, 47, 46
113       }, new byte[] {
114           6, 4, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 4,
115           4, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 4, 4, 4,
116           5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6
117       });
118   public static final String NOT_ENOUGH_INPUT_BYTES = "Not enough input bytes";
119   public static final String OUTPUT_BUFFER_TOO_SMALL =
120       "Output buffer too small";
121   public static final String EXPECTED_MATCH_LENGTH_TABLE_TO_BE_PRESENT =
122       "Expected match length table to be present";
123   public static final String INPUT_IS_CORRUPTED = "Input is corrupted";
124   public static final String VALUE_EXCEEDS_EXPECTED_MAXIMUM_VALUE =
125       "Value exceeds expected maximum value";
126 
127   private final byte[] literals = new byte[MAX_BLOCK_SIZE + SIZE_OF_LONG];
128   // extra space to allow for long-at-a-time copy
129 
130   // current buffer containing literals
131   private Object literalsBase;
132   private long literalsAddress;
133   private long literalsLimit;
134 
135   private final int[] previousOffsets = new int[3];
136 
137   private final FiniteStateEntropy.Table literalsLengthTable =
138       new FiniteStateEntropy.Table(LITERAL_LENGTH_TABLE_LOG);
139   private final FiniteStateEntropy.Table offsetCodesTable =
140       new FiniteStateEntropy.Table(OFFSET_TABLE_LOG);
141   private final FiniteStateEntropy.Table matchLengthTable =
142       new FiniteStateEntropy.Table(MATCH_LENGTH_TABLE_LOG);
143 
144   private FiniteStateEntropy.Table currentLiteralsLengthTable;
145   private FiniteStateEntropy.Table currentOffsetCodesTable;
146   private FiniteStateEntropy.Table currentMatchLengthTable;
147 
148   private final Huffman huffman = new Huffman();
149   private final FseTableReader fse = new FseTableReader();
150 
151   public int decompress(final Object inputBase, final long inputAddress,
152                         final long inputLimit, final Object outputBase,
153                         final long outputAddress, final long outputLimit) {
154     if (outputAddress == outputLimit) {
155       return 0;
156     }
157 
158     long input = inputAddress;
159     long output = outputAddress;
160 
161     while (input < inputLimit) {
162       reset();
163       final long outputStart = output;
164       input += verifyMagic(inputBase, inputAddress, inputLimit);
165 
166       final FrameHeader frameHeader =
167           readFrameHeader(inputBase, input, inputLimit);
168       input += frameHeader.headerSize;
169 
170       boolean lastBlock;
171       do {
172         verify(input + SIZE_OF_BLOCK_HEADER <= inputLimit, input,
173                NOT_ENOUGH_INPUT_BYTES);
174 
175         // read block header
176         final int header = UNSAFE.getInt(inputBase, input) & 0xFFFFFF;
177         input += SIZE_OF_BLOCK_HEADER;
178 
179         lastBlock = (header & 1) != 0;
180         final int blockType = (header >>> 1) & 0x3;
181         final int blockSize = (header >>> 3) & 0x1FFFFF; // 21 bits
182 
183         final int decodedSize;
184         switch (blockType) {
185           case RAW_BLOCK:
186             verify(inputAddress + blockSize <= inputLimit, input,
187                    NOT_ENOUGH_INPUT_BYTES);
188             decodedSize =
189                 decodeRawBlock(inputBase, input, blockSize, outputBase, output,
190                                outputLimit);
191             input += blockSize;
192             break;
193           case RLE_BLOCK:
194             verify(inputAddress + 1 <= inputLimit, input,
195                    NOT_ENOUGH_INPUT_BYTES);
196             decodedSize =
197                 decodeRleBlock(blockSize, inputBase, input, outputBase, output,
198                                outputLimit);
199             input += 1;
200             break;
201           case COMPRESSED_BLOCK:
202             verify(inputAddress + blockSize <= inputLimit, input,
203                    NOT_ENOUGH_INPUT_BYTES);
204             decodedSize =
205                 decodeCompressedBlock(inputBase, input, blockSize, outputBase,
206                                       output, outputLimit,
207                                       frameHeader.windowSize, outputAddress);
208             input += blockSize;
209             break;
210           default:
211             throw fail(input, "Invalid block type");
212         }
213 
214         output += decodedSize;
215       } while (!lastBlock);
216 
217       if (frameHeader.hasChecksum) {
218         final int decodedFrameSize = (int) (output - outputStart);
219 
220         final long hash =
221             XxHash64.hash(0, outputBase, outputStart, decodedFrameSize);
222 
223         final int checksum = UNSAFE.getInt(inputBase, input);
224         if (checksum != (int) hash) {
225           throw new MalformedInputException(input, String.format(
226               "Bad checksum. Expected: %s, actual: %s",
227               Integer.toHexString(checksum), Integer.toHexString((int) hash)));
228         }
229 
230         input += SIZE_OF_INT;
231       }
232     }
233 
234     return (int) (output - outputAddress);
235   }
236 
237   private void reset() {
238     previousOffsets[0] = 1;
239     previousOffsets[1] = 4;
240     previousOffsets[2] = 8;
241 
242     currentLiteralsLengthTable = null;
243     currentOffsetCodesTable = null;
244     currentMatchLengthTable = null;
245   }
246 
247   private static int decodeRawBlock(final Object inputBase,
248                                     final long inputAddress,
249                                     final int blockSize,
250                                     final Object outputBase,
251                                     final long outputAddress,
252                                     final long outputLimit) {
253     verify(outputAddress + blockSize <= outputLimit, inputAddress,
254            OUTPUT_BUFFER_TOO_SMALL);
255 
256     UNSAFE.copyMemory(inputBase, inputAddress, outputBase, outputAddress,
257                       blockSize);
258     return blockSize;
259   }
260 
261   private static int decodeRleBlock(final int size, final Object inputBase,
262                                     final long inputAddress,
263                                     final Object outputBase,
264                                     final long outputAddress,
265                                     final long outputLimit) {
266     verify(outputAddress + size <= outputLimit, inputAddress,
267            OUTPUT_BUFFER_TOO_SMALL);
268 
269     long output = outputAddress;
270     final long value = UNSAFE.getByte(inputBase, inputAddress) & 0xFFL;
271 
272     int remaining = size;
273     if (remaining >= SIZE_OF_LONG) {
274       final long packed =
275           value | (value << 8) | (value << 16) | (value << 24) | (value << 32) |
276           (value << 40) | (value << 48) | (value << 56);
277 
278       do {
279         UNSAFE.putLong(outputBase, output, packed);
280         output += SIZE_OF_LONG;
281         remaining -= SIZE_OF_LONG;
282       } while (remaining >= SIZE_OF_LONG);
283     }
284 
285     for (int i = 0; i < remaining; i++) {
286       UNSAFE.putByte(outputBase, output, (byte) value);
287       output++;
288     }
289 
290     return size;
291   }
292 
293   private int decodeCompressedBlock(final Object inputBase,
294                                     final long inputAddress,
295                                     final int blockSize,
296                                     final Object outputBase,
297                                     final long outputAddress,
298                                     final long outputLimit,
299                                     final int windowSize,
300                                     final long outputAbsoluteBaseAddress) {
301     final long inputLimit = inputAddress + blockSize;
302     long input = inputAddress;
303 
304     verify(blockSize <= MAX_BLOCK_SIZE, input,
305            EXPECTED_MATCH_LENGTH_TABLE_TO_BE_PRESENT);
306     verify(blockSize >= MIN_BLOCK_SIZE, input,
307            "Compressed block size too small");
308 
309     // decode literals
310     final int literalsBlockType = UNSAFE.getByte(inputBase, input) & 0x3;
311 
312     switch (literalsBlockType) {
313       case RAW_LITERALS_BLOCK: {
314         input += decodeRawLiterals(inputBase, input, inputLimit);
315         break;
316       }
317       case RLE_LITERALS_BLOCK: {
318         input += decodeRleLiterals(inputBase, input, blockSize);
319         break;
320       }
321       case TREELESS_LITERALS_BLOCK:
322         verify(huffman.isLoaded(), input, "Dictionary is corrupted");
323       case COMPRESSED_LITERALS_BLOCK: {
324         input += decodeCompressedLiterals(inputBase, input, blockSize,
325                                           literalsBlockType);
326         break;
327       }
328       default:
329         throw fail(input, "Invalid literals block encoding type");
330     }
331 
332     verify(windowSize <= MAX_WINDOW_SIZE, input,
333            "Window size too large (not yet supported)");
334 
335     return decompressSequences(inputBase, input, inputAddress + blockSize,
336                                outputBase, outputAddress, outputLimit,
337                                literalsBase, literalsAddress, literalsLimit,
338                                outputAbsoluteBaseAddress);
339   }
340 
341   private int decompressSequences(final Object inputBase,
342                                   final long inputAddress,
343                                   final long inputLimit,
344                                   final Object outputBase,
345                                   final long outputAddress,
346                                   final long outputLimit,
347                                   final Object literalsBase,
348                                   final long literalsAddress,
349                                   final long literalsLimit,
350                                   final long outputAbsoluteBaseAddress) {
351     final long fastOutputLimit = outputLimit - SIZE_OF_LONG;
352     final long fastMatchOutputLimit = fastOutputLimit - SIZE_OF_LONG;
353 
354     long input = inputAddress;
355     long output = outputAddress;
356 
357     long literalsInput = literalsAddress;
358 
359     final int size = (int) (inputLimit - inputAddress);
360     verify(size >= MIN_SEQUENCES_SIZE, input, NOT_ENOUGH_INPUT_BYTES);
361 
362     // decode header
363     int sequenceCount = UNSAFE.getByte(inputBase, input++) & 0xFF;
364     if (sequenceCount != 0) {
365       if (sequenceCount == 255) {
366         verify(input + SIZE_OF_SHORT <= inputLimit, input,
367                NOT_ENOUGH_INPUT_BYTES);
368         sequenceCount = (UNSAFE.getShort(inputBase, input) & 0xFFFF) +
369                         LONG_NUMBER_OF_SEQUENCES;
370         input += SIZE_OF_SHORT;
371       } else if (sequenceCount > 127) {
372         verify(input < inputLimit, input, NOT_ENOUGH_INPUT_BYTES);
373         sequenceCount = ((sequenceCount - 128) << 8) +
374                         (UNSAFE.getByte(inputBase, input++) & 0xFF);
375       }
376 
377       verify(input + SIZE_OF_INT <= inputLimit, input, NOT_ENOUGH_INPUT_BYTES);
378 
379       final byte type = UNSAFE.getByte(inputBase, input++);
380 
381       final int literalsLengthType = (type & 0xFF) >>> 6;
382       final int offsetCodesType = (type >>> 4) & 0x3;
383       final int matchLengthType = (type >>> 2) & 0x3;
384 
385       input = computeLiteralsTable(literalsLengthType, inputBase, input,
386                                    inputLimit);
387       input =
388           computeOffsetsTable(offsetCodesType, inputBase, input, inputLimit);
389       input = computeMatchLengthTable(matchLengthType, inputBase, input,
390                                       inputLimit);
391 
392       // decompress sequences
393       final BitInputStream.Initializer initializer =
394           new BitInputStream.Initializer(inputBase, input, inputLimit);
395       initializer.initialize();
396       int bitsConsumed = initializer.getBitsConsumed();
397       long bits = initializer.getBits();
398       long currentAddress = initializer.getCurrentAddress();
399 
400       final FiniteStateEntropy.Table literalsLengthTable1 =
401           this.currentLiteralsLengthTable;
402       final FiniteStateEntropy.Table offsetCodesTable1 =
403           this.currentOffsetCodesTable;
404       final FiniteStateEntropy.Table matchLengthTable1 =
405           this.currentMatchLengthTable;
406 
407       int literalsLengthState =
408           (int) peekBits(bitsConsumed, bits, literalsLengthTable1.log2Size);
409       bitsConsumed += literalsLengthTable1.log2Size;
410 
411       int offsetCodesState =
412           (int) peekBits(bitsConsumed, bits, offsetCodesTable1.log2Size);
413       bitsConsumed += offsetCodesTable1.log2Size;
414 
415       int matchLengthState =
416           (int) peekBits(bitsConsumed, bits, matchLengthTable1.log2Size);
417       bitsConsumed += matchLengthTable1.log2Size;
418 
419       final int[] previousOffsets1 = this.previousOffsets;
420 
421       final byte[] literalsLengthNumbersOfBits =
422           literalsLengthTable1.numberOfBits;
423       final int[] literalsLengthNewStates = literalsLengthTable1.newState;
424       final byte[] literalsLengthSymbols = literalsLengthTable1.symbol;
425 
426       final byte[] matchLengthNumbersOfBits = matchLengthTable1.numberOfBits;
427       final int[] matchLengthNewStates = matchLengthTable1.newState;
428       final byte[] matchLengthSymbols = matchLengthTable1.symbol;
429 
430       final byte[] offsetCodesNumbersOfBits = offsetCodesTable1.numberOfBits;
431       final int[] offsetCodesNewStates = offsetCodesTable1.newState;
432       final byte[] offsetCodesSymbols = offsetCodesTable1.symbol;
433 
434       while (sequenceCount > 0) {
435         sequenceCount--;
436 
437         final BitInputStream.Loader loader =
438             new BitInputStream.Loader(inputBase, input, currentAddress, bits,
439                                       bitsConsumed);
440         loader.load();
441         bitsConsumed = loader.getBitsConsumed();
442         bits = loader.getBits();
443         currentAddress = loader.getCurrentAddress();
444         if (loader.isOverflow()) {
445           verify(sequenceCount == 0, input, "Not all sequences were consumed");
446           break;
447         }
448 
449         // decode sequence
450         final int literalsLengthCode =
451             literalsLengthSymbols[literalsLengthState];
452         final int matchLengthCode = matchLengthSymbols[matchLengthState];
453         final int offsetCode = offsetCodesSymbols[offsetCodesState];
454 
455         final int literalsLengthBits = LITERALS_LENGTH_BITS[literalsLengthCode];
456         final int matchLengthBits = MATCH_LENGTH_BITS[matchLengthCode];
457 
458         int offset = OFFSET_CODES_BASE[offsetCode];
459         if (offsetCode > 0) {
460           offset += peekBits(bitsConsumed, bits, offsetCode);
461           bitsConsumed += offsetCode;
462         }
463 
464         if (offsetCode <= 1) {
465           if (literalsLengthCode == 0) {
466             offset++;
467           }
468 
469           if (offset != 0) {
470             int temp;
471             if (offset == 3) {
472               temp = previousOffsets1[0] - 1;
473             } else {
474               temp = previousOffsets1[offset];
475             }
476 
477             if (temp == 0) {
478               temp = 1;
479             }
480 
481             if (offset != 1) {
482               previousOffsets1[2] = previousOffsets1[1];
483             }
484             previousOffsets1[1] = previousOffsets1[0];
485             previousOffsets1[0] = temp;
486 
487             offset = temp;
488           } else {
489             offset = previousOffsets1[0];
490           }
491         } else {
492           previousOffsets1[2] = previousOffsets1[1];
493           previousOffsets1[1] = previousOffsets1[0];
494           previousOffsets1[0] = offset;
495         }
496 
497         int matchLength = MATCH_LENGTH_BASE[matchLengthCode];
498         if (matchLengthCode > 31) {
499           matchLength += peekBits(bitsConsumed, bits, matchLengthBits);
500           bitsConsumed += matchLengthBits;
501         }
502 
503         int literalsLength = LITERALS_LENGTH_BASE[literalsLengthCode];
504         if (literalsLengthCode > 15) {
505           literalsLength += peekBits(bitsConsumed, bits, literalsLengthBits);
506           bitsConsumed += literalsLengthBits;
507         }
508 
509         final int totalBits = literalsLengthBits + matchLengthBits + offsetCode;
510         if (totalBits > 64 - 7 -
511                         (LITERAL_LENGTH_TABLE_LOG + MATCH_LENGTH_TABLE_LOG +
512                          OFFSET_TABLE_LOG)) {
513           final BitInputStream.Loader loader1 =
514               new BitInputStream.Loader(inputBase, input, currentAddress, bits,
515                                         bitsConsumed);
516           loader1.load();
517 
518           bitsConsumed = loader1.getBitsConsumed();
519           bits = loader1.getBits();
520           currentAddress = loader1.getCurrentAddress();
521         }
522 
523         int numberOfBits;
524 
525         numberOfBits = literalsLengthNumbersOfBits[literalsLengthState];
526         literalsLengthState =
527             (int) (literalsLengthNewStates[literalsLengthState] +
528                    peekBits(bitsConsumed, bits, numberOfBits)); // <= 9 bits
529         bitsConsumed += numberOfBits;
530 
531         numberOfBits = matchLengthNumbersOfBits[matchLengthState];
532         matchLengthState = (int) (matchLengthNewStates[matchLengthState] +
533                                   peekBits(bitsConsumed, bits,
534                                            numberOfBits)); // <= 9 bits
535         bitsConsumed += numberOfBits;
536 
537         numberOfBits = offsetCodesNumbersOfBits[offsetCodesState];
538         offsetCodesState = (int) (offsetCodesNewStates[offsetCodesState] +
539                                   peekBits(bitsConsumed, bits,
540                                            numberOfBits)); // <= 8 bits
541         bitsConsumed += numberOfBits;
542 
543         final long literalOutputLimit = output + literalsLength;
544         final long matchOutputLimit = literalOutputLimit + matchLength;
545 
546         verify(matchOutputLimit <= outputLimit, input, OUTPUT_BUFFER_TOO_SMALL);
547         final long literalEnd = literalsInput + literalsLength;
548         verify(literalEnd <= literalsLimit, input, INPUT_IS_CORRUPTED);
549 
550         final long matchAddress = literalOutputLimit - offset;
551         verify(matchAddress >= outputAbsoluteBaseAddress, input,
552                INPUT_IS_CORRUPTED);
553 
554         if (literalOutputLimit > fastOutputLimit) {
555           executeLastSequence(outputBase, output, literalOutputLimit,
556                               matchOutputLimit, fastOutputLimit, literalsInput,
557                               matchAddress);
558         } else {
559           // copy literals. literalOutputLimit <= fastOutputLimit, so we can copy
560           // long at a time with over-copy
561           output = copyLiterals(outputBase, literalsBase, output, literalsInput,
562                                 literalOutputLimit);
563           copyMatch(outputBase, fastOutputLimit, output, offset,
564                     matchOutputLimit, matchAddress, matchLength,
565                     fastMatchOutputLimit);
566         }
567         output = matchOutputLimit;
568         literalsInput = literalEnd;
569       }
570     }
571 
572     // last literal segment
573     output = copyLastLiteral(outputBase, literalsBase, literalsLimit, output,
574                              literalsInput);
575 
576     return (int) (output - outputAddress);
577   }
578 
579   private long copyLastLiteral(final Object outputBase,
580                                final Object literalsBase,
581                                final long literalsLimit, long output,
582                                final long literalsInput) {
583     final long lastLiteralsSize = literalsLimit - literalsInput;
584     UNSAFE.copyMemory(literalsBase, literalsInput, outputBase, output,
585                       lastLiteralsSize);
586     output += lastLiteralsSize;
587     return output;
588   }
589 
590   private void copyMatch(final Object outputBase, final long fastOutputLimit,
591                          long output, final int offset,
592                          final long matchOutputLimit, long matchAddress,
593                          int matchLength, final long fastMatchOutputLimit) {
594     matchAddress = copyMatchHead(outputBase, output, offset, matchAddress);
595     output += SIZE_OF_LONG;
596     matchLength -= SIZE_OF_LONG; // first 8 bytes copied above
597 
598     copyMatchTail(outputBase, fastOutputLimit, output, matchOutputLimit,
599                   matchAddress, matchLength, fastMatchOutputLimit);
600   }
601 
602   private void copyMatchTail(final Object outputBase,
603                              final long fastOutputLimit, long output,
604                              final long matchOutputLimit, long matchAddress,
605                              final int matchLength,
606                              final long fastMatchOutputLimit) {
607     // fastMatchOutputLimit is just fastOutputLimit - SIZE_OF_LONG. It needs to be passed in so that it can be computed once for the
608     // whole invocation to decompressSequences. Otherwise, we'd just compute it here.
609     // If matchOutputLimit is < fastMatchOutputLimit, we know that even after the head (8 bytes) has been copied, the output pointer
610     // will be within fastOutputLimit, so it's safe to copy blindly before checking the limit condition
611     if (matchOutputLimit < fastMatchOutputLimit) {
612       int copied = 0;
613       do {
614         UNSAFE.putLong(outputBase, output,
615                        UNSAFE.getLong(outputBase, matchAddress));
616         output += SIZE_OF_LONG;
617         matchAddress += SIZE_OF_LONG;
618         copied += SIZE_OF_LONG;
619       } while (copied < matchLength);
620     } else {
621       while (output < fastOutputLimit) {
622         UNSAFE.putLong(outputBase, output,
623                        UNSAFE.getLong(outputBase, matchAddress));
624         matchAddress += SIZE_OF_LONG;
625         output += SIZE_OF_LONG;
626       }
627 
628       while (output < matchOutputLimit) {
629         UNSAFE.putByte(outputBase, output++,
630                        UNSAFE.getByte(outputBase, matchAddress++));
631       }
632     }
633   }
634 
635   private long copyMatchHead(final Object outputBase, final long output,
636                              final int offset, long matchAddress) {
637     // copy match
638     if (offset < 8) {
639       // 8 bytes apart so that we can copy long-at-a-time below
640       final int increment32 = DEC_32_TABLE[offset];
641       final int decrement64 = DEC_64_TABLE[offset];
642 
643       UNSAFE.putByte(outputBase, output,
644                      UNSAFE.getByte(outputBase, matchAddress));
645       UNSAFE.putByte(outputBase, output + 1,
646                      UNSAFE.getByte(outputBase, matchAddress + 1));
647       UNSAFE.putByte(outputBase, output + 2,
648                      UNSAFE.getByte(outputBase, matchAddress + 2));
649       UNSAFE.putByte(outputBase, output + 3,
650                      UNSAFE.getByte(outputBase, matchAddress + 3));
651       matchAddress += increment32;
652 
653       UNSAFE.putInt(outputBase, output + 4,
654                     UNSAFE.getInt(outputBase, matchAddress));
655       matchAddress -= decrement64;
656     } else {
657       UNSAFE.putLong(outputBase, output,
658                      UNSAFE.getLong(outputBase, matchAddress));
659       matchAddress += SIZE_OF_LONG;
660     }
661     return matchAddress;
662   }
663 
664   private long copyLiterals(final Object outputBase, final Object literalsBase,
665                             long output, final long literalsInput,
666                             final long literalOutputLimit) {
667     long literalInput = literalsInput;
668     do {
669       UNSAFE.putLong(outputBase, output,
670                      UNSAFE.getLong(literalsBase, literalInput));
671       output += SIZE_OF_LONG;
672       literalInput += SIZE_OF_LONG;
673     } while (output < literalOutputLimit);
674     output = literalOutputLimit; // correction in case we over-copied
675     return output;
676   }
677 
678   private long computeMatchLengthTable(final int matchLengthType,
679                                        final Object inputBase, long input,
680                                        final long inputLimit) {
681     switch (matchLengthType) {
682       case SEQUENCE_ENCODING_RLE:
683         verify(input < inputLimit, input, NOT_ENOUGH_INPUT_BYTES);
684 
685         final byte value = UNSAFE.getByte(inputBase, input++);
686         verify(value <= MAX_MATCH_LENGTH_SYMBOL, input,
687                VALUE_EXCEEDS_EXPECTED_MAXIMUM_VALUE);
688 
689         FseTableReader.initializeRleTable(matchLengthTable, value);
690         currentMatchLengthTable = matchLengthTable;
691         break;
692       case SEQUENCE_ENCODING_BASIC:
693         currentMatchLengthTable = DEFAULT_MATCH_LENGTH_TABLE;
694         break;
695       case SEQUENCE_ENCODING_REPEAT:
696         verify(currentMatchLengthTable != null, input,
697                EXPECTED_MATCH_LENGTH_TABLE_TO_BE_PRESENT);
698         break;
699       case SEQUENCE_ENCODING_COMPRESSED:
700         input +=
701             fse.readFseTable(matchLengthTable, inputBase, input, inputLimit,
702                              MAX_MATCH_LENGTH_SYMBOL, MATCH_LENGTH_TABLE_LOG);
703         currentMatchLengthTable = matchLengthTable;
704         break;
705       default:
706         throw fail(input, "Invalid match length encoding type");
707     }
708     return input;
709   }
710 
711   private long computeOffsetsTable(final int offsetCodesType,
712                                    final Object inputBase, long input,
713                                    final long inputLimit) {
714     switch (offsetCodesType) {
715       case SEQUENCE_ENCODING_RLE:
716         verify(input < inputLimit, input, NOT_ENOUGH_INPUT_BYTES);
717 
718         final byte value = UNSAFE.getByte(inputBase, input++);
719         verify(value <= DEFAULT_MAX_OFFSET_CODE_SYMBOL, input,
720                VALUE_EXCEEDS_EXPECTED_MAXIMUM_VALUE);
721 
722         FseTableReader.initializeRleTable(offsetCodesTable, value);
723         currentOffsetCodesTable = offsetCodesTable;
724         break;
725       case SEQUENCE_ENCODING_BASIC:
726         currentOffsetCodesTable = DEFAULT_OFFSET_CODES_TABLE;
727         break;
728       case SEQUENCE_ENCODING_REPEAT:
729         verify(currentOffsetCodesTable != null, input,
730                EXPECTED_MATCH_LENGTH_TABLE_TO_BE_PRESENT);
731         break;
732       case SEQUENCE_ENCODING_COMPRESSED:
733         input +=
734             fse.readFseTable(offsetCodesTable, inputBase, input, inputLimit,
735                              DEFAULT_MAX_OFFSET_CODE_SYMBOL, OFFSET_TABLE_LOG);
736         currentOffsetCodesTable = offsetCodesTable;
737         break;
738       default:
739         throw fail(input, "Invalid offset code encoding type");
740     }
741     return input;
742   }
743 
744   private long computeLiteralsTable(final int literalsLengthType,
745                                     final Object inputBase, long input,
746                                     final long inputLimit) {
747     switch (literalsLengthType) {
748       case SEQUENCE_ENCODING_RLE:
749         verify(input < inputLimit, input, NOT_ENOUGH_INPUT_BYTES);
750 
751         final byte value = UNSAFE.getByte(inputBase, input++);
752         verify(value <= MAX_LITERALS_LENGTH_SYMBOL, input,
753                VALUE_EXCEEDS_EXPECTED_MAXIMUM_VALUE);
754 
755         FseTableReader.initializeRleTable(literalsLengthTable, value);
756         currentLiteralsLengthTable = literalsLengthTable;
757         break;
758       case SEQUENCE_ENCODING_BASIC:
759         currentLiteralsLengthTable = DEFAULT_LITERALS_LENGTH_TABLE;
760         break;
761       case SEQUENCE_ENCODING_REPEAT:
762         verify(currentLiteralsLengthTable != null, input,
763                EXPECTED_MATCH_LENGTH_TABLE_TO_BE_PRESENT);
764         break;
765       case SEQUENCE_ENCODING_COMPRESSED:
766         input +=
767             fse.readFseTable(literalsLengthTable, inputBase, input, inputLimit,
768                              MAX_LITERALS_LENGTH_SYMBOL,
769                              LITERAL_LENGTH_TABLE_LOG);
770         currentLiteralsLengthTable = literalsLengthTable;
771         break;
772       default:
773         throw fail(input, "Invalid literals length encoding type");
774     }
775     return input;
776   }
777 
778   private void executeLastSequence(final Object outputBase, long output,
779                                    final long literalOutputLimit,
780                                    final long matchOutputLimit,
781                                    final long fastOutputLimit,
782                                    long literalInput, long matchAddress) {
783     // copy literals
784     if (output < fastOutputLimit) {
785       // wild copy
786       do {
787         UNSAFE.putLong(outputBase, output,
788                        UNSAFE.getLong(literalsBase, literalInput));
789         output += SIZE_OF_LONG;
790         literalInput += SIZE_OF_LONG;
791       } while (output < fastOutputLimit);
792 
793       literalInput -= output - fastOutputLimit;
794       output = fastOutputLimit;
795     }
796 
797     while (output < literalOutputLimit) {
798       UNSAFE.putByte(outputBase, output,
799                      UNSAFE.getByte(literalsBase, literalInput));
800       output++;
801       literalInput++;
802     }
803 
804     // copy match
805     while (output < matchOutputLimit) {
806       UNSAFE.putByte(outputBase, output,
807                      UNSAFE.getByte(outputBase, matchAddress));
808       output++;
809       matchAddress++;
810     }
811   }
812 
813   private int decodeCompressedLiterals(final Object inputBase,
814                                        final long inputAddress,
815                                        final int blockSize,
816                                        final int literalsBlockType) {
817     long input = inputAddress;
818     verify(blockSize >= 5, input, NOT_ENOUGH_INPUT_BYTES);
819 
820     // compressed
821     final int compressedSize;
822     final int uncompressedSize;
823     boolean singleStream = false;
824     final int headerSize;
825     final int type = (UNSAFE.getByte(inputBase, input) >> 2) & 0x3;
826     switch (type) {
827       case 0:
828         singleStream = true;
829       case 1: {
830         final int header = UNSAFE.getInt(inputBase, input);
831 
832         headerSize = 3;
833         uncompressedSize = (header >>> 4) & mask(10);
834         compressedSize = (header >>> 14) & mask(10);
835         break;
836       }
837       case 2: {
838         final int header = UNSAFE.getInt(inputBase, input);
839 
840         headerSize = 4;
841         uncompressedSize = (header >>> 4) & mask(14);
842         compressedSize = (header >>> 18) & mask(14);
843         break;
844       }
845       case 3: {
846         // read 5 little-endian bytes
847         final long header = UNSAFE.getByte(inputBase, input) & 0xFF |
848                             (UNSAFE.getInt(inputBase, input + 1) &
849                              0xFFFFFFFFL) << 8;
850 
851         headerSize = 5;
852         uncompressedSize = (int) ((header >>> 4) & mask(18));
853         compressedSize = (int) ((header >>> 22) & mask(18));
854         break;
855       }
856       default:
857         throw fail(input, "Invalid literals header size type");
858     }
859 
860     verify(uncompressedSize <= MAX_BLOCK_SIZE, input,
861            "Block exceeds maximum size");
862     verify(headerSize + compressedSize <= blockSize, input, INPUT_IS_CORRUPTED);
863 
864     input += headerSize;
865 
866     final long inputLimit = input + compressedSize;
867     if (literalsBlockType != TREELESS_LITERALS_BLOCK) {
868       input += huffman.readTable(inputBase, input, compressedSize);
869     }
870 
871     literalsBase = literals;
872     literalsAddress = ARRAY_BYTE_BASE_OFFSET;
873     literalsLimit = ARRAY_BYTE_BASE_OFFSET + uncompressedSize;
874 
875     if (singleStream) {
876       huffman.decodeSingleStream(inputBase, input, inputLimit, literals,
877                                  literalsAddress, literalsLimit);
878     } else {
879       huffman.decode4Streams(inputBase, input, inputLimit, literals,
880                              literalsAddress, literalsLimit);
881     }
882 
883     return headerSize + compressedSize;
884   }
885 
886   private int decodeRleLiterals(final Object inputBase, final long inputAddress,
887                                 final int blockSize) {
888     long input = inputAddress;
889     final int outputSize;
890 
891     final int type = (UNSAFE.getByte(inputBase, input) >> 2) & 0x3;
892     switch (type) {
893       case 0:
894       case 2:
895         outputSize = (UNSAFE.getByte(inputBase, input) & 0xFF) >>> 3;
896         input++;
897         break;
898       case 1:
899         outputSize = (UNSAFE.getShort(inputBase, input) & 0xFFFF) >>> 4;
900         input += 2;
901         break;
902       case 3:
903         // we need at least 4 bytes (3 for the header, 1 for the payload)
904         verify(blockSize >= SIZE_OF_INT, input, NOT_ENOUGH_INPUT_BYTES);
905         outputSize = (UNSAFE.getInt(inputBase, input) & 0xFFFFFF) >>> 4;
906         input += 3;
907         break;
908       default:
909         throw fail(input, "Invalid RLE literals header encoding type");
910     }
911 
912     verify(outputSize <= MAX_BLOCK_SIZE, input,
913            "Output exceeds maximum block size");
914 
915     final byte value = UNSAFE.getByte(inputBase, input++);
916     Arrays.fill(literals, 0, outputSize + SIZE_OF_LONG, value);
917 
918     literalsBase = literals;
919     literalsAddress = ARRAY_BYTE_BASE_OFFSET;
920     literalsLimit = ARRAY_BYTE_BASE_OFFSET + outputSize;
921 
922     return (int) (input - inputAddress);
923   }
924 
925   private int decodeRawLiterals(final Object inputBase, final long inputAddress,
926                                 final long inputLimit) {
927     long input = inputAddress;
928     final int type = (UNSAFE.getByte(inputBase, input) >> 2) & 0x3;
929 
930     final int literalSize;
931     switch (type) {
932       case 0:
933       case 2:
934         literalSize = (UNSAFE.getByte(inputBase, input) & 0xFF) >>> 3;
935         input++;
936         break;
937       case 1:
938         literalSize = (UNSAFE.getShort(inputBase, input) & 0xFFFF) >>> 4;
939         input += 2;
940         break;
941       case 3:
942         // read 3 little-endian bytes
943         final int header = ((UNSAFE.getByte(inputBase, input) & 0xFF) |
944                             ((UNSAFE.getShort(inputBase, input + 1) & 0xFFFF) <<
945                              8));
946 
947         literalSize = header >>> 4;
948         input += 3;
949         break;
950       default:
951         throw fail(input, "Invalid raw literals header encoding type");
952     }
953 
954     verify(input + literalSize <= inputLimit, input, NOT_ENOUGH_INPUT_BYTES);
955 
956     // Set literals pointer to [input, literalSize], but only if we can copy 8 bytes at a time during sequence decoding
957     // Otherwise, copy literals into buffer that's big enough to guarantee that
958     if (literalSize > (inputLimit - input) - SIZE_OF_LONG) {
959       literalsBase = literals;
960       literalsAddress = ARRAY_BYTE_BASE_OFFSET;
961       literalsLimit = ARRAY_BYTE_BASE_OFFSET + literalSize;
962 
963       UNSAFE.copyMemory(inputBase, input, literals, literalsAddress,
964                         literalSize);
965       Arrays.fill(literals, literalSize, literalSize + SIZE_OF_LONG, (byte) 0);
966     } else {
967       literalsBase = inputBase;
968       literalsAddress = input;
969       literalsLimit = literalsAddress + literalSize;
970     }
971     input += literalSize;
972 
973     return (int) (input - inputAddress);
974   }
975 
976   static FrameHeader readFrameHeader(final Object inputBase,
977                                      final long inputAddress,
978                                      final long inputLimit) {
979     long input = inputAddress;
980     verify(input < inputLimit, input, NOT_ENOUGH_INPUT_BYTES);
981 
982     final int frameHeaderDescriptor = UNSAFE.getByte(inputBase, input++) & 0xFF;
983     final boolean singleSegment = (frameHeaderDescriptor & 0x20) != 0;
984     final int dictionaryDescriptor = frameHeaderDescriptor & 0x3;
985     final int contentSizeDescriptor = frameHeaderDescriptor >>> 6;
986 
987     final int headerSize = 1 + (singleSegment? 0 : 1) +
988                            (dictionaryDescriptor == 0? 0 :
989                                (1 << (dictionaryDescriptor - 1))) +
990                            (contentSizeDescriptor == 0? (singleSegment? 1 : 0) :
991                                (1 << contentSizeDescriptor));
992 
993     verify(headerSize <= inputLimit - inputAddress, input,
994            NOT_ENOUGH_INPUT_BYTES);
995 
996     // decode window size
997     int windowSize = -1;
998     if (!singleSegment) {
999       final int windowDescriptor = UNSAFE.getByte(inputBase, input++) & 0xFF;
1000       final int exponent = windowDescriptor >>> 3;
1001       final int mantissa = windowDescriptor & 0x7;
1002 
1003       final int base = 1 << (MIN_WINDOW_LOG + exponent);
1004       windowSize = base + (base / 8) * mantissa;
1005     }
1006 
1007     // decode dictionary id
1008     long dictionaryId = -1;
1009     switch (dictionaryDescriptor) {
1010       case 1:
1011         dictionaryId = UNSAFE.getByte(inputBase, input) & 0xFF;
1012         input += SIZE_OF_BYTE;
1013         break;
1014       case 2:
1015         dictionaryId = UNSAFE.getShort(inputBase, input) & 0xFFFF;
1016         input += SIZE_OF_SHORT;
1017         break;
1018       case 3:
1019         dictionaryId = UNSAFE.getInt(inputBase, input) & 0xFFFFFFFFL;
1020         input += SIZE_OF_INT;
1021         break;
1022     }
1023     verify(dictionaryId == -1, input, "Custom dictionaries not supported");
1024 
1025     // decode content size
1026     long contentSize = -1;
1027     switch (contentSizeDescriptor) {
1028       case 0:
1029         if (singleSegment) {
1030           contentSize = UNSAFE.getByte(inputBase, input) & 0xFF;
1031           input += SIZE_OF_BYTE;
1032         }
1033         break;
1034       case 1:
1035         contentSize = UNSAFE.getShort(inputBase, input) & 0xFFFF;
1036         contentSize += 256;
1037         input += SIZE_OF_SHORT;
1038         break;
1039       case 2:
1040         contentSize = UNSAFE.getInt(inputBase, input) & 0xFFFFFFFFL;
1041         input += SIZE_OF_INT;
1042         break;
1043       case 3:
1044         contentSize = UNSAFE.getLong(inputBase, input);
1045         input += SIZE_OF_LONG;
1046         break;
1047     }
1048 
1049     final boolean hasChecksum = (frameHeaderDescriptor & 0x4) != 0;
1050 
1051     return new FrameHeader(input - inputAddress, windowSize, contentSize,
1052                            dictionaryId, hasChecksum);
1053   }
1054 
1055   public static long getDecompressedSize(final Object inputBase,
1056                                          final long inputAddress,
1057                                          final long inputLimit) {
1058     long input = inputAddress;
1059     input += verifyMagic(inputBase, input, inputLimit);
1060     return readFrameHeader(inputBase, input, inputLimit).contentSize;
1061   }
1062 
1063   static int verifyMagic(final Object inputBase, final long inputAddress,
1064                          final long inputLimit) {
1065     verify(inputLimit - inputAddress >= 4, inputAddress,
1066            NOT_ENOUGH_INPUT_BYTES);
1067 
1068     final int magic = UNSAFE.getInt(inputBase, inputAddress);
1069     if (magic != MAGIC_NUMBER) {
1070       if (magic == V07_MAGIC_NUMBER) {
1071         throw new MalformedInputException(inputAddress,
1072                                           "Data encoded in unsupported ZSTD v0.7 format");
1073       }
1074       throw new MalformedInputException(inputAddress, "Invalid magic prefix: " +
1075                                                       Integer.toHexString(
1076                                                           magic));
1077     }
1078 
1079     return SIZE_OF_INT;
1080   }
1081 }