View Javadoc
1   /*
2    * This file is part of Waarp Project (named also Waarp or GG).
3    *
4    *  Copyright (c) 2019, Waarp SAS, and individual contributors by the @author
5    *  tags. See the COPYRIGHT.txt in the distribution for a full listing of
6    * individual contributors.
7    *
8    *  All Waarp Project is free software: you can redistribute it and/or
9    * modify it under the terms of the GNU General Public License as published by
10   * the Free Software Foundation, either version 3 of the License, or (at your
11   * option) any later version.
12   *
13   * Waarp is distributed in the hope that it will be useful, but WITHOUT ANY
14   * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
15   * A PARTICULAR PURPOSE. See the GNU General Public License for more details.
16   *
17   *  You should have received a copy of the GNU General Public License along with
18   * Waarp . If not, see <http://www.gnu.org/licenses/>.
19   */
20  
21  /*
22   * Licensed under the Apache License, Version 2.0 (the "License");
23   * you may not use this file except in compliance with the License.
24   * You may obtain a copy of the License at
25   *
26   *     http://www.apache.org/licenses/LICENSE-2.0
27   *
28   * Unless required by applicable law or agreed to in writing, software
29   * distributed under the License is distributed on an "AS IS" BASIS,
30   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31   * See the License for the specific language governing permissions and
32   * limitations under the License.
33   */
34  package org.waarp.compress.zstdunsafe;
35  
36  import java.util.Arrays;
37  
38  import static org.waarp.compress.zstdunsafe.BitInputStream.*;
39  import static org.waarp.compress.zstdunsafe.Constants.*;
40  
41  class Huffman {
42    public static final int MAX_SYMBOL = 255;
43    public static final int MAX_SYMBOL_COUNT = MAX_SYMBOL + 1;
44  
45    public static final int MAX_TABLE_LOG = 12;
46    public static final int MIN_TABLE_LOG = 5;
47    public static final int MAX_FSE_TABLE_LOG = 6;
48    public static final String NOT_ENOUGH_INPUT_BYTES = "Not enough input bytes";
49    public static final String INPUT_IS_CORRUPTED = "Input is corrupted";
50  
51    // stats
52    private final byte[] weights = new byte[MAX_SYMBOL + 1];
53    private final int[] ranks = new int[MAX_TABLE_LOG + 1];
54  
55    // table
56    private int tableLog = -1;
57    private final byte[] symbols = new byte[1 << MAX_TABLE_LOG];
58    private final byte[] numbersOfBits = new byte[1 << MAX_TABLE_LOG];
59  
60    private final FseTableReader reader = new FseTableReader();
61    private final FiniteStateEntropy.Table fseTable =
62        new FiniteStateEntropy.Table(MAX_FSE_TABLE_LOG);
63  
64    public boolean isLoaded() {
65      return tableLog != -1;
66    }
67  
68    public int readTable(final Object inputBase, final long inputAddress,
69                         final int size) {
70      Arrays.fill(ranks, 0);
71      long input = inputAddress;
72  
73      // read table header
74      Util.verify(size > 0, input, NOT_ENOUGH_INPUT_BYTES);
75      int inputSize = UnsafeUtil.UNSAFE.getByte(inputBase, input++) & 0xFF;
76  
77      final int outputSize;
78      if (inputSize >= 128) {
79        outputSize = inputSize - 127;
80        inputSize = ((outputSize + 1) / 2);
81  
82        Util.verify(inputSize + 1 <= size, input, NOT_ENOUGH_INPUT_BYTES);
83        Util.verify(true, input, INPUT_IS_CORRUPTED);
84  
85        for (int i = 0; i < outputSize; i += 2) {
86          final int value =
87              UnsafeUtil.UNSAFE.getByte(inputBase, input + i / 2) & 0xFF;
88          weights[i] = (byte) (value >>> 4);
89          weights[i + 1] = (byte) (value & 0xf);
90        }
91      } else {
92        Util.verify(inputSize + 1 <= size, input, NOT_ENOUGH_INPUT_BYTES);
93  
94        final long inputLimit = input + inputSize;
95        input += reader.readFseTable(fseTable, inputBase, input, inputLimit,
96                                     FiniteStateEntropy.MAX_SYMBOL,
97                                     MAX_FSE_TABLE_LOG);
98        outputSize =
99            FiniteStateEntropy.decompress(fseTable, inputBase, input, inputLimit,
100                                         weights);
101     }
102 
103     int totalWeight = 0;
104     for (int i = 0; i < outputSize; i++) {
105       ranks[weights[i]]++;
106       totalWeight +=
107           (1 << weights[i]) >> 1;   // TODO same as 1 << (weights[n] - 1)?
108     }
109     Util.verify(totalWeight != 0, input, INPUT_IS_CORRUPTED);
110 
111     tableLog = Util.highestBit(totalWeight) + 1;
112     Util.verify(tableLog <= MAX_TABLE_LOG, input, INPUT_IS_CORRUPTED);
113 
114     final int total = 1 << tableLog;
115     final int rest = total - totalWeight;
116     Util.verify(Util.isPowerOf2(rest), input, INPUT_IS_CORRUPTED);
117 
118     final int lastWeight = Util.highestBit(rest) + 1;
119 
120     weights[outputSize] = (byte) lastWeight;
121     ranks[lastWeight]++;
122 
123     final int numberOfSymbols = outputSize + 1;
124 
125     // populate table
126     int nextRankStart = 0;
127     for (int i = 1; i < tableLog + 1; ++i) {
128       final int current = nextRankStart;
129       nextRankStart += ranks[i] << (i - 1);
130       ranks[i] = current;
131     }
132 
133     for (int n = 0; n < numberOfSymbols; n++) {
134       final int weight = weights[n];
135       final int length = (1 << weight) >> 1;  // TODO: 1 << (weight - 1) ??
136 
137       final byte symbol = (byte) n;
138       final byte numberOfBits = (byte) (tableLog + 1 - weight);
139       for (int i = ranks[weight]; i < ranks[weight] + length; i++) {
140         symbols[i] = symbol;
141         numbersOfBits[i] = numberOfBits;
142       }
143       ranks[weight] += length;
144     }
145 
146     Util.verify(ranks[1] >= 2 && (ranks[1] & 1) == 0, input,
147                 INPUT_IS_CORRUPTED);
148 
149     return inputSize + 1;
150   }
151 
152   public void decodeSingleStream(final Object inputBase,
153                                  final long inputAddress, final long inputLimit,
154                                  final Object outputBase,
155                                  final long outputAddress,
156                                  final long outputLimit) {
157     final BitInputStream.Initializer initializer =
158         new BitInputStream.Initializer(inputBase, inputAddress, inputLimit);
159     initializer.initialize();
160 
161     long bits = initializer.getBits();
162     int bitsConsumed = initializer.getBitsConsumed();
163     long currentAddress = initializer.getCurrentAddress();
164 
165     final int tableLog1 = this.tableLog;
166     final byte[] numbersOfBits1 = this.numbersOfBits;
167     final byte[] symbols1 = this.symbols;
168 
169     // 4 symbols at a time
170     long output = outputAddress;
171     final long fastOutputLimit = outputLimit - 4;
172     while (output < fastOutputLimit) {
173       final BitInputStream.Loader loader =
174           new BitInputStream.Loader(inputBase, inputAddress, currentAddress,
175                                     bits, bitsConsumed);
176       final boolean done = loader.load();
177       bits = loader.getBits();
178       bitsConsumed = loader.getBitsConsumed();
179       currentAddress = loader.getCurrentAddress();
180       if (done) {
181         break;
182       }
183 
184       bitsConsumed =
185           decodeSymbol(outputBase, output, bits, bitsConsumed, tableLog1,
186                        numbersOfBits1, symbols1);
187       bitsConsumed =
188           decodeSymbol(outputBase, output + 1, bits, bitsConsumed, tableLog1,
189                        numbersOfBits1, symbols1);
190       bitsConsumed =
191           decodeSymbol(outputBase, output + 2, bits, bitsConsumed, tableLog1,
192                        numbersOfBits1, symbols1);
193       bitsConsumed =
194           decodeSymbol(outputBase, output + 3, bits, bitsConsumed, tableLog1,
195                        numbersOfBits1, symbols1);
196       output += SIZE_OF_INT;
197     }
198 
199     decodeTail(inputBase, inputAddress, currentAddress, bitsConsumed, bits,
200                outputBase, output, outputLimit);
201   }
202 
203   public void decode4Streams(final Object inputBase, final long inputAddress,
204                              final long inputLimit, final Object outputBase,
205                              final long outputAddress, final long outputLimit) {
206     Util.verify(inputLimit - inputAddress >= 10, inputAddress,
207                 INPUT_IS_CORRUPTED); // jump table + 1 byte per stream
208 
209     final long start1 =
210         inputAddress + 3 * SIZE_OF_SHORT; // for the shorts we read below
211     final long start2 =
212         start1 + (UnsafeUtil.UNSAFE.getShort(inputBase, inputAddress) & 0xFFFF);
213     final long start3 = start2 + (UnsafeUtil.UNSAFE.getShort(inputBase,
214                                                              inputAddress + 2) &
215                                   0xFFFF);
216     final long start4 = start3 + (UnsafeUtil.UNSAFE.getShort(inputBase,
217                                                              inputAddress + 4) &
218                                   0xFFFF);
219 
220     BitInputStream.Initializer initializer =
221         new BitInputStream.Initializer(inputBase, start1, start2);
222     initializer.initialize();
223     int stream1bitsConsumed = initializer.getBitsConsumed();
224     long stream1currentAddress = initializer.getCurrentAddress();
225     long stream1bits = initializer.getBits();
226 
227     initializer = new BitInputStream.Initializer(inputBase, start2, start3);
228     initializer.initialize();
229     int stream2bitsConsumed = initializer.getBitsConsumed();
230     long stream2currentAddress = initializer.getCurrentAddress();
231     long stream2bits = initializer.getBits();
232 
233     initializer = new BitInputStream.Initializer(inputBase, start3, start4);
234     initializer.initialize();
235     int stream3bitsConsumed = initializer.getBitsConsumed();
236     long stream3currentAddress = initializer.getCurrentAddress();
237     long stream3bits = initializer.getBits();
238 
239     initializer = new BitInputStream.Initializer(inputBase, start4, inputLimit);
240     initializer.initialize();
241     int stream4bitsConsumed = initializer.getBitsConsumed();
242     long stream4currentAddress = initializer.getCurrentAddress();
243     long stream4bits = initializer.getBits();
244 
245     final int segmentSize = (int) ((outputLimit - outputAddress + 3) / 4);
246 
247     final long outputStart2 = outputAddress + segmentSize;
248     final long outputStart3 = outputStart2 + segmentSize;
249     final long outputStart4 = outputStart3 + segmentSize;
250 
251     long output1 = outputAddress;
252     long output2 = outputStart2;
253     long output3 = outputStart3;
254     long output4 = outputStart4;
255 
256     final long fastOutputLimit = outputLimit - 7;
257     final int tableLog1 = this.tableLog;
258     final byte[] numbersOfBits1 = this.numbersOfBits;
259     final byte[] symbols1 = this.symbols;
260 
261     while (output4 < fastOutputLimit) {
262       stream1bitsConsumed =
263           decodeSymbol(outputBase, output1, stream1bits, stream1bitsConsumed,
264                        tableLog1, numbersOfBits1, symbols1);
265       stream2bitsConsumed =
266           decodeSymbol(outputBase, output2, stream2bits, stream2bitsConsumed,
267                        tableLog1, numbersOfBits1, symbols1);
268       stream3bitsConsumed =
269           decodeSymbol(outputBase, output3, stream3bits, stream3bitsConsumed,
270                        tableLog1, numbersOfBits1, symbols1);
271       stream4bitsConsumed =
272           decodeSymbol(outputBase, output4, stream4bits, stream4bitsConsumed,
273                        tableLog1, numbersOfBits1, symbols1);
274 
275       stream1bitsConsumed = decodeSymbol(outputBase, output1 + 1, stream1bits,
276                                          stream1bitsConsumed, tableLog1,
277                                          numbersOfBits1, symbols1);
278       stream2bitsConsumed = decodeSymbol(outputBase, output2 + 1, stream2bits,
279                                          stream2bitsConsumed, tableLog1,
280                                          numbersOfBits1, symbols1);
281       stream3bitsConsumed = decodeSymbol(outputBase, output3 + 1, stream3bits,
282                                          stream3bitsConsumed, tableLog1,
283                                          numbersOfBits1, symbols1);
284       stream4bitsConsumed = decodeSymbol(outputBase, output4 + 1, stream4bits,
285                                          stream4bitsConsumed, tableLog1,
286                                          numbersOfBits1, symbols1);
287 
288       stream1bitsConsumed = decodeSymbol(outputBase, output1 + 2, stream1bits,
289                                          stream1bitsConsumed, tableLog1,
290                                          numbersOfBits1, symbols1);
291       stream2bitsConsumed = decodeSymbol(outputBase, output2 + 2, stream2bits,
292                                          stream2bitsConsumed, tableLog1,
293                                          numbersOfBits1, symbols1);
294       stream3bitsConsumed = decodeSymbol(outputBase, output3 + 2, stream3bits,
295                                          stream3bitsConsumed, tableLog1,
296                                          numbersOfBits1, symbols1);
297       stream4bitsConsumed = decodeSymbol(outputBase, output4 + 2, stream4bits,
298                                          stream4bitsConsumed, tableLog1,
299                                          numbersOfBits1, symbols1);
300 
301       stream1bitsConsumed = decodeSymbol(outputBase, output1 + 3, stream1bits,
302                                          stream1bitsConsumed, tableLog1,
303                                          numbersOfBits1, symbols1);
304       stream2bitsConsumed = decodeSymbol(outputBase, output2 + 3, stream2bits,
305                                          stream2bitsConsumed, tableLog1,
306                                          numbersOfBits1, symbols1);
307       stream3bitsConsumed = decodeSymbol(outputBase, output3 + 3, stream3bits,
308                                          stream3bitsConsumed, tableLog1,
309                                          numbersOfBits1, symbols1);
310       stream4bitsConsumed = decodeSymbol(outputBase, output4 + 3, stream4bits,
311                                          stream4bitsConsumed, tableLog1,
312                                          numbersOfBits1, symbols1);
313 
314       output1 += SIZE_OF_INT;
315       output2 += SIZE_OF_INT;
316       output3 += SIZE_OF_INT;
317       output4 += SIZE_OF_INT;
318 
319       BitInputStream.Loader loader =
320           new BitInputStream.Loader(inputBase, start1, stream1currentAddress,
321                                     stream1bits, stream1bitsConsumed);
322       boolean done = loader.load();
323       stream1bitsConsumed = loader.getBitsConsumed();
324       stream1bits = loader.getBits();
325       stream1currentAddress = loader.getCurrentAddress();
326 
327       if (done) {
328         break;
329       }
330 
331       loader =
332           new BitInputStream.Loader(inputBase, start2, stream2currentAddress,
333                                     stream2bits, stream2bitsConsumed);
334       done = loader.load();
335       stream2bitsConsumed = loader.getBitsConsumed();
336       stream2bits = loader.getBits();
337       stream2currentAddress = loader.getCurrentAddress();
338 
339       if (done) {
340         break;
341       }
342 
343       loader =
344           new BitInputStream.Loader(inputBase, start3, stream3currentAddress,
345                                     stream3bits, stream3bitsConsumed);
346       done = loader.load();
347       stream3bitsConsumed = loader.getBitsConsumed();
348       stream3bits = loader.getBits();
349       stream3currentAddress = loader.getCurrentAddress();
350       if (done) {
351         break;
352       }
353 
354       loader =
355           new BitInputStream.Loader(inputBase, start4, stream4currentAddress,
356                                     stream4bits, stream4bitsConsumed);
357       done = loader.load();
358       stream4bitsConsumed = loader.getBitsConsumed();
359       stream4bits = loader.getBits();
360       stream4currentAddress = loader.getCurrentAddress();
361       if (done) {
362         break;
363       }
364     }
365 
366     Util.verify(output1 <= outputStart2 && output2 <= outputStart3 &&
367                 output3 <= outputStart4, inputAddress, INPUT_IS_CORRUPTED);
368 
369     /// finish streams one by one
370     decodeTail(inputBase, start1, stream1currentAddress, stream1bitsConsumed,
371                stream1bits, outputBase, output1, outputStart2);
372     decodeTail(inputBase, start2, stream2currentAddress, stream2bitsConsumed,
373                stream2bits, outputBase, output2, outputStart3);
374     decodeTail(inputBase, start3, stream3currentAddress, stream3bitsConsumed,
375                stream3bits, outputBase, output3, outputStart4);
376     decodeTail(inputBase, start4, stream4currentAddress, stream4bitsConsumed,
377                stream4bits, outputBase, output4, outputLimit);
378   }
379 
380   private void decodeTail(final Object inputBase, final long startAddress,
381                           long currentAddress, int bitsConsumed, long bits,
382                           final Object outputBase, long outputAddress,
383                           final long outputLimit) {
384     final int tableLog1 = this.tableLog;
385     final byte[] numbersOfBits1 = this.numbersOfBits;
386     final byte[] symbols1 = this.symbols;
387 
388     // closer to the end
389     while (outputAddress < outputLimit) {
390       final BitInputStream.Loader loader =
391           new BitInputStream.Loader(inputBase, startAddress, currentAddress,
392                                     bits, bitsConsumed);
393       final boolean done = loader.load();
394       bitsConsumed = loader.getBitsConsumed();
395       bits = loader.getBits();
396       currentAddress = loader.getCurrentAddress();
397       if (done) {
398         break;
399       }
400 
401       bitsConsumed =
402           decodeSymbol(outputBase, outputAddress++, bits, bitsConsumed,
403                        tableLog1, numbersOfBits1, symbols1);
404     }
405 
406     // not more data in bit stream, so no need to reload
407     while (outputAddress < outputLimit) {
408       bitsConsumed =
409           decodeSymbol(outputBase, outputAddress++, bits, bitsConsumed,
410                        tableLog1, numbersOfBits1, symbols1);
411     }
412 
413     Util.verify(isEndOfStream(startAddress, currentAddress, bitsConsumed),
414                 startAddress, "Bit stream is not fully consumed");
415   }
416 
417   private static int decodeSymbol(final Object outputBase,
418                                   final long outputAddress,
419                                   final long bitContainer,
420                                   final int bitsConsumed, final int tableLog,
421                                   final byte[] numbersOfBits,
422                                   final byte[] symbols) {
423     final int value = (int) peekBitsFast(bitsConsumed, bitContainer, tableLog);
424     UnsafeUtil.UNSAFE.putByte(outputBase, outputAddress, symbols[value]);
425     return bitsConsumed + numbersOfBits[value];
426   }
427 }