View Javadoc
1   /*
2    * This file is part of Waarp Project (named also Waarp or GG).
3    *
4    *  Copyright (c) 2019, Waarp SAS, and individual contributors by the @author
5    *  tags. See the COPYRIGHT.txt in the distribution for a full listing of
6    * individual contributors.
7    *
8    *  All Waarp Project is free software: you can redistribute it and/or
9    * modify it under the terms of the GNU General Public License as published by
10   * the Free Software Foundation, either version 3 of the License, or (at your
11   * option) any later version.
12   *
13   * Waarp is distributed in the hope that it will be useful, but WITHOUT ANY
14   * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
15   * A PARTICULAR PURPOSE. See the GNU General Public License for more details.
16   *
17   *  You should have received a copy of the GNU General Public License along with
18   * Waarp . If not, see <http://www.gnu.org/licenses/>.
19   */
20  
21  /*
22   * Licensed under the Apache License, Version 2.0 (the "License");
23   * you may not use this file except in compliance with the License.
24   * You may obtain a copy of the License at
25   *
26   *     http://www.apache.org/licenses/LICENSE-2.0
27   *
28   * Unless required by applicable law or agreed to in writing, software
29   * distributed under the License is distributed on an "AS IS" BASIS,
30   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31   * See the License for the specific language governing permissions and
32   * limitations under the License.
33   */
34  package org.waarp.compress.zstdsafe;
35  
36  import static org.waarp.compress.zstdsafe.FiniteStateEntropy.*;
37  import static org.waarp.compress.zstdsafe.UnsafeUtil.*;
38  import static org.waarp.compress.zstdsafe.Util.*;
39  
40  class FseTableReader {
41    private final short[] nextSymbol = new short[MAX_SYMBOL + 1];
42    private final short[] normalizedCounters = new short[MAX_SYMBOL + 1];
43  
44    public int readFseTable(final FiniteStateEntropy.Table table,
45                            final byte[] inputBase, final int inputAddress,
46                            final int inputLimit, int maxSymbol,
47                            final int maxTableLog) {
48      // read table headers
49      int input = inputAddress;
50      verify(inputLimit - inputAddress >= 4, input, "Not enough input bytes");
51  
52      int threshold;
53      int symbolNumber = 0;
54      boolean previousIsZero = false;
55  
56      int bitStream = getInt(inputBase, input);
57  
58      final int tableLog = (bitStream & 0xF) + MIN_TABLE_LOG;
59  
60      int numberOfBits = tableLog + 1;
61      bitStream >>>= 4;
62      int bitCount = 4;
63  
64      verify(tableLog <= maxTableLog, input,
65             "FSE table size exceeds maximum allowed size");
66  
67      int remaining = (1 << tableLog) + 1;
68      threshold = 1 << tableLog;
69  
70      while (remaining > 1 && symbolNumber <= maxSymbol) {
71        if (previousIsZero) {
72          int n0 = symbolNumber;
73          while ((bitStream & 0xFFFF) == 0xFFFF) {
74            n0 += 24;
75            if (input < inputLimit - 5) {
76              input += 2;
77              bitStream = (getInt(inputBase, input) >>> bitCount);
78            } else {
79              // end of bit stream
80              bitStream >>>= 16;
81              bitCount += 16;
82            }
83          }
84          while ((bitStream & 3) == 3) {
85            n0 += 3;
86            bitStream >>>= 2;
87            bitCount += 2;
88          }
89          n0 += bitStream & 3;
90          bitCount += 2;
91  
92          verify(n0 <= maxSymbol, input, "Symbol larger than max value");
93  
94          while (symbolNumber < n0) {
95            normalizedCounters[symbolNumber++] = 0;
96          }
97          if ((input <= inputLimit - 7) ||
98              (input + (bitCount >>> 3) <= inputLimit - 4)) {
99            input += bitCount >>> 3;
100           bitCount &= 7;
101           bitStream = getInt(inputBase, input) >>> bitCount;
102         } else {
103           bitStream >>>= 2;
104         }
105       }
106 
107       final short max = (short) ((2 * threshold - 1) - remaining);
108       short count;
109 
110       if ((bitStream & (threshold - 1)) < max) {
111         count = (short) (bitStream & (threshold - 1));
112         bitCount += numberOfBits - 1;
113       } else {
114         count = (short) (bitStream & (2 * threshold - 1));
115         if (count >= threshold) {
116           count -= max;
117         }
118         bitCount += numberOfBits;
119       }
120       count--;  // extra accuracy
121 
122       remaining -= Math.abs(count);
123       normalizedCounters[symbolNumber++] = count;
124       previousIsZero = count == 0;
125       while (remaining < threshold) {
126         numberOfBits--;
127         threshold >>>= 1;
128       }
129 
130       if ((input <= inputLimit - 7) ||
131           (input + (bitCount >> 3) <= inputLimit - 4)) {
132         input += bitCount >>> 3;
133         bitCount &= 7;
134       } else {
135         bitCount -= 8 * (inputLimit - 4 - input);
136         input = inputLimit - 4;
137       }
138       bitStream = getInt(inputBase, input) >>> (bitCount & 31);
139     }
140 
141     verify(remaining == 1 && bitCount <= 32, input, "Input is corrupted");
142 
143     maxSymbol = symbolNumber - 1;
144     verify(maxSymbol <= MAX_SYMBOL, input,
145            "Max symbol value too large (too many symbols for FSE)");
146 
147     input += (bitCount + 7) >> 3;
148 
149     // populate decoding table
150     final int symbolCount = maxSymbol + 1;
151     final int tableSize = 1 << tableLog;
152     int highThreshold = tableSize - 1;
153 
154     table.log2Size = tableLog;
155 
156     for (byte symbol = 0; symbol < symbolCount; symbol++) {
157       if (normalizedCounters[symbol] == -1) {
158         table.symbol[highThreshold--] = symbol;
159         nextSymbol[symbol] = 1;
160       } else {
161         nextSymbol[symbol] = normalizedCounters[symbol];
162       }
163     }
164 
165     final int position =
166         FseCompressionTable.spreadSymbols(normalizedCounters, maxSymbol,
167                                           tableSize, highThreshold,
168                                           table.symbol);
169 
170     // position must reach all cells once, otherwise normalizedCounter is incorrect
171     verify(position == 0, input, "Input is corrupted");
172 
173     for (int i = 0; i < tableSize; i++) {
174       final byte symbol = table.symbol[i];
175       final short nextState = nextSymbol[symbol]++;
176       table.numberOfBits[i] = (byte) (tableLog - highestBit(nextState));
177       table.newState[i] =
178           (short) ((nextState << table.numberOfBits[i]) - tableSize);
179     }
180 
181     return input - inputAddress;
182   }
183 
184   public static void initializeRleTable(final FiniteStateEntropy.Table table,
185                                         final byte value) {
186     table.log2Size = 0;
187     table.symbol[0] = value;
188     table.newState[0] = 0;
189     table.numberOfBits[0] = 0;
190   }
191 }