View Javadoc
1   /*
2    * This file is part of Waarp Project (named also Waarp or GG).
3    *
4    *  Copyright (c) 2019, Waarp SAS, and individual contributors by the @author
5    *  tags. See the COPYRIGHT.txt in the distribution for a full listing of
6    * individual contributors.
7    *
8    *  All Waarp Project is free software: you can redistribute it and/or
9    * modify it under the terms of the GNU General Public License as published by
10   * the Free Software Foundation, either version 3 of the License, or (at your
11   * option) any later version.
12   *
13   * Waarp is distributed in the hope that it will be useful, but WITHOUT ANY
14   * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
15   * A PARTICULAR PURPOSE. See the GNU General Public License for more details.
16   *
17   *  You should have received a copy of the GNU General Public License along with
18   * Waarp . If not, see <http://www.gnu.org/licenses/>.
19   */
20  
21  /*
22   * Licensed under the Apache License, Version 2.0 (the "License");
23   * you may not use this file except in compliance with the License.
24   * You may obtain a copy of the License at
25   *
26   *     http://www.apache.org/licenses/LICENSE-2.0
27   *
28   * Unless required by applicable law or agreed to in writing, software
29   * distributed under the License is distributed on an "AS IS" BASIS,
30   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31   * See the License for the specific language governing permissions and
32   * limitations under the License.
33   */
34  package org.waarp.compress.zstdunsafe;
35  
36  import static org.waarp.compress.zstdunsafe.BitInputStream.*;
37  import static sun.misc.Unsafe.*;
38  
39  class FiniteStateEntropy {
40    public static final int MAX_SYMBOL = 255;
41    public static final int MAX_TABLE_LOG = 12;
42    public static final int MIN_TABLE_LOG = 5;
43  
44    private static final int[] REST_TO_BEAT =
45        new int[] { 0, 473195, 504333, 520860, 550000, 700000, 750000, 830000 };
46    private static final short UNASSIGNED = -2;
47    public static final String OUTPUT_BUFFER_TOO_SMALL =
48        "Output buffer too small";
49  
50    private FiniteStateEntropy() {
51    }
52  
53    public static int decompress(final FiniteStateEntropy.Table table,
54                                 final Object inputBase, final long inputAddress,
55                                 final long inputLimit,
56                                 final byte[] outputBuffer) {
57      final long outputAddress = ARRAY_BYTE_BASE_OFFSET;
58      final long outputLimit = outputAddress + outputBuffer.length;
59  
60      long output = outputAddress;
61  
62      // initialize bit stream
63      final BitInputStream.Initializer initializer =
64          new BitInputStream.Initializer(inputBase, inputAddress, inputLimit);
65      initializer.initialize();
66      int bitsConsumed = initializer.getBitsConsumed();
67      long currentAddress = initializer.getCurrentAddress();
68      long bits = initializer.getBits();
69  
70      // initialize first FSE stream
71      int state1 = (int) peekBits(bitsConsumed, bits, table.log2Size);
72      bitsConsumed += table.log2Size;
73  
74      BitInputStream.Loader loader =
75          new BitInputStream.Loader(inputBase, inputAddress, currentAddress, bits,
76                                    bitsConsumed);
77      loader.load();
78      bits = loader.getBits();
79      bitsConsumed = loader.getBitsConsumed();
80      currentAddress = loader.getCurrentAddress();
81  
82      // initialize second FSE stream
83      int state2 = (int) peekBits(bitsConsumed, bits, table.log2Size);
84      bitsConsumed += table.log2Size;
85  
86      loader =
87          new BitInputStream.Loader(inputBase, inputAddress, currentAddress, bits,
88                                    bitsConsumed);
89      loader.load();
90      bits = loader.getBits();
91      bitsConsumed = loader.getBitsConsumed();
92      currentAddress = loader.getCurrentAddress();
93  
94      final byte[] symbols = table.symbol;
95      final byte[] numbersOfBits = table.numberOfBits;
96      final int[] newStates = table.newState;
97  
98      // decode 4 symbols per loop
99      while (output <= outputLimit - 4) {
100       int numberOfBits;
101 
102       UnsafeUtil.UNSAFE.putByte(outputBuffer, output, symbols[state1]);
103       numberOfBits = numbersOfBits[state1];
104       state1 = (int) (newStates[state1] +
105                       peekBits(bitsConsumed, bits, numberOfBits));
106       bitsConsumed += numberOfBits;
107 
108       UnsafeUtil.UNSAFE.putByte(outputBuffer, output + 1, symbols[state2]);
109       numberOfBits = numbersOfBits[state2];
110       state2 = (int) (newStates[state2] +
111                       peekBits(bitsConsumed, bits, numberOfBits));
112       bitsConsumed += numberOfBits;
113 
114       UnsafeUtil.UNSAFE.putByte(outputBuffer, output + 2, symbols[state1]);
115       numberOfBits = numbersOfBits[state1];
116       state1 = (int) (newStates[state1] +
117                       peekBits(bitsConsumed, bits, numberOfBits));
118       bitsConsumed += numberOfBits;
119 
120       UnsafeUtil.UNSAFE.putByte(outputBuffer, output + 3, symbols[state2]);
121       numberOfBits = numbersOfBits[state2];
122       state2 = (int) (newStates[state2] +
123                       peekBits(bitsConsumed, bits, numberOfBits));
124       bitsConsumed += numberOfBits;
125 
126       output += Constants.SIZE_OF_INT;
127 
128       loader =
129           new BitInputStream.Loader(inputBase, inputAddress, currentAddress,
130                                     bits, bitsConsumed);
131       final boolean done = loader.load();
132       bitsConsumed = loader.getBitsConsumed();
133       bits = loader.getBits();
134       currentAddress = loader.getCurrentAddress();
135       if (done) {
136         break;
137       }
138     }
139 
140     while (true) {
141       Util.verify(output <= outputLimit - 2, inputAddress,
142                   "Output buffer is too small");
143       UnsafeUtil.UNSAFE.putByte(outputBuffer, output++, symbols[state1]);
144       final int numberOfBits = numbersOfBits[state1];
145       state1 = (int) (newStates[state1] +
146                       peekBits(bitsConsumed, bits, numberOfBits));
147       bitsConsumed += numberOfBits;
148 
149       loader =
150           new BitInputStream.Loader(inputBase, inputAddress, currentAddress,
151                                     bits, bitsConsumed);
152       loader.load();
153       bitsConsumed = loader.getBitsConsumed();
154       bits = loader.getBits();
155       currentAddress = loader.getCurrentAddress();
156 
157       if (loader.isOverflow()) {
158         UnsafeUtil.UNSAFE.putByte(outputBuffer, output++, symbols[state2]);
159         break;
160       }
161 
162       Util.verify(output <= outputLimit - 2, inputAddress,
163                   "Output buffer is too small");
164       UnsafeUtil.UNSAFE.putByte(outputBuffer, output++, symbols[state2]);
165       final int numberOfBits1 = numbersOfBits[state2];
166       state2 = (int) (newStates[state2] +
167                       peekBits(bitsConsumed, bits, numberOfBits1));
168       bitsConsumed += numberOfBits1;
169 
170       loader =
171           new BitInputStream.Loader(inputBase, inputAddress, currentAddress,
172                                     bits, bitsConsumed);
173       loader.load();
174       bitsConsumed = loader.getBitsConsumed();
175       bits = loader.getBits();
176       currentAddress = loader.getCurrentAddress();
177 
178       if (loader.isOverflow()) {
179         UnsafeUtil.UNSAFE.putByte(outputBuffer, output++, symbols[state1]);
180         break;
181       }
182     }
183 
184     return (int) (output - outputAddress);
185   }
186 
187   public static int compress(final Object outputBase, final long outputAddress,
188                              final int outputSize, final byte[] input,
189                              final int inputSize,
190                              final FseCompressionTable table) {
191     return compress(outputBase, outputAddress, outputSize, input,
192                     ARRAY_BYTE_BASE_OFFSET, inputSize, table);
193   }
194 
195   public static int compress(final Object outputBase, final long outputAddress,
196                              final int outputSize, final Object inputBase,
197                              final long inputAddress, int inputSize,
198                              final FseCompressionTable table) {
199     Util.checkArgument(outputSize >= Constants.SIZE_OF_LONG,
200                        OUTPUT_BUFFER_TOO_SMALL);
201 
202     long input = inputAddress + inputSize;
203 
204     if (inputSize <= 2) {
205       return 0;
206     }
207 
208     final BitOutputStream stream =
209         new BitOutputStream(outputBase, outputAddress, outputSize);
210 
211     int state1;
212     int state2;
213 
214     if ((inputSize & 1) != 0) {
215       input--;
216       state1 = table.begin(UnsafeUtil.UNSAFE.getByte(inputBase, input));
217 
218       input--;
219       state2 = table.begin(UnsafeUtil.UNSAFE.getByte(inputBase, input));
220 
221       input--;
222       state1 = table.encode(stream, state1,
223                             UnsafeUtil.UNSAFE.getByte(inputBase, input));
224 
225       stream.flush();
226     } else {
227       input--;
228       state2 = table.begin(UnsafeUtil.UNSAFE.getByte(inputBase, input));
229 
230       input--;
231       state1 = table.begin(UnsafeUtil.UNSAFE.getByte(inputBase, input));
232     }
233 
234     // join to mod 4
235     inputSize -= 2;
236 
237     if ((inputSize & 2) != 0) {  /* test bit 2 */
238       input--;
239       state2 = table.encode(stream, state2,
240                             UnsafeUtil.UNSAFE.getByte(inputBase, input));
241 
242       input--;
243       state1 = table.encode(stream, state1,
244                             UnsafeUtil.UNSAFE.getByte(inputBase, input));
245 
246       stream.flush();
247     }
248 
249     // 2 or 4 encoding per loop
250     while (input > inputAddress) {
251       input--;
252       state2 = table.encode(stream, state2,
253                             UnsafeUtil.UNSAFE.getByte(inputBase, input));
254 
255       input--;
256       state1 = table.encode(stream, state1,
257                             UnsafeUtil.UNSAFE.getByte(inputBase, input));
258 
259       input--;
260       state2 = table.encode(stream, state2,
261                             UnsafeUtil.UNSAFE.getByte(inputBase, input));
262 
263       input--;
264       state1 = table.encode(stream, state1,
265                             UnsafeUtil.UNSAFE.getByte(inputBase, input));
266 
267       stream.flush();
268     }
269 
270     table.finish(stream, state2);
271     table.finish(stream, state1);
272 
273     return stream.close();
274   }
275 
276   public static int optimalTableLog(final int maxTableLog, final int inputSize,
277                                     final int maxSymbol) {
278     if (inputSize <= 1) {
279       throw new IllegalArgumentException(); // not supported. Use RLE instead
280     }
281 
282     int result = maxTableLog;
283 
284     result = Math.min(result, Util.highestBit((inputSize - 1)) -
285                               2); // we may be able to reduce accuracy if input is small
286 
287     // Need a minimum to safely represent all symbol values
288     result = Math.max(result, Util.minTableLog(inputSize, maxSymbol));
289 
290     result = Math.max(result, MIN_TABLE_LOG);
291     result = Math.min(result, MAX_TABLE_LOG);
292 
293     return result;
294   }
295 
296   public static void normalizeCounts(final short[] normalizedCounts,
297                                      final int tableLog, final int[] counts,
298                                      final int total, final int maxSymbol) {
299     Util.checkArgument(tableLog >= MIN_TABLE_LOG, "Unsupported FSE table size");
300     Util.checkArgument(tableLog <= MAX_TABLE_LOG, "FSE table size too large");
301     Util.checkArgument(tableLog >= Util.minTableLog(total, maxSymbol),
302                        "FSE table size too small");
303 
304     final long scale = 62L - tableLog;
305     final long step = (1L << 62) / total;
306     final long vstep = 1L << (scale - 20);
307 
308     int stillToDistribute = 1 << tableLog;
309 
310     int largest = 0;
311     short largestProbability = 0;
312     final int lowThreshold = total >>> tableLog;
313 
314     for (int symbol = 0; symbol <= maxSymbol; symbol++) {
315       if (counts[symbol] == total) {
316         throw new IllegalArgumentException(); // TODO: should have been RLE-compressed by upper layers
317       }
318       if (counts[symbol] == 0) {
319         normalizedCounts[symbol] = 0;
320         continue;
321       }
322       if (counts[symbol] <= lowThreshold) {
323         normalizedCounts[symbol] = -1;
324         stillToDistribute--;
325       } else {
326         short probability = (short) ((counts[symbol] * step) >>> scale);
327         if (probability < 8) {
328           final long restToBeat = vstep * REST_TO_BEAT[probability];
329           final long delta =
330               counts[symbol] * step - (((long) probability) << scale);
331           if (delta > restToBeat) {
332             probability++;
333           }
334         }
335         if (probability > largestProbability) {
336           largestProbability = probability;
337           largest = symbol;
338         }
339         normalizedCounts[symbol] = probability;
340         stillToDistribute -= probability;
341       }
342     }
343 
344     if (-stillToDistribute >= (normalizedCounts[largest] >>> 1)) {
345       // corner case. Need another normalization method
346       // TODO size_t const errorCode = FSE_normalizeM2(normalizedCounter, tableLog, count, total, maxSymbolValue);
347       normalizeCounts2(normalizedCounts, tableLog, counts, total, maxSymbol);
348     } else {
349       normalizedCounts[largest] += (short) stillToDistribute;
350     }
351 
352   }
353 
354   private static void normalizeCounts2(final short[] normalizedCounts,
355                                        final int tableLog, final int[] counts,
356                                        int total, final int maxSymbol) {
357     int distributed = 0;
358 
359     final int lowThreshold = total >>>
360                              tableLog; // minimum count below which frequency in the normalized table is "too small" (~ < 1)
361     int lowOne = (total * 3) >>> (tableLog +
362                                   1); // 1.5 * lowThreshold. If count in (lowThreshold, lowOne] => assign frequency 1
363 
364     for (int i = 0; i <= maxSymbol; i++) {
365       if (counts[i] == 0) {
366         normalizedCounts[i] = 0;
367       } else if (counts[i] <= lowThreshold) {
368         normalizedCounts[i] = -1;
369         distributed++;
370         total -= counts[i];
371       } else if (counts[i] <= lowOne) {
372         normalizedCounts[i] = 1;
373         distributed++;
374         total -= counts[i];
375       } else {
376         normalizedCounts[i] = UNASSIGNED;
377       }
378     }
379 
380     final int normalizationFactor = 1 << tableLog;
381     int toDistribute = normalizationFactor - distributed;
382 
383     if ((total / toDistribute) > lowOne) {
384       /* risk of rounding to zero */
385       lowOne = ((total * 3) / (toDistribute * 2));
386       for (int i = 0; i <= maxSymbol; i++) {
387         if ((normalizedCounts[i] == UNASSIGNED) && (counts[i] <= lowOne)) {
388           normalizedCounts[i] = 1;
389           distributed++;
390           total -= counts[i];
391         }
392       }
393       toDistribute = normalizationFactor - distributed;
394     }
395 
396     if (distributed == maxSymbol + 1) {
397       // all values are pretty poor
398       // probably incompressible data (should have already been detected);
399       // find max, then give all remaining points to max
400       int maxValue = 0;
401       int maxCount = 0;
402       for (int i = 0; i <= maxSymbol; i++) {
403         if (counts[i] > maxCount) {
404           maxValue = i;
405           maxCount = counts[i];
406         }
407       }
408       normalizedCounts[maxValue] += (short) toDistribute;
409       return;
410     }
411 
412     if (total == 0) {
413       // all of the symbols were low enough for the lowOne or lowThreshold
414       for (int i = 0; toDistribute > 0; i = (i + 1) % (maxSymbol + 1)) {
415         if (normalizedCounts[i] > 0) {
416           toDistribute--;
417           normalizedCounts[i]++;
418         }
419       }
420       return;
421     }
422 
423     // TODO: simplify/document this code
424     final long vStepLog = 62 - tableLog;
425     final long mid = (1L << (vStepLog - 1)) - 1;
426     final long rStep = (((1L << vStepLog) * toDistribute) + mid) /
427                        total;   /* scale on remaining */
428     long tmpTotal = mid;
429     for (int i = 0; i <= maxSymbol; i++) {
430       if (normalizedCounts[i] == UNASSIGNED) {
431         final long end = tmpTotal + (counts[i] * rStep);
432         final int sStart = (int) (tmpTotal >>> vStepLog);
433         final int sEnd = (int) (end >>> vStepLog);
434         final int weight = sEnd - sStart;
435 
436         if (weight < 1) {
437           throw new AssertionError();
438         }
439         normalizedCounts[i] = (short) weight;
440         tmpTotal = end;
441       }
442     }
443 
444   }
445 
446   public static int writeNormalizedCounts(final Object outputBase,
447                                           final long outputAddress,
448                                           final int outputSize,
449                                           final short[] normalizedCounts,
450                                           final int maxSymbol,
451                                           final int tableLog) {
452     Util.checkArgument(tableLog <= MAX_TABLE_LOG, "FSE table too large");
453     Util.checkArgument(tableLog >= MIN_TABLE_LOG, "FSE table too small");
454 
455     long output = outputAddress;
456     final long outputLimit = outputAddress + outputSize;
457 
458     final int tableSize = 1 << tableLog;
459 
460     int bitCount = 0;
461 
462     // encode table size
463     int bitStream = (tableLog - MIN_TABLE_LOG);
464     bitCount += 4;
465 
466     int remaining = tableSize + 1; // +1 for extra accuracy
467     int threshold = tableSize;
468     int tableBitCount = tableLog + 1;
469 
470     int symbol = 0;
471 
472     boolean previousIs0 = false;
473     while (remaining > 1) {
474       if (previousIs0) {
475         // From RFC 8478, section 4.1.1:
476         //   When a symbol has a probability of zero, it is followed by a 2-bit
477         //   repeat flag.  This repeat flag tells how many probabilities of zeroes
478         //   follow the current one.  It provides a number ranging from 0 to 3.
479         //   If it is a 3, another 2-bit repeat flag follows, and so on.
480         int start = symbol;
481 
482         // find run of symbols with count 0
483         while (normalizedCounts[symbol] == 0) {
484           symbol++;
485         }
486 
487         // encode in batches if 8 repeat sequences in one shot (representing 24 symbols total)
488         while (symbol >= start + 24) {
489           start += 24;
490           bitStream |= (0xffff << bitCount);
491           Util.checkArgument(output + Constants.SIZE_OF_SHORT <= outputLimit,
492                              OUTPUT_BUFFER_TOO_SMALL);
493 
494           UnsafeUtil.UNSAFE.putShort(outputBase, output, (short) bitStream);
495           output += Constants.SIZE_OF_SHORT;
496 
497           // flush now, so no need to increase bitCount by 16
498           bitStream >>>= Short.SIZE;
499         }
500 
501         // encode remaining in batches of 3 symbols
502         while (symbol >= start + 3) {
503           start += 3;
504           bitStream |= 0x3 << bitCount;
505           bitCount += 2;
506         }
507 
508         // encode tail
509         bitStream |= (symbol - start) << bitCount;
510         bitCount += 2;
511 
512         // flush bitstream if necessary
513         if (bitCount > 16) {
514           Util.checkArgument(output + Constants.SIZE_OF_SHORT <= outputLimit,
515                              OUTPUT_BUFFER_TOO_SMALL);
516 
517           UnsafeUtil.UNSAFE.putShort(outputBase, output, (short) bitStream);
518           output += Constants.SIZE_OF_SHORT;
519 
520           bitStream >>>= Short.SIZE;
521           bitCount -= Short.SIZE;
522         }
523       }
524 
525       int count = normalizedCounts[symbol++];
526       final int max = (2 * threshold - 1) - remaining;
527       remaining -= count < 0? -count : count;
528       count++;   /* +1 for extra accuracy */
529       if (count >= threshold) {
530         count += max;
531       }
532       bitStream |= count << bitCount;
533       bitCount += tableBitCount;
534       bitCount -= (count < max? 1 : 0);
535       previousIs0 = (count == 1);
536 
537       if (remaining < 1) {
538         throw new AssertionError();
539       }
540 
541       while (remaining < threshold) {
542         tableBitCount--;
543         threshold >>= 1;
544       }
545 
546       // flush bitstream if necessary
547       if (bitCount > 16) {
548         Util.checkArgument(output + Constants.SIZE_OF_SHORT <= outputLimit,
549                            OUTPUT_BUFFER_TOO_SMALL);
550 
551         UnsafeUtil.UNSAFE.putShort(outputBase, output, (short) bitStream);
552         output += Constants.SIZE_OF_SHORT;
553 
554         bitStream >>>= Short.SIZE;
555         bitCount -= Short.SIZE;
556       }
557     }
558 
559     // flush remaining bitstream
560     Util.checkArgument(output + Constants.SIZE_OF_SHORT <= outputLimit,
561                        OUTPUT_BUFFER_TOO_SMALL);
562     UnsafeUtil.UNSAFE.putShort(outputBase, output, (short) bitStream);
563     output += (bitCount + 7) / 8;
564 
565     Util.checkArgument(symbol <= maxSymbol + 1, "Error"); // TODO
566 
567     return (int) (output - outputAddress);
568   }
569 
570   public static final class Table {
571     final int[] newState;
572     final byte[] symbol;
573     final byte[] numberOfBits;
574     int log2Size;
575 
576     public Table(final int log2Capacity) {
577       final int capacity = 1 << log2Capacity;
578       newState = new int[capacity];
579       symbol = new byte[capacity];
580       numberOfBits = new byte[capacity];
581     }
582 
583     public Table(final int log2Size, final int[] newState, final byte[] symbol,
584                  final byte[] numberOfBits) {
585       final int size = 1 << log2Size;
586       if (newState.length != size || symbol.length != size ||
587           numberOfBits.length != size) {
588         throw new IllegalArgumentException(
589             "Expected arrays to match provided size");
590       }
591 
592       this.log2Size = log2Size;
593       this.newState = newState;
594       this.symbol = symbol;
595       this.numberOfBits = numberOfBits;
596     }
597   }
598 }