1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34 package org.waarp.compress.zstdunsafe;
35
36 import java.util.Arrays;
37
38 final class HuffmanCompressionTable {
39 private final short[] values;
40 private final byte[] numberOfBits;
41
42 private int maxSymbol;
43 private int maxNumberOfBits;
44
45 public HuffmanCompressionTable(final int capacity) {
46 this.values = new short[capacity];
47 this.numberOfBits = new byte[capacity];
48 }
49
50 public static int optimalNumberOfBits(final int maxNumberOfBits,
51 final int inputSize,
52 final int maxSymbol) {
53 if (inputSize <= 1) {
54 throw new IllegalArgumentException();
55 }
56
57 int result = maxNumberOfBits;
58
59 result = Math.min(result, Util.highestBit((inputSize - 1)) -
60 1);
61
62
63 result = Math.max(result, Util.minTableLog(inputSize, maxSymbol));
64
65 result =
66 Math.max(result, Huffman.MIN_TABLE_LOG);
67 result =
68 Math.min(result, Huffman.MAX_TABLE_LOG);
69
70 return result;
71 }
72
73 public void initialize(final int[] counts, final int maxSymbol,
74 int maxNumberOfBits,
75 final HuffmanCompressionTableWorkspace workspace) {
76 Util.checkArgument(maxSymbol <= Huffman.MAX_SYMBOL,
77 "Max symbol value too large");
78
79 workspace.reset();
80
81 final NodeTable nodeTable = workspace.nodeTable;
82 nodeTable.reset();
83
84 final int lastNonZero = buildTree(counts, maxSymbol, nodeTable);
85
86
87 maxNumberOfBits =
88 setMaxHeight(nodeTable, lastNonZero, maxNumberOfBits, workspace);
89 Util.checkArgument(maxNumberOfBits <= Huffman.MAX_TABLE_LOG,
90 "Max number of bits larger than max table size");
91
92
93 final int symbolCount = maxSymbol + 1;
94 for (int node = 0; node < symbolCount; node++) {
95 final int symbol = nodeTable.symbols[node];
96 numberOfBits[symbol] = nodeTable.numberOfBits[node];
97 }
98
99 final short[] entriesPerRank = workspace.entriesPerRank;
100 final short[] valuesPerRank = workspace.valuesPerRank;
101
102 for (int n = 0; n <= lastNonZero; n++) {
103 entriesPerRank[nodeTable.numberOfBits[n]]++;
104 }
105
106
107 short startingValue = 0;
108 for (int rank = maxNumberOfBits; rank > 0; rank--) {
109 valuesPerRank[rank] =
110 startingValue;
111 startingValue += entriesPerRank[rank];
112 startingValue >>>= 1;
113 }
114
115 for (int n = 0; n <= maxSymbol; n++) {
116 values[n] =
117 valuesPerRank[numberOfBits[n]]++;
118 }
119
120 this.maxSymbol = maxSymbol;
121 this.maxNumberOfBits = maxNumberOfBits;
122 }
123
124 private int buildTree(final int[] counts, final int maxSymbol,
125 final NodeTable nodeTable) {
126
127
128 short current = 0;
129
130 for (int symbol = 0; symbol <= maxSymbol; symbol++) {
131 final int count = counts[symbol];
132
133
134 int position = current;
135 while (position > 1 && count > nodeTable.count[position - 1]) {
136 nodeTable.copyNode(position - 1, position);
137 position--;
138 }
139
140 nodeTable.count[position] = count;
141 nodeTable.symbols[position] = symbol;
142
143 current++;
144 }
145
146 int lastNonZero = maxSymbol;
147 while (nodeTable.count[lastNonZero] == 0) {
148 lastNonZero--;
149 }
150
151
152 final short nonLeafStart = Huffman.MAX_SYMBOL_COUNT;
153 current = nonLeafStart;
154
155 int currentLeaf = lastNonZero;
156
157
158 int currentNonLeaf = current;
159 nodeTable.count[current] =
160 nodeTable.count[currentLeaf] + nodeTable.count[currentLeaf - 1];
161 nodeTable.parents[currentLeaf] = current;
162 nodeTable.parents[currentLeaf - 1] = current;
163 current++;
164 currentLeaf -= 2;
165
166 final int root = Huffman.MAX_SYMBOL_COUNT + lastNonZero - 1;
167
168
169 for (int n = current; n <= root; n++) {
170 nodeTable.count[n] = 1 << 30;
171 }
172
173
174 while (current <= root) {
175 final int child1;
176 if (currentLeaf >= 0 &&
177 nodeTable.count[currentLeaf] < nodeTable.count[currentNonLeaf]) {
178 child1 = currentLeaf--;
179 } else {
180 child1 = currentNonLeaf++;
181 }
182
183 final int child2;
184 if (currentLeaf >= 0 &&
185 nodeTable.count[currentLeaf] < nodeTable.count[currentNonLeaf]) {
186 child2 = currentLeaf--;
187 } else {
188 child2 = currentNonLeaf++;
189 }
190
191 nodeTable.count[current] =
192 nodeTable.count[child1] + nodeTable.count[child2];
193 nodeTable.parents[child1] = current;
194 nodeTable.parents[child2] = current;
195 current++;
196 }
197
198
199 nodeTable.numberOfBits[root] = 0;
200 for (int n = root - 1; n >= nonLeafStart; n--) {
201 final short parent = nodeTable.parents[n];
202 nodeTable.numberOfBits[n] = (byte) (nodeTable.numberOfBits[parent] + 1);
203 }
204
205 for (int n = 0; n <= lastNonZero; n++) {
206 final short parent = nodeTable.parents[n];
207 nodeTable.numberOfBits[n] = (byte) (nodeTable.numberOfBits[parent] + 1);
208 }
209
210 return lastNonZero;
211 }
212
213
214
215
216
217
218 public void encodeSymbol(final BitOutputStream output, final int symbol) {
219 output.addBitsFast(values[symbol], numberOfBits[symbol]);
220 }
221
222 public int write(final Object outputBase, final long outputAddress,
223 final int outputSize,
224 final HuffmanTableWriterWorkspace workspace) {
225 final byte[] weights = workspace.weights;
226
227 long output = outputAddress;
228
229 final int numberOfBits1 = this.maxNumberOfBits;
230 final int maxSymbol1 = this.maxSymbol;
231
232
233 for (int symbol = 0; symbol < maxSymbol1; symbol++) {
234 final int bits = numberOfBits[symbol];
235
236 if (bits == 0) {
237 weights[symbol] = 0;
238 } else {
239 weights[symbol] = (byte) (numberOfBits1 + 1 - bits);
240 }
241 }
242
243
244 int size = compressWeights(outputBase, output + 1, outputSize - 1, weights,
245 maxSymbol1, workspace);
246
247 if (maxSymbol1 > 127 && size > 127) {
248
249
250 throw new AssertionError();
251 }
252
253 if (size != 0 && size != 1 && size < maxSymbol1 / 2) {
254
255
256
257
258
259 UnsafeUtil.UNSAFE.putByte(outputBase, output, (byte) size);
260 return size + 1;
261 } else {
262
263
264
265
266 size = (maxSymbol1 + 1) / 2;
267 Util.checkArgument(size + 1 <= outputSize,
268 "Output size too small");
269
270
271
272 UnsafeUtil.UNSAFE.putByte(outputBase, output, (byte) (127 + maxSymbol1));
273 output++;
274
275 weights[maxSymbol1] =
276 0;
277 for (int i = 0; i < maxSymbol1; i += 2) {
278 UnsafeUtil.UNSAFE.putByte(outputBase, output,
279 (byte) ((weights[i] << 4) +
280 (weights[i + 1] & 0xFF)));
281 output++;
282 }
283
284 return (int) (output - outputAddress);
285 }
286 }
287
288
289
290
291 public boolean isValid(final int[] counts, final int maxSymbol) {
292 if (maxSymbol > this.maxSymbol) {
293
294 return false;
295 }
296
297 for (int symbol = 0; symbol <= maxSymbol; ++symbol) {
298 if (counts[symbol] != 0 && numberOfBits[symbol] == 0) {
299 return false;
300 }
301 }
302 return true;
303 }
304
305 public int estimateCompressedSize(final int[] counts, final int maxSymbol) {
306 int numberOfBits1 = 0;
307 for (int symbol = 0; symbol <= Math.min(maxSymbol, this.maxSymbol);
308 symbol++) {
309 numberOfBits1 += this.numberOfBits[symbol] * counts[symbol];
310 }
311
312 return numberOfBits1 >>> 3;
313 }
314
315
316 private static int setMaxHeight(final NodeTable nodeTable,
317 final int lastNonZero,
318 final int maxNumberOfBits,
319 final HuffmanCompressionTableWorkspace workspace) {
320 final int largestBits = nodeTable.numberOfBits[lastNonZero];
321
322 if (largestBits <= maxNumberOfBits) {
323 return largestBits;
324 }
325
326
327 int totalCost = 0;
328 final int baseCost = 1 << (largestBits - maxNumberOfBits);
329 int n = lastNonZero;
330
331 while (nodeTable.numberOfBits[n] > maxNumberOfBits) {
332 totalCost += baseCost - (1 << (largestBits - nodeTable.numberOfBits[n]));
333 nodeTable.numberOfBits[n] = (byte) maxNumberOfBits;
334 n--;
335 }
336
337 while (nodeTable.numberOfBits[n] == maxNumberOfBits) {
338 n--;
339 }
340
341
342 totalCost >>>= (largestBits -
343 maxNumberOfBits);
344
345
346 final int noSymbol = 0xF0F0F0F0;
347 final int[] rankLast = workspace.rankLast;
348 Arrays.fill(rankLast, noSymbol);
349
350
351 int currentNbBits = maxNumberOfBits;
352 for (int pos = n; pos >= 0; pos--) {
353 if (nodeTable.numberOfBits[pos] >= currentNbBits) {
354 continue;
355 }
356 currentNbBits = nodeTable.numberOfBits[pos];
357 rankLast[maxNumberOfBits - currentNbBits] = pos;
358 }
359
360 while (totalCost > 0) {
361 int numberOfBitsToDecrease = Util.highestBit(totalCost) + 1;
362 for (; numberOfBitsToDecrease > 1; numberOfBitsToDecrease--) {
363 final int highPosition = rankLast[numberOfBitsToDecrease];
364 final int lowPosition = rankLast[numberOfBitsToDecrease - 1];
365 if (highPosition == noSymbol) {
366 continue;
367 }
368 if (lowPosition == noSymbol) {
369 break;
370 }
371 final int highTotal = nodeTable.count[highPosition];
372 final int lowTotal = 2 * nodeTable.count[lowPosition];
373 if (highTotal <= lowTotal) {
374 break;
375 }
376 }
377
378
379
380 while ((numberOfBitsToDecrease <= Huffman.MAX_TABLE_LOG) &&
381 (rankLast[numberOfBitsToDecrease] == noSymbol)) {
382 numberOfBitsToDecrease++;
383 }
384 totalCost -= 1 << (numberOfBitsToDecrease - 1);
385 if (rankLast[numberOfBitsToDecrease - 1] == noSymbol) {
386 rankLast[numberOfBitsToDecrease - 1] =
387 rankLast[numberOfBitsToDecrease];
388 }
389 nodeTable.numberOfBits[rankLast[numberOfBitsToDecrease]]++;
390 if (rankLast[numberOfBitsToDecrease] ==
391 0) {
392 rankLast[numberOfBitsToDecrease] = noSymbol;
393 } else {
394 rankLast[numberOfBitsToDecrease]--;
395 if (nodeTable.numberOfBits[rankLast[numberOfBitsToDecrease]] !=
396 maxNumberOfBits - numberOfBitsToDecrease) {
397 rankLast[numberOfBitsToDecrease] =
398 noSymbol;
399 }
400 }
401 }
402
403 while (totalCost < 0) {
404 if (rankLast[1] ==
405 noSymbol) {
406 while (nodeTable.numberOfBits[n] == maxNumberOfBits) {
407 n--;
408 }
409 nodeTable.numberOfBits[n + 1]--;
410 rankLast[1] = n + 1;
411 totalCost++;
412 continue;
413 }
414 nodeTable.numberOfBits[rankLast[1] + 1]--;
415 rankLast[1]++;
416 totalCost++;
417 }
418
419 return maxNumberOfBits;
420 }
421
422
423
424
425 private static int compressWeights(final Object outputBase,
426 final long outputAddress,
427 final int outputSize, final byte[] weights,
428 final int weightsLength,
429 final HuffmanTableWriterWorkspace workspace) {
430 if (weightsLength <= 1) {
431 return 0;
432 }
433
434
435 final int[] counts = workspace.counts;
436 Histogram.count(weights, weightsLength, counts);
437 final int maxSymbol =
438 Histogram.findMaxSymbol(counts, Huffman.MAX_TABLE_LOG);
439 final int maxCount = Histogram.findLargestCount(counts, maxSymbol);
440
441 if (maxCount == weightsLength) {
442 return 1;
443 }
444 if (maxCount == 1) {
445 return 0;
446 }
447
448 final short[] normalizedCounts = workspace.normalizedCounts;
449
450 final int tableLog =
451 FiniteStateEntropy.optimalTableLog(Huffman.MAX_FSE_TABLE_LOG,
452 weightsLength, maxSymbol);
453 FiniteStateEntropy.normalizeCounts(normalizedCounts, tableLog, counts,
454 weightsLength, maxSymbol);
455
456 long output = outputAddress;
457 final long outputLimit = outputAddress + outputSize;
458
459
460 final int headerSize =
461 FiniteStateEntropy.writeNormalizedCounts(outputBase, output, outputSize,
462 normalizedCounts, maxSymbol,
463 tableLog);
464 output += headerSize;
465
466
467 final FseCompressionTable compressionTable = workspace.fseTable;
468 compressionTable.initialize(normalizedCounts, maxSymbol, tableLog);
469 final int compressedSize = FiniteStateEntropy.compress(outputBase, output,
470 (int) (outputLimit -
471 output),
472 weights,
473 weightsLength,
474 compressionTable);
475 if (compressedSize == 0) {
476 return 0;
477 }
478 output += compressedSize;
479
480 return (int) (output - outputAddress);
481 }
482 }