View Javadoc
1   /*
2    * This file is part of Waarp Project (named also Waarp or GG).
3    *
4    *  Copyright (c) 2019, Waarp SAS, and individual contributors by the @author
5    *  tags. See the COPYRIGHT.txt in the distribution for a full listing of
6    * individual contributors.
7    *
8    *  All Waarp Project is free software: you can redistribute it and/or
9    * modify it under the terms of the GNU General Public License as published by
10   * the Free Software Foundation, either version 3 of the License, or (at your
11   * option) any later version.
12   *
13   * Waarp is distributed in the hope that it will be useful, but WITHOUT ANY
14   * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
15   * A PARTICULAR PURPOSE. See the GNU General Public License for more details.
16   *
17   *  You should have received a copy of the GNU General Public License along with
18   * Waarp . If not, see <http://www.gnu.org/licenses/>.
19   */
20  
21  /*
22   * Licensed under the Apache License, Version 2.0 (the "License");
23   * you may not use this file except in compliance with the License.
24   * You may obtain a copy of the License at
25   *
26   *     http://www.apache.org/licenses/LICENSE-2.0
27   *
28   * Unless required by applicable law or agreed to in writing, software
29   * distributed under the License is distributed on an "AS IS" BASIS,
30   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31   * See the License for the specific language governing permissions and
32   * limitations under the License.
33   */
34  package org.waarp.compress.zstdunsafe;
35  
36  class FseCompressionTable {
37    private final short[] nextState;
38    private final int[] deltaNumberOfBits;
39    private final int[] deltaFindState;
40  
41    private int log2Size;
42  
43    public FseCompressionTable(final int maxTableLog, final int maxSymbol) {
44      nextState = new short[1 << maxTableLog];
45      deltaNumberOfBits = new int[maxSymbol + 1];
46      deltaFindState = new int[maxSymbol + 1];
47    }
48  
49    public static FseCompressionTable newInstance(final short[] normalizedCounts,
50                                                  final int maxSymbol,
51                                                  final int tableLog) {
52      final FseCompressionTable result =
53          new FseCompressionTable(tableLog, maxSymbol);
54      result.initialize(normalizedCounts, maxSymbol, tableLog);
55  
56      return result;
57    }
58  
59    public void initializeRleTable(final int symbol) {
60      log2Size = 0;
61  
62      nextState[0] = 0;
63      nextState[1] = 0;
64  
65      deltaFindState[symbol] = 0;
66      deltaNumberOfBits[symbol] = 0;
67    }
68  
69    public void initialize(final short[] normalizedCounts, final int maxSymbol,
70                           final int tableLog) {
71      final int tableSize = 1 << tableLog;
72  
73      final byte[] table = new byte[tableSize]; // TODO: allocate in workspace
74      int highThreshold = tableSize - 1;
75  
76      // TODO: make sure FseCompressionTable has enough size
77      log2Size = tableLog;
78  
79      // For explanations on how to distribute symbol values over the table:
80      // http://fastcompression.blogspot.fr/2014/02/fse-distributing-symbol-values.html
81  
82      // symbol start positions
83      final int[] cumulative = new int[FiniteStateEntropy.MAX_SYMBOL +
84                                       2]; // TODO: allocate in workspace
85      cumulative[0] = 0;
86      for (int i = 1; i <= maxSymbol + 1; i++) {
87        if (normalizedCounts[i - 1] == -1) {  // Low probability symbol
88          cumulative[i] = cumulative[i - 1] + 1;
89          table[highThreshold--] = (byte) (i - 1);
90        } else {
91          cumulative[i] = cumulative[i - 1] + normalizedCounts[i - 1];
92        }
93      }
94      cumulative[maxSymbol + 1] = tableSize + 1;
95  
96      // Spread symbols
97      final int position =
98          spreadSymbols(normalizedCounts, maxSymbol, tableSize, highThreshold,
99                        table);
100 
101     if (position != 0) {
102       throw new AssertionError("Spread symbols failed");
103     }
104 
105     // Build table
106     for (int i = 0; i < tableSize; i++) {
107       final byte symbol = table[i];
108       nextState[cumulative[symbol]++] = (short) (tableSize +
109                                                  i);  /* TableU16 : sorted by symbol order; gives next state value */
110     }
111 
112     // Build symbol transformation table
113     int total = 0;
114     for (int symbol = 0; symbol <= maxSymbol; symbol++) {
115       switch (normalizedCounts[symbol]) {
116         case 0:
117           deltaNumberOfBits[symbol] = ((tableLog + 1) << 16) - tableSize;
118           break;
119         case -1:
120         case 1:
121           deltaNumberOfBits[symbol] = (tableLog << 16) - tableSize;
122           deltaFindState[symbol] = total - 1;
123           total++;
124           break;
125         default:
126           final int maxBitsOut =
127               tableLog - Util.highestBit(normalizedCounts[symbol] - 1);
128           final int minStatePlus = normalizedCounts[symbol] << maxBitsOut;
129           deltaNumberOfBits[symbol] = (maxBitsOut << 16) - minStatePlus;
130           deltaFindState[symbol] = total - normalizedCounts[symbol];
131           total += normalizedCounts[symbol];
132           break;
133       }
134     }
135   }
136 
137   public int begin(final byte symbol) {
138     final int outputBits = (deltaNumberOfBits[symbol] + (1 << 15)) >>> 16;
139     final int base =
140         ((outputBits << 16) - deltaNumberOfBits[symbol]) >>> outputBits;
141     return nextState[base + deltaFindState[symbol]];
142   }
143 
144   public int encode(final BitOutputStream stream, final int state,
145                     final int symbol) {
146     final int outputBits = (state + deltaNumberOfBits[symbol]) >>> 16;
147     stream.addBits(state, outputBits);
148     return nextState[(state >>> outputBits) + deltaFindState[symbol]];
149   }
150 
151   public void finish(final BitOutputStream stream, final int state) {
152     stream.addBits(state, log2Size);
153     stream.flush();
154   }
155 
156   private static int calculateStep(final int tableSize) {
157     return (tableSize >>> 1) + (tableSize >>> 3) + 3;
158   }
159 
160   public static int spreadSymbols(final short[] normalizedCounters,
161                                   final int maxSymbolValue, final int tableSize,
162                                   final int highThreshold,
163                                   final byte[] symbols) {
164     final int mask = tableSize - 1;
165     final int step = calculateStep(tableSize);
166 
167     int position = 0;
168     for (byte symbol = 0; symbol <= maxSymbolValue; symbol++) {
169       for (int i = 0; i < normalizedCounters[symbol]; i++) {
170         symbols[position] = symbol;
171         do {
172           position = (position + step) & mask;
173         } while (position > highThreshold);
174       }
175     }
176     return position;
177   }
178 }