View Javadoc
1   /*
2    * This file is part of Waarp Project (named also Waarp or GG).
3    *
4    *  Copyright (c) 2019, Waarp SAS, and individual contributors by the @author
5    *  tags. See the COPYRIGHT.txt in the distribution for a full listing of
6    * individual contributors.
7    *
8    *  All Waarp Project is free software: you can redistribute it and/or
9    * modify it under the terms of the GNU General Public License as published by
10   * the Free Software Foundation, either version 3 of the License, or (at your
11   * option) any later version.
12   *
13   * Waarp is distributed in the hope that it will be useful, but WITHOUT ANY
14   * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
15   * A PARTICULAR PURPOSE. See the GNU General Public License for more details.
16   *
17   *  You should have received a copy of the GNU General Public License along with
18   * Waarp . If not, see <http://www.gnu.org/licenses/>.
19   */
20  
21  /*
22   * Licensed under the Apache License, Version 2.0 (the "License");
23   * you may not use this file except in compliance with the License.
24   * You may obtain a copy of the License at
25   *
26   *     http://www.apache.org/licenses/LICENSE-2.0
27   *
28   * Unless required by applicable law or agreed to in writing, software
29   * distributed under the License is distributed on an "AS IS" BASIS,
30   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31   * See the License for the specific language governing permissions and
32   * limitations under the License.
33   */
34  package org.waarp.compress.zstdsafe;
35  
36  import java.util.Arrays;
37  
38  import static org.waarp.compress.zstdsafe.BitInputStream.*;
39  import static org.waarp.compress.zstdsafe.Constants.*;
40  import static org.waarp.compress.zstdsafe.UnsafeUtil.*;
41  import static org.waarp.compress.zstdsafe.Util.*;
42  
43  class Huffman {
44    public static final int MAX_SYMBOL = 255;
45    public static final int MAX_SYMBOL_COUNT = MAX_SYMBOL + 1;
46  
47    public static final int MAX_TABLE_LOG = 12;
48    public static final int MIN_TABLE_LOG = 5;
49    public static final int MAX_FSE_TABLE_LOG = 6;
50    public static final String NOT_ENOUGH_INPUT_BYTES = "Not enough input bytes";
51    public static final String INPUT_IS_CORRUPTED = "Input is corrupted";
52  
53    // stats
54    private final byte[] weights = new byte[MAX_SYMBOL + 1];
55    private final int[] ranks = new int[MAX_TABLE_LOG + 1];
56  
57    // table
58    private int tableLog = -1;
59    private final byte[] symbols = new byte[1 << MAX_TABLE_LOG];
60    private final byte[] numbersOfBits = new byte[1 << MAX_TABLE_LOG];
61  
62    private final FseTableReader reader = new FseTableReader();
63    private final FiniteStateEntropy.Table fseTable =
64        new FiniteStateEntropy.Table(MAX_FSE_TABLE_LOG);
65  
66    public boolean isLoaded() {
67      return tableLog != -1;
68    }
69  
70    public int readTable(final byte[] inputBase, final int inputAddress,
71                         final int size) {
72      Arrays.fill(ranks, 0);
73      int input = inputAddress;
74  
75      // read table header
76      verify(size > 0, input, NOT_ENOUGH_INPUT_BYTES);
77      int inputSize = inputBase[input++] & 0xFF;
78  
79      final int outputSize;
80      if (inputSize >= 128) {
81        outputSize = inputSize - 127;
82        inputSize = ((outputSize + 1) / 2);
83  
84        verify(inputSize + 1 <= size, input, NOT_ENOUGH_INPUT_BYTES);
85        verify(outputSize <= MAX_SYMBOL + 1, input, INPUT_IS_CORRUPTED);
86  
87        for (int i = 0; i < outputSize; i += 2) {
88          final int value = inputBase[input + i / 2] & 0xFF;
89          weights[i] = (byte) (value >>> 4);
90          weights[i + 1] = (byte) (value & 0xF);
91        }
92      } else {
93        verify(inputSize + 1 <= size, input, NOT_ENOUGH_INPUT_BYTES);
94  
95        final int inputLimit = input + inputSize;
96        input += reader.readFseTable(fseTable, inputBase, input, inputLimit,
97                                     FiniteStateEntropy.MAX_SYMBOL,
98                                     MAX_FSE_TABLE_LOG);
99        outputSize =
100           FiniteStateEntropy.decompress(fseTable, inputBase, input, inputLimit,
101                                         weights);
102     }
103 
104     int totalWeight = 0;
105     for (int i = 0; i < outputSize; i++) {
106       ranks[weights[i]]++;
107       totalWeight +=
108           (1 << weights[i]) >> 1;   // TODO same as 1 << (weights[n] - 1)?
109     }
110     verify(totalWeight != 0, input, INPUT_IS_CORRUPTED);
111 
112     tableLog = Util.highestBit(totalWeight) + 1;
113     verify(tableLog <= MAX_TABLE_LOG, input, INPUT_IS_CORRUPTED);
114 
115     final int total = 1 << tableLog;
116     final int rest = total - totalWeight;
117     verify(isPowerOf2(rest), input, INPUT_IS_CORRUPTED);
118 
119     final int lastWeight = Util.highestBit(rest) + 1;
120 
121     weights[outputSize] = (byte) lastWeight;
122     ranks[lastWeight]++;
123 
124     final int numberOfSymbols = outputSize + 1;
125 
126     // populate table
127     int nextRankStart = 0;
128     for (int i = 1; i < tableLog + 1; ++i) {
129       final int current = nextRankStart;
130       nextRankStart += ranks[i] << (i - 1);
131       ranks[i] = current;
132     }
133 
134     for (int n = 0; n < numberOfSymbols; n++) {
135       final int weight = weights[n];
136       final int length = (1 << weight) >> 1;  // TODO: 1 << (weight - 1) ??
137 
138       final byte symbol = (byte) n;
139       final byte numberOfBits = (byte) (tableLog + 1 - weight);
140       for (int i = ranks[weight]; i < ranks[weight] + length; i++) {
141         symbols[i] = symbol;
142         numbersOfBits[i] = numberOfBits;
143       }
144       ranks[weight] += length;
145     }
146 
147     verify(ranks[1] >= 2 && (ranks[1] & 1) == 0, input, INPUT_IS_CORRUPTED);
148 
149     return inputSize + 1;
150   }
151 
152   public void decodeSingleStream(final byte[] inputBase, final int inputAddress,
153                                  final int inputLimit, final byte[] outputBase,
154                                  final int outputAddress,
155                                  final int outputLimit) {
156     final BitInputStream.Initializer initializer =
157         new BitInputStream.Initializer(inputBase, inputAddress, inputLimit);
158     initializer.initialize();
159 
160     long bits = initializer.getBits();
161     int bitsConsumed = initializer.getBitsConsumed();
162     int currentAddress = initializer.getCurrentAddress();
163 
164     final int tableLog1 = this.tableLog;
165     final byte[] numbersOfBits1 = this.numbersOfBits;
166     final byte[] symbols1 = this.symbols;
167 
168     // 4 symbols at a time
169     int output = outputAddress;
170     final int fastOutputLimit = outputLimit - 4;
171     while (output < fastOutputLimit) {
172       final BitInputStream.Loader loader =
173           new BitInputStream.Loader(inputBase, inputAddress, currentAddress,
174                                     bits, bitsConsumed);
175       final boolean done = loader.load();
176       bits = loader.getBits();
177       bitsConsumed = loader.getBitsConsumed();
178       currentAddress = loader.getCurrentAddress();
179       if (done) {
180         break;
181       }
182 
183       bitsConsumed =
184           decodeSymbol(outputBase, output, bits, bitsConsumed, tableLog1,
185                        numbersOfBits1, symbols1);
186       bitsConsumed =
187           decodeSymbol(outputBase, output + 1, bits, bitsConsumed, tableLog1,
188                        numbersOfBits1, symbols1);
189       bitsConsumed =
190           decodeSymbol(outputBase, output + 2, bits, bitsConsumed, tableLog1,
191                        numbersOfBits1, symbols1);
192       bitsConsumed =
193           decodeSymbol(outputBase, output + 3, bits, bitsConsumed, tableLog1,
194                        numbersOfBits1, symbols1);
195       output += SIZE_OF_INT;
196     }
197 
198     decodeTail(inputBase, inputAddress, currentAddress, bitsConsumed, bits,
199                outputBase, output, outputLimit);
200   }
201 
202   public void decode4Streams(final byte[] inputBase, final int inputAddress,
203                              final int inputLimit, final byte[] outputBase,
204                              final int outputAddress, final int outputLimit) {
205     verify(inputLimit - inputAddress >= 10, inputAddress,
206            INPUT_IS_CORRUPTED); // jump table + 1 byte per stream
207 
208     final int start1 =
209         inputAddress + 3 * SIZE_OF_SHORT; // for the shorts we read below
210     final int start2 = start1 + (getShort(inputBase, inputAddress) & 0xFFFF);
211     final int start3 =
212         start2 + (getShort(inputBase, inputAddress + 2) & 0xFFFF);
213     final int start4 =
214         start3 + (getShort(inputBase, inputAddress + 4) & 0xFFFF);
215 
216     BitInputStream.Initializer initializer =
217         new BitInputStream.Initializer(inputBase, start1, start2);
218     initializer.initialize();
219     int stream1bitsConsumed = initializer.getBitsConsumed();
220     int stream1currentAddress = initializer.getCurrentAddress();
221     long stream1bits = initializer.getBits();
222 
223     initializer = new BitInputStream.Initializer(inputBase, start2, start3);
224     initializer.initialize();
225     int stream2bitsConsumed = initializer.getBitsConsumed();
226     int stream2currentAddress = initializer.getCurrentAddress();
227     long stream2bits = initializer.getBits();
228 
229     initializer = new BitInputStream.Initializer(inputBase, start3, start4);
230     initializer.initialize();
231     int stream3bitsConsumed = initializer.getBitsConsumed();
232     int stream3currentAddress = initializer.getCurrentAddress();
233     long stream3bits = initializer.getBits();
234 
235     initializer = new BitInputStream.Initializer(inputBase, start4, inputLimit);
236     initializer.initialize();
237     int stream4bitsConsumed = initializer.getBitsConsumed();
238     int stream4currentAddress = initializer.getCurrentAddress();
239     long stream4bits = initializer.getBits();
240 
241     final int segmentSize = (outputLimit - outputAddress + 3) / 4;
242 
243     final int outputStart2 = outputAddress + segmentSize;
244     final int outputStart3 = outputStart2 + segmentSize;
245     final int outputStart4 = outputStart3 + segmentSize;
246 
247     int output1 = outputAddress;
248     int output2 = outputStart2;
249     int output3 = outputStart3;
250     int output4 = outputStart4;
251 
252     final int fastOutputLimit = outputLimit - 7;
253     final int tableLog1 = this.tableLog;
254     final byte[] numbersOfBits1 = this.numbersOfBits;
255     final byte[] symbols1 = this.symbols;
256 
257     while (output4 < fastOutputLimit) {
258       stream1bitsConsumed =
259           decodeSymbol(outputBase, output1, stream1bits, stream1bitsConsumed,
260                        tableLog1, numbersOfBits1, symbols1);
261       stream2bitsConsumed =
262           decodeSymbol(outputBase, output2, stream2bits, stream2bitsConsumed,
263                        tableLog1, numbersOfBits1, symbols1);
264       stream3bitsConsumed =
265           decodeSymbol(outputBase, output3, stream3bits, stream3bitsConsumed,
266                        tableLog1, numbersOfBits1, symbols1);
267       stream4bitsConsumed =
268           decodeSymbol(outputBase, output4, stream4bits, stream4bitsConsumed,
269                        tableLog1, numbersOfBits1, symbols1);
270 
271       stream1bitsConsumed = decodeSymbol(outputBase, output1 + 1, stream1bits,
272                                          stream1bitsConsumed, tableLog1,
273                                          numbersOfBits1, symbols1);
274       stream2bitsConsumed = decodeSymbol(outputBase, output2 + 1, stream2bits,
275                                          stream2bitsConsumed, tableLog1,
276                                          numbersOfBits1, symbols1);
277       stream3bitsConsumed = decodeSymbol(outputBase, output3 + 1, stream3bits,
278                                          stream3bitsConsumed, tableLog1,
279                                          numbersOfBits1, symbols1);
280       stream4bitsConsumed = decodeSymbol(outputBase, output4 + 1, stream4bits,
281                                          stream4bitsConsumed, tableLog1,
282                                          numbersOfBits1, symbols1);
283 
284       stream1bitsConsumed = decodeSymbol(outputBase, output1 + 2, stream1bits,
285                                          stream1bitsConsumed, tableLog1,
286                                          numbersOfBits1, symbols1);
287       stream2bitsConsumed = decodeSymbol(outputBase, output2 + 2, stream2bits,
288                                          stream2bitsConsumed, tableLog1,
289                                          numbersOfBits1, symbols1);
290       stream3bitsConsumed = decodeSymbol(outputBase, output3 + 2, stream3bits,
291                                          stream3bitsConsumed, tableLog1,
292                                          numbersOfBits1, symbols1);
293       stream4bitsConsumed = decodeSymbol(outputBase, output4 + 2, stream4bits,
294                                          stream4bitsConsumed, tableLog1,
295                                          numbersOfBits1, symbols1);
296 
297       stream1bitsConsumed = decodeSymbol(outputBase, output1 + 3, stream1bits,
298                                          stream1bitsConsumed, tableLog1,
299                                          numbersOfBits1, symbols1);
300       stream2bitsConsumed = decodeSymbol(outputBase, output2 + 3, stream2bits,
301                                          stream2bitsConsumed, tableLog1,
302                                          numbersOfBits1, symbols1);
303       stream3bitsConsumed = decodeSymbol(outputBase, output3 + 3, stream3bits,
304                                          stream3bitsConsumed, tableLog1,
305                                          numbersOfBits1, symbols1);
306       stream4bitsConsumed = decodeSymbol(outputBase, output4 + 3, stream4bits,
307                                          stream4bitsConsumed, tableLog1,
308                                          numbersOfBits1, symbols1);
309 
310       output1 += SIZE_OF_INT;
311       output2 += SIZE_OF_INT;
312       output3 += SIZE_OF_INT;
313       output4 += SIZE_OF_INT;
314 
315       BitInputStream.Loader loader =
316           new BitInputStream.Loader(inputBase, start1, stream1currentAddress,
317                                     stream1bits, stream1bitsConsumed);
318       boolean done = loader.load();
319       stream1bitsConsumed = loader.getBitsConsumed();
320       stream1bits = loader.getBits();
321       stream1currentAddress = loader.getCurrentAddress();
322 
323       if (done) {
324         break;
325       }
326 
327       loader =
328           new BitInputStream.Loader(inputBase, start2, stream2currentAddress,
329                                     stream2bits, stream2bitsConsumed);
330       done = loader.load();
331       stream2bitsConsumed = loader.getBitsConsumed();
332       stream2bits = loader.getBits();
333       stream2currentAddress = loader.getCurrentAddress();
334 
335       if (done) {
336         break;
337       }
338 
339       loader =
340           new BitInputStream.Loader(inputBase, start3, stream3currentAddress,
341                                     stream3bits, stream3bitsConsumed);
342       done = loader.load();
343       stream3bitsConsumed = loader.getBitsConsumed();
344       stream3bits = loader.getBits();
345       stream3currentAddress = loader.getCurrentAddress();
346       if (done) {
347         break;
348       }
349 
350       loader =
351           new BitInputStream.Loader(inputBase, start4, stream4currentAddress,
352                                     stream4bits, stream4bitsConsumed);
353       done = loader.load();
354       stream4bitsConsumed = loader.getBitsConsumed();
355       stream4bits = loader.getBits();
356       stream4currentAddress = loader.getCurrentAddress();
357       if (done) {
358         break;
359       }
360     }
361 
362     verify(output1 <= outputStart2 && output2 <= outputStart3 &&
363            output3 <= outputStart4, inputAddress, INPUT_IS_CORRUPTED);
364 
365     /// finish streams one by one
366     decodeTail(inputBase, start1, stream1currentAddress, stream1bitsConsumed,
367                stream1bits, outputBase, output1, outputStart2);
368     decodeTail(inputBase, start2, stream2currentAddress, stream2bitsConsumed,
369                stream2bits, outputBase, output2, outputStart3);
370     decodeTail(inputBase, start3, stream3currentAddress, stream3bitsConsumed,
371                stream3bits, outputBase, output3, outputStart4);
372     decodeTail(inputBase, start4, stream4currentAddress, stream4bitsConsumed,
373                stream4bits, outputBase, output4, outputLimit);
374   }
375 
376   private void decodeTail(final byte[] inputBase, final int startAddress,
377                           int currentAddress, int bitsConsumed, long bits,
378                           final byte[] outputBase, int outputAddress,
379                           final int outputLimit) {
380     final int tableLog1 = this.tableLog;
381     final byte[] numbersOfBits1 = this.numbersOfBits;
382     final byte[] symbols1 = this.symbols;
383 
384     // closer to the end
385     while (outputAddress < outputLimit) {
386       final BitInputStream.Loader loader =
387           new BitInputStream.Loader(inputBase, startAddress, currentAddress,
388                                     bits, bitsConsumed);
389       final boolean done = loader.load();
390       bitsConsumed = loader.getBitsConsumed();
391       bits = loader.getBits();
392       currentAddress = loader.getCurrentAddress();
393       if (done) {
394         break;
395       }
396 
397       bitsConsumed =
398           decodeSymbol(outputBase, outputAddress++, bits, bitsConsumed,
399                        tableLog1, numbersOfBits1, symbols1);
400     }
401 
402     // not more data in bit stream, so no need to reload
403     while (outputAddress < outputLimit) {
404       bitsConsumed =
405           decodeSymbol(outputBase, outputAddress++, bits, bitsConsumed,
406                        tableLog1, numbersOfBits1, symbols1);
407     }
408 
409     verify(isEndOfStream(startAddress, currentAddress, bitsConsumed),
410            startAddress, "Bit stream is not fully consumed");
411   }
412 
413   private static int decodeSymbol(final byte[] outputBase,
414                                   final int outputAddress,
415                                   final long bitContainer,
416                                   final int bitsConsumed, final int tableLog,
417                                   final byte[] numbersOfBits,
418                                   final byte[] symbols) {
419     final int value = (int) peekBitsFast(bitsConsumed, bitContainer, tableLog);
420     outputBase[outputAddress] = symbols[value];
421     return bitsConsumed + numbersOfBits[value];
422   }
423 }