View Javadoc
1   /*
2    * This file is part of Waarp Project (named also Waarp or GG).
3    *
4    *  Copyright (c) 2019, Waarp SAS, and individual contributors by the @author
5    *  tags. See the COPYRIGHT.txt in the distribution for a full listing of
6    * individual contributors.
7    *
8    *  All Waarp Project is free software: you can redistribute it and/or
9    * modify it under the terms of the GNU General Public License as published by
10   * the Free Software Foundation, either version 3 of the License, or (at your
11   * option) any later version.
12   *
13   * Waarp is distributed in the hope that it will be useful, but WITHOUT ANY
14   * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
15   * A PARTICULAR PURPOSE. See the GNU General Public License for more details.
16   *
17   *  You should have received a copy of the GNU General Public License along with
18   * Waarp . If not, see <http://www.gnu.org/licenses/>.
19   */
20  
21  /*
22   * Licensed under the Apache License, Version 2.0 (the "License");
23   * you may not use this file except in compliance with the License.
24   * You may obtain a copy of the License at
25   *
26   *     http://www.apache.org/licenses/LICENSE-2.0
27   *
28   * Unless required by applicable law or agreed to in writing, software
29   * distributed under the License is distributed on an "AS IS" BASIS,
30   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31   * See the License for the specific language governing permissions and
32   * limitations under the License.
33   */
34  package org.waarp.compress.zstdsafe;
35  
36  import static org.waarp.compress.zstdsafe.FiniteStateEntropy.*;
37  
38  class FseCompressionTable {
39    private final short[] nextState;
40    private final int[] deltaNumberOfBits;
41    private final int[] deltaFindState;
42  
43    private int log2Size;
44  
45    public FseCompressionTable(final int maxTableLog, final int maxSymbol) {
46      nextState = new short[1 << maxTableLog];
47      deltaNumberOfBits = new int[maxSymbol + 1];
48      deltaFindState = new int[maxSymbol + 1];
49    }
50  
51    public static FseCompressionTable newInstance(final short[] normalizedCounts,
52                                                  final int maxSymbol,
53                                                  final int tableLog) {
54      final FseCompressionTable result =
55          new FseCompressionTable(tableLog, maxSymbol);
56      result.initialize(normalizedCounts, maxSymbol, tableLog);
57  
58      return result;
59    }
60  
61    public void initializeRleTable(final int symbol) {
62      log2Size = 0;
63  
64      nextState[0] = 0;
65      nextState[1] = 0;
66  
67      deltaFindState[symbol] = 0;
68      deltaNumberOfBits[symbol] = 0;
69    }
70  
71    public void initialize(final short[] normalizedCounts, final int maxSymbol,
72                           final int tableLog) {
73      final int tableSize = 1 << tableLog;
74  
75      final byte[] table = new byte[tableSize]; // TODO: allocate in workspace
76      int highThreshold = tableSize - 1;
77  
78      // TODO: make sure FseCompressionTable has enough size
79      log2Size = tableLog;
80  
81      // For explanations on how to distribute symbol values over the table:
82      // http://fastcompression.blogspot.fr/2014/02/fse-distributing-symbol-values.html
83  
84      // symbol start positions
85      final int[] cumulative =
86          new int[MAX_SYMBOL + 2]; // TODO: allocate in workspace
87      cumulative[0] = 0;
88      for (int i = 1; i <= maxSymbol + 1; i++) {
89        if (normalizedCounts[i - 1] == -1) {  // Low probability symbol
90          cumulative[i] = cumulative[i - 1] + 1;
91          table[highThreshold--] = (byte) (i - 1);
92        } else {
93          cumulative[i] = cumulative[i - 1] + normalizedCounts[i - 1];
94        }
95      }
96      cumulative[maxSymbol + 1] = tableSize + 1;
97  
98      // Spread symbols
99      final int position =
100         spreadSymbols(normalizedCounts, maxSymbol, tableSize, highThreshold,
101                       table);
102 
103     if (position != 0) {
104       throw new AssertionError("Spread symbols failed");
105     }
106 
107     // Build table
108     for (int i = 0; i < tableSize; i++) {
109       final byte symbol = table[i];
110       nextState[cumulative[symbol]++] = (short) (tableSize +
111                                                  i);  /* TableU16 : sorted by symbol order; gives next state value */
112     }
113 
114     // Build symbol transformation table
115     int total = 0;
116     for (int symbol = 0; symbol <= maxSymbol; symbol++) {
117       switch (normalizedCounts[symbol]) {
118         case 0:
119           deltaNumberOfBits[symbol] = ((tableLog + 1) << 16) - tableSize;
120           break;
121         case -1:
122         case 1:
123           deltaNumberOfBits[symbol] = (tableLog << 16) - tableSize;
124           deltaFindState[symbol] = total - 1;
125           total++;
126           break;
127         default:
128           final int maxBitsOut =
129               tableLog - Util.highestBit(normalizedCounts[symbol] - 1);
130           final int minStatePlus = normalizedCounts[symbol] << maxBitsOut;
131           deltaNumberOfBits[symbol] = (maxBitsOut << 16) - minStatePlus;
132           deltaFindState[symbol] = total - normalizedCounts[symbol];
133           total += normalizedCounts[symbol];
134           break;
135       }
136     }
137   }
138 
139   public int begin(final byte symbol) {
140     final int outputBits = (deltaNumberOfBits[symbol] + (1 << 15)) >>> 16;
141     final int base =
142         ((outputBits << 16) - deltaNumberOfBits[symbol]) >>> outputBits;
143     return nextState[base + deltaFindState[symbol]];
144   }
145 
146   public int encode(final BitOutputStream stream, final int state,
147                     final int symbol) {
148     final int outputBits = (state + deltaNumberOfBits[symbol]) >>> 16;
149     stream.addBits(state, outputBits);
150     return nextState[(state >>> outputBits) + deltaFindState[symbol]];
151   }
152 
153   public void finish(final BitOutputStream stream, final int state) {
154     stream.addBits(state, log2Size);
155     stream.flush();
156   }
157 
158   private static int calculateStep(final int tableSize) {
159     return (tableSize >>> 1) + (tableSize >>> 3) + 3;
160   }
161 
162   public static int spreadSymbols(final short[] normalizedCounters,
163                                   final int maxSymbolValue, final int tableSize,
164                                   final int highThreshold,
165                                   final byte[] symbols) {
166     final int mask = tableSize - 1;
167     final int step = calculateStep(tableSize);
168 
169     int position = 0;
170     for (byte symbol = 0; symbol <= maxSymbolValue; symbol++) {
171       for (int i = 0; i < normalizedCounters[symbol]; i++) {
172         symbols[position] = symbol;
173         do {
174           position = (position + step) & mask;
175         } while (position > highThreshold);
176       }
177     }
178     return position;
179   }
180 }