Skip to content

Commit

Permalink
Fix FixedSampleInfoDataset to support multiple inputs (#1323)
Browse files Browse the repository at this point in the history
The previous implementation of FixedSampleInfoDataset caused issues when working with models that expected multiple inputs with different shapes. Specifically, during inference it would raise a ValueError indicating a mismatch between the expected and received number of input tensors.
Instead of directly converting samples to tensors, we now separate the input tensors by their position in the input tuple, create a dictionary mapping each tensor position to its corresponding data, use tf.data.Dataset.from_tensor_slices with the dictionary structure and map the dictionary outputs back to the expected tuple format.
This preserves the tensor grouping and shapes while allowing different shapes for different input layers.

Co-authored-by: reuvenp <[email protected]>
  • Loading branch information
reuvenperetz and reuvenp authored Jan 14, 2025
1 parent 1f1a1cc commit ae06a1c
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 19 deletions.
29 changes: 24 additions & 5 deletions model_compression_toolkit/core/keras/data_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,30 @@ def __init__(self, samples: Sequence, sample_info: Sequence):
self.samples = samples
self.sample_info = sample_info

# Create a TensorFlow dataset that holds (sample, sample_info) tuples
self.tf_dataset = tf.data.Dataset.from_tensor_slices((
tf.convert_to_tensor(self.samples),
tuple(tf.convert_to_tensor(info) for info in self.sample_info)
))
# Get the number of tensors in each tuple (corresponds to the number of input layers the model has)
num_tensors = len(samples[0])

# Create separate lists: one for each input layer and separate the tuples into lists
sample_tensor_lists = [[] for _ in range(num_tensors)]
for s in samples:
for i, data_tensor in enumerate(s):
sample_tensor_lists[i].append(data_tensor)

# In order to deal with models that have different input shapes for different layers, we need first to
# organize the data in a dictionary in order to use tf.data.Dataset.from_tensor_slices
samples_dict = {f'tensor_{i}': tensors for i, tensors in enumerate(sample_tensor_lists)}
info_dict = {f'info_{i}': tf.convert_to_tensor(info) for i, info in enumerate(self.sample_info)}
combined_dict = {**samples_dict, **info_dict}

tf_dataset = tf.data.Dataset.from_tensor_slices(combined_dict)

# Map the dataset to return tuples instead of dict
def reorganize_ds_outputs(ds_output):
tensors = tuple(ds_output[f'tensor_{i}'] for i in range(num_tensors))
infos = tuple(ds_output[f'info_{i}'] for i in range(len(sample_info)))
return tensors, infos

self.tf_dataset = tf_dataset.map(reorganize_ds_outputs)

def __len__(self):
return len(self.samples)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(self,
quant_method=QuantizationMethod.SYMMETRIC,
rounding_type=RoundingType.STE,
per_channel=True,
val_batch_size=1,
input_shape=(16, 16, 3),
hessian_weights=True,
log_norm_weights=True,
Expand All @@ -79,7 +80,8 @@ def __init__(self,

super().__init__(unit_test,
input_shape=input_shape,
num_calibration_iter=num_calibration_iter)
num_calibration_iter=num_calibration_iter,
val_batch_size=val_batch_size)

self.quant_method = quant_method
self.rounding_type = rounding_type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -734,21 +734,13 @@ def test_gptq_with_sample_layer_attention(self):
tf.config.run_functions_eagerly(True)
kwargs = dict(per_sample=True, loss=sample_layer_attention_loss,
hessian_weights=True, hessian_num_samples=None,
norm_scores=False, log_norm_weights=False, scaled_log_norm=False)
norm_scores=False, log_norm_weights=False, scaled_log_norm=False, val_batch_size=7)
GradientPTQTest(self, **kwargs).run_test()
GradientPTQTest(self, hessian_batch_size=16, rounding_type=RoundingType.SoftQuantizer, **kwargs).run_test()
GradientPTQTest(self, hessian_batch_size=5, rounding_type=RoundingType.SoftQuantizer, gradual_activation_quantization=True, **kwargs).run_test()
GradientPTQTest(self, rounding_type=RoundingType.STE, **kwargs)
tf.config.run_functions_eagerly(False)

# TODO: reuven - new experimental facade needs to be tested regardless the exporter.
# def test_gptq_new_exporter(self):
# self.test_gptq(experimental_exporter=True)

# Comment out due to problem in Tensorflow 2.8
# def test_gptq_conv_group(self):
# GradientPTQLearnRateZeroConvGroupTest(self).run_test()
# GradientPTQWeightsUpdateConvGroupTest(self).run_test()

def test_gptq_conv_group_dilation(self):
# This call removes the effect of @tf.function decoration and executes the decorated function eagerly, which
Expand Down
15 changes: 11 additions & 4 deletions tests_pytest/keras/core/test_data_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def test_create_tf_dataloader_fixed_tfdataset(self, fixed_gen):

for i, sample in enumerate(dataset):
assert np.array_equal(sample[0].cpu().numpy(), np.full((3, 30, 20), i))
assert np.array_equal(sample[1].cpu().numpy(), np.full((10,), i+10))


batch_size = 16
dataloader = create_tf_dataloader(dataset, batch_size=batch_size)
Expand All @@ -86,13 +88,15 @@ def test_fixed_tfdataset_too_many_requested_samples(self, fixed_gen):
def test_create_tf_dataloader_fixed_tfdataset_with_info(self, fixed_gen):
samples = []
for b in list(fixed_gen()):
samples.extend(tf.unstack(b[0], axis=0))
break # Take one batch only (since this tests fixed,small dataset)
for sample in zip(tf.unstack(b[0], axis=0), tf.unstack(b[1], axis=0)):
samples.append(sample)
break # Take one batch only (since this tests fixed,small dataset)
dataset = FixedSampleInfoDataset(samples, [tf.range(32)])

for i, sample_with_info in enumerate(dataset):
sample, info = sample_with_info
assert np.array_equal(sample.cpu().numpy(), np.full((3, 30, 20), i))
assert np.array_equal(sample[0].cpu().numpy(), np.full((3, 30, 20), i))
assert np.array_equal(sample[1].cpu().numpy(), np.full((10, ), i+10))
assert info == (i,)

batch_size = 16
Expand All @@ -103,7 +107,8 @@ def test_create_tf_dataloader_fixed_tfdataset_with_info(self, fixed_gen):

for batch in dataloader:
samples, additional_info = batch
assert samples.shape == (batch_size, 3, 30, 20)
assert samples[0].shape == (batch_size, 3, 30, 20)
assert samples[1].shape == (batch_size, 10)
assert additional_info[0].shape == (batch_size,)

def test_create_tf_dataloader_iterable_tfdataset_with_const_info(self, fixed_gen):
Expand All @@ -113,6 +118,7 @@ def test_create_tf_dataloader_iterable_tfdataset_with_const_info(self, fixed_gen
for i, sample_with_info in enumerate(dataset):
sample, info = sample_with_info
assert np.array_equal(sample[0].cpu().numpy(), np.full((3, 30, 20), i))
assert np.array_equal(sample[1].cpu().numpy(), np.full((10,), i+10))
assert info == tf.constant("some_string")

batch_size = 16
Expand All @@ -124,5 +130,6 @@ def test_create_tf_dataloader_iterable_tfdataset_with_const_info(self, fixed_gen
for batch in dataloader:
samples, additional_info = batch
assert samples[0].shape == (batch_size, 3, 30, 20)
assert samples[1].shape == (batch_size, 10)
assert additional_info.shape == (batch_size,)
assert all(additional_info == tf.constant("some_string"))

0 comments on commit ae06a1c

Please sign in to comment.