1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34 package org.waarp.compress.zstdsafe;
35
36 import java.util.Arrays;
37
38 import static org.waarp.compress.zstdsafe.Huffman.*;
39 import static org.waarp.compress.zstdsafe.Util.*;
40
41 final class HuffmanCompressionTable {
42 private final short[] values;
43 private final byte[] numberOfBits;
44
45 private int maxSymbol;
46 private int maxNumberOfBits;
47
48 public HuffmanCompressionTable(final int capacity) {
49 this.values = new short[capacity];
50 this.numberOfBits = new byte[capacity];
51 }
52
53 public static int optimalNumberOfBits(final int maxNumberOfBits,
54 final int inputSize,
55 final int maxSymbol) {
56 if (inputSize <= 1) {
57 throw new IllegalArgumentException();
58 }
59
60 int result = maxNumberOfBits;
61
62 result = Math.min(result, Util.highestBit((inputSize - 1)) -
63 1);
64
65
66 result = Math.max(result, minTableLog(inputSize, maxSymbol));
67
68 result = Math.max(result, MIN_TABLE_LOG);
69 result = Math.min(result, MAX_TABLE_LOG);
70
71 return result;
72 }
73
74 public void initialize(final int[] counts, final int maxSymbol,
75 int maxNumberOfBits,
76 final HuffmanCompressionTableWorkspace workspace) {
77 checkArgument(maxSymbol <= MAX_SYMBOL, "Max symbol value too large");
78
79 workspace.reset();
80
81 final NodeTable nodeTable = workspace.nodeTable;
82 nodeTable.reset();
83
84 final int lastNonZero = buildTree(counts, maxSymbol, nodeTable);
85
86
87 maxNumberOfBits =
88 setMaxHeight(nodeTable, lastNonZero, maxNumberOfBits, workspace);
89 checkArgument(maxNumberOfBits <= MAX_TABLE_LOG,
90 "Max number of bits larger than max table size");
91
92
93 final int symbolCount = maxSymbol + 1;
94 for (int node = 0; node < symbolCount; node++) {
95 final int symbol = nodeTable.symbols[node];
96 numberOfBits[symbol] = nodeTable.numberOfBits[node];
97 }
98
99 final short[] entriesPerRank = workspace.entriesPerRank;
100 final short[] valuesPerRank = workspace.valuesPerRank;
101
102 for (int n = 0; n <= lastNonZero; n++) {
103 entriesPerRank[nodeTable.numberOfBits[n]]++;
104 }
105
106
107 short startingValue = 0;
108 for (int rank = maxNumberOfBits; rank > 0; rank--) {
109 valuesPerRank[rank] =
110 startingValue;
111 startingValue += entriesPerRank[rank];
112 startingValue >>>= 1;
113 }
114
115 for (int n = 0; n <= maxSymbol; n++) {
116 values[n] =
117 valuesPerRank[numberOfBits[n]]++;
118 }
119
120 this.maxSymbol = maxSymbol;
121 this.maxNumberOfBits = maxNumberOfBits;
122 }
123
124 private int buildTree(final int[] counts, final int maxSymbol,
125 final NodeTable nodeTable) {
126
127
128 short current = 0;
129
130 for (int symbol = 0; symbol <= maxSymbol; symbol++) {
131 final int count = counts[symbol];
132
133
134 int position = current;
135 while (position > 1 && count > nodeTable.count[position - 1]) {
136 nodeTable.copyNode(position - 1, position);
137 position--;
138 }
139
140 nodeTable.count[position] = count;
141 nodeTable.symbols[position] = symbol;
142
143 current++;
144 }
145
146 int lastNonZero = maxSymbol;
147 while (nodeTable.count[lastNonZero] == 0) {
148 lastNonZero--;
149 }
150
151
152 final short nonLeafStart = MAX_SYMBOL_COUNT;
153 current = nonLeafStart;
154
155 int currentLeaf = lastNonZero;
156
157
158 int currentNonLeaf = current;
159 nodeTable.count[current] =
160 nodeTable.count[currentLeaf] + nodeTable.count[currentLeaf - 1];
161 nodeTable.parents[currentLeaf] = current;
162 nodeTable.parents[currentLeaf - 1] = current;
163 current++;
164 currentLeaf -= 2;
165
166 final int root = MAX_SYMBOL_COUNT + lastNonZero - 1;
167
168
169 for (int n = current; n <= root; n++) {
170 nodeTable.count[n] = 1 << 30;
171 }
172
173
174 while (current <= root) {
175 final int child1;
176 if (currentLeaf >= 0 &&
177 nodeTable.count[currentLeaf] < nodeTable.count[currentNonLeaf]) {
178 child1 = currentLeaf--;
179 } else {
180 child1 = currentNonLeaf++;
181 }
182
183 final int child2;
184 if (currentLeaf >= 0 &&
185 nodeTable.count[currentLeaf] < nodeTable.count[currentNonLeaf]) {
186 child2 = currentLeaf--;
187 } else {
188 child2 = currentNonLeaf++;
189 }
190
191 nodeTable.count[current] =
192 nodeTable.count[child1] + nodeTable.count[child2];
193 nodeTable.parents[child1] = current;
194 nodeTable.parents[child2] = current;
195 current++;
196 }
197
198
199 nodeTable.numberOfBits[root] = 0;
200 for (int n = root - 1; n >= nonLeafStart; n--) {
201 final short parent = nodeTable.parents[n];
202 nodeTable.numberOfBits[n] = (byte) (nodeTable.numberOfBits[parent] + 1);
203 }
204
205 for (int n = 0; n <= lastNonZero; n++) {
206 final short parent = nodeTable.parents[n];
207 nodeTable.numberOfBits[n] = (byte) (nodeTable.numberOfBits[parent] + 1);
208 }
209
210 return lastNonZero;
211 }
212
213
214
215
216
217
218 public void encodeSymbol(final BitOutputStream output, final int symbol) {
219 output.addBitsFast(values[symbol], numberOfBits[symbol]);
220 }
221
222 public int write(final byte[] outputBase, final int outputAddress,
223 final int outputSize,
224 final HuffmanTableWriterWorkspace workspace) {
225 final byte[] weights = workspace.weights;
226
227 int output = outputAddress;
228
229 final int maxNumberOfBits1 = this.maxNumberOfBits;
230 final int maxSymbol1 = this.maxSymbol;
231
232
233 for (int symbol = 0; symbol < maxSymbol1; symbol++) {
234 final int bits = numberOfBits[symbol];
235
236 if (bits == 0) {
237 weights[symbol] = 0;
238 } else {
239 weights[symbol] = (byte) (maxNumberOfBits1 + 1 - bits);
240 }
241 }
242
243
244 int size = compressWeights(outputBase, output + 1, outputSize - 1, weights,
245 maxSymbol1, workspace);
246
247 if (maxSymbol1 > 127 && size > 127) {
248
249
250 throw new AssertionError();
251 }
252
253 if (size != 0 && size != 1 && size < maxSymbol1 / 2) {
254
255
256
257
258
259 outputBase[output] = (byte) size;
260 return size + 1;
261 } else {
262
263
264
265
266 size = (maxSymbol1 + 1) / 2;
267 checkArgument(size + 1 <= outputSize,
268 "Output size too small");
269
270
271
272 outputBase[output] = (byte) (127 + maxSymbol1);
273 output++;
274
275 weights[maxSymbol1] =
276 0;
277 for (int i = 0; i < maxSymbol1; i += 2) {
278 outputBase[output] = (byte) ((weights[i] << 4) + weights[i + 1]);
279 output++;
280 }
281
282 return output - outputAddress;
283 }
284 }
285
286
287
288
289 public boolean isValid(final int[] counts, final int maxSymbol) {
290 if (maxSymbol > this.maxSymbol) {
291
292 return false;
293 }
294
295 for (int symbol = 0; symbol <= maxSymbol; ++symbol) {
296 if (counts[symbol] != 0 && numberOfBits[symbol] == 0) {
297 return false;
298 }
299 }
300 return true;
301 }
302
303 public int estimateCompressedSize(final int[] counts, final int maxSymbol) {
304 int numberOfBits = 0;
305 for (int symbol = 0; symbol <= Math.min(maxSymbol, this.maxSymbol);
306 symbol++) {
307 numberOfBits += this.numberOfBits[symbol] * counts[symbol];
308 }
309
310 return numberOfBits >>> 3;
311 }
312
313
314 private static int setMaxHeight(final NodeTable nodeTable,
315 final int lastNonZero,
316 final int maxNumberOfBits,
317 final HuffmanCompressionTableWorkspace workspace) {
318 final int largestBits = nodeTable.numberOfBits[lastNonZero];
319
320 if (largestBits <= maxNumberOfBits) {
321 return largestBits;
322 }
323
324
325 int totalCost = 0;
326 final int baseCost = 1 << (largestBits - maxNumberOfBits);
327 int n = lastNonZero;
328
329 while (nodeTable.numberOfBits[n] > maxNumberOfBits) {
330 totalCost += baseCost - (1 << (largestBits - nodeTable.numberOfBits[n]));
331 nodeTable.numberOfBits[n] = (byte) maxNumberOfBits;
332 n--;
333 }
334
335 while (nodeTable.numberOfBits[n] == maxNumberOfBits) {
336 n--;
337 }
338
339
340 totalCost >>>= (largestBits -
341 maxNumberOfBits);
342
343
344 final int noSymbol = 0xF0F0F0F0;
345 final int[] rankLast = workspace.rankLast;
346 Arrays.fill(rankLast, noSymbol);
347
348
349 int currentNbBits = maxNumberOfBits;
350 for (int pos = n; pos >= 0; pos--) {
351 if (nodeTable.numberOfBits[pos] >= currentNbBits) {
352 continue;
353 }
354 currentNbBits = nodeTable.numberOfBits[pos];
355 rankLast[maxNumberOfBits - currentNbBits] = pos;
356 }
357
358 while (totalCost > 0) {
359 int numberOfBitsToDecrease = Util.highestBit(totalCost) + 1;
360 for (; numberOfBitsToDecrease > 1; numberOfBitsToDecrease--) {
361 final int highPosition = rankLast[numberOfBitsToDecrease];
362 final int lowPosition = rankLast[numberOfBitsToDecrease - 1];
363 if (highPosition == noSymbol) {
364 continue;
365 }
366 if (lowPosition == noSymbol) {
367 break;
368 }
369 final int highTotal = nodeTable.count[highPosition];
370 final int lowTotal = 2 * nodeTable.count[lowPosition];
371 if (highTotal <= lowTotal) {
372 break;
373 }
374 }
375
376
377
378 while ((numberOfBitsToDecrease <= MAX_TABLE_LOG) &&
379 (rankLast[numberOfBitsToDecrease] == noSymbol)) {
380 numberOfBitsToDecrease++;
381 }
382 totalCost -= 1 << (numberOfBitsToDecrease - 1);
383 if (rankLast[numberOfBitsToDecrease - 1] == noSymbol) {
384 rankLast[numberOfBitsToDecrease - 1] =
385 rankLast[numberOfBitsToDecrease];
386 }
387 nodeTable.numberOfBits[rankLast[numberOfBitsToDecrease]]++;
388 if (rankLast[numberOfBitsToDecrease] ==
389 0) {
390 rankLast[numberOfBitsToDecrease] = noSymbol;
391 } else {
392 rankLast[numberOfBitsToDecrease]--;
393 if (nodeTable.numberOfBits[rankLast[numberOfBitsToDecrease]] !=
394 maxNumberOfBits - numberOfBitsToDecrease) {
395 rankLast[numberOfBitsToDecrease] =
396 noSymbol;
397 }
398 }
399 }
400
401 while (totalCost < 0) {
402 if (rankLast[1] ==
403 noSymbol) {
404 while (nodeTable.numberOfBits[n] == maxNumberOfBits) {
405 n--;
406 }
407 nodeTable.numberOfBits[n + 1]--;
408 rankLast[1] = n + 1;
409 totalCost++;
410 continue;
411 }
412 nodeTable.numberOfBits[rankLast[1] + 1]--;
413 rankLast[1]++;
414 totalCost++;
415 }
416
417 return maxNumberOfBits;
418 }
419
420
421
422
423 private static int compressWeights(final byte[] outputBase,
424 final int outputAddress,
425 final int outputSize, final byte[] weights,
426 final int weightsLength,
427 final HuffmanTableWriterWorkspace workspace) {
428 if (weightsLength <= 1) {
429 return 0;
430 }
431
432
433 final int[] counts = workspace.counts;
434 Histogram.count(weights, weightsLength, counts);
435 final int maxSymbol = Histogram.findMaxSymbol(counts, MAX_TABLE_LOG);
436 final int maxCount = Histogram.findLargestCount(counts, maxSymbol);
437
438 if (maxCount == weightsLength) {
439 return 1;
440 }
441 if (maxCount == 1) {
442 return 0;
443 }
444
445 final short[] normalizedCounts = workspace.normalizedCounts;
446
447 final int tableLog =
448 FiniteStateEntropy.optimalTableLog(MAX_FSE_TABLE_LOG, weightsLength,
449 maxSymbol);
450 FiniteStateEntropy.normalizeCounts(normalizedCounts, tableLog, counts,
451 weightsLength, maxSymbol);
452
453 int output = outputAddress;
454 final int outputLimit = outputAddress + outputSize;
455
456
457 final int headerSize =
458 FiniteStateEntropy.writeNormalizedCounts(outputBase, output, outputSize,
459 normalizedCounts, maxSymbol,
460 tableLog);
461 output += headerSize;
462
463
464 final FseCompressionTable compressionTable = workspace.fseTable;
465 compressionTable.initialize(normalizedCounts, maxSymbol, tableLog);
466 final int compressedSize =
467 FiniteStateEntropy.compress(outputBase, output, outputLimit - output,
468 weights, weightsLength, compressionTable);
469 if (compressedSize == 0) {
470 return 0;
471 }
472 output += compressedSize;
473
474 return output - outputAddress;
475 }
476 }