Skip to content

Commit

Permalink
add option to not write tar.gz for oneAPI and Quartus (fastmachinelea…
Browse files Browse the repository at this point in the history
…rning#1189)

* add option to not write tar.gz for oneAPI and Quartus

* fix the docsting for initial config

* ignore extra parameters in oneAPI initial config
  • Loading branch information
jmitrevs authored Feb 8, 2025
1 parent 4d23e9f commit e690afe
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 6 deletions.
19 changes: 18 additions & 1 deletion hls4ml/backends/oneapi/oneapi_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,13 +129,30 @@ def get_default_flow(self):
def get_writer_flow(self):
return self._writer_flow

def create_initial_config(self, part='Arria10', clock_period=5, io_type='io_parallel'):
def create_initial_config(self, part='Arria10', clock_period=5, io_type='io_parallel', write_tar=False, **_):
"""Create initial configuration of the oneAPI backend.
Args:
part (str, optional): The FPGA part to be used. Defaults to 'Arria10'.
clock_period (int, optional): The clock period. Defaults to 5.
io_type (str, optional): Type of implementation used. One of
'io_parallel' or 'io_stream'. Defaults to 'io_parallel'.
write_tar (bool, optional): If True, compresses the output directory into a .tar.gz file. Defaults to False.
Returns:
dict: initial configuration.
"""

config = {}

config['Part'] = part if part is not None else 'Arria10'
config['ClockPeriod'] = clock_period
config['IOType'] = io_type
config['HLSConfig'] = {}
config['WriterConfig'] = {
# TODO: add namespace
'WriteTar': write_tar,
}

return config

Expand Down
5 changes: 4 additions & 1 deletion hls4ml/backends/quartus/quartus_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,16 @@ def get_default_flow(self):
def get_writer_flow(self):
return self._writer_flow

def create_initial_config(self, part='Arria10', clock_period=5, io_type='io_parallel', **_):
def create_initial_config(self, part='Arria10', clock_period=5, io_type='io_parallel', write_tar=False, **_):
config = {}

config['Part'] = part if part is not None else 'Arria10'
config['ClockPeriod'] = clock_period if clock_period is not None else 5
config['IOType'] = io_type if io_type is not None else 'io_parallel'
config['HLSConfig'] = {}
config['WriterConfig'] = {
'WriteTar': write_tar,
}

return config

Expand Down
8 changes: 6 additions & 2 deletions hls4ml/writer/oneapi_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,8 +955,12 @@ def write_tar(self, model):
model (ModelGraph): the hls4ml model.
"""

with tarfile.open(model.config.get_output_dir() + '.tar.gz', mode='w:gz') as archive:
archive.add(model.config.get_output_dir(), recursive=True)
if model.config.get_writer_config().get('WriteTar', False):
tar_path = model.config.get_output_dir() + '.tar.gz'
if os.path.exists(tar_path):
os.remove(tar_path)
with tarfile.open(model.config.get_output_dir() + '.tar.gz', mode='w:gz') as archive:
archive.add(model.config.get_output_dir(), recursive=True)

def write_hls(self, model):
print('Writing HLS project')
Expand Down
8 changes: 6 additions & 2 deletions hls4ml/writer/quartus_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1345,8 +1345,12 @@ def write_tar(self, model):
model (ModelGraph): the hls4ml model.
"""

with tarfile.open(model.config.get_output_dir() + '.tar.gz', mode='w:gz') as archive:
archive.add(model.config.get_output_dir(), recursive=True)
if model.config.get_writer_config().get('WriteTar', False):
tar_path = model.config.get_output_dir() + '.tar.gz'
if os.path.exists(tar_path):
os.remove(tar_path)
with tarfile.open(model.config.get_output_dir() + '.tar.gz', mode='w:gz') as archive:
archive.add(model.config.get_output_dir(), recursive=True)

def write_hls(self, model):
print('Writing HLS project')
Expand Down

0 comments on commit e690afe

Please sign in to comment.