diff --git a/src_cpp/torchBridge/nifppNerlTensorTorch.h b/src_cpp/torchBridge/nifppNerlTensorTorch.h index bb1e97f9..d8ffb905 100644 --- a/src_cpp/torchBridge/nifppNerlTensorTorch.h +++ b/src_cpp/torchBridge/nifppNerlTensorTorch.h @@ -19,6 +19,8 @@ namespace nifpp // Declarations template int get_nerltensor_dims(ErlNifEnv *env , ERL_NIF_TERM bin_term, nerltensor_dims &dims_info); template int get_nerltensor(ErlNifEnv *env , ERL_NIF_TERM bin_term, TorchTensor &tensor, torch::ScalarType torch_dtype); + template void make_tensor(ErlNifEnv *env , nifpp::TERM &ret_bin_term, TorchTensor &tensor); + // Definitions template int get_nerltensor_dims(ErlNifEnv *env , ERL_NIF_TERM bin_term, nerltensor_dims &dims_info) @@ -86,4 +88,30 @@ namespace nifpp std::memcpy(tensor.data_ptr(),bin.data + skip_dims_bytes, sizeof(BasicType)*tensor.numel()); } + template void make_tensor(ErlNifEnv *env , nifpp::TERM &ret_bin_term, TorchTensor &tensor) + { + std::vector dims; + dims.resize(DIMS_TOTAL); + for (int dim=0; dim < DIMS_TOTAL; dim++) + { + if (dim < tensor.sizes().Length()) + { + dims[dim] = static_cast(tensor.sizes()[dim]); + } + else + { + dims[dim] = 1; + } + } + size_t dims_size = DIMS_TOTAL * sizeof(BasicType); + size_t data_size = tensor.numel() * sizeof(BasicType); + + nifpp::binary nifpp_bin(dims_size + data_size); + + std::memcpy(nifpp_bin.data, dims.data(), dims_size); + std::memcpy(nifpp_bin.data + dims_size, tensor.data_ptr(), data_size); + + ret_bin_term = nifpp:make(env, nifpp_bin); + } + } \ No newline at end of file