Skip to content

Commit

Permalink
Migration of the Federation resolver to use the translation layer (#4631
Browse files Browse the repository at this point in the history
)
  • Loading branch information
MacondoExpress authored Feb 1, 2024
1 parent 428a181 commit 7d6d196
Show file tree
Hide file tree
Showing 8 changed files with 48 additions and 88 deletions.
2 changes: 1 addition & 1 deletion packages/graphql/src/classes/Neo4jGraphQL.ts
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ class Neo4jGraphQL {
this._relationships = relationships;

// TODO: Move into makeAugmentedSchema, add resolvers alongside other resolvers
const referenceResolvers = subgraph.getReferenceResolvers(this._nodes, this.schemaModel);
const referenceResolvers = subgraph.getReferenceResolvers(this.schemaModel);

const schema = subgraph.buildSchema({
typeDefs,
Expand Down
18 changes: 7 additions & 11 deletions packages/graphql/src/classes/Subgraph.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import type {
import { Kind, parse, print } from "graphql";
import type { Neo4jGraphQLSchemaModel } from "../schema-model/Neo4jGraphQLSchemaModel";
import { translateResolveReference } from "../translate/translate-resolve-reference";
import type { Node } from "../types";
import { execute } from "../utils";
import getNeo4jResolveTree from "../utils/get-neo4j-resolve-tree";
import { isInArray } from "../utils/is-in-array";
Expand Down Expand Up @@ -106,7 +105,7 @@ export class Subgraph {
});
}

public getReferenceResolvers(nodes: Node[], schemaModel: Neo4jGraphQLSchemaModel): IResolvers {
public getReferenceResolvers(schemaModel: Neo4jGraphQLSchemaModel): IResolvers {
const resolverMap: IResolvers = {};

const document = mergeTypeDefs(this.typeDefs);
Expand All @@ -125,33 +124,30 @@ export class Subgraph {
}

resolverMap[def.name.value] = {
__resolveReference: this.getReferenceResolver(nodes),
__resolveReference: this.getReferenceResolver(schemaModel),
};
}
});

return resolverMap;
}

private getReferenceResolver(nodes: Node[]): ReferenceResolver {
private getReferenceResolver(schemaModel: Neo4jGraphQLSchemaModel): ReferenceResolver {
const __resolveReference = async (
reference,
context: Neo4jGraphQLComposedContext,
info: GraphQLResolveInfo
): Promise<unknown> => {
const { __typename } = reference;

const node = nodes.find((n) => n.name === __typename);

if (!node) {
throw new Error("Unable to find matching node");
const entityAdapter = schemaModel.getConcreteEntityAdapter(__typename);
if (!entityAdapter) {
throw new Error(`Unable to find matching entity with name ${__typename}`);
}

(context as Neo4jGraphQLTranslationContext).resolveTree = getNeo4jResolveTree(info);

const { cypher, params } = translateResolveReference({
context: context as Neo4jGraphQLTranslationContext,
node,
entityAdapter,
reference,
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,13 @@ export class OperationsFactory {
resolveTree,
context,
varName,
reference,
}: {
entity?: EntityAdapter;
resolveTree: ResolveTree;
context: Neo4jGraphQLTranslationContext;
varName?: string;
reference?: any;
}): Operation {
const operationMatch = parseTopLevelOperationField(resolveTree.name, context.schemaModel, entity);
if (!entity && operationMatch.isCustomCypher) {
Expand Down Expand Up @@ -154,6 +156,7 @@ export class OperationsFactory {
resolveTree,
context,
varName,
reference,
}) as ReadOperation;
}
op.nodeAlias = TOP_LEVEL_NODE_NAME;
Expand Down Expand Up @@ -203,7 +206,7 @@ export class OperationsFactory {
resolveTree: ResolveTree,
context: Neo4jGraphQLTranslationContext
): FulltextOperation {
let resolveTreeWhere: Record<string, any> = isObject(resolveTree.args.where) ? resolveTree.args.where : {};
let resolveTreeWhere: Record<string, any> = this.getWhereArgs(resolveTree);
let sortOptions: Record<string, any> = (resolveTree.args.options as Record<string, any>) || {};
let fieldsByTypeName = resolveTree.fieldsByTypeName;
let resolverArgs = resolveTree.args;
Expand Down Expand Up @@ -293,15 +296,17 @@ export class OperationsFactory {
resolveTree,
context,
varName,
reference,
}: {
entityOrRel: EntityAdapter | RelationshipAdapter;
resolveTree: ResolveTree;
context: Neo4jGraphQLTranslationContext;
varName?: string;
reference?: any;
}): ReadOperation | CompositeReadOperation {
const entity = entityOrRel instanceof RelationshipAdapter ? entityOrRel.target : entityOrRel;
const relationship = entityOrRel instanceof RelationshipAdapter ? entityOrRel : undefined;
const resolveTreeWhere: Record<string, any> = isObject(resolveTree.args.where) ? resolveTree.args.where : {};
const resolveTreeWhere: Record<string, any> = this.getWhereArgs(resolveTree, reference);

if (isConcreteEntity(entity)) {
checkEntityAuthentication({
Expand Down Expand Up @@ -408,7 +413,7 @@ export class OperationsFactory {
entity = entityOrRel;
}

const resolveTreeWhere = (resolveTree.args.where || {}) as Record<string, unknown>;
const resolveTreeWhere = this.getWhereArgs(resolveTree);

if (entityOrRel instanceof RelationshipAdapter) {
if (isConcreteEntity(entity)) {
Expand Down Expand Up @@ -499,7 +504,7 @@ export class OperationsFactory {

operation.setFields(fields);

const whereArgs = (resolveTree.args.where || {}) as Record<string, unknown>;
const whereArgs = this.getWhereArgs(resolveTree);
const authFilters = this.authorizationFactory.getAuthFilters({
entity,
operations: ["AGGREGATE"],
Expand Down Expand Up @@ -571,7 +576,7 @@ export class OperationsFactory {
throw new Error("Top-Level Connection are currently supported only for concrete entities");
}
const directed = Boolean(resolveTree.args.directed) ?? true;
const resolveTreeWhere: Record<string, any> = isObject(resolveTree.args.where) ? resolveTree.args.where : {};
const resolveTreeWhere: Record<string, any> = this.getWhereArgs(resolveTree);

let nodeWhere: Record<string, any>;
if (isInterfaceEntity(target)) {
Expand Down Expand Up @@ -628,7 +633,7 @@ export class OperationsFactory {
context: Neo4jGraphQLTranslationContext;
}): ConnectionReadOperation {
const directed = Boolean(resolveTree.args.directed) ?? true;
const resolveTreeWhere: Record<string, any> = isObject(resolveTree.args.where) ? resolveTree.args.where : {};
const resolveTreeWhere: Record<string, any> = this.getWhereArgs(resolveTree);
checkEntityAuthentication({
entity: target.entity,
targetOperations: ["READ"],
Expand Down Expand Up @@ -1411,4 +1416,14 @@ export class OperationsFactory {
}
}
}

private getWhereArgs(resolveTree: ResolveTree, reference?: any): Record<string, any> {
const whereArgs = isRecord(resolveTree.args.where) ? resolveTree.args.where : {};

if (resolveTree.name === "_entities" && reference) {
const { __typename, ...referenceWhere } = reference;
return { ...referenceWhere, ...whereArgs };
}
return whereArgs;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,19 @@ export class QueryASTFactory {
resolveTree,
entityAdapter,
context,
reference,
}: {
resolveTree: ResolveTree;
entityAdapter?: EntityAdapter;
context: Neo4jGraphQLTranslationContext;
reference?: any;
}): QueryAST {
const operation = this.operationsFactory.createTopLevelOperation({
entity: entityAdapter,
resolveTree,
context,
varName: "this",
reference,
});
return new QueryAST(operation);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ function parseOperationField(
): OperationFieldMatch {
const rootTypeFieldNames = entityAdapter.operations.rootTypeFieldNames;
return {
isRead: field === rootTypeFieldNames.read,
isRead: field === rootTypeFieldNames.read || field === "_entities",
isConnection: field === rootTypeFieldNames.connection,
isAggregation: field === rootTypeFieldNames.aggregate,
isCreate: field === rootTypeFieldNames.create,
Expand Down
2 changes: 0 additions & 2 deletions packages/graphql/src/translate/translate-read.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ export function translateRead(
): Cypher.CypherResult {
const { resolveTree } = context;
const operationsTreeFactory = new QueryASTFactory(context.schemaModel, context.experimental);

if (!entityAdapter) throw new Error("Entity not found");
const operationsTree = operationsTreeFactory.createQueryAST({resolveTree, entityAdapter, context});
debug(operationsTree.print());
const clause = operationsTree.build(context, varName);
Expand Down
79 changes: 14 additions & 65 deletions packages/graphql/src/translate/translate-resolve-reference.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,79 +17,28 @@
* limitations under the License.
*/

import type { Node } from "../classes";
import createProjectionAndParams from "./create-projection-and-params";
import { createMatchClause } from "./translate-top-level-match";
import Cypher from "@neo4j/cypher-builder";
import { compileCypher } from "../utils/compile-cypher";
import type Cypher from "@neo4j/cypher-builder";
import type { Neo4jGraphQLTranslationContext } from "../types/neo4j-graphql-translation-context";
import Debug from "debug";
import { QueryASTFactory } from "./queryAST/factory/QueryASTFactory";
import type { EntityAdapter } from "../schema-model/entity/EntityAdapter";
import { DEBUG_TRANSLATE } from "../constants";

const debug = Debug(DEBUG_TRANSLATE);

export function translateResolveReference({
node,
entityAdapter,
context,
reference,
}: {
context: Neo4jGraphQLTranslationContext;
node: Node;
entityAdapter: EntityAdapter;
reference: any;
}): Cypher.CypherResult {
const varName = "this";
const { resolveTree } = context;

const matchNode = new Cypher.NamedNode(varName, { labels: node.getLabels(context) });

const { __typename, ...where } = reference;

const {
matchClause: topLevelMatch,
preComputedWhereFieldSubqueries,
whereClause: topLevelWhereClause,
} = createMatchClause({
matchNode,
node,
context,
operation: "READ",
where,
});

const projection = createProjectionAndParams({
node,
context,
resolveTree,
varName: matchNode,
cypherFieldAliasMap: {},
});

let projAuth: Cypher.Clause | undefined;

const predicates: Cypher.Predicate[] = [];

predicates.push(...projection.predicates);

if (predicates.length) {
projAuth = new Cypher.With("*").where(Cypher.and(...predicates));
}

const projectionSubqueries = Cypher.concat(...projection.subqueries, ...projection.subqueriesBeforeSort);

const projectionExpression = new Cypher.Raw((env) => {
return [`${varName} ${compileCypher(projection.projection, env)}`, projection.params];
});

const returnClause = new Cypher.Return([projectionExpression, varName]);

const preComputedWhereFields =
preComputedWhereFieldSubqueries && !preComputedWhereFieldSubqueries.empty
? Cypher.concat(preComputedWhereFieldSubqueries, topLevelWhereClause)
: topLevelWhereClause;

const readQuery = Cypher.concat(
topLevelMatch,
preComputedWhereFields,
projAuth,
projectionSubqueries,
returnClause
);

return readQuery.build();
const operationsTreeFactory = new QueryASTFactory(context.schemaModel, context.experimental);
const operationsTree = operationsTreeFactory.createQueryAST({ resolveTree, entityAdapter, context, reference });
debug(operationsTree.print());
const clause = operationsTree.build(context);
return clause.build();
}
3 changes: 1 addition & 2 deletions packages/graphql/tests/tck/federation/authorization.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,8 @@ describe("Federation and authorization", () => {

expect(formatCypher(result.cypher)).toMatchInlineSnapshot(`
"MATCH (this:User)
WHERE this.id = $param0
WITH *
WHERE ($isAuthenticated = true AND ($jwt.sub IS NOT NULL AND this.id = $jwt.sub))
WHERE (this.id = $param0 AND ($isAuthenticated = true AND ($jwt.sub IS NOT NULL AND this.id = $jwt.sub)))
RETURN this { .id, .name, .password } AS this"
`);

Expand Down

0 comments on commit 7d6d196

Please sign in to comment.