ZstdFrameCompressor.java

/*
 * This file is part of Waarp Project (named also Waarp or GG).
 *
 *  Copyright (c) 2019, Waarp SAS, and individual contributors by the @author
 *  tags. See the COPYRIGHT.txt in the distribution for a full listing of
 * individual contributors.
 *
 *  All Waarp Project is free software: you can redistribute it and/or
 * modify it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or (at your
 * option) any later version.
 *
 * Waarp is distributed in the hope that it will be useful, but WITHOUT ANY
 * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
 * A PARTICULAR PURPOSE. See the GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License along with
 * Waarp . If not, see <http://www.gnu.org/licenses/>.
 */

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.waarp.compress.zstdunsafe;

import static org.waarp.compress.zstdunsafe.Constants.*;
import static org.waarp.compress.zstdunsafe.Huffman.*;
import static org.waarp.compress.zstdunsafe.UnsafeUtil.*;
import static org.waarp.compress.zstdunsafe.Util.*;
import static sun.misc.Unsafe.*;

class ZstdFrameCompressor {
  static final int MAX_FRAME_HEADER_SIZE = 14;

  private static final int CHECKSUM_FLAG = 0x4;
  private static final int SINGLE_SEGMENT_FLAG = 0x20;

  private static final int MINIMUM_LITERALS_SIZE = 63;

  // the maximum table log allowed for literal encoding per RFC 8478, section 4.2.1
  private static final int MAX_HUFFMAN_TABLE_LOG = 11;
  public static final String OUTPUT_BUFFER_TOO_SMALL =
      "Output buffer too small";

  private ZstdFrameCompressor() {
  }

  // visible for testing
  static int writeMagic(final Object outputBase, final long outputAddress,
                        final long outputLimit) {
    checkArgument(outputLimit - outputAddress >= SIZE_OF_INT,
                  OUTPUT_BUFFER_TOO_SMALL);

    UNSAFE.putInt(outputBase, outputAddress, MAGIC_NUMBER);
    return SIZE_OF_INT;
  }

  // visible for testing
  static int writeFrameHeader(final Object outputBase, final long outputAddress,
                              final long outputLimit, final int inputSize,
                              final int windowSize) {
    checkArgument(outputLimit - outputAddress >= MAX_FRAME_HEADER_SIZE,
                  OUTPUT_BUFFER_TOO_SMALL);

    long output = outputAddress;

    final int contentSizeDescriptor =
        (inputSize >= 256? 1 : 0) + (inputSize >= 65536 + 256? 1 : 0);
    int frameHeaderDescriptor =
        (contentSizeDescriptor << 6) | CHECKSUM_FLAG; // dictionary ID missing

    final boolean singleSegment = windowSize >= inputSize;
    if (singleSegment) {
      frameHeaderDescriptor |= SINGLE_SEGMENT_FLAG;
    }

    UNSAFE.putByte(outputBase, output, (byte) frameHeaderDescriptor);
    output++;

    if (!singleSegment) {
      final int base = Integer.highestOneBit(windowSize);

      final int exponent = 32 - Integer.numberOfLeadingZeros(base) - 1;
      if (exponent < MIN_WINDOW_LOG) {
        throw new IllegalArgumentException(
            "Minimum window size is " + (1 << MIN_WINDOW_LOG));
      }

      final int remainder = windowSize - base;
      if (remainder % (base / 8) != 0) {
        throw new IllegalArgumentException(
            "Window size of magnitude 2^" + exponent + " must be multiple of " +
            (base / 8));
      }

      // mantissa is guaranteed to be between 0-7
      final int mantissa = remainder / (base / 8);
      final int encoded = ((exponent - MIN_WINDOW_LOG) << 3) | mantissa;

      UNSAFE.putByte(outputBase, output, (byte) encoded);
      output++;
    }

    switch (contentSizeDescriptor) {
      case 0:
        if (singleSegment) {
          UNSAFE.putByte(outputBase, output++, (byte) inputSize);
        }
        break;
      case 1:
        UNSAFE.putShort(outputBase, output, (short) (inputSize - 256));
        output += SIZE_OF_SHORT;
        break;
      case 2:
        UNSAFE.putInt(outputBase, output, inputSize);
        output += SIZE_OF_INT;
        break;
      default:
        throw new AssertionError();
    }

    return (int) (output - outputAddress);
  }

  // visible for testing
  static int writeChecksum(final Object outputBase, final long outputAddress,
                           final long outputLimit, final Object inputBase,
                           final long inputAddress, final long inputLimit) {
    checkArgument(outputLimit - outputAddress >= SIZE_OF_INT,
                  OUTPUT_BUFFER_TOO_SMALL);

    final int inputSize = (int) (inputLimit - inputAddress);

    final long hash = XxHash64.hash(0, inputBase, inputAddress, inputSize);

    UNSAFE.putInt(outputBase, outputAddress, (int) hash);

    return SIZE_OF_INT;
  }

  public static int compress(final Object inputBase, final long inputAddress,
                             final long inputLimit, final Object outputBase,
                             final long outputAddress, final long outputLimit,
                             final int compressionLevel) {
    final int inputSize = (int) (inputLimit - inputAddress);

    final CompressionParameters parameters =
        CompressionParameters.compute(compressionLevel, inputSize);

    long output = outputAddress;

    output += writeMagic(outputBase, output, outputLimit);
    output += writeFrameHeader(outputBase, output, outputLimit, inputSize,
                               1 << parameters.getWindowLog());
    output +=
        compressFrame(inputBase, inputAddress, inputLimit, outputBase, output,
                      outputLimit, parameters);
    output +=
        writeChecksum(outputBase, output, outputLimit, inputBase, inputAddress,
                      inputLimit);

    return (int) (output - outputAddress);
  }

  private static int compressFrame(final Object inputBase,
                                   final long inputAddress,
                                   final long inputLimit,
                                   final Object outputBase,
                                   final long outputAddress,
                                   final long outputLimit,
                                   final CompressionParameters parameters) {
    final int windowSize = 1 <<
                           parameters.getWindowLog(); // TODO: store window size in parameters directly?
    int blockSize = Math.min(MAX_BLOCK_SIZE, windowSize);

    int outputSize = (int) (outputLimit - outputAddress);
    int remaining = (int) (inputLimit - inputAddress);

    long output = outputAddress;
    long input = inputAddress;

    final CompressionContext context =
        new CompressionContext(parameters, inputAddress, remaining);

    do {
      checkArgument(outputSize >= SIZE_OF_BLOCK_HEADER + MIN_BLOCK_SIZE,
                    OUTPUT_BUFFER_TOO_SMALL);

      final int lastBlockFlag = blockSize >= remaining? 1 : 0;
      blockSize = Math.min(blockSize, remaining);

      int compressedSize = 0;
      if (remaining > 0) {
        compressedSize = compressBlock(inputBase, input, blockSize, outputBase,
                                       output + SIZE_OF_BLOCK_HEADER,
                                       outputSize - SIZE_OF_BLOCK_HEADER,
                                       context, parameters);
      }

      if (compressedSize == 0) { // block is not compressible
        checkArgument(blockSize + SIZE_OF_BLOCK_HEADER <= outputSize,
                      "Output size too small");

        final int blockHeader =
            lastBlockFlag | (RAW_BLOCK << 1) | (blockSize << 3);
        put24BitLittleEndian(outputBase, output, blockHeader);
        UNSAFE.copyMemory(inputBase, input, outputBase,
                          output + SIZE_OF_BLOCK_HEADER, blockSize);
        compressedSize = SIZE_OF_BLOCK_HEADER + blockSize;
      } else {
        final int blockHeader =
            lastBlockFlag | (COMPRESSED_BLOCK << 1) | (compressedSize << 3);
        put24BitLittleEndian(outputBase, output, blockHeader);
        compressedSize += SIZE_OF_BLOCK_HEADER;
      }

      input += blockSize;
      remaining -= blockSize;
      output += compressedSize;
      outputSize -= compressedSize;
    } while (remaining > 0);

    return (int) (output - outputAddress);
  }

  private static int compressBlock(final Object inputBase,
                                   final long inputAddress, final int inputSize,
                                   final Object outputBase,
                                   final long outputAddress,
                                   final int outputSize,
                                   final CompressionContext context,
                                   final CompressionParameters parameters) {
    if (inputSize < MIN_BLOCK_SIZE + SIZE_OF_BLOCK_HEADER + 1) {
      //  don't even attempt compression below a certain input size
      return 0;
    }

    context.blockCompressionState.enforceMaxDistance(inputAddress + inputSize,
                                                     1 <<
                                                     parameters.getWindowLog());
    context.sequenceStore.reset();

    final int lastLiteralsSize = parameters.getStrategy().getCompressor()
                                           .compressBlock(inputBase,
                                                          inputAddress,
                                                          inputSize,
                                                          context.sequenceStore,
                                                          context.blockCompressionState,
                                                          context.offsets,
                                                          parameters);

    final long lastLiteralsAddress =
        inputAddress + inputSize - lastLiteralsSize;

    // append [lastLiteralsAddress .. lastLiteralsSize] to sequenceStore literals buffer
    context.sequenceStore.appendLiterals(inputBase, lastLiteralsAddress,
                                         lastLiteralsSize);

    // convert length/offsets into codes
    context.sequenceStore.generateCodes();

    final long outputLimit = outputAddress + outputSize;
    long output = outputAddress;

    final int compressedLiteralsSize =
        encodeLiterals(context.huffmanContext, parameters, outputBase, output,
                       (int) (outputLimit - output),
                       context.sequenceStore.literalsBuffer,
                       context.sequenceStore.literalsLength);
    output += compressedLiteralsSize;

    final int compressedSequencesSize =
        SequenceEncoder.compressSequences(outputBase, output,
                                          (int) (outputLimit - output),
                                          context.sequenceStore,
                                          parameters.getStrategy(),
                                          context.sequenceEncodingContext);

    final int compressedSize = compressedLiteralsSize + compressedSequencesSize;
    if (compressedSize == 0) {
      // not compressible
      return compressedSize;
    }

    // Check compressibility
    final int maxCompressedSize =
        inputSize - calculateMinimumGain(inputSize, parameters.getStrategy());
    if (compressedSize > maxCompressedSize) {
      return 0; // not compressed
    }

    // confirm repeated offsets and entropy tables
    context.commit();

    return compressedSize;
  }

  private static int encodeLiterals(final HuffmanCompressionContext context,
                                    final CompressionParameters parameters,
                                    final Object outputBase,
                                    final long outputAddress,
                                    final int outputSize, final byte[] literals,
                                    final int literalsSize) {
    // TODO: move this to Strategy
    final boolean bypassCompression =
        (parameters.getStrategy() == CompressionParameters.Strategy.FAST) &&
        (parameters.getTargetLength() > 0);
    if (bypassCompression || literalsSize <= MINIMUM_LITERALS_SIZE) {
      return rawLiterals(outputBase, outputAddress, outputSize, literals,
                         literalsSize);
    }

    final int headerSize =
        3 + (literalsSize >= 1024? 1 : 0) + (literalsSize >= 16384? 1 : 0);

    checkArgument(headerSize + 1 <= outputSize, OUTPUT_BUFFER_TOO_SMALL);

    final int[] counts = new int[MAX_SYMBOL_COUNT]; // TODO: preallocate
    Histogram.count(literals, literalsSize, counts);
    final int maxSymbol = Histogram.findMaxSymbol(counts, MAX_SYMBOL);
    final int largestCount = Histogram.findLargestCount(counts, maxSymbol);

    final long literalsAddress = ARRAY_BYTE_BASE_OFFSET;
    if (largestCount == literalsSize) {
      // all bytes in input are equal
      return rleLiterals(outputBase, outputAddress, literals, literalsSize);
    } else if (largestCount <= (literalsSize >>> 7) + 4) {
      // heuristic: probably not compressible enough
      return rawLiterals(outputBase, outputAddress, outputSize, literals,
                         literalsSize);
    }

    final HuffmanCompressionTable previousTable = context.getPreviousTable();
    final HuffmanCompressionTable table;
    int serializedTableSize;
    final boolean reuseTable;

    final boolean canReuse = previousTable.isValid(counts, maxSymbol);

    // heuristic: use existing table for small inputs if valid
    // TODO: move to Strategy
    final boolean preferReuse = parameters.getStrategy().ordinal() <
                                CompressionParameters.Strategy.LAZY.ordinal() &&
                                literalsSize <= 1024;
    if (preferReuse && canReuse) {
      table = previousTable;
      reuseTable = true;
      serializedTableSize = 0;
    } else {
      final HuffmanCompressionTable newTable = context.borrowTemporaryTable();

      newTable.initialize(counts, maxSymbol,
                          HuffmanCompressionTable.optimalNumberOfBits(
                              MAX_HUFFMAN_TABLE_LOG, literalsSize, maxSymbol),
                          context.getCompressionTableWorkspace());

      serializedTableSize =
          newTable.write(outputBase, outputAddress + headerSize,
                         outputSize - headerSize,
                         context.getTableWriterWorkspace());

      // Check if using previous huffman table is beneficial
      if (canReuse && previousTable.estimateCompressedSize(counts, maxSymbol) <=
                      serializedTableSize +
                      newTable.estimateCompressedSize(counts, maxSymbol)) {
        table = previousTable;
        reuseTable = true;
        serializedTableSize = 0;
        context.discardTemporaryTable();
      } else {
        table = newTable;
        reuseTable = false;
      }
    }

    final int compressedSize;
    final boolean singleStream = literalsSize < 256;
    if (singleStream) {
      compressedSize = HuffmanCompressor.compressSingleStream(outputBase,
                                                              outputAddress +
                                                              headerSize +
                                                              serializedTableSize,
                                                              outputSize -
                                                              headerSize -
                                                              serializedTableSize,
                                                              literals,
                                                              literalsAddress,
                                                              literalsSize,
                                                              table);
    } else {
      compressedSize = HuffmanCompressor.compress4streams(outputBase,
                                                          outputAddress +
                                                          headerSize +
                                                          serializedTableSize,
                                                          outputSize -
                                                          headerSize -
                                                          serializedTableSize,
                                                          literals,
                                                          literalsAddress,
                                                          literalsSize, table);
    }

    final int totalSize = serializedTableSize + compressedSize;
    final int minimumGain =
        calculateMinimumGain(literalsSize, parameters.getStrategy());

    if (compressedSize == 0 || totalSize >= literalsSize - minimumGain) {
      // incompressible or no savings

      // discard any temporary table we might have borrowed above
      context.discardTemporaryTable();

      return rawLiterals(outputBase, outputAddress, outputSize, literals,
                         literalsSize);
    }

    final int encodingType =
        reuseTable? TREELESS_LITERALS_BLOCK : COMPRESSED_LITERALS_BLOCK;

    // Build header
    switch (headerSize) {
      case 3: { // 2 - 2 - 10 - 10
        final int header =
            encodingType | ((singleStream? 0 : 1) << 2) | (literalsSize << 4) |
            (totalSize << 14);
        put24BitLittleEndian(outputBase, outputAddress, header);
        break;
      }
      case 4: { // 2 - 2 - 14 - 14
        final int header =
            encodingType | (2 << 2) | (literalsSize << 4) | (totalSize << 18);
        UNSAFE.putInt(outputBase, outputAddress, header);
        break;
      }
      case 5: { // 2 - 2 - 18 - 18
        final int header =
            encodingType | (3 << 2) | (literalsSize << 4) | (totalSize << 22);
        UNSAFE.putInt(outputBase, outputAddress, header);
        UNSAFE.putByte(outputBase, outputAddress + SIZE_OF_INT,
                       (byte) (totalSize >>> 10));
        break;
      }
      default:  // not possible : headerSize is {3,4,5}
        throw new IllegalStateException();
    }

    return headerSize + totalSize;
  }

  private static int rleLiterals(final Object outputBase,
                                 final long outputAddress,
                                 final Object inputBase, final int inputSize) {
    final int headerSize =
        1 + (inputSize > 31? 1 : 0) + (inputSize > 4095? 1 : 0);

    switch (headerSize) {
      case 1: // 2 - 1 - 5
        UNSAFE.putByte(outputBase, outputAddress,
                       (byte) (RLE_LITERALS_BLOCK | (inputSize << 3)));
        break;
      case 2: // 2 - 2 - 12
        UNSAFE.putShort(outputBase, outputAddress,
                        (short) (RLE_LITERALS_BLOCK | (1 << 2) |
                                 (inputSize << 4)));
        break;
      case 3: // 2 - 2 - 20
        UNSAFE.putInt(outputBase, outputAddress,
                      RLE_LITERALS_BLOCK | 3 << 2 | inputSize << 4);
        break;
      default:   // impossible. headerSize is {1,2,3}
        throw new IllegalStateException();
    }

    UNSAFE.putByte(outputBase, outputAddress + headerSize,
                   UNSAFE.getByte(inputBase,
                                  sun.misc.Unsafe.ARRAY_BYTE_BASE_OFFSET));

    return headerSize + 1;
  }

  private static int calculateMinimumGain(final int inputSize,
                                          final CompressionParameters.Strategy strategy) {
    // TODO: move this to Strategy to avoid hardcoding a specific strategy here
    final int minLog =
        strategy == CompressionParameters.Strategy.BTULTRA? 7 : 6;
    return (inputSize >>> minLog) + 2;
  }

  private static int rawLiterals(final Object outputBase,
                                 final long outputAddress, final int outputSize,
                                 final Object inputBase, final int inputSize) {
    int headerSize = 1;
    if (inputSize >= 32) {
      headerSize++;
    }
    if (inputSize >= 4096) {
      headerSize++;
    }

    checkArgument(inputSize + headerSize <= outputSize,
                  OUTPUT_BUFFER_TOO_SMALL);

    switch (headerSize) {
      case 1:
        UNSAFE.putByte(outputBase, outputAddress,
                       (byte) (RAW_LITERALS_BLOCK | (inputSize << 3)));
        break;
      case 2:
        UNSAFE.putShort(outputBase, outputAddress,
                        (short) (RAW_LITERALS_BLOCK | (1 << 2) |
                                 (inputSize << 4)));
        break;
      case 3:
        put24BitLittleEndian(outputBase, outputAddress,
                             RAW_LITERALS_BLOCK | (3 << 2) | (inputSize << 4));
        break;
      default:
        throw new AssertionError();
    }

    // TODO: ensure this test is correct
    checkArgument(inputSize + 1 <= outputSize, OUTPUT_BUFFER_TOO_SMALL);

    UNSAFE.copyMemory(inputBase, sun.misc.Unsafe.ARRAY_BYTE_BASE_OFFSET,
                      outputBase, outputAddress + headerSize, inputSize);

    return headerSize + inputSize;
  }
}