View Javadoc
1   /*
2    * This file is part of Waarp Project (named also Waarp or GG).
3    *
4    *  Copyright (c) 2019, Waarp SAS, and individual contributors by the @author
5    *  tags. See the COPYRIGHT.txt in the distribution for a full listing of
6    * individual contributors.
7    *
8    *  All Waarp Project is free software: you can redistribute it and/or
9    * modify it under the terms of the GNU General Public License as published by
10   * the Free Software Foundation, either version 3 of the License, or (at your
11   * option) any later version.
12   *
13   * Waarp is distributed in the hope that it will be useful, but WITHOUT ANY
14   * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
15   * A PARTICULAR PURPOSE. See the GNU General Public License for more details.
16   *
17   *  You should have received a copy of the GNU General Public License along with
18   * Waarp . If not, see <http://www.gnu.org/licenses/>.
19   */
20  
21  /*
22   * Licensed under the Apache License, Version 2.0 (the "License");
23   * you may not use this file except in compliance with the License.
24   * You may obtain a copy of the License at
25   *
26   *     http://www.apache.org/licenses/LICENSE-2.0
27   *
28   * Unless required by applicable law or agreed to in writing, software
29   * distributed under the License is distributed on an "AS IS" BASIS,
30   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31   * See the License for the specific language governing permissions and
32   * limitations under the License.
33   */
34  package org.waarp.compress.zstdsafe;
35  
36  import static org.waarp.compress.zstdsafe.Constants.*;
37  import static org.waarp.compress.zstdsafe.Huffman.*;
38  import static org.waarp.compress.zstdsafe.UnsafeUtil.*;
39  import static org.waarp.compress.zstdsafe.Util.*;
40  
41  class ZstdFrameCompressor {
42    static final int MAX_FRAME_HEADER_SIZE = 14;
43  
44    private static final int CHECKSUM_FLAG = 0x4;
45    private static final int SINGLE_SEGMENT_FLAG = 0x20;
46  
47    private static final int MINIMUM_LITERALS_SIZE = 63;
48  
49    // the maximum table log allowed for literal encoding per RFC 8478, section 4.2.1
50    private static final int MAX_HUFFMAN_TABLE_LOG = 11;
51    public static final String OUTPUT_BUFFER_TOO_SMALL =
52        "Output buffer too small";
53  
54    private ZstdFrameCompressor() {
55    }
56  
57    // visible for testing
58    static int writeMagic(final byte[] outputBase, final int outputAddress,
59                          final int outputLimit) {
60      checkArgument(outputLimit - outputAddress >= SIZE_OF_INT,
61                    OUTPUT_BUFFER_TOO_SMALL);
62  
63      putInt(outputBase, outputAddress, MAGIC_NUMBER);
64      return SIZE_OF_INT;
65    }
66  
67    // visible for testing
68    static int writeFrameHeader(final byte[] outputBase, final int outputAddress,
69                                final int outputLimit, final int inputSize,
70                                final int windowSize) {
71      checkArgument(outputLimit - outputAddress >= MAX_FRAME_HEADER_SIZE,
72                    OUTPUT_BUFFER_TOO_SMALL);
73  
74      int output = outputAddress;
75  
76      final int contentSizeDescriptor =
77          (inputSize >= 256? 1 : 0) + (inputSize >= 65536 + 256? 1 : 0);
78      int frameHeaderDescriptor =
79          (contentSizeDescriptor << 6) | CHECKSUM_FLAG; // dictionary ID missing
80  
81      final boolean singleSegment = windowSize >= inputSize;
82      if (singleSegment) {
83        frameHeaderDescriptor |= SINGLE_SEGMENT_FLAG;
84      }
85  
86      outputBase[output] = (byte) frameHeaderDescriptor;
87      output++;
88  
89      if (!singleSegment) {
90        final int base = Integer.highestOneBit(windowSize);
91  
92        final int exponent = 32 - Integer.numberOfLeadingZeros(base) - 1;
93        if (exponent < MIN_WINDOW_LOG) {
94          throw new IllegalArgumentException(
95              "Minimum window size is " + (1 << MIN_WINDOW_LOG));
96        }
97  
98        final int remainder = windowSize - base;
99        if (remainder % (base / 8) != 0) {
100         throw new IllegalArgumentException(
101             "Window size of magnitude 2^" + exponent + " must be multiple of " +
102             (base / 8));
103       }
104 
105       // mantissa is guaranteed to be between 0-7
106       final int mantissa = remainder / (base / 8);
107       final int encoded = ((exponent - MIN_WINDOW_LOG) << 3) | mantissa;
108 
109       outputBase[output] = (byte) encoded;
110       output++;
111     }
112 
113     switch (contentSizeDescriptor) {
114       case 0:
115         if (singleSegment) {
116           outputBase[output++] = (byte) inputSize;
117         }
118         break;
119       case 1:
120         putShort(outputBase, output, (short) (inputSize - 256));
121         output += SIZE_OF_SHORT;
122         break;
123       case 2:
124         putInt(outputBase, output, inputSize);
125         output += SIZE_OF_INT;
126         break;
127       default:
128         throw new AssertionError();
129     }
130 
131     return output - outputAddress;
132   }
133 
134   // visible for testing
135   static int writeChecksum(final byte[] outputBase, final int outputAddress,
136                            final int outputLimit, final byte[] inputBase,
137                            final int inputAddress, final int inputLimit) {
138     checkArgument(outputLimit - outputAddress >= SIZE_OF_INT,
139                   OUTPUT_BUFFER_TOO_SMALL);
140 
141     final int inputSize = inputLimit - inputAddress;
142 
143     final long hash = XxHash64.hash(0, inputBase, inputAddress, inputSize);
144 
145     putInt(outputBase, outputAddress, (int) hash);
146 
147     return SIZE_OF_INT;
148   }
149 
150   public static int compress(final byte[] inputBase, final int inputAddress,
151                              final int inputLimit, final byte[] outputBase,
152                              final int outputAddress, final int outputLimit,
153                              final int compressionLevel) {
154     final int inputSize = inputLimit - inputAddress;
155 
156     final CompressionParameters parameters =
157         CompressionParameters.compute(compressionLevel, inputSize);
158 
159     int output = outputAddress;
160 
161     output += writeMagic(outputBase, output, outputLimit);
162     output += writeFrameHeader(outputBase, output, outputLimit, inputSize,
163                                1 << parameters.getWindowLog());
164     output +=
165         compressFrame(inputBase, inputAddress, inputLimit, outputBase, output,
166                       outputLimit, parameters);
167     output +=
168         writeChecksum(outputBase, output, outputLimit, inputBase, inputAddress,
169                       inputLimit);
170 
171     return output - outputAddress;
172   }
173 
174   private static int compressFrame(final byte[] inputBase,
175                                    final int inputAddress, final int inputLimit,
176                                    final byte[] outputBase,
177                                    final int outputAddress,
178                                    final int outputLimit,
179                                    final CompressionParameters parameters) {
180     final int windowSize = 1 <<
181                            parameters.getWindowLog(); // TODO: store window size in parameters directly?
182     int blockSize = Math.min(MAX_BLOCK_SIZE, windowSize);
183 
184     int outputSize = outputLimit - outputAddress;
185     int remaining = inputLimit - inputAddress;
186 
187     int output = outputAddress;
188     int input = inputAddress;
189 
190     final CompressionContext context =
191         new CompressionContext(parameters, inputAddress, remaining);
192 
193     do {
194       checkArgument(outputSize >= SIZE_OF_BLOCK_HEADER + MIN_BLOCK_SIZE,
195                     OUTPUT_BUFFER_TOO_SMALL);
196 
197       final int lastBlockFlag = blockSize >= remaining? 1 : 0;
198       blockSize = Math.min(blockSize, remaining);
199 
200       int compressedSize = 0;
201       if (remaining > 0) {
202         compressedSize = compressBlock(inputBase, input, blockSize, outputBase,
203                                        output + SIZE_OF_BLOCK_HEADER,
204                                        outputSize - SIZE_OF_BLOCK_HEADER,
205                                        context, parameters);
206       }
207 
208       if (compressedSize == 0) { // block is not compressible
209         checkArgument(blockSize + SIZE_OF_BLOCK_HEADER <= outputSize,
210                       "Output size too small");
211 
212         final int blockHeader =
213             lastBlockFlag | (RAW_BLOCK << 1) | (blockSize << 3);
214         put24BitLittleEndian(outputBase, output, blockHeader);
215         copyMemory(inputBase, input, outputBase, output + SIZE_OF_BLOCK_HEADER,
216                    blockSize);
217         compressedSize = SIZE_OF_BLOCK_HEADER + blockSize;
218       } else {
219         final int blockHeader =
220             lastBlockFlag | (COMPRESSED_BLOCK << 1) | (compressedSize << 3);
221         put24BitLittleEndian(outputBase, output, blockHeader);
222         compressedSize += SIZE_OF_BLOCK_HEADER;
223       }
224 
225       input += blockSize;
226       remaining -= blockSize;
227       output += compressedSize;
228       outputSize -= compressedSize;
229     } while (remaining > 0);
230 
231     return output - outputAddress;
232   }
233 
234   private static int compressBlock(final byte[] inputBase,
235                                    final int inputAddress, final int inputSize,
236                                    final byte[] outputBase,
237                                    final int outputAddress,
238                                    final int outputSize,
239                                    final CompressionContext context,
240                                    final CompressionParameters parameters) {
241     if (inputSize < MIN_BLOCK_SIZE + SIZE_OF_BLOCK_HEADER + 1) {
242       //  don't even attempt compression below a certain input size
243       return 0;
244     }
245 
246     context.blockCompressionState.enforceMaxDistance(inputAddress + inputSize,
247                                                      1 <<
248                                                      parameters.getWindowLog());
249     context.sequenceStore.reset();
250 
251     final int lastLiteralsSize = parameters.getStrategy().getCompressor()
252                                            .compressBlock(inputBase,
253                                                           inputAddress,
254                                                           inputSize,
255                                                           context.sequenceStore,
256                                                           context.blockCompressionState,
257                                                           context.offsets,
258                                                           parameters);
259 
260     final int lastLiteralsAddress = inputAddress + inputSize - lastLiteralsSize;
261 
262     // append [lastLiteralsAddress .. lastLiteralsSize] to sequenceStore literals buffer
263     context.sequenceStore.appendLiterals(inputBase, lastLiteralsAddress,
264                                          lastLiteralsSize);
265 
266     // convert length/offsets into codes
267     context.sequenceStore.generateCodes();
268 
269     final int outputLimit = outputAddress + outputSize;
270     int output = outputAddress;
271 
272     final int compressedLiteralsSize =
273         encodeLiterals(context.huffmanContext, parameters, outputBase, output,
274                        outputLimit - output,
275                        context.sequenceStore.literalsBuffer,
276                        context.sequenceStore.literalsLength);
277     output += compressedLiteralsSize;
278 
279     final int compressedSequencesSize =
280         SequenceEncoder.compressSequences(outputBase, output,
281                                           outputLimit - output,
282                                           context.sequenceStore,
283                                           parameters.getStrategy(),
284                                           context.sequenceEncodingContext);
285 
286     final int compressedSize = compressedLiteralsSize + compressedSequencesSize;
287     if (compressedSize == 0) {
288       // not compressible
289       return compressedSize;
290     }
291 
292     // Check compressibility
293     final int maxCompressedSize =
294         inputSize - calculateMinimumGain(inputSize, parameters.getStrategy());
295     if (compressedSize > maxCompressedSize) {
296       return 0; // not compressed
297     }
298 
299     // confirm repeated offsets and entropy tables
300     context.commit();
301 
302     return compressedSize;
303   }
304 
305   private static int encodeLiterals(final HuffmanCompressionContext context,
306                                     final CompressionParameters parameters,
307                                     final byte[] outputBase,
308                                     final int outputAddress,
309                                     final int outputSize, final byte[] literals,
310                                     final int literalsSize) {
311     // TODO: move this to Strategy
312     final boolean bypassCompression =
313         (parameters.getStrategy() == CompressionParameters.Strategy.FAST) &&
314         (parameters.getTargetLength() > 0);
315     if (bypassCompression || literalsSize <= MINIMUM_LITERALS_SIZE) {
316       return rawLiterals(outputBase, outputAddress, outputSize, literals,
317                          literalsSize);
318     }
319 
320     final int headerSize =
321         3 + (literalsSize >= 1024? 1 : 0) + (literalsSize >= 16384? 1 : 0);
322 
323     checkArgument(headerSize + 1 <= outputSize, OUTPUT_BUFFER_TOO_SMALL);
324 
325     final int[] counts = new int[MAX_SYMBOL_COUNT]; // TODO: preallocate
326     Histogram.count(literals, literalsSize, counts);
327     final int maxSymbol = Histogram.findMaxSymbol(counts, MAX_SYMBOL);
328     final int largestCount = Histogram.findLargestCount(counts, maxSymbol);
329 
330     final int literalsAddress = 0;
331     if (largestCount == literalsSize) {
332       // all bytes in input are equal
333       return rleLiterals(outputBase, outputAddress, literals, literalsSize);
334     } else if (largestCount <= (literalsSize >>> 7) + 4) {
335       // heuristic: probably not compressible enough
336       return rawLiterals(outputBase, outputAddress, outputSize, literals,
337                          literalsSize);
338     }
339 
340     final HuffmanCompressionTable previousTable = context.getPreviousTable();
341     final HuffmanCompressionTable table;
342     int serializedTableSize;
343     final boolean reuseTable;
344 
345     final boolean canReuse = previousTable.isValid(counts, maxSymbol);
346 
347     // heuristic: use existing table for small inputs if valid
348     // TODO: move to Strategy
349     final boolean preferReuse = parameters.getStrategy().ordinal() <
350                                 CompressionParameters.Strategy.LAZY.ordinal() &&
351                                 literalsSize <= 1024;
352     if (preferReuse && canReuse) {
353       table = previousTable;
354       reuseTable = true;
355       serializedTableSize = 0;
356     } else {
357       final HuffmanCompressionTable newTable = context.borrowTemporaryTable();
358 
359       newTable.initialize(counts, maxSymbol,
360                           HuffmanCompressionTable.optimalNumberOfBits(
361                               MAX_HUFFMAN_TABLE_LOG, literalsSize, maxSymbol),
362                           context.getCompressionTableWorkspace());
363 
364       serializedTableSize =
365           newTable.write(outputBase, outputAddress + headerSize,
366                          outputSize - headerSize,
367                          context.getTableWriterWorkspace());
368 
369       // Check if using previous huffman table is beneficial
370       if (canReuse && previousTable.estimateCompressedSize(counts, maxSymbol) <=
371                       serializedTableSize +
372                       newTable.estimateCompressedSize(counts, maxSymbol)) {
373         table = previousTable;
374         reuseTable = true;
375         serializedTableSize = 0;
376         context.discardTemporaryTable();
377       } else {
378         table = newTable;
379         reuseTable = false;
380       }
381     }
382 
383     final int compressedSize;
384     final boolean singleStream = literalsSize < 256;
385     if (singleStream) {
386       compressedSize = HuffmanCompressor.compressSingleStream(outputBase,
387                                                               outputAddress +
388                                                               headerSize +
389                                                               serializedTableSize,
390                                                               outputSize -
391                                                               headerSize -
392                                                               serializedTableSize,
393                                                               literals,
394                                                               literalsAddress,
395                                                               literalsSize,
396                                                               table);
397     } else {
398       compressedSize = HuffmanCompressor.compress4streams(outputBase,
399                                                           outputAddress +
400                                                           headerSize +
401                                                           serializedTableSize,
402                                                           outputSize -
403                                                           headerSize -
404                                                           serializedTableSize,
405                                                           literals,
406                                                           literalsAddress,
407                                                           literalsSize, table);
408     }
409 
410     final int totalSize = serializedTableSize + compressedSize;
411     final int minimumGain =
412         calculateMinimumGain(literalsSize, parameters.getStrategy());
413 
414     if (compressedSize == 0 || totalSize >= literalsSize - minimumGain) {
415       // incompressible or no savings
416 
417       // discard any temporary table we might have borrowed above
418       context.discardTemporaryTable();
419 
420       return rawLiterals(outputBase, outputAddress, outputSize, literals,
421                          literalsSize);
422     }
423 
424     final int encodingType =
425         reuseTable? TREELESS_LITERALS_BLOCK : COMPRESSED_LITERALS_BLOCK;
426 
427     // Build header
428     switch (headerSize) {
429       case 3: { // 2 - 2 - 10 - 10
430         final int header =
431             encodingType | ((singleStream? 0 : 1) << 2) | (literalsSize << 4) |
432             (totalSize << 14);
433         put24BitLittleEndian(outputBase, outputAddress, header);
434         break;
435       }
436       case 4: { // 2 - 2 - 14 - 14
437         final int header =
438             encodingType | (2 << 2) | (literalsSize << 4) | (totalSize << 18);
439         putInt(outputBase, outputAddress, header);
440         break;
441       }
442       case 5: { // 2 - 2 - 18 - 18
443         final int header =
444             encodingType | (3 << 2) | (literalsSize << 4) | (totalSize << 22);
445         putInt(outputBase, outputAddress, header);
446         outputBase[outputAddress + SIZE_OF_INT] = (byte) (totalSize >>> 10);
447         break;
448       }
449       default:  // not possible : headerSize is {3,4,5}
450         throw new IllegalStateException();
451     }
452 
453     return headerSize + totalSize;
454   }
455 
456   private static int rleLiterals(final byte[] outputBase,
457                                  final int outputAddress,
458                                  final byte[] inputBase, final int inputSize) {
459     final int headerSize =
460         1 + (inputSize > 31? 1 : 0) + (inputSize > 4095? 1 : 0);
461 
462     switch (headerSize) {
463       case 1: // 2 - 1 - 5
464         outputBase[outputAddress] =
465             (byte) (RLE_LITERALS_BLOCK | (inputSize << 3));
466         break;
467       case 2: // 2 - 2 - 12
468         putShort(outputBase, outputAddress,
469                  (short) (RLE_LITERALS_BLOCK | (1 << 2) | (inputSize << 4)));
470         break;
471       case 3: // 2 - 2 - 20
472         putInt(outputBase, outputAddress,
473                RLE_LITERALS_BLOCK | 3 << 2 | inputSize << 4);
474         break;
475       default:   // impossible. headerSize is {1,2,3}
476         throw new IllegalStateException();
477     }
478 
479     outputBase[outputAddress + headerSize] = inputBase[0];
480 
481     return headerSize + 1;
482   }
483 
484   private static int calculateMinimumGain(final int inputSize,
485                                           final CompressionParameters.Strategy strategy) {
486     // TODO: move this to Strategy to avoid hardcoding a specific strategy here
487     final int minLog =
488         strategy == CompressionParameters.Strategy.BTULTRA? 7 : 6;
489     return (inputSize >>> minLog) + 2;
490   }
491 
492   private static int rawLiterals(final byte[] outputBase,
493                                  final int outputAddress, final int outputSize,
494                                  final byte[] inputBase, final int inputSize) {
495     int headerSize = 1;
496     if (inputSize >= 32) {
497       headerSize++;
498     }
499     if (inputSize >= 4096) {
500       headerSize++;
501     }
502 
503     checkArgument(inputSize + headerSize <= outputSize,
504                   OUTPUT_BUFFER_TOO_SMALL);
505 
506     switch (headerSize) {
507       case 1:
508         outputBase[outputAddress] =
509             (byte) (RAW_LITERALS_BLOCK | (inputSize << 3));
510         break;
511       case 2:
512         putShort(outputBase, outputAddress,
513                  (short) (RAW_LITERALS_BLOCK | (1 << 2) | (inputSize << 4)));
514         break;
515       case 3:
516         put24BitLittleEndian(outputBase, outputAddress,
517                              RAW_LITERALS_BLOCK | (3 << 2) | (inputSize << 4));
518         break;
519       default:
520         throw new AssertionError();
521     }
522 
523     // TODO: ensure this test is correct
524     checkArgument(inputSize + 1 <= outputSize, OUTPUT_BUFFER_TOO_SMALL);
525 
526     copyMemory(inputBase, 0, outputBase, outputAddress + headerSize, inputSize);
527 
528     return headerSize + inputSize;
529   }
530 }