1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34 package org.waarp.compress.zstdunsafe;
35
36
37
38
39
40
41
42
43 class BitInputStream {
44 private BitInputStream() {
45 }
46
47 public static boolean isEndOfStream(final long startAddress,
48 final long currentAddress,
49 final int bitsConsumed) {
50 return startAddress == currentAddress && bitsConsumed == Long.SIZE;
51 }
52
53 static long readTail(final Object inputBase, final long inputAddress,
54 final int inputSize) {
55 long bits = UnsafeUtil.UNSAFE.getByte(inputBase, inputAddress) & 0xFF;
56
57 switch (inputSize) {
58 case 7:
59 bits |=
60 (UnsafeUtil.UNSAFE.getByte(inputBase, inputAddress + 6) & 0xFFL) <<
61 48;
62 case 6:
63 bits |=
64 (UnsafeUtil.UNSAFE.getByte(inputBase, inputAddress + 5) & 0xFFL) <<
65 40;
66 case 5:
67 bits |=
68 (UnsafeUtil.UNSAFE.getByte(inputBase, inputAddress + 4) & 0xFFL) <<
69 32;
70 case 4:
71 bits |=
72 (UnsafeUtil.UNSAFE.getByte(inputBase, inputAddress + 3) & 0xFFL) <<
73 24;
74 case 3:
75 bits |=
76 (UnsafeUtil.UNSAFE.getByte(inputBase, inputAddress + 2) & 0xFFL) <<
77 16;
78 case 2:
79 bits |=
80 (UnsafeUtil.UNSAFE.getByte(inputBase, inputAddress + 1) & 0xFFL) <<
81 8;
82 }
83
84 return bits;
85 }
86
87
88
89
90 public static long peekBits(final int bitsConsumed, final long bitContainer,
91 final int numberOfBits) {
92 return (((bitContainer << bitsConsumed) >>> 1) >>> (63 - numberOfBits));
93 }
94
95
96
97
98
99
100 public static long peekBitsFast(final int bitsConsumed,
101 final long bitContainer,
102 final int numberOfBits) {
103 return ((bitContainer << bitsConsumed) >>> (64 - numberOfBits));
104 }
105
106 static class Initializer {
107 private final Object inputBase;
108 private final long startAddress;
109 private final long endAddress;
110 private long bits;
111 private long currentAddress;
112 private int bitsConsumed;
113
114 public Initializer(final Object inputBase, final long startAddress,
115 final long endAddress) {
116 this.inputBase = inputBase;
117 this.startAddress = startAddress;
118 this.endAddress = endAddress;
119 }
120
121 public long getBits() {
122 return bits;
123 }
124
125 public long getCurrentAddress() {
126 return currentAddress;
127 }
128
129 public int getBitsConsumed() {
130 return bitsConsumed;
131 }
132
133 public void initialize() {
134 Util.verify(endAddress - startAddress >= 1, startAddress,
135 "Bitstream is empty");
136
137 final int lastByte =
138 UnsafeUtil.UNSAFE.getByte(inputBase, endAddress - 1) & 0xFF;
139 Util.verify(lastByte != 0, endAddress, "Bitstream end mark not present");
140
141 bitsConsumed = Constants.SIZE_OF_LONG - Util.highestBit(lastByte);
142
143 final int inputSize = (int) (endAddress - startAddress);
144 if (inputSize >= Constants.SIZE_OF_LONG) {
145 currentAddress = endAddress - Constants.SIZE_OF_LONG;
146 bits = UnsafeUtil.UNSAFE.getLong(inputBase, currentAddress);
147 } else {
148 currentAddress = startAddress;
149 bits = readTail(inputBase, startAddress, inputSize);
150
151 bitsConsumed += (Constants.SIZE_OF_LONG - inputSize) * 8;
152 }
153 }
154 }
155
156 static final class Loader {
157 private final Object inputBase;
158 private final long startAddress;
159 private long bits;
160 private long currentAddress;
161 private int bitsConsumed;
162 private boolean overflow;
163
164 public Loader(final Object inputBase, final long startAddress,
165 final long currentAddress, final long bits,
166 final int bitsConsumed) {
167 this.inputBase = inputBase;
168 this.startAddress = startAddress;
169 this.bits = bits;
170 this.currentAddress = currentAddress;
171 this.bitsConsumed = bitsConsumed;
172 }
173
174 public long getBits() {
175 return bits;
176 }
177
178 public long getCurrentAddress() {
179 return currentAddress;
180 }
181
182 public int getBitsConsumed() {
183 return bitsConsumed;
184 }
185
186 public boolean isOverflow() {
187 return overflow;
188 }
189
190 public boolean load() {
191 if (bitsConsumed > 64) {
192 overflow = true;
193 return true;
194 } else if (currentAddress == startAddress) {
195 return true;
196 }
197
198 int bytes = bitsConsumed >>> 3;
199 if (currentAddress >= startAddress + Constants.SIZE_OF_LONG) {
200 if (bytes > 0) {
201 currentAddress -= bytes;
202 bits = UnsafeUtil.UNSAFE.getLong(inputBase, currentAddress);
203 }
204 bitsConsumed &= 0x7;
205 } else if (currentAddress - bytes < startAddress) {
206 bytes = (int) (currentAddress - startAddress);
207 currentAddress = startAddress;
208 bitsConsumed -= bytes * Constants.SIZE_OF_LONG;
209 bits = UnsafeUtil.UNSAFE.getLong(inputBase, startAddress);
210 return true;
211 } else {
212 currentAddress -= bytes;
213 bitsConsumed -= bytes * Constants.SIZE_OF_LONG;
214 bits = UnsafeUtil.UNSAFE.getLong(inputBase, currentAddress);
215 }
216
217 return false;
218 }
219 }
220 }