View Javadoc
1   /*
2    * This file is part of Waarp Project (named also Waarp or GG).
3    *
4    *  Copyright (c) 2019, Waarp SAS, and individual contributors by the @author
5    *  tags. See the COPYRIGHT.txt in the distribution for a full listing of
6    * individual contributors.
7    *
8    *  All Waarp Project is free software: you can redistribute it and/or
9    * modify it under the terms of the GNU General Public License as published by
10   * the Free Software Foundation, either version 3 of the License, or (at your
11   * option) any later version.
12   *
13   * Waarp is distributed in the hope that it will be useful, but WITHOUT ANY
14   * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
15   * A PARTICULAR PURPOSE. See the GNU General Public License for more details.
16   *
17   *  You should have received a copy of the GNU General Public License along with
18   * Waarp . If not, see <http://www.gnu.org/licenses/>.
19   */
20  
21  /*
22   * Licensed under the Apache License, Version 2.0 (the "License");
23   * you may not use this file except in compliance with the License.
24   * You may obtain a copy of the License at
25   *
26   *     http://www.apache.org/licenses/LICENSE-2.0
27   *
28   * Unless required by applicable law or agreed to in writing, software
29   * distributed under the License is distributed on an "AS IS" BASIS,
30   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31   * See the License for the specific language governing permissions and
32   * limitations under the License.
33   */
34  package org.waarp.compress.zstdsafe;
35  
36  import org.waarp.compress.MalformedInputException;
37  
38  import java.util.Arrays;
39  
40  import static org.waarp.compress.zstdsafe.BitInputStream.*;
41  import static org.waarp.compress.zstdsafe.Constants.*;
42  import static org.waarp.compress.zstdsafe.UnsafeUtil.*;
43  import static org.waarp.compress.zstdsafe.Util.*;
44  
45  class ZstdFrameDecompressor {
46    private static final int[] DEC_32_TABLE = { 4, 1, 2, 1, 4, 4, 4, 4 };
47    private static final int[] DEC_64_TABLE = { 0, 0, 0, -1, 0, 1, 2, 3 };
48  
49    private static final int V07_MAGIC_NUMBER = 0xFD2FB527;
50  
51    private static final int MAX_WINDOW_SIZE = 1 << 23;
52  
53    private static final int[] LITERALS_LENGTH_BASE = {
54        0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18, 20, 22, 24,
55        28, 32, 40, 48, 64, 0x80, 0x100, 0x200, 0x400, 0x800, 0x1000, 0x2000,
56        0x4000, 0x8000, 0x10000
57    };
58  
59    private static final int[] MATCH_LENGTH_BASE = {
60        3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22,
61        23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 37, 39, 41, 43, 47,
62        51, 59, 67, 83, 99, 0x83, 0x103, 0x203, 0x403, 0x803, 0x1003, 0x2003,
63        0x4003, 0x8003, 0x10003
64    };
65  
66    private static final int[] OFFSET_CODES_BASE = {
67        0, 1, 1, 5, 0xD, 0x1D, 0x3D, 0x7D, 0xFD, 0x1FD, 0x3FD, 0x7FD, 0xFFD,
68        0x1FFD, 0x3FFD, 0x7FFD, 0xFFFD, 0x1FFFD, 0x3FFFD, 0x7FFFD, 0xFFFFD,
69        0x1FFFFD, 0x3FFFFD, 0x7FFFFD, 0xFFFFFD, 0x1FFFFFD, 0x3FFFFFD, 0x7FFFFFD,
70        0xFFFFFFD
71    };
72  
73    private static final FiniteStateEntropy.Table DEFAULT_LITERALS_LENGTH_TABLE =
74        new FiniteStateEntropy.Table(6, new int[] {
75            0, 16, 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32, 0, 0, 0, 0,
76            32, 0, 0, 32, 0, 32, 0, 32, 0, 0, 32, 0, 32, 0, 32, 0, 0, 16, 32, 0,
77            0, 48, 16, 32, 32, 32, 32, 32, 32, 32, 32, 0, 32, 32, 32, 32, 32, 32,
78            0, 0, 0, 0
79        }, new byte[] {
80            0, 0, 1, 3, 4, 6, 7, 9, 10, 12, 14, 16, 18, 19, 21, 22, 24, 25, 26,
81            27, 29, 31, 0, 1, 2, 4, 5, 7, 8, 10, 11, 13, 16, 17, 19, 20, 22, 23,
82            25, 25, 26, 28, 30, 0, 1, 2, 3, 5, 6, 8, 9, 11, 12, 15, 17, 18, 20,
83            21, 23, 24, 35, 34, 33, 32
84        }, new byte[] {
85            4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 6, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 4,
86            4, 5, 5, 5, 5, 5, 5, 5, 6, 5, 5, 5, 5, 5, 5, 4, 4, 5, 6, 6, 4, 4, 5,
87            5, 5, 5, 5, 5, 5, 5, 6, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6
88        });
89  
90    private static final FiniteStateEntropy.Table DEFAULT_OFFSET_CODES_TABLE =
91        new FiniteStateEntropy.Table(5, new int[] {
92            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 0, 0, 0, 0, 16, 0, 0,
93            0, 16, 0, 0, 0, 0, 0, 0, 0
94        }, new byte[] {
95            0, 6, 9, 15, 21, 3, 7, 12, 18, 23, 5, 8, 14, 20, 2, 7, 11, 17, 22, 4,
96            8, 13, 19, 1, 6, 10, 16, 28, 27, 26, 25, 24
97        }, new byte[] {
98            5, 4, 5, 5, 5, 5, 4, 5, 5, 5, 5, 4, 5, 5, 5, 4, 5, 5, 5, 5, 4, 5, 5,
99            5, 4, 5, 5, 5, 5, 5, 5, 5
100       });
101 
102   private static final FiniteStateEntropy.Table DEFAULT_MATCH_LENGTH_TABLE =
103       new FiniteStateEntropy.Table(6, new int[] {
104           0, 0, 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16,
105           0, 32, 0, 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32, 48,
106           16, 32, 32, 32, 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
107       }, new byte[] {
108           0, 1, 2, 3, 5, 6, 8, 10, 13, 16, 19, 22, 25, 28, 31, 33, 35, 37, 39,
109           41, 43, 45, 1, 2, 3, 4, 6, 7, 9, 12, 15, 18, 21, 24, 27, 30, 32, 34,
110           36, 38, 40, 42, 44, 1, 1, 2, 4, 5, 7, 8, 11, 14, 17, 20, 23, 26, 29,
111           52, 51, 50, 49, 48, 47, 46
112       }, new byte[] {
113           6, 4, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 4,
114           4, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 4, 4, 4,
115           5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6
116       });
117   public static final String NOT_ENOUGH_INPUT_BYTES = "Not enough input bytes";
118   public static final String OUTPUT_BUFFER_TOO_SMALL =
119       "Output buffer too small";
120   public static final String EXPECTED_MATCH_LENGTH_TABLE_TO_BE_PRESENT =
121       "Expected match length table to be present";
122   public static final String INPUT_IS_CORRUPTED = "Input is corrupted";
123   public static final String VALUE_EXCEEDS_EXPECTED_MAXIMUM_VALUE =
124       "Value exceeds expected maximum value";
125 
126   private final byte[] literals = new byte[MAX_BLOCK_SIZE + SIZE_OF_LONG];
127   // extra space to allow for long-at-a-time copy
128 
129   // current buffer containing literals
130   private byte[] literalsBase;
131   private int literalsAddress;
132   private int literalsLimit;
133 
134   private final int[] previousOffsets = new int[3];
135 
136   private final FiniteStateEntropy.Table literalsLengthTable =
137       new FiniteStateEntropy.Table(LITERAL_LENGTH_TABLE_LOG);
138   private final FiniteStateEntropy.Table offsetCodesTable =
139       new FiniteStateEntropy.Table(OFFSET_TABLE_LOG);
140   private final FiniteStateEntropy.Table matchLengthTable =
141       new FiniteStateEntropy.Table(MATCH_LENGTH_TABLE_LOG);
142 
143   private FiniteStateEntropy.Table currentLiteralsLengthTable;
144   private FiniteStateEntropy.Table currentOffsetCodesTable;
145   private FiniteStateEntropy.Table currentMatchLengthTable;
146 
147   private final Huffman huffman = new Huffman();
148   private final FseTableReader fse = new FseTableReader();
149 
150   public int decompress(final byte[] inputBase, final int inputAddress,
151                         final int inputLimit, final byte[] outputBase,
152                         final int outputAddress, final int outputLimit) {
153     if (outputAddress == outputLimit) {
154       return 0;
155     }
156 
157     int input = inputAddress;
158     int output = outputAddress;
159 
160     while (input < inputLimit) {
161       reset();
162       final int outputStart = output;
163       input += verifyMagic(inputBase, inputAddress, inputLimit);
164 
165       final FrameHeader frameHeader =
166           readFrameHeader(inputBase, input, inputLimit);
167       input += frameHeader.headerSize;
168 
169       boolean lastBlock;
170       do {
171         verify(input + SIZE_OF_BLOCK_HEADER <= inputLimit, input,
172                NOT_ENOUGH_INPUT_BYTES);
173 
174         // read block header
175         final int header = getInt(inputBase, input) & 0xFFFFFF;
176         input += SIZE_OF_BLOCK_HEADER;
177 
178         lastBlock = (header & 1) != 0;
179         final int blockType = (header >>> 1) & 0x3;
180         final int blockSize = (header >>> 3) & 0x1FFFFF; // 21 bits
181 
182         final int decodedSize;
183         switch (blockType) {
184           case RAW_BLOCK:
185             verify(inputAddress + blockSize <= inputLimit, input,
186                    NOT_ENOUGH_INPUT_BYTES);
187             decodedSize =
188                 decodeRawBlock(inputBase, input, blockSize, outputBase, output,
189                                outputLimit);
190             input += blockSize;
191             break;
192           case RLE_BLOCK:
193             verify(inputAddress + 1 <= inputLimit, input,
194                    NOT_ENOUGH_INPUT_BYTES);
195             decodedSize =
196                 decodeRleBlock(blockSize, inputBase, input, outputBase, output,
197                                outputLimit);
198             input += 1;
199             break;
200           case COMPRESSED_BLOCK:
201             verify(inputAddress + blockSize <= inputLimit, input,
202                    NOT_ENOUGH_INPUT_BYTES);
203             decodedSize =
204                 decodeCompressedBlock(inputBase, input, blockSize, outputBase,
205                                       output, outputLimit,
206                                       frameHeader.windowSize, outputAddress);
207             input += blockSize;
208             break;
209           default:
210             throw fail(input, "Invalid block type");
211         }
212 
213         output += decodedSize;
214       } while (!lastBlock);
215 
216       if (frameHeader.hasChecksum) {
217         final int decodedFrameSize = output - outputStart;
218 
219         final long hash =
220             XxHash64.hash(0, outputBase, outputStart, decodedFrameSize);
221 
222         final int checksum = getInt(inputBase, input);
223         if (checksum != (int) hash) {
224           throw new MalformedInputException(input, String.format(
225               "Bad checksum. Expected: %s, actual: %s",
226               Integer.toHexString(checksum), Integer.toHexString((int) hash)));
227         }
228 
229         input += SIZE_OF_INT;
230       }
231     }
232 
233     return output - outputAddress;
234   }
235 
236   private void reset() {
237     previousOffsets[0] = 1;
238     previousOffsets[1] = 4;
239     previousOffsets[2] = 8;
240 
241     currentLiteralsLengthTable = null;
242     currentOffsetCodesTable = null;
243     currentMatchLengthTable = null;
244   }
245 
246   private static int decodeRawBlock(final byte[] inputBase,
247                                     final int inputAddress, final int blockSize,
248                                     final byte[] outputBase,
249                                     final int outputAddress,
250                                     final int outputLimit) {
251     verify(outputAddress + blockSize <= outputLimit, inputAddress,
252            OUTPUT_BUFFER_TOO_SMALL);
253 
254     copyMemory(inputBase, inputAddress, outputBase, outputAddress, blockSize);
255     return blockSize;
256   }
257 
258   private static int decodeRleBlock(final int size, final byte[] inputBase,
259                                     final int inputAddress,
260                                     final byte[] outputBase,
261                                     final int outputAddress,
262                                     final int outputLimit) {
263     verify(outputAddress + size <= outputLimit, inputAddress,
264            OUTPUT_BUFFER_TOO_SMALL);
265 
266     int output = outputAddress;
267     final long value = inputBase[inputAddress] & 0xFFL;
268 
269     int remaining = size;
270     if (remaining >= SIZE_OF_LONG) {
271       final long packed =
272           value | (value << 8) | (value << 16) | (value << 24) | (value << 32) |
273           (value << 40) | (value << 48) | (value << 56);
274 
275       do {
276         putLong(outputBase, output, packed);
277         output += SIZE_OF_LONG;
278         remaining -= SIZE_OF_LONG;
279       } while (remaining >= SIZE_OF_LONG);
280     }
281 
282     for (int i = 0; i < remaining; i++) {
283       outputBase[output] = (byte) value;
284       output++;
285     }
286 
287     return size;
288   }
289 
290   private int decodeCompressedBlock(final byte[] inputBase,
291                                     final int inputAddress, final int blockSize,
292                                     final byte[] outputBase,
293                                     final int outputAddress,
294                                     final int outputLimit, final int windowSize,
295                                     final int outputAbsoluteBaseAddress) {
296     final int inputLimit = inputAddress + blockSize;
297     int input = inputAddress;
298 
299     verify(blockSize <= MAX_BLOCK_SIZE, input,
300            EXPECTED_MATCH_LENGTH_TABLE_TO_BE_PRESENT);
301     verify(blockSize >= MIN_BLOCK_SIZE, input,
302            "Compressed block size too small");
303 
304     // decode literals
305     final int literalsBlockType = inputBase[input] & 0x3;
306 
307     switch (literalsBlockType) {
308       case RAW_LITERALS_BLOCK: {
309         input += decodeRawLiterals(inputBase, input, inputLimit);
310         break;
311       }
312       case RLE_LITERALS_BLOCK: {
313         input += decodeRleLiterals(inputBase, input, blockSize);
314         break;
315       }
316       case TREELESS_LITERALS_BLOCK:
317         verify(huffman.isLoaded(), input, "Dictionary is corrupted");
318       case COMPRESSED_LITERALS_BLOCK: {
319         input += decodeCompressedLiterals(inputBase, input, blockSize,
320                                           literalsBlockType);
321         break;
322       }
323       default:
324         throw fail(input, "Invalid literals block encoding type");
325     }
326 
327     verify(windowSize <= MAX_WINDOW_SIZE, input,
328            "Window size too large (not yet supported)");
329 
330     return decompressSequences(inputBase, input, inputAddress + blockSize,
331                                outputBase, outputAddress, outputLimit,
332                                literalsBase, literalsAddress, literalsLimit,
333                                outputAbsoluteBaseAddress);
334   }
335 
336   private int decompressSequences(final byte[] inputBase,
337                                   final int inputAddress, final int inputLimit,
338                                   final byte[] outputBase,
339                                   final int outputAddress,
340                                   final int outputLimit,
341                                   final byte[] literalsBase,
342                                   final int literalsAddress,
343                                   final int literalsLimit,
344                                   final int outputAbsoluteBaseAddress) {
345     final int fastOutputLimit = outputLimit - SIZE_OF_LONG;
346     final int fastMatchOutputLimit = fastOutputLimit - SIZE_OF_LONG;
347 
348     int input = inputAddress;
349     int output = outputAddress;
350 
351     int literalsInput = literalsAddress;
352 
353     final int size = inputLimit - inputAddress;
354     verify(size >= MIN_SEQUENCES_SIZE, input, NOT_ENOUGH_INPUT_BYTES);
355 
356     // decode header
357     int sequenceCount = inputBase[input++] & 0xFF;
358     if (sequenceCount != 0) {
359       if (sequenceCount == 255) {
360         verify(input + SIZE_OF_SHORT <= inputLimit, input,
361                NOT_ENOUGH_INPUT_BYTES);
362         sequenceCount =
363             (getShort(inputBase, input) & 0xFFFF) + LONG_NUMBER_OF_SEQUENCES;
364         input += SIZE_OF_SHORT;
365       } else if (sequenceCount > 127) {
366         verify(input < inputLimit, input, NOT_ENOUGH_INPUT_BYTES);
367         sequenceCount =
368             ((sequenceCount - 128) << 8) + (inputBase[input++] & 0xFF);
369       }
370 
371       verify(input + SIZE_OF_INT <= inputLimit, input, NOT_ENOUGH_INPUT_BYTES);
372 
373       final byte type = inputBase[input++];
374 
375       final int literalsLengthType = (type & 0xFF) >>> 6;
376       final int offsetCodesType = (type >>> 4) & 0x3;
377       final int matchLengthType = (type >>> 2) & 0x3;
378 
379       input = computeLiteralsTable(literalsLengthType, inputBase, input,
380                                    inputLimit);
381       input =
382           computeOffsetsTable(offsetCodesType, inputBase, input, inputLimit);
383       input = computeMatchLengthTable(matchLengthType, inputBase, input,
384                                       inputLimit);
385 
386       // decompress sequences
387       final BitInputStream.Initializer initializer =
388           new BitInputStream.Initializer(inputBase, input, inputLimit);
389       initializer.initialize();
390       int bitsConsumed = initializer.getBitsConsumed();
391       long bits = initializer.getBits();
392       int currentAddress = initializer.getCurrentAddress();
393 
394       final FiniteStateEntropy.Table currentLiteralsLengthTable1 =
395           this.currentLiteralsLengthTable;
396       final FiniteStateEntropy.Table currentOffsetCodesTable1 =
397           this.currentOffsetCodesTable;
398       final FiniteStateEntropy.Table currentMatchLengthTable1 =
399           this.currentMatchLengthTable;
400 
401       int literalsLengthState = (int) peekBits(bitsConsumed, bits,
402                                                currentLiteralsLengthTable1.log2Size);
403       bitsConsumed += currentLiteralsLengthTable1.log2Size;
404 
405       int offsetCodesState =
406           (int) peekBits(bitsConsumed, bits, currentOffsetCodesTable1.log2Size);
407       bitsConsumed += currentOffsetCodesTable1.log2Size;
408 
409       int matchLengthState =
410           (int) peekBits(bitsConsumed, bits, currentMatchLengthTable1.log2Size);
411       bitsConsumed += currentMatchLengthTable1.log2Size;
412 
413       final int[] previousOffsets1 = this.previousOffsets;
414 
415       final byte[] literalsLengthNumbersOfBits =
416           currentLiteralsLengthTable1.numberOfBits;
417       final int[] literalsLengthNewStates =
418           currentLiteralsLengthTable1.newState;
419       final byte[] literalsLengthSymbols = currentLiteralsLengthTable1.symbol;
420 
421       final byte[] matchLengthNumbersOfBits =
422           currentMatchLengthTable1.numberOfBits;
423       final int[] matchLengthNewStates = currentMatchLengthTable1.newState;
424       final byte[] matchLengthSymbols = currentMatchLengthTable1.symbol;
425 
426       final byte[] offsetCodesNumbersOfBits =
427           currentOffsetCodesTable1.numberOfBits;
428       final int[] offsetCodesNewStates = currentOffsetCodesTable1.newState;
429       final byte[] offsetCodesSymbols = currentOffsetCodesTable1.symbol;
430 
431       while (sequenceCount > 0) {
432         sequenceCount--;
433 
434         final BitInputStream.Loader loader =
435             new BitInputStream.Loader(inputBase, input, currentAddress, bits,
436                                       bitsConsumed);
437         loader.load();
438         bitsConsumed = loader.getBitsConsumed();
439         bits = loader.getBits();
440         currentAddress = loader.getCurrentAddress();
441         if (loader.isOverflow()) {
442           verify(sequenceCount == 0, input, "Not all sequences were consumed");
443           break;
444         }
445 
446         // decode sequence
447         final int literalsLengthCode =
448             literalsLengthSymbols[literalsLengthState];
449         final int matchLengthCode = matchLengthSymbols[matchLengthState];
450         final int offsetCode = offsetCodesSymbols[offsetCodesState];
451 
452         final int literalsLengthBits = LITERALS_LENGTH_BITS[literalsLengthCode];
453         final int matchLengthBits = MATCH_LENGTH_BITS[matchLengthCode];
454 
455         int offset = OFFSET_CODES_BASE[offsetCode];
456         if (offsetCode > 0) {
457           offset += peekBits(bitsConsumed, bits, offsetCode);
458           bitsConsumed += offsetCode;
459         }
460 
461         if (offsetCode <= 1) {
462           if (literalsLengthCode == 0) {
463             offset++;
464           }
465 
466           if (offset != 0) {
467             int temp;
468             if (offset == 3) {
469               temp = previousOffsets1[0] - 1;
470             } else {
471               temp = previousOffsets1[offset];
472             }
473 
474             if (temp == 0) {
475               temp = 1;
476             }
477 
478             if (offset != 1) {
479               previousOffsets1[2] = previousOffsets1[1];
480             }
481             previousOffsets1[1] = previousOffsets1[0];
482             previousOffsets1[0] = temp;
483 
484             offset = temp;
485           } else {
486             offset = previousOffsets1[0];
487           }
488         } else {
489           previousOffsets1[2] = previousOffsets1[1];
490           previousOffsets1[1] = previousOffsets1[0];
491           previousOffsets1[0] = offset;
492         }
493 
494         int matchLength = MATCH_LENGTH_BASE[matchLengthCode];
495         if (matchLengthCode > 31) {
496           matchLength += peekBits(bitsConsumed, bits, matchLengthBits);
497           bitsConsumed += matchLengthBits;
498         }
499 
500         int literalsLength = LITERALS_LENGTH_BASE[literalsLengthCode];
501         if (literalsLengthCode > 15) {
502           literalsLength += peekBits(bitsConsumed, bits, literalsLengthBits);
503           bitsConsumed += literalsLengthBits;
504         }
505 
506         final int totalBits = literalsLengthBits + matchLengthBits + offsetCode;
507         if (totalBits > 64 - 7 -
508                         (LITERAL_LENGTH_TABLE_LOG + MATCH_LENGTH_TABLE_LOG +
509                          OFFSET_TABLE_LOG)) {
510           final BitInputStream.Loader loader1 =
511               new BitInputStream.Loader(inputBase, input, currentAddress, bits,
512                                         bitsConsumed);
513           loader1.load();
514 
515           bitsConsumed = loader1.getBitsConsumed();
516           bits = loader1.getBits();
517           currentAddress = loader1.getCurrentAddress();
518         }
519 
520         int numberOfBits;
521 
522         numberOfBits = literalsLengthNumbersOfBits[literalsLengthState];
523         literalsLengthState =
524             (int) (literalsLengthNewStates[literalsLengthState] +
525                    peekBits(bitsConsumed, bits, numberOfBits)); // <= 9 bits
526         bitsConsumed += numberOfBits;
527 
528         numberOfBits = matchLengthNumbersOfBits[matchLengthState];
529         matchLengthState = (int) (matchLengthNewStates[matchLengthState] +
530                                   peekBits(bitsConsumed, bits,
531                                            numberOfBits)); // <= 9 bits
532         bitsConsumed += numberOfBits;
533 
534         numberOfBits = offsetCodesNumbersOfBits[offsetCodesState];
535         offsetCodesState = (int) (offsetCodesNewStates[offsetCodesState] +
536                                   peekBits(bitsConsumed, bits,
537                                            numberOfBits)); // <= 8 bits
538         bitsConsumed += numberOfBits;
539 
540         final int literalOutputLimit = output + literalsLength;
541         final int matchOutputLimit = literalOutputLimit + matchLength;
542 
543         verify(matchOutputLimit <= outputLimit, input, OUTPUT_BUFFER_TOO_SMALL);
544         final int literalEnd = literalsInput + literalsLength;
545         verify(literalEnd <= literalsLimit, input, INPUT_IS_CORRUPTED);
546 
547         final int matchAddress = literalOutputLimit - offset;
548         verify(matchAddress >= outputAbsoluteBaseAddress, input,
549                INPUT_IS_CORRUPTED);
550 
551         if (literalOutputLimit > fastOutputLimit) {
552           executeLastSequence(outputBase, output, literalOutputLimit,
553                               matchOutputLimit, fastOutputLimit, literalsInput,
554                               matchAddress);
555         } else {
556           // copy literals. literalOutputLimit <= fastOutputLimit, so we can copy
557           // long at a time with over-copy
558           output = copyLiterals(outputBase, literalsBase, output, literalsInput,
559                                 literalOutputLimit);
560           copyMatch(outputBase, fastOutputLimit, output, offset,
561                     matchOutputLimit, matchAddress, matchLength,
562                     fastMatchOutputLimit);
563         }
564         output = matchOutputLimit;
565         literalsInput = literalEnd;
566       }
567     }
568 
569     // last literal segment
570     output = copyLastLiteral(outputBase, literalsBase, literalsLimit, output,
571                              literalsInput);
572 
573     return output - outputAddress;
574   }
575 
576   private int copyLastLiteral(final byte[] outputBase,
577                               final byte[] literalsBase,
578                               final int literalsLimit, int output,
579                               final int literalsInput) {
580     final int lastLiteralsSize = literalsLimit - literalsInput;
581     copyMemory(literalsBase, literalsInput, outputBase, output,
582                lastLiteralsSize);
583     output += lastLiteralsSize;
584     return output;
585   }
586 
587   private void copyMatch(final byte[] outputBase, final int fastOutputLimit,
588                          int output, final int offset,
589                          final int matchOutputLimit, int matchAddress,
590                          int matchLength, final int fastMatchOutputLimit) {
591     matchAddress = copyMatchHead(outputBase, output, offset, matchAddress);
592     output += SIZE_OF_LONG;
593     matchLength -= SIZE_OF_LONG; // first 8 bytes copied above
594 
595     copyMatchTail(outputBase, fastOutputLimit, output, matchOutputLimit,
596                   matchAddress, matchLength, fastMatchOutputLimit);
597   }
598 
599   private void copyMatchTail(final byte[] outputBase, final int fastOutputLimit,
600                              int output, final int matchOutputLimit,
601                              int matchAddress, final int matchLength,
602                              final int fastMatchOutputLimit) {
603     // fastMatchOutputLimit is just fastOutputLimit - SIZE_OF_LONG. It needs to be passed in so that it can be computed once for the
604     // whole invocation to decompressSequences. Otherwise, we'd just compute it here.
605     // If matchOutputLimit is < fastMatchOutputLimit, we know that even after the head (8 bytes) has been copied, the output pointer
606     // will be within fastOutputLimit, so it's safe to copy blindly before checking the limit condition
607     if (matchOutputLimit < fastMatchOutputLimit) {
608       int copied = 0;
609       do {
610         putLong(outputBase, output, getLong(outputBase, matchAddress));
611         output += SIZE_OF_LONG;
612         matchAddress += SIZE_OF_LONG;
613         copied += SIZE_OF_LONG;
614       } while (copied < matchLength);
615     } else {
616       while (output < fastOutputLimit) {
617         putLong(outputBase, output, getLong(outputBase, matchAddress));
618         matchAddress += SIZE_OF_LONG;
619         output += SIZE_OF_LONG;
620       }
621 
622       while (output < matchOutputLimit) {
623         outputBase[output++] = outputBase[matchAddress++];
624       }
625     }
626   }
627 
628   private int copyMatchHead(final byte[] outputBase, final int output,
629                             final int offset, int matchAddress) {
630     // copy match
631     if (offset < 8) {
632       // 8 bytes apart so that we can copy long-at-a-time below
633       final int increment32 = DEC_32_TABLE[offset];
634       final int decrement64 = DEC_64_TABLE[offset];
635 
636       outputBase[output] = outputBase[matchAddress];
637       outputBase[output + 1] = outputBase[matchAddress + 1];
638       outputBase[output + 2] = outputBase[matchAddress + 2];
639       outputBase[output + 3] = outputBase[matchAddress + 3];
640       matchAddress += increment32;
641 
642       putInt(outputBase, output + 4, getInt(outputBase, matchAddress));
643       matchAddress -= decrement64;
644     } else {
645       putLong(outputBase, output, getLong(outputBase, matchAddress));
646       matchAddress += SIZE_OF_LONG;
647     }
648     return matchAddress;
649   }
650 
651   private int copyLiterals(final byte[] outputBase, final byte[] literalsBase,
652                            int output, final int literalsInput,
653                            final int literalOutputLimit) {
654     int literalInput = literalsInput;
655     do {
656       putLong(outputBase, output, getLong(literalsBase, literalInput));
657       output += SIZE_OF_LONG;
658       literalInput += SIZE_OF_LONG;
659     } while (output < literalOutputLimit);
660     output = literalOutputLimit; // correction in case we over-copied
661     return output;
662   }
663 
664   private int computeMatchLengthTable(final int matchLengthType,
665                                       final byte[] inputBase, int input,
666                                       final int inputLimit) {
667     switch (matchLengthType) {
668       case SEQUENCE_ENCODING_RLE:
669         verify(input < inputLimit, input, NOT_ENOUGH_INPUT_BYTES);
670 
671         final byte value = inputBase[input++];
672         verify(value <= MAX_MATCH_LENGTH_SYMBOL, input,
673                VALUE_EXCEEDS_EXPECTED_MAXIMUM_VALUE);
674 
675         FseTableReader.initializeRleTable(matchLengthTable, value);
676         currentMatchLengthTable = matchLengthTable;
677         break;
678       case SEQUENCE_ENCODING_BASIC:
679         currentMatchLengthTable = DEFAULT_MATCH_LENGTH_TABLE;
680         break;
681       case SEQUENCE_ENCODING_REPEAT:
682         verify(currentMatchLengthTable != null, input,
683                EXPECTED_MATCH_LENGTH_TABLE_TO_BE_PRESENT);
684         break;
685       case SEQUENCE_ENCODING_COMPRESSED:
686         input +=
687             fse.readFseTable(matchLengthTable, inputBase, input, inputLimit,
688                              MAX_MATCH_LENGTH_SYMBOL, MATCH_LENGTH_TABLE_LOG);
689         currentMatchLengthTable = matchLengthTable;
690         break;
691       default:
692         throw fail(input, "Invalid match length encoding type");
693     }
694     return input;
695   }
696 
697   private int computeOffsetsTable(final int offsetCodesType,
698                                   final byte[] inputBase, int input,
699                                   final int inputLimit) {
700     switch (offsetCodesType) {
701       case SEQUENCE_ENCODING_RLE:
702         verify(input < inputLimit, input, NOT_ENOUGH_INPUT_BYTES);
703 
704         final byte value = inputBase[input++];
705         verify(value <= DEFAULT_MAX_OFFSET_CODE_SYMBOL, input,
706                VALUE_EXCEEDS_EXPECTED_MAXIMUM_VALUE);
707 
708         FseTableReader.initializeRleTable(offsetCodesTable, value);
709         currentOffsetCodesTable = offsetCodesTable;
710         break;
711       case SEQUENCE_ENCODING_BASIC:
712         currentOffsetCodesTable = DEFAULT_OFFSET_CODES_TABLE;
713         break;
714       case SEQUENCE_ENCODING_REPEAT:
715         verify(currentOffsetCodesTable != null, input,
716                EXPECTED_MATCH_LENGTH_TABLE_TO_BE_PRESENT);
717         break;
718       case SEQUENCE_ENCODING_COMPRESSED:
719         input +=
720             fse.readFseTable(offsetCodesTable, inputBase, input, inputLimit,
721                              DEFAULT_MAX_OFFSET_CODE_SYMBOL, OFFSET_TABLE_LOG);
722         currentOffsetCodesTable = offsetCodesTable;
723         break;
724       default:
725         throw fail(input, "Invalid offset code encoding type");
726     }
727     return input;
728   }
729 
730   private int computeLiteralsTable(final int literalsLengthType,
731                                    final byte[] inputBase, int input,
732                                    final int inputLimit) {
733     switch (literalsLengthType) {
734       case SEQUENCE_ENCODING_RLE:
735         verify(input < inputLimit, input, NOT_ENOUGH_INPUT_BYTES);
736 
737         final byte value = inputBase[input++];
738         verify(value <= MAX_LITERALS_LENGTH_SYMBOL, input,
739                VALUE_EXCEEDS_EXPECTED_MAXIMUM_VALUE);
740 
741         FseTableReader.initializeRleTable(literalsLengthTable, value);
742         currentLiteralsLengthTable = literalsLengthTable;
743         break;
744       case SEQUENCE_ENCODING_BASIC:
745         currentLiteralsLengthTable = DEFAULT_LITERALS_LENGTH_TABLE;
746         break;
747       case SEQUENCE_ENCODING_REPEAT:
748         verify(currentLiteralsLengthTable != null, input,
749                EXPECTED_MATCH_LENGTH_TABLE_TO_BE_PRESENT);
750         break;
751       case SEQUENCE_ENCODING_COMPRESSED:
752         input +=
753             fse.readFseTable(literalsLengthTable, inputBase, input, inputLimit,
754                              MAX_LITERALS_LENGTH_SYMBOL,
755                              LITERAL_LENGTH_TABLE_LOG);
756         currentLiteralsLengthTable = literalsLengthTable;
757         break;
758       default:
759         throw fail(input, "Invalid literals length encoding type");
760     }
761     return input;
762   }
763 
764   private void executeLastSequence(final byte[] outputBase, int output,
765                                    final int literalOutputLimit,
766                                    final int matchOutputLimit,
767                                    final int fastOutputLimit, int literalInput,
768                                    int matchAddress) {
769     // copy literals
770     if (output < fastOutputLimit) {
771       // wild copy
772       do {
773         putLong(outputBase, output, getLong(literalsBase, literalInput));
774         output += SIZE_OF_LONG;
775         literalInput += SIZE_OF_LONG;
776       } while (output < fastOutputLimit);
777 
778       literalInput -= output - fastOutputLimit;
779       output = fastOutputLimit;
780     }
781 
782     while (output < literalOutputLimit) {
783       outputBase[output] = literalsBase[literalInput];
784       output++;
785       literalInput++;
786     }
787 
788     // copy match
789     while (output < matchOutputLimit) {
790       outputBase[output] = outputBase[matchAddress];
791       output++;
792       matchAddress++;
793     }
794   }
795 
796   private int decodeCompressedLiterals(final byte[] inputBase,
797                                        final int inputAddress,
798                                        final int blockSize,
799                                        final int literalsBlockType) {
800     int input = inputAddress;
801     verify(blockSize >= 5, input, NOT_ENOUGH_INPUT_BYTES);
802 
803     // compressed
804     final int compressedSize;
805     final int uncompressedSize;
806     boolean singleStream = false;
807     final int headerSize;
808     final int type = (inputBase[input] >> 2) & 0x3;
809     switch (type) {
810       case 0:
811         singleStream = true;
812       case 1: {
813         final int header = getInt(inputBase, input);
814 
815         headerSize = 3;
816         uncompressedSize = (header >>> 4) & mask(10);
817         compressedSize = (header >>> 14) & mask(10);
818         break;
819       }
820       case 2: {
821         final int header = getInt(inputBase, input);
822 
823         headerSize = 4;
824         uncompressedSize = (header >>> 4) & mask(14);
825         compressedSize = (header >>> 18) & mask(14);
826         break;
827       }
828       case 3: {
829         // read 5 little-endian bytes
830         final long header = inputBase[input] & 0xFF |
831                             (getInt(inputBase, input + 1) & 0xFFFFFFFFL) << 8;
832 
833         headerSize = 5;
834         uncompressedSize = (int) ((header >>> 4) & mask(18));
835         compressedSize = (int) ((header >>> 22) & mask(18));
836         break;
837       }
838       default:
839         throw fail(input, "Invalid literals header size type");
840     }
841 
842     verify(uncompressedSize <= MAX_BLOCK_SIZE, input,
843            "Block exceeds maximum size");
844     verify(headerSize + compressedSize <= blockSize, input, INPUT_IS_CORRUPTED);
845 
846     input += headerSize;
847 
848     final int inputLimit = input + compressedSize;
849     if (literalsBlockType != TREELESS_LITERALS_BLOCK) {
850       input += huffman.readTable(inputBase, input, compressedSize);
851     }
852 
853     literalsBase = literals;
854     literalsAddress = 0;
855     literalsLimit = uncompressedSize;
856 
857     if (singleStream) {
858       huffman.decodeSingleStream(inputBase, input, inputLimit, literals,
859                                  literalsAddress, literalsLimit);
860     } else {
861       huffman.decode4Streams(inputBase, input, inputLimit, literals,
862                              literalsAddress, literalsLimit);
863     }
864 
865     return headerSize + compressedSize;
866   }
867 
868   private int decodeRleLiterals(final byte[] inputBase, final int inputAddress,
869                                 final int blockSize) {
870     int input = inputAddress;
871     final int outputSize;
872 
873     final int type = (inputBase[input] >> 2) & 0x3;
874     switch (type) {
875       case 0:
876       case 2:
877         outputSize = (inputBase[input] & 0xFF) >>> 3;
878         input++;
879         break;
880       case 1:
881         outputSize = (getShort(inputBase, input) & 0xFFFF) >>> 4;
882         input += 2;
883         break;
884       case 3:
885         // we need at least 4 bytes (3 for the header, 1 for the payload)
886         verify(blockSize >= SIZE_OF_INT, input, NOT_ENOUGH_INPUT_BYTES);
887         outputSize = (getInt(inputBase, input) & 0xFFFFFF) >>> 4;
888         input += 3;
889         break;
890       default:
891         throw fail(input, "Invalid RLE literals header encoding type");
892     }
893 
894     verify(outputSize <= MAX_BLOCK_SIZE, input,
895            "Output exceeds maximum block size");
896 
897     final byte value = inputBase[input++];
898     Arrays.fill(literals, 0, outputSize + SIZE_OF_LONG, value);
899 
900     literalsBase = literals;
901     literalsAddress = 0;
902     literalsLimit = outputSize;
903 
904     return input - inputAddress;
905   }
906 
907   private int decodeRawLiterals(final byte[] inputBase, final int inputAddress,
908                                 final int inputLimit) {
909     int input = inputAddress;
910     final int type = (inputBase[input] >> 2) & 0x3;
911 
912     final int literalSize;
913     switch (type) {
914       case 0:
915       case 2:
916         literalSize = (inputBase[input] & 0xFF) >>> 3;
917         input++;
918         break;
919       case 1:
920         literalSize = (getShort(inputBase, input) & 0xFFFF) >>> 4;
921         input += 2;
922         break;
923       case 3:
924         // read 3 little-endian bytes
925         final int header = ((inputBase[input] & 0xFF) |
926                             ((getShort(inputBase, input + 1) & 0xFFFF) << 8));
927 
928         literalSize = header >>> 4;
929         input += 3;
930         break;
931       default:
932         throw fail(input, "Invalid raw literals header encoding type");
933     }
934 
935     verify(input + literalSize <= inputLimit, input, NOT_ENOUGH_INPUT_BYTES);
936 
937     // Set literals pointer to [input, literalSize], but only if we can copy 8 bytes at a time during sequence decoding
938     // Otherwise, copy literals into buffer that's big enough to guarantee that
939     if (literalSize > (inputLimit - input) - SIZE_OF_LONG) {
940       literalsBase = literals;
941       literalsAddress = 0;
942       literalsLimit = literalSize;
943 
944       copyMemory(inputBase, input, literals, literalsAddress, literalSize);
945       Arrays.fill(literals, literalSize, literalSize + SIZE_OF_LONG, (byte) 0);
946     } else {
947       literalsBase = inputBase;
948       literalsAddress = input;
949       literalsLimit = literalsAddress + literalSize;
950     }
951     input += literalSize;
952 
953     return input - inputAddress;
954   }
955 
956   static FrameHeader readFrameHeader(final byte[] inputBase,
957                                      final int inputAddress,
958                                      final int inputLimit) {
959     int input = inputAddress;
960     verify(input < inputLimit, input, NOT_ENOUGH_INPUT_BYTES);
961 
962     final int frameHeaderDescriptor = inputBase[input++] & 0xFF;
963     final boolean singleSegment = (frameHeaderDescriptor & 0x20) != 0;
964     final int dictionaryDescriptor = frameHeaderDescriptor & 0x3;
965     final int contentSizeDescriptor = frameHeaderDescriptor >>> 6;
966 
967     final int headerSize = 1 + (singleSegment? 0 : 1) +
968                            (dictionaryDescriptor == 0? 0 :
969                                (1 << (dictionaryDescriptor - 1))) +
970                            (contentSizeDescriptor == 0? (singleSegment? 1 : 0) :
971                                (1 << contentSizeDescriptor));
972 
973     verify(headerSize <= inputLimit - inputAddress, input,
974            NOT_ENOUGH_INPUT_BYTES);
975 
976     // decode window size
977     int windowSize = -1;
978     if (!singleSegment) {
979       final int windowDescriptor = inputBase[input++] & 0xFF;
980       final int exponent = windowDescriptor >>> 3;
981       final int mantissa = windowDescriptor & 0x7;
982 
983       final int base = 1 << (MIN_WINDOW_LOG + exponent);
984       windowSize = base + (base / 8) * mantissa;
985     }
986 
987     // decode dictionary id
988     int dictionaryId = -1;
989     switch (dictionaryDescriptor) {
990       case 1:
991         dictionaryId = inputBase[input] & 0xFF;
992         input += SIZE_OF_BYTE;
993         break;
994       case 2:
995         dictionaryId = getShort(inputBase, input) & 0xFFFF;
996         input += SIZE_OF_SHORT;
997         break;
998       case 3:
999         dictionaryId = getInt(inputBase, input);
1000         input += SIZE_OF_INT;
1001         break;
1002     }
1003     verify(dictionaryId == -1, input, "Custom dictionaries not supported");
1004 
1005     // decode content size
1006     int contentSize = -1;
1007     switch (contentSizeDescriptor) {
1008       case 0:
1009         if (singleSegment) {
1010           contentSize = inputBase[input] & 0xFF;
1011           input += SIZE_OF_BYTE;
1012         }
1013         break;
1014       case 1:
1015         contentSize = getShort(inputBase, input) & 0xFFFF;
1016         contentSize += 256;
1017         input += SIZE_OF_SHORT;
1018         break;
1019       case 2:
1020         contentSize = getInt(inputBase, input);
1021         input += SIZE_OF_INT;
1022         break;
1023       case 3:
1024         contentSize = (int) getLong(inputBase, input);
1025         input += SIZE_OF_LONG;
1026         break;
1027     }
1028 
1029     final boolean hasChecksum = (frameHeaderDescriptor & 0x4) != 0;
1030 
1031     return new FrameHeader(input - inputAddress, windowSize, contentSize,
1032                            dictionaryId, hasChecksum);
1033   }
1034 
1035   public static int getDecompressedSize(final byte[] inputBase,
1036                                         final int inputAddress,
1037                                         final int inputLimit) {
1038     int input = inputAddress;
1039     input += verifyMagic(inputBase, input, inputLimit);
1040     return readFrameHeader(inputBase, input, inputLimit).contentSize;
1041   }
1042 
1043   static int verifyMagic(final byte[] inputBase, final int inputAddress,
1044                          final int inputLimit) {
1045     verify(inputLimit - inputAddress >= 4, inputAddress,
1046            NOT_ENOUGH_INPUT_BYTES);
1047 
1048     final int magic = getInt(inputBase, inputAddress);
1049     if (magic != MAGIC_NUMBER) {
1050       if (magic == V07_MAGIC_NUMBER) {
1051         throw new MalformedInputException(inputAddress,
1052                                           "Data encoded in unsupported ZSTD v0.7 format");
1053       }
1054       throw new MalformedInputException(inputAddress, "Invalid magic prefix: " +
1055                                                       Integer.toHexString(
1056                                                           magic));
1057     }
1058 
1059     return SIZE_OF_INT;
1060   }
1061 }