diff --git a/src_cpp/torchBridge/NerlWorkerTorch.h b/src_cpp/torchBridge/NerlWorkerTorch.h index 49574f23..c6ccd096 100644 --- a/src_cpp/torchBridge/NerlWorkerTorch.h +++ b/src_cpp/torchBridge/NerlWorkerTorch.h @@ -3,9 +3,9 @@ #include #include -#include "../opennn/opennn/opennn.h" #include "../common/nerlWorker.h" #include "worker_definitions_ag.h" +#include "nifppNerlTensorTorch.h" namespace nerlnet diff --git a/src_cpp/torchBridge/nerltensorTorchDefs.h b/src_cpp/torchBridge/nerltensorTorchDefs.h index 1c1450aa..3970cfbb 100644 --- a/src_cpp/torchBridge/nerltensorTorchDefs.h +++ b/src_cpp/torchBridge/nerltensorTorchDefs.h @@ -1,3 +1,5 @@ +#pragma once + #include #include diff --git a/src_cpp/torchBridge/nifppNerlTensorTorch.h b/src_cpp/torchBridge/nifppNerlTensorTorch.h index fd18fe53..e7f81ab6 100644 --- a/src_cpp/torchBridge/nifppNerlTensorTorch.h +++ b/src_cpp/torchBridge/nifppNerlTensorTorch.h @@ -1,7 +1,8 @@ #pragma once -#include "nerltensorTorchDefs.h" #include "nifpp.h" +#include "nerltensorTorchDefs.h" + namespace nifpp { @@ -113,7 +114,7 @@ namespace nifpp std::memcpy(nifpp_bin.data, dims.data(), dims_size); std::memcpy(nifpp_bin.data + dims_size, tensor.data_ptr(), data_size); - ret_bin_term = nifpp:make(env, nifpp_bin); + ret_bin_term = nifpp::make(env, nifpp_bin); } } \ No newline at end of file diff --git a/src_cpp/torchBridge/torchNIF.cpp b/src_cpp/torchBridge/torchNIF.cpp index c6e7f96e..be6c9144 100644 --- a/src_cpp/torchBridge/torchNIF.cpp +++ b/src_cpp/torchBridge/torchNIF.cpp @@ -1 +1,13 @@ -#include "torchNIF.h" \ No newline at end of file +#include "torchNIF.h" + + +void* train_threaded_function(void* args) +{ + +} + + +void* predict_threaded_function(void* args) +{ + +} \ No newline at end of file diff --git a/src_cpp/torchBridge/torchNIF.h b/src_cpp/torchBridge/torchNIF.h index 408786e8..b127a6ef 100644 --- a/src_cpp/torchBridge/torchNIF.h +++ b/src_cpp/torchBridge/torchNIF.h @@ -1,5 +1,72 @@ #pragma once -#include "nifppNerlTensorTorch.h" -#include "nifpp.h" +#include "NerlWorkerTorch.h" +#include "bridgeController.h" +#include + +class dirty_thread_args +{ +public: + long int mid; // model id + std::shared_ptr data; + nifpp::str_atom return_tensor_type; // holds the type of tensor should be returned + std::chrono::high_resolution_clock::time_point start_time; + + + ErlNifTid tid; + ErlNifPid pid; +}; + +void* train_threaded_function(void* args); +void* predict_threaded_function(void* args); + +/* +train_nif function is called by NIF from Erlang. +It creates a TorchTensor from input data and calls the threaded train funciton +*/ +static ERL_NIF_TERM train_nif(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) +{ + std::shared_ptr* p_thread_args_ptr = new std::shared_ptr(std::make_shared()); + std::shared_ptr thread_args_ptr = *p_thread_args_ptr; + + thread_args_ptr->start_time = std::chrono::high_resolution_clock::now(); + + enum{ARG_MODEL_ID, ARG_NERLTENSOR, ARG_NERLTENSOR_TYPE}; + nifpp::str_atom tensor_type; + + nifpp::get_throws(env, argv[ARG_NERLTENSOR_TYPE],tensor_type); + assert(tensor_type == "float"); // add support for other types + thread_args_ptr->return_tensor_type = tensor_type; + + nifpp::get_throws(env, argv[ARG_MODEL_ID], thread_args_ptr->mid); + + nifpp::get_nerltensor(env, argv[ARG_NERLTENSOR], *(thread_args_ptr->data), torch::kFloat32); +} +/* +predict_nif function is called by NIF from Erlang. +It creates a TorchTensor from input data and calls the threaded predict function +*/ +static ERL_NIF_TERM predict_nif(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) +{ + +} + +static ErlNifFunc nif_funcs[] = +{ + {"get_active_models_ids_list",0, get_active_models_ids_list_nif}, + {"train_nif", 3 , train_nif}, + {"predict_nif", 3 , predict_nif}, + // {"get_weights_nif",1, get_weights_nif}, + // {"set_weights_nif",3, set_weights_nif}, + // {"encode_nif",2, encode_nif}, + // {"decode_nif",2, decode_nif}, + // {"nerltensor_sum_nif",3, nerltensor_sum_nif}, + // {"nerltensor_scalar_multiplication_nif",3,nerltensor_scalar_multiplication_nif}, + // // nerlworker functions + // {"new_nerlworker_nif", 13, new_nerlworker_nif}, + // {"test_nerlworker_nif", 13, test_nerlworker_nif}, + // {"update_nerlworker_train_params_nif", 6, update_nerlworker_train_params_nif}, + // {"remove_nerlworker_nif", 1, remove_nerlworker_nif}, + // {"get_distributed_system_train_labels_count_nif", 1, get_distributed_system_train_labels_count_nif} +};