View Javadoc
1   /*
2    * This file is part of Waarp Project (named also Waarp or GG).
3    *
4    *  Copyright (c) 2019, Waarp SAS, and individual contributors by the @author
5    *  tags. See the COPYRIGHT.txt in the distribution for a full listing of
6    * individual contributors.
7    *
8    *  All Waarp Project is free software: you can redistribute it and/or
9    * modify it under the terms of the GNU General Public License as published by
10   * the Free Software Foundation, either version 3 of the License, or (at your
11   * option) any later version.
12   *
13   * Waarp is distributed in the hope that it will be useful, but WITHOUT ANY
14   * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
15   * A PARTICULAR PURPOSE. See the GNU General Public License for more details.
16   *
17   *  You should have received a copy of the GNU General Public License along with
18   * Waarp . If not, see <http://www.gnu.org/licenses/>.
19   */
20  
21  /*
22   * Licensed under the Apache License, Version 2.0 (the "License");
23   * you may not use this file except in compliance with the License.
24   * You may obtain a copy of the License at
25   *
26   *     http://www.apache.org/licenses/LICENSE-2.0
27   *
28   * Unless required by applicable law or agreed to in writing, software
29   * distributed under the License is distributed on an "AS IS" BASIS,
30   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31   * See the License for the specific language governing permissions and
32   * limitations under the License.
33   */
34  package org.waarp.compress.zstdunsafe;
35  
36  import java.util.Arrays;
37  
38  final class HuffmanCompressionTable {
39    private final short[] values;
40    private final byte[] numberOfBits;
41  
42    private int maxSymbol;
43    private int maxNumberOfBits;
44  
45    public HuffmanCompressionTable(final int capacity) {
46      this.values = new short[capacity];
47      this.numberOfBits = new byte[capacity];
48    }
49  
50    public static int optimalNumberOfBits(final int maxNumberOfBits,
51                                          final int inputSize,
52                                          final int maxSymbol) {
53      if (inputSize <= 1) {
54        throw new IllegalArgumentException(); // not supported. Use RLE instead
55      }
56  
57      int result = maxNumberOfBits;
58  
59      result = Math.min(result, Util.highestBit((inputSize - 1)) -
60                                1); // we may be able to reduce accuracy if input is small
61  
62      // Need a minimum to safely represent all symbol values
63      result = Math.max(result, Util.minTableLog(inputSize, maxSymbol));
64  
65      result =
66          Math.max(result, Huffman.MIN_TABLE_LOG); // absolute minimum for Huffman
67      result =
68          Math.min(result, Huffman.MAX_TABLE_LOG); // absolute maximum for Huffman
69  
70      return result;
71    }
72  
73    public void initialize(final int[] counts, final int maxSymbol,
74                           int maxNumberOfBits,
75                           final HuffmanCompressionTableWorkspace workspace) {
76      Util.checkArgument(maxSymbol <= Huffman.MAX_SYMBOL,
77                         "Max symbol value too large");
78  
79      workspace.reset();
80  
81      final NodeTable nodeTable = workspace.nodeTable;
82      nodeTable.reset();
83  
84      final int lastNonZero = buildTree(counts, maxSymbol, nodeTable);
85  
86      // enforce max table log
87      maxNumberOfBits =
88          setMaxHeight(nodeTable, lastNonZero, maxNumberOfBits, workspace);
89      Util.checkArgument(maxNumberOfBits <= Huffman.MAX_TABLE_LOG,
90                         "Max number of bits larger than max table size");
91  
92      // populate table
93      final int symbolCount = maxSymbol + 1;
94      for (int node = 0; node < symbolCount; node++) {
95        final int symbol = nodeTable.symbols[node];
96        numberOfBits[symbol] = nodeTable.numberOfBits[node];
97      }
98  
99      final short[] entriesPerRank = workspace.entriesPerRank;
100     final short[] valuesPerRank = workspace.valuesPerRank;
101 
102     for (int n = 0; n <= lastNonZero; n++) {
103       entriesPerRank[nodeTable.numberOfBits[n]]++;
104     }
105 
106     // determine starting value per rank
107     short startingValue = 0;
108     for (int rank = maxNumberOfBits; rank > 0; rank--) {
109       valuesPerRank[rank] =
110           startingValue; // get starting value within each rank
111       startingValue += entriesPerRank[rank];
112       startingValue >>>= 1;
113     }
114 
115     for (int n = 0; n <= maxSymbol; n++) {
116       values[n] =
117           valuesPerRank[numberOfBits[n]]++; // assign value within rank, symbol order
118     }
119 
120     this.maxSymbol = maxSymbol;
121     this.maxNumberOfBits = maxNumberOfBits;
122   }
123 
124   private int buildTree(final int[] counts, final int maxSymbol,
125                         final NodeTable nodeTable) {
126     // populate the leaves of the node table from the histogram of counts
127     // in descending order by count, ascending by symbol value.
128     short current = 0;
129 
130     for (int symbol = 0; symbol <= maxSymbol; symbol++) {
131       final int count = counts[symbol];
132 
133       // simple insertion sort
134       int position = current;
135       while (position > 1 && count > nodeTable.count[position - 1]) {
136         nodeTable.copyNode(position - 1, position);
137         position--;
138       }
139 
140       nodeTable.count[position] = count;
141       nodeTable.symbols[position] = symbol;
142 
143       current++;
144     }
145 
146     int lastNonZero = maxSymbol;
147     while (nodeTable.count[lastNonZero] == 0) {
148       lastNonZero--;
149     }
150 
151     // populate the non-leaf nodes
152     final short nonLeafStart = Huffman.MAX_SYMBOL_COUNT;
153     current = nonLeafStart;
154 
155     int currentLeaf = lastNonZero;
156 
157     // combine the two smallest leaves to create the first intermediate node
158     int currentNonLeaf = current;
159     nodeTable.count[current] =
160         nodeTable.count[currentLeaf] + nodeTable.count[currentLeaf - 1];
161     nodeTable.parents[currentLeaf] = current;
162     nodeTable.parents[currentLeaf - 1] = current;
163     current++;
164     currentLeaf -= 2;
165 
166     final int root = Huffman.MAX_SYMBOL_COUNT + lastNonZero - 1;
167 
168     // fill in sentinels
169     for (int n = current; n <= root; n++) {
170       nodeTable.count[n] = 1 << 30;
171     }
172 
173     // create parents
174     while (current <= root) {
175       final int child1;
176       if (currentLeaf >= 0 &&
177           nodeTable.count[currentLeaf] < nodeTable.count[currentNonLeaf]) {
178         child1 = currentLeaf--;
179       } else {
180         child1 = currentNonLeaf++;
181       }
182 
183       final int child2;
184       if (currentLeaf >= 0 &&
185           nodeTable.count[currentLeaf] < nodeTable.count[currentNonLeaf]) {
186         child2 = currentLeaf--;
187       } else {
188         child2 = currentNonLeaf++;
189       }
190 
191       nodeTable.count[current] =
192           nodeTable.count[child1] + nodeTable.count[child2];
193       nodeTable.parents[child1] = current;
194       nodeTable.parents[child2] = current;
195       current++;
196     }
197 
198     // distribute weights
199     nodeTable.numberOfBits[root] = 0;
200     for (int n = root - 1; n >= nonLeafStart; n--) {
201       final short parent = nodeTable.parents[n];
202       nodeTable.numberOfBits[n] = (byte) (nodeTable.numberOfBits[parent] + 1);
203     }
204 
205     for (int n = 0; n <= lastNonZero; n++) {
206       final short parent = nodeTable.parents[n];
207       nodeTable.numberOfBits[n] = (byte) (nodeTable.numberOfBits[parent] + 1);
208     }
209 
210     return lastNonZero;
211   }
212 
213   // TODO: consider encoding 2 symbols at a time
214   //   - need a table with 256x256 entries with
215   //      - the concatenated bits for the corresponding pair of symbols
216   //      - the sum of bits for the corresponding pair of symbols
217   //   - read 2 symbols at a time from the input
218   public void encodeSymbol(final BitOutputStream output, final int symbol) {
219     output.addBitsFast(values[symbol], numberOfBits[symbol]);
220   }
221 
222   public int write(final Object outputBase, final long outputAddress,
223                    final int outputSize,
224                    final HuffmanTableWriterWorkspace workspace) {
225     final byte[] weights = workspace.weights;
226 
227     long output = outputAddress;
228 
229     final int numberOfBits1 = this.maxNumberOfBits;
230     final int maxSymbol1 = this.maxSymbol;
231 
232     // convert to weights per RFC 8478 section 4.2.1
233     for (int symbol = 0; symbol < maxSymbol1; symbol++) {
234       final int bits = numberOfBits[symbol];
235 
236       if (bits == 0) {
237         weights[symbol] = 0;
238       } else {
239         weights[symbol] = (byte) (numberOfBits1 + 1 - bits);
240       }
241     }
242 
243     // attempt weights compression by FSE
244     int size = compressWeights(outputBase, output + 1, outputSize - 1, weights,
245                                maxSymbol1, workspace);
246 
247     if (maxSymbol1 > 127 && size > 127) {
248       // This should never happen. Since weights are in the range [0, 12], they can be compressed optimally to ~3.7 bits per symbol for a uniform distribution.
249       // Since maxSymbol has to be <= MAX_SYMBOL (255), this is 119 bytes + FSE headers.
250       throw new AssertionError();
251     }
252 
253     if (size != 0 && size != 1 && size < maxSymbol1 / 2) {
254       // Go with FSE only if:
255       //   - the weights are compressible
256       //   - the compressed size is better than what we'd get with the raw encoding below
257       //   - the compressed size is <= 127 bytes, which is the most that the encoding can hold for FSE-compressed weights (see RFC 8478 section 4.2.1.1). This is implied
258       //     by the maxSymbol / 2 check, since maxSymbol must be <= 255
259       UnsafeUtil.UNSAFE.putByte(outputBase, output, (byte) size);
260       return size + 1; // header + size
261     } else {
262       // Use raw encoding (4 bits per entry)
263 
264       // #entries = #symbols - 1 since last symbol is implicit. Thus, #entries = (maxSymbol + 1) - 1 = maxSymbol
265 
266       size = (maxSymbol1 + 1) / 2;  // ceil(#entries / 2)
267       Util.checkArgument(size + 1 /* header */ <= outputSize,
268                          "Output size too small"); // 2 entries per byte
269 
270       // encode number of symbols
271       // header = #entries + 127 per RFC
272       UnsafeUtil.UNSAFE.putByte(outputBase, output, (byte) (127 + maxSymbol1));
273       output++;
274 
275       weights[maxSymbol1] =
276           0; // last weight is implicit, so set to 0 so that it doesn't get encoded below
277       for (int i = 0; i < maxSymbol1; i += 2) {
278         UnsafeUtil.UNSAFE.putByte(outputBase, output,
279                                   (byte) ((weights[i] << 4) +
280                                           (weights[i + 1] & 0xFF)));
281         output++;
282       }
283 
284       return (int) (output - outputAddress);
285     }
286   }
287 
288   /**
289    * Can this table encode all symbols with non-zero count?
290    */
291   public boolean isValid(final int[] counts, final int maxSymbol) {
292     if (maxSymbol > this.maxSymbol) {
293       // some non-zero count symbols cannot be encoded by the current table
294       return false;
295     }
296 
297     for (int symbol = 0; symbol <= maxSymbol; ++symbol) {
298       if (counts[symbol] != 0 && numberOfBits[symbol] == 0) {
299         return false;
300       }
301     }
302     return true;
303   }
304 
305   public int estimateCompressedSize(final int[] counts, final int maxSymbol) {
306     int numberOfBits1 = 0;
307     for (int symbol = 0; symbol <= Math.min(maxSymbol, this.maxSymbol);
308          symbol++) {
309       numberOfBits1 += this.numberOfBits[symbol] * counts[symbol];
310     }
311 
312     return numberOfBits1 >>> 3; // convert to bytes
313   }
314 
315   // http://fastcompression.blogspot.com/2015/07/huffman-revisited-part-3-depth-limited.html
316   private static int setMaxHeight(final NodeTable nodeTable,
317                                   final int lastNonZero,
318                                   final int maxNumberOfBits,
319                                   final HuffmanCompressionTableWorkspace workspace) {
320     final int largestBits = nodeTable.numberOfBits[lastNonZero];
321 
322     if (largestBits <= maxNumberOfBits) {
323       return largestBits;   // early exit: no elements > maxNumberOfBits
324     }
325 
326     // there are several too large elements (at least >= 2)
327     int totalCost = 0;
328     final int baseCost = 1 << (largestBits - maxNumberOfBits);
329     int n = lastNonZero;
330 
331     while (nodeTable.numberOfBits[n] > maxNumberOfBits) {
332       totalCost += baseCost - (1 << (largestBits - nodeTable.numberOfBits[n]));
333       nodeTable.numberOfBits[n] = (byte) maxNumberOfBits;
334       n--;
335     }  // n stops at nodeTable.numberOfBits[n + offset] <= maxNumberOfBits
336 
337     while (nodeTable.numberOfBits[n] == maxNumberOfBits) {
338       n--;   // n ends at index of smallest symbol using < maxNumberOfBits
339     }
340 
341     // renormalize totalCost
342     totalCost >>>= (largestBits -
343                     maxNumberOfBits);  // note: totalCost is necessarily a multiple of baseCost
344 
345     // repay normalized cost
346     final int noSymbol = 0xF0F0F0F0;
347     final int[] rankLast = workspace.rankLast;
348     Arrays.fill(rankLast, noSymbol);
349 
350     // Get pos of last (smallest) symbol per rank
351     int currentNbBits = maxNumberOfBits;
352     for (int pos = n; pos >= 0; pos--) {
353       if (nodeTable.numberOfBits[pos] >= currentNbBits) {
354         continue;
355       }
356       currentNbBits = nodeTable.numberOfBits[pos];   // < maxNumberOfBits
357       rankLast[maxNumberOfBits - currentNbBits] = pos;
358     }
359 
360     while (totalCost > 0) {
361       int numberOfBitsToDecrease = Util.highestBit(totalCost) + 1;
362       for (; numberOfBitsToDecrease > 1; numberOfBitsToDecrease--) {
363         final int highPosition = rankLast[numberOfBitsToDecrease];
364         final int lowPosition = rankLast[numberOfBitsToDecrease - 1];
365         if (highPosition == noSymbol) {
366           continue;
367         }
368         if (lowPosition == noSymbol) {
369           break;
370         }
371         final int highTotal = nodeTable.count[highPosition];
372         final int lowTotal = 2 * nodeTable.count[lowPosition];
373         if (highTotal <= lowTotal) {
374           break;
375         }
376       }
377 
378       // only triggered when no more rank 1 symbol left => find closest one (note : there is necessarily at least one !)
379       // HUF_MAX_TABLELOG test just to please gcc 5+; but it should not be necessary
380       while ((numberOfBitsToDecrease <= Huffman.MAX_TABLE_LOG) &&
381              (rankLast[numberOfBitsToDecrease] == noSymbol)) {
382         numberOfBitsToDecrease++;
383       }
384       totalCost -= 1 << (numberOfBitsToDecrease - 1);
385       if (rankLast[numberOfBitsToDecrease - 1] == noSymbol) {
386         rankLast[numberOfBitsToDecrease - 1] =
387             rankLast[numberOfBitsToDecrease];   // this rank is no longer empty
388       }
389       nodeTable.numberOfBits[rankLast[numberOfBitsToDecrease]]++;
390       if (rankLast[numberOfBitsToDecrease] ==
391           0) {   /* special case, reached largest symbol */
392         rankLast[numberOfBitsToDecrease] = noSymbol;
393       } else {
394         rankLast[numberOfBitsToDecrease]--;
395         if (nodeTable.numberOfBits[rankLast[numberOfBitsToDecrease]] !=
396             maxNumberOfBits - numberOfBitsToDecrease) {
397           rankLast[numberOfBitsToDecrease] =
398               noSymbol;   // this rank is now empty
399         }
400       }
401     }
402 
403     while (totalCost < 0) {  // Sometimes, cost correction overshoot
404       if (rankLast[1] ==
405           noSymbol) {  /* special case : no rank 1 symbol (using maxNumberOfBits-1); let's create one from largest rank 0 (using maxNumberOfBits) */
406         while (nodeTable.numberOfBits[n] == maxNumberOfBits) {
407           n--;
408         }
409         nodeTable.numberOfBits[n + 1]--;
410         rankLast[1] = n + 1;
411         totalCost++;
412         continue;
413       }
414       nodeTable.numberOfBits[rankLast[1] + 1]--;
415       rankLast[1]++;
416       totalCost++;
417     }
418 
419     return maxNumberOfBits;
420   }
421 
422   /**
423    * All elements within weightTable must be <= Huffman.MAX_TABLE_LOG
424    */
425   private static int compressWeights(final Object outputBase,
426                                      final long outputAddress,
427                                      final int outputSize, final byte[] weights,
428                                      final int weightsLength,
429                                      final HuffmanTableWriterWorkspace workspace) {
430     if (weightsLength <= 1) {
431       return 0; // Not compressible
432     }
433 
434     // Scan input and build symbol stats
435     final int[] counts = workspace.counts;
436     Histogram.count(weights, weightsLength, counts);
437     final int maxSymbol =
438         Histogram.findMaxSymbol(counts, Huffman.MAX_TABLE_LOG);
439     final int maxCount = Histogram.findLargestCount(counts, maxSymbol);
440 
441     if (maxCount == weightsLength) {
442       return 1; // only a single symbol in source
443     }
444     if (maxCount == 1) {
445       return 0; // each symbol present maximum once => not compressible
446     }
447 
448     final short[] normalizedCounts = workspace.normalizedCounts;
449 
450     final int tableLog =
451         FiniteStateEntropy.optimalTableLog(Huffman.MAX_FSE_TABLE_LOG,
452                                            weightsLength, maxSymbol);
453     FiniteStateEntropy.normalizeCounts(normalizedCounts, tableLog, counts,
454                                        weightsLength, maxSymbol);
455 
456     long output = outputAddress;
457     final long outputLimit = outputAddress + outputSize;
458 
459     // Write table description header
460     final int headerSize =
461         FiniteStateEntropy.writeNormalizedCounts(outputBase, output, outputSize,
462                                                  normalizedCounts, maxSymbol,
463                                                  tableLog);
464     output += headerSize;
465 
466     // Compress
467     final FseCompressionTable compressionTable = workspace.fseTable;
468     compressionTable.initialize(normalizedCounts, maxSymbol, tableLog);
469     final int compressedSize = FiniteStateEntropy.compress(outputBase, output,
470                                                            (int) (outputLimit -
471                                                                   output),
472                                                            weights,
473                                                            weightsLength,
474                                                            compressionTable);
475     if (compressedSize == 0) {
476       return 0;
477     }
478     output += compressedSize;
479 
480     return (int) (output - outputAddress);
481   }
482 }