HuffmanCompressionTable.java
/*
* This file is part of Waarp Project (named also Waarp or GG).
*
* Copyright (c) 2019, Waarp SAS, and individual contributors by the @author
* tags. See the COPYRIGHT.txt in the distribution for a full listing of
* individual contributors.
*
* All Waarp Project is free software: you can redistribute it and/or
* modify it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or (at your
* option) any later version.
*
* Waarp is distributed in the hope that it will be useful, but WITHOUT ANY
* WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
* A PARTICULAR PURPOSE. See the GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License along with
* Waarp . If not, see <http://www.gnu.org/licenses/>.
*/
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.waarp.compress.zstdunsafe;
import java.util.Arrays;
final class HuffmanCompressionTable {
private final short[] values;
private final byte[] numberOfBits;
private int maxSymbol;
private int maxNumberOfBits;
public HuffmanCompressionTable(final int capacity) {
this.values = new short[capacity];
this.numberOfBits = new byte[capacity];
}
public static int optimalNumberOfBits(final int maxNumberOfBits,
final int inputSize,
final int maxSymbol) {
if (inputSize <= 1) {
throw new IllegalArgumentException(); // not supported. Use RLE instead
}
int result = maxNumberOfBits;
result = Math.min(result, Util.highestBit((inputSize - 1)) -
1); // we may be able to reduce accuracy if input is small
// Need a minimum to safely represent all symbol values
result = Math.max(result, Util.minTableLog(inputSize, maxSymbol));
result =
Math.max(result, Huffman.MIN_TABLE_LOG); // absolute minimum for Huffman
result =
Math.min(result, Huffman.MAX_TABLE_LOG); // absolute maximum for Huffman
return result;
}
public void initialize(final int[] counts, final int maxSymbol,
int maxNumberOfBits,
final HuffmanCompressionTableWorkspace workspace) {
Util.checkArgument(maxSymbol <= Huffman.MAX_SYMBOL,
"Max symbol value too large");
workspace.reset();
final NodeTable nodeTable = workspace.nodeTable;
nodeTable.reset();
final int lastNonZero = buildTree(counts, maxSymbol, nodeTable);
// enforce max table log
maxNumberOfBits =
setMaxHeight(nodeTable, lastNonZero, maxNumberOfBits, workspace);
Util.checkArgument(maxNumberOfBits <= Huffman.MAX_TABLE_LOG,
"Max number of bits larger than max table size");
// populate table
final int symbolCount = maxSymbol + 1;
for (int node = 0; node < symbolCount; node++) {
final int symbol = nodeTable.symbols[node];
numberOfBits[symbol] = nodeTable.numberOfBits[node];
}
final short[] entriesPerRank = workspace.entriesPerRank;
final short[] valuesPerRank = workspace.valuesPerRank;
for (int n = 0; n <= lastNonZero; n++) {
entriesPerRank[nodeTable.numberOfBits[n]]++;
}
// determine starting value per rank
short startingValue = 0;
for (int rank = maxNumberOfBits; rank > 0; rank--) {
valuesPerRank[rank] =
startingValue; // get starting value within each rank
startingValue += entriesPerRank[rank];
startingValue >>>= 1;
}
for (int n = 0; n <= maxSymbol; n++) {
values[n] =
valuesPerRank[numberOfBits[n]]++; // assign value within rank, symbol order
}
this.maxSymbol = maxSymbol;
this.maxNumberOfBits = maxNumberOfBits;
}
private int buildTree(final int[] counts, final int maxSymbol,
final NodeTable nodeTable) {
// populate the leaves of the node table from the histogram of counts
// in descending order by count, ascending by symbol value.
short current = 0;
for (int symbol = 0; symbol <= maxSymbol; symbol++) {
final int count = counts[symbol];
// simple insertion sort
int position = current;
while (position > 1 && count > nodeTable.count[position - 1]) {
nodeTable.copyNode(position - 1, position);
position--;
}
nodeTable.count[position] = count;
nodeTable.symbols[position] = symbol;
current++;
}
int lastNonZero = maxSymbol;
while (nodeTable.count[lastNonZero] == 0) {
lastNonZero--;
}
// populate the non-leaf nodes
final short nonLeafStart = Huffman.MAX_SYMBOL_COUNT;
current = nonLeafStart;
int currentLeaf = lastNonZero;
// combine the two smallest leaves to create the first intermediate node
int currentNonLeaf = current;
nodeTable.count[current] =
nodeTable.count[currentLeaf] + nodeTable.count[currentLeaf - 1];
nodeTable.parents[currentLeaf] = current;
nodeTable.parents[currentLeaf - 1] = current;
current++;
currentLeaf -= 2;
final int root = Huffman.MAX_SYMBOL_COUNT + lastNonZero - 1;
// fill in sentinels
for (int n = current; n <= root; n++) {
nodeTable.count[n] = 1 << 30;
}
// create parents
while (current <= root) {
final int child1;
if (currentLeaf >= 0 &&
nodeTable.count[currentLeaf] < nodeTable.count[currentNonLeaf]) {
child1 = currentLeaf--;
} else {
child1 = currentNonLeaf++;
}
final int child2;
if (currentLeaf >= 0 &&
nodeTable.count[currentLeaf] < nodeTable.count[currentNonLeaf]) {
child2 = currentLeaf--;
} else {
child2 = currentNonLeaf++;
}
nodeTable.count[current] =
nodeTable.count[child1] + nodeTable.count[child2];
nodeTable.parents[child1] = current;
nodeTable.parents[child2] = current;
current++;
}
// distribute weights
nodeTable.numberOfBits[root] = 0;
for (int n = root - 1; n >= nonLeafStart; n--) {
final short parent = nodeTable.parents[n];
nodeTable.numberOfBits[n] = (byte) (nodeTable.numberOfBits[parent] + 1);
}
for (int n = 0; n <= lastNonZero; n++) {
final short parent = nodeTable.parents[n];
nodeTable.numberOfBits[n] = (byte) (nodeTable.numberOfBits[parent] + 1);
}
return lastNonZero;
}
// TODO: consider encoding 2 symbols at a time
// - need a table with 256x256 entries with
// - the concatenated bits for the corresponding pair of symbols
// - the sum of bits for the corresponding pair of symbols
// - read 2 symbols at a time from the input
public void encodeSymbol(final BitOutputStream output, final int symbol) {
output.addBitsFast(values[symbol], numberOfBits[symbol]);
}
public int write(final Object outputBase, final long outputAddress,
final int outputSize,
final HuffmanTableWriterWorkspace workspace) {
final byte[] weights = workspace.weights;
long output = outputAddress;
final int numberOfBits1 = this.maxNumberOfBits;
final int maxSymbol1 = this.maxSymbol;
// convert to weights per RFC 8478 section 4.2.1
for (int symbol = 0; symbol < maxSymbol1; symbol++) {
final int bits = numberOfBits[symbol];
if (bits == 0) {
weights[symbol] = 0;
} else {
weights[symbol] = (byte) (numberOfBits1 + 1 - bits);
}
}
// attempt weights compression by FSE
int size = compressWeights(outputBase, output + 1, outputSize - 1, weights,
maxSymbol1, workspace);
if (maxSymbol1 > 127 && size > 127) {
// This should never happen. Since weights are in the range [0, 12], they can be compressed optimally to ~3.7 bits per symbol for a uniform distribution.
// Since maxSymbol has to be <= MAX_SYMBOL (255), this is 119 bytes + FSE headers.
throw new AssertionError();
}
if (size != 0 && size != 1 && size < maxSymbol1 / 2) {
// Go with FSE only if:
// - the weights are compressible
// - the compressed size is better than what we'd get with the raw encoding below
// - the compressed size is <= 127 bytes, which is the most that the encoding can hold for FSE-compressed weights (see RFC 8478 section 4.2.1.1). This is implied
// by the maxSymbol / 2 check, since maxSymbol must be <= 255
UnsafeUtil.UNSAFE.putByte(outputBase, output, (byte) size);
return size + 1; // header + size
} else {
// Use raw encoding (4 bits per entry)
// #entries = #symbols - 1 since last symbol is implicit. Thus, #entries = (maxSymbol + 1) - 1 = maxSymbol
size = (maxSymbol1 + 1) / 2; // ceil(#entries / 2)
Util.checkArgument(size + 1 /* header */ <= outputSize,
"Output size too small"); // 2 entries per byte
// encode number of symbols
// header = #entries + 127 per RFC
UnsafeUtil.UNSAFE.putByte(outputBase, output, (byte) (127 + maxSymbol1));
output++;
weights[maxSymbol1] =
0; // last weight is implicit, so set to 0 so that it doesn't get encoded below
for (int i = 0; i < maxSymbol1; i += 2) {
UnsafeUtil.UNSAFE.putByte(outputBase, output,
(byte) ((weights[i] << 4) +
(weights[i + 1] & 0xFF)));
output++;
}
return (int) (output - outputAddress);
}
}
/**
* Can this table encode all symbols with non-zero count?
*/
public boolean isValid(final int[] counts, final int maxSymbol) {
if (maxSymbol > this.maxSymbol) {
// some non-zero count symbols cannot be encoded by the current table
return false;
}
for (int symbol = 0; symbol <= maxSymbol; ++symbol) {
if (counts[symbol] != 0 && numberOfBits[symbol] == 0) {
return false;
}
}
return true;
}
public int estimateCompressedSize(final int[] counts, final int maxSymbol) {
int numberOfBits1 = 0;
for (int symbol = 0; symbol <= Math.min(maxSymbol, this.maxSymbol);
symbol++) {
numberOfBits1 += this.numberOfBits[symbol] * counts[symbol];
}
return numberOfBits1 >>> 3; // convert to bytes
}
// http://fastcompression.blogspot.com/2015/07/huffman-revisited-part-3-depth-limited.html
private static int setMaxHeight(final NodeTable nodeTable,
final int lastNonZero,
final int maxNumberOfBits,
final HuffmanCompressionTableWorkspace workspace) {
final int largestBits = nodeTable.numberOfBits[lastNonZero];
if (largestBits <= maxNumberOfBits) {
return largestBits; // early exit: no elements > maxNumberOfBits
}
// there are several too large elements (at least >= 2)
int totalCost = 0;
final int baseCost = 1 << (largestBits - maxNumberOfBits);
int n = lastNonZero;
while (nodeTable.numberOfBits[n] > maxNumberOfBits) {
totalCost += baseCost - (1 << (largestBits - nodeTable.numberOfBits[n]));
nodeTable.numberOfBits[n] = (byte) maxNumberOfBits;
n--;
} // n stops at nodeTable.numberOfBits[n + offset] <= maxNumberOfBits
while (nodeTable.numberOfBits[n] == maxNumberOfBits) {
n--; // n ends at index of smallest symbol using < maxNumberOfBits
}
// renormalize totalCost
totalCost >>>= (largestBits -
maxNumberOfBits); // note: totalCost is necessarily a multiple of baseCost
// repay normalized cost
final int noSymbol = 0xF0F0F0F0;
final int[] rankLast = workspace.rankLast;
Arrays.fill(rankLast, noSymbol);
// Get pos of last (smallest) symbol per rank
int currentNbBits = maxNumberOfBits;
for (int pos = n; pos >= 0; pos--) {
if (nodeTable.numberOfBits[pos] >= currentNbBits) {
continue;
}
currentNbBits = nodeTable.numberOfBits[pos]; // < maxNumberOfBits
rankLast[maxNumberOfBits - currentNbBits] = pos;
}
while (totalCost > 0) {
int numberOfBitsToDecrease = Util.highestBit(totalCost) + 1;
for (; numberOfBitsToDecrease > 1; numberOfBitsToDecrease--) {
final int highPosition = rankLast[numberOfBitsToDecrease];
final int lowPosition = rankLast[numberOfBitsToDecrease - 1];
if (highPosition == noSymbol) {
continue;
}
if (lowPosition == noSymbol) {
break;
}
final int highTotal = nodeTable.count[highPosition];
final int lowTotal = 2 * nodeTable.count[lowPosition];
if (highTotal <= lowTotal) {
break;
}
}
// only triggered when no more rank 1 symbol left => find closest one (note : there is necessarily at least one !)
// HUF_MAX_TABLELOG test just to please gcc 5+; but it should not be necessary
while ((numberOfBitsToDecrease <= Huffman.MAX_TABLE_LOG) &&
(rankLast[numberOfBitsToDecrease] == noSymbol)) {
numberOfBitsToDecrease++;
}
totalCost -= 1 << (numberOfBitsToDecrease - 1);
if (rankLast[numberOfBitsToDecrease - 1] == noSymbol) {
rankLast[numberOfBitsToDecrease - 1] =
rankLast[numberOfBitsToDecrease]; // this rank is no longer empty
}
nodeTable.numberOfBits[rankLast[numberOfBitsToDecrease]]++;
if (rankLast[numberOfBitsToDecrease] ==
0) { /* special case, reached largest symbol */
rankLast[numberOfBitsToDecrease] = noSymbol;
} else {
rankLast[numberOfBitsToDecrease]--;
if (nodeTable.numberOfBits[rankLast[numberOfBitsToDecrease]] !=
maxNumberOfBits - numberOfBitsToDecrease) {
rankLast[numberOfBitsToDecrease] =
noSymbol; // this rank is now empty
}
}
}
while (totalCost < 0) { // Sometimes, cost correction overshoot
if (rankLast[1] ==
noSymbol) { /* special case : no rank 1 symbol (using maxNumberOfBits-1); let's create one from largest rank 0 (using maxNumberOfBits) */
while (nodeTable.numberOfBits[n] == maxNumberOfBits) {
n--;
}
nodeTable.numberOfBits[n + 1]--;
rankLast[1] = n + 1;
totalCost++;
continue;
}
nodeTable.numberOfBits[rankLast[1] + 1]--;
rankLast[1]++;
totalCost++;
}
return maxNumberOfBits;
}
/**
* All elements within weightTable must be <= Huffman.MAX_TABLE_LOG
*/
private static int compressWeights(final Object outputBase,
final long outputAddress,
final int outputSize, final byte[] weights,
final int weightsLength,
final HuffmanTableWriterWorkspace workspace) {
if (weightsLength <= 1) {
return 0; // Not compressible
}
// Scan input and build symbol stats
final int[] counts = workspace.counts;
Histogram.count(weights, weightsLength, counts);
final int maxSymbol =
Histogram.findMaxSymbol(counts, Huffman.MAX_TABLE_LOG);
final int maxCount = Histogram.findLargestCount(counts, maxSymbol);
if (maxCount == weightsLength) {
return 1; // only a single symbol in source
}
if (maxCount == 1) {
return 0; // each symbol present maximum once => not compressible
}
final short[] normalizedCounts = workspace.normalizedCounts;
final int tableLog =
FiniteStateEntropy.optimalTableLog(Huffman.MAX_FSE_TABLE_LOG,
weightsLength, maxSymbol);
FiniteStateEntropy.normalizeCounts(normalizedCounts, tableLog, counts,
weightsLength, maxSymbol);
long output = outputAddress;
final long outputLimit = outputAddress + outputSize;
// Write table description header
final int headerSize =
FiniteStateEntropy.writeNormalizedCounts(outputBase, output, outputSize,
normalizedCounts, maxSymbol,
tableLog);
output += headerSize;
// Compress
final FseCompressionTable compressionTable = workspace.fseTable;
compressionTable.initialize(normalizedCounts, maxSymbol, tableLog);
final int compressedSize = FiniteStateEntropy.compress(outputBase, output,
(int) (outputLimit -
output),
weights,
weightsLength,
compressionTable);
if (compressedSize == 0) {
return 0;
}
output += compressedSize;
return (int) (output - outputAddress);
}
}