Skip to content

Commit

Permalink
[AUTOGEN] Add Erlang DC hrl
Browse files Browse the repository at this point in the history
- Change definition in dc files to support atoms conventions
  • Loading branch information
leondavi committed Nov 11, 2023
1 parent 9c5647f commit aacee47
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 23 deletions.
4 changes: 3 additions & 1 deletion NerlnetBuild.sh
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,12 @@ if command -v python3 >/dev/null 2>&1; then
set -e
AUTOGENERATED_WORKER_DEFINITIONS_PATH="`pwd`/src_cpp/opennnBridge/worker_definitions_ag.h"
AUTOGENERATED_WORKER_DEFINITIONS_PATH_HRL="`pwd`/src_erl/NerlnetApp/src/worker_definitions_ag.hrl"
AUTOGENERATED_DC_DEFINITIONS_PATH_HRL="`pwd`/src_erl/NerlnetApp/src/dc_definitions_ag.hrl"

echo "$NERLNET_BUILD_PREFIX Generate auto-generated files"
python3 src_py/nerlPlanner/CppHeadersExporter.py --output $AUTOGENERATED_WORKER_DEFINITIONS_PATH #--debug
python3 src_py/nerlPlanner/ErlHeadersExporter.py --output $AUTOGENERATED_WORKER_DEFINITIONS_PATH_HRL #--debug
python3 src_py/nerlPlanner/ErlHeadersExporter.py --gen_worker_fields_hrl --output $AUTOGENERATED_WORKER_DEFINITIONS_PATH_HRL #--debug
python3 src_py/nerlPlanner/ErlHeadersExporter.py --gen_dc_fields_hrl --output $AUTOGENERATED_DC_DEFINITIONS_PATH_HRL #--debug
set +e
else
echo "$NERLNET_BUILD_PREFIX Python 3 is not installed"
Expand Down
44 changes: 44 additions & 0 deletions src_erl/NerlnetApp/src/dc_definitions_ag.hrl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
% This is an auto generated .hrl file
% Generated by Nerlplanner version: 1.0.0

-define(DC_KEY_NERLNET_SETTINGS_ATOM,nerlnetSettings).
-define(DC_KEY_FREQUENCY_ATOM,frequency).
-define(DC_KEY_BATCH_SIZE_ATOM,batchSize).
-define(DC_KEY_DEVICES_ATOM,devices).
-define(DC_KEY_CLIENTS_ATOM,clients).
-define(DC_KEY_WORKERS_ATOM,workers).
-define(DC_KEY_MODEL_SHA_ATOM,model_sha).
-define(DC_KEY_SOURCES_ATOM,sources).
-define(DC_KEY_ROUTERS_ATOM,routers).
-define(DC_NAME_FIELD_ATOM,name).
-define(DC_WORKER_MODEL_SHA_FIELD_ATOM,model_sha).
-define(DC_IPV4_FIELD_ATOM,ipv4).
-define(DC_PORT_FIELD_ATOM,port).
-define(DC_ARGS_FIELD_ATOM,args).
-define(DC_ENTITIES_FIELD_ATOM,entities).
-define(DC_POLICY_FIELD_ATOM,policy).
-define(DC_EPOCHS_FIELD_ATOM,epochs).
-define(DC_TYPE_FIELD_ATOM,type).
-define(DC_FREQUENCY_FIELD_ATOM,frequency).
-define(DC_WORKERS_FIELD_ATOM,workers).

-define(DC_KEY_NERLNET_SETTINGS_STR,"nerlnetSettings").
-define(DC_KEY_FREQUENCY_STR,"frequency").
-define(DC_KEY_BATCH_SIZE_STR,"batchSize").
-define(DC_KEY_DEVICES_STR,"devices").
-define(DC_KEY_CLIENTS_STR,"clients").
-define(DC_KEY_WORKERS_STR,"workers").
-define(DC_KEY_MODEL_SHA_STR,"model_sha").
-define(DC_KEY_SOURCES_STR,"sources").
-define(DC_KEY_ROUTERS_STR,"routers").
-define(DC_NAME_FIELD_STR,"name").
-define(DC_WORKER_MODEL_SHA_FIELD_STR,"model_sha").
-define(DC_IPV4_FIELD_STR,"ipv4").
-define(DC_PORT_FIELD_STR,"port").
-define(DC_ARGS_FIELD_STR,"args").
-define(DC_ENTITIES_FIELD_STR,"entities").
-define(DC_POLICY_FIELD_STR,"policy").
-define(DC_EPOCHS_FIELD_STR,"epochs").
-define(DC_TYPE_FIELD_STR,"type").
-define(DC_FREQUENCY_FIELD_STR,"frequency").
-define(DC_WORKERS_FIELD_STR,"workers").
16 changes: 8 additions & 8 deletions src_erl/NerlnetApp/src/worker_definitions_ag.hrl
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
% This is an auto generated .hrl file
% Generated by Nerlplanner version: 1.0.0

-define(KEY_MODEL_TYPE,modelType).
-define(KEY_LAYER_SIZES_LIST,layersSizes).
-define(KEY_LAYER_TYPES_LIST,layerTypesList).
-define(KEY_LAYERS_FUNCTIONS,layers_functions).
-define(KEY_LOSS_METHOD,lossMethod).
-define(KEY_LEARNING_RATE,lr).
-define(KEY_EPOCHS,epochs).
-define(KEY_OPTIMIZER_TYPE,optimizer).
-define(WORKER_KEY_MODEL_TYPE,modelType).
-define(WORKER_KEY_LAYER_SIZES_LIST,layersSizes).
-define(WORKER_KEY_LAYER_TYPES_LIST,layerTypesList).
-define(WORKER_KEY_LAYERS_FUNCTIONS,layers_functions).
-define(WORKER_KEY_LOSS_METHOD,lossMethod).
-define(WORKER_KEY_LEARNING_RATE,lr).
-define(WORKER_KEY_EPOCHS,epochs).
-define(WORKER_KEY_OPTIMIZER_TYPE,optimizer).
49 changes: 44 additions & 5 deletions src_py/nerlPlanner/ErlHeadersExporter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import os
from ErlHeadersExporterDefs import *
from JsonDistributedConfigDefs import *
from JsonElementWorkerDefinitions import *
from Definitions import VERSION as NERLPLANNER_VERSION

Expand All @@ -11,6 +12,10 @@ def gen_erlang_exporter_logger(message : str):
if DEBUG:
print(f'[NERLPLANNER][AUTO_HEADER_GENERATOR][DEBUG] {message}')

def path_validator(path : str):
if os.path.dirname(path):
os.makedirs(os.path.dirname(path), exist_ok=True)

def gen_worker_fields_hrl(header_path : str, debug : bool = False):
global DEBUG
DEBUG = debug
Expand All @@ -29,13 +34,13 @@ def gen_worker_fields_hrl(header_path : str, debug : bool = False):
'KEY_LAYER_TYPES_LIST', 'KEY_LAYERS_FUNCTIONS',
'KEY_LOSS_METHOD', 'KEY_LEARNING_RATE',
'KEY_EPOCHS', 'KEY_OPTIMIZER_TYPE']
fields_list_strs = [f'WORKER_{x}' for x in fields_list_strs]

fields_list_defs = [ Definition(fields_list_strs[idx], f'{Definition.assert_not_atom(fields_list_vals[idx])}') for idx in range(len(fields_list_vals))]
[gen_erlang_exporter_logger(x.generate_code()) for x in fields_list_defs]


if os.path.dirname(header_path):
os.makedirs(os.path.dirname(header_path), exist_ok=True)
path_validator(header_path)

with open(header_path, 'w') as f:
f.write(auto_generated_header.generate_code())
Expand All @@ -53,16 +58,50 @@ def gen_dc_fields_hrl(header_path : str, debug : bool = False):
nerlplanner_version = Comment(f'Generated by Nerlplanner version: {NERLPLANNER_VERSION}')
gen_erlang_exporter_logger(nerlplanner_version.generate_code())

#TODO
fields_list_vals_atoms = [KEY_NERLNET_SETTINGS, KEY_FREQUENCY, KEY_BATCH_SIZE,
KEY_DEVICES, KEY_CLIENTS, KEY_WORKERS, KEY_MODEL_SHA,
KEY_SOURCES, KEY_ROUTERS, NAME_FIELD, WORKER_MODEL_SHA_FIELD,
IPV4_FIELD, PORT_FIELD, ARGS_FIELD, ENTITIES_FIELD,
POLICY_FIELD, EPOCHS_FIELD, TYPE_FIELD, FREQUENCY_FIELD,
WORKERS_FIELD]
fields_list_vals_strs = [f'"{x}"' for x in fields_list_vals_atoms]
fields_list_strs = ['KEY_NERLNET_SETTINGS', 'KEY_FREQUENCY', 'KEY_BATCH_SIZE',
'KEY_DEVICES', 'KEY_CLIENTS', 'KEY_WORKERS', 'KEY_MODEL_SHA',
'KEY_SOURCES', 'KEY_ROUTERS', 'NAME_FIELD', 'WORKER_MODEL_SHA_FIELD',
'IPV4_FIELD', 'PORT_FIELD', 'ARGS_FIELD', 'ENTITIES_FIELD',
'POLICY_FIELD', 'EPOCHS_FIELD', 'TYPE_FIELD', 'FREQUENCY_FIELD',
'WORKERS_FIELD']
fields_list_strs_atom = [f'DC_{x}_ATOM' for x in fields_list_strs]
fields_list_strs_string = [f'DC_{x}_STR' for x in fields_list_strs]

fields_list_defs_atoms = [ Definition(fields_list_strs_atom[idx], f'{fields_list_vals_atoms[idx]}') for idx in range(len(fields_list_strs))]
[gen_erlang_exporter_logger(x.generate_code()) for x in fields_list_defs_atoms]

fields_list_defs_strings = [ Definition(fields_list_strs_string[idx], f'{fields_list_vals_strs[idx]}') for idx in range(len(fields_list_strs))]
[gen_erlang_exporter_logger(x.generate_code()) for x in fields_list_defs_strings]

path_validator(header_path)

with open(header_path, 'w') as f:
f.write(auto_generated_header.generate_code())
f.write(nerlplanner_version.generate_code())
f.write(EMPTY_LINE)
[f.write(x.generate_code()) for x in fields_list_defs_atoms]
f.write(EMPTY_LINE)
[f.write(x.generate_code()) for x in fields_list_defs_strings]

def main():
parser = argparse.ArgumentParser(description='Generate C++ header file for nerlPlanner')
parser.add_argument('-o', '--output', help='output header file path', required=True)
parser.add_argument('-d', '--debug', help='debug mode', action='store_true')
parser.add_argument('--gen_worker_fields_hrl', help='debug mode', action='store_true')
parser.add_argument('--gen_dc_fields_hrl', help='debug mode', action='store_true')

args = parser.parse_args()
gen_worker_fields_hrl(args.output, args.debug)
gen_dc_fields_hrl(args.output, args.debug)
if args.gen_worker_fields_hrl:
gen_worker_fields_hrl(args.output, args.debug)
if args.gen_dc_fields_hrl:
gen_dc_fields_hrl(args.output, args.debug)

if __name__=="__main__":
main()
Expand Down
19 changes: 16 additions & 3 deletions src_py/nerlPlanner/JsonDistributedConfigDefs.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,26 @@
# Any change of this file influences autogenerated Erlang files
# Please change version of Nerlnet Planner if this file is changed

KEY_NERLNET_SETTINGS = "NerlNetSettings"
# The following definitions are also treated as atoms in Erlang
# Definition must start with lower case letter
KEY_NERLNET_SETTINGS = "nerlnetSettings"
KEY_FREQUENCY = "frequency"
KEY_BATCH_SIZE = "batchSize"
KEY_DEVICES = "devices"
KEY_CLIENTS = "clients"
KEY_WORKERS = "workers"
KEY_MODEL_SHA = "model-sha"
KEY_MODEL_SHA = "model_sha"
KEY_SOURCES = "sources"
KEY_ROUTERS = "routers"

NAME_FIELD = "name"
WORKER_MODEL_SHA_FIELD = "model-sha"
WORKER_MODEL_SHA_FIELD = "model_sha"
IPV4_FIELD = "ipv4"
PORT_FIELD = "port"
ARGS_FIELD = "args"
ENTITIES_FIELD = "entities"
POLICY_FIELD = "policy"
EPOCHS_FIELD = "epochs"
TYPE_FIELD = "type"
FREQUENCY_FIELD = "frequency"
WORKERS_FIELD = "workers"
12 changes: 6 additions & 6 deletions tests/inputJsonsFiles/dc_test_synt_1d_2c_1s_4r_4w.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"NerlNetSettings": {
"nerlnetSettings": {
"frequency": "60",
"batchSize": "50"
},
Expand Down Expand Up @@ -65,22 +65,22 @@
"workers": [
{
"name": "w1",
"model-sha": "5396cc8dbd1407c408021a16bb4e014e780de95cdf62680c19cac81139ad791c"
"model_sha": "5396cc8dbd1407c408021a16bb4e014e780de95cdf62680c19cac81139ad791c"
},
{
"name": "w2",
"model-sha": "5396cc8dbd1407c408021a16bb4e014e780de95cdf62680c19cac81139ad791c"
"model_sha": "5396cc8dbd1407c408021a16bb4e014e780de95cdf62680c19cac81139ad791c"
},
{
"name": "w3",
"model-sha": "5396cc8dbd1407c408021a16bb4e014e780de95cdf62680c19cac81139ad791c"
"model_sha": "5396cc8dbd1407c408021a16bb4e014e780de95cdf62680c19cac81139ad791c"
},
{
"name": "w4",
"model-sha": "5396cc8dbd1407c408021a16bb4e014e780de95cdf62680c19cac81139ad791c"
"model_sha": "5396cc8dbd1407c408021a16bb4e014e780de95cdf62680c19cac81139ad791c"
}
],
"model-sha": {
"model_sha": {
"5396cc8dbd1407c408021a16bb4e014e780de95cdf62680c19cac81139ad791c": {
"modelType": "5",
"_doc_modelType": " approximation:1 | classification:2 | forecasting:3 | encoder_decoder:4 | nn:5 | autoencoder:6 | ae-classifier:7 | fed-client:8 | fed-server:9 |",
Expand Down

0 comments on commit aacee47

Please sign in to comment.