Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reusing existing connection for requests #5704

Merged
merged 1 commit into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions ios/MullvadREST/Transport/AccessMethodIterator.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@ class AccessMethodIterator {
private let dataSource: AccessMethodRepositoryDataSource

private var index = 0
private var enabledConfigurations: [PersistentAccessMethod] = []
private var cancellables = Set<Combine.AnyCancellable>()

private var enabledConfigurations: [PersistentAccessMethod] {
dataSource.fetchAll().filter { $0.isEnabled }
}

private var lastReachableApiAccessId: UUID {
lastReachableApiAccessCache.id
}
Expand All @@ -31,9 +34,8 @@ class AccessMethodIterator {

self.dataSource
.publisher
.sink { [weak self] configurations in
.sink { [weak self] _ in
guard let self else { return }
self.enabledConfigurations = configurations.filter { $0.isEnabled }
self.refreshCacheIfNeeded()
}
.store(in: &cancellables)
Expand All @@ -45,7 +47,7 @@ class AccessMethodIterator {
index = firstIndex
} else {
/// When `firstIndex` is `nil`, that means the current configuration is not valid anymore
/// Invalidating cache by replacing the `current` to the next enabled access method
/// Invalidating cache by replacing the `current` to the next enabled access method
lastReachableApiAccessCache.id = pick().id
}
}
Expand All @@ -57,14 +59,15 @@ class AccessMethodIterator {
}

func pick() -> PersistentAccessMethod {
if enabledConfigurations.isEmpty {
/// Returning `Default` strategy when all is disabled
let configurations = enabledConfigurations
if configurations.isEmpty {
/// Returning `Default` strategy when all is disabled
return dataSource.directAccess
} else {
/// Picking the next `Enabled` configuration in order they are added
/// And starting from the beginning when it reaches end
let circularIndex = index % enabledConfigurations.count
return enabledConfigurations[circularIndex]
let circularIndex = index % configurations.count
return configurations[circularIndex]
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import Foundation
import MullvadTypes
import Network

public struct ShadowsocksConfiguration: Codable {
public struct ShadowsocksConfiguration: Codable, Equatable {
public let address: AnyIPAddress
public let port: UInt16
public let password: String
Expand Down
2 changes: 1 addition & 1 deletion ios/MullvadREST/Transport/Socks5/Socks5Configuration.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import MullvadTypes

/// Socks5 configuration.
/// - See: ``URLSessionSocks5Transport``
public struct Socks5Configuration {
public struct Socks5Configuration: Equatable {
/// The socks proxy endpoint.
public var proxyEndpoint: AnyIPEndpoint

Expand Down
46 changes: 29 additions & 17 deletions ios/MullvadREST/Transport/TransportProvider.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ public final class TransportProvider: RESTTransportProvider {
private let addressCache: REST.AddressCache
private var transportStrategy: TransportStrategy
private var currentTransport: RESTTransport?
private var currentTransportType: TransportStrategy.Transport
private let parallelRequestsMutex = NSLock()

public init(
Expand All @@ -25,6 +26,7 @@ public final class TransportProvider: RESTTransportProvider {
self.urlSessionTransport = urlSessionTransport
self.addressCache = addressCache
self.transportStrategy = transportStrategy
self.currentTransportType = transportStrategy.connectionTransport()
}

public func makeTransport() -> RESTTransport? {
Expand Down Expand Up @@ -62,26 +64,36 @@ public final class TransportProvider: RESTTransportProvider {
///
/// - Returns: A `RESTTransport` object to make a connection
private func makeTransportInner() -> RESTTransport? {
switch transportStrategy.connectionTransport() {
case .direct:
currentTransport = urlSessionTransport
case let .shadowsocks(configuration):
currentTransport = ShadowsocksTransport(
urlSession: urlSessionTransport.urlSession,
configuration: configuration,
addressCache: addressCache
)
case let .socks5(configuration):
currentTransport = URLSessionSocks5Transport(
urlSession: urlSessionTransport.urlSession,
configuration: configuration,
addressCache: addressCache
)
case .none:
currentTransport = nil
if currentTransport == nil || shouldNotReuseCurrentTransport {
currentTransportType = transportStrategy.connectionTransport()
switch currentTransportType {
case .direct:
currentTransport = urlSessionTransport
case let .shadowsocks(configuration):
currentTransport = ShadowsocksTransport(
urlSession: urlSessionTransport.urlSession,
configuration: configuration,
addressCache: addressCache
)
case let .socks5(configuration):
currentTransport = URLSessionSocks5Transport(
urlSession: urlSessionTransport.urlSession,
configuration: configuration,
addressCache: addressCache
)
case .none:
currentTransport = nil
}
}
return currentTransport
}

/// The `Main` allows modifications to access methods through the UI.
/// The `TransportProvider` relies on a `CurrentTransport` value set during build time or network error.
/// To ensure both process `Packet Tunnel` and `Main` uses the latest changes, the `TransportProvider` compares the `transportType` with the latest value in the cache and reuse it if it's still valid .
private var shouldNotReuseCurrentTransport: Bool {
currentTransportType != transportStrategy.connectionTransport()
}
}

private extension URLError {
Expand Down
15 changes: 14 additions & 1 deletion ios/MullvadREST/Transport/TransportStrategy.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import MullvadTypes

public class TransportStrategy: Equatable {
/// The different transports suggested by the strategy
public enum Transport {
public enum Transport: Equatable {
/// Connecting a direct connection
case direct

Expand All @@ -25,6 +25,19 @@ public class TransportStrategy: Equatable {

/// Failing to retrive transport
case none

public static func == (lhs: Self, rhs: Self) -> Bool {
switch (lhs, rhs) {
case(.direct, .direct), (.none, .none):
return true
case let (.shadowsocks(lhsConfiguration), .shadowsocks(rhsConfiguration)):
return lhsConfiguration == rhsConfiguration
case let (.socks5(lhsConfiguration), .socks5(rhsConfiguration)):
return lhsConfiguration == rhsConfiguration
default:
return false
}
}
}

private let shadowsocksLoader: ShadowsocksLoaderProtocol
Expand Down
5 changes: 4 additions & 1 deletion ios/MullvadRESTTests/AccessMethodRepositoryStub.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import MullvadSettings
typealias PersistentAccessMethod = MullvadSettings.PersistentAccessMethod
struct AccessMethodRepositoryStub: AccessMethodRepositoryDataSource {
var directAccess: MullvadSettings.PersistentAccessMethod

var publisher: AnyPublisher<[MullvadSettings.PersistentAccessMethod], Never> {
passthroughSubject.eraseToAnyPublisher()
}
Expand All @@ -23,4 +22,8 @@ struct AccessMethodRepositoryStub: AccessMethodRepositoryDataSource {
directAccess = accessMethods.first(where: { $0.kind == .direct })!
passthroughSubject.send(accessMethods)
}

func fetchAll() -> [MullvadSettings.PersistentAccessMethod] {
passthroughSubject.value
}
}
17 changes: 0 additions & 17 deletions ios/MullvadRESTTests/TransportStrategyTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -238,23 +238,6 @@ class TransportStrategyTests: XCTestCase {
}
}

extension TransportStrategy.Transport: Equatable {
public static func == (lhs: Self, rhs: Self) -> Bool {
switch (lhs, rhs) {
case(.direct, .direct), (.none, .none):
return true
case let (.shadowsocks(config1), .shadowsocks(config2)):
return config1.port == config2.port && config1.cipher == config2.cipher && config1.password == config2
.password
case let (.socks5(config1), .socks5(config2)):
return config1.proxyEndpoint == config2.proxyEndpoint && config1.username == config2.username && config1
.password == config2.password
default:
return false
}
}
}

private enum IOError: Error {
case fileNotFound
}
9 changes: 5 additions & 4 deletions ios/MullvadSettings/AccessMethodRepositoryProtocol.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@ public protocol AccessMethodRepositoryDataSource {
/// Publisher that propagates a snapshot of persistent store upon modifications.
var publisher: AnyPublisher<[PersistentAccessMethod], Never> { get }

/// - Returns: the default strategy.
var directAccess: PersistentAccessMethod { get }

/// Fetch all access method from the persistent store.
/// - Returns: an array of all persistent access method.
func fetchAll() -> [PersistentAccessMethod]
}

public protocol AccessMethodRepositoryProtocol: AccessMethodRepositoryDataSource {
Expand All @@ -34,10 +39,6 @@ public protocol AccessMethodRepositoryProtocol: AccessMethodRepositoryDataSource
/// - Returns: a persistent access method model upon success, otherwise `nil`.
func fetch(by id: UUID) -> PersistentAccessMethod?

/// Fetch all access method from the persistent store.
/// - Returns: an array of all persistent access method.
func fetchAll() -> [PersistentAccessMethod]

/// Refreshes the storage with default values.
func reloadWithDefaultsAfterDataRemoval()
}
Loading