Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor FixedSampleInfoDataset to Support Multi-Input Models #1323

Merged
merged 1 commit into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions model_compression_toolkit/core/keras/data_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,30 @@ def __init__(self, samples: Sequence, sample_info: Sequence):
self.samples = samples
self.sample_info = sample_info

# Create a TensorFlow dataset that holds (sample, sample_info) tuples
self.tf_dataset = tf.data.Dataset.from_tensor_slices((
tf.convert_to_tensor(self.samples),
tuple(tf.convert_to_tensor(info) for info in self.sample_info)
))
# Get the number of tensors in each tuple (corresponds to the number of input layers the model has)
num_tensors = len(samples[0])

# Create separate lists: one for each input layer and separate the tuples into lists
sample_tensor_lists = [[] for _ in range(num_tensors)]
for s in samples:
for i, data_tensor in enumerate(s):
sample_tensor_lists[i].append(data_tensor)

# In order to deal with models that have different input shapes for different layers, we need first to
# organize the data in a dictionary in order to use tf.data.Dataset.from_tensor_slices
samples_dict = {f'tensor_{i}': tensors for i, tensors in enumerate(sample_tensor_lists)}
info_dict = {f'info_{i}': tf.convert_to_tensor(info) for i, info in enumerate(self.sample_info)}
combined_dict = {**samples_dict, **info_dict}

tf_dataset = tf.data.Dataset.from_tensor_slices(combined_dict)

# Map the dataset to return tuples instead of dict
def reorganize_ds_outputs(ds_output):
tensors = tuple(ds_output[f'tensor_{i}'] for i in range(num_tensors))
infos = tuple(ds_output[f'info_{i}'] for i in range(len(sample_info)))
return tensors, infos

self.tf_dataset = tf_dataset.map(reorganize_ds_outputs)

def __len__(self):
return len(self.samples)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(self,
quant_method=QuantizationMethod.SYMMETRIC,
rounding_type=RoundingType.STE,
per_channel=True,
val_batch_size=1,
input_shape=(16, 16, 3),
hessian_weights=True,
log_norm_weights=True,
Expand All @@ -80,7 +81,8 @@ def __init__(self,

super().__init__(unit_test,
input_shape=input_shape,
num_calibration_iter=num_calibration_iter)
num_calibration_iter=num_calibration_iter,
val_batch_size=val_batch_size)

self.quant_method = quant_method
self.rounding_type = rounding_type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -734,21 +734,13 @@ def test_gptq_with_sample_layer_attention(self):
tf.config.run_functions_eagerly(True)
kwargs = dict(per_sample=True, loss=sample_layer_attention_loss,
hessian_weights=True, hessian_num_samples=None,
norm_scores=False, log_norm_weights=False, scaled_log_norm=False)
norm_scores=False, log_norm_weights=False, scaled_log_norm=False, val_batch_size=7)
GradientPTQTest(self, **kwargs).run_test()
GradientPTQTest(self, hessian_batch_size=16, rounding_type=RoundingType.SoftQuantizer, **kwargs).run_test()
GradientPTQTest(self, hessian_batch_size=5, rounding_type=RoundingType.SoftQuantizer, gradual_activation_quantization=True, **kwargs).run_test()
GradientPTQTest(self, rounding_type=RoundingType.STE, **kwargs)
tf.config.run_functions_eagerly(False)

# TODO: reuven - new experimental facade needs to be tested regardless the exporter.
# def test_gptq_new_exporter(self):
# self.test_gptq(experimental_exporter=True)

# Comment out due to problem in Tensorflow 2.8
# def test_gptq_conv_group(self):
# GradientPTQLearnRateZeroConvGroupTest(self).run_test()
# GradientPTQWeightsUpdateConvGroupTest(self).run_test()

def test_gptq_conv_group_dilation(self):
# This call removes the effect of @tf.function decoration and executes the decorated function eagerly, which
Expand Down
15 changes: 11 additions & 4 deletions tests_pytest/keras/core/test_data_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def test_create_tf_dataloader_fixed_tfdataset(self, fixed_gen):

for i, sample in enumerate(dataset):
assert np.array_equal(sample[0].cpu().numpy(), np.full((3, 30, 20), i))
assert np.array_equal(sample[1].cpu().numpy(), np.full((10,), i+10))


batch_size = 16
dataloader = create_tf_dataloader(dataset, batch_size=batch_size)
Expand All @@ -86,13 +88,15 @@ def test_fixed_tfdataset_too_many_requested_samples(self, fixed_gen):
def test_create_tf_dataloader_fixed_tfdataset_with_info(self, fixed_gen):
samples = []
for b in list(fixed_gen()):
samples.extend(tf.unstack(b[0], axis=0))
break # Take one batch only (since this tests fixed,small dataset)
for sample in zip(tf.unstack(b[0], axis=0), tf.unstack(b[1], axis=0)):
samples.append(sample)
break # Take one batch only (since this tests fixed,small dataset)
dataset = FixedSampleInfoDataset(samples, [tf.range(32)])

for i, sample_with_info in enumerate(dataset):
sample, info = sample_with_info
assert np.array_equal(sample.cpu().numpy(), np.full((3, 30, 20), i))
assert np.array_equal(sample[0].cpu().numpy(), np.full((3, 30, 20), i))
assert np.array_equal(sample[1].cpu().numpy(), np.full((10, ), i+10))
assert info == (i,)

batch_size = 16
Expand All @@ -103,7 +107,8 @@ def test_create_tf_dataloader_fixed_tfdataset_with_info(self, fixed_gen):

for batch in dataloader:
samples, additional_info = batch
assert samples.shape == (batch_size, 3, 30, 20)
assert samples[0].shape == (batch_size, 3, 30, 20)
assert samples[1].shape == (batch_size, 10)
assert additional_info[0].shape == (batch_size,)

def test_create_tf_dataloader_iterable_tfdataset_with_const_info(self, fixed_gen):
Expand All @@ -113,6 +118,7 @@ def test_create_tf_dataloader_iterable_tfdataset_with_const_info(self, fixed_gen
for i, sample_with_info in enumerate(dataset):
sample, info = sample_with_info
assert np.array_equal(sample[0].cpu().numpy(), np.full((3, 30, 20), i))
assert np.array_equal(sample[1].cpu().numpy(), np.full((10,), i+10))
assert info == tf.constant("some_string")

batch_size = 16
Expand All @@ -124,5 +130,6 @@ def test_create_tf_dataloader_iterable_tfdataset_with_const_info(self, fixed_gen
for batch in dataloader:
samples, additional_info = batch
assert samples[0].shape == (batch_size, 3, 30, 20)
assert samples[1].shape == (batch_size, 10)
assert additional_info.shape == (batch_size,)
assert all(additional_info == tf.constant("some_string"))
Loading