diff --git a/AmplifyPlugins/DataStore/Sources/AWSDataStorePlugin/Storage/StorageEngine+SyncRequirement.swift b/AmplifyPlugins/DataStore/Sources/AWSDataStorePlugin/Storage/StorageEngine+SyncRequirement.swift index 711ad7ab57..0b99894ea6 100644 --- a/AmplifyPlugins/DataStore/Sources/AWSDataStorePlugin/Storage/StorageEngine+SyncRequirement.swift +++ b/AmplifyPlugins/DataStore/Sources/AWSDataStorePlugin/Storage/StorageEngine+SyncRequirement.swift @@ -34,7 +34,10 @@ extension StorageEngine { )) } - let authPluginRequired = StorageEngine.requiresAuthPlugin(api) + let authPluginRequired = StorageEngine.requiresAuthPlugin( + api, + authModeStrategy: dataStoreConfiguration.authModeStrategyType + ) guard authPluginRequired else { syncEngine.start(api: apiGraphQL, auth: nil) return .success(.successfullyInitialized) @@ -96,20 +99,45 @@ extension StorageEngine { } } - static func requiresAuthPlugin(_ apiPlugin: APICategoryPlugin) -> Bool { + static func requiresAuthPlugin( + _ apiPlugin: APICategoryPlugin, + authModeStrategy: AuthModeStrategyType + ) -> Bool { let modelsRequireAuthPlugin = ModelRegistry.modelSchemas.contains { schema in guard schema.isSyncable else { return false } - return StorageEngine.requiresAuthPlugin(apiPlugin, authRules: schema.authRules) + return requiresAuthPlugin(apiPlugin, + authRules: schema.authRules, + authModeStrategy: authModeStrategy) } return modelsRequireAuthPlugin } - static func requiresAuthPlugin(_ apiPlugin: APICategoryPlugin, authRules: [AuthRule]) -> Bool { - if let rulesRequireAuthPlugin = authRules.requireAuthPlugin { - return rulesRequireAuthPlugin + static func requiresAuthPlugin( + _ apiPlugin: APICategoryPlugin, + authRules: [AuthRule], + authModeStrategy: AuthModeStrategyType + ) -> Bool { + switch authModeStrategy { + case .default: + if authRules.isEmpty { + return false + } + // Only use the auth rule as determination for auth plugin requirement when there is + // exactly one. If there is more than one auth rule AND multi-auth is not enabled, + // then immediately fall back to using the default auth type configured on the APIPlugin because + // we do not have enough information to know which provider to use to make the determination. + if authRules.count == 1, + let singleAuthRule = authRules.first, + let ruleRequireAuthPlugin = singleAuthRule.requiresAuthPlugin { + return ruleRequireAuthPlugin + } + case .multiAuth: + if let rulesRequireAuthPlugin = authRules.requireAuthPlugin { + return rulesRequireAuthPlugin + } } // Fall back to the endpoint's auth type if a determination cannot be made from the auth rules. This can diff --git a/AmplifyPlugins/DataStore/Tests/AWSDataStorePluginTests/Sync/StorageEngineSyncRequirementsTests.swift b/AmplifyPlugins/DataStore/Tests/AWSDataStorePluginTests/Sync/StorageEngineSyncRequirementsTests.swift index 24467a3b67..917616005f 100644 --- a/AmplifyPlugins/DataStore/Tests/AWSDataStorePluginTests/Sync/StorageEngineSyncRequirementsTests.swift +++ b/AmplifyPlugins/DataStore/Tests/AWSDataStorePluginTests/Sync/StorageEngineSyncRequirementsTests.swift @@ -17,80 +17,91 @@ class StorageEngineSyncRequirementsTests: XCTestCase { func testRequiresAuthPluginFalseForMissingAuthRules() { let apiPlugin = MockAPICategoryPlugin() - let result = StorageEngine.requiresAuthPlugin(apiPlugin) - XCTAssertFalse(result) + XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authModeStrategy: .default)) + XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authModeStrategy: .multiAuth)) } func testRequiresAuthPluginSingleAuthRuleAPIKey() { let apiPlugin = MockAPICategoryPlugin() let authRules = [AuthRule(allow: .owner, provider: .apiKey)] - XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules)) + XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules, authModeStrategy: .default)) + XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules, authModeStrategy: .multiAuth)) } func testRequiresAuthPluginSingleAuthRuleOIDC() { let apiPlugin = MockAPICategoryPlugin() let authRules = [AuthRule(allow: .owner, provider: .oidc)] - XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules)) + XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules, authModeStrategy: .default)) + XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules, authModeStrategy: .multiAuth)) } func testRequiresAuthPluginSingleAuthRuleFunction() { let apiPlugin = MockAPICategoryPlugin() let authRules = [AuthRule(allow: .private, provider: .function)] - XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules)) + XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules, authModeStrategy: .default)) + XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules, authModeStrategy: .multiAuth)) } func testRequiresAuthPluginSingleAuthRuleUserPools() { let apiPlugin = MockAPICategoryPlugin() let authRules = [AuthRule(allow: .owner, provider: .userPools)] - XCTAssertTrue(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules)) + XCTAssertTrue(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules, authModeStrategy: .default)) + XCTAssertTrue(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules, authModeStrategy: .multiAuth)) } func testRequiresAuthPluginSingleAuthRuleIAM() { let apiPlugin = MockAPICategoryPlugin() let authRules = [AuthRule(allow: .owner, provider: .iam)] - XCTAssertTrue(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules)) + XCTAssertTrue(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules, authModeStrategy: .default)) + XCTAssertTrue(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules, authModeStrategy: .multiAuth)) } func testRequiresAuthPluginNoProvidersWithAuthTypeFunction() { let authRules = [AuthRule(allow: .owner)] let apiPlugin = MockAPIAuthInformationPlugin() apiPlugin.authType = .function - XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules)) + XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules, authModeStrategy: .default)) + XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules, authModeStrategy: .multiAuth)) } func testRequiresAuthPluginNoProvidersWithAuthTypeAPIKey() { let authRules = [AuthRule(allow: .owner)] let apiPlugin = MockAPIAuthInformationPlugin() apiPlugin.authType = .apiKey - XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules)) + XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules, authModeStrategy: .default)) + XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules, authModeStrategy: .multiAuth)) } func testRequiresAuthPluginNoProvidersWithAuthTypeUserPools() { let authRules = [AuthRule(allow: .owner)] let apiPlugin = MockAPIAuthInformationPlugin() apiPlugin.authType = .amazonCognitoUserPools - XCTAssertTrue(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules)) + XCTAssertTrue(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules, authModeStrategy: .default)) + XCTAssertTrue(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules, authModeStrategy: .multiAuth)) } func testRequiresAuthPluginNoProvidersWithAuthTypeIAM() { let authRules = [AuthRule(allow: .owner)] let apiPlugin = MockAPIAuthInformationPlugin() apiPlugin.authType = .awsIAM - XCTAssertTrue(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules)) + XCTAssertTrue(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules, authModeStrategy: .default)) + XCTAssertTrue(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules, authModeStrategy: .multiAuth)) } func testRequiresAuthPluginNoProvidersWithAuthTypeODIC() { let authRules = [AuthRule(allow: .owner)] let apiPlugin = MockAPIAuthInformationPlugin() apiPlugin.authType = .openIDConnect - XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules)) + XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules, authModeStrategy: .default)) + XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules, authModeStrategy: .multiAuth)) } func testRequiresAuthPluginNoProvidersWithAuthTypeNone() { let authRules = [AuthRule(allow: .owner)] let apiPlugin = MockAPIAuthInformationPlugin() apiPlugin.authType = AWSAuthorizationType.none - XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules)) + XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules, authModeStrategy: .default)) + XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules, authModeStrategy: .multiAuth)) } func testRequiresAuthPluginOIDCProvider() { @@ -99,7 +110,41 @@ class StorageEngineSyncRequirementsTests: XCTestCase { apiPlugin.defaultAuthTypeError = APIError.unknown("Could not get default auth type", "", nil) let oidcProvider = MockOIDCAuthProvider() apiPlugin.authProviderFactory = MockAPIAuthProviderFactory(oidcProvider: oidcProvider) - XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules)) + XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules, authModeStrategy: .default)) + XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules, authModeStrategy: .multiAuth)) + } + + func testRequiresAuthPluginOIDCProvider_MultiAuthRules() { + // OIDC requires an auth provider on the API, this is added below + let authRules = [AuthRule(allow: .owner, provider: .oidc), + AuthRule(allow: .private, provider: .iam)] + let apiPlugin = MockAPIAuthInformationPlugin() + apiPlugin.defaultAuthTypeError = APIError.unknown("Could not get default auth type", "", nil) + let oidcProvider = MockOIDCAuthProvider() + apiPlugin.authProviderFactory = MockAPIAuthProviderFactory(oidcProvider: oidcProvider) + XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, + authRules: authRules, + authModeStrategy: .default), + "Should be false since OIDC is the default auth type on the API.") + XCTAssertTrue(StorageEngine.requiresAuthPlugin(apiPlugin, + authRules: authRules, + authModeStrategy: .multiAuth), + "Should be true since IAM requires auth plugin.") + } + + func testRequiresAuthPluginUserPoolProvider_MultiAuthRules() { + let authRules = [AuthRule(allow: .owner, provider: .userPools), + AuthRule(allow: .private, provider: .iam)] + let apiPlugin = MockAPIAuthInformationPlugin() + apiPlugin.authType = AWSAuthorizationType.amazonCognitoUserPools + XCTAssertTrue(StorageEngine.requiresAuthPlugin(apiPlugin, + authRules: authRules, + authModeStrategy: .default), + "Should be true since UserPool is the default auth type on the API.") + XCTAssertTrue(StorageEngine.requiresAuthPlugin(apiPlugin, + authRules: authRules, + authModeStrategy: .multiAuth), + "Should be true since both UserPool and IAM requires auth plugin.") } func testRequiresAuthPluginFunctionProvider() { @@ -108,14 +153,16 @@ class StorageEngineSyncRequirementsTests: XCTestCase { apiPlugin.defaultAuthTypeError = APIError.unknown("Could not get default auth type", "", nil) let functionProvider = MockFunctionAuthProvider() apiPlugin.authProviderFactory = MockAPIAuthProviderFactory(functionProvider: functionProvider) - XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules)) + XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules, authModeStrategy: .default)) + XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules, authModeStrategy: .multiAuth)) } func testRequiresAuthPluginWithAuthRules() { let authRules = [AuthRule(allow: .owner)] let apiPlugin = MockAPIAuthInformationPlugin() apiPlugin.defaultAuthTypeError = APIError.unknown("Could not get default auth type", "", nil) - XCTAssertTrue(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules)) + XCTAssertTrue(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules, authModeStrategy: .default)) + XCTAssertTrue(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules, authModeStrategy: .multiAuth)) } // MARK: - AuthRules tests