From 2fd21891ab48316c73bb3f421d4204c5b8772b65 Mon Sep 17 00:00:00 2001 From: Ulf Lilleengen Date: Mon, 10 Oct 2022 15:46:46 +0200 Subject: [PATCH] Add DTLS support --- src/packet_stream.rs | 4 +- src/server.rs | 23 +++++++-- src/udp/io.rs | 25 +++++----- src/udp/stream.rs | 9 ++-- tests/certs/ca-cert.pem | 12 +++++ tests/certs/ca-key.pem | 5 ++ tests/certs/server-cert.pem | 14 ++++++ tests/certs/server-key.pem | 5 ++ tests/integration_test.rs | 95 ++++++++++++++++++++++++++++++++----- 9 files changed, 161 insertions(+), 31 deletions(-) create mode 100644 tests/certs/ca-cert.pem create mode 100644 tests/certs/ca-key.pem create mode 100644 tests/certs/server-cert.pem create mode 100644 tests/certs/server-key.pem diff --git a/src/packet_stream.rs b/src/packet_stream.rs index 44197b1..d069872 100644 --- a/src/packet_stream.rs +++ b/src/packet_stream.rs @@ -6,7 +6,7 @@ use tokio::io::{AsyncRead, AsyncWrite}; use tokio_util::codec::{BytesCodec, Framed}; pub trait PacketStream: - futures::Sink + futures::Stream> + Unpin + futures::Sink + futures::Stream> + Unpin + Send { } @@ -14,7 +14,7 @@ pub(crate) struct FramedPacketStream(pub(crate) Framed) where T: AsyncRead + AsyncWrite + Unpin; -impl PacketStream for FramedPacketStream where T: AsyncRead + AsyncWrite + Unpin {} +impl PacketStream for FramedPacketStream where T: AsyncRead + AsyncWrite + Unpin + Send {} impl futures::Sink for FramedPacketStream where diff --git a/src/server.rs b/src/server.rs index 44409ed..09cc42c 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,9 +1,13 @@ +use super::packet_stream::*; use crate::udp::io::UdpIo; use crate::udp::stream::UdpStream; - -use std::io::Result; +use core::pin::Pin; +use openssl::ssl::{Ssl, SslContext}; +use std::io::{Error as StdError, ErrorKind, Result}; use tokio::net::UdpSocket; use tokio::sync::{mpsc, oneshot}; +use tokio_openssl::SslStream; +use tokio_util::codec::{BytesCodec, Decoder}; pub struct Server { stop: Option>, @@ -23,9 +27,20 @@ impl Server { } } - pub async fn accept(&mut self) -> Result { + pub async fn accept(&mut self, tls: Option) -> Result> { match self.accept_rx.recv().await { - Some(s) => s, + Some(s) => { + let s = s?; + if let Some(ctx) = &tls { + let mut dtls = SslStream::new(Ssl::new(&ctx)?, s)?; + Pin::new(&mut dtls).accept().await.map_err(|_| { + StdError::new(ErrorKind::ConnectionReset, "Error during TLS handshake") + })?; + Ok(Box::new(FramedPacketStream(BytesCodec::new().framed(dtls)))) + } else { + Ok(Box::new(FramedPacketStream(BytesCodec::new().framed(s)))) + } + } None => Err(std::io::Error::new( std::io::ErrorKind::ConnectionReset, "Acceptor closed", diff --git a/src/udp/io.rs b/src/udp/io.rs index 467522e..e145ecf 100644 --- a/src/udp/io.rs +++ b/src/udp/io.rs @@ -61,15 +61,6 @@ impl UdpIo { let mut rx_done = self.rx_done.lock().await; loop { tokio::select! { - done = rx_done.recv() => { - if let Some(peer) = done { - let mut peers = self.peers.lock().await; - let _ = peers.remove(&peer); - } - } - _ = &mut stop => { - return Ok(()); - } inbound = self.socket.recv_from(&mut buf) => { match inbound { Ok((size, src)) => { @@ -82,7 +73,8 @@ impl UdpIo { if let Some(acceptor) = &mut acceptor { let (tx_in, rx_in) = mpsc::channel(10); let tx_out = self.tx_out.clone(); - let udp = UdpStream::new(src, tx_out, rx_in); + let tx_done = self.tx_done.clone(); + let udp = UdpStream::new(src, tx_out, rx_in, tx_done); let r = if udp.is_ok() { Some(v.insert(tx_in)) } else { @@ -109,7 +101,8 @@ impl UdpIo { } } } - outbound = rx_out.recv() => match outbound { + outbound = rx_out.recv() => { + match outbound { Some((dest, data)) => { match self.socket.send_to(&data[..], &dest).await { Ok(_) => {} @@ -122,6 +115,16 @@ impl UdpIo { None => { return Ok(()) } + } + } + done = rx_done.recv() => { + if let Some(peer) = done { + let mut peers = self.peers.lock().await; + let _ = peers.remove(&peer); + } + } + _ = &mut stop => { + return Ok(()); } } } diff --git a/src/udp/stream.rs b/src/udp/stream.rs index 20fcc5e..9b0e36f 100644 --- a/src/udp/stream.rs +++ b/src/udp/stream.rs @@ -1,4 +1,5 @@ use bytes::Bytes; +use futures::SinkExt; use core::pin::Pin; use core::task::{Context, Poll}; @@ -98,10 +99,12 @@ impl AsyncWrite for UdpStream { } fn poll_flush( - self: Pin<&mut Self>, - _: &mut std::task::Context<'_>, + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, ) -> Poll> { - Poll::Ready(Ok(())) + self.tx + .poll_flush_unpin(cx) + .map_err(|_| std::io::Error::new(ErrorKind::BrokenPipe, "Flush error")) } fn poll_shutdown( diff --git a/tests/certs/ca-cert.pem b/tests/certs/ca-cert.pem new file mode 100644 index 0000000..a5f1e57 --- /dev/null +++ b/tests/certs/ca-cert.pem @@ -0,0 +1,12 @@ +-----BEGIN CERTIFICATE----- +MIIB2jCCAYGgAwIBAgIUBOC0qHOiuv2QsAge8U8RtafIVsswCgYIKoZIzj0EAwIw +QjELMAkGA1UEBhMCWFgxFTATBgNVBAcMDERlZmF1bHQgQ2l0eTEcMBoGA1UECgwT +RGVmYXVsdCBDb21wYW55IEx0ZDAgFw0yMjEwMTAxMzM5MjhaGA8yMDUyMTEyMTEz +MzkyOFowQjELMAkGA1UEBhMCWFgxFTATBgNVBAcMDERlZmF1bHQgQ2l0eTEcMBoG +A1UECgwTRGVmYXVsdCBDb21wYW55IEx0ZDBZMBMGByqGSM49AgEGCCqGSM49AwEH +A0IABDXPSIiPvhzCp3+sA9XGmp3kBooiaugAdlfDswTLz/6o4lj1Sn4tpS3wcgIv +zigt4T3ZPjpiyO5zT+P0hSPMLAejUzBRMB0GA1UdDgQWBBS4SYx0zrP3yAA7wwXd +iPoV6fPZTjAfBgNVHSMEGDAWgBS4SYx0zrP3yAA7wwXdiPoV6fPZTjAPBgNVHRMB +Af8EBTADAQH/MAoGCCqGSM49BAMCA0cAMEQCIHy24tB8pvAiUilxbJ6TPqsrECb4 +vED7+27k8XohHBokAiBxHKFaANnfRQbVam924piktNXNbuhNEXsMmZFIS3bEPw== +-----END CERTIFICATE----- diff --git a/tests/certs/ca-key.pem b/tests/certs/ca-key.pem new file mode 100644 index 0000000..03e3e45 --- /dev/null +++ b/tests/certs/ca-key.pem @@ -0,0 +1,5 @@ +-----BEGIN PRIVATE KEY----- +MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgOMI3UIBcOKLV7mDX +vU2uhtPfid80iX1g6DyoZ7ekIMmhRANCAAQ1z0iIj74cwqd/rAPVxpqd5AaKImro +AHZXw7MEy8/+qOJY9Up+LaUt8HICL84oLeE92T46Ysjuc0/j9IUjzCwH +-----END PRIVATE KEY----- diff --git a/tests/certs/server-cert.pem b/tests/certs/server-cert.pem new file mode 100644 index 0000000..4e75082 --- /dev/null +++ b/tests/certs/server-cert.pem @@ -0,0 +1,14 @@ +-----BEGIN CERTIFICATE----- +MIICKDCCAc6gAwIBAgIUaptPaaO7FrO+ER4qLqOns4SiCoswCgYIKoZIzj0EAwIw +QjELMAkGA1UEBhMCWFgxFTATBgNVBAcMDERlZmF1bHQgQ2l0eTEcMBoGA1UECgwT +RGVmYXVsdCBDb21wYW55IEx0ZDAgFw0yMjEwMTAxMzM5MjlaGA8yMDUyMTEyMTEz +MzkyOVowcjELMAkGA1UEBhMCTk8xDjAMBgNVBAgMBUhhbWFyMQ4wDAYDVQQHDAVI +YW1hcjEYMBYGA1UECgwPR2xvYmFsIFNlY3VyaXR5MRUwEwYDVQQLDAxIb2xzZXRi +YWtrZW4xEjAQBgNVBAMMCTEyNy4wLjAuMTBZMBMGByqGSM49AgEGCCqGSM49AwEH +A0IABMrludMdhxnYWA+hMGVBZBkctvDQtJNr5/f+nu+5R4hSN55aRygzLQIOe3cj +rmnyS5D72m+Y31jKC9P8FPy3/8yjcDBuMB8GA1UdIwQYMBaAFLhJjHTOs/fIADvD +Bd2I+hXp89lOMAkGA1UdEwQCMAAwCwYDVR0PBAQDAgTwMBQGA1UdEQQNMAuCCWxv +Y2FsaG9zdDAdBgNVHQ4EFgQUhVW31o5frrZoYqV7xZqEnNiYKe4wCgYIKoZIzj0E +AwIDSAAwRQIgCkr4VgZ9TWvxLzUuTnzcjZ14FKESp8e5lkgbMwAc1hoCIQCXt+kg +35L2/0F3h+kDKKT3drkR5huYHnx++ds9RKF2tg== +-----END CERTIFICATE----- diff --git a/tests/certs/server-key.pem b/tests/certs/server-key.pem new file mode 100644 index 0000000..65cf620 --- /dev/null +++ b/tests/certs/server-key.pem @@ -0,0 +1,5 @@ +-----BEGIN PRIVATE KEY----- +MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQg+N1TvmZPYJn+Zr/H +MnAA+Tj9E3d80dfBkMi0771MO5ChRANCAATK5bnTHYcZ2FgPoTBlQWQZHLbw0LST +a+f3/p7vuUeIUjeeWkcoMy0CDnt3I65p8kuQ+9pvmN9YygvT/BT8t//M +-----END PRIVATE KEY----- diff --git a/tests/integration_test.rs b/tests/integration_test.rs index 129bc81..480f6dc 100644 --- a/tests/integration_test.rs +++ b/tests/integration_test.rs @@ -1,7 +1,10 @@ use futures::SinkExt; use futures::StreamExt; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use openssl::ssl::{Ssl, SslContext, SslFiletype, SslMethod}; +use std::path::{Path, PathBuf}; use tokio::net::UdpSocket; +use tokio::sync::oneshot; + use tokio_dtls_stream::Client; use tokio_dtls_stream::Server; @@ -41,26 +44,96 @@ async fn test_plain_server() { let server = UdpSocket::bind("127.0.0.1:0").await.unwrap(); let saddr = server.local_addr().unwrap(); + let (sig, mut stop) = oneshot::channel(); let mut server = Server::new(server); let s = tokio::spawn(async move { - let mut c = server.accept().await.unwrap(); - let mut buf = [0; 2048]; - match c.read(&mut buf).await { - Ok(len) => { - assert!(c.write(&buf[..len]).await.is_ok()); - } - Err(e) => { - assert!(false, "Error while receiving data: {:?}", e); + let mut c = server.accept(None).await.unwrap(); + loop { + tokio::select! { + _ = &mut stop => { + break; + } + r = c.next() => match r { + Some(Ok(rx)) => { + assert!(c.send(rx.into()).await.is_ok()); + } + Some(Err(e)) => { + assert!(false, "Error while receiving data: {:?}", e); + } + _ => { + assert!(false, "Stream closed unexpectedly"); + } + } } } }); let client = UdpSocket::bind("127.0.0.1:0").await.unwrap(); let mut rx = [0; 4]; - client.send_to(b"PING", saddr).await.unwrap(); - let (len, from) = client.recv_from(&mut rx[..]).await.unwrap(); + let tx_fut = client.send_to(b"PING", saddr); + let rx_fut = client.recv_from(&mut rx[..]); + + let (rxr, _) = tokio::join!(rx_fut, tx_fut); + + let (len, from) = rxr.unwrap(); + assert_eq!(4, len); assert_eq!(saddr, from); + sig.send(()).unwrap(); + s.await.unwrap(); + + assert_eq!(b"PING", &rx[..]); +} + +#[tokio::test] +async fn test_dtls() { + let base = env!("CARGO_MANIFEST_DIR"); + let key: PathBuf = [base, "tests", "certs", "server-key.pem"].iter().collect(); + let cert: PathBuf = [base, "tests", "certs", "server-cert.pem"].iter().collect(); + let ca: PathBuf = [base, "tests", "certs", "ca-cert.pem"].iter().collect(); + + let client = UdpSocket::bind("127.0.0.1:0").await.unwrap(); + let server = UdpSocket::bind("127.0.0.1:0").await.unwrap(); + let saddr = server.local_addr().unwrap(); + + let mut ctx = SslContext::builder(SslMethod::dtls()).unwrap(); + ctx.set_private_key_file(key, SslFiletype::PEM).unwrap(); + ctx.set_certificate_chain_file(cert).unwrap(); + ctx.set_ca_file(ca).unwrap(); + ctx.check_private_key().unwrap(); + let ctx = ctx.build(); + + let client = Client::new(client); + let mut server = Server::new(server); + let (sig, mut stop) = oneshot::channel(); + + let c = ctx.clone(); + let s = tokio::spawn(async move { + let mut c = server.accept(Some(c)).await.unwrap(); + loop { + tokio::select! { + _ = &mut stop => { + break; + } + r = c.next() => match r { + Some(Ok(rx)) => { + assert!(c.send(rx.into()).await.is_ok()); + } + Some(Err(e)) => { + assert!(false, "Error while receiving data: {:?}", e); + } + _ => { + assert!(false, "Stream closed unexpectedly"); + } + } + } + } + }); + + let mut stream = client.connect(saddr, Some(ctx)).await.unwrap(); + stream.send("PING".into()).await.unwrap(); + let rx = stream.next().await.unwrap().unwrap(); + sig.send(()).unwrap(); s.await.unwrap(); assert_eq!(b"PING", &rx[..]);