diff --git a/.swiftpm/xcode/xcshareddata/xcschemes/AWSCognitoAuthPlugin.xcscheme b/.swiftpm/xcode/xcshareddata/xcschemes/AWSCognitoAuthPlugin.xcscheme
index d92bafde07..96d011b6cc 100644
--- a/.swiftpm/xcode/xcshareddata/xcschemes/AWSCognitoAuthPlugin.xcscheme
+++ b/.swiftpm/xcode/xcshareddata/xcschemes/AWSCognitoAuthPlugin.xcscheme
@@ -48,11 +48,6 @@
BlueprintName = "AWSCognitoAuthPluginUnitTests"
ReferencedContainer = "container:">
-
-
-
-
diff --git a/AmplifyPlugins/Auth/Sources/AWSCognitoAuthPlugin/Models/AWSAuthCognitoSession.swift b/AmplifyPlugins/Auth/Sources/AWSCognitoAuthPlugin/Models/AWSAuthCognitoSession.swift
index 6c46eceb40..59d799e963 100644
--- a/AmplifyPlugins/Auth/Sources/AWSCognitoAuthPlugin/Models/AWSAuthCognitoSession.swift
+++ b/AmplifyPlugins/Auth/Sources/AWSCognitoAuthPlugin/Models/AWSAuthCognitoSession.swift
@@ -76,24 +76,6 @@ public struct AWSAuthCognitoSession: AuthSession,
}
-/// Internal Helpers for managing session tokens
-internal extension AWSAuthCognitoSession {
- func areTokensExpiring(in seconds: TimeInterval? = nil) -> Bool {
-
- guard let tokens = try? userPoolTokensResult.get(),
- let idTokenClaims = try? AWSAuthService().getTokenClaims(tokenString: tokens.idToken).get(),
- let accessTokenClaims = try? AWSAuthService().getTokenClaims(tokenString: tokens.idToken).get(),
- let idTokenExpiration = idTokenClaims["exp"]?.doubleValue,
- let accessTokenExpiration = accessTokenClaims["exp"]?.doubleValue else {
- return true
- }
-
- // If the session expires < X minutes return it
- return (Date(timeIntervalSince1970: idTokenExpiration).compare(Date(timeIntervalSinceNow: seconds ?? 0)) == .orderedDescending &&
- Date(timeIntervalSince1970: accessTokenExpiration).compare(Date(timeIntervalSinceNow: seconds ?? 0)) == .orderedDescending)
- }
-}
-
extension AWSAuthCognitoSession: Equatable {
public static func == (lhs: AWSAuthCognitoSession, rhs: AWSAuthCognitoSession) -> Bool {
switch (lhs.getCognitoTokens(), rhs.getCognitoTokens()) {
diff --git a/AmplifyPlugins/Auth/Sources/AWSCognitoAuthPlugin/Models/AWSCognitoUserPoolTokens.swift b/AmplifyPlugins/Auth/Sources/AWSCognitoAuthPlugin/Models/AWSCognitoUserPoolTokens.swift
index af7d80f96a..c5f4daed06 100644
--- a/AmplifyPlugins/Auth/Sources/AWSCognitoAuthPlugin/Models/AWSCognitoUserPoolTokens.swift
+++ b/AmplifyPlugins/Auth/Sources/AWSCognitoAuthPlugin/Models/AWSCognitoUserPoolTokens.swift
@@ -65,10 +65,10 @@ public struct AWSCognitoUserPoolTokens: AuthCognitoTokens {
case (.some(let idTokenValue), .none):
expirationDoubleValue = idTokenValue
case (.none, .none):
- expirationDoubleValue = 0
+ expirationDoubleValue = Date().timeIntervalSince1970
}
- self.expiration = Date().addingTimeInterval(TimeInterval((expirationDoubleValue ?? 0)))
+ self.expiration = Date(timeIntervalSince1970: TimeInterval(expirationDoubleValue))
}
}
diff --git a/AmplifyPlugins/Auth/Sources/AWSCognitoAuthPlugin/Support/Helpers/AuthCognitoSignedOutSessionHelper.swift b/AmplifyPlugins/Auth/Sources/AWSCognitoAuthPlugin/Support/Helpers/AuthCognitoSignedOutSessionHelper.swift
index 9ded44922d..8e391355e5 100644
--- a/AmplifyPlugins/Auth/Sources/AWSCognitoAuthPlugin/Support/Helpers/AuthCognitoSignedOutSessionHelper.swift
+++ b/AmplifyPlugins/Auth/Sources/AWSCognitoAuthPlugin/Support/Helpers/AuthCognitoSignedOutSessionHelper.swift
@@ -25,27 +25,6 @@ struct AuthCognitoSignedOutSessionHelper {
return authSession
}
- /// Guest/SignedOut session with any unhandled error
- ///
- /// The unhandled error is passed as identityId and aws credentials result. UserSub and Cognito Tokens will still
- /// have signOut error.
- ///
- /// - Parameter error: Unhandled error
- /// - Returns: Session will have isSignedIn = false
- private static func makeSignedOutSession(withUnhandledError error: AuthError) -> AWSAuthCognitoSession {
-
- let identityIdError = error
- let awsCredentialsError = error
-
- let tokensError = makeCognitoTokensSignedOutError()
-
- let authSession = AWSAuthCognitoSession(isSignedIn: false,
- identityIdResult: .failure(identityIdError),
- awsCredentialsResult: .failure(awsCredentialsError),
- cognitoTokensResult: .failure(tokensError))
- return authSession
- }
-
/// Guest/SignOut session when the guest access is not enabled.
/// - Returns: Session with isSignedIn = false
static func makeSessionWithNoGuestAccess() -> AWSAuthCognitoSession {
@@ -68,26 +47,6 @@ struct AuthCognitoSignedOutSessionHelper {
return authSession
}
- private static func makeOfflineSignedOutSession() -> AWSAuthCognitoSession {
- let identityIdError = AuthError.service(
- AuthPluginErrorConstants.identityIdOfflineError.errorDescription,
- AuthPluginErrorConstants.identityIdOfflineError.recoverySuggestion,
- AWSCognitoAuthError.network)
-
- let awsCredentialsError = AuthError.service(
- AuthPluginErrorConstants.awsCredentialsOfflineError.errorDescription,
- AuthPluginErrorConstants.awsCredentialsOfflineError.recoverySuggestion,
- AWSCognitoAuthError.network)
-
- let tokensError = makeCognitoTokensSignedOutError()
-
- let authSession = AWSAuthCognitoSession(isSignedIn: false,
- identityIdResult: .failure(identityIdError),
- awsCredentialsResult: .failure(awsCredentialsError),
- cognitoTokensResult: .failure(tokensError))
- return authSession
- }
-
/// Guest/SignedOut session with couldnot retreive either aws credentials or identity id.
/// - Returns: Session will have isSignedIn = false
private static func makeSignedOutSessionWithServiceIssue() -> AWSAuthCognitoSession {
@@ -109,13 +68,6 @@ struct AuthCognitoSignedOutSessionHelper {
return authSession
}
- private static func makeUserSubSignedOutError() -> AuthError {
- let userSubError = AuthError.signedOut(
- AuthPluginErrorConstants.userSubSignOutError.errorDescription,
- AuthPluginErrorConstants.userSubSignOutError.recoverySuggestion)
- return userSubError
- }
-
private static func makeCognitoTokensSignedOutError() -> AuthError {
let tokensError = AuthError.signedOut(
AuthPluginErrorConstants.cognitoTokensSignOutError.errorDescription,
diff --git a/AmplifyPlugins/Auth/Sources/AWSCognitoAuthPlugin/Support/HostedUI/HostedUIASWebAuthenticationSession.swift b/AmplifyPlugins/Auth/Sources/AWSCognitoAuthPlugin/Support/HostedUI/HostedUIASWebAuthenticationSession.swift
index cd9760637a..9c225ec931 100644
--- a/AmplifyPlugins/Auth/Sources/AWSCognitoAuthPlugin/Support/HostedUI/HostedUIASWebAuthenticationSession.swift
+++ b/AmplifyPlugins/Auth/Sources/AWSCognitoAuthPlugin/Support/HostedUI/HostedUIASWebAuthenticationSession.swift
@@ -22,7 +22,7 @@ class HostedUIASWebAuthenticationSession: NSObject, HostedUISessionBehavior {
callback: @escaping (Result<[URLQueryItem], HostedUIError>) -> Void) {
#if os(iOS) || os(macOS)
self.webPresentation = presentationAnchor
- let aswebAuthenticationSession = ASWebAuthenticationSession(
+ let aswebAuthenticationSession = createAuthenticationSession(
url: url,
callbackURLScheme: callbackScheme,
completionHandler: { url, error in
@@ -58,6 +58,16 @@ class HostedUIASWebAuthenticationSession: NSObject, HostedUISessionBehavior {
}
#if os(iOS) || os(macOS)
+ var authenticationSessionFactory = ASWebAuthenticationSession.init(url:callbackURLScheme:completionHandler:)
+
+ private func createAuthenticationSession(
+ url: URL,
+ callbackURLScheme: String?,
+ completionHandler: @escaping ASWebAuthenticationSession.CompletionHandler
+ ) -> ASWebAuthenticationSession {
+ return authenticationSessionFactory(url, callbackURLScheme, completionHandler)
+ }
+
private func convertHostedUIError(_ error: Error) -> HostedUIError {
if let asWebAuthError = error as? ASWebAuthenticationSessionError {
switch asWebAuthError.code {
diff --git a/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/ActionTests/CredentialStore/MigrateLegacyCredentialStoreTests.swift b/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/ActionTests/CredentialStore/MigrateLegacyCredentialStoreTests.swift
index 06ebd109c9..1252888632 100644
--- a/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/ActionTests/CredentialStore/MigrateLegacyCredentialStoreTests.swift
+++ b/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/ActionTests/CredentialStore/MigrateLegacyCredentialStoreTests.swift
@@ -121,7 +121,122 @@ class MigrateLegacyCredentialStoreTests: XCTestCase {
await fulfillment(
of: [migrationCompletionInvoked],
+
timeout: 0.1
)
}
+
+ /// - Given: A credential store with an invalid environment
+ /// - When: The migration legacy store action is executed
+ /// - Then: An error event of type configuration is dispatched
+ func testExecute_withInvalidEnvironment_shouldDispatchError() async {
+ let expectation = expectation(description: "noEnvironment")
+ let action = MigrateLegacyCredentialStore()
+ await action.execute(
+ withDispatcher: MockDispatcher { event in
+ guard let event = event as? CredentialStoreEvent,
+ case let .throwError(error) = event.eventType else {
+ XCTFail("Expected failure due to no CredentialEnvironment")
+ expectation.fulfill()
+ return
+ }
+ XCTAssertEqual(error, .configuration(message: AuthPluginErrorConstants.configurationError))
+ expectation.fulfill()
+ },
+ environment: MockInvalidEnvironment()
+ )
+ await fulfillment(of: [expectation], timeout: 1)
+ }
+
+ /// - Given: A credential store with an environment that only has identity pool
+ /// - When: The migration legacy store action is executed
+ /// - Then:
+ /// - A .loadCredentialStore event with type .amplifyCredentials is dispatched
+ /// - An .identityPoolOnly credential is saved
+ func testExecute_withoutUserPool_andWithoutLoginsTokens_shouldDispatchLoadEvent() async {
+ let expectation = expectation(description: "noUserPoolTokens")
+ let action = MigrateLegacyCredentialStore()
+ await action.execute(
+ withDispatcher: MockDispatcher { event in
+ guard let event = event as? CredentialStoreEvent,
+ case .loadCredentialStore(let type) = event.eventType else {
+ XCTFail("Expected .loadCredentialStore")
+ expectation.fulfill()
+ return
+ }
+ XCTAssertEqual(type, .amplifyCredentials)
+ expectation.fulfill()
+ },
+ environment: CredentialEnvironment(
+ authConfiguration: .identityPools(.testData),
+ credentialStoreEnvironment: BasicCredentialStoreEnvironment(
+ amplifyCredentialStoreFactory: {
+ MockAmplifyCredentialStoreBehavior(
+ saveCredentialHandler: { codableCredentials in
+ guard let amplifyCredentials = codableCredentials as? AmplifyCredentials,
+ case .identityPoolOnly(_, let credentials) = amplifyCredentials else {
+ XCTFail("Expected .identityPoolOnly")
+ return
+ }
+ XCTAssertFalse(credentials.sessionToken.isEmpty)
+ }
+ )
+ },
+ legacyKeychainStoreFactory: { _ in
+ MockKeychainStoreBehavior(data: "hostedUI")
+ }),
+ logger: MigrateLegacyCredentialStore.log
+ )
+ )
+ await fulfillment(of: [expectation], timeout: 1)
+ }
+
+ /// - Given: A credential store with an environment that only has identity pool
+ /// - When: The migration legacy store action is executed
+ /// - A .loadCredentialStore event with type .amplifyCredentials is dispatched
+ /// - An .identityPoolWithFederation credential is saved
+ func testExecute_withoutUserPool_andWithLoginsTokens_shouldDispatchLoadEvent() async {
+ let expectation = expectation(description: "noUserPoolTokens")
+ let action = MigrateLegacyCredentialStore()
+ await action.execute(
+ withDispatcher: MockDispatcher { event in
+ guard let event = event as? CredentialStoreEvent,
+ case .loadCredentialStore(let type) = event.eventType else {
+ XCTFail("Expected .loadCredentialStore")
+ expectation.fulfill()
+ return
+ }
+ XCTAssertEqual(type, .amplifyCredentials)
+ expectation.fulfill()
+ },
+ environment: CredentialEnvironment(
+ authConfiguration: .identityPools(.testData),
+ credentialStoreEnvironment: BasicCredentialStoreEnvironment(
+ amplifyCredentialStoreFactory: {
+ MockAmplifyCredentialStoreBehavior(
+ saveCredentialHandler: { codableCredentials in
+ guard let amplifyCredentials = codableCredentials as? AmplifyCredentials,
+ case .identityPoolWithFederation(let token, _, _) = amplifyCredentials else {
+ XCTFail("Expected .identityPoolWithFederation")
+ return
+ }
+
+ XCTAssertEqual(token.token, "token")
+ XCTAssertEqual(token.provider.userPoolProviderName, "provider")
+ }
+ )
+ },
+ legacyKeychainStoreFactory: { _ in
+ let data = try! JSONEncoder().encode([
+ "provider": "token"
+ ])
+ return MockKeychainStoreBehavior(
+ data: String(decoding: data, as: UTF8.self)
+ )
+ }),
+ logger: action.log
+ )
+ )
+ await fulfillment(of: [expectation], timeout: 1)
+ }
}
diff --git a/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/ActionTests/FetchAuthSession/FetchUserPoolTokens/RefreshHostedUITokensTests.swift b/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/ActionTests/FetchAuthSession/FetchUserPoolTokens/RefreshHostedUITokensTests.swift
new file mode 100644
index 0000000000..2b94f7721f
--- /dev/null
+++ b/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/ActionTests/FetchAuthSession/FetchUserPoolTokens/RefreshHostedUITokensTests.swift
@@ -0,0 +1,310 @@
+//
+// Copyright Amazon.com Inc. or its affiliates.
+// All Rights Reserved.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#if os(iOS) || os(macOS)
+
+@testable import AWSCognitoAuthPlugin
+import AWSCognitoIdentityProvider
+import AWSPluginsCore
+import XCTest
+
+class RefreshHostedUITokensTests: XCTestCase {
+ private let tokenResult: [String: Any] = [
+ "id_token": AWSCognitoUserPoolTokens.testData.idToken,
+ "access_token": AWSCognitoUserPoolTokens.testData.accessToken,
+ "refresh_token": AWSCognitoUserPoolTokens.testData.refreshToken,
+ "expires_in": 10
+ ]
+
+ private var hostedUIEnvironment: HostedUIEnvironment {
+ BasicHostedUIEnvironment(
+ configuration: .init(
+ clientId: "clientId",
+ oauth: .init(
+ domain: "cognitodomain",
+ scopes: ["name"],
+ signInRedirectURI: "myapp://",
+ signOutRedirectURI: "myapp://"
+ )
+ ),
+ hostedUISessionFactory: sessionFactory,
+ urlSessionFactory: urlSessionMock,
+ randomStringFactory: mockRandomString
+ )
+ }
+
+ override func setUp() {
+ let result = try! JSONSerialization.data(withJSONObject: tokenResult)
+ MockURLProtocol.requestHandler = { _ in
+ return (HTTPURLResponse(), result)
+ }
+ }
+
+ override func tearDown() {
+ MockURLProtocol.requestHandler = nil
+ }
+
+ /// Given: A RefreshHostedUITokens action
+ /// When: execute is invoked with a valid response
+ /// Then: A RefreshSessionEvent.refreshIdentityInfo is dispatched
+ func testExecute_withValidResponse_shouldDispatchRefreshEvent() async {
+ let expectation = expectation(description: "refreshHostedUITokens")
+ let action = RefreshHostedUITokens(existingSignedIndata: .testData)
+ action.execute(
+ withDispatcher: MockDispatcher { event in
+ guard let event = event as? RefreshSessionEvent,
+ case .refreshIdentityInfo(let data, _) = event.eventType else {
+ XCTFail("Failed to refresh tokens")
+ expectation.fulfill()
+ return
+ }
+
+ XCTAssertEqual(data.cognitoUserPoolTokens.idToken, self.tokenResult["id_token"] as? String)
+ XCTAssertEqual(data.cognitoUserPoolTokens.accessToken, self.tokenResult["access_token"] as? String)
+ XCTAssertEqual(data.cognitoUserPoolTokens.refreshToken, self.tokenResult["refresh_token"] as? String)
+ expectation.fulfill()
+ },
+ environment: Defaults.makeDefaultAuthEnvironment(
+ userPoolFactory: identityProviderFactory,
+ hostedUIEnvironment: hostedUIEnvironment
+ )
+ )
+
+ await fulfillment(of: [expectation], timeout: 1)
+ }
+
+ /// Given: A RefreshHostedUITokens action
+ /// When: execute is invoked and throws a HostedUIError
+ /// Then: A RefreshSessionEvent.throwError is dispatched with .service
+ func testExecute_withHostedUIError_shouldDispatchErrorEvent() async {
+ let expectedError = HostedUIError.serviceMessage("Something went wrong")
+ MockURLProtocol.requestHandler = { _ in
+ throw expectedError
+ }
+
+ let expectation = expectation(description: "refreshHostedUITokens")
+ let action = RefreshHostedUITokens(existingSignedIndata: .testData)
+ action.execute(
+ withDispatcher: MockDispatcher { event in
+ guard let event = event as? RefreshSessionEvent,
+ case let .throwError(error) = event.eventType else {
+ XCTFail("Expected failure due to Service Error")
+ expectation.fulfill()
+ return
+ }
+
+ XCTAssertEqual(error, .service(expectedError))
+ expectation.fulfill()
+ },
+ environment: Defaults.makeDefaultAuthEnvironment(
+ userPoolFactory: identityProviderFactory,
+ hostedUIEnvironment: hostedUIEnvironment
+ )
+ )
+
+ await fulfillment(of: [expectation], timeout: 1)
+ }
+
+ /// Given: A RefreshHostedUITokens action
+ /// When: execute is invoked and returns empty data
+ /// Then: A RefreshSessionEvent.throwError is dispatched with .service
+ func testExecute_withEmptyData_shouldDispatchErrorEvent() async {
+ MockURLProtocol.requestHandler = { _ in
+ return (HTTPURLResponse(), Data())
+ }
+
+ let expectation = expectation(description: "refreshHostedUITokens")
+ let action = RefreshHostedUITokens(existingSignedIndata: .testData)
+ action.execute(
+ withDispatcher: MockDispatcher { event in
+ guard let event = event as? RefreshSessionEvent,
+ case let .throwError(error) = event.eventType else {
+ XCTFail("Expected failure due to Invalid Tokens")
+ expectation.fulfill()
+ return
+ }
+
+ guard case .service(let serviceError) = error else {
+ XCTFail("Expected FetchSessionError.service, got \(error)")
+ expectation.fulfill()
+ return
+ }
+
+
+ XCTAssertEqual((serviceError as NSError).code, NSPropertyListReadCorruptError)
+ expectation.fulfill()
+ },
+ environment: Defaults.makeDefaultAuthEnvironment(
+ userPoolFactory: identityProviderFactory,
+ hostedUIEnvironment: hostedUIEnvironment
+ )
+ )
+
+ await fulfillment(of: [expectation], timeout: 1)
+ }
+
+ /// Given: A RefreshHostedUITokens action
+ /// When: execute is invoked and returns data that is invalid for tokens
+ /// Then: A RefreshSessionEvent.throwError is dispatched with .invalidTokens
+ func testExecute_withInvalidTokens_shouldDispatchErrorEvent() async {
+ let result: [String: Any] = [
+ "key": "value"
+ ]
+ MockURLProtocol.requestHandler = { _ in
+ return (HTTPURLResponse(), try! JSONSerialization.data(withJSONObject: result))
+ }
+
+ let expectation = expectation(description: "refreshHostedUITokens")
+ let action = RefreshHostedUITokens(existingSignedIndata: .testData)
+ action.execute(
+ withDispatcher: MockDispatcher { event in
+ guard let event = event as? RefreshSessionEvent,
+ case let .throwError(error) = event.eventType else {
+ XCTFail("Expected failure due to Invalid Tokens")
+ expectation.fulfill()
+ return
+ }
+
+
+ XCTAssertEqual(error, .invalidTokens)
+ expectation.fulfill()
+ },
+ environment: Defaults.makeDefaultAuthEnvironment(
+ userPoolFactory: identityProviderFactory,
+ hostedUIEnvironment: hostedUIEnvironment
+ )
+ )
+
+ await fulfillment(of: [expectation], timeout: 1)
+ }
+
+ /// Given: A RefreshHostedUITokens action
+ /// When: execute is invoked and returns data representing an error
+ /// Then: A RefreshSessionEvent.throwError is dispatched with .service
+ func testExecute_withErrorResponse_shouldDispatchErrorEvent() async {
+ let result: [String: Any] = [
+ "error": "Error.",
+ "error_description": "Something went wrong"
+ ]
+ MockURLProtocol.requestHandler = { _ in
+ return (HTTPURLResponse(), try! JSONSerialization.data(withJSONObject: result))
+ }
+
+ let expectation = expectation(description: "refreshHostedUITokens")
+ let action = RefreshHostedUITokens(existingSignedIndata: .testData)
+ action.execute(
+ withDispatcher: MockDispatcher { event in
+ guard let event = event as? RefreshSessionEvent,
+ case let .throwError(error) = event.eventType else {
+ XCTFail("Expected failure due to Invalid Tokens")
+ expectation.fulfill()
+ return
+ }
+
+ guard case .service(let serviceError) = error,
+ case .serviceMessage(let errorMessage) = serviceError as? HostedUIError else {
+ XCTFail("Expected HostedUIError.serviceMessage, got \(error)")
+ expectation.fulfill()
+ return
+ }
+
+
+ XCTAssertEqual(errorMessage, "Error. Something went wrong")
+ expectation.fulfill()
+ },
+ environment: Defaults.makeDefaultAuthEnvironment(
+ userPoolFactory: identityProviderFactory,
+ hostedUIEnvironment: hostedUIEnvironment
+ )
+ )
+
+ await fulfillment(of: [expectation], timeout: 1)
+ }
+
+ /// Given: A RefreshHostedUITokens action
+ /// When: execute is invoked without a HostedUIEnvironment
+ /// Then: A RefreshSessionEvent.throwError is dispatched with .noUserPool
+ func testExecute_withoutHostedUIEnvironment_shouldDispatchErrorEvent() async {
+ let expectation = expectation(description: "noHostedUIEnvironment")
+ let action = RefreshHostedUITokens(existingSignedIndata: .testData)
+ action.execute(
+ withDispatcher: MockDispatcher { event in
+ guard let event = event as? RefreshSessionEvent,
+ case let .throwError(error) = event.eventType else {
+ XCTFail("Expected failure due to no HostedUIEnvironment")
+ expectation.fulfill()
+ return
+ }
+
+ XCTAssertEqual(error, .noUserPool)
+ expectation.fulfill()
+ },
+ environment: Defaults.makeDefaultAuthEnvironment(
+ userPoolFactory: identityProviderFactory,
+ hostedUIEnvironment: nil
+ )
+ )
+
+ await fulfillment(of: [expectation], timeout: 1)
+ }
+
+ /// Given: A RefreshHostedUITokens action
+ /// When: execute is invoked without a UserPoolEnvironment
+ /// Then: A RefreshSessionEvent.throwError is dispatched with .noUserPool
+ func testExecute_withoutUserPoolEnvironment_shouldDispatchErrorEvent() async {
+ let expectation = expectation(description: "noUserPoolEnvironment")
+ let action = RefreshHostedUITokens(existingSignedIndata: .testData)
+ action.execute(
+ withDispatcher: MockDispatcher { event in
+ guard let event = event as? RefreshSessionEvent,
+ case let .throwError(error) = event.eventType else {
+ XCTFail("Expected failure due to no UserPoolEnvironment")
+ expectation.fulfill()
+ return
+ }
+
+ XCTAssertEqual(error, .noUserPool)
+ expectation.fulfill()
+ },
+ environment: MockInvalidEnvironment()
+ )
+
+ await fulfillment(of: [expectation], timeout: 1)
+ }
+
+ private func identityProviderFactory() throws -> CognitoUserPoolBehavior {
+ return MockIdentityProvider(
+ mockInitiateAuthResponse: { _ in
+ return InitiateAuthOutputResponse(
+ authenticationResult: .init(
+ accessToken: "accessTokenNew",
+ expiresIn: 100,
+ idToken: "idTokenNew",
+ refreshToken: "refreshTokenNew")
+ )
+ }
+ )
+ }
+
+ private func urlSessionMock() -> URLSession {
+ let configuration = URLSessionConfiguration.ephemeral
+ configuration.protocolClasses = [MockURLProtocol.self]
+ return URLSession(configuration: configuration)
+ }
+
+ private func sessionFactory() -> HostedUISessionBehavior {
+ MockHostedUISession(result: .failure(.cancelled))
+ }
+
+ private func mockRandomString() -> RandomStringBehavior {
+ return MockRandomStringGenerator(
+ mockString: "mockString",
+ mockUUID: "mockUUID"
+ )
+ }
+}
+#endif
diff --git a/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/ActionTests/InitiateAuthSRP/VerifyDevicePasswordSRPSignatureTests.swift b/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/ActionTests/InitiateAuthSRP/VerifyDevicePasswordSRPSignatureTests.swift
new file mode 100644
index 0000000000..53afd8af8f
--- /dev/null
+++ b/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/ActionTests/InitiateAuthSRP/VerifyDevicePasswordSRPSignatureTests.swift
@@ -0,0 +1,177 @@
+//
+// Copyright Amazon.com Inc. or its affiliates.
+// All Rights Reserved.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+
+@testable import AWSCognitoAuthPlugin
+import AWSCognitoIdentityProvider
+@testable import AWSPluginsTestCommon
+import XCTest
+
+class VerifyDevicePasswordSRPSignatureTests: XCTestCase {
+ private var srpClient: MockSRPClientBehavior!
+
+ override func setUp() async throws {
+ MockSRPClientBehavior.reset()
+ srpClient = MockSRPClientBehavior()
+ }
+
+ override func tearDown() {
+ MockSRPClientBehavior.reset()
+ srpClient = nil
+ }
+
+ /// Given: A VerifyDevicePasswordSRP
+ /// When: signature is invoked
+ /// Then: a non-empty string is returned
+ func testSignature_withValidValues_shouldReturnSignature() async {
+ do {
+ let signature = try signature()
+ XCTAssertFalse(signature.isEmpty)
+ } catch {
+ XCTFail("Should not throw error: \(error)")
+ }
+ }
+
+ /// Given: A VerifyDevicePasswordSRP
+ /// When: signature is invoked and the srpClient throws an SRPError error when generating a shared secret
+ /// Then: a .calculation error is thrown
+ func testSignature_withSRPErrorOnSharedSecret_shouldThrowCalculationError() async {
+ srpClient.sharedSecret = .failure(SRPError.numberConversion)
+ do {
+ try signature()
+ XCTFail("Should not succeed")
+ } catch {
+ guard case .calculation(let srpError) = error as? SignInError else {
+ XCTFail("Expected SRPError.calculation, got \(error)")
+ return
+ }
+
+ XCTAssertEqual(srpError, .numberConversion)
+ }
+ }
+
+ /// Given: A VerifyDevicePasswordSRP
+ /// When: signature is invoked and the srpClient throws a non-SRPError error when generating a shared secret
+ /// Then: a .configuration error is thrown
+ func testSignature_withOtherErrorOnSharedSecret_shouldThrowCalculationError() async {
+ srpClient.sharedSecret = .failure(CancellationError())
+ do {
+ try signature()
+ XCTFail("Should not succeed")
+ } catch {
+ guard case .configuration(let message) = error as? SignInError else {
+ XCTFail("Expected SRPError.configuration, got \(error)")
+ return
+ }
+
+ XCTAssertEqual(message, "Could not calculate shared secret")
+ }
+ }
+
+ /// Given: A VerifyDevicePasswordSRP
+ /// When: signature is invoked and the srpClient throws a SRPError error when generating an authentication key
+ /// Then: a .calculation error is thrown
+ func testSignature_withSRPErrorOnAuthenticationKey_shouldThrowCalculationError() async {
+ MockSRPClientBehavior.authenticationKey = .failure(SRPError.numberConversion)
+ do {
+ try signature()
+ XCTFail("Should not succeed")
+ } catch {
+ guard case .calculation(let srpError) = error as? SignInError else {
+ XCTFail("Expected SRPError.calculation, got \(error)")
+ return
+ }
+
+ XCTAssertEqual(srpError, .numberConversion)
+ }
+ }
+
+ /// Given: A VerifyDevicePasswordSRP
+ /// When: signature is invoked and the srpClient throws a non-SRPError error when generating an authentication key
+ /// Then: a .configuration error is thrown
+ func testSignature_withOtherErrorOnAuthenticationKey_shouldThrowCalculationError() async {
+ MockSRPClientBehavior.authenticationKey = .failure(CancellationError())
+ do {
+ try signature()
+ XCTFail("Should not succeed")
+ } catch {
+ guard case .configuration(let message) = error as? SignInError else {
+ XCTFail("Expected SRPError.configuration, got \(error)")
+ return
+ }
+
+ XCTAssertEqual(message, "Could not calculate signature")
+ }
+ }
+
+ @discardableResult
+ private func signature() throws -> String {
+ let action = VerifyDevicePasswordSRP(
+ stateData: .testData,
+ authResponse: InitiateAuthOutputResponse.validTestData
+ )
+
+ return try action.signature(
+ deviceGroupKey: "deviceGroupKey",
+ deviceKey: "deviceKey",
+ deviceSecret: "deviceSecret",
+ saltHex: "saltHex",
+ secretBlock: "secretBlock".data(using: .utf8) ?? Data(),
+ serverPublicBHexString: "serverPublicBHexString",
+ srpClient: srpClient
+ )
+ }
+}
+
+private class MockSRPClientBehavior: SRPClientBehavior {
+ var kHexValue: String = "kHexValue"
+
+ static func calculateUHexValue(
+ clientPublicKeyHexValue: String,
+ serverPublicKeyHexValue: String
+ ) throws -> String {
+ return "UHexValue"
+ }
+
+ static var authenticationKey: Result = .success("AuthenticationKey".data(using: .utf8)!)
+ static func generateAuthenticationKey(
+ sharedSecretHexValue: String,
+ uHexValue: String
+ ) throws -> Data {
+ return try authenticationKey.get()
+ }
+
+ static func reset() {
+ authenticationKey = .success("AuthenticationKey".data(using: .utf8)!)
+ }
+
+ func generateClientKeyPair() -> SRPKeys {
+ return .init(
+ publicKeyHexValue: "publicKeyHexValue",
+ privateKeyHexValue: "privateKeyHexValue"
+ )
+ }
+
+ var sharedSecret: Result = .success("SharedSecret")
+ func calculateSharedSecret(
+ username: String,
+ password: String,
+ saltHexValue: String,
+ clientPrivateKeyHexValue: String,
+ clientPublicKeyHexValue: String,
+ serverPublicKeyHexValue: String
+ ) throws -> String {
+ return try sharedSecret.get()
+ }
+
+ func generateDevicePasswordVerifier(
+ deviceGroupKey: String,
+ deviceKey: String,
+ password: String
+ ) -> (salt: Data, passwordVerifier: Data) {
+ return (salt: Data(), passwordVerifier: Data())
+ }
+}
diff --git a/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/ActionTests/SignOut/ShowHostedUISignOutTests.swift b/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/ActionTests/SignOut/ShowHostedUISignOutTests.swift
new file mode 100644
index 0000000000..0751703550
--- /dev/null
+++ b/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/ActionTests/SignOut/ShowHostedUISignOutTests.swift
@@ -0,0 +1,401 @@
+//
+// Copyright Amazon.com Inc. or its affiliates.
+// All Rights Reserved.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+
+@testable import AWSCognitoAuthPlugin
+import AWSCognitoIdentityProvider
+import AWSPluginsCore
+import XCTest
+
+class ShowHostedUISignOutTests: XCTestCase {
+ private var mockHostedUIResult: Result<[URLQueryItem], HostedUIError>!
+ private var signOutRedirectURI: String!
+
+ override func setUp() {
+ signOutRedirectURI = "myapp://"
+ mockHostedUIResult = .success([.init(name: "key", value: "value")])
+ }
+
+ override func tearDown() {
+ signOutRedirectURI = nil
+ mockHostedUIResult = nil
+ }
+
+ /// Given: A ShowHostedUISignOut action with global sign out set to true
+ /// When: execute is invoked with a success result
+ /// Then: A .signOutGlobally event is dispatched with a nil error
+ func testExecute_withGlobalSignOut_andSuccessResult_shouldDispatchSignOutEvent() async {
+ let expectation = expectation(description: "showHostedUISignOut")
+ let signInData = SignedInData.testData
+ let action = ShowHostedUISignOut(
+ signOutEvent: SignOutEventData(globalSignOut: true),
+ signInData: signInData
+ )
+
+ await action.execute(
+ withDispatcher: MockDispatcher { event in
+ guard let event = event as? SignOutEvent,
+ case .signOutGlobally(let data, let error) = event.eventType else {
+ XCTFail("Expected SignOutEvent.signOutGlobally, got \(event)")
+ expectation.fulfill()
+ return
+ }
+
+ XCTAssertNil(error)
+ XCTAssertEqual(data, signInData)
+ self.validateDebugInformation(signInData: signInData, action: action)
+
+ expectation.fulfill()
+ },
+ environment: Defaults.makeDefaultAuthEnvironment(
+ userPoolFactory: identityProviderFactory,
+ hostedUIEnvironment: hostedUIEnvironment
+ )
+ )
+
+ await fulfillment(of: [expectation], timeout: 1)
+ }
+
+ /// Given: A ShowHostedUISignOut action with global sign out set to false
+ /// When: execute is invoked with a success result
+ /// Then: A .revokeToken event is dispatched
+ func testExecute_withLocalSignOut_andSuccessResult_shouldDispatchSignOutEvent() async {
+ let expectation = expectation(description: "showHostedUISignOut")
+ let signInData = SignedInData.testData
+ let action = ShowHostedUISignOut(
+ signOutEvent: SignOutEventData(globalSignOut: false),
+ signInData: signInData
+ )
+
+ await action.execute(
+ withDispatcher: MockDispatcher { event in
+ guard let event = event as? SignOutEvent,
+ case .revokeToken(let data, let error, let globalSignOutError) = event.eventType else {
+ XCTFail("Expected SignOutEvent.revokeToken, got \(event)")
+ expectation.fulfill()
+ return
+ }
+
+ XCTAssertNil(error)
+ XCTAssertNil(globalSignOutError)
+ XCTAssertEqual(data, signInData)
+ expectation.fulfill()
+ },
+ environment: Defaults.makeDefaultAuthEnvironment(
+ userPoolFactory: identityProviderFactory,
+ hostedUIEnvironment: hostedUIEnvironment
+ )
+ )
+
+ await fulfillment(of: [expectation], timeout: 1)
+ }
+
+ /// Given: A ShowHostedUISignOut action
+ /// When: execute is invoked but fails to create a HostedUI session
+ /// Then: A .userCancelled event is dispatched
+ func testExecute_withInvalidResult_shouldDispatchUserCancelledEvent() async {
+ mockHostedUIResult = .failure(.cancelled)
+ let signInData = SignedInData.testData
+
+ let action = ShowHostedUISignOut(
+ signOutEvent: .testData,
+ signInData: signInData
+ )
+
+ let expectation = expectation(description: "showHostedUISignOut")
+ await action.execute(
+ withDispatcher: MockDispatcher { event in
+ guard let event = event as? SignOutEvent else {
+ XCTFail("Expected SignOutEvent, got \(event)")
+ expectation.fulfill()
+ return
+ }
+
+ XCTAssertEqual(event.eventType, .userCancelled)
+ expectation.fulfill()
+ },
+ environment: Defaults.makeDefaultAuthEnvironment(
+ userPoolFactory: identityProviderFactory,
+ hostedUIEnvironment: hostedUIEnvironment
+ )
+ )
+
+ await fulfillment(of: [expectation], timeout: 1)
+ }
+
+ /// Given: A ShowHostedUISignOut action
+ /// When: execute is invoked but fails to create a HostedUI session with a HostedUIError.signOutURI
+ /// Then: A .signOutGlobally event is dispatched with a HosterUIError.configuration error
+ func testExecute_withSignOutURIError_shouldThrowConfigurationError() async {
+ mockHostedUIResult = .failure(HostedUIError.signOutURI)
+ let signInData = SignedInData.testData
+
+ let action = ShowHostedUISignOut(
+ signOutEvent: .testData,
+ signInData: signInData
+ )
+
+ let expectation = expectation(description: "showHostedUISignOut")
+ await action.execute(
+ withDispatcher: MockDispatcher { event in
+ guard let event = event as? SignOutEvent,
+ case .signOutGlobally(let data, let hostedUIError) = event.eventType else {
+ XCTFail("Expected SignOutEvent.signOutGlobally, got \(event)")
+ expectation.fulfill()
+ return
+ }
+
+ guard let hostedUIError = hostedUIError,
+ case .configuration(let errorDescription, _, let serviceError) = hostedUIError.error else {
+ XCTFail("Expected AuthError.configuration")
+ expectation.fulfill()
+ return
+ }
+
+ XCTAssertEqual(errorDescription, "Could not create logout URL")
+ XCTAssertEqual(data, signInData)
+ XCTAssertNil(serviceError)
+ expectation.fulfill()
+ },
+ environment: Defaults.makeDefaultAuthEnvironment(
+ userPoolFactory: identityProviderFactory,
+ hostedUIEnvironment: hostedUIEnvironment
+ )
+ )
+
+ await fulfillment(of: [expectation], timeout: 1)
+ }
+
+ /// Given: A ShowHostedUISignOut action
+ /// When: execute is invoked but fails to create a HostedUI session with a HostedUIError.invalidContext
+ /// Then: A .signOutGlobally event is dispatched with a HosterUIError.invalidState error
+ func testExecute_withInvalidContext_shouldThrowInvalidStateError() async {
+ mockHostedUIResult = .failure(HostedUIError.invalidContext)
+ let signInData = SignedInData.testData
+
+ let action = ShowHostedUISignOut(
+ signOutEvent: .testData,
+ signInData: signInData
+ )
+
+ let expectation = expectation(description: "showHostedUISignOut")
+ await action.execute(
+ withDispatcher: MockDispatcher { event in
+ guard let event = event as? SignOutEvent,
+ case .signOutGlobally(let data, let hostedUIError) = event.eventType else {
+ XCTFail("Expected SignOutEvent.signOutGlobally, got \(event)")
+ expectation.fulfill()
+ return
+ }
+
+ guard let hostedUIError = hostedUIError,
+ case .invalidState(let errorDescription, let recoverySuggestion, let serviceError) = hostedUIError.error else {
+ XCTFail("Expected AuthError.invalidState")
+ expectation.fulfill()
+ return
+ }
+
+ XCTAssertEqual(errorDescription, AuthPluginErrorConstants.hostedUIInvalidPresentation.errorDescription)
+ XCTAssertEqual(recoverySuggestion, AuthPluginErrorConstants.hostedUIInvalidPresentation.recoverySuggestion)
+ XCTAssertEqual(data, signInData)
+ XCTAssertNil(serviceError)
+ expectation.fulfill()
+ },
+ environment: Defaults.makeDefaultAuthEnvironment(
+ userPoolFactory: identityProviderFactory,
+ hostedUIEnvironment: hostedUIEnvironment
+ )
+ )
+
+ await fulfillment(of: [expectation], timeout: 1)
+ }
+
+ /// Given: A ShowHostedUISignOut action with an invalid SignOutRedirectURI
+ /// When: execute is invoked
+ /// Then: A .signOutGlobally event is dispatched with a HosterUIError.configuration error
+ func testExecute_withInvalidSignOutURI_shouldThrowConfigurationError() async {
+ signOutRedirectURI = "invalidURI"
+ let signInData = SignedInData.testData
+
+ let action = ShowHostedUISignOut(
+ signOutEvent: .testData,
+ signInData: signInData
+ )
+
+ let expectation = expectation(description: "showHostedUISignOut")
+ await action.execute(
+ withDispatcher: MockDispatcher { event in
+ guard let event = event as? SignOutEvent,
+ case .signOutGlobally(let data, let hostedUIError) = event.eventType else {
+ XCTFail("Expected SignOutEvent.signOutGlobally, got \(event)")
+ expectation.fulfill()
+ return
+ }
+
+ guard let hostedUIError = hostedUIError,
+ case .configuration(let errorDescription, _, let serviceError) = hostedUIError.error else {
+ XCTFail("Expected AuthError.configuration")
+ expectation.fulfill()
+ return
+ }
+
+ XCTAssertEqual(errorDescription, "Callback URL could not be retrieved")
+ XCTAssertEqual(data, signInData)
+ XCTAssertNil(serviceError)
+ expectation.fulfill()
+ },
+ environment: Defaults.makeDefaultAuthEnvironment(
+ userPoolFactory: identityProviderFactory,
+ hostedUIEnvironment: hostedUIEnvironment
+ )
+ )
+
+ await fulfillment(of: [expectation], timeout: 1)
+ }
+
+ /// Given: A ShowHostedUISignOut action
+ /// When: execute is invoked with a nil HostedUIEnvironment
+ /// Then: A .signOutGlobally event is dispatched with a HosterUIError.configuration error
+ func testExecute_withoutHostedUIEnvironment_shouldThrowConfigurationError() async {
+ let expectation = expectation(description: "noHostedUIEnvironment")
+ let signInData = SignedInData.testData
+ let action = ShowHostedUISignOut(
+ signOutEvent: .testData,
+ signInData: signInData
+ )
+ await action.execute(
+ withDispatcher: MockDispatcher { event in
+ guard let event = event as? SignOutEvent,
+ case .signOutGlobally(let data, let hostedUIError) = event.eventType else {
+ XCTFail("Expected SignOutEvent.signOutGlobally, got \(event)")
+ expectation.fulfill()
+ return
+ }
+
+ guard let hostedUIError = hostedUIError,
+ case .configuration(let errorDescription, _, let serviceError) = hostedUIError.error else {
+ XCTFail("Expected AuthError.configuration")
+ expectation.fulfill()
+ return
+ }
+
+ XCTAssertEqual(data, signInData)
+ XCTAssertEqual(errorDescription, AuthPluginErrorConstants.configurationError)
+ XCTAssertNil(serviceError)
+ expectation.fulfill()
+ },
+ environment: Defaults.makeDefaultAuthEnvironment(
+ userPoolFactory: identityProviderFactory,
+ hostedUIEnvironment: nil
+ )
+ )
+
+ await fulfillment(of: [expectation], timeout: 1)
+ }
+
+ /// Given: A ShowHostedUISignOut action
+ /// When: execute is invoked with an invalid environment
+ /// Then: A .signOutGlobally event is dispatched with a HosterUIError.configuration error
+ func testExecute_withInvalidUserPoolEnvironment_shouldThrowConfigurationError() async {
+ let expectation = expectation(description: "invalidUserPoolEnvironment")
+ let signInData = SignedInData.testData
+ let action = ShowHostedUISignOut(
+ signOutEvent: .testData,
+ signInData: signInData
+ )
+ await action.execute(
+ withDispatcher: MockDispatcher { event in
+ guard let event = event as? SignOutEvent,
+ case .signOutGlobally(let data, let hostedUIError) = event.eventType else {
+ XCTFail("Expected SignOutEvent.signOutGlobally, got \(event)")
+ expectation.fulfill()
+ return
+ }
+
+ guard let hostedUIError = hostedUIError,
+ case .configuration(let errorDescription, _, let serviceError) = hostedUIError.error else {
+ XCTFail("Expected AuthError.configuration")
+ expectation.fulfill()
+ return
+ }
+
+ XCTAssertEqual(data, signInData)
+ XCTAssertEqual(errorDescription, AuthPluginErrorConstants.configurationError)
+ XCTAssertNil(serviceError)
+ expectation.fulfill()
+ },
+ environment: MockInvalidEnvironment()
+ )
+
+ await fulfillment(of: [expectation], timeout: 1)
+ }
+
+ private func validateDebugInformation(signInData: SignedInData, action: ShowHostedUISignOut) {
+ XCTAssertFalse(action.debugDescription.isEmpty)
+ guard let signInDataDictionary = action.debugDictionary["signInData"] as? [String: Any] else {
+ XCTFail("Expected signInData dictionary")
+ return
+ }
+ XCTAssertEqual(signInDataDictionary.count, signInData.debugDictionary.count)
+
+ for key in signInDataDictionary.keys {
+ guard let left = signInDataDictionary[key] as? any Equatable,
+ let right = signInData.debugDictionary[key] as? any Equatable else {
+ continue
+ }
+ XCTAssertTrue(left.isEqual(to: right))
+ }
+ }
+
+ private var hostedUIEnvironment: HostedUIEnvironment {
+ BasicHostedUIEnvironment(
+ configuration: .init(
+ clientId: "clientId",
+ oauth: .init(
+ domain: "cognitodomain",
+ scopes: ["name"],
+ signInRedirectURI: "myapp://",
+ signOutRedirectURI: signOutRedirectURI
+ )
+ ),
+ hostedUISessionFactory: {
+ MockHostedUISession(result: self.mockHostedUIResult)
+ },
+ urlSessionFactory: {
+ URLSession.shared
+ },
+ randomStringFactory: {
+ MockRandomStringGenerator(
+ mockString: "mockString",
+ mockUUID: "mockUUID"
+ )
+ }
+ )
+ }
+
+ private func identityProviderFactory() throws -> CognitoUserPoolBehavior {
+ return MockIdentityProvider(
+ mockInitiateAuthResponse: { _ in
+ return InitiateAuthOutputResponse(
+ authenticationResult: .init(
+ accessToken: "accessTokenNew",
+ expiresIn: 100,
+ idToken: "idTokenNew",
+ refreshToken: "refreshTokenNew")
+ )
+ }
+ )
+ }
+}
+
+private extension Equatable {
+ func isEqual(to other: any Equatable) -> Bool {
+ guard let other = other as? Self else {
+ return false
+ }
+ return self == other
+ }
+}
diff --git a/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/CognitoASFTests/CognitoUserPoolASFTests.swift b/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/CognitoASFTests/CognitoUserPoolASFTests.swift
new file mode 100644
index 0000000000..63e72819b2
--- /dev/null
+++ b/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/CognitoASFTests/CognitoUserPoolASFTests.swift
@@ -0,0 +1,62 @@
+//
+// Copyright Amazon.com Inc. or its affiliates.
+// All Rights Reserved.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+
+@testable import AWSCognitoAuthPlugin
+import XCTest
+
+class CognitoUserPoolASFTests: XCTestCase {
+ private var userPool: CognitoUserPoolASF!
+
+ override func setUp() {
+ userPool = CognitoUserPoolASF()
+ }
+
+ override func tearDown() {
+ userPool = nil
+ }
+
+ /// Given: A CognitoUserPoolASF
+ /// When: userContextData is invoked
+ /// Then: A non-empty string is returned
+ func testUserContextData_shouldReturnData() throws {
+ let result = try userPool.userContextData(
+ for: "TestUser",
+ deviceInfo: ASFDeviceInfo(id: "mockedDevice"),
+ appInfo: ASFAppInfo(),
+ configuration: .testData
+ )
+ XCTAssertFalse(result.isEmpty)
+ }
+
+ /// Given: A CognitoUserPoolASF
+ /// When: calculateSecretHash is invoked
+ /// Then: A non-empty string is returned
+ func testCalculateSecretHash_shouldReturnHash() throws {
+ let result = try userPool.calculateSecretHash(
+ contextJson: "contextJson",
+ clientId: "clientId"
+ )
+ XCTAssertFalse(result.isEmpty)
+ }
+
+ /// Given: A CognitoUserPoolASF
+ /// When: calculateSecretHash is invoked with a clientId that cannot be parsed
+ /// Then: A ASFError.hashKey is thrown
+ func testCalculateSecretHash_withInvalidClientId_shouldThrowHashKeyError() {
+ do {
+ let result = try userPool.calculateSecretHash(
+ contextJson: "contextJson",
+ clientId: "🕺🏼" // This string cannot be represented using .ascii, so it will throw an error
+ )
+ XCTFail("Expected ASFError.hashKey, got \(result)")
+ } catch let error as ASFError {
+ XCTAssertEqual(error, .hashKey)
+ } catch {
+ XCTFail("Expected ASFError.hashKey, for \(error)")
+ }
+ }
+}
diff --git a/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/ConfigurationTests/EscapeHatchTests.swift b/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/ConfigurationTests/EscapeHatchTests.swift
index d385c628d8..acf3ec1c0d 100644
--- a/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/ConfigurationTests/EscapeHatchTests.swift
+++ b/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/ConfigurationTests/EscapeHatchTests.swift
@@ -6,143 +6,116 @@
//
import XCTest
-@testable import Amplify
+@testable import func AmplifyTestCommon.XCTAssertThrowFatalError
+import enum Amplify.JSONValue
@testable import AWSCognitoAuthPlugin
-class EscapeHatchTests: XCTestCase {
-
- let skipBrokenTests = true
-
- override func tearDown() async throws {
- await Amplify.reset()
- }
+class EscapeHatchTests: XCTestCase {
/// Test escape hatch with valid config for user pool and identity pool
///
- /// - Given: Given valid config for user pool and identity pool
+ /// - Given: A AWSCognitoAuthPlugin configured with User Pool and Identity Pool
/// - When:
- /// - I configure auth with the given configuration and call getEscapeHatch
+ /// - I call getEscapeHatch
/// - Then:
- /// - I should get back user pool and identity pool clients
+ /// - I should get back both the User Pool and Identity Pool clients
///
func testEscapeHatchWithUserPoolAndIdentityPool() throws {
- if skipBrokenTests {
- throw XCTSkip("TODO: fix this test")
- }
-
- let plugin = AWSCognitoAuthPlugin()
- try Amplify.add(plugin: plugin)
-
- let expectation = expectation(description: "Should get service")
- let categoryConfig = AuthCategoryConfiguration(plugins: [
- "awsCognitoAuthPlugin": [
- "CredentialsProvider": ["CognitoIdentity": ["Default":
- ["PoolId": "xx",
- "Region": "us-east-1"]
- ]],
- "CognitoUserPool": ["Default": [
+ let configuration: JSONValue = [
+ "CredentialsProvider": [
+ "CognitoIdentity": [
+ "Default": [
+ "PoolId": "xx",
+ "Region": "us-east-1"
+ ]
+ ]
+ ],
+ "CognitoUserPool": [
+ "Default": [
"PoolId": "xx",
"Region": "us-east-1",
"AppClientId": "xx",
- "AppClientSecret": "xx"]]
+ "AppClientSecret": "xx"
+ ]
]
- ])
- let amplifyConfig = AmplifyConfiguration(auth: categoryConfig)
- try Amplify.configure(amplifyConfig)
- let internalPlugin = try Amplify.Auth.getPlugin(
- for: "awsCognitoAuthPlugin"
- ) as! AWSCognitoAuthPlugin
- let service = internalPlugin.getEscapeHatch()
- switch service {
- case .userPool:
- XCTFail("Should return userPoolAndIdentityPool")
- case .identityPool:
- XCTFail("Should return userPoolAndIdentityPool")
- case .userPoolAndIdentityPool:
- expectation.fulfill()
+ ]
+ let plugin = AWSCognitoAuthPlugin()
+ try plugin.configure(using: configuration)
+ let escapeHatch = plugin.getEscapeHatch()
+ guard case .userPoolAndIdentityPool = escapeHatch else {
+ XCTFail("Expected .userPoolAndIdentityPool, got \(escapeHatch)")
+ return
}
- wait(for: [expectation], timeout: 1)
}
/// Test escape hatch with valid config for only identity pool
///
- /// - Given: Given valid config for only identity pool
+ /// - Given: A AWSCognitoAuthPlugin configured with only Identity Pool
/// - When:
- /// - I configure auth with the given configuration and invoke getEscapeHatch
+ /// - I call getEscapeHatch
/// - Then:
- /// - I should get back only identity pool client
+ /// - I should get back only the Identity Pool client
///
func testEscapeHatchWithOnlyIdentityPool() throws {
- if skipBrokenTests {
- throw XCTSkip("TODO: fix this test")
- }
-
- let plugin = AWSCognitoAuthPlugin()
- try Amplify.add(plugin: plugin)
-
- let categoryConfig = AuthCategoryConfiguration(plugins: [
- "awsCognitoAuthPlugin": [
- "CredentialsProvider": ["CognitoIdentity": ["Default":
- ["PoolId": "cc",
- "Region": "us-east-1"]
- ]]
+ let configuration: JSONValue = [
+ "CredentialsProvider": [
+ "CognitoIdentity": [
+ "Default": [
+ "PoolId": "xx",
+ "Region": "us-east-1"
+ ]
+ ]
]
- ])
- let amplifyConfig = AmplifyConfiguration(auth: categoryConfig)
- try Amplify.configure(amplifyConfig)
- let internalPlugin = try Amplify.Auth.getPlugin(
- for: "awsCognitoAuthPlugin"
- ) as! AWSCognitoAuthPlugin
- let service = internalPlugin.getEscapeHatch()
- switch service {
- case .userPool:
- XCTFail("Should return identityPool")
- case .userPoolAndIdentityPool:
- XCTFail("Should return identityPool")
- case .identityPool:
- print("")
+ ]
+ let plugin = AWSCognitoAuthPlugin()
+ try plugin.configure(using: configuration)
+ let escapeHatch = plugin.getEscapeHatch()
+ guard case .identityPool = escapeHatch else {
+ XCTFail("Expected .identityPool, got \(escapeHatch)")
+ return
}
}
/// Test escape hatch with valid config for only user pool
///
- /// - Given: Given valid config for only user pool
+ /// - Given: A AWSCognitoAuthPlugin configured with only User Pool
/// - When:
- /// - I configure auth with the given configuration and invoke getEscapeHatch
+ /// - I call getEscapeHatch
/// - Then:
- /// - I should get the Cognito User pool client
+ /// - I should get only the User Pool client
///
func testEscapeHatchWithOnlyUserPool() throws {
- if skipBrokenTests {
- throw XCTSkip("TODO: fix this test")
- }
-
- let plugin = AWSCognitoAuthPlugin()
- try Amplify.add(plugin: plugin)
-
- let categoryConfig = AuthCategoryConfiguration(plugins: [
- "awsCognitoAuthPlugin": [
- "CognitoUserPool": ["Default": [
+ let configuration: JSONValue = [
+ "CognitoUserPool": [
+ "Default": [
"PoolId": "xx",
"Region": "us-east-1",
"AppClientId": "xx",
- "AppClientSecret": "xx"]]
+ "AppClientSecret": "xx"
+ ]
]
- ])
- let amplifyConfig = AmplifyConfiguration(auth: categoryConfig)
- try Amplify.configure(amplifyConfig)
- let internalPlugin = try Amplify.Auth.getPlugin(
- for: "awsCognitoAuthPlugin"
- ) as! AWSCognitoAuthPlugin
- let service = internalPlugin.getEscapeHatch()
- switch service {
- case .userPool:
- break
- case .identityPool:
- XCTFail("Should return userPool")
- case .userPoolAndIdentityPool:
- XCTFail("Should return userPool")
+ ]
+ let plugin = AWSCognitoAuthPlugin()
+ try plugin.configure(using: configuration)
+ let escapeHatch = plugin.getEscapeHatch()
+ guard case .userPool = escapeHatch else {
+ XCTFail("Expected .userPool, got \(escapeHatch)")
+ return
+ }
+ }
+
+ /// Test escape hatch without a valid configuration
+ ///
+ /// - Given: A AWSCognitoAuthPlugin plugin without being configured
+ /// - When:
+ /// - I call getEscapeHatch
+ /// - Then:
+ /// - A fatalError is thrown
+ ///
+ func testEscapeHatchWithoutConfiguration() throws {
+ let plugin = AWSCognitoAuthPlugin()
+ try XCTAssertThrowFatalError {
+ _ = plugin.getEscapeHatch()
}
}
-
}
diff --git a/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/Support/AWSAuthCognitoSessionTests.swift b/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/Support/AWSAuthCognitoSessionTests.swift
index fdb8862284..43db976492 100644
--- a/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/Support/AWSAuthCognitoSessionTests.swift
+++ b/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/Support/AWSAuthCognitoSessionTests.swift
@@ -26,8 +26,7 @@ class AWSAuthCognitoSessionTests: XCTestCase {
let error = AuthError.unknown("", nil)
let tokens = AWSCognitoUserPoolTokens(idToken: CognitoAuthTestHelper.buildToken(for: tokenData),
accessToken: CognitoAuthTestHelper.buildToken(for: tokenData),
- refreshToken: "refreshToken",
- expiresIn: 121)
+ refreshToken: "refreshToken")
let session = AWSAuthCognitoSession(isSignedIn: true,
identityIdResult: .failure(error),
@@ -53,8 +52,7 @@ class AWSAuthCognitoSessionTests: XCTestCase {
let error = AuthError.unknown("", nil)
let tokens = AWSCognitoUserPoolTokens(idToken: CognitoAuthTestHelper.buildToken(for: tokenData),
accessToken: CognitoAuthTestHelper.buildToken(for: tokenData),
- refreshToken: "refreshToken",
- expiresIn: 121)
+ refreshToken: "refreshToken")
let session = AWSAuthCognitoSession(isSignedIn: true,
identityIdResult: .failure(error),
@@ -65,4 +63,156 @@ class AWSAuthCognitoSessionTests: XCTestCase {
XCTAssertFalse(cognitoTokens.doesExpire())
}
+ /// Given: An AWSAuthCognitoSession with a valid AWSCognitoUserPoolTokens
+ /// When: getUserSub is invoked
+ /// Then: The "sub" from the token data should be returned
+ func testGetUserSub_shouldReturnResult() {
+ let tokenData = [
+ "sub": "1234567890",
+ "name": "John Doe",
+ "iat": "1516239022",
+ "exp": String(Date(timeIntervalSinceNow: 121).timeIntervalSince1970)
+ ]
+
+ let error = AuthError.unknown("", nil)
+ let tokens = AWSCognitoUserPoolTokens(
+ idToken: CognitoAuthTestHelper.buildToken(for: tokenData),
+ accessToken: CognitoAuthTestHelper.buildToken(for: tokenData),
+ refreshToken: "refreshToken"
+ )
+
+ let session = AWSAuthCognitoSession(
+ isSignedIn: true,
+ identityIdResult: .failure(error),
+ awsCredentialsResult: .failure(error),
+ cognitoTokensResult: .success(tokens)
+ )
+
+ guard case .success(let userSub) = session.getUserSub() else {
+ XCTFail("Unable to retrieve userSub")
+ return
+ }
+ XCTAssertEqual(userSub, "1234567890")
+ }
+
+ /// Given: An AWSAuthCognitoSession with a AWSCognitoUserPoolTokens that does not include a "sub" attribute
+ /// When: getUserSub is invoked
+ /// Then: A .failure with AuthError.unknown error is returned
+ func testGetUserSub_withoutSub_shouldReturnError() {
+ let tokenData = [
+ "name": "John Doe",
+ "iat": "1516239022",
+ "exp": String(Date(timeIntervalSinceNow: 121).timeIntervalSince1970)
+ ]
+
+ let error = AuthError.unknown("", nil)
+ let tokens = AWSCognitoUserPoolTokens(
+ idToken: CognitoAuthTestHelper.buildToken(for: tokenData),
+ accessToken: CognitoAuthTestHelper.buildToken(for: tokenData),
+ refreshToken: "refreshToken"
+ )
+
+ let session = AWSAuthCognitoSession(
+ isSignedIn: true,
+ identityIdResult: .failure(error),
+ awsCredentialsResult: .failure(error),
+ cognitoTokensResult: .success(tokens)
+ )
+
+ guard case .failure(let error) = session.getUserSub(),
+ case .unknown(let errorDescription, _) = error else {
+ XCTFail("Expected AuthError.unknown")
+ return
+ }
+
+ XCTAssertEqual(errorDescription, "Could not retreive user sub from the fetched Cognito tokens.")
+ }
+
+ /// Given: An AWSAuthCognitoSession that is signed out
+ /// When: getUserSub is invoked
+ /// Then: A .failure with AuthError.signedOut error is returned
+ func testGetUserSub_signedOut_shouldReturnError() {
+ let error = AuthError.signedOut("", "", nil)
+ let session = AWSAuthCognitoSession(
+ isSignedIn: false,
+ identityIdResult: .failure(error),
+ awsCredentialsResult: .failure(error),
+ cognitoTokensResult: .failure(error)
+ )
+
+ guard case .failure(let error) = session.getUserSub(),
+ case .signedOut(let errorDescription, let recoverySuggestion, _) = error else {
+ XCTFail("Expected AuthError.signedOut")
+ return
+ }
+
+ XCTAssertEqual(errorDescription, AuthPluginErrorConstants.userSubSignOutError.errorDescription)
+ XCTAssertEqual(recoverySuggestion, AuthPluginErrorConstants.userSubSignOutError.recoverySuggestion)
+ }
+
+ /// Given: An AWSAuthCognitoSession that has a service error
+ /// When: getUserSub is invoked
+ /// Then: A .failure with AuthError.signedOut error is returned
+ func testGetUserSub_serviceError_shouldReturnError() {
+ let serviceError = AuthError.service("Something went wrong", "Try again", nil)
+ let session = AWSAuthCognitoSession(
+ isSignedIn: false,
+ identityIdResult: .failure(serviceError),
+ awsCredentialsResult: .failure(serviceError),
+ cognitoTokensResult: .failure(serviceError)
+ )
+
+ guard case .failure(let error) = session.getUserSub() else {
+ XCTFail("Expected AuthError.signedOut")
+ return
+ }
+
+ XCTAssertEqual(error, serviceError)
+ }
+
+ /// Given: An AuthAWSCognitoCredentials and an AWSCognitoUserPoolTokens instance
+ /// When: Two AWSAuthCognitoSession are created from the same values
+ /// Then: The two AWSAuthCognitoSession are considered equal
+ func testSessionsAreEqual() {
+ let expiration = Date(timeIntervalSinceNow: 121)
+ let tokenData = [
+ "sub": "1234567890",
+ "name": "John Doe",
+ "iat": "1516239022",
+ "exp": String(expiration.timeIntervalSince1970)
+ ]
+
+ let credentials = AuthAWSCognitoCredentials(
+ accessKeyId: "accessKeyId",
+ secretAccessKey: "secretAccessKey",
+ sessionToken: "sessionToken",
+ expiration: expiration
+ )
+
+ let tokens = AWSCognitoUserPoolTokens(
+ idToken: CognitoAuthTestHelper.buildToken(for: tokenData),
+ accessToken: CognitoAuthTestHelper.buildToken(for: tokenData),
+ refreshToken: "refreshToken"
+ )
+
+ let session1 = AWSAuthCognitoSession(
+ isSignedIn: true,
+ identityIdResult: .success("identityId"),
+ awsCredentialsResult: .success(credentials),
+ cognitoTokensResult: .success(tokens)
+ )
+
+ let session2 = AWSAuthCognitoSession(
+ isSignedIn: true,
+ identityIdResult: .success("identityId"),
+ awsCredentialsResult: .success(credentials),
+ cognitoTokensResult: .success(tokens)
+ )
+
+ XCTAssertEqual(session1, session2)
+ XCTAssertEqual(session1.debugDictionary.count, session2.debugDictionary.count)
+ for key in session1.debugDictionary.keys where (key != "AWS Credentials" && key != "cognitoTokens") {
+ XCTAssertEqual(session1.debugDictionary[key] as? String, session2.debugDictionary[key] as? String)
+ }
+ }
}
diff --git a/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/Support/HostedUIASWebAuthenticationSessionTests.swift b/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/Support/HostedUIASWebAuthenticationSessionTests.swift
new file mode 100644
index 0000000000..3909827f36
--- /dev/null
+++ b/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/Support/HostedUIASWebAuthenticationSessionTests.swift
@@ -0,0 +1,248 @@
+//
+// Copyright Amazon.com Inc. or its affiliates.
+// All Rights Reserved.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#if os(iOS) || os(macOS)
+import Amplify
+import AuthenticationServices
+@testable import AWSCognitoAuthPlugin
+import XCTest
+
+class HostedUIASWebAuthenticationSessionTests: XCTestCase {
+ private var session: HostedUIASWebAuthenticationSession!
+ private var factory: ASWebAuthenticationSessionFactory!
+
+ override func setUp() {
+ session = HostedUIASWebAuthenticationSession()
+ factory = ASWebAuthenticationSessionFactory()
+ session.authenticationSessionFactory = factory.createSession(url:callbackURLScheme:completionHandler:)
+ }
+
+ override func tearDown() {
+ session = nil
+ factory = nil
+ }
+
+ /// Given: A HostedUIASWebAuthenticationSession
+ /// When: showHostedUI is invoked and the session factory returns a URL with query items
+ /// Then: An array of query items should be returned
+ func testShowHostedUI_withUrlInCallback_withQueryItems_shouldReturnQueryItems() {
+ let expectation = expectation(description: "showHostedUI")
+ factory.mockedURL = createURL(queryItems: [.init(name: "name", value: "value")])
+
+ session.showHostedUI() { result in
+ do {
+ let queryItems = try result.get()
+ XCTAssertEqual(queryItems.count, 1)
+ XCTAssertEqual(queryItems.first?.name, "name")
+ XCTAssertEqual(queryItems.first?.value, "value")
+ } catch {
+ XCTFail("Expected .success(queryItems), got \(result)")
+ }
+ expectation.fulfill()
+ }
+ waitForExpectations(timeout: 1)
+ }
+
+ /// Given: A HostedUIASWebAuthenticationSession
+ /// When: showHostedUI is invoked and the session factory returns a URL without query items
+ /// Then: An empty array should be returned
+ func testShowHostedUI_withUrlInCallback_withoutQueryItems_shouldReturnEmptyQueryItems() {
+ let expectation = expectation(description: "showHostedUI")
+ factory.mockedURL = createURL()
+
+ session.showHostedUI() { result in
+ do {
+ let queryItems = try result.get()
+ XCTAssertTrue(queryItems.isEmpty)
+ } catch {
+ XCTFail("Expected .success(queryItems), got \(result)")
+ }
+ expectation.fulfill()
+ }
+ waitForExpectations(timeout: 1)
+ }
+
+ /// Given: A HostedUIASWebAuthenticationSession
+ /// When: showHostedUI is invoked and the session factory returns a URL with query items representing errors
+ /// Then: A HostedUIError.serviceMessage should be returned
+ func testShowHostedUI_withUrlInCallback_withErrorInQueryItems_shouldReturnServiceMessageError() {
+ let expectation = expectation(description: "showHostedUI")
+ factory.mockedURL = createURL(
+ queryItems: [
+ .init(name: "error", value: "Error."),
+ .init(name: "error_description", value: "Something went wrong")
+ ]
+ )
+
+ session.showHostedUI() { result in
+ do {
+ _ = try result.get()
+ XCTFail("Expected failure(.serviceMessage), got \(result)")
+ } catch let error as HostedUIError {
+ if case .serviceMessage(let message) = error {
+ XCTAssertEqual(message, "Error. Something went wrong")
+ } else {
+ XCTFail("Expected HostedUIError.serviceMessage, got \(error)")
+ }
+ } catch {
+ XCTFail("Expected HostedUIError.serviceMessage, got \(error)")
+ }
+ expectation.fulfill()
+ }
+ waitForExpectations(timeout: 1)
+ }
+
+ /// Given: A HostedUIASWebAuthenticationSession
+ /// When: showHostedUI is invoked and the session factory returns ASWebAuthenticationSessionErrors
+ /// Then: A HostedUIError corresponding to the error code should be returned
+ func testShowHostedUI_withASWebAuthenticationSessionErrors_shouldReturnRightError() {
+ let errorMap: [ASWebAuthenticationSessionError.Code: HostedUIError] = [
+ .canceledLogin: .cancelled,
+ .presentationContextNotProvided: .invalidContext,
+ .presentationContextInvalid: .invalidContext
+ ]
+
+ let errorCodes: [ASWebAuthenticationSessionError.Code] = [
+ .canceledLogin,
+ .presentationContextNotProvided,
+ .presentationContextInvalid,
+ .init(rawValue: 500)!
+ ]
+
+ for code in errorCodes {
+ factory.mockedError = ASWebAuthenticationSessionError(code)
+ let expectedError = errorMap[code] ?? .unknown
+ let expectation = expectation(description: "showHostedUI for error \(code)")
+ session.showHostedUI() { result in
+ do {
+ _ = try result.get()
+ XCTFail("Expected failure(.\(expectedError)), got \(result)")
+ } catch let error as HostedUIError {
+ XCTAssertEqual(error, expectedError)
+ } catch {
+ XCTFail("Expected HostedUIError.\(expectedError), got \(error)")
+ }
+ expectation.fulfill()
+ }
+ waitForExpectations(timeout: 1)
+ }
+ }
+
+ /// Given: A HostedUIASWebAuthenticationSession
+ /// When: showHostedUI is invoked and the session factory returns an error
+ /// Then: A HostedUIError.unknown should be returned
+ func testShowHostedUI_withOtherError_shouldReturnUnknownError() {
+ factory.mockedError = CancellationError()
+ let expectation = expectation(description: "showHostedUI")
+ session.showHostedUI() { result in
+ do {
+ _ = try result.get()
+ XCTFail("Expected failure(.unknown), got \(result)")
+ } catch let error as HostedUIError {
+ XCTAssertEqual(error, .unknown)
+ } catch {
+ XCTFail("Expected HostedUIError.unknown, got \(error)")
+ }
+ expectation.fulfill()
+ }
+ waitForExpectations(timeout: 1)
+ }
+
+ private func createURL(queryItems: [URLQueryItem] = []) -> URL {
+ var components = URLComponents(string: "https://test.com")!
+ components.queryItems = queryItems
+ return components.url!
+ }
+}
+
+class ASWebAuthenticationSessionFactory {
+ var mockedURL: URL?
+ var mockedError: Error?
+
+ func createSession(
+ url URL: URL,
+ callbackURLScheme: String?,
+ completionHandler: @escaping ASWebAuthenticationSession.CompletionHandler
+ ) -> ASWebAuthenticationSession {
+ let session = MockASWebAuthenticationSession(
+ url: URL,
+ callbackURLScheme: callbackURLScheme,
+ completionHandler: completionHandler
+ )
+ session.mockedURL = mockedURL
+ session.mockedError = mockedError
+ return session
+ }
+}
+
+class MockASWebAuthenticationSession: ASWebAuthenticationSession {
+ private var callback: ASWebAuthenticationSession.CompletionHandler
+ override init(
+ url URL: URL,
+ callbackURLScheme: String?,
+ completionHandler: @escaping ASWebAuthenticationSession.CompletionHandler
+ ) {
+ self.callback = completionHandler
+ super.init(
+ url: URL,
+ callbackURLScheme: callbackURLScheme,
+ completionHandler: completionHandler
+ )
+ }
+
+ var mockedURL: URL? = nil
+ var mockedError: Error? = nil
+ override func start() -> Bool {
+ callback(mockedURL, mockedError)
+ return presentationContextProvider?.presentationAnchor(for: self) != nil
+ }
+}
+
+extension HostedUIASWebAuthenticationSession {
+ func showHostedUI(callback: @escaping (Result<[URLQueryItem], HostedUIError>) -> Void) {
+ showHostedUI(
+ url: URL(string: "https://test.com")!,
+ callbackScheme: "https",
+ inPrivate: false,
+ presentationAnchor: nil,
+ callback: callback)
+ }
+}
+#else
+
+@testable import AWSCognitoAuthPlugin
+import XCTest
+
+class HostedUIASWebAuthenticationSessionTests: XCTestCase {
+ func testShowHostedUI_shouldThrowServiceError() {
+ let expectation = expectation(description: "showHostedUI")
+ let session = HostedUIASWebAuthenticationSession()
+ session.showHostedUI(
+ url: URL(string: "https://test.com")!,
+ callbackScheme: "https",
+ inPrivate: false,
+ presentationAnchor: nil
+ ) { result in
+ do {
+ _ = try result.get()
+ XCTFail("Expected failure(.serviceMessage), got \(result)")
+ } catch let error as HostedUIError {
+ if case .serviceMessage(let message) = error {
+ XCTAssertEqual(message, "HostedUI is only available in iOS and macOS")
+ } else {
+ XCTFail("Expected HostedUIError.serviceMessage, got \(error)")
+ }
+ } catch {
+ XCTFail("Expected HostedUIError.serviceMessage, got \(error)")
+ }
+ expectation.fulfill()
+ }
+ waitForExpectations(timeout: 1)
+ }
+}
+
+#endif
diff --git a/AmplifyPlugins/Notifications/Push/Tests/AWSPinpointPushNotificationsPluginUnitTests/ErrorPushNotificationsTests.swift b/AmplifyPlugins/Notifications/Push/Tests/AWSPinpointPushNotificationsPluginUnitTests/ErrorPushNotificationsTests.swift
new file mode 100644
index 0000000000..14963a0cc1
--- /dev/null
+++ b/AmplifyPlugins/Notifications/Push/Tests/AWSPinpointPushNotificationsPluginUnitTests/ErrorPushNotificationsTests.swift
@@ -0,0 +1,105 @@
+//
+// Copyright Amazon.com Inc. or its affiliates.
+// All Rights Reserved.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+
+@testable import Amplify
+import AWSClientRuntime
+import AwsCommonRuntimeKit
+import AWSPinpoint
+@testable import AWSPinpointPushNotificationsPlugin
+import ClientRuntime
+import Foundation
+import XCTest
+
+class ErrorPushNotificationsTests: XCTestCase {
+ /// Given: A NSError error
+ /// When: pushNotificationsError is invoked
+ /// Then: An .unknown error is returned
+ func testPushNotificationsError_withUnknownError_shouldReturnUnknownError() {
+ let error = NSError(domain: "MyError", code: 1234)
+ let pushNotificationsError = error.pushNotificationsError
+ switch pushNotificationsError {
+ case .unknown(let errorDescription, let underlyingError):
+ XCTAssertEqual(errorDescription, "An unknown error occurred")
+ XCTAssertEqual(error.localizedDescription, underlyingError?.localizedDescription)
+ default:
+ XCTFail("Expected error of type .unknown, got \(pushNotificationsError)")
+ }
+ }
+
+ /// Given: A NSError error with a connectivity-related error code
+ /// When: pushNotificationsError is invoked
+ /// Then: A .network error is returned
+ func testPushNotificationsError_withConnectivityError_shouldReturnNetworkError() {
+ let error = NSError(domain: "ConnectivityError", code: NSURLErrorNotConnectedToInternet)
+ let pushNotificationsError = error.pushNotificationsError
+ switch pushNotificationsError {
+ case .network(let errorDescription, let recoverySuggestion, let underlyingError):
+ XCTAssertEqual(errorDescription, PushNotificationsPluginErrorConstants.deviceOffline.errorDescription)
+ XCTAssertEqual(recoverySuggestion, PushNotificationsPluginErrorConstants.deviceOffline.recoverySuggestion)
+ XCTAssertEqual(error.localizedDescription, underlyingError?.localizedDescription)
+ default:
+ XCTFail("Expected error of type .network, got \(pushNotificationsError)")
+ }
+ }
+
+ /// Given: An Error defined by the SDK
+ /// When: pushNotificationsError is invoked
+ /// Then: A .service error is returned
+ func testPushNotificationError_withServiceError_shouldReturnServiceError() {
+ let errors: [(String, PushNotificationsErrorConvertible & Error)] = [
+ ("BadRequestException", BadRequestException(message: "BadRequestException")),
+ ("InternalServerErrorException", InternalServerErrorException(message: "InternalServerErrorException")),
+ ("ForbiddenException", ForbiddenException(message: "ForbiddenException")),
+ ("MethodNotAllowedException", MethodNotAllowedException(message: "MethodNotAllowedException")),
+ ("NotFoundException", NotFoundException(message: "NotFoundException")),
+ ("PayloadTooLargeException", PayloadTooLargeException(message: "PayloadTooLargeException")),
+ ("TooManyRequestsException", TooManyRequestsException(message: "TooManyRequestsException"))
+ ]
+
+ for (expectedMessage, error) in errors {
+ let pushNotificationsError = error.pushNotificationsError
+ switch pushNotificationsError {
+ case .service(let errorDescription, let recoverySuggestion, let underlyingError):
+ XCTAssertEqual(errorDescription, expectedMessage)
+ XCTAssertEqual(recoverySuggestion, PushNotificationsPluginErrorConstants.nonRetryableServiceError.recoverySuggestion)
+ XCTAssertEqual(error.localizedDescription, underlyingError?.localizedDescription)
+ default:
+ XCTFail("Expected error of type .service, got \(pushNotificationsError)")
+ }
+ }
+ }
+
+ /// Given: An UnknownAWSHTTPServiceError
+ /// When: pushNotificationsError is invoked
+ /// Then: A .unknown error is returned
+ func testPushNotificationError_withUnknownAWSHTTPServiceError_shouldReturnUnknownError() {
+ let error = UnknownAWSHTTPServiceError(httpResponse: .init(body: .none, statusCode: .accepted), message: "UnknownAWSHTTPServiceError", requestID: nil, typeName: nil)
+ let pushNotificationsError = error.pushNotificationsError
+ switch pushNotificationsError {
+ case .unknown(let errorDescription, let underlyingError):
+ XCTAssertEqual(errorDescription, "UnknownAWSHTTPServiceError")
+ XCTAssertEqual(error.localizedDescription, underlyingError?.localizedDescription)
+ default:
+ XCTFail("Expected error of type .unknown, got \(pushNotificationsError)")
+ }
+ }
+
+ /// Given: A CommonRunTimeError.crtError
+ /// When: pushNotificationsError is invoked
+ /// Then: A .unknown error is returned
+ func testPushNotificationError_withCommonRunTimeError_shouldReturnUnknownError() {
+ let error = CommonRunTimeError.crtError(.init(code: 12345))
+ let pushNotificationsError = error.pushNotificationsError
+ switch pushNotificationsError {
+ case .unknown(let errorDescription, let underlyingError):
+ XCTAssertEqual(errorDescription, "Unknown Error Code")
+ XCTAssertEqual(error.localizedDescription, underlyingError?.localizedDescription)
+ default:
+ XCTFail("Expected error of type .unknown, got \(pushNotificationsError)")
+ }
+ }
+}
diff --git a/AmplifyPlugins/Storage/Sources/AWSS3StoragePlugin/Dependency/AWSS3Adapter.swift b/AmplifyPlugins/Storage/Sources/AWSS3StoragePlugin/Dependency/AWSS3Adapter.swift
index 969fd6475f..5fd65c6ec0 100644
--- a/AmplifyPlugins/Storage/Sources/AWSS3StoragePlugin/Dependency/AWSS3Adapter.swift
+++ b/AmplifyPlugins/Storage/Sources/AWSS3StoragePlugin/Dependency/AWSS3Adapter.swift
@@ -18,10 +18,10 @@ import AWSClientRuntime
/// and allows for mocking in unit tests. The methods contain no other logic other than calling the
/// same method using the AWSS3 instance.
class AWSS3Adapter: AWSS3Behavior {
- let awsS3: S3Client
+ let awsS3: S3ClientProtocol
let config: S3Client.S3ClientConfiguration
- init(_ awsS3: S3Client, config: S3Client.S3ClientConfiguration) {
+ init(_ awsS3: S3ClientProtocol, config: S3Client.S3ClientConfiguration) {
self.awsS3 = awsS3
self.config = config
}
@@ -161,7 +161,7 @@ class AWSS3Adapter: AWSS3Behavior {
/// Instance of S3 service.
/// - Returns: S3 service instance.
- func getS3() -> S3Client {
+ func getS3() -> S3ClientProtocol {
return awsS3
}
}
diff --git a/AmplifyPlugins/Storage/Sources/AWSS3StoragePlugin/Dependency/AWSS3Behavior.swift b/AmplifyPlugins/Storage/Sources/AWSS3StoragePlugin/Dependency/AWSS3Behavior.swift
index 7319878805..400ca3eb6c 100644
--- a/AmplifyPlugins/Storage/Sources/AWSS3StoragePlugin/Dependency/AWSS3Behavior.swift
+++ b/AmplifyPlugins/Storage/Sources/AWSS3StoragePlugin/Dependency/AWSS3Behavior.swift
@@ -35,7 +35,7 @@ protocol AWSS3Behavior {
func abortMultipartUpload(_ request: AWSS3AbortMultipartUploadRequest, completion: @escaping (Result) -> Void)
// Gets a client for AWS S3 Service.
- func getS3() -> S3Client
+ func getS3() -> S3ClientProtocol
}
diff --git a/AmplifyPlugins/Storage/Sources/AWSS3StoragePlugin/Service/Storage/AWSS3StorageService.swift b/AmplifyPlugins/Storage/Sources/AWSS3StoragePlugin/Service/Storage/AWSS3StorageService.swift
index c26fbc2c79..d31b9e588e 100644
--- a/AmplifyPlugins/Storage/Sources/AWSS3StoragePlugin/Service/Storage/AWSS3StorageService.swift
+++ b/AmplifyPlugins/Storage/Sources/AWSS3StoragePlugin/Service/Storage/AWSS3StorageService.swift
@@ -54,6 +54,7 @@ class AWSS3StorageService: AWSS3StorageServiceBehavior, StorageServiceProxy {
httpClientEngineProxy: HttpClientEngineProxy? = nil,
storageConfiguration: StorageConfiguration = .default,
storageTransferDatabase: StorageTransferDatabase = .default,
+ fileSystem: FileSystem = .default,
sessionConfiguration: URLSessionConfiguration? = nil,
delegateQueue: OperationQueue? = nil,
logger: Logger = storageLogger) throws {
@@ -97,7 +98,9 @@ class AWSS3StorageService: AWSS3StorageServiceBehavior, StorageServiceProxy {
self.init(authService: authService,
storageConfiguration: storageConfiguration,
storageTransferDatabase: storageTransferDatabase,
+ fileSystem: fileSystem,
sessionConfiguration: _sessionConfiguration,
+ logger: logger,
s3Client: s3Client,
preSignedURLBuilder: preSignedURLBuilder,
awsS3: awsS3,
diff --git a/AmplifyPlugins/Storage/Sources/AWSS3StoragePlugin/Support/Internal/StorageMultipartUploadSession.swift b/AmplifyPlugins/Storage/Sources/AWSS3StoragePlugin/Support/Internal/StorageMultipartUploadSession.swift
index d335f31805..fc51016cb9 100644
--- a/AmplifyPlugins/Storage/Sources/AWSS3StoragePlugin/Support/Internal/StorageMultipartUploadSession.swift
+++ b/AmplifyPlugins/Storage/Sources/AWSS3StoragePlugin/Support/Internal/StorageMultipartUploadSession.swift
@@ -44,14 +44,6 @@ class StorageMultipartUploadSession {
private let transferTask: StorageTransferTask
- private var contentType: String? {
- transferTask.contentType
- }
-
- private var requestHeaders: RequestHeaders? {
- transferTask.requestHeaders
- }
-
init(client: StorageMultipartUploadClient,
bucket: String,
key: String,
diff --git a/AmplifyPlugins/Storage/Sources/AWSS3StoragePlugin/Support/Internal/StorageTransferTask.swift b/AmplifyPlugins/Storage/Sources/AWSS3StoragePlugin/Support/Internal/StorageTransferTask.swift
index 1cd7e95385..1ecf8894db 100644
--- a/AmplifyPlugins/Storage/Sources/AWSS3StoragePlugin/Support/Internal/StorageTransferTask.swift
+++ b/AmplifyPlugins/Storage/Sources/AWSS3StoragePlugin/Support/Internal/StorageTransferTask.swift
@@ -173,10 +173,6 @@ class StorageTransferTask {
}
}
- private var cancelled: Bool {
- status == .cancelled
- }
-
var isFailed: Bool {
status == .error
}
@@ -324,7 +320,7 @@ class StorageTransferTask {
logger.warn("Unable to complete after cancelled")
return
}
- guard _status == .completed else {
+ guard _status != .completed else {
logger.warn("Task is already completed")
return
}
diff --git a/AmplifyPlugins/Storage/Tests/AWSS3StoragePluginTests/Dependency/AWSS3AdapterTests.swift b/AmplifyPlugins/Storage/Tests/AWSS3StoragePluginTests/Dependency/AWSS3AdapterTests.swift
new file mode 100644
index 0000000000..4cf455e494
--- /dev/null
+++ b/AmplifyPlugins/Storage/Tests/AWSS3StoragePluginTests/Dependency/AWSS3AdapterTests.swift
@@ -0,0 +1,754 @@
+//
+// Copyright Amazon.com Inc. or its affiliates.
+// All Rights Reserved.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+
+@testable import Amplify
+@testable import AWSS3StoragePlugin
+import AWSS3
+import XCTest
+
+class AWSS3AdapterTests: XCTestCase {
+ private var adapter: AWSS3Adapter!
+ private var awsS3: S3ClientMock!
+
+ override func setUp() {
+ awsS3 = S3ClientMock()
+ adapter = AWSS3Adapter(
+ awsS3,
+ config: try! S3Client.S3ClientConfiguration(
+ region: "us-east-1"
+ )
+ )
+ }
+
+ override func tearDown() {
+ adapter = nil
+ awsS3 = nil
+ }
+
+ /// Given: An AWSS3Adapter
+ /// When: deleteObject is invoked and the s3 client returns success
+ /// Then: A .success result is returned
+ func testDeleteObject_withSuccess_shouldSucceed() {
+ let deleteExpectation = expectation(description: "Delete Object")
+ adapter.deleteObject(.init(bucket: "bucket", key: "key")) { result in
+ XCTAssertEqual(self.awsS3.deleteObjectCount, 1)
+ guard case .success = result else {
+ XCTFail("Expected success")
+ return
+ }
+ deleteExpectation.fulfill()
+ }
+
+ waitForExpectations(timeout: 1)
+ }
+
+ /// Given: An AWSS3Adapter
+ /// When: deleteObject is invoked and the s3 client returns an error
+ /// Then: A .failure result is returned
+ func testDeleteObject_withError_shouldFail() {
+ let deleteExpectation = expectation(description: "Delete Object")
+ awsS3.deleteObjectResult = .failure(StorageError.keyNotFound("InvalidKey", "", "", nil))
+ adapter.deleteObject(.init(bucket: "bucket", key: "key")) { result in
+ XCTAssertEqual(self.awsS3.deleteObjectCount, 1)
+ guard case .failure(let error) = result,
+ case .keyNotFound(let key, _, _, _) = error else {
+ XCTFail("Expected StorageError.keyNotFound")
+ return
+ }
+ XCTAssertEqual(key, "InvalidKey")
+ deleteExpectation.fulfill()
+ }
+ waitForExpectations(timeout: 1)
+ }
+
+ /// Given: An AWSS3Adapter
+ /// When: listObjectsV2 is invoked and the s3 client returns a list of objects
+ /// Then: A .success result is returned containing the corresponding list items
+ func testListObjectsV2_withSuccess_shouldSucceed() {
+ let listExpectation = expectation(description: "List Objects")
+ awsS3.listObjectsV2Result = .success(ListObjectsV2OutputResponse(
+ contents: [
+ .init(eTag: "one", key: "prefix/key1", lastModified: .init()),
+ .init(eTag: "two", key: "prefix/key2", lastModified: .init())
+ ]
+ ))
+ adapter.listObjectsV2(.init(
+ bucket: "bucket",
+ prefix: "prefix/"
+ )) { result in
+ XCTAssertEqual(self.awsS3.listObjectsV2Count, 1)
+ guard case .success(let response) = result else {
+ XCTFail("Expected success")
+ return
+ }
+ XCTAssertEqual(response.items.count, 2)
+ XCTAssertTrue(response.items.contains(where: { $0.key == "key1" && $0.eTag == "one" }))
+ XCTAssertTrue(response.items.contains(where: { $0.key == "key2" && $0.eTag == "two" }))
+ listExpectation.fulfill()
+ }
+
+ waitForExpectations(timeout: 1)
+ }
+
+ /// Given: An AWSS3Adapter
+ /// When: listObjectsV2 is invoked and the s3 client returns an error
+ /// Then: A .failure result is returned
+ func testListObjectsV2_withError_shouldFail() {
+ let listExpectation = expectation(description: "List Objects")
+ awsS3.listObjectsV2Result = .failure(StorageError.accessDenied("AccessDenied", "", nil))
+ adapter.listObjectsV2(.init(
+ bucket: "bucket",
+ prefix: "prefix"
+ )) { result in
+ XCTAssertEqual(self.awsS3.listObjectsV2Count, 1)
+ guard case .failure(let error) = result,
+ case .accessDenied(let description, _, _) = error else {
+ XCTFail("Expected StorageError.accessDenied")
+ return
+ }
+ XCTAssertEqual(description, "AccessDenied")
+ listExpectation.fulfill()
+ }
+
+ waitForExpectations(timeout: 1)
+ }
+
+ /// Given: An AWSS3Adapter
+ /// When: createMultipartUpload is invoked and the s3 client returns a valid response
+ /// Then: A .success result is returned containing the corresponding parsed response
+ func testCreateMultipartUpload_withSuccess_shouldSucceed() {
+ let createMultipartUploadExpectation = expectation(description: "Create Multipart Upload")
+ awsS3.createMultipartUploadResult = .success(.init(
+ bucket: "bucket",
+ key: "key",
+ uploadId: "uploadId"
+ ))
+ adapter.createMultipartUpload(.init(bucket: "bucket", key: "key")) { result in
+ XCTAssertEqual(self.awsS3.createMultipartUploadCount, 1)
+ guard case .success(let response) = result else {
+ XCTFail("Expected success")
+ return
+ }
+ XCTAssertEqual(response.bucket, "bucket")
+ XCTAssertEqual(response.key, "key")
+ XCTAssertEqual(response.uploadId, "uploadId")
+ createMultipartUploadExpectation.fulfill()
+ }
+
+ waitForExpectations(timeout: 1)
+ }
+
+ /// Given: An AWSS3Adapter
+ /// When: createMultipartUpload is invoked and the s3 client returns an invalid response
+ /// Then: A .failure result is returned with an .uknown error
+ func testCreateMultipartUpload_withWrongResponse_shouldFail() {
+ let createMultipartUploadExpectation = expectation(description: "Create Multipart Upload")
+ adapter.createMultipartUpload(.init(bucket: "bucket", key: "key")) { result in
+ XCTAssertEqual(self.awsS3.createMultipartUploadCount, 1)
+ guard case .failure(let error) = result,
+ case .unknown(let description, _) = error else {
+ XCTFail("Expected StorageError.unknown")
+ return
+ }
+ XCTAssertEqual(description, "Invalid response for creating multipart upload")
+ createMultipartUploadExpectation.fulfill()
+ }
+
+ waitForExpectations(timeout: 1)
+ }
+
+ /// Given: An AWSS3Adapter
+ /// When: createMultipartUpload is invoked and the s3 client returns an error
+ /// Then: A .failure result is returned
+ func testCreateMultipartUpload_withError_shouldFail() {
+ let createMultipartUploadExpectation = expectation(description: "Create Multipart Upload")
+ awsS3.createMultipartUploadResult = .failure(StorageError.accessDenied("AccessDenied", "", nil))
+ adapter.createMultipartUpload(.init(bucket: "bucket", key: "key")) { result in
+ XCTAssertEqual(self.awsS3.createMultipartUploadCount, 1)
+ guard case .failure(let error) = result,
+ case .accessDenied(let description, _, _) = error else {
+ XCTFail("Expected StorageError.accessDenied")
+ return
+ }
+ XCTAssertEqual(description, "AccessDenied")
+ createMultipartUploadExpectation.fulfill()
+ }
+
+ waitForExpectations(timeout: 1)
+ }
+
+ /// Given: An AWSS3Adapter
+ /// When: listParts is invoked and the s3 client returns a valid response
+ /// Then: A .success result is returned containing the corresponding parsed response
+ func testListParts_withSuccess_shouldSucceed() {
+ let listPartsExpectation = expectation(description: "List Parts")
+ awsS3.listPartsResult = .success(.init(
+ bucket: "bucket",
+ key: "key",
+ parts: [
+ .init(eTag: "eTag1", partNumber: 1),
+ .init(eTag: "eTag2", partNumber: 2)
+ ],
+ uploadId: "uploadId"
+ ))
+ adapter.listParts(bucket: "bucket", key: "key", uploadId: "uploadId") { result in
+ XCTAssertEqual(self.awsS3.listPartsCount, 1)
+ guard case .success(let response) = result else {
+ XCTFail("Expected success")
+ return
+ }
+ XCTAssertEqual(response.bucket, "bucket")
+ XCTAssertEqual(response.key, "key")
+ XCTAssertEqual(response.uploadId, "uploadId")
+ XCTAssertEqual(response.parts.count, 2)
+ listPartsExpectation.fulfill()
+ }
+
+ waitForExpectations(timeout: 1)
+ }
+
+ /// Given: An AWSS3Adapter
+ /// When: listParts is invoked and the s3 client returns an invalid response
+ /// Then: A .failure result is returned with an .unknown error
+ func testListParts_withWrongResponse_shouldFail() {
+ let listPartsExpectation = expectation(description: "List Parts")
+ adapter.listParts(bucket: "bucket", key: "key", uploadId: "uploadId") { result in
+ XCTAssertEqual(self.awsS3.listPartsCount, 1)
+ guard case .failure(let error) = result,
+ case .unknown(let description, _) = error else {
+ XCTFail("Expected StorageError.unknown")
+ return
+ }
+ XCTAssertEqual(description, "ListParts response is invalid")
+ listPartsExpectation.fulfill()
+ }
+
+ waitForExpectations(timeout: 1)
+ }
+
+ /// Given: An AWSS3Adapter
+ /// When: listParts is invoked and the s3 client returns an error
+ /// Then: A .failure result is returned
+ func testListParts_withError_shouldFail() {
+ let listPartsExpectation = expectation(description: "List Parts")
+ awsS3.listPartsResult = .failure(StorageError.authError("AuthError", "", nil))
+ adapter.listParts(bucket: "bucket", key: "key", uploadId: "uploadId") { result in
+ XCTAssertEqual(self.awsS3.listPartsCount, 1)
+ guard case .failure(let error) = result,
+ case .authError(let description, _, _) = error else {
+ XCTFail("Expected StorageError.authError")
+ return
+ }
+ XCTAssertEqual(description, "AuthError")
+ listPartsExpectation.fulfill()
+ }
+
+ waitForExpectations(timeout: 1)
+ }
+
+ /// Given: An AWSS3Adapter
+ /// When: completeMultipartUpload is invoked and the s3 client returns a valid response
+ /// Then: A .success result is returned containing the corresponding parsed response
+ func testCompleteMultipartUpload_withSuccess_shouldSucceed() {
+ let completeMultipartUploadExpectation = expectation(description: "Complete Multipart Upload")
+ awsS3.completeMultipartUploadResult = .success(.init(
+ eTag: "eTag"
+ ))
+ adapter.completeMultipartUpload(.init(
+ bucket: "bucket",
+ key: "key",
+ uploadId: "uploadId",
+ parts: [.init(partNumber: 1, eTag: "eTag1"), .init(partNumber: 2, eTag: "eTag2")]
+ )) { result in
+ XCTAssertEqual(self.awsS3.completeMultipartUploadCount, 1)
+ guard case .success(let response) = result else {
+ XCTFail("Expected success")
+ return
+ }
+ XCTAssertEqual(response.bucket, "bucket")
+ XCTAssertEqual(response.key, "key")
+ XCTAssertEqual(response.eTag, "eTag")
+ completeMultipartUploadExpectation.fulfill()
+ }
+
+ waitForExpectations(timeout: 1)
+ }
+
+ /// Given: An AWSS3Adapter
+ /// When: completeMultipartUpload is invoked and the s3 client returns an invalid response
+ /// Then: A .failure result is returned with .unknown error
+ func testCompleteMultipartUpload_withWrongResponse_shouldFail() {
+ let completeMultipartUploadExpectation = expectation(description: "Complete Multipart Upload")
+ adapter.completeMultipartUpload(.init(bucket: "bucket", key: "key", uploadId: "uploadId", parts: [])) { result in
+ XCTAssertEqual(self.awsS3.completeMultipartUploadCount, 1)
+ guard case .failure(let error) = result,
+ case .unknown(let description, _) = error else {
+ XCTFail("Expected StorageError.unknown")
+ return
+ }
+ XCTAssertEqual(description, "Invalid response for completing multipart upload")
+ completeMultipartUploadExpectation.fulfill()
+ }
+
+ waitForExpectations(timeout: 1)
+ }
+
+ /// Given: An AWSS3Adapter
+ /// When: completeMultipartUpload is invoked and the s3 client returns an error
+ /// Then: A .failure result is returned
+ func testCompleteMultipartUpload_withError_shouldFail() {
+ let completeMultipartUploadExpectation = expectation(description: "Complete Multipart Upload")
+ awsS3.completeMultipartUploadResult = .failure(StorageError.authError("AuthError", "", nil))
+ adapter.completeMultipartUpload(.init(bucket: "bucket", key: "key", uploadId: "uploadId", parts: [])) { result in
+ XCTAssertEqual(self.awsS3.completeMultipartUploadCount, 1)
+ guard case .failure(let error) = result,
+ case .authError(let description, _, _) = error else {
+ XCTFail("Expected StorageError.authError")
+ return
+ }
+ XCTAssertEqual(description, "AuthError")
+ completeMultipartUploadExpectation.fulfill()
+ }
+
+ waitForExpectations(timeout: 1)
+ }
+
+ /// Given: An AWSS3Adapter
+ /// When: abortMultipartUpload is invoked and the s3 client returns a valid response
+ /// Then: A .success result is returned
+ func testAbortMultipartUpload_withSuccess_shouldSucceed() {
+ let abortExpectation = expectation(description: "Abort Multipart Upload")
+ adapter.abortMultipartUpload(.init(bucket: "bucket", key: "key", uploadId: "uploadId")) { result in
+ XCTAssertEqual(self.awsS3.abortMultipartUploadCount, 1)
+ guard case .success = result else {
+ XCTFail("Expected success")
+ return
+ }
+ abortExpectation.fulfill()
+ }
+
+ waitForExpectations(timeout: 1)
+ }
+
+ /// Given: An AWSS3Adapter
+ /// When: abortMultipartUpload is invoked and the s3 client returns an error
+ /// Then: A .failure result is returned
+ func testAbortMultipartUpload_withError_shouldFail() {
+ let abortExpectation = expectation(description: "Abort Multipart Upload")
+ awsS3.abortMultipartUploadResult = .failure(StorageError.keyNotFound("InvalidKey", "", "", nil))
+ adapter.abortMultipartUpload(.init(bucket: "bucket", key: "key", uploadId: "uploadId")) { result in
+ XCTAssertEqual(self.awsS3.abortMultipartUploadCount, 1)
+ guard case .failure(let error) = result,
+ case .keyNotFound(let key, _, _, _) = error else {
+ XCTFail("Expected StorageError.keyNotFound")
+ return
+ }
+ XCTAssertEqual(key, "InvalidKey")
+ abortExpectation.fulfill()
+ }
+ waitForExpectations(timeout: 1)
+ }
+
+ /// Given: An AWSS3Adapter
+ /// When: getS3 is invoked
+ /// Then: The underlying S3ClientProtocol instance is returned
+ func testGetS3() {
+ XCTAssertTrue(adapter.getS3() is S3ClientMock)
+ }
+}
+
+private class S3ClientMock: S3ClientProtocol {
+ var deleteObjectCount = 0
+ var deleteObjectResult: Result = .success(.init())
+ func deleteObject(input: AWSS3.DeleteObjectInput) async throws -> AWSS3.DeleteObjectOutputResponse {
+ deleteObjectCount += 1
+ return try deleteObjectResult.get()
+ }
+
+ var listObjectsV2Count = 0
+ var listObjectsV2Result: Result = .success(.init())
+ func listObjectsV2(input: AWSS3.ListObjectsV2Input) async throws -> AWSS3.ListObjectsV2OutputResponse {
+ listObjectsV2Count += 1
+ return try listObjectsV2Result.get()
+ }
+
+ var createMultipartUploadCount = 0
+ var createMultipartUploadResult: Result = .success(.init())
+ func createMultipartUpload(input: AWSS3.CreateMultipartUploadInput) async throws -> AWSS3.CreateMultipartUploadOutputResponse {
+ createMultipartUploadCount += 1
+ return try createMultipartUploadResult.get()
+ }
+
+ var listPartsCount = 0
+ var listPartsResult: Result = .success(.init())
+ func listParts(input: AWSS3.ListPartsInput) async throws -> AWSS3.ListPartsOutputResponse {
+ listPartsCount += 1
+ return try listPartsResult.get()
+ }
+
+ var completeMultipartUploadCount = 0
+ var completeMultipartUploadResult: Result = .success(.init())
+ func completeMultipartUpload(input: AWSS3.CompleteMultipartUploadInput) async throws -> AWSS3.CompleteMultipartUploadOutputResponse {
+ completeMultipartUploadCount += 1
+ return try completeMultipartUploadResult.get()
+ }
+
+ var abortMultipartUploadCount = 0
+ var abortMultipartUploadResult: Result = .success(.init())
+ func abortMultipartUpload(input: AWSS3.AbortMultipartUploadInput) async throws -> AWSS3.AbortMultipartUploadOutputResponse {
+ abortMultipartUploadCount += 1
+ return try abortMultipartUploadResult.get()
+ }
+
+ func copyObject(input: AWSS3.CopyObjectInput) async throws -> AWSS3.CopyObjectOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func createBucket(input: AWSS3.CreateBucketInput) async throws -> AWSS3.CreateBucketOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func deleteBucket(input: AWSS3.DeleteBucketInput) async throws -> AWSS3.DeleteBucketOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func deleteBucketAnalyticsConfiguration(input: AWSS3.DeleteBucketAnalyticsConfigurationInput) async throws -> AWSS3.DeleteBucketAnalyticsConfigurationOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func deleteBucketCors(input: AWSS3.DeleteBucketCorsInput) async throws -> AWSS3.DeleteBucketCorsOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func deleteBucketEncryption(input: AWSS3.DeleteBucketEncryptionInput) async throws -> AWSS3.DeleteBucketEncryptionOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func deleteBucketIntelligentTieringConfiguration(input: AWSS3.DeleteBucketIntelligentTieringConfigurationInput) async throws -> AWSS3.DeleteBucketIntelligentTieringConfigurationOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func deleteBucketInventoryConfiguration(input: AWSS3.DeleteBucketInventoryConfigurationInput) async throws -> AWSS3.DeleteBucketInventoryConfigurationOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func deleteBucketLifecycle(input: AWSS3.DeleteBucketLifecycleInput) async throws -> AWSS3.DeleteBucketLifecycleOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func deleteBucketMetricsConfiguration(input: AWSS3.DeleteBucketMetricsConfigurationInput) async throws -> AWSS3.DeleteBucketMetricsConfigurationOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func deleteBucketOwnershipControls(input: AWSS3.DeleteBucketOwnershipControlsInput) async throws -> AWSS3.DeleteBucketOwnershipControlsOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func deleteBucketPolicy(input: AWSS3.DeleteBucketPolicyInput) async throws -> AWSS3.DeleteBucketPolicyOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func deleteBucketReplication(input: AWSS3.DeleteBucketReplicationInput) async throws -> AWSS3.DeleteBucketReplicationOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func deleteBucketTagging(input: AWSS3.DeleteBucketTaggingInput) async throws -> AWSS3.DeleteBucketTaggingOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func deleteBucketWebsite(input: AWSS3.DeleteBucketWebsiteInput) async throws -> AWSS3.DeleteBucketWebsiteOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func deleteObjects(input: AWSS3.DeleteObjectsInput) async throws -> AWSS3.DeleteObjectsOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func deleteObjectTagging(input: AWSS3.DeleteObjectTaggingInput) async throws -> AWSS3.DeleteObjectTaggingOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func deletePublicAccessBlock(input: AWSS3.DeletePublicAccessBlockInput) async throws -> AWSS3.DeletePublicAccessBlockOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func getBucketAccelerateConfiguration(input: AWSS3.GetBucketAccelerateConfigurationInput) async throws -> AWSS3.GetBucketAccelerateConfigurationOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func getBucketAcl(input: AWSS3.GetBucketAclInput) async throws -> AWSS3.GetBucketAclOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func getBucketAnalyticsConfiguration(input: AWSS3.GetBucketAnalyticsConfigurationInput) async throws -> AWSS3.GetBucketAnalyticsConfigurationOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func getBucketCors(input: AWSS3.GetBucketCorsInput) async throws -> AWSS3.GetBucketCorsOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func getBucketEncryption(input: AWSS3.GetBucketEncryptionInput) async throws -> AWSS3.GetBucketEncryptionOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func getBucketIntelligentTieringConfiguration(input: AWSS3.GetBucketIntelligentTieringConfigurationInput) async throws -> AWSS3.GetBucketIntelligentTieringConfigurationOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func getBucketInventoryConfiguration(input: AWSS3.GetBucketInventoryConfigurationInput) async throws -> AWSS3.GetBucketInventoryConfigurationOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func getBucketLifecycleConfiguration(input: AWSS3.GetBucketLifecycleConfigurationInput) async throws -> AWSS3.GetBucketLifecycleConfigurationOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func getBucketLocation(input: AWSS3.GetBucketLocationInput) async throws -> AWSS3.GetBucketLocationOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func getBucketLogging(input: AWSS3.GetBucketLoggingInput) async throws -> AWSS3.GetBucketLoggingOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func getBucketMetricsConfiguration(input: AWSS3.GetBucketMetricsConfigurationInput) async throws -> AWSS3.GetBucketMetricsConfigurationOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func getBucketNotificationConfiguration(input: AWSS3.GetBucketNotificationConfigurationInput) async throws -> AWSS3.GetBucketNotificationConfigurationOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func getBucketOwnershipControls(input: AWSS3.GetBucketOwnershipControlsInput) async throws -> AWSS3.GetBucketOwnershipControlsOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func getBucketPolicy(input: AWSS3.GetBucketPolicyInput) async throws -> AWSS3.GetBucketPolicyOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func getBucketPolicyStatus(input: AWSS3.GetBucketPolicyStatusInput) async throws -> AWSS3.GetBucketPolicyStatusOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func getBucketReplication(input: AWSS3.GetBucketReplicationInput) async throws -> AWSS3.GetBucketReplicationOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func getBucketRequestPayment(input: AWSS3.GetBucketRequestPaymentInput) async throws -> AWSS3.GetBucketRequestPaymentOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func getBucketTagging(input: AWSS3.GetBucketTaggingInput) async throws -> AWSS3.GetBucketTaggingOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func getBucketVersioning(input: AWSS3.GetBucketVersioningInput) async throws -> AWSS3.GetBucketVersioningOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func getBucketWebsite(input: AWSS3.GetBucketWebsiteInput) async throws -> AWSS3.GetBucketWebsiteOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func getObject(input: AWSS3.GetObjectInput) async throws -> AWSS3.GetObjectOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func getObjectAcl(input: AWSS3.GetObjectAclInput) async throws -> AWSS3.GetObjectAclOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func getObjectAttributes(input: AWSS3.GetObjectAttributesInput) async throws -> AWSS3.GetObjectAttributesOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func getObjectLegalHold(input: AWSS3.GetObjectLegalHoldInput) async throws -> AWSS3.GetObjectLegalHoldOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func getObjectLockConfiguration(input: AWSS3.GetObjectLockConfigurationInput) async throws -> AWSS3.GetObjectLockConfigurationOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func getObjectRetention(input: AWSS3.GetObjectRetentionInput) async throws -> AWSS3.GetObjectRetentionOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func getObjectTagging(input: AWSS3.GetObjectTaggingInput) async throws -> AWSS3.GetObjectTaggingOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func getObjectTorrent(input: AWSS3.GetObjectTorrentInput) async throws -> AWSS3.GetObjectTorrentOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func getPublicAccessBlock(input: AWSS3.GetPublicAccessBlockInput) async throws -> AWSS3.GetPublicAccessBlockOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func headBucket(input: AWSS3.HeadBucketInput) async throws -> AWSS3.HeadBucketOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func headObject(input: AWSS3.HeadObjectInput) async throws -> AWSS3.HeadObjectOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func listBucketAnalyticsConfigurations(input: AWSS3.ListBucketAnalyticsConfigurationsInput) async throws -> AWSS3.ListBucketAnalyticsConfigurationsOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func listBucketIntelligentTieringConfigurations(input: AWSS3.ListBucketIntelligentTieringConfigurationsInput) async throws -> AWSS3.ListBucketIntelligentTieringConfigurationsOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func listBucketInventoryConfigurations(input: AWSS3.ListBucketInventoryConfigurationsInput) async throws -> AWSS3.ListBucketInventoryConfigurationsOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func listBucketMetricsConfigurations(input: AWSS3.ListBucketMetricsConfigurationsInput) async throws -> AWSS3.ListBucketMetricsConfigurationsOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func listBuckets(input: AWSS3.ListBucketsInput) async throws -> AWSS3.ListBucketsOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func listMultipartUploads(input: AWSS3.ListMultipartUploadsInput) async throws -> AWSS3.ListMultipartUploadsOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func listObjects(input: AWSS3.ListObjectsInput) async throws -> AWSS3.ListObjectsOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func listObjectVersions(input: AWSS3.ListObjectVersionsInput) async throws -> AWSS3.ListObjectVersionsOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func putBucketAccelerateConfiguration(input: AWSS3.PutBucketAccelerateConfigurationInput) async throws -> AWSS3.PutBucketAccelerateConfigurationOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func putBucketAcl(input: AWSS3.PutBucketAclInput) async throws -> AWSS3.PutBucketAclOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func putBucketAnalyticsConfiguration(input: AWSS3.PutBucketAnalyticsConfigurationInput) async throws -> AWSS3.PutBucketAnalyticsConfigurationOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func putBucketCors(input: AWSS3.PutBucketCorsInput) async throws -> AWSS3.PutBucketCorsOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func putBucketEncryption(input: AWSS3.PutBucketEncryptionInput) async throws -> AWSS3.PutBucketEncryptionOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func putBucketIntelligentTieringConfiguration(input: AWSS3.PutBucketIntelligentTieringConfigurationInput) async throws -> AWSS3.PutBucketIntelligentTieringConfigurationOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func putBucketInventoryConfiguration(input: AWSS3.PutBucketInventoryConfigurationInput) async throws -> AWSS3.PutBucketInventoryConfigurationOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func putBucketLifecycleConfiguration(input: AWSS3.PutBucketLifecycleConfigurationInput) async throws -> AWSS3.PutBucketLifecycleConfigurationOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func putBucketLogging(input: AWSS3.PutBucketLoggingInput) async throws -> AWSS3.PutBucketLoggingOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func putBucketMetricsConfiguration(input: AWSS3.PutBucketMetricsConfigurationInput) async throws -> AWSS3.PutBucketMetricsConfigurationOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func putBucketNotificationConfiguration(input: AWSS3.PutBucketNotificationConfigurationInput) async throws -> AWSS3.PutBucketNotificationConfigurationOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func putBucketOwnershipControls(input: AWSS3.PutBucketOwnershipControlsInput) async throws -> AWSS3.PutBucketOwnershipControlsOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func putBucketPolicy(input: AWSS3.PutBucketPolicyInput) async throws -> AWSS3.PutBucketPolicyOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func putBucketReplication(input: AWSS3.PutBucketReplicationInput) async throws -> AWSS3.PutBucketReplicationOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func putBucketRequestPayment(input: AWSS3.PutBucketRequestPaymentInput) async throws -> AWSS3.PutBucketRequestPaymentOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func putBucketTagging(input: AWSS3.PutBucketTaggingInput) async throws -> AWSS3.PutBucketTaggingOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func putBucketVersioning(input: AWSS3.PutBucketVersioningInput) async throws -> AWSS3.PutBucketVersioningOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func putBucketWebsite(input: AWSS3.PutBucketWebsiteInput) async throws -> AWSS3.PutBucketWebsiteOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func putObject(input: AWSS3.PutObjectInput) async throws -> AWSS3.PutObjectOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func putObjectAcl(input: AWSS3.PutObjectAclInput) async throws -> AWSS3.PutObjectAclOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func putObjectLegalHold(input: AWSS3.PutObjectLegalHoldInput) async throws -> AWSS3.PutObjectLegalHoldOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func putObjectLockConfiguration(input: AWSS3.PutObjectLockConfigurationInput) async throws -> AWSS3.PutObjectLockConfigurationOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func putObjectRetention(input: AWSS3.PutObjectRetentionInput) async throws -> AWSS3.PutObjectRetentionOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func putObjectTagging(input: AWSS3.PutObjectTaggingInput) async throws -> AWSS3.PutObjectTaggingOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func putPublicAccessBlock(input: AWSS3.PutPublicAccessBlockInput) async throws -> AWSS3.PutPublicAccessBlockOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func restoreObject(input: AWSS3.RestoreObjectInput) async throws -> AWSS3.RestoreObjectOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func selectObjectContent(input: AWSS3.SelectObjectContentInput) async throws -> AWSS3.SelectObjectContentOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func uploadPart(input: AWSS3.UploadPartInput) async throws -> AWSS3.UploadPartOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func uploadPartCopy(input: AWSS3.UploadPartCopyInput) async throws -> AWSS3.UploadPartCopyOutputResponse {
+ fatalError("Not Implemented")
+ }
+
+ func writeGetObjectResponse(input: AWSS3.WriteGetObjectResponseInput) async throws -> AWSS3.WriteGetObjectResponseOutputResponse {
+ fatalError("Not Implemented")
+ }
+}
diff --git a/AmplifyPlugins/Storage/Tests/AWSS3StoragePluginTests/Service/Storage/AWSS3StorageServiceTests.swift b/AmplifyPlugins/Storage/Tests/AWSS3StoragePluginTests/Service/Storage/AWSS3StorageServiceTests.swift
new file mode 100644
index 0000000000..25f5410d4a
--- /dev/null
+++ b/AmplifyPlugins/Storage/Tests/AWSS3StoragePluginTests/Service/Storage/AWSS3StorageServiceTests.swift
@@ -0,0 +1,455 @@
+//
+// Copyright Amazon.com Inc. or its affiliates.
+// All Rights Reserved.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+
+@testable import Amplify
+@testable import AWSPluginsTestCommon
+@testable import AWSS3StoragePlugin
+import ClientRuntime
+import AWSS3
+import XCTest
+
+class AWSS3StorageServiceTests: XCTestCase {
+ private var service: AWSS3StorageService!
+ private var authService: MockAWSAuthService!
+ private var database: StorageTransferDatabaseMock!
+ private var task: StorageTransferTask!
+ private var fileSystem: MockFileSystem!
+
+ override func setUp() async throws {
+ authService = MockAWSAuthService()
+ database = StorageTransferDatabaseMock()
+ fileSystem = MockFileSystem()
+ task = StorageTransferTask(
+ transferType: .download(onEvent: { _ in}),
+ bucket: "bucket",
+ key: "key"
+ )
+ task.uploadId = "uploadId"
+ task.sessionTask = MockStorageSessionTask(taskIdentifier: 1)
+ database.recoverResult = .success([
+ .init(transferTask: task,
+ multipartUploads: [
+ .created(
+ uploadId: "uploadId",
+ uploadFile:UploadFile(
+ fileURL: FileSystem.default.createTemporaryFileURL(),
+ temporaryFileCreated: true,
+ size: UInt64(Bytes.megabytes(12).bytes)
+ )
+ )
+ ]
+ )
+ ])
+ service = try AWSS3StorageService(
+ authService: authService,
+ region: "region",
+ bucket: "bucket",
+ httpClientEngineProxy: MockHttpClientEngineProxy(),
+ storageTransferDatabase: database,
+ fileSystem: fileSystem,
+ logger: MockLogger()
+ )
+ }
+
+ override func tearDown() {
+ authService = nil
+ service = nil
+ database = nil
+ task = nil
+ fileSystem = nil
+ }
+
+ /// Given: An AWSS3StorageService
+ /// When: it's deallocated
+ /// Then: StorageBackgroundEventsRegistry.identifier should be set to nil
+ func testDeinit_shouldUnregisterIdentifier() {
+ XCTAssertNotNil(StorageBackgroundEventsRegistry.identifier)
+ service = nil
+ XCTAssertNil(StorageBackgroundEventsRegistry.identifier)
+ }
+
+ /// Given: An AWSS3StorageService
+ /// When: reset is invoked
+ /// Then: Its members should be set to nil
+ func testReset_shouldSetValuesToNil() {
+ service.reset()
+ XCTAssertNil(service.preSignedURLBuilder)
+ XCTAssertNil(service.awsS3)
+ XCTAssertNil(service.region)
+ XCTAssertNil(service.bucket)
+ XCTAssertTrue(service.tasks.isEmpty)
+ XCTAssertTrue(service.multipartUploadSessions.isEmpty)
+ }
+
+ /// Given: An AWSS3StorageService
+ /// When: attachEventHandlers is invoked and a .completed event is sent
+ /// Then: A .completed event is dispatched to the event handler
+ func testAttachEventHandlers() {
+ let expectation = self.expectation(description: "Attach Event Handlers")
+ service.attachEventHandlers(
+ onUpload: { event in
+ guard case .completed(_) = event else {
+ XCTFail("Expected completed")
+ return
+ }
+ expectation.fulfill()
+ }
+ )
+ XCTAssertNotNil(database.onUploadHandler)
+ database.onUploadHandler?(.completed(()))
+ waitForExpectations(timeout: 1)
+ }
+
+ /// Given: An AWSS3StorageService
+ /// When: register is invoked with a task
+ /// Then: The task should be added to its map of tasks
+ func testRegisterTask_shouldAddItToTasksDictionary() {
+ service.register(task: task)
+ XCTAssertEqual(service.tasks.count, 1)
+ XCTAssertNotNil(service.tasks[1])
+ }
+
+ /// Given: An AWSS3StorageService with a task in its map of tasks
+ /// When: unregister is invoked with said task
+ /// Then: The task should be removed from the map of tasks
+ func testUnregisterTask_shouldRemoveItToTasksDictionary() {
+ service.tasks = [
+ 1: task
+ ]
+ service.unregister(task: task)
+ XCTAssertTrue(service.tasks.isEmpty)
+ XCTAssertNil(service.tasks[1])
+ }
+
+ /// Given: An AWSS3StorageService with some tasks in its map of tasks
+ /// When: unregister is invoked with an identifier that is known to be mapped to a task.
+ /// Then: The task corresponding to the given identifier should be removed from the map of tasks
+ func testUnregisterTaskIdentifiers_shouldRemoveItToTasksDictionary() {
+ service.tasks = [
+ 1: task,
+ 2: task
+ ]
+ service.unregister(taskIdentifiers: [1])
+ XCTAssertEqual(service.tasks.count, 1)
+ XCTAssertNotNil(service.tasks[2])
+ XCTAssertNil(service.tasks[1])
+ }
+
+ /// Given: An AWSS3StorageService with a task in its map of tasks
+ /// When: findTask is invoked with the identifier known to be mapped to a task
+ /// Then: The task corresponding to the given identifier is returned
+ func testFindTask_shouldReturnTask() {
+ service.tasks = [
+ 1: task
+ ]
+ XCTAssertNotNil(service.findTask(taskIdentifier: 1))
+ }
+
+ /// Given: An AWSS3StorageService
+ /// When: validateParameters is invoked with an empty bucket parameter
+ /// Then: A .validation error is thrown
+ func testValidateParameters_withEmptyBucket_shouldThrowError() {
+ do {
+ try service.validateParameters(bucket: "", key: "key", accelerationModeEnabled: true)
+ XCTFail("Expected error")
+ } catch {
+ guard case .validation(let field, let description, let recovery, _) = error as? StorageError else {
+ XCTFail("Expected StorageError.validation")
+ return
+ }
+ XCTAssertEqual(field, "bucket")
+ XCTAssertEqual(description, "Invalid bucket specified.")
+ XCTAssertEqual(recovery, "Please specify a bucket name or configure the bucket property.")
+ }
+ }
+
+ /// Given: An AWSS3StorageService
+ /// When: validateParameters is invoked with an empty key parameter
+ /// Then: A .validation error is thrown
+ func testValidateParameters_withEmptyKey_shouldThrowError() {
+ do {
+ try service.validateParameters(bucket: "bucket", key: "", accelerationModeEnabled: true)
+ XCTFail("Expected error")
+ } catch {
+ guard case .validation(let field, let description, let recovery, _) = error as? StorageError else {
+ XCTFail("Expected StorageError.validation")
+ return
+ }
+ XCTAssertEqual(field, "key")
+ XCTAssertEqual(description, "Invalid key specified.")
+ XCTAssertEqual(recovery, "Please specify a key.")
+ }
+ }
+
+ /// Given: An AWSS3StorageService
+ /// When: validateParameters is invoked with valid bucket and key parameters
+ /// Then: No error is thrown
+ func testValidateParameters_withValidParams_shouldNotThrowError() {
+ do {
+ try service.validateParameters(bucket: "bucket", key: "key", accelerationModeEnabled: true)
+ } catch {
+ XCTFail("Expected success, got \(error)")
+ }
+ }
+
+ /// Given: An AWSS3StorageService
+ /// When: createTransferTask is invoked with valid parameters
+ /// Then: A task is returned with attributes matching the ones provided
+ func testCreateTransferTask_shouldReturnTask() {
+ let task = service.createTransferTask(
+ transferType: .upload(onEvent: { event in }),
+ bucket: "bucket",
+ key: "key",
+ requestHeaders: [
+ "header": "value"
+ ]
+ )
+ XCTAssertEqual(task.bucket, "bucket")
+ XCTAssertEqual(task.key, "key")
+ XCTAssertEqual(task.requestHeaders?.count, 1)
+ XCTAssertEqual(task.requestHeaders?["header"], "value")
+ guard case .upload(_) = task.transferType else {
+ XCTFail("Expected .upload transferType")
+ return
+ }
+ }
+
+ /// Given: An AWSS3StorageService with a non-completed download task
+ /// When: completeDownload is invoked for the identifier matching the task
+ /// Then: The task is marked as completed and a .completed event is dispatched
+ func testCompleteDownload_shouldReturnData() {
+ let expectation = self.expectation(description: "Complete Download")
+
+ let downloadTask = StorageTransferTask(
+ transferType: .download(onEvent: { event in
+ guard case .completed(let data) = event,
+ let data = data else {
+ XCTFail("Expected .completed event with data")
+ return
+ }
+ XCTAssertEqual(String(decoding: data, as: UTF8.self), "someFile")
+ expectation.fulfill()
+ }),
+ bucket: "bucket",
+ key: "key"
+ )
+
+ let sourceUrl = FileManager.default.temporaryDirectory.appendingPathComponent("\(UUID().uuidString).txt")
+ try! "someFile".write(to: sourceUrl, atomically: true, encoding: .utf8)
+
+ service.tasks = [
+ 1: downloadTask
+ ]
+
+ service.completeDownload(taskIdentifier: 1, sourceURL: sourceUrl)
+ XCTAssertEqual(downloadTask.status, .completed)
+ waitForExpectations(timeout: 1)
+ }
+
+ /// Given: An AWSS3StorageService with a non-completed download task that sets a location
+ /// When: completeDownload is invoked for the identifier matching the task
+ /// Then: The task is marked as completed and the file is moved to the expected location
+ func testCompleteDownload_withLocation_shouldMoveFileToLocation() {
+ let temporaryDirectory = FileManager.default.temporaryDirectory
+ let location = temporaryDirectory.appendingPathComponent("\(UUID().uuidString)-newFile.txt")
+
+ let downloadTask = StorageTransferTask(
+ transferType: .download(onEvent: { _ in }),
+ bucket: "bucket",
+ key: "key",
+ location: location
+ )
+
+ let sourceUrl = temporaryDirectory.appendingPathComponent("\(UUID().uuidString)-oldFile.txt")
+ try! "someFile".write(to: sourceUrl, atomically: true, encoding: .utf8)
+
+ service.tasks = [
+ 1: downloadTask
+ ]
+
+ service.completeDownload(taskIdentifier: 1, sourceURL: sourceUrl)
+ XCTAssertTrue(FileManager.default.fileExists(atPath: location.path))
+ XCTAssertFalse(FileManager.default.fileExists(atPath: sourceUrl.path))
+ XCTAssertEqual(downloadTask.status, .completed)
+ }
+
+ /// Given: An AWSS3StorageService with a non-completed download task that sets a location
+ /// When: completeDownload is invoked for the identifier matching the task, but the file system fails to move the file
+ /// Then: The task is marked as error and the file is not moved to the expected location
+ func testCompleteDownload_withLocation_andError_shouldFailTask() {
+ let temporaryDirectory = FileManager.default.temporaryDirectory
+ let location = temporaryDirectory.appendingPathComponent("\(UUID().uuidString)-newFile.txt")
+
+ let downloadTask = StorageTransferTask(
+ transferType: .download(onEvent: { _ in }),
+ bucket: "bucket",
+ key: "key",
+ location: location
+ )
+
+ let sourceUrl = temporaryDirectory.appendingPathComponent("\(UUID().uuidString)-oldFile.txt")
+ try! "someFile".write(to: sourceUrl, atomically: true, encoding: .utf8)
+
+ service.tasks = [
+ 1: downloadTask
+ ]
+
+ fileSystem.moveFileError = StorageError.unknown("Unable to move file", nil)
+ service.completeDownload(taskIdentifier: 1, sourceURL: sourceUrl)
+ XCTAssertFalse(FileManager.default.fileExists(atPath: location.path))
+ XCTAssertTrue(FileManager.default.fileExists(atPath: sourceUrl.path))
+ XCTAssertEqual(downloadTask.status, .error)
+ }
+
+ /// Given: An AWSS3StorageService with a non-completed upload task that sets a location
+ /// When: completeDownload is invoked for the identifier matching the task
+ /// Then: The task status is not updated and an .upload event is not dispatched
+ func testCompleteDownload_withNoDownload_shouldDoNothing() {
+ let expectation = self.expectation(description: "Complete Download")
+ expectation.isInverted = true
+
+ let uploadTask = StorageTransferTask(
+ transferType: .upload(onEvent: { event in
+ XCTFail("Should not report event")
+ expectation.fulfill()
+ }),
+ bucket: "bucket",
+ key: "key"
+ )
+
+ let sourceUrl = FileManager.default.temporaryDirectory.appendingPathComponent("\(UUID().uuidString).txt")
+ try! "someFile".write(to: sourceUrl, atomically: true, encoding: .utf8)
+
+ service.tasks = [
+ 1: uploadTask
+ ]
+
+ service.completeDownload(taskIdentifier: 1, sourceURL: sourceUrl)
+ XCTAssertNotEqual(uploadTask.status, .completed)
+ XCTAssertNotEqual(uploadTask.status, .error)
+ waitForExpectations(timeout: 1)
+ }
+
+ /// Given: An AWSS3StorageService that cannot create a pre signed url
+ /// When: upload is invoked
+ /// Then: A .failed event is dispatched with an .unknown error
+ func testUpload_withoutPreSignedURL_shouldSendFailEvent() {
+ let data = "someData".data(using: .utf8)!
+ let expectation = self.expectation(description: "Upload")
+ service.upload(
+ serviceKey: "key",
+ uploadSource: .data(data),
+ contentType: "application/json",
+ metadata: [:],
+ accelerate: true,
+ onEvent: { event in
+ guard case .failed(let error) = event,
+ case .unknown(let description, _) = error else {
+ XCTFail("Expected .failed event with .unknown error, got \(event)")
+ return
+ }
+ XCTAssertEqual(description, "Failed to get pre-signed URL")
+ expectation.fulfill()
+ }
+ )
+
+ waitForExpectations(timeout: 1)
+ }
+
+ /// Given: An AWSS3StorageService that can create a pre signed url
+ /// When: upload is invoked
+ /// Then: An .initiated event is dispatched
+ func testUpload_withPreSignedURL_shouldSendInitiatedEvent() {
+ let data = "someData".data(using: .utf8)!
+ let expectation = self.expectation(description: "Upload")
+ service.preSignedURLBuilder = MockAWSS3PreSignedURLBuilder()
+ service.upload(
+ serviceKey: "key",
+ uploadSource: .data(data),
+ contentType: "application/json",
+ metadata: [:],
+ accelerate: true,
+ onEvent: { event in
+ guard case .initiated(_) = event else {
+ XCTFail("Expected .initiated event, got \(event)")
+ return
+ }
+ expectation.fulfill()
+ }
+ )
+
+ waitForExpectations(timeout: 1)
+ }
+}
+
+private class MockHttpClientEngineProxy: HttpClientEngineProxy {
+ var target: HttpClientEngine? = nil
+
+ var executeCount = 0
+ var executeRequest: SdkHttpRequest?
+ func execute(request: SdkHttpRequest) async throws -> HttpResponse {
+ executeCount += 1
+ executeRequest = request
+ return .init(body: .empty, statusCode: .accepted)
+ }
+}
+
+private class StorageTransferDatabaseMock: StorageTransferDatabase {
+
+ func prepareForBackground(completion: (() -> Void)?) {
+ completion?()
+ }
+
+ func insertTransferRequest(task: StorageTransferTask) {
+
+ }
+
+ func updateTransferRequest(task: StorageTransferTask) {
+
+ }
+
+ func removeTransferRequest(task: StorageTransferTask) {
+
+ }
+
+ func defaultTransferType(persistableTransferTask: StoragePersistableTransferTask) -> StorageTransferType? {
+ return nil
+ }
+
+ var recoverCount = 0
+ var recoverResult: Result = .failure(StorageError.unknown("Result not set", nil))
+ func recover(urlSession: StorageURLSession,
+ completionHandler: @escaping (Result) -> Void) {
+ recoverCount += 1
+ completionHandler(recoverResult)
+ }
+
+ var attachEventHandlersCount = 0
+ var onUploadHandler: AWSS3StorageServiceBehavior.StorageServiceUploadEventHandler? = nil
+ var onDownloadHandler: AWSS3StorageServiceBehavior.StorageServiceDownloadEventHandler? = nil
+ var onMultipartUploadHandler: AWSS3StorageServiceBehavior.StorageServiceMultiPartUploadEventHandler? = nil
+ func attachEventHandlers(
+ onUpload: AWSS3StorageServiceBehavior.StorageServiceUploadEventHandler?,
+ onDownload: AWSS3StorageServiceBehavior.StorageServiceDownloadEventHandler?,
+ onMultipartUpload: AWSS3StorageServiceBehavior.StorageServiceMultiPartUploadEventHandler?
+ ) {
+ attachEventHandlersCount += 1
+ onUploadHandler = onUpload
+ onDownloadHandler = onDownload
+ onMultipartUploadHandler = onMultipartUpload
+ }
+}
+
+private class MockFileSystem: FileSystem {
+ var moveFileError: Error? = nil
+ override func moveFile(from sourceFileURL: URL, to destinationURL: URL) throws {
+ if let moveFileError = moveFileError {
+ throw moveFileError
+ }
+ try super.moveFile(from: sourceFileURL, to: destinationURL)
+ }
+}
diff --git a/AmplifyPlugins/Storage/Tests/AWSS3StoragePluginTests/Support/Internal/DefaultStorageMultipartUploadClientTests.swift b/AmplifyPlugins/Storage/Tests/AWSS3StoragePluginTests/Support/Internal/DefaultStorageMultipartUploadClientTests.swift
new file mode 100644
index 0000000000..561543f57d
--- /dev/null
+++ b/AmplifyPlugins/Storage/Tests/AWSS3StoragePluginTests/Support/Internal/DefaultStorageMultipartUploadClientTests.swift
@@ -0,0 +1,458 @@
+//
+// Copyright Amazon.com Inc. or its affiliates.
+// All Rights Reserved.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+
+@testable import Amplify
+@testable import func AmplifyTestCommon.XCTAssertThrowFatalError
+@testable import AWSS3StoragePlugin
+import AWSS3
+import XCTest
+
+class DefaultStorageMultipartUploadClientTests: XCTestCase {
+ private var defaultClient: DefaultStorageMultipartUploadClient!
+ private var serviceProxy: MockStorageServiceProxy!
+ private var session: MockStorageMultipartUploadSession!
+ private var awss3Behavior: MockAWSS3Behavior!
+ private var uploadFile: UploadFile!
+
+ override func setUp() async throws {
+ awss3Behavior = MockAWSS3Behavior()
+ serviceProxy = MockStorageServiceProxy(
+ awsS3: awss3Behavior
+ )
+ let tempFileURL = FileManager.default.temporaryDirectory
+ .appendingPathComponent(UUID().uuidString)
+ .appendingPathExtension("txt")
+ try "Hello World".write(to: tempFileURL, atomically: true, encoding: .utf8)
+ uploadFile = UploadFile(
+ fileURL: tempFileURL,
+ temporaryFileCreated: false,
+ size: 88
+ )
+ defaultClient = DefaultStorageMultipartUploadClient(
+ serviceProxy: serviceProxy,
+ bucket: "bucket",
+ key: "key",
+ uploadFile: uploadFile
+ )
+ session = MockStorageMultipartUploadSession(
+ client: client,
+ bucket: "bucket",
+ key: "key",
+ onEvent: { event in }
+ )
+ client.integrate(session: session)
+ }
+
+ private var client: StorageMultipartUploadClient! {
+ defaultClient
+ }
+
+ override func tearDown() {
+ defaultClient = nil
+ serviceProxy = nil
+ session = nil
+ awss3Behavior = nil
+ uploadFile = nil
+ }
+
+ /// Given: a DefaultStorageMultipartUploadClient
+ /// When: createMultipartUpload is invoked and AWSS3Behavior returns .success
+ /// Then: A .created event is reported to the session and the session is registered
+ func testCreateMultipartUpload_withSuccess_shouldSucceed() throws {
+ awss3Behavior.createMultipartUploadExpectation = expectation(description: "Create Multipart Upload")
+ awss3Behavior.createMultipartUploadResult = .success(.init(
+ bucket: "bucket",
+ key: "key",
+ uploadId: "uploadId"
+ ))
+ try client.createMultipartUpload()
+
+ waitForExpectations(timeout: 1)
+ XCTAssertEqual(awss3Behavior.createMultipartUploadCount, 1)
+ XCTAssertEqual(session.handleMultipartUploadCount, 2)
+ XCTAssertEqual(session.failCount, 0)
+ if case .created(let uploadFile, let uploadId) = try XCTUnwrap(session.lastMultipartUploadEvent) {
+ XCTAssertEqual(uploadFile.fileURL, uploadFile.fileURL)
+ XCTAssertEqual(uploadId, "uploadId")
+ }
+ XCTAssertEqual(serviceProxy.registerMultipartUploadSessionCount, 1)
+ }
+
+ /// Given: a DefaultStorageMultipartUploadClient
+ /// When: createMultipartUpload is invoked and AWSS3Behavior returns .failure
+ /// Then: An .unknown error is reported to the session and the session is not registered
+ func testCreateMultipartUpload_withError_shouldFail() throws {
+ awss3Behavior.createMultipartUploadExpectation = expectation(description: "Create Multipart Upload")
+ awss3Behavior.createMultipartUploadResult = .failure(.unknown("Unknown Error", nil))
+ try client.createMultipartUpload()
+
+ waitForExpectations(timeout: 1)
+ XCTAssertEqual(awss3Behavior.createMultipartUploadCount, 1)
+ XCTAssertEqual(session.handleMultipartUploadCount, 1)
+ XCTAssertEqual(session.failCount, 1)
+ if case .unknown(let description, _) = try XCTUnwrap(session.lastError) as? StorageError {
+ XCTAssertEqual(description, "Unknown Error")
+ }
+ XCTAssertEqual(serviceProxy.registerMultipartUploadSessionCount, 0)
+ }
+
+ /// Given: a DefaultStorageMultipartUploadClient
+ /// When: serviceProxy is set to nil and createMultipartUpload is invoked
+ /// Then: A fatal error is thrown
+ func testCreateMultipartUpload_withoutServiceProxy_shouldThrowFatalError() throws {
+ serviceProxy = nil
+ try XCTAssertThrowFatalError {
+ try? self.client.createMultipartUpload()
+ }
+ }
+
+ /// Given: a DefaultStorageMultipartUploadClient
+ /// When: uploadPart is invoked with valid parts
+ /// Then: A .started event is reported to the session
+ func testUploadPart_withParts_shouldSucceed() throws {
+ session.handleUploadPartExpectation = expectation(description: "Upload Part with parts")
+
+ try client.uploadPart(
+ partNumber: 1,
+ multipartUpload: .parts(
+ uploadId: "uploadId",
+ uploadFile: uploadFile,
+ partSize: .default,
+ parts: [
+ .pending(bytes: 10),
+ .pending(bytes: 20)
+ ]
+ ),
+ subTask: .init(
+ transferType: .upload(onEvent: { event in }),
+ bucket: "bucket",
+ key: "key"
+ )
+ )
+
+ waitForExpectations(timeout: 1)
+ XCTAssertEqual(session.handleUploadPartCount, 1)
+ XCTAssertEqual(session.failCount, 0)
+ if case .started(let partNumber, _) = try XCTUnwrap(session.lastUploadEvent) {
+ XCTAssertEqual(partNumber, 1)
+ }
+ }
+
+ /// Given: a DefaultStorageMultipartUploadClient
+ /// When: uploadPart is invoked with a non-existing file
+ /// Then: An error is reported to the session
+ func testUploadPart_withInvalidFile_shouldFail() throws {
+ session.failExpectation = expectation(description: "Upload Part with invalid file")
+
+ try client.uploadPart(
+ partNumber: 1,
+ multipartUpload: .parts(
+ uploadId: "uploadId",
+ uploadFile: .init(
+ fileURL: FileManager.default.temporaryDirectory.appendingPathComponent("noFile.txt"),
+ temporaryFileCreated: false,
+ size: 1024),
+ partSize: .default,
+ parts: [
+ .pending(bytes: 10),
+ .pending(bytes: 20)
+ ]
+ ),
+ subTask: .init(
+ transferType: .upload(onEvent: { event in }),
+ bucket: "bucket",
+ key: "key"
+ )
+ )
+
+ waitForExpectations(timeout: 1)
+ XCTAssertEqual(session.handleUploadPartCount, 0)
+ XCTAssertEqual(session.failCount, 1)
+ XCTAssertNil(session.lastUploadEvent)
+ XCTAssertNotNil(session.lastError)
+ }
+
+ /// Given: a DefaultStorageMultipartUploadClient
+ /// When: serviceProxy is set to nil and uploadPart is invoked
+ /// Then: A fatal error is thrown
+ func testUploadPart_withoutServiceProxy_shouldThrowFatalError() throws {
+ self.serviceProxy = nil
+ try XCTAssertThrowFatalError {
+ try? self.client.uploadPart(
+ partNumber: 1,
+ multipartUpload: .parts(
+ uploadId: "uploadId",
+ uploadFile: self.uploadFile,
+ partSize: .default,
+ parts: [
+ .pending(bytes: 10),
+ .pending(bytes: 20)
+ ]
+ ),
+ subTask: .init(
+ transferType: .upload(onEvent: { event in }),
+ bucket: "bucket",
+ key: "key"
+ )
+ )
+ }
+ }
+
+ /// Given: a DefaultStorageMultipartUploadClient
+ /// When: uploadPart is invoked without parts
+ /// Then: A fatal error is thrown
+ func testUploadPart_withoutParts_shouldThrowFatalError() throws {
+ try XCTAssertThrowFatalError {
+ try? self.client.uploadPart(
+ partNumber: 1,
+ multipartUpload: .created(
+ uploadId: "uploadId",
+ uploadFile: self.uploadFile
+ ),
+ subTask: .init(
+ transferType: .upload(onEvent: { event in }),
+ bucket: "bucket",
+ key: "key"
+ )
+ )
+ }
+ }
+
+ /// Given: a DefaultStorageMultipartUploadClient
+ /// When: completeMultipartUpload is invoked and AWSS3Behaviour returns succees
+ /// Then: A .completed event is reported to the session and the session is unregistered
+ func testCompleteMultipartUpload_withSuccess_shouldSucceed() throws {
+ awss3Behavior.completeMultipartUploadExpectation = expectation(description: "Complete Multipart Upload")
+ awss3Behavior.completeMultipartUploadResult = .success(.init(
+ bucket: "bucket",
+ key: "key",
+ eTag: "eTag"
+ ))
+ try client.completeMultipartUpload(uploadId: "uploadId")
+
+ waitForExpectations(timeout: 1)
+ XCTAssertEqual(awss3Behavior.completeMultipartUploadCount, 1)
+ XCTAssertEqual(session.handleMultipartUploadCount, 1)
+ XCTAssertEqual(session.failCount, 0)
+ if case .completed(let uploadId) = try XCTUnwrap(session.lastMultipartUploadEvent) {
+ XCTAssertEqual(uploadId, "uploadId")
+ }
+ XCTAssertEqual(serviceProxy.unregisterMultipartUploadSessionCount, 1)
+ }
+
+ /// Given: a DefaultStorageMultipartUploadClient
+ /// When: completeMultipartUpload is invoked and AWSS3Behaviour returns failure
+ /// Then: A .unknown error is reported to the session and the session is not unregistered
+ func testCompleteMultipartUpload_withError_shouldFail() throws {
+ awss3Behavior.completeMultipartUploadExpectation = expectation(description: "Complete Multipart Upload")
+ awss3Behavior.completeMultipartUploadResult = .failure(.unknown("Unknown Error", nil))
+ try client.completeMultipartUpload(uploadId: "uploadId")
+
+ waitForExpectations(timeout: 1)
+ XCTAssertEqual(awss3Behavior.completeMultipartUploadCount, 1)
+ XCTAssertEqual(session.handleMultipartUploadCount, 0)
+ XCTAssertEqual(session.failCount, 1)
+ if case .unknown(let description, _) = try XCTUnwrap(session.lastError) as? StorageError {
+ XCTAssertEqual(description, "Unknown Error")
+ }
+ XCTAssertEqual(serviceProxy.unregisterMultipartUploadSessionCount, 1)
+ }
+
+ /// Given: a DefaultStorageMultipartUploadClient
+ /// When: serviceProxy is set to nil and completeMultipartUpload is invoked
+ /// Then: A fatal error is thrown
+ func testCompleteMultipartUpload_withoutServiceProxy_shouldThrowFatalError() throws {
+ serviceProxy = nil
+ try XCTAssertThrowFatalError {
+ try? self.client.completeMultipartUpload(uploadId: "uploadId")
+ }
+ }
+
+ /// Given: a DefaultStorageMultipartUploadClient
+ /// When: abortMultipartUpload is invoked and AWSS3Behaviour returns success
+ /// Then: An .aborted event is reported to the session and the session is unregistered
+ func testAbortMultipartUpload_withSuccess_shouldSucceed() throws {
+ awss3Behavior.abortMultipartUploadExpectation = expectation(description: "Abort Multipart Upload")
+ awss3Behavior.abortMultipartUploadResult = .success(())
+ try client.abortMultipartUpload(uploadId: "uploadId", error: CancellationError())
+
+ waitForExpectations(timeout: 1)
+ XCTAssertEqual(awss3Behavior.abortMultipartUploadCount, 1)
+ XCTAssertEqual(session.handleMultipartUploadCount, 1)
+ XCTAssertEqual(session.failCount, 0)
+ if case .aborted(let uploadId, let error) = try XCTUnwrap(session.lastMultipartUploadEvent) {
+ XCTAssertEqual(uploadId, "uploadId")
+ XCTAssertTrue(error is CancellationError)
+ }
+ XCTAssertEqual(serviceProxy.unregisterMultipartUploadSessionCount, 1)
+ }
+
+ /// Given: a DefaultStorageMultipartUploadClient
+ /// When: abortMultipartUpload is invoked and AWSS3Behaviour returns failure
+ /// Then: A .unknown error is reported to the session and the session is not unregistered
+ func testAbortMultipartUpload_withError_shouldFail() throws {
+ awss3Behavior.abortMultipartUploadExpectation = expectation(description: "Abort Multipart Upload")
+ awss3Behavior.abortMultipartUploadResult = .failure(.unknown("Unknown Error", nil))
+ try client.abortMultipartUpload(uploadId: "uploadId")
+
+ waitForExpectations(timeout: 1)
+ XCTAssertEqual(awss3Behavior.abortMultipartUploadCount, 1)
+ XCTAssertEqual(session.handleMultipartUploadCount, 0)
+ XCTAssertEqual(session.failCount, 1)
+ if case .unknown(let description, _) = try XCTUnwrap(session.lastError) as? StorageError {
+ XCTAssertEqual(description, "Unknown Error")
+ }
+ XCTAssertEqual(serviceProxy.unregisterMultipartUploadSessionCount, 1)
+ }
+
+ /// Given: a DefaultStorageMultipartUploadClient
+ /// When: serviceProxy is set to nil and abortMultipartUpload is invoked
+ /// Then: A fatal error is thrown
+ func testAbortMultipartUpload_withoutServiceProxy_shouldThrowFatalError() throws {
+ serviceProxy = nil
+ try XCTAssertThrowFatalError {
+ try? self.client.abortMultipartUpload(uploadId: "uploadId")
+ }
+ }
+
+ /// Given: a DefaultStorageMultipartUploadClient
+ /// When: cancelUploadTasks is invoked with identifiers
+ /// Then: The tasks are unregistered
+ func testCancelUploadTasks_shouldSucceed() throws {
+ let cancelExpectation = expectation(description: "Cancel Upload Tasks")
+ client.cancelUploadTasks(taskIdentifiers: [0, 1,2], done: {
+ cancelExpectation.fulfill()
+ })
+
+ waitForExpectations(timeout: 1)
+ XCTAssertEqual(serviceProxy.unregisterTaskIdentifiersCount, 1)
+ }
+
+ /// Given: a DefaultStorageMultipartUploadClient
+ /// When: filter is invoked with some disallowed values
+ /// Then: a dictionary is returned with the disallowed values removed
+ func testFilterRequestHeaders_shouldResultFilteredHeaders() {
+ let filteredHeaders = defaultClient.filter(
+ requestHeaders: [
+ "validHeader": "validValue",
+ "x-amz-acl": "invalidValue",
+ "x-amz-tagging": "invalidValue",
+ "x-amz-storage-class": "invalidValue",
+ "x-amz-server-side-encryption": "invalidValue",
+ "x-amz-meta-invalid_one": "invalidValue",
+ "x-amz-meta-invalid_two": "invalidValue",
+ "x-amz-grant-invalid_one": "invalidvalue",
+ "x-amz-grant-invalid_two": "invalidvalue"
+ ]
+ )
+
+ XCTAssertEqual(filteredHeaders.count, 1)
+ XCTAssertEqual(filteredHeaders["validHeader"], "validValue")
+ }
+}
+
+private class MockStorageServiceProxy: StorageServiceProxy {
+ var preSignedURLBuilder: AWSS3PreSignedURLBuilderBehavior! = MockAWSS3PreSignedURLBuilder()
+ var awsS3: AWSS3Behavior!
+ var urlSession = URLSession.shared
+ var userAgent: String = ""
+ var urlRequestDelegate: URLRequestDelegate? = nil
+
+ init(awsS3: AWSS3Behavior) {
+ self.awsS3 = awsS3
+ }
+
+ func register(task: StorageTransferTask) {}
+
+ func unregister(task: StorageTransferTask) {}
+
+ var unregisterTaskIdentifiersCount = 0
+ func unregister(taskIdentifiers: [TaskIdentifier]) {
+ unregisterTaskIdentifiersCount += 1
+ }
+
+ var registerMultipartUploadSessionCount = 0
+ func register(multipartUploadSession: StorageMultipartUploadSession) {
+ registerMultipartUploadSessionCount += 1
+ }
+
+ var unregisterMultipartUploadSessionCount = 0
+ func unregister(multipartUploadSession: StorageMultipartUploadSession) {
+ unregisterMultipartUploadSessionCount += 1
+ }
+}
+
+private class MockAWSS3Behavior: AWSS3Behavior {
+ func deleteObject(_ request: AWSS3DeleteObjectRequest, completion: @escaping (Result) -> Void) {}
+
+ func listObjectsV2(_ request: AWSS3ListObjectsV2Request, completion: @escaping (Result) -> Void) {}
+
+ var createMultipartUploadCount = 0
+ var createMultipartUploadResult: Result? = nil
+ var createMultipartUploadExpectation: XCTestExpectation? = nil
+ func createMultipartUpload(_ request: CreateMultipartUploadRequest, completion: @escaping (Result) -> Void) {
+ createMultipartUploadCount += 1
+ if let result = createMultipartUploadResult {
+ completion(result)
+ }
+ createMultipartUploadExpectation?.fulfill()
+ }
+
+ var completeMultipartUploadCount = 0
+ var completeMultipartUploadResult: Result? = nil
+ var completeMultipartUploadExpectation: XCTestExpectation? = nil
+ func completeMultipartUpload(_ request: AWSS3CompleteMultipartUploadRequest, completion: @escaping (Result) -> Void) {
+ completeMultipartUploadCount += 1
+ if let result = completeMultipartUploadResult {
+ completion(result)
+ }
+ completeMultipartUploadExpectation?.fulfill()
+ }
+
+ var abortMultipartUploadCount = 0
+ var abortMultipartUploadResult: Result? = nil
+ var abortMultipartUploadExpectation: XCTestExpectation? = nil
+ func abortMultipartUpload(_ request: AWSS3AbortMultipartUploadRequest, completion: @escaping (Result) -> Void) {
+ abortMultipartUploadCount += 1
+ if let result = abortMultipartUploadResult {
+ completion(result)
+ }
+ abortMultipartUploadExpectation?.fulfill()
+ }
+
+ func getS3() -> S3ClientProtocol {
+ return MockS3Client()
+ }
+}
+
+class MockStorageMultipartUploadSession: StorageMultipartUploadSession {
+ var handleMultipartUploadCount = 0
+ var lastMultipartUploadEvent: StorageMultipartUploadEvent? = nil
+ override func handle(multipartUploadEvent: StorageMultipartUploadEvent) {
+ handleMultipartUploadCount += 1
+ lastMultipartUploadEvent = multipartUploadEvent
+ }
+
+ var handleUploadPartCount = 0
+ var lastUploadEvent: StorageUploadPartEvent? = nil
+ var handleUploadPartExpectation: XCTestExpectation? = nil
+
+ override func handle(uploadPartEvent: StorageUploadPartEvent) {
+ handleUploadPartCount += 1
+ lastUploadEvent = uploadPartEvent
+ handleUploadPartExpectation?.fulfill()
+ }
+
+ var failCount = 0
+ var lastError: Error? = nil
+ var failExpectation: XCTestExpectation? = nil
+ override func fail(error: Error) {
+ failCount += 1
+ lastError = error
+ failExpectation?.fulfill()
+ }
+}
diff --git a/AmplifyPlugins/Storage/Tests/AWSS3StoragePluginTests/Support/Internal/DefaultStorageTransferDatabaseTests.swift b/AmplifyPlugins/Storage/Tests/AWSS3StoragePluginTests/Support/Internal/DefaultStorageTransferDatabaseTests.swift
new file mode 100644
index 0000000000..127c151229
--- /dev/null
+++ b/AmplifyPlugins/Storage/Tests/AWSS3StoragePluginTests/Support/Internal/DefaultStorageTransferDatabaseTests.swift
@@ -0,0 +1,241 @@
+//
+// Copyright Amazon.com Inc. or its affiliates.
+// All Rights Reserved.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+
+@testable import Amplify
+@testable import AWSS3StoragePlugin
+import XCTest
+
+class DefaultStorageTransferDatabaseTests: XCTestCase {
+ private var database: DefaultStorageTransferDatabase!
+ private var uploadFile: UploadFile!
+ private var session: MockStorageSessionTask!
+
+ override func setUp() {
+ database = DefaultStorageTransferDatabase(
+ databaseDirectoryURL: FileManager.default.temporaryDirectory,
+ logger: MockLogger()
+ )
+ uploadFile = UploadFile(
+ fileURL: FileSystem.default.createTemporaryFileURL(),
+ temporaryFileCreated: true,
+ size: UInt64(Bytes.megabytes(12).bytes)
+ )
+ session = MockStorageSessionTask(taskIdentifier: 1)
+ }
+
+ override func tearDown() {
+ database = nil
+ uploadFile = nil
+ session = nil
+ }
+
+ /// Given: A DefaultStorageTransferDatabase
+ /// When: linkTasksWithSessions is invoked with tasks containing multipart uploads and a sessionTask, and a session
+ /// Then: A StorageTransferTaskPairs linking the tasks with the session is returned
+ func testLinkTasksWithSessions_withMultipartUpload_shouldReturnPairs() {
+ let transferTask1 = StorageTransferTask(
+ transferType: .multiPartUpload(onEvent: { _ in }),
+ bucket: "bucket",
+ key: "key1"
+ )
+ transferTask1.sessionTask = session
+ transferTask1.multipartUpload = .created(
+ uploadId: "uploadId",
+ uploadFile: uploadFile
+ )
+
+ let transferTask2 = StorageTransferTask(
+ transferType: .multiPartUpload(onEvent: { _ in }),
+ bucket: "bucket",
+ key: "key2"
+ )
+ transferTask2.sessionTask = session
+ transferTask2.multipartUpload = .created(
+ uploadId: "uploadId",
+ uploadFile: uploadFile
+ )
+
+ let pairs = database.linkTasksWithSessions(
+ persistableTransferTasks: [
+ "taskId1": .init(task: transferTask1),
+ "taskId2": .init(task: transferTask2)
+ ],
+ sessionTasks: [
+ session
+ ]
+ )
+
+ XCTAssertEqual(pairs.count, 2)
+ XCTAssertTrue(pairs.contains(where: { $0.transferTask.key == "key1" }))
+ XCTAssertTrue(pairs.contains(where: { $0.transferTask.key == "key2" }))
+ }
+
+ /// Given: A DefaultStorageTransferDatabase
+ /// When: linkTasksWithSessions is invoked with tasks containing multipart uploads but without a sessionTask, and a session
+ /// Then: A StorageTransferTaskPairs linking the tasks with the session is returned
+ func testLinkTasksWithSessions_withMultipartUpload_andNoSession_shouldReturnPairs() {
+ let transferTask1 = StorageTransferTask(
+ transferType: .multiPartUpload(onEvent: { _ in }),
+ bucket: "bucket",
+ key: "key1"
+ )
+ transferTask1.multipartUpload = .created(
+ uploadId: "uploadId",
+ uploadFile: uploadFile
+ )
+
+ let transferTask2 = StorageTransferTask(
+ transferType: .multiPartUpload(onEvent: { _ in }),
+ bucket: "bucket",
+ key: "key2"
+ )
+ transferTask2.multipartUpload = .created(
+ uploadId: "uploadId",
+ uploadFile: uploadFile
+ )
+
+ let pairs = database.linkTasksWithSessions(
+ persistableTransferTasks: [
+ "taskId1": .init(task: transferTask1),
+ "taskId2": .init(task: transferTask2)
+ ],
+ sessionTasks: [
+ session
+ ]
+ )
+
+ XCTAssertEqual(pairs.count, 2)
+ XCTAssertTrue(pairs.contains(where: { $0.transferTask.key == "key1" }))
+ XCTAssertTrue(pairs.contains(where: { $0.transferTask.key == "key2" }))
+ }
+
+ /// Given: A DefaultStorageTransferDatabase
+ /// When: linkTasksWithSessions is invoked with tasks containing multipart upload parts, and a session
+ /// Then: A StorageTransferTaskPairs linking the tasks with the session is returned
+ func testLinkTasksWithSessions_withMultipartUploadPart_shouldReturnPairs() {
+ let transferTask0 = StorageTransferTask(
+ transferType: .multiPartUpload(onEvent: { _ in }),
+ bucket: "bucket",
+ key: "key1"
+ )
+ transferTask0.sessionTask = session
+ transferTask0.multipartUpload = .created(
+ uploadId: "uploadId",
+ uploadFile: uploadFile
+ )
+
+ let transferTask1 = StorageTransferTask(
+ transferType: .multiPartUploadPart(
+ uploadId: "uploadId",
+ partNumber: 1
+ ),
+ bucket: "bucket",
+ key: "key1"
+ )
+ transferTask1.sessionTask = session
+ transferTask1.uploadId = "uploadId"
+ transferTask1.multipartUpload = .parts(
+ uploadId: "uploadId",
+ uploadFile: uploadFile,
+ partSize: try! .init(fileSize: UInt64(Bytes.megabytes(6).bytes)),
+ parts: [
+ .inProgress(
+ bytes: Bytes.megabytes(6).bytes,
+ bytesTransferred: Bytes.megabytes(3).bytes,
+ taskIdentifier: 1
+ ),
+ .completed(
+ bytes: Bytes.megabytes(6).bytes,
+ eTag: "eTag")
+ ,
+ .pending(bytes: Bytes.megabytes(6).bytes)
+ ]
+ )
+ transferTask1.uploadPart = .completed(
+ bytes: Bytes.megabytes(6).bytes,
+ eTag: "eTag"
+ )
+
+ let transferTask2 = StorageTransferTask(
+ transferType: .multiPartUploadPart(
+ uploadId: "uploadId",
+ partNumber: 2
+ ),
+ bucket: "bucket",
+ key: "key1"
+ )
+ transferTask2.sessionTask = session
+ transferTask2.uploadId = "uploadId"
+ transferTask2.multipartUpload = .parts(
+ uploadId: "uploadId",
+ uploadFile: uploadFile,
+ partSize: try! .init(fileSize: UInt64(Bytes.megabytes(6).bytes)),
+ parts: [
+ .pending(bytes: Bytes.megabytes(6).bytes),
+ .pending(bytes: Bytes.megabytes(6).bytes)
+ ]
+ )
+ transferTask2.uploadPart = .inProgress(
+ bytes: Bytes.megabytes(6).bytes,
+ bytesTransferred: Bytes.megabytes(3).bytes,
+ taskIdentifier: 1
+ )
+
+ let pairs = database.linkTasksWithSessions(
+ persistableTransferTasks: [
+ "taskId0": .init(task: transferTask0),
+ "taskId1": .init(task: transferTask1),
+ "taskId2": .init(task: transferTask2)
+ ],
+ sessionTasks: [
+ session
+ ]
+ )
+
+ XCTAssertEqual(pairs.count, 3)
+ XCTAssertTrue(pairs.contains(where: { $0.transferTask.key == "key1" }))
+ XCTAssertFalse(pairs.contains(where: { $0.transferTask.key == "key2" }))
+ }
+
+ /// Given: A DefaultStorageTransferDatabase
+ /// When: recover is invoked with a StorageURLSession that returns a session
+ /// Then: A .success is returned
+ func testLoadPersistableTasks() {
+ let urlSession = MockStorageURLSession(
+ sessionTasks: [
+ session
+ ])
+ let expectation = self.expectation(description: "Recover")
+ database.recover(urlSession: urlSession) { result in
+ guard case .success(_) = result else {
+ XCTFail("Expected success")
+ return
+ }
+ expectation.fulfill()
+ }
+ waitForExpectations(timeout: 1)
+ }
+
+ /// Given: A DefaultStorageTransferDatabase
+ /// When: prepareForBackground is invoked
+ /// Then: A callback is invoked
+ func testPrepareForBackground() {
+ let expectation = self.expectation(description: "Prepare for Background")
+ database.prepareForBackground() {
+ expectation.fulfill()
+ }
+ waitForExpectations(timeout: 1)
+ }
+
+ /// Given: The StorageTransferDatabase Type
+ /// When: default is invoked
+ /// Then: An instance of DefaultStorageTransferDatabase is returned
+ func testDefault_shouldReturnDefaultInstance() {
+ let defaultProtocol: StorageTransferDatabase = .default
+ XCTAssertTrue(defaultProtocol is DefaultStorageTransferDatabase)
+ }
+}
diff --git a/AmplifyPlugins/Storage/Tests/AWSS3StoragePluginTests/Support/Internal/StorageMultipartUploadSessionTests.swift b/AmplifyPlugins/Storage/Tests/AWSS3StoragePluginTests/Support/Internal/StorageMultipartUploadSessionTests.swift
index fa808b4ccc..dad5a0a531 100644
--- a/AmplifyPlugins/Storage/Tests/AWSS3StoragePluginTests/Support/Internal/StorageMultipartUploadSessionTests.swift
+++ b/AmplifyPlugins/Storage/Tests/AWSS3StoragePluginTests/Support/Internal/StorageMultipartUploadSessionTests.swift
@@ -41,6 +41,44 @@ class StorageMultipartUploadSessionTests: XCTestCase {
XCTAssertFalse(session.partsFailed)
}
+ /// Given: A StorageTransferTask with a valid StorageTransferType
+ /// When: A StorageMultipartUploadSession is created from the task
+ /// Then: Its values are set correctly
+ func testSessionCreation_withTransferTask() throws {
+ let client = MockMultipartUploadClient()
+ let transferType: StorageTransferType = .multiPartUpload(onEvent: {_ in })
+ let transferTask = StorageTransferTask(
+ transferType: transferType,
+ bucket: "bucket",
+ key: "key"
+ )
+
+ let session = try XCTUnwrap(StorageMultipartUploadSession(client: client, transferTask: transferTask, multipartUpload: .none, logger: MockLogger()))
+ XCTAssertEqual(session.partsCount, 0)
+ XCTAssertEqual(session.inProgressCount, 0)
+ XCTAssertFalse(session.partsCompleted)
+ XCTAssertFalse(session.partsFailed)
+ }
+
+ /// Given: A StorageTransferTask with an invalid StorageTransferType
+ /// When: A StorageMultipartUploadSession is created from the task
+ /// Then: Its values are set correctly
+ func testSessionCreation_withTransferTask_andInvalidTransferType_shouldReturnNil() throws {
+ let client = MockMultipartUploadClient()
+ let transferType: StorageTransferType = .list(onEvent: {_ in })
+ let transferTask = StorageTransferTask(
+ transferType: transferType,
+ bucket: "bucket",
+ key: "key"
+ )
+
+ XCTAssertNil(StorageMultipartUploadSession(
+ client: client,
+ transferTask: transferTask,
+ multipartUpload: .none
+ ))
+ }
+
func testCompletedMultipartUploadSession() throws {
let initiatedExp = expectation(description: "Initiated")
let completedExp = expectation(description: "Completed")
@@ -105,7 +143,7 @@ class StorageMultipartUploadSessionTests: XCTestCase {
let client = MockMultipartUploadClient() // creates an UploadFile for the mock process
client.didCompletePartUpload = { (_, partNumber, _, _) in
if partNumber == 5 {
- closureSession?.handle(multipartUploadEvent: .aborting(error: nil))
+ closureSession?.cancel()
XCTAssertTrue(closureSession?.isAborted ?? false)
}
@@ -156,10 +194,10 @@ class StorageMultipartUploadSessionTests: XCTestCase {
if pauseCount == 0, partNumber > 5, bytesTransferred > 0 {
print("pausing on \(partNumber)")
pauseCount += 1
- closureSession?.handle(multipartUploadEvent: .pausing)
+ closureSession?.pause()
XCTAssertTrue(closureSession?.isPaused ?? false)
print("resuming on \(partNumber)")
- closureSession?.handle(multipartUploadEvent: .resuming)
+ closureSession?.resume()
XCTAssertFalse(closureSession?.isPaused ?? true)
}
}
diff --git a/AmplifyPlugins/Storage/Tests/AWSS3StoragePluginTests/Support/Internal/StorageServiceSessionDelegateTests.swift b/AmplifyPlugins/Storage/Tests/AWSS3StoragePluginTests/Support/Internal/StorageServiceSessionDelegateTests.swift
new file mode 100644
index 0000000000..b8498a0387
--- /dev/null
+++ b/AmplifyPlugins/Storage/Tests/AWSS3StoragePluginTests/Support/Internal/StorageServiceSessionDelegateTests.swift
@@ -0,0 +1,363 @@
+//
+// Copyright Amazon.com Inc. or its affiliates.
+// All Rights Reserved.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+
+@testable import Amplify
+@testable import AWSPluginsTestCommon
+@testable import AWSS3StoragePlugin
+import ClientRuntime
+import AWSS3
+import XCTest
+
+class StorageServiceSessionDelegateTests: XCTestCase {
+ private var delegate: StorageServiceSessionDelegate!
+ private var service: AWSS3StorageServiceMock!
+ private var logger: MockLogger!
+
+ override func setUp() {
+ service = try! AWSS3StorageServiceMock()
+ logger = MockLogger()
+ delegate = StorageServiceSessionDelegate(
+ identifier: "delegateTest",
+ logger: logger
+ )
+ delegate.storageService = service
+ }
+
+ override func tearDown() {
+ logger = nil
+ service = nil
+ delegate = nil
+ }
+
+ /// Given: A StorageServiceSessionDelegate
+ /// When: logURLSessionActivity is invoked with warning set to true
+ /// Then: A warn message is logged
+ func testLogURLSession_withWarningTrue_shouldLogWarning() {
+ delegate.logURLSessionActivity("message", warning: true)
+ XCTAssertEqual(logger.warnCount, 1)
+ XCTAssertEqual(logger.infoCount, 0)
+ }
+
+ /// Given: A StorageServiceSessionDelegate
+ /// When: logURLSessionActivity is invoked without setting warning
+ /// Then: An info message is logged
+ func testLogURLSession_shouldLogInfo() {
+ delegate.logURLSessionActivity("message")
+ XCTAssertEqual(logger.warnCount, 0)
+ XCTAssertEqual(logger.infoCount, 1)
+ }
+
+ /// Given: A StorageServiceSessionDelegate and an identifier registered in the registry
+ /// When: the registry's handleBackgroundEvents is invoked with a matching identifier and then urlSessionDidFinishEvents is invoked
+ /// Then: The registry's continuation is triggered with true
+ func testDidFinishEvents_withMatchingIdentifiers_shouldTriggerContinuationWithTrue() async {
+ let handleEventsExpectation = self.expectation(description: "Handle Background Events")
+ let finishEventsExpectation = self.expectation(description: "Did Finish Events")
+ StorageBackgroundEventsRegistry.register(identifier: "identifier")
+ Task {
+ let result = await withCheckedContinuation { continuation in
+ StorageBackgroundEventsRegistry.handleBackgroundEvents(
+ identifier: "identifier",
+ continuation: continuation
+ )
+ handleEventsExpectation.fulfill()
+ }
+ XCTAssertTrue(result)
+ finishEventsExpectation.fulfill()
+ }
+
+ await fulfillment(of: [handleEventsExpectation], timeout: 1)
+ XCTAssertNotNil(StorageBackgroundEventsRegistry.continuation)
+ delegate.urlSessionDidFinishEvents(forBackgroundURLSession: .shared)
+ await fulfillment(of: [finishEventsExpectation], timeout: 1)
+ XCTAssertNil(StorageBackgroundEventsRegistry.continuation)
+ }
+
+ /// Given: A StorageServiceSessionDelegate and an identifier registered in the registry
+ /// When: the registry's handleBackgroundEvents is invoked first with a matching identifier and then with a non-matching one, and after that urlSessionDidFinishEvents is invoked
+ /// Then: The registry's continuation for the non-matching identifier is triggered immediately with false, while the one for the matching identifier is triggered with true only after urlSessionDidFinishEvents is invoked
+ func testDidFinishEvents_withNonMatchingIdentifiers_shouldTriggerContinuationWithFalse() async {
+ let handleEventsMatchingExpectation = self.expectation(description: "Handle Background Events with Matching Identifiers")
+ let finishEventsExpectation = self.expectation(description: "Did Finish Events")
+ StorageBackgroundEventsRegistry.register(identifier: "identifier")
+ Task {
+ let result = await withCheckedContinuation { continuation in
+ StorageBackgroundEventsRegistry.handleBackgroundEvents(
+ identifier: "identifier",
+ continuation: continuation
+ )
+ handleEventsMatchingExpectation.fulfill()
+ }
+ XCTAssertTrue(result)
+ finishEventsExpectation.fulfill()
+ }
+
+ await fulfillment(of: [handleEventsMatchingExpectation], timeout: 1)
+ XCTAssertNotNil(StorageBackgroundEventsRegistry.continuation)
+
+ let handleEventsNonMatchingExpectation = self.expectation(description: "Handle Background Events with Matching Identifiers")
+ Task {
+ let result = await withCheckedContinuation { continuation in
+ StorageBackgroundEventsRegistry.handleBackgroundEvents(
+ identifier: "identifier2",
+ continuation: continuation
+ )
+ }
+ XCTAssertFalse(result)
+ handleEventsNonMatchingExpectation.fulfill()
+ }
+ await fulfillment(of: [handleEventsNonMatchingExpectation], timeout: 1)
+ delegate.urlSessionDidFinishEvents(forBackgroundURLSession: .shared)
+ await fulfillment(of: [finishEventsExpectation], timeout: 1)
+ XCTAssertNil(StorageBackgroundEventsRegistry.continuation)
+ }
+
+ /// Given: A StorageServiceSessionDelegate
+ /// When: didBecomeInvalidWithError is invoked with a StorageError
+ /// Then: The service's resetURLSession is invoked
+ func testDidBecomeInvalid_withError_shouldResetURLSession() {
+ delegate.urlSession(.shared, didBecomeInvalidWithError: StorageError.accessDenied("", "", nil))
+ XCTAssertEqual(service.resetURLSessionCount, 1)
+ }
+
+ /// Given: A StorageServiceSessionDelegate
+ /// When: didBecomeInvalidWithError is invoked with a nil error
+ /// Then: The service's resetURLSession is invoked
+ func testDidBecomeInvalid_withNilError_shouldResetURLSession() {
+ delegate.urlSession(.shared, didBecomeInvalidWithError: nil)
+ XCTAssertEqual(service.resetURLSessionCount, 1)
+ }
+
+ /// Given: A StorageServiceSessionDelegate and a StorageTransferTask with a NSError with a NSURLErrorCancelled reason
+ /// When: didComplete is invoked
+ /// Then: The task is not unregistered
+ func testDidComplete_withNSURLErrorCancelled_shouldNotCompleteTask() {
+ let task = URLSession.shared.dataTask(with: FileManager.default.temporaryDirectory)
+ let reasons = [
+ NSURLErrorCancelledReasonBackgroundUpdatesDisabled,
+ NSURLErrorCancelledReasonInsufficientSystemResources,
+ NSURLErrorCancelledReasonUserForceQuitApplication,
+ NSURLErrorCancelled
+ ]
+
+ for reason in reasons {
+ let expectation = self.expectation(description: "Did Complete With Error Reason \(reason)")
+ expectation.isInverted = true
+ let storageTask = StorageTransferTask(
+ transferType: .upload(onEvent: { _ in
+ expectation.fulfill()
+ }),
+ bucket: "bucket",
+ key: "key"
+ )
+ service.mockedTask = storageTask
+ let error: Error = NSError(
+ domain: NSURLErrorDomain,
+ code: NSURLErrorCancelled,
+ userInfo: [
+ NSURLErrorBackgroundTaskCancelledReasonKey: reason
+ ]
+ )
+
+ delegate.urlSession(.shared, task: task, didCompleteWithError: error)
+ waitForExpectations(timeout: 1)
+ XCTAssertEqual(storageTask.status, .unknown)
+ XCTAssertEqual(service.unregisterCount, 0)
+ }
+ }
+
+ /// Given: A StorageServiceSessionDelegate and a StorageTransferTask with a StorageError
+ /// When: didComplete is invoked
+ /// Then: The task status is set to error and it's unregistered
+ func testDidComplete_withError_shouldFailTask() {
+ let task = URLSession.shared.dataTask(with: FileManager.default.temporaryDirectory)
+ let expectation = self.expectation(description: "Did Complete With Error")
+ let storageTask = StorageTransferTask(
+ transferType: .upload(onEvent: { _ in
+ expectation.fulfill()
+ }),
+ bucket: "bucket",
+ key: "key"
+ )
+ service.mockedTask = storageTask
+
+ delegate.urlSession(.shared, task: task, didCompleteWithError: StorageError.accessDenied("", "", nil))
+ waitForExpectations(timeout: 1)
+ XCTAssertEqual(storageTask.status, .error)
+ XCTAssertEqual(service.unregisterCount, 1)
+ }
+
+ /// Given: A StorageServiceSessionDelegate and a StorageTransferTask of type .upload
+ /// When: didSendBodyData is invoked
+ /// Then: An .inProcess event is reported, with the corresponding values
+ func testDidSendBodyData_upload_shouldSendInProcessEvent() {
+ let task = URLSession.shared.dataTask(with: FileManager.default.temporaryDirectory)
+ let expectation = self.expectation(description: "Did Send Body Data")
+ let storageTask = StorageTransferTask(
+ transferType: .upload(onEvent: { event in
+ guard case .inProcess(let progress) = event else {
+ XCTFail("Expected .inProcess event, got \(event)")
+ return
+ }
+ XCTAssertEqual(progress.totalUnitCount, 120)
+ XCTAssertEqual(progress.completedUnitCount, 100)
+ expectation.fulfill()
+ }),
+ bucket: "bucket",
+ key: "key"
+ )
+ service.mockedTask = storageTask
+
+ delegate.urlSession(
+ .shared,
+ task: task,
+ didSendBodyData: 10,
+ totalBytesSent: 100,
+ totalBytesExpectedToSend: 120
+ )
+
+ waitForExpectations(timeout: 1)
+ }
+
+ /// Given: A StorageServiceSessionDelegate and a StorageTransferTask of type .multiPartUploadPart
+ /// When: didSendBodyData is invoked
+ /// Then: A .progressUpdated event is reported to the session
+ func testDidSendBodyData_multiPartUploadPart_shouldSendInProcessEvent() {
+ let task = URLSession.shared.dataTask(with: FileManager.default.temporaryDirectory)
+ let storageTask = StorageTransferTask(
+ transferType: .multiPartUploadPart(
+ uploadId: "uploadId",
+ partNumber: 3
+ ),
+ bucket: "bucket",
+ key: "key"
+ )
+ service.mockedTask = storageTask
+ let multipartSession = MockStorageMultipartUploadSession(
+ client: MockMultipartUploadClient(),
+ bucket: "bucket",
+ key: "key",
+ onEvent: { event in }
+ )
+ service.mockedMultipartUploadSession = multipartSession
+
+ delegate.urlSession(
+ .shared,
+ task: task,
+ didSendBodyData: 10,
+ totalBytesSent: 100,
+ totalBytesExpectedToSend: 120
+ )
+ XCTAssertEqual(multipartSession.handleUploadPartCount, 1)
+ guard case .progressUpdated(let partNumber, let bytesTransferred, let taskIdentifier) = multipartSession.lastUploadEvent else {
+ XCTFail("Expected .progressUpdated event")
+ return
+ }
+
+ XCTAssertEqual(partNumber, 3)
+ XCTAssertEqual(bytesTransferred, 10)
+ XCTAssertEqual(taskIdentifier, task.taskIdentifier)
+ }
+
+ /// Given: A StorageServiceSessionDelegate and a StorageTransferTask of type .download
+ /// When: didWriteData is invoked
+ /// Then: An .inProcess event is reported, with the corresponding values
+ func testDidWriteData_shouldNotifyProgress() {
+ let task = URLSession.shared.downloadTask(with: FileManager.default.temporaryDirectory)
+ let expectation = self.expectation(description: "Did Write Data")
+ let storageTask = StorageTransferTask(
+ transferType: .download(onEvent: { event in
+ guard case .inProcess(let progress) = event else {
+ XCTFail("Expected .inProcess event, got \(event)")
+ return
+ }
+ XCTAssertEqual(progress.totalUnitCount, 300)
+ XCTAssertEqual(progress.completedUnitCount, 200)
+ expectation.fulfill()
+ }),
+ bucket: "bucket",
+ key: "key"
+ )
+ service.mockedTask = storageTask
+
+ delegate.urlSession(
+ .shared,
+ downloadTask: task,
+ didWriteData: 15,
+ totalBytesWritten: 200,
+ totalBytesExpectedToWrite: 300
+ )
+
+ waitForExpectations(timeout: 1)
+ }
+
+ /// Given: A StorageServiceSessionDelegate and a URLSessionDownloadTask without a httpResponse
+ /// When: didFinishDownloadingTo is invoked
+ /// Then: No event is reported and the task is not completed
+ func testDiFinishDownloading_withError_shouldNotCompleteDownload() {
+ let task = URLSession.shared.downloadTask(with: FileManager.default.temporaryDirectory)
+ let expectation = self.expectation(description: "Did Finish Downloading")
+ expectation.isInverted = true
+ let storageTask = StorageTransferTask(
+ transferType: .download(onEvent: { _ in
+ expectation.fulfill()
+ }),
+ bucket: "bucket",
+ key: "key"
+ )
+ service.mockedTask = storageTask
+
+ delegate.urlSession(
+ .shared,
+ downloadTask: task,
+ didFinishDownloadingTo: FileManager.default.temporaryDirectory
+ )
+
+ waitForExpectations(timeout: 1)
+ XCTAssertEqual(service.completeDownloadCount, 0)
+ }
+}
+
+private class AWSS3StorageServiceMock: AWSS3StorageService {
+ convenience init() throws {
+ try self.init(
+ authService: MockAWSAuthService(),
+ region: "region",
+ bucket: "bucket",
+ storageTransferDatabase: MockStorageTransferDatabase()
+ )
+ }
+
+ override var identifier: String {
+ return "identifier"
+ }
+
+ var mockedTask: StorageTransferTask? = nil
+ override func findTask(taskIdentifier: TaskIdentifier) -> StorageTransferTask? {
+ return mockedTask
+ }
+
+ var resetURLSessionCount = 0
+ override func resetURLSession() {
+ resetURLSessionCount += 1
+ }
+
+ var unregisterCount = 0
+ override func unregister(task: StorageTransferTask) {
+ unregisterCount += 1
+ }
+
+ var mockedMultipartUploadSession: StorageMultipartUploadSession? = nil
+ override func findMultipartUploadSession(uploadId: UploadID) -> StorageMultipartUploadSession? {
+ return mockedMultipartUploadSession
+ }
+
+ var completeDownloadCount = 0
+ override func completeDownload(taskIdentifier: TaskIdentifier, sourceURL: URL) {
+ completeDownloadCount += 1
+ }
+}
diff --git a/AmplifyPlugins/Storage/Tests/AWSS3StoragePluginTests/Support/Internal/StorageTransferTaskTests.swift b/AmplifyPlugins/Storage/Tests/AWSS3StoragePluginTests/Support/Internal/StorageTransferTaskTests.swift
new file mode 100644
index 0000000000..c9b96aea8d
--- /dev/null
+++ b/AmplifyPlugins/Storage/Tests/AWSS3StoragePluginTests/Support/Internal/StorageTransferTaskTests.swift
@@ -0,0 +1,658 @@
+//
+// Copyright Amazon.com Inc. or its affiliates.
+// All Rights Reserved.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+
+import Amplify
+@testable import AWSS3StoragePlugin
+import XCTest
+
+class StorageTransferTaskTests: XCTestCase {
+
+ // MARK: - Resume tests
+ /// Given: A StorageTransferTask with a sessionTask
+ /// When: resume is invoked
+ /// Then: an .initiated event is reported and the task set to .inProgress
+ func testResume_withSessionTask_shouldCallResume_andReportInitiatedEvent() {
+ let expectation = expectation(description: ".initiated event received on resume with only sessionTask")
+ let sessionTask = MockSessionTask()
+ let task = createTask(
+ transferType: .upload(onEvent: { event in
+ guard case .initiated(_) = event else {
+ XCTFail("Expected .initiated, got \(event)")
+ return
+ }
+ expectation.fulfill()
+ }),
+ sessionTask: sessionTask,
+ proxyStorageTask: nil
+ )
+ XCTAssertEqual(task.status, .paused)
+
+ task.resume()
+ waitForExpectations(timeout: 0.5)
+
+ XCTAssertEqual(sessionTask.resumeCount, 1)
+ XCTAssertEqual(task.status, .inProgress)
+ }
+
+ /// Given: A StorageTransferTask with a proxyStorageTask
+ /// When: resume is invoked
+ /// Then: an .initiated event is reported and the task set to .inProgress
+ func testResume_withProxyStorageTask_shouldCallResume_andReportInitiatedEvent() {
+ let expectation = expectation(description: ".initiated event received on resume with only proxyStorageTask")
+ let sessionTask = MockSessionTask()
+ let storageTask = MockStorageTask()
+ let task = createTask(
+ transferType: .download(onEvent: { event in
+ guard case .initiated(_) = event else {
+ XCTFail("Expected .initiated, got \(event)")
+ return
+ }
+ expectation.fulfill()
+ }),
+ sessionTask: sessionTask, // Set the sessionTask to set task.status = .paused
+ proxyStorageTask: storageTask
+ )
+ task.sessionTask = nil // Remove the session task
+ XCTAssertEqual(task.status, .paused)
+
+ task.resume()
+ waitForExpectations(timeout: 0.5)
+
+ XCTAssertEqual(sessionTask.resumeCount, 0)
+ XCTAssertEqual(storageTask.resumeCount, 1)
+ XCTAssertEqual(task.status, .inProgress)
+ }
+
+ /// Given: A StorageTransferTask with a sessionTask and a proxyStorageTask
+ /// When: resume is invoked
+ /// Then: an .initiated event is reported and the task set to .inProgress
+ func testResume_withSessionTask_andProxyStorageTask_shouldCallResume_andReportInitiatedEvent() {
+ let expectation = expectation(description: ".initiated event received on resume with sessionTask and proxyStorageTask")
+ let sessionTask = MockSessionTask()
+ let storageTask = MockStorageTask()
+ let task = createTask(
+ transferType: .multiPartUpload(onEvent: { event in
+ guard case .initiated(_) = event else {
+ XCTFail("Expected .initiated, got \(event)")
+ return
+ }
+ expectation.fulfill()
+ }),
+ sessionTask: sessionTask,
+ proxyStorageTask: storageTask
+ )
+ XCTAssertEqual(task.status, .paused)
+
+ task.resume()
+ waitForExpectations(timeout: 0.5)
+
+ XCTAssertEqual(sessionTask.resumeCount, 1)
+ XCTAssertEqual(storageTask.resumeCount, 0)
+ XCTAssertEqual(task.status, .inProgress)
+ }
+
+ /// Given: A StorageTransferTask without a sessionTask and without a proxyStorageTask
+ /// When: resume is invoked
+ /// Then: No event is reported and the task is not to .inProgress
+ func testResume_withoutSessionTask_withoutProxyStorateTask_shouldNotCallResume_andNotReportEvent() {
+ let expectation = expectation(description: "no event is received on resume when no sessionTask nor proxyStorageTask")
+ expectation.isInverted = true
+ let sessionTask = MockSessionTask()
+ let task = createTask(
+ transferType: .multiPartUpload(onEvent: { event in
+ XCTFail("No event expected, got \(event)")
+ expectation.fulfill()
+ }),
+ sessionTask: sessionTask, // Set the sessionTask to set task.status = .paused
+ proxyStorageTask: nil
+ )
+ task.sessionTask = nil // Remove the sessionTask
+ XCTAssertEqual(task.status, .paused)
+
+ task.resume()
+ waitForExpectations(timeout: 0.5)
+
+ XCTAssertEqual(sessionTask.resumeCount, 0)
+ XCTAssertEqual(task.status, .paused)
+ }
+
+ /// Given: A StorageTransferTask with status not being paused
+ /// When: resume is invoked
+ /// Then: No event is reported and the task is not set to .inProgress
+ func testResume_withTaskNotPaused_shouldNotCallResume_andNotReportEvent() {
+ let expectation = expectation(description: "no event is received on resume when the session is not paused")
+ expectation.isInverted = true
+ let task = createTask(
+ transferType: .multiPartUpload(onEvent: { event in
+ XCTFail("No event expected, got \(event)")
+ expectation.fulfill()
+ }),
+ sessionTask: nil, // Do not set session task so task.status = .unknown
+ proxyStorageTask: nil
+ )
+ XCTAssertEqual(task.status, .unknown)
+
+ task.resume()
+ waitForExpectations(timeout: 0.5)
+
+ XCTAssertEqual(task.status, .unknown)
+ }
+
+ // MARK: - Suspend Tests
+ /// Given: A StorageTransferTask with a sessionTask
+ /// When: suspend is invoked
+ /// Then: The task is set to .paused
+ func testSuspend_withSessionTask_shouldCallSuspend() {
+ let sessionTask = MockSessionTask(state: .running)
+ let task = createTask(
+ transferType: .upload(onEvent: { _ in }),
+ sessionTask: sessionTask,
+ proxyStorageTask: nil
+ )
+ // Set the task to inProgress by setting a multiPartUpload.creating
+ task.multipartUpload = .creating
+ XCTAssertEqual(task.status, .inProgress)
+
+ task.suspend()
+
+ XCTAssertEqual(sessionTask.suspendCount, 1)
+ XCTAssertEqual(task.status, .paused)
+ }
+
+ /// Given: A StorageTransferTask with a proxyStorageTask
+ /// When: suspend is invoked
+ /// Then: The task is set to .paused
+ func testSuspend_withProxyStorageTask_shouldCallPause() {
+ let storageTask = MockStorageTask()
+ let task = createTask(
+ transferType: .upload(onEvent: { _ in }),
+ sessionTask: nil,
+ proxyStorageTask: storageTask
+ )
+ // Set the task to inProgress by setting a multiPartUpload.creating
+ task.multipartUpload = .creating
+ XCTAssertEqual(task.status, .inProgress)
+
+ task.suspend()
+
+ XCTAssertEqual(storageTask.pauseCount, 1)
+ XCTAssertEqual(task.status, .paused)
+ }
+
+ /// Given: A StorageTransferTask with a sessionTask and a proxyStorageTask
+ /// When: suspend is invoked
+ /// Then: The task is set to .paused
+ func testSuspend_withSessionTask_andProxyStorageTask_shouldCallSuspend() {
+ let sessionTask = MockSessionTask(state: .running)
+ let storageTask = MockStorageTask()
+ let task = createTask(
+ transferType: .upload(onEvent: { _ in }),
+ sessionTask: sessionTask,
+ proxyStorageTask: storageTask
+ )
+ // Set the task to inProgress by setting a multiPartUpload.creating
+ task.multipartUpload = .creating
+ XCTAssertEqual(task.status, .inProgress)
+
+ task.suspend()
+
+ XCTAssertEqual(sessionTask.suspendCount, 1)
+ XCTAssertEqual(storageTask.pauseCount, 0)
+ XCTAssertEqual(task.status, .paused)
+ }
+
+ /// Given: A StorageTransferTask without a sessionTask and without a proxyStorageTask
+ /// When: suspend is invoked
+ /// Then: The task remains .inProgress
+ func testSuspend_withoutSessionTask_andWithoutProxyStorageTask_shouldDoNothing() {
+ let task = createTask(
+ transferType: .upload(onEvent: { _ in }),
+ sessionTask: nil,
+ proxyStorageTask: nil
+ )
+ // Set the task to inProgress by setting a multiPartUpload.creating
+ task.multipartUpload = .creating
+ XCTAssertEqual(task.status, .inProgress)
+
+ task.suspend()
+
+ XCTAssertEqual(task.status, .inProgress)
+ }
+
+ /// Given: A StorageTransferTask with status completed
+ /// When: suspend is invoked
+ /// Then: The task remains completed
+ func testSuspend_withTaskNotInProgress_shouldDoNothing() {
+ let sessionTask = MockSessionTask()
+ let task = createTask(
+ transferType: .upload(onEvent: { _ in }),
+ sessionTask: sessionTask,
+ proxyStorageTask: nil
+ )
+ // Set the task to completed by setting a multiPartUpload.completed
+ task.multipartUpload = .completed(uploadId: "")
+ XCTAssertEqual(task.status, .completed)
+
+ task.suspend()
+
+ XCTAssertEqual(sessionTask.suspendCount, 0)
+ XCTAssertEqual(task.status, .completed)
+ }
+
+ /// Given: A StorageTransferTask
+ /// When: pause is invoked
+ /// Then: The task is set to .paused
+ func testPause_shouldCallSuspend() {
+ let sessionTask = MockSessionTask(state: .running)
+ let task = createTask(
+ transferType: .upload(onEvent: { _ in }),
+ sessionTask: sessionTask,
+ proxyStorageTask: nil
+ )
+ // Set the task to inProgress by setting a multiPartUpload.creating
+ task.multipartUpload = .creating
+ XCTAssertEqual(task.status, .inProgress)
+
+ task.pause()
+
+ XCTAssertEqual(sessionTask.suspendCount, 1)
+ XCTAssertEqual(task.status, .paused)
+ }
+
+ // MARK: - Cancel Tests
+ /// Given: A StorageTransferTask with a sessionTask
+ /// When: cancel is invoked
+ /// Then: The task is set to .cancelled
+ func testCancel_withSessionTask_shouldCancel() {
+ let sessionTask = MockSessionTask()
+ let task = createTask(
+ transferType: .upload(onEvent: { _ in }),
+ sessionTask: sessionTask,
+ proxyStorageTask: MockStorageTask()
+ )
+
+ // Set the task to completed by setting a multiPartUpload.completed
+ XCTAssertNotEqual(task.status, .completed)
+
+ task.cancel()
+
+ XCTAssertEqual(task.status, .cancelled)
+ XCTAssertEqual(sessionTask.cancelCount, 1)
+ XCTAssertNil(task.proxyStorageTask)
+ }
+
+ /// Given: A StorageTransferTask with a proxyStorageTask
+ /// When: cancel is invoked
+ /// Then: The task is set to .cancelled
+ func testCancel_withProxyStorageTask_shouldCancel() {
+ let storageTask = MockStorageTask()
+ let task = createTask(
+ transferType: .upload(onEvent: { _ in }),
+ sessionTask: nil,
+ proxyStorageTask: storageTask
+ )
+
+ task.cancel()
+ XCTAssertEqual(task.status, .cancelled)
+ XCTAssertEqual(storageTask.cancelCount, 1)
+ XCTAssertNil(task.proxyStorageTask)
+ }
+
+ /// Given: A StorageTransferTask without a sessionTask and without a proxyStorageTask
+ /// When: cancel is invoked
+ /// Then: The task is not set to .cancelled
+ func testCancel_withoutSessionTask_withoutProxyStorageTask_shouldDoNothing() {
+ let task = createTask(
+ transferType: .upload(onEvent: { _ in }),
+ sessionTask: nil,
+ proxyStorageTask: nil
+ )
+
+ task.cancel()
+ XCTAssertNotEqual(task.status, .cancelled)
+ }
+
+ /// Given: A StorageTransferTask with status completed
+ /// When: cancel is invoked
+ /// Then: The task is not set to .cancelled
+ func testCancel_withTaskCompleted_shouldDoNothing() {
+ let sessionTask = MockSessionTask()
+ let task = createTask(
+ transferType: .upload(onEvent: { _ in }),
+ sessionTask: sessionTask,
+ proxyStorageTask: MockStorageTask()
+ )
+ // Set the task to completed by setting a multiPartUpload.completed
+ task.multipartUpload = .completed(uploadId: "")
+ XCTAssertEqual(task.status, .completed)
+
+ task.cancel()
+ XCTAssertNotEqual(task.status, .cancelled)
+ XCTAssertEqual(sessionTask.cancelCount, 0)
+ XCTAssertNotNil(task.proxyStorageTask)
+ }
+
+ // MARK: - Complete Tests
+ /// Given: A StorageTransferTask with sessionTask
+ /// When: complete is invoked
+ /// Then: The task is set to .completed
+ func testComplete_withSessionTask_shouldComplete() {
+ let sessionTask = MockSessionTask()
+ let task = createTask(
+ transferType: .upload(onEvent: { _ in }),
+ sessionTask: sessionTask,
+ proxyStorageTask: MockStorageTask()
+ )
+
+ task.complete()
+ XCTAssertEqual(task.status, .completed)
+ XCTAssertNil(task.proxyStorageTask)
+ }
+
+ /// Given: A StorageTransferTask with status cancelled
+ /// When: complete is invoked
+ /// Then: The task is remains .cancelled
+ func testComplete_withTaskCancelled_shouldDoNothing() {
+ let sessionTask = MockSessionTask()
+ let task = createTask(
+ transferType: .upload(onEvent: { _ in }),
+ sessionTask: sessionTask,
+ proxyStorageTask: nil
+ )
+ task.cancel()
+ XCTAssertEqual(task.status, .cancelled)
+
+ task.complete()
+ XCTAssertEqual(task.status, .cancelled)
+ }
+
+ /// Given: A StorageTransferTask with status completed
+ /// When: complete is invoked
+ /// Then: The task is remains .completed
+ func testComplete_withTaskCompleted_shouldDoNothing() {
+ let sessionTask = MockSessionTask()
+ let task = createTask(
+ transferType: .upload(onEvent: { _ in }),
+ sessionTask: sessionTask,
+ proxyStorageTask: MockStorageTask()
+ )
+ // Set the task to completed by setting a multiPartUpload.completed
+ task.multipartUpload = .completed(uploadId: "")
+ XCTAssertEqual(task.status, .completed)
+
+ task.complete()
+
+ XCTAssertNotNil(task.proxyStorageTask)
+ }
+
+ // MARK: - Fail Tests
+ /// Given: A StorageTransferTask
+ /// When: fail is invoked
+ /// Then: A .failed event is reported
+ func testFail_shouldReportFailEvent() {
+ let expectation = expectation(description: ".failed event received on fail")
+ let task = createTask(
+ transferType: .upload(onEvent: { event in
+ guard case .failed(_) = event else {
+ XCTFail("Expected .failed, got \(event)")
+ return
+ }
+ expectation.fulfill()
+ }),
+ sessionTask: MockSessionTask(),
+ proxyStorageTask: MockStorageTask()
+ )
+ task.fail(error: CancellationError())
+
+ waitForExpectations(timeout: 0.5)
+ XCTAssertEqual(task.status, .error)
+ XCTAssertTrue(task.isFailed)
+ XCTAssertNil(task.proxyStorageTask)
+ }
+
+ /// Given: A StorageTransferTask with status .failed
+ /// When: fail is invoked
+ /// Then: No event is reported
+ func testFail_withFailedTask_shouldNotReportEvent() {
+ let expectation = expectation(description: "event received on fail for failed task")
+ expectation.isInverted = true
+ let task = createTask(
+ transferType: .upload(onEvent: { event in
+ XCTFail("No event expected, got \(event)")
+ expectation.fulfill()
+ }),
+ sessionTask: MockSessionTask(),
+ proxyStorageTask: MockStorageTask()
+ )
+
+ // Set the task to error by setting a multiPartUpload.failed
+ task.multipartUpload = .failed(uploadId: "", parts: nil, error: CancellationError())
+ XCTAssertEqual(task.status, .error)
+ task.fail(error: CancellationError())
+
+ waitForExpectations(timeout: 0.5)
+ XCTAssertNotNil(task.proxyStorageTask)
+ }
+
+ // MARK: - Response Tests
+ /// Given: A StorageTransferTask with a valid responseData
+ /// When: responseText is invoked
+ /// Then: A string representing the data is returned
+ func testResponseText_withValidData_shouldReturnText() {
+ let task = createTask(
+ transferType: .upload(onEvent: { _ in}),
+ sessionTask: nil,
+ proxyStorageTask: nil
+ )
+ task.responseData = "Test".data(using: .utf8)
+
+ XCTAssertEqual(task.responseText, "Test")
+ }
+
+ /// Given: A StorageTransferTask with an invalid responseData
+ /// When: responseText is invoked
+ /// Then: nil is returned
+ func testResponseText_withInvalidData_shouldReturnNil() {
+ let task = createTask(
+ transferType: .upload(onEvent: { _ in}),
+ sessionTask: nil,
+ proxyStorageTask: nil
+ )
+ task.responseData = Data(count: 9999)
+
+ XCTAssertNil(task.responseText)
+ }
+
+ /// Given: A StorageTransferTask with a nil responseData
+ /// When: responseText is invoked
+ /// Then: nil is returned
+ func testResponseText_withoutData_shouldReturnNil() {
+ let task = createTask(
+ transferType: .upload(onEvent: { _ in}),
+ sessionTask: nil,
+ proxyStorageTask: nil
+ )
+ task.responseData = nil
+
+ XCTAssertNil(task.responseText)
+ }
+
+ // MARK: - PartNumber Tests
+ /// Given: A StorageTransferTask of type .multiPartUploadPart
+ /// When: partNumber is invoked
+ /// Then: The corresponding part number is returned
+ func testPartNumber_withMultipartUpload_shouldReturnPartNumber() {
+ let partNumber: PartNumber = 5
+ let task = createTask(
+ transferType: .multiPartUploadPart(uploadId: "", partNumber: partNumber),
+ sessionTask: nil,
+ proxyStorageTask: nil
+ )
+
+ XCTAssertEqual(task.partNumber, partNumber)
+ }
+
+ /// Given: A StorageTransferTask of type .upload
+ /// When: partNumber is invoked
+ /// Then: nil is returned
+ func testPartNumber_withOtherTransferType_shouldReturnNil() {
+ let task = createTask(
+ transferType: .upload(onEvent: { _ in}),
+ sessionTask: nil,
+ proxyStorageTask: nil
+ )
+
+ XCTAssertNil(task.partNumber)
+ }
+
+ // MARK: - HTTPRequestHeaders Tests
+ /// Given: A StorageTransferTask with requestHeaders
+ /// When: URLRequest.setHTTPRequestHeaders is invoked with said task
+ /// Then: The request includes the corresponding headers
+ func testHTTPRequestHeaders_shouldSetValues() {
+ let task = createTask(
+ transferType: .upload(onEvent: { _ in}),
+ sessionTask: nil,
+ proxyStorageTask: nil,
+ requestHeaders: [
+ "header1": "value1",
+ "header2": "value2"
+ ]
+ )
+
+ var request = URLRequest(url: FileManager.default.temporaryDirectory)
+ XCTAssertNil(request.allHTTPHeaderFields)
+
+ request.setHTTPRequestHeaders(transferTask: task)
+ XCTAssertEqual(request.allHTTPHeaderFields?.count, 2)
+ XCTAssertEqual(request.allHTTPHeaderFields?["header1"], "value1")
+ XCTAssertEqual(request.allHTTPHeaderFields?["header2"], "value2")
+ }
+
+ /// Given: A StorageTransferTask with nil requestHeaders
+ /// When: URLRequest.setHTTPRequestHeaders is invoked with said task
+ /// Then: The request does not adds headers
+ func testHTTPRequestHeaders_withoutHeaders_shouldDoNothing() {
+ let task = createTask(
+ transferType: .upload(onEvent: { _ in}),
+ sessionTask: nil,
+ proxyStorageTask: nil,
+ requestHeaders: nil
+ )
+
+ var request = URLRequest(url: FileManager.default.temporaryDirectory)
+ XCTAssertNil(request.allHTTPHeaderFields)
+
+ request.setHTTPRequestHeaders(transferTask: task)
+ XCTAssertNil(request.allHTTPHeaderFields)
+ }
+}
+
+extension StorageTransferTaskTests {
+ private func createTask(
+ transferType: StorageTransferType,
+ sessionTask: StorageSessionTask?,
+ proxyStorageTask: StorageTask?,
+ requestHeaders: [String: String]? = nil
+ ) -> StorageTransferTask {
+ let transferID = UUID().uuidString
+ let bucket = "BUCKET"
+ let key = UUID().uuidString
+ let task = StorageTransferTask(
+ transferID: transferID,
+ transferType: transferType,
+ bucket: bucket,
+ key: key,
+ location: nil,
+ contentType: nil,
+ requestHeaders: requestHeaders,
+ storageTransferDatabase: MockStorageTransferDatabase(),
+ logger: MockLogger()
+ )
+ task.sessionTask = sessionTask
+ task.proxyStorageTask = proxyStorageTask
+ return task
+ }
+}
+
+
+private class MockStorageTask: StorageTask {
+ var pauseCount = 0
+ func pause() {
+ pauseCount += 1
+ }
+
+ var resumeCount = 0
+ func resume() {
+ resumeCount += 1
+ }
+
+ var cancelCount = 0
+ func cancel() {
+ cancelCount += 1
+ }
+}
+
+private class MockSessionTask: StorageSessionTask {
+ let taskIdentifier: TaskIdentifier
+ let state: URLSessionTask.State
+
+ init(
+ taskIdentifier: TaskIdentifier = 1,
+ state: URLSessionTask.State = .suspended
+ ) {
+ self.taskIdentifier = taskIdentifier
+ self.state = state
+ }
+
+ var resumeCount = 0
+ func resume() {
+ resumeCount += 1
+ }
+
+ var suspendCount = 0
+ func suspend() {
+ suspendCount += 1
+ }
+
+ var cancelCount = 0
+ func cancel() {
+ cancelCount += 1
+ }
+}
+
+class MockLogger: Logger {
+ var logLevel: LogLevel = .verbose
+
+ func error(_ message: @autoclosure () -> String) {
+ print(message())
+ }
+
+ func error(error: Error) {
+ print(error)
+ }
+
+ var warnCount = 0
+ func warn(_ message: @autoclosure () -> String) {
+ print(message())
+ warnCount += 1
+ }
+
+ var infoCount = 0
+ func info(_ message: @autoclosure () -> String) {
+ print(message())
+ infoCount += 1
+ }
+
+ func debug(_ message: @autoclosure () -> String) {
+ print(message())
+ }
+
+ func verbose(_ message: @autoclosure () -> String) {
+ print(message())
+ }
+}