FiniteStateEntropy.java
/*
* This file is part of Waarp Project (named also Waarp or GG).
*
* Copyright (c) 2019, Waarp SAS, and individual contributors by the @author
* tags. See the COPYRIGHT.txt in the distribution for a full listing of
* individual contributors.
*
* All Waarp Project is free software: you can redistribute it and/or
* modify it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or (at your
* option) any later version.
*
* Waarp is distributed in the hope that it will be useful, but WITHOUT ANY
* WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
* A PARTICULAR PURPOSE. See the GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License along with
* Waarp . If not, see <http://www.gnu.org/licenses/>.
*/
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.waarp.compress.zstdunsafe;
import static org.waarp.compress.zstdunsafe.BitInputStream.*;
import static sun.misc.Unsafe.*;
class FiniteStateEntropy {
public static final int MAX_SYMBOL = 255;
public static final int MAX_TABLE_LOG = 12;
public static final int MIN_TABLE_LOG = 5;
private static final int[] REST_TO_BEAT =
new int[] { 0, 473195, 504333, 520860, 550000, 700000, 750000, 830000 };
private static final short UNASSIGNED = -2;
public static final String OUTPUT_BUFFER_TOO_SMALL =
"Output buffer too small";
private FiniteStateEntropy() {
}
public static int decompress(final FiniteStateEntropy.Table table,
final Object inputBase, final long inputAddress,
final long inputLimit,
final byte[] outputBuffer) {
final long outputAddress = ARRAY_BYTE_BASE_OFFSET;
final long outputLimit = outputAddress + outputBuffer.length;
long output = outputAddress;
// initialize bit stream
final BitInputStream.Initializer initializer =
new BitInputStream.Initializer(inputBase, inputAddress, inputLimit);
initializer.initialize();
int bitsConsumed = initializer.getBitsConsumed();
long currentAddress = initializer.getCurrentAddress();
long bits = initializer.getBits();
// initialize first FSE stream
int state1 = (int) peekBits(bitsConsumed, bits, table.log2Size);
bitsConsumed += table.log2Size;
BitInputStream.Loader loader =
new BitInputStream.Loader(inputBase, inputAddress, currentAddress, bits,
bitsConsumed);
loader.load();
bits = loader.getBits();
bitsConsumed = loader.getBitsConsumed();
currentAddress = loader.getCurrentAddress();
// initialize second FSE stream
int state2 = (int) peekBits(bitsConsumed, bits, table.log2Size);
bitsConsumed += table.log2Size;
loader =
new BitInputStream.Loader(inputBase, inputAddress, currentAddress, bits,
bitsConsumed);
loader.load();
bits = loader.getBits();
bitsConsumed = loader.getBitsConsumed();
currentAddress = loader.getCurrentAddress();
final byte[] symbols = table.symbol;
final byte[] numbersOfBits = table.numberOfBits;
final int[] newStates = table.newState;
// decode 4 symbols per loop
while (output <= outputLimit - 4) {
int numberOfBits;
UnsafeUtil.UNSAFE.putByte(outputBuffer, output, symbols[state1]);
numberOfBits = numbersOfBits[state1];
state1 = (int) (newStates[state1] +
peekBits(bitsConsumed, bits, numberOfBits));
bitsConsumed += numberOfBits;
UnsafeUtil.UNSAFE.putByte(outputBuffer, output + 1, symbols[state2]);
numberOfBits = numbersOfBits[state2];
state2 = (int) (newStates[state2] +
peekBits(bitsConsumed, bits, numberOfBits));
bitsConsumed += numberOfBits;
UnsafeUtil.UNSAFE.putByte(outputBuffer, output + 2, symbols[state1]);
numberOfBits = numbersOfBits[state1];
state1 = (int) (newStates[state1] +
peekBits(bitsConsumed, bits, numberOfBits));
bitsConsumed += numberOfBits;
UnsafeUtil.UNSAFE.putByte(outputBuffer, output + 3, symbols[state2]);
numberOfBits = numbersOfBits[state2];
state2 = (int) (newStates[state2] +
peekBits(bitsConsumed, bits, numberOfBits));
bitsConsumed += numberOfBits;
output += Constants.SIZE_OF_INT;
loader =
new BitInputStream.Loader(inputBase, inputAddress, currentAddress,
bits, bitsConsumed);
final boolean done = loader.load();
bitsConsumed = loader.getBitsConsumed();
bits = loader.getBits();
currentAddress = loader.getCurrentAddress();
if (done) {
break;
}
}
while (true) {
Util.verify(output <= outputLimit - 2, inputAddress,
"Output buffer is too small");
UnsafeUtil.UNSAFE.putByte(outputBuffer, output++, symbols[state1]);
final int numberOfBits = numbersOfBits[state1];
state1 = (int) (newStates[state1] +
peekBits(bitsConsumed, bits, numberOfBits));
bitsConsumed += numberOfBits;
loader =
new BitInputStream.Loader(inputBase, inputAddress, currentAddress,
bits, bitsConsumed);
loader.load();
bitsConsumed = loader.getBitsConsumed();
bits = loader.getBits();
currentAddress = loader.getCurrentAddress();
if (loader.isOverflow()) {
UnsafeUtil.UNSAFE.putByte(outputBuffer, output++, symbols[state2]);
break;
}
Util.verify(output <= outputLimit - 2, inputAddress,
"Output buffer is too small");
UnsafeUtil.UNSAFE.putByte(outputBuffer, output++, symbols[state2]);
final int numberOfBits1 = numbersOfBits[state2];
state2 = (int) (newStates[state2] +
peekBits(bitsConsumed, bits, numberOfBits1));
bitsConsumed += numberOfBits1;
loader =
new BitInputStream.Loader(inputBase, inputAddress, currentAddress,
bits, bitsConsumed);
loader.load();
bitsConsumed = loader.getBitsConsumed();
bits = loader.getBits();
currentAddress = loader.getCurrentAddress();
if (loader.isOverflow()) {
UnsafeUtil.UNSAFE.putByte(outputBuffer, output++, symbols[state1]);
break;
}
}
return (int) (output - outputAddress);
}
public static int compress(final Object outputBase, final long outputAddress,
final int outputSize, final byte[] input,
final int inputSize,
final FseCompressionTable table) {
return compress(outputBase, outputAddress, outputSize, input,
ARRAY_BYTE_BASE_OFFSET, inputSize, table);
}
public static int compress(final Object outputBase, final long outputAddress,
final int outputSize, final Object inputBase,
final long inputAddress, int inputSize,
final FseCompressionTable table) {
Util.checkArgument(outputSize >= Constants.SIZE_OF_LONG,
OUTPUT_BUFFER_TOO_SMALL);
long input = inputAddress + inputSize;
if (inputSize <= 2) {
return 0;
}
final BitOutputStream stream =
new BitOutputStream(outputBase, outputAddress, outputSize);
int state1;
int state2;
if ((inputSize & 1) != 0) {
input--;
state1 = table.begin(UnsafeUtil.UNSAFE.getByte(inputBase, input));
input--;
state2 = table.begin(UnsafeUtil.UNSAFE.getByte(inputBase, input));
input--;
state1 = table.encode(stream, state1,
UnsafeUtil.UNSAFE.getByte(inputBase, input));
stream.flush();
} else {
input--;
state2 = table.begin(UnsafeUtil.UNSAFE.getByte(inputBase, input));
input--;
state1 = table.begin(UnsafeUtil.UNSAFE.getByte(inputBase, input));
}
// join to mod 4
inputSize -= 2;
if ((inputSize & 2) != 0) { /* test bit 2 */
input--;
state2 = table.encode(stream, state2,
UnsafeUtil.UNSAFE.getByte(inputBase, input));
input--;
state1 = table.encode(stream, state1,
UnsafeUtil.UNSAFE.getByte(inputBase, input));
stream.flush();
}
// 2 or 4 encoding per loop
while (input > inputAddress) {
input--;
state2 = table.encode(stream, state2,
UnsafeUtil.UNSAFE.getByte(inputBase, input));
input--;
state1 = table.encode(stream, state1,
UnsafeUtil.UNSAFE.getByte(inputBase, input));
input--;
state2 = table.encode(stream, state2,
UnsafeUtil.UNSAFE.getByte(inputBase, input));
input--;
state1 = table.encode(stream, state1,
UnsafeUtil.UNSAFE.getByte(inputBase, input));
stream.flush();
}
table.finish(stream, state2);
table.finish(stream, state1);
return stream.close();
}
public static int optimalTableLog(final int maxTableLog, final int inputSize,
final int maxSymbol) {
if (inputSize <= 1) {
throw new IllegalArgumentException(); // not supported. Use RLE instead
}
int result = maxTableLog;
result = Math.min(result, Util.highestBit((inputSize - 1)) -
2); // we may be able to reduce accuracy if input is small
// Need a minimum to safely represent all symbol values
result = Math.max(result, Util.minTableLog(inputSize, maxSymbol));
result = Math.max(result, MIN_TABLE_LOG);
result = Math.min(result, MAX_TABLE_LOG);
return result;
}
public static void normalizeCounts(final short[] normalizedCounts,
final int tableLog, final int[] counts,
final int total, final int maxSymbol) {
Util.checkArgument(tableLog >= MIN_TABLE_LOG, "Unsupported FSE table size");
Util.checkArgument(tableLog <= MAX_TABLE_LOG, "FSE table size too large");
Util.checkArgument(tableLog >= Util.minTableLog(total, maxSymbol),
"FSE table size too small");
final long scale = 62L - tableLog;
final long step = (1L << 62) / total;
final long vstep = 1L << (scale - 20);
int stillToDistribute = 1 << tableLog;
int largest = 0;
short largestProbability = 0;
final int lowThreshold = total >>> tableLog;
for (int symbol = 0; symbol <= maxSymbol; symbol++) {
if (counts[symbol] == total) {
throw new IllegalArgumentException(); // TODO: should have been RLE-compressed by upper layers
}
if (counts[symbol] == 0) {
normalizedCounts[symbol] = 0;
continue;
}
if (counts[symbol] <= lowThreshold) {
normalizedCounts[symbol] = -1;
stillToDistribute--;
} else {
short probability = (short) ((counts[symbol] * step) >>> scale);
if (probability < 8) {
final long restToBeat = vstep * REST_TO_BEAT[probability];
final long delta =
counts[symbol] * step - (((long) probability) << scale);
if (delta > restToBeat) {
probability++;
}
}
if (probability > largestProbability) {
largestProbability = probability;
largest = symbol;
}
normalizedCounts[symbol] = probability;
stillToDistribute -= probability;
}
}
if (-stillToDistribute >= (normalizedCounts[largest] >>> 1)) {
// corner case. Need another normalization method
// TODO size_t const errorCode = FSE_normalizeM2(normalizedCounter, tableLog, count, total, maxSymbolValue);
normalizeCounts2(normalizedCounts, tableLog, counts, total, maxSymbol);
} else {
normalizedCounts[largest] += (short) stillToDistribute;
}
}
private static void normalizeCounts2(final short[] normalizedCounts,
final int tableLog, final int[] counts,
int total, final int maxSymbol) {
int distributed = 0;
final int lowThreshold = total >>>
tableLog; // minimum count below which frequency in the normalized table is "too small" (~ < 1)
int lowOne = (total * 3) >>> (tableLog +
1); // 1.5 * lowThreshold. If count in (lowThreshold, lowOne] => assign frequency 1
for (int i = 0; i <= maxSymbol; i++) {
if (counts[i] == 0) {
normalizedCounts[i] = 0;
} else if (counts[i] <= lowThreshold) {
normalizedCounts[i] = -1;
distributed++;
total -= counts[i];
} else if (counts[i] <= lowOne) {
normalizedCounts[i] = 1;
distributed++;
total -= counts[i];
} else {
normalizedCounts[i] = UNASSIGNED;
}
}
final int normalizationFactor = 1 << tableLog;
int toDistribute = normalizationFactor - distributed;
if ((total / toDistribute) > lowOne) {
/* risk of rounding to zero */
lowOne = ((total * 3) / (toDistribute * 2));
for (int i = 0; i <= maxSymbol; i++) {
if ((normalizedCounts[i] == UNASSIGNED) && (counts[i] <= lowOne)) {
normalizedCounts[i] = 1;
distributed++;
total -= counts[i];
}
}
toDistribute = normalizationFactor - distributed;
}
if (distributed == maxSymbol + 1) {
// all values are pretty poor
// probably incompressible data (should have already been detected);
// find max, then give all remaining points to max
int maxValue = 0;
int maxCount = 0;
for (int i = 0; i <= maxSymbol; i++) {
if (counts[i] > maxCount) {
maxValue = i;
maxCount = counts[i];
}
}
normalizedCounts[maxValue] += (short) toDistribute;
return;
}
if (total == 0) {
// all of the symbols were low enough for the lowOne or lowThreshold
for (int i = 0; toDistribute > 0; i = (i + 1) % (maxSymbol + 1)) {
if (normalizedCounts[i] > 0) {
toDistribute--;
normalizedCounts[i]++;
}
}
return;
}
// TODO: simplify/document this code
final long vStepLog = 62 - tableLog;
final long mid = (1L << (vStepLog - 1)) - 1;
final long rStep = (((1L << vStepLog) * toDistribute) + mid) /
total; /* scale on remaining */
long tmpTotal = mid;
for (int i = 0; i <= maxSymbol; i++) {
if (normalizedCounts[i] == UNASSIGNED) {
final long end = tmpTotal + (counts[i] * rStep);
final int sStart = (int) (tmpTotal >>> vStepLog);
final int sEnd = (int) (end >>> vStepLog);
final int weight = sEnd - sStart;
if (weight < 1) {
throw new AssertionError();
}
normalizedCounts[i] = (short) weight;
tmpTotal = end;
}
}
}
public static int writeNormalizedCounts(final Object outputBase,
final long outputAddress,
final int outputSize,
final short[] normalizedCounts,
final int maxSymbol,
final int tableLog) {
Util.checkArgument(tableLog <= MAX_TABLE_LOG, "FSE table too large");
Util.checkArgument(tableLog >= MIN_TABLE_LOG, "FSE table too small");
long output = outputAddress;
final long outputLimit = outputAddress + outputSize;
final int tableSize = 1 << tableLog;
int bitCount = 0;
// encode table size
int bitStream = (tableLog - MIN_TABLE_LOG);
bitCount += 4;
int remaining = tableSize + 1; // +1 for extra accuracy
int threshold = tableSize;
int tableBitCount = tableLog + 1;
int symbol = 0;
boolean previousIs0 = false;
while (remaining > 1) {
if (previousIs0) {
// From RFC 8478, section 4.1.1:
// When a symbol has a probability of zero, it is followed by a 2-bit
// repeat flag. This repeat flag tells how many probabilities of zeroes
// follow the current one. It provides a number ranging from 0 to 3.
// If it is a 3, another 2-bit repeat flag follows, and so on.
int start = symbol;
// find run of symbols with count 0
while (normalizedCounts[symbol] == 0) {
symbol++;
}
// encode in batches if 8 repeat sequences in one shot (representing 24 symbols total)
while (symbol >= start + 24) {
start += 24;
bitStream |= (0xffff << bitCount);
Util.checkArgument(output + Constants.SIZE_OF_SHORT <= outputLimit,
OUTPUT_BUFFER_TOO_SMALL);
UnsafeUtil.UNSAFE.putShort(outputBase, output, (short) bitStream);
output += Constants.SIZE_OF_SHORT;
// flush now, so no need to increase bitCount by 16
bitStream >>>= Short.SIZE;
}
// encode remaining in batches of 3 symbols
while (symbol >= start + 3) {
start += 3;
bitStream |= 0x3 << bitCount;
bitCount += 2;
}
// encode tail
bitStream |= (symbol - start) << bitCount;
bitCount += 2;
// flush bitstream if necessary
if (bitCount > 16) {
Util.checkArgument(output + Constants.SIZE_OF_SHORT <= outputLimit,
OUTPUT_BUFFER_TOO_SMALL);
UnsafeUtil.UNSAFE.putShort(outputBase, output, (short) bitStream);
output += Constants.SIZE_OF_SHORT;
bitStream >>>= Short.SIZE;
bitCount -= Short.SIZE;
}
}
int count = normalizedCounts[symbol++];
final int max = (2 * threshold - 1) - remaining;
remaining -= count < 0? -count : count;
count++; /* +1 for extra accuracy */
if (count >= threshold) {
count += max;
}
bitStream |= count << bitCount;
bitCount += tableBitCount;
bitCount -= (count < max? 1 : 0);
previousIs0 = (count == 1);
if (remaining < 1) {
throw new AssertionError();
}
while (remaining < threshold) {
tableBitCount--;
threshold >>= 1;
}
// flush bitstream if necessary
if (bitCount > 16) {
Util.checkArgument(output + Constants.SIZE_OF_SHORT <= outputLimit,
OUTPUT_BUFFER_TOO_SMALL);
UnsafeUtil.UNSAFE.putShort(outputBase, output, (short) bitStream);
output += Constants.SIZE_OF_SHORT;
bitStream >>>= Short.SIZE;
bitCount -= Short.SIZE;
}
}
// flush remaining bitstream
Util.checkArgument(output + Constants.SIZE_OF_SHORT <= outputLimit,
OUTPUT_BUFFER_TOO_SMALL);
UnsafeUtil.UNSAFE.putShort(outputBase, output, (short) bitStream);
output += (bitCount + 7) / 8;
Util.checkArgument(symbol <= maxSymbol + 1, "Error"); // TODO
return (int) (output - outputAddress);
}
public static final class Table {
final int[] newState;
final byte[] symbol;
final byte[] numberOfBits;
int log2Size;
public Table(final int log2Capacity) {
final int capacity = 1 << log2Capacity;
newState = new int[capacity];
symbol = new byte[capacity];
numberOfBits = new byte[capacity];
}
public Table(final int log2Size, final int[] newState, final byte[] symbol,
final byte[] numberOfBits) {
final int size = 1 << log2Size;
if (newState.length != size || symbol.length != size ||
numberOfBits.length != size) {
throw new IllegalArgumentException(
"Expected arrays to match provided size");
}
this.log2Size = log2Size;
this.newState = newState;
this.symbol = symbol;
this.numberOfBits = numberOfBits;
}
}
}