View Javadoc
1   /*
2    * This file is part of Waarp Project (named also Waarp or GG).
3    *
4    *  Copyright (c) 2019, Waarp SAS, and individual contributors by the @author
5    *  tags. See the COPYRIGHT.txt in the distribution for a full listing of
6    * individual contributors.
7    *
8    *  All Waarp Project is free software: you can redistribute it and/or
9    * modify it under the terms of the GNU General Public License as published by
10   * the Free Software Foundation, either version 3 of the License, or (at your
11   * option) any later version.
12   *
13   * Waarp is distributed in the hope that it will be useful, but WITHOUT ANY
14   * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
15   * A PARTICULAR PURPOSE. See the GNU General Public License for more details.
16   *
17   *  You should have received a copy of the GNU General Public License along with
18   * Waarp . If not, see <http://www.gnu.org/licenses/>.
19   */
20  
21  /*
22   * Licensed under the Apache License, Version 2.0 (the "License");
23   * you may not use this file except in compliance with the License.
24   * You may obtain a copy of the License at
25   *
26   *     http://www.apache.org/licenses/LICENSE-2.0
27   *
28   * Unless required by applicable law or agreed to in writing, software
29   * distributed under the License is distributed on an "AS IS" BASIS,
30   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31   * See the License for the specific language governing permissions and
32   * limitations under the License.
33   */
34  package org.waarp.compress.zstdsafe;
35  
36  import java.util.Arrays;
37  
38  import static org.waarp.compress.zstdsafe.Huffman.*;
39  import static org.waarp.compress.zstdsafe.Util.*;
40  
41  final class HuffmanCompressionTable {
42    private final short[] values;
43    private final byte[] numberOfBits;
44  
45    private int maxSymbol;
46    private int maxNumberOfBits;
47  
48    public HuffmanCompressionTable(final int capacity) {
49      this.values = new short[capacity];
50      this.numberOfBits = new byte[capacity];
51    }
52  
53    public static int optimalNumberOfBits(final int maxNumberOfBits,
54                                          final int inputSize,
55                                          final int maxSymbol) {
56      if (inputSize <= 1) {
57        throw new IllegalArgumentException(); // not supported. Use RLE instead
58      }
59  
60      int result = maxNumberOfBits;
61  
62      result = Math.min(result, Util.highestBit((inputSize - 1)) -
63                                1); // we may be able to reduce accuracy if input is small
64  
65      // Need a minimum to safely represent all symbol values
66      result = Math.max(result, minTableLog(inputSize, maxSymbol));
67  
68      result = Math.max(result, MIN_TABLE_LOG); // absolute minimum for Huffman
69      result = Math.min(result, MAX_TABLE_LOG); // absolute maximum for Huffman
70  
71      return result;
72    }
73  
74    public void initialize(final int[] counts, final int maxSymbol,
75                           int maxNumberOfBits,
76                           final HuffmanCompressionTableWorkspace workspace) {
77      checkArgument(maxSymbol <= MAX_SYMBOL, "Max symbol value too large");
78  
79      workspace.reset();
80  
81      final NodeTable nodeTable = workspace.nodeTable;
82      nodeTable.reset();
83  
84      final int lastNonZero = buildTree(counts, maxSymbol, nodeTable);
85  
86      // enforce max table log
87      maxNumberOfBits =
88          setMaxHeight(nodeTable, lastNonZero, maxNumberOfBits, workspace);
89      checkArgument(maxNumberOfBits <= MAX_TABLE_LOG,
90                    "Max number of bits larger than max table size");
91  
92      // populate table
93      final int symbolCount = maxSymbol + 1;
94      for (int node = 0; node < symbolCount; node++) {
95        final int symbol = nodeTable.symbols[node];
96        numberOfBits[symbol] = nodeTable.numberOfBits[node];
97      }
98  
99      final short[] entriesPerRank = workspace.entriesPerRank;
100     final short[] valuesPerRank = workspace.valuesPerRank;
101 
102     for (int n = 0; n <= lastNonZero; n++) {
103       entriesPerRank[nodeTable.numberOfBits[n]]++;
104     }
105 
106     // determine starting value per rank
107     short startingValue = 0;
108     for (int rank = maxNumberOfBits; rank > 0; rank--) {
109       valuesPerRank[rank] =
110           startingValue; // get starting value within each rank
111       startingValue += entriesPerRank[rank];
112       startingValue >>>= 1;
113     }
114 
115     for (int n = 0; n <= maxSymbol; n++) {
116       values[n] =
117           valuesPerRank[numberOfBits[n]]++; // assign value within rank, symbol order
118     }
119 
120     this.maxSymbol = maxSymbol;
121     this.maxNumberOfBits = maxNumberOfBits;
122   }
123 
124   private int buildTree(final int[] counts, final int maxSymbol,
125                         final NodeTable nodeTable) {
126     // populate the leaves of the node table from the histogram of counts
127     // in descending order by count, ascending by symbol value.
128     short current = 0;
129 
130     for (int symbol = 0; symbol <= maxSymbol; symbol++) {
131       final int count = counts[symbol];
132 
133       // simple insertion sort
134       int position = current;
135       while (position > 1 && count > nodeTable.count[position - 1]) {
136         nodeTable.copyNode(position - 1, position);
137         position--;
138       }
139 
140       nodeTable.count[position] = count;
141       nodeTable.symbols[position] = symbol;
142 
143       current++;
144     }
145 
146     int lastNonZero = maxSymbol;
147     while (nodeTable.count[lastNonZero] == 0) {
148       lastNonZero--;
149     }
150 
151     // populate the non-leaf nodes
152     final short nonLeafStart = MAX_SYMBOL_COUNT;
153     current = nonLeafStart;
154 
155     int currentLeaf = lastNonZero;
156 
157     // combine the two smallest leaves to create the first intermediate node
158     int currentNonLeaf = current;
159     nodeTable.count[current] =
160         nodeTable.count[currentLeaf] + nodeTable.count[currentLeaf - 1];
161     nodeTable.parents[currentLeaf] = current;
162     nodeTable.parents[currentLeaf - 1] = current;
163     current++;
164     currentLeaf -= 2;
165 
166     final int root = MAX_SYMBOL_COUNT + lastNonZero - 1;
167 
168     // fill in sentinels
169     for (int n = current; n <= root; n++) {
170       nodeTable.count[n] = 1 << 30;
171     }
172 
173     // create parents
174     while (current <= root) {
175       final int child1;
176       if (currentLeaf >= 0 &&
177           nodeTable.count[currentLeaf] < nodeTable.count[currentNonLeaf]) {
178         child1 = currentLeaf--;
179       } else {
180         child1 = currentNonLeaf++;
181       }
182 
183       final int child2;
184       if (currentLeaf >= 0 &&
185           nodeTable.count[currentLeaf] < nodeTable.count[currentNonLeaf]) {
186         child2 = currentLeaf--;
187       } else {
188         child2 = currentNonLeaf++;
189       }
190 
191       nodeTable.count[current] =
192           nodeTable.count[child1] + nodeTable.count[child2];
193       nodeTable.parents[child1] = current;
194       nodeTable.parents[child2] = current;
195       current++;
196     }
197 
198     // distribute weights
199     nodeTable.numberOfBits[root] = 0;
200     for (int n = root - 1; n >= nonLeafStart; n--) {
201       final short parent = nodeTable.parents[n];
202       nodeTable.numberOfBits[n] = (byte) (nodeTable.numberOfBits[parent] + 1);
203     }
204 
205     for (int n = 0; n <= lastNonZero; n++) {
206       final short parent = nodeTable.parents[n];
207       nodeTable.numberOfBits[n] = (byte) (nodeTable.numberOfBits[parent] + 1);
208     }
209 
210     return lastNonZero;
211   }
212 
213   // TODO: consider encoding 2 symbols at a time
214   //   - need a table with 256x256 entries with
215   //      - the concatenated bits for the corresponding pair of symbols
216   //      - the sum of bits for the corresponding pair of symbols
217   //   - read 2 symbols at a time from the input
218   public void encodeSymbol(final BitOutputStream output, final int symbol) {
219     output.addBitsFast(values[symbol], numberOfBits[symbol]);
220   }
221 
222   public int write(final byte[] outputBase, final int outputAddress,
223                    final int outputSize,
224                    final HuffmanTableWriterWorkspace workspace) {
225     final byte[] weights = workspace.weights;
226 
227     int output = outputAddress;
228 
229     final int maxNumberOfBits1 = this.maxNumberOfBits;
230     final int maxSymbol1 = this.maxSymbol;
231 
232     // convert to weights per RFC 8478 section 4.2.1
233     for (int symbol = 0; symbol < maxSymbol1; symbol++) {
234       final int bits = numberOfBits[symbol];
235 
236       if (bits == 0) {
237         weights[symbol] = 0;
238       } else {
239         weights[symbol] = (byte) (maxNumberOfBits1 + 1 - bits);
240       }
241     }
242 
243     // attempt weights compression by FSE
244     int size = compressWeights(outputBase, output + 1, outputSize - 1, weights,
245                                maxSymbol1, workspace);
246 
247     if (maxSymbol1 > 127 && size > 127) {
248       // This should never happen. Since weights are in the range [0, 12], they can be compressed optimally to ~3.7 bits per symbol for a uniform distribution.
249       // Since maxSymbol has to be <= MAX_SYMBOL (255), this is 119 bytes + FSE headers.
250       throw new AssertionError();
251     }
252 
253     if (size != 0 && size != 1 && size < maxSymbol1 / 2) {
254       // Go with FSE only if:
255       //   - the weights are compressible
256       //   - the compressed size is better than what we'd get with the raw encoding below
257       //   - the compressed size is <= 127 bytes, which is the most that the encoding can hold for FSE-compressed weights (see RFC 8478 section 4.2.1.1). This is implied
258       //     by the maxSymbol / 2 check, since maxSymbol must be <= 255
259       outputBase[output] = (byte) size;
260       return size + 1; // header + size
261     } else {
262       // Use raw encoding (4 bits per entry)
263 
264       // #entries = #symbols - 1 since last symbol is implicit. Thus, #entries = (maxSymbol + 1) - 1 = maxSymbol
265 
266       size = (maxSymbol1 + 1) / 2;  // ceil(#entries / 2)
267       checkArgument(size + 1 /* header */ <= outputSize,
268                     "Output size too small"); // 2 entries per byte
269 
270       // encode number of symbols
271       // header = #entries + 127 per RFC
272       outputBase[output] = (byte) (127 + maxSymbol1);
273       output++;
274 
275       weights[maxSymbol1] =
276           0; // last weight is implicit, so set to 0 so that it doesn't get encoded below
277       for (int i = 0; i < maxSymbol1; i += 2) {
278         outputBase[output] = (byte) ((weights[i] << 4) + weights[i + 1]);
279         output++;
280       }
281 
282       return output - outputAddress;
283     }
284   }
285 
286   /**
287    * Can this table encode all symbols with non-zero count?
288    */
289   public boolean isValid(final int[] counts, final int maxSymbol) {
290     if (maxSymbol > this.maxSymbol) {
291       // some non-zero count symbols cannot be encoded by the current table
292       return false;
293     }
294 
295     for (int symbol = 0; symbol <= maxSymbol; ++symbol) {
296       if (counts[symbol] != 0 && numberOfBits[symbol] == 0) {
297         return false;
298       }
299     }
300     return true;
301   }
302 
303   public int estimateCompressedSize(final int[] counts, final int maxSymbol) {
304     int numberOfBits = 0;
305     for (int symbol = 0; symbol <= Math.min(maxSymbol, this.maxSymbol);
306          symbol++) {
307       numberOfBits += this.numberOfBits[symbol] * counts[symbol];
308     }
309 
310     return numberOfBits >>> 3; // convert to bytes
311   }
312 
313   // http://fastcompression.blogspot.com/2015/07/huffman-revisited-part-3-depth-limited.html
314   private static int setMaxHeight(final NodeTable nodeTable,
315                                   final int lastNonZero,
316                                   final int maxNumberOfBits,
317                                   final HuffmanCompressionTableWorkspace workspace) {
318     final int largestBits = nodeTable.numberOfBits[lastNonZero];
319 
320     if (largestBits <= maxNumberOfBits) {
321       return largestBits;   // early exit: no elements > maxNumberOfBits
322     }
323 
324     // there are several too large elements (at least >= 2)
325     int totalCost = 0;
326     final int baseCost = 1 << (largestBits - maxNumberOfBits);
327     int n = lastNonZero;
328 
329     while (nodeTable.numberOfBits[n] > maxNumberOfBits) {
330       totalCost += baseCost - (1 << (largestBits - nodeTable.numberOfBits[n]));
331       nodeTable.numberOfBits[n] = (byte) maxNumberOfBits;
332       n--;
333     }  // n stops at nodeTable.numberOfBits[n + offset] <= maxNumberOfBits
334 
335     while (nodeTable.numberOfBits[n] == maxNumberOfBits) {
336       n--;   // n ends at index of smallest symbol using < maxNumberOfBits
337     }
338 
339     // renormalize totalCost
340     totalCost >>>= (largestBits -
341                     maxNumberOfBits);  // note: totalCost is necessarily a multiple of baseCost
342 
343     // repay normalized cost
344     final int noSymbol = 0xF0F0F0F0;
345     final int[] rankLast = workspace.rankLast;
346     Arrays.fill(rankLast, noSymbol);
347 
348     // Get pos of last (smallest) symbol per rank
349     int currentNbBits = maxNumberOfBits;
350     for (int pos = n; pos >= 0; pos--) {
351       if (nodeTable.numberOfBits[pos] >= currentNbBits) {
352         continue;
353       }
354       currentNbBits = nodeTable.numberOfBits[pos];   // < maxNumberOfBits
355       rankLast[maxNumberOfBits - currentNbBits] = pos;
356     }
357 
358     while (totalCost > 0) {
359       int numberOfBitsToDecrease = Util.highestBit(totalCost) + 1;
360       for (; numberOfBitsToDecrease > 1; numberOfBitsToDecrease--) {
361         final int highPosition = rankLast[numberOfBitsToDecrease];
362         final int lowPosition = rankLast[numberOfBitsToDecrease - 1];
363         if (highPosition == noSymbol) {
364           continue;
365         }
366         if (lowPosition == noSymbol) {
367           break;
368         }
369         final int highTotal = nodeTable.count[highPosition];
370         final int lowTotal = 2 * nodeTable.count[lowPosition];
371         if (highTotal <= lowTotal) {
372           break;
373         }
374       }
375 
376       // only triggered when no more rank 1 symbol left => find closest one (note : there is necessarily at least one !)
377       // HUF_MAX_TABLELOG test just to please gcc 5+; but it should not be necessary
378       while ((numberOfBitsToDecrease <= MAX_TABLE_LOG) &&
379              (rankLast[numberOfBitsToDecrease] == noSymbol)) {
380         numberOfBitsToDecrease++;
381       }
382       totalCost -= 1 << (numberOfBitsToDecrease - 1);
383       if (rankLast[numberOfBitsToDecrease - 1] == noSymbol) {
384         rankLast[numberOfBitsToDecrease - 1] =
385             rankLast[numberOfBitsToDecrease];   // this rank is no longer empty
386       }
387       nodeTable.numberOfBits[rankLast[numberOfBitsToDecrease]]++;
388       if (rankLast[numberOfBitsToDecrease] ==
389           0) {   /* special case, reached largest symbol */
390         rankLast[numberOfBitsToDecrease] = noSymbol;
391       } else {
392         rankLast[numberOfBitsToDecrease]--;
393         if (nodeTable.numberOfBits[rankLast[numberOfBitsToDecrease]] !=
394             maxNumberOfBits - numberOfBitsToDecrease) {
395           rankLast[numberOfBitsToDecrease] =
396               noSymbol;   // this rank is now empty
397         }
398       }
399     }
400 
401     while (totalCost < 0) {  // Sometimes, cost correction overshoot
402       if (rankLast[1] ==
403           noSymbol) {  /* special case : no rank 1 symbol (using maxNumberOfBits-1); let's create one from largest rank 0 (using maxNumberOfBits) */
404         while (nodeTable.numberOfBits[n] == maxNumberOfBits) {
405           n--;
406         }
407         nodeTable.numberOfBits[n + 1]--;
408         rankLast[1] = n + 1;
409         totalCost++;
410         continue;
411       }
412       nodeTable.numberOfBits[rankLast[1] + 1]--;
413       rankLast[1]++;
414       totalCost++;
415     }
416 
417     return maxNumberOfBits;
418   }
419 
420   /**
421    * All elements within weightTable must be <= Huffman.MAX_TABLE_LOG
422    */
423   private static int compressWeights(final byte[] outputBase,
424                                      final int outputAddress,
425                                      final int outputSize, final byte[] weights,
426                                      final int weightsLength,
427                                      final HuffmanTableWriterWorkspace workspace) {
428     if (weightsLength <= 1) {
429       return 0; // Not compressible
430     }
431 
432     // Scan input and build symbol stats
433     final int[] counts = workspace.counts;
434     Histogram.count(weights, weightsLength, counts);
435     final int maxSymbol = Histogram.findMaxSymbol(counts, MAX_TABLE_LOG);
436     final int maxCount = Histogram.findLargestCount(counts, maxSymbol);
437 
438     if (maxCount == weightsLength) {
439       return 1; // only a single symbol in source
440     }
441     if (maxCount == 1) {
442       return 0; // each symbol present maximum once => not compressible
443     }
444 
445     final short[] normalizedCounts = workspace.normalizedCounts;
446 
447     final int tableLog =
448         FiniteStateEntropy.optimalTableLog(MAX_FSE_TABLE_LOG, weightsLength,
449                                            maxSymbol);
450     FiniteStateEntropy.normalizeCounts(normalizedCounts, tableLog, counts,
451                                        weightsLength, maxSymbol);
452 
453     int output = outputAddress;
454     final int outputLimit = outputAddress + outputSize;
455 
456     // Write table description header
457     final int headerSize =
458         FiniteStateEntropy.writeNormalizedCounts(outputBase, output, outputSize,
459                                                  normalizedCounts, maxSymbol,
460                                                  tableLog);
461     output += headerSize;
462 
463     // Compress
464     final FseCompressionTable compressionTable = workspace.fseTable;
465     compressionTable.initialize(normalizedCounts, maxSymbol, tableLog);
466     final int compressedSize =
467         FiniteStateEntropy.compress(outputBase, output, outputLimit - output,
468                                     weights, weightsLength, compressionTable);
469     if (compressedSize == 0) {
470       return 0;
471     }
472     output += compressedSize;
473 
474     return output - outputAddress;
475   }
476 }