Skip to content

Commit

Permalink
Add functionality to use granularity option also for pytorch models (f…
Browse files Browse the repository at this point in the history
…astmachinelearning#1051)

* allow granularity options in pytorch parser

* pre-commit

* [pre-commit.ci] auto fixes from pre-commit hooks

* add torch to setup?

* add torch to setup2?

* add torch to setup3?

* add torch to requirements

* fix failing pytest

* adapat new batchnorm pytests to changes in interface

* addressing comments from Vladimir and Jovan

* remvoving torch from requirements

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
JanFSchulte and pre-commit-ci[bot] authored Aug 27, 2024
1 parent 2898ab2 commit 2cb6fe1
Show file tree
Hide file tree
Showing 10 changed files with 204 additions and 100 deletions.
9 changes: 4 additions & 5 deletions hls4ml/converters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from hls4ml.converters.keras_to_hls import get_supported_keras_layers # noqa: F401
from hls4ml.converters.keras_to_hls import parse_keras_model # noqa: F401
from hls4ml.converters.keras_to_hls import keras_to_hls, register_keras_layer_handler

# from hls4ml.converters.pytorch_to_hls import parse_pytorch_model # noqa: F401
from hls4ml.model import ModelGraph
from hls4ml.utils.config import create_config
from hls4ml.utils.symbolic_utils import LUTFunction
Expand Down Expand Up @@ -238,7 +240,6 @@ def convert_from_keras_model(

def convert_from_pytorch_model(
model,
input_shape,
output_dir='my-hls-test',
project_name='myproject',
input_data_tb=None,
Expand All @@ -251,7 +252,6 @@ def convert_from_pytorch_model(
Args:
model: PyTorch model to convert.
input_shape (list): The shape of the input tensor. First element is the batch size, needs to be None
output_dir (str, optional): Output directory of the generated HLS project. Defaults to 'my-hls-test'.
project_name (str, optional): Name of the HLS project. Defaults to 'myproject'.
input_data_tb (str, optional): String representing the path of input data in .npy or .dat format that will be
Expand Down Expand Up @@ -293,17 +293,16 @@ def convert_from_pytorch_model(
config = create_config(output_dir=output_dir, project_name=project_name, backend=backend, **kwargs)

config['PytorchModel'] = model
config['InputShape'] = input_shape
config['InputData'] = input_data_tb
config['OutputPredictions'] = output_data_tb
config['HLSConfig'] = {}

if hls_config is None:
hls_config = {}

model_config = hls_config.get('Model', None)
model_config = hls_config.get('Model')
config['HLSConfig']['Model'] = _check_model_config(model_config)

config['InputShape'] = hls_config.get('InputShape')
_check_hls_config(config, hls_config)

return pytorch_to_hls(config)
Expand Down
44 changes: 31 additions & 13 deletions hls4ml/converters/pytorch_to_hls.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def decorator(function):
# ----------------------------------------------------------------


def pytorch_to_hls(config):
def parse_pytorch_model(config, verbose=True):
"""Convert PyTorch model to hls4ml ModelGraph.
Args:
Expand All @@ -118,14 +118,15 @@ def pytorch_to_hls(config):
# This is a list of dictionaries to hold all the layer info we need to generate HLS
layer_list = []

print('Interpreting Model ...')

if verbose:
print('Interpreting Model ...')
reader = PyTorchFileReader(config) if isinstance(config['PytorchModel'], str) else PyTorchModelReader(config)
if type(reader.input_shape) is tuple:
input_shapes = [list(reader.input_shape)]
else:
input_shapes = list(reader.input_shape)
input_shapes = [list(shape) for shape in input_shapes]
# first element needs to 'None' as placeholder for the batch size, insert it if not present
input_shapes = [[None] + list(shape) if shape[0] is not None else list(shape) for shape in input_shapes]

model = reader.torch_model

Expand All @@ -151,7 +152,8 @@ def pytorch_to_hls(config):
output_shape = None

# Loop through layers
print('Topology:')
if verbose:
print('Topology:')
layer_counter = 0

n_inputs = 0
Expand Down Expand Up @@ -226,13 +228,14 @@ def pytorch_to_hls(config):
pytorch_class, layer_name, input_names, input_shapes, node, class_object, reader, config
)

print(
'Layer name: {}, layer type: {}, input shape: {}'.format(
layer['name'],
layer['class_name'],
input_shapes,
if verbose:
print(
'Layer name: {}, layer type: {}, input shape: {}'.format(
layer['name'],
layer['class_name'],
input_shapes,
)
)
)
layer_list.append(layer)

assert output_shape is not None
Expand Down Expand Up @@ -288,7 +291,12 @@ def pytorch_to_hls(config):
operation, layer_name, input_names, input_shapes, node, None, reader, config
)

print('Layer name: {}, layer type: {}, input shape: {}'.format(layer['name'], layer['class_name'], input_shapes))
if verbose:
print(
'Layer name: {}, layer type: {}, input shape: {}'.format(
layer['name'], layer['class_name'], input_shapes
)
)
layer_list.append(layer)

assert output_shape is not None
Expand Down Expand Up @@ -342,7 +350,12 @@ def pytorch_to_hls(config):
operation, layer_name, input_names, input_shapes, node, None, reader, config
)

print('Layer name: {}, layer type: {}, input shape: {}'.format(layer['name'], layer['class_name'], input_shapes))
if verbose:
print(
'Layer name: {}, layer type: {}, input shape: {}'.format(
layer['name'], layer['class_name'], input_shapes
)
)
layer_list.append(layer)

assert output_shape is not None
Expand All @@ -351,6 +364,11 @@ def pytorch_to_hls(config):
if len(input_layers) == 0:
input_layers = None

return layer_list, input_layers


def pytorch_to_hls(config):
layer_list, input_layers = parse_pytorch_model(config)
print('Creating HLS model')
hls_model = ModelGraph(config, layer_list, inputs=input_layers)
return hls_model
79 changes: 79 additions & 0 deletions hls4ml/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ def make_layer_config(layer):

def config_from_pytorch_model(
model,
input_shape,
granularity='model',
backend=None,
default_precision='ap_fixed<16,6>',
Expand All @@ -284,6 +285,7 @@ def config_from_pytorch_model(
Args:
model: PyTorch model
input_shape (tuple or list of tuples): The shape of the input tensor, excluding the batch size.
granularity (str, optional): Granularity of the created config. Defaults to 'model'.
Can be set to 'model', 'type' and 'layer'.
Expand Down Expand Up @@ -321,6 +323,83 @@ def config_from_pytorch_model(
model_config['Strategy'] = 'Latency'

config['Model'] = model_config
config['PytorchModel'] = model
if not (isinstance(input_shape, tuple) or (isinstance(input_shape, list) and isinstance(input_shape[0], tuple))):
raise Exception('Input shape must be tuple (single input) or list of tuples (multiple inputs)')
config['InputShape'] = input_shape

if granularity.lower() not in ['model', 'type', 'name']:
raise Exception(
f'Invalid configuration granularity specified, expected "model", "type" or "name" got "{granularity}"'
)

if backend is not None:
backend = hls4ml.backends.get_backend(backend)

from hls4ml.converters.pytorch_to_hls import parse_pytorch_model

(
layer_list,
_,
) = parse_pytorch_model(config, verbose=False)

def make_layer_config(layer):
cls_name = layer['class_name']
if 'config' in layer.keys():
if 'activation' in layer['config'].keys():
if layer['config']['activation'] == 'softmax':
cls_name = 'Softmax'

layer_cls = hls4ml.model.layers.layer_map[cls_name]
if backend is not None:
layer_cls = backend.create_layer_class(layer_cls)

layer_config = {}

config_attrs = [a for a in layer_cls.expected_attributes if a.configurable]
for attr in config_attrs:
if isinstance(attr, hls4ml.model.attributes.TypeAttribute):
precision_cfg = layer_config.setdefault('Precision', {})
name = attr.name
if name.endswith('_t'):
name = name[:-2]
if attr.default is None:
precision_cfg[name] = default_precision
else:
precision_cfg[name] = str(attr.default)
elif attr.name == 'reuse_factor':
layer_config[attr.config_name] = default_reuse_factor
else:
if attr.default is not None:
layer_config[attr.config_name] = attr.default

if layer['class_name'] == 'Input':
dtype = layer['config']['dtype']
if dtype.startswith('int') or dtype.startswith('uint'):
typename = dtype[: dtype.index('int') + 3]
width = int(dtype[dtype.index('int') + 3 :])
layer_config['Precision']['result'] = f'ap_{typename}<{width}>'
# elif bool, q[u]int, ...

return layer_config

if granularity.lower() == 'type':
type_config = {}
for layer in layer_list:
if layer['class_name'] in type_config:
continue
layer_config = make_layer_config(layer)
type_config[layer['class_name']] = layer_config

config['LayerType'] = type_config

elif granularity.lower() == 'name':
name_config = {}
for layer in layer_list:
layer_config = make_layer_config(layer)
name_config[layer['name']] = layer_config

config['LayerName'] = name_config

return config

Expand Down
33 changes: 22 additions & 11 deletions test/pytest/test_backend_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_backend_config(framework, backend, part, clock_period, clock_unc):
convert_fn = hls4ml.converters.convert_from_keras_model
else:
model = torch.nn.Sequential(torch.nn.Linear(1, 2), torch.nn.ReLU())
config = hls4ml.utils.config_from_pytorch_model(model)
config = hls4ml.utils.config_from_pytorch_model(model, input_shape=(None, 1))
convert_fn = hls4ml.converters.convert_from_pytorch_model

if clock_unc is not None:
Expand All @@ -42,16 +42,27 @@ def test_backend_config(framework, backend, part, clock_period, clock_unc):
test_dir = f'hls4mlprj_backend_config_{framework}_{backend}_part_{part}_period_{clock_period}_unc_{unc_str}'
output_dir = test_root_path / test_dir

hls_model = convert_fn(
model,
input_shape=(None, 1), # This serves as a test of handling unexpected values by the backend in keras converer
hls_config=config,
output_dir=str(output_dir),
backend=backend,
part=part,
clock_period=clock_period,
clock_uncertainty=clock_unc,
)
if framework == "keras":
hls_model = convert_fn(
model,
input_shape=(None, 1), # This serves as a test of handling unexpected values by the backend in keras converer
hls_config=config,
output_dir=str(output_dir),
backend=backend,
part=part,
clock_period=clock_period,
clock_uncertainty=clock_unc,
)
else:
hls_model = convert_fn(
model,
hls_config=config,
output_dir=str(output_dir),
backend=backend,
part=part,
clock_period=clock_period,
clock_uncertainty=clock_unc,
)

hls_model.write()

Expand Down
15 changes: 10 additions & 5 deletions test/pytest/test_batchnorm_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,12 @@ def test_batchnorm(data, backend, io_type):

default_precision = 'ac_fixed<32, 1, true>' if backend == 'Quartus' else 'ac_fixed<32, 1>'

config = hls4ml.utils.config_from_pytorch_model(model, default_precision=default_precision, granularity='name')
config = hls4ml.utils.config_from_pytorch_model(
model, (in_shape,), default_precision=default_precision, granularity='name'
)
output_dir = str(test_root_path / f'hls4mlprj_batchnorm_{backend}_{io_type}')
hls_model = hls4ml.converters.convert_from_pytorch_model(
model, (None, in_shape), backend=backend, hls_config=config, io_type=io_type, output_dir=output_dir
model, backend=backend, hls_config=config, io_type=io_type, output_dir=output_dir
)
hls_model.compile()

Expand Down Expand Up @@ -94,17 +96,20 @@ def test_batchnorm_fusion(fusion_data, backend, io_type):
# We do not have an implementation of a transpose for io_stream, need to transpose inputs and outputs outside of hls4ml
if io_type == 'io_stream':
fusion_data = np.ascontiguousarray(fusion_data.transpose(0, 2, 1))
config = hls4ml.utils.config_from_pytorch_model(model, channels_last_conversion='internal', transpose_outputs=False)
config = hls4ml.utils.config_from_pytorch_model(
model, (n_in, size_in_height), channels_last_conversion='internal', transpose_outputs=False
)
else:
config = hls4ml.utils.config_from_pytorch_model(model, channels_last_conversion='full', transpose_outputs=True)
config = hls4ml.utils.config_from_pytorch_model(
model, (n_in, size_in_height), channels_last_conversion='full', transpose_outputs=True
)

config['Model']['Strategy'] = 'Resource'

# conversion
output_dir = str(test_root_path / f'hls4mlprj_block_{backend}_{io_type}')
hls_model = hls4ml.converters.convert_from_pytorch_model(
model,
(None, n_in, size_in_height),
hls_config=config,
output_dir=output_dir,
backend=backend,
Expand Down
8 changes: 5 additions & 3 deletions test/pytest/test_merge_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,16 @@ def test_merge(merge_op, io_type, backend):
model = MergeModule(merge_op)
model.eval()

batch_input_shape = (None,) + input_shape
config = hls4ml.utils.config_from_pytorch_model(
model, default_precision='ap_fixed<32,16>', channels_last_conversion="internal", transpose_outputs=False
model,
[input_shape, input_shape],
default_precision='ap_fixed<32,16>',
channels_last_conversion="internal",
transpose_outputs=False,
)
output_dir = str(test_root_path / f'hls4mlprj_merge_pytorch_{merge_op}_{backend}_{io_type}')
hls_model = hls4ml.converters.convert_from_pytorch_model(
model,
[batch_input_shape, batch_input_shape],
hls_config=config,
output_dir=output_dir,
io_type=io_type,
Expand Down
Loading

0 comments on commit 2cb6fe1

Please sign in to comment.