|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import itertools |
|
|
|
from absl.testing import parameterized |
|
import numpy as np |
|
import tensorflow as tf |
|
|
|
from tensorflow_model_optimization.python.core.internal.tensor_encoding.stages.research import misc |
|
from tensorflow_model_optimization.python.core.internal.tensor_encoding.testing import test_utils |
|
|
|
|
|
if tf.executing_eagerly(): |
|
tf.compat.v1.disable_eager_execution() |
|
|
|
|
|
class SplitBySmallValueEncodingStageTest(test_utils.BaseEncodingStageTest): |
|
|
|
def default_encoding_stage(self): |
|
"""See base class.""" |
|
return misc.SplitBySmallValueEncodingStage() |
|
|
|
def default_input(self): |
|
"""See base class.""" |
|
return tf.random.uniform([50], minval=-1.0, maxval=1.0) |
|
|
|
@property |
|
def is_lossless(self): |
|
"""See base class.""" |
|
return False |
|
|
|
def common_asserts_for_test_data(self, data): |
|
"""See base class.""" |
|
self._assert_is_integer( |
|
data.encoded_x[misc.SplitBySmallValueEncodingStage.ENCODED_INDICES_KEY]) |
|
|
|
def _assert_is_integer(self, indices): |
|
"""Asserts that indices values are integers.""" |
|
assert indices.dtype == np.int32 |
|
|
|
@parameterized.parameters([tf.float32, tf.float64]) |
|
def test_input_types(self, x_dtype): |
|
|
|
x = tf.constant([1.0, 0.1, 0.01, 0.001, 0.0001], dtype=x_dtype) |
|
threshold = 0.05 |
|
stage = misc.SplitBySmallValueEncodingStage(threshold=threshold) |
|
encode_params, decode_params = stage.get_params() |
|
encoded_x, decoded_x = self.encode_decode_x(stage, x, encode_params, |
|
decode_params) |
|
test_data = test_utils.TestData(x, encoded_x, decoded_x) |
|
test_data = self.evaluate_test_data(test_data) |
|
|
|
self._assert_is_integer(test_data.encoded_x[ |
|
misc.SplitBySmallValueEncodingStage.ENCODED_INDICES_KEY]) |
|
|
|
|
|
expected_encoded_values = np.array([1.0, 0.1], dtype=x.dtype.as_numpy_dtype) |
|
expected_encoded_indices = np.array([0, 1], dtype=np.int32) |
|
expected_decoded_x = np.array([1.0, 0.1, 0., 0., 0.], |
|
dtype=x_dtype.as_numpy_dtype) |
|
self.assertAllEqual(test_data.encoded_x[stage.ENCODED_VALUES_KEY], |
|
expected_encoded_values) |
|
self.assertAllEqual(test_data.encoded_x[stage.ENCODED_INDICES_KEY], |
|
expected_encoded_indices) |
|
self.assertAllEqual(test_data.decoded_x, expected_decoded_x) |
|
|
|
def test_all_zero_input_works(self): |
|
|
|
|
|
stage = misc.SplitBySmallValueEncodingStage() |
|
test_data = self.run_one_to_many_encode_decode(stage, |
|
lambda: tf.zeros([50])) |
|
|
|
self.assertAllEqual(np.zeros((50)).astype(np.float32), test_data.decoded_x) |
|
|
|
def test_all_below_threshold_works(self): |
|
|
|
|
|
stage = misc.SplitBySmallValueEncodingStage(threshold=0.1) |
|
x = tf.random.uniform([50], minval=-0.01, maxval=0.01) |
|
encode_params, decode_params = stage.get_params() |
|
encoded_x, decoded_x = self.encode_decode_x(stage, x, encode_params, |
|
decode_params) |
|
test_data = test_utils.TestData(x, encoded_x, decoded_x) |
|
test_data = self.evaluate_test_data(test_data) |
|
|
|
expected_encoded_indices = np.array([], dtype=np.int32).reshape([0]) |
|
self.assertAllEqual(test_data.encoded_x[stage.ENCODED_VALUES_KEY], []) |
|
self.assertAllEqual(test_data.encoded_x[stage.ENCODED_INDICES_KEY], |
|
expected_encoded_indices) |
|
self.assertAllEqual(test_data.decoded_x, |
|
np.zeros([50], dtype=x.dtype.as_numpy_dtype)) |
|
|
|
|
|
class DifferenceBetweenIntegersEncodingStageTest( |
|
test_utils.BaseEncodingStageTest): |
|
|
|
def default_encoding_stage(self): |
|
"""See base class.""" |
|
return misc.DifferenceBetweenIntegersEncodingStage() |
|
|
|
def default_input(self): |
|
"""See base class.""" |
|
return tf.random.uniform([10], minval=0, maxval=10, dtype=tf.int64) |
|
|
|
@property |
|
def is_lossless(self): |
|
"""See base class.""" |
|
return True |
|
|
|
def common_asserts_for_test_data(self, data): |
|
"""See base class.""" |
|
self.assertAllEqual(data.x, data.decoded_x) |
|
|
|
@parameterized.parameters( |
|
itertools.product([[1,], [2,], [10,]], [tf.int32, tf.int64])) |
|
def test_with_multiple_input_shapes(self, input_dims, dtype): |
|
|
|
def x_fn(): |
|
return tf.random.uniform(input_dims, minval=0, maxval=10, dtype=dtype) |
|
|
|
test_data = self.run_one_to_many_encode_decode( |
|
self.default_encoding_stage(), x_fn) |
|
self.common_asserts_for_test_data(test_data) |
|
|
|
def test_empty_input_static(self): |
|
|
|
x = [] |
|
x = tf.convert_to_tensor(x, dtype=tf.int32) |
|
assert x.shape.as_list() == [0] |
|
|
|
stage = self.default_encoding_stage() |
|
encode_params, decode_params = stage.get_params() |
|
encoded_x, decoded_x = self.encode_decode_x(stage, x, encode_params, |
|
decode_params) |
|
|
|
test_data = self.evaluate_test_data( |
|
test_utils.TestData(x, encoded_x, decoded_x)) |
|
self.common_asserts_for_test_data(test_data) |
|
|
|
def test_empty_input_dynamic(self): |
|
|
|
|
|
y = tf.zeros((10,)) |
|
indices = tf.compat.v2.where(tf.abs(y) > 1e-8) |
|
x = tf.gather_nd(y, indices) |
|
x = tf.cast(x, tf.int32) |
|
assert x.shape.as_list() == [None] |
|
stage = self.default_encoding_stage() |
|
encode_params, decode_params = stage.get_params() |
|
encoded_x, decoded_x = self.encode_decode_x(stage, x, encode_params, |
|
decode_params) |
|
|
|
test_data = self.evaluate_test_data( |
|
test_utils.TestData(x, encoded_x, decoded_x)) |
|
assert test_data.x.shape == (0,) |
|
assert test_data.encoded_x[stage.ENCODED_VALUES_KEY].shape == (0,) |
|
assert test_data.decoded_x.shape == (0,) |
|
|
|
@parameterized.parameters([tf.bool, tf.float32]) |
|
def test_encode_unsupported_type_raises(self, dtype): |
|
stage = self.default_encoding_stage() |
|
with self.assertRaisesRegexp(TypeError, 'Unsupported input type'): |
|
self.run_one_to_many_encode_decode( |
|
stage, lambda: tf.cast(self.default_input(), dtype)) |
|
|
|
def test_encode_unsupported_input_shape_raises(self): |
|
x = tf.random.uniform((3, 4), maxval=10, dtype=tf.int32) |
|
stage = self.default_encoding_stage() |
|
params, _ = stage.get_params() |
|
with self.assertRaisesRegexp(ValueError, 'Number of dimensions must be 1'): |
|
stage.encode(x, params) |
|
|
|
|
|
if __name__ == '__main__': |
|
tf.test.main() |
|
|