diff --git a/Amplify/Categories/API/Request/GraphQLOperationRequest.swift b/Amplify/Categories/API/Request/GraphQLOperationRequest.swift index 99115342bb..2f5ebf1ed2 100644 --- a/Amplify/Categories/API/Request/GraphQLOperationRequest.swift +++ b/Amplify/Categories/API/Request/GraphQLOperationRequest.swift @@ -25,6 +25,9 @@ public struct GraphQLOperationRequest: AmplifyOperationRequest { /// The path to traverse before decoding to `responseType`. public let decodePath: String? + /// The authorization mode + public let authMode: AuthorizationMode? + /// Options to adjust the behavior of this request, including plugin-options public let options: Options @@ -35,6 +38,7 @@ public struct GraphQLOperationRequest: AmplifyOperationRequest { variables: [String: Any]? = nil, responseType: R.Type, decodePath: String? = nil, + authMode: AuthorizationMode? = nil, options: Options) { self.apiName = apiName self.operationType = operationType @@ -42,6 +46,7 @@ public struct GraphQLOperationRequest: AmplifyOperationRequest { self.variables = variables self.responseType = responseType self.decodePath = decodePath + self.authMode = authMode self.options = options } } diff --git a/Amplify/Categories/API/Request/GraphQLRequest.swift b/Amplify/Categories/API/Request/GraphQLRequest.swift index 5c566d2fca..ba0086de66 100644 --- a/Amplify/Categories/API/Request/GraphQLRequest.swift +++ b/Amplify/Categories/API/Request/GraphQLRequest.swift @@ -5,6 +5,9 @@ // SPDX-License-Identifier: Apache-2.0 // +/// Empty protocol for plugins to define specific `AuthorizationMode` types for the request. +public protocol AuthorizationMode { } + /// GraphQL Request public struct GraphQLRequest { @@ -21,6 +24,9 @@ public struct GraphQLRequest { /// Type to decode the graphql response data object to public let responseType: R.Type + /// The authorization mode + public let authMode: AuthorizationMode? + /// The path to decode to the graphQL response data to `responseType`. Delimited by `.` The decode path /// "listTodos.items" will traverse to the object at `listTodos`, and decode the object at `items` to `responseType` /// The data at that decode path is a list of Todo objects so `responseType` should be `[Todo].self` @@ -34,11 +40,13 @@ public struct GraphQLRequest { variables: [String: Any]? = nil, responseType: R.Type, decodePath: String? = nil, + authMode: AuthorizationMode? = nil, options: GraphQLRequest.Options? = nil) { self.apiName = apiName self.document = document self.variables = variables self.responseType = responseType + self.authMode = authMode self.decodePath = decodePath self.options = options } diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Operation/AWSGraphQLOperation.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Operation/AWSGraphQLOperation.swift index d57c2ba1c4..b3a61608fb 100644 --- a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Operation/AWSGraphQLOperation.swift +++ b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Operation/AWSGraphQLOperation.swift @@ -46,7 +46,7 @@ final public class AWSGraphQLOperation: GraphQLOperation { } let urlRequest = validateRequest(request).flatMap(buildURLRequest(from:)) - let finalRequest = await getEndpointInterceptors(from: request).flatMapAsync { requestInterceptors in + let finalRequest = await getEndpointInterceptors().flatMapAsync { requestInterceptors in let preludeInterceptors = requestInterceptors?.preludeInterceptors ?? [] let customerInterceptors = requestInterceptors?.interceptors ?? [] let postludeInterceptors = requestInterceptors?.postludeInterceptors ?? [] @@ -150,7 +150,7 @@ final public class AWSGraphQLOperation: GraphQLOperation { } } - private func getEndpointInterceptors(from request: GraphQLOperationRequest) -> Result { + func getEndpointInterceptors() -> Result { getEndpointConfig(from: request).flatMap { endpointConfig in do { if let pluginOptions = request.options.pluginOptions as? AWSAPIPluginDataStoreOptions, @@ -159,6 +159,11 @@ final public class AWSGraphQLOperation: GraphQLOperation { withConfig: endpointConfig, authType: authType )) + } else if let authType = request.authMode as? AWSAuthorizationType { + return .success(try pluginConfig.interceptorsForEndpoint( + withConfig: endpointConfig, + authType: authType + )) } else { return .success(pluginConfig.interceptorsForEndpoint(withConfig: endpointConfig)) } diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Operation/AWSGraphQLSubscriptionTaskRunner.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Operation/AWSGraphQLSubscriptionTaskRunner.swift index 12427ad9ab..44e2cf378d 100644 --- a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Operation/AWSGraphQLSubscriptionTaskRunner.swift +++ b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Operation/AWSGraphQLSubscriptionTaskRunner.swift @@ -91,14 +91,21 @@ public class AWSGraphQLSubscriptionTaskRunner: InternalTaskRunner, return } - let pluginOptions = request.options.pluginOptions as? AWSAPIPluginDataStoreOptions + let authType: AWSAuthorizationType? + if let pluginOptions = request.options.pluginOptions as? AWSAPIPluginDataStoreOptions { + authType = pluginOptions.authType + } else if let authorizationMode = request.authMode as? AWSAuthorizationType { + authType = authorizationMode + } else { + authType = nil + } // Retrieve the subscription connection do { self.appSyncClient = try await appSyncClientFactory.getAppSyncRealTimeClient( for: endpointConfig, endpoint: endpointConfig.baseURL, authService: authService, - authType: pluginOptions?.authType, + authType: authType, apiAuthProviderFactory: apiAuthProviderFactory ) @@ -262,14 +269,21 @@ final public class AWSGraphQLSubscriptionOperation: GraphQLSubscri return } - let pluginOptions = request.options.pluginOptions as? AWSAPIPluginDataStoreOptions + let authType: AWSAuthorizationType? + if let pluginOptions = request.options.pluginOptions as? AWSAPIPluginDataStoreOptions { + authType = pluginOptions.authType + } else if let authorizationMode = request.authMode as? AWSAuthorizationType { + authType = authorizationMode + } else { + authType = nil + } Task { do { appSyncRealTimeClient = try await appSyncRealTimeClientFactory.getAppSyncRealTimeClient( for: endpointConfig, endpoint: endpointConfig.baseURL, authService: authService, - authType: pluginOptions?.authType, + authType: authType, apiAuthProviderFactory: apiAuthProviderFactory ) diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Support/Utils/GraphQLRequest+toOperationRequest.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Support/Utils/GraphQLRequest+toOperationRequest.swift index 1079685b66..ba518bbe6e 100644 --- a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Support/Utils/GraphQLRequest+toOperationRequest.swift +++ b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Support/Utils/GraphQLRequest+toOperationRequest.swift @@ -16,6 +16,7 @@ extension GraphQLRequest { variables: variables, responseType: responseType, decodePath: decodePath, + authMode: authMode, options: requestOptions) } } diff --git a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/AWSAPICategoryPlugin+GraphQLBehaviorTests.swift b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/AWSAPICategoryPlugin+GraphQLBehaviorTests.swift index 098d3ae447..16b4ff573c 100644 --- a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/AWSAPICategoryPlugin+GraphQLBehaviorTests.swift +++ b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/AWSAPICategoryPlugin+GraphQLBehaviorTests.swift @@ -8,6 +8,7 @@ import XCTest import Amplify @testable import AWSAPIPlugin +import AWSPluginsCore class AWSAPICategoryPluginGraphQLBehaviorTests: AWSAPICategoryPluginTestBase { @@ -15,10 +16,11 @@ class AWSAPICategoryPluginGraphQLBehaviorTests: AWSAPICategoryPluginTestBase { func testQuery() { let operationFinished = expectation(description: "Operation should finish") - let request = GraphQLRequest(apiName: apiName, - document: testDocument, - variables: nil, - responseType: JSONValue.self) + let request = GraphQLRequest(apiName: apiName, + document: testDocument, + variables: nil, + responseType: JSONValue.self, + authMode: AWSAuthorizationType.apiKey) let operation = apiPlugin.query(request: request) { _ in operationFinished.fulfill() } diff --git a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Operation/AWSGraphQLOperationTests.swift b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Operation/AWSGraphQLOperationTests.swift index 93af2539b2..717a87f5ab 100644 --- a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Operation/AWSGraphQLOperationTests.swift +++ b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Operation/AWSGraphQLOperationTests.swift @@ -9,6 +9,8 @@ import XCTest @testable import Amplify @testable import AmplifyTestCommon @testable import AWSAPIPlugin +@testable import AWSPluginsTestCommon +import AWSPluginsCore class AWSGraphQLOperationTests: AWSAPICategoryPluginTestBase { @@ -37,4 +39,44 @@ class AWSGraphQLOperationTests: AWSAPICategoryPluginTestBase { XCTAssertNil(task) } + + /// Request for `.amazonCognitoUserPool` at runtime with `request` while passing in what + /// is configured as `.apiKey`. Expect that the interceptor is the token interceptor + func testGetEndpointInterceptors() throws { + let request = GraphQLRequest(apiName: apiName, + document: testDocument, + variables: nil, + responseType: JSONValue.self, + authMode: AWSAuthorizationType.amazonCognitoUserPools) + let task = try OperationTestBase.makeSingleValueErrorMockTask() + let mockSession = MockURLSession(onTaskForRequest: { _ in task }) + let pluginConfig = AWSAPICategoryPluginConfiguration( + endpoints: [ + apiName: try .init( + name: apiName, + baseURL: URL(string: "url")!, + region: "us-test-1", + authorizationType: .apiKey, + endpointType: .graphQL, + apiKey: "apiKey", + apiAuthProviderFactory: .init())], + apiAuthProviderFactory: .init(), + authService: MockAWSAuthService()) + let operation = AWSGraphQLOperation(request: request.toOperationRequest(operationType: .query), + session: mockSession, + mapper: OperationTaskMapper(), + pluginConfig: pluginConfig, + resultListener: { _ in }) + + // Act + let results = operation.getEndpointInterceptors() + + // Assert + guard case let .success(interceptors) = results, + let interceptor = interceptors?.preludeInterceptors.first, + (interceptor as? AuthTokenURLRequestInterceptor) != nil else { + XCTFail("Should be token interceptor for Cognito User Pool") + return + } + } } diff --git a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Operation/OperationTestBase.swift b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Operation/OperationTestBase.swift index 26d9014091..b36ea8ccf0 100644 --- a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Operation/OperationTestBase.swift +++ b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Operation/OperationTestBase.swift @@ -59,7 +59,7 @@ class OperationTestBase: XCTestCase { } func setUpPluginForSingleError(for endpointType: AWSAPICategoryPluginEndpointType) throws { - let task = try makeSingleValueErrorMockTask() + let task = try Self.makeSingleValueErrorMockTask() let mockSession = MockURLSession(onTaskForRequest: { _ in task }) let sessionFactory = MockSessionFactory(returning: mockSession) try setUpPlugin(sessionFactory: sessionFactory, endpointType: endpointType) @@ -102,7 +102,7 @@ class OperationTestBase: XCTestCase { return task } - func makeSingleValueErrorMockTask() throws -> MockURLSessionTask { + static func makeSingleValueErrorMockTask() throws -> MockURLSessionTask { var mockTask: MockURLSessionTask! mockTask = MockURLSessionTask(onResume: { guard let mockSession = mockTask.mockSession, diff --git a/AmplifyPlugins/Core/AWSPluginsCore/Auth/AWSAuthorizationType.swift b/AmplifyPlugins/Core/AWSPluginsCore/Auth/AWSAuthorizationType.swift index 0530183e3c..465cc35337 100644 --- a/AmplifyPlugins/Core/AWSPluginsCore/Auth/AWSAuthorizationType.swift +++ b/AmplifyPlugins/Core/AWSPluginsCore/Auth/AWSAuthorizationType.swift @@ -6,6 +6,7 @@ // import Foundation +import Amplify // swiftlint:disable line_length @@ -13,7 +14,7 @@ import Foundation /// GraphQL backend, or an Amazon API Gateway endpoint. /// /// - SeeAlso: [https://docs.aws.amazon.com/appsync/latest/devguide/security.html](AppSync Security) -public enum AWSAuthorizationType: String { +public enum AWSAuthorizationType: String, AuthorizationMode { /// For public APIs case none = "NONE" diff --git a/AmplifyPlugins/Core/AWSPluginsCore/Model/GraphQLRequest/GraphQLRequest+Model.swift b/AmplifyPlugins/Core/AWSPluginsCore/Model/GraphQLRequest/GraphQLRequest+Model.swift index 7229c4ea1b..7338fab830 100644 --- a/AmplifyPlugins/Core/AWSPluginsCore/Model/GraphQLRequest/GraphQLRequest+Model.swift +++ b/AmplifyPlugins/Core/AWSPluginsCore/Model/GraphQLRequest/GraphQLRequest+Model.swift @@ -35,7 +35,8 @@ protocol ModelGraphQLRequestFactory { static func list(_ modelType: M.Type, where predicate: QueryPredicate?, includes: IncludedAssociations, - limit: Int?) -> GraphQLRequest> + limit: Int?, + authMode: AWSAuthorizationType?) -> GraphQLRequest> /// Creates a `GraphQLRequest` that represents a query that expects a single value as a result. /// The request will be created with the correct correct document based on the `ModelSchema` and @@ -50,16 +51,19 @@ protocol ModelGraphQLRequestFactory { /// - seealso: `GraphQLQuery`, `GraphQLQueryType.get` static func get(_ modelType: M.Type, byId id: String, - includes: IncludedAssociations) -> GraphQLRequest + includes: IncludedAssociations, + authMode: AWSAuthorizationType?) -> GraphQLRequest static func get(_ modelType: M.Type, byIdentifier id: String, - includes: IncludedAssociations) -> GraphQLRequest + includes: IncludedAssociations, + authMode: AWSAuthorizationType?) -> GraphQLRequest where M: ModelIdentifiable, M.IdentifierFormat == ModelIdentifierFormat.Default static func get(_ modelType: M.Type, byIdentifier id: ModelIdentifier, - includes: IncludedAssociations) -> GraphQLRequest + includes: IncludedAssociations, + authMode: AWSAuthorizationType?) -> GraphQLRequest where M: ModelIdentifiable // MARK: Mutation @@ -76,7 +80,8 @@ protocol ModelGraphQLRequestFactory { modelSchema: ModelSchema, where predicate: QueryPredicate?, includes: IncludedAssociations, - type: GraphQLMutationType) -> GraphQLRequest + type: GraphQLMutationType, + authMode: AWSAuthorizationType?) -> GraphQLRequest /// Creates a `GraphQLRequest` that represents a create mutation /// for a given `model` instance. @@ -85,7 +90,9 @@ protocol ModelGraphQLRequestFactory { /// - model: the model instance populated with values /// - Returns: a valid `GraphQLRequest` instance /// - seealso: `GraphQLRequest.mutation(of:where:type:)` - static func create(_ model: M, includes: IncludedAssociations) -> GraphQLRequest + static func create(_ model: M, + includes: IncludedAssociations, + authMode: AWSAuthorizationType?) -> GraphQLRequest /// Creates a `GraphQLRequest` that represents an update mutation /// for a given `model` instance. @@ -97,7 +104,8 @@ protocol ModelGraphQLRequestFactory { /// - seealso: `GraphQLRequest.mutation(of:where:type:)` static func update(_ model: M, where predicate: QueryPredicate?, - includes: IncludedAssociations) -> GraphQLRequest + includes: IncludedAssociations, + authMode: AWSAuthorizationType?) -> GraphQLRequest /// Creates a `GraphQLRequest` that represents a delete mutation /// for a given `model` instance. @@ -109,7 +117,8 @@ protocol ModelGraphQLRequestFactory { /// - seealso: `GraphQLRequest.mutation(of:where:type:)` static func delete(_ model: M, where predicate: QueryPredicate?, - includes: IncludedAssociations) -> GraphQLRequest + includes: IncludedAssociations, + authMode: AWSAuthorizationType?) -> GraphQLRequest // MARK: Subscription @@ -125,7 +134,8 @@ protocol ModelGraphQLRequestFactory { /// - seealso: `GraphQLSubscription`, `GraphQLSubscriptionType` static func subscription(of: M.Type, type: GraphQLSubscriptionType, - includes: IncludedAssociations) -> GraphQLRequest + includes: IncludedAssociations, + authMode: AWSAuthorizationType?) -> GraphQLRequest } // MARK: - Extension @@ -141,52 +151,97 @@ extension GraphQLRequest: ModelGraphQLRequestFactory { return modelType.schema } - public static func create(_ model: M, includes: IncludedAssociations = { _ in [] }) -> GraphQLRequest { - return create(model, modelSchema: modelSchema(for: model), includes: includes) + public static func create( + _ model: M, + includes: IncludedAssociations = { _ in [] }, + authMode: AWSAuthorizationType? = nil) -> GraphQLRequest { + return create( + model, + modelSchema: modelSchema(for: model), + includes: includes, + authMode: authMode) } public static func update(_ model: M, where predicate: QueryPredicate? = nil, - includes: IncludedAssociations = { _ in [] }) -> GraphQLRequest { - return update(model, modelSchema: modelSchema(for: model), where: predicate, includes: includes) + includes: IncludedAssociations = { _ in [] }, + authMode: AWSAuthorizationType? = nil) -> GraphQLRequest { + return update( + model, + modelSchema: modelSchema(for: model), + where: predicate, + includes: includes, + authMode: authMode) } public static func delete(_ model: M, where predicate: QueryPredicate? = nil, - includes: IncludedAssociations = { _ in [] }) -> GraphQLRequest { - return delete(model, modelSchema: modelSchema(for: model), where: predicate, includes: includes) + includes: IncludedAssociations = { _ in [] }, + authMode: AWSAuthorizationType? = nil) -> GraphQLRequest { + return delete( + model, + modelSchema: modelSchema(for: model), + where: predicate, + includes: includes, + authMode: authMode) } - public static func create(_ model: M, modelSchema: ModelSchema, includes: IncludedAssociations = { _ in [] }) -> GraphQLRequest { - return mutation(of: model, modelSchema: modelSchema, includes: includes, type: .create) + public static func create(_ model: M, + modelSchema: ModelSchema, + includes: IncludedAssociations = { _ in [] }, + authMode: AWSAuthorizationType? = nil) -> GraphQLRequest { + return mutation(of: model, + modelSchema: modelSchema, + includes: includes, + type: .create, + authMode: authMode) } public static func update(_ model: M, modelSchema: ModelSchema, where predicate: QueryPredicate? = nil, - includes: IncludedAssociations = { _ in [] }) -> GraphQLRequest { - return mutation(of: model, modelSchema: modelSchema, where: predicate, includes: includes, type: .update) + includes: IncludedAssociations = { _ in [] }, + authMode: AWSAuthorizationType? = nil) -> GraphQLRequest { + return mutation(of: model, + modelSchema: modelSchema, + where: predicate, + includes: includes, + type: .update, + authMode: authMode) } public static func delete(_ model: M, modelSchema: ModelSchema, where predicate: QueryPredicate? = nil, - includes: IncludedAssociations = { _ in [] }) -> GraphQLRequest { - return mutation(of: model, modelSchema: modelSchema, where: predicate, includes: includes, type: .delete) + includes: IncludedAssociations = { _ in [] }, + authMode: AWSAuthorizationType? = nil) -> GraphQLRequest { + return mutation(of: model, + modelSchema: modelSchema, + where: predicate, + includes: includes, + type: .delete, + authMode: authMode) } public static func mutation(of model: M, where predicate: QueryPredicate? = nil, includes: IncludedAssociations = { _ in [] }, - type: GraphQLMutationType) -> GraphQLRequest { - mutation(of: model, modelSchema: model.schema, where: predicate, includes: includes, type: type) + type: GraphQLMutationType, + authMode: AWSAuthorizationType? = nil) -> GraphQLRequest { + mutation(of: model, + modelSchema: model.schema, + where: predicate, + includes: includes, + type: type, + authMode: authMode) } public static func mutation(of model: M, modelSchema: ModelSchema, where predicate: QueryPredicate? = nil, includes: IncludedAssociations = { _ in [] }, - type: GraphQLMutationType) -> GraphQLRequest { + type: GraphQLMutationType, + authMode: AWSAuthorizationType? = nil) -> GraphQLRequest { var documentBuilder = ModelBasedGraphQLDocumentBuilder(modelSchema: modelSchema, operationType: .mutation) documentBuilder.add(decorator: DirectiveNameDecorator(type: type)) @@ -216,12 +271,14 @@ extension GraphQLRequest: ModelGraphQLRequestFactory { return GraphQLRequest(document: document.stringValue, variables: document.variables, responseType: M.self, - decodePath: document.name) + decodePath: document.name, + authMode: authMode) } public static func get(_ modelType: M.Type, byId id: String, - includes: IncludedAssociations = { _ in [] }) -> GraphQLRequest { + includes: IncludedAssociations = { _ in [] }, + authMode: AWSAuthorizationType? = nil) -> GraphQLRequest { var documentBuilder = ModelBasedGraphQLDocumentBuilder(modelSchema: modelType.schema, operationType: .query) documentBuilder.add(decorator: DirectiveNameDecorator(type: .get)) @@ -237,19 +294,22 @@ extension GraphQLRequest: ModelGraphQLRequestFactory { return GraphQLRequest(document: document.stringValue, variables: document.variables, responseType: M?.self, - decodePath: document.name) + decodePath: document.name, + authMode: authMode) } public static func get(_ modelType: M.Type, byIdentifier id: String, - includes: IncludedAssociations = { _ in [] }) -> GraphQLRequest + includes: IncludedAssociations = { _ in [] }, + authMode: AWSAuthorizationType? = nil) -> GraphQLRequest where M: ModelIdentifiable, M.IdentifierFormat == ModelIdentifierFormat.Default { - return .get(modelType, byId: id, includes: includes) + return .get(modelType, byId: id, includes: includes, authMode: authMode) } public static func get(_ modelType: M.Type, byIdentifier id: ModelIdentifier, - includes: IncludedAssociations = { _ in [] }) -> GraphQLRequest + includes: IncludedAssociations = { _ in [] }, + authMode: AWSAuthorizationType? = nil) -> GraphQLRequest where M: ModelIdentifiable { var documentBuilder = ModelBasedGraphQLDocumentBuilder(modelSchema: modelType.schema, operationType: .query) @@ -265,13 +325,15 @@ extension GraphQLRequest: ModelGraphQLRequestFactory { return GraphQLRequest(document: document.stringValue, variables: document.variables, responseType: M?.self, - decodePath: document.name) + decodePath: document.name, + authMode: authMode) } public static func list(_ modelType: M.Type, where predicate: QueryPredicate? = nil, includes: IncludedAssociations = { _ in [] }, - limit: Int? = nil) -> GraphQLRequest> { + limit: Int? = nil, + authMode: AWSAuthorizationType? = nil) -> GraphQLRequest> { let primaryKeysOnly = (M.rootPath != nil) ? true : false var documentBuilder = ModelBasedGraphQLDocumentBuilder(modelSchema: modelType.schema, operationType: .query) @@ -292,12 +354,14 @@ extension GraphQLRequest: ModelGraphQLRequestFactory { return GraphQLRequest>(document: document.stringValue, variables: document.variables, responseType: List.self, - decodePath: document.name) + decodePath: document.name, + authMode: authMode) } public static func subscription(of modelType: M.Type, type: GraphQLSubscriptionType, - includes: IncludedAssociations = { _ in [] }) -> GraphQLRequest { + includes: IncludedAssociations = { _ in [] }, + authMode: AWSAuthorizationType? = nil) -> GraphQLRequest { var documentBuilder = ModelBasedGraphQLDocumentBuilder(modelSchema: modelType.schema, operationType: .subscription) documentBuilder.add(decorator: DirectiveNameDecorator(type: type)) @@ -312,6 +376,7 @@ extension GraphQLRequest: ModelGraphQLRequestFactory { return GraphQLRequest(document: document.stringValue, variables: document.variables, responseType: modelType, - decodePath: document.name) + decodePath: document.name, + authMode: authMode) } } diff --git a/AmplifyPlugins/Core/AWSPluginsCoreTests/Model/GraphQLRequest/GraphQLRequestModelTests.swift b/AmplifyPlugins/Core/AWSPluginsCoreTests/Model/GraphQLRequest/GraphQLRequestModelTests.swift index 63ff15e50a..332ec0b328 100644 --- a/AmplifyPlugins/Core/AWSPluginsCoreTests/Model/GraphQLRequest/GraphQLRequestModelTests.swift +++ b/AmplifyPlugins/Core/AWSPluginsCoreTests/Model/GraphQLRequest/GraphQLRequestModelTests.swift @@ -29,11 +29,12 @@ class GraphQLRequestModelTest: XCTestCase { documentBuilder.add(decorator: ModelDecorator(model: post, mutationType: .create)) let document = documentBuilder.build() - let request = GraphQLRequest.create(post) + let request = GraphQLRequest.create(post, authMode: .amazonCognitoUserPools) XCTAssertEqual(document.stringValue, request.document) XCTAssert(request.responseType == Post.self) XCTAssert(request.variables != nil) + assertEquals(actualAuthMode: request.authMode, expectedAuthMode: .amazonCognitoUserPools) } func testUpdateMutationGraphQLRequest() { @@ -43,11 +44,12 @@ class GraphQLRequestModelTest: XCTestCase { documentBuilder.add(decorator: ModelDecorator(model: post, mutationType: .update)) let document = documentBuilder.build() - let request = GraphQLRequest.update(post) + let request = GraphQLRequest.update(post, authMode: .amazonCognitoUserPools) XCTAssertEqual(document.stringValue, request.document) XCTAssert(request.responseType == Post.self) XCTAssert(request.variables != nil) + assertEquals(actualAuthMode: request.authMode, expectedAuthMode: .amazonCognitoUserPools) } func testDeleteMutationGraphQLRequest() { @@ -57,11 +59,12 @@ class GraphQLRequestModelTest: XCTestCase { documentBuilder.add(decorator: ModelDecorator(model: post, mutationType: .delete)) let document = documentBuilder.build() - let request = GraphQLRequest.delete(post) + let request = GraphQLRequest.delete(post, authMode: .amazonCognitoUserPools) XCTAssertEqual(document.stringValue, request.document) XCTAssert(request.responseType == Post.self) XCTAssert(request.variables != nil) + assertEquals(actualAuthMode: request.authMode, expectedAuthMode: .amazonCognitoUserPools) } func testQueryByIdGraphQLRequest() { @@ -70,11 +73,12 @@ class GraphQLRequestModelTest: XCTestCase { documentBuilder.add(decorator: ModelIdDecorator(id: "id")) let document = documentBuilder.build() - let request = GraphQLRequest.get(Post.self, byId: "id") + let request = GraphQLRequest.get(Post.self, byId: "id", authMode: .amazonCognitoUserPools) XCTAssertEqual(document.stringValue, request.document) XCTAssert(request.responseType == Post?.self) XCTAssert(request.variables != nil) + assertEquals(actualAuthMode: request.authMode, expectedAuthMode: .amazonCognitoUserPools) } func testListQueryGraphQLRequest() { @@ -87,11 +91,12 @@ class GraphQLRequestModelTest: XCTestCase { documentBuilder.add(decorator: PaginationDecorator()) let document = documentBuilder.build() - let request = GraphQLRequest.list(Post.self, where: predicate) + let request = GraphQLRequest.list(Post.self, where: predicate, authMode: .amazonCognitoUserPools) XCTAssertEqual(document.stringValue, request.document) XCTAssert(request.responseType == List.self) XCTAssertNotNil(request.variables) + assertEquals(actualAuthMode: request.authMode, expectedAuthMode: .amazonCognitoUserPools) } func testPaginatedListQueryGraphQLRequest() { @@ -104,11 +109,12 @@ class GraphQLRequestModelTest: XCTestCase { documentBuilder.add(decorator: PaginationDecorator(limit: 10)) let document = documentBuilder.build() - let request = GraphQLRequest.list(Post.self, where: predicate, limit: 10) + let request = GraphQLRequest.list(Post.self, where: predicate, limit: 10, authMode: .amazonCognitoUserPools) XCTAssertEqual(document.stringValue, request.document) XCTAssert(request.responseType == List.self) XCTAssertNotNil(request.variables) + assertEquals(actualAuthMode: request.authMode, expectedAuthMode: .amazonCognitoUserPools) } func testOnCreateSubscriptionGraphQLRequest() { @@ -116,11 +122,11 @@ class GraphQLRequestModelTest: XCTestCase { documentBuilder.add(decorator: DirectiveNameDecorator(type: .onCreate)) let document = documentBuilder.build() - let request = GraphQLRequest.subscription(of: Post.self, type: .onCreate) + let request = GraphQLRequest.subscription(of: Post.self, type: .onCreate, authMode: .amazonCognitoUserPools) XCTAssertEqual(document.stringValue, request.document) XCTAssert(request.responseType == Post.self) - + assertEquals(actualAuthMode: request.authMode, expectedAuthMode: .amazonCognitoUserPools) } func testOnUpdateSubscriptionGraphQLRequest() { @@ -128,10 +134,11 @@ class GraphQLRequestModelTest: XCTestCase { documentBuilder.add(decorator: DirectiveNameDecorator(type: .onUpdate)) let document = documentBuilder.build() - let request = GraphQLRequest.subscription(of: Post.self, type: .onUpdate) + let request = GraphQLRequest.subscription(of: Post.self, type: .onUpdate, authMode: .amazonCognitoUserPools) XCTAssertEqual(document.stringValue, request.document) XCTAssert(request.responseType == Post.self) + assertEquals(actualAuthMode: request.authMode, expectedAuthMode: .amazonCognitoUserPools) } func testOnDeleteSubscriptionGraphQLRequest() { @@ -139,9 +146,20 @@ class GraphQLRequestModelTest: XCTestCase { documentBuilder.add(decorator: DirectiveNameDecorator(type: .onDelete)) let document = documentBuilder.build() - let request = GraphQLRequest.subscription(of: Post.self, type: .onDelete) + let request = GraphQLRequest.subscription(of: Post.self, type: .onDelete, authMode: .amazonCognitoUserPools) XCTAssertEqual(document.stringValue, request.document) XCTAssert(request.responseType == Post.self) + assertEquals(actualAuthMode: request.authMode, expectedAuthMode: .amazonCognitoUserPools) + } + + // MARK: - Helpers + + func assertEquals(actualAuthMode: AuthorizationMode?, expectedAuthMode: AWSAuthorizationType) { + guard let authMode = actualAuthMode as? AWSAuthorizationType else { + XCTFail("Missing authorizationMode on request") + return + } + XCTAssertEqual(authMode, expectedAuthMode) } }