Skip to content

Commit

Permalink
[TORCH] train/predict nif header
Browse files Browse the repository at this point in the history
  • Loading branch information
leondavi committed Jul 24, 2024
1 parent 1a4c4a7 commit 5f688da
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 5 deletions.
36 changes: 36 additions & 0 deletions src_cpp/torchBridge/nerltensorTorchDefs.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

#include <nerltensor.h>
#include <torch/torch.h>
#include <unordered_map>
#include <string>

namespace nerlnet
{
Expand All @@ -11,4 +13,38 @@ using TorchTensor = torch::Tensor;
enum {DIMS_CASE_1D,DIMS_CASE_2D,DIMS_CASE_3D};
enum {DIMS_X_IDX,DIMS_Y_IDX,DIMS_Z_IDX,DIMS_TOTAL};

const std::map<std::string, c10::ScalarType> torch_dtype_map = {
{"float", c10::ScalarType::Float},
{"double", c10::ScalarType::Double},
{"int", c10::ScalarType::Int},
{"long", c10::ScalarType::Long},
{"bool", c10::ScalarType::Bool},
{"uint8", c10::ScalarType::Byte},
{"int8", c10::ScalarType::Char},
{"int16", c10::ScalarType::Short},
{"int32", c10::ScalarType::Int},
{"int64", c10::ScalarType::Long},
{"float16", c10::ScalarType::Half},
{"float32", c10::ScalarType::Float},
{"float64", c10::ScalarType::Double},
{"qint8", c10::ScalarType::QInt8},
{"quint8", c10::ScalarType::QUInt8},
{"qint32", c10::ScalarType::QInt32},
{"bfloat16", c10::ScalarType::BFloat16},
{"bool", c10::ScalarType::Bool},
{"uint8", c10::ScalarType::Byte},
{"int8", c10::ScalarType::Char},
{"int16", c10::ScalarType::Short},
{"int32", c10::ScalarType::Int},
{"int64", c10::ScalarType::Long},
{"float16", c10::ScalarType::Half},
{"float32", c10::ScalarType::Float},
{"float64", c10::ScalarType::Double}};

inline c10::ScalarType get_torch_dtype(const std::string &dtype_str)
{
assert (torch_dtype_map.find(dtype_str) != torch_dtype_map.end());
return torch_dtype_map.at(dtype_str);
}

} // namespace nerlnet
8 changes: 8 additions & 0 deletions src_cpp/torchBridge/torchNIF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,19 @@

void* train_threaded_function(void* args)
{
std::shared_ptr<dirty_thread_args>* p_thread_args_ptr = static_cast<std::shared_ptr<dirty_thread_args>*>(args);
std::shared_ptr<dirty_thread_args> thread_args_ptr = *p_thread_args_ptr;
delete p_thread_args_ptr;

// TODO implement training
}


void* predict_threaded_function(void* args)
{
std::shared_ptr<dirty_thread_args>* p_thread_args_ptr = static_cast<std::shared_ptr<dirty_thread_args>*>(args);
std::shared_ptr<dirty_thread_args> thread_args_ptr = *p_thread_args_ptr;
delete p_thread_args_ptr;

// TODO implement prediction
}
70 changes: 65 additions & 5 deletions src_cpp/torchBridge/torchNIF.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,26 +30,86 @@ static ERL_NIF_TERM train_nif(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[
std::shared_ptr<dirty_thread_args>* p_thread_args_ptr = new std::shared_ptr<dirty_thread_args>(std::make_shared<dirty_thread_args>());
std::shared_ptr<dirty_thread_args> thread_args_ptr = *p_thread_args_ptr;

nifpp::str_atom tensor_type;
c10::ScalarType torch_dtype;

thread_args_ptr->start_time = std::chrono::high_resolution_clock::now();

enum{ARG_MODEL_ID, ARG_NERLTENSOR, ARG_NERLTENSOR_TYPE};
nifpp::str_atom tensor_type;

nifpp::get_throws(env, argv[ARG_NERLTENSOR_TYPE],tensor_type);
assert(tensor_type == "float"); // add support for other types
thread_args_ptr->return_tensor_type = tensor_type;
// extract model id
nifpp::get_throws(env, argv[ARG_MODEL_ID], thread_args_ptr->mid);
torch_dtype = nerlnet::get_torch_dtype(tensor_type);
// extract nerltensor
nifpp::get_nerltensor<float>(env, argv[ARG_NERLTENSOR], *(thread_args_ptr->data), torch_dtype);

nifpp::get_throws(env, argv[ARG_MODEL_ID], thread_args_ptr->mid);

nifpp::get_nerltensor<float>(env, argv[ARG_NERLTENSOR], *(thread_args_ptr->data), torch::kFloat32);
char* thread_name = "train_thread";
int thread_create_status = enif_thread_create(thread_name, &(thread_args_ptr->tid), train_threaded_function, (void*) p_thread_args_ptr, NULL);
void** exit_code;
if (thread_create_status != 0)
{
LogError("failed to call enif_thread_create with train_nif");
nifpp::str_atom ret_status("train_nif_error");
return nifpp::make(env, ret_status);
}
else
{
thread_create_status = enif_thread_join(thread_args_ptr->tid, exit_code );
if (thread_create_status != 0)
{
LogError("failed to join with train_nif");
nifpp::str_atom ret_status("train_nif_error");
return nifpp::make(env, ret_status);
}
}
return nifpp::make(env, "ok");
}
/*
predict_nif function is called by NIF from Erlang.
It creates a TorchTensor from input data and calls the threaded predict function
*/
static ERL_NIF_TERM predict_nif(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[])
{
std::shared_ptr<dirty_thread_args>* p_thread_args_ptr = new std::shared_ptr<dirty_thread_args>(std::make_shared<dirty_thread_args>());
std::shared_ptr<dirty_thread_args> thread_args_ptr = *p_thread_args_ptr;

nifpp::str_atom tensor_type;
c10::ScalarType torch_dtype;

thread_args_ptr->start_time = std::chrono::high_resolution_clock::now();

enum{ARG_MODEL_ID, ARG_NERLTENSOR, ARG_NERLTENSOR_TYPE};

nifpp::get_throws(env, argv[ARG_NERLTENSOR_TYPE],tensor_type);
thread_args_ptr->return_tensor_type = tensor_type;
// extract model id
nifpp::get_throws(env, argv[ARG_MODEL_ID], thread_args_ptr->mid);
torch_dtype = nerlnet::get_torch_dtype(tensor_type);
// extract nerltensor
nifpp::get_nerltensor<float>(env, argv[ARG_NERLTENSOR], *(thread_args_ptr->data), torch_dtype);

char* thread_name = "predict_thread";
int thread_create_status = enif_thread_create(thread_name, &(thread_args_ptr->tid), predict_threaded_function, (void*) p_thread_args_ptr, NULL);
void** exit_code;
if (thread_create_status != 0)
{
LogError("failed to call enif_thread_create with predict_nif");
nifpp::str_atom ret_status("predict_nif_error");
return nifpp::make(env, ret_status);
}
else
{
thread_create_status = enif_thread_join(thread_args_ptr->tid, exit_code );
if (thread_create_status != 0)
{
LogError("failed to join with predict_nif");
nifpp::str_atom ret_status("predict_nif_error");
return nifpp::make(env, ret_status);
}
}
return nifpp::make(env, "ok");
}

static ErlNifFunc nif_funcs[] =
Expand Down

0 comments on commit 5f688da

Please sign in to comment.