|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Integration tests for CQAT, PCQAT cases.""" |
|
from absl.testing import parameterized |
|
import numpy as np |
|
import tensorflow as tf |
|
|
|
from tensorflow_model_optimization.python.core.clustering.keras import cluster |
|
from tensorflow_model_optimization.python.core.clustering.keras import cluster_config |
|
from tensorflow_model_optimization.python.core.clustering.keras.experimental import cluster as experimental_cluster |
|
from tensorflow_model_optimization.python.core.keras.compat import keras |
|
from tensorflow_model_optimization.python.core.quantization.keras import quantize |
|
from tensorflow_model_optimization.python.core.quantization.keras.collab_opts.cluster_preserve import ( |
|
default_8bit_cluster_preserve_quantize_scheme,) |
|
from tensorflow_model_optimization.python.core.quantization.keras.collab_opts.cluster_preserve.cluster_utils import ( |
|
strip_clustering_cqat,) |
|
|
|
|
|
layers = keras.layers |
|
|
|
|
|
class ClusterPreserveIntegrationTest(tf.test.TestCase, parameterized.TestCase): |
|
|
|
def setUp(self): |
|
super(ClusterPreserveIntegrationTest, self).setUp() |
|
self.cluster_params = { |
|
'number_of_clusters': 4, |
|
'cluster_centroids_init': cluster_config.CentroidInitialization.LINEAR |
|
} |
|
|
|
def compile_and_fit(self, model): |
|
"""Here we compile and fit the model.""" |
|
model.compile( |
|
loss=keras.losses.categorical_crossentropy, |
|
optimizer='adam', |
|
metrics=['accuracy'], |
|
) |
|
model.fit( |
|
np.random.rand(20, 10), |
|
keras.utils.to_categorical(np.random.randint(5, size=(20, 1)), 5), |
|
batch_size=20, |
|
) |
|
|
|
def _get_number_of_unique_weights(self, stripped_model, layer_nr, |
|
weight_name): |
|
layer = stripped_model.layers[layer_nr] |
|
if isinstance(layer, quantize.quantize_wrapper.QuantizeWrapper): |
|
for weight_item in layer.trainable_weights: |
|
if weight_name in weight_item.name: |
|
weight = weight_item |
|
else: |
|
weight = getattr(layer, weight_name) |
|
weights_as_list = weight.numpy().flatten() |
|
nr_of_unique_weights = len(set(weights_as_list)) |
|
return nr_of_unique_weights |
|
|
|
def _get_sparsity(self, model): |
|
sparsity_list = [] |
|
for layer in model.layers: |
|
for weights in layer.trainable_weights: |
|
if 'kernel' in weights.name: |
|
np_weights = keras.backend.get_value(weights) |
|
sparsity = 1.0 - np.count_nonzero(np_weights) / float( |
|
np_weights.size) |
|
sparsity_list.append(sparsity) |
|
|
|
return sparsity_list |
|
|
|
def _get_clustered_model(self, preserve_sparsity): |
|
"""Cluster the (sparse) model and return clustered_model.""" |
|
tf.random.set_seed(1) |
|
original_model = keras.Sequential([ |
|
layers.Dense(5, activation='softmax', input_shape=(10,)), |
|
layers.Flatten(), |
|
]) |
|
|
|
|
|
if preserve_sparsity: |
|
first_layer_weights = original_model.layers[0].get_weights() |
|
first_layer_weights[0][:][0:2] = 0.0 |
|
original_model.layers[0].set_weights(first_layer_weights) |
|
|
|
|
|
clustering_params = { |
|
'number_of_clusters': 4, |
|
'cluster_centroids_init': cluster_config.CentroidInitialization.LINEAR, |
|
'preserve_sparsity': True |
|
} |
|
|
|
clustered_model = experimental_cluster.cluster_weights( |
|
original_model, **clustering_params) |
|
|
|
return clustered_model |
|
|
|
def _get_conv_model(self, |
|
nr_of_channels, |
|
data_format=None, |
|
kernel_size=(3, 3)): |
|
"""Returns functional model with Conv2D layer.""" |
|
inp = keras.layers.Input(shape=(32, 32), batch_size=100) |
|
shape = (1, 32, 32) if data_format == 'channels_first' else (32, 32, 1) |
|
x = keras.layers.Reshape(shape)(inp) |
|
x = keras.layers.Conv2D( |
|
filters=nr_of_channels, |
|
kernel_size=kernel_size, |
|
data_format=data_format, |
|
activation='relu', |
|
)(x) |
|
x = keras.layers.MaxPool2D(2, 2)(x) |
|
out = keras.layers.Flatten()(x) |
|
model = keras.Model(inputs=inp, outputs=out) |
|
return model |
|
|
|
def _compile_and_fit_conv_model(self, model, nr_epochs=1): |
|
"""Compile and fit conv model from _get_conv_model.""" |
|
x_train = np.random.uniform(size=(500, 32, 32)) |
|
y_train = np.random.randint(low=0, high=1024, size=(500,)) |
|
model.compile( |
|
optimizer=keras.optimizers.Adam(learning_rate=1e-4), |
|
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), |
|
metrics=[keras.metrics.SparseCategoricalAccuracy(name='accuracy')], |
|
) |
|
|
|
model.fit(x_train, y_train, epochs=nr_epochs, batch_size=100, verbose=1) |
|
|
|
return model |
|
|
|
def _get_conv_clustered_model(self, |
|
nr_of_channels, |
|
nr_of_clusters, |
|
data_format, |
|
preserve_sparsity, |
|
kernel_size=(3, 3)): |
|
"""Returns clustered per channel model with Conv2D layer.""" |
|
tf.random.set_seed(42) |
|
model = self._get_conv_model(nr_of_channels, data_format, kernel_size) |
|
|
|
if preserve_sparsity: |
|
|
|
assert model.layers[2].name == 'conv2d' |
|
|
|
conv_layer_weights = model.layers[2].get_weights() |
|
shape = conv_layer_weights[0].shape |
|
conv_layer_weights_flatten = conv_layer_weights[0].flatten() |
|
|
|
nr_elems = len(conv_layer_weights_flatten) |
|
conv_layer_weights_flatten[0:1 + nr_elems // 2] = 0.0 |
|
pruned_conv_layer_weights = tf.reshape(conv_layer_weights_flatten, shape) |
|
conv_layer_weights[0] = pruned_conv_layer_weights |
|
model.layers[2].set_weights(conv_layer_weights) |
|
|
|
clustering_params = { |
|
'number_of_clusters': |
|
nr_of_clusters, |
|
'cluster_centroids_init': |
|
cluster_config.CentroidInitialization.KMEANS_PLUS_PLUS, |
|
'cluster_per_channel': |
|
True, |
|
'preserve_sparsity': |
|
preserve_sparsity |
|
} |
|
|
|
clustered_model = experimental_cluster.cluster_weights(model, |
|
**clustering_params) |
|
clustered_model = self._compile_and_fit_conv_model(clustered_model) |
|
|
|
|
|
return clustered_model |
|
|
|
def _pcqat_training(self, preserve_sparsity, quant_aware_annotate_model): |
|
"""PCQAT training on the input model.""" |
|
quant_aware_model = quantize.quantize_apply( |
|
quant_aware_annotate_model, |
|
scheme=default_8bit_cluster_preserve_quantize_scheme |
|
.Default8BitClusterPreserveQuantizeScheme(preserve_sparsity)) |
|
|
|
self.compile_and_fit(quant_aware_model) |
|
|
|
stripped_pcqat_model = strip_clustering_cqat(quant_aware_model) |
|
|
|
|
|
|
|
num_of_unique_weights_pcqat = self._get_number_of_unique_weights( |
|
stripped_pcqat_model, 1, 'kernel') |
|
|
|
sparsity_pcqat = self._get_sparsity(stripped_pcqat_model) |
|
|
|
return sparsity_pcqat, num_of_unique_weights_pcqat |
|
|
|
def testEndToEndClusterPreserve(self): |
|
"""Runs CQAT end to end and whole model is quantized.""" |
|
original_model = keras.Sequential( |
|
[layers.Dense(5, activation='softmax', input_shape=(10,))] |
|
) |
|
clustered_model = cluster.cluster_weights( |
|
original_model, |
|
**self.cluster_params) |
|
self.compile_and_fit(clustered_model) |
|
clustered_model = cluster.strip_clustering(clustered_model) |
|
num_of_unique_weights_clustering = self._get_number_of_unique_weights( |
|
clustered_model, 0, 'kernel') |
|
|
|
quant_aware_annotate_model = ( |
|
quantize.quantize_annotate_model(clustered_model)) |
|
|
|
quant_aware_model = quantize.quantize_apply( |
|
quant_aware_annotate_model, |
|
scheme=default_8bit_cluster_preserve_quantize_scheme |
|
.Default8BitClusterPreserveQuantizeScheme()) |
|
|
|
self.compile_and_fit(quant_aware_model) |
|
stripped_cqat_model = strip_clustering_cqat(quant_aware_model) |
|
|
|
|
|
|
|
num_of_unique_weights_cqat = self._get_number_of_unique_weights( |
|
stripped_cqat_model, 1, 'kernel') |
|
self.assertAllEqual(num_of_unique_weights_clustering, |
|
num_of_unique_weights_cqat) |
|
|
|
def testEndToEndClusterPreservePerLayer(self): |
|
"""Runs CQAT end to end and model is quantized per layers.""" |
|
original_model = keras.Sequential([ |
|
layers.Dense(5, activation='relu', input_shape=(10,)), |
|
layers.Dense(5, activation='softmax', input_shape=(10,)), |
|
]) |
|
clustered_model = cluster.cluster_weights( |
|
original_model, |
|
**self.cluster_params) |
|
self.compile_and_fit(clustered_model) |
|
clustered_model = cluster.strip_clustering(clustered_model) |
|
num_of_unique_weights_clustering = self._get_number_of_unique_weights( |
|
clustered_model, 1, 'kernel') |
|
|
|
def apply_quantization_to_dense(layer): |
|
if isinstance(layer, keras.layers.Dense): |
|
return quantize.quantize_annotate_layer(layer) |
|
return layer |
|
|
|
quant_aware_annotate_model = keras.models.clone_model( |
|
clustered_model, |
|
clone_function=apply_quantization_to_dense, |
|
) |
|
|
|
quant_aware_model = quantize.quantize_apply( |
|
quant_aware_annotate_model, |
|
scheme=default_8bit_cluster_preserve_quantize_scheme |
|
.Default8BitClusterPreserveQuantizeScheme()) |
|
|
|
self.compile_and_fit(quant_aware_model) |
|
stripped_cqat_model = strip_clustering_cqat( |
|
quant_aware_model) |
|
|
|
|
|
|
|
num_of_unique_weights_cqat = self._get_number_of_unique_weights( |
|
stripped_cqat_model, 2, 'kernel') |
|
self.assertAllEqual(num_of_unique_weights_clustering, |
|
num_of_unique_weights_cqat) |
|
|
|
def testEndToEndClusterPreserveOneLayer(self): |
|
"""Runs CQAT end to end and model is quantized only for a single layer.""" |
|
original_model = keras.Sequential([ |
|
layers.Dense(5, activation='relu', input_shape=(10,)), |
|
layers.Dense(5, activation='softmax', input_shape=(10,), name='qat'), |
|
]) |
|
clustered_model = cluster.cluster_weights( |
|
original_model, |
|
**self.cluster_params) |
|
self.compile_and_fit(clustered_model) |
|
clustered_model = cluster.strip_clustering(clustered_model) |
|
num_of_unique_weights_clustering = self._get_number_of_unique_weights( |
|
clustered_model, 1, 'kernel') |
|
|
|
def apply_quantization_to_dense(layer): |
|
if isinstance(layer, keras.layers.Dense): |
|
if layer.name == 'qat': |
|
return quantize.quantize_annotate_layer(layer) |
|
return layer |
|
|
|
quant_aware_annotate_model = keras.models.clone_model( |
|
clustered_model, |
|
clone_function=apply_quantization_to_dense, |
|
) |
|
|
|
quant_aware_model = quantize.quantize_apply( |
|
quant_aware_annotate_model, |
|
scheme=default_8bit_cluster_preserve_quantize_scheme |
|
.Default8BitClusterPreserveQuantizeScheme()) |
|
|
|
self.compile_and_fit(quant_aware_model) |
|
|
|
stripped_cqat_model = strip_clustering_cqat( |
|
quant_aware_model) |
|
|
|
|
|
|
|
num_of_unique_weights_cqat = self._get_number_of_unique_weights( |
|
stripped_cqat_model, 1, 'kernel') |
|
self.assertAllEqual(num_of_unique_weights_clustering, |
|
num_of_unique_weights_cqat) |
|
|
|
def testEndToEndPruneClusterPreserveQAT(self): |
|
"""Runs PCQAT end to end when we quantize the whole model.""" |
|
preserve_sparsity = True |
|
clustered_model = self._get_clustered_model(preserve_sparsity) |
|
|
|
first_layer_weights = clustered_model.layers[0].weights[1] |
|
stripped_model_before_tuning = cluster.strip_clustering( |
|
clustered_model) |
|
nr_of_unique_weights_before = self._get_number_of_unique_weights( |
|
stripped_model_before_tuning, 0, 'kernel') |
|
|
|
self.compile_and_fit(clustered_model) |
|
|
|
stripped_model_clustered = cluster.strip_clustering(clustered_model) |
|
weights_after_tuning = stripped_model_clustered.layers[0].kernel |
|
nr_of_unique_weights_after = self._get_number_of_unique_weights( |
|
stripped_model_clustered, 0, 'kernel') |
|
|
|
|
|
|
|
self.assertEqual(nr_of_unique_weights_before, nr_of_unique_weights_after) |
|
|
|
|
|
|
|
|
|
|
|
self.assertTrue( |
|
np.array_equal(first_layer_weights[:][0:2], |
|
weights_after_tuning[:][0:2])) |
|
|
|
|
|
sparsity_pruning = self._get_sparsity(stripped_model_clustered) |
|
|
|
|
|
quant_aware_annotate_model = ( |
|
quantize.quantize_annotate_model(stripped_model_clustered) |
|
) |
|
|
|
|
|
|
|
preserve_sparsity = True |
|
sparsity_pcqat, unique_weights_pcqat = self._pcqat_training( |
|
preserve_sparsity, quant_aware_annotate_model) |
|
self.assertAllGreaterEqual(np.array(sparsity_pcqat), |
|
sparsity_pruning[0]) |
|
self.assertAllEqual(nr_of_unique_weights_after, unique_weights_pcqat) |
|
|
|
def testEndToEndClusterPreserveQATClusteredPerChannel( |
|
self, data_format='channels_last'): |
|
"""Runs CQAT end to end for the model that is clustered per channel.""" |
|
|
|
nr_of_channels = 12 |
|
nr_of_clusters = 4 |
|
|
|
clustered_model = self._get_conv_clustered_model( |
|
nr_of_channels, nr_of_clusters, data_format, preserve_sparsity=False) |
|
stripped_model = cluster.strip_clustering(clustered_model) |
|
|
|
|
|
conv2d_layer = stripped_model.layers[2] |
|
self.assertEqual(conv2d_layer.name, 'conv2d') |
|
|
|
|
|
nr_unique_weights = -1 |
|
|
|
for weight in conv2d_layer.weights: |
|
if 'kernel' in weight.name: |
|
nr_unique_weights = len(np.unique(weight.numpy())) |
|
self.assertLessEqual(nr_unique_weights, nr_of_clusters*nr_of_channels) |
|
|
|
quant_aware_annotate_model = ( |
|
quantize.quantize_annotate_model(stripped_model) |
|
) |
|
|
|
quant_aware_model = quantize.quantize_apply( |
|
quant_aware_annotate_model, |
|
scheme=default_8bit_cluster_preserve_quantize_scheme |
|
.Default8BitClusterPreserveQuantizeScheme()) |
|
|
|
|
|
model = self._compile_and_fit_conv_model(quant_aware_model, 3) |
|
|
|
stripped_cqat_model = strip_clustering_cqat(model) |
|
|
|
|
|
|
|
layer_nr = 3 |
|
num_of_unique_weights_cqat = self._get_number_of_unique_weights( |
|
stripped_cqat_model, layer_nr, 'kernel') |
|
self.assertLessEqual(num_of_unique_weights_cqat, nr_unique_weights) |
|
|
|
|
|
|
|
layer = stripped_cqat_model.layers[layer_nr] |
|
weight_to_check = None |
|
if isinstance(layer, quantize.quantize_wrapper.QuantizeWrapper): |
|
for weight_item in layer.trainable_weights: |
|
if 'kernel' in weight_item.name: |
|
weight_to_check = weight_item |
|
|
|
assert weight_to_check is not None |
|
|
|
for i in range(nr_of_channels): |
|
nr_unique_weights_per_channel = len( |
|
np.unique(weight_to_check[:, :, :, i])) |
|
assert nr_unique_weights_per_channel == nr_of_clusters |
|
|
|
def testEndToEndPCQATClusteredPerChannel(self, data_format='channels_last'): |
|
"""Runs PCQAT end to end for the model that is clustered per channel.""" |
|
|
|
nr_of_channels = 12 |
|
nr_of_clusters = 4 |
|
|
|
clustered_model = self._get_conv_clustered_model( |
|
nr_of_channels, nr_of_clusters, data_format, preserve_sparsity=True) |
|
stripped_model = cluster.strip_clustering(clustered_model) |
|
|
|
|
|
conv2d_layer = stripped_model.layers[2] |
|
self.assertEqual(conv2d_layer.name, 'conv2d') |
|
|
|
|
|
nr_unique_weights = -1 |
|
|
|
for weight in conv2d_layer.weights: |
|
if 'kernel' in weight.name: |
|
nr_unique_weights = len(np.unique(weight.numpy())) |
|
self.assertLessEqual(nr_unique_weights, nr_of_clusters*nr_of_channels) |
|
|
|
|
|
|
|
control_sparsity = self._get_sparsity(stripped_model) |
|
self.assertGreater(control_sparsity[0], 0.5) |
|
|
|
quant_aware_annotate_model = ( |
|
quantize.quantize_annotate_model(stripped_model) |
|
) |
|
|
|
quant_aware_model = quantize.quantize_apply( |
|
quant_aware_annotate_model, |
|
scheme=default_8bit_cluster_preserve_quantize_scheme |
|
.Default8BitClusterPreserveQuantizeScheme()) |
|
|
|
|
|
model = self._compile_and_fit_conv_model(quant_aware_model, 3) |
|
|
|
stripped_cqat_model = strip_clustering_cqat(model) |
|
|
|
|
|
|
|
layer_nr = 3 |
|
num_of_unique_weights_cqat = self._get_number_of_unique_weights( |
|
stripped_cqat_model, layer_nr, 'kernel') |
|
self.assertLessEqual(num_of_unique_weights_cqat, nr_unique_weights) |
|
|
|
|
|
|
|
layer = stripped_cqat_model.layers[layer_nr] |
|
weight_to_check = None |
|
if isinstance(layer, quantize.quantize_wrapper.QuantizeWrapper): |
|
for weight_item in layer.trainable_weights: |
|
if 'kernel' in weight_item.name: |
|
weight_to_check = weight_item |
|
|
|
assert weight_to_check is not None |
|
|
|
for i in range(nr_of_channels): |
|
nr_unique_weights_per_channel = len( |
|
np.unique(weight_to_check[:, :, :, i])) |
|
assert nr_unique_weights_per_channel == nr_of_clusters |
|
|
|
cqat_sparsity = self._get_sparsity(stripped_cqat_model) |
|
self.assertLessEqual(cqat_sparsity[0], control_sparsity[0]) |
|
|
|
def testEndToEndPCQATClusteredPerChannelConv2d1x1(self, |
|
data_format='channels_last' |
|
): |
|
"""Runs PCQAT for model containing a 1x1 Conv2D. |
|
|
|
(with insufficient number of weights per channel). |
|
|
|
Args: |
|
data_format: Format of input data. |
|
""" |
|
nr_of_channels = 12 |
|
nr_of_clusters = 4 |
|
|
|
|
|
|
|
with self.assertWarnsRegex(Warning, |
|
r'Layer conv2d does not have enough weights'): |
|
clustered_model = self._get_conv_clustered_model( |
|
nr_of_channels, |
|
nr_of_clusters, |
|
data_format, |
|
preserve_sparsity=True, |
|
kernel_size=(1, 1)) |
|
stripped_model = cluster.strip_clustering(clustered_model) |
|
|
|
|
|
conv2d_layer = stripped_model.layers[2] |
|
self.assertEqual(conv2d_layer.name, 'conv2d') |
|
|
|
for weight in conv2d_layer.weights: |
|
if 'kernel' in weight.name: |
|
|
|
nr_original_weights = len(np.unique(weight.numpy())) |
|
self.assertLess(nr_original_weights, nr_of_channels * nr_of_clusters) |
|
|
|
|
|
|
|
for channel in range(nr_of_channels): |
|
channel_weights = ( |
|
weight[:, channel, :, :] |
|
if data_format == 'channels_first' else weight[:, :, :, channel]) |
|
nr_channel_weights = len(channel_weights) |
|
self.assertGreater(nr_channel_weights, 0) |
|
self.assertLessEqual(nr_channel_weights, nr_of_clusters) |
|
|
|
|
|
|
|
control_sparsity = self._get_sparsity(stripped_model) |
|
self.assertGreater(control_sparsity[0], 0.5) |
|
|
|
quant_aware_annotate_model = ( |
|
quantize.quantize_annotate_model(stripped_model)) |
|
|
|
with self.assertWarnsRegex( |
|
Warning, r'No clustering performed on layer quant_conv2d'): |
|
quant_aware_model = quantize.quantize_apply( |
|
quant_aware_annotate_model, |
|
scheme=default_8bit_cluster_preserve_quantize_scheme |
|
.Default8BitClusterPreserveQuantizeScheme(preserve_sparsity=True)) |
|
|
|
|
|
model = self._compile_and_fit_conv_model(quant_aware_model, 3) |
|
|
|
stripped_cqat_model = strip_clustering_cqat(model) |
|
|
|
|
|
|
|
layer_nr = 3 |
|
num_of_unique_weights_cqat = self._get_number_of_unique_weights( |
|
stripped_cqat_model, layer_nr, 'kernel') |
|
self.assertEqual(num_of_unique_weights_cqat, nr_original_weights) |
|
|
|
cqat_sparsity = self._get_sparsity(stripped_cqat_model) |
|
self.assertLessEqual(cqat_sparsity[0], control_sparsity[0]) |
|
|
|
def testPassingNonPrunedModelToPCQAT(self): |
|
"""Runs PCQAT as CQAT if the input model is not pruned.""" |
|
preserve_sparsity = False |
|
clustered_model = self._get_clustered_model(preserve_sparsity) |
|
|
|
clustered_model = cluster.strip_clustering(clustered_model) |
|
nr_of_unique_weights_after = self._get_number_of_unique_weights( |
|
clustered_model, 0, 'kernel') |
|
|
|
|
|
|
|
quant_aware_annotate_model = ( |
|
quantize.quantize_annotate_model(clustered_model) |
|
) |
|
|
|
quant_aware_model = quantize.quantize_apply( |
|
quant_aware_annotate_model, |
|
scheme=default_8bit_cluster_preserve_quantize_scheme |
|
.Default8BitClusterPreserveQuantizeScheme(True)) |
|
|
|
self.compile_and_fit(quant_aware_model) |
|
stripped_pcqat_model = strip_clustering_cqat( |
|
quant_aware_model) |
|
|
|
|
|
num_of_unique_weights_pcqat = self._get_number_of_unique_weights( |
|
stripped_pcqat_model, 1, 'kernel') |
|
self.assertAllEqual(nr_of_unique_weights_after, |
|
num_of_unique_weights_pcqat) |
|
|
|
@parameterized.parameters((0.), (2.)) |
|
def testPassingModelWithUniformWeightsToPCQAT(self, uniform_weights): |
|
"""If pruned_clustered_model has uniform weights, it won't break PCQAT.""" |
|
preserve_sparsity = True |
|
original_model = keras.Sequential([ |
|
layers.Dense(5, activation='softmax', input_shape=(10,)), |
|
layers.Flatten(), |
|
]) |
|
|
|
|
|
first_layer_weights = original_model.layers[0].get_weights() |
|
first_layer_weights[0][:] = uniform_weights |
|
original_model.layers[0].set_weights(first_layer_weights) |
|
|
|
|
|
clustering_params = { |
|
'number_of_clusters': 4, |
|
'cluster_centroids_init': cluster_config.CentroidInitialization.LINEAR, |
|
'preserve_sparsity': True |
|
} |
|
|
|
clustered_model = experimental_cluster.cluster_weights( |
|
original_model, **clustering_params) |
|
clustered_model = cluster.strip_clustering(clustered_model) |
|
|
|
nr_of_unique_weights_after = self._get_number_of_unique_weights( |
|
clustered_model, 0, 'kernel') |
|
sparsity_pruning = self._get_sparsity(clustered_model) |
|
|
|
quant_aware_annotate_model = ( |
|
quantize.quantize_annotate_model(clustered_model) |
|
) |
|
|
|
sparsity_pcqat, unique_weights_pcqat = self._pcqat_training( |
|
preserve_sparsity, quant_aware_annotate_model) |
|
self.assertAllGreaterEqual(np.array(sparsity_pcqat), |
|
sparsity_pruning[0]) |
|
self.assertAllEqual(nr_of_unique_weights_after, unique_weights_pcqat) |
|
|
|
def testTrainableWeightsBehaveCorrectlyDuringPCQAT(self): |
|
"""PCQAT zero centroid masks stay the same and trainable variables are updating between epochs.""" |
|
preserve_sparsity = True |
|
clustered_model = self._get_clustered_model(preserve_sparsity) |
|
clustered_model = cluster.strip_clustering(clustered_model) |
|
|
|
|
|
quant_aware_annotate_model = ( |
|
quantize.quantize_annotate_model(clustered_model) |
|
) |
|
|
|
quant_aware_model = quantize.quantize_apply( |
|
quant_aware_annotate_model, |
|
scheme=default_8bit_cluster_preserve_quantize_scheme |
|
.Default8BitClusterPreserveQuantizeScheme(True)) |
|
|
|
quant_aware_model.compile( |
|
loss=keras.losses.categorical_crossentropy, |
|
optimizer='adam', |
|
metrics=['accuracy'], |
|
) |
|
|
|
class CheckCentroidsAndTrainableVarsCallback(keras.callbacks.Callback): |
|
"""Check the updates of trainable variables and centroid masks.""" |
|
|
|
def on_epoch_begin(self, batch, logs=None): |
|
|
|
vars_dictionary = self.model.layers[1]._weight_vars[0][2] |
|
self.centroid_mask = vars_dictionary['centroids_mask'] |
|
self.zero_centroid_index_begin = np.where( |
|
self.centroid_mask == 0)[0] |
|
|
|
|
|
self.layer_kernel = ( |
|
self.model.layers[1].weights[3].numpy() |
|
) |
|
self.original_weight = vars_dictionary['ori_weights_vars_tf'].numpy() |
|
self.centroids = vars_dictionary['cluster_centroids_tf'].numpy() |
|
|
|
def on_epoch_end(self, batch, logs=None): |
|
|
|
vars_dictionary = self.model.layers[1]._weight_vars[0][2] |
|
self.zero_centroid_index_end = np.where( |
|
vars_dictionary['centroids_mask'] == 0)[0] |
|
assert np.array_equal( |
|
self.zero_centroid_index_begin, |
|
self.zero_centroid_index_end |
|
) |
|
|
|
|
|
assert not np.array_equal( |
|
self.layer_kernel, |
|
self.model.layers[1].weights[3].numpy() |
|
) |
|
assert not np.array_equal( |
|
self.original_weight, |
|
vars_dictionary['ori_weights_vars_tf'].numpy() |
|
) |
|
assert not np.array_equal( |
|
self.centroids, |
|
vars_dictionary['cluster_centroids_tf'].numpy() |
|
) |
|
|
|
|
|
|
|
|
|
quant_aware_model.fit( |
|
np.random.rand(20, 10), |
|
keras.utils.to_categorical(np.random.randint(5, size=(20, 1)), 5), |
|
steps_per_epoch=5, |
|
epochs=3, |
|
callbacks=[CheckCentroidsAndTrainableVarsCallback()], |
|
) |
|
|
|
|
|
if __name__ == '__main__': |
|
tf.test.main() |
|
|