Skip to content

Commit

Permalink
Ensure atomicity between (re)connection attempts
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrej Mihajlov authored and buggmagnet committed Oct 13, 2023
1 parent 127bd50 commit 40f7b0d
Show file tree
Hide file tree
Showing 6 changed files with 145 additions and 33 deletions.
22 changes: 4 additions & 18 deletions ios/PacketTunnelCore/Actor/Actor+ConnectionMonitoring.swift
Original file line number Diff line number Diff line change
Expand Up @@ -47,28 +47,14 @@ extension PacketTunnelActor {
}
}

/// Increment connection attempt counter and reconnect the tunnel.
/// Tell the tunnel to reconnect providing the correct reason to ensure that the attempt counter is incremented before reconnect.
private func onHandleConnectionRecovery() async {
switch state {
case var .connecting(connState):
connState.incrementAttemptCount()
state = .connecting(connState)

case var .reconnecting(connState):
connState.incrementAttemptCount()
state = .reconnecting(connState)

case var .connected(connState):
connState.incrementAttemptCount()
state = .connected(connState)
case .connecting, .reconnecting, .connected:
commandChannel.send(.reconnect(.random, reason: .connectionLoss))

case .initial, .disconnected, .disconnecting, .error:
// Explicit return to prevent reconnecting the tunnel.
return
break
}

// Tunnel monitor should already be paused at this point so don't stop it to avoid a reset of its internal
// counters.
commandChannel.send(.reconnect(.random, stopTunnelMonitor: false))
}
}
46 changes: 35 additions & 11 deletions ios/PacketTunnelCore/Actor/Actor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ public actor PacketTunnelActor {
case .stop:
await stop()

case let .reconnect(nextRelay, stopTunnelMonitor):
await reconnect(to: nextRelay, shouldStopTunnelMonitor: stopTunnelMonitor)
case let .reconnect(nextRelay, reason):
await reconnect(to: nextRelay, reason: reason)

case let .error(reason):
await setErrorStateInternal(with: reason)
Expand Down Expand Up @@ -173,16 +173,22 @@ extension PacketTunnelActor {

- Parameters:
- nextRelay: next relay to connect to
- shouldStopTunnelMonitor: whether tunnel monitor should be stopped
- reason: reason for reconnect
*/
private func reconnect(to nextRelay: NextRelay, shouldStopTunnelMonitor: Bool) async {
private func reconnect(to nextRelay: NextRelay, reason: ReconnectReason) async {
do {
switch state {
case .connecting, .connected, .reconnecting, .error:
if shouldStopTunnelMonitor {
switch reason {
case .connectionLoss:
// Tunnel monitor is already paused at this point. Avoid calling stop() to prevent the reset of
// internal state
break
case .userInitiated:
tunnelMonitor.stop()
}
try await tryStart(nextRelay: nextRelay)

try await tryStart(nextRelay: nextRelay, reason: reason)

case .disconnected, .disconnecting, .initial:
break
Expand All @@ -205,12 +211,14 @@ extension PacketTunnelActor {
- Start tunnel monitor.
- Reactivate default path observation (disabled when configuring tunnel adapter)

- Parameter nextRelay: which relay should be selected next.
- Parameters:
- nextRelay: which relay should be selected next.
- reason: reason for reconnect
*/
private func tryStart(nextRelay: NextRelay = .random) async throws {
private func tryStart(nextRelay: NextRelay = .random, reason: ReconnectReason = .userInitiated) async throws {
let settings: Settings = try settingsReader.read()

guard let connectionState = try makeConnectionState(nextRelay: nextRelay, settings: settings),
guard let connectionState = try makeConnectionState(nextRelay: nextRelay, settings: settings, reason: reason),
let targetState = state.targetStateForReconnect else { return }

let activeKey: PrivateKey
Expand Down Expand Up @@ -261,10 +269,15 @@ extension PacketTunnelActor {
- Parameters:
- nextRelay: relay preference that should be used when selecting next relay.
- settings: current settings
- reason: reason for reconnect

- Returns: New connection state or `nil` if current state is at or past `.disconnecting` phase.
*/
private func makeConnectionState(nextRelay: NextRelay, settings: Settings) throws -> ConnectionState? {
private func makeConnectionState(
nextRelay: NextRelay,
settings: Settings,
reason: ReconnectReason
) throws -> ConnectionState? {
let relayConstraints = settings.relayConstraints
let privateKey = settings.privateKey

Expand All @@ -284,7 +297,18 @@ extension PacketTunnelActor {
connectionAttemptCount: 0
)

case var .connecting(connState), var .connected(connState), var .reconnecting(connState):
case var .connecting(connState), var .reconnecting(connState):
switch reason {
case .connectionLoss:
// Increment attempt counter when reconnection is requested due to connectivity loss.
connState.incrementAttemptCount()
case .userInitiated:
break
}
// Explicit fallthrough
fallthrough

case var .connected(connState):
connState.selectedRelay = try selectRelay(
nextRelay: nextRelay,
relayConstraints: relayConstraints,
Expand Down
4 changes: 1 addition & 3 deletions ios/PacketTunnelCore/Actor/Command.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@ enum Command {
case stop

/// Reconnect tunnel.
/// `stopTunnelMonitor = false` is only used when tunnel monitor is paused in response to connectivity loss and shouldn't be stopped explicitly,
/// as this would reset its internal counters.
case reconnect(NextRelay, stopTunnelMonitor: Bool = true)
case reconnect(NextRelay, reason: ReconnectReason = .userInitiated)

/// Enter blocked state.
case error(BlockedStateReason)
Expand Down
10 changes: 10 additions & 0 deletions ios/PacketTunnelCore/Actor/State.swift
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,13 @@ public enum NextRelay: Equatable, Codable {
/// Use pre-selected relay.
case preSelected(SelectedRelay)
}

/// Describes the reason for reconnection request.
public enum ReconnectReason {
/// Initiated by user.
case userInitiated

/// Initiated by tunnel monitor due to loss of connectivity.
/// Actor will increment the connection attempt counter before picking next relay.
case connectionLoss
}
94 changes: 94 additions & 0 deletions ios/PacketTunnelCoreTests/ActorTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -99,4 +99,98 @@ final class ActorTests: XCTestCase {

await fulfillment(of: allExpectations, timeout: 1, enforceOrder: true)
}

/**
Each subsequent connection attempt should produce a single change to `state` containing the incremented attempt counter and new relay.

.connecting (attempt: 0) → .connecting (attempt: 1) → .connecting (attempt: 2) → ...
*/
func testConnectionAttemptTransition() async throws {
let tunnelMonitor = TunnelMonitorStub { _, _ in }
let actor = PacketTunnelActor.mock(tunnelMonitor: tunnelMonitor)
let connectingStateExpectation = expectation(description: "Expect connecting state")
connectingStateExpectation.expectedFulfillmentCount = 5

var nextAttemptCount: UInt = 0
stateSink = await actor.$state
.receive(on: DispatchQueue.main)
.sink { newState in
switch newState {
case .initial:
break

case let .connecting(connState):
XCTAssertEqual(connState.connectionAttemptCount, nextAttemptCount)
nextAttemptCount += 1
connectingStateExpectation.fulfill()

if nextAttemptCount < connectingStateExpectation.expectedFulfillmentCount {
tunnelMonitor.dispatch(.connectionLost, after: .milliseconds(10))
}

default:
XCTFail("Received invalid state: \(newState.name).")
}
}

self.actor = actor

actor.start(options: StartOptions(launchSource: .app))

await fulfillment(of: [connectingStateExpectation], timeout: 1)
}

/**
Each subsequent re-connection attempt should produce a single change to `state` containing the incremented attempt counter and new relay.

.reconnecting (attempt: 0) → .reconnecting (attempt: 1) → .reconnecting (attempt: 2) → ...
*/
func testReconnectionAttemptTransition() async throws {
let tunnelMonitor = TunnelMonitorStub { _, _ in }
let actor = PacketTunnelActor.mock(tunnelMonitor: tunnelMonitor)
let connectingStateExpectation = expectation(description: "Expect connecting state")
let connectedStateExpectation = expectation(description: "Expect connected state")
let reconnectingStateExpectation = expectation(description: "Expect reconnecting state")
reconnectingStateExpectation.expectedFulfillmentCount = 5

var nextAttemptCount: UInt = 0
stateSink = await actor.$state
.receive(on: DispatchQueue.main)
.sink { newState in
switch newState {
case .initial:
break

case .connecting:
connectingStateExpectation.fulfill()
tunnelMonitor.dispatch(.connectionEstablished, after: .milliseconds(10))

case .connected:
connectedStateExpectation.fulfill()
tunnelMonitor.dispatch(.connectionLost, after: .milliseconds(10))

case let .reconnecting(connState):
XCTAssertEqual(connState.connectionAttemptCount, nextAttemptCount)
nextAttemptCount += 1
reconnectingStateExpectation.fulfill()

if nextAttemptCount < reconnectingStateExpectation.expectedFulfillmentCount {
tunnelMonitor.dispatch(.connectionLost, after: .milliseconds(10))
}

default:
XCTFail("Received invalid state: \(newState.name).")
}
}

self.actor = actor

actor.start(options: StartOptions(launchSource: .app))

await fulfillment(
of: [connectingStateExpectation, connectedStateExpectation, reconnectingStateExpectation],
timeout: 1,
enforceOrder: true
)
}
}
2 changes: 1 addition & 1 deletion ios/PacketTunnelCoreTests/Mocks/TunnelMonitorStub.swift
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class TunnelMonitorStub: TunnelMonitorProtocol {

func onSleep() {}

private func dispatch(_ event: TunnelMonitorEvent, after delay: DispatchTimeInterval = .never) {
func dispatch(_ event: TunnelMonitorEvent, after delay: DispatchTimeInterval = .never) {
if case .never = delay {
onEvent?(event)
} else {
Expand Down

0 comments on commit 40f7b0d

Please sign in to comment.