diff --git a/docs/advanced/extension.rst b/docs/advanced/extension.rst index ad86051d82..b6acc4ac6a 100644 --- a/docs/advanced/extension.rst +++ b/docs/advanced/extension.rst @@ -35,6 +35,10 @@ For concreteness, let's say our custom layer ``KReverse`` is implemented in Kera def call(self, inputs): return tf.reverse(inputs, axis=[-1]) + def get_config(self): + return super().get_config() + +Make sure you define a ``get_config()`` method for your custom layer as this is needed for correct parsing. We can define the equivalent layer in hls4ml ``HReverse``, which inherits from ``hls4ml.model.layers.Layer``. .. code-block:: Python diff --git a/test/pytest/test_extensions.py b/test/pytest/test_extensions.py index 0820a58c7c..bf5c7e2981 100644 --- a/test/pytest/test_extensions.py +++ b/test/pytest/test_extensions.py @@ -19,6 +19,10 @@ def __init__(self): def call(self, inputs): return tf.reverse(inputs, axis=[-1]) + def get_config(self): + # Breaks serialization and parsing in hls4ml if not defined + return super().get_config() + # hls4ml layer implementation class HReverse(hls4ml.model.layers.Layer):