View Javadoc
1   /*
2    * This file is part of Waarp Project (named also Waarp or GG).
3    *
4    *  Copyright (c) 2019, Waarp SAS, and individual contributors by the @author
5    *  tags. See the COPYRIGHT.txt in the distribution for a full listing of
6    * individual contributors.
7    *
8    *  All Waarp Project is free software: you can redistribute it and/or
9    * modify it under the terms of the GNU General Public License as published by
10   * the Free Software Foundation, either version 3 of the License, or (at your
11   * option) any later version.
12   *
13   * Waarp is distributed in the hope that it will be useful, but WITHOUT ANY
14   * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
15   * A PARTICULAR PURPOSE. See the GNU General Public License for more details.
16   *
17   *  You should have received a copy of the GNU General Public License along with
18   * Waarp . If not, see <http://www.gnu.org/licenses/>.
19   */
20  
21  /*
22   * Licensed under the Apache License, Version 2.0 (the "License");
23   * you may not use this file except in compliance with the License.
24   * You may obtain a copy of the License at
25   *
26   *     http://www.apache.org/licenses/LICENSE-2.0
27   *
28   * Unless required by applicable law or agreed to in writing, software
29   * distributed under the License is distributed on an "AS IS" BASIS,
30   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31   * See the License for the specific language governing permissions and
32   * limitations under the License.
33   */
34  package org.waarp.compress.zstdunsafe;
35  
36  import static org.waarp.compress.zstdunsafe.Constants.*;
37  import static org.waarp.compress.zstdunsafe.Huffman.*;
38  import static org.waarp.compress.zstdunsafe.UnsafeUtil.*;
39  import static org.waarp.compress.zstdunsafe.Util.*;
40  import static sun.misc.Unsafe.*;
41  
42  class ZstdFrameCompressor {
43    static final int MAX_FRAME_HEADER_SIZE = 14;
44  
45    private static final int CHECKSUM_FLAG = 0x4;
46    private static final int SINGLE_SEGMENT_FLAG = 0x20;
47  
48    private static final int MINIMUM_LITERALS_SIZE = 63;
49  
50    // the maximum table log allowed for literal encoding per RFC 8478, section 4.2.1
51    private static final int MAX_HUFFMAN_TABLE_LOG = 11;
52    public static final String OUTPUT_BUFFER_TOO_SMALL =
53        "Output buffer too small";
54  
55    private ZstdFrameCompressor() {
56    }
57  
58    // visible for testing
59    static int writeMagic(final Object outputBase, final long outputAddress,
60                          final long outputLimit) {
61      checkArgument(outputLimit - outputAddress >= SIZE_OF_INT,
62                    OUTPUT_BUFFER_TOO_SMALL);
63  
64      UNSAFE.putInt(outputBase, outputAddress, MAGIC_NUMBER);
65      return SIZE_OF_INT;
66    }
67  
68    // visible for testing
69    static int writeFrameHeader(final Object outputBase, final long outputAddress,
70                                final long outputLimit, final int inputSize,
71                                final int windowSize) {
72      checkArgument(outputLimit - outputAddress >= MAX_FRAME_HEADER_SIZE,
73                    OUTPUT_BUFFER_TOO_SMALL);
74  
75      long output = outputAddress;
76  
77      final int contentSizeDescriptor =
78          (inputSize >= 256? 1 : 0) + (inputSize >= 65536 + 256? 1 : 0);
79      int frameHeaderDescriptor =
80          (contentSizeDescriptor << 6) | CHECKSUM_FLAG; // dictionary ID missing
81  
82      final boolean singleSegment = windowSize >= inputSize;
83      if (singleSegment) {
84        frameHeaderDescriptor |= SINGLE_SEGMENT_FLAG;
85      }
86  
87      UNSAFE.putByte(outputBase, output, (byte) frameHeaderDescriptor);
88      output++;
89  
90      if (!singleSegment) {
91        final int base = Integer.highestOneBit(windowSize);
92  
93        final int exponent = 32 - Integer.numberOfLeadingZeros(base) - 1;
94        if (exponent < MIN_WINDOW_LOG) {
95          throw new IllegalArgumentException(
96              "Minimum window size is " + (1 << MIN_WINDOW_LOG));
97        }
98  
99        final int remainder = windowSize - base;
100       if (remainder % (base / 8) != 0) {
101         throw new IllegalArgumentException(
102             "Window size of magnitude 2^" + exponent + " must be multiple of " +
103             (base / 8));
104       }
105 
106       // mantissa is guaranteed to be between 0-7
107       final int mantissa = remainder / (base / 8);
108       final int encoded = ((exponent - MIN_WINDOW_LOG) << 3) | mantissa;
109 
110       UNSAFE.putByte(outputBase, output, (byte) encoded);
111       output++;
112     }
113 
114     switch (contentSizeDescriptor) {
115       case 0:
116         if (singleSegment) {
117           UNSAFE.putByte(outputBase, output++, (byte) inputSize);
118         }
119         break;
120       case 1:
121         UNSAFE.putShort(outputBase, output, (short) (inputSize - 256));
122         output += SIZE_OF_SHORT;
123         break;
124       case 2:
125         UNSAFE.putInt(outputBase, output, inputSize);
126         output += SIZE_OF_INT;
127         break;
128       default:
129         throw new AssertionError();
130     }
131 
132     return (int) (output - outputAddress);
133   }
134 
135   // visible for testing
136   static int writeChecksum(final Object outputBase, final long outputAddress,
137                            final long outputLimit, final Object inputBase,
138                            final long inputAddress, final long inputLimit) {
139     checkArgument(outputLimit - outputAddress >= SIZE_OF_INT,
140                   OUTPUT_BUFFER_TOO_SMALL);
141 
142     final int inputSize = (int) (inputLimit - inputAddress);
143 
144     final long hash = XxHash64.hash(0, inputBase, inputAddress, inputSize);
145 
146     UNSAFE.putInt(outputBase, outputAddress, (int) hash);
147 
148     return SIZE_OF_INT;
149   }
150 
151   public static int compress(final Object inputBase, final long inputAddress,
152                              final long inputLimit, final Object outputBase,
153                              final long outputAddress, final long outputLimit,
154                              final int compressionLevel) {
155     final int inputSize = (int) (inputLimit - inputAddress);
156 
157     final CompressionParameters parameters =
158         CompressionParameters.compute(compressionLevel, inputSize);
159 
160     long output = outputAddress;
161 
162     output += writeMagic(outputBase, output, outputLimit);
163     output += writeFrameHeader(outputBase, output, outputLimit, inputSize,
164                                1 << parameters.getWindowLog());
165     output +=
166         compressFrame(inputBase, inputAddress, inputLimit, outputBase, output,
167                       outputLimit, parameters);
168     output +=
169         writeChecksum(outputBase, output, outputLimit, inputBase, inputAddress,
170                       inputLimit);
171 
172     return (int) (output - outputAddress);
173   }
174 
175   private static int compressFrame(final Object inputBase,
176                                    final long inputAddress,
177                                    final long inputLimit,
178                                    final Object outputBase,
179                                    final long outputAddress,
180                                    final long outputLimit,
181                                    final CompressionParameters parameters) {
182     final int windowSize = 1 <<
183                            parameters.getWindowLog(); // TODO: store window size in parameters directly?
184     int blockSize = Math.min(MAX_BLOCK_SIZE, windowSize);
185 
186     int outputSize = (int) (outputLimit - outputAddress);
187     int remaining = (int) (inputLimit - inputAddress);
188 
189     long output = outputAddress;
190     long input = inputAddress;
191 
192     final CompressionContext context =
193         new CompressionContext(parameters, inputAddress, remaining);
194 
195     do {
196       checkArgument(outputSize >= SIZE_OF_BLOCK_HEADER + MIN_BLOCK_SIZE,
197                     OUTPUT_BUFFER_TOO_SMALL);
198 
199       final int lastBlockFlag = blockSize >= remaining? 1 : 0;
200       blockSize = Math.min(blockSize, remaining);
201 
202       int compressedSize = 0;
203       if (remaining > 0) {
204         compressedSize = compressBlock(inputBase, input, blockSize, outputBase,
205                                        output + SIZE_OF_BLOCK_HEADER,
206                                        outputSize - SIZE_OF_BLOCK_HEADER,
207                                        context, parameters);
208       }
209 
210       if (compressedSize == 0) { // block is not compressible
211         checkArgument(blockSize + SIZE_OF_BLOCK_HEADER <= outputSize,
212                       "Output size too small");
213 
214         final int blockHeader =
215             lastBlockFlag | (RAW_BLOCK << 1) | (blockSize << 3);
216         put24BitLittleEndian(outputBase, output, blockHeader);
217         UNSAFE.copyMemory(inputBase, input, outputBase,
218                           output + SIZE_OF_BLOCK_HEADER, blockSize);
219         compressedSize = SIZE_OF_BLOCK_HEADER + blockSize;
220       } else {
221         final int blockHeader =
222             lastBlockFlag | (COMPRESSED_BLOCK << 1) | (compressedSize << 3);
223         put24BitLittleEndian(outputBase, output, blockHeader);
224         compressedSize += SIZE_OF_BLOCK_HEADER;
225       }
226 
227       input += blockSize;
228       remaining -= blockSize;
229       output += compressedSize;
230       outputSize -= compressedSize;
231     } while (remaining > 0);
232 
233     return (int) (output - outputAddress);
234   }
235 
236   private static int compressBlock(final Object inputBase,
237                                    final long inputAddress, final int inputSize,
238                                    final Object outputBase,
239                                    final long outputAddress,
240                                    final int outputSize,
241                                    final CompressionContext context,
242                                    final CompressionParameters parameters) {
243     if (inputSize < MIN_BLOCK_SIZE + SIZE_OF_BLOCK_HEADER + 1) {
244       //  don't even attempt compression below a certain input size
245       return 0;
246     }
247 
248     context.blockCompressionState.enforceMaxDistance(inputAddress + inputSize,
249                                                      1 <<
250                                                      parameters.getWindowLog());
251     context.sequenceStore.reset();
252 
253     final int lastLiteralsSize = parameters.getStrategy().getCompressor()
254                                            .compressBlock(inputBase,
255                                                           inputAddress,
256                                                           inputSize,
257                                                           context.sequenceStore,
258                                                           context.blockCompressionState,
259                                                           context.offsets,
260                                                           parameters);
261 
262     final long lastLiteralsAddress =
263         inputAddress + inputSize - lastLiteralsSize;
264 
265     // append [lastLiteralsAddress .. lastLiteralsSize] to sequenceStore literals buffer
266     context.sequenceStore.appendLiterals(inputBase, lastLiteralsAddress,
267                                          lastLiteralsSize);
268 
269     // convert length/offsets into codes
270     context.sequenceStore.generateCodes();
271 
272     final long outputLimit = outputAddress + outputSize;
273     long output = outputAddress;
274 
275     final int compressedLiteralsSize =
276         encodeLiterals(context.huffmanContext, parameters, outputBase, output,
277                        (int) (outputLimit - output),
278                        context.sequenceStore.literalsBuffer,
279                        context.sequenceStore.literalsLength);
280     output += compressedLiteralsSize;
281 
282     final int compressedSequencesSize =
283         SequenceEncoder.compressSequences(outputBase, output,
284                                           (int) (outputLimit - output),
285                                           context.sequenceStore,
286                                           parameters.getStrategy(),
287                                           context.sequenceEncodingContext);
288 
289     final int compressedSize = compressedLiteralsSize + compressedSequencesSize;
290     if (compressedSize == 0) {
291       // not compressible
292       return compressedSize;
293     }
294 
295     // Check compressibility
296     final int maxCompressedSize =
297         inputSize - calculateMinimumGain(inputSize, parameters.getStrategy());
298     if (compressedSize > maxCompressedSize) {
299       return 0; // not compressed
300     }
301 
302     // confirm repeated offsets and entropy tables
303     context.commit();
304 
305     return compressedSize;
306   }
307 
308   private static int encodeLiterals(final HuffmanCompressionContext context,
309                                     final CompressionParameters parameters,
310                                     final Object outputBase,
311                                     final long outputAddress,
312                                     final int outputSize, final byte[] literals,
313                                     final int literalsSize) {
314     // TODO: move this to Strategy
315     final boolean bypassCompression =
316         (parameters.getStrategy() == CompressionParameters.Strategy.FAST) &&
317         (parameters.getTargetLength() > 0);
318     if (bypassCompression || literalsSize <= MINIMUM_LITERALS_SIZE) {
319       return rawLiterals(outputBase, outputAddress, outputSize, literals,
320                          literalsSize);
321     }
322 
323     final int headerSize =
324         3 + (literalsSize >= 1024? 1 : 0) + (literalsSize >= 16384? 1 : 0);
325 
326     checkArgument(headerSize + 1 <= outputSize, OUTPUT_BUFFER_TOO_SMALL);
327 
328     final int[] counts = new int[MAX_SYMBOL_COUNT]; // TODO: preallocate
329     Histogram.count(literals, literalsSize, counts);
330     final int maxSymbol = Histogram.findMaxSymbol(counts, MAX_SYMBOL);
331     final int largestCount = Histogram.findLargestCount(counts, maxSymbol);
332 
333     final long literalsAddress = ARRAY_BYTE_BASE_OFFSET;
334     if (largestCount == literalsSize) {
335       // all bytes in input are equal
336       return rleLiterals(outputBase, outputAddress, literals, literalsSize);
337     } else if (largestCount <= (literalsSize >>> 7) + 4) {
338       // heuristic: probably not compressible enough
339       return rawLiterals(outputBase, outputAddress, outputSize, literals,
340                          literalsSize);
341     }
342 
343     final HuffmanCompressionTable previousTable = context.getPreviousTable();
344     final HuffmanCompressionTable table;
345     int serializedTableSize;
346     final boolean reuseTable;
347 
348     final boolean canReuse = previousTable.isValid(counts, maxSymbol);
349 
350     // heuristic: use existing table for small inputs if valid
351     // TODO: move to Strategy
352     final boolean preferReuse = parameters.getStrategy().ordinal() <
353                                 CompressionParameters.Strategy.LAZY.ordinal() &&
354                                 literalsSize <= 1024;
355     if (preferReuse && canReuse) {
356       table = previousTable;
357       reuseTable = true;
358       serializedTableSize = 0;
359     } else {
360       final HuffmanCompressionTable newTable = context.borrowTemporaryTable();
361 
362       newTable.initialize(counts, maxSymbol,
363                           HuffmanCompressionTable.optimalNumberOfBits(
364                               MAX_HUFFMAN_TABLE_LOG, literalsSize, maxSymbol),
365                           context.getCompressionTableWorkspace());
366 
367       serializedTableSize =
368           newTable.write(outputBase, outputAddress + headerSize,
369                          outputSize - headerSize,
370                          context.getTableWriterWorkspace());
371 
372       // Check if using previous huffman table is beneficial
373       if (canReuse && previousTable.estimateCompressedSize(counts, maxSymbol) <=
374                       serializedTableSize +
375                       newTable.estimateCompressedSize(counts, maxSymbol)) {
376         table = previousTable;
377         reuseTable = true;
378         serializedTableSize = 0;
379         context.discardTemporaryTable();
380       } else {
381         table = newTable;
382         reuseTable = false;
383       }
384     }
385 
386     final int compressedSize;
387     final boolean singleStream = literalsSize < 256;
388     if (singleStream) {
389       compressedSize = HuffmanCompressor.compressSingleStream(outputBase,
390                                                               outputAddress +
391                                                               headerSize +
392                                                               serializedTableSize,
393                                                               outputSize -
394                                                               headerSize -
395                                                               serializedTableSize,
396                                                               literals,
397                                                               literalsAddress,
398                                                               literalsSize,
399                                                               table);
400     } else {
401       compressedSize = HuffmanCompressor.compress4streams(outputBase,
402                                                           outputAddress +
403                                                           headerSize +
404                                                           serializedTableSize,
405                                                           outputSize -
406                                                           headerSize -
407                                                           serializedTableSize,
408                                                           literals,
409                                                           literalsAddress,
410                                                           literalsSize, table);
411     }
412 
413     final int totalSize = serializedTableSize + compressedSize;
414     final int minimumGain =
415         calculateMinimumGain(literalsSize, parameters.getStrategy());
416 
417     if (compressedSize == 0 || totalSize >= literalsSize - minimumGain) {
418       // incompressible or no savings
419 
420       // discard any temporary table we might have borrowed above
421       context.discardTemporaryTable();
422 
423       return rawLiterals(outputBase, outputAddress, outputSize, literals,
424                          literalsSize);
425     }
426 
427     final int encodingType =
428         reuseTable? TREELESS_LITERALS_BLOCK : COMPRESSED_LITERALS_BLOCK;
429 
430     // Build header
431     switch (headerSize) {
432       case 3: { // 2 - 2 - 10 - 10
433         final int header =
434             encodingType | ((singleStream? 0 : 1) << 2) | (literalsSize << 4) |
435             (totalSize << 14);
436         put24BitLittleEndian(outputBase, outputAddress, header);
437         break;
438       }
439       case 4: { // 2 - 2 - 14 - 14
440         final int header =
441             encodingType | (2 << 2) | (literalsSize << 4) | (totalSize << 18);
442         UNSAFE.putInt(outputBase, outputAddress, header);
443         break;
444       }
445       case 5: { // 2 - 2 - 18 - 18
446         final int header =
447             encodingType | (3 << 2) | (literalsSize << 4) | (totalSize << 22);
448         UNSAFE.putInt(outputBase, outputAddress, header);
449         UNSAFE.putByte(outputBase, outputAddress + SIZE_OF_INT,
450                        (byte) (totalSize >>> 10));
451         break;
452       }
453       default:  // not possible : headerSize is {3,4,5}
454         throw new IllegalStateException();
455     }
456 
457     return headerSize + totalSize;
458   }
459 
460   private static int rleLiterals(final Object outputBase,
461                                  final long outputAddress,
462                                  final Object inputBase, final int inputSize) {
463     final int headerSize =
464         1 + (inputSize > 31? 1 : 0) + (inputSize > 4095? 1 : 0);
465 
466     switch (headerSize) {
467       case 1: // 2 - 1 - 5
468         UNSAFE.putByte(outputBase, outputAddress,
469                        (byte) (RLE_LITERALS_BLOCK | (inputSize << 3)));
470         break;
471       case 2: // 2 - 2 - 12
472         UNSAFE.putShort(outputBase, outputAddress,
473                         (short) (RLE_LITERALS_BLOCK | (1 << 2) |
474                                  (inputSize << 4)));
475         break;
476       case 3: // 2 - 2 - 20
477         UNSAFE.putInt(outputBase, outputAddress,
478                       RLE_LITERALS_BLOCK | 3 << 2 | inputSize << 4);
479         break;
480       default:   // impossible. headerSize is {1,2,3}
481         throw new IllegalStateException();
482     }
483 
484     UNSAFE.putByte(outputBase, outputAddress + headerSize,
485                    UNSAFE.getByte(inputBase,
486                                   sun.misc.Unsafe.ARRAY_BYTE_BASE_OFFSET));
487 
488     return headerSize + 1;
489   }
490 
491   private static int calculateMinimumGain(final int inputSize,
492                                           final CompressionParameters.Strategy strategy) {
493     // TODO: move this to Strategy to avoid hardcoding a specific strategy here
494     final int minLog =
495         strategy == CompressionParameters.Strategy.BTULTRA? 7 : 6;
496     return (inputSize >>> minLog) + 2;
497   }
498 
499   private static int rawLiterals(final Object outputBase,
500                                  final long outputAddress, final int outputSize,
501                                  final Object inputBase, final int inputSize) {
502     int headerSize = 1;
503     if (inputSize >= 32) {
504       headerSize++;
505     }
506     if (inputSize >= 4096) {
507       headerSize++;
508     }
509 
510     checkArgument(inputSize + headerSize <= outputSize,
511                   OUTPUT_BUFFER_TOO_SMALL);
512 
513     switch (headerSize) {
514       case 1:
515         UNSAFE.putByte(outputBase, outputAddress,
516                        (byte) (RAW_LITERALS_BLOCK | (inputSize << 3)));
517         break;
518       case 2:
519         UNSAFE.putShort(outputBase, outputAddress,
520                         (short) (RAW_LITERALS_BLOCK | (1 << 2) |
521                                  (inputSize << 4)));
522         break;
523       case 3:
524         put24BitLittleEndian(outputBase, outputAddress,
525                              RAW_LITERALS_BLOCK | (3 << 2) | (inputSize << 4));
526         break;
527       default:
528         throw new AssertionError();
529     }
530 
531     // TODO: ensure this test is correct
532     checkArgument(inputSize + 1 <= outputSize, OUTPUT_BUFFER_TOO_SMALL);
533 
534     UNSAFE.copyMemory(inputBase, sun.misc.Unsafe.ARRAY_BYTE_BASE_OFFSET,
535                       outputBase, outputAddress + headerSize, inputSize);
536 
537     return headerSize + inputSize;
538   }
539 }