diff --git a/boring/src/ssl/mod.rs b/boring/src/ssl/mod.rs index 51a9f50f..eafc20a6 100644 --- a/boring/src/ssl/mod.rs +++ b/boring/src/ssl/mod.rs @@ -478,6 +478,9 @@ pub struct SelectCertError(ffi::ssl_select_cert_result_t); impl SelectCertError { /// A fatal error occured and the handshake should be terminated. pub const ERROR: Self = Self(ffi::ssl_select_cert_result_t::ssl_select_cert_error); + + /// The operation could not be completed and should be retried later. + pub const RETRY: Self = Self(ffi::ssl_select_cert_result_t::ssl_select_cert_retry); } /// Extension types, to be used with `ClientHello::get_extension`. @@ -3197,6 +3200,11 @@ impl MidHandshakeSslStream { self.stream.ssl() } + /// Returns a mutable reference to the `Ssl` of the stream. + pub fn ssl_mut(&mut self) -> &mut SslRef { + self.stream.ssl_mut() + } + /// Returns the underlying error which interrupted this handshake. pub fn error(&self) -> &Error { &self.error @@ -3451,6 +3459,11 @@ impl SslStream { pub fn ssl(&self) -> &SslRef { &self.ssl } + + /// Returns a mutable reference to the `Ssl` object associated with this stream. + pub fn ssl_mut(&mut self) -> &mut SslRef { + &mut self.ssl + } } impl Read for SslStream { diff --git a/tokio-boring/Cargo.toml b/tokio-boring/Cargo.toml index 009dd580..88ba52d0 100644 --- a/tokio-boring/Cargo.toml +++ b/tokio-boring/Cargo.toml @@ -31,6 +31,7 @@ pq-experimental = ["boring/pq-experimental"] [dependencies] boring = { workspace = true } boring-sys = { workspace = true } +once_cell = { workspace = true } tokio = { workspace = true } [dev-dependencies] diff --git a/tokio-boring/src/async_callbacks.rs b/tokio-boring/src/async_callbacks.rs new file mode 100644 index 00000000..b2ec4c0f --- /dev/null +++ b/tokio-boring/src/async_callbacks.rs @@ -0,0 +1,262 @@ +use boring::ex_data::Index; +use boring::ssl::{self, ClientHello, PrivateKeyMethod, Ssl, SslContextBuilder}; +use once_cell::sync::Lazy; +use std::future::Future; +use std::pin::Pin; +use std::task::{ready, Context, Poll, Waker}; + +type BoxSelectCertFuture = ExDataFuture>; + +type BoxSelectCertFinish = Box) -> Result<(), AsyncSelectCertError>>; + +/// The type of futures returned by [`AsyncPrivateKeyMethod`] methods. +pub type BoxPrivateKeyMethodFuture = + ExDataFuture>; + +/// The type of callbacks returned by [`BoxPrivateKeyMethodFuture`]. +pub type BoxPrivateKeyMethodFinish = + Box Result>; + +type ExDataFuture = Pin + Send + Sync>>; + +pub(crate) static TASK_WAKER_INDEX: Lazy>> = + Lazy::new(|| Ssl::new_ex_index().unwrap()); +pub(crate) static SELECT_CERT_FUTURE_INDEX: Lazy> = + Lazy::new(|| Ssl::new_ex_index().unwrap()); +pub(crate) static SELECT_PRIVATE_KEY_METHOD_FUTURE_INDEX: Lazy< + Index, +> = Lazy::new(|| Ssl::new_ex_index().unwrap()); + +/// Extensions to [`SslContextBuilder`]. +/// +/// This trait provides additional methods to use async callbacks with boring. +pub trait SslContextBuilderExt: private::Sealed { + /// Sets a callback that is called before most [`ClientHello`] processing + /// and before the decision whether to resume a session is made. The + /// callback may inspect the [`ClientHello`] and configure the connection. + /// + /// This method uses a function that returns a future whose output is + /// itself a closure that will be passed [`ClientHello`] to configure + /// the connection based on the computations done in the future. + /// + /// See [`SslContextBuilder::set_select_certificate_callback`] for the sync + /// setter of this callback. + fn set_async_select_certificate_callback(&mut self, callback: Init) + where + Init: Fn(&mut ClientHello<'_>) -> Result + Send + Sync + 'static, + Fut: Future> + Send + Sync + 'static, + Finish: FnOnce(ClientHello<'_>) -> Result<(), AsyncSelectCertError> + 'static; + + /// Configures a custom private key method on the context. + /// + /// See [`AsyncPrivateKeyMethod`] for more details. + fn set_async_private_key_method(&mut self, method: impl AsyncPrivateKeyMethod); +} + +impl SslContextBuilderExt for SslContextBuilder { + fn set_async_select_certificate_callback(&mut self, callback: Init) + where + Init: Fn(&mut ClientHello<'_>) -> Result + Send + Sync + 'static, + Fut: Future> + Send + Sync + 'static, + Finish: FnOnce(ClientHello<'_>) -> Result<(), AsyncSelectCertError> + 'static, + { + self.set_select_certificate_callback(move |mut client_hello| { + let fut_poll_result = with_ex_data_future( + &mut client_hello, + *SELECT_CERT_FUTURE_INDEX, + ClientHello::ssl_mut, + |client_hello| { + let fut = callback(client_hello)?; + + Ok(Box::pin(async move { + Ok(Box::new(fut.await?) as BoxSelectCertFinish) + })) + }, + ); + + let fut_result = match fut_poll_result { + Poll::Ready(fut_result) => fut_result, + Poll::Pending => return Err(ssl::SelectCertError::RETRY), + }; + + let finish = fut_result.or(Err(ssl::SelectCertError::ERROR))?; + + finish(client_hello).or(Err(ssl::SelectCertError::ERROR)) + }) + } + + fn set_async_private_key_method(&mut self, method: impl AsyncPrivateKeyMethod) { + self.set_private_key_method(AsyncPrivateKeyMethodBridge(Box::new(method))); + } +} + +/// A fatal error to be returned from async select certificate callbacks. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub struct AsyncSelectCertError; + +/// Describes async private key hooks. This is used to off-load signing +/// operations to a custom, potentially asynchronous, backend. Metadata about the +/// key such as the type and size are parsed out of the certificate. +/// +/// See [`PrivateKeyMethod`] for the sync version of those hooks. +/// +/// [`ssl_private_key_method_st`]: https://commondatastorage.googleapis.com/chromium-boringssl-docs/ssl.h.html#ssl_private_key_method_st +pub trait AsyncPrivateKeyMethod: Send + Sync + 'static { + /// Signs the message `input` using the specified signature algorithm. + /// + /// This method uses a function that returns a future whose output is + /// itself a closure that will be passed `ssl` and `output` + /// to finish writing the signature. + /// + /// See [`PrivateKeyMethod::sign`] for the sync version of this method. + fn sign( + &self, + ssl: &mut ssl::SslRef, + input: &[u8], + signature_algorithm: ssl::SslSignatureAlgorithm, + output: &mut [u8], + ) -> Result; + + /// Decrypts `input`. + /// + /// This method uses a function that returns a future whose output is + /// itself a closure that will be passed `ssl` and `output` + /// to finish decrypting the input. + /// + /// See [`PrivateKeyMethod::decrypt`] for the sync version of this method. + fn decrypt( + &self, + ssl: &mut ssl::SslRef, + input: &[u8], + output: &mut [u8], + ) -> Result; +} + +/// A fatal error to be returned from async private key methods. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub struct AsyncPrivateKeyMethodError; + +struct AsyncPrivateKeyMethodBridge(Box); + +impl PrivateKeyMethod for AsyncPrivateKeyMethodBridge { + fn sign( + &self, + ssl: &mut ssl::SslRef, + input: &[u8], + signature_algorithm: ssl::SslSignatureAlgorithm, + output: &mut [u8], + ) -> Result { + with_private_key_method(ssl, output, |ssl, output| { + ::sign(&*self.0, ssl, input, signature_algorithm, output) + }) + } + + fn decrypt( + &self, + ssl: &mut ssl::SslRef, + input: &[u8], + output: &mut [u8], + ) -> Result { + with_private_key_method(ssl, output, |ssl, output| { + ::decrypt(&*self.0, ssl, input, output) + }) + } + + fn complete( + &self, + ssl: &mut ssl::SslRef, + output: &mut [u8], + ) -> Result { + with_private_key_method(ssl, output, |_, _| { + // This should never be reached, if it does, that's a bug on boring's side, + // which called `complete` without having been returned to with a pending + // future from `sign` or `decrypt`. + + if cfg!(debug_assertions) { + panic!("BUG: boring called complete without a pending operation"); + } + + Err(AsyncPrivateKeyMethodError) + }) + } +} + +/// Creates and drives a private key method future. +/// +/// This is a convenience function for the three methods of impl `PrivateKeyMethod`` +/// for `dyn AsyncPrivateKeyMethod`. It relies on [`with_ex_data_future`] to +/// drive the future and then immediately calls the final [`BoxPrivateKeyMethodFinish`] +/// when the future is ready. +fn with_private_key_method( + ssl: &mut ssl::SslRef, + output: &mut [u8], + create_fut: impl FnOnce( + &mut ssl::SslRef, + &mut [u8], + ) -> Result, +) -> Result { + let fut_poll_result = with_ex_data_future( + ssl, + *SELECT_PRIVATE_KEY_METHOD_FUTURE_INDEX, + |ssl| ssl, + |ssl| create_fut(ssl, output), + ); + + let fut_result = match fut_poll_result { + Poll::Ready(fut_result) => fut_result, + Poll::Pending => return Err(ssl::PrivateKeyMethodError::RETRY), + }; + + let finish = fut_result.or(Err(ssl::PrivateKeyMethodError::FAILURE))?; + + finish(ssl, output).or(Err(ssl::PrivateKeyMethodError::FAILURE)) +} + +/// Creates and drives a future stored in `ssl_handle`'s `Ssl` at ex data index `index`. +/// +/// This function won't even bother storing the future in `index` if the future +/// created by `create_fut` returns `Poll::Ready(_)` on the first poll call. +fn with_ex_data_future( + ssl_handle: &mut H, + index: Index>>, + get_ssl_mut: impl Fn(&mut H) -> &mut ssl::SslRef, + create_fut: impl FnOnce(&mut H) -> Result>, E>, +) -> Poll> { + let ssl = get_ssl_mut(ssl_handle); + let waker = ssl + .ex_data(*TASK_WAKER_INDEX) + .cloned() + .flatten() + .expect("task waker should be set"); + + let mut ctx = Context::from_waker(&waker); + + match ssl.ex_data_mut(index) { + Some(fut) => { + let fut_result = ready!(fut.as_mut().poll(&mut ctx)); + + // NOTE(nox): For memory usage concerns, maybe we should implement + // a way to remove the stored future from the `Ssl` value here? + + Poll::Ready(fut_result) + } + None => { + let mut fut = create_fut(ssl_handle)?; + + match fut.as_mut().poll(&mut ctx) { + Poll::Ready(fut_result) => Poll::Ready(fut_result), + Poll::Pending => { + get_ssl_mut(ssl_handle).set_ex_data(index, fut); + + Poll::Pending + } + } + } + } +} + +mod private { + pub trait Sealed {} +} + +impl private::Sealed for SslContextBuilder {} diff --git a/tokio-boring/src/lib.rs b/tokio-boring/src/lib.rs index f437ee26..79a2e9b5 100644 --- a/tokio-boring/src/lib.rs +++ b/tokio-boring/src/lib.rs @@ -27,8 +27,14 @@ use std::pin::Pin; use std::task::{Context, Poll}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +mod async_callbacks; mod bridge; +use self::async_callbacks::TASK_WAKER_INDEX; +pub use self::async_callbacks::{ + AsyncPrivateKeyMethod, AsyncPrivateKeyMethodError, AsyncSelectCertError, + BoxPrivateKeyMethodFinish, BoxPrivateKeyMethodFuture, SslContextBuilderExt, +}; use self::bridge::AsyncStreamBridge; /// Asynchronously performs a client-side TLS handshake over the provided stream. @@ -90,6 +96,11 @@ impl SslStream { self.0.ssl() } + /// Returns a mutable reference to the `Ssl` object associated with this stream. + pub fn ssl_mut(&mut self) -> &mut SslRef { + self.0.ssl_mut() + } + /// Returns a shared reference to the underlying stream. pub fn get_ref(&self) -> &S { &self.0.get_ref().stream @@ -285,15 +296,20 @@ where let mut mid_handshake = self.0.take().expect("future polled after completion"); mid_handshake.get_mut().set_waker(Some(ctx)); + mid_handshake + .ssl_mut() + .set_ex_data(*TASK_WAKER_INDEX, Some(ctx.waker().clone())); match mid_handshake.handshake() { Ok(mut stream) => { stream.get_mut().set_waker(None); + stream.ssl_mut().set_ex_data(*TASK_WAKER_INDEX, None); Poll::Ready(Ok(SslStream(stream))) } Err(ssl::HandshakeError::WouldBlock(mut mid_handshake)) => { mid_handshake.get_mut().set_waker(None); + mid_handshake.ssl_mut().set_ex_data(*TASK_WAKER_INDEX, None); self.0 = Some(mid_handshake); diff --git a/tokio-boring/tests/async_private_key_method.rs b/tokio-boring/tests/async_private_key_method.rs new file mode 100644 index 00000000..b39ef7d3 --- /dev/null +++ b/tokio-boring/tests/async_private_key_method.rs @@ -0,0 +1,187 @@ +use boring::hash::MessageDigest; +use boring::pkey::PKey; +use boring::rsa::Padding; +use boring::sign::{RsaPssSaltlen, Signer}; +use boring::ssl::{SslRef, SslSignatureAlgorithm}; +use futures::future; +use tokio::task::yield_now; +use tokio_boring::{ + AsyncPrivateKeyMethod, AsyncPrivateKeyMethodError, BoxPrivateKeyMethodFuture, + SslContextBuilderExt, +}; + +mod common; + +use self::common::{connect, create_server, with_trivial_client_server_exchange}; + +#[allow(clippy::type_complexity)] +struct Method { + sign: Box< + dyn Fn( + &mut SslRef, + &[u8], + SslSignatureAlgorithm, + &mut [u8], + ) -> Result + + Send + + Sync + + 'static, + >, + decrypt: Box< + dyn Fn( + &mut SslRef, + &[u8], + &mut [u8], + ) -> Result + + Send + + Sync + + 'static, + >, +} + +impl Method { + fn new() -> Self { + Self { + sign: Box::new(|_, _, _, _| unreachable!("called sign")), + decrypt: Box::new(|_, _, _| unreachable!("called decrypt")), + } + } + + fn sign( + mut self, + sign: impl Fn( + &mut SslRef, + &[u8], + SslSignatureAlgorithm, + &mut [u8], + ) -> Result + + Send + + Sync + + 'static, + ) -> Self { + self.sign = Box::new(sign); + + self + } + + #[allow(dead_code)] + fn decrypt( + mut self, + decrypt: impl Fn( + &mut SslRef, + &[u8], + &mut [u8], + ) -> Result + + Send + + Sync + + 'static, + ) -> Self { + self.decrypt = Box::new(decrypt); + + self + } +} + +impl AsyncPrivateKeyMethod for Method { + fn sign( + &self, + ssl: &mut SslRef, + input: &[u8], + signature_algorithm: SslSignatureAlgorithm, + output: &mut [u8], + ) -> Result { + (self.sign)(ssl, input, signature_algorithm, output) + } + + fn decrypt( + &self, + ssl: &mut SslRef, + input: &[u8], + output: &mut [u8], + ) -> Result { + (self.decrypt)(ssl, input, output) + } +} + +#[tokio::test] +async fn test_sign_failure() { + with_async_private_key_method_error( + Method::new().sign(|_, _, _, _| Err(AsyncPrivateKeyMethodError)), + ) + .await; +} + +#[tokio::test] +async fn test_sign_future_failure() { + with_async_private_key_method_error( + Method::new().sign(|_, _, _, _| Ok(Box::pin(async { Err(AsyncPrivateKeyMethodError) }))), + ) + .await; +} + +#[tokio::test] +async fn test_sign_future_yield_failure() { + with_async_private_key_method_error(Method::new().sign(|_, _, _, _| { + Ok(Box::pin(async { + yield_now().await; + + Err(AsyncPrivateKeyMethodError) + })) + })) + .await; +} + +#[tokio::test] +async fn test_sign_ok() { + with_trivial_client_server_exchange(|builder| { + builder.set_async_private_key_method(Method::new().sign( + |_, input, signature_algorithm, _| { + assert_eq!( + signature_algorithm, + SslSignatureAlgorithm::RSA_PSS_RSAE_SHA256, + ); + + let input = input.to_owned(); + + Ok(Box::pin(async move { + Ok(Box::new(move |_: &mut SslRef, output: &mut [u8]| { + Ok(sign_with_default_config(&input, output)) + }) as Box<_>) + })) + }, + )); + }) + .await; +} + +fn sign_with_default_config(input: &[u8], output: &mut [u8]) -> usize { + let pkey = PKey::private_key_from_pem(include_bytes!("key.pem")).unwrap(); + let mut signer = Signer::new(MessageDigest::sha256(), &pkey).unwrap(); + + signer.set_rsa_padding(Padding::PKCS1_PSS).unwrap(); + signer + .set_rsa_pss_saltlen(RsaPssSaltlen::DIGEST_LENGTH) + .unwrap(); + + signer.update(input).unwrap(); + + signer.sign(output).unwrap() +} + +async fn with_async_private_key_method_error(method: Method) { + let (stream, addr) = create_server(move |builder| { + builder.set_async_private_key_method(method); + }); + + let server = async { + let _err = stream.await.unwrap_err(); + }; + + let client = async { + let _err = connect(addr, |builder| builder.set_ca_file("tests/cert.pem")) + .await + .unwrap_err(); + }; + + future::join(server, client).await; +} diff --git a/tokio-boring/tests/async_select_certificate.rs b/tokio-boring/tests/async_select_certificate.rs new file mode 100644 index 00000000..0c8c4dac --- /dev/null +++ b/tokio-boring/tests/async_select_certificate.rs @@ -0,0 +1,96 @@ +use boring::ssl::ClientHello; +use futures::future::{self, Pending}; +use futures::Future; +use tokio::task::yield_now; +use tokio_boring::{AsyncSelectCertError, SslContextBuilderExt}; + +mod common; + +use self::common::{connect, create_server, with_trivial_client_server_exchange}; + +#[tokio::test] +async fn test_async_select_certificate_callback_trivial() { + with_trivial_client_server_exchange(|builder| { + builder.set_async_select_certificate_callback(|_| { + Ok(async move { Ok(|_: ClientHello<'_>| Ok(())) }) + }); + }) + .await; +} + +#[tokio::test] +async fn test_async_select_certificate_callback_yield() { + with_trivial_client_server_exchange(|builder| { + builder.set_async_select_certificate_callback(|_| { + Ok(async move { + yield_now().await; + + Ok(|_: ClientHello<'_>| Ok(())) + }) + }); + }) + .await; +} + +#[tokio::test] +async fn test_async_select_certificate_callback_return_error() { + with_async_select_certificate_callback_error::<_, Pending<_>, fn(_: ClientHello<'_>) -> _>( + |_| Err(AsyncSelectCertError), + ) + .await; +} + +#[tokio::test] +async fn test_async_select_certificate_callback_future_error() { + with_async_select_certificate_callback_error::<_, _, fn(_: ClientHello<'_>) -> _>(|_| { + Ok(async move { Err(AsyncSelectCertError) }) + }) + .await; +} + +#[tokio::test] +async fn test_async_select_certificate_callback_future_yield_error() { + with_async_select_certificate_callback_error::<_, _, fn(_: ClientHello<'_>) -> _>(|_| { + Ok(async move { + yield_now().await; + + Err(AsyncSelectCertError) + }) + }) + .await; +} + +#[tokio::test] +async fn test_async_select_certificate_callback_finish_error() { + with_async_select_certificate_callback_error(|_| { + Ok(async move { + yield_now().await; + + Ok(|_: ClientHello<'_>| Err(AsyncSelectCertError)) + }) + }) + .await; +} + +async fn with_async_select_certificate_callback_error(callback: Init) +where + Init: Fn(&mut ClientHello<'_>) -> Result + Send + Sync + 'static, + Fut: Future> + Send + Sync + 'static, + Finish: FnOnce(ClientHello<'_>) -> Result<(), AsyncSelectCertError> + 'static, +{ + let (stream, addr) = create_server(|builder| { + builder.set_async_select_certificate_callback(callback); + }); + + let server = async { + let _err = stream.await.unwrap_err(); + }; + + let client = async { + let _err = connect(addr, |builder| builder.set_ca_file("tests/cert.pem")) + .await + .unwrap_err(); + }; + + future::join(server, client).await; +}