1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34 package org.waarp.compress.zstdsafe;
35
36 import static org.waarp.compress.zstdsafe.BitInputStream.*;
37 import static org.waarp.compress.zstdsafe.Constants.*;
38 import static org.waarp.compress.zstdsafe.UnsafeUtil.*;
39 import static org.waarp.compress.zstdsafe.Util.*;
40
41 class FiniteStateEntropy {
42 public static final int MAX_SYMBOL = 255;
43 public static final int MAX_TABLE_LOG = 12;
44 public static final int MIN_TABLE_LOG = 5;
45
46 private static final int[] REST_TO_BEAT =
47 new int[] { 0, 473195, 504333, 520860, 550000, 700000, 750000, 830000 };
48 private static final short UNASSIGNED = -2;
49 public static final String OUTPUT_BUFFER_TOO_SMALL =
50 "Output buffer too small";
51
52 private FiniteStateEntropy() {
53 }
54
55 public static int decompress(final FiniteStateEntropy.Table table,
56 final byte[] inputBase, final int inputAddress,
57 final int inputLimit,
58 final byte[] outputBuffer) {
59 final int outputAddress = 0;
60 final int outputLimit = outputAddress + outputBuffer.length;
61
62 int output = outputAddress;
63
64
65 final BitInputStream.Initializer initializer =
66 new BitInputStream.Initializer(inputBase, inputAddress, inputLimit);
67 initializer.initialize();
68 int bitsConsumed = initializer.getBitsConsumed();
69 int currentAddress = initializer.getCurrentAddress();
70 long bits = initializer.getBits();
71
72
73 int state1 = (int) peekBits(bitsConsumed, bits, table.log2Size);
74 bitsConsumed += table.log2Size;
75
76 BitInputStream.Loader loader =
77 new BitInputStream.Loader(inputBase, inputAddress, currentAddress, bits,
78 bitsConsumed);
79 loader.load();
80 bits = loader.getBits();
81 bitsConsumed = loader.getBitsConsumed();
82 currentAddress = loader.getCurrentAddress();
83
84
85 int state2 = (int) peekBits(bitsConsumed, bits, table.log2Size);
86 bitsConsumed += table.log2Size;
87
88 loader =
89 new BitInputStream.Loader(inputBase, inputAddress, currentAddress, bits,
90 bitsConsumed);
91 loader.load();
92 bits = loader.getBits();
93 bitsConsumed = loader.getBitsConsumed();
94 currentAddress = loader.getCurrentAddress();
95
96 final byte[] symbols = table.symbol;
97 final byte[] numbersOfBits = table.numberOfBits;
98 final int[] newStates = table.newState;
99
100
101 while (output <= outputLimit - 4) {
102 int numberOfBits;
103
104 outputBuffer[output] = symbols[state1];
105 numberOfBits = numbersOfBits[state1];
106 state1 = (int) (newStates[state1] +
107 peekBits(bitsConsumed, bits, numberOfBits));
108 bitsConsumed += numberOfBits;
109
110 outputBuffer[output + 1] = symbols[state2];
111 numberOfBits = numbersOfBits[state2];
112 state2 = (int) (newStates[state2] +
113 peekBits(bitsConsumed, bits, numberOfBits));
114 bitsConsumed += numberOfBits;
115
116 outputBuffer[output + 2] = symbols[state1];
117 numberOfBits = numbersOfBits[state1];
118 state1 = (int) (newStates[state1] +
119 peekBits(bitsConsumed, bits, numberOfBits));
120 bitsConsumed += numberOfBits;
121
122 outputBuffer[output + 3] = symbols[state2];
123 numberOfBits = numbersOfBits[state2];
124 state2 = (int) (newStates[state2] +
125 peekBits(bitsConsumed, bits, numberOfBits));
126 bitsConsumed += numberOfBits;
127
128 output += SIZE_OF_INT;
129
130 loader =
131 new BitInputStream.Loader(inputBase, inputAddress, currentAddress,
132 bits, bitsConsumed);
133 final boolean done = loader.load();
134 bitsConsumed = loader.getBitsConsumed();
135 bits = loader.getBits();
136 currentAddress = loader.getCurrentAddress();
137 if (done) {
138 break;
139 }
140 }
141
142 while (true) {
143 verify(output <= outputLimit - 2, inputAddress,
144 "Output buffer is too small");
145 outputBuffer[output++] = symbols[state1];
146 final int numberOfBits = numbersOfBits[state1];
147 state1 = (int) (newStates[state1] +
148 peekBits(bitsConsumed, bits, numberOfBits));
149 bitsConsumed += numberOfBits;
150
151 loader =
152 new BitInputStream.Loader(inputBase, inputAddress, currentAddress,
153 bits, bitsConsumed);
154 loader.load();
155 bitsConsumed = loader.getBitsConsumed();
156 bits = loader.getBits();
157 currentAddress = loader.getCurrentAddress();
158
159 if (loader.isOverflow()) {
160 outputBuffer[output++] = symbols[state2];
161 break;
162 }
163
164 verify(output <= outputLimit - 2, inputAddress,
165 "Output buffer is too small");
166 outputBuffer[output++] = symbols[state2];
167 final int numberOfBits1 = numbersOfBits[state2];
168 state2 = (int) (newStates[state2] +
169 peekBits(bitsConsumed, bits, numberOfBits1));
170 bitsConsumed += numberOfBits1;
171
172 loader =
173 new BitInputStream.Loader(inputBase, inputAddress, currentAddress,
174 bits, bitsConsumed);
175 loader.load();
176 bitsConsumed = loader.getBitsConsumed();
177 bits = loader.getBits();
178 currentAddress = loader.getCurrentAddress();
179
180 if (loader.isOverflow()) {
181 outputBuffer[output++] = symbols[state1];
182 break;
183 }
184 }
185
186 return output - outputAddress;
187 }
188
189 public static int compress(final byte[] outputBase, final int outputAddress,
190 final int outputSize, final byte[] input,
191 final int inputSize,
192 final FseCompressionTable table) {
193 return compress(outputBase, outputAddress, outputSize, input, 0, inputSize,
194 table);
195 }
196
197 public static int compress(final byte[] outputBase, final int outputAddress,
198 final int outputSize, final byte[] inputBase,
199 final int inputAddress, int inputSize,
200 final FseCompressionTable table) {
201 checkArgument(outputSize >= SIZE_OF_LONG, OUTPUT_BUFFER_TOO_SMALL);
202
203 int input = inputAddress + inputSize;
204
205 if (inputSize <= 2) {
206 return 0;
207 }
208
209 final BitOutputStream stream =
210 new BitOutputStream(outputBase, outputAddress, outputSize);
211
212 int state1;
213 int state2;
214
215 if ((inputSize & 1) != 0) {
216 input--;
217 state1 = table.begin(inputBase[input]);
218
219 input--;
220 state2 = table.begin(inputBase[input]);
221
222 input--;
223 state1 = table.encode(stream, state1, inputBase[input]);
224
225 stream.flush();
226 } else {
227 input--;
228 state2 = table.begin(inputBase[input]);
229
230 input--;
231 state1 = table.begin(inputBase[input]);
232 }
233
234
235 inputSize -= 2;
236
237 if ((inputSize & 2) != 0) {
238 input--;
239 state2 = table.encode(stream, state2, inputBase[input]);
240
241 input--;
242 state1 = table.encode(stream, state1, inputBase[input]);
243
244 stream.flush();
245 }
246
247
248 while (input > inputAddress) {
249 input--;
250 state2 = table.encode(stream, state2, inputBase[input]);
251
252 input--;
253 state1 = table.encode(stream, state1, inputBase[input]);
254
255 input--;
256 state2 = table.encode(stream, state2, inputBase[input]);
257
258 input--;
259 state1 = table.encode(stream, state1, inputBase[input]);
260
261 stream.flush();
262 }
263
264 table.finish(stream, state2);
265 table.finish(stream, state1);
266
267 return stream.close();
268 }
269
270 public static int optimalTableLog(final int maxTableLog, final int inputSize,
271 final int maxSymbol) {
272 if (inputSize <= 1) {
273 throw new IllegalArgumentException();
274 }
275
276 int result = maxTableLog;
277
278 result = Math.min(result, Util.highestBit((inputSize - 1)) -
279 2);
280
281
282 result = Math.max(result, Util.minTableLog(inputSize, maxSymbol));
283
284 result = Math.max(result, MIN_TABLE_LOG);
285 result = Math.min(result, MAX_TABLE_LOG);
286
287 return result;
288 }
289
290 public static void normalizeCounts(final short[] normalizedCounts,
291 final int tableLog, final int[] counts,
292 final int total, final int maxSymbol) {
293 checkArgument(tableLog >= MIN_TABLE_LOG, "Unsupported FSE table size");
294 checkArgument(tableLog <= MAX_TABLE_LOG, "FSE table size too large");
295 checkArgument(tableLog >= Util.minTableLog(total, maxSymbol),
296 "FSE table size too small");
297
298 final int scale = 62 - tableLog;
299 final long step = (1L << 62) / total;
300 final long vstep = 1L << (scale - 20);
301
302 int stillToDistribute = 1 << tableLog;
303
304 int largest = 0;
305 short largestProbability = 0;
306 final int lowThreshold = total >>> tableLog;
307
308 for (int symbol = 0; symbol <= maxSymbol; symbol++) {
309 if (counts[symbol] == total) {
310 throw new IllegalArgumentException();
311 }
312 if (counts[symbol] == 0) {
313 normalizedCounts[symbol] = 0;
314 continue;
315 }
316 if (counts[symbol] <= lowThreshold) {
317 normalizedCounts[symbol] = -1;
318 stillToDistribute--;
319 } else {
320 short probability = (short) ((counts[symbol] * step) >>> scale);
321 if (probability < 8) {
322 final long restToBeat = vstep * REST_TO_BEAT[probability];
323 final long delta =
324 counts[symbol] * step - (((long) probability) << scale);
325 if (delta > restToBeat) {
326 probability++;
327 }
328 }
329 if (probability > largestProbability) {
330 largestProbability = probability;
331 largest = symbol;
332 }
333 normalizedCounts[symbol] = probability;
334 stillToDistribute -= probability;
335 }
336 }
337
338 if (-stillToDistribute >= (normalizedCounts[largest] >>> 1)) {
339
340
341 normalizeCounts2(normalizedCounts, tableLog, counts, total, maxSymbol);
342 } else {
343 normalizedCounts[largest] += (short) stillToDistribute;
344 }
345
346 }
347
348 private static void normalizeCounts2(final short[] normalizedCounts,
349 final int tableLog, final int[] counts,
350 int total, final int maxSymbol) {
351 int distributed = 0;
352
353 final int lowThreshold = total >>>
354 tableLog;
355 int lowOne = (total * 3) >>> (tableLog +
356 1);
357
358 for (int i = 0; i <= maxSymbol; i++) {
359 if (counts[i] == 0) {
360 normalizedCounts[i] = 0;
361 } else if (counts[i] <= lowThreshold) {
362 normalizedCounts[i] = -1;
363 distributed++;
364 total -= counts[i];
365 } else if (counts[i] <= lowOne) {
366 normalizedCounts[i] = 1;
367 distributed++;
368 total -= counts[i];
369 } else {
370 normalizedCounts[i] = UNASSIGNED;
371 }
372 }
373
374 final int normalizationFactor = 1 << tableLog;
375 int toDistribute = normalizationFactor - distributed;
376
377 if ((total / toDistribute) > lowOne) {
378
379 lowOne = ((total * 3) / (toDistribute * 2));
380 for (int i = 0; i <= maxSymbol; i++) {
381 if ((normalizedCounts[i] == UNASSIGNED) && (counts[i] <= lowOne)) {
382 normalizedCounts[i] = 1;
383 distributed++;
384 total -= counts[i];
385 }
386 }
387 toDistribute = normalizationFactor - distributed;
388 }
389
390 if (distributed == maxSymbol + 1) {
391
392
393
394 int maxValue = 0;
395 int maxCount = 0;
396 for (int i = 0; i <= maxSymbol; i++) {
397 if (counts[i] > maxCount) {
398 maxValue = i;
399 maxCount = counts[i];
400 }
401 }
402 normalizedCounts[maxValue] += (short) toDistribute;
403 return;
404 }
405
406 if (total == 0) {
407
408 for (int i = 0; toDistribute > 0; i = (i + 1) % (maxSymbol + 1)) {
409 if (normalizedCounts[i] > 0) {
410 toDistribute--;
411 normalizedCounts[i]++;
412 }
413 }
414 return;
415 }
416
417
418 final long vStepLog = 62 - tableLog;
419 final long mid = (1L << (vStepLog - 1)) - 1;
420 final long rStep = (((1L << vStepLog) * toDistribute) + mid) /
421 total;
422 long tmpTotal = mid;
423 for (int i = 0; i <= maxSymbol; i++) {
424 if (normalizedCounts[i] == UNASSIGNED) {
425 final long end = tmpTotal + (counts[i] * rStep);
426 final int sStart = (int) (tmpTotal >>> vStepLog);
427 final int sEnd = (int) (end >>> vStepLog);
428 final int weight = sEnd - sStart;
429
430 if (weight < 1) {
431 throw new AssertionError();
432 }
433 normalizedCounts[i] = (short) weight;
434 tmpTotal = end;
435 }
436 }
437
438 }
439
440 public static int writeNormalizedCounts(final byte[] outputBase,
441 final int outputAddress,
442 final int outputSize,
443 final short[] normalizedCounts,
444 final int maxSymbol,
445 final int tableLog) {
446 checkArgument(tableLog <= MAX_TABLE_LOG, "FSE table too large");
447 checkArgument(tableLog >= MIN_TABLE_LOG, "FSE table too small");
448
449 int output = outputAddress;
450 final int outputLimit = outputAddress + outputSize;
451
452 final int tableSize = 1 << tableLog;
453
454 int bitCount = 0;
455
456
457 int bitStream = (tableLog - MIN_TABLE_LOG);
458 bitCount += 4;
459
460 int remaining = tableSize + 1;
461 int threshold = tableSize;
462 int tableBitCount = tableLog + 1;
463
464 int symbol = 0;
465
466 boolean previousIs0 = false;
467 while (remaining > 1) {
468 if (previousIs0) {
469
470
471
472
473
474 int start = symbol;
475
476
477 while (normalizedCounts[symbol] == 0) {
478 symbol++;
479 }
480
481
482 while (symbol >= start + 24) {
483 start += 24;
484 bitStream |= (0xFFFF << bitCount);
485 checkArgument(output + SIZE_OF_SHORT <= outputLimit,
486 OUTPUT_BUFFER_TOO_SMALL);
487
488 putShort(outputBase, output, (short) bitStream);
489 output += SIZE_OF_SHORT;
490
491
492 bitStream >>>= Short.SIZE;
493 }
494
495
496 while (symbol >= start + 3) {
497 start += 3;
498 bitStream |= 0x3 << bitCount;
499 bitCount += 2;
500 }
501
502
503 bitStream |= (symbol - start) << bitCount;
504 bitCount += 2;
505
506
507 if (bitCount > 16) {
508 checkArgument(output + SIZE_OF_SHORT <= outputLimit,
509 OUTPUT_BUFFER_TOO_SMALL);
510
511 putShort(outputBase, output, (short) bitStream);
512 output += SIZE_OF_SHORT;
513
514 bitStream >>>= Short.SIZE;
515 bitCount -= Short.SIZE;
516 }
517 }
518
519 int count = normalizedCounts[symbol++];
520 final int max = (2 * threshold - 1) - remaining;
521 remaining -= count < 0? -count : count;
522 count++;
523 if (count >= threshold) {
524 count += max;
525 }
526 bitStream |= count << bitCount;
527 bitCount += tableBitCount;
528 bitCount -= (count < max? 1 : 0);
529 previousIs0 = (count == 1);
530
531 if (remaining < 1) {
532 throw new AssertionError();
533 }
534
535 while (remaining < threshold) {
536 tableBitCount--;
537 threshold >>= 1;
538 }
539
540
541 if (bitCount > 16) {
542 checkArgument(output + SIZE_OF_SHORT <= outputLimit,
543 OUTPUT_BUFFER_TOO_SMALL);
544
545 putShort(outputBase, output, (short) bitStream);
546 output += SIZE_OF_SHORT;
547
548 bitStream >>>= Short.SIZE;
549 bitCount -= Short.SIZE;
550 }
551 }
552
553
554 checkArgument(output + SIZE_OF_SHORT <= outputLimit,
555 OUTPUT_BUFFER_TOO_SMALL);
556 putShort(outputBase, output, (short) bitStream);
557 output += (bitCount + 7) / 8;
558
559 checkArgument(symbol <= maxSymbol + 1, "Error");
560
561 return output - outputAddress;
562 }
563
564 public static final class Table {
565 int log2Size;
566 final int[] newState;
567 final byte[] symbol;
568 final byte[] numberOfBits;
569
570 public Table(final int log2Capacity) {
571 final int capacity = 1 << log2Capacity;
572 newState = new int[capacity];
573 symbol = new byte[capacity];
574 numberOfBits = new byte[capacity];
575 }
576
577 public Table(final int log2Size, final int[] newState, final byte[] symbol,
578 final byte[] numberOfBits) {
579 final int size = 1 << log2Size;
580 if (newState.length != size || symbol.length != size ||
581 numberOfBits.length != size) {
582 throw new IllegalArgumentException(
583 "Expected arrays to match provided size");
584 }
585
586 this.log2Size = log2Size;
587 this.newState = newState;
588 this.symbol = symbol;
589 this.numberOfBits = numberOfBits;
590 }
591 }
592 }