Skip to content

Commit

Permalink
Merge pull request #6003 from neo4j/count-aggregation
Browse files Browse the repository at this point in the history
Count aggregation
  • Loading branch information
angrykoala authored Feb 17, 2025
2 parents d672232 + 7a68d7a commit 128051d
Show file tree
Hide file tree
Showing 119 changed files with 3,737 additions and 993 deletions.
32 changes: 32 additions & 0 deletions .changeset/hot-bikes-complain.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
---
"@neo4j/graphql": patch
---

Add count fields in aggregations with support for nodes and edges count:

```graphql
query {
moviesConnection {
aggregate {
count {
nodes
}
}
}
}
```

```graphql
query {
movies {
actorsConnection {
aggregate {
count {
nodes
edges
}
}
}
}
}
```
2 changes: 1 addition & 1 deletion packages/graphql/src/graphql/objects/CartesianPoint.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ export const CartesianPoint = new GraphQLObjectType({
},
srid: {
type: new GraphQLNonNull(GraphQLInt),
resolve: (source, args, context, info) => numericalResolver(source, args, context, info),
resolve: numericalResolver,
},
},
});
2 changes: 1 addition & 1 deletion packages/graphql/src/graphql/objects/Point.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ export const Point = new GraphQLObjectType({
},
srid: {
type: new GraphQLNonNull(GraphQLInt),
resolve: (source, args, context, info) => numericalResolver(source, args, context, info),
resolve: numericalResolver,
},
},
});
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,15 @@ export class ImplementingEntityOperations<T extends InterfaceEntityAdapter | Con
return `${this.entityAdapter.name}ImplementationsSubscriptionWhere`;
}

/** @deprecated use `getAggregateFieldTypename` instead */
public getAggregationFieldTypename(): string {
return this.aggregateTypeNames.selection;
}

public getAggregateFieldTypename(): string {
return this.aggregateTypeNames.selection;
}

public get rootTypeFieldNames(): RootTypeFieldNames {
return {
connection: `${this.entityAdapter.plural}Connection`,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,6 @@ export class RelationshipAdapter {
return false;
}

if (this.target instanceof UnionEntityAdapter || this.source instanceof InterfaceEntityAdapter) {
return false;
}

return this.annotations.selectable?.onAggregate !== false;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ export abstract class RelationshipBaseOperations<T extends RelationshipAdapter |

protected abstract get edgePrefix(): string;

/**Note: Required for now to infer the types without ResolveTree */
/**Note: Required for now to infer the types without ResolveTree
* @deprecated use getAggregateFieldTypename
*
*/
public getAggregationFieldTypename(nestedField?: "node" | "edge"): string {
const nestedFieldStr = upperFirst(nestedField || "");
const aggregationStr = nestedField ? "Aggregate" : "Aggregation";
Expand All @@ -48,6 +51,13 @@ export abstract class RelationshipBaseOperations<T extends RelationshipAdapter |
)}${nestedFieldStr}${aggregationStr}Selection`;
}

public getAggregateFieldTypename(nestedField?: "node" | "edge"): string {
const nestedFieldStr = upperFirst(nestedField || "");
return `${this.relationship.source.name}${this.relationship.target.name}${upperFirst(
this.relationship.name
)}${nestedFieldStr}AggregateSelection`;
}

public getTargetTypePrettyName(): string {
if (this.relationship.isList) {
return `[${this.relationship.target.name}!]${!this.relationship.isNullable ? "!" : ""}`;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import { RelationshipAdapter } from "../../schema-model/relationship/model-adapt
import type { RelationshipDeclarationAdapter } from "../../schema-model/relationship/model-adapters/RelationshipDeclarationAdapter";
import type { Neo4jFeaturesSettings } from "../../types";
import { DEPRECATE_ID_AGGREGATION } from "../constants";
import { getCountConnectionType } from "../generation/aggregate-types";
import { shouldAddDeprecatedFields } from "../generation/utils";
import { numericalResolver } from "../resolvers/field/numerical";
import { AggregationTypesMapper } from "./aggregation-types-mapper";
Expand Down Expand Up @@ -81,6 +82,15 @@ export class FieldAggregationComposer {
);
}

this.composer.createObjectTC({
name: relationshipAdapter.operations.getAggregateFieldTypename(),
fields: {
count: getCountConnectionType(this.composer).NonNull,
...(aggregateSelectionNode ? { node: aggregateSelectionNode } : {}),
...(aggregateSelectionEdge ? { edge: aggregateSelectionEdge } : {}),
},
});

return this.composer.createObjectTC({
name: relationshipAdapter.operations.getAggregationFieldTypename(),
fields: {
Expand Down
65 changes: 51 additions & 14 deletions packages/graphql/src/schema/generation/aggregate-types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,36 @@ export function withAggregateSelectionType({
return aggregateSelection;
}

/** Top level count */
export function getCountType(composer: SchemaComposer): ObjectTypeComposer {
const countFieldName = "Count";
return composer.getOrCreateOTC(countFieldName, (countField) => {
countField.addFields({
nodes: {
type: new GraphQLNonNull(GraphQLInt),
resolve: numericalResolver,
},
});
});
}

/** Nested count */
export function getCountConnectionType(composer: SchemaComposer): ObjectTypeComposer {
const countFieldName = "CountConnection";
return composer.getOrCreateOTC(countFieldName, (countField) => {
countField.addFields({
nodes: {
type: new GraphQLNonNull(GraphQLInt),
resolve: numericalResolver,
},
edges: {
type: new GraphQLNonNull(GraphQLInt),
resolve: numericalResolver,
},
});
});
}

/** Create aggregate field inside connections */
function createConnectionAggregate({
entityAdapter,
Expand All @@ -88,26 +118,33 @@ function createConnectionAggregate({
composer: SchemaComposer;
features: Neo4jFeaturesSettings | undefined;
}): ObjectTypeComposer {
const aggregateNode = composer.createObjectTC({
name: entityAdapter.operations.aggregateTypeNames.node,
fields: {
count: {
type: new GraphQLNonNull(GraphQLInt),
resolve: numericalResolver,
args: {},
},
},
directives: graphqlDirectivesToCompose(propagatedDirectives),
});
aggregateNode.addFields(makeAggregableFields({ entityAdapter, aggregationTypesMapper, features }));
const aggregableFields = makeAggregableFields({ entityAdapter, aggregationTypesMapper, features });
let aggregateNode: ObjectTypeComposer | undefined;
const hasNodeAggregateFields = Object.keys(aggregableFields).length > 0;
if (hasNodeAggregateFields) {
aggregateNode = composer.createObjectTC({
name: entityAdapter.operations.aggregateTypeNames.node,
fields: {},
directives: graphqlDirectivesToCompose(propagatedDirectives),
});
aggregateNode.addFields(aggregableFields);
}

return composer.createObjectTC({
const connectionAggregate = composer.createObjectTC({
name: entityAdapter.operations.aggregateTypeNames.connection,
fields: {
node: aggregateNode.NonNull,
count: getCountType(composer).NonNull,
},
directives: graphqlDirectivesToCompose(propagatedDirectives),
});

if (aggregateNode) {
connectionAggregate.addFields({
node: aggregateNode.NonNull,
});
}

return connectionAggregate;
}

function makeAggregableFields({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ export function withConnectionObjectType({

if (relationshipAdapter.isAggregable() && !isTargetUnion && !isSourceInterface) {
connectionObjectType.addFields({
aggregate: composer.getOTC(relationshipAdapter.operations.getAggregationFieldTypename()).NonNull,
aggregate: composer.getOTC(relationshipAdapter.operations.getAggregateFieldTypename()).NonNull,
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,22 @@ import { AggregationField } from "./AggregationField";

export class CountField extends AggregationField {
private entity: Entity;
public edgeVar: Cypher.Variable | undefined;

constructor({ alias, entity }: { alias: string; entity: Entity }) {
private countFields: { nodes: boolean; edges: boolean };

constructor({
alias,
entity,
fields,
}: {
alias: string;
entity: Entity;
fields: { nodes: boolean; edges: boolean };
}) {
super(alias);
this.entity = entity;
this.countFields = fields;
}

public getChildren(): QueryASTNode[] {
Expand All @@ -43,6 +55,20 @@ export class CountField extends AggregationField {
}

public getAggregationProjection(target: Cypher.Variable, returnVar: Cypher.Variable): Cypher.Clause {
return new Cypher.Return([this.getAggregationExpr(target), returnVar]);
const resultMap = new Cypher.Map();

if (this.countFields.nodes) {
resultMap.set("nodes", this.getAggregationExpr(target));
}
if (this.countFields.edges) {
if (!this.edgeVar) {
throw new Error(
"Edge variable not defined in Count field. This is likely a bug with the GraphQL library."
);
}
resultMap.set("edges", this.getAggregationExpr(this.edgeVar));
}

return new Cypher.Return([resultMap, returnVar]);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright (c) "Neo4j"
* Neo4j Sweden AB [http://neo4j.com]
*
* This file is part of Neo4j.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import Cypher from "@neo4j/cypher-builder";
import type { Entity } from "../../../../../schema-model/entity/Entity";
import type { QueryASTNode } from "../../QueryASTNode";
import { AggregationField } from "./AggregationField";

export class DeprecatedCountField extends AggregationField {
private entity: Entity;

constructor({ alias, entity }: { alias: string; entity: Entity }) {
super(alias);
this.entity = entity;
}

public getChildren(): QueryASTNode[] {
return [];
}

public getProjectionField(variable: Cypher.Variable): Record<string, Cypher.Expr> {
return { [this.alias]: variable };
}

public getAggregationExpr(variable: Cypher.Variable): Cypher.Expr {
return Cypher.count(variable);
}

public getAggregationProjection(target: Cypher.Variable, returnVar: Cypher.Variable): Cypher.Clause {
return new Cypher.Return([this.getAggregationExpr(target), returnVar]);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import { wrapSubqueriesInCypherCalls } from "../../utils/wrap-subquery-in-calls"
import { QueryASTContext } from "../QueryASTContext";
import type { QueryASTNode } from "../QueryASTNode";
import type { AggregationField } from "../fields/aggregation-fields/AggregationField";
import { CountField } from "../fields/aggregation-fields/CountField";
import type { Filter } from "../filters/Filter";
import type { AuthorizationFilters } from "../filters/authorization-filters/AuthorizationFilters";
import type { EntitySelection } from "../selection/EntitySelection";
Expand Down Expand Up @@ -172,12 +173,7 @@ export class AggregationOperation extends Operation {
const nodeMap = new Cypher.Map();
const fieldSubqueries = this.fields.map((f) => {
const returnVariable = new Cypher.Variable();
if (this.isInConnectionField) {
// Default fields are in node in connection translation
nodeMap.set(f.getProjectionField(returnVariable));
} else {
this.aggregationProjectionMap.set(f.getProjectionField(returnVariable));
}
this.aggregationProjectionMap.set(f.getProjectionField(returnVariable));
return this.createSubquery(f, pattern, returnVariable, context);
});

Expand Down Expand Up @@ -238,6 +234,10 @@ export class AggregationOperation extends Operation {
}
}

if (field instanceof CountField) {
field.edgeVar = nestedContext.relationship;
}

const ret = this.getFieldProjectionClause(targetVar, returnVariable, field);

return Cypher.utils.concat(matchClause, ...selectionClauses, ...nestedSubqueries, extraSelectionWith, ret);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,7 @@ export class CompositeAggregationOperation extends Operation {
private transpileAggregationOperation(context: QueryASTContext, addWith = true): OperationTranspileResult {
this.addWith = addWith;

let fieldSubqueries: Cypher.CompositeClause[];
if (this.isInConnectionField) {
fieldSubqueries = this.createSubqueries(this.fields, context, this.nodeMap);
} else {
// NOTE: this is to support deprecated aggregations
fieldSubqueries = this.createSubqueries(this.fields, context, this.aggregationProjectionMap);
}
const fieldSubqueries = this.createSubqueries(this.fields, context, this.aggregationProjectionMap);

const nodeFieldSubqueries = this.createSubqueries(this.nodeFields, context, this.nodeMap);
const edgeFieldSubqueries = this.createSubqueries(
Expand Down
Loading

0 comments on commit 128051d

Please sign in to comment.