Skip to content

Commit

Permalink
Use box instead of tunnel state wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
dlon committed Oct 17, 2023
1 parent eb5791d commit d668753
Show file tree
Hide file tree
Showing 6 changed files with 295 additions and 382 deletions.
132 changes: 57 additions & 75 deletions talpid-core/src/tunnel_state_machine/connected_state.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::{
AfterDisconnect, ConnectingState, DisconnectingState, ErrorState, EventConsequence,
EventResult, SharedTunnelStateValues, TunnelCommand, TunnelCommandReceiver, TunnelState,
TunnelStateTransition, TunnelStateWrapper,
TunnelStateTransition,
};
use crate::{
firewall::FirewallPolicy,
Expand All @@ -27,14 +27,6 @@ use super::connecting_state::TunnelCloseEvent;
pub(crate) type TunnelEventsReceiver =
Fuse<mpsc::UnboundedReceiver<(TunnelEvent, oneshot::Sender<()>)>>;

pub struct ConnectedStateBootstrap {
pub metadata: TunnelMetadata,
pub tunnel_events: TunnelEventsReceiver,
pub tunnel_parameters: TunnelParameters,
pub tunnel_close_event: TunnelCloseEvent,
pub tunnel_close_tx: oneshot::Sender<()>,
}

/// The tunnel is up and working.
pub struct ConnectedState {
metadata: TunnelMetadata,
Expand All @@ -45,13 +37,47 @@ pub struct ConnectedState {
}

impl ConnectedState {
fn from(bootstrap: ConnectedStateBootstrap) -> Self {
ConnectedState {
metadata: bootstrap.metadata,
tunnel_events: bootstrap.tunnel_events,
tunnel_parameters: bootstrap.tunnel_parameters,
tunnel_close_event: bootstrap.tunnel_close_event,
tunnel_close_tx: bootstrap.tunnel_close_tx,
#[cfg_attr(target_os = "android", allow(unused_variables))]
pub(super) fn enter(
shared_values: &mut SharedTunnelStateValues,
metadata: TunnelMetadata,
tunnel_events: TunnelEventsReceiver,
tunnel_parameters: TunnelParameters,
tunnel_close_event: TunnelCloseEvent,
tunnel_close_tx: oneshot::Sender<()>,
) -> (Box<dyn TunnelState>, TunnelStateTransition) {
let connected_state = ConnectedState {
metadata,
tunnel_events,
tunnel_parameters,
tunnel_close_event,
tunnel_close_tx,
};

let tunnel_interface = Some(connected_state.metadata.interface.clone());
let tunnel_endpoint = talpid_types::net::TunnelEndpoint {
tunnel_interface,
..connected_state.tunnel_parameters.get_tunnel_endpoint()
};

if let Err(error) = connected_state.set_firewall_policy(shared_values) {
DisconnectingState::enter(
connected_state.tunnel_close_tx,
connected_state.tunnel_close_event,
AfterDisconnect::Block(ErrorStateCause::SetFirewallPolicyError(error)),
)
} else if let Err(error) = connected_state.set_dns(shared_values) {
log::error!("{}", error.display_chain_with_msg("Failed to set DNS"));
DisconnectingState::enter(
connected_state.tunnel_close_tx,
connected_state.tunnel_close_event,
AfterDisconnect::Block(ErrorStateCause::SetDnsError),
)
} else {
(
Box::new(connected_state),
TunnelStateTransition::Connected(tunnel_endpoint),
)
}
}

Expand Down Expand Up @@ -173,17 +199,14 @@ impl ConnectedState {
Self::reset_routes(shared_values);

EventConsequence::NewState(DisconnectingState::enter(
shared_values,
(
self.tunnel_close_tx,
self.tunnel_close_event,
after_disconnect,
),
self.tunnel_close_tx,
self.tunnel_close_event,
after_disconnect,
))
}

fn handle_commands(
self,
self: Box<Self>,
command: Option<TunnelCommand>,
shared_values: &mut SharedTunnelStateValues,
) -> EventConsequence {
Expand All @@ -199,7 +222,7 @@ impl ConnectedState {
if cfg!(target_os = "android") {
self.disconnect(shared_values, AfterDisconnect::Reconnect(0))
} else {
SameState(self.into())
SameState(self)
}
}
Err(error) => self.disconnect(
Expand All @@ -212,7 +235,7 @@ impl ConnectedState {
Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => {
shared_values.allowed_endpoint = endpoint;
let _ = tx.send(());
SameState(self.into())
SameState(self)
}
Some(TunnelCommand::Dns(servers)) => match shared_values.set_dns_servers(servers) {
Ok(true) => {
Expand All @@ -227,7 +250,7 @@ impl ConnectedState {
#[cfg(target_os = "android")]
Ok(()) => self.disconnect(shared_values, AfterDisconnect::Reconnect(0)),
#[cfg(not(target_os = "android"))]
Ok(()) => SameState(self.into()),
Ok(()) => SameState(self),
Err(error) => {
log::error!("{}", error.display_chain_with_msg("Failed to set DNS"));
self.disconnect(
Expand All @@ -237,14 +260,14 @@ impl ConnectedState {
}
}
}
Ok(false) => SameState(self.into()),
Ok(false) => SameState(self),
Err(error_cause) => {
self.disconnect(shared_values, AfterDisconnect::Block(error_cause))
}
},
Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => {
shared_values.block_when_disconnected = block_when_disconnected;
SameState(self.into())
SameState(self)
}
Some(TunnelCommand::IsOffline(is_offline)) => {
shared_values.is_offline = is_offline;
Expand All @@ -254,7 +277,7 @@ impl ConnectedState {
AfterDisconnect::Block(ErrorStateCause::IsOffline),
)
} else {
SameState(self.into())
SameState(self)
}
}
Some(TunnelCommand::Connect) => {
Expand All @@ -269,18 +292,18 @@ impl ConnectedState {
#[cfg(target_os = "android")]
Some(TunnelCommand::BypassSocket(fd, done_tx)) => {
shared_values.bypass_socket(fd, done_tx);
SameState(self.into())
SameState(self)
}
#[cfg(windows)]
Some(TunnelCommand::SetExcludedApps(result_tx, paths)) => {
shared_values.split_tunnel.set_paths(&paths, result_tx);
SameState(self.into())
SameState(self)
}
}
}

fn handle_tunnel_events(
self,
self: Box<Self>,
event: Option<(TunnelEvent, oneshot::Sender<()>)>,
shared_values: &mut SharedTunnelStateValues,
) -> EventConsequence {
Expand All @@ -290,7 +313,7 @@ impl ConnectedState {
Some((TunnelEvent::Down, _)) | None => {
self.disconnect(shared_values, AfterDisconnect::Reconnect(0))
}
Some(_) => SameState(self.into()),
Some(_) => SameState(self),
}
}

Expand All @@ -315,49 +338,8 @@ impl ConnectedState {
}

impl TunnelState for ConnectedState {
type Bootstrap = ConnectedStateBootstrap;

#[cfg_attr(target_os = "android", allow(unused_variables))]
fn enter(
shared_values: &mut SharedTunnelStateValues,
bootstrap: Self::Bootstrap,
) -> (TunnelStateWrapper, TunnelStateTransition) {
let connected_state = ConnectedState::from(bootstrap);
let tunnel_interface = Some(connected_state.metadata.interface.clone());
let tunnel_endpoint = talpid_types::net::TunnelEndpoint {
tunnel_interface,
..connected_state.tunnel_parameters.get_tunnel_endpoint()
};

if let Err(error) = connected_state.set_firewall_policy(shared_values) {
DisconnectingState::enter(
shared_values,
(
connected_state.tunnel_close_tx,
connected_state.tunnel_close_event,
AfterDisconnect::Block(ErrorStateCause::SetFirewallPolicyError(error)),
),
)
} else if let Err(error) = connected_state.set_dns(shared_values) {
log::error!("{}", error.display_chain_with_msg("Failed to set DNS"));
DisconnectingState::enter(
shared_values,
(
connected_state.tunnel_close_tx,
connected_state.tunnel_close_event,
AfterDisconnect::Block(ErrorStateCause::SetDnsError),
),
)
} else {
(
TunnelStateWrapper::from(connected_state),
TunnelStateTransition::Connected(tunnel_endpoint),
)
}
}

fn handle_event(
mut self,
mut self: Box<Self>,
runtime: &tokio::runtime::Handle,
commands: &mut TunnelCommandReceiver,
shared_values: &mut SharedTunnelStateValues,
Expand Down
Loading

0 comments on commit d668753

Please sign in to comment.