ZstdFrameDecompressor.java
/*
* This file is part of Waarp Project (named also Waarp or GG).
*
* Copyright (c) 2019, Waarp SAS, and individual contributors by the @author
* tags. See the COPYRIGHT.txt in the distribution for a full listing of
* individual contributors.
*
* All Waarp Project is free software: you can redistribute it and/or
* modify it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or (at your
* option) any later version.
*
* Waarp is distributed in the hope that it will be useful, but WITHOUT ANY
* WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
* A PARTICULAR PURPOSE. See the GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License along with
* Waarp . If not, see <http://www.gnu.org/licenses/>.
*/
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.waarp.compress.zstdsafe;
import org.waarp.compress.MalformedInputException;
import java.util.Arrays;
import static org.waarp.compress.zstdsafe.BitInputStream.*;
import static org.waarp.compress.zstdsafe.Constants.*;
import static org.waarp.compress.zstdsafe.UnsafeUtil.*;
import static org.waarp.compress.zstdsafe.Util.*;
class ZstdFrameDecompressor {
private static final int[] DEC_32_TABLE = { 4, 1, 2, 1, 4, 4, 4, 4 };
private static final int[] DEC_64_TABLE = { 0, 0, 0, -1, 0, 1, 2, 3 };
private static final int V07_MAGIC_NUMBER = 0xFD2FB527;
private static final int MAX_WINDOW_SIZE = 1 << 23;
private static final int[] LITERALS_LENGTH_BASE = {
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18, 20, 22, 24,
28, 32, 40, 48, 64, 0x80, 0x100, 0x200, 0x400, 0x800, 0x1000, 0x2000,
0x4000, 0x8000, 0x10000
};
private static final int[] MATCH_LENGTH_BASE = {
3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22,
23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 37, 39, 41, 43, 47,
51, 59, 67, 83, 99, 0x83, 0x103, 0x203, 0x403, 0x803, 0x1003, 0x2003,
0x4003, 0x8003, 0x10003
};
private static final int[] OFFSET_CODES_BASE = {
0, 1, 1, 5, 0xD, 0x1D, 0x3D, 0x7D, 0xFD, 0x1FD, 0x3FD, 0x7FD, 0xFFD,
0x1FFD, 0x3FFD, 0x7FFD, 0xFFFD, 0x1FFFD, 0x3FFFD, 0x7FFFD, 0xFFFFD,
0x1FFFFD, 0x3FFFFD, 0x7FFFFD, 0xFFFFFD, 0x1FFFFFD, 0x3FFFFFD, 0x7FFFFFD,
0xFFFFFFD
};
private static final FiniteStateEntropy.Table DEFAULT_LITERALS_LENGTH_TABLE =
new FiniteStateEntropy.Table(6, new int[] {
0, 16, 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32, 0, 0, 0, 0,
32, 0, 0, 32, 0, 32, 0, 32, 0, 0, 32, 0, 32, 0, 32, 0, 0, 16, 32, 0,
0, 48, 16, 32, 32, 32, 32, 32, 32, 32, 32, 0, 32, 32, 32, 32, 32, 32,
0, 0, 0, 0
}, new byte[] {
0, 0, 1, 3, 4, 6, 7, 9, 10, 12, 14, 16, 18, 19, 21, 22, 24, 25, 26,
27, 29, 31, 0, 1, 2, 4, 5, 7, 8, 10, 11, 13, 16, 17, 19, 20, 22, 23,
25, 25, 26, 28, 30, 0, 1, 2, 3, 5, 6, 8, 9, 11, 12, 15, 17, 18, 20,
21, 23, 24, 35, 34, 33, 32
}, new byte[] {
4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 6, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 4,
4, 5, 5, 5, 5, 5, 5, 5, 6, 5, 5, 5, 5, 5, 5, 4, 4, 5, 6, 6, 4, 4, 5,
5, 5, 5, 5, 5, 5, 5, 6, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6
});
private static final FiniteStateEntropy.Table DEFAULT_OFFSET_CODES_TABLE =
new FiniteStateEntropy.Table(5, new int[] {
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 0, 0, 0, 0, 16, 0, 0,
0, 16, 0, 0, 0, 0, 0, 0, 0
}, new byte[] {
0, 6, 9, 15, 21, 3, 7, 12, 18, 23, 5, 8, 14, 20, 2, 7, 11, 17, 22, 4,
8, 13, 19, 1, 6, 10, 16, 28, 27, 26, 25, 24
}, new byte[] {
5, 4, 5, 5, 5, 5, 4, 5, 5, 5, 5, 4, 5, 5, 5, 4, 5, 5, 5, 5, 4, 5, 5,
5, 4, 5, 5, 5, 5, 5, 5, 5
});
private static final FiniteStateEntropy.Table DEFAULT_MATCH_LENGTH_TABLE =
new FiniteStateEntropy.Table(6, new int[] {
0, 0, 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16,
0, 32, 0, 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32, 48,
16, 32, 32, 32, 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
}, new byte[] {
0, 1, 2, 3, 5, 6, 8, 10, 13, 16, 19, 22, 25, 28, 31, 33, 35, 37, 39,
41, 43, 45, 1, 2, 3, 4, 6, 7, 9, 12, 15, 18, 21, 24, 27, 30, 32, 34,
36, 38, 40, 42, 44, 1, 1, 2, 4, 5, 7, 8, 11, 14, 17, 20, 23, 26, 29,
52, 51, 50, 49, 48, 47, 46
}, new byte[] {
6, 4, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 4,
4, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 4, 4, 4,
5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6
});
public static final String NOT_ENOUGH_INPUT_BYTES = "Not enough input bytes";
public static final String OUTPUT_BUFFER_TOO_SMALL =
"Output buffer too small";
public static final String EXPECTED_MATCH_LENGTH_TABLE_TO_BE_PRESENT =
"Expected match length table to be present";
public static final String INPUT_IS_CORRUPTED = "Input is corrupted";
public static final String VALUE_EXCEEDS_EXPECTED_MAXIMUM_VALUE =
"Value exceeds expected maximum value";
private final byte[] literals = new byte[MAX_BLOCK_SIZE + SIZE_OF_LONG];
// extra space to allow for long-at-a-time copy
// current buffer containing literals
private byte[] literalsBase;
private int literalsAddress;
private int literalsLimit;
private final int[] previousOffsets = new int[3];
private final FiniteStateEntropy.Table literalsLengthTable =
new FiniteStateEntropy.Table(LITERAL_LENGTH_TABLE_LOG);
private final FiniteStateEntropy.Table offsetCodesTable =
new FiniteStateEntropy.Table(OFFSET_TABLE_LOG);
private final FiniteStateEntropy.Table matchLengthTable =
new FiniteStateEntropy.Table(MATCH_LENGTH_TABLE_LOG);
private FiniteStateEntropy.Table currentLiteralsLengthTable;
private FiniteStateEntropy.Table currentOffsetCodesTable;
private FiniteStateEntropy.Table currentMatchLengthTable;
private final Huffman huffman = new Huffman();
private final FseTableReader fse = new FseTableReader();
public int decompress(final byte[] inputBase, final int inputAddress,
final int inputLimit, final byte[] outputBase,
final int outputAddress, final int outputLimit) {
if (outputAddress == outputLimit) {
return 0;
}
int input = inputAddress;
int output = outputAddress;
while (input < inputLimit) {
reset();
final int outputStart = output;
input += verifyMagic(inputBase, inputAddress, inputLimit);
final FrameHeader frameHeader =
readFrameHeader(inputBase, input, inputLimit);
input += frameHeader.headerSize;
boolean lastBlock;
do {
verify(input + SIZE_OF_BLOCK_HEADER <= inputLimit, input,
NOT_ENOUGH_INPUT_BYTES);
// read block header
final int header = getInt(inputBase, input) & 0xFFFFFF;
input += SIZE_OF_BLOCK_HEADER;
lastBlock = (header & 1) != 0;
final int blockType = (header >>> 1) & 0x3;
final int blockSize = (header >>> 3) & 0x1FFFFF; // 21 bits
final int decodedSize;
switch (blockType) {
case RAW_BLOCK:
verify(inputAddress + blockSize <= inputLimit, input,
NOT_ENOUGH_INPUT_BYTES);
decodedSize =
decodeRawBlock(inputBase, input, blockSize, outputBase, output,
outputLimit);
input += blockSize;
break;
case RLE_BLOCK:
verify(inputAddress + 1 <= inputLimit, input,
NOT_ENOUGH_INPUT_BYTES);
decodedSize =
decodeRleBlock(blockSize, inputBase, input, outputBase, output,
outputLimit);
input += 1;
break;
case COMPRESSED_BLOCK:
verify(inputAddress + blockSize <= inputLimit, input,
NOT_ENOUGH_INPUT_BYTES);
decodedSize =
decodeCompressedBlock(inputBase, input, blockSize, outputBase,
output, outputLimit,
frameHeader.windowSize, outputAddress);
input += blockSize;
break;
default:
throw fail(input, "Invalid block type");
}
output += decodedSize;
} while (!lastBlock);
if (frameHeader.hasChecksum) {
final int decodedFrameSize = output - outputStart;
final long hash =
XxHash64.hash(0, outputBase, outputStart, decodedFrameSize);
final int checksum = getInt(inputBase, input);
if (checksum != (int) hash) {
throw new MalformedInputException(input, String.format(
"Bad checksum. Expected: %s, actual: %s",
Integer.toHexString(checksum), Integer.toHexString((int) hash)));
}
input += SIZE_OF_INT;
}
}
return output - outputAddress;
}
private void reset() {
previousOffsets[0] = 1;
previousOffsets[1] = 4;
previousOffsets[2] = 8;
currentLiteralsLengthTable = null;
currentOffsetCodesTable = null;
currentMatchLengthTable = null;
}
private static int decodeRawBlock(final byte[] inputBase,
final int inputAddress, final int blockSize,
final byte[] outputBase,
final int outputAddress,
final int outputLimit) {
verify(outputAddress + blockSize <= outputLimit, inputAddress,
OUTPUT_BUFFER_TOO_SMALL);
copyMemory(inputBase, inputAddress, outputBase, outputAddress, blockSize);
return blockSize;
}
private static int decodeRleBlock(final int size, final byte[] inputBase,
final int inputAddress,
final byte[] outputBase,
final int outputAddress,
final int outputLimit) {
verify(outputAddress + size <= outputLimit, inputAddress,
OUTPUT_BUFFER_TOO_SMALL);
int output = outputAddress;
final long value = inputBase[inputAddress] & 0xFFL;
int remaining = size;
if (remaining >= SIZE_OF_LONG) {
final long packed =
value | (value << 8) | (value << 16) | (value << 24) | (value << 32) |
(value << 40) | (value << 48) | (value << 56);
do {
putLong(outputBase, output, packed);
output += SIZE_OF_LONG;
remaining -= SIZE_OF_LONG;
} while (remaining >= SIZE_OF_LONG);
}
for (int i = 0; i < remaining; i++) {
outputBase[output] = (byte) value;
output++;
}
return size;
}
private int decodeCompressedBlock(final byte[] inputBase,
final int inputAddress, final int blockSize,
final byte[] outputBase,
final int outputAddress,
final int outputLimit, final int windowSize,
final int outputAbsoluteBaseAddress) {
final int inputLimit = inputAddress + blockSize;
int input = inputAddress;
verify(blockSize <= MAX_BLOCK_SIZE, input,
EXPECTED_MATCH_LENGTH_TABLE_TO_BE_PRESENT);
verify(blockSize >= MIN_BLOCK_SIZE, input,
"Compressed block size too small");
// decode literals
final int literalsBlockType = inputBase[input] & 0x3;
switch (literalsBlockType) {
case RAW_LITERALS_BLOCK: {
input += decodeRawLiterals(inputBase, input, inputLimit);
break;
}
case RLE_LITERALS_BLOCK: {
input += decodeRleLiterals(inputBase, input, blockSize);
break;
}
case TREELESS_LITERALS_BLOCK:
verify(huffman.isLoaded(), input, "Dictionary is corrupted");
case COMPRESSED_LITERALS_BLOCK: {
input += decodeCompressedLiterals(inputBase, input, blockSize,
literalsBlockType);
break;
}
default:
throw fail(input, "Invalid literals block encoding type");
}
verify(windowSize <= MAX_WINDOW_SIZE, input,
"Window size too large (not yet supported)");
return decompressSequences(inputBase, input, inputAddress + blockSize,
outputBase, outputAddress, outputLimit,
literalsBase, literalsAddress, literalsLimit,
outputAbsoluteBaseAddress);
}
private int decompressSequences(final byte[] inputBase,
final int inputAddress, final int inputLimit,
final byte[] outputBase,
final int outputAddress,
final int outputLimit,
final byte[] literalsBase,
final int literalsAddress,
final int literalsLimit,
final int outputAbsoluteBaseAddress) {
final int fastOutputLimit = outputLimit - SIZE_OF_LONG;
final int fastMatchOutputLimit = fastOutputLimit - SIZE_OF_LONG;
int input = inputAddress;
int output = outputAddress;
int literalsInput = literalsAddress;
final int size = inputLimit - inputAddress;
verify(size >= MIN_SEQUENCES_SIZE, input, NOT_ENOUGH_INPUT_BYTES);
// decode header
int sequenceCount = inputBase[input++] & 0xFF;
if (sequenceCount != 0) {
if (sequenceCount == 255) {
verify(input + SIZE_OF_SHORT <= inputLimit, input,
NOT_ENOUGH_INPUT_BYTES);
sequenceCount =
(getShort(inputBase, input) & 0xFFFF) + LONG_NUMBER_OF_SEQUENCES;
input += SIZE_OF_SHORT;
} else if (sequenceCount > 127) {
verify(input < inputLimit, input, NOT_ENOUGH_INPUT_BYTES);
sequenceCount =
((sequenceCount - 128) << 8) + (inputBase[input++] & 0xFF);
}
verify(input + SIZE_OF_INT <= inputLimit, input, NOT_ENOUGH_INPUT_BYTES);
final byte type = inputBase[input++];
final int literalsLengthType = (type & 0xFF) >>> 6;
final int offsetCodesType = (type >>> 4) & 0x3;
final int matchLengthType = (type >>> 2) & 0x3;
input = computeLiteralsTable(literalsLengthType, inputBase, input,
inputLimit);
input =
computeOffsetsTable(offsetCodesType, inputBase, input, inputLimit);
input = computeMatchLengthTable(matchLengthType, inputBase, input,
inputLimit);
// decompress sequences
final BitInputStream.Initializer initializer =
new BitInputStream.Initializer(inputBase, input, inputLimit);
initializer.initialize();
int bitsConsumed = initializer.getBitsConsumed();
long bits = initializer.getBits();
int currentAddress = initializer.getCurrentAddress();
final FiniteStateEntropy.Table currentLiteralsLengthTable1 =
this.currentLiteralsLengthTable;
final FiniteStateEntropy.Table currentOffsetCodesTable1 =
this.currentOffsetCodesTable;
final FiniteStateEntropy.Table currentMatchLengthTable1 =
this.currentMatchLengthTable;
int literalsLengthState = (int) peekBits(bitsConsumed, bits,
currentLiteralsLengthTable1.log2Size);
bitsConsumed += currentLiteralsLengthTable1.log2Size;
int offsetCodesState =
(int) peekBits(bitsConsumed, bits, currentOffsetCodesTable1.log2Size);
bitsConsumed += currentOffsetCodesTable1.log2Size;
int matchLengthState =
(int) peekBits(bitsConsumed, bits, currentMatchLengthTable1.log2Size);
bitsConsumed += currentMatchLengthTable1.log2Size;
final int[] previousOffsets1 = this.previousOffsets;
final byte[] literalsLengthNumbersOfBits =
currentLiteralsLengthTable1.numberOfBits;
final int[] literalsLengthNewStates =
currentLiteralsLengthTable1.newState;
final byte[] literalsLengthSymbols = currentLiteralsLengthTable1.symbol;
final byte[] matchLengthNumbersOfBits =
currentMatchLengthTable1.numberOfBits;
final int[] matchLengthNewStates = currentMatchLengthTable1.newState;
final byte[] matchLengthSymbols = currentMatchLengthTable1.symbol;
final byte[] offsetCodesNumbersOfBits =
currentOffsetCodesTable1.numberOfBits;
final int[] offsetCodesNewStates = currentOffsetCodesTable1.newState;
final byte[] offsetCodesSymbols = currentOffsetCodesTable1.symbol;
while (sequenceCount > 0) {
sequenceCount--;
final BitInputStream.Loader loader =
new BitInputStream.Loader(inputBase, input, currentAddress, bits,
bitsConsumed);
loader.load();
bitsConsumed = loader.getBitsConsumed();
bits = loader.getBits();
currentAddress = loader.getCurrentAddress();
if (loader.isOverflow()) {
verify(sequenceCount == 0, input, "Not all sequences were consumed");
break;
}
// decode sequence
final int literalsLengthCode =
literalsLengthSymbols[literalsLengthState];
final int matchLengthCode = matchLengthSymbols[matchLengthState];
final int offsetCode = offsetCodesSymbols[offsetCodesState];
final int literalsLengthBits = LITERALS_LENGTH_BITS[literalsLengthCode];
final int matchLengthBits = MATCH_LENGTH_BITS[matchLengthCode];
int offset = OFFSET_CODES_BASE[offsetCode];
if (offsetCode > 0) {
offset += peekBits(bitsConsumed, bits, offsetCode);
bitsConsumed += offsetCode;
}
if (offsetCode <= 1) {
if (literalsLengthCode == 0) {
offset++;
}
if (offset != 0) {
int temp;
if (offset == 3) {
temp = previousOffsets1[0] - 1;
} else {
temp = previousOffsets1[offset];
}
if (temp == 0) {
temp = 1;
}
if (offset != 1) {
previousOffsets1[2] = previousOffsets1[1];
}
previousOffsets1[1] = previousOffsets1[0];
previousOffsets1[0] = temp;
offset = temp;
} else {
offset = previousOffsets1[0];
}
} else {
previousOffsets1[2] = previousOffsets1[1];
previousOffsets1[1] = previousOffsets1[0];
previousOffsets1[0] = offset;
}
int matchLength = MATCH_LENGTH_BASE[matchLengthCode];
if (matchLengthCode > 31) {
matchLength += peekBits(bitsConsumed, bits, matchLengthBits);
bitsConsumed += matchLengthBits;
}
int literalsLength = LITERALS_LENGTH_BASE[literalsLengthCode];
if (literalsLengthCode > 15) {
literalsLength += peekBits(bitsConsumed, bits, literalsLengthBits);
bitsConsumed += literalsLengthBits;
}
final int totalBits = literalsLengthBits + matchLengthBits + offsetCode;
if (totalBits > 64 - 7 -
(LITERAL_LENGTH_TABLE_LOG + MATCH_LENGTH_TABLE_LOG +
OFFSET_TABLE_LOG)) {
final BitInputStream.Loader loader1 =
new BitInputStream.Loader(inputBase, input, currentAddress, bits,
bitsConsumed);
loader1.load();
bitsConsumed = loader1.getBitsConsumed();
bits = loader1.getBits();
currentAddress = loader1.getCurrentAddress();
}
int numberOfBits;
numberOfBits = literalsLengthNumbersOfBits[literalsLengthState];
literalsLengthState =
(int) (literalsLengthNewStates[literalsLengthState] +
peekBits(bitsConsumed, bits, numberOfBits)); // <= 9 bits
bitsConsumed += numberOfBits;
numberOfBits = matchLengthNumbersOfBits[matchLengthState];
matchLengthState = (int) (matchLengthNewStates[matchLengthState] +
peekBits(bitsConsumed, bits,
numberOfBits)); // <= 9 bits
bitsConsumed += numberOfBits;
numberOfBits = offsetCodesNumbersOfBits[offsetCodesState];
offsetCodesState = (int) (offsetCodesNewStates[offsetCodesState] +
peekBits(bitsConsumed, bits,
numberOfBits)); // <= 8 bits
bitsConsumed += numberOfBits;
final int literalOutputLimit = output + literalsLength;
final int matchOutputLimit = literalOutputLimit + matchLength;
verify(matchOutputLimit <= outputLimit, input, OUTPUT_BUFFER_TOO_SMALL);
final int literalEnd = literalsInput + literalsLength;
verify(literalEnd <= literalsLimit, input, INPUT_IS_CORRUPTED);
final int matchAddress = literalOutputLimit - offset;
verify(matchAddress >= outputAbsoluteBaseAddress, input,
INPUT_IS_CORRUPTED);
if (literalOutputLimit > fastOutputLimit) {
executeLastSequence(outputBase, output, literalOutputLimit,
matchOutputLimit, fastOutputLimit, literalsInput,
matchAddress);
} else {
// copy literals. literalOutputLimit <= fastOutputLimit, so we can copy
// long at a time with over-copy
output = copyLiterals(outputBase, literalsBase, output, literalsInput,
literalOutputLimit);
copyMatch(outputBase, fastOutputLimit, output, offset,
matchOutputLimit, matchAddress, matchLength,
fastMatchOutputLimit);
}
output = matchOutputLimit;
literalsInput = literalEnd;
}
}
// last literal segment
output = copyLastLiteral(outputBase, literalsBase, literalsLimit, output,
literalsInput);
return output - outputAddress;
}
private int copyLastLiteral(final byte[] outputBase,
final byte[] literalsBase,
final int literalsLimit, int output,
final int literalsInput) {
final int lastLiteralsSize = literalsLimit - literalsInput;
copyMemory(literalsBase, literalsInput, outputBase, output,
lastLiteralsSize);
output += lastLiteralsSize;
return output;
}
private void copyMatch(final byte[] outputBase, final int fastOutputLimit,
int output, final int offset,
final int matchOutputLimit, int matchAddress,
int matchLength, final int fastMatchOutputLimit) {
matchAddress = copyMatchHead(outputBase, output, offset, matchAddress);
output += SIZE_OF_LONG;
matchLength -= SIZE_OF_LONG; // first 8 bytes copied above
copyMatchTail(outputBase, fastOutputLimit, output, matchOutputLimit,
matchAddress, matchLength, fastMatchOutputLimit);
}
private void copyMatchTail(final byte[] outputBase, final int fastOutputLimit,
int output, final int matchOutputLimit,
int matchAddress, final int matchLength,
final int fastMatchOutputLimit) {
// fastMatchOutputLimit is just fastOutputLimit - SIZE_OF_LONG. It needs to be passed in so that it can be computed once for the
// whole invocation to decompressSequences. Otherwise, we'd just compute it here.
// If matchOutputLimit is < fastMatchOutputLimit, we know that even after the head (8 bytes) has been copied, the output pointer
// will be within fastOutputLimit, so it's safe to copy blindly before checking the limit condition
if (matchOutputLimit < fastMatchOutputLimit) {
int copied = 0;
do {
putLong(outputBase, output, getLong(outputBase, matchAddress));
output += SIZE_OF_LONG;
matchAddress += SIZE_OF_LONG;
copied += SIZE_OF_LONG;
} while (copied < matchLength);
} else {
while (output < fastOutputLimit) {
putLong(outputBase, output, getLong(outputBase, matchAddress));
matchAddress += SIZE_OF_LONG;
output += SIZE_OF_LONG;
}
while (output < matchOutputLimit) {
outputBase[output++] = outputBase[matchAddress++];
}
}
}
private int copyMatchHead(final byte[] outputBase, final int output,
final int offset, int matchAddress) {
// copy match
if (offset < 8) {
// 8 bytes apart so that we can copy long-at-a-time below
final int increment32 = DEC_32_TABLE[offset];
final int decrement64 = DEC_64_TABLE[offset];
outputBase[output] = outputBase[matchAddress];
outputBase[output + 1] = outputBase[matchAddress + 1];
outputBase[output + 2] = outputBase[matchAddress + 2];
outputBase[output + 3] = outputBase[matchAddress + 3];
matchAddress += increment32;
putInt(outputBase, output + 4, getInt(outputBase, matchAddress));
matchAddress -= decrement64;
} else {
putLong(outputBase, output, getLong(outputBase, matchAddress));
matchAddress += SIZE_OF_LONG;
}
return matchAddress;
}
private int copyLiterals(final byte[] outputBase, final byte[] literalsBase,
int output, final int literalsInput,
final int literalOutputLimit) {
int literalInput = literalsInput;
do {
putLong(outputBase, output, getLong(literalsBase, literalInput));
output += SIZE_OF_LONG;
literalInput += SIZE_OF_LONG;
} while (output < literalOutputLimit);
output = literalOutputLimit; // correction in case we over-copied
return output;
}
private int computeMatchLengthTable(final int matchLengthType,
final byte[] inputBase, int input,
final int inputLimit) {
switch (matchLengthType) {
case SEQUENCE_ENCODING_RLE:
verify(input < inputLimit, input, NOT_ENOUGH_INPUT_BYTES);
final byte value = inputBase[input++];
verify(value <= MAX_MATCH_LENGTH_SYMBOL, input,
VALUE_EXCEEDS_EXPECTED_MAXIMUM_VALUE);
FseTableReader.initializeRleTable(matchLengthTable, value);
currentMatchLengthTable = matchLengthTable;
break;
case SEQUENCE_ENCODING_BASIC:
currentMatchLengthTable = DEFAULT_MATCH_LENGTH_TABLE;
break;
case SEQUENCE_ENCODING_REPEAT:
verify(currentMatchLengthTable != null, input,
EXPECTED_MATCH_LENGTH_TABLE_TO_BE_PRESENT);
break;
case SEQUENCE_ENCODING_COMPRESSED:
input +=
fse.readFseTable(matchLengthTable, inputBase, input, inputLimit,
MAX_MATCH_LENGTH_SYMBOL, MATCH_LENGTH_TABLE_LOG);
currentMatchLengthTable = matchLengthTable;
break;
default:
throw fail(input, "Invalid match length encoding type");
}
return input;
}
private int computeOffsetsTable(final int offsetCodesType,
final byte[] inputBase, int input,
final int inputLimit) {
switch (offsetCodesType) {
case SEQUENCE_ENCODING_RLE:
verify(input < inputLimit, input, NOT_ENOUGH_INPUT_BYTES);
final byte value = inputBase[input++];
verify(value <= DEFAULT_MAX_OFFSET_CODE_SYMBOL, input,
VALUE_EXCEEDS_EXPECTED_MAXIMUM_VALUE);
FseTableReader.initializeRleTable(offsetCodesTable, value);
currentOffsetCodesTable = offsetCodesTable;
break;
case SEQUENCE_ENCODING_BASIC:
currentOffsetCodesTable = DEFAULT_OFFSET_CODES_TABLE;
break;
case SEQUENCE_ENCODING_REPEAT:
verify(currentOffsetCodesTable != null, input,
EXPECTED_MATCH_LENGTH_TABLE_TO_BE_PRESENT);
break;
case SEQUENCE_ENCODING_COMPRESSED:
input +=
fse.readFseTable(offsetCodesTable, inputBase, input, inputLimit,
DEFAULT_MAX_OFFSET_CODE_SYMBOL, OFFSET_TABLE_LOG);
currentOffsetCodesTable = offsetCodesTable;
break;
default:
throw fail(input, "Invalid offset code encoding type");
}
return input;
}
private int computeLiteralsTable(final int literalsLengthType,
final byte[] inputBase, int input,
final int inputLimit) {
switch (literalsLengthType) {
case SEQUENCE_ENCODING_RLE:
verify(input < inputLimit, input, NOT_ENOUGH_INPUT_BYTES);
final byte value = inputBase[input++];
verify(value <= MAX_LITERALS_LENGTH_SYMBOL, input,
VALUE_EXCEEDS_EXPECTED_MAXIMUM_VALUE);
FseTableReader.initializeRleTable(literalsLengthTable, value);
currentLiteralsLengthTable = literalsLengthTable;
break;
case SEQUENCE_ENCODING_BASIC:
currentLiteralsLengthTable = DEFAULT_LITERALS_LENGTH_TABLE;
break;
case SEQUENCE_ENCODING_REPEAT:
verify(currentLiteralsLengthTable != null, input,
EXPECTED_MATCH_LENGTH_TABLE_TO_BE_PRESENT);
break;
case SEQUENCE_ENCODING_COMPRESSED:
input +=
fse.readFseTable(literalsLengthTable, inputBase, input, inputLimit,
MAX_LITERALS_LENGTH_SYMBOL,
LITERAL_LENGTH_TABLE_LOG);
currentLiteralsLengthTable = literalsLengthTable;
break;
default:
throw fail(input, "Invalid literals length encoding type");
}
return input;
}
private void executeLastSequence(final byte[] outputBase, int output,
final int literalOutputLimit,
final int matchOutputLimit,
final int fastOutputLimit, int literalInput,
int matchAddress) {
// copy literals
if (output < fastOutputLimit) {
// wild copy
do {
putLong(outputBase, output, getLong(literalsBase, literalInput));
output += SIZE_OF_LONG;
literalInput += SIZE_OF_LONG;
} while (output < fastOutputLimit);
literalInput -= output - fastOutputLimit;
output = fastOutputLimit;
}
while (output < literalOutputLimit) {
outputBase[output] = literalsBase[literalInput];
output++;
literalInput++;
}
// copy match
while (output < matchOutputLimit) {
outputBase[output] = outputBase[matchAddress];
output++;
matchAddress++;
}
}
private int decodeCompressedLiterals(final byte[] inputBase,
final int inputAddress,
final int blockSize,
final int literalsBlockType) {
int input = inputAddress;
verify(blockSize >= 5, input, NOT_ENOUGH_INPUT_BYTES);
// compressed
final int compressedSize;
final int uncompressedSize;
boolean singleStream = false;
final int headerSize;
final int type = (inputBase[input] >> 2) & 0x3;
switch (type) {
case 0:
singleStream = true;
case 1: {
final int header = getInt(inputBase, input);
headerSize = 3;
uncompressedSize = (header >>> 4) & mask(10);
compressedSize = (header >>> 14) & mask(10);
break;
}
case 2: {
final int header = getInt(inputBase, input);
headerSize = 4;
uncompressedSize = (header >>> 4) & mask(14);
compressedSize = (header >>> 18) & mask(14);
break;
}
case 3: {
// read 5 little-endian bytes
final long header = inputBase[input] & 0xFF |
(getInt(inputBase, input + 1) & 0xFFFFFFFFL) << 8;
headerSize = 5;
uncompressedSize = (int) ((header >>> 4) & mask(18));
compressedSize = (int) ((header >>> 22) & mask(18));
break;
}
default:
throw fail(input, "Invalid literals header size type");
}
verify(uncompressedSize <= MAX_BLOCK_SIZE, input,
"Block exceeds maximum size");
verify(headerSize + compressedSize <= blockSize, input, INPUT_IS_CORRUPTED);
input += headerSize;
final int inputLimit = input + compressedSize;
if (literalsBlockType != TREELESS_LITERALS_BLOCK) {
input += huffman.readTable(inputBase, input, compressedSize);
}
literalsBase = literals;
literalsAddress = 0;
literalsLimit = uncompressedSize;
if (singleStream) {
huffman.decodeSingleStream(inputBase, input, inputLimit, literals,
literalsAddress, literalsLimit);
} else {
huffman.decode4Streams(inputBase, input, inputLimit, literals,
literalsAddress, literalsLimit);
}
return headerSize + compressedSize;
}
private int decodeRleLiterals(final byte[] inputBase, final int inputAddress,
final int blockSize) {
int input = inputAddress;
final int outputSize;
final int type = (inputBase[input] >> 2) & 0x3;
switch (type) {
case 0:
case 2:
outputSize = (inputBase[input] & 0xFF) >>> 3;
input++;
break;
case 1:
outputSize = (getShort(inputBase, input) & 0xFFFF) >>> 4;
input += 2;
break;
case 3:
// we need at least 4 bytes (3 for the header, 1 for the payload)
verify(blockSize >= SIZE_OF_INT, input, NOT_ENOUGH_INPUT_BYTES);
outputSize = (getInt(inputBase, input) & 0xFFFFFF) >>> 4;
input += 3;
break;
default:
throw fail(input, "Invalid RLE literals header encoding type");
}
verify(outputSize <= MAX_BLOCK_SIZE, input,
"Output exceeds maximum block size");
final byte value = inputBase[input++];
Arrays.fill(literals, 0, outputSize + SIZE_OF_LONG, value);
literalsBase = literals;
literalsAddress = 0;
literalsLimit = outputSize;
return input - inputAddress;
}
private int decodeRawLiterals(final byte[] inputBase, final int inputAddress,
final int inputLimit) {
int input = inputAddress;
final int type = (inputBase[input] >> 2) & 0x3;
final int literalSize;
switch (type) {
case 0:
case 2:
literalSize = (inputBase[input] & 0xFF) >>> 3;
input++;
break;
case 1:
literalSize = (getShort(inputBase, input) & 0xFFFF) >>> 4;
input += 2;
break;
case 3:
// read 3 little-endian bytes
final int header = ((inputBase[input] & 0xFF) |
((getShort(inputBase, input + 1) & 0xFFFF) << 8));
literalSize = header >>> 4;
input += 3;
break;
default:
throw fail(input, "Invalid raw literals header encoding type");
}
verify(input + literalSize <= inputLimit, input, NOT_ENOUGH_INPUT_BYTES);
// Set literals pointer to [input, literalSize], but only if we can copy 8 bytes at a time during sequence decoding
// Otherwise, copy literals into buffer that's big enough to guarantee that
if (literalSize > (inputLimit - input) - SIZE_OF_LONG) {
literalsBase = literals;
literalsAddress = 0;
literalsLimit = literalSize;
copyMemory(inputBase, input, literals, literalsAddress, literalSize);
Arrays.fill(literals, literalSize, literalSize + SIZE_OF_LONG, (byte) 0);
} else {
literalsBase = inputBase;
literalsAddress = input;
literalsLimit = literalsAddress + literalSize;
}
input += literalSize;
return input - inputAddress;
}
static FrameHeader readFrameHeader(final byte[] inputBase,
final int inputAddress,
final int inputLimit) {
int input = inputAddress;
verify(input < inputLimit, input, NOT_ENOUGH_INPUT_BYTES);
final int frameHeaderDescriptor = inputBase[input++] & 0xFF;
final boolean singleSegment = (frameHeaderDescriptor & 0x20) != 0;
final int dictionaryDescriptor = frameHeaderDescriptor & 0x3;
final int contentSizeDescriptor = frameHeaderDescriptor >>> 6;
final int headerSize = 1 + (singleSegment? 0 : 1) +
(dictionaryDescriptor == 0? 0 :
(1 << (dictionaryDescriptor - 1))) +
(contentSizeDescriptor == 0? (singleSegment? 1 : 0) :
(1 << contentSizeDescriptor));
verify(headerSize <= inputLimit - inputAddress, input,
NOT_ENOUGH_INPUT_BYTES);
// decode window size
int windowSize = -1;
if (!singleSegment) {
final int windowDescriptor = inputBase[input++] & 0xFF;
final int exponent = windowDescriptor >>> 3;
final int mantissa = windowDescriptor & 0x7;
final int base = 1 << (MIN_WINDOW_LOG + exponent);
windowSize = base + (base / 8) * mantissa;
}
// decode dictionary id
int dictionaryId = -1;
switch (dictionaryDescriptor) {
case 1:
dictionaryId = inputBase[input] & 0xFF;
input += SIZE_OF_BYTE;
break;
case 2:
dictionaryId = getShort(inputBase, input) & 0xFFFF;
input += SIZE_OF_SHORT;
break;
case 3:
dictionaryId = getInt(inputBase, input);
input += SIZE_OF_INT;
break;
}
verify(dictionaryId == -1, input, "Custom dictionaries not supported");
// decode content size
int contentSize = -1;
switch (contentSizeDescriptor) {
case 0:
if (singleSegment) {
contentSize = inputBase[input] & 0xFF;
input += SIZE_OF_BYTE;
}
break;
case 1:
contentSize = getShort(inputBase, input) & 0xFFFF;
contentSize += 256;
input += SIZE_OF_SHORT;
break;
case 2:
contentSize = getInt(inputBase, input);
input += SIZE_OF_INT;
break;
case 3:
contentSize = (int) getLong(inputBase, input);
input += SIZE_OF_LONG;
break;
}
final boolean hasChecksum = (frameHeaderDescriptor & 0x4) != 0;
return new FrameHeader(input - inputAddress, windowSize, contentSize,
dictionaryId, hasChecksum);
}
public static int getDecompressedSize(final byte[] inputBase,
final int inputAddress,
final int inputLimit) {
int input = inputAddress;
input += verifyMagic(inputBase, input, inputLimit);
return readFrameHeader(inputBase, input, inputLimit).contentSize;
}
static int verifyMagic(final byte[] inputBase, final int inputAddress,
final int inputLimit) {
verify(inputLimit - inputAddress >= 4, inputAddress,
NOT_ENOUGH_INPUT_BYTES);
final int magic = getInt(inputBase, inputAddress);
if (magic != MAGIC_NUMBER) {
if (magic == V07_MAGIC_NUMBER) {
throw new MalformedInputException(inputAddress,
"Data encoded in unsupported ZSTD v0.7 format");
}
throw new MalformedInputException(inputAddress, "Invalid magic prefix: " +
Integer.toHexString(
magic));
}
return SIZE_OF_INT;
}
}