From 9a0ccf690e4c9cbf323108408231b52558a044d8 Mon Sep 17 00:00:00 2001 From: Di Wu Date: Fri, 26 Jan 2024 18:03:03 -0800 Subject: [PATCH 1/2] fix(datastore-v1): correct the filter predicate logic applied to optional fields --- .../Model/Internal/Persistable.swift | 36 +++++++++++++++++++ .../DataStore/Query/QueryOperator.swift | 6 ++-- .../DataStore/Query/QueryPredicate.swift | 30 +--------------- ...enticationProviderConfirmSignupTests.swift | 4 +-- ...icationProviderResendSignupCodeTests.swift | 5 +-- .../AuthenticationProviderSignupTests.swift | 12 ++++--- AmplifyPlugins/Auth/Podfile.lock | 2 +- ...yPredicateEvaluateGeneratedBoolTests.swift | 4 +-- ...yPredicateEvaluateGeneratedDateTests.swift | 8 ++--- ...dicateEvaluateGeneratedDateTimeTests.swift | 8 ++--- ...icateEvaluateGeneratedDoubleIntTests.swift | 12 +++---- ...redicateEvaluateGeneratedDoubleTests.swift | 12 +++---- ...ryPredicateEvaluateGeneratedIntTests.swift | 6 ++-- ...redicateEvaluateGeneratedStringTests.swift | 8 ++--- ...yPredicateEvaluateGeneratedTimeTests.swift | 8 ++--- 15 files changed, 87 insertions(+), 74 deletions(-) diff --git a/Amplify/Categories/DataStore/Model/Internal/Persistable.swift b/Amplify/Categories/DataStore/Model/Internal/Persistable.swift index aa32cbddb4..5ac322cf3a 100644 --- a/Amplify/Categories/DataStore/Model/Internal/Persistable.swift +++ b/Amplify/Categories/DataStore/Model/Internal/Persistable.swift @@ -65,6 +65,12 @@ struct PersistableHelper { return lhs == rhs case let (lhs, rhs) as (String, String): return lhs == rhs + case let (lhs, rhs) as (any EnumPersistable, String): + return lhs.rawValue == rhs + case let (lhs, rhs) as (String, any EnumPersistable): + return lhs == rhs.rawValue + case let (lhs, rhs) as (any EnumPersistable, any EnumPersistable): + return lhs.rawValue == rhs.rawValue default: return false } @@ -95,6 +101,12 @@ struct PersistableHelper { return lhs == Double(rhs) case let (lhs, rhs) as (String, String): return lhs == rhs + case let (lhs, rhs) as (any EnumPersistable, String): + return lhs.rawValue == rhs + case let (lhs, rhs) as (String, any EnumPersistable): + return lhs == rhs.rawValue + case let (lhs, rhs) as (any EnumPersistable, any EnumPersistable): + return lhs.rawValue == rhs.rawValue default: return false } @@ -123,6 +135,12 @@ struct PersistableHelper { return lhs <= Double(rhs) case let (lhs, rhs) as (String, String): return lhs <= rhs + case let (lhs, rhs) as (any EnumPersistable, String): + return lhs.rawValue <= rhs + case let (lhs, rhs) as (String, any EnumPersistable): + return lhs <= rhs.rawValue + case let (lhs, rhs) as (any EnumPersistable, any EnumPersistable): + return lhs.rawValue <= rhs.rawValue default: return false } @@ -151,6 +169,12 @@ struct PersistableHelper { return lhs < Double(rhs) case let (lhs, rhs) as (String, String): return lhs < rhs + case let (lhs, rhs) as (any EnumPersistable, String): + return lhs.rawValue < rhs + case let (lhs, rhs) as (String, any EnumPersistable): + return lhs < rhs.rawValue + case let (lhs, rhs) as (any EnumPersistable, any EnumPersistable): + return lhs.rawValue < rhs.rawValue default: return false } @@ -179,6 +203,12 @@ struct PersistableHelper { return lhs >= Double(rhs) case let (lhs, rhs) as (String, String): return lhs >= rhs + case let (lhs, rhs) as (any EnumPersistable, String): + return lhs.rawValue >= rhs + case let (lhs, rhs) as (String, any EnumPersistable): + return lhs >= rhs.rawValue + case let (lhs, rhs) as (any EnumPersistable, any EnumPersistable): + return lhs.rawValue >= rhs.rawValue default: return false } @@ -207,6 +237,12 @@ struct PersistableHelper { return Double(lhs) > rhs case let (lhs, rhs) as (String, String): return lhs > rhs + case let (lhs, rhs) as (any EnumPersistable, String): + return lhs.rawValue > rhs + case let (lhs, rhs) as (String, any EnumPersistable): + return lhs > rhs.rawValue + case let (lhs, rhs) as (any EnumPersistable, any EnumPersistable): + return lhs.rawValue > rhs.rawValue default: return false } diff --git a/Amplify/Categories/DataStore/Query/QueryOperator.swift b/Amplify/Categories/DataStore/Query/QueryOperator.swift index abe59657d3..f9c4d855b3 100644 --- a/Amplify/Categories/DataStore/Query/QueryOperator.swift +++ b/Amplify/Categories/DataStore/Query/QueryOperator.swift @@ -19,7 +19,7 @@ public enum QueryOperator { case beginsWith(_ value: String) // swiftlint:disable:next cyclomatic_complexity - public func evaluate(target: Any) -> Bool { + public func evaluate(target: Any?) -> Bool { switch self { case .notEqual(let predicateValue): return !PersistableHelper.isEqual(target, predicateValue) @@ -34,14 +34,14 @@ public enum QueryOperator { case .greaterThan(let predicateValue): return PersistableHelper.isGreaterThan(target, predicateValue) case .contains(let predicateString): - if let targetString = target as? String { + if let targetString = target.flatMap({ $0 as? String }) { return targetString.contains(predicateString) } return false case .between(let start, let end): return PersistableHelper.isBetween(start, end, target) case .beginsWith(let predicateValue): - if let targetString = target as? String { + if let targetString = target.flatMap({ $0 as? String }) { return targetString.starts(with: predicateValue) } } diff --git a/Amplify/Categories/DataStore/Query/QueryPredicate.swift b/Amplify/Categories/DataStore/Query/QueryPredicate.swift index 622330d8b9..2802010c54 100644 --- a/Amplify/Categories/DataStore/Query/QueryPredicate.swift +++ b/Amplify/Categories/DataStore/Query/QueryPredicate.swift @@ -127,34 +127,6 @@ public class QueryPredicateOperation: QueryPredicate { } public func evaluate(target: Model) -> Bool { - guard let fieldValue = target[field] else { - return false - } - - guard let value = fieldValue else { - return false - } - - if let booleanValue = value as? Bool { - return self.operator.evaluate(target: booleanValue) - } - - if let doubleValue = value as? Double { - return self.operator.evaluate(target: doubleValue) - } - - if let intValue = value as? Int { - return self.operator.evaluate(target: intValue) - } - - if let timeValue = value as? Temporal.Time { - return self.operator.evaluate(target: timeValue) - } - - if let enumValue = value as? EnumPersistable { - return self.operator.evaluate(target: enumValue.rawValue) - } - - return self.operator.evaluate(target: value) + return self.operator.evaluate(target: target[field]?.flatMap { $0 }) } } diff --git a/AmplifyPlugins/Auth/AWSCognitoAuthPluginTests/AuthenticationProviderTests/AuthenticationProviderConfirmSignupTests.swift b/AmplifyPlugins/Auth/AWSCognitoAuthPluginTests/AuthenticationProviderTests/AuthenticationProviderConfirmSignupTests.swift index 8b474de1c0..1a62908c2a 100644 --- a/AmplifyPlugins/Auth/AWSCognitoAuthPluginTests/AuthenticationProviderTests/AuthenticationProviderConfirmSignupTests.swift +++ b/AmplifyPlugins/Auth/AWSCognitoAuthPluginTests/AuthenticationProviderTests/AuthenticationProviderConfirmSignupTests.swift @@ -26,7 +26,7 @@ class AuthenticationProviderConfirmSignupTests: BaseAuthenticationProviderTest { /// func testSuccessfulConfirmSignUp() { - let mockSignupResult = SignUpResult(signUpState: .confirmed, codeDeliveryDetails: nil) + let mockSignupResult = SignUpResult(signUpState: .confirmed, codeDeliveryDetails: nil, userSub: nil) mockAWSMobileClient?.confirmSignUpMockResult = .success(mockSignupResult) let resultExpectation = expectation(description: "Should receive a result") @@ -60,7 +60,7 @@ class AuthenticationProviderConfirmSignupTests: BaseAuthenticationProviderTest { /// func testConfirmSignUpWithEmptyUserName() { - let mockSignupResult = SignUpResult(signUpState: .confirmed, codeDeliveryDetails: nil) + let mockSignupResult = SignUpResult(signUpState: .confirmed, codeDeliveryDetails: nil, userSub: nil) mockAWSMobileClient?.confirmSignUpMockResult = .success(mockSignupResult) let resultExpectation = expectation(description: "Should receive a result") diff --git a/AmplifyPlugins/Auth/AWSCognitoAuthPluginTests/AuthenticationProviderTests/AuthenticationProviderResendSignupCodeTests.swift b/AmplifyPlugins/Auth/AWSCognitoAuthPluginTests/AuthenticationProviderTests/AuthenticationProviderResendSignupCodeTests.swift index 60aaf4fdd0..adbd4617bf 100644 --- a/AmplifyPlugins/Auth/AWSCognitoAuthPluginTests/AuthenticationProviderTests/AuthenticationProviderResendSignupCodeTests.swift +++ b/AmplifyPlugins/Auth/AWSCognitoAuthPluginTests/AuthenticationProviderTests/AuthenticationProviderResendSignupCodeTests.swift @@ -31,7 +31,7 @@ class AuthenticationProviderResendSignupCodeTests: BaseAuthenticationProviderTes destination: nil, attributeName: nil) let mockResendSignUpCodeResult = SignUpResult(signUpState: .confirmed, - codeDeliveryDetails: codeDelieveryDetails) + codeDeliveryDetails: codeDelieveryDetails, userSub: nil) mockAWSMobileClient?.resendSignUpCodeMockResult = .success(mockResendSignUpCodeResult) let resultExpectation = expectation(description: "Should receive a result") _ = plugin.resendSignUpCode(for: "username") { result in @@ -66,7 +66,8 @@ class AuthenticationProviderResendSignupCodeTests: BaseAuthenticationProviderTes destination: nil, attributeName: nil) let mockResendSignUpCodeResult = SignUpResult(signUpState: .confirmed, - codeDeliveryDetails: codeDelieveryDetails) + codeDeliveryDetails: codeDelieveryDetails, + userSub: nil) mockAWSMobileClient?.resendSignUpCodeMockResult = .success(mockResendSignUpCodeResult) let resultExpectation = expectation(description: "Should receive a result") _ = plugin.resendSignUpCode(for: "") { result in diff --git a/AmplifyPlugins/Auth/AWSCognitoAuthPluginTests/AuthenticationProviderTests/AuthenticationProviderSignupTests.swift b/AmplifyPlugins/Auth/AWSCognitoAuthPluginTests/AuthenticationProviderTests/AuthenticationProviderSignupTests.swift index df91f3953f..225dc036f2 100644 --- a/AmplifyPlugins/Auth/AWSCognitoAuthPluginTests/AuthenticationProviderTests/AuthenticationProviderSignupTests.swift +++ b/AmplifyPlugins/Auth/AWSCognitoAuthPluginTests/AuthenticationProviderTests/AuthenticationProviderSignupTests.swift @@ -26,7 +26,7 @@ class AuthenticationProviderSignupTests: BaseAuthenticationProviderTest { /// func testSignupWithSuccess() { - let mockSignupResult = SignUpResult(signUpState: .confirmed, codeDeliveryDetails: nil) + let mockSignupResult = SignUpResult(signUpState: .confirmed, codeDeliveryDetails: nil, userSub: nil) mockAWSMobileClient?.signupMockResult = .success(mockSignupResult) let emailAttribute = AuthUserAttribute(.email, value: "email") @@ -61,7 +61,7 @@ class AuthenticationProviderSignupTests: BaseAuthenticationProviderTest { /// func testSignupWithEmptyUserName() { - let mockSignupResult = SignUpResult(signUpState: .confirmed, codeDeliveryDetails: nil) + let mockSignupResult = SignUpResult(signUpState: .confirmed, codeDeliveryDetails: nil, userSub: nil) mockAWSMobileClient?.signupMockResult = .success(mockSignupResult) let emailAttribute = AuthUserAttribute(.email, value: "email") @@ -96,7 +96,7 @@ class AuthenticationProviderSignupTests: BaseAuthenticationProviderTest { /// func testSignupWithEmptyPassword() { - let mockSignupResult = SignUpResult(signUpState: .confirmed, codeDeliveryDetails: nil) + let mockSignupResult = SignUpResult(signUpState: .confirmed, codeDeliveryDetails: nil, userSub: nil) mockAWSMobileClient?.signupMockResult = .success(mockSignupResult) let emailAttribute = AuthUserAttribute(.email, value: "email") @@ -138,7 +138,11 @@ class AuthenticationProviderSignupTests: BaseAuthenticationProviderTest { let mockCodeDelivery = UserCodeDeliveryDetails(deliveryMedium: .email, destination: mockEmail, attributeName: "email") - let mockSignupResult = SignUpResult(signUpState: .unconfirmed, codeDeliveryDetails: mockCodeDelivery) + let mockSignupResult = SignUpResult( + signUpState: .unconfirmed, + codeDeliveryDetails: mockCodeDelivery, + userSub: nil + ) mockAWSMobileClient?.signupMockResult = .success(mockSignupResult) let resultExpectation = expectation(description: "Should receive a result") diff --git a/AmplifyPlugins/Auth/Podfile.lock b/AmplifyPlugins/Auth/Podfile.lock index 68377e8c3d..e898f59fb7 100644 --- a/AmplifyPlugins/Auth/Podfile.lock +++ b/AmplifyPlugins/Auth/Podfile.lock @@ -99,4 +99,4 @@ SPEC CHECKSUMS: PODFILE CHECKSUM: 371cf67fe35ebb5167d0880bad12b01618a0fb0e -COCOAPODS: 1.11.3 +COCOAPODS: 1.14.3 diff --git a/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedBoolTests.swift b/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedBoolTests.swift index e3e013d248..465ccfe256 100644 --- a/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedBoolTests.swift +++ b/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedBoolTests.swift @@ -41,7 +41,7 @@ class QueryPredicateEvaluateGeneratedBoolTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testBoolfalsenotEqualBooltrue() throws { @@ -70,7 +70,7 @@ class QueryPredicateEvaluateGeneratedBoolTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testBooltrueequalsBooltrue() throws { diff --git a/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedDateTests.swift b/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedDateTests.swift index ae7c9c8f34..5055bf2230 100644 --- a/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedDateTests.swift +++ b/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedDateTests.swift @@ -60,7 +60,7 @@ class QueryPredicateEvaluateGeneratedDateTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testTemporalDateTemporal_Date_now_addvalue1to_daynotEqualTemporalDateTemporal_Date_now() throws { @@ -109,7 +109,7 @@ class QueryPredicateEvaluateGeneratedDateTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testTemporalDateTemporal_Date_now_addvalue2to_daynotEqualTemporalDateTemporal_Date_now() throws { @@ -158,7 +158,7 @@ class QueryPredicateEvaluateGeneratedDateTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testTemporalDateTemporal_Date_now_addvalue3to_daynotEqualTemporalDateTemporal_Date_now() throws { @@ -207,7 +207,7 @@ class QueryPredicateEvaluateGeneratedDateTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testTemporalDateTemporal_Date_nowequalsTemporalDateTemporal_Date_now() throws { diff --git a/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedDateTimeTests.swift b/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedDateTimeTests.swift index 226b3d7908..14728e550a 100644 --- a/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedDateTimeTests.swift +++ b/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedDateTimeTests.swift @@ -66,7 +66,7 @@ class QueryPredicateEvaluateGeneratedDateTimeTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testTemporalDateTimeTemporal_DateTime_now_addvalue1to_hournotEqualTemporalDateTimeTemporal_DateTime_now() throws { @@ -120,7 +120,7 @@ class QueryPredicateEvaluateGeneratedDateTimeTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testTemporalDateTimeTemporal_DateTime_now_addvalue2to_hournotEqualTemporalDateTimeTemporal_DateTime_now() throws { @@ -174,7 +174,7 @@ class QueryPredicateEvaluateGeneratedDateTimeTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testTemporalDateTimeTemporal_DateTime_now_addvalue3to_hournotEqualTemporalDateTimeTemporal_DateTime_now() throws { @@ -228,7 +228,7 @@ class QueryPredicateEvaluateGeneratedDateTimeTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testTemporalDateTimeTemporal_DateTime_nowequalsTemporalDateTimeTemporal_DateTime_now() throws { diff --git a/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedDoubleIntTests.swift b/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedDoubleIntTests.swift index 31fe364268..8866439b17 100644 --- a/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedDoubleIntTests.swift +++ b/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedDoubleIntTests.swift @@ -50,7 +50,7 @@ class QueryPredicateEvaluateGeneratedDoubleIntTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testDouble2_1notEqualInt1() throws { @@ -89,7 +89,7 @@ class QueryPredicateEvaluateGeneratedDoubleIntTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testDouble3_1notEqualInt1() throws { @@ -128,7 +128,7 @@ class QueryPredicateEvaluateGeneratedDoubleIntTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testDouble1notEqualInt1() throws { @@ -167,7 +167,7 @@ class QueryPredicateEvaluateGeneratedDoubleIntTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testDouble2notEqualInt1() throws { @@ -206,7 +206,7 @@ class QueryPredicateEvaluateGeneratedDoubleIntTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testDouble3notEqualInt1() throws { @@ -245,7 +245,7 @@ class QueryPredicateEvaluateGeneratedDoubleIntTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testDouble1_1equalsInt1() throws { diff --git a/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedDoubleTests.swift b/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedDoubleTests.swift index fbc4c6566d..12f54572b6 100644 --- a/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedDoubleTests.swift +++ b/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedDoubleTests.swift @@ -80,7 +80,7 @@ class QueryPredicateEvaluateGeneratedDoubleTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testDouble2_1notEqualDouble1_1() throws { @@ -149,7 +149,7 @@ class QueryPredicateEvaluateGeneratedDoubleTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testDouble3_1notEqualDouble1_1() throws { @@ -218,7 +218,7 @@ class QueryPredicateEvaluateGeneratedDoubleTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testDouble1notEqualDouble1_1() throws { @@ -287,7 +287,7 @@ class QueryPredicateEvaluateGeneratedDoubleTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testDouble2notEqualDouble1_1() throws { @@ -356,7 +356,7 @@ class QueryPredicateEvaluateGeneratedDoubleTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testDouble3notEqualDouble1_1() throws { @@ -425,7 +425,7 @@ class QueryPredicateEvaluateGeneratedDoubleTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testDouble1_1equalsDouble1_1() throws { diff --git a/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedIntTests.swift b/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedIntTests.swift index 1e5c4ec370..315c648ae1 100644 --- a/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedIntTests.swift +++ b/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedIntTests.swift @@ -54,7 +54,7 @@ class QueryPredicateEvaluateGeneratedIntBetweenTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testInt2notEqualInt1() throws { @@ -93,7 +93,7 @@ class QueryPredicateEvaluateGeneratedIntBetweenTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testInt3notEqualInt1() throws { @@ -132,7 +132,7 @@ class QueryPredicateEvaluateGeneratedIntBetweenTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testInt1equalsInt1() throws { diff --git a/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedStringTests.swift b/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedStringTests.swift index bca4f41179..f785a73937 100644 --- a/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedStringTests.swift +++ b/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedStringTests.swift @@ -65,7 +65,7 @@ class QueryPredicateEvaluateGeneratedStringTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testStringbbnotEqualStringa() throws { @@ -114,7 +114,7 @@ class QueryPredicateEvaluateGeneratedStringTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testStringaanotEqualStringa() throws { @@ -163,7 +163,7 @@ class QueryPredicateEvaluateGeneratedStringTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testStringcnotEqualStringa() throws { @@ -212,7 +212,7 @@ class QueryPredicateEvaluateGeneratedStringTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testStringaequalsStringa() throws { diff --git a/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedTimeTests.swift b/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedTimeTests.swift index 13e03b330c..d692178d7c 100644 --- a/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedTimeTests.swift +++ b/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedTimeTests.swift @@ -70,7 +70,7 @@ class QueryPredicateEvaluateGeneratedTimeTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testTemporalTimeTemporal_Time_now_addvalue1to_hournotEqualTemporalTimeTemporal_Time_now() throws { @@ -124,7 +124,7 @@ class QueryPredicateEvaluateGeneratedTimeTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testTemporalTimeTemporal_Time_now_addvalue2to_hournotEqualTemporalTimeTemporal_Time_now() throws { @@ -178,7 +178,7 @@ class QueryPredicateEvaluateGeneratedTimeTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testTemporalTimeTemporal_Time_now_addvalue3to_hournotEqualTemporalTimeTemporal_Time_now() throws { @@ -232,7 +232,7 @@ class QueryPredicateEvaluateGeneratedTimeTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testTemporalTimeTemporal_Time_nowequalsTemporalTimeTemporal_Time_now() throws { From 18d3a76c2bf42f669361e5a8828c21a483ee1dcd Mon Sep 17 00:00:00 2001 From: Michael Law <1365977+lawmicha@users.noreply.github.com> Date: Fri, 2 Feb 2024 09:04:07 -0500 Subject: [PATCH 2/2] fix(DataStore-v1): auth plugin requirement for single auth rule (#3454) * fix(DataStore-v1): auth plugin requirement for single auth rule * fix code structure * Update AmplifyPlugins/DataStore/AWSDataStoreCategoryPlugin/Storage/StorageEngine+SyncRequirement.swift Co-authored-by: Sebastian Villena <97059974+ruisebas@users.noreply.github.com> --------- Co-authored-by: Sebastian Villena <97059974+ruisebas@users.noreply.github.com> --- .../StorageEngine+SyncRequirement.swift | 40 ++++++++-- .../StorageEngineSyncRequirementsTests.swift | 79 +++++++++++++++---- 2 files changed, 97 insertions(+), 22 deletions(-) diff --git a/AmplifyPlugins/DataStore/AWSDataStoreCategoryPlugin/Storage/StorageEngine+SyncRequirement.swift b/AmplifyPlugins/DataStore/AWSDataStoreCategoryPlugin/Storage/StorageEngine+SyncRequirement.swift index d036673735..d586a5e599 100644 --- a/AmplifyPlugins/DataStore/AWSDataStoreCategoryPlugin/Storage/StorageEngine+SyncRequirement.swift +++ b/AmplifyPlugins/DataStore/AWSDataStoreCategoryPlugin/Storage/StorageEngine+SyncRequirement.swift @@ -25,7 +25,10 @@ extension StorageEngine { )) } - let authPluginRequired = StorageEngine.requiresAuthPlugin(api) + let authPluginRequired = StorageEngine.requiresAuthPlugin( + api, + authModeStrategy: dataStoreConfiguration.authModeStrategyType + ) guard authPluginRequired else { syncEngine.start(api: api, auth: nil) @@ -81,20 +84,45 @@ extension StorageEngine { } } - static func requiresAuthPlugin(_ apiPlugin: APICategoryPlugin) -> Bool { + static func requiresAuthPlugin( + _ apiPlugin: APICategoryPlugin, + authModeStrategy: AuthModeStrategyType + ) -> Bool { let modelsRequireAuthPlugin = ModelRegistry.modelSchemas.contains { schema in guard schema.isSyncable else { return false } - return StorageEngine.requiresAuthPlugin(apiPlugin, authRules: schema.authRules) + return StorageEngine.requiresAuthPlugin(apiPlugin, + authRules: schema.authRules, + authModeStrategy: authModeStrategy) } return modelsRequireAuthPlugin } - static func requiresAuthPlugin(_ apiPlugin: APICategoryPlugin, authRules: [AuthRule]) -> Bool { - if let rulesRequireAuthPlugin = authRules.requireAuthPlugin { - return rulesRequireAuthPlugin + static func requiresAuthPlugin( + _ apiPlugin: APICategoryPlugin, + authRules: [AuthRule], + authModeStrategy: AuthModeStrategyType + ) -> Bool { + switch authModeStrategy { + case .default: + if authRules.isEmpty { + return false + } + // Only use the auth rule as determination for auth plugin requirement when there is + // exactly one. If there is more than one auth rule AND multi-auth is not enabled, + // then immediately fall back to using the default auth type configured on the APIPlugin because + // we do not have enough information to know which provider to use to make the determination. + if authRules.count == 1, + let singleAuthRule = authRules.first, + let ruleRequireAuthPlugin = singleAuthRule.requiresAuthPlugin { + return ruleRequireAuthPlugin + } + case .multiAuth: + if let rulesRequireAuthPlugin = authRules.requireAuthPlugin { + return rulesRequireAuthPlugin + } } // Fall back to the endpoint's auth type if a determination cannot be made from the auth rules. This can diff --git a/AmplifyPlugins/DataStore/AWSDataStoreCategoryPluginTests/Sync/StorageEngineSyncRequirementsTests.swift b/AmplifyPlugins/DataStore/AWSDataStoreCategoryPluginTests/Sync/StorageEngineSyncRequirementsTests.swift index e4b07e1f54..50b6dcad38 100644 --- a/AmplifyPlugins/DataStore/AWSDataStoreCategoryPluginTests/Sync/StorageEngineSyncRequirementsTests.swift +++ b/AmplifyPlugins/DataStore/AWSDataStoreCategoryPluginTests/Sync/StorageEngineSyncRequirementsTests.swift @@ -17,80 +17,91 @@ class StorageEngineSyncRequirementsTests: XCTestCase { func testRequiresAuthPluginFalseForMissingAuthRules() { let apiPlugin = MockAPICategoryPlugin() - let result = StorageEngine.requiresAuthPlugin(apiPlugin) - XCTAssertFalse(result) + XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authModeStrategy: .default)) + XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authModeStrategy: .multiAuth)) } func testRequiresAuthPluginSingleAuthRuleAPIKey() { let apiPlugin = MockAPICategoryPlugin() let authRules = [AuthRule(allow: .owner, provider: .apiKey)] - XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules)) + XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules, authModeStrategy: .default)) + XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules, authModeStrategy: .multiAuth)) } func testRequiresAuthPluginSingleAuthRuleOIDC() { let apiPlugin = MockAPICategoryPlugin() let authRules = [AuthRule(allow: .owner, provider: .oidc)] - XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules)) + XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules, authModeStrategy: .default)) + XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules, authModeStrategy: .multiAuth)) } func testRequiresAuthPluginSingleAuthRuleFunction() { let apiPlugin = MockAPICategoryPlugin() let authRules = [AuthRule(allow: .private, provider: .function)] - XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules)) + XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules, authModeStrategy: .default)) + XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules, authModeStrategy: .multiAuth)) } func testRequiresAuthPluginSingleAuthRuleUserPools() { let apiPlugin = MockAPICategoryPlugin() let authRules = [AuthRule(allow: .owner, provider: .userPools)] - XCTAssertTrue(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules)) + XCTAssertTrue(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules, authModeStrategy: .default)) + XCTAssertTrue(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules, authModeStrategy: .multiAuth)) } func testRequiresAuthPluginSingleAuthRuleIAM() { let apiPlugin = MockAPICategoryPlugin() let authRules = [AuthRule(allow: .owner, provider: .iam)] - XCTAssertTrue(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules)) + XCTAssertTrue(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules, authModeStrategy: .default)) + XCTAssertTrue(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules, authModeStrategy: .multiAuth)) } func testRequiresAuthPluginNoProvidersWithAuthTypeFunction() { let authRules = [AuthRule(allow: .owner)] let apiPlugin = MockAPIAuthInformationPlugin() apiPlugin.authType = .function - XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules)) + XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules, authModeStrategy: .default)) + XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules, authModeStrategy: .multiAuth)) } func testRequiresAuthPluginNoProvidersWithAuthTypeAPIKey() { let authRules = [AuthRule(allow: .owner)] let apiPlugin = MockAPIAuthInformationPlugin() apiPlugin.authType = .apiKey - XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules)) + XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules, authModeStrategy: .default)) + XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules, authModeStrategy: .multiAuth)) } func testRequiresAuthPluginNoProvidersWithAuthTypeUserPools() { let authRules = [AuthRule(allow: .owner)] let apiPlugin = MockAPIAuthInformationPlugin() apiPlugin.authType = .amazonCognitoUserPools - XCTAssertTrue(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules)) + XCTAssertTrue(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules, authModeStrategy: .default)) + XCTAssertTrue(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules, authModeStrategy: .multiAuth)) } func testRequiresAuthPluginNoProvidersWithAuthTypeIAM() { let authRules = [AuthRule(allow: .owner)] let apiPlugin = MockAPIAuthInformationPlugin() apiPlugin.authType = .awsIAM - XCTAssertTrue(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules)) + XCTAssertTrue(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules, authModeStrategy: .default)) + XCTAssertTrue(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules, authModeStrategy: .multiAuth)) } func testRequiresAuthPluginNoProvidersWithAuthTypeODIC() { let authRules = [AuthRule(allow: .owner)] let apiPlugin = MockAPIAuthInformationPlugin() apiPlugin.authType = .openIDConnect - XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules)) + XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules, authModeStrategy: .default)) + XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules, authModeStrategy: .multiAuth)) } func testRequiresAuthPluginNoProvidersWithAuthTypeNone() { let authRules = [AuthRule(allow: .owner)] let apiPlugin = MockAPIAuthInformationPlugin() apiPlugin.authType = AWSAuthorizationType.none - XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules)) + XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules, authModeStrategy: .default)) + XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules, authModeStrategy: .multiAuth)) } func testRequiresAuthPluginOIDCProvider() { @@ -99,7 +110,41 @@ class StorageEngineSyncRequirementsTests: XCTestCase { apiPlugin.defaultAuthTypeError = APIError.unknown("Could not get default auth type", "", nil) let oidcProvider = MockOIDCAuthProvider() apiPlugin.authProviderFactory = MockAPIAuthProviderFactory(oidcProvider: oidcProvider) - XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules)) + XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules, authModeStrategy: .default)) + XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules, authModeStrategy: .multiAuth)) + } + + func testRequiresAuthPluginOIDCProvider_MultiAuthRules() { + // OIDC requires an auth provider on the API, this is added below + let authRules = [AuthRule(allow: .owner, provider: .oidc), + AuthRule(allow: .private, provider: .iam)] + let apiPlugin = MockAPIAuthInformationPlugin() + apiPlugin.defaultAuthTypeError = APIError.unknown("Could not get default auth type", "", nil) + let oidcProvider = MockOIDCAuthProvider() + apiPlugin.authProviderFactory = MockAPIAuthProviderFactory(oidcProvider: oidcProvider) + XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, + authRules: authRules, + authModeStrategy: .default), + "Should be false since OIDC is the default auth type on the API.") + XCTAssertTrue(StorageEngine.requiresAuthPlugin(apiPlugin, + authRules: authRules, + authModeStrategy: .multiAuth), + "Should be true since IAM requires auth plugin.") + } + + func testRequiresAuthPluginUserPoolProvider_MultiAuthRules() { + let authRules = [AuthRule(allow: .owner, provider: .userPools), + AuthRule(allow: .private, provider: .iam)] + let apiPlugin = MockAPIAuthInformationPlugin() + apiPlugin.authType = AWSAuthorizationType.amazonCognitoUserPools + XCTAssertTrue(StorageEngine.requiresAuthPlugin(apiPlugin, + authRules: authRules, + authModeStrategy: .default), + "Should be true since UserPool is the default auth type on the API.") + XCTAssertTrue(StorageEngine.requiresAuthPlugin(apiPlugin, + authRules: authRules, + authModeStrategy: .multiAuth), + "Should be true since both UserPool and IAM requires auth plugin.") } func testRequiresAuthPluginFunctionProvider() { @@ -108,14 +153,16 @@ class StorageEngineSyncRequirementsTests: XCTestCase { apiPlugin.defaultAuthTypeError = APIError.unknown("Could not get default auth type", "", nil) let functionProvider = MockFunctionAuthProvider() apiPlugin.authProviderFactory = MockAPIAuthProviderFactory(functionProvider: functionProvider) - XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules)) + XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules, authModeStrategy: .default)) + XCTAssertFalse(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules, authModeStrategy: .multiAuth)) } func testRequiresAuthPluginWithAuthRules() { let authRules = [AuthRule(allow: .owner)] let apiPlugin = MockAPIAuthInformationPlugin() apiPlugin.defaultAuthTypeError = APIError.unknown("Could not get default auth type", "", nil) - XCTAssertTrue(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules)) + XCTAssertTrue(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules, authModeStrategy: .default)) + XCTAssertTrue(StorageEngine.requiresAuthPlugin(apiPlugin, authRules: authRules, authModeStrategy: .multiAuth)) } // MARK: - AuthRules tests