Skip to content

Commit

Permalink
Add verbose flag to FakelyQuantKerasExporter to avoid confusion in tf…
Browse files Browse the repository at this point in the history
…lite export
  • Loading branch information
reuvenp committed Dec 28, 2023
1 parent 5306a8d commit b7df1b6
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,21 @@ class FakelyQuantKerasExporter(BaseKerasExporter):
def __init__(self,
model: keras.models.Model,
is_layer_exportable_fn: Callable,
save_model_path: str):
save_model_path: str,
verbose: bool = True):
"""
Args:
model: Model to export.
is_layer_exportable_fn: Callable to check whether a layer can be exported or not.
save_model_path: Path to save the exported model.
verbose: Whether to log information about the export process or not.
"""

super().__init__(model,
is_layer_exportable_fn,
save_model_path)
self._verbose = verbose

def export(self) -> Dict[str, type]:
"""
Expand Down Expand Up @@ -138,7 +141,8 @@ def _unwrap_quantize_wrapper(layer: Layer):
if self.exported_model is None:
Logger.critical(f'Exporter can not save model as it is not exported') # pragma: no cover

Logger.info(f'Exporting FQ Keras model to: {self.save_model_path}')
if self._verbose:
Logger.info(f'Exporting FQ Keras model to: {self.save_model_path}')

keras.models.save_model(self.exported_model, self.save_model_path)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,10 @@ def export(self):
# Use Keras exporter to quantize model's weights before converting it to TFLite.
# Since exporter saves the model, we use a tmp path for saving, and then we delete it automatically.
with tempfile.NamedTemporaryFile(suffix=TMP_KERAS_EXPORT_FORMAT) as tmp_file:
custom_objects = FakelyQuantKerasExporter(self.model,
self.is_layer_exportable_fn,
tmp_file.name).export()
FakelyQuantKerasExporter(self.model,
self.is_layer_exportable_fn,
tmp_file.name,
verbose=False).export()

model = keras_load_quantized_model(tmp_file.name)

Expand Down

0 comments on commit b7df1b6

Please sign in to comment.