Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement SslContextBuilder::set_private_key_method
Browse files Browse the repository at this point in the history
nox committed Aug 4, 2023
1 parent 43c57d0 commit 87fb27e
Showing 5 changed files with 503 additions and 32 deletions.
111 changes: 98 additions & 13 deletions boring/src/ssl/callbacks.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
#![forbid(unsafe_op_in_unsafe_fn)]

use super::{
AlpnError, ClientHello, PrivateKeyError, PrivateKeyMethod, SelectCertError, SniError, Ssl,
SslAlert, SslContext, SslContextRef, SslRef, SslSession, SslSessionRef, SslSignatureAlgorithm,
SESSION_CTX_INDEX,
};
use crate::error::ErrorStack;
use crate::ffi;
use crate::x509::{X509StoreContext, X509StoreContextRef};
use foreign_types::ForeignType;
use foreign_types::ForeignTypeRef;
use libc::c_char;
@@ -12,19 +19,7 @@ use std::slice;
use std::str;
use std::sync::Arc;

use crate::error::ErrorStack;
use crate::ssl::AlpnError;
use crate::ssl::{ClientHello, SelectCertError};
use crate::ssl::{
SniError, Ssl, SslAlert, SslContext, SslContextRef, SslRef, SslSession, SslSessionRef,
SESSION_CTX_INDEX,
};
use crate::x509::{X509StoreContext, X509StoreContextRef};

pub(super) unsafe extern "C" fn raw_verify<F>(
preverify_ok: c_int,
x509_ctx: *mut ffi::X509_STORE_CTX,
) -> c_int
pub extern "C" fn raw_verify<F>(preverify_ok: c_int, x509_ctx: *mut ffi::X509_STORE_CTX) -> c_int
where
F: Fn(bool, &mut X509StoreContextRef) -> bool + 'static + Sync + Send,
{
@@ -372,3 +367,93 @@ where

callback(ssl, line);
}

pub(super) unsafe extern "C" fn raw_sign<M>(
ssl: *mut ffi::SSL,
out: *mut u8,
out_len: *mut usize,
max_out: usize,
signature_algorithm: u16,
in_: *const u8,
in_len: usize,
) -> ffi::ssl_private_key_result_t
where
M: PrivateKeyMethod,
{
// SAFETY: boring provides valid inputs.
let input = unsafe { slice::from_raw_parts(in_, in_len) };

let signature_algorithm = SslSignatureAlgorithm(signature_algorithm);

let callback = |method: &M, ssl: &mut _, output: &mut _| {
method.sign(ssl, input, signature_algorithm, output)
};

// SAFETY: boring provides valid inputs.
unsafe { raw_private_key_callback(ssl, out, out_len, max_out, callback) }
}

pub(super) unsafe extern "C" fn raw_decrypt<M>(
ssl: *mut ffi::SSL,
out: *mut u8,
out_len: *mut usize,
max_out: usize,
in_: *const u8,
in_len: usize,
) -> ffi::ssl_private_key_result_t
where
M: PrivateKeyMethod,
{
// SAFETY: boring provides valid inputs.
let input = unsafe { slice::from_raw_parts(in_, in_len) };

let callback = |method: &M, ssl: &mut _, output: &mut _| method.decrypt(ssl, input, output);

// SAFETY: boring provides valid inputs.
unsafe { raw_private_key_callback(ssl, out, out_len, max_out, callback) }
}

pub(super) unsafe extern "C" fn raw_complete<M>(
ssl: *mut ffi::SSL,
out: *mut u8,
out_len: *mut usize,
max_out: usize,
) -> ffi::ssl_private_key_result_t
where
M: PrivateKeyMethod,
{
// SAFETY: boring provides valid inputs.
unsafe { raw_private_key_callback::<M>(ssl, out, out_len, max_out, M::complete) }
}

unsafe fn raw_private_key_callback<M>(
ssl: *mut ffi::SSL,
out: *mut u8,
out_len: *mut usize,
max_out: usize,
callback: impl FnOnce(&M, &mut SslRef, &mut [u8]) -> Result<usize, PrivateKeyError>,
) -> ffi::ssl_private_key_result_t
where
M: PrivateKeyMethod,
{
// SAFETY: boring provides valid inputs.
let ssl = unsafe { SslRef::from_ptr_mut(ssl) };
let output = unsafe { slice::from_raw_parts_mut(out, max_out) };
let out_len = unsafe { &mut *out_len };

let ssl_context = ssl.ssl_context().to_owned();
let method = ssl_context
.ex_data(SslContext::cached_ex_index::<M>())
.expect("BUG: private key method missing");

match callback(method, ssl, output) {
Ok(written) => {
assert!(written <= max_out);

*out_len = written;

ffi::ssl_private_key_result_t::ssl_private_key_success
}
Err(err) => err.0,
}
}
95 changes: 95 additions & 0 deletions boring/src/ssl/mod.rs
Original file line number Diff line number Diff line change
@@ -1382,6 +1382,31 @@ impl SslContextBuilder {
}
}

/// Configures a custom private key method on the context.
///
/// See [`PrivateKeyMethod`] for more details.
///
/// This corresponds to [`SSL_CTX_set_private_key_method`]
///
/// [`SSL_CTX_set_private_key_method`]: https://commondatastorage.googleapis.com/chromium-boringssl-docs/ssl.h.html#SSL_CTX_set_private_key_method
pub fn set_private_key_method<M>(&mut self, method: M)
where
M: PrivateKeyMethod,
{
unsafe {
self.set_ex_data(SslContext::cached_ex_index::<M>(), method);

ffi::SSL_CTX_set_private_key_method(
self.as_ptr(),
&ffi::SSL_PRIVATE_KEY_METHOD {
sign: Some(callbacks::raw_sign::<M>),
decrypt: Some(callbacks::raw_decrypt::<M>),
complete: Some(callbacks::raw_complete::<M>),
},
)
}
}

/// Checks for consistency between the private key and certificate.
///
/// This corresponds to [`SSL_CTX_check_private_key`].
@@ -3649,6 +3674,76 @@ bitflags! {
}
}

/// Describes private key hooks. This is used to off-load signing operations to
/// a custom, potentially asynchronous, backend. Metadata about the key such as
/// the type and size are parsed out of the certificate.
///
/// Corresponds to [`ssl_private_key_method_st`].
///
/// [`ssl_private_key_method_st`]: https://commondatastorage.googleapis.com/chromium-boringssl-docs/ssl.h.html#ssl_private_key_method_st
pub trait PrivateKeyMethod: Send + Sync + 'static {
/// Signs the message `input` using the specified signature algorithm.
///
/// On success, it returns `Ok(written)` where `written` is the number of
/// bytes written into `output`. On failure, it returns
/// `Err(PrivateKeyError::FAILURE)`. If the operation has not completed,
/// it returns `Err(PrivateKeyError::RETRY)`.
///
/// The caller should arrange for the high-level operation on `ssl` to be
/// retried when the operation is completed. This will result in a call to
/// [`Self::complete`].
fn sign(
&self,
ssl: &mut SslRef,
input: &[u8],
signature_algorithm: SslSignatureAlgorithm,
output: &mut [u8],
) -> Result<usize, PrivateKeyError>;

/// Decrypts `input`.
///
/// On success, it returns `Ok(written)` where `written` is the number of
/// bytes written into `output`. On failure, it returns
/// `Err(PrivateKeyError::FAILURE)`. If the operation has not completed,
/// it returns `Err(PrivateKeyError::RETRY)`.
///
/// The caller should arrange for the high-level operation on `ssl` to be
/// retried when the operation is completed. This will result in a call to
/// [`Self::complete`].
///
/// This method only works with RSA keys and should perform a raw RSA
/// decryption operation with no padding.
// NOTE(nox): What does it mean that it is an error?
fn decrypt(
&self,
ssl: &mut SslRef,
input: &[u8],
output: &mut [u8],
) -> Result<usize, PrivateKeyError>;

/// Completes a pending operation.
///
/// On success, it returns `Ok(written)` where `written` is the number of
/// bytes written into `output`. On failure, it returns
/// `Err(PrivateKeyError::FAILURE)`. If the operation has not completed,
/// it returns `Err(PrivateKeyError::RETRY)`.
///
/// This method may be called arbitrarily many times before completion.
fn complete(&self, ssl: &mut SslRef, output: &mut [u8]) -> Result<usize, PrivateKeyError>;
}

/// An error returned from a private key method.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct PrivateKeyError(ffi::ssl_private_key_result_t);

impl PrivateKeyError {
/// A fatal error occured and the handshake should be terminated.
pub const FAILURE: Self = Self(ffi::ssl_private_key_result_t::ssl_private_key_failure);

/// The operation could not be completed and should be retried later.
pub const RETRY: Self = Self(ffi::ssl_private_key_result_t::ssl_private_key_retry);
}

use crate::ffi::{SSL_CTX_up_ref, SSL_SESSION_get_master_key, SSL_SESSION_up_ref, SSL_is_server};

use crate::ffi::{DTLS_method, TLS_client_method, TLS_method, TLS_server_method};
24 changes: 11 additions & 13 deletions boring/src/ssl/test/mod.rs
Original file line number Diff line number Diff line change
@@ -34,6 +34,7 @@ use crate::x509::store::X509StoreBuilder;
use crate::x509::verify::X509CheckFlags;
use crate::x509::{X509Name, X509StoreContext, X509VerifyResult, X509};

mod private_key_method;
mod server;

static ROOT_CERT: &[u8] = include_bytes!("../../../test/root-ca.pem");
@@ -55,9 +56,7 @@ fn verify_untrusted() {
#[test]
fn verify_trusted() {
let server = Server::builder().build();

let mut client = server.client();
client.ctx().set_ca_file("test/root-ca.pem").unwrap();
let client = server.client_with_root_ca();

client.connect();
}
@@ -109,9 +108,8 @@ fn verify_untrusted_callback_override_bad() {
#[test]
fn verify_trusted_callback_override_ok() {
let server = Server::builder().build();
let mut client = server.client_with_root_ca();

let mut client = server.client();
client.ctx().set_ca_file("test/root-ca.pem").unwrap();
client
.ctx()
.set_verify_callback(SslVerifyMode::PEER, |_, x509| {
@@ -125,11 +123,12 @@ fn verify_trusted_callback_override_ok() {
#[test]
fn verify_trusted_callback_override_bad() {
let mut server = Server::builder();

server.should_error();

let server = server.build();
let mut client = server.client_with_root_ca();

let mut client = server.client();
client.ctx().set_ca_file("test/root-ca.pem").unwrap();
client
.ctx()
.set_verify_callback(SslVerifyMode::PEER, |_, _| false);
@@ -155,9 +154,8 @@ fn verify_callback_load_certs() {
#[test]
fn verify_trusted_get_error_ok() {
let server = Server::builder().build();
let mut client = server.client_with_root_ca();

let mut client = server.client();
client.ctx().set_ca_file("test/root-ca.pem").unwrap();
client
.ctx()
.set_verify_callback(SslVerifyMode::PEER, |_, x509| {
@@ -697,9 +695,8 @@ fn add_extra_chain_cert() {
#[test]
fn verify_valid_hostname() {
let server = Server::builder().build();
let mut client = server.client_with_root_ca();

let mut client = server.client();
client.ctx().set_ca_file("test/root-ca.pem").unwrap();
client.ctx().set_verify(SslVerifyMode::PEER);

let mut client = client.build().builder();
@@ -714,11 +711,12 @@ fn verify_valid_hostname() {
#[test]
fn verify_invalid_hostname() {
let mut server = Server::builder();

server.should_error();

let server = server.build();
let mut client = server.client_with_root_ca();

let mut client = server.client();
client.ctx().set_ca_file("test/root-ca.pem").unwrap();
client.ctx().set_verify(SslVerifyMode::PEER);

let mut client = client.build().builder();
272 changes: 272 additions & 0 deletions boring/src/ssl/test/private_key_method.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,272 @@
use once_cell::sync::OnceCell;

use super::server::{Builder, Server};
use super::KEY;
use crate::hash::{Hasher, MessageDigest};
use crate::pkey::PKey;
use crate::rsa::Padding;
use crate::sign::{RsaPssSaltlen, Signer};
use crate::ssl::{
ErrorCode, HandshakeError, PrivateKeyError, PrivateKeyMethod, SslRef, SslSignatureAlgorithm,
};
use crate::x509::X509;
use std::cmp;
use std::io::{Read, Write};
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;

pub(super) struct Method {
sign: Box<
dyn Fn(
&mut SslRef,
&[u8],
SslSignatureAlgorithm,
&mut [u8],
) -> Result<usize, PrivateKeyError>
+ Send
+ Sync
+ 'static,
>,
decrypt: Box<
dyn Fn(&mut SslRef, &[u8], &mut [u8]) -> Result<usize, PrivateKeyError>
+ Send
+ Sync
+ 'static,
>,
complete: Box<
dyn Fn(&mut SslRef, &mut [u8]) -> Result<usize, PrivateKeyError> + Send + Sync + 'static,
>,
}

impl Method {
pub(super) fn new() -> Self {
Self {
sign: Box::new(|_, _, _, _| unreachable!("called sign")),
decrypt: Box::new(|_, _, _| unreachable!("called decrypt")),
complete: Box::new(|_, _| unreachable!("called complete")),
}
}

pub(super) fn sign(
mut self,
sign: impl Fn(
&mut SslRef,
&[u8],
SslSignatureAlgorithm,
&mut [u8],
) -> Result<usize, PrivateKeyError>
+ Send
+ Sync
+ 'static,
) -> Self {
self.sign = Box::new(sign);

self
}

pub(super) fn decrypt(
mut self,
decrypt: impl Fn(&mut SslRef, &[u8], &mut [u8]) -> Result<usize, PrivateKeyError>
+ Send
+ Sync
+ 'static,
) -> Self {
self.decrypt = Box::new(decrypt);

self
}

pub(super) fn complete(
mut self,
complete: impl Fn(&mut SslRef, &mut [u8]) -> Result<usize, PrivateKeyError>
+ Send
+ Sync
+ 'static,
) -> Self {
self.complete = Box::new(complete);

self
}
}

impl PrivateKeyMethod for Method {
fn sign(
&self,
ssl: &mut SslRef,
input: &[u8],
signature_algorithm: SslSignatureAlgorithm,
output: &mut [u8],
) -> Result<usize, PrivateKeyError> {
(self.sign)(ssl, input, signature_algorithm, output)
}

fn decrypt(
&self,
ssl: &mut SslRef,
input: &[u8],
output: &mut [u8],
) -> Result<usize, PrivateKeyError> {
(self.decrypt)(ssl, input, output)
}

fn complete(&self, ssl: &mut SslRef, output: &mut [u8]) -> Result<usize, PrivateKeyError> {
(self.complete)(ssl, output)
}
}

fn builder_with_private_key_method(method: Method) -> Builder {
let mut builder = Server::builder();

builder.ctx().set_private_key_method(method);

builder
}

#[test]
fn test_sign_failure() {
let called_sign = Arc::new(AtomicBool::new(false));
let called_sign_clone = called_sign.clone();

let mut builder = builder_with_private_key_method(Method::new().sign(move |_, _, _, _| {
called_sign_clone.store(true, Ordering::SeqCst);

Err(PrivateKeyError::FAILURE)
}));

builder.err_cb(|error| {
let HandshakeError::Failure(mid_handshake) = error else {
panic!("should be Failure");
};

assert_eq!(mid_handshake.error().code(), ErrorCode::SSL);
});

let server = builder.build();
let client = server.client_with_root_ca();

client.connect_err();

assert!(called_sign.load(Ordering::SeqCst));
}

#[test]
fn test_sign_retry_complete_failure() {
let called_complete = Arc::new(AtomicUsize::new(0));
let called_complete_clone = called_complete.clone();

let mut builder = builder_with_private_key_method(
Method::new()
.sign(|_, _, _, _| Err(PrivateKeyError::RETRY))
.complete(move |_, _| {
let old = called_complete_clone.fetch_add(1, Ordering::SeqCst);

Err(if old == 0 {
PrivateKeyError::RETRY
} else {
PrivateKeyError::FAILURE
})
}),
);

builder.err_cb(|error| {
let HandshakeError::WouldBlock(mid_handshake) = error else {
panic!("should be WouldBlock");
};

assert!(mid_handshake.error().would_block());
assert_eq!(
mid_handshake.error().code(),
ErrorCode::WANT_PRIVATE_KEY_OPERATION
);

let HandshakeError::WouldBlock(mid_handshake) = mid_handshake.handshake().unwrap_err() else {
panic!("should be WouldBlock");
};

assert_eq!(
mid_handshake.error().code(),
ErrorCode::WANT_PRIVATE_KEY_OPERATION
);

let HandshakeError::Failure(mid_handshake) = mid_handshake.handshake().unwrap_err() else {
panic!("should be Failure");
};

assert_eq!(mid_handshake.error().code(), ErrorCode::SSL);
});

let server = builder.build();
let client = server.client_with_root_ca();

client.connect_err();

assert_eq!(called_complete.load(Ordering::SeqCst), 2);
}

#[test]
fn test_sign_ok() {
let server = builder_with_private_key_method(Method::new().sign(
|_, input, signature_algorithm, output| {
assert_eq!(
signature_algorithm,
SslSignatureAlgorithm::RSA_PSS_RSAE_SHA256,
);

Ok(sign_with_default_config(input, output))
},
))
.build();

let client = server.client_with_root_ca();

client.connect();
}

#[test]
fn test_sign_retry_complete_ok() {
let input_cell = Arc::new(OnceCell::new());
let input_cell_clone = input_cell.clone();

let mut builder = builder_with_private_key_method(
Method::new()
.sign(move |_, input, _, _| {
input_cell.set(input.to_owned()).unwrap();

Err(PrivateKeyError::RETRY)
})
.complete(move |_, output| {
let input = input_cell_clone.get().unwrap();

Ok(sign_with_default_config(input, output))
}),
);

builder.err_cb(|error| {
let HandshakeError::WouldBlock(mid_handshake) = error else {
panic!("should be WouldBlock");
};

let mut socket = mid_handshake.handshake().unwrap();

socket.write_all(&[0]).unwrap();
});

let server = builder.build();
let client = server.client_with_root_ca();

client.connect();
}

fn sign_with_default_config(input: &[u8], output: &mut [u8]) -> usize {
let pkey = PKey::private_key_from_pem(KEY).unwrap();
let mut signer = Signer::new(MessageDigest::sha256(), &pkey).unwrap();

signer.set_rsa_padding(Padding::PKCS1_PSS).unwrap();
signer
.set_rsa_pss_saltlen(RsaPssSaltlen::DIGEST_LENGTH)
.unwrap();

signer.update(input).unwrap();

signer.sign(output).unwrap()
}
33 changes: 27 additions & 6 deletions boring/src/ssl/test/server.rs
Original file line number Diff line number Diff line change
@@ -2,7 +2,10 @@ use std::io::{Read, Write};
use std::net::{SocketAddr, TcpListener, TcpStream};
use std::thread::{self, JoinHandle};

use crate::ssl::{Ssl, SslContext, SslContextBuilder, SslFiletype, SslMethod, SslRef, SslStream};
use crate::ssl::{
MidHandshakeSslStream, Ssl, SslContext, SslContextBuilder, SslFiletype, SslMethod, SslRef,
SslStream, HandshakeError,
};

pub struct Server {
handle: Option<JoinHandle<()>>,
@@ -28,6 +31,7 @@ impl Server {
ctx,
ssl_cb: Box::new(|_| {}),
io_cb: Box::new(|_| {}),
err_cb: Box::new(|_| {}),
should_error: false,
}
}
@@ -39,6 +43,14 @@ impl Server {
}
}

pub fn client_with_root_ca(&self) -> ClientBuilder {
let mut client = self.client();

client.ctx().set_ca_file("test/root-ca.pem").unwrap();

client
}

pub fn connect_tcp(&self) -> TcpStream {
TcpStream::connect(self.addr).unwrap()
}
@@ -48,6 +60,7 @@ pub struct Builder {
ctx: SslContextBuilder,
ssl_cb: Box<dyn FnMut(&mut SslRef) + Send>,
io_cb: Box<dyn FnMut(SslStream<TcpStream>) + Send>,
err_cb: Box<dyn FnMut(HandshakeError<TcpStream>) + Send>,
should_error: bool,
}

@@ -70,6 +83,12 @@ impl Builder {
self.io_cb = Box::new(cb);
}

pub fn err_cb(&mut self, cb: impl FnMut(HandshakeError<TcpStream>) + Send + 'static) {
self.should_error();

self.err_cb = Box::new(cb);
}

pub fn should_error(&mut self) {
self.should_error = true;
}
@@ -80,6 +99,7 @@ impl Builder {
let addr = socket.local_addr().unwrap();
let mut ssl_cb = self.ssl_cb;
let mut io_cb = self.io_cb;
let mut err_cb = self.err_cb;
let should_error = self.should_error;

let handle = thread::spawn(move || {
@@ -88,7 +108,7 @@ impl Builder {
ssl_cb(&mut ssl);
let r = ssl.accept(socket);
if should_error {
r.unwrap_err();
err_cb(r.unwrap_err());
} else {
let mut socket = r.unwrap();
socket.write_all(&[0]).unwrap();
@@ -124,8 +144,8 @@ impl ClientBuilder {
self.build().builder().connect()
}

pub fn connect_err(self) {
self.build().builder().connect_err();
pub fn connect_err(self) -> HandshakeError<TcpStream> {
self.build().builder().connect_err()
}
}

@@ -160,8 +180,9 @@ impl ClientSslBuilder {
s
}

pub fn connect_err(self) {
pub fn connect_err(self) -> HandshakeError<TcpStream> {
let socket = TcpStream::connect(self.addr).unwrap();
self.ssl.connect(socket).unwrap_err();

self.ssl.setup_connect(socket).handshake().unwrap_err()
}
}

0 comments on commit 87fb27e

Please sign in to comment.