Skip to content

Commit

Permalink
Add DTLS support
Browse files Browse the repository at this point in the history
  • Loading branch information
lulf committed Oct 10, 2022
1 parent ef4212d commit 2fd2189
Show file tree
Hide file tree
Showing 9 changed files with 161 additions and 31 deletions.
4 changes: 2 additions & 2 deletions src/packet_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@ use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::codec::{BytesCodec, Framed};

pub trait PacketStream:
futures::Sink<Bytes, Error = StdError> + futures::Stream<Item = Result<BytesMut>> + Unpin
futures::Sink<Bytes, Error = StdError> + futures::Stream<Item = Result<BytesMut>> + Unpin + Send
{
}

pub(crate) struct FramedPacketStream<T>(pub(crate) Framed<T, BytesCodec>)
where
T: AsyncRead + AsyncWrite + Unpin;

impl<T> PacketStream for FramedPacketStream<T> where T: AsyncRead + AsyncWrite + Unpin {}
impl<T> PacketStream for FramedPacketStream<T> where T: AsyncRead + AsyncWrite + Unpin + Send {}

impl<T> futures::Sink<Bytes> for FramedPacketStream<T>
where
Expand Down
23 changes: 19 additions & 4 deletions src/server.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
use super::packet_stream::*;
use crate::udp::io::UdpIo;
use crate::udp::stream::UdpStream;

use std::io::Result;
use core::pin::Pin;
use openssl::ssl::{Ssl, SslContext};
use std::io::{Error as StdError, ErrorKind, Result};
use tokio::net::UdpSocket;
use tokio::sync::{mpsc, oneshot};
use tokio_openssl::SslStream;
use tokio_util::codec::{BytesCodec, Decoder};

pub struct Server {
stop: Option<oneshot::Sender<()>>,
Expand All @@ -23,9 +27,20 @@ impl Server {
}
}

pub async fn accept(&mut self) -> Result<UdpStream> {
pub async fn accept(&mut self, tls: Option<SslContext>) -> Result<Box<dyn PacketStream>> {
match self.accept_rx.recv().await {
Some(s) => s,
Some(s) => {
let s = s?;
if let Some(ctx) = &tls {
let mut dtls = SslStream::new(Ssl::new(&ctx)?, s)?;
Pin::new(&mut dtls).accept().await.map_err(|_| {
StdError::new(ErrorKind::ConnectionReset, "Error during TLS handshake")
})?;
Ok(Box::new(FramedPacketStream(BytesCodec::new().framed(dtls))))
} else {
Ok(Box::new(FramedPacketStream(BytesCodec::new().framed(s))))
}
}
None => Err(std::io::Error::new(
std::io::ErrorKind::ConnectionReset,
"Acceptor closed",
Expand Down
25 changes: 14 additions & 11 deletions src/udp/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,6 @@ impl UdpIo {
let mut rx_done = self.rx_done.lock().await;
loop {
tokio::select! {
done = rx_done.recv() => {
if let Some(peer) = done {
let mut peers = self.peers.lock().await;
let _ = peers.remove(&peer);
}
}
_ = &mut stop => {
return Ok(());
}
inbound = self.socket.recv_from(&mut buf) => {
match inbound {
Ok((size, src)) => {
Expand All @@ -82,7 +73,8 @@ impl UdpIo {
if let Some(acceptor) = &mut acceptor {
let (tx_in, rx_in) = mpsc::channel(10);
let tx_out = self.tx_out.clone();
let udp = UdpStream::new(src, tx_out, rx_in);
let tx_done = self.tx_done.clone();
let udp = UdpStream::new(src, tx_out, rx_in, tx_done);
let r = if udp.is_ok() {
Some(v.insert(tx_in))
} else {
Expand All @@ -109,7 +101,8 @@ impl UdpIo {
}
}
}
outbound = rx_out.recv() => match outbound {
outbound = rx_out.recv() => {
match outbound {
Some((dest, data)) => {
match self.socket.send_to(&data[..], &dest).await {
Ok(_) => {}
Expand All @@ -122,6 +115,16 @@ impl UdpIo {
None => {
return Ok(())
}
}
}
done = rx_done.recv() => {
if let Some(peer) = done {
let mut peers = self.peers.lock().await;
let _ = peers.remove(&peer);
}
}
_ = &mut stop => {
return Ok(());
}
}
}
Expand Down
9 changes: 6 additions & 3 deletions src/udp/stream.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use bytes::Bytes;
use futures::SinkExt;

use core::pin::Pin;
use core::task::{Context, Poll};
Expand Down Expand Up @@ -98,10 +99,12 @@ impl AsyncWrite for UdpStream {
}

fn poll_flush(
self: Pin<&mut Self>,
_: &mut std::task::Context<'_>,
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
Poll::Ready(Ok(()))
self.tx
.poll_flush_unpin(cx)
.map_err(|_| std::io::Error::new(ErrorKind::BrokenPipe, "Flush error"))
}

fn poll_shutdown(
Expand Down
12 changes: 12 additions & 0 deletions tests/certs/ca-cert.pem
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
-----BEGIN CERTIFICATE-----
MIIB2jCCAYGgAwIBAgIUBOC0qHOiuv2QsAge8U8RtafIVsswCgYIKoZIzj0EAwIw
QjELMAkGA1UEBhMCWFgxFTATBgNVBAcMDERlZmF1bHQgQ2l0eTEcMBoGA1UECgwT
RGVmYXVsdCBDb21wYW55IEx0ZDAgFw0yMjEwMTAxMzM5MjhaGA8yMDUyMTEyMTEz
MzkyOFowQjELMAkGA1UEBhMCWFgxFTATBgNVBAcMDERlZmF1bHQgQ2l0eTEcMBoG
A1UECgwTRGVmYXVsdCBDb21wYW55IEx0ZDBZMBMGByqGSM49AgEGCCqGSM49AwEH
A0IABDXPSIiPvhzCp3+sA9XGmp3kBooiaugAdlfDswTLz/6o4lj1Sn4tpS3wcgIv
zigt4T3ZPjpiyO5zT+P0hSPMLAejUzBRMB0GA1UdDgQWBBS4SYx0zrP3yAA7wwXd
iPoV6fPZTjAfBgNVHSMEGDAWgBS4SYx0zrP3yAA7wwXdiPoV6fPZTjAPBgNVHRMB
Af8EBTADAQH/MAoGCCqGSM49BAMCA0cAMEQCIHy24tB8pvAiUilxbJ6TPqsrECb4
vED7+27k8XohHBokAiBxHKFaANnfRQbVam924piktNXNbuhNEXsMmZFIS3bEPw==
-----END CERTIFICATE-----
5 changes: 5 additions & 0 deletions tests/certs/ca-key.pem
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
-----BEGIN PRIVATE KEY-----
MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgOMI3UIBcOKLV7mDX
vU2uhtPfid80iX1g6DyoZ7ekIMmhRANCAAQ1z0iIj74cwqd/rAPVxpqd5AaKImro
AHZXw7MEy8/+qOJY9Up+LaUt8HICL84oLeE92T46Ysjuc0/j9IUjzCwH
-----END PRIVATE KEY-----
14 changes: 14 additions & 0 deletions tests/certs/server-cert.pem
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
-----BEGIN CERTIFICATE-----
MIICKDCCAc6gAwIBAgIUaptPaaO7FrO+ER4qLqOns4SiCoswCgYIKoZIzj0EAwIw
QjELMAkGA1UEBhMCWFgxFTATBgNVBAcMDERlZmF1bHQgQ2l0eTEcMBoGA1UECgwT
RGVmYXVsdCBDb21wYW55IEx0ZDAgFw0yMjEwMTAxMzM5MjlaGA8yMDUyMTEyMTEz
MzkyOVowcjELMAkGA1UEBhMCTk8xDjAMBgNVBAgMBUhhbWFyMQ4wDAYDVQQHDAVI
YW1hcjEYMBYGA1UECgwPR2xvYmFsIFNlY3VyaXR5MRUwEwYDVQQLDAxIb2xzZXRi
YWtrZW4xEjAQBgNVBAMMCTEyNy4wLjAuMTBZMBMGByqGSM49AgEGCCqGSM49AwEH
A0IABMrludMdhxnYWA+hMGVBZBkctvDQtJNr5/f+nu+5R4hSN55aRygzLQIOe3cj
rmnyS5D72m+Y31jKC9P8FPy3/8yjcDBuMB8GA1UdIwQYMBaAFLhJjHTOs/fIADvD
Bd2I+hXp89lOMAkGA1UdEwQCMAAwCwYDVR0PBAQDAgTwMBQGA1UdEQQNMAuCCWxv
Y2FsaG9zdDAdBgNVHQ4EFgQUhVW31o5frrZoYqV7xZqEnNiYKe4wCgYIKoZIzj0E
AwIDSAAwRQIgCkr4VgZ9TWvxLzUuTnzcjZ14FKESp8e5lkgbMwAc1hoCIQCXt+kg
35L2/0F3h+kDKKT3drkR5huYHnx++ds9RKF2tg==
-----END CERTIFICATE-----
5 changes: 5 additions & 0 deletions tests/certs/server-key.pem
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
-----BEGIN PRIVATE KEY-----
MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQg+N1TvmZPYJn+Zr/H
MnAA+Tj9E3d80dfBkMi0771MO5ChRANCAATK5bnTHYcZ2FgPoTBlQWQZHLbw0LST
a+f3/p7vuUeIUjeeWkcoMy0CDnt3I65p8kuQ+9pvmN9YygvT/BT8t//M
-----END PRIVATE KEY-----
95 changes: 84 additions & 11 deletions tests/integration_test.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use futures::SinkExt;
use futures::StreamExt;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use openssl::ssl::{Ssl, SslContext, SslFiletype, SslMethod};
use std::path::{Path, PathBuf};
use tokio::net::UdpSocket;
use tokio::sync::oneshot;

use tokio_dtls_stream::Client;
use tokio_dtls_stream::Server;

Expand Down Expand Up @@ -41,26 +44,96 @@ async fn test_plain_server() {
let server = UdpSocket::bind("127.0.0.1:0").await.unwrap();
let saddr = server.local_addr().unwrap();

let (sig, mut stop) = oneshot::channel();
let mut server = Server::new(server);
let s = tokio::spawn(async move {
let mut c = server.accept().await.unwrap();
let mut buf = [0; 2048];
match c.read(&mut buf).await {
Ok(len) => {
assert!(c.write(&buf[..len]).await.is_ok());
}
Err(e) => {
assert!(false, "Error while receiving data: {:?}", e);
let mut c = server.accept(None).await.unwrap();
loop {
tokio::select! {
_ = &mut stop => {
break;
}
r = c.next() => match r {
Some(Ok(rx)) => {
assert!(c.send(rx.into()).await.is_ok());
}
Some(Err(e)) => {
assert!(false, "Error while receiving data: {:?}", e);
}
_ => {
assert!(false, "Stream closed unexpectedly");
}
}
}
}
});

let client = UdpSocket::bind("127.0.0.1:0").await.unwrap();
let mut rx = [0; 4];
client.send_to(b"PING", saddr).await.unwrap();
let (len, from) = client.recv_from(&mut rx[..]).await.unwrap();
let tx_fut = client.send_to(b"PING", saddr);
let rx_fut = client.recv_from(&mut rx[..]);

let (rxr, _) = tokio::join!(rx_fut, tx_fut);

let (len, from) = rxr.unwrap();

assert_eq!(4, len);
assert_eq!(saddr, from);
sig.send(()).unwrap();
s.await.unwrap();

assert_eq!(b"PING", &rx[..]);
}

#[tokio::test]
async fn test_dtls() {
let base = env!("CARGO_MANIFEST_DIR");
let key: PathBuf = [base, "tests", "certs", "server-key.pem"].iter().collect();
let cert: PathBuf = [base, "tests", "certs", "server-cert.pem"].iter().collect();
let ca: PathBuf = [base, "tests", "certs", "ca-cert.pem"].iter().collect();

let client = UdpSocket::bind("127.0.0.1:0").await.unwrap();
let server = UdpSocket::bind("127.0.0.1:0").await.unwrap();
let saddr = server.local_addr().unwrap();

let mut ctx = SslContext::builder(SslMethod::dtls()).unwrap();
ctx.set_private_key_file(key, SslFiletype::PEM).unwrap();
ctx.set_certificate_chain_file(cert).unwrap();
ctx.set_ca_file(ca).unwrap();
ctx.check_private_key().unwrap();
let ctx = ctx.build();

let client = Client::new(client);
let mut server = Server::new(server);
let (sig, mut stop) = oneshot::channel();

let c = ctx.clone();
let s = tokio::spawn(async move {
let mut c = server.accept(Some(c)).await.unwrap();
loop {
tokio::select! {
_ = &mut stop => {
break;
}
r = c.next() => match r {
Some(Ok(rx)) => {
assert!(c.send(rx.into()).await.is_ok());
}
Some(Err(e)) => {
assert!(false, "Error while receiving data: {:?}", e);
}
_ => {
assert!(false, "Stream closed unexpectedly");
}
}
}
}
});

let mut stream = client.connect(saddr, Some(ctx)).await.unwrap();
stream.send("PING".into()).await.unwrap();
let rx = stream.next().await.unwrap().unwrap();
sig.send(()).unwrap();
s.await.unwrap();

assert_eq!(b"PING", &rx[..]);
Expand Down

0 comments on commit 2fd2189

Please sign in to comment.