Skip to content

Commit

Permalink
add: VOICEVOX CORE用の初期化経路を構築する
Browse files Browse the repository at this point in the history
  • Loading branch information
qryxip committed Jun 25, 2024
1 parent b6c41c6 commit 7642ebc
Show file tree
Hide file tree
Showing 6 changed files with 237 additions and 2 deletions.
10 changes: 8 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ strip = true
codegen-units = 1

[package.metadata.docs.rs]
features = [ "ndarray", "half", "operator-libraries", "fetch-models", "load-dynamic", "copy-dylibs" ]
features = [ "ndarray", "half", "operator-libraries", "fetch-models", "load-dynamic", "copy-dylibs", "__init-for-voicevox" ]
targets = ["x86_64-unknown-linux-gnu", "wasm32-unknown-unknown"]
rustdoc-args = [ "--cfg", "docsrs" ]

Expand Down Expand Up @@ -80,9 +80,16 @@ vitis = [ "voicevox-ort-sys/vitis" ]
cann = [ "voicevox-ort-sys/cann" ]
qnn = [ "voicevox-ort-sys/qnn" ]

# 動的ライブラリの読み込みから`OrtEnv`の作成までを、VOICEVOX独自の方法で行えるようにする。
#
# ortとしての通常の初期化の経路は禁止される。
__init-for-voicevox = []

[dependencies]
anyhow = "1.0"
ndarray = { version = "0.15", optional = true }
thiserror = "1.0"
once_cell = "1.19.0"
voicevox-ort-sys = { version = "2.0.0-rc.2", path = "ort-sys" }
libloading = { version = "0.8", optional = true }

Expand All @@ -101,7 +108,6 @@ js-sys = "0.3"
web-sys = "0.3"

[dev-dependencies]
anyhow = "1.0"
ureq = "2.1"
image = "0.25"
test-log = { version = "0.2", default-features = false, features = [ "trace" ] }
Expand Down
1 change: 1 addition & 0 deletions ort-sys/VERSION_NUMBER
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
1.17.3
13 changes: 13 additions & 0 deletions ort-sys/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,10 @@ fn prepare_libort_dir() -> (PathBuf, bool) {
copy_libraries(&lib_dir.join("lib"), &out_dir);
}

let our_version = include_str!("./VERSION_NUMBER");
let their_version = fs::read_to_string(lib_dir.join("VERSION_NUMBER")).unwrap_or_else(|e| panic!("`VERSION_NUMBER`を読めませんでした: {e}"));
assert_eq!(our_version.trim_end(), their_version.trim_end(), "`VERSION_NUMBER`が異なります");

(lib_dir, true)
}
#[cfg(not(feature = "download-binaries"))]
Expand Down Expand Up @@ -421,6 +425,15 @@ fn real_main(link: bool) {
}

fn main() {
if cfg!(feature = "download-binaries") {
let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap());
fs::write(
out_dir.join("downloaded_version.rs"),
format!("#[macro_export] macro_rules! downloaded_version(() => ({:?}));", include_str!("./VERSION_NUMBER").trim_end())
)
.unwrap();
}

if env::var("DOCS_RS").is_ok() {
return;
}
Expand Down
2 changes: 2 additions & 0 deletions ort-sys/src/internal/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
pub mod dirs;

include!(concat!(env!("OUT_DIR"), "/downloaded_version.rs"));
6 changes: 6 additions & 0 deletions src/environment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,17 @@ impl EnvironmentBuilder {

/// Commit the environment configuration and set the global environment.
pub fn commit(self) -> Result<()> {
if cfg!(feature = "__init-for-voicevox") {
panic!("`__init-for-voicevox`により禁止されています");
}
// drop global reference to previous environment
if let Some(env_arc) = unsafe { (*G_ENV.cell.get()).take() } {
drop(env_arc);
}
self.commit_()
}

pub(crate) fn commit_(self) -> Result<()> {
let (env_ptr, has_global_threadpool) = if let Some(global_thread_pool) = self.global_thread_pool_options {
let mut env_ptr: *mut ort_sys::OrtEnv = std::ptr::null_mut();
let logging_function: ort_sys::OrtLoggingFunction = Some(custom_logger);
Expand Down
207 changes: 207 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ use std::{
};

pub use ort_sys as sys;
#[cfg(feature = "download-binaries")]
pub use ort_sys::downloaded_version;

#[cfg(feature = "load-dynamic")]
pub use self::environment::init_from;
Expand Down Expand Up @@ -73,6 +75,38 @@ pub use self::value::{
ValueRefMut, ValueType, ValueTypeMarker
};

/// このクレートの`load-dynamic`が有効化されていなければコンパイルエラー。
#[cfg(feature = "load-dynamic")]
#[macro_export]
macro_rules! assert_load_dynamic_is_enabled {
($_:literal $(,)?) => {};
}

/// このクレートの`load-dynamic`が有効化されていなければコンパイルエラー。
#[cfg(not(feature = "load-dynamic"))]
#[macro_export]
macro_rules! assert_load_dynamic_is_enabled {
($msg:literal $(,)?) => {
::std::compile_error!($msg);
};
}

/// このクレートの`load-dynamic`が無効化されていなければコンパイルエラー。
#[cfg(feature = "load-dynamic")]
#[macro_export]
macro_rules! assert_load_dynamic_is_disabled {
($msg:literal $(,)?) => {
::std::compile_error!($msg);
};
}

/// このクレートの`load-dynamic`が無効化されていなければコンパイルエラー。
#[cfg(not(feature = "load-dynamic"))]
#[macro_export]
macro_rules! assert_load_dynamic_is_disabled {
($_:literal $(,)?) => {};
}

#[cfg(not(all(target_arch = "x86", target_os = "windows")))]
macro_rules! extern_system_fn {
($(#[$meta:meta])* fn $($tt:tt)*) => ($(#[$meta])* extern "C" fn $($tt)*);
Expand Down Expand Up @@ -100,6 +134,9 @@ pub(crate) static G_ORT_LIB: OnceLock<Arc<libloading::Library>> = OnceLock::new(

#[cfg(feature = "load-dynamic")]
pub(crate) fn dylib_path() -> &'static String {
if cfg!(feature = "__init-for-voicevox") {
panic!("`__init-for-voicevox`により禁止されています");
}
G_ORT_DYLIB_PATH.get_or_init(|| {
let path = match std::env::var("ORT_DYLIB_PATH") {
Ok(s) if !s.is_empty() => s,
Expand All @@ -116,6 +153,13 @@ pub(crate) fn dylib_path() -> &'static String {

#[cfg(feature = "load-dynamic")]
pub(crate) fn lib_handle() -> &'static libloading::Library {
#[cfg(feature = "__init-for-voicevox")]
if true {
return &G_ENV_FOR_VOICEVOX
.get()
.expect("`try_init_from`または`try_init`で初期化されていなくてはなりません")
.dylib;
}
G_ORT_LIB.get_or_init(|| {
// resolve path relative to executable
let path: std::path::PathBuf = dylib_path().into();
Expand All @@ -135,6 +179,155 @@ pub(crate) fn lib_handle() -> &'static libloading::Library {
})
}

#[cfg(feature = "__init-for-voicevox")]
static G_ENV_FOR_VOICEVOX: once_cell::sync::OnceCell<EnvHandle> = once_cell::sync::OnceCell::new();

#[cfg(feature = "__init-for-voicevox")]
static G_ORT_API_FOR_ENV_BUILD: std::sync::Mutex<Option<AssertSendSync<NonNull<ort_sys::OrtApi>>>> = std::sync::Mutex::new(None);

#[cfg(feature = "__init-for-voicevox")]
#[cfg_attr(docsrs, doc(cfg(feature = "__init-for-voicevox")))]
#[derive(Debug)]
pub struct EnvHandle {
_env: std::sync::Arc<Environment>,
api: AssertSendSync<NonNull<ort_sys::OrtApi>>,
#[cfg(feature = "load-dynamic")]
dylib: libloading::Library
}

#[cfg(feature = "__init-for-voicevox")]
impl EnvHandle {
/// インスタンスが既に作られているならそれを得る。
///
/// 作られていなければ`None`。
pub fn get() -> Option<&'static Self> {
G_ENV_FOR_VOICEVOX.get()
}
}

#[cfg(feature = "__init-for-voicevox")]
#[derive(Clone, Copy, Debug)]
struct AssertSendSync<T>(T);

// SAFETY: `OrtApi`はスレッドセーフとされているはず
#[cfg(feature = "__init-for-voicevox")]
unsafe impl Send for AssertSendSync<NonNull<ort_sys::OrtApi>> {}

// SAFETY: `OrtApi`はスレッドセーフとされているはず
#[cfg(feature = "__init-for-voicevox")]
unsafe impl Sync for AssertSendSync<NonNull<ort_sys::OrtApi>> {}

/// VOICEVOX CORE用に、`OrtEnv`の作成までをやる。
///
/// 一度成功したら以後は同じ参照を返す。
#[cfg(all(feature = "__init-for-voicevox", feature = "load-dynamic"))]
#[cfg_attr(docsrs, doc(cfg(all(feature = "__init-for-voicevox", feature = "load-dynamic"))))]
pub fn try_init_from(filename: &std::ffi::OsStr, tp_options: Option<EnvironmentGlobalThreadPoolOptions>) -> anyhow::Result<&'static EnvHandle> {
use anyhow::bail;
use ort_sys::ORT_API_VERSION;

G_ENV_FOR_VOICEVOX.get_or_try_init(|| {
let (dylib, api) = unsafe {
let dylib = libloading::Library::new(filename)?;

// この下にある`api()`のものをできるだけ真似る

let base_getter: libloading::Symbol<unsafe extern "C" fn() -> *const ort_sys::OrtApiBase> = dylib
.get(b"OrtGetApiBase")
.expect("`OrtGetApiBase` must be present in ONNX Runtime dylib");
let base: *const ort_sys::OrtApiBase = base_getter();
assert_ne!(base, ptr::null());

let get_version_string: extern_system_fn! { unsafe fn () -> *const c_char } =
(*base).GetVersionString.expect("`GetVersionString` must be present in `OrtApiBase`");
let version_string = get_version_string();
let version_string = CStr::from_ptr(version_string).to_string_lossy();
tracing::info!("Loaded ONNX Runtime dylib with version '{version_string}'");

let lib_minor_version = version_string.split('.').nth(1).map_or(0, |x| x.parse::<u32>().unwrap_or(0));
match lib_minor_version.cmp(&MINOR_VERSION) {
// TODO: libvoicevox_onnxruntimeを使うようになったらこのメッセージは不要
std::cmp::Ordering::Less if filename == "onnxruntime.dll" => {
bail!(r"バージョン{version_string}のonnxruntime.dllが解決されました。Windows\System32下にある古いonnxruntime.dllかもしれません");
}
std::cmp::Ordering::Less => bail!(
"ort 2.0 is not compatible with the ONNX Runtime binary found at `{}`; expected GetVersionString to return '1.{MINOR_VERSION}.x', but got \
'{version_string}'",
filename.to_string_lossy(),
),
std::cmp::Ordering::Greater => tracing::warn!(
"ort 2.0 may have compatibility issues with the ONNX Runtime binary found at `{}`; expected GetVersionString to return \
'1.{MINOR_VERSION}.x', but got '{version_string}'",
filename.to_string_lossy(),
),
std::cmp::Ordering::Equal => {}
};

let get_api: extern_system_fn! { unsafe fn(u32) -> *const ort_sys::OrtApi } = (*base).GetApi.expect("`GetApi` must be present in `OrtApiBase`");
let api = get_api(ORT_API_VERSION);
(dylib, api)
};
let api = AssertSendSync(NonNull::new(api.cast_mut()).unwrap_or_else(|| panic!("`GetApi({ORT_API_VERSION})`が失敗しました")));

let _env = create_env(api, tp_options)?;

Ok(EnvHandle { _env, api, dylib })
})
}

/// VOICEVOX CORE用に、`OrtEnv`の作成までをやる。
///
/// 一度成功したら以後は同じ参照を返す。
#[cfg(all(feature = "__init-for-voicevox", any(doc, not(feature = "load-dynamic"))))]
#[cfg_attr(docsrs, doc(cfg(all(feature = "__init-for-voicevox", not(feature = "load-dynamic")))))]
pub fn try_init(tp_options: Option<EnvironmentGlobalThreadPoolOptions>) -> anyhow::Result<&'static EnvHandle> {
use ort_sys::ORT_API_VERSION;

G_ENV_FOR_VOICEVOX.get_or_try_init(|| {
let api = unsafe {
// この下にある`api()`のものをできるだけ真似る
let base: *const ort_sys::OrtApiBase = ort_sys::OrtGetApiBase();
assert_ne!(base, ptr::null());
let get_api: extern_system_fn! { unsafe fn(u32) -> *const ort_sys::OrtApi } = (*base).GetApi.expect("`GetApi` must be present in `OrtApiBase`");
get_api(ORT_API_VERSION)
};
let api = NonNull::new(api.cast_mut())
.unwrap_or_else(|| panic!("`GetApi({ORT_API_VERSION})`が失敗しました。おそらく1.{MINOR_VERSION}より古いものがリンクされています"));
let api = AssertSendSync(api);

let _env = create_env(api, tp_options)?;

Ok(EnvHandle { _env, api })
})
}

#[cfg(feature = "__init-for-voicevox")]
fn create_env(
api: AssertSendSync<NonNull<ort_sys::OrtApi>>,
tp_options: Option<EnvironmentGlobalThreadPoolOptions>
) -> anyhow::Result<std::sync::Arc<Environment>> {
*G_ORT_API_FOR_ENV_BUILD.lock().unwrap_or_else(|e| panic!("{e}")) = Some(api);
let _unset_api = UnsetOrtApi;

let mut env = EnvironmentBuilder::default().with_name(env!("CARGO_PKG_NAME"));
if let Some(tp_options) = tp_options {
env = env.with_global_thread_pool(tp_options);
}
env.commit_()?;

return Ok(get_environment().expect("失敗しないはず").clone());

struct UnsetOrtApi;

impl Drop for UnsetOrtApi {
fn drop(&mut self) {
if let Ok(mut api) = G_ORT_API_FOR_ENV_BUILD.lock() {
*api = None;
}
}
}
}

pub(crate) static G_ORT_API: OnceLock<AtomicPtr<ort_sys::OrtApi>> = OnceLock::new();

/// Returns a pointer to the global [`ort_sys::OrtApi`] object.
Expand All @@ -144,6 +337,20 @@ pub(crate) static G_ORT_API: OnceLock<AtomicPtr<ort_sys::OrtApi>> = OnceLock::ne
/// - Getting the `OrtApi` struct fails, due to `ort` loading an unsupported version of ONNX Runtime.
/// - Loading the ONNX Runtime dynamic library fails if the `load-dynamic` feature is enabled.
pub fn api() -> NonNull<ort_sys::OrtApi> {
#[cfg(feature = "__init-for-voicevox")]
if true {
return G_ENV_FOR_VOICEVOX
.get()
.map(|&EnvHandle { api: AssertSendSync(api), .. }| api)
.or_else(|| {
G_ORT_API_FOR_ENV_BUILD
.lock()
.unwrap_or_else(|e| panic!("{e}"))
.as_ref()
.map(|&AssertSendSync(api)| api)
})
.expect("`try_init_from`または`try_init`で初期化されていなくてはなりません");
}
unsafe {
NonNull::new_unchecked(
G_ORT_API
Expand Down

0 comments on commit 7642ebc

Please sign in to comment.