diff --git a/.swiftpm/xcode/xcshareddata/xcschemes/AWSCognitoAuthPlugin.xcscheme b/.swiftpm/xcode/xcshareddata/xcschemes/AWSCognitoAuthPlugin.xcscheme index d92bafde07..96d011b6cc 100644 --- a/.swiftpm/xcode/xcshareddata/xcschemes/AWSCognitoAuthPlugin.xcscheme +++ b/.swiftpm/xcode/xcshareddata/xcschemes/AWSCognitoAuthPlugin.xcscheme @@ -48,11 +48,6 @@ BlueprintName = "AWSCognitoAuthPluginUnitTests" ReferencedContainer = "container:"> - - - - diff --git a/AmplifyPlugins/Auth/Sources/AWSCognitoAuthPlugin/Models/AWSAuthCognitoSession.swift b/AmplifyPlugins/Auth/Sources/AWSCognitoAuthPlugin/Models/AWSAuthCognitoSession.swift index 6c46eceb40..59d799e963 100644 --- a/AmplifyPlugins/Auth/Sources/AWSCognitoAuthPlugin/Models/AWSAuthCognitoSession.swift +++ b/AmplifyPlugins/Auth/Sources/AWSCognitoAuthPlugin/Models/AWSAuthCognitoSession.swift @@ -76,24 +76,6 @@ public struct AWSAuthCognitoSession: AuthSession, } -/// Internal Helpers for managing session tokens -internal extension AWSAuthCognitoSession { - func areTokensExpiring(in seconds: TimeInterval? = nil) -> Bool { - - guard let tokens = try? userPoolTokensResult.get(), - let idTokenClaims = try? AWSAuthService().getTokenClaims(tokenString: tokens.idToken).get(), - let accessTokenClaims = try? AWSAuthService().getTokenClaims(tokenString: tokens.idToken).get(), - let idTokenExpiration = idTokenClaims["exp"]?.doubleValue, - let accessTokenExpiration = accessTokenClaims["exp"]?.doubleValue else { - return true - } - - // If the session expires < X minutes return it - return (Date(timeIntervalSince1970: idTokenExpiration).compare(Date(timeIntervalSinceNow: seconds ?? 0)) == .orderedDescending && - Date(timeIntervalSince1970: accessTokenExpiration).compare(Date(timeIntervalSinceNow: seconds ?? 0)) == .orderedDescending) - } -} - extension AWSAuthCognitoSession: Equatable { public static func == (lhs: AWSAuthCognitoSession, rhs: AWSAuthCognitoSession) -> Bool { switch (lhs.getCognitoTokens(), rhs.getCognitoTokens()) { diff --git a/AmplifyPlugins/Auth/Sources/AWSCognitoAuthPlugin/Models/AWSCognitoUserPoolTokens.swift b/AmplifyPlugins/Auth/Sources/AWSCognitoAuthPlugin/Models/AWSCognitoUserPoolTokens.swift index af7d80f96a..c5f4daed06 100644 --- a/AmplifyPlugins/Auth/Sources/AWSCognitoAuthPlugin/Models/AWSCognitoUserPoolTokens.swift +++ b/AmplifyPlugins/Auth/Sources/AWSCognitoAuthPlugin/Models/AWSCognitoUserPoolTokens.swift @@ -65,10 +65,10 @@ public struct AWSCognitoUserPoolTokens: AuthCognitoTokens { case (.some(let idTokenValue), .none): expirationDoubleValue = idTokenValue case (.none, .none): - expirationDoubleValue = 0 + expirationDoubleValue = Date().timeIntervalSince1970 } - self.expiration = Date().addingTimeInterval(TimeInterval((expirationDoubleValue ?? 0))) + self.expiration = Date(timeIntervalSince1970: TimeInterval(expirationDoubleValue)) } } diff --git a/AmplifyPlugins/Auth/Sources/AWSCognitoAuthPlugin/Support/Helpers/AuthCognitoSignedOutSessionHelper.swift b/AmplifyPlugins/Auth/Sources/AWSCognitoAuthPlugin/Support/Helpers/AuthCognitoSignedOutSessionHelper.swift index 9ded44922d..8e391355e5 100644 --- a/AmplifyPlugins/Auth/Sources/AWSCognitoAuthPlugin/Support/Helpers/AuthCognitoSignedOutSessionHelper.swift +++ b/AmplifyPlugins/Auth/Sources/AWSCognitoAuthPlugin/Support/Helpers/AuthCognitoSignedOutSessionHelper.swift @@ -25,27 +25,6 @@ struct AuthCognitoSignedOutSessionHelper { return authSession } - /// Guest/SignedOut session with any unhandled error - /// - /// The unhandled error is passed as identityId and aws credentials result. UserSub and Cognito Tokens will still - /// have signOut error. - /// - /// - Parameter error: Unhandled error - /// - Returns: Session will have isSignedIn = false - private static func makeSignedOutSession(withUnhandledError error: AuthError) -> AWSAuthCognitoSession { - - let identityIdError = error - let awsCredentialsError = error - - let tokensError = makeCognitoTokensSignedOutError() - - let authSession = AWSAuthCognitoSession(isSignedIn: false, - identityIdResult: .failure(identityIdError), - awsCredentialsResult: .failure(awsCredentialsError), - cognitoTokensResult: .failure(tokensError)) - return authSession - } - /// Guest/SignOut session when the guest access is not enabled. /// - Returns: Session with isSignedIn = false static func makeSessionWithNoGuestAccess() -> AWSAuthCognitoSession { @@ -68,26 +47,6 @@ struct AuthCognitoSignedOutSessionHelper { return authSession } - private static func makeOfflineSignedOutSession() -> AWSAuthCognitoSession { - let identityIdError = AuthError.service( - AuthPluginErrorConstants.identityIdOfflineError.errorDescription, - AuthPluginErrorConstants.identityIdOfflineError.recoverySuggestion, - AWSCognitoAuthError.network) - - let awsCredentialsError = AuthError.service( - AuthPluginErrorConstants.awsCredentialsOfflineError.errorDescription, - AuthPluginErrorConstants.awsCredentialsOfflineError.recoverySuggestion, - AWSCognitoAuthError.network) - - let tokensError = makeCognitoTokensSignedOutError() - - let authSession = AWSAuthCognitoSession(isSignedIn: false, - identityIdResult: .failure(identityIdError), - awsCredentialsResult: .failure(awsCredentialsError), - cognitoTokensResult: .failure(tokensError)) - return authSession - } - /// Guest/SignedOut session with couldnot retreive either aws credentials or identity id. /// - Returns: Session will have isSignedIn = false private static func makeSignedOutSessionWithServiceIssue() -> AWSAuthCognitoSession { @@ -109,13 +68,6 @@ struct AuthCognitoSignedOutSessionHelper { return authSession } - private static func makeUserSubSignedOutError() -> AuthError { - let userSubError = AuthError.signedOut( - AuthPluginErrorConstants.userSubSignOutError.errorDescription, - AuthPluginErrorConstants.userSubSignOutError.recoverySuggestion) - return userSubError - } - private static func makeCognitoTokensSignedOutError() -> AuthError { let tokensError = AuthError.signedOut( AuthPluginErrorConstants.cognitoTokensSignOutError.errorDescription, diff --git a/AmplifyPlugins/Auth/Sources/AWSCognitoAuthPlugin/Support/HostedUI/HostedUIASWebAuthenticationSession.swift b/AmplifyPlugins/Auth/Sources/AWSCognitoAuthPlugin/Support/HostedUI/HostedUIASWebAuthenticationSession.swift index cd9760637a..9c225ec931 100644 --- a/AmplifyPlugins/Auth/Sources/AWSCognitoAuthPlugin/Support/HostedUI/HostedUIASWebAuthenticationSession.swift +++ b/AmplifyPlugins/Auth/Sources/AWSCognitoAuthPlugin/Support/HostedUI/HostedUIASWebAuthenticationSession.swift @@ -22,7 +22,7 @@ class HostedUIASWebAuthenticationSession: NSObject, HostedUISessionBehavior { callback: @escaping (Result<[URLQueryItem], HostedUIError>) -> Void) { #if os(iOS) || os(macOS) self.webPresentation = presentationAnchor - let aswebAuthenticationSession = ASWebAuthenticationSession( + let aswebAuthenticationSession = createAuthenticationSession( url: url, callbackURLScheme: callbackScheme, completionHandler: { url, error in @@ -58,6 +58,16 @@ class HostedUIASWebAuthenticationSession: NSObject, HostedUISessionBehavior { } #if os(iOS) || os(macOS) + var authenticationSessionFactory = ASWebAuthenticationSession.init(url:callbackURLScheme:completionHandler:) + + private func createAuthenticationSession( + url: URL, + callbackURLScheme: String?, + completionHandler: @escaping ASWebAuthenticationSession.CompletionHandler + ) -> ASWebAuthenticationSession { + return authenticationSessionFactory(url, callbackURLScheme, completionHandler) + } + private func convertHostedUIError(_ error: Error) -> HostedUIError { if let asWebAuthError = error as? ASWebAuthenticationSessionError { switch asWebAuthError.code { diff --git a/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/ActionTests/CredentialStore/MigrateLegacyCredentialStoreTests.swift b/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/ActionTests/CredentialStore/MigrateLegacyCredentialStoreTests.swift index 06ebd109c9..1252888632 100644 --- a/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/ActionTests/CredentialStore/MigrateLegacyCredentialStoreTests.swift +++ b/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/ActionTests/CredentialStore/MigrateLegacyCredentialStoreTests.swift @@ -121,7 +121,122 @@ class MigrateLegacyCredentialStoreTests: XCTestCase { await fulfillment( of: [migrationCompletionInvoked], + timeout: 0.1 ) } + + /// - Given: A credential store with an invalid environment + /// - When: The migration legacy store action is executed + /// - Then: An error event of type configuration is dispatched + func testExecute_withInvalidEnvironment_shouldDispatchError() async { + let expectation = expectation(description: "noEnvironment") + let action = MigrateLegacyCredentialStore() + await action.execute( + withDispatcher: MockDispatcher { event in + guard let event = event as? CredentialStoreEvent, + case let .throwError(error) = event.eventType else { + XCTFail("Expected failure due to no CredentialEnvironment") + expectation.fulfill() + return + } + XCTAssertEqual(error, .configuration(message: AuthPluginErrorConstants.configurationError)) + expectation.fulfill() + }, + environment: MockInvalidEnvironment() + ) + await fulfillment(of: [expectation], timeout: 1) + } + + /// - Given: A credential store with an environment that only has identity pool + /// - When: The migration legacy store action is executed + /// - Then: + /// - A .loadCredentialStore event with type .amplifyCredentials is dispatched + /// - An .identityPoolOnly credential is saved + func testExecute_withoutUserPool_andWithoutLoginsTokens_shouldDispatchLoadEvent() async { + let expectation = expectation(description: "noUserPoolTokens") + let action = MigrateLegacyCredentialStore() + await action.execute( + withDispatcher: MockDispatcher { event in + guard let event = event as? CredentialStoreEvent, + case .loadCredentialStore(let type) = event.eventType else { + XCTFail("Expected .loadCredentialStore") + expectation.fulfill() + return + } + XCTAssertEqual(type, .amplifyCredentials) + expectation.fulfill() + }, + environment: CredentialEnvironment( + authConfiguration: .identityPools(.testData), + credentialStoreEnvironment: BasicCredentialStoreEnvironment( + amplifyCredentialStoreFactory: { + MockAmplifyCredentialStoreBehavior( + saveCredentialHandler: { codableCredentials in + guard let amplifyCredentials = codableCredentials as? AmplifyCredentials, + case .identityPoolOnly(_, let credentials) = amplifyCredentials else { + XCTFail("Expected .identityPoolOnly") + return + } + XCTAssertFalse(credentials.sessionToken.isEmpty) + } + ) + }, + legacyKeychainStoreFactory: { _ in + MockKeychainStoreBehavior(data: "hostedUI") + }), + logger: MigrateLegacyCredentialStore.log + ) + ) + await fulfillment(of: [expectation], timeout: 1) + } + + /// - Given: A credential store with an environment that only has identity pool + /// - When: The migration legacy store action is executed + /// - A .loadCredentialStore event with type .amplifyCredentials is dispatched + /// - An .identityPoolWithFederation credential is saved + func testExecute_withoutUserPool_andWithLoginsTokens_shouldDispatchLoadEvent() async { + let expectation = expectation(description: "noUserPoolTokens") + let action = MigrateLegacyCredentialStore() + await action.execute( + withDispatcher: MockDispatcher { event in + guard let event = event as? CredentialStoreEvent, + case .loadCredentialStore(let type) = event.eventType else { + XCTFail("Expected .loadCredentialStore") + expectation.fulfill() + return + } + XCTAssertEqual(type, .amplifyCredentials) + expectation.fulfill() + }, + environment: CredentialEnvironment( + authConfiguration: .identityPools(.testData), + credentialStoreEnvironment: BasicCredentialStoreEnvironment( + amplifyCredentialStoreFactory: { + MockAmplifyCredentialStoreBehavior( + saveCredentialHandler: { codableCredentials in + guard let amplifyCredentials = codableCredentials as? AmplifyCredentials, + case .identityPoolWithFederation(let token, _, _) = amplifyCredentials else { + XCTFail("Expected .identityPoolWithFederation") + return + } + + XCTAssertEqual(token.token, "token") + XCTAssertEqual(token.provider.userPoolProviderName, "provider") + } + ) + }, + legacyKeychainStoreFactory: { _ in + let data = try! JSONEncoder().encode([ + "provider": "token" + ]) + return MockKeychainStoreBehavior( + data: String(decoding: data, as: UTF8.self) + ) + }), + logger: action.log + ) + ) + await fulfillment(of: [expectation], timeout: 1) + } } diff --git a/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/ActionTests/FetchAuthSession/FetchUserPoolTokens/RefreshHostedUITokensTests.swift b/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/ActionTests/FetchAuthSession/FetchUserPoolTokens/RefreshHostedUITokensTests.swift new file mode 100644 index 0000000000..2b94f7721f --- /dev/null +++ b/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/ActionTests/FetchAuthSession/FetchUserPoolTokens/RefreshHostedUITokensTests.swift @@ -0,0 +1,310 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if os(iOS) || os(macOS) + +@testable import AWSCognitoAuthPlugin +import AWSCognitoIdentityProvider +import AWSPluginsCore +import XCTest + +class RefreshHostedUITokensTests: XCTestCase { + private let tokenResult: [String: Any] = [ + "id_token": AWSCognitoUserPoolTokens.testData.idToken, + "access_token": AWSCognitoUserPoolTokens.testData.accessToken, + "refresh_token": AWSCognitoUserPoolTokens.testData.refreshToken, + "expires_in": 10 + ] + + private var hostedUIEnvironment: HostedUIEnvironment { + BasicHostedUIEnvironment( + configuration: .init( + clientId: "clientId", + oauth: .init( + domain: "cognitodomain", + scopes: ["name"], + signInRedirectURI: "myapp://", + signOutRedirectURI: "myapp://" + ) + ), + hostedUISessionFactory: sessionFactory, + urlSessionFactory: urlSessionMock, + randomStringFactory: mockRandomString + ) + } + + override func setUp() { + let result = try! JSONSerialization.data(withJSONObject: tokenResult) + MockURLProtocol.requestHandler = { _ in + return (HTTPURLResponse(), result) + } + } + + override func tearDown() { + MockURLProtocol.requestHandler = nil + } + + /// Given: A RefreshHostedUITokens action + /// When: execute is invoked with a valid response + /// Then: A RefreshSessionEvent.refreshIdentityInfo is dispatched + func testExecute_withValidResponse_shouldDispatchRefreshEvent() async { + let expectation = expectation(description: "refreshHostedUITokens") + let action = RefreshHostedUITokens(existingSignedIndata: .testData) + action.execute( + withDispatcher: MockDispatcher { event in + guard let event = event as? RefreshSessionEvent, + case .refreshIdentityInfo(let data, _) = event.eventType else { + XCTFail("Failed to refresh tokens") + expectation.fulfill() + return + } + + XCTAssertEqual(data.cognitoUserPoolTokens.idToken, self.tokenResult["id_token"] as? String) + XCTAssertEqual(data.cognitoUserPoolTokens.accessToken, self.tokenResult["access_token"] as? String) + XCTAssertEqual(data.cognitoUserPoolTokens.refreshToken, self.tokenResult["refresh_token"] as? String) + expectation.fulfill() + }, + environment: Defaults.makeDefaultAuthEnvironment( + userPoolFactory: identityProviderFactory, + hostedUIEnvironment: hostedUIEnvironment + ) + ) + + await fulfillment(of: [expectation], timeout: 1) + } + + /// Given: A RefreshHostedUITokens action + /// When: execute is invoked and throws a HostedUIError + /// Then: A RefreshSessionEvent.throwError is dispatched with .service + func testExecute_withHostedUIError_shouldDispatchErrorEvent() async { + let expectedError = HostedUIError.serviceMessage("Something went wrong") + MockURLProtocol.requestHandler = { _ in + throw expectedError + } + + let expectation = expectation(description: "refreshHostedUITokens") + let action = RefreshHostedUITokens(existingSignedIndata: .testData) + action.execute( + withDispatcher: MockDispatcher { event in + guard let event = event as? RefreshSessionEvent, + case let .throwError(error) = event.eventType else { + XCTFail("Expected failure due to Service Error") + expectation.fulfill() + return + } + + XCTAssertEqual(error, .service(expectedError)) + expectation.fulfill() + }, + environment: Defaults.makeDefaultAuthEnvironment( + userPoolFactory: identityProviderFactory, + hostedUIEnvironment: hostedUIEnvironment + ) + ) + + await fulfillment(of: [expectation], timeout: 1) + } + + /// Given: A RefreshHostedUITokens action + /// When: execute is invoked and returns empty data + /// Then: A RefreshSessionEvent.throwError is dispatched with .service + func testExecute_withEmptyData_shouldDispatchErrorEvent() async { + MockURLProtocol.requestHandler = { _ in + return (HTTPURLResponse(), Data()) + } + + let expectation = expectation(description: "refreshHostedUITokens") + let action = RefreshHostedUITokens(existingSignedIndata: .testData) + action.execute( + withDispatcher: MockDispatcher { event in + guard let event = event as? RefreshSessionEvent, + case let .throwError(error) = event.eventType else { + XCTFail("Expected failure due to Invalid Tokens") + expectation.fulfill() + return + } + + guard case .service(let serviceError) = error else { + XCTFail("Expected FetchSessionError.service, got \(error)") + expectation.fulfill() + return + } + + + XCTAssertEqual((serviceError as NSError).code, NSPropertyListReadCorruptError) + expectation.fulfill() + }, + environment: Defaults.makeDefaultAuthEnvironment( + userPoolFactory: identityProviderFactory, + hostedUIEnvironment: hostedUIEnvironment + ) + ) + + await fulfillment(of: [expectation], timeout: 1) + } + + /// Given: A RefreshHostedUITokens action + /// When: execute is invoked and returns data that is invalid for tokens + /// Then: A RefreshSessionEvent.throwError is dispatched with .invalidTokens + func testExecute_withInvalidTokens_shouldDispatchErrorEvent() async { + let result: [String: Any] = [ + "key": "value" + ] + MockURLProtocol.requestHandler = { _ in + return (HTTPURLResponse(), try! JSONSerialization.data(withJSONObject: result)) + } + + let expectation = expectation(description: "refreshHostedUITokens") + let action = RefreshHostedUITokens(existingSignedIndata: .testData) + action.execute( + withDispatcher: MockDispatcher { event in + guard let event = event as? RefreshSessionEvent, + case let .throwError(error) = event.eventType else { + XCTFail("Expected failure due to Invalid Tokens") + expectation.fulfill() + return + } + + + XCTAssertEqual(error, .invalidTokens) + expectation.fulfill() + }, + environment: Defaults.makeDefaultAuthEnvironment( + userPoolFactory: identityProviderFactory, + hostedUIEnvironment: hostedUIEnvironment + ) + ) + + await fulfillment(of: [expectation], timeout: 1) + } + + /// Given: A RefreshHostedUITokens action + /// When: execute is invoked and returns data representing an error + /// Then: A RefreshSessionEvent.throwError is dispatched with .service + func testExecute_withErrorResponse_shouldDispatchErrorEvent() async { + let result: [String: Any] = [ + "error": "Error.", + "error_description": "Something went wrong" + ] + MockURLProtocol.requestHandler = { _ in + return (HTTPURLResponse(), try! JSONSerialization.data(withJSONObject: result)) + } + + let expectation = expectation(description: "refreshHostedUITokens") + let action = RefreshHostedUITokens(existingSignedIndata: .testData) + action.execute( + withDispatcher: MockDispatcher { event in + guard let event = event as? RefreshSessionEvent, + case let .throwError(error) = event.eventType else { + XCTFail("Expected failure due to Invalid Tokens") + expectation.fulfill() + return + } + + guard case .service(let serviceError) = error, + case .serviceMessage(let errorMessage) = serviceError as? HostedUIError else { + XCTFail("Expected HostedUIError.serviceMessage, got \(error)") + expectation.fulfill() + return + } + + + XCTAssertEqual(errorMessage, "Error. Something went wrong") + expectation.fulfill() + }, + environment: Defaults.makeDefaultAuthEnvironment( + userPoolFactory: identityProviderFactory, + hostedUIEnvironment: hostedUIEnvironment + ) + ) + + await fulfillment(of: [expectation], timeout: 1) + } + + /// Given: A RefreshHostedUITokens action + /// When: execute is invoked without a HostedUIEnvironment + /// Then: A RefreshSessionEvent.throwError is dispatched with .noUserPool + func testExecute_withoutHostedUIEnvironment_shouldDispatchErrorEvent() async { + let expectation = expectation(description: "noHostedUIEnvironment") + let action = RefreshHostedUITokens(existingSignedIndata: .testData) + action.execute( + withDispatcher: MockDispatcher { event in + guard let event = event as? RefreshSessionEvent, + case let .throwError(error) = event.eventType else { + XCTFail("Expected failure due to no HostedUIEnvironment") + expectation.fulfill() + return + } + + XCTAssertEqual(error, .noUserPool) + expectation.fulfill() + }, + environment: Defaults.makeDefaultAuthEnvironment( + userPoolFactory: identityProviderFactory, + hostedUIEnvironment: nil + ) + ) + + await fulfillment(of: [expectation], timeout: 1) + } + + /// Given: A RefreshHostedUITokens action + /// When: execute is invoked without a UserPoolEnvironment + /// Then: A RefreshSessionEvent.throwError is dispatched with .noUserPool + func testExecute_withoutUserPoolEnvironment_shouldDispatchErrorEvent() async { + let expectation = expectation(description: "noUserPoolEnvironment") + let action = RefreshHostedUITokens(existingSignedIndata: .testData) + action.execute( + withDispatcher: MockDispatcher { event in + guard let event = event as? RefreshSessionEvent, + case let .throwError(error) = event.eventType else { + XCTFail("Expected failure due to no UserPoolEnvironment") + expectation.fulfill() + return + } + + XCTAssertEqual(error, .noUserPool) + expectation.fulfill() + }, + environment: MockInvalidEnvironment() + ) + + await fulfillment(of: [expectation], timeout: 1) + } + + private func identityProviderFactory() throws -> CognitoUserPoolBehavior { + return MockIdentityProvider( + mockInitiateAuthResponse: { _ in + return InitiateAuthOutputResponse( + authenticationResult: .init( + accessToken: "accessTokenNew", + expiresIn: 100, + idToken: "idTokenNew", + refreshToken: "refreshTokenNew") + ) + } + ) + } + + private func urlSessionMock() -> URLSession { + let configuration = URLSessionConfiguration.ephemeral + configuration.protocolClasses = [MockURLProtocol.self] + return URLSession(configuration: configuration) + } + + private func sessionFactory() -> HostedUISessionBehavior { + MockHostedUISession(result: .failure(.cancelled)) + } + + private func mockRandomString() -> RandomStringBehavior { + return MockRandomStringGenerator( + mockString: "mockString", + mockUUID: "mockUUID" + ) + } +} +#endif diff --git a/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/ActionTests/InitiateAuthSRP/VerifyDevicePasswordSRPSignatureTests.swift b/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/ActionTests/InitiateAuthSRP/VerifyDevicePasswordSRPSignatureTests.swift new file mode 100644 index 0000000000..53afd8af8f --- /dev/null +++ b/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/ActionTests/InitiateAuthSRP/VerifyDevicePasswordSRPSignatureTests.swift @@ -0,0 +1,177 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + +@testable import AWSCognitoAuthPlugin +import AWSCognitoIdentityProvider +@testable import AWSPluginsTestCommon +import XCTest + +class VerifyDevicePasswordSRPSignatureTests: XCTestCase { + private var srpClient: MockSRPClientBehavior! + + override func setUp() async throws { + MockSRPClientBehavior.reset() + srpClient = MockSRPClientBehavior() + } + + override func tearDown() { + MockSRPClientBehavior.reset() + srpClient = nil + } + + /// Given: A VerifyDevicePasswordSRP + /// When: signature is invoked + /// Then: a non-empty string is returned + func testSignature_withValidValues_shouldReturnSignature() async { + do { + let signature = try signature() + XCTAssertFalse(signature.isEmpty) + } catch { + XCTFail("Should not throw error: \(error)") + } + } + + /// Given: A VerifyDevicePasswordSRP + /// When: signature is invoked and the srpClient throws an SRPError error when generating a shared secret + /// Then: a .calculation error is thrown + func testSignature_withSRPErrorOnSharedSecret_shouldThrowCalculationError() async { + srpClient.sharedSecret = .failure(SRPError.numberConversion) + do { + try signature() + XCTFail("Should not succeed") + } catch { + guard case .calculation(let srpError) = error as? SignInError else { + XCTFail("Expected SRPError.calculation, got \(error)") + return + } + + XCTAssertEqual(srpError, .numberConversion) + } + } + + /// Given: A VerifyDevicePasswordSRP + /// When: signature is invoked and the srpClient throws a non-SRPError error when generating a shared secret + /// Then: a .configuration error is thrown + func testSignature_withOtherErrorOnSharedSecret_shouldThrowCalculationError() async { + srpClient.sharedSecret = .failure(CancellationError()) + do { + try signature() + XCTFail("Should not succeed") + } catch { + guard case .configuration(let message) = error as? SignInError else { + XCTFail("Expected SRPError.configuration, got \(error)") + return + } + + XCTAssertEqual(message, "Could not calculate shared secret") + } + } + + /// Given: A VerifyDevicePasswordSRP + /// When: signature is invoked and the srpClient throws a SRPError error when generating an authentication key + /// Then: a .calculation error is thrown + func testSignature_withSRPErrorOnAuthenticationKey_shouldThrowCalculationError() async { + MockSRPClientBehavior.authenticationKey = .failure(SRPError.numberConversion) + do { + try signature() + XCTFail("Should not succeed") + } catch { + guard case .calculation(let srpError) = error as? SignInError else { + XCTFail("Expected SRPError.calculation, got \(error)") + return + } + + XCTAssertEqual(srpError, .numberConversion) + } + } + + /// Given: A VerifyDevicePasswordSRP + /// When: signature is invoked and the srpClient throws a non-SRPError error when generating an authentication key + /// Then: a .configuration error is thrown + func testSignature_withOtherErrorOnAuthenticationKey_shouldThrowCalculationError() async { + MockSRPClientBehavior.authenticationKey = .failure(CancellationError()) + do { + try signature() + XCTFail("Should not succeed") + } catch { + guard case .configuration(let message) = error as? SignInError else { + XCTFail("Expected SRPError.configuration, got \(error)") + return + } + + XCTAssertEqual(message, "Could not calculate signature") + } + } + + @discardableResult + private func signature() throws -> String { + let action = VerifyDevicePasswordSRP( + stateData: .testData, + authResponse: InitiateAuthOutputResponse.validTestData + ) + + return try action.signature( + deviceGroupKey: "deviceGroupKey", + deviceKey: "deviceKey", + deviceSecret: "deviceSecret", + saltHex: "saltHex", + secretBlock: "secretBlock".data(using: .utf8) ?? Data(), + serverPublicBHexString: "serverPublicBHexString", + srpClient: srpClient + ) + } +} + +private class MockSRPClientBehavior: SRPClientBehavior { + var kHexValue: String = "kHexValue" + + static func calculateUHexValue( + clientPublicKeyHexValue: String, + serverPublicKeyHexValue: String + ) throws -> String { + return "UHexValue" + } + + static var authenticationKey: Result = .success("AuthenticationKey".data(using: .utf8)!) + static func generateAuthenticationKey( + sharedSecretHexValue: String, + uHexValue: String + ) throws -> Data { + return try authenticationKey.get() + } + + static func reset() { + authenticationKey = .success("AuthenticationKey".data(using: .utf8)!) + } + + func generateClientKeyPair() -> SRPKeys { + return .init( + publicKeyHexValue: "publicKeyHexValue", + privateKeyHexValue: "privateKeyHexValue" + ) + } + + var sharedSecret: Result = .success("SharedSecret") + func calculateSharedSecret( + username: String, + password: String, + saltHexValue: String, + clientPrivateKeyHexValue: String, + clientPublicKeyHexValue: String, + serverPublicKeyHexValue: String + ) throws -> String { + return try sharedSecret.get() + } + + func generateDevicePasswordVerifier( + deviceGroupKey: String, + deviceKey: String, + password: String + ) -> (salt: Data, passwordVerifier: Data) { + return (salt: Data(), passwordVerifier: Data()) + } +} diff --git a/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/ActionTests/SignOut/ShowHostedUISignOutTests.swift b/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/ActionTests/SignOut/ShowHostedUISignOutTests.swift new file mode 100644 index 0000000000..0751703550 --- /dev/null +++ b/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/ActionTests/SignOut/ShowHostedUISignOutTests.swift @@ -0,0 +1,401 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + +@testable import AWSCognitoAuthPlugin +import AWSCognitoIdentityProvider +import AWSPluginsCore +import XCTest + +class ShowHostedUISignOutTests: XCTestCase { + private var mockHostedUIResult: Result<[URLQueryItem], HostedUIError>! + private var signOutRedirectURI: String! + + override func setUp() { + signOutRedirectURI = "myapp://" + mockHostedUIResult = .success([.init(name: "key", value: "value")]) + } + + override func tearDown() { + signOutRedirectURI = nil + mockHostedUIResult = nil + } + + /// Given: A ShowHostedUISignOut action with global sign out set to true + /// When: execute is invoked with a success result + /// Then: A .signOutGlobally event is dispatched with a nil error + func testExecute_withGlobalSignOut_andSuccessResult_shouldDispatchSignOutEvent() async { + let expectation = expectation(description: "showHostedUISignOut") + let signInData = SignedInData.testData + let action = ShowHostedUISignOut( + signOutEvent: SignOutEventData(globalSignOut: true), + signInData: signInData + ) + + await action.execute( + withDispatcher: MockDispatcher { event in + guard let event = event as? SignOutEvent, + case .signOutGlobally(let data, let error) = event.eventType else { + XCTFail("Expected SignOutEvent.signOutGlobally, got \(event)") + expectation.fulfill() + return + } + + XCTAssertNil(error) + XCTAssertEqual(data, signInData) + self.validateDebugInformation(signInData: signInData, action: action) + + expectation.fulfill() + }, + environment: Defaults.makeDefaultAuthEnvironment( + userPoolFactory: identityProviderFactory, + hostedUIEnvironment: hostedUIEnvironment + ) + ) + + await fulfillment(of: [expectation], timeout: 1) + } + + /// Given: A ShowHostedUISignOut action with global sign out set to false + /// When: execute is invoked with a success result + /// Then: A .revokeToken event is dispatched + func testExecute_withLocalSignOut_andSuccessResult_shouldDispatchSignOutEvent() async { + let expectation = expectation(description: "showHostedUISignOut") + let signInData = SignedInData.testData + let action = ShowHostedUISignOut( + signOutEvent: SignOutEventData(globalSignOut: false), + signInData: signInData + ) + + await action.execute( + withDispatcher: MockDispatcher { event in + guard let event = event as? SignOutEvent, + case .revokeToken(let data, let error, let globalSignOutError) = event.eventType else { + XCTFail("Expected SignOutEvent.revokeToken, got \(event)") + expectation.fulfill() + return + } + + XCTAssertNil(error) + XCTAssertNil(globalSignOutError) + XCTAssertEqual(data, signInData) + expectation.fulfill() + }, + environment: Defaults.makeDefaultAuthEnvironment( + userPoolFactory: identityProviderFactory, + hostedUIEnvironment: hostedUIEnvironment + ) + ) + + await fulfillment(of: [expectation], timeout: 1) + } + + /// Given: A ShowHostedUISignOut action + /// When: execute is invoked but fails to create a HostedUI session + /// Then: A .userCancelled event is dispatched + func testExecute_withInvalidResult_shouldDispatchUserCancelledEvent() async { + mockHostedUIResult = .failure(.cancelled) + let signInData = SignedInData.testData + + let action = ShowHostedUISignOut( + signOutEvent: .testData, + signInData: signInData + ) + + let expectation = expectation(description: "showHostedUISignOut") + await action.execute( + withDispatcher: MockDispatcher { event in + guard let event = event as? SignOutEvent else { + XCTFail("Expected SignOutEvent, got \(event)") + expectation.fulfill() + return + } + + XCTAssertEqual(event.eventType, .userCancelled) + expectation.fulfill() + }, + environment: Defaults.makeDefaultAuthEnvironment( + userPoolFactory: identityProviderFactory, + hostedUIEnvironment: hostedUIEnvironment + ) + ) + + await fulfillment(of: [expectation], timeout: 1) + } + + /// Given: A ShowHostedUISignOut action + /// When: execute is invoked but fails to create a HostedUI session with a HostedUIError.signOutURI + /// Then: A .signOutGlobally event is dispatched with a HosterUIError.configuration error + func testExecute_withSignOutURIError_shouldThrowConfigurationError() async { + mockHostedUIResult = .failure(HostedUIError.signOutURI) + let signInData = SignedInData.testData + + let action = ShowHostedUISignOut( + signOutEvent: .testData, + signInData: signInData + ) + + let expectation = expectation(description: "showHostedUISignOut") + await action.execute( + withDispatcher: MockDispatcher { event in + guard let event = event as? SignOutEvent, + case .signOutGlobally(let data, let hostedUIError) = event.eventType else { + XCTFail("Expected SignOutEvent.signOutGlobally, got \(event)") + expectation.fulfill() + return + } + + guard let hostedUIError = hostedUIError, + case .configuration(let errorDescription, _, let serviceError) = hostedUIError.error else { + XCTFail("Expected AuthError.configuration") + expectation.fulfill() + return + } + + XCTAssertEqual(errorDescription, "Could not create logout URL") + XCTAssertEqual(data, signInData) + XCTAssertNil(serviceError) + expectation.fulfill() + }, + environment: Defaults.makeDefaultAuthEnvironment( + userPoolFactory: identityProviderFactory, + hostedUIEnvironment: hostedUIEnvironment + ) + ) + + await fulfillment(of: [expectation], timeout: 1) + } + + /// Given: A ShowHostedUISignOut action + /// When: execute is invoked but fails to create a HostedUI session with a HostedUIError.invalidContext + /// Then: A .signOutGlobally event is dispatched with a HosterUIError.invalidState error + func testExecute_withInvalidContext_shouldThrowInvalidStateError() async { + mockHostedUIResult = .failure(HostedUIError.invalidContext) + let signInData = SignedInData.testData + + let action = ShowHostedUISignOut( + signOutEvent: .testData, + signInData: signInData + ) + + let expectation = expectation(description: "showHostedUISignOut") + await action.execute( + withDispatcher: MockDispatcher { event in + guard let event = event as? SignOutEvent, + case .signOutGlobally(let data, let hostedUIError) = event.eventType else { + XCTFail("Expected SignOutEvent.signOutGlobally, got \(event)") + expectation.fulfill() + return + } + + guard let hostedUIError = hostedUIError, + case .invalidState(let errorDescription, let recoverySuggestion, let serviceError) = hostedUIError.error else { + XCTFail("Expected AuthError.invalidState") + expectation.fulfill() + return + } + + XCTAssertEqual(errorDescription, AuthPluginErrorConstants.hostedUIInvalidPresentation.errorDescription) + XCTAssertEqual(recoverySuggestion, AuthPluginErrorConstants.hostedUIInvalidPresentation.recoverySuggestion) + XCTAssertEqual(data, signInData) + XCTAssertNil(serviceError) + expectation.fulfill() + }, + environment: Defaults.makeDefaultAuthEnvironment( + userPoolFactory: identityProviderFactory, + hostedUIEnvironment: hostedUIEnvironment + ) + ) + + await fulfillment(of: [expectation], timeout: 1) + } + + /// Given: A ShowHostedUISignOut action with an invalid SignOutRedirectURI + /// When: execute is invoked + /// Then: A .signOutGlobally event is dispatched with a HosterUIError.configuration error + func testExecute_withInvalidSignOutURI_shouldThrowConfigurationError() async { + signOutRedirectURI = "invalidURI" + let signInData = SignedInData.testData + + let action = ShowHostedUISignOut( + signOutEvent: .testData, + signInData: signInData + ) + + let expectation = expectation(description: "showHostedUISignOut") + await action.execute( + withDispatcher: MockDispatcher { event in + guard let event = event as? SignOutEvent, + case .signOutGlobally(let data, let hostedUIError) = event.eventType else { + XCTFail("Expected SignOutEvent.signOutGlobally, got \(event)") + expectation.fulfill() + return + } + + guard let hostedUIError = hostedUIError, + case .configuration(let errorDescription, _, let serviceError) = hostedUIError.error else { + XCTFail("Expected AuthError.configuration") + expectation.fulfill() + return + } + + XCTAssertEqual(errorDescription, "Callback URL could not be retrieved") + XCTAssertEqual(data, signInData) + XCTAssertNil(serviceError) + expectation.fulfill() + }, + environment: Defaults.makeDefaultAuthEnvironment( + userPoolFactory: identityProviderFactory, + hostedUIEnvironment: hostedUIEnvironment + ) + ) + + await fulfillment(of: [expectation], timeout: 1) + } + + /// Given: A ShowHostedUISignOut action + /// When: execute is invoked with a nil HostedUIEnvironment + /// Then: A .signOutGlobally event is dispatched with a HosterUIError.configuration error + func testExecute_withoutHostedUIEnvironment_shouldThrowConfigurationError() async { + let expectation = expectation(description: "noHostedUIEnvironment") + let signInData = SignedInData.testData + let action = ShowHostedUISignOut( + signOutEvent: .testData, + signInData: signInData + ) + await action.execute( + withDispatcher: MockDispatcher { event in + guard let event = event as? SignOutEvent, + case .signOutGlobally(let data, let hostedUIError) = event.eventType else { + XCTFail("Expected SignOutEvent.signOutGlobally, got \(event)") + expectation.fulfill() + return + } + + guard let hostedUIError = hostedUIError, + case .configuration(let errorDescription, _, let serviceError) = hostedUIError.error else { + XCTFail("Expected AuthError.configuration") + expectation.fulfill() + return + } + + XCTAssertEqual(data, signInData) + XCTAssertEqual(errorDescription, AuthPluginErrorConstants.configurationError) + XCTAssertNil(serviceError) + expectation.fulfill() + }, + environment: Defaults.makeDefaultAuthEnvironment( + userPoolFactory: identityProviderFactory, + hostedUIEnvironment: nil + ) + ) + + await fulfillment(of: [expectation], timeout: 1) + } + + /// Given: A ShowHostedUISignOut action + /// When: execute is invoked with an invalid environment + /// Then: A .signOutGlobally event is dispatched with a HosterUIError.configuration error + func testExecute_withInvalidUserPoolEnvironment_shouldThrowConfigurationError() async { + let expectation = expectation(description: "invalidUserPoolEnvironment") + let signInData = SignedInData.testData + let action = ShowHostedUISignOut( + signOutEvent: .testData, + signInData: signInData + ) + await action.execute( + withDispatcher: MockDispatcher { event in + guard let event = event as? SignOutEvent, + case .signOutGlobally(let data, let hostedUIError) = event.eventType else { + XCTFail("Expected SignOutEvent.signOutGlobally, got \(event)") + expectation.fulfill() + return + } + + guard let hostedUIError = hostedUIError, + case .configuration(let errorDescription, _, let serviceError) = hostedUIError.error else { + XCTFail("Expected AuthError.configuration") + expectation.fulfill() + return + } + + XCTAssertEqual(data, signInData) + XCTAssertEqual(errorDescription, AuthPluginErrorConstants.configurationError) + XCTAssertNil(serviceError) + expectation.fulfill() + }, + environment: MockInvalidEnvironment() + ) + + await fulfillment(of: [expectation], timeout: 1) + } + + private func validateDebugInformation(signInData: SignedInData, action: ShowHostedUISignOut) { + XCTAssertFalse(action.debugDescription.isEmpty) + guard let signInDataDictionary = action.debugDictionary["signInData"] as? [String: Any] else { + XCTFail("Expected signInData dictionary") + return + } + XCTAssertEqual(signInDataDictionary.count, signInData.debugDictionary.count) + + for key in signInDataDictionary.keys { + guard let left = signInDataDictionary[key] as? any Equatable, + let right = signInData.debugDictionary[key] as? any Equatable else { + continue + } + XCTAssertTrue(left.isEqual(to: right)) + } + } + + private var hostedUIEnvironment: HostedUIEnvironment { + BasicHostedUIEnvironment( + configuration: .init( + clientId: "clientId", + oauth: .init( + domain: "cognitodomain", + scopes: ["name"], + signInRedirectURI: "myapp://", + signOutRedirectURI: signOutRedirectURI + ) + ), + hostedUISessionFactory: { + MockHostedUISession(result: self.mockHostedUIResult) + }, + urlSessionFactory: { + URLSession.shared + }, + randomStringFactory: { + MockRandomStringGenerator( + mockString: "mockString", + mockUUID: "mockUUID" + ) + } + ) + } + + private func identityProviderFactory() throws -> CognitoUserPoolBehavior { + return MockIdentityProvider( + mockInitiateAuthResponse: { _ in + return InitiateAuthOutputResponse( + authenticationResult: .init( + accessToken: "accessTokenNew", + expiresIn: 100, + idToken: "idTokenNew", + refreshToken: "refreshTokenNew") + ) + } + ) + } +} + +private extension Equatable { + func isEqual(to other: any Equatable) -> Bool { + guard let other = other as? Self else { + return false + } + return self == other + } +} diff --git a/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/CognitoASFTests/CognitoUserPoolASFTests.swift b/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/CognitoASFTests/CognitoUserPoolASFTests.swift new file mode 100644 index 0000000000..63e72819b2 --- /dev/null +++ b/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/CognitoASFTests/CognitoUserPoolASFTests.swift @@ -0,0 +1,62 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + +@testable import AWSCognitoAuthPlugin +import XCTest + +class CognitoUserPoolASFTests: XCTestCase { + private var userPool: CognitoUserPoolASF! + + override func setUp() { + userPool = CognitoUserPoolASF() + } + + override func tearDown() { + userPool = nil + } + + /// Given: A CognitoUserPoolASF + /// When: userContextData is invoked + /// Then: A non-empty string is returned + func testUserContextData_shouldReturnData() throws { + let result = try userPool.userContextData( + for: "TestUser", + deviceInfo: ASFDeviceInfo(id: "mockedDevice"), + appInfo: ASFAppInfo(), + configuration: .testData + ) + XCTAssertFalse(result.isEmpty) + } + + /// Given: A CognitoUserPoolASF + /// When: calculateSecretHash is invoked + /// Then: A non-empty string is returned + func testCalculateSecretHash_shouldReturnHash() throws { + let result = try userPool.calculateSecretHash( + contextJson: "contextJson", + clientId: "clientId" + ) + XCTAssertFalse(result.isEmpty) + } + + /// Given: A CognitoUserPoolASF + /// When: calculateSecretHash is invoked with a clientId that cannot be parsed + /// Then: A ASFError.hashKey is thrown + func testCalculateSecretHash_withInvalidClientId_shouldThrowHashKeyError() { + do { + let result = try userPool.calculateSecretHash( + contextJson: "contextJson", + clientId: "🕺🏼" // This string cannot be represented using .ascii, so it will throw an error + ) + XCTFail("Expected ASFError.hashKey, got \(result)") + } catch let error as ASFError { + XCTAssertEqual(error, .hashKey) + } catch { + XCTFail("Expected ASFError.hashKey, for \(error)") + } + } +} diff --git a/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/ConfigurationTests/EscapeHatchTests.swift b/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/ConfigurationTests/EscapeHatchTests.swift index d385c628d8..acf3ec1c0d 100644 --- a/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/ConfigurationTests/EscapeHatchTests.swift +++ b/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/ConfigurationTests/EscapeHatchTests.swift @@ -6,143 +6,116 @@ // import XCTest -@testable import Amplify +@testable import func AmplifyTestCommon.XCTAssertThrowFatalError +import enum Amplify.JSONValue @testable import AWSCognitoAuthPlugin -class EscapeHatchTests: XCTestCase { - - let skipBrokenTests = true - - override func tearDown() async throws { - await Amplify.reset() - } +class EscapeHatchTests: XCTestCase { /// Test escape hatch with valid config for user pool and identity pool /// - /// - Given: Given valid config for user pool and identity pool + /// - Given: A AWSCognitoAuthPlugin configured with User Pool and Identity Pool /// - When: - /// - I configure auth with the given configuration and call getEscapeHatch + /// - I call getEscapeHatch /// - Then: - /// - I should get back user pool and identity pool clients + /// - I should get back both the User Pool and Identity Pool clients /// func testEscapeHatchWithUserPoolAndIdentityPool() throws { - if skipBrokenTests { - throw XCTSkip("TODO: fix this test") - } - - let plugin = AWSCognitoAuthPlugin() - try Amplify.add(plugin: plugin) - - let expectation = expectation(description: "Should get service") - let categoryConfig = AuthCategoryConfiguration(plugins: [ - "awsCognitoAuthPlugin": [ - "CredentialsProvider": ["CognitoIdentity": ["Default": - ["PoolId": "xx", - "Region": "us-east-1"] - ]], - "CognitoUserPool": ["Default": [ + let configuration: JSONValue = [ + "CredentialsProvider": [ + "CognitoIdentity": [ + "Default": [ + "PoolId": "xx", + "Region": "us-east-1" + ] + ] + ], + "CognitoUserPool": [ + "Default": [ "PoolId": "xx", "Region": "us-east-1", "AppClientId": "xx", - "AppClientSecret": "xx"]] + "AppClientSecret": "xx" + ] ] - ]) - let amplifyConfig = AmplifyConfiguration(auth: categoryConfig) - try Amplify.configure(amplifyConfig) - let internalPlugin = try Amplify.Auth.getPlugin( - for: "awsCognitoAuthPlugin" - ) as! AWSCognitoAuthPlugin - let service = internalPlugin.getEscapeHatch() - switch service { - case .userPool: - XCTFail("Should return userPoolAndIdentityPool") - case .identityPool: - XCTFail("Should return userPoolAndIdentityPool") - case .userPoolAndIdentityPool: - expectation.fulfill() + ] + let plugin = AWSCognitoAuthPlugin() + try plugin.configure(using: configuration) + let escapeHatch = plugin.getEscapeHatch() + guard case .userPoolAndIdentityPool = escapeHatch else { + XCTFail("Expected .userPoolAndIdentityPool, got \(escapeHatch)") + return } - wait(for: [expectation], timeout: 1) } /// Test escape hatch with valid config for only identity pool /// - /// - Given: Given valid config for only identity pool + /// - Given: A AWSCognitoAuthPlugin configured with only Identity Pool /// - When: - /// - I configure auth with the given configuration and invoke getEscapeHatch + /// - I call getEscapeHatch /// - Then: - /// - I should get back only identity pool client + /// - I should get back only the Identity Pool client /// func testEscapeHatchWithOnlyIdentityPool() throws { - if skipBrokenTests { - throw XCTSkip("TODO: fix this test") - } - - let plugin = AWSCognitoAuthPlugin() - try Amplify.add(plugin: plugin) - - let categoryConfig = AuthCategoryConfiguration(plugins: [ - "awsCognitoAuthPlugin": [ - "CredentialsProvider": ["CognitoIdentity": ["Default": - ["PoolId": "cc", - "Region": "us-east-1"] - ]] + let configuration: JSONValue = [ + "CredentialsProvider": [ + "CognitoIdentity": [ + "Default": [ + "PoolId": "xx", + "Region": "us-east-1" + ] + ] ] - ]) - let amplifyConfig = AmplifyConfiguration(auth: categoryConfig) - try Amplify.configure(amplifyConfig) - let internalPlugin = try Amplify.Auth.getPlugin( - for: "awsCognitoAuthPlugin" - ) as! AWSCognitoAuthPlugin - let service = internalPlugin.getEscapeHatch() - switch service { - case .userPool: - XCTFail("Should return identityPool") - case .userPoolAndIdentityPool: - XCTFail("Should return identityPool") - case .identityPool: - print("") + ] + let plugin = AWSCognitoAuthPlugin() + try plugin.configure(using: configuration) + let escapeHatch = plugin.getEscapeHatch() + guard case .identityPool = escapeHatch else { + XCTFail("Expected .identityPool, got \(escapeHatch)") + return } } /// Test escape hatch with valid config for only user pool /// - /// - Given: Given valid config for only user pool + /// - Given: A AWSCognitoAuthPlugin configured with only User Pool /// - When: - /// - I configure auth with the given configuration and invoke getEscapeHatch + /// - I call getEscapeHatch /// - Then: - /// - I should get the Cognito User pool client + /// - I should get only the User Pool client /// func testEscapeHatchWithOnlyUserPool() throws { - if skipBrokenTests { - throw XCTSkip("TODO: fix this test") - } - - let plugin = AWSCognitoAuthPlugin() - try Amplify.add(plugin: plugin) - - let categoryConfig = AuthCategoryConfiguration(plugins: [ - "awsCognitoAuthPlugin": [ - "CognitoUserPool": ["Default": [ + let configuration: JSONValue = [ + "CognitoUserPool": [ + "Default": [ "PoolId": "xx", "Region": "us-east-1", "AppClientId": "xx", - "AppClientSecret": "xx"]] + "AppClientSecret": "xx" + ] ] - ]) - let amplifyConfig = AmplifyConfiguration(auth: categoryConfig) - try Amplify.configure(amplifyConfig) - let internalPlugin = try Amplify.Auth.getPlugin( - for: "awsCognitoAuthPlugin" - ) as! AWSCognitoAuthPlugin - let service = internalPlugin.getEscapeHatch() - switch service { - case .userPool: - break - case .identityPool: - XCTFail("Should return userPool") - case .userPoolAndIdentityPool: - XCTFail("Should return userPool") + ] + let plugin = AWSCognitoAuthPlugin() + try plugin.configure(using: configuration) + let escapeHatch = plugin.getEscapeHatch() + guard case .userPool = escapeHatch else { + XCTFail("Expected .userPool, got \(escapeHatch)") + return + } + } + + /// Test escape hatch without a valid configuration + /// + /// - Given: A AWSCognitoAuthPlugin plugin without being configured + /// - When: + /// - I call getEscapeHatch + /// - Then: + /// - A fatalError is thrown + /// + func testEscapeHatchWithoutConfiguration() throws { + let plugin = AWSCognitoAuthPlugin() + try XCTAssertThrowFatalError { + _ = plugin.getEscapeHatch() } } - } diff --git a/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/Support/AWSAuthCognitoSessionTests.swift b/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/Support/AWSAuthCognitoSessionTests.swift index fdb8862284..43db976492 100644 --- a/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/Support/AWSAuthCognitoSessionTests.swift +++ b/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/Support/AWSAuthCognitoSessionTests.swift @@ -26,8 +26,7 @@ class AWSAuthCognitoSessionTests: XCTestCase { let error = AuthError.unknown("", nil) let tokens = AWSCognitoUserPoolTokens(idToken: CognitoAuthTestHelper.buildToken(for: tokenData), accessToken: CognitoAuthTestHelper.buildToken(for: tokenData), - refreshToken: "refreshToken", - expiresIn: 121) + refreshToken: "refreshToken") let session = AWSAuthCognitoSession(isSignedIn: true, identityIdResult: .failure(error), @@ -53,8 +52,7 @@ class AWSAuthCognitoSessionTests: XCTestCase { let error = AuthError.unknown("", nil) let tokens = AWSCognitoUserPoolTokens(idToken: CognitoAuthTestHelper.buildToken(for: tokenData), accessToken: CognitoAuthTestHelper.buildToken(for: tokenData), - refreshToken: "refreshToken", - expiresIn: 121) + refreshToken: "refreshToken") let session = AWSAuthCognitoSession(isSignedIn: true, identityIdResult: .failure(error), @@ -65,4 +63,156 @@ class AWSAuthCognitoSessionTests: XCTestCase { XCTAssertFalse(cognitoTokens.doesExpire()) } + /// Given: An AWSAuthCognitoSession with a valid AWSCognitoUserPoolTokens + /// When: getUserSub is invoked + /// Then: The "sub" from the token data should be returned + func testGetUserSub_shouldReturnResult() { + let tokenData = [ + "sub": "1234567890", + "name": "John Doe", + "iat": "1516239022", + "exp": String(Date(timeIntervalSinceNow: 121).timeIntervalSince1970) + ] + + let error = AuthError.unknown("", nil) + let tokens = AWSCognitoUserPoolTokens( + idToken: CognitoAuthTestHelper.buildToken(for: tokenData), + accessToken: CognitoAuthTestHelper.buildToken(for: tokenData), + refreshToken: "refreshToken" + ) + + let session = AWSAuthCognitoSession( + isSignedIn: true, + identityIdResult: .failure(error), + awsCredentialsResult: .failure(error), + cognitoTokensResult: .success(tokens) + ) + + guard case .success(let userSub) = session.getUserSub() else { + XCTFail("Unable to retrieve userSub") + return + } + XCTAssertEqual(userSub, "1234567890") + } + + /// Given: An AWSAuthCognitoSession with a AWSCognitoUserPoolTokens that does not include a "sub" attribute + /// When: getUserSub is invoked + /// Then: A .failure with AuthError.unknown error is returned + func testGetUserSub_withoutSub_shouldReturnError() { + let tokenData = [ + "name": "John Doe", + "iat": "1516239022", + "exp": String(Date(timeIntervalSinceNow: 121).timeIntervalSince1970) + ] + + let error = AuthError.unknown("", nil) + let tokens = AWSCognitoUserPoolTokens( + idToken: CognitoAuthTestHelper.buildToken(for: tokenData), + accessToken: CognitoAuthTestHelper.buildToken(for: tokenData), + refreshToken: "refreshToken" + ) + + let session = AWSAuthCognitoSession( + isSignedIn: true, + identityIdResult: .failure(error), + awsCredentialsResult: .failure(error), + cognitoTokensResult: .success(tokens) + ) + + guard case .failure(let error) = session.getUserSub(), + case .unknown(let errorDescription, _) = error else { + XCTFail("Expected AuthError.unknown") + return + } + + XCTAssertEqual(errorDescription, "Could not retreive user sub from the fetched Cognito tokens.") + } + + /// Given: An AWSAuthCognitoSession that is signed out + /// When: getUserSub is invoked + /// Then: A .failure with AuthError.signedOut error is returned + func testGetUserSub_signedOut_shouldReturnError() { + let error = AuthError.signedOut("", "", nil) + let session = AWSAuthCognitoSession( + isSignedIn: false, + identityIdResult: .failure(error), + awsCredentialsResult: .failure(error), + cognitoTokensResult: .failure(error) + ) + + guard case .failure(let error) = session.getUserSub(), + case .signedOut(let errorDescription, let recoverySuggestion, _) = error else { + XCTFail("Expected AuthError.signedOut") + return + } + + XCTAssertEqual(errorDescription, AuthPluginErrorConstants.userSubSignOutError.errorDescription) + XCTAssertEqual(recoverySuggestion, AuthPluginErrorConstants.userSubSignOutError.recoverySuggestion) + } + + /// Given: An AWSAuthCognitoSession that has a service error + /// When: getUserSub is invoked + /// Then: A .failure with AuthError.signedOut error is returned + func testGetUserSub_serviceError_shouldReturnError() { + let serviceError = AuthError.service("Something went wrong", "Try again", nil) + let session = AWSAuthCognitoSession( + isSignedIn: false, + identityIdResult: .failure(serviceError), + awsCredentialsResult: .failure(serviceError), + cognitoTokensResult: .failure(serviceError) + ) + + guard case .failure(let error) = session.getUserSub() else { + XCTFail("Expected AuthError.signedOut") + return + } + + XCTAssertEqual(error, serviceError) + } + + /// Given: An AuthAWSCognitoCredentials and an AWSCognitoUserPoolTokens instance + /// When: Two AWSAuthCognitoSession are created from the same values + /// Then: The two AWSAuthCognitoSession are considered equal + func testSessionsAreEqual() { + let expiration = Date(timeIntervalSinceNow: 121) + let tokenData = [ + "sub": "1234567890", + "name": "John Doe", + "iat": "1516239022", + "exp": String(expiration.timeIntervalSince1970) + ] + + let credentials = AuthAWSCognitoCredentials( + accessKeyId: "accessKeyId", + secretAccessKey: "secretAccessKey", + sessionToken: "sessionToken", + expiration: expiration + ) + + let tokens = AWSCognitoUserPoolTokens( + idToken: CognitoAuthTestHelper.buildToken(for: tokenData), + accessToken: CognitoAuthTestHelper.buildToken(for: tokenData), + refreshToken: "refreshToken" + ) + + let session1 = AWSAuthCognitoSession( + isSignedIn: true, + identityIdResult: .success("identityId"), + awsCredentialsResult: .success(credentials), + cognitoTokensResult: .success(tokens) + ) + + let session2 = AWSAuthCognitoSession( + isSignedIn: true, + identityIdResult: .success("identityId"), + awsCredentialsResult: .success(credentials), + cognitoTokensResult: .success(tokens) + ) + + XCTAssertEqual(session1, session2) + XCTAssertEqual(session1.debugDictionary.count, session2.debugDictionary.count) + for key in session1.debugDictionary.keys where (key != "AWS Credentials" && key != "cognitoTokens") { + XCTAssertEqual(session1.debugDictionary[key] as? String, session2.debugDictionary[key] as? String) + } + } } diff --git a/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/Support/HostedUIASWebAuthenticationSessionTests.swift b/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/Support/HostedUIASWebAuthenticationSessionTests.swift new file mode 100644 index 0000000000..3909827f36 --- /dev/null +++ b/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/Support/HostedUIASWebAuthenticationSessionTests.swift @@ -0,0 +1,248 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if os(iOS) || os(macOS) +import Amplify +import AuthenticationServices +@testable import AWSCognitoAuthPlugin +import XCTest + +class HostedUIASWebAuthenticationSessionTests: XCTestCase { + private var session: HostedUIASWebAuthenticationSession! + private var factory: ASWebAuthenticationSessionFactory! + + override func setUp() { + session = HostedUIASWebAuthenticationSession() + factory = ASWebAuthenticationSessionFactory() + session.authenticationSessionFactory = factory.createSession(url:callbackURLScheme:completionHandler:) + } + + override func tearDown() { + session = nil + factory = nil + } + + /// Given: A HostedUIASWebAuthenticationSession + /// When: showHostedUI is invoked and the session factory returns a URL with query items + /// Then: An array of query items should be returned + func testShowHostedUI_withUrlInCallback_withQueryItems_shouldReturnQueryItems() { + let expectation = expectation(description: "showHostedUI") + factory.mockedURL = createURL(queryItems: [.init(name: "name", value: "value")]) + + session.showHostedUI() { result in + do { + let queryItems = try result.get() + XCTAssertEqual(queryItems.count, 1) + XCTAssertEqual(queryItems.first?.name, "name") + XCTAssertEqual(queryItems.first?.value, "value") + } catch { + XCTFail("Expected .success(queryItems), got \(result)") + } + expectation.fulfill() + } + waitForExpectations(timeout: 1) + } + + /// Given: A HostedUIASWebAuthenticationSession + /// When: showHostedUI is invoked and the session factory returns a URL without query items + /// Then: An empty array should be returned + func testShowHostedUI_withUrlInCallback_withoutQueryItems_shouldReturnEmptyQueryItems() { + let expectation = expectation(description: "showHostedUI") + factory.mockedURL = createURL() + + session.showHostedUI() { result in + do { + let queryItems = try result.get() + XCTAssertTrue(queryItems.isEmpty) + } catch { + XCTFail("Expected .success(queryItems), got \(result)") + } + expectation.fulfill() + } + waitForExpectations(timeout: 1) + } + + /// Given: A HostedUIASWebAuthenticationSession + /// When: showHostedUI is invoked and the session factory returns a URL with query items representing errors + /// Then: A HostedUIError.serviceMessage should be returned + func testShowHostedUI_withUrlInCallback_withErrorInQueryItems_shouldReturnServiceMessageError() { + let expectation = expectation(description: "showHostedUI") + factory.mockedURL = createURL( + queryItems: [ + .init(name: "error", value: "Error."), + .init(name: "error_description", value: "Something went wrong") + ] + ) + + session.showHostedUI() { result in + do { + _ = try result.get() + XCTFail("Expected failure(.serviceMessage), got \(result)") + } catch let error as HostedUIError { + if case .serviceMessage(let message) = error { + XCTAssertEqual(message, "Error. Something went wrong") + } else { + XCTFail("Expected HostedUIError.serviceMessage, got \(error)") + } + } catch { + XCTFail("Expected HostedUIError.serviceMessage, got \(error)") + } + expectation.fulfill() + } + waitForExpectations(timeout: 1) + } + + /// Given: A HostedUIASWebAuthenticationSession + /// When: showHostedUI is invoked and the session factory returns ASWebAuthenticationSessionErrors + /// Then: A HostedUIError corresponding to the error code should be returned + func testShowHostedUI_withASWebAuthenticationSessionErrors_shouldReturnRightError() { + let errorMap: [ASWebAuthenticationSessionError.Code: HostedUIError] = [ + .canceledLogin: .cancelled, + .presentationContextNotProvided: .invalidContext, + .presentationContextInvalid: .invalidContext + ] + + let errorCodes: [ASWebAuthenticationSessionError.Code] = [ + .canceledLogin, + .presentationContextNotProvided, + .presentationContextInvalid, + .init(rawValue: 500)! + ] + + for code in errorCodes { + factory.mockedError = ASWebAuthenticationSessionError(code) + let expectedError = errorMap[code] ?? .unknown + let expectation = expectation(description: "showHostedUI for error \(code)") + session.showHostedUI() { result in + do { + _ = try result.get() + XCTFail("Expected failure(.\(expectedError)), got \(result)") + } catch let error as HostedUIError { + XCTAssertEqual(error, expectedError) + } catch { + XCTFail("Expected HostedUIError.\(expectedError), got \(error)") + } + expectation.fulfill() + } + waitForExpectations(timeout: 1) + } + } + + /// Given: A HostedUIASWebAuthenticationSession + /// When: showHostedUI is invoked and the session factory returns an error + /// Then: A HostedUIError.unknown should be returned + func testShowHostedUI_withOtherError_shouldReturnUnknownError() { + factory.mockedError = CancellationError() + let expectation = expectation(description: "showHostedUI") + session.showHostedUI() { result in + do { + _ = try result.get() + XCTFail("Expected failure(.unknown), got \(result)") + } catch let error as HostedUIError { + XCTAssertEqual(error, .unknown) + } catch { + XCTFail("Expected HostedUIError.unknown, got \(error)") + } + expectation.fulfill() + } + waitForExpectations(timeout: 1) + } + + private func createURL(queryItems: [URLQueryItem] = []) -> URL { + var components = URLComponents(string: "https://test.com")! + components.queryItems = queryItems + return components.url! + } +} + +class ASWebAuthenticationSessionFactory { + var mockedURL: URL? + var mockedError: Error? + + func createSession( + url URL: URL, + callbackURLScheme: String?, + completionHandler: @escaping ASWebAuthenticationSession.CompletionHandler + ) -> ASWebAuthenticationSession { + let session = MockASWebAuthenticationSession( + url: URL, + callbackURLScheme: callbackURLScheme, + completionHandler: completionHandler + ) + session.mockedURL = mockedURL + session.mockedError = mockedError + return session + } +} + +class MockASWebAuthenticationSession: ASWebAuthenticationSession { + private var callback: ASWebAuthenticationSession.CompletionHandler + override init( + url URL: URL, + callbackURLScheme: String?, + completionHandler: @escaping ASWebAuthenticationSession.CompletionHandler + ) { + self.callback = completionHandler + super.init( + url: URL, + callbackURLScheme: callbackURLScheme, + completionHandler: completionHandler + ) + } + + var mockedURL: URL? = nil + var mockedError: Error? = nil + override func start() -> Bool { + callback(mockedURL, mockedError) + return presentationContextProvider?.presentationAnchor(for: self) != nil + } +} + +extension HostedUIASWebAuthenticationSession { + func showHostedUI(callback: @escaping (Result<[URLQueryItem], HostedUIError>) -> Void) { + showHostedUI( + url: URL(string: "https://test.com")!, + callbackScheme: "https", + inPrivate: false, + presentationAnchor: nil, + callback: callback) + } +} +#else + +@testable import AWSCognitoAuthPlugin +import XCTest + +class HostedUIASWebAuthenticationSessionTests: XCTestCase { + func testShowHostedUI_shouldThrowServiceError() { + let expectation = expectation(description: "showHostedUI") + let session = HostedUIASWebAuthenticationSession() + session.showHostedUI( + url: URL(string: "https://test.com")!, + callbackScheme: "https", + inPrivate: false, + presentationAnchor: nil + ) { result in + do { + _ = try result.get() + XCTFail("Expected failure(.serviceMessage), got \(result)") + } catch let error as HostedUIError { + if case .serviceMessage(let message) = error { + XCTAssertEqual(message, "HostedUI is only available in iOS and macOS") + } else { + XCTFail("Expected HostedUIError.serviceMessage, got \(error)") + } + } catch { + XCTFail("Expected HostedUIError.serviceMessage, got \(error)") + } + expectation.fulfill() + } + waitForExpectations(timeout: 1) + } +} + +#endif diff --git a/AmplifyPlugins/Notifications/Push/Tests/AWSPinpointPushNotificationsPluginUnitTests/ErrorPushNotificationsTests.swift b/AmplifyPlugins/Notifications/Push/Tests/AWSPinpointPushNotificationsPluginUnitTests/ErrorPushNotificationsTests.swift new file mode 100644 index 0000000000..14963a0cc1 --- /dev/null +++ b/AmplifyPlugins/Notifications/Push/Tests/AWSPinpointPushNotificationsPluginUnitTests/ErrorPushNotificationsTests.swift @@ -0,0 +1,105 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + +@testable import Amplify +import AWSClientRuntime +import AwsCommonRuntimeKit +import AWSPinpoint +@testable import AWSPinpointPushNotificationsPlugin +import ClientRuntime +import Foundation +import XCTest + +class ErrorPushNotificationsTests: XCTestCase { + /// Given: A NSError error + /// When: pushNotificationsError is invoked + /// Then: An .unknown error is returned + func testPushNotificationsError_withUnknownError_shouldReturnUnknownError() { + let error = NSError(domain: "MyError", code: 1234) + let pushNotificationsError = error.pushNotificationsError + switch pushNotificationsError { + case .unknown(let errorDescription, let underlyingError): + XCTAssertEqual(errorDescription, "An unknown error occurred") + XCTAssertEqual(error.localizedDescription, underlyingError?.localizedDescription) + default: + XCTFail("Expected error of type .unknown, got \(pushNotificationsError)") + } + } + + /// Given: A NSError error with a connectivity-related error code + /// When: pushNotificationsError is invoked + /// Then: A .network error is returned + func testPushNotificationsError_withConnectivityError_shouldReturnNetworkError() { + let error = NSError(domain: "ConnectivityError", code: NSURLErrorNotConnectedToInternet) + let pushNotificationsError = error.pushNotificationsError + switch pushNotificationsError { + case .network(let errorDescription, let recoverySuggestion, let underlyingError): + XCTAssertEqual(errorDescription, PushNotificationsPluginErrorConstants.deviceOffline.errorDescription) + XCTAssertEqual(recoverySuggestion, PushNotificationsPluginErrorConstants.deviceOffline.recoverySuggestion) + XCTAssertEqual(error.localizedDescription, underlyingError?.localizedDescription) + default: + XCTFail("Expected error of type .network, got \(pushNotificationsError)") + } + } + + /// Given: An Error defined by the SDK + /// When: pushNotificationsError is invoked + /// Then: A .service error is returned + func testPushNotificationError_withServiceError_shouldReturnServiceError() { + let errors: [(String, PushNotificationsErrorConvertible & Error)] = [ + ("BadRequestException", BadRequestException(message: "BadRequestException")), + ("InternalServerErrorException", InternalServerErrorException(message: "InternalServerErrorException")), + ("ForbiddenException", ForbiddenException(message: "ForbiddenException")), + ("MethodNotAllowedException", MethodNotAllowedException(message: "MethodNotAllowedException")), + ("NotFoundException", NotFoundException(message: "NotFoundException")), + ("PayloadTooLargeException", PayloadTooLargeException(message: "PayloadTooLargeException")), + ("TooManyRequestsException", TooManyRequestsException(message: "TooManyRequestsException")) + ] + + for (expectedMessage, error) in errors { + let pushNotificationsError = error.pushNotificationsError + switch pushNotificationsError { + case .service(let errorDescription, let recoverySuggestion, let underlyingError): + XCTAssertEqual(errorDescription, expectedMessage) + XCTAssertEqual(recoverySuggestion, PushNotificationsPluginErrorConstants.nonRetryableServiceError.recoverySuggestion) + XCTAssertEqual(error.localizedDescription, underlyingError?.localizedDescription) + default: + XCTFail("Expected error of type .service, got \(pushNotificationsError)") + } + } + } + + /// Given: An UnknownAWSHTTPServiceError + /// When: pushNotificationsError is invoked + /// Then: A .unknown error is returned + func testPushNotificationError_withUnknownAWSHTTPServiceError_shouldReturnUnknownError() { + let error = UnknownAWSHTTPServiceError(httpResponse: .init(body: .none, statusCode: .accepted), message: "UnknownAWSHTTPServiceError", requestID: nil, typeName: nil) + let pushNotificationsError = error.pushNotificationsError + switch pushNotificationsError { + case .unknown(let errorDescription, let underlyingError): + XCTAssertEqual(errorDescription, "UnknownAWSHTTPServiceError") + XCTAssertEqual(error.localizedDescription, underlyingError?.localizedDescription) + default: + XCTFail("Expected error of type .unknown, got \(pushNotificationsError)") + } + } + + /// Given: A CommonRunTimeError.crtError + /// When: pushNotificationsError is invoked + /// Then: A .unknown error is returned + func testPushNotificationError_withCommonRunTimeError_shouldReturnUnknownError() { + let error = CommonRunTimeError.crtError(.init(code: 12345)) + let pushNotificationsError = error.pushNotificationsError + switch pushNotificationsError { + case .unknown(let errorDescription, let underlyingError): + XCTAssertEqual(errorDescription, "Unknown Error Code") + XCTAssertEqual(error.localizedDescription, underlyingError?.localizedDescription) + default: + XCTFail("Expected error of type .unknown, got \(pushNotificationsError)") + } + } +} diff --git a/AmplifyPlugins/Storage/Sources/AWSS3StoragePlugin/Dependency/AWSS3Adapter.swift b/AmplifyPlugins/Storage/Sources/AWSS3StoragePlugin/Dependency/AWSS3Adapter.swift index 969fd6475f..5fd65c6ec0 100644 --- a/AmplifyPlugins/Storage/Sources/AWSS3StoragePlugin/Dependency/AWSS3Adapter.swift +++ b/AmplifyPlugins/Storage/Sources/AWSS3StoragePlugin/Dependency/AWSS3Adapter.swift @@ -18,10 +18,10 @@ import AWSClientRuntime /// and allows for mocking in unit tests. The methods contain no other logic other than calling the /// same method using the AWSS3 instance. class AWSS3Adapter: AWSS3Behavior { - let awsS3: S3Client + let awsS3: S3ClientProtocol let config: S3Client.S3ClientConfiguration - init(_ awsS3: S3Client, config: S3Client.S3ClientConfiguration) { + init(_ awsS3: S3ClientProtocol, config: S3Client.S3ClientConfiguration) { self.awsS3 = awsS3 self.config = config } @@ -161,7 +161,7 @@ class AWSS3Adapter: AWSS3Behavior { /// Instance of S3 service. /// - Returns: S3 service instance. - func getS3() -> S3Client { + func getS3() -> S3ClientProtocol { return awsS3 } } diff --git a/AmplifyPlugins/Storage/Sources/AWSS3StoragePlugin/Dependency/AWSS3Behavior.swift b/AmplifyPlugins/Storage/Sources/AWSS3StoragePlugin/Dependency/AWSS3Behavior.swift index 7319878805..400ca3eb6c 100644 --- a/AmplifyPlugins/Storage/Sources/AWSS3StoragePlugin/Dependency/AWSS3Behavior.swift +++ b/AmplifyPlugins/Storage/Sources/AWSS3StoragePlugin/Dependency/AWSS3Behavior.swift @@ -35,7 +35,7 @@ protocol AWSS3Behavior { func abortMultipartUpload(_ request: AWSS3AbortMultipartUploadRequest, completion: @escaping (Result) -> Void) // Gets a client for AWS S3 Service. - func getS3() -> S3Client + func getS3() -> S3ClientProtocol } diff --git a/AmplifyPlugins/Storage/Sources/AWSS3StoragePlugin/Service/Storage/AWSS3StorageService.swift b/AmplifyPlugins/Storage/Sources/AWSS3StoragePlugin/Service/Storage/AWSS3StorageService.swift index c26fbc2c79..d31b9e588e 100644 --- a/AmplifyPlugins/Storage/Sources/AWSS3StoragePlugin/Service/Storage/AWSS3StorageService.swift +++ b/AmplifyPlugins/Storage/Sources/AWSS3StoragePlugin/Service/Storage/AWSS3StorageService.swift @@ -54,6 +54,7 @@ class AWSS3StorageService: AWSS3StorageServiceBehavior, StorageServiceProxy { httpClientEngineProxy: HttpClientEngineProxy? = nil, storageConfiguration: StorageConfiguration = .default, storageTransferDatabase: StorageTransferDatabase = .default, + fileSystem: FileSystem = .default, sessionConfiguration: URLSessionConfiguration? = nil, delegateQueue: OperationQueue? = nil, logger: Logger = storageLogger) throws { @@ -97,7 +98,9 @@ class AWSS3StorageService: AWSS3StorageServiceBehavior, StorageServiceProxy { self.init(authService: authService, storageConfiguration: storageConfiguration, storageTransferDatabase: storageTransferDatabase, + fileSystem: fileSystem, sessionConfiguration: _sessionConfiguration, + logger: logger, s3Client: s3Client, preSignedURLBuilder: preSignedURLBuilder, awsS3: awsS3, diff --git a/AmplifyPlugins/Storage/Sources/AWSS3StoragePlugin/Support/Internal/StorageMultipartUploadSession.swift b/AmplifyPlugins/Storage/Sources/AWSS3StoragePlugin/Support/Internal/StorageMultipartUploadSession.swift index d335f31805..fc51016cb9 100644 --- a/AmplifyPlugins/Storage/Sources/AWSS3StoragePlugin/Support/Internal/StorageMultipartUploadSession.swift +++ b/AmplifyPlugins/Storage/Sources/AWSS3StoragePlugin/Support/Internal/StorageMultipartUploadSession.swift @@ -44,14 +44,6 @@ class StorageMultipartUploadSession { private let transferTask: StorageTransferTask - private var contentType: String? { - transferTask.contentType - } - - private var requestHeaders: RequestHeaders? { - transferTask.requestHeaders - } - init(client: StorageMultipartUploadClient, bucket: String, key: String, diff --git a/AmplifyPlugins/Storage/Sources/AWSS3StoragePlugin/Support/Internal/StorageTransferTask.swift b/AmplifyPlugins/Storage/Sources/AWSS3StoragePlugin/Support/Internal/StorageTransferTask.swift index 1cd7e95385..1ecf8894db 100644 --- a/AmplifyPlugins/Storage/Sources/AWSS3StoragePlugin/Support/Internal/StorageTransferTask.swift +++ b/AmplifyPlugins/Storage/Sources/AWSS3StoragePlugin/Support/Internal/StorageTransferTask.swift @@ -173,10 +173,6 @@ class StorageTransferTask { } } - private var cancelled: Bool { - status == .cancelled - } - var isFailed: Bool { status == .error } @@ -324,7 +320,7 @@ class StorageTransferTask { logger.warn("Unable to complete after cancelled") return } - guard _status == .completed else { + guard _status != .completed else { logger.warn("Task is already completed") return } diff --git a/AmplifyPlugins/Storage/Tests/AWSS3StoragePluginTests/Dependency/AWSS3AdapterTests.swift b/AmplifyPlugins/Storage/Tests/AWSS3StoragePluginTests/Dependency/AWSS3AdapterTests.swift new file mode 100644 index 0000000000..4cf455e494 --- /dev/null +++ b/AmplifyPlugins/Storage/Tests/AWSS3StoragePluginTests/Dependency/AWSS3AdapterTests.swift @@ -0,0 +1,754 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + +@testable import Amplify +@testable import AWSS3StoragePlugin +import AWSS3 +import XCTest + +class AWSS3AdapterTests: XCTestCase { + private var adapter: AWSS3Adapter! + private var awsS3: S3ClientMock! + + override func setUp() { + awsS3 = S3ClientMock() + adapter = AWSS3Adapter( + awsS3, + config: try! S3Client.S3ClientConfiguration( + region: "us-east-1" + ) + ) + } + + override func tearDown() { + adapter = nil + awsS3 = nil + } + + /// Given: An AWSS3Adapter + /// When: deleteObject is invoked and the s3 client returns success + /// Then: A .success result is returned + func testDeleteObject_withSuccess_shouldSucceed() { + let deleteExpectation = expectation(description: "Delete Object") + adapter.deleteObject(.init(bucket: "bucket", key: "key")) { result in + XCTAssertEqual(self.awsS3.deleteObjectCount, 1) + guard case .success = result else { + XCTFail("Expected success") + return + } + deleteExpectation.fulfill() + } + + waitForExpectations(timeout: 1) + } + + /// Given: An AWSS3Adapter + /// When: deleteObject is invoked and the s3 client returns an error + /// Then: A .failure result is returned + func testDeleteObject_withError_shouldFail() { + let deleteExpectation = expectation(description: "Delete Object") + awsS3.deleteObjectResult = .failure(StorageError.keyNotFound("InvalidKey", "", "", nil)) + adapter.deleteObject(.init(bucket: "bucket", key: "key")) { result in + XCTAssertEqual(self.awsS3.deleteObjectCount, 1) + guard case .failure(let error) = result, + case .keyNotFound(let key, _, _, _) = error else { + XCTFail("Expected StorageError.keyNotFound") + return + } + XCTAssertEqual(key, "InvalidKey") + deleteExpectation.fulfill() + } + waitForExpectations(timeout: 1) + } + + /// Given: An AWSS3Adapter + /// When: listObjectsV2 is invoked and the s3 client returns a list of objects + /// Then: A .success result is returned containing the corresponding list items + func testListObjectsV2_withSuccess_shouldSucceed() { + let listExpectation = expectation(description: "List Objects") + awsS3.listObjectsV2Result = .success(ListObjectsV2OutputResponse( + contents: [ + .init(eTag: "one", key: "prefix/key1", lastModified: .init()), + .init(eTag: "two", key: "prefix/key2", lastModified: .init()) + ] + )) + adapter.listObjectsV2(.init( + bucket: "bucket", + prefix: "prefix/" + )) { result in + XCTAssertEqual(self.awsS3.listObjectsV2Count, 1) + guard case .success(let response) = result else { + XCTFail("Expected success") + return + } + XCTAssertEqual(response.items.count, 2) + XCTAssertTrue(response.items.contains(where: { $0.key == "key1" && $0.eTag == "one" })) + XCTAssertTrue(response.items.contains(where: { $0.key == "key2" && $0.eTag == "two" })) + listExpectation.fulfill() + } + + waitForExpectations(timeout: 1) + } + + /// Given: An AWSS3Adapter + /// When: listObjectsV2 is invoked and the s3 client returns an error + /// Then: A .failure result is returned + func testListObjectsV2_withError_shouldFail() { + let listExpectation = expectation(description: "List Objects") + awsS3.listObjectsV2Result = .failure(StorageError.accessDenied("AccessDenied", "", nil)) + adapter.listObjectsV2(.init( + bucket: "bucket", + prefix: "prefix" + )) { result in + XCTAssertEqual(self.awsS3.listObjectsV2Count, 1) + guard case .failure(let error) = result, + case .accessDenied(let description, _, _) = error else { + XCTFail("Expected StorageError.accessDenied") + return + } + XCTAssertEqual(description, "AccessDenied") + listExpectation.fulfill() + } + + waitForExpectations(timeout: 1) + } + + /// Given: An AWSS3Adapter + /// When: createMultipartUpload is invoked and the s3 client returns a valid response + /// Then: A .success result is returned containing the corresponding parsed response + func testCreateMultipartUpload_withSuccess_shouldSucceed() { + let createMultipartUploadExpectation = expectation(description: "Create Multipart Upload") + awsS3.createMultipartUploadResult = .success(.init( + bucket: "bucket", + key: "key", + uploadId: "uploadId" + )) + adapter.createMultipartUpload(.init(bucket: "bucket", key: "key")) { result in + XCTAssertEqual(self.awsS3.createMultipartUploadCount, 1) + guard case .success(let response) = result else { + XCTFail("Expected success") + return + } + XCTAssertEqual(response.bucket, "bucket") + XCTAssertEqual(response.key, "key") + XCTAssertEqual(response.uploadId, "uploadId") + createMultipartUploadExpectation.fulfill() + } + + waitForExpectations(timeout: 1) + } + + /// Given: An AWSS3Adapter + /// When: createMultipartUpload is invoked and the s3 client returns an invalid response + /// Then: A .failure result is returned with an .uknown error + func testCreateMultipartUpload_withWrongResponse_shouldFail() { + let createMultipartUploadExpectation = expectation(description: "Create Multipart Upload") + adapter.createMultipartUpload(.init(bucket: "bucket", key: "key")) { result in + XCTAssertEqual(self.awsS3.createMultipartUploadCount, 1) + guard case .failure(let error) = result, + case .unknown(let description, _) = error else { + XCTFail("Expected StorageError.unknown") + return + } + XCTAssertEqual(description, "Invalid response for creating multipart upload") + createMultipartUploadExpectation.fulfill() + } + + waitForExpectations(timeout: 1) + } + + /// Given: An AWSS3Adapter + /// When: createMultipartUpload is invoked and the s3 client returns an error + /// Then: A .failure result is returned + func testCreateMultipartUpload_withError_shouldFail() { + let createMultipartUploadExpectation = expectation(description: "Create Multipart Upload") + awsS3.createMultipartUploadResult = .failure(StorageError.accessDenied("AccessDenied", "", nil)) + adapter.createMultipartUpload(.init(bucket: "bucket", key: "key")) { result in + XCTAssertEqual(self.awsS3.createMultipartUploadCount, 1) + guard case .failure(let error) = result, + case .accessDenied(let description, _, _) = error else { + XCTFail("Expected StorageError.accessDenied") + return + } + XCTAssertEqual(description, "AccessDenied") + createMultipartUploadExpectation.fulfill() + } + + waitForExpectations(timeout: 1) + } + + /// Given: An AWSS3Adapter + /// When: listParts is invoked and the s3 client returns a valid response + /// Then: A .success result is returned containing the corresponding parsed response + func testListParts_withSuccess_shouldSucceed() { + let listPartsExpectation = expectation(description: "List Parts") + awsS3.listPartsResult = .success(.init( + bucket: "bucket", + key: "key", + parts: [ + .init(eTag: "eTag1", partNumber: 1), + .init(eTag: "eTag2", partNumber: 2) + ], + uploadId: "uploadId" + )) + adapter.listParts(bucket: "bucket", key: "key", uploadId: "uploadId") { result in + XCTAssertEqual(self.awsS3.listPartsCount, 1) + guard case .success(let response) = result else { + XCTFail("Expected success") + return + } + XCTAssertEqual(response.bucket, "bucket") + XCTAssertEqual(response.key, "key") + XCTAssertEqual(response.uploadId, "uploadId") + XCTAssertEqual(response.parts.count, 2) + listPartsExpectation.fulfill() + } + + waitForExpectations(timeout: 1) + } + + /// Given: An AWSS3Adapter + /// When: listParts is invoked and the s3 client returns an invalid response + /// Then: A .failure result is returned with an .unknown error + func testListParts_withWrongResponse_shouldFail() { + let listPartsExpectation = expectation(description: "List Parts") + adapter.listParts(bucket: "bucket", key: "key", uploadId: "uploadId") { result in + XCTAssertEqual(self.awsS3.listPartsCount, 1) + guard case .failure(let error) = result, + case .unknown(let description, _) = error else { + XCTFail("Expected StorageError.unknown") + return + } + XCTAssertEqual(description, "ListParts response is invalid") + listPartsExpectation.fulfill() + } + + waitForExpectations(timeout: 1) + } + + /// Given: An AWSS3Adapter + /// When: listParts is invoked and the s3 client returns an error + /// Then: A .failure result is returned + func testListParts_withError_shouldFail() { + let listPartsExpectation = expectation(description: "List Parts") + awsS3.listPartsResult = .failure(StorageError.authError("AuthError", "", nil)) + adapter.listParts(bucket: "bucket", key: "key", uploadId: "uploadId") { result in + XCTAssertEqual(self.awsS3.listPartsCount, 1) + guard case .failure(let error) = result, + case .authError(let description, _, _) = error else { + XCTFail("Expected StorageError.authError") + return + } + XCTAssertEqual(description, "AuthError") + listPartsExpectation.fulfill() + } + + waitForExpectations(timeout: 1) + } + + /// Given: An AWSS3Adapter + /// When: completeMultipartUpload is invoked and the s3 client returns a valid response + /// Then: A .success result is returned containing the corresponding parsed response + func testCompleteMultipartUpload_withSuccess_shouldSucceed() { + let completeMultipartUploadExpectation = expectation(description: "Complete Multipart Upload") + awsS3.completeMultipartUploadResult = .success(.init( + eTag: "eTag" + )) + adapter.completeMultipartUpload(.init( + bucket: "bucket", + key: "key", + uploadId: "uploadId", + parts: [.init(partNumber: 1, eTag: "eTag1"), .init(partNumber: 2, eTag: "eTag2")] + )) { result in + XCTAssertEqual(self.awsS3.completeMultipartUploadCount, 1) + guard case .success(let response) = result else { + XCTFail("Expected success") + return + } + XCTAssertEqual(response.bucket, "bucket") + XCTAssertEqual(response.key, "key") + XCTAssertEqual(response.eTag, "eTag") + completeMultipartUploadExpectation.fulfill() + } + + waitForExpectations(timeout: 1) + } + + /// Given: An AWSS3Adapter + /// When: completeMultipartUpload is invoked and the s3 client returns an invalid response + /// Then: A .failure result is returned with .unknown error + func testCompleteMultipartUpload_withWrongResponse_shouldFail() { + let completeMultipartUploadExpectation = expectation(description: "Complete Multipart Upload") + adapter.completeMultipartUpload(.init(bucket: "bucket", key: "key", uploadId: "uploadId", parts: [])) { result in + XCTAssertEqual(self.awsS3.completeMultipartUploadCount, 1) + guard case .failure(let error) = result, + case .unknown(let description, _) = error else { + XCTFail("Expected StorageError.unknown") + return + } + XCTAssertEqual(description, "Invalid response for completing multipart upload") + completeMultipartUploadExpectation.fulfill() + } + + waitForExpectations(timeout: 1) + } + + /// Given: An AWSS3Adapter + /// When: completeMultipartUpload is invoked and the s3 client returns an error + /// Then: A .failure result is returned + func testCompleteMultipartUpload_withError_shouldFail() { + let completeMultipartUploadExpectation = expectation(description: "Complete Multipart Upload") + awsS3.completeMultipartUploadResult = .failure(StorageError.authError("AuthError", "", nil)) + adapter.completeMultipartUpload(.init(bucket: "bucket", key: "key", uploadId: "uploadId", parts: [])) { result in + XCTAssertEqual(self.awsS3.completeMultipartUploadCount, 1) + guard case .failure(let error) = result, + case .authError(let description, _, _) = error else { + XCTFail("Expected StorageError.authError") + return + } + XCTAssertEqual(description, "AuthError") + completeMultipartUploadExpectation.fulfill() + } + + waitForExpectations(timeout: 1) + } + + /// Given: An AWSS3Adapter + /// When: abortMultipartUpload is invoked and the s3 client returns a valid response + /// Then: A .success result is returned + func testAbortMultipartUpload_withSuccess_shouldSucceed() { + let abortExpectation = expectation(description: "Abort Multipart Upload") + adapter.abortMultipartUpload(.init(bucket: "bucket", key: "key", uploadId: "uploadId")) { result in + XCTAssertEqual(self.awsS3.abortMultipartUploadCount, 1) + guard case .success = result else { + XCTFail("Expected success") + return + } + abortExpectation.fulfill() + } + + waitForExpectations(timeout: 1) + } + + /// Given: An AWSS3Adapter + /// When: abortMultipartUpload is invoked and the s3 client returns an error + /// Then: A .failure result is returned + func testAbortMultipartUpload_withError_shouldFail() { + let abortExpectation = expectation(description: "Abort Multipart Upload") + awsS3.abortMultipartUploadResult = .failure(StorageError.keyNotFound("InvalidKey", "", "", nil)) + adapter.abortMultipartUpload(.init(bucket: "bucket", key: "key", uploadId: "uploadId")) { result in + XCTAssertEqual(self.awsS3.abortMultipartUploadCount, 1) + guard case .failure(let error) = result, + case .keyNotFound(let key, _, _, _) = error else { + XCTFail("Expected StorageError.keyNotFound") + return + } + XCTAssertEqual(key, "InvalidKey") + abortExpectation.fulfill() + } + waitForExpectations(timeout: 1) + } + + /// Given: An AWSS3Adapter + /// When: getS3 is invoked + /// Then: The underlying S3ClientProtocol instance is returned + func testGetS3() { + XCTAssertTrue(adapter.getS3() is S3ClientMock) + } +} + +private class S3ClientMock: S3ClientProtocol { + var deleteObjectCount = 0 + var deleteObjectResult: Result = .success(.init()) + func deleteObject(input: AWSS3.DeleteObjectInput) async throws -> AWSS3.DeleteObjectOutputResponse { + deleteObjectCount += 1 + return try deleteObjectResult.get() + } + + var listObjectsV2Count = 0 + var listObjectsV2Result: Result = .success(.init()) + func listObjectsV2(input: AWSS3.ListObjectsV2Input) async throws -> AWSS3.ListObjectsV2OutputResponse { + listObjectsV2Count += 1 + return try listObjectsV2Result.get() + } + + var createMultipartUploadCount = 0 + var createMultipartUploadResult: Result = .success(.init()) + func createMultipartUpload(input: AWSS3.CreateMultipartUploadInput) async throws -> AWSS3.CreateMultipartUploadOutputResponse { + createMultipartUploadCount += 1 + return try createMultipartUploadResult.get() + } + + var listPartsCount = 0 + var listPartsResult: Result = .success(.init()) + func listParts(input: AWSS3.ListPartsInput) async throws -> AWSS3.ListPartsOutputResponse { + listPartsCount += 1 + return try listPartsResult.get() + } + + var completeMultipartUploadCount = 0 + var completeMultipartUploadResult: Result = .success(.init()) + func completeMultipartUpload(input: AWSS3.CompleteMultipartUploadInput) async throws -> AWSS3.CompleteMultipartUploadOutputResponse { + completeMultipartUploadCount += 1 + return try completeMultipartUploadResult.get() + } + + var abortMultipartUploadCount = 0 + var abortMultipartUploadResult: Result = .success(.init()) + func abortMultipartUpload(input: AWSS3.AbortMultipartUploadInput) async throws -> AWSS3.AbortMultipartUploadOutputResponse { + abortMultipartUploadCount += 1 + return try abortMultipartUploadResult.get() + } + + func copyObject(input: AWSS3.CopyObjectInput) async throws -> AWSS3.CopyObjectOutputResponse { + fatalError("Not Implemented") + } + + func createBucket(input: AWSS3.CreateBucketInput) async throws -> AWSS3.CreateBucketOutputResponse { + fatalError("Not Implemented") + } + + func deleteBucket(input: AWSS3.DeleteBucketInput) async throws -> AWSS3.DeleteBucketOutputResponse { + fatalError("Not Implemented") + } + + func deleteBucketAnalyticsConfiguration(input: AWSS3.DeleteBucketAnalyticsConfigurationInput) async throws -> AWSS3.DeleteBucketAnalyticsConfigurationOutputResponse { + fatalError("Not Implemented") + } + + func deleteBucketCors(input: AWSS3.DeleteBucketCorsInput) async throws -> AWSS3.DeleteBucketCorsOutputResponse { + fatalError("Not Implemented") + } + + func deleteBucketEncryption(input: AWSS3.DeleteBucketEncryptionInput) async throws -> AWSS3.DeleteBucketEncryptionOutputResponse { + fatalError("Not Implemented") + } + + func deleteBucketIntelligentTieringConfiguration(input: AWSS3.DeleteBucketIntelligentTieringConfigurationInput) async throws -> AWSS3.DeleteBucketIntelligentTieringConfigurationOutputResponse { + fatalError("Not Implemented") + } + + func deleteBucketInventoryConfiguration(input: AWSS3.DeleteBucketInventoryConfigurationInput) async throws -> AWSS3.DeleteBucketInventoryConfigurationOutputResponse { + fatalError("Not Implemented") + } + + func deleteBucketLifecycle(input: AWSS3.DeleteBucketLifecycleInput) async throws -> AWSS3.DeleteBucketLifecycleOutputResponse { + fatalError("Not Implemented") + } + + func deleteBucketMetricsConfiguration(input: AWSS3.DeleteBucketMetricsConfigurationInput) async throws -> AWSS3.DeleteBucketMetricsConfigurationOutputResponse { + fatalError("Not Implemented") + } + + func deleteBucketOwnershipControls(input: AWSS3.DeleteBucketOwnershipControlsInput) async throws -> AWSS3.DeleteBucketOwnershipControlsOutputResponse { + fatalError("Not Implemented") + } + + func deleteBucketPolicy(input: AWSS3.DeleteBucketPolicyInput) async throws -> AWSS3.DeleteBucketPolicyOutputResponse { + fatalError("Not Implemented") + } + + func deleteBucketReplication(input: AWSS3.DeleteBucketReplicationInput) async throws -> AWSS3.DeleteBucketReplicationOutputResponse { + fatalError("Not Implemented") + } + + func deleteBucketTagging(input: AWSS3.DeleteBucketTaggingInput) async throws -> AWSS3.DeleteBucketTaggingOutputResponse { + fatalError("Not Implemented") + } + + func deleteBucketWebsite(input: AWSS3.DeleteBucketWebsiteInput) async throws -> AWSS3.DeleteBucketWebsiteOutputResponse { + fatalError("Not Implemented") + } + + func deleteObjects(input: AWSS3.DeleteObjectsInput) async throws -> AWSS3.DeleteObjectsOutputResponse { + fatalError("Not Implemented") + } + + func deleteObjectTagging(input: AWSS3.DeleteObjectTaggingInput) async throws -> AWSS3.DeleteObjectTaggingOutputResponse { + fatalError("Not Implemented") + } + + func deletePublicAccessBlock(input: AWSS3.DeletePublicAccessBlockInput) async throws -> AWSS3.DeletePublicAccessBlockOutputResponse { + fatalError("Not Implemented") + } + + func getBucketAccelerateConfiguration(input: AWSS3.GetBucketAccelerateConfigurationInput) async throws -> AWSS3.GetBucketAccelerateConfigurationOutputResponse { + fatalError("Not Implemented") + } + + func getBucketAcl(input: AWSS3.GetBucketAclInput) async throws -> AWSS3.GetBucketAclOutputResponse { + fatalError("Not Implemented") + } + + func getBucketAnalyticsConfiguration(input: AWSS3.GetBucketAnalyticsConfigurationInput) async throws -> AWSS3.GetBucketAnalyticsConfigurationOutputResponse { + fatalError("Not Implemented") + } + + func getBucketCors(input: AWSS3.GetBucketCorsInput) async throws -> AWSS3.GetBucketCorsOutputResponse { + fatalError("Not Implemented") + } + + func getBucketEncryption(input: AWSS3.GetBucketEncryptionInput) async throws -> AWSS3.GetBucketEncryptionOutputResponse { + fatalError("Not Implemented") + } + + func getBucketIntelligentTieringConfiguration(input: AWSS3.GetBucketIntelligentTieringConfigurationInput) async throws -> AWSS3.GetBucketIntelligentTieringConfigurationOutputResponse { + fatalError("Not Implemented") + } + + func getBucketInventoryConfiguration(input: AWSS3.GetBucketInventoryConfigurationInput) async throws -> AWSS3.GetBucketInventoryConfigurationOutputResponse { + fatalError("Not Implemented") + } + + func getBucketLifecycleConfiguration(input: AWSS3.GetBucketLifecycleConfigurationInput) async throws -> AWSS3.GetBucketLifecycleConfigurationOutputResponse { + fatalError("Not Implemented") + } + + func getBucketLocation(input: AWSS3.GetBucketLocationInput) async throws -> AWSS3.GetBucketLocationOutputResponse { + fatalError("Not Implemented") + } + + func getBucketLogging(input: AWSS3.GetBucketLoggingInput) async throws -> AWSS3.GetBucketLoggingOutputResponse { + fatalError("Not Implemented") + } + + func getBucketMetricsConfiguration(input: AWSS3.GetBucketMetricsConfigurationInput) async throws -> AWSS3.GetBucketMetricsConfigurationOutputResponse { + fatalError("Not Implemented") + } + + func getBucketNotificationConfiguration(input: AWSS3.GetBucketNotificationConfigurationInput) async throws -> AWSS3.GetBucketNotificationConfigurationOutputResponse { + fatalError("Not Implemented") + } + + func getBucketOwnershipControls(input: AWSS3.GetBucketOwnershipControlsInput) async throws -> AWSS3.GetBucketOwnershipControlsOutputResponse { + fatalError("Not Implemented") + } + + func getBucketPolicy(input: AWSS3.GetBucketPolicyInput) async throws -> AWSS3.GetBucketPolicyOutputResponse { + fatalError("Not Implemented") + } + + func getBucketPolicyStatus(input: AWSS3.GetBucketPolicyStatusInput) async throws -> AWSS3.GetBucketPolicyStatusOutputResponse { + fatalError("Not Implemented") + } + + func getBucketReplication(input: AWSS3.GetBucketReplicationInput) async throws -> AWSS3.GetBucketReplicationOutputResponse { + fatalError("Not Implemented") + } + + func getBucketRequestPayment(input: AWSS3.GetBucketRequestPaymentInput) async throws -> AWSS3.GetBucketRequestPaymentOutputResponse { + fatalError("Not Implemented") + } + + func getBucketTagging(input: AWSS3.GetBucketTaggingInput) async throws -> AWSS3.GetBucketTaggingOutputResponse { + fatalError("Not Implemented") + } + + func getBucketVersioning(input: AWSS3.GetBucketVersioningInput) async throws -> AWSS3.GetBucketVersioningOutputResponse { + fatalError("Not Implemented") + } + + func getBucketWebsite(input: AWSS3.GetBucketWebsiteInput) async throws -> AWSS3.GetBucketWebsiteOutputResponse { + fatalError("Not Implemented") + } + + func getObject(input: AWSS3.GetObjectInput) async throws -> AWSS3.GetObjectOutputResponse { + fatalError("Not Implemented") + } + + func getObjectAcl(input: AWSS3.GetObjectAclInput) async throws -> AWSS3.GetObjectAclOutputResponse { + fatalError("Not Implemented") + } + + func getObjectAttributes(input: AWSS3.GetObjectAttributesInput) async throws -> AWSS3.GetObjectAttributesOutputResponse { + fatalError("Not Implemented") + } + + func getObjectLegalHold(input: AWSS3.GetObjectLegalHoldInput) async throws -> AWSS3.GetObjectLegalHoldOutputResponse { + fatalError("Not Implemented") + } + + func getObjectLockConfiguration(input: AWSS3.GetObjectLockConfigurationInput) async throws -> AWSS3.GetObjectLockConfigurationOutputResponse { + fatalError("Not Implemented") + } + + func getObjectRetention(input: AWSS3.GetObjectRetentionInput) async throws -> AWSS3.GetObjectRetentionOutputResponse { + fatalError("Not Implemented") + } + + func getObjectTagging(input: AWSS3.GetObjectTaggingInput) async throws -> AWSS3.GetObjectTaggingOutputResponse { + fatalError("Not Implemented") + } + + func getObjectTorrent(input: AWSS3.GetObjectTorrentInput) async throws -> AWSS3.GetObjectTorrentOutputResponse { + fatalError("Not Implemented") + } + + func getPublicAccessBlock(input: AWSS3.GetPublicAccessBlockInput) async throws -> AWSS3.GetPublicAccessBlockOutputResponse { + fatalError("Not Implemented") + } + + func headBucket(input: AWSS3.HeadBucketInput) async throws -> AWSS3.HeadBucketOutputResponse { + fatalError("Not Implemented") + } + + func headObject(input: AWSS3.HeadObjectInput) async throws -> AWSS3.HeadObjectOutputResponse { + fatalError("Not Implemented") + } + + func listBucketAnalyticsConfigurations(input: AWSS3.ListBucketAnalyticsConfigurationsInput) async throws -> AWSS3.ListBucketAnalyticsConfigurationsOutputResponse { + fatalError("Not Implemented") + } + + func listBucketIntelligentTieringConfigurations(input: AWSS3.ListBucketIntelligentTieringConfigurationsInput) async throws -> AWSS3.ListBucketIntelligentTieringConfigurationsOutputResponse { + fatalError("Not Implemented") + } + + func listBucketInventoryConfigurations(input: AWSS3.ListBucketInventoryConfigurationsInput) async throws -> AWSS3.ListBucketInventoryConfigurationsOutputResponse { + fatalError("Not Implemented") + } + + func listBucketMetricsConfigurations(input: AWSS3.ListBucketMetricsConfigurationsInput) async throws -> AWSS3.ListBucketMetricsConfigurationsOutputResponse { + fatalError("Not Implemented") + } + + func listBuckets(input: AWSS3.ListBucketsInput) async throws -> AWSS3.ListBucketsOutputResponse { + fatalError("Not Implemented") + } + + func listMultipartUploads(input: AWSS3.ListMultipartUploadsInput) async throws -> AWSS3.ListMultipartUploadsOutputResponse { + fatalError("Not Implemented") + } + + func listObjects(input: AWSS3.ListObjectsInput) async throws -> AWSS3.ListObjectsOutputResponse { + fatalError("Not Implemented") + } + + func listObjectVersions(input: AWSS3.ListObjectVersionsInput) async throws -> AWSS3.ListObjectVersionsOutputResponse { + fatalError("Not Implemented") + } + + func putBucketAccelerateConfiguration(input: AWSS3.PutBucketAccelerateConfigurationInput) async throws -> AWSS3.PutBucketAccelerateConfigurationOutputResponse { + fatalError("Not Implemented") + } + + func putBucketAcl(input: AWSS3.PutBucketAclInput) async throws -> AWSS3.PutBucketAclOutputResponse { + fatalError("Not Implemented") + } + + func putBucketAnalyticsConfiguration(input: AWSS3.PutBucketAnalyticsConfigurationInput) async throws -> AWSS3.PutBucketAnalyticsConfigurationOutputResponse { + fatalError("Not Implemented") + } + + func putBucketCors(input: AWSS3.PutBucketCorsInput) async throws -> AWSS3.PutBucketCorsOutputResponse { + fatalError("Not Implemented") + } + + func putBucketEncryption(input: AWSS3.PutBucketEncryptionInput) async throws -> AWSS3.PutBucketEncryptionOutputResponse { + fatalError("Not Implemented") + } + + func putBucketIntelligentTieringConfiguration(input: AWSS3.PutBucketIntelligentTieringConfigurationInput) async throws -> AWSS3.PutBucketIntelligentTieringConfigurationOutputResponse { + fatalError("Not Implemented") + } + + func putBucketInventoryConfiguration(input: AWSS3.PutBucketInventoryConfigurationInput) async throws -> AWSS3.PutBucketInventoryConfigurationOutputResponse { + fatalError("Not Implemented") + } + + func putBucketLifecycleConfiguration(input: AWSS3.PutBucketLifecycleConfigurationInput) async throws -> AWSS3.PutBucketLifecycleConfigurationOutputResponse { + fatalError("Not Implemented") + } + + func putBucketLogging(input: AWSS3.PutBucketLoggingInput) async throws -> AWSS3.PutBucketLoggingOutputResponse { + fatalError("Not Implemented") + } + + func putBucketMetricsConfiguration(input: AWSS3.PutBucketMetricsConfigurationInput) async throws -> AWSS3.PutBucketMetricsConfigurationOutputResponse { + fatalError("Not Implemented") + } + + func putBucketNotificationConfiguration(input: AWSS3.PutBucketNotificationConfigurationInput) async throws -> AWSS3.PutBucketNotificationConfigurationOutputResponse { + fatalError("Not Implemented") + } + + func putBucketOwnershipControls(input: AWSS3.PutBucketOwnershipControlsInput) async throws -> AWSS3.PutBucketOwnershipControlsOutputResponse { + fatalError("Not Implemented") + } + + func putBucketPolicy(input: AWSS3.PutBucketPolicyInput) async throws -> AWSS3.PutBucketPolicyOutputResponse { + fatalError("Not Implemented") + } + + func putBucketReplication(input: AWSS3.PutBucketReplicationInput) async throws -> AWSS3.PutBucketReplicationOutputResponse { + fatalError("Not Implemented") + } + + func putBucketRequestPayment(input: AWSS3.PutBucketRequestPaymentInput) async throws -> AWSS3.PutBucketRequestPaymentOutputResponse { + fatalError("Not Implemented") + } + + func putBucketTagging(input: AWSS3.PutBucketTaggingInput) async throws -> AWSS3.PutBucketTaggingOutputResponse { + fatalError("Not Implemented") + } + + func putBucketVersioning(input: AWSS3.PutBucketVersioningInput) async throws -> AWSS3.PutBucketVersioningOutputResponse { + fatalError("Not Implemented") + } + + func putBucketWebsite(input: AWSS3.PutBucketWebsiteInput) async throws -> AWSS3.PutBucketWebsiteOutputResponse { + fatalError("Not Implemented") + } + + func putObject(input: AWSS3.PutObjectInput) async throws -> AWSS3.PutObjectOutputResponse { + fatalError("Not Implemented") + } + + func putObjectAcl(input: AWSS3.PutObjectAclInput) async throws -> AWSS3.PutObjectAclOutputResponse { + fatalError("Not Implemented") + } + + func putObjectLegalHold(input: AWSS3.PutObjectLegalHoldInput) async throws -> AWSS3.PutObjectLegalHoldOutputResponse { + fatalError("Not Implemented") + } + + func putObjectLockConfiguration(input: AWSS3.PutObjectLockConfigurationInput) async throws -> AWSS3.PutObjectLockConfigurationOutputResponse { + fatalError("Not Implemented") + } + + func putObjectRetention(input: AWSS3.PutObjectRetentionInput) async throws -> AWSS3.PutObjectRetentionOutputResponse { + fatalError("Not Implemented") + } + + func putObjectTagging(input: AWSS3.PutObjectTaggingInput) async throws -> AWSS3.PutObjectTaggingOutputResponse { + fatalError("Not Implemented") + } + + func putPublicAccessBlock(input: AWSS3.PutPublicAccessBlockInput) async throws -> AWSS3.PutPublicAccessBlockOutputResponse { + fatalError("Not Implemented") + } + + func restoreObject(input: AWSS3.RestoreObjectInput) async throws -> AWSS3.RestoreObjectOutputResponse { + fatalError("Not Implemented") + } + + func selectObjectContent(input: AWSS3.SelectObjectContentInput) async throws -> AWSS3.SelectObjectContentOutputResponse { + fatalError("Not Implemented") + } + + func uploadPart(input: AWSS3.UploadPartInput) async throws -> AWSS3.UploadPartOutputResponse { + fatalError("Not Implemented") + } + + func uploadPartCopy(input: AWSS3.UploadPartCopyInput) async throws -> AWSS3.UploadPartCopyOutputResponse { + fatalError("Not Implemented") + } + + func writeGetObjectResponse(input: AWSS3.WriteGetObjectResponseInput) async throws -> AWSS3.WriteGetObjectResponseOutputResponse { + fatalError("Not Implemented") + } +} diff --git a/AmplifyPlugins/Storage/Tests/AWSS3StoragePluginTests/Service/Storage/AWSS3StorageServiceTests.swift b/AmplifyPlugins/Storage/Tests/AWSS3StoragePluginTests/Service/Storage/AWSS3StorageServiceTests.swift new file mode 100644 index 0000000000..25f5410d4a --- /dev/null +++ b/AmplifyPlugins/Storage/Tests/AWSS3StoragePluginTests/Service/Storage/AWSS3StorageServiceTests.swift @@ -0,0 +1,455 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + +@testable import Amplify +@testable import AWSPluginsTestCommon +@testable import AWSS3StoragePlugin +import ClientRuntime +import AWSS3 +import XCTest + +class AWSS3StorageServiceTests: XCTestCase { + private var service: AWSS3StorageService! + private var authService: MockAWSAuthService! + private var database: StorageTransferDatabaseMock! + private var task: StorageTransferTask! + private var fileSystem: MockFileSystem! + + override func setUp() async throws { + authService = MockAWSAuthService() + database = StorageTransferDatabaseMock() + fileSystem = MockFileSystem() + task = StorageTransferTask( + transferType: .download(onEvent: { _ in}), + bucket: "bucket", + key: "key" + ) + task.uploadId = "uploadId" + task.sessionTask = MockStorageSessionTask(taskIdentifier: 1) + database.recoverResult = .success([ + .init(transferTask: task, + multipartUploads: [ + .created( + uploadId: "uploadId", + uploadFile:UploadFile( + fileURL: FileSystem.default.createTemporaryFileURL(), + temporaryFileCreated: true, + size: UInt64(Bytes.megabytes(12).bytes) + ) + ) + ] + ) + ]) + service = try AWSS3StorageService( + authService: authService, + region: "region", + bucket: "bucket", + httpClientEngineProxy: MockHttpClientEngineProxy(), + storageTransferDatabase: database, + fileSystem: fileSystem, + logger: MockLogger() + ) + } + + override func tearDown() { + authService = nil + service = nil + database = nil + task = nil + fileSystem = nil + } + + /// Given: An AWSS3StorageService + /// When: it's deallocated + /// Then: StorageBackgroundEventsRegistry.identifier should be set to nil + func testDeinit_shouldUnregisterIdentifier() { + XCTAssertNotNil(StorageBackgroundEventsRegistry.identifier) + service = nil + XCTAssertNil(StorageBackgroundEventsRegistry.identifier) + } + + /// Given: An AWSS3StorageService + /// When: reset is invoked + /// Then: Its members should be set to nil + func testReset_shouldSetValuesToNil() { + service.reset() + XCTAssertNil(service.preSignedURLBuilder) + XCTAssertNil(service.awsS3) + XCTAssertNil(service.region) + XCTAssertNil(service.bucket) + XCTAssertTrue(service.tasks.isEmpty) + XCTAssertTrue(service.multipartUploadSessions.isEmpty) + } + + /// Given: An AWSS3StorageService + /// When: attachEventHandlers is invoked and a .completed event is sent + /// Then: A .completed event is dispatched to the event handler + func testAttachEventHandlers() { + let expectation = self.expectation(description: "Attach Event Handlers") + service.attachEventHandlers( + onUpload: { event in + guard case .completed(_) = event else { + XCTFail("Expected completed") + return + } + expectation.fulfill() + } + ) + XCTAssertNotNil(database.onUploadHandler) + database.onUploadHandler?(.completed(())) + waitForExpectations(timeout: 1) + } + + /// Given: An AWSS3StorageService + /// When: register is invoked with a task + /// Then: The task should be added to its map of tasks + func testRegisterTask_shouldAddItToTasksDictionary() { + service.register(task: task) + XCTAssertEqual(service.tasks.count, 1) + XCTAssertNotNil(service.tasks[1]) + } + + /// Given: An AWSS3StorageService with a task in its map of tasks + /// When: unregister is invoked with said task + /// Then: The task should be removed from the map of tasks + func testUnregisterTask_shouldRemoveItToTasksDictionary() { + service.tasks = [ + 1: task + ] + service.unregister(task: task) + XCTAssertTrue(service.tasks.isEmpty) + XCTAssertNil(service.tasks[1]) + } + + /// Given: An AWSS3StorageService with some tasks in its map of tasks + /// When: unregister is invoked with an identifier that is known to be mapped to a task. + /// Then: The task corresponding to the given identifier should be removed from the map of tasks + func testUnregisterTaskIdentifiers_shouldRemoveItToTasksDictionary() { + service.tasks = [ + 1: task, + 2: task + ] + service.unregister(taskIdentifiers: [1]) + XCTAssertEqual(service.tasks.count, 1) + XCTAssertNotNil(service.tasks[2]) + XCTAssertNil(service.tasks[1]) + } + + /// Given: An AWSS3StorageService with a task in its map of tasks + /// When: findTask is invoked with the identifier known to be mapped to a task + /// Then: The task corresponding to the given identifier is returned + func testFindTask_shouldReturnTask() { + service.tasks = [ + 1: task + ] + XCTAssertNotNil(service.findTask(taskIdentifier: 1)) + } + + /// Given: An AWSS3StorageService + /// When: validateParameters is invoked with an empty bucket parameter + /// Then: A .validation error is thrown + func testValidateParameters_withEmptyBucket_shouldThrowError() { + do { + try service.validateParameters(bucket: "", key: "key", accelerationModeEnabled: true) + XCTFail("Expected error") + } catch { + guard case .validation(let field, let description, let recovery, _) = error as? StorageError else { + XCTFail("Expected StorageError.validation") + return + } + XCTAssertEqual(field, "bucket") + XCTAssertEqual(description, "Invalid bucket specified.") + XCTAssertEqual(recovery, "Please specify a bucket name or configure the bucket property.") + } + } + + /// Given: An AWSS3StorageService + /// When: validateParameters is invoked with an empty key parameter + /// Then: A .validation error is thrown + func testValidateParameters_withEmptyKey_shouldThrowError() { + do { + try service.validateParameters(bucket: "bucket", key: "", accelerationModeEnabled: true) + XCTFail("Expected error") + } catch { + guard case .validation(let field, let description, let recovery, _) = error as? StorageError else { + XCTFail("Expected StorageError.validation") + return + } + XCTAssertEqual(field, "key") + XCTAssertEqual(description, "Invalid key specified.") + XCTAssertEqual(recovery, "Please specify a key.") + } + } + + /// Given: An AWSS3StorageService + /// When: validateParameters is invoked with valid bucket and key parameters + /// Then: No error is thrown + func testValidateParameters_withValidParams_shouldNotThrowError() { + do { + try service.validateParameters(bucket: "bucket", key: "key", accelerationModeEnabled: true) + } catch { + XCTFail("Expected success, got \(error)") + } + } + + /// Given: An AWSS3StorageService + /// When: createTransferTask is invoked with valid parameters + /// Then: A task is returned with attributes matching the ones provided + func testCreateTransferTask_shouldReturnTask() { + let task = service.createTransferTask( + transferType: .upload(onEvent: { event in }), + bucket: "bucket", + key: "key", + requestHeaders: [ + "header": "value" + ] + ) + XCTAssertEqual(task.bucket, "bucket") + XCTAssertEqual(task.key, "key") + XCTAssertEqual(task.requestHeaders?.count, 1) + XCTAssertEqual(task.requestHeaders?["header"], "value") + guard case .upload(_) = task.transferType else { + XCTFail("Expected .upload transferType") + return + } + } + + /// Given: An AWSS3StorageService with a non-completed download task + /// When: completeDownload is invoked for the identifier matching the task + /// Then: The task is marked as completed and a .completed event is dispatched + func testCompleteDownload_shouldReturnData() { + let expectation = self.expectation(description: "Complete Download") + + let downloadTask = StorageTransferTask( + transferType: .download(onEvent: { event in + guard case .completed(let data) = event, + let data = data else { + XCTFail("Expected .completed event with data") + return + } + XCTAssertEqual(String(decoding: data, as: UTF8.self), "someFile") + expectation.fulfill() + }), + bucket: "bucket", + key: "key" + ) + + let sourceUrl = FileManager.default.temporaryDirectory.appendingPathComponent("\(UUID().uuidString).txt") + try! "someFile".write(to: sourceUrl, atomically: true, encoding: .utf8) + + service.tasks = [ + 1: downloadTask + ] + + service.completeDownload(taskIdentifier: 1, sourceURL: sourceUrl) + XCTAssertEqual(downloadTask.status, .completed) + waitForExpectations(timeout: 1) + } + + /// Given: An AWSS3StorageService with a non-completed download task that sets a location + /// When: completeDownload is invoked for the identifier matching the task + /// Then: The task is marked as completed and the file is moved to the expected location + func testCompleteDownload_withLocation_shouldMoveFileToLocation() { + let temporaryDirectory = FileManager.default.temporaryDirectory + let location = temporaryDirectory.appendingPathComponent("\(UUID().uuidString)-newFile.txt") + + let downloadTask = StorageTransferTask( + transferType: .download(onEvent: { _ in }), + bucket: "bucket", + key: "key", + location: location + ) + + let sourceUrl = temporaryDirectory.appendingPathComponent("\(UUID().uuidString)-oldFile.txt") + try! "someFile".write(to: sourceUrl, atomically: true, encoding: .utf8) + + service.tasks = [ + 1: downloadTask + ] + + service.completeDownload(taskIdentifier: 1, sourceURL: sourceUrl) + XCTAssertTrue(FileManager.default.fileExists(atPath: location.path)) + XCTAssertFalse(FileManager.default.fileExists(atPath: sourceUrl.path)) + XCTAssertEqual(downloadTask.status, .completed) + } + + /// Given: An AWSS3StorageService with a non-completed download task that sets a location + /// When: completeDownload is invoked for the identifier matching the task, but the file system fails to move the file + /// Then: The task is marked as error and the file is not moved to the expected location + func testCompleteDownload_withLocation_andError_shouldFailTask() { + let temporaryDirectory = FileManager.default.temporaryDirectory + let location = temporaryDirectory.appendingPathComponent("\(UUID().uuidString)-newFile.txt") + + let downloadTask = StorageTransferTask( + transferType: .download(onEvent: { _ in }), + bucket: "bucket", + key: "key", + location: location + ) + + let sourceUrl = temporaryDirectory.appendingPathComponent("\(UUID().uuidString)-oldFile.txt") + try! "someFile".write(to: sourceUrl, atomically: true, encoding: .utf8) + + service.tasks = [ + 1: downloadTask + ] + + fileSystem.moveFileError = StorageError.unknown("Unable to move file", nil) + service.completeDownload(taskIdentifier: 1, sourceURL: sourceUrl) + XCTAssertFalse(FileManager.default.fileExists(atPath: location.path)) + XCTAssertTrue(FileManager.default.fileExists(atPath: sourceUrl.path)) + XCTAssertEqual(downloadTask.status, .error) + } + + /// Given: An AWSS3StorageService with a non-completed upload task that sets a location + /// When: completeDownload is invoked for the identifier matching the task + /// Then: The task status is not updated and an .upload event is not dispatched + func testCompleteDownload_withNoDownload_shouldDoNothing() { + let expectation = self.expectation(description: "Complete Download") + expectation.isInverted = true + + let uploadTask = StorageTransferTask( + transferType: .upload(onEvent: { event in + XCTFail("Should not report event") + expectation.fulfill() + }), + bucket: "bucket", + key: "key" + ) + + let sourceUrl = FileManager.default.temporaryDirectory.appendingPathComponent("\(UUID().uuidString).txt") + try! "someFile".write(to: sourceUrl, atomically: true, encoding: .utf8) + + service.tasks = [ + 1: uploadTask + ] + + service.completeDownload(taskIdentifier: 1, sourceURL: sourceUrl) + XCTAssertNotEqual(uploadTask.status, .completed) + XCTAssertNotEqual(uploadTask.status, .error) + waitForExpectations(timeout: 1) + } + + /// Given: An AWSS3StorageService that cannot create a pre signed url + /// When: upload is invoked + /// Then: A .failed event is dispatched with an .unknown error + func testUpload_withoutPreSignedURL_shouldSendFailEvent() { + let data = "someData".data(using: .utf8)! + let expectation = self.expectation(description: "Upload") + service.upload( + serviceKey: "key", + uploadSource: .data(data), + contentType: "application/json", + metadata: [:], + accelerate: true, + onEvent: { event in + guard case .failed(let error) = event, + case .unknown(let description, _) = error else { + XCTFail("Expected .failed event with .unknown error, got \(event)") + return + } + XCTAssertEqual(description, "Failed to get pre-signed URL") + expectation.fulfill() + } + ) + + waitForExpectations(timeout: 1) + } + + /// Given: An AWSS3StorageService that can create a pre signed url + /// When: upload is invoked + /// Then: An .initiated event is dispatched + func testUpload_withPreSignedURL_shouldSendInitiatedEvent() { + let data = "someData".data(using: .utf8)! + let expectation = self.expectation(description: "Upload") + service.preSignedURLBuilder = MockAWSS3PreSignedURLBuilder() + service.upload( + serviceKey: "key", + uploadSource: .data(data), + contentType: "application/json", + metadata: [:], + accelerate: true, + onEvent: { event in + guard case .initiated(_) = event else { + XCTFail("Expected .initiated event, got \(event)") + return + } + expectation.fulfill() + } + ) + + waitForExpectations(timeout: 1) + } +} + +private class MockHttpClientEngineProxy: HttpClientEngineProxy { + var target: HttpClientEngine? = nil + + var executeCount = 0 + var executeRequest: SdkHttpRequest? + func execute(request: SdkHttpRequest) async throws -> HttpResponse { + executeCount += 1 + executeRequest = request + return .init(body: .empty, statusCode: .accepted) + } +} + +private class StorageTransferDatabaseMock: StorageTransferDatabase { + + func prepareForBackground(completion: (() -> Void)?) { + completion?() + } + + func insertTransferRequest(task: StorageTransferTask) { + + } + + func updateTransferRequest(task: StorageTransferTask) { + + } + + func removeTransferRequest(task: StorageTransferTask) { + + } + + func defaultTransferType(persistableTransferTask: StoragePersistableTransferTask) -> StorageTransferType? { + return nil + } + + var recoverCount = 0 + var recoverResult: Result = .failure(StorageError.unknown("Result not set", nil)) + func recover(urlSession: StorageURLSession, + completionHandler: @escaping (Result) -> Void) { + recoverCount += 1 + completionHandler(recoverResult) + } + + var attachEventHandlersCount = 0 + var onUploadHandler: AWSS3StorageServiceBehavior.StorageServiceUploadEventHandler? = nil + var onDownloadHandler: AWSS3StorageServiceBehavior.StorageServiceDownloadEventHandler? = nil + var onMultipartUploadHandler: AWSS3StorageServiceBehavior.StorageServiceMultiPartUploadEventHandler? = nil + func attachEventHandlers( + onUpload: AWSS3StorageServiceBehavior.StorageServiceUploadEventHandler?, + onDownload: AWSS3StorageServiceBehavior.StorageServiceDownloadEventHandler?, + onMultipartUpload: AWSS3StorageServiceBehavior.StorageServiceMultiPartUploadEventHandler? + ) { + attachEventHandlersCount += 1 + onUploadHandler = onUpload + onDownloadHandler = onDownload + onMultipartUploadHandler = onMultipartUpload + } +} + +private class MockFileSystem: FileSystem { + var moveFileError: Error? = nil + override func moveFile(from sourceFileURL: URL, to destinationURL: URL) throws { + if let moveFileError = moveFileError { + throw moveFileError + } + try super.moveFile(from: sourceFileURL, to: destinationURL) + } +} diff --git a/AmplifyPlugins/Storage/Tests/AWSS3StoragePluginTests/Support/Internal/DefaultStorageMultipartUploadClientTests.swift b/AmplifyPlugins/Storage/Tests/AWSS3StoragePluginTests/Support/Internal/DefaultStorageMultipartUploadClientTests.swift new file mode 100644 index 0000000000..561543f57d --- /dev/null +++ b/AmplifyPlugins/Storage/Tests/AWSS3StoragePluginTests/Support/Internal/DefaultStorageMultipartUploadClientTests.swift @@ -0,0 +1,458 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + +@testable import Amplify +@testable import func AmplifyTestCommon.XCTAssertThrowFatalError +@testable import AWSS3StoragePlugin +import AWSS3 +import XCTest + +class DefaultStorageMultipartUploadClientTests: XCTestCase { + private var defaultClient: DefaultStorageMultipartUploadClient! + private var serviceProxy: MockStorageServiceProxy! + private var session: MockStorageMultipartUploadSession! + private var awss3Behavior: MockAWSS3Behavior! + private var uploadFile: UploadFile! + + override func setUp() async throws { + awss3Behavior = MockAWSS3Behavior() + serviceProxy = MockStorageServiceProxy( + awsS3: awss3Behavior + ) + let tempFileURL = FileManager.default.temporaryDirectory + .appendingPathComponent(UUID().uuidString) + .appendingPathExtension("txt") + try "Hello World".write(to: tempFileURL, atomically: true, encoding: .utf8) + uploadFile = UploadFile( + fileURL: tempFileURL, + temporaryFileCreated: false, + size: 88 + ) + defaultClient = DefaultStorageMultipartUploadClient( + serviceProxy: serviceProxy, + bucket: "bucket", + key: "key", + uploadFile: uploadFile + ) + session = MockStorageMultipartUploadSession( + client: client, + bucket: "bucket", + key: "key", + onEvent: { event in } + ) + client.integrate(session: session) + } + + private var client: StorageMultipartUploadClient! { + defaultClient + } + + override func tearDown() { + defaultClient = nil + serviceProxy = nil + session = nil + awss3Behavior = nil + uploadFile = nil + } + + /// Given: a DefaultStorageMultipartUploadClient + /// When: createMultipartUpload is invoked and AWSS3Behavior returns .success + /// Then: A .created event is reported to the session and the session is registered + func testCreateMultipartUpload_withSuccess_shouldSucceed() throws { + awss3Behavior.createMultipartUploadExpectation = expectation(description: "Create Multipart Upload") + awss3Behavior.createMultipartUploadResult = .success(.init( + bucket: "bucket", + key: "key", + uploadId: "uploadId" + )) + try client.createMultipartUpload() + + waitForExpectations(timeout: 1) + XCTAssertEqual(awss3Behavior.createMultipartUploadCount, 1) + XCTAssertEqual(session.handleMultipartUploadCount, 2) + XCTAssertEqual(session.failCount, 0) + if case .created(let uploadFile, let uploadId) = try XCTUnwrap(session.lastMultipartUploadEvent) { + XCTAssertEqual(uploadFile.fileURL, uploadFile.fileURL) + XCTAssertEqual(uploadId, "uploadId") + } + XCTAssertEqual(serviceProxy.registerMultipartUploadSessionCount, 1) + } + + /// Given: a DefaultStorageMultipartUploadClient + /// When: createMultipartUpload is invoked and AWSS3Behavior returns .failure + /// Then: An .unknown error is reported to the session and the session is not registered + func testCreateMultipartUpload_withError_shouldFail() throws { + awss3Behavior.createMultipartUploadExpectation = expectation(description: "Create Multipart Upload") + awss3Behavior.createMultipartUploadResult = .failure(.unknown("Unknown Error", nil)) + try client.createMultipartUpload() + + waitForExpectations(timeout: 1) + XCTAssertEqual(awss3Behavior.createMultipartUploadCount, 1) + XCTAssertEqual(session.handleMultipartUploadCount, 1) + XCTAssertEqual(session.failCount, 1) + if case .unknown(let description, _) = try XCTUnwrap(session.lastError) as? StorageError { + XCTAssertEqual(description, "Unknown Error") + } + XCTAssertEqual(serviceProxy.registerMultipartUploadSessionCount, 0) + } + + /// Given: a DefaultStorageMultipartUploadClient + /// When: serviceProxy is set to nil and createMultipartUpload is invoked + /// Then: A fatal error is thrown + func testCreateMultipartUpload_withoutServiceProxy_shouldThrowFatalError() throws { + serviceProxy = nil + try XCTAssertThrowFatalError { + try? self.client.createMultipartUpload() + } + } + + /// Given: a DefaultStorageMultipartUploadClient + /// When: uploadPart is invoked with valid parts + /// Then: A .started event is reported to the session + func testUploadPart_withParts_shouldSucceed() throws { + session.handleUploadPartExpectation = expectation(description: "Upload Part with parts") + + try client.uploadPart( + partNumber: 1, + multipartUpload: .parts( + uploadId: "uploadId", + uploadFile: uploadFile, + partSize: .default, + parts: [ + .pending(bytes: 10), + .pending(bytes: 20) + ] + ), + subTask: .init( + transferType: .upload(onEvent: { event in }), + bucket: "bucket", + key: "key" + ) + ) + + waitForExpectations(timeout: 1) + XCTAssertEqual(session.handleUploadPartCount, 1) + XCTAssertEqual(session.failCount, 0) + if case .started(let partNumber, _) = try XCTUnwrap(session.lastUploadEvent) { + XCTAssertEqual(partNumber, 1) + } + } + + /// Given: a DefaultStorageMultipartUploadClient + /// When: uploadPart is invoked with a non-existing file + /// Then: An error is reported to the session + func testUploadPart_withInvalidFile_shouldFail() throws { + session.failExpectation = expectation(description: "Upload Part with invalid file") + + try client.uploadPart( + partNumber: 1, + multipartUpload: .parts( + uploadId: "uploadId", + uploadFile: .init( + fileURL: FileManager.default.temporaryDirectory.appendingPathComponent("noFile.txt"), + temporaryFileCreated: false, + size: 1024), + partSize: .default, + parts: [ + .pending(bytes: 10), + .pending(bytes: 20) + ] + ), + subTask: .init( + transferType: .upload(onEvent: { event in }), + bucket: "bucket", + key: "key" + ) + ) + + waitForExpectations(timeout: 1) + XCTAssertEqual(session.handleUploadPartCount, 0) + XCTAssertEqual(session.failCount, 1) + XCTAssertNil(session.lastUploadEvent) + XCTAssertNotNil(session.lastError) + } + + /// Given: a DefaultStorageMultipartUploadClient + /// When: serviceProxy is set to nil and uploadPart is invoked + /// Then: A fatal error is thrown + func testUploadPart_withoutServiceProxy_shouldThrowFatalError() throws { + self.serviceProxy = nil + try XCTAssertThrowFatalError { + try? self.client.uploadPart( + partNumber: 1, + multipartUpload: .parts( + uploadId: "uploadId", + uploadFile: self.uploadFile, + partSize: .default, + parts: [ + .pending(bytes: 10), + .pending(bytes: 20) + ] + ), + subTask: .init( + transferType: .upload(onEvent: { event in }), + bucket: "bucket", + key: "key" + ) + ) + } + } + + /// Given: a DefaultStorageMultipartUploadClient + /// When: uploadPart is invoked without parts + /// Then: A fatal error is thrown + func testUploadPart_withoutParts_shouldThrowFatalError() throws { + try XCTAssertThrowFatalError { + try? self.client.uploadPart( + partNumber: 1, + multipartUpload: .created( + uploadId: "uploadId", + uploadFile: self.uploadFile + ), + subTask: .init( + transferType: .upload(onEvent: { event in }), + bucket: "bucket", + key: "key" + ) + ) + } + } + + /// Given: a DefaultStorageMultipartUploadClient + /// When: completeMultipartUpload is invoked and AWSS3Behaviour returns succees + /// Then: A .completed event is reported to the session and the session is unregistered + func testCompleteMultipartUpload_withSuccess_shouldSucceed() throws { + awss3Behavior.completeMultipartUploadExpectation = expectation(description: "Complete Multipart Upload") + awss3Behavior.completeMultipartUploadResult = .success(.init( + bucket: "bucket", + key: "key", + eTag: "eTag" + )) + try client.completeMultipartUpload(uploadId: "uploadId") + + waitForExpectations(timeout: 1) + XCTAssertEqual(awss3Behavior.completeMultipartUploadCount, 1) + XCTAssertEqual(session.handleMultipartUploadCount, 1) + XCTAssertEqual(session.failCount, 0) + if case .completed(let uploadId) = try XCTUnwrap(session.lastMultipartUploadEvent) { + XCTAssertEqual(uploadId, "uploadId") + } + XCTAssertEqual(serviceProxy.unregisterMultipartUploadSessionCount, 1) + } + + /// Given: a DefaultStorageMultipartUploadClient + /// When: completeMultipartUpload is invoked and AWSS3Behaviour returns failure + /// Then: A .unknown error is reported to the session and the session is not unregistered + func testCompleteMultipartUpload_withError_shouldFail() throws { + awss3Behavior.completeMultipartUploadExpectation = expectation(description: "Complete Multipart Upload") + awss3Behavior.completeMultipartUploadResult = .failure(.unknown("Unknown Error", nil)) + try client.completeMultipartUpload(uploadId: "uploadId") + + waitForExpectations(timeout: 1) + XCTAssertEqual(awss3Behavior.completeMultipartUploadCount, 1) + XCTAssertEqual(session.handleMultipartUploadCount, 0) + XCTAssertEqual(session.failCount, 1) + if case .unknown(let description, _) = try XCTUnwrap(session.lastError) as? StorageError { + XCTAssertEqual(description, "Unknown Error") + } + XCTAssertEqual(serviceProxy.unregisterMultipartUploadSessionCount, 1) + } + + /// Given: a DefaultStorageMultipartUploadClient + /// When: serviceProxy is set to nil and completeMultipartUpload is invoked + /// Then: A fatal error is thrown + func testCompleteMultipartUpload_withoutServiceProxy_shouldThrowFatalError() throws { + serviceProxy = nil + try XCTAssertThrowFatalError { + try? self.client.completeMultipartUpload(uploadId: "uploadId") + } + } + + /// Given: a DefaultStorageMultipartUploadClient + /// When: abortMultipartUpload is invoked and AWSS3Behaviour returns success + /// Then: An .aborted event is reported to the session and the session is unregistered + func testAbortMultipartUpload_withSuccess_shouldSucceed() throws { + awss3Behavior.abortMultipartUploadExpectation = expectation(description: "Abort Multipart Upload") + awss3Behavior.abortMultipartUploadResult = .success(()) + try client.abortMultipartUpload(uploadId: "uploadId", error: CancellationError()) + + waitForExpectations(timeout: 1) + XCTAssertEqual(awss3Behavior.abortMultipartUploadCount, 1) + XCTAssertEqual(session.handleMultipartUploadCount, 1) + XCTAssertEqual(session.failCount, 0) + if case .aborted(let uploadId, let error) = try XCTUnwrap(session.lastMultipartUploadEvent) { + XCTAssertEqual(uploadId, "uploadId") + XCTAssertTrue(error is CancellationError) + } + XCTAssertEqual(serviceProxy.unregisterMultipartUploadSessionCount, 1) + } + + /// Given: a DefaultStorageMultipartUploadClient + /// When: abortMultipartUpload is invoked and AWSS3Behaviour returns failure + /// Then: A .unknown error is reported to the session and the session is not unregistered + func testAbortMultipartUpload_withError_shouldFail() throws { + awss3Behavior.abortMultipartUploadExpectation = expectation(description: "Abort Multipart Upload") + awss3Behavior.abortMultipartUploadResult = .failure(.unknown("Unknown Error", nil)) + try client.abortMultipartUpload(uploadId: "uploadId") + + waitForExpectations(timeout: 1) + XCTAssertEqual(awss3Behavior.abortMultipartUploadCount, 1) + XCTAssertEqual(session.handleMultipartUploadCount, 0) + XCTAssertEqual(session.failCount, 1) + if case .unknown(let description, _) = try XCTUnwrap(session.lastError) as? StorageError { + XCTAssertEqual(description, "Unknown Error") + } + XCTAssertEqual(serviceProxy.unregisterMultipartUploadSessionCount, 1) + } + + /// Given: a DefaultStorageMultipartUploadClient + /// When: serviceProxy is set to nil and abortMultipartUpload is invoked + /// Then: A fatal error is thrown + func testAbortMultipartUpload_withoutServiceProxy_shouldThrowFatalError() throws { + serviceProxy = nil + try XCTAssertThrowFatalError { + try? self.client.abortMultipartUpload(uploadId: "uploadId") + } + } + + /// Given: a DefaultStorageMultipartUploadClient + /// When: cancelUploadTasks is invoked with identifiers + /// Then: The tasks are unregistered + func testCancelUploadTasks_shouldSucceed() throws { + let cancelExpectation = expectation(description: "Cancel Upload Tasks") + client.cancelUploadTasks(taskIdentifiers: [0, 1,2], done: { + cancelExpectation.fulfill() + }) + + waitForExpectations(timeout: 1) + XCTAssertEqual(serviceProxy.unregisterTaskIdentifiersCount, 1) + } + + /// Given: a DefaultStorageMultipartUploadClient + /// When: filter is invoked with some disallowed values + /// Then: a dictionary is returned with the disallowed values removed + func testFilterRequestHeaders_shouldResultFilteredHeaders() { + let filteredHeaders = defaultClient.filter( + requestHeaders: [ + "validHeader": "validValue", + "x-amz-acl": "invalidValue", + "x-amz-tagging": "invalidValue", + "x-amz-storage-class": "invalidValue", + "x-amz-server-side-encryption": "invalidValue", + "x-amz-meta-invalid_one": "invalidValue", + "x-amz-meta-invalid_two": "invalidValue", + "x-amz-grant-invalid_one": "invalidvalue", + "x-amz-grant-invalid_two": "invalidvalue" + ] + ) + + XCTAssertEqual(filteredHeaders.count, 1) + XCTAssertEqual(filteredHeaders["validHeader"], "validValue") + } +} + +private class MockStorageServiceProxy: StorageServiceProxy { + var preSignedURLBuilder: AWSS3PreSignedURLBuilderBehavior! = MockAWSS3PreSignedURLBuilder() + var awsS3: AWSS3Behavior! + var urlSession = URLSession.shared + var userAgent: String = "" + var urlRequestDelegate: URLRequestDelegate? = nil + + init(awsS3: AWSS3Behavior) { + self.awsS3 = awsS3 + } + + func register(task: StorageTransferTask) {} + + func unregister(task: StorageTransferTask) {} + + var unregisterTaskIdentifiersCount = 0 + func unregister(taskIdentifiers: [TaskIdentifier]) { + unregisterTaskIdentifiersCount += 1 + } + + var registerMultipartUploadSessionCount = 0 + func register(multipartUploadSession: StorageMultipartUploadSession) { + registerMultipartUploadSessionCount += 1 + } + + var unregisterMultipartUploadSessionCount = 0 + func unregister(multipartUploadSession: StorageMultipartUploadSession) { + unregisterMultipartUploadSessionCount += 1 + } +} + +private class MockAWSS3Behavior: AWSS3Behavior { + func deleteObject(_ request: AWSS3DeleteObjectRequest, completion: @escaping (Result) -> Void) {} + + func listObjectsV2(_ request: AWSS3ListObjectsV2Request, completion: @escaping (Result) -> Void) {} + + var createMultipartUploadCount = 0 + var createMultipartUploadResult: Result? = nil + var createMultipartUploadExpectation: XCTestExpectation? = nil + func createMultipartUpload(_ request: CreateMultipartUploadRequest, completion: @escaping (Result) -> Void) { + createMultipartUploadCount += 1 + if let result = createMultipartUploadResult { + completion(result) + } + createMultipartUploadExpectation?.fulfill() + } + + var completeMultipartUploadCount = 0 + var completeMultipartUploadResult: Result? = nil + var completeMultipartUploadExpectation: XCTestExpectation? = nil + func completeMultipartUpload(_ request: AWSS3CompleteMultipartUploadRequest, completion: @escaping (Result) -> Void) { + completeMultipartUploadCount += 1 + if let result = completeMultipartUploadResult { + completion(result) + } + completeMultipartUploadExpectation?.fulfill() + } + + var abortMultipartUploadCount = 0 + var abortMultipartUploadResult: Result? = nil + var abortMultipartUploadExpectation: XCTestExpectation? = nil + func abortMultipartUpload(_ request: AWSS3AbortMultipartUploadRequest, completion: @escaping (Result) -> Void) { + abortMultipartUploadCount += 1 + if let result = abortMultipartUploadResult { + completion(result) + } + abortMultipartUploadExpectation?.fulfill() + } + + func getS3() -> S3ClientProtocol { + return MockS3Client() + } +} + +class MockStorageMultipartUploadSession: StorageMultipartUploadSession { + var handleMultipartUploadCount = 0 + var lastMultipartUploadEvent: StorageMultipartUploadEvent? = nil + override func handle(multipartUploadEvent: StorageMultipartUploadEvent) { + handleMultipartUploadCount += 1 + lastMultipartUploadEvent = multipartUploadEvent + } + + var handleUploadPartCount = 0 + var lastUploadEvent: StorageUploadPartEvent? = nil + var handleUploadPartExpectation: XCTestExpectation? = nil + + override func handle(uploadPartEvent: StorageUploadPartEvent) { + handleUploadPartCount += 1 + lastUploadEvent = uploadPartEvent + handleUploadPartExpectation?.fulfill() + } + + var failCount = 0 + var lastError: Error? = nil + var failExpectation: XCTestExpectation? = nil + override func fail(error: Error) { + failCount += 1 + lastError = error + failExpectation?.fulfill() + } +} diff --git a/AmplifyPlugins/Storage/Tests/AWSS3StoragePluginTests/Support/Internal/DefaultStorageTransferDatabaseTests.swift b/AmplifyPlugins/Storage/Tests/AWSS3StoragePluginTests/Support/Internal/DefaultStorageTransferDatabaseTests.swift new file mode 100644 index 0000000000..127c151229 --- /dev/null +++ b/AmplifyPlugins/Storage/Tests/AWSS3StoragePluginTests/Support/Internal/DefaultStorageTransferDatabaseTests.swift @@ -0,0 +1,241 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + +@testable import Amplify +@testable import AWSS3StoragePlugin +import XCTest + +class DefaultStorageTransferDatabaseTests: XCTestCase { + private var database: DefaultStorageTransferDatabase! + private var uploadFile: UploadFile! + private var session: MockStorageSessionTask! + + override func setUp() { + database = DefaultStorageTransferDatabase( + databaseDirectoryURL: FileManager.default.temporaryDirectory, + logger: MockLogger() + ) + uploadFile = UploadFile( + fileURL: FileSystem.default.createTemporaryFileURL(), + temporaryFileCreated: true, + size: UInt64(Bytes.megabytes(12).bytes) + ) + session = MockStorageSessionTask(taskIdentifier: 1) + } + + override func tearDown() { + database = nil + uploadFile = nil + session = nil + } + + /// Given: A DefaultStorageTransferDatabase + /// When: linkTasksWithSessions is invoked with tasks containing multipart uploads and a sessionTask, and a session + /// Then: A StorageTransferTaskPairs linking the tasks with the session is returned + func testLinkTasksWithSessions_withMultipartUpload_shouldReturnPairs() { + let transferTask1 = StorageTransferTask( + transferType: .multiPartUpload(onEvent: { _ in }), + bucket: "bucket", + key: "key1" + ) + transferTask1.sessionTask = session + transferTask1.multipartUpload = .created( + uploadId: "uploadId", + uploadFile: uploadFile + ) + + let transferTask2 = StorageTransferTask( + transferType: .multiPartUpload(onEvent: { _ in }), + bucket: "bucket", + key: "key2" + ) + transferTask2.sessionTask = session + transferTask2.multipartUpload = .created( + uploadId: "uploadId", + uploadFile: uploadFile + ) + + let pairs = database.linkTasksWithSessions( + persistableTransferTasks: [ + "taskId1": .init(task: transferTask1), + "taskId2": .init(task: transferTask2) + ], + sessionTasks: [ + session + ] + ) + + XCTAssertEqual(pairs.count, 2) + XCTAssertTrue(pairs.contains(where: { $0.transferTask.key == "key1" })) + XCTAssertTrue(pairs.contains(where: { $0.transferTask.key == "key2" })) + } + + /// Given: A DefaultStorageTransferDatabase + /// When: linkTasksWithSessions is invoked with tasks containing multipart uploads but without a sessionTask, and a session + /// Then: A StorageTransferTaskPairs linking the tasks with the session is returned + func testLinkTasksWithSessions_withMultipartUpload_andNoSession_shouldReturnPairs() { + let transferTask1 = StorageTransferTask( + transferType: .multiPartUpload(onEvent: { _ in }), + bucket: "bucket", + key: "key1" + ) + transferTask1.multipartUpload = .created( + uploadId: "uploadId", + uploadFile: uploadFile + ) + + let transferTask2 = StorageTransferTask( + transferType: .multiPartUpload(onEvent: { _ in }), + bucket: "bucket", + key: "key2" + ) + transferTask2.multipartUpload = .created( + uploadId: "uploadId", + uploadFile: uploadFile + ) + + let pairs = database.linkTasksWithSessions( + persistableTransferTasks: [ + "taskId1": .init(task: transferTask1), + "taskId2": .init(task: transferTask2) + ], + sessionTasks: [ + session + ] + ) + + XCTAssertEqual(pairs.count, 2) + XCTAssertTrue(pairs.contains(where: { $0.transferTask.key == "key1" })) + XCTAssertTrue(pairs.contains(where: { $0.transferTask.key == "key2" })) + } + + /// Given: A DefaultStorageTransferDatabase + /// When: linkTasksWithSessions is invoked with tasks containing multipart upload parts, and a session + /// Then: A StorageTransferTaskPairs linking the tasks with the session is returned + func testLinkTasksWithSessions_withMultipartUploadPart_shouldReturnPairs() { + let transferTask0 = StorageTransferTask( + transferType: .multiPartUpload(onEvent: { _ in }), + bucket: "bucket", + key: "key1" + ) + transferTask0.sessionTask = session + transferTask0.multipartUpload = .created( + uploadId: "uploadId", + uploadFile: uploadFile + ) + + let transferTask1 = StorageTransferTask( + transferType: .multiPartUploadPart( + uploadId: "uploadId", + partNumber: 1 + ), + bucket: "bucket", + key: "key1" + ) + transferTask1.sessionTask = session + transferTask1.uploadId = "uploadId" + transferTask1.multipartUpload = .parts( + uploadId: "uploadId", + uploadFile: uploadFile, + partSize: try! .init(fileSize: UInt64(Bytes.megabytes(6).bytes)), + parts: [ + .inProgress( + bytes: Bytes.megabytes(6).bytes, + bytesTransferred: Bytes.megabytes(3).bytes, + taskIdentifier: 1 + ), + .completed( + bytes: Bytes.megabytes(6).bytes, + eTag: "eTag") + , + .pending(bytes: Bytes.megabytes(6).bytes) + ] + ) + transferTask1.uploadPart = .completed( + bytes: Bytes.megabytes(6).bytes, + eTag: "eTag" + ) + + let transferTask2 = StorageTransferTask( + transferType: .multiPartUploadPart( + uploadId: "uploadId", + partNumber: 2 + ), + bucket: "bucket", + key: "key1" + ) + transferTask2.sessionTask = session + transferTask2.uploadId = "uploadId" + transferTask2.multipartUpload = .parts( + uploadId: "uploadId", + uploadFile: uploadFile, + partSize: try! .init(fileSize: UInt64(Bytes.megabytes(6).bytes)), + parts: [ + .pending(bytes: Bytes.megabytes(6).bytes), + .pending(bytes: Bytes.megabytes(6).bytes) + ] + ) + transferTask2.uploadPart = .inProgress( + bytes: Bytes.megabytes(6).bytes, + bytesTransferred: Bytes.megabytes(3).bytes, + taskIdentifier: 1 + ) + + let pairs = database.linkTasksWithSessions( + persistableTransferTasks: [ + "taskId0": .init(task: transferTask0), + "taskId1": .init(task: transferTask1), + "taskId2": .init(task: transferTask2) + ], + sessionTasks: [ + session + ] + ) + + XCTAssertEqual(pairs.count, 3) + XCTAssertTrue(pairs.contains(where: { $0.transferTask.key == "key1" })) + XCTAssertFalse(pairs.contains(where: { $0.transferTask.key == "key2" })) + } + + /// Given: A DefaultStorageTransferDatabase + /// When: recover is invoked with a StorageURLSession that returns a session + /// Then: A .success is returned + func testLoadPersistableTasks() { + let urlSession = MockStorageURLSession( + sessionTasks: [ + session + ]) + let expectation = self.expectation(description: "Recover") + database.recover(urlSession: urlSession) { result in + guard case .success(_) = result else { + XCTFail("Expected success") + return + } + expectation.fulfill() + } + waitForExpectations(timeout: 1) + } + + /// Given: A DefaultStorageTransferDatabase + /// When: prepareForBackground is invoked + /// Then: A callback is invoked + func testPrepareForBackground() { + let expectation = self.expectation(description: "Prepare for Background") + database.prepareForBackground() { + expectation.fulfill() + } + waitForExpectations(timeout: 1) + } + + /// Given: The StorageTransferDatabase Type + /// When: default is invoked + /// Then: An instance of DefaultStorageTransferDatabase is returned + func testDefault_shouldReturnDefaultInstance() { + let defaultProtocol: StorageTransferDatabase = .default + XCTAssertTrue(defaultProtocol is DefaultStorageTransferDatabase) + } +} diff --git a/AmplifyPlugins/Storage/Tests/AWSS3StoragePluginTests/Support/Internal/StorageMultipartUploadSessionTests.swift b/AmplifyPlugins/Storage/Tests/AWSS3StoragePluginTests/Support/Internal/StorageMultipartUploadSessionTests.swift index fa808b4ccc..dad5a0a531 100644 --- a/AmplifyPlugins/Storage/Tests/AWSS3StoragePluginTests/Support/Internal/StorageMultipartUploadSessionTests.swift +++ b/AmplifyPlugins/Storage/Tests/AWSS3StoragePluginTests/Support/Internal/StorageMultipartUploadSessionTests.swift @@ -41,6 +41,44 @@ class StorageMultipartUploadSessionTests: XCTestCase { XCTAssertFalse(session.partsFailed) } + /// Given: A StorageTransferTask with a valid StorageTransferType + /// When: A StorageMultipartUploadSession is created from the task + /// Then: Its values are set correctly + func testSessionCreation_withTransferTask() throws { + let client = MockMultipartUploadClient() + let transferType: StorageTransferType = .multiPartUpload(onEvent: {_ in }) + let transferTask = StorageTransferTask( + transferType: transferType, + bucket: "bucket", + key: "key" + ) + + let session = try XCTUnwrap(StorageMultipartUploadSession(client: client, transferTask: transferTask, multipartUpload: .none, logger: MockLogger())) + XCTAssertEqual(session.partsCount, 0) + XCTAssertEqual(session.inProgressCount, 0) + XCTAssertFalse(session.partsCompleted) + XCTAssertFalse(session.partsFailed) + } + + /// Given: A StorageTransferTask with an invalid StorageTransferType + /// When: A StorageMultipartUploadSession is created from the task + /// Then: Its values are set correctly + func testSessionCreation_withTransferTask_andInvalidTransferType_shouldReturnNil() throws { + let client = MockMultipartUploadClient() + let transferType: StorageTransferType = .list(onEvent: {_ in }) + let transferTask = StorageTransferTask( + transferType: transferType, + bucket: "bucket", + key: "key" + ) + + XCTAssertNil(StorageMultipartUploadSession( + client: client, + transferTask: transferTask, + multipartUpload: .none + )) + } + func testCompletedMultipartUploadSession() throws { let initiatedExp = expectation(description: "Initiated") let completedExp = expectation(description: "Completed") @@ -105,7 +143,7 @@ class StorageMultipartUploadSessionTests: XCTestCase { let client = MockMultipartUploadClient() // creates an UploadFile for the mock process client.didCompletePartUpload = { (_, partNumber, _, _) in if partNumber == 5 { - closureSession?.handle(multipartUploadEvent: .aborting(error: nil)) + closureSession?.cancel() XCTAssertTrue(closureSession?.isAborted ?? false) } @@ -156,10 +194,10 @@ class StorageMultipartUploadSessionTests: XCTestCase { if pauseCount == 0, partNumber > 5, bytesTransferred > 0 { print("pausing on \(partNumber)") pauseCount += 1 - closureSession?.handle(multipartUploadEvent: .pausing) + closureSession?.pause() XCTAssertTrue(closureSession?.isPaused ?? false) print("resuming on \(partNumber)") - closureSession?.handle(multipartUploadEvent: .resuming) + closureSession?.resume() XCTAssertFalse(closureSession?.isPaused ?? true) } } diff --git a/AmplifyPlugins/Storage/Tests/AWSS3StoragePluginTests/Support/Internal/StorageServiceSessionDelegateTests.swift b/AmplifyPlugins/Storage/Tests/AWSS3StoragePluginTests/Support/Internal/StorageServiceSessionDelegateTests.swift new file mode 100644 index 0000000000..b8498a0387 --- /dev/null +++ b/AmplifyPlugins/Storage/Tests/AWSS3StoragePluginTests/Support/Internal/StorageServiceSessionDelegateTests.swift @@ -0,0 +1,363 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + +@testable import Amplify +@testable import AWSPluginsTestCommon +@testable import AWSS3StoragePlugin +import ClientRuntime +import AWSS3 +import XCTest + +class StorageServiceSessionDelegateTests: XCTestCase { + private var delegate: StorageServiceSessionDelegate! + private var service: AWSS3StorageServiceMock! + private var logger: MockLogger! + + override func setUp() { + service = try! AWSS3StorageServiceMock() + logger = MockLogger() + delegate = StorageServiceSessionDelegate( + identifier: "delegateTest", + logger: logger + ) + delegate.storageService = service + } + + override func tearDown() { + logger = nil + service = nil + delegate = nil + } + + /// Given: A StorageServiceSessionDelegate + /// When: logURLSessionActivity is invoked with warning set to true + /// Then: A warn message is logged + func testLogURLSession_withWarningTrue_shouldLogWarning() { + delegate.logURLSessionActivity("message", warning: true) + XCTAssertEqual(logger.warnCount, 1) + XCTAssertEqual(logger.infoCount, 0) + } + + /// Given: A StorageServiceSessionDelegate + /// When: logURLSessionActivity is invoked without setting warning + /// Then: An info message is logged + func testLogURLSession_shouldLogInfo() { + delegate.logURLSessionActivity("message") + XCTAssertEqual(logger.warnCount, 0) + XCTAssertEqual(logger.infoCount, 1) + } + + /// Given: A StorageServiceSessionDelegate and an identifier registered in the registry + /// When: the registry's handleBackgroundEvents is invoked with a matching identifier and then urlSessionDidFinishEvents is invoked + /// Then: The registry's continuation is triggered with true + func testDidFinishEvents_withMatchingIdentifiers_shouldTriggerContinuationWithTrue() async { + let handleEventsExpectation = self.expectation(description: "Handle Background Events") + let finishEventsExpectation = self.expectation(description: "Did Finish Events") + StorageBackgroundEventsRegistry.register(identifier: "identifier") + Task { + let result = await withCheckedContinuation { continuation in + StorageBackgroundEventsRegistry.handleBackgroundEvents( + identifier: "identifier", + continuation: continuation + ) + handleEventsExpectation.fulfill() + } + XCTAssertTrue(result) + finishEventsExpectation.fulfill() + } + + await fulfillment(of: [handleEventsExpectation], timeout: 1) + XCTAssertNotNil(StorageBackgroundEventsRegistry.continuation) + delegate.urlSessionDidFinishEvents(forBackgroundURLSession: .shared) + await fulfillment(of: [finishEventsExpectation], timeout: 1) + XCTAssertNil(StorageBackgroundEventsRegistry.continuation) + } + + /// Given: A StorageServiceSessionDelegate and an identifier registered in the registry + /// When: the registry's handleBackgroundEvents is invoked first with a matching identifier and then with a non-matching one, and after that urlSessionDidFinishEvents is invoked + /// Then: The registry's continuation for the non-matching identifier is triggered immediately with false, while the one for the matching identifier is triggered with true only after urlSessionDidFinishEvents is invoked + func testDidFinishEvents_withNonMatchingIdentifiers_shouldTriggerContinuationWithFalse() async { + let handleEventsMatchingExpectation = self.expectation(description: "Handle Background Events with Matching Identifiers") + let finishEventsExpectation = self.expectation(description: "Did Finish Events") + StorageBackgroundEventsRegistry.register(identifier: "identifier") + Task { + let result = await withCheckedContinuation { continuation in + StorageBackgroundEventsRegistry.handleBackgroundEvents( + identifier: "identifier", + continuation: continuation + ) + handleEventsMatchingExpectation.fulfill() + } + XCTAssertTrue(result) + finishEventsExpectation.fulfill() + } + + await fulfillment(of: [handleEventsMatchingExpectation], timeout: 1) + XCTAssertNotNil(StorageBackgroundEventsRegistry.continuation) + + let handleEventsNonMatchingExpectation = self.expectation(description: "Handle Background Events with Matching Identifiers") + Task { + let result = await withCheckedContinuation { continuation in + StorageBackgroundEventsRegistry.handleBackgroundEvents( + identifier: "identifier2", + continuation: continuation + ) + } + XCTAssertFalse(result) + handleEventsNonMatchingExpectation.fulfill() + } + await fulfillment(of: [handleEventsNonMatchingExpectation], timeout: 1) + delegate.urlSessionDidFinishEvents(forBackgroundURLSession: .shared) + await fulfillment(of: [finishEventsExpectation], timeout: 1) + XCTAssertNil(StorageBackgroundEventsRegistry.continuation) + } + + /// Given: A StorageServiceSessionDelegate + /// When: didBecomeInvalidWithError is invoked with a StorageError + /// Then: The service's resetURLSession is invoked + func testDidBecomeInvalid_withError_shouldResetURLSession() { + delegate.urlSession(.shared, didBecomeInvalidWithError: StorageError.accessDenied("", "", nil)) + XCTAssertEqual(service.resetURLSessionCount, 1) + } + + /// Given: A StorageServiceSessionDelegate + /// When: didBecomeInvalidWithError is invoked with a nil error + /// Then: The service's resetURLSession is invoked + func testDidBecomeInvalid_withNilError_shouldResetURLSession() { + delegate.urlSession(.shared, didBecomeInvalidWithError: nil) + XCTAssertEqual(service.resetURLSessionCount, 1) + } + + /// Given: A StorageServiceSessionDelegate and a StorageTransferTask with a NSError with a NSURLErrorCancelled reason + /// When: didComplete is invoked + /// Then: The task is not unregistered + func testDidComplete_withNSURLErrorCancelled_shouldNotCompleteTask() { + let task = URLSession.shared.dataTask(with: FileManager.default.temporaryDirectory) + let reasons = [ + NSURLErrorCancelledReasonBackgroundUpdatesDisabled, + NSURLErrorCancelledReasonInsufficientSystemResources, + NSURLErrorCancelledReasonUserForceQuitApplication, + NSURLErrorCancelled + ] + + for reason in reasons { + let expectation = self.expectation(description: "Did Complete With Error Reason \(reason)") + expectation.isInverted = true + let storageTask = StorageTransferTask( + transferType: .upload(onEvent: { _ in + expectation.fulfill() + }), + bucket: "bucket", + key: "key" + ) + service.mockedTask = storageTask + let error: Error = NSError( + domain: NSURLErrorDomain, + code: NSURLErrorCancelled, + userInfo: [ + NSURLErrorBackgroundTaskCancelledReasonKey: reason + ] + ) + + delegate.urlSession(.shared, task: task, didCompleteWithError: error) + waitForExpectations(timeout: 1) + XCTAssertEqual(storageTask.status, .unknown) + XCTAssertEqual(service.unregisterCount, 0) + } + } + + /// Given: A StorageServiceSessionDelegate and a StorageTransferTask with a StorageError + /// When: didComplete is invoked + /// Then: The task status is set to error and it's unregistered + func testDidComplete_withError_shouldFailTask() { + let task = URLSession.shared.dataTask(with: FileManager.default.temporaryDirectory) + let expectation = self.expectation(description: "Did Complete With Error") + let storageTask = StorageTransferTask( + transferType: .upload(onEvent: { _ in + expectation.fulfill() + }), + bucket: "bucket", + key: "key" + ) + service.mockedTask = storageTask + + delegate.urlSession(.shared, task: task, didCompleteWithError: StorageError.accessDenied("", "", nil)) + waitForExpectations(timeout: 1) + XCTAssertEqual(storageTask.status, .error) + XCTAssertEqual(service.unregisterCount, 1) + } + + /// Given: A StorageServiceSessionDelegate and a StorageTransferTask of type .upload + /// When: didSendBodyData is invoked + /// Then: An .inProcess event is reported, with the corresponding values + func testDidSendBodyData_upload_shouldSendInProcessEvent() { + let task = URLSession.shared.dataTask(with: FileManager.default.temporaryDirectory) + let expectation = self.expectation(description: "Did Send Body Data") + let storageTask = StorageTransferTask( + transferType: .upload(onEvent: { event in + guard case .inProcess(let progress) = event else { + XCTFail("Expected .inProcess event, got \(event)") + return + } + XCTAssertEqual(progress.totalUnitCount, 120) + XCTAssertEqual(progress.completedUnitCount, 100) + expectation.fulfill() + }), + bucket: "bucket", + key: "key" + ) + service.mockedTask = storageTask + + delegate.urlSession( + .shared, + task: task, + didSendBodyData: 10, + totalBytesSent: 100, + totalBytesExpectedToSend: 120 + ) + + waitForExpectations(timeout: 1) + } + + /// Given: A StorageServiceSessionDelegate and a StorageTransferTask of type .multiPartUploadPart + /// When: didSendBodyData is invoked + /// Then: A .progressUpdated event is reported to the session + func testDidSendBodyData_multiPartUploadPart_shouldSendInProcessEvent() { + let task = URLSession.shared.dataTask(with: FileManager.default.temporaryDirectory) + let storageTask = StorageTransferTask( + transferType: .multiPartUploadPart( + uploadId: "uploadId", + partNumber: 3 + ), + bucket: "bucket", + key: "key" + ) + service.mockedTask = storageTask + let multipartSession = MockStorageMultipartUploadSession( + client: MockMultipartUploadClient(), + bucket: "bucket", + key: "key", + onEvent: { event in } + ) + service.mockedMultipartUploadSession = multipartSession + + delegate.urlSession( + .shared, + task: task, + didSendBodyData: 10, + totalBytesSent: 100, + totalBytesExpectedToSend: 120 + ) + XCTAssertEqual(multipartSession.handleUploadPartCount, 1) + guard case .progressUpdated(let partNumber, let bytesTransferred, let taskIdentifier) = multipartSession.lastUploadEvent else { + XCTFail("Expected .progressUpdated event") + return + } + + XCTAssertEqual(partNumber, 3) + XCTAssertEqual(bytesTransferred, 10) + XCTAssertEqual(taskIdentifier, task.taskIdentifier) + } + + /// Given: A StorageServiceSessionDelegate and a StorageTransferTask of type .download + /// When: didWriteData is invoked + /// Then: An .inProcess event is reported, with the corresponding values + func testDidWriteData_shouldNotifyProgress() { + let task = URLSession.shared.downloadTask(with: FileManager.default.temporaryDirectory) + let expectation = self.expectation(description: "Did Write Data") + let storageTask = StorageTransferTask( + transferType: .download(onEvent: { event in + guard case .inProcess(let progress) = event else { + XCTFail("Expected .inProcess event, got \(event)") + return + } + XCTAssertEqual(progress.totalUnitCount, 300) + XCTAssertEqual(progress.completedUnitCount, 200) + expectation.fulfill() + }), + bucket: "bucket", + key: "key" + ) + service.mockedTask = storageTask + + delegate.urlSession( + .shared, + downloadTask: task, + didWriteData: 15, + totalBytesWritten: 200, + totalBytesExpectedToWrite: 300 + ) + + waitForExpectations(timeout: 1) + } + + /// Given: A StorageServiceSessionDelegate and a URLSessionDownloadTask without a httpResponse + /// When: didFinishDownloadingTo is invoked + /// Then: No event is reported and the task is not completed + func testDiFinishDownloading_withError_shouldNotCompleteDownload() { + let task = URLSession.shared.downloadTask(with: FileManager.default.temporaryDirectory) + let expectation = self.expectation(description: "Did Finish Downloading") + expectation.isInverted = true + let storageTask = StorageTransferTask( + transferType: .download(onEvent: { _ in + expectation.fulfill() + }), + bucket: "bucket", + key: "key" + ) + service.mockedTask = storageTask + + delegate.urlSession( + .shared, + downloadTask: task, + didFinishDownloadingTo: FileManager.default.temporaryDirectory + ) + + waitForExpectations(timeout: 1) + XCTAssertEqual(service.completeDownloadCount, 0) + } +} + +private class AWSS3StorageServiceMock: AWSS3StorageService { + convenience init() throws { + try self.init( + authService: MockAWSAuthService(), + region: "region", + bucket: "bucket", + storageTransferDatabase: MockStorageTransferDatabase() + ) + } + + override var identifier: String { + return "identifier" + } + + var mockedTask: StorageTransferTask? = nil + override func findTask(taskIdentifier: TaskIdentifier) -> StorageTransferTask? { + return mockedTask + } + + var resetURLSessionCount = 0 + override func resetURLSession() { + resetURLSessionCount += 1 + } + + var unregisterCount = 0 + override func unregister(task: StorageTransferTask) { + unregisterCount += 1 + } + + var mockedMultipartUploadSession: StorageMultipartUploadSession? = nil + override func findMultipartUploadSession(uploadId: UploadID) -> StorageMultipartUploadSession? { + return mockedMultipartUploadSession + } + + var completeDownloadCount = 0 + override func completeDownload(taskIdentifier: TaskIdentifier, sourceURL: URL) { + completeDownloadCount += 1 + } +} diff --git a/AmplifyPlugins/Storage/Tests/AWSS3StoragePluginTests/Support/Internal/StorageTransferTaskTests.swift b/AmplifyPlugins/Storage/Tests/AWSS3StoragePluginTests/Support/Internal/StorageTransferTaskTests.swift new file mode 100644 index 0000000000..c9b96aea8d --- /dev/null +++ b/AmplifyPlugins/Storage/Tests/AWSS3StoragePluginTests/Support/Internal/StorageTransferTaskTests.swift @@ -0,0 +1,658 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + +import Amplify +@testable import AWSS3StoragePlugin +import XCTest + +class StorageTransferTaskTests: XCTestCase { + + // MARK: - Resume tests + /// Given: A StorageTransferTask with a sessionTask + /// When: resume is invoked + /// Then: an .initiated event is reported and the task set to .inProgress + func testResume_withSessionTask_shouldCallResume_andReportInitiatedEvent() { + let expectation = expectation(description: ".initiated event received on resume with only sessionTask") + let sessionTask = MockSessionTask() + let task = createTask( + transferType: .upload(onEvent: { event in + guard case .initiated(_) = event else { + XCTFail("Expected .initiated, got \(event)") + return + } + expectation.fulfill() + }), + sessionTask: sessionTask, + proxyStorageTask: nil + ) + XCTAssertEqual(task.status, .paused) + + task.resume() + waitForExpectations(timeout: 0.5) + + XCTAssertEqual(sessionTask.resumeCount, 1) + XCTAssertEqual(task.status, .inProgress) + } + + /// Given: A StorageTransferTask with a proxyStorageTask + /// When: resume is invoked + /// Then: an .initiated event is reported and the task set to .inProgress + func testResume_withProxyStorageTask_shouldCallResume_andReportInitiatedEvent() { + let expectation = expectation(description: ".initiated event received on resume with only proxyStorageTask") + let sessionTask = MockSessionTask() + let storageTask = MockStorageTask() + let task = createTask( + transferType: .download(onEvent: { event in + guard case .initiated(_) = event else { + XCTFail("Expected .initiated, got \(event)") + return + } + expectation.fulfill() + }), + sessionTask: sessionTask, // Set the sessionTask to set task.status = .paused + proxyStorageTask: storageTask + ) + task.sessionTask = nil // Remove the session task + XCTAssertEqual(task.status, .paused) + + task.resume() + waitForExpectations(timeout: 0.5) + + XCTAssertEqual(sessionTask.resumeCount, 0) + XCTAssertEqual(storageTask.resumeCount, 1) + XCTAssertEqual(task.status, .inProgress) + } + + /// Given: A StorageTransferTask with a sessionTask and a proxyStorageTask + /// When: resume is invoked + /// Then: an .initiated event is reported and the task set to .inProgress + func testResume_withSessionTask_andProxyStorageTask_shouldCallResume_andReportInitiatedEvent() { + let expectation = expectation(description: ".initiated event received on resume with sessionTask and proxyStorageTask") + let sessionTask = MockSessionTask() + let storageTask = MockStorageTask() + let task = createTask( + transferType: .multiPartUpload(onEvent: { event in + guard case .initiated(_) = event else { + XCTFail("Expected .initiated, got \(event)") + return + } + expectation.fulfill() + }), + sessionTask: sessionTask, + proxyStorageTask: storageTask + ) + XCTAssertEqual(task.status, .paused) + + task.resume() + waitForExpectations(timeout: 0.5) + + XCTAssertEqual(sessionTask.resumeCount, 1) + XCTAssertEqual(storageTask.resumeCount, 0) + XCTAssertEqual(task.status, .inProgress) + } + + /// Given: A StorageTransferTask without a sessionTask and without a proxyStorageTask + /// When: resume is invoked + /// Then: No event is reported and the task is not to .inProgress + func testResume_withoutSessionTask_withoutProxyStorateTask_shouldNotCallResume_andNotReportEvent() { + let expectation = expectation(description: "no event is received on resume when no sessionTask nor proxyStorageTask") + expectation.isInverted = true + let sessionTask = MockSessionTask() + let task = createTask( + transferType: .multiPartUpload(onEvent: { event in + XCTFail("No event expected, got \(event)") + expectation.fulfill() + }), + sessionTask: sessionTask, // Set the sessionTask to set task.status = .paused + proxyStorageTask: nil + ) + task.sessionTask = nil // Remove the sessionTask + XCTAssertEqual(task.status, .paused) + + task.resume() + waitForExpectations(timeout: 0.5) + + XCTAssertEqual(sessionTask.resumeCount, 0) + XCTAssertEqual(task.status, .paused) + } + + /// Given: A StorageTransferTask with status not being paused + /// When: resume is invoked + /// Then: No event is reported and the task is not set to .inProgress + func testResume_withTaskNotPaused_shouldNotCallResume_andNotReportEvent() { + let expectation = expectation(description: "no event is received on resume when the session is not paused") + expectation.isInverted = true + let task = createTask( + transferType: .multiPartUpload(onEvent: { event in + XCTFail("No event expected, got \(event)") + expectation.fulfill() + }), + sessionTask: nil, // Do not set session task so task.status = .unknown + proxyStorageTask: nil + ) + XCTAssertEqual(task.status, .unknown) + + task.resume() + waitForExpectations(timeout: 0.5) + + XCTAssertEqual(task.status, .unknown) + } + + // MARK: - Suspend Tests + /// Given: A StorageTransferTask with a sessionTask + /// When: suspend is invoked + /// Then: The task is set to .paused + func testSuspend_withSessionTask_shouldCallSuspend() { + let sessionTask = MockSessionTask(state: .running) + let task = createTask( + transferType: .upload(onEvent: { _ in }), + sessionTask: sessionTask, + proxyStorageTask: nil + ) + // Set the task to inProgress by setting a multiPartUpload.creating + task.multipartUpload = .creating + XCTAssertEqual(task.status, .inProgress) + + task.suspend() + + XCTAssertEqual(sessionTask.suspendCount, 1) + XCTAssertEqual(task.status, .paused) + } + + /// Given: A StorageTransferTask with a proxyStorageTask + /// When: suspend is invoked + /// Then: The task is set to .paused + func testSuspend_withProxyStorageTask_shouldCallPause() { + let storageTask = MockStorageTask() + let task = createTask( + transferType: .upload(onEvent: { _ in }), + sessionTask: nil, + proxyStorageTask: storageTask + ) + // Set the task to inProgress by setting a multiPartUpload.creating + task.multipartUpload = .creating + XCTAssertEqual(task.status, .inProgress) + + task.suspend() + + XCTAssertEqual(storageTask.pauseCount, 1) + XCTAssertEqual(task.status, .paused) + } + + /// Given: A StorageTransferTask with a sessionTask and a proxyStorageTask + /// When: suspend is invoked + /// Then: The task is set to .paused + func testSuspend_withSessionTask_andProxyStorageTask_shouldCallSuspend() { + let sessionTask = MockSessionTask(state: .running) + let storageTask = MockStorageTask() + let task = createTask( + transferType: .upload(onEvent: { _ in }), + sessionTask: sessionTask, + proxyStorageTask: storageTask + ) + // Set the task to inProgress by setting a multiPartUpload.creating + task.multipartUpload = .creating + XCTAssertEqual(task.status, .inProgress) + + task.suspend() + + XCTAssertEqual(sessionTask.suspendCount, 1) + XCTAssertEqual(storageTask.pauseCount, 0) + XCTAssertEqual(task.status, .paused) + } + + /// Given: A StorageTransferTask without a sessionTask and without a proxyStorageTask + /// When: suspend is invoked + /// Then: The task remains .inProgress + func testSuspend_withoutSessionTask_andWithoutProxyStorageTask_shouldDoNothing() { + let task = createTask( + transferType: .upload(onEvent: { _ in }), + sessionTask: nil, + proxyStorageTask: nil + ) + // Set the task to inProgress by setting a multiPartUpload.creating + task.multipartUpload = .creating + XCTAssertEqual(task.status, .inProgress) + + task.suspend() + + XCTAssertEqual(task.status, .inProgress) + } + + /// Given: A StorageTransferTask with status completed + /// When: suspend is invoked + /// Then: The task remains completed + func testSuspend_withTaskNotInProgress_shouldDoNothing() { + let sessionTask = MockSessionTask() + let task = createTask( + transferType: .upload(onEvent: { _ in }), + sessionTask: sessionTask, + proxyStorageTask: nil + ) + // Set the task to completed by setting a multiPartUpload.completed + task.multipartUpload = .completed(uploadId: "") + XCTAssertEqual(task.status, .completed) + + task.suspend() + + XCTAssertEqual(sessionTask.suspendCount, 0) + XCTAssertEqual(task.status, .completed) + } + + /// Given: A StorageTransferTask + /// When: pause is invoked + /// Then: The task is set to .paused + func testPause_shouldCallSuspend() { + let sessionTask = MockSessionTask(state: .running) + let task = createTask( + transferType: .upload(onEvent: { _ in }), + sessionTask: sessionTask, + proxyStorageTask: nil + ) + // Set the task to inProgress by setting a multiPartUpload.creating + task.multipartUpload = .creating + XCTAssertEqual(task.status, .inProgress) + + task.pause() + + XCTAssertEqual(sessionTask.suspendCount, 1) + XCTAssertEqual(task.status, .paused) + } + + // MARK: - Cancel Tests + /// Given: A StorageTransferTask with a sessionTask + /// When: cancel is invoked + /// Then: The task is set to .cancelled + func testCancel_withSessionTask_shouldCancel() { + let sessionTask = MockSessionTask() + let task = createTask( + transferType: .upload(onEvent: { _ in }), + sessionTask: sessionTask, + proxyStorageTask: MockStorageTask() + ) + + // Set the task to completed by setting a multiPartUpload.completed + XCTAssertNotEqual(task.status, .completed) + + task.cancel() + + XCTAssertEqual(task.status, .cancelled) + XCTAssertEqual(sessionTask.cancelCount, 1) + XCTAssertNil(task.proxyStorageTask) + } + + /// Given: A StorageTransferTask with a proxyStorageTask + /// When: cancel is invoked + /// Then: The task is set to .cancelled + func testCancel_withProxyStorageTask_shouldCancel() { + let storageTask = MockStorageTask() + let task = createTask( + transferType: .upload(onEvent: { _ in }), + sessionTask: nil, + proxyStorageTask: storageTask + ) + + task.cancel() + XCTAssertEqual(task.status, .cancelled) + XCTAssertEqual(storageTask.cancelCount, 1) + XCTAssertNil(task.proxyStorageTask) + } + + /// Given: A StorageTransferTask without a sessionTask and without a proxyStorageTask + /// When: cancel is invoked + /// Then: The task is not set to .cancelled + func testCancel_withoutSessionTask_withoutProxyStorageTask_shouldDoNothing() { + let task = createTask( + transferType: .upload(onEvent: { _ in }), + sessionTask: nil, + proxyStorageTask: nil + ) + + task.cancel() + XCTAssertNotEqual(task.status, .cancelled) + } + + /// Given: A StorageTransferTask with status completed + /// When: cancel is invoked + /// Then: The task is not set to .cancelled + func testCancel_withTaskCompleted_shouldDoNothing() { + let sessionTask = MockSessionTask() + let task = createTask( + transferType: .upload(onEvent: { _ in }), + sessionTask: sessionTask, + proxyStorageTask: MockStorageTask() + ) + // Set the task to completed by setting a multiPartUpload.completed + task.multipartUpload = .completed(uploadId: "") + XCTAssertEqual(task.status, .completed) + + task.cancel() + XCTAssertNotEqual(task.status, .cancelled) + XCTAssertEqual(sessionTask.cancelCount, 0) + XCTAssertNotNil(task.proxyStorageTask) + } + + // MARK: - Complete Tests + /// Given: A StorageTransferTask with sessionTask + /// When: complete is invoked + /// Then: The task is set to .completed + func testComplete_withSessionTask_shouldComplete() { + let sessionTask = MockSessionTask() + let task = createTask( + transferType: .upload(onEvent: { _ in }), + sessionTask: sessionTask, + proxyStorageTask: MockStorageTask() + ) + + task.complete() + XCTAssertEqual(task.status, .completed) + XCTAssertNil(task.proxyStorageTask) + } + + /// Given: A StorageTransferTask with status cancelled + /// When: complete is invoked + /// Then: The task is remains .cancelled + func testComplete_withTaskCancelled_shouldDoNothing() { + let sessionTask = MockSessionTask() + let task = createTask( + transferType: .upload(onEvent: { _ in }), + sessionTask: sessionTask, + proxyStorageTask: nil + ) + task.cancel() + XCTAssertEqual(task.status, .cancelled) + + task.complete() + XCTAssertEqual(task.status, .cancelled) + } + + /// Given: A StorageTransferTask with status completed + /// When: complete is invoked + /// Then: The task is remains .completed + func testComplete_withTaskCompleted_shouldDoNothing() { + let sessionTask = MockSessionTask() + let task = createTask( + transferType: .upload(onEvent: { _ in }), + sessionTask: sessionTask, + proxyStorageTask: MockStorageTask() + ) + // Set the task to completed by setting a multiPartUpload.completed + task.multipartUpload = .completed(uploadId: "") + XCTAssertEqual(task.status, .completed) + + task.complete() + + XCTAssertNotNil(task.proxyStorageTask) + } + + // MARK: - Fail Tests + /// Given: A StorageTransferTask + /// When: fail is invoked + /// Then: A .failed event is reported + func testFail_shouldReportFailEvent() { + let expectation = expectation(description: ".failed event received on fail") + let task = createTask( + transferType: .upload(onEvent: { event in + guard case .failed(_) = event else { + XCTFail("Expected .failed, got \(event)") + return + } + expectation.fulfill() + }), + sessionTask: MockSessionTask(), + proxyStorageTask: MockStorageTask() + ) + task.fail(error: CancellationError()) + + waitForExpectations(timeout: 0.5) + XCTAssertEqual(task.status, .error) + XCTAssertTrue(task.isFailed) + XCTAssertNil(task.proxyStorageTask) + } + + /// Given: A StorageTransferTask with status .failed + /// When: fail is invoked + /// Then: No event is reported + func testFail_withFailedTask_shouldNotReportEvent() { + let expectation = expectation(description: "event received on fail for failed task") + expectation.isInverted = true + let task = createTask( + transferType: .upload(onEvent: { event in + XCTFail("No event expected, got \(event)") + expectation.fulfill() + }), + sessionTask: MockSessionTask(), + proxyStorageTask: MockStorageTask() + ) + + // Set the task to error by setting a multiPartUpload.failed + task.multipartUpload = .failed(uploadId: "", parts: nil, error: CancellationError()) + XCTAssertEqual(task.status, .error) + task.fail(error: CancellationError()) + + waitForExpectations(timeout: 0.5) + XCTAssertNotNil(task.proxyStorageTask) + } + + // MARK: - Response Tests + /// Given: A StorageTransferTask with a valid responseData + /// When: responseText is invoked + /// Then: A string representing the data is returned + func testResponseText_withValidData_shouldReturnText() { + let task = createTask( + transferType: .upload(onEvent: { _ in}), + sessionTask: nil, + proxyStorageTask: nil + ) + task.responseData = "Test".data(using: .utf8) + + XCTAssertEqual(task.responseText, "Test") + } + + /// Given: A StorageTransferTask with an invalid responseData + /// When: responseText is invoked + /// Then: nil is returned + func testResponseText_withInvalidData_shouldReturnNil() { + let task = createTask( + transferType: .upload(onEvent: { _ in}), + sessionTask: nil, + proxyStorageTask: nil + ) + task.responseData = Data(count: 9999) + + XCTAssertNil(task.responseText) + } + + /// Given: A StorageTransferTask with a nil responseData + /// When: responseText is invoked + /// Then: nil is returned + func testResponseText_withoutData_shouldReturnNil() { + let task = createTask( + transferType: .upload(onEvent: { _ in}), + sessionTask: nil, + proxyStorageTask: nil + ) + task.responseData = nil + + XCTAssertNil(task.responseText) + } + + // MARK: - PartNumber Tests + /// Given: A StorageTransferTask of type .multiPartUploadPart + /// When: partNumber is invoked + /// Then: The corresponding part number is returned + func testPartNumber_withMultipartUpload_shouldReturnPartNumber() { + let partNumber: PartNumber = 5 + let task = createTask( + transferType: .multiPartUploadPart(uploadId: "", partNumber: partNumber), + sessionTask: nil, + proxyStorageTask: nil + ) + + XCTAssertEqual(task.partNumber, partNumber) + } + + /// Given: A StorageTransferTask of type .upload + /// When: partNumber is invoked + /// Then: nil is returned + func testPartNumber_withOtherTransferType_shouldReturnNil() { + let task = createTask( + transferType: .upload(onEvent: { _ in}), + sessionTask: nil, + proxyStorageTask: nil + ) + + XCTAssertNil(task.partNumber) + } + + // MARK: - HTTPRequestHeaders Tests + /// Given: A StorageTransferTask with requestHeaders + /// When: URLRequest.setHTTPRequestHeaders is invoked with said task + /// Then: The request includes the corresponding headers + func testHTTPRequestHeaders_shouldSetValues() { + let task = createTask( + transferType: .upload(onEvent: { _ in}), + sessionTask: nil, + proxyStorageTask: nil, + requestHeaders: [ + "header1": "value1", + "header2": "value2" + ] + ) + + var request = URLRequest(url: FileManager.default.temporaryDirectory) + XCTAssertNil(request.allHTTPHeaderFields) + + request.setHTTPRequestHeaders(transferTask: task) + XCTAssertEqual(request.allHTTPHeaderFields?.count, 2) + XCTAssertEqual(request.allHTTPHeaderFields?["header1"], "value1") + XCTAssertEqual(request.allHTTPHeaderFields?["header2"], "value2") + } + + /// Given: A StorageTransferTask with nil requestHeaders + /// When: URLRequest.setHTTPRequestHeaders is invoked with said task + /// Then: The request does not adds headers + func testHTTPRequestHeaders_withoutHeaders_shouldDoNothing() { + let task = createTask( + transferType: .upload(onEvent: { _ in}), + sessionTask: nil, + proxyStorageTask: nil, + requestHeaders: nil + ) + + var request = URLRequest(url: FileManager.default.temporaryDirectory) + XCTAssertNil(request.allHTTPHeaderFields) + + request.setHTTPRequestHeaders(transferTask: task) + XCTAssertNil(request.allHTTPHeaderFields) + } +} + +extension StorageTransferTaskTests { + private func createTask( + transferType: StorageTransferType, + sessionTask: StorageSessionTask?, + proxyStorageTask: StorageTask?, + requestHeaders: [String: String]? = nil + ) -> StorageTransferTask { + let transferID = UUID().uuidString + let bucket = "BUCKET" + let key = UUID().uuidString + let task = StorageTransferTask( + transferID: transferID, + transferType: transferType, + bucket: bucket, + key: key, + location: nil, + contentType: nil, + requestHeaders: requestHeaders, + storageTransferDatabase: MockStorageTransferDatabase(), + logger: MockLogger() + ) + task.sessionTask = sessionTask + task.proxyStorageTask = proxyStorageTask + return task + } +} + + +private class MockStorageTask: StorageTask { + var pauseCount = 0 + func pause() { + pauseCount += 1 + } + + var resumeCount = 0 + func resume() { + resumeCount += 1 + } + + var cancelCount = 0 + func cancel() { + cancelCount += 1 + } +} + +private class MockSessionTask: StorageSessionTask { + let taskIdentifier: TaskIdentifier + let state: URLSessionTask.State + + init( + taskIdentifier: TaskIdentifier = 1, + state: URLSessionTask.State = .suspended + ) { + self.taskIdentifier = taskIdentifier + self.state = state + } + + var resumeCount = 0 + func resume() { + resumeCount += 1 + } + + var suspendCount = 0 + func suspend() { + suspendCount += 1 + } + + var cancelCount = 0 + func cancel() { + cancelCount += 1 + } +} + +class MockLogger: Logger { + var logLevel: LogLevel = .verbose + + func error(_ message: @autoclosure () -> String) { + print(message()) + } + + func error(error: Error) { + print(error) + } + + var warnCount = 0 + func warn(_ message: @autoclosure () -> String) { + print(message()) + warnCount += 1 + } + + var infoCount = 0 + func info(_ message: @autoclosure () -> String) { + print(message()) + infoCount += 1 + } + + func debug(_ message: @autoclosure () -> String) { + print(message()) + } + + func verbose(_ message: @autoclosure () -> String) { + print(message()) + } +}