diff --git a/hf_xet/Cargo.lock b/hf_xet/Cargo.lock index 1cefd30e..38563790 100644 --- a/hf_xet/Cargo.lock +++ b/hf_xet/Cargo.lock @@ -444,6 +444,7 @@ dependencies = [ "cas_types", "http 1.1.0", "lz4_flex", + "merkledb", "merklehash", "tempfile", "tracing", @@ -1433,6 +1434,8 @@ dependencies = [ "parutils", "pyo3", "tokio", + "tracing", + "tracing-subscriber", ] [[package]] diff --git a/hf_xet/Cargo.toml b/hf_xet/Cargo.toml index 32d0989f..e35a59b8 100644 --- a/hf_xet/Cargo.toml +++ b/hf_xet/Cargo.toml @@ -17,4 +17,6 @@ pyo3 = { version = "0.20.2", features = [ data = { path = "../data" } tokio = { version = "1.36", features = ["full"] } parutils = { path = "../parutils" } +tracing = "0.1.*" +tracing-subscriber = { version = "0.3", features = ["tracing-log"] } diff --git a/hf_xet/src/lib.rs b/hf_xet/src/lib.rs index 8e3dc2b4..3524caa7 100644 --- a/hf_xet/src/lib.rs +++ b/hf_xet/src/lib.rs @@ -1,5 +1,6 @@ mod data_client; mod config; +mod log; use pyo3::{pyfunction, PyResult}; use pyo3::exceptions::PyException; @@ -82,6 +83,7 @@ impl PyPointerFile { #[pymodule] pub fn hf_xet(_py: Python<'_>, m: &PyModule) -> PyResult<()> { + log::initialize_logging(); m.add_function(wrap_pyfunction!(upload_files, m)?)?; m.add_function(wrap_pyfunction!(download_files, m)?)?; m.add_class::()?; diff --git a/hf_xet/src/log.rs b/hf_xet/src/log.rs new file mode 100644 index 00000000..e564264d --- /dev/null +++ b/hf_xet/src/log.rs @@ -0,0 +1,22 @@ +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter}; + +/// Default log level for the library to use. Override using `RUST_LOG` env variable. +/// TODO: probably change default to warn or error before shipping. +const DEFAULT_LOG_LEVEL: &str = "info"; + +pub fn initialize_logging() { + // TODO: maybe have an env variable for writing to a log file instead of stderr + let fmt_layer = tracing_subscriber::fmt::layer() + .with_line_number(true) + .with_file(true) + .with_target(false) + .json(); + + let filter_layer = EnvFilter::try_from_default_env() + .or_else(|_| EnvFilter::try_new(DEFAULT_LOG_LEVEL)) + .unwrap_or_default(); + tracing_subscriber::registry() + .with(fmt_layer) + .with(filter_layer) + .init(); +}