1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34 package org.waarp.compress.zstdunsafe;
35
36 import java.util.Arrays;
37
38 import static org.waarp.compress.zstdunsafe.BitInputStream.*;
39 import static org.waarp.compress.zstdunsafe.Constants.*;
40
41 class Huffman {
42 public static final int MAX_SYMBOL = 255;
43 public static final int MAX_SYMBOL_COUNT = MAX_SYMBOL + 1;
44
45 public static final int MAX_TABLE_LOG = 12;
46 public static final int MIN_TABLE_LOG = 5;
47 public static final int MAX_FSE_TABLE_LOG = 6;
48 public static final String NOT_ENOUGH_INPUT_BYTES = "Not enough input bytes";
49 public static final String INPUT_IS_CORRUPTED = "Input is corrupted";
50
51
52 private final byte[] weights = new byte[MAX_SYMBOL + 1];
53 private final int[] ranks = new int[MAX_TABLE_LOG + 1];
54
55
56 private int tableLog = -1;
57 private final byte[] symbols = new byte[1 << MAX_TABLE_LOG];
58 private final byte[] numbersOfBits = new byte[1 << MAX_TABLE_LOG];
59
60 private final FseTableReader reader = new FseTableReader();
61 private final FiniteStateEntropy.Table fseTable =
62 new FiniteStateEntropy.Table(MAX_FSE_TABLE_LOG);
63
64 public boolean isLoaded() {
65 return tableLog != -1;
66 }
67
68 public int readTable(final Object inputBase, final long inputAddress,
69 final int size) {
70 Arrays.fill(ranks, 0);
71 long input = inputAddress;
72
73
74 Util.verify(size > 0, input, NOT_ENOUGH_INPUT_BYTES);
75 int inputSize = UnsafeUtil.UNSAFE.getByte(inputBase, input++) & 0xFF;
76
77 final int outputSize;
78 if (inputSize >= 128) {
79 outputSize = inputSize - 127;
80 inputSize = ((outputSize + 1) / 2);
81
82 Util.verify(inputSize + 1 <= size, input, NOT_ENOUGH_INPUT_BYTES);
83 Util.verify(true, input, INPUT_IS_CORRUPTED);
84
85 for (int i = 0; i < outputSize; i += 2) {
86 final int value =
87 UnsafeUtil.UNSAFE.getByte(inputBase, input + i / 2) & 0xFF;
88 weights[i] = (byte) (value >>> 4);
89 weights[i + 1] = (byte) (value & 0xf);
90 }
91 } else {
92 Util.verify(inputSize + 1 <= size, input, NOT_ENOUGH_INPUT_BYTES);
93
94 final long inputLimit = input + inputSize;
95 input += reader.readFseTable(fseTable, inputBase, input, inputLimit,
96 FiniteStateEntropy.MAX_SYMBOL,
97 MAX_FSE_TABLE_LOG);
98 outputSize =
99 FiniteStateEntropy.decompress(fseTable, inputBase, input, inputLimit,
100 weights);
101 }
102
103 int totalWeight = 0;
104 for (int i = 0; i < outputSize; i++) {
105 ranks[weights[i]]++;
106 totalWeight +=
107 (1 << weights[i]) >> 1;
108 }
109 Util.verify(totalWeight != 0, input, INPUT_IS_CORRUPTED);
110
111 tableLog = Util.highestBit(totalWeight) + 1;
112 Util.verify(tableLog <= MAX_TABLE_LOG, input, INPUT_IS_CORRUPTED);
113
114 final int total = 1 << tableLog;
115 final int rest = total - totalWeight;
116 Util.verify(Util.isPowerOf2(rest), input, INPUT_IS_CORRUPTED);
117
118 final int lastWeight = Util.highestBit(rest) + 1;
119
120 weights[outputSize] = (byte) lastWeight;
121 ranks[lastWeight]++;
122
123 final int numberOfSymbols = outputSize + 1;
124
125
126 int nextRankStart = 0;
127 for (int i = 1; i < tableLog + 1; ++i) {
128 final int current = nextRankStart;
129 nextRankStart += ranks[i] << (i - 1);
130 ranks[i] = current;
131 }
132
133 for (int n = 0; n < numberOfSymbols; n++) {
134 final int weight = weights[n];
135 final int length = (1 << weight) >> 1;
136
137 final byte symbol = (byte) n;
138 final byte numberOfBits = (byte) (tableLog + 1 - weight);
139 for (int i = ranks[weight]; i < ranks[weight] + length; i++) {
140 symbols[i] = symbol;
141 numbersOfBits[i] = numberOfBits;
142 }
143 ranks[weight] += length;
144 }
145
146 Util.verify(ranks[1] >= 2 && (ranks[1] & 1) == 0, input,
147 INPUT_IS_CORRUPTED);
148
149 return inputSize + 1;
150 }
151
152 public void decodeSingleStream(final Object inputBase,
153 final long inputAddress, final long inputLimit,
154 final Object outputBase,
155 final long outputAddress,
156 final long outputLimit) {
157 final BitInputStream.Initializer initializer =
158 new BitInputStream.Initializer(inputBase, inputAddress, inputLimit);
159 initializer.initialize();
160
161 long bits = initializer.getBits();
162 int bitsConsumed = initializer.getBitsConsumed();
163 long currentAddress = initializer.getCurrentAddress();
164
165 final int tableLog1 = this.tableLog;
166 final byte[] numbersOfBits1 = this.numbersOfBits;
167 final byte[] symbols1 = this.symbols;
168
169
170 long output = outputAddress;
171 final long fastOutputLimit = outputLimit - 4;
172 while (output < fastOutputLimit) {
173 final BitInputStream.Loader loader =
174 new BitInputStream.Loader(inputBase, inputAddress, currentAddress,
175 bits, bitsConsumed);
176 final boolean done = loader.load();
177 bits = loader.getBits();
178 bitsConsumed = loader.getBitsConsumed();
179 currentAddress = loader.getCurrentAddress();
180 if (done) {
181 break;
182 }
183
184 bitsConsumed =
185 decodeSymbol(outputBase, output, bits, bitsConsumed, tableLog1,
186 numbersOfBits1, symbols1);
187 bitsConsumed =
188 decodeSymbol(outputBase, output + 1, bits, bitsConsumed, tableLog1,
189 numbersOfBits1, symbols1);
190 bitsConsumed =
191 decodeSymbol(outputBase, output + 2, bits, bitsConsumed, tableLog1,
192 numbersOfBits1, symbols1);
193 bitsConsumed =
194 decodeSymbol(outputBase, output + 3, bits, bitsConsumed, tableLog1,
195 numbersOfBits1, symbols1);
196 output += SIZE_OF_INT;
197 }
198
199 decodeTail(inputBase, inputAddress, currentAddress, bitsConsumed, bits,
200 outputBase, output, outputLimit);
201 }
202
203 public void decode4Streams(final Object inputBase, final long inputAddress,
204 final long inputLimit, final Object outputBase,
205 final long outputAddress, final long outputLimit) {
206 Util.verify(inputLimit - inputAddress >= 10, inputAddress,
207 INPUT_IS_CORRUPTED);
208
209 final long start1 =
210 inputAddress + 3 * SIZE_OF_SHORT;
211 final long start2 =
212 start1 + (UnsafeUtil.UNSAFE.getShort(inputBase, inputAddress) & 0xFFFF);
213 final long start3 = start2 + (UnsafeUtil.UNSAFE.getShort(inputBase,
214 inputAddress + 2) &
215 0xFFFF);
216 final long start4 = start3 + (UnsafeUtil.UNSAFE.getShort(inputBase,
217 inputAddress + 4) &
218 0xFFFF);
219
220 BitInputStream.Initializer initializer =
221 new BitInputStream.Initializer(inputBase, start1, start2);
222 initializer.initialize();
223 int stream1bitsConsumed = initializer.getBitsConsumed();
224 long stream1currentAddress = initializer.getCurrentAddress();
225 long stream1bits = initializer.getBits();
226
227 initializer = new BitInputStream.Initializer(inputBase, start2, start3);
228 initializer.initialize();
229 int stream2bitsConsumed = initializer.getBitsConsumed();
230 long stream2currentAddress = initializer.getCurrentAddress();
231 long stream2bits = initializer.getBits();
232
233 initializer = new BitInputStream.Initializer(inputBase, start3, start4);
234 initializer.initialize();
235 int stream3bitsConsumed = initializer.getBitsConsumed();
236 long stream3currentAddress = initializer.getCurrentAddress();
237 long stream3bits = initializer.getBits();
238
239 initializer = new BitInputStream.Initializer(inputBase, start4, inputLimit);
240 initializer.initialize();
241 int stream4bitsConsumed = initializer.getBitsConsumed();
242 long stream4currentAddress = initializer.getCurrentAddress();
243 long stream4bits = initializer.getBits();
244
245 final int segmentSize = (int) ((outputLimit - outputAddress + 3) / 4);
246
247 final long outputStart2 = outputAddress + segmentSize;
248 final long outputStart3 = outputStart2 + segmentSize;
249 final long outputStart4 = outputStart3 + segmentSize;
250
251 long output1 = outputAddress;
252 long output2 = outputStart2;
253 long output3 = outputStart3;
254 long output4 = outputStart4;
255
256 final long fastOutputLimit = outputLimit - 7;
257 final int tableLog1 = this.tableLog;
258 final byte[] numbersOfBits1 = this.numbersOfBits;
259 final byte[] symbols1 = this.symbols;
260
261 while (output4 < fastOutputLimit) {
262 stream1bitsConsumed =
263 decodeSymbol(outputBase, output1, stream1bits, stream1bitsConsumed,
264 tableLog1, numbersOfBits1, symbols1);
265 stream2bitsConsumed =
266 decodeSymbol(outputBase, output2, stream2bits, stream2bitsConsumed,
267 tableLog1, numbersOfBits1, symbols1);
268 stream3bitsConsumed =
269 decodeSymbol(outputBase, output3, stream3bits, stream3bitsConsumed,
270 tableLog1, numbersOfBits1, symbols1);
271 stream4bitsConsumed =
272 decodeSymbol(outputBase, output4, stream4bits, stream4bitsConsumed,
273 tableLog1, numbersOfBits1, symbols1);
274
275 stream1bitsConsumed = decodeSymbol(outputBase, output1 + 1, stream1bits,
276 stream1bitsConsumed, tableLog1,
277 numbersOfBits1, symbols1);
278 stream2bitsConsumed = decodeSymbol(outputBase, output2 + 1, stream2bits,
279 stream2bitsConsumed, tableLog1,
280 numbersOfBits1, symbols1);
281 stream3bitsConsumed = decodeSymbol(outputBase, output3 + 1, stream3bits,
282 stream3bitsConsumed, tableLog1,
283 numbersOfBits1, symbols1);
284 stream4bitsConsumed = decodeSymbol(outputBase, output4 + 1, stream4bits,
285 stream4bitsConsumed, tableLog1,
286 numbersOfBits1, symbols1);
287
288 stream1bitsConsumed = decodeSymbol(outputBase, output1 + 2, stream1bits,
289 stream1bitsConsumed, tableLog1,
290 numbersOfBits1, symbols1);
291 stream2bitsConsumed = decodeSymbol(outputBase, output2 + 2, stream2bits,
292 stream2bitsConsumed, tableLog1,
293 numbersOfBits1, symbols1);
294 stream3bitsConsumed = decodeSymbol(outputBase, output3 + 2, stream3bits,
295 stream3bitsConsumed, tableLog1,
296 numbersOfBits1, symbols1);
297 stream4bitsConsumed = decodeSymbol(outputBase, output4 + 2, stream4bits,
298 stream4bitsConsumed, tableLog1,
299 numbersOfBits1, symbols1);
300
301 stream1bitsConsumed = decodeSymbol(outputBase, output1 + 3, stream1bits,
302 stream1bitsConsumed, tableLog1,
303 numbersOfBits1, symbols1);
304 stream2bitsConsumed = decodeSymbol(outputBase, output2 + 3, stream2bits,
305 stream2bitsConsumed, tableLog1,
306 numbersOfBits1, symbols1);
307 stream3bitsConsumed = decodeSymbol(outputBase, output3 + 3, stream3bits,
308 stream3bitsConsumed, tableLog1,
309 numbersOfBits1, symbols1);
310 stream4bitsConsumed = decodeSymbol(outputBase, output4 + 3, stream4bits,
311 stream4bitsConsumed, tableLog1,
312 numbersOfBits1, symbols1);
313
314 output1 += SIZE_OF_INT;
315 output2 += SIZE_OF_INT;
316 output3 += SIZE_OF_INT;
317 output4 += SIZE_OF_INT;
318
319 BitInputStream.Loader loader =
320 new BitInputStream.Loader(inputBase, start1, stream1currentAddress,
321 stream1bits, stream1bitsConsumed);
322 boolean done = loader.load();
323 stream1bitsConsumed = loader.getBitsConsumed();
324 stream1bits = loader.getBits();
325 stream1currentAddress = loader.getCurrentAddress();
326
327 if (done) {
328 break;
329 }
330
331 loader =
332 new BitInputStream.Loader(inputBase, start2, stream2currentAddress,
333 stream2bits, stream2bitsConsumed);
334 done = loader.load();
335 stream2bitsConsumed = loader.getBitsConsumed();
336 stream2bits = loader.getBits();
337 stream2currentAddress = loader.getCurrentAddress();
338
339 if (done) {
340 break;
341 }
342
343 loader =
344 new BitInputStream.Loader(inputBase, start3, stream3currentAddress,
345 stream3bits, stream3bitsConsumed);
346 done = loader.load();
347 stream3bitsConsumed = loader.getBitsConsumed();
348 stream3bits = loader.getBits();
349 stream3currentAddress = loader.getCurrentAddress();
350 if (done) {
351 break;
352 }
353
354 loader =
355 new BitInputStream.Loader(inputBase, start4, stream4currentAddress,
356 stream4bits, stream4bitsConsumed);
357 done = loader.load();
358 stream4bitsConsumed = loader.getBitsConsumed();
359 stream4bits = loader.getBits();
360 stream4currentAddress = loader.getCurrentAddress();
361 if (done) {
362 break;
363 }
364 }
365
366 Util.verify(output1 <= outputStart2 && output2 <= outputStart3 &&
367 output3 <= outputStart4, inputAddress, INPUT_IS_CORRUPTED);
368
369
370 decodeTail(inputBase, start1, stream1currentAddress, stream1bitsConsumed,
371 stream1bits, outputBase, output1, outputStart2);
372 decodeTail(inputBase, start2, stream2currentAddress, stream2bitsConsumed,
373 stream2bits, outputBase, output2, outputStart3);
374 decodeTail(inputBase, start3, stream3currentAddress, stream3bitsConsumed,
375 stream3bits, outputBase, output3, outputStart4);
376 decodeTail(inputBase, start4, stream4currentAddress, stream4bitsConsumed,
377 stream4bits, outputBase, output4, outputLimit);
378 }
379
380 private void decodeTail(final Object inputBase, final long startAddress,
381 long currentAddress, int bitsConsumed, long bits,
382 final Object outputBase, long outputAddress,
383 final long outputLimit) {
384 final int tableLog1 = this.tableLog;
385 final byte[] numbersOfBits1 = this.numbersOfBits;
386 final byte[] symbols1 = this.symbols;
387
388
389 while (outputAddress < outputLimit) {
390 final BitInputStream.Loader loader =
391 new BitInputStream.Loader(inputBase, startAddress, currentAddress,
392 bits, bitsConsumed);
393 final boolean done = loader.load();
394 bitsConsumed = loader.getBitsConsumed();
395 bits = loader.getBits();
396 currentAddress = loader.getCurrentAddress();
397 if (done) {
398 break;
399 }
400
401 bitsConsumed =
402 decodeSymbol(outputBase, outputAddress++, bits, bitsConsumed,
403 tableLog1, numbersOfBits1, symbols1);
404 }
405
406
407 while (outputAddress < outputLimit) {
408 bitsConsumed =
409 decodeSymbol(outputBase, outputAddress++, bits, bitsConsumed,
410 tableLog1, numbersOfBits1, symbols1);
411 }
412
413 Util.verify(isEndOfStream(startAddress, currentAddress, bitsConsumed),
414 startAddress, "Bit stream is not fully consumed");
415 }
416
417 private static int decodeSymbol(final Object outputBase,
418 final long outputAddress,
419 final long bitContainer,
420 final int bitsConsumed, final int tableLog,
421 final byte[] numbersOfBits,
422 final byte[] symbols) {
423 final int value = (int) peekBitsFast(bitsConsumed, bitContainer, tableLog);
424 UnsafeUtil.UNSAFE.putByte(outputBase, outputAddress, symbols[value]);
425 return bitsConsumed + numbersOfBits[value];
426 }
427 }