1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34 package org.waarp.compress.zstdunsafe;
35
36 import static org.waarp.compress.zstdunsafe.BitInputStream.*;
37 import static sun.misc.Unsafe.*;
38
39 class FiniteStateEntropy {
40 public static final int MAX_SYMBOL = 255;
41 public static final int MAX_TABLE_LOG = 12;
42 public static final int MIN_TABLE_LOG = 5;
43
44 private static final int[] REST_TO_BEAT =
45 new int[] { 0, 473195, 504333, 520860, 550000, 700000, 750000, 830000 };
46 private static final short UNASSIGNED = -2;
47 public static final String OUTPUT_BUFFER_TOO_SMALL =
48 "Output buffer too small";
49
50 private FiniteStateEntropy() {
51 }
52
53 public static int decompress(final FiniteStateEntropy.Table table,
54 final Object inputBase, final long inputAddress,
55 final long inputLimit,
56 final byte[] outputBuffer) {
57 final long outputAddress = ARRAY_BYTE_BASE_OFFSET;
58 final long outputLimit = outputAddress + outputBuffer.length;
59
60 long output = outputAddress;
61
62
63 final BitInputStream.Initializer initializer =
64 new BitInputStream.Initializer(inputBase, inputAddress, inputLimit);
65 initializer.initialize();
66 int bitsConsumed = initializer.getBitsConsumed();
67 long currentAddress = initializer.getCurrentAddress();
68 long bits = initializer.getBits();
69
70
71 int state1 = (int) peekBits(bitsConsumed, bits, table.log2Size);
72 bitsConsumed += table.log2Size;
73
74 BitInputStream.Loader loader =
75 new BitInputStream.Loader(inputBase, inputAddress, currentAddress, bits,
76 bitsConsumed);
77 loader.load();
78 bits = loader.getBits();
79 bitsConsumed = loader.getBitsConsumed();
80 currentAddress = loader.getCurrentAddress();
81
82
83 int state2 = (int) peekBits(bitsConsumed, bits, table.log2Size);
84 bitsConsumed += table.log2Size;
85
86 loader =
87 new BitInputStream.Loader(inputBase, inputAddress, currentAddress, bits,
88 bitsConsumed);
89 loader.load();
90 bits = loader.getBits();
91 bitsConsumed = loader.getBitsConsumed();
92 currentAddress = loader.getCurrentAddress();
93
94 final byte[] symbols = table.symbol;
95 final byte[] numbersOfBits = table.numberOfBits;
96 final int[] newStates = table.newState;
97
98
99 while (output <= outputLimit - 4) {
100 int numberOfBits;
101
102 UnsafeUtil.UNSAFE.putByte(outputBuffer, output, symbols[state1]);
103 numberOfBits = numbersOfBits[state1];
104 state1 = (int) (newStates[state1] +
105 peekBits(bitsConsumed, bits, numberOfBits));
106 bitsConsumed += numberOfBits;
107
108 UnsafeUtil.UNSAFE.putByte(outputBuffer, output + 1, symbols[state2]);
109 numberOfBits = numbersOfBits[state2];
110 state2 = (int) (newStates[state2] +
111 peekBits(bitsConsumed, bits, numberOfBits));
112 bitsConsumed += numberOfBits;
113
114 UnsafeUtil.UNSAFE.putByte(outputBuffer, output + 2, symbols[state1]);
115 numberOfBits = numbersOfBits[state1];
116 state1 = (int) (newStates[state1] +
117 peekBits(bitsConsumed, bits, numberOfBits));
118 bitsConsumed += numberOfBits;
119
120 UnsafeUtil.UNSAFE.putByte(outputBuffer, output + 3, symbols[state2]);
121 numberOfBits = numbersOfBits[state2];
122 state2 = (int) (newStates[state2] +
123 peekBits(bitsConsumed, bits, numberOfBits));
124 bitsConsumed += numberOfBits;
125
126 output += Constants.SIZE_OF_INT;
127
128 loader =
129 new BitInputStream.Loader(inputBase, inputAddress, currentAddress,
130 bits, bitsConsumed);
131 final boolean done = loader.load();
132 bitsConsumed = loader.getBitsConsumed();
133 bits = loader.getBits();
134 currentAddress = loader.getCurrentAddress();
135 if (done) {
136 break;
137 }
138 }
139
140 while (true) {
141 Util.verify(output <= outputLimit - 2, inputAddress,
142 "Output buffer is too small");
143 UnsafeUtil.UNSAFE.putByte(outputBuffer, output++, symbols[state1]);
144 final int numberOfBits = numbersOfBits[state1];
145 state1 = (int) (newStates[state1] +
146 peekBits(bitsConsumed, bits, numberOfBits));
147 bitsConsumed += numberOfBits;
148
149 loader =
150 new BitInputStream.Loader(inputBase, inputAddress, currentAddress,
151 bits, bitsConsumed);
152 loader.load();
153 bitsConsumed = loader.getBitsConsumed();
154 bits = loader.getBits();
155 currentAddress = loader.getCurrentAddress();
156
157 if (loader.isOverflow()) {
158 UnsafeUtil.UNSAFE.putByte(outputBuffer, output++, symbols[state2]);
159 break;
160 }
161
162 Util.verify(output <= outputLimit - 2, inputAddress,
163 "Output buffer is too small");
164 UnsafeUtil.UNSAFE.putByte(outputBuffer, output++, symbols[state2]);
165 final int numberOfBits1 = numbersOfBits[state2];
166 state2 = (int) (newStates[state2] +
167 peekBits(bitsConsumed, bits, numberOfBits1));
168 bitsConsumed += numberOfBits1;
169
170 loader =
171 new BitInputStream.Loader(inputBase, inputAddress, currentAddress,
172 bits, bitsConsumed);
173 loader.load();
174 bitsConsumed = loader.getBitsConsumed();
175 bits = loader.getBits();
176 currentAddress = loader.getCurrentAddress();
177
178 if (loader.isOverflow()) {
179 UnsafeUtil.UNSAFE.putByte(outputBuffer, output++, symbols[state1]);
180 break;
181 }
182 }
183
184 return (int) (output - outputAddress);
185 }
186
187 public static int compress(final Object outputBase, final long outputAddress,
188 final int outputSize, final byte[] input,
189 final int inputSize,
190 final FseCompressionTable table) {
191 return compress(outputBase, outputAddress, outputSize, input,
192 ARRAY_BYTE_BASE_OFFSET, inputSize, table);
193 }
194
195 public static int compress(final Object outputBase, final long outputAddress,
196 final int outputSize, final Object inputBase,
197 final long inputAddress, int inputSize,
198 final FseCompressionTable table) {
199 Util.checkArgument(outputSize >= Constants.SIZE_OF_LONG,
200 OUTPUT_BUFFER_TOO_SMALL);
201
202 long input = inputAddress + inputSize;
203
204 if (inputSize <= 2) {
205 return 0;
206 }
207
208 final BitOutputStream stream =
209 new BitOutputStream(outputBase, outputAddress, outputSize);
210
211 int state1;
212 int state2;
213
214 if ((inputSize & 1) != 0) {
215 input--;
216 state1 = table.begin(UnsafeUtil.UNSAFE.getByte(inputBase, input));
217
218 input--;
219 state2 = table.begin(UnsafeUtil.UNSAFE.getByte(inputBase, input));
220
221 input--;
222 state1 = table.encode(stream, state1,
223 UnsafeUtil.UNSAFE.getByte(inputBase, input));
224
225 stream.flush();
226 } else {
227 input--;
228 state2 = table.begin(UnsafeUtil.UNSAFE.getByte(inputBase, input));
229
230 input--;
231 state1 = table.begin(UnsafeUtil.UNSAFE.getByte(inputBase, input));
232 }
233
234
235 inputSize -= 2;
236
237 if ((inputSize & 2) != 0) {
238 input--;
239 state2 = table.encode(stream, state2,
240 UnsafeUtil.UNSAFE.getByte(inputBase, input));
241
242 input--;
243 state1 = table.encode(stream, state1,
244 UnsafeUtil.UNSAFE.getByte(inputBase, input));
245
246 stream.flush();
247 }
248
249
250 while (input > inputAddress) {
251 input--;
252 state2 = table.encode(stream, state2,
253 UnsafeUtil.UNSAFE.getByte(inputBase, input));
254
255 input--;
256 state1 = table.encode(stream, state1,
257 UnsafeUtil.UNSAFE.getByte(inputBase, input));
258
259 input--;
260 state2 = table.encode(stream, state2,
261 UnsafeUtil.UNSAFE.getByte(inputBase, input));
262
263 input--;
264 state1 = table.encode(stream, state1,
265 UnsafeUtil.UNSAFE.getByte(inputBase, input));
266
267 stream.flush();
268 }
269
270 table.finish(stream, state2);
271 table.finish(stream, state1);
272
273 return stream.close();
274 }
275
276 public static int optimalTableLog(final int maxTableLog, final int inputSize,
277 final int maxSymbol) {
278 if (inputSize <= 1) {
279 throw new IllegalArgumentException();
280 }
281
282 int result = maxTableLog;
283
284 result = Math.min(result, Util.highestBit((inputSize - 1)) -
285 2);
286
287
288 result = Math.max(result, Util.minTableLog(inputSize, maxSymbol));
289
290 result = Math.max(result, MIN_TABLE_LOG);
291 result = Math.min(result, MAX_TABLE_LOG);
292
293 return result;
294 }
295
296 public static void normalizeCounts(final short[] normalizedCounts,
297 final int tableLog, final int[] counts,
298 final int total, final int maxSymbol) {
299 Util.checkArgument(tableLog >= MIN_TABLE_LOG, "Unsupported FSE table size");
300 Util.checkArgument(tableLog <= MAX_TABLE_LOG, "FSE table size too large");
301 Util.checkArgument(tableLog >= Util.minTableLog(total, maxSymbol),
302 "FSE table size too small");
303
304 final long scale = 62L - tableLog;
305 final long step = (1L << 62) / total;
306 final long vstep = 1L << (scale - 20);
307
308 int stillToDistribute = 1 << tableLog;
309
310 int largest = 0;
311 short largestProbability = 0;
312 final int lowThreshold = total >>> tableLog;
313
314 for (int symbol = 0; symbol <= maxSymbol; symbol++) {
315 if (counts[symbol] == total) {
316 throw new IllegalArgumentException();
317 }
318 if (counts[symbol] == 0) {
319 normalizedCounts[symbol] = 0;
320 continue;
321 }
322 if (counts[symbol] <= lowThreshold) {
323 normalizedCounts[symbol] = -1;
324 stillToDistribute--;
325 } else {
326 short probability = (short) ((counts[symbol] * step) >>> scale);
327 if (probability < 8) {
328 final long restToBeat = vstep * REST_TO_BEAT[probability];
329 final long delta =
330 counts[symbol] * step - (((long) probability) << scale);
331 if (delta > restToBeat) {
332 probability++;
333 }
334 }
335 if (probability > largestProbability) {
336 largestProbability = probability;
337 largest = symbol;
338 }
339 normalizedCounts[symbol] = probability;
340 stillToDistribute -= probability;
341 }
342 }
343
344 if (-stillToDistribute >= (normalizedCounts[largest] >>> 1)) {
345
346
347 normalizeCounts2(normalizedCounts, tableLog, counts, total, maxSymbol);
348 } else {
349 normalizedCounts[largest] += (short) stillToDistribute;
350 }
351
352 }
353
354 private static void normalizeCounts2(final short[] normalizedCounts,
355 final int tableLog, final int[] counts,
356 int total, final int maxSymbol) {
357 int distributed = 0;
358
359 final int lowThreshold = total >>>
360 tableLog;
361 int lowOne = (total * 3) >>> (tableLog +
362 1);
363
364 for (int i = 0; i <= maxSymbol; i++) {
365 if (counts[i] == 0) {
366 normalizedCounts[i] = 0;
367 } else if (counts[i] <= lowThreshold) {
368 normalizedCounts[i] = -1;
369 distributed++;
370 total -= counts[i];
371 } else if (counts[i] <= lowOne) {
372 normalizedCounts[i] = 1;
373 distributed++;
374 total -= counts[i];
375 } else {
376 normalizedCounts[i] = UNASSIGNED;
377 }
378 }
379
380 final int normalizationFactor = 1 << tableLog;
381 int toDistribute = normalizationFactor - distributed;
382
383 if ((total / toDistribute) > lowOne) {
384
385 lowOne = ((total * 3) / (toDistribute * 2));
386 for (int i = 0; i <= maxSymbol; i++) {
387 if ((normalizedCounts[i] == UNASSIGNED) && (counts[i] <= lowOne)) {
388 normalizedCounts[i] = 1;
389 distributed++;
390 total -= counts[i];
391 }
392 }
393 toDistribute = normalizationFactor - distributed;
394 }
395
396 if (distributed == maxSymbol + 1) {
397
398
399
400 int maxValue = 0;
401 int maxCount = 0;
402 for (int i = 0; i <= maxSymbol; i++) {
403 if (counts[i] > maxCount) {
404 maxValue = i;
405 maxCount = counts[i];
406 }
407 }
408 normalizedCounts[maxValue] += (short) toDistribute;
409 return;
410 }
411
412 if (total == 0) {
413
414 for (int i = 0; toDistribute > 0; i = (i + 1) % (maxSymbol + 1)) {
415 if (normalizedCounts[i] > 0) {
416 toDistribute--;
417 normalizedCounts[i]++;
418 }
419 }
420 return;
421 }
422
423
424 final long vStepLog = 62 - tableLog;
425 final long mid = (1L << (vStepLog - 1)) - 1;
426 final long rStep = (((1L << vStepLog) * toDistribute) + mid) /
427 total;
428 long tmpTotal = mid;
429 for (int i = 0; i <= maxSymbol; i++) {
430 if (normalizedCounts[i] == UNASSIGNED) {
431 final long end = tmpTotal + (counts[i] * rStep);
432 final int sStart = (int) (tmpTotal >>> vStepLog);
433 final int sEnd = (int) (end >>> vStepLog);
434 final int weight = sEnd - sStart;
435
436 if (weight < 1) {
437 throw new AssertionError();
438 }
439 normalizedCounts[i] = (short) weight;
440 tmpTotal = end;
441 }
442 }
443
444 }
445
446 public static int writeNormalizedCounts(final Object outputBase,
447 final long outputAddress,
448 final int outputSize,
449 final short[] normalizedCounts,
450 final int maxSymbol,
451 final int tableLog) {
452 Util.checkArgument(tableLog <= MAX_TABLE_LOG, "FSE table too large");
453 Util.checkArgument(tableLog >= MIN_TABLE_LOG, "FSE table too small");
454
455 long output = outputAddress;
456 final long outputLimit = outputAddress + outputSize;
457
458 final int tableSize = 1 << tableLog;
459
460 int bitCount = 0;
461
462
463 int bitStream = (tableLog - MIN_TABLE_LOG);
464 bitCount += 4;
465
466 int remaining = tableSize + 1;
467 int threshold = tableSize;
468 int tableBitCount = tableLog + 1;
469
470 int symbol = 0;
471
472 boolean previousIs0 = false;
473 while (remaining > 1) {
474 if (previousIs0) {
475
476
477
478
479
480 int start = symbol;
481
482
483 while (normalizedCounts[symbol] == 0) {
484 symbol++;
485 }
486
487
488 while (symbol >= start + 24) {
489 start += 24;
490 bitStream |= (0xffff << bitCount);
491 Util.checkArgument(output + Constants.SIZE_OF_SHORT <= outputLimit,
492 OUTPUT_BUFFER_TOO_SMALL);
493
494 UnsafeUtil.UNSAFE.putShort(outputBase, output, (short) bitStream);
495 output += Constants.SIZE_OF_SHORT;
496
497
498 bitStream >>>= Short.SIZE;
499 }
500
501
502 while (symbol >= start + 3) {
503 start += 3;
504 bitStream |= 0x3 << bitCount;
505 bitCount += 2;
506 }
507
508
509 bitStream |= (symbol - start) << bitCount;
510 bitCount += 2;
511
512
513 if (bitCount > 16) {
514 Util.checkArgument(output + Constants.SIZE_OF_SHORT <= outputLimit,
515 OUTPUT_BUFFER_TOO_SMALL);
516
517 UnsafeUtil.UNSAFE.putShort(outputBase, output, (short) bitStream);
518 output += Constants.SIZE_OF_SHORT;
519
520 bitStream >>>= Short.SIZE;
521 bitCount -= Short.SIZE;
522 }
523 }
524
525 int count = normalizedCounts[symbol++];
526 final int max = (2 * threshold - 1) - remaining;
527 remaining -= count < 0? -count : count;
528 count++;
529 if (count >= threshold) {
530 count += max;
531 }
532 bitStream |= count << bitCount;
533 bitCount += tableBitCount;
534 bitCount -= (count < max? 1 : 0);
535 previousIs0 = (count == 1);
536
537 if (remaining < 1) {
538 throw new AssertionError();
539 }
540
541 while (remaining < threshold) {
542 tableBitCount--;
543 threshold >>= 1;
544 }
545
546
547 if (bitCount > 16) {
548 Util.checkArgument(output + Constants.SIZE_OF_SHORT <= outputLimit,
549 OUTPUT_BUFFER_TOO_SMALL);
550
551 UnsafeUtil.UNSAFE.putShort(outputBase, output, (short) bitStream);
552 output += Constants.SIZE_OF_SHORT;
553
554 bitStream >>>= Short.SIZE;
555 bitCount -= Short.SIZE;
556 }
557 }
558
559
560 Util.checkArgument(output + Constants.SIZE_OF_SHORT <= outputLimit,
561 OUTPUT_BUFFER_TOO_SMALL);
562 UnsafeUtil.UNSAFE.putShort(outputBase, output, (short) bitStream);
563 output += (bitCount + 7) / 8;
564
565 Util.checkArgument(symbol <= maxSymbol + 1, "Error");
566
567 return (int) (output - outputAddress);
568 }
569
570 public static final class Table {
571 final int[] newState;
572 final byte[] symbol;
573 final byte[] numberOfBits;
574 int log2Size;
575
576 public Table(final int log2Capacity) {
577 final int capacity = 1 << log2Capacity;
578 newState = new int[capacity];
579 symbol = new byte[capacity];
580 numberOfBits = new byte[capacity];
581 }
582
583 public Table(final int log2Size, final int[] newState, final byte[] symbol,
584 final byte[] numberOfBits) {
585 final int size = 1 << log2Size;
586 if (newState.length != size || symbol.length != size ||
587 numberOfBits.length != size) {
588 throw new IllegalArgumentException(
589 "Expected arrays to match provided size");
590 }
591
592 this.log2Size = log2Size;
593 this.newState = newState;
594 this.symbol = symbol;
595 this.numberOfBits = numberOfBits;
596 }
597 }
598 }