-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmodule.py
61 lines (56 loc) · 2.39 KB
/
module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import tensorflow.compat.v1 as tf
import tensorflow_compression as tfc
class RoundingEntropyBottleneck(tfc.EntropyBottleneck):
def __init__(
self,
init_scale=10,
filters=(3, 3, 3),
data_format="channels_last",
approx="STE-Q",
**kwargs
):
super(RoundingEntropyBottleneck, self).__init__(
init_scale=init_scale, filters=filters, data_format=data_format, **kwargs
)
assert approx in {"STE-Q", "St-Q", "SGA-Q", "U-Q"}
self.approx = approx
self.tau = 0.5
def _quantize(self, inputs, mode):
# Add noise or quantize (and optionally dequantize in one step).
half = tf.constant(0.5, dtype=self.dtype)
_, _, _, input_slices = self._get_input_dims()
medians = self._medians[input_slices]
outputs = tf.math.floor(inputs + (half - medians))
outputs = tf.cast(outputs, self.dtype)
if mode == "noise":
if self.approx == "STE-Q":
return tf.stop_gradient(outputs + medians - inputs) + inputs
elif self.approx in {"St-Q", "SGA-Q"}:
diff = (inputs - medians) - tf.floor(inputs - medians)
if self.approx == "St-Q":
probability = diff
else:
likelihood_up = tf.exp(-tf.atanh(diff) / self.tau)
likelihood_down = tf.exp(-tf.atanh(1 - diff) / self.tau)
probability = likelihood_down / (likelihood_up + likelihood_down)
delta = tf.cast(
(probability >= tf.random.uniform(tf.shape(probability))),
tf.float32,
)
outputs = tf.floor(inputs - medians) + delta
return tf.stop_gradient(outputs + medians - inputs) + inputs
elif self.approx == "U-Q":
# random value, shape: (N, 1, 1, 1)
noise = tf.random.uniform(tf.shape(inputs), -half, half)[
:, 0:1, 0:1, 0:1
]
outputs = tf.round(inputs + noise) - noise
return tf.stop_gradient(outputs - inputs) + inputs
else:
raise NotImplementedError
elif mode == "dequantize":
return outputs + medians
else:
assert mode == "symbols", mode
outputs = tf.cast(outputs, tf.int32)
return outputs