From ae06a1c125d7a22ed98fd789af43440f7cc4d191 Mon Sep 17 00:00:00 2001 From: Reuven <44209964+reuvenperetz@users.noreply.github.com> Date: Tue, 14 Jan 2025 10:41:17 +0200 Subject: [PATCH] Fix FixedSampleInfoDataset to support multiple inputs (#1323) The previous implementation of FixedSampleInfoDataset caused issues when working with models that expected multiple inputs with different shapes. Specifically, during inference it would raise a ValueError indicating a mismatch between the expected and received number of input tensors. Instead of directly converting samples to tensors, we now separate the input tensors by their position in the input tuple, create a dictionary mapping each tensor position to its corresponding data, use tf.data.Dataset.from_tensor_slices with the dictionary structure and map the dictionary outputs back to the expected tuple format. This preserves the tensor grouping and shapes while allowing different shapes for different input layers. Co-authored-by: reuvenp --- .../core/keras/data_util.py | 29 +++++++++++++++---- .../feature_networks/gptq/gptq_test.py | 4 ++- .../test_features_runner.py | 10 +------ tests_pytest/keras/core/test_data_util.py | 15 +++++++--- 4 files changed, 39 insertions(+), 19 deletions(-) diff --git a/model_compression_toolkit/core/keras/data_util.py b/model_compression_toolkit/core/keras/data_util.py index daa678ca8..7f0accb8c 100644 --- a/model_compression_toolkit/core/keras/data_util.py +++ b/model_compression_toolkit/core/keras/data_util.py @@ -134,11 +134,30 @@ def __init__(self, samples: Sequence, sample_info: Sequence): self.samples = samples self.sample_info = sample_info - # Create a TensorFlow dataset that holds (sample, sample_info) tuples - self.tf_dataset = tf.data.Dataset.from_tensor_slices(( - tf.convert_to_tensor(self.samples), - tuple(tf.convert_to_tensor(info) for info in self.sample_info) - )) + # Get the number of tensors in each tuple (corresponds to the number of input layers the model has) + num_tensors = len(samples[0]) + + # Create separate lists: one for each input layer and separate the tuples into lists + sample_tensor_lists = [[] for _ in range(num_tensors)] + for s in samples: + for i, data_tensor in enumerate(s): + sample_tensor_lists[i].append(data_tensor) + + # In order to deal with models that have different input shapes for different layers, we need first to + # organize the data in a dictionary in order to use tf.data.Dataset.from_tensor_slices + samples_dict = {f'tensor_{i}': tensors for i, tensors in enumerate(sample_tensor_lists)} + info_dict = {f'info_{i}': tf.convert_to_tensor(info) for i, info in enumerate(self.sample_info)} + combined_dict = {**samples_dict, **info_dict} + + tf_dataset = tf.data.Dataset.from_tensor_slices(combined_dict) + + # Map the dataset to return tuples instead of dict + def reorganize_ds_outputs(ds_output): + tensors = tuple(ds_output[f'tensor_{i}'] for i in range(num_tensors)) + infos = tuple(ds_output[f'info_{i}'] for i in range(len(sample_info))) + return tensors, infos + + self.tf_dataset = tf_dataset.map(reorganize_ds_outputs) def __len__(self): return len(self.samples) diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/gptq/gptq_test.py b/tests/keras_tests/feature_networks_tests/feature_networks/gptq/gptq_test.py index e4e16feaf..eab4fd2be 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/gptq/gptq_test.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/gptq/gptq_test.py @@ -61,6 +61,7 @@ def __init__(self, quant_method=QuantizationMethod.SYMMETRIC, rounding_type=RoundingType.STE, per_channel=True, + val_batch_size=1, input_shape=(16, 16, 3), hessian_weights=True, log_norm_weights=True, @@ -79,7 +80,8 @@ def __init__(self, super().__init__(unit_test, input_shape=input_shape, - num_calibration_iter=num_calibration_iter) + num_calibration_iter=num_calibration_iter, + val_batch_size=val_batch_size) self.quant_method = quant_method self.rounding_type = rounding_type diff --git a/tests/keras_tests/feature_networks_tests/test_features_runner.py b/tests/keras_tests/feature_networks_tests/test_features_runner.py index 3d9988897..1193f2513 100644 --- a/tests/keras_tests/feature_networks_tests/test_features_runner.py +++ b/tests/keras_tests/feature_networks_tests/test_features_runner.py @@ -734,21 +734,13 @@ def test_gptq_with_sample_layer_attention(self): tf.config.run_functions_eagerly(True) kwargs = dict(per_sample=True, loss=sample_layer_attention_loss, hessian_weights=True, hessian_num_samples=None, - norm_scores=False, log_norm_weights=False, scaled_log_norm=False) + norm_scores=False, log_norm_weights=False, scaled_log_norm=False, val_batch_size=7) GradientPTQTest(self, **kwargs).run_test() GradientPTQTest(self, hessian_batch_size=16, rounding_type=RoundingType.SoftQuantizer, **kwargs).run_test() GradientPTQTest(self, hessian_batch_size=5, rounding_type=RoundingType.SoftQuantizer, gradual_activation_quantization=True, **kwargs).run_test() GradientPTQTest(self, rounding_type=RoundingType.STE, **kwargs) tf.config.run_functions_eagerly(False) - # TODO: reuven - new experimental facade needs to be tested regardless the exporter. - # def test_gptq_new_exporter(self): - # self.test_gptq(experimental_exporter=True) - - # Comment out due to problem in Tensorflow 2.8 - # def test_gptq_conv_group(self): - # GradientPTQLearnRateZeroConvGroupTest(self).run_test() - # GradientPTQWeightsUpdateConvGroupTest(self).run_test() def test_gptq_conv_group_dilation(self): # This call removes the effect of @tf.function decoration and executes the decorated function eagerly, which diff --git a/tests_pytest/keras/core/test_data_util.py b/tests_pytest/keras/core/test_data_util.py index 1eb9ee9b7..1cb6a46bf 100644 --- a/tests_pytest/keras/core/test_data_util.py +++ b/tests_pytest/keras/core/test_data_util.py @@ -65,6 +65,8 @@ def test_create_tf_dataloader_fixed_tfdataset(self, fixed_gen): for i, sample in enumerate(dataset): assert np.array_equal(sample[0].cpu().numpy(), np.full((3, 30, 20), i)) + assert np.array_equal(sample[1].cpu().numpy(), np.full((10,), i+10)) + batch_size = 16 dataloader = create_tf_dataloader(dataset, batch_size=batch_size) @@ -86,13 +88,15 @@ def test_fixed_tfdataset_too_many_requested_samples(self, fixed_gen): def test_create_tf_dataloader_fixed_tfdataset_with_info(self, fixed_gen): samples = [] for b in list(fixed_gen()): - samples.extend(tf.unstack(b[0], axis=0)) - break # Take one batch only (since this tests fixed,small dataset) + for sample in zip(tf.unstack(b[0], axis=0), tf.unstack(b[1], axis=0)): + samples.append(sample) + break # Take one batch only (since this tests fixed,small dataset) dataset = FixedSampleInfoDataset(samples, [tf.range(32)]) for i, sample_with_info in enumerate(dataset): sample, info = sample_with_info - assert np.array_equal(sample.cpu().numpy(), np.full((3, 30, 20), i)) + assert np.array_equal(sample[0].cpu().numpy(), np.full((3, 30, 20), i)) + assert np.array_equal(sample[1].cpu().numpy(), np.full((10, ), i+10)) assert info == (i,) batch_size = 16 @@ -103,7 +107,8 @@ def test_create_tf_dataloader_fixed_tfdataset_with_info(self, fixed_gen): for batch in dataloader: samples, additional_info = batch - assert samples.shape == (batch_size, 3, 30, 20) + assert samples[0].shape == (batch_size, 3, 30, 20) + assert samples[1].shape == (batch_size, 10) assert additional_info[0].shape == (batch_size,) def test_create_tf_dataloader_iterable_tfdataset_with_const_info(self, fixed_gen): @@ -113,6 +118,7 @@ def test_create_tf_dataloader_iterable_tfdataset_with_const_info(self, fixed_gen for i, sample_with_info in enumerate(dataset): sample, info = sample_with_info assert np.array_equal(sample[0].cpu().numpy(), np.full((3, 30, 20), i)) + assert np.array_equal(sample[1].cpu().numpy(), np.full((10,), i+10)) assert info == tf.constant("some_string") batch_size = 16 @@ -124,5 +130,6 @@ def test_create_tf_dataloader_iterable_tfdataset_with_const_info(self, fixed_gen for batch in dataloader: samples, additional_info = batch assert samples[0].shape == (batch_size, 3, 30, 20) + assert samples[1].shape == (batch_size, 10) assert additional_info.shape == (batch_size,) assert all(additional_info == tf.constant("some_string")) \ No newline at end of file