View Javadoc
1   /*
2    * This file is part of Waarp Project (named also Waarp or GG).
3    *
4    *  Copyright (c) 2019, Waarp SAS, and individual contributors by the @author
5    *  tags. See the COPYRIGHT.txt in the distribution for a full listing of
6    * individual contributors.
7    *
8    *  All Waarp Project is free software: you can redistribute it and/or
9    * modify it under the terms of the GNU General Public License as published by
10   * the Free Software Foundation, either version 3 of the License, or (at your
11   * option) any later version.
12   *
13   * Waarp is distributed in the hope that it will be useful, but WITHOUT ANY
14   * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
15   * A PARTICULAR PURPOSE. See the GNU General Public License for more details.
16   *
17   *  You should have received a copy of the GNU General Public License along with
18   * Waarp . If not, see <http://www.gnu.org/licenses/>.
19   */
20  
21  /*
22   * Licensed under the Apache License, Version 2.0 (the "License");
23   * you may not use this file except in compliance with the License.
24   * You may obtain a copy of the License at
25   *
26   *     http://www.apache.org/licenses/LICENSE-2.0
27   *
28   * Unless required by applicable law or agreed to in writing, software
29   * distributed under the License is distributed on an "AS IS" BASIS,
30   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31   * See the License for the specific language governing permissions and
32   * limitations under the License.
33   */
34  package org.waarp.compress.zstdunsafe;
35  
36  import static org.waarp.compress.zstdunsafe.Constants.*;
37  import static org.waarp.compress.zstdunsafe.FiniteStateEntropy.*;
38  import static org.waarp.compress.zstdunsafe.UnsafeUtil.*;
39  import static org.waarp.compress.zstdunsafe.Util.*;
40  
41  class SequenceEncoder {
42    private static final int DEFAULT_LITERAL_LENGTH_NORMALIZED_COUNTS_LOG = 6;
43    private static final short[] DEFAULT_LITERAL_LENGTH_NORMALIZED_COUNTS = {
44        4, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2,
45        3, 2, 1, 1, 1, 1, 1, -1, -1, -1, -1
46    };
47  
48    private static final int DEFAULT_MATCH_LENGTH_NORMALIZED_COUNTS_LOG = 6;
49    private static final short[] DEFAULT_MATCH_LENGTH_NORMALIZED_COUNTS = {
50        1, 4, 3, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
51        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1,
52        -1, -1, -1, -1
53    };
54  
55    private static final int DEFAULT_OFFSET_NORMALIZED_COUNTS_LOG = 5;
56    private static final short[] DEFAULT_OFFSET_NORMALIZED_COUNTS = {
57        1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
58        -1, -1, -1, -1, -1
59    };
60  
61    private static final FseCompressionTable DEFAULT_LITERAL_LENGTHS_TABLE =
62        FseCompressionTable.newInstance(DEFAULT_LITERAL_LENGTH_NORMALIZED_COUNTS,
63                                        MAX_LITERALS_LENGTH_SYMBOL,
64                                        DEFAULT_LITERAL_LENGTH_NORMALIZED_COUNTS_LOG);
65    private static final FseCompressionTable DEFAULT_MATCH_LENGTHS_TABLE =
66        FseCompressionTable.newInstance(DEFAULT_MATCH_LENGTH_NORMALIZED_COUNTS,
67                                        MAX_MATCH_LENGTH_SYMBOL,
68                                        DEFAULT_LITERAL_LENGTH_NORMALIZED_COUNTS_LOG);
69    private static final FseCompressionTable DEFAULT_OFFSETS_TABLE =
70        FseCompressionTable.newInstance(DEFAULT_OFFSET_NORMALIZED_COUNTS,
71                                        DEFAULT_MAX_OFFSET_CODE_SYMBOL,
72                                        DEFAULT_OFFSET_NORMALIZED_COUNTS_LOG);
73    public static final String NOT_YET_IMPLEMENTED = "not yet implemented";
74  
75    private SequenceEncoder() {
76    }
77  
78    public static int compressSequences(final Object outputBase,
79                                        final long outputAddress,
80                                        final int outputSize,
81                                        final SequenceStore sequences,
82                                        final CompressionParameters.Strategy strategy,
83                                        final SequenceEncodingContext workspace) {
84      long output = outputAddress;
85      final long outputLimit = outputAddress + outputSize;
86  
87      checkArgument(outputLimit - output >
88                    3 /* max sequence count Size */ + 1 /* encoding type flags */,
89                    "Output buffer too small");
90  
91      final int sequenceCount = sequences.sequenceCount;
92      if (sequenceCount < 0x7F) {
93        UNSAFE.putByte(outputBase, output, (byte) sequenceCount);
94        output++;
95      } else if (sequenceCount < LONG_NUMBER_OF_SEQUENCES) {
96        UNSAFE.putByte(outputBase, output, (byte) (sequenceCount >>> 8 | 0x80));
97        UNSAFE.putByte(outputBase, output + 1, (byte) sequenceCount);
98        output += SIZE_OF_SHORT;
99      } else {
100       UNSAFE.putByte(outputBase, output, (byte) 0xFF);
101       output++;
102       UNSAFE.putShort(outputBase, output,
103                       (short) (sequenceCount - LONG_NUMBER_OF_SEQUENCES));
104       output += SIZE_OF_SHORT;
105     }
106 
107     if (sequenceCount == 0) {
108       return (int) (output - outputAddress);
109     }
110 
111     // flags for FSE encoding type
112     final long headerAddress = output++;
113 
114     int maxSymbol;
115     int largestCount;
116 
117     // literal lengths
118     final int[] counts = workspace.counts;
119     Histogram.count(sequences.literalLengthCodes, sequenceCount,
120                     workspace.counts);
121     maxSymbol = Histogram.findMaxSymbol(counts, MAX_LITERALS_LENGTH_SYMBOL);
122     largestCount = Histogram.findLargestCount(counts, maxSymbol);
123 
124     final int literalsLengthEncodingType =
125         selectEncodingType(largestCount, sequenceCount,
126                            DEFAULT_LITERAL_LENGTH_NORMALIZED_COUNTS_LOG, true,
127                            strategy);
128 
129     final FseCompressionTable literalLengthTable;
130     switch (literalsLengthEncodingType) {
131       case SEQUENCE_ENCODING_RLE:
132         UNSAFE.putByte(outputBase, output, sequences.literalLengthCodes[0]);
133         output++;
134         workspace.literalLengthTable.initializeRleTable(maxSymbol);
135         literalLengthTable = workspace.literalLengthTable;
136         break;
137       case SEQUENCE_ENCODING_BASIC:
138         literalLengthTable = DEFAULT_LITERAL_LENGTHS_TABLE;
139         break;
140       case SEQUENCE_ENCODING_COMPRESSED:
141         output +=
142             buildCompressionTable(workspace.literalLengthTable, outputBase,
143                                   output, outputLimit, sequenceCount,
144                                   LITERAL_LENGTH_TABLE_LOG,
145                                   sequences.literalLengthCodes,
146                                   workspace.counts, maxSymbol,
147                                   workspace.normalizedCounts);
148         literalLengthTable = workspace.literalLengthTable;
149         break;
150       default:
151         throw new UnsupportedOperationException(NOT_YET_IMPLEMENTED);
152     }
153 
154     // offsets
155     Histogram.count(sequences.offsetCodes, sequenceCount, workspace.counts);
156     maxSymbol = Histogram.findMaxSymbol(counts, MAX_OFFSET_CODE_SYMBOL);
157     largestCount = Histogram.findLargestCount(counts, maxSymbol);
158 
159     // We can only use the basic table if max <= DEFAULT_MAX_OFFSET_CODE_SYMBOL, otherwise the offsets are too large .
160     final boolean defaultAllowed = maxSymbol < DEFAULT_MAX_OFFSET_CODE_SYMBOL;
161 
162     final int offsetEncodingType =
163         selectEncodingType(largestCount, sequenceCount,
164                            DEFAULT_OFFSET_NORMALIZED_COUNTS_LOG, defaultAllowed,
165                            strategy);
166 
167     final FseCompressionTable offsetCodeTable;
168     switch (offsetEncodingType) {
169       case SEQUENCE_ENCODING_RLE:
170         UNSAFE.putByte(outputBase, output, sequences.offsetCodes[0]);
171         output++;
172         workspace.offsetCodeTable.initializeRleTable(maxSymbol);
173         offsetCodeTable = workspace.offsetCodeTable;
174         break;
175       case SEQUENCE_ENCODING_BASIC:
176         offsetCodeTable = DEFAULT_OFFSETS_TABLE;
177         break;
178       case SEQUENCE_ENCODING_COMPRESSED:
179         output +=
180             buildCompressionTable(workspace.offsetCodeTable, outputBase, output,
181                                   output + outputSize, sequenceCount,
182                                   OFFSET_TABLE_LOG, sequences.offsetCodes,
183                                   workspace.counts, maxSymbol,
184                                   workspace.normalizedCounts);
185         offsetCodeTable = workspace.offsetCodeTable;
186         break;
187       default:
188         throw new UnsupportedOperationException(NOT_YET_IMPLEMENTED);
189     }
190 
191     // match lengths
192     Histogram.count(sequences.matchLengthCodes, sequenceCount,
193                     workspace.counts);
194     maxSymbol = Histogram.findMaxSymbol(counts, MAX_MATCH_LENGTH_SYMBOL);
195     largestCount = Histogram.findLargestCount(counts, maxSymbol);
196 
197     final int matchLengthEncodingType =
198         selectEncodingType(largestCount, sequenceCount,
199                            DEFAULT_MATCH_LENGTH_NORMALIZED_COUNTS_LOG, true,
200                            strategy);
201 
202     final FseCompressionTable matchLengthTable;
203     switch (matchLengthEncodingType) {
204       case SEQUENCE_ENCODING_RLE:
205         UNSAFE.putByte(outputBase, output, sequences.matchLengthCodes[0]);
206         output++;
207         workspace.matchLengthTable.initializeRleTable(maxSymbol);
208         matchLengthTable = workspace.matchLengthTable;
209         break;
210       case SEQUENCE_ENCODING_BASIC:
211         matchLengthTable = DEFAULT_MATCH_LENGTHS_TABLE;
212         break;
213       case SEQUENCE_ENCODING_COMPRESSED:
214         output += buildCompressionTable(workspace.matchLengthTable, outputBase,
215                                         output, outputLimit, sequenceCount,
216                                         MATCH_LENGTH_TABLE_LOG,
217                                         sequences.matchLengthCodes,
218                                         workspace.counts, maxSymbol,
219                                         workspace.normalizedCounts);
220         matchLengthTable = workspace.matchLengthTable;
221         break;
222       default:
223         throw new UnsupportedOperationException(NOT_YET_IMPLEMENTED);
224     }
225 
226     // flags
227     UNSAFE.putByte(outputBase, headerAddress,
228                    (byte) ((literalsLengthEncodingType << 6) |
229                            (offsetEncodingType << 4) |
230                            (matchLengthEncodingType << 2)));
231 
232     output += encodeSequences(outputBase, output, outputLimit, matchLengthTable,
233                               offsetCodeTable, literalLengthTable, sequences);
234 
235     return (int) (output - outputAddress);
236   }
237 
238   private static int buildCompressionTable(final FseCompressionTable table,
239                                            final Object outputBase,
240                                            final long output,
241                                            final long outputLimit,
242                                            int sequenceCount,
243                                            final int maxTableLog,
244                                            final byte[] codes,
245                                            final int[] counts,
246                                            final int maxSymbol,
247                                            final short[] normalizedCounts) {
248     final int tableLog = optimalTableLog(maxTableLog, sequenceCount, maxSymbol);
249 
250     // this is a minor optimization. The last symbol is embedded in the initial FSE state, so it's not part of the bitstream. We can omit it from the
251     // statistics (but only if its count is > 1). This makes the statistics a tiny bit more accurate.
252     if (counts[codes[sequenceCount - 1]] > 1) {
253       counts[codes[sequenceCount - 1]]--;
254       sequenceCount--;
255     }
256 
257     FiniteStateEntropy.normalizeCounts(normalizedCounts, tableLog, counts,
258                                        sequenceCount, maxSymbol);
259     table.initialize(normalizedCounts, maxSymbol, tableLog);
260 
261     return FiniteStateEntropy.writeNormalizedCounts(outputBase, output,
262                                                     (int) (outputLimit -
263                                                            output),
264                                                     normalizedCounts, maxSymbol,
265                                                     tableLog); // TODO: pass outputLimit directly
266   }
267 
268   private static int encodeSequences(final Object outputBase, final long output,
269                                      final long outputLimit,
270                                      final FseCompressionTable matchLengthTable,
271                                      final FseCompressionTable offsetsTable,
272                                      final FseCompressionTable literalLengthTable,
273                                      final SequenceStore sequences) {
274     final byte[] matchLengthCodes = sequences.matchLengthCodes;
275     final byte[] offsetCodes = sequences.offsetCodes;
276     final byte[] literalLengthCodes = sequences.literalLengthCodes;
277 
278     final BitOutputStream blockStream =
279         new BitOutputStream(outputBase, output, (int) (outputLimit - output));
280 
281     final int sequenceCount = sequences.sequenceCount;
282 
283     // first symbols
284     int matchLengthState =
285         matchLengthTable.begin(matchLengthCodes[sequenceCount - 1]);
286     int offsetState = offsetsTable.begin(offsetCodes[sequenceCount - 1]);
287     int literalLengthState =
288         literalLengthTable.begin(literalLengthCodes[sequenceCount - 1]);
289 
290     blockStream.addBits(sequences.literalLengths[sequenceCount - 1],
291                         LITERALS_LENGTH_BITS[literalLengthCodes[sequenceCount -
292                                                                 1]]);
293     blockStream.addBits(sequences.matchLengths[sequenceCount - 1],
294                         MATCH_LENGTH_BITS[matchLengthCodes[sequenceCount - 1]]);
295     blockStream.addBits(sequences.offsets[sequenceCount - 1],
296                         offsetCodes[sequenceCount - 1]);
297     blockStream.flush();
298 
299     if (sequenceCount >= 2) {
300       for (int n = sequenceCount - 2; n >= 0; n--) {
301         final byte literalLengthCode = literalLengthCodes[n];
302         final byte offsetCode = offsetCodes[n];
303         final byte matchLengthCode = matchLengthCodes[n];
304 
305         final int literalLengthBits = LITERALS_LENGTH_BITS[literalLengthCode];
306         final int matchLengthBits = MATCH_LENGTH_BITS[matchLengthCode];
307 
308         // (7)
309         offsetState =
310             offsetsTable.encode(blockStream, offsetState, offsetCode); // 15
311         matchLengthState =
312             matchLengthTable.encode(blockStream, matchLengthState,
313                                     matchLengthCode); // 24
314         literalLengthState =
315             literalLengthTable.encode(blockStream, literalLengthState,
316                                       literalLengthCode); // 33
317 
318         if (((int) offsetCode + matchLengthBits + literalLengthBits >= 64 - 7 -
319                                                                        (LITERAL_LENGTH_TABLE_LOG +
320                                                                         MATCH_LENGTH_TABLE_LOG +
321                                                                         OFFSET_TABLE_LOG))) {
322           blockStream.flush();                                /* (7)*/
323         }
324 
325         blockStream.addBits(sequences.literalLengths[n], literalLengthBits);
326         if (((literalLengthBits + matchLengthBits) > 24)) {
327           blockStream.flush();
328         }
329 
330         blockStream.addBits(sequences.matchLengths[n], matchLengthBits);
331         if (((int) offsetCode + matchLengthBits + literalLengthBits > 56)) {
332           blockStream.flush();
333         }
334 
335         blockStream.addBits(sequences.offsets[n], offsetCode); // 31
336         blockStream.flush(); // (7)
337       }
338     }
339 
340     matchLengthTable.finish(blockStream, matchLengthState);
341     offsetsTable.finish(blockStream, offsetState);
342     literalLengthTable.finish(blockStream, literalLengthState);
343 
344     final int streamSize = blockStream.close();
345     checkArgument(streamSize > 0, "Output buffer too small");
346 
347     return streamSize;
348   }
349 
350   private static int selectEncodingType(final int largestCount,
351                                         final int sequenceCount,
352                                         final int defaultNormalizedCountsLog,
353                                         final boolean isDefaultTableAllowed,
354                                         final CompressionParameters.Strategy strategy) {
355     if (largestCount == sequenceCount) { // => all entries are equal
356       if (isDefaultTableAllowed && sequenceCount <= 2) {
357         /* Prefer set_basic over set_rle when there are 2 or fewer symbols,
358          * since RLE uses 1 byte, but set_basic uses 5-6 bits per symbol.
359          * If basic encoding isn't possible, always choose RLE.
360          */
361         return SEQUENCE_ENCODING_BASIC;
362       }
363 
364       return SEQUENCE_ENCODING_RLE;
365     }
366 
367     if (strategy.ordinal() <
368         CompressionParameters.Strategy.LAZY.ordinal()) { // TODO: more robust check. Maybe encapsulate in strategy objects
369       if (isDefaultTableAllowed) {
370         final int factor =
371             10 - strategy.ordinal(); // TODO more robust. Move it to strategy
372         final int baseLog = 3;
373         final long minNumberOfSequences =
374             ((1L << defaultNormalizedCountsLog) * factor) >>
375             baseLog;  /* 28-36 for offset, 56-72 for lengths */
376 
377         if ((sequenceCount < minNumberOfSequences) || (largestCount <
378                                                        (sequenceCount >>
379                                                         (defaultNormalizedCountsLog -
380                                                          1)))) {
381           /* The format allows default tables to be repeated, but it isn't useful.
382            * When using simple heuristics to select encoding type, we don't want
383            * to confuse these tables with dictionaries. When running more careful
384            * analysis, we don't need to waste time checking both repeating tables
385            * and default tables.
386            */
387           return SEQUENCE_ENCODING_BASIC;
388         }
389       }
390     } else {
391       // TODO implement when other strategies are supported
392       throw new UnsupportedOperationException(NOT_YET_IMPLEMENTED);
393     }
394 
395     return SEQUENCE_ENCODING_COMPRESSED;
396   }
397 }