Skip to content

Commit

Permalink
Merge pull request fastmachinelearning#1006 from fastmachinelearning/…
Browse files Browse the repository at this point in the history
…pre-commit-and-keras

Fix pre-commit warning and change '.h5' to '.keras' for written output
  • Loading branch information
jmitrevs authored May 3, 2024
2 parents 46ec22b + ce33496 commit ed55394
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion hls4ml/writer/catapult_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,7 +884,7 @@ def write_yml(self, model):
"""

def keras_model_representer(dumper, keras_model):
model_path = model.config.get_output_dir() + '/keras_model.h5'
model_path = model.config.get_output_dir() + '/keras_model.keras'
keras_model.save(model_path)
return dumper.represent_scalar('!keras_model', model_path)

Expand Down
2 changes: 1 addition & 1 deletion hls4ml/writer/quartus_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1322,7 +1322,7 @@ def write_yml(self, model):
"""

def keras_model_representer(dumper, keras_model):
model_path = model.config.get_output_dir() + '/keras_model.h5'
model_path = model.config.get_output_dir() + '/keras_model.keras'
keras_model.save(model_path)
return dumper.represent_scalar('!keras_model', model_path)

Expand Down
2 changes: 1 addition & 1 deletion hls4ml/writer/vivado_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,7 @@ def write_yml(self, model):
"""

def keras_model_representer(dumper, keras_model):
model_path = model.config.get_output_dir() + '/keras_model.h5'
model_path = model.config.get_output_dir() + '/keras_model.keras'
keras_model.save(model_path)
return dumper.represent_scalar('!keras_model', model_path)

Expand Down
2 changes: 1 addition & 1 deletion test/pytest/test_weight_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,5 @@ def test_weight_writer(k, i, f):
print(w_paths[0])
assert len(w_paths) == 1
w_loaded = np.loadtxt(w_paths[0], delimiter=',').reshape(1, 1)
print(f'{w[0,0]:.14}', f'{w_loaded[0,0]:.14}')
print(f'{w[0, 0]:.14}', f'{w_loaded[0, 0]:.14}')
assert np.all(w == w_loaded)

0 comments on commit ed55394

Please sign in to comment.