View Javadoc
1   /*
2    * This file is part of Waarp Project (named also Waarp or GG).
3    *
4    *  Copyright (c) 2019, Waarp SAS, and individual contributors by the @author
5    *  tags. See the COPYRIGHT.txt in the distribution for a full listing of
6    * individual contributors.
7    *
8    *  All Waarp Project is free software: you can redistribute it and/or
9    * modify it under the terms of the GNU General Public License as published by
10   * the Free Software Foundation, either version 3 of the License, or (at your
11   * option) any later version.
12   *
13   * Waarp is distributed in the hope that it will be useful, but WITHOUT ANY
14   * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
15   * A PARTICULAR PURPOSE. See the GNU General Public License for more details.
16   *
17   *  You should have received a copy of the GNU General Public License along with
18   * Waarp . If not, see <http://www.gnu.org/licenses/>.
19   */
20  
21  /*
22   * Licensed under the Apache License, Version 2.0 (the "License");
23   * you may not use this file except in compliance with the License.
24   * You may obtain a copy of the License at
25   *
26   *     http://www.apache.org/licenses/LICENSE-2.0
27   *
28   * Unless required by applicable law or agreed to in writing, software
29   * distributed under the License is distributed on an "AS IS" BASIS,
30   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31   * See the License for the specific language governing permissions and
32   * limitations under the License.
33   */
34  package org.waarp.compress.zstdsafe;
35  
36  import static org.waarp.compress.zstdsafe.BitInputStream.*;
37  import static org.waarp.compress.zstdsafe.Constants.*;
38  import static org.waarp.compress.zstdsafe.UnsafeUtil.*;
39  import static org.waarp.compress.zstdsafe.Util.*;
40  
41  class FiniteStateEntropy {
42    public static final int MAX_SYMBOL = 255;
43    public static final int MAX_TABLE_LOG = 12;
44    public static final int MIN_TABLE_LOG = 5;
45  
46    private static final int[] REST_TO_BEAT =
47        new int[] { 0, 473195, 504333, 520860, 550000, 700000, 750000, 830000 };
48    private static final short UNASSIGNED = -2;
49    public static final String OUTPUT_BUFFER_TOO_SMALL =
50        "Output buffer too small";
51  
52    private FiniteStateEntropy() {
53    }
54  
55    public static int decompress(final FiniteStateEntropy.Table table,
56                                 final byte[] inputBase, final int inputAddress,
57                                 final int inputLimit,
58                                 final byte[] outputBuffer) {
59      final int outputAddress = 0;
60      final int outputLimit = outputAddress + outputBuffer.length;
61  
62      int output = outputAddress;
63  
64      // initialize bit stream
65      final BitInputStream.Initializer initializer =
66          new BitInputStream.Initializer(inputBase, inputAddress, inputLimit);
67      initializer.initialize();
68      int bitsConsumed = initializer.getBitsConsumed();
69      int currentAddress = initializer.getCurrentAddress();
70      long bits = initializer.getBits();
71  
72      // initialize first FSE stream
73      int state1 = (int) peekBits(bitsConsumed, bits, table.log2Size);
74      bitsConsumed += table.log2Size;
75  
76      BitInputStream.Loader loader =
77          new BitInputStream.Loader(inputBase, inputAddress, currentAddress, bits,
78                                    bitsConsumed);
79      loader.load();
80      bits = loader.getBits();
81      bitsConsumed = loader.getBitsConsumed();
82      currentAddress = loader.getCurrentAddress();
83  
84      // initialize second FSE stream
85      int state2 = (int) peekBits(bitsConsumed, bits, table.log2Size);
86      bitsConsumed += table.log2Size;
87  
88      loader =
89          new BitInputStream.Loader(inputBase, inputAddress, currentAddress, bits,
90                                    bitsConsumed);
91      loader.load();
92      bits = loader.getBits();
93      bitsConsumed = loader.getBitsConsumed();
94      currentAddress = loader.getCurrentAddress();
95  
96      final byte[] symbols = table.symbol;
97      final byte[] numbersOfBits = table.numberOfBits;
98      final int[] newStates = table.newState;
99  
100     // decode 4 symbols per loop
101     while (output <= outputLimit - 4) {
102       int numberOfBits;
103 
104       outputBuffer[output] = symbols[state1];
105       numberOfBits = numbersOfBits[state1];
106       state1 = (int) (newStates[state1] +
107                       peekBits(bitsConsumed, bits, numberOfBits));
108       bitsConsumed += numberOfBits;
109 
110       outputBuffer[output + 1] = symbols[state2];
111       numberOfBits = numbersOfBits[state2];
112       state2 = (int) (newStates[state2] +
113                       peekBits(bitsConsumed, bits, numberOfBits));
114       bitsConsumed += numberOfBits;
115 
116       outputBuffer[output + 2] = symbols[state1];
117       numberOfBits = numbersOfBits[state1];
118       state1 = (int) (newStates[state1] +
119                       peekBits(bitsConsumed, bits, numberOfBits));
120       bitsConsumed += numberOfBits;
121 
122       outputBuffer[output + 3] = symbols[state2];
123       numberOfBits = numbersOfBits[state2];
124       state2 = (int) (newStates[state2] +
125                       peekBits(bitsConsumed, bits, numberOfBits));
126       bitsConsumed += numberOfBits;
127 
128       output += SIZE_OF_INT;
129 
130       loader =
131           new BitInputStream.Loader(inputBase, inputAddress, currentAddress,
132                                     bits, bitsConsumed);
133       final boolean done = loader.load();
134       bitsConsumed = loader.getBitsConsumed();
135       bits = loader.getBits();
136       currentAddress = loader.getCurrentAddress();
137       if (done) {
138         break;
139       }
140     }
141 
142     while (true) {
143       verify(output <= outputLimit - 2, inputAddress,
144              "Output buffer is too small");
145       outputBuffer[output++] = symbols[state1];
146       final int numberOfBits = numbersOfBits[state1];
147       state1 = (int) (newStates[state1] +
148                       peekBits(bitsConsumed, bits, numberOfBits));
149       bitsConsumed += numberOfBits;
150 
151       loader =
152           new BitInputStream.Loader(inputBase, inputAddress, currentAddress,
153                                     bits, bitsConsumed);
154       loader.load();
155       bitsConsumed = loader.getBitsConsumed();
156       bits = loader.getBits();
157       currentAddress = loader.getCurrentAddress();
158 
159       if (loader.isOverflow()) {
160         outputBuffer[output++] = symbols[state2];
161         break;
162       }
163 
164       verify(output <= outputLimit - 2, inputAddress,
165              "Output buffer is too small");
166       outputBuffer[output++] = symbols[state2];
167       final int numberOfBits1 = numbersOfBits[state2];
168       state2 = (int) (newStates[state2] +
169                       peekBits(bitsConsumed, bits, numberOfBits1));
170       bitsConsumed += numberOfBits1;
171 
172       loader =
173           new BitInputStream.Loader(inputBase, inputAddress, currentAddress,
174                                     bits, bitsConsumed);
175       loader.load();
176       bitsConsumed = loader.getBitsConsumed();
177       bits = loader.getBits();
178       currentAddress = loader.getCurrentAddress();
179 
180       if (loader.isOverflow()) {
181         outputBuffer[output++] = symbols[state1];
182         break;
183       }
184     }
185 
186     return output - outputAddress;
187   }
188 
189   public static int compress(final byte[] outputBase, final int outputAddress,
190                              final int outputSize, final byte[] input,
191                              final int inputSize,
192                              final FseCompressionTable table) {
193     return compress(outputBase, outputAddress, outputSize, input, 0, inputSize,
194                     table);
195   }
196 
197   public static int compress(final byte[] outputBase, final int outputAddress,
198                              final int outputSize, final byte[] inputBase,
199                              final int inputAddress, int inputSize,
200                              final FseCompressionTable table) {
201     checkArgument(outputSize >= SIZE_OF_LONG, OUTPUT_BUFFER_TOO_SMALL);
202 
203     int input = inputAddress + inputSize;
204 
205     if (inputSize <= 2) {
206       return 0;
207     }
208 
209     final BitOutputStream stream =
210         new BitOutputStream(outputBase, outputAddress, outputSize);
211 
212     int state1;
213     int state2;
214 
215     if ((inputSize & 1) != 0) {
216       input--;
217       state1 = table.begin(inputBase[input]);
218 
219       input--;
220       state2 = table.begin(inputBase[input]);
221 
222       input--;
223       state1 = table.encode(stream, state1, inputBase[input]);
224 
225       stream.flush();
226     } else {
227       input--;
228       state2 = table.begin(inputBase[input]);
229 
230       input--;
231       state1 = table.begin(inputBase[input]);
232     }
233 
234     // join to mod 4
235     inputSize -= 2;
236 
237     if ((inputSize & 2) != 0) {  /* test bit 2 */
238       input--;
239       state2 = table.encode(stream, state2, inputBase[input]);
240 
241       input--;
242       state1 = table.encode(stream, state1, inputBase[input]);
243 
244       stream.flush();
245     }
246 
247     // 2 or 4 encoding per loop
248     while (input > inputAddress) {
249       input--;
250       state2 = table.encode(stream, state2, inputBase[input]);
251 
252       input--;
253       state1 = table.encode(stream, state1, inputBase[input]);
254 
255       input--;
256       state2 = table.encode(stream, state2, inputBase[input]);
257 
258       input--;
259       state1 = table.encode(stream, state1, inputBase[input]);
260 
261       stream.flush();
262     }
263 
264     table.finish(stream, state2);
265     table.finish(stream, state1);
266 
267     return stream.close();
268   }
269 
270   public static int optimalTableLog(final int maxTableLog, final int inputSize,
271                                     final int maxSymbol) {
272     if (inputSize <= 1) {
273       throw new IllegalArgumentException(); // not supported. Use RLE instead
274     }
275 
276     int result = maxTableLog;
277 
278     result = Math.min(result, Util.highestBit((inputSize - 1)) -
279                               2); // we may be able to reduce accuracy if input is small
280 
281     // Need a minimum to safely represent all symbol values
282     result = Math.max(result, Util.minTableLog(inputSize, maxSymbol));
283 
284     result = Math.max(result, MIN_TABLE_LOG);
285     result = Math.min(result, MAX_TABLE_LOG);
286 
287     return result;
288   }
289 
290   public static void normalizeCounts(final short[] normalizedCounts,
291                                      final int tableLog, final int[] counts,
292                                      final int total, final int maxSymbol) {
293     checkArgument(tableLog >= MIN_TABLE_LOG, "Unsupported FSE table size");
294     checkArgument(tableLog <= MAX_TABLE_LOG, "FSE table size too large");
295     checkArgument(tableLog >= Util.minTableLog(total, maxSymbol),
296                   "FSE table size too small");
297 
298     final int scale = 62 - tableLog;
299     final long step = (1L << 62) / total;
300     final long vstep = 1L << (scale - 20);
301 
302     int stillToDistribute = 1 << tableLog;
303 
304     int largest = 0;
305     short largestProbability = 0;
306     final int lowThreshold = total >>> tableLog;
307 
308     for (int symbol = 0; symbol <= maxSymbol; symbol++) {
309       if (counts[symbol] == total) {
310         throw new IllegalArgumentException(); // TODO: should have been RLE-compressed by upper layers
311       }
312       if (counts[symbol] == 0) {
313         normalizedCounts[symbol] = 0;
314         continue;
315       }
316       if (counts[symbol] <= lowThreshold) {
317         normalizedCounts[symbol] = -1;
318         stillToDistribute--;
319       } else {
320         short probability = (short) ((counts[symbol] * step) >>> scale);
321         if (probability < 8) {
322           final long restToBeat = vstep * REST_TO_BEAT[probability];
323           final long delta =
324               counts[symbol] * step - (((long) probability) << scale);
325           if (delta > restToBeat) {
326             probability++;
327           }
328         }
329         if (probability > largestProbability) {
330           largestProbability = probability;
331           largest = symbol;
332         }
333         normalizedCounts[symbol] = probability;
334         stillToDistribute -= probability;
335       }
336     }
337 
338     if (-stillToDistribute >= (normalizedCounts[largest] >>> 1)) {
339       // corner case. Need another normalization method
340       // TODO size_t const errorCode = FSE_normalizeM2(normalizedCounter, tableLog, count, total, maxSymbolValue);
341       normalizeCounts2(normalizedCounts, tableLog, counts, total, maxSymbol);
342     } else {
343       normalizedCounts[largest] += (short) stillToDistribute;
344     }
345 
346   }
347 
348   private static void normalizeCounts2(final short[] normalizedCounts,
349                                        final int tableLog, final int[] counts,
350                                        int total, final int maxSymbol) {
351     int distributed = 0;
352 
353     final int lowThreshold = total >>>
354                              tableLog; // minimum count below which frequency in the normalized table is "too small" (~ < 1)
355     int lowOne = (total * 3) >>> (tableLog +
356                                   1); // 1.5 * lowThreshold. If count in (lowThreshold, lowOne] => assign frequency 1
357 
358     for (int i = 0; i <= maxSymbol; i++) {
359       if (counts[i] == 0) {
360         normalizedCounts[i] = 0;
361       } else if (counts[i] <= lowThreshold) {
362         normalizedCounts[i] = -1;
363         distributed++;
364         total -= counts[i];
365       } else if (counts[i] <= lowOne) {
366         normalizedCounts[i] = 1;
367         distributed++;
368         total -= counts[i];
369       } else {
370         normalizedCounts[i] = UNASSIGNED;
371       }
372     }
373 
374     final int normalizationFactor = 1 << tableLog;
375     int toDistribute = normalizationFactor - distributed;
376 
377     if ((total / toDistribute) > lowOne) {
378       /* risk of rounding to zero */
379       lowOne = ((total * 3) / (toDistribute * 2));
380       for (int i = 0; i <= maxSymbol; i++) {
381         if ((normalizedCounts[i] == UNASSIGNED) && (counts[i] <= lowOne)) {
382           normalizedCounts[i] = 1;
383           distributed++;
384           total -= counts[i];
385         }
386       }
387       toDistribute = normalizationFactor - distributed;
388     }
389 
390     if (distributed == maxSymbol + 1) {
391       // all values are pretty poor;
392       // probably incompressible data (should have already been detected);
393       // find max, then give all remaining points to max
394       int maxValue = 0;
395       int maxCount = 0;
396       for (int i = 0; i <= maxSymbol; i++) {
397         if (counts[i] > maxCount) {
398           maxValue = i;
399           maxCount = counts[i];
400         }
401       }
402       normalizedCounts[maxValue] += (short) toDistribute;
403       return;
404     }
405 
406     if (total == 0) {
407       // all of the symbols were low enough for the lowOne or lowThreshold
408       for (int i = 0; toDistribute > 0; i = (i + 1) % (maxSymbol + 1)) {
409         if (normalizedCounts[i] > 0) {
410           toDistribute--;
411           normalizedCounts[i]++;
412         }
413       }
414       return;
415     }
416 
417     // TODO: simplify/document this code
418     final long vStepLog = 62 - tableLog;
419     final long mid = (1L << (vStepLog - 1)) - 1;
420     final long rStep = (((1L << vStepLog) * toDistribute) + mid) /
421                        total;   /* scale on remaining */
422     long tmpTotal = mid;
423     for (int i = 0; i <= maxSymbol; i++) {
424       if (normalizedCounts[i] == UNASSIGNED) {
425         final long end = tmpTotal + (counts[i] * rStep);
426         final int sStart = (int) (tmpTotal >>> vStepLog);
427         final int sEnd = (int) (end >>> vStepLog);
428         final int weight = sEnd - sStart;
429 
430         if (weight < 1) {
431           throw new AssertionError();
432         }
433         normalizedCounts[i] = (short) weight;
434         tmpTotal = end;
435       }
436     }
437 
438   }
439 
440   public static int writeNormalizedCounts(final byte[] outputBase,
441                                           final int outputAddress,
442                                           final int outputSize,
443                                           final short[] normalizedCounts,
444                                           final int maxSymbol,
445                                           final int tableLog) {
446     checkArgument(tableLog <= MAX_TABLE_LOG, "FSE table too large");
447     checkArgument(tableLog >= MIN_TABLE_LOG, "FSE table too small");
448 
449     int output = outputAddress;
450     final int outputLimit = outputAddress + outputSize;
451 
452     final int tableSize = 1 << tableLog;
453 
454     int bitCount = 0;
455 
456     // encode table size
457     int bitStream = (tableLog - MIN_TABLE_LOG);
458     bitCount += 4;
459 
460     int remaining = tableSize + 1; // +1 for extra accuracy
461     int threshold = tableSize;
462     int tableBitCount = tableLog + 1;
463 
464     int symbol = 0;
465 
466     boolean previousIs0 = false;
467     while (remaining > 1) {
468       if (previousIs0) {
469         // From RFC 8478, section 4.1.1:
470         //   When a symbol has a probability of zero, it is followed by a 2-bit
471         //   repeat flag.  This repeat flag tells how many probabilities of zeroes
472         //   follow the current one.  It provides a number ranging from 0 to 3.
473         //   If it is a 3, another 2-bit repeat flag follows, and so on.
474         int start = symbol;
475 
476         // find run of symbols with count 0
477         while (normalizedCounts[symbol] == 0) {
478           symbol++;
479         }
480 
481         // encode in batches if 8 repeat sequences in one shot (representing 24 symbols total)
482         while (symbol >= start + 24) {
483           start += 24;
484           bitStream |= (0xFFFF << bitCount);
485           checkArgument(output + SIZE_OF_SHORT <= outputLimit,
486                         OUTPUT_BUFFER_TOO_SMALL);
487 
488           putShort(outputBase, output, (short) bitStream);
489           output += SIZE_OF_SHORT;
490 
491           // flush now, so no need to increase bitCount by 16
492           bitStream >>>= Short.SIZE;
493         }
494 
495         // encode remaining in batches of 3 symbols
496         while (symbol >= start + 3) {
497           start += 3;
498           bitStream |= 0x3 << bitCount;
499           bitCount += 2;
500         }
501 
502         // encode tail
503         bitStream |= (symbol - start) << bitCount;
504         bitCount += 2;
505 
506         // flush bitstream if necessary
507         if (bitCount > 16) {
508           checkArgument(output + SIZE_OF_SHORT <= outputLimit,
509                         OUTPUT_BUFFER_TOO_SMALL);
510 
511           putShort(outputBase, output, (short) bitStream);
512           output += SIZE_OF_SHORT;
513 
514           bitStream >>>= Short.SIZE;
515           bitCount -= Short.SIZE;
516         }
517       }
518 
519       int count = normalizedCounts[symbol++];
520       final int max = (2 * threshold - 1) - remaining;
521       remaining -= count < 0? -count : count;
522       count++;   /* +1 for extra accuracy */
523       if (count >= threshold) {
524         count += max;
525       }
526       bitStream |= count << bitCount;
527       bitCount += tableBitCount;
528       bitCount -= (count < max? 1 : 0);
529       previousIs0 = (count == 1);
530 
531       if (remaining < 1) {
532         throw new AssertionError();
533       }
534 
535       while (remaining < threshold) {
536         tableBitCount--;
537         threshold >>= 1;
538       }
539 
540       // flush bitstream if necessary
541       if (bitCount > 16) {
542         checkArgument(output + SIZE_OF_SHORT <= outputLimit,
543                       OUTPUT_BUFFER_TOO_SMALL);
544 
545         putShort(outputBase, output, (short) bitStream);
546         output += SIZE_OF_SHORT;
547 
548         bitStream >>>= Short.SIZE;
549         bitCount -= Short.SIZE;
550       }
551     }
552 
553     // flush remaining bitstream
554     checkArgument(output + SIZE_OF_SHORT <= outputLimit,
555                   OUTPUT_BUFFER_TOO_SMALL);
556     putShort(outputBase, output, (short) bitStream);
557     output += (bitCount + 7) / 8;
558 
559     checkArgument(symbol <= maxSymbol + 1, "Error"); // TODO
560 
561     return output - outputAddress;
562   }
563 
564   public static final class Table {
565     int log2Size;
566     final int[] newState;
567     final byte[] symbol;
568     final byte[] numberOfBits;
569 
570     public Table(final int log2Capacity) {
571       final int capacity = 1 << log2Capacity;
572       newState = new int[capacity];
573       symbol = new byte[capacity];
574       numberOfBits = new byte[capacity];
575     }
576 
577     public Table(final int log2Size, final int[] newState, final byte[] symbol,
578                  final byte[] numberOfBits) {
579       final int size = 1 << log2Size;
580       if (newState.length != size || symbol.length != size ||
581           numberOfBits.length != size) {
582         throw new IllegalArgumentException(
583             "Expected arrays to match provided size");
584       }
585 
586       this.log2Size = log2Size;
587       this.newState = newState;
588       this.symbol = symbol;
589       this.numberOfBits = numberOfBits;
590     }
591   }
592 }