From 52b114aaa054fed970db7dcd2c86f333789dc529 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20L=C3=B6nnhager?= Date: Wed, 14 Dec 2022 23:13:17 +0100 Subject: [PATCH] Add burst guard to IP registration ioctl --- talpid-core/src/split_tunnel/windows/mod.rs | 78 ++++++++++++++++++--- 1 file changed, 70 insertions(+), 8 deletions(-) diff --git a/talpid-core/src/split_tunnel/windows/mod.rs b/talpid-core/src/split_tunnel/windows/mod.rs index cdbe60fc600d..8d5ae1c7f673 100644 --- a/talpid-core/src/split_tunnel/windows/mod.rs +++ b/talpid-core/src/split_tunnel/windows/mod.rs @@ -5,7 +5,10 @@ mod volume_monitor; mod windows; use crate::{tunnel::TunnelMetadata, tunnel_state_machine::TunnelCommand}; -use futures::channel::{mpsc, oneshot}; +use futures::{ + channel::{mpsc, oneshot}, + StreamExt, +}; use std::{ collections::HashMap, convert::TryFrom, @@ -764,18 +767,27 @@ impl Drop for SplitTunnel { } struct SplitTunnelDefaultRouteChangeHandlerContext { - request_tx: RequestTx, + tx: mpsc::UnboundedSender, + abort_handle: tokio::task::JoinHandle<()>, pub addresses: InterfaceAddresses, } +impl Drop for SplitTunnelDefaultRouteChangeHandlerContext { + fn drop(&mut self) { + self.abort_handle.abort(); + } +} + impl SplitTunnelDefaultRouteChangeHandlerContext { pub fn new( request_tx: RequestTx, tunnel_ipv4: Option, tunnel_ipv6: Option, ) -> Self { + let (tx, abort_handle) = Self::create_burst_guard(request_tx); SplitTunnelDefaultRouteChangeHandlerContext { - request_tx, + tx, + abort_handle, addresses: InterfaceAddresses { tunnel_ipv4, tunnel_ipv6, @@ -785,6 +797,60 @@ impl SplitTunnelDefaultRouteChangeHandlerContext { } } + fn create_burst_guard( + request_tx: RequestTx, + ) -> ( + mpsc::UnboundedSender, + tokio::task::JoinHandle<()>, + ) { + let (tx, mut rx) = mpsc::unbounded(); + + let send_request = move |addresses| { + if request_tx + .send((Request::RegisterIps(addresses), None)) + .is_err() + { + log::error!("Split tunnel request thread is down"); + } + }; + + let abort_handle = tokio::spawn(async move { + const GRACE_PERIOD: Duration = Duration::from_secs(5); + const MAX_PERIOD: Duration = Duration::from_secs(10); + + while let Some(mut addresses) = rx.next().await { + let initial_time = tokio::time::Instant::now(); + loop { + if initial_time.elapsed() >= MAX_PERIOD { + send_request(addresses); + break; + } + + let next = rx.next(); + let delay = tokio::time::sleep(GRACE_PERIOD); + futures::pin_mut!(delay); + + match futures::future::select(next, delay).await { + futures::future::Either::Left((Some(new_addresses), _)) => { + // TODO: combine? + addresses = new_addresses; + continue; + } + futures::future::Either::Left((None, _)) => { + // Return from function + return; + } + futures::future::Either::Right((..)) => { + send_request(addresses); + break; + } + } + } + } + }); + (tx, abort_handle) + } + pub fn initialize_internet_addresses(&mut self) -> Result<(), Error> { // Identify IP address that gives us Internet access let internet_ipv4 = get_best_default_route(AddressFamily::Ipv4) @@ -872,11 +938,7 @@ fn split_tunnel_default_route_change_handler( return; } - if ctx - .request_tx - .send((Request::RegisterIps(ctx.addresses.clone()), None)) - .is_err() - { + if ctx.tx.unbounded_send(ctx.addresses.clone()).is_err() { log::error!("Split tunnel request thread is down"); } }