Skip to content

Commit

Permalink
Deadlock fix. (#140)
Browse files Browse the repository at this point in the history
Fix for issue where: 
- multiple python threads call into download_files.  
- Thread 1 begins to initializes the runtime, including logging. 
- Thread 2 acquires the GIL but then blocks on the init lock on the
runtime.
- Thread 1 hits the logging section, which requires the GIL, and is
blocked.

Fix is to: 
- Initialize logging and runtime on module load.  
- Delay initializing logging until after runtime write lock is released
for the case where CTRL-C forces the runtime to be re-initialized.
  • Loading branch information
hoytak authored Jan 14, 2025
1 parent 32cf083 commit e691905
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 50 deletions.
61 changes: 32 additions & 29 deletions hf_xet/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion hf_xet/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,12 @@ impl PyPointerFile {
}

#[pymodule]
pub fn hf_xet(m: &Bound<'_, PyModule>) -> PyResult<()> {
pub fn hf_xet(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(upload_files, m)?)?;
m.add_function(wrap_pyfunction!(download_files, m)?)?;
m.add_class::<PyPointerFile>()?;

// Init the threadpool
runtime::init_threadpool(py)?;
Ok(())
}
24 changes: 19 additions & 5 deletions hf_xet/src/log.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::env;
use std::sync::Arc;
use std::sync::{Arc, OnceLock};

use pyo3::Python;
use tracing_subscriber::filter::FilterFn;
Expand All @@ -8,7 +8,7 @@ use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::{EnvFilter, Layer};
use utils::ThreadPool;

use crate::log_buffer::{get_telemetry_task, LogBufferLayer, TELEMETRY_PRE_ALLOC_BYTES};
use crate::log_buffer::{get_telemetry_task, LogBufferLayer, TelemetryTaskInfo, TELEMETRY_PRE_ALLOC_BYTES};

/// Default log level for the library to use. Override using `RUST_LOG` env variable.
#[cfg(not(debug_assertions))]
Expand All @@ -17,7 +17,7 @@ const DEFAULT_LOG_LEVEL: &str = "warn";
#[cfg(debug_assertions)]
const DEFAULT_LOG_LEVEL: &str = "info";

pub fn initialize_logging(py: Python, runtime: Arc<ThreadPool>) {
fn init_global_logging(py: Python) -> Option<TelemetryTaskInfo> {
let fmt_layer = tracing_subscriber::fmt::layer()
.with_line_number(true)
.with_file(true)
Expand All @@ -30,10 +30,11 @@ pub fn initialize_logging(py: Python, runtime: Arc<ThreadPool>) {

if env::var("HF_HUB_DISABLE_TELEMETRY").as_deref() == Ok("1") {
tracing_subscriber::registry().with(fmt_layer).with(filter_layer).init();
None
} else {
let telemetry_buffer_layer = LogBufferLayer::new(py, TELEMETRY_PRE_ALLOC_BYTES);
let telemetry_task =
get_telemetry_task(telemetry_buffer_layer.buffer.clone(), telemetry_buffer_layer.stats.clone());
let telemetry_task_info: TelemetryTaskInfo =
(telemetry_buffer_layer.buffer.clone(), telemetry_buffer_layer.stats.clone());

let telemetry_filter_layer =
telemetry_buffer_layer.with_filter(FilterFn::new(|meta| meta.target() == "client_telemetry"));
Expand All @@ -44,6 +45,19 @@ pub fn initialize_logging(py: Python, runtime: Arc<ThreadPool>) {
.with(telemetry_filter_layer)
.init();

Some(telemetry_task_info)
}
}

pub fn initialize_runtime_logging(py: Python, runtime: Arc<ThreadPool>) {
static GLOBAL_TELEMETRY_TASK_INFO: OnceLock<Option<TelemetryTaskInfo>> = OnceLock::new();

// First get or init the global logging componenents.
let telemetry_task_info = GLOBAL_TELEMETRY_TASK_INFO.get_or_init(move || init_global_logging(py));

// Spawn the telemetry logging.
if let Some(ref tti) = telemetry_task_info {
let telemetry_task = get_telemetry_task(tti.clone());
let _telemetry_task = runtime.spawn(telemetry_task);
}
}
5 changes: 4 additions & 1 deletion hf_xet/src/log_buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,10 @@ pub fn get_telemetry_endpoint() -> Option<String> {
}))
}

pub async fn get_telemetry_task(log_buffer: Arc<Mutex<BipBuffer<u8>>>, log_stats: Arc<LogBufferStats>) {
pub type TelemetryTaskInfo = (Arc<Mutex<BipBuffer<u8>>>, Arc<LogBufferStats>);

pub async fn get_telemetry_task(telemetry_task_info: TelemetryTaskInfo) {
let (log_buffer, log_stats) = telemetry_task_info;
let client = reqwest::Client::new();
let telemetry_url = format!("{}/{}", get_telemetry_endpoint().unwrap_or_default(), TELEMETRY_SUFFIX);

Expand Down
44 changes: 30 additions & 14 deletions hf_xet/src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,16 +100,7 @@ fn signal_check_background_loop() {
}
}

// This function initializes the runtime if not present, otherwise returns the existing one.
fn get_threadpool(py: Python) -> PyResult<Arc<ThreadPool>> {
{
// First try a read lock to see if it's already initialized.
let guard = MULTITHREADED_RUNTIME.read().unwrap();
if let Some(ref existing) = *guard {
return Ok(existing.clone());
}
}

pub fn init_threadpool(py: Python) -> PyResult<Arc<ThreadPool>> {
// Need to initialize. Upgrade to write lock.
let mut guard = MULTITHREADED_RUNTIME.write().unwrap();

Expand All @@ -121,9 +112,6 @@ fn get_threadpool(py: Python) -> PyResult<Arc<ThreadPool>> {
// Create a new Tokio runtime.
let runtime = Arc::new(ThreadPool::new().map_err(convert_multithreading_error)?);

// Initialize the logging
log::initialize_logging(py, runtime.clone());

// Check the signal handler
check_sigint_handler()?;

Expand All @@ -133,10 +121,38 @@ fn get_threadpool(py: Python) -> PyResult<Arc<ThreadPool>> {
// Spawn a background non-tokio thread to check the sigint flag.
std::thread::spawn(move || signal_check_background_loop());

// Return the handle to use to run tasks.
// Drop the guard and initialize the logging.
//
// We want to drop this first is that multiple threads entering this runtime
// may cause a deadlock if the thread that has the GIL tries to acquire the runtime,
// but then the logging expects the GIL in order to initialize it properly.
//
// In most cases, this will done on module initialization; however, after CTRL-C, the runtime is
// initialized lazily and so putting this here avoids the deadlock (and possibly some info! or other
// error statements may not be sent to python if the other thread continues ahead of the logging
// being initialized.)
drop(guard);

// Initialize the logging
log::initialize_runtime_logging(py, runtime.clone());

// Return the runtime
Ok(runtime)
}

// This function initializes the runtime if not present, otherwise returns the existing one.
fn get_threadpool(py: Python) -> PyResult<Arc<ThreadPool>> {
// First try a read lock to see if it's already initialized.
{
let guard = MULTITHREADED_RUNTIME.read().unwrap();
if let Some(ref existing) = *guard {
return Ok(existing.clone());
}
}
// Init and return
init_threadpool(py)
}

fn convert_multithreading_error(e: MultithreadedRuntimeError) -> PyErr {
PyRuntimeError::new_err(format!("Xet Runtime Error: {}", e))
}
Expand Down

0 comments on commit e691905

Please sign in to comment.