From 80382838f7c57e1c934abd8597d40ae469cc84eb Mon Sep 17 00:00:00 2001 From: Natalia Gavrilenko Date: Tue, 5 Nov 2024 13:46:05 +0100 Subject: [PATCH] Support explicitly defined offsets in aggregate type --- .../expression/ExpressionFactory.java | 11 +- .../expression/misc/ConstructExpr.java | 4 +- .../expression/misc/ExtractExpr.java | 16 ++- .../processing/ExprTransformer.java | 2 +- .../expression/type/AggregateType.java | 43 +++++-- .../expression/type/TypeFactory.java | 106 +++++++++++++--- .../dartagnan/expression/type/TypeLayout.java | 62 ---------- .../dartagnan/expression/type/TypeOffset.java | 36 ++---- .../parsers/program/visitors/VisitorLlvm.java | 11 +- .../visitors/spirv/VisitorOpsAnnotation.java | 21 +++- .../visitors/spirv/VisitorOpsConstant.java | 13 +- .../visitors/spirv/VisitorOpsType.java | 25 +++- .../visitors/spirv/VisitorSpirvInput.java | 9 +- .../spirv/builders/DecorationsBuilder.java | 9 +- .../visitors/spirv/decorations/Offset.java | 31 +++++ .../VisitorExtensionClspvReflection.java | 90 +++++++------- .../visitors/spirv/helpers/HelperInputs.java | 9 +- .../visitors/spirv/helpers/HelperTypes.java | 20 ++- .../com/dat3m/dartagnan/program/Program.java | 9 +- .../program/memory/MemoryObject.java | 2 +- .../program/processing/Intrinsics.java | 5 +- .../expression/type/AggregateTypeTest.java | 115 ++++++++++++++++++ .../spirv/VisitorOpsConstantTest.java | 31 +++-- .../visitors/spirv/VisitorOpsMemoryTest.java | 53 ++++---- .../visitors/spirv/VisitorOpsTypeTest.java | 38 ++++-- .../VisitorExtensionClspvReflectionTest.java | 73 ++++++++--- .../spirv/mocks/MockProgramBuilder.java | 18 ++- .../dartagnan/spirv/header/AbstractTest.java | 5 + .../dartagnan/spirv/header/BadIndexTest.java | 2 +- .../spirv/basic/array-of-vector1.spv.dis | 2 + .../resources/spirv/basic/mixed-size.spv.dis | 11 ++ 31 files changed, 587 insertions(+), 295 deletions(-) delete mode 100644 dartagnan/src/main/java/com/dat3m/dartagnan/expression/type/TypeLayout.java create mode 100644 dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/decorations/Offset.java create mode 100644 dartagnan/src/test/java/com/dat3m/dartagnan/expression/type/AggregateTypeTest.java diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/expression/ExpressionFactory.java b/dartagnan/src/main/java/com/dat3m/dartagnan/expression/ExpressionFactory.java index 90d8481105..79cc12a5bb 100644 --- a/dartagnan/src/main/java/com/dat3m/dartagnan/expression/ExpressionFactory.java +++ b/dartagnan/src/main/java/com/dat3m/dartagnan/expression/ExpressionFactory.java @@ -257,8 +257,7 @@ public Expression makeFloatCast(Expression operand, FloatType targetType, boolea // ----------------------------------------------------------------------------------------------------------------- // Aggregates - public Expression makeConstruct(List arguments) { - final AggregateType type = types.getAggregateType(arguments.stream().map(Expression::getType).toList()); + public Expression makeConstruct(Type type, List arguments) { return new ConstructExpr(type, arguments); } @@ -302,11 +301,11 @@ public Expression makeGeneralZero(Type type) { } return makeArray(arrayType.getElementType(), zeroes, true); } else if (type instanceof AggregateType structType) { - List zeroes = new ArrayList<>(structType.getDirectFields().size()); - for (Type fieldType : structType.getDirectFields()) { - zeroes.add(makeGeneralZero(fieldType)); + List zeroes = new ArrayList<>(structType.getTypeOffsets().size()); + for (TypeOffset typeOffset : structType.getTypeOffsets()) { + zeroes.add(makeGeneralZero(typeOffset.type())); } - return makeConstruct(zeroes); + return makeConstruct(structType, zeroes); } else if (type instanceof IntegerType intType) { return makeZero(intType); } else if (type instanceof BooleanType) { diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/expression/misc/ConstructExpr.java b/dartagnan/src/main/java/com/dat3m/dartagnan/expression/misc/ConstructExpr.java index f118bffba7..05eee31598 100644 --- a/dartagnan/src/main/java/com/dat3m/dartagnan/expression/misc/ConstructExpr.java +++ b/dartagnan/src/main/java/com/dat3m/dartagnan/expression/misc/ConstructExpr.java @@ -7,6 +7,7 @@ import com.dat3m.dartagnan.expression.base.NaryExpressionBase; import com.dat3m.dartagnan.expression.type.AggregateType; import com.dat3m.dartagnan.expression.type.ArrayType; +import com.dat3m.dartagnan.expression.type.TypeOffset; import java.util.List; import java.util.stream.Collectors; @@ -20,7 +21,8 @@ public ConstructExpr(Type type, List arguments) { checkArgument(type instanceof AggregateType || type instanceof ArrayType, "Non-constructible type %s.", type); checkArgument(!(type instanceof AggregateType a) || - arguments.stream().map(Expression::getType).toList().equals(a.getDirectFields()), + arguments.stream().map(Expression::getType).toList() + .equals(a.getTypeOffsets().stream().map(TypeOffset::type).toList()), "Arguments do not match the constructor signature."); checkArgument(!(type instanceof ArrayType a) || !a.hasKnownNumElements() || diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/expression/misc/ExtractExpr.java b/dartagnan/src/main/java/com/dat3m/dartagnan/expression/misc/ExtractExpr.java index cbae78305c..663502d8f9 100644 --- a/dartagnan/src/main/java/com/dat3m/dartagnan/expression/misc/ExtractExpr.java +++ b/dartagnan/src/main/java/com/dat3m/dartagnan/expression/misc/ExtractExpr.java @@ -7,8 +7,11 @@ import com.dat3m.dartagnan.expression.base.UnaryExpressionBase; import com.dat3m.dartagnan.expression.type.AggregateType; import com.dat3m.dartagnan.expression.type.ArrayType; +import com.dat3m.dartagnan.expression.type.TypeOffset; import com.google.common.base.Preconditions; +import java.util.List; + import static com.google.common.base.Preconditions.checkArgument; public final class ExtractExpr extends UnaryExpressionBase { @@ -25,13 +28,14 @@ private static Type extractType(Expression expr, int index) { Preconditions.checkArgument(exprType instanceof AggregateType || exprType instanceof ArrayType, "Cannot extract from a non-aggregate expression: (%s)[%d].", expr, index); if (exprType instanceof AggregateType aggregateType) { - return aggregateType.getDirectFields().get(index); - } else { - final ArrayType arrayType = (ArrayType) exprType; - checkArgument(0 <= index && (!arrayType.hasKnownNumElements() || index < arrayType.getNumElements()), - "Index %s out of bounds [0,%s].", index, arrayType.getNumElements() - 1); - return arrayType.getElementType(); + final List typeOffsets = aggregateType.getTypeOffsets(); + checkArgument(0 <= index && index < typeOffsets.size()); + return typeOffsets.get(index).type(); } + final ArrayType arrayType = (ArrayType) exprType; + checkArgument(0 <= index && (!arrayType.hasKnownNumElements() || index < arrayType.getNumElements()), + "Index %s out of bounds [0,%s].", index, arrayType.getNumElements() - 1); + return arrayType.getElementType(); } public int getFieldIndex() { diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/expression/processing/ExprTransformer.java b/dartagnan/src/main/java/com/dat3m/dartagnan/expression/processing/ExprTransformer.java index 59452dfd08..385f4df364 100644 --- a/dartagnan/src/main/java/com/dat3m/dartagnan/expression/processing/ExprTransformer.java +++ b/dartagnan/src/main/java/com/dat3m/dartagnan/expression/processing/ExprTransformer.java @@ -85,7 +85,7 @@ public Expression visitConstructExpression(ConstructExpr construct) { for (final Expression argument : construct.getOperands()) { arguments.add(argument.accept(this)); } - return expressions.makeConstruct(arguments); + return expressions.makeConstruct(construct.getType(), arguments); } @Override diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/expression/type/AggregateType.java b/dartagnan/src/main/java/com/dat3m/dartagnan/expression/type/AggregateType.java index ed85391e8e..f8ac95a9d5 100644 --- a/dartagnan/src/main/java/com/dat3m/dartagnan/expression/type/AggregateType.java +++ b/dartagnan/src/main/java/com/dat3m/dartagnan/expression/type/AggregateType.java @@ -2,33 +2,58 @@ import com.dat3m.dartagnan.expression.Type; +import java.util.ArrayList; import java.util.List; import java.util.stream.Collectors; +import java.util.stream.IntStream; -public final class AggregateType implements Type { +import static com.dat3m.dartagnan.expression.type.TypeFactory.paddedSize; - private final List fields; +public class AggregateType implements Type { - AggregateType(List directFields) { - this.fields = List.copyOf(directFields); + private static final TypeFactory types = TypeFactory.getInstance(); + + private final List directFields; + + AggregateType(List fields) { + this(fields, computeDefaultOffsets(fields)); + } + + AggregateType(List fields, List offsets) { + this.directFields = IntStream.range(0, fields.size()).boxed().map(i -> new TypeOffset(fields.get(i), offsets.get(i))).toList(); + } + + private static List computeDefaultOffsets(List fields) { + List offsets = new ArrayList<>(); + int offset = 0; + if (!fields.isEmpty()) { + offset = types.getMemorySizeInBytes(fields.get(0)); + offsets.add(0); + } + for (int i = 1; i < fields.size(); i++) { + offset = paddedSize(offset, types.getAlignment(fields.get(i))); + offsets.add(offset); + offset += types.getMemorySizeInBytes(fields.get(i)); + } + return offsets; } - public List getDirectFields() { - return fields; + public List getTypeOffsets() { + return directFields; } @Override public int hashCode() { - return fields.hashCode(); + return directFields.hashCode(); } @Override public boolean equals(Object obj) { - return this == obj || obj instanceof AggregateType o && fields.equals(o.fields); + return this == obj || obj instanceof AggregateType o && directFields.equals(o.directFields); } @Override public String toString() { - return fields.stream().map(Type::toString).collect(Collectors.joining(", ", "{ ", " }")); + return directFields.stream().map(f -> f.offset() + ": " + f.type()).collect(Collectors.joining(", ", "{ ", " }")); } } diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/expression/type/TypeFactory.java b/dartagnan/src/main/java/com/dat3m/dartagnan/expression/type/TypeFactory.java index 148b4cb13d..f098931604 100644 --- a/dartagnan/src/main/java/com/dat3m/dartagnan/expression/type/TypeFactory.java +++ b/dartagnan/src/main/java/com/dat3m/dartagnan/expression/type/TypeFactory.java @@ -2,10 +2,11 @@ import com.dat3m.dartagnan.expression.Type; import com.dat3m.dartagnan.utils.Normalizer; +import com.google.common.math.IntMath; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; +import java.math.RoundingMode; +import java.util.*; +import java.util.stream.IntStream; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; @@ -74,6 +75,22 @@ public AggregateType getAggregateType(List fields) { return typeNormalizer.normalize(new AggregateType(fields)); } + public AggregateType getAggregateType(List fields, List offsets) { + checkNotNull(fields); + checkNotNull(offsets); + checkArgument(fields.stream().noneMatch(t -> t == voidType), "Void fields are not allowed"); + checkArgument(fields.size() == offsets.size(), "Offsets number does not match the fields number"); + checkArgument(offsets.stream().noneMatch(o -> o < 0), "Offset cannot be negative"); + checkArgument(offsets.isEmpty() || offsets.get(0) == 0, "The first offset must be zero"); + checkArgument(IntStream.range(1, offsets.size()).boxed().allMatch( + i -> offsets.get(i) >= offsets.get(i - 1) + Integer.max(0, getMemorySizeInBytes(fields.get(i - 1), false))), + "Offset is too small"); + checkArgument(IntStream.range(0, offsets.size() - 1).boxed().allMatch( + i -> getMemorySizeInBytes(fields.get(i)) > 0), + "Non-last element with unknown size"); + return typeNormalizer.normalize(new AggregateType(fields, offsets)); + } + public ArrayType getArrayType(Type element) { return typeNormalizer.normalize(new ArrayType(element, -1)); } @@ -92,7 +109,63 @@ public IntegerType getByteType() { } public int getMemorySizeInBytes(Type type) { - return TypeLayout.of(type).totalSizeInBytes(); + return getMemorySizeInBytes(type, true); + } + + public int getMemorySizeInBytes(Type type, boolean padded) { + if (type instanceof BooleanType) { + return 1; + } + if (type instanceof IntegerType integerType) { + return IntMath.divide(integerType.getBitWidth(), 8, RoundingMode.CEILING); + } + if (type instanceof FloatType floatType) { + return IntMath.divide(floatType.getBitWidth(), 8, RoundingMode.CEILING); + } + if (type instanceof ArrayType arrayType) { + if (arrayType.hasKnownNumElements()) { + Type elType = arrayType.getElementType(); + return getMemorySizeInBytes(elType) * arrayType.getNumElements(); + } + return -1; + } + if (type instanceof AggregateType aType) { + List typeOffsets = aType.getTypeOffsets(); + if (typeOffsets.isEmpty()) { + return 0; + } + if (aType.getTypeOffsets().stream().anyMatch(o -> getMemorySizeInBytes(o.type()) < 0)) { + return -1; + } + TypeOffset lastTypeOffset = typeOffsets.get(typeOffsets.size() - 1); + int baseSize = lastTypeOffset.offset() + getMemorySizeInBytes(lastTypeOffset.type()); + if (padded) { + return paddedSize(baseSize, getAlignment(type)); + } + return baseSize; + } + throw new UnsupportedOperationException("Cannot compute memory layout of type " + type); + } + + public int getAlignment(Type type) { + if (type instanceof BooleanType || type instanceof IntegerType || type instanceof FloatType) { + return getMemorySizeInBytes(type); + } + if (type instanceof ArrayType arrayType) { + return getMemorySizeInBytes(arrayType.getElementType()); + } + if (type instanceof AggregateType aType) { + return aType.getTypeOffsets().stream().map(o -> getAlignment(o.type())).max(Integer::compare).orElseThrow(); + } + throw new UnsupportedOperationException("Cannot compute memory layout of type " + type); + } + + public static int paddedSize(int size, int alignment) { + int mod = size % alignment; + if (mod > 0) { + return size + alignment - mod; + } + return size; } public int getMemorySizeInBits(Type type) { @@ -119,16 +192,13 @@ public Map decomposeIntoPrimitives(Type type) { } } } else if (type instanceof AggregateType aggregateType) { - final List fields = aggregateType.getDirectFields(); - for (int i = 0; i < fields.size(); i++) { - final int offset = getOffsetInBytes(aggregateType, i); - final Map innerDecomposition = decomposeIntoPrimitives(fields.get(i)); + for (TypeOffset typeOffset : aggregateType.getTypeOffsets()) { + final Map innerDecomposition = decomposeIntoPrimitives(typeOffset.type()); if (innerDecomposition == null) { return null; } - for (Map.Entry entry : innerDecomposition.entrySet()) { - decomposition.put(entry.getKey() + offset, entry.getValue()); + decomposition.put(typeOffset.offset() + entry.getKey(), entry.getValue()); } } } else { @@ -147,12 +217,7 @@ public static boolean isStaticType(Type type) { return aType.hasKnownNumElements() && isStaticType(aType.getElementType()); } if (type instanceof AggregateType aType) { - for (Type elType : aType.getDirectFields()) { - if (!isStaticType(elType)) { - return false; - } - } - return true; + return aType.getTypeOffsets().stream().allMatch(o -> isStaticType(o.type())); } throw new UnsupportedOperationException("Cannot compute if type '" + type + "' is static"); } @@ -162,12 +227,15 @@ public static boolean isStaticTypeOf(Type staticType, Type runtimeType) { return true; } if (staticType instanceof AggregateType aStaticType && runtimeType instanceof AggregateType aRuntimeType) { - int size = aStaticType.getDirectFields().size(); - if (size != aRuntimeType.getDirectFields().size()) { + int size = aStaticType.getTypeOffsets().size(); + if (size != aRuntimeType.getTypeOffsets().size()) { return false; } for (int i = 0; i < size; i++) { - if (!isStaticTypeOf(aStaticType.getDirectFields().get(i), aRuntimeType.getDirectFields().get(i))) { + TypeOffset staticTypeOffset = aStaticType.getTypeOffsets().get(i); + TypeOffset runtimeTypeOffset = aRuntimeType.getTypeOffsets().get(i); + if (staticTypeOffset.offset() != runtimeTypeOffset.offset() + || !isStaticTypeOf(staticTypeOffset.type(), runtimeTypeOffset.type())) { return false; } } diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/expression/type/TypeLayout.java b/dartagnan/src/main/java/com/dat3m/dartagnan/expression/type/TypeLayout.java deleted file mode 100644 index ae71f8212e..0000000000 --- a/dartagnan/src/main/java/com/dat3m/dartagnan/expression/type/TypeLayout.java +++ /dev/null @@ -1,62 +0,0 @@ -package com.dat3m.dartagnan.expression.type; - -import com.dat3m.dartagnan.expression.Type; -import com.google.common.math.IntMath; - -import java.math.RoundingMode; - -public record TypeLayout(int unpaddedSize, int alignment) { - - public int totalSizeInBytes() { return paddedSize(unpaddedSize, alignment); } - - @Override - public String toString() { - return String.format("[totalSize = %s bytes, unpaddedSize = %s bytes, alignment = %s bytes]", - totalSizeInBytes(), unpaddedSize(), alignment()); - } - - public static TypeLayout of(Type type) { - final int unpaddedSize; - final int alignment; - - // For primitives, we assume that size and alignment requirement coincide - if (type instanceof BooleanType) { - unpaddedSize = 1; - alignment = unpaddedSize; - } else if (type instanceof IntegerType integerType) { - unpaddedSize = IntMath.divide(integerType.getBitWidth(), 8, RoundingMode.CEILING); - alignment = unpaddedSize; - } else if (type instanceof FloatType floatType) { - unpaddedSize = IntMath.divide(floatType.getBitWidth(), 8, RoundingMode.CEILING); - alignment = unpaddedSize; - } else if (type instanceof ArrayType arrayType) { - final TypeLayout elemTypeLayout = of(arrayType.getElementType()); - unpaddedSize = elemTypeLayout.totalSizeInBytes() * arrayType.getNumElements(); - alignment = elemTypeLayout.alignment(); - } else if (type instanceof AggregateType aggregateType) { - return of(aggregateType.getDirectFields()); - } else { - throw new UnsupportedOperationException("Cannot compute memory layout of type " + type); - } - - return new TypeLayout(unpaddedSize, alignment); - } - - public static TypeLayout of(Iterable aggregate) { - int aggregateSize = 0; - int maxAlignment = 1; - for (Type fieldType : aggregate) { - final TypeLayout layout = of(fieldType); - aggregateSize = paddedSize(aggregateSize, layout.alignment()) + layout.totalSizeInBytes(); - maxAlignment = Math.max(maxAlignment, layout.alignment()); - } - return new TypeLayout(aggregateSize, maxAlignment); - } - - public static int paddedSize(int size, int alignment) { - final int mod = size % alignment; - final int padding = mod == 0 ? 0 : (alignment - mod); - return size + padding; - } - -} diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/expression/type/TypeOffset.java b/dartagnan/src/main/java/com/dat3m/dartagnan/expression/type/TypeOffset.java index 9ad4dc7878..a16de48dde 100644 --- a/dartagnan/src/main/java/com/dat3m/dartagnan/expression/type/TypeOffset.java +++ b/dartagnan/src/main/java/com/dat3m/dartagnan/expression/type/TypeOffset.java @@ -1,43 +1,23 @@ package com.dat3m.dartagnan.expression.type; import com.dat3m.dartagnan.expression.Type; -import com.google.common.base.Preconditions; - -import java.util.List; - -import static com.dat3m.dartagnan.expression.type.TypeLayout.paddedSize; public record TypeOffset(Type type, int offset) { + private static final TypeFactory types = TypeFactory.getInstance(); + public static TypeOffset of(Type type, int index) { if (index == 0) { return new TypeOffset(type, 0); } - if (type instanceof ArrayType arrayType) { - final Type elemType = arrayType.getElementType(); - return new TypeOffset(elemType, TypeLayout.of(elemType).totalSizeInBytes() * index); - } else if (type instanceof AggregateType aggregateType) { - final List fields = aggregateType.getDirectFields(); - Preconditions.checkArgument(index < fields.size()); - final TypeLayout prefixLayout = TypeLayout.of(fields.subList(0, index)); - final TypeLayout fieldLayout = TypeLayout.of(fields.get(index)); - final int offset = paddedSize(prefixLayout.unpaddedSize(), fieldLayout.alignment()); - return new TypeOffset(fields.get(index), offset); - } else { - final String error = String.format("Cannot compute offset of index %d into type %s.", index, type); - throw new UnsupportedOperationException(error); + Type elType = arrayType.getElementType(); + return new TypeOffset(elType, types.getMemorySizeInBytes(elType) * index); } - } - - public static TypeOffset of(Type type, Iterable indices) { - int totalOffset = 0; - for (int i : indices) { - final TypeOffset inner = of(type, i); - type = inner.type(); - totalOffset += inner.offset(); + if (type instanceof AggregateType aggregateType) { + return aggregateType.getTypeOffsets().get(index); } - - return new TypeOffset(type, totalOffset); + String error = String.format("Cannot compute offset of index %d into type %s.", index, type); + throw new UnsupportedOperationException(error); } } diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/VisitorLlvm.java b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/VisitorLlvm.java index e9be54d72f..2043c3c2a9 100644 --- a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/VisitorLlvm.java +++ b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/VisitorLlvm.java @@ -738,7 +738,8 @@ public Expression visitCmpXchgInst(CmpXchgInstContext ctx) { getOrNewCurrentRegister(types.getAggregateType(List.of(comparator.getType(), getIntegerType(1)))); if (register != null) { final Expression cast = expressions.makeIntegerCast(asExpected, getIntegerType(1), false); - final Expression result = expressions.makeConstruct(List.of(value, cast)); + final Type type = types.getAggregateType(List.of(value.getType(), cast.getType())); + final Expression result = expressions.makeConstruct(type, List.of(value, cast)); block.events.add(newLocal(register, result)); } return register; @@ -886,11 +887,9 @@ public Expression visitPoisonConst(PoisonConstContext ctx) { @Override public Expression visitStructConst(StructConstContext ctx) { - List structMembers = new ArrayList<>(); - for (TypeConstContext typeCtx : ctx.typeConst()) { - structMembers.add(visitTypeConst(typeCtx)); - } - return expressions.makeConstruct(structMembers); + List structMembers = ctx.typeConst().stream().map(this::visitTypeConst).toList(); + List structTypes = structMembers.stream().map(Expression::getType).toList(); + return expressions.makeConstruct(types.getAggregateType(structTypes), structMembers); } @Override diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/VisitorOpsAnnotation.java b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/VisitorOpsAnnotation.java index d1de71fbce..2502cebb93 100644 --- a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/VisitorOpsAnnotation.java +++ b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/VisitorOpsAnnotation.java @@ -15,10 +15,12 @@ public class VisitorOpsAnnotation extends SpirvBaseVisitor { private final Decoration builtIn; private final Decoration specId; + private final Decoration offset; public VisitorOpsAnnotation(ProgramBuilder builder) { this.builtIn = builder.getDecorationsBuilder().getDecoration(BUILT_IN); this.specId = builder.getDecorationsBuilder().getDecoration(SPEC_ID); + this.offset = builder.getDecorationsBuilder().getDecoration(OFFSET); } @Override @@ -34,10 +36,25 @@ public Void visitOpDecorate(SpirvParser.OpDecorateContext ctx) { String value = ctx.decoration().specializationConstantID().getText(); specId.addDecoration(id, value); } - case ARRAY_STRIDE, BINDING, BLOCK, BUFFER_BLOCK, COHERENT, DESCRIPTOR_SET, OFFSET, NO_CONTRACTION, NO_PERSPECTIVE, NON_WRITABLE -> { + case ARRAY_STRIDE, BINDING, BLOCK, BUFFER_BLOCK, COHERENT, DESCRIPTOR_SET, NO_CONTRACTION, NO_PERSPECTIVE, NON_WRITABLE -> { // TODO: Implementation } - default -> throw new ParsingException("Unsupported decoration type '%s'", type); + default -> throw new ParsingException("Unsupported decoration '%s'", type); + } + return null; + } + + @Override + public Void visitOpMemberDecorate(SpirvParser.OpMemberDecorateContext ctx) { + String id = ctx.structureType().getText(); + String index = ctx.member().getText(); + DecorationType type = fromString(ctx.decoration().getChild(0).getText()); + switch (type) { + case OFFSET -> { + String value = ctx.decoration().byteOffset().getText(); + offset.addDecoration(id, index, value); + } + default -> throw new ParsingException("Unsupported member decoration '%s'", type); } return null; } diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/VisitorOpsConstant.java b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/VisitorOpsConstant.java index 4de218bf58..d2c2de689d 100644 --- a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/VisitorOpsConstant.java +++ b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/VisitorOpsConstant.java @@ -5,10 +5,7 @@ import com.dat3m.dartagnan.expression.ExpressionFactory; import com.dat3m.dartagnan.expression.Type; import com.dat3m.dartagnan.expression.integers.IntLiteral; -import com.dat3m.dartagnan.expression.type.AggregateType; -import com.dat3m.dartagnan.expression.type.ArrayType; -import com.dat3m.dartagnan.expression.type.BooleanType; -import com.dat3m.dartagnan.expression.type.IntegerType; +import com.dat3m.dartagnan.expression.type.*; import com.dat3m.dartagnan.parsers.SpirvBaseVisitor; import com.dat3m.dartagnan.parsers.SpirvParser; import com.dat3m.dartagnan.parsers.program.visitors.spirv.decorations.BuiltIn; @@ -181,7 +178,7 @@ private Expression makeConstantComposite(String id, Type type, List elem } private Expression makeConstantStruct(String id, AggregateType type, List elementIds) { - List elementTypes = type.getDirectFields(); + List elementTypes = type.getTypeOffsets(); if (elementTypes.size() != elementIds.size()) { throw new ParsingException("Mismatching number of elements in the composite constant '%s', " + "expected %d elements but received %d elements", id, elementTypes.size(), elementIds.size()); @@ -189,14 +186,14 @@ private Expression makeConstantStruct(String id, AggregateType type, List elements = new ArrayList<>(); for (int i = 0; i < elementTypes.size(); i++) { Expression expression = builder.getExpression(elementIds.get(i)); - if (!expression.getType().equals(elementTypes.get(i))) { + if (!expression.getType().equals(elementTypes.get(i).type())) { throw new ParsingException("Mismatching type of a composite constant '%s' element '%s', " + "expected '%s' but received '%s'", id, elementIds.get(i), - elementTypes.get(i), expression.getType()); + elementTypes.get(i).type(), expression.getType()); } elements.add(expression); } - return expressions.makeConstruct(elements); + return expressions.makeConstruct(type, elements); } private Expression makeConstantArray(String id, ArrayType type, List elementIds) { diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/VisitorOpsType.java b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/VisitorOpsType.java index ad4719ccfc..45047a189a 100644 --- a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/VisitorOpsType.java +++ b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/VisitorOpsType.java @@ -7,20 +7,24 @@ import com.dat3m.dartagnan.expression.type.TypeFactory; import com.dat3m.dartagnan.parsers.SpirvBaseVisitor; import com.dat3m.dartagnan.parsers.SpirvParser; +import com.dat3m.dartagnan.parsers.program.visitors.spirv.decorations.DecorationType; +import com.dat3m.dartagnan.parsers.program.visitors.spirv.decorations.Offset; import com.dat3m.dartagnan.parsers.program.visitors.spirv.helpers.HelperTags; import com.dat3m.dartagnan.parsers.program.visitors.spirv.builders.ProgramBuilder; -import java.util.List; -import java.util.Set; +import java.util.*; +import java.util.stream.IntStream; public class VisitorOpsType extends SpirvBaseVisitor { private static final TypeFactory types = TypeFactory.getInstance(); private final ProgramBuilder builder; + private final Offset offset; public VisitorOpsType(ProgramBuilder builder) { this.builder = builder; + this.offset = (Offset) builder.getDecorationsBuilder().getDecoration(DecorationType.OFFSET); } @Override @@ -92,8 +96,21 @@ public Type visitOpTypeStruct(SpirvParser.OpTypeStructContext ctx) { String id = ctx.idResult().getText(); List memberTypes = ctx.memberType().stream() .map(memberCtx -> builder.getType(memberCtx.getText())).toList(); - Type type = types.getAggregateType(memberTypes); - return builder.addType(id, type); + Map offsets = offset.getValue(id); + if (offsets != null) { + if (memberTypes.size() == offsets.size()) { + List memberOffsets = IntStream.range(0, offsets.size()).boxed().map(i -> { + if (!offsets.containsKey(i)) { + throw new ParsingException("Missing member offset decoration for struct '%s' index '%s'", id, i); + } + return offsets.get(i); + }).toList(); + Type type = types.getAggregateType(memberTypes, memberOffsets); + return builder.addType(id, type); + } + throw new ParsingException("Illegal member offset decorations for struct '%s'", id); + } + throw new ParsingException("Missing member offset decorations for struct '%s'", id); } @Override diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/VisitorSpirvInput.java b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/VisitorSpirvInput.java index 765c728dcf..a6aa6ef271 100644 --- a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/VisitorSpirvInput.java +++ b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/VisitorSpirvInput.java @@ -3,12 +3,15 @@ import com.dat3m.dartagnan.exception.ParsingException; import com.dat3m.dartagnan.expression.Expression; import com.dat3m.dartagnan.expression.ExpressionFactory; +import com.dat3m.dartagnan.expression.Type; import com.dat3m.dartagnan.expression.type.IntegerType; import com.dat3m.dartagnan.expression.type.TypeFactory; import com.dat3m.dartagnan.parsers.SpirvBaseVisitor; import com.dat3m.dartagnan.parsers.SpirvParser; import com.dat3m.dartagnan.parsers.program.visitors.spirv.builders.ProgramBuilder; +import java.util.List; + public class VisitorSpirvInput extends SpirvBaseVisitor { private static final TypeFactory types = TypeFactory.getInstance(); private static final ExpressionFactory expressions = ExpressionFactory.getInstance(); @@ -49,8 +52,8 @@ public Expression visitInitBaseValue(SpirvParser.InitBaseValueContext ctx) { @Override public Expression visitInitCollectionValue(SpirvParser.InitCollectionValueContext ctx) { - return expressions.makeConstruct(ctx.initValues().initValue().stream() - .map(this::visitInitValue) - .toList()); + List structMembers = ctx.initValues().initValue().stream().map(this::visitInitValue).toList(); + List structTypes = structMembers.stream().map(Expression::getType).toList(); + return expressions.makeConstruct(types.getAggregateType(structTypes), structMembers); } } diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/builders/DecorationsBuilder.java b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/builders/DecorationsBuilder.java index 9114261fb7..3767758e35 100644 --- a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/builders/DecorationsBuilder.java +++ b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/builders/DecorationsBuilder.java @@ -1,16 +1,12 @@ package com.dat3m.dartagnan.parsers.program.visitors.spirv.builders; import com.dat3m.dartagnan.exception.ParsingException; -import com.dat3m.dartagnan.parsers.program.visitors.spirv.decorations.BuiltIn; -import com.dat3m.dartagnan.parsers.program.visitors.spirv.decorations.Decoration; -import com.dat3m.dartagnan.parsers.program.visitors.spirv.decorations.DecorationType; -import com.dat3m.dartagnan.parsers.program.visitors.spirv.decorations.SpecId; +import com.dat3m.dartagnan.parsers.program.visitors.spirv.decorations.*; import com.dat3m.dartagnan.parsers.program.visitors.spirv.utils.ThreadGrid; import java.util.EnumMap; -import static com.dat3m.dartagnan.parsers.program.visitors.spirv.decorations.DecorationType.BUILT_IN; -import static com.dat3m.dartagnan.parsers.program.visitors.spirv.decorations.DecorationType.SPEC_ID; +import static com.dat3m.dartagnan.parsers.program.visitors.spirv.decorations.DecorationType.*; public class DecorationsBuilder { @@ -19,6 +15,7 @@ public class DecorationsBuilder { public DecorationsBuilder(ThreadGrid grid) { mapping.put(BUILT_IN, new BuiltIn(grid)); mapping.put(SPEC_ID, new SpecId()); + mapping.put(OFFSET, new Offset()); } public Decoration getDecoration(DecorationType type) { diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/decorations/Offset.java b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/decorations/Offset.java new file mode 100644 index 0000000000..91ff687a66 --- /dev/null +++ b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/decorations/Offset.java @@ -0,0 +1,31 @@ +package com.dat3m.dartagnan.parsers.program.visitors.spirv.decorations; + +import com.dat3m.dartagnan.exception.ParsingException; + +import java.util.HashMap; +import java.util.Map; + +public class Offset implements Decoration { + + private final Map> mapping = new HashMap<>(); + + @Override + public void addDecoration(String id, String... params) { + if (params.length != 2) { + throw new ParsingException("Illegal decoration '%s' for '%s'", + getClass().getSimpleName(), id); + } + int index = Integer.parseInt(params[0]); + int offset = Integer.parseInt(params[1]); + Map typeOffsets = mapping.computeIfAbsent(id, x -> new HashMap<>()); + if (typeOffsets.containsKey(index)) { + throw new ParsingException("Duplicated '%s' decoration for '%s' index '%s'", + getClass().getSimpleName(), id, index); + } + typeOffsets.put(index, offset); + } + + public Map getValue(String id) { + return mapping.get(id); + } +} diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/extenstions/VisitorExtensionClspvReflection.java b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/extenstions/VisitorExtensionClspvReflection.java index f08a39da6d..7d6426d5ec 100644 --- a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/extenstions/VisitorExtensionClspvReflection.java +++ b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/extenstions/VisitorExtensionClspvReflection.java @@ -5,15 +5,12 @@ import com.dat3m.dartagnan.expression.ExpressionFactory; import com.dat3m.dartagnan.expression.Type; import com.dat3m.dartagnan.expression.integers.IntLiteral; -import com.dat3m.dartagnan.expression.type.AggregateType; -import com.dat3m.dartagnan.expression.type.ArrayType; -import com.dat3m.dartagnan.expression.type.IntegerType; -import com.dat3m.dartagnan.expression.type.TypeFactory; +import com.dat3m.dartagnan.expression.type.*; import com.dat3m.dartagnan.parsers.SpirvParser; import com.dat3m.dartagnan.parsers.program.visitors.spirv.builders.ProgramBuilder; import com.dat3m.dartagnan.parsers.program.visitors.spirv.utils.ThreadGrid; -import com.dat3m.dartagnan.program.memory.ScopedPointerVariable; import com.dat3m.dartagnan.program.event.Tag; +import com.dat3m.dartagnan.program.memory.ScopedPointerVariable; import java.util.List; import java.util.Set; @@ -26,8 +23,6 @@ public class VisitorExtensionClspvReflection extends VisitorExtension { private final ProgramBuilder builder; private ScopedPointerVariable pushConstant; private AggregateType pushConstantType; - private int pushConstantIndex = 0; - private int pushConstantOffset = 0; public VisitorExtensionClspvReflection(ProgramBuilder builder) { this.builder = builder; @@ -65,78 +60,61 @@ public Void visitSpecConstantWorkgroupSize(SpirvParser.SpecConstantWorkgroupSize @Override public Void visitPushConstantGlobalOffset(SpirvParser.PushConstantGlobalOffsetContext ctx) { - return setPushConstantValue("PushConstantGlobalOffset", ctx.sizeIdRef().getText()); + return setPushConstantValue("PushConstantGlobalOffset", ctx.offsetIdRef().getText(), ctx.sizeIdRef().getText()); } @Override public Void visitPushConstantGlobalSize(SpirvParser.PushConstantGlobalSizeContext ctx) { - return setPushConstantValue("PushConstantGlobalSize", ctx.sizeIdRef().getText()); + return setPushConstantValue("PushConstantGlobalSize", ctx.offsetIdRef().getText(), ctx.sizeIdRef().getText()); } @Override public Void visitPushConstantEnqueuedLocalSize(SpirvParser.PushConstantEnqueuedLocalSizeContext ctx) { - return setPushConstantValue("PushConstantEnqueuedLocalSize", ctx.sizeIdRef().getText()); + return setPushConstantValue("PushConstantEnqueuedLocalSize", ctx.offsetIdRef().getText(), ctx.sizeIdRef().getText()); } @Override public Void visitPushConstantNumWorkgroups(SpirvParser.PushConstantNumWorkgroupsContext ctx) { - return setPushConstantValue("PushConstantNumWorkgroups", ctx.sizeIdRef().getText()); + return setPushConstantValue("PushConstantNumWorkgroups", ctx.offsetIdRef().getText(), ctx.sizeIdRef().getText()); } @Override public Void visitPushConstantRegionOffset(SpirvParser.PushConstantRegionOffsetContext ctx) { - return setPushConstantValue("PushConstantRegionOffset", ctx.sizeIdRef().getText()); + return setPushConstantValue("PushConstantRegionOffset", ctx.offsetIdRef().getText(), ctx.sizeIdRef().getText()); } @Override public Void visitPushConstantRegionGroupOffset(SpirvParser.PushConstantRegionGroupOffsetContext ctx) { - return setPushConstantValue("PushConstantRegionGroupOffset", ctx.sizeIdRef().getText()); + return setPushConstantValue("PushConstantRegionGroupOffset", ctx.offsetIdRef().getText(), ctx.sizeIdRef().getText()); } @Override public Void visitArgumentPodPushConstant(SpirvParser.ArgumentPodPushConstantContext ctx) { initPushConstant(); - if (pushConstantIndex >= pushConstantType.getDirectFields().size()) { - throw new ParsingException("Out of bounds definition 'ArgumentPodPushConstant' in PushConstant '%s'", - pushConstant.getId()); - } - Type type = pushConstantType.getDirectFields().get(pushConstantIndex); - int typeSize = types.getMemorySizeInBytes(type); - if (typeSize != getExpressionAsConstInteger(ctx.sizeIdRef().getText())) { - throw new ParsingException("Unexpected offset in PushConstant '%s' element '%s'", - pushConstant.getId(), pushConstantIndex); - } - pushConstantOffset += typeSize; - pushConstantIndex++; + int argOffset = getExpressionAsConstInteger(ctx.offsetIdRef().getText()); + int argSize = getExpressionAsConstInteger(ctx.sizeIdRef().getText()); + getTypeOffset("ArgumentPodPushConstant", pushConstantType, argOffset, argSize); return null; } - private Void setPushConstantValue(String decorationId, String sizeId) { + private Void setPushConstantValue(String argument, String offsetId, String sizeId) { initPushConstant(); - if (pushConstantIndex >= pushConstantType.getDirectFields().size()) { - throw new ParsingException("Out of bounds definition '%s' in PushConstant '%s'", - decorationId, pushConstant.getId()); - } - Type type = pushConstantType.getDirectFields().get(pushConstantIndex); - int typeSize = types.getMemorySizeInBytes(type); - int expectedSize = getExpressionAsConstInteger(sizeId); - if (type instanceof ArrayType aType && aType.getNumElements() == 3 && typeSize == expectedSize) { + int argOffset = getExpressionAsConstInteger(offsetId); + int argSize = getExpressionAsConstInteger(sizeId); + TypeOffset typeOffset = getTypeOffset(argument, pushConstantType, argOffset, argSize); + if (typeOffset.type() instanceof ArrayType aType && aType.getNumElements() == 3) { Type elType = aType.getElementType(); if (elType instanceof IntegerType iType) { - List values = computePushConstantValue(decorationId); - int localOffset = 0; - for (int value : values) { + int offset = typeOffset.offset(); + for (int value : computePushConstantValue(argument)) { Expression elExpr = expressions.makeValue(value, iType); - pushConstant.setInitialValue(pushConstantOffset + localOffset, elExpr); - localOffset += types.getMemorySizeInBytes(elExpr.getType()); + pushConstant.setInitialValue(offset, elExpr); + offset += types.getMemorySizeInBytes(elExpr.getType()); } - pushConstantOffset += localOffset; - pushConstantIndex++; return null; } } - throw new ParsingException("Unexpected element type in '%s' at index %s", - pushConstant.getId(), pushConstantIndex); + throw new ParsingException("Argument %s doesn't match the PushConstant type", argument); } private List computePushConstantValue(String command) { @@ -147,8 +125,7 @@ private List computePushConstantValue(String command) { case "PushConstantNumWorkgroups" -> List.of(grid.qfSize() / grid.wgSize(), 1, 1); case "PushConstantGlobalOffset", "PushConstantRegionOffset", - "PushConstantRegionGroupOffset" - -> List.of(0, 0, 0); + "PushConstantRegionGroupOffset" -> List.of(0, 0, 0); default -> throw new ParsingException("Unsupported PushConstant command '%s'", command); }; } @@ -181,6 +158,29 @@ private int getExpressionAsConstInteger(String id) { throw new ParsingException("Expression '%s' is not an integer constant", id); } + private TypeOffset getTypeOffset(String argument, AggregateType type, int argOffset, int argSize) { + TypeOffset lastOffset = null; + for (TypeOffset typeOffset : type.getTypeOffsets()) { + if (argOffset <= typeOffset.offset()) { + if (argOffset == typeOffset.offset()) { + lastOffset = typeOffset; + } + break; + } + lastOffset = typeOffset; + } + if (lastOffset != null) { + if (argOffset == lastOffset.offset() && argSize == types.getMemorySizeInBytes(lastOffset.type())) { + return new TypeOffset(lastOffset.type(), lastOffset.offset()); + } + if (lastOffset.type() instanceof AggregateType aType) { + TypeOffset subTypeOffset = getTypeOffset(argument, aType, argOffset - lastOffset.offset(), argSize); + return new TypeOffset(subTypeOffset.type(), lastOffset.offset() + subTypeOffset.offset()); + } + } + throw new ParsingException("Argument %s doesn't match the PushConstant type", argument); + } + @Override public Set getSupportedInstructions() { return Set.of( diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/helpers/HelperInputs.java b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/helpers/HelperInputs.java index dc01fc65e0..92a874ae44 100644 --- a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/helpers/HelperInputs.java +++ b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/helpers/HelperInputs.java @@ -48,16 +48,19 @@ private static Expression castArray(String id, ArrayType type, Expression value) private static Expression castAggregate(String id, AggregateType type, Expression value) { if (value instanceof ConstructExpr aValue) { - int expectedSize = type.getDirectFields().size(); + int expectedSize = type.getTypeOffsets().size(); int actualSize = aValue.getOperands().size(); if (expectedSize != actualSize) { throw new ParsingException(errorMismatchingElementCount(id, expectedSize, actualSize)); } List elements = new ArrayList<>(); for (int i = 0; i < actualSize; i++) { - elements.add(castInput(id, type.getDirectFields().get(i), aValue.getOperands().get(i))); + elements.add(castInput(id, type.getTypeOffsets().get(i).type(), aValue.getOperands().get(i))); } - return expressions.makeConstruct(elements); + List fields = elements.stream().map(Expression::getType).toList(); + List offsets = type.getTypeOffsets().stream().map(TypeOffset::offset).toList(); + AggregateType aType = types.getAggregateType(fields, offsets); + return expressions.makeConstruct(aType, elements); } throw new ParsingException(errorMismatchingType(id, type, value.getType())); } diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/helpers/HelperTypes.java b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/helpers/HelperTypes.java index 1d278b69df..68f8f70e11 100644 --- a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/helpers/HelperTypes.java +++ b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/helpers/HelperTypes.java @@ -77,8 +77,8 @@ private static Type getArrayMemberType(String id, ArrayType type, List private static Type getStructMemberType(String id, AggregateType type, List indexes) { int index = indexes.get(0); if (index >= 0) { - if (index < type.getDirectFields().size()) { - return getMemberType(id, type.getDirectFields().get(index), indexes.subList(1, indexes.size())); + if (index < type.getTypeOffsets().size()) { + return getMemberType(id, type.getTypeOffsets().get(index).type(), indexes.subList(1, indexes.size())); } throw new ParsingException(indexOutOfBoundsError(id)); } @@ -101,9 +101,9 @@ private static int getArrayMemberOffset(String id, int offset, ArrayType type, L private static int getStructMemberOffset(String id, int offset, AggregateType type, List indexes) { int index = indexes.get(0); if (index >= 0) { - if (index < type.getDirectFields().size()) { - offset += types.getOffsetInBytes(type, index); - Type elType = type.getDirectFields().get(index); + if (index < type.getTypeOffsets().size()) { + offset += type.getTypeOffsets().get(index).offset(); + Type elType = type.getTypeOffsets().get(index).type(); return getMemberOffset(id, offset, elType, indexes.subList(1, indexes.size())); } throw new ParsingException(indexOutOfBoundsError(id)); @@ -125,14 +125,12 @@ private static Expression getStructMemberAddress(String id, Expression base, Agg Expression indexExpr = indexes.get(0); if (indexExpr instanceof IntLiteral intLiteral) { int index = intLiteral.getValueAsInt(); - if (index < type.getDirectFields().size()) { - int offset = 0; - for (int i = 0; i < index; i++) { - offset += types.getMemorySizeInBytes(type.getDirectFields().get(i)); - } + if (index < type.getTypeOffsets().size()) { + Type subType = type.getTypeOffsets().get(index).type(); + int offset = type.getTypeOffsets().get(index).offset(); IntLiteral offsetExpr = expressions.makeValue(offset, archType); Expression expression = expressions.makeBinary(base, ADD, offsetExpr); - return getMemberAddress(id, expression, type.getDirectFields().get(index), indexes.subList(1, indexes.size())); + return getMemberAddress(id, expression, subType, indexes.subList(1, indexes.size())); } throw new ParsingException(indexOutOfBoundsError(id)); } diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/program/Program.java b/dartagnan/src/main/java/com/dat3m/dartagnan/program/Program.java index 614162b97d..ce42dc7ed5 100644 --- a/dartagnan/src/main/java/com/dat3m/dartagnan/program/Program.java +++ b/dartagnan/src/main/java/com/dat3m/dartagnan/program/Program.java @@ -6,6 +6,7 @@ import com.dat3m.dartagnan.expression.Type; import com.dat3m.dartagnan.expression.type.AggregateType; import com.dat3m.dartagnan.expression.type.ArrayType; +import com.dat3m.dartagnan.expression.type.TypeOffset; import com.dat3m.dartagnan.program.event.Event; import com.dat3m.dartagnan.program.memory.Memory; import com.dat3m.dartagnan.program.misc.NonDetValue; @@ -145,11 +146,11 @@ public Expression newConstant(Type type) { return expressions.makeArray(arrayType.getElementType(), entries, true); } if (type instanceof AggregateType aggregateType) { - final List elements = new ArrayList<>(aggregateType.getDirectFields().size()); - for (Type fieldType : aggregateType.getDirectFields()) { - elements.add(newConstant(fieldType)); + final List elements = new ArrayList<>(aggregateType.getTypeOffsets().size()); + for (TypeOffset typeOffset : aggregateType.getTypeOffsets()) { + elements.add(newConstant(typeOffset.type())); } - return expressions.makeConstruct(elements); + return expressions.makeConstruct(type, elements); } var expression = new NonDetValue(type, nextConstantId++); constants.add(expression); diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/program/memory/MemoryObject.java b/dartagnan/src/main/java/com/dat3m/dartagnan/program/memory/MemoryObject.java index c35588b0ef..6b606ba576 100644 --- a/dartagnan/src/main/java/com/dat3m/dartagnan/program/memory/MemoryObject.java +++ b/dartagnan/src/main/java/com/dat3m/dartagnan/program/memory/MemoryObject.java @@ -119,7 +119,7 @@ public void setInitialValue(int offset, Expression value) { final ConstructExpr constStruct = (ConstructExpr) value; final List structElements = constStruct.getOperands(); for (int i = 0; i < structElements.size(); i++) { - int innerOffset = types.getOffsetInBytes(aggregateType, i); + int innerOffset = aggregateType.getTypeOffsets().get(i).offset(); setInitialValue(offset + innerOffset, structElements.get(i)); } } else if (value.getType() instanceof IntegerType diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/program/processing/Intrinsics.java b/dartagnan/src/main/java/com/dat3m/dartagnan/program/processing/Intrinsics.java index bc6576c3d1..1ef4613035 100644 --- a/dartagnan/src/main/java/com/dat3m/dartagnan/program/processing/Intrinsics.java +++ b/dartagnan/src/main/java/com/dat3m/dartagnan/program/processing/Intrinsics.java @@ -1181,14 +1181,13 @@ private List inlineLLVMSOpWithOverflow(ValueFunctionCall call, IntBinaryO final Expression yExt = expressions.makeCast(y, types.getIntegerType(width + 1), true); final Expression resultExt = expressions.makeCast(result, types.getIntegerType(width + 1), true); final Expression bvCheck = expressions.makeEQ(expressions.makeIntBinary(xExt, op, yExt), resultExt); - final Expression flag = expressions.makeCast( expressions.makeNot(expressions.makeAnd(bvCheck, rangeCheck)), types.getIntegerType(1) ); - + final Type type = types.getAggregateType(List.of(result.getType(), flag.getType())); return List.of( - EventFactory.newLocal(resultReg, expressions.makeConstruct(List.of(result, flag))) + EventFactory.newLocal(resultReg, expressions.makeConstruct(type, List.of(result, flag))) ); } diff --git a/dartagnan/src/test/java/com/dat3m/dartagnan/expression/type/AggregateTypeTest.java b/dartagnan/src/test/java/com/dat3m/dartagnan/expression/type/AggregateTypeTest.java new file mode 100644 index 0000000000..9f8abfd9ba --- /dev/null +++ b/dartagnan/src/test/java/com/dat3m/dartagnan/expression/type/AggregateTypeTest.java @@ -0,0 +1,115 @@ +package com.dat3m.dartagnan.expression.type; + +import com.dat3m.dartagnan.expression.Type; +import org.junit.Test; + +import java.util.List; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +public class AggregateTypeTest { + + private static final TypeFactory types = TypeFactory.getInstance(); + + @Test + public void testDefaultOffsets() { + Type i8 = types.getIntegerType(8); + Type i16 = types.getIntegerType(16); + Type i32 = types.getIntegerType(32); + Type i64 = types.getIntegerType(64); + + testStandardOffsets(List.of(i8, i8, i8, i32), List.of(0, 1, 2, 4), 8); + testStandardOffsets(List.of(i32, i8, i8, i8), List.of(0, 4, 5, 6), 8); + testStandardOffsets(List.of(i8, i16, i32), List.of(0, 2, 4), 8); + testStandardOffsets(List.of(i32, i16, i8), List.of(0, 4, 6), 8); + testStandardOffsets(List.of(i8, i64), List.of(0, 8), 16); + testStandardOffsets(List.of(i64, i8), List.of(0, 8), 16); + + Type arr1 = types.getArrayType(i16, 3); + Type arr2 = types.getArrayType(i16); + + testStandardOffsets(List.of(arr1), List.of(0), 6); + testStandardOffsets(List.of(i16, arr1), List.of(0, 2), 8); + testStandardOffsets(List.of(arr1, i16), List.of(0, 6), 8); + testStandardOffsets(List.of(arr2), List.of(0), -1); + testStandardOffsets(List.of(i16, arr2), List.of(0, 2), -1); + + Type s1 = types.getAggregateType(List.of(i32, i8)); + + testStandardOffsets(List.of(s1, i16), List.of(0, 8), 12); + } + + @Test + public void testExplicitOffsets() { + Type i8 = types.getIntegerType(8); + + testExplicitOffsets(List.of(i8, i8, i8), List.of(0, 1, 2), 3); + testExplicitOffsets(List.of(i8, i8, i8), List.of(0, 3, 5), 6); + testExplicitOffsets(List.of(i8, i8, i8), List.of(0, 1, 7), 8); + testExplicitOffsets(List.of(i8, i8, i8), List.of(0, 7, 13), 14); + + Type i16 = types.getIntegerType(16); + Type i32 = types.getIntegerType(32); + Type i64 = types.getIntegerType(64); + + testExplicitOffsets(List.of(i8, i8, i8, i32), List.of(0, 4, 8, 12), 16); + testExplicitOffsets(List.of(i8, i8, i8, i32), List.of(0, 4, 8, 12), 16); + testExplicitOffsets(List.of(i8, i64), List.of(0, 1), 16); + testExplicitOffsets(List.of(i64, i8), List.of(0, 15), 16); + + Type arr1 = types.getArrayType(i16, 3); + Type arr2 = types.getArrayType(i16); + + testExplicitOffsets(List.of(i16, arr1), List.of(0, 8), 14); + testExplicitOffsets(List.of(arr1, i16), List.of(0, 8), 10); + testExplicitOffsets(List.of(i16, arr2), List.of(0, 8), -1); + + Type s1 = types.getAggregateType(List.of(i32, i8)); + + testExplicitOffsets(List.of(s1, i16), List.of(0, 5), 8); + testExplicitOffsets(List.of(s1, i16), List.of(0, 6), 8); + testExplicitOffsets(List.of(s1, i16), List.of(0, 8), 12); + } + + @Test + public void testExplicitOffsetsTooSmall() { + Type i32 = types.getIntegerType(32); + Type arr = types.getArrayType(i32); + + testIllegalOffsets(List.of(i32, i32), List.of(0, 2), "Offset is too small"); + testIllegalOffsets(List.of(i32, i32), List.of(4, 8), "The first offset must be zero"); + testIllegalOffsets(List.of(i32, i32), List.of(0, -1), "Offset cannot be negative"); + testIllegalOffsets(List.of(i32, i32), List.of(0), "Offsets number does not match the fields number"); + testIllegalOffsets(List.of(i32, i32), List.of(0, 4, 8), "Offsets number does not match the fields number"); + testIllegalOffsets(List.of(arr, i32), List.of(0, 8), "Non-last element with unknown size"); + } + + private void testStandardOffsets(List fields, List offsets, int size) { + testDefaultOffsets(fields, offsets, size); + testExplicitOffsets(fields, offsets, size); + } + + private void testDefaultOffsets(List fields, List offsets, int size) { + AggregateType type = types.getAggregateType(fields); + assertEquals(fields, type.getTypeOffsets().stream().map(TypeOffset::type).toList()); + assertEquals(offsets, type.getTypeOffsets().stream().map(TypeOffset::offset).toList()); + assertEquals(size, types.getMemorySizeInBytes(type)); + } + + private void testExplicitOffsets(List fields, List offsets, int size) { + AggregateType type = types.getAggregateType(fields, offsets); + assertEquals(fields, type.getTypeOffsets().stream().map(TypeOffset::type).toList()); + assertEquals(offsets, type.getTypeOffsets().stream().map(TypeOffset::offset).toList()); + assertEquals(size, types.getMemorySizeInBytes(type)); + } + + private void testIllegalOffsets(List fields, List offsets, String error) { + try { + types.getAggregateType(fields, offsets); + fail("Should throw exception"); + } catch (IllegalArgumentException e) { + assertEquals(error, e.getMessage()); + } + } +} diff --git a/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/VisitorOpsConstantTest.java b/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/VisitorOpsConstantTest.java index 522818c997..d52c6b0665 100644 --- a/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/VisitorOpsConstantTest.java +++ b/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/VisitorOpsConstantTest.java @@ -3,10 +3,7 @@ import com.dat3m.dartagnan.exception.ParsingException; import com.dat3m.dartagnan.expression.Expression; import com.dat3m.dartagnan.expression.ExpressionFactory; -import com.dat3m.dartagnan.expression.type.AggregateType; -import com.dat3m.dartagnan.expression.type.BooleanType; -import com.dat3m.dartagnan.expression.type.IntegerType; -import com.dat3m.dartagnan.expression.type.TypeFactory; +import com.dat3m.dartagnan.expression.type.*; import com.dat3m.dartagnan.parsers.program.visitors.spirv.decorations.SpecId; import com.dat3m.dartagnan.parsers.program.visitors.spirv.mocks.MockProgramBuilder; import com.dat3m.dartagnan.parsers.program.visitors.spirv.mocks.MockSpirvParser; @@ -209,7 +206,7 @@ private void doTestOpConstantStruct(String input) { // given builder.mockBoolType("%bool"); IntegerType iType = builder.mockIntType("%int", 64); - builder.mockAggregateType("%struct", "%bool", "%int"); + AggregateType aType = builder.mockAggregateType("%struct", "%bool", "%int"); // when Map data = parseConstants(input); @@ -219,7 +216,7 @@ private void doTestOpConstantStruct(String input) { Expression b = expressions.makeTrue(); Expression i = expressions.makeValue(7, iType); - Expression s = expressions.makeConstruct(List.of(b, i)); + Expression s = expressions.makeConstruct(aType, List.of(b, i)); assertEquals(b, data.get("%b")); assertEquals(i, data.get("%i")); @@ -298,9 +295,9 @@ public void testNestedCompositeTypes() { builder.mockBoolType("%bool"); IntegerType iType = builder.mockIntType("%int", 64); - AggregateType aType = builder.mockAggregateType("%inner", "%bool", "%int"); + AggregateType innerType = builder.mockAggregateType("%inner", "%bool", "%int"); builder.mockVectorType("%v2inner", "%inner", 2); - builder.mockAggregateType("%outer", "%inner", "%v2inner"); + AggregateType outerType = builder.mockAggregateType("%outer", "%inner", "%v2inner"); // when Map data = parseConstants(input); @@ -316,12 +313,12 @@ public void testNestedCompositeTypes() { Expression i1 = expressions.makeValue(1, iType); Expression i2 = expressions.makeValue(2, iType); - Expression s0 = expressions.makeConstruct(List.of(b0, i0)); - Expression s1 = expressions.makeConstruct(List.of(b1, i1)); - Expression s2 = expressions.makeConstruct(List.of(b2, i2)); + Expression s0 = expressions.makeConstruct(innerType, List.of(b0, i0)); + Expression s1 = expressions.makeConstruct(innerType, List.of(b1, i1)); + Expression s2 = expressions.makeConstruct(innerType, List.of(b2, i2)); - Expression a0 = expressions.makeArray(aType, List.of(s1, s2), true); - Expression s = expressions.makeConstruct(List.of(s0, a0)); + Expression a0 = expressions.makeArray(innerType, List.of(s1, s2), true); + Expression s = expressions.makeConstruct(outerType, List.of(s0, a0)); assertEquals(b0, data.get("%b0")); assertEquals(b1, data.get("%b1")); @@ -653,7 +650,7 @@ public void testOverrideNotSpecConstant() { builder.mockBoolType("%bool"); IntegerType iType = builder.mockIntType("%int", 64); - builder.mockAggregateType("%struct", "%bool", "%bool", "%int"); + AggregateType aType = builder.mockAggregateType("%struct", "%bool", "%bool", "%int"); specId.addDecoration("%f", "1"); specId.addDecoration("%t", "0"); specId.addDecoration("%i", "2"); @@ -670,7 +667,7 @@ public void testOverrideNotSpecConstant() { assertEquals(t, data.get("%t")); assertEquals(i, data.get("%i")); - assertEquals(expressions.makeConstruct(List.of(f, t, i)), data.get("%s")); + assertEquals(expressions.makeConstruct(aType, List.of(f, t, i)), data.get("%s")); } @Test @@ -812,7 +809,7 @@ public void testInputNotSpecConstant() { builder.mockBoolType("%bool"); IntegerType iType = builder.mockIntType("%int", 64); - builder.mockAggregateType("%struct", "%bool", "%bool", "%int"); + AggregateType aType = builder.mockAggregateType("%struct", "%bool", "%bool", "%int"); // when Map data = parseConstants(input); @@ -826,7 +823,7 @@ public void testInputNotSpecConstant() { assertEquals(t, data.get("%t")); assertEquals(i, data.get("%i")); - assertEquals(expressions.makeConstruct(List.of(f, t, i)), data.get("%s")); + assertEquals(expressions.makeConstruct(aType, List.of(f, t, i)), data.get("%s")); } private Map parseConstants(String input) { diff --git a/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/VisitorOpsMemoryTest.java b/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/VisitorOpsMemoryTest.java index ca5cb0ac27..14ef9f4fb5 100644 --- a/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/VisitorOpsMemoryTest.java +++ b/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/VisitorOpsMemoryTest.java @@ -6,9 +6,7 @@ import com.dat3m.dartagnan.expression.Type; import com.dat3m.dartagnan.expression.integers.IntBinaryExpr; import com.dat3m.dartagnan.expression.misc.ConstructExpr; -import com.dat3m.dartagnan.expression.type.ArrayType; -import com.dat3m.dartagnan.expression.type.IntegerType; -import com.dat3m.dartagnan.expression.type.TypeFactory; +import com.dat3m.dartagnan.expression.type.*; import com.dat3m.dartagnan.parsers.SpirvParser; import com.dat3m.dartagnan.parsers.program.visitors.spirv.mocks.MockProgramBuilder; import com.dat3m.dartagnan.parsers.program.visitors.spirv.mocks.MockSpirvParser; @@ -238,7 +236,8 @@ public void testInitializedVariableInput() { Expression i2 = expressions.makeValue(7890, archType); List iValues = Stream.of(1, 2, 3).map(i -> (Expression) expressions.makeValue(i, archType)).toList(); Expression i3 = expressions.makeArray(archType, iValues, true); - Expression i4 = expressions.makeConstruct(List.of(i1, i2, i3)); + AggregateType aType = types.getAggregateType(List.of(i1.getType(), i2.getType(), i3.getType())); + Expression i4 = expressions.makeConstruct(aType, List.of(i1, i2, i3)); builder = new MockProgramBuilder(); builder.addInput("%v1", i1); @@ -270,7 +269,8 @@ private void doTestInitializedVariable(String input) { Expression o2 = expressions.makeValue(7890, iType); List oValues = Stream.of(1, 2, 3).map(i -> (Expression) expressions.makeValue(i, iType)).toList(); Expression o3 = expressions.makeArray(iType, oValues, true); - Expression o4 = expressions.makeConstruct(List.of(o1, o2, o3)); + AggregateType aType = types.getAggregateType(List.of(o1.getType(), o2.getType(), o3.getType())); + Expression o4 = expressions.makeConstruct(aType, List.of(o1, o2, o3)); ScopedPointerVariable v1 = (ScopedPointerVariable) builder.getExpression("%v1"); assertNotNull(v1); @@ -311,7 +311,8 @@ public void testRuntimeArray() { """; IntegerType archType = types.getArchType(); - Type aType = types.getArrayType(archType, 2); + Type arrType = types.getArrayType(archType, 2); + AggregateType aggType = types.getAggregateType(List.of(archType, arrType)); Expression i1 = expressions.makeValue(1, archType); Expression i2 = expressions.makeValue(2, archType); @@ -324,8 +325,8 @@ public void testRuntimeArray() { Expression a2 = expressions.makeArray(archType, List.of(i3, i4), true); Expression a3 = expressions.makeArray(archType, List.of(i5, i6), true); - Expression a3a = expressions.makeArray(aType, List.of(a1, a2, a3), true); - Expression s = expressions.makeConstruct(List.of(i1, a1)); + Expression a3a = expressions.makeArray(arrType, List.of(a1, a2, a3), true); + Expression s = expressions.makeConstruct(aggType, List.of(i1, a1)); builder = new MockProgramBuilder(); builder.addInput("%v1", a1); @@ -600,9 +601,11 @@ public void testMismatchingValueTypeInNestedStruct() { // given String input = "%v = OpVariable %struct2_ptr Uniform %const"; - builder.mockBoolType("%bool"); + BooleanType boolType = builder.mockBoolType("%bool"); builder.mockIntType("%int16", 16); IntegerType i32Type = builder.mockIntType("%int32", 32); + AggregateType a1Type = types.getAggregateType(List.of(boolType, i32Type)); + AggregateType a2Type = types.getAggregateType(List.of(boolType, a1Type)); builder.mockAggregateType("%struct1", "%bool", "%int16"); builder.mockAggregateType("%struct2", "%bool", "%struct1"); @@ -611,8 +614,8 @@ public void testMismatchingValueTypeInNestedStruct() { Expression bool = expressions.makeTrue(); Expression int32 = expressions.makeValue(1, i32Type); - Expression struct1 = expressions.makeConstruct(List.of(bool, int32)); - Expression struct2 = expressions.makeConstruct(List.of(bool, struct1)); + Expression struct1 = expressions.makeConstruct(a1Type, List.of(bool, int32)); + Expression struct2 = expressions.makeConstruct(a2Type, List.of(bool, struct1)); builder.addExpression("%const", struct2); @@ -623,7 +626,9 @@ public void testMismatchingValueTypeInNestedStruct() { } catch (ParsingException e) { // then assertEquals("Mismatching value type for variable '%v', " + - "expected '{ bool, { bool, bv16 } }' but received '{ bool, { bool, bv32 } }'", e.getMessage()); + "expected '{ 0: bool, 2: { 0: bool, 2: bv16 } }' " + + "but received '{ 0: bool, 4: { 0: bool, 4: bv32 } }'", + e.getMessage()); } } @@ -701,10 +706,12 @@ public void testAccessChainStruct() { %element = OpAccessChain %i32_ptr %variable %4 %2 """; - builder.mockBoolType("%bool"); + BooleanType boolType = builder.mockBoolType("%bool"); IntegerType i16Type = builder.mockIntType("%int16", 16); IntegerType i32Type = builder.mockIntType("%int32", 32); IntegerType i64Type = builder.mockIntType("%int64", 64); + AggregateType a1Type = types.getAggregateType(List.of(boolType, i16Type, i32Type, i64Type)); + AggregateType a2Type = types.getAggregateType(List.of(boolType, i16Type, i32Type, i64Type, a1Type)); builder.mockAggregateType("%agg1", "%bool", "%int16", "%int32", "%int64"); builder.mockAggregateType("%agg2", "%bool", "%int16", "%int32", "%int64", "%agg1"); @@ -715,8 +722,8 @@ public void testAccessChainStruct() { Expression i16 = expressions.makeValue(1, i16Type); Expression i32 = expressions.makeValue(11, i32Type); Expression i64 = expressions.makeValue(111, i64Type); - Expression agg1 = expressions.makeConstruct(List.of(b, i16, i32, i64)); - Expression agg2 = expressions.makeConstruct(List.of(b, i16, i32, i64, agg1)); + Expression agg1 = expressions.makeConstruct(a1Type, List.of(b, i16, i32, i64)); + Expression agg2 = expressions.makeConstruct(a2Type, List.of(b, i16, i32, i64, agg1)); builder.addExpression("%const", agg2); @@ -729,9 +736,9 @@ public void testAccessChainStruct() { // then IntBinaryExpr e1 = (IntBinaryExpr) ((ScopedPointer)builder.getExpression("%element")).getAddress(); assertEquals(types.getArchType(), e1.getType()); - assertEquals(expressions.makeValue(3, i64Type), e1.getRight()); + assertEquals(expressions.makeValue(4, i64Type), e1.getRight()); IntBinaryExpr e2 = (IntBinaryExpr) e1.getLeft(); - assertEquals(expressions.makeValue(15, i64Type), e2.getRight()); + assertEquals(expressions.makeValue(16, i64Type), e2.getRight()); assertEquals(builder.getExpression("%variable"), e2.getLeft()); } @@ -776,16 +783,16 @@ public void testAccessChainStructureRegister() { IntegerType i16Type = builder.mockIntType("%int16", 16); IntegerType i32Type = builder.mockIntType("%int32", 32); - builder.mockAggregateType("%agg", "%int16", "%int32"); + AggregateType aType = builder.mockAggregateType("%agg", "%int16", "%int32"); builder.mockPtrType("%i16_ptr", "%int16", "Uniform"); builder.mockPtrType("%agg_ptr", "%agg", "Uniform"); Expression i1 = expressions.makeValue(1, i16Type); Expression i2 = expressions.makeValue(2, i32Type); - Expression arr = expressions.makeConstruct(List.of(i1, i2)); + Expression struct = expressions.makeConstruct(aType, List.of(i1, i2)); - builder.addExpression("%const", arr); + builder.addExpression("%const", struct); builder.mockFunctionStart(true); builder.addExpression("%register", builder.addRegister("%register", "%int32")); VisitorOpsMemory visitor = new VisitorOpsMemory(builder); @@ -874,16 +881,16 @@ public void testAccessChainMismatchingTypeStructure() { IntegerType i16Type = builder.mockIntType("%int16", 16); IntegerType i32Type = builder.mockIntType("%int32", 32); - builder.mockAggregateType("%agg", "%int16", "%int32"); + AggregateType aType = builder.mockAggregateType("%agg", "%int16", "%int32"); builder.mockPtrType("%i16_ptr", "%int16", "Uniform"); builder.mockPtrType("%agg_ptr", "%agg", "Uniform"); Expression i1 = expressions.makeValue(1, i16Type); Expression i2 = expressions.makeValue(2, i32Type); - Expression arr = expressions.makeConstruct(List.of(i1, i2)); + Expression struct = expressions.makeConstruct(aType, List.of(i1, i2)); - builder.addExpression("%const", arr); + builder.addExpression("%const", struct); builder.addExpression("%1", expressions.makeValue(1, i32Type)); try { diff --git a/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/VisitorOpsTypeTest.java b/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/VisitorOpsTypeTest.java index 982f3b5b10..dcc91893fc 100644 --- a/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/VisitorOpsTypeTest.java +++ b/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/VisitorOpsTypeTest.java @@ -4,10 +4,11 @@ import com.dat3m.dartagnan.expression.Type; import com.dat3m.dartagnan.expression.integers.IntLiteral; import com.dat3m.dartagnan.expression.type.IntegerType; +import com.dat3m.dartagnan.expression.type.ScopedPointerType; import com.dat3m.dartagnan.expression.type.TypeFactory; +import com.dat3m.dartagnan.parsers.program.visitors.spirv.decorations.DecorationType; import com.dat3m.dartagnan.parsers.program.visitors.spirv.mocks.MockProgramBuilder; import com.dat3m.dartagnan.parsers.program.visitors.spirv.mocks.MockSpirvParser; -import com.dat3m.dartagnan.expression.type.ScopedPointerType; import com.dat3m.dartagnan.program.event.Tag; import org.junit.Test; @@ -16,6 +17,7 @@ import java.util.Map; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; public class VisitorOpsTypeTest { @@ -37,6 +39,9 @@ public void testSupportedTypes() { """; addIntConstant("%val_20", 20); + addMemberOffset("%struct", "0", "0"); + addMemberOffset("%struct", "1", "2"); + addMemberOffset("%struct", "2", "10"); // when Map data = parseTypes(input); @@ -51,7 +56,7 @@ public void testSupportedTypes() { Type typeArray = types.getArrayType(typeInteger, 20); Type typePointer = types.getScopedPointerType(Tag.Spirv.SC_INPUT, typeInteger); Type typeFunction = types.getFunctionType(typeVoid, List.of(typePointer, typeInteger)); - Type typeStruct = types.getAggregateType(List.of(typeInteger, typePointer, typeArray)); + Type typeStruct = types.getAggregateType(List.of(typeInteger, typePointer, typeArray), List.of(0, 2, 10)); assertEquals(typeVoid, data.get("%void")); assertEquals(typeBoolean, data.get("%bool")); @@ -186,15 +191,15 @@ public void testPointerType() { // then assertEquals(5, data.size()); - ScopedPointerType boolPtr = (ScopedPointerType)data.get("%ptr_input_bool"); + ScopedPointerType boolPtr = (ScopedPointerType) data.get("%ptr_input_bool"); assertEquals(Tag.Spirv.SC_INPUT, boolPtr.getScopeId()); assertEquals(builder.getType("%bool"), boolPtr.getPointedType()); - ScopedPointerType inputIntPtr = (ScopedPointerType)data.get("%ptr_input_int"); + ScopedPointerType inputIntPtr = (ScopedPointerType) data.get("%ptr_input_int"); assertEquals(Tag.Spirv.SC_INPUT, inputIntPtr.getScopeId()); assertEquals(builder.getType("%int"), inputIntPtr.getPointedType()); - ScopedPointerType workgroupIntPtr = (ScopedPointerType)data.get("%ptr_workgroup_int"); + ScopedPointerType workgroupIntPtr = (ScopedPointerType) data.get("%ptr_workgroup_int"); assertEquals(Tag.Spirv.SC_WORKGROUP, workgroupIntPtr.getScopeId()); assertEquals(builder.getType("%int"), workgroupIntPtr.getPointedType()); } @@ -273,25 +278,35 @@ public void testStructType() { %s1 = OpTypeStruct %int %array %ptr = OpTypePointer Input %s1 %s2 = OpTypeStruct %bool %ptr + %s3 = OpTypeStruct %bool %ptr """; addIntConstant("%val_10", 10); + addMemberOffset("%s1", "0", "0"); + addMemberOffset("%s1", "1", "4"); + addMemberOffset("%s2", "0", "0"); + addMemberOffset("%s2", "1", "1"); + addMemberOffset("%s3", "0", "0"); + addMemberOffset("%s3", "1", "2"); // when Map data = parseTypes(input); // then - assertEquals(6, data.size()); + assertEquals(7, data.size()); Type typeBoolean = types.getBooleanType(); Type typeInteger = types.getIntegerType(32); Type typeArray = types.getArrayType(typeInteger, 10); - Type typeStructFirst = types.getAggregateType(List.of(typeInteger, typeArray)); + Type typeStructFirst = types.getAggregateType(List.of(typeInteger, typeArray), List.of(0, 4)); Type typePointer = types.getScopedPointerType(Tag.Spirv.SC_INPUT, typeStructFirst); - Type typeStructSecond = types.getAggregateType(List.of(typeBoolean, typePointer)); + Type typeStructSecond = types.getAggregateType(List.of(typeBoolean, typePointer), List.of(0, 1)); + Type typeStructThird = types.getAggregateType(List.of(typeBoolean, typePointer), List.of(0, 2)); assertEquals(data.get("%s1"), typeStructFirst); assertEquals(data.get("%s2"), typeStructSecond); + assertEquals(data.get("%s3"), typeStructThird); + assertNotEquals(data.get("%s2"), data.get("%s3")); } @Test(expected = ParsingException.class) @@ -302,6 +317,9 @@ public void testStructTypeUndefinedReference() { %s1 = OpTypeStruct %int %ptr """; + addMemberOffset("%s1", "0", "0"); + addMemberOffset("%s1", "1", "4"); + // when parseTypes(input); } @@ -316,4 +334,8 @@ private void addIntConstant(String id, int value) { IntLiteral iValue = new IntLiteral(type, new BigInteger(Integer.toString(value))); builder.addExpression(id, iValue); } + + private void addMemberOffset(String id, String idx, String offset) { + builder.getDecorationsBuilder().getDecoration(DecorationType.OFFSET).addDecoration(id, idx, offset); + } } diff --git a/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/extensions/VisitorExtensionClspvReflectionTest.java b/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/extensions/VisitorExtensionClspvReflectionTest.java index b97d9b9752..a0cf5406d7 100644 --- a/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/extensions/VisitorExtensionClspvReflectionTest.java +++ b/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/extensions/VisitorExtensionClspvReflectionTest.java @@ -22,7 +22,6 @@ public class VisitorExtensionClspvReflectionTest { @Before public void before() { builder.mockIntType("%uint", 32); - builder.mockVectorType("%v1uint", "%uint", 1); builder.mockVectorType("%v2uint", "%uint", 2); builder.mockVectorType("%v3uint", "%uint", 3); @@ -32,6 +31,14 @@ public void before() { builder.mockConstant("%uint_4", "%uint", 4); builder.mockConstant("%uint_8", "%uint", 8); builder.mockConstant("%uint_12", "%uint", 12); + builder.mockConstant("%uint_16", "%uint", 16); + builder.mockConstant("%uint_24", "%uint", 24); + builder.mockConstant("%uint_32", "%uint", 32); + builder.mockConstant("%uint_36", "%uint", 36); + builder.mockConstant("%uint_48", "%uint", 48); + builder.mockConstant("%uint_60", "%uint", 60); + builder.mockConstant("%uint_64", "%uint", 64); + builder.mockConstant("%uint_80", "%uint", 80); } @Test @@ -40,13 +47,44 @@ public void testPushConstant() { String input = """ %ext = OpExtInstImport "NonSemantic.ClspvReflection.5" %1 = OpExtInst %void %ext PushConstantGlobalOffset %uint_0 %uint_12 - %2 = OpExtInst %void %ext PushConstantGlobalSize %uint_0 %uint_12 - %3 = OpExtInst %void %ext PushConstantEnqueuedLocalSize %uint_0 %uint_12 - %4 = OpExtInst %void %ext PushConstantNumWorkgroups %uint_0 %uint_12 - %5 = OpExtInst %void %ext PushConstantRegionOffset %uint_0 %uint_12 - %6 = OpExtInst %void %ext PushConstantRegionGroupOffset %uint_0 %uint_12 + %2 = OpExtInst %void %ext PushConstantGlobalSize %uint_16 %uint_12 + %3 = OpExtInst %void %ext PushConstantEnqueuedLocalSize %uint_32 %uint_12 + %4 = OpExtInst %void %ext PushConstantNumWorkgroups %uint_48 %uint_12 + %5 = OpExtInst %void %ext PushConstantRegionOffset %uint_64 %uint_12 + %6 = OpExtInst %void %ext PushConstantRegionGroupOffset %uint_80 %uint_12 + """; + + builder.mockStructMemberOffsets("%6x_v3uint", 0, 16, 32, 48, 64, 80); + builder.mockAggregateType("%6x_v3uint", "%v3uint", "%v3uint", "%v3uint", "%v3uint", "%v3uint", "%v3uint"); + builder.mockPtrType("%ptr_6x_v3uint", "%6x_v3uint", "PushConstant"); + ScopedPointerVariable pointer = builder.mockVariable("%var", "%ptr_6x_v3uint"); + + // when + new MockSpirvParser(input).spv().accept(new VisitorExtensionClspvReflection(builder)); + + // then + verifyPushConstant(pointer, 0, List.of(0, 0, 0)); + verifyPushConstant(pointer, 16, List.of(24, 1, 1)); + verifyPushConstant(pointer, 32, List.of(6, 1, 1)); + verifyPushConstant(pointer, 48, List.of(4, 1, 1)); + verifyPushConstant(pointer, 64, List.of(0, 0, 0)); + verifyPushConstant(pointer, 80, List.of(0, 0, 0)); + } + + @Test + public void testPushConstantOffsets() { + // given + String input = """ + %ext = OpExtInstImport "NonSemantic.ClspvReflection.5" + %1 = OpExtInst %void %ext PushConstantGlobalOffset %uint_0 %uint_12 + %2 = OpExtInst %void %ext PushConstantGlobalSize %uint_12 %uint_12 + %3 = OpExtInst %void %ext PushConstantEnqueuedLocalSize %uint_24 %uint_12 + %4 = OpExtInst %void %ext PushConstantNumWorkgroups %uint_36 %uint_12 + %5 = OpExtInst %void %ext PushConstantRegionOffset %uint_48 %uint_12 + %6 = OpExtInst %void %ext PushConstantRegionGroupOffset %uint_60 %uint_12 """; + builder.mockStructMemberOffsets("%6x_v3uint", 0, 12, 24, 36, 48, 60); builder.mockAggregateType("%6x_v3uint", "%v3uint", "%v3uint", "%v3uint", "%v3uint", "%v3uint", "%v3uint"); builder.mockPtrType("%ptr_6x_v3uint", "%6x_v3uint", "PushConstant"); ScopedPointerVariable pointer = builder.mockVariable("%var", "%ptr_6x_v3uint"); @@ -134,10 +172,11 @@ public void testPushConstantIndexOutOfBound() { String input = """ %ext = OpExtInstImport "NonSemantic.ClspvReflection.5" %1 = OpExtInst %void %ext PushConstantGlobalOffset %uint_0 %uint_12 - %2 = OpExtInst %void %ext PushConstantGlobalSize %uint_0 %uint_12 - %3 = OpExtInst %void %ext PushConstantEnqueuedLocalSize %uint_0 %uint_12 + %2 = OpExtInst %void %ext PushConstantGlobalSize %uint_16 %uint_12 + %3 = OpExtInst %void %ext PushConstantEnqueuedLocalSize %uint_32 %uint_12 """; + builder.mockStructMemberOffsets("%2x_v3uint", 0, 16); builder.mockAggregateType("%2x_v3uint", "%v3uint", "%v3uint"); builder.mockPtrType("%ptr_2x_v3uint", "%2x_v3uint", "PushConstant"); builder.mockVariable("%var", "%ptr_2x_v3uint"); @@ -149,8 +188,8 @@ public void testPushConstantIndexOutOfBound() { } catch (ParsingException e) { // then - assertEquals("Out of bounds definition 'PushConstantEnqueuedLocalSize' " + - "in PushConstant '%var'", e.getMessage()); + assertEquals("Argument PushConstantEnqueuedLocalSize doesn't match the PushConstant type", + e.getMessage()); } } @@ -173,7 +212,7 @@ public void testPushConstantMismatchingElementType() { } catch (ParsingException e) { // then - assertEquals("Unexpected element type in '%var' at index 0", e.getMessage()); + assertEquals("Argument PushConstantGlobalOffset doesn't match the PushConstant type", e.getMessage()); } } @@ -204,7 +243,7 @@ public void testPodPushConstantMixed() { %ext = OpExtInstImport "NonSemantic.ClspvReflection.5" %1 = OpExtInst %void %ext PushConstantGlobalSize %uint_0 %uint_12 %2 = OpExtInst %void %ext ArgumentInfo %kernel - %3 = OpExtInst %void %ext ArgumentPodPushConstant %kernel %uint_1 %uint_0 %uint_4 %2 + %3 = OpExtInst %void %ext ArgumentPodPushConstant %kernel %uint_1 %uint_12 %uint_4 %2 """; builder.mockAggregateType("%v3uint_v1uint", "%v3uint", "%v1uint"); @@ -311,8 +350,8 @@ public void testPodPushConstantIndexOutOfBound() { } catch (ParsingException e) { // then - assertEquals("Out of bounds definition 'ArgumentPodPushConstant' " + - "in PushConstant '%var'", e.getMessage()); + assertEquals("Argument ArgumentPodPushConstant doesn't match the PushConstant type", + e.getMessage()); } } @@ -337,7 +376,8 @@ public void testPodPushConstantMismatchingElementType() { } catch (ParsingException e) { // then - assertEquals("Unexpected offset in PushConstant '%var' element '1'", e.getMessage()); + assertEquals("Argument ArgumentPodPushConstant doesn't match the PushConstant type", + e.getMessage()); } } @@ -362,7 +402,8 @@ public void testPodPushConstantElementSizeOutOfBound() { } catch (ParsingException e) { // then - assertEquals("Unexpected offset in PushConstant '%var' element '1'", e.getMessage()); + assertEquals("Argument ArgumentPodPushConstant doesn't match the PushConstant type", + e.getMessage()); } } diff --git a/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/mocks/MockProgramBuilder.java b/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/mocks/MockProgramBuilder.java index fe19f6231a..0713e232ce 100644 --- a/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/mocks/MockProgramBuilder.java +++ b/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/mocks/MockProgramBuilder.java @@ -6,6 +6,9 @@ import com.dat3m.dartagnan.expression.booleans.BoolLiteral; import com.dat3m.dartagnan.expression.integers.IntLiteral; import com.dat3m.dartagnan.expression.type.*; +import com.dat3m.dartagnan.parsers.program.visitors.spirv.decorations.Decoration; +import com.dat3m.dartagnan.parsers.program.visitors.spirv.decorations.DecorationType; +import com.dat3m.dartagnan.parsers.program.visitors.spirv.decorations.Offset; import com.dat3m.dartagnan.parsers.program.visitors.spirv.helpers.HelperTags; import com.dat3m.dartagnan.parsers.program.visitors.spirv.builders.ProgramBuilder; import com.dat3m.dartagnan.parsers.program.visitors.spirv.utils.ThreadGrid; @@ -63,8 +66,12 @@ public ArrayType mockVectorType(String id, String innerTypeId, int size) { } public AggregateType mockAggregateType(String id, String... innerTypeIds) { + Offset decoration = (Offset) getDecorationsBuilder().getDecoration(DecorationType.OFFSET); + Map offsets = decoration.getValue(id); List innerTypes = Arrays.stream(innerTypeIds).map(this::getType).toList(); - AggregateType type = typeFactory.getAggregateType(innerTypes); + AggregateType type = offsets != null + ? typeFactory.getAggregateType(innerTypes, new ArrayList<>(offsets.values())) + : typeFactory.getAggregateType(innerTypes); return (AggregateType) addType(id, type); } @@ -90,7 +97,7 @@ public Expression mockConstant(String id, String typeId, Object value) { return addExpression(id, construction); } else if (type instanceof AggregateType) { List members = ((List) value).stream().map(s -> getExpression((String) s)).toList(); - Expression construction = exprFactory.makeConstruct(members); + Expression construction = exprFactory.makeConstruct(type, members); return addExpression(id, construction); } throw new UnsupportedOperationException("Unsupported mock constant type " + typeId); @@ -125,6 +132,13 @@ public ScopedPointerVariable mockVariable(String id, String typeId) { return (ScopedPointerVariable) addExpression(id, pointer); } + public void mockStructMemberOffsets(String id, Integer... offsets) { + Decoration decoration = getDecorationsBuilder().getDecoration(DecorationType.OFFSET); + for (int i = 0; i < offsets.length; i++) { + decoration.addDecoration(id, Integer.toString(i), Integer.toString(offsets[i])); + } + } + public void mockFunctionStart(boolean addStartLabel) { FunctionType type = typeFactory.getFunctionType(typeFactory.getVoidType(), List.of()); startCurrentFunction(new Function("mock_function", type, List.of(), 0, null)); diff --git a/dartagnan/src/test/java/com/dat3m/dartagnan/spirv/header/AbstractTest.java b/dartagnan/src/test/java/com/dat3m/dartagnan/spirv/header/AbstractTest.java index 26e73ea7ea..90dddf69c2 100644 --- a/dartagnan/src/test/java/com/dat3m/dartagnan/spirv/header/AbstractTest.java +++ b/dartagnan/src/test/java/com/dat3m/dartagnan/spirv/header/AbstractTest.java @@ -13,6 +13,11 @@ public abstract class AbstractTest { OpMemoryModel Logical Vulkan OpEntryPoint GLCompute %main "main" OpSource GLSL 450 + OpMemberDecorate %struct 0 Offset 0 + OpMemberDecorate %struct 1 Offset 16 + OpMemberDecorate %struct_2 0 Offset 0 + OpMemberDecorate %struct_2 1 Offset 4 + OpMemberDecorate %struct_2 2 Offset 8 %void = OpTypeVoid %uint16 = OpTypeInt 16 0 %uint32 = OpTypeInt 32 0 diff --git a/dartagnan/src/test/java/com/dat3m/dartagnan/spirv/header/BadIndexTest.java b/dartagnan/src/test/java/com/dat3m/dartagnan/spirv/header/BadIndexTest.java index 24ce20246c..dcced3e61c 100644 --- a/dartagnan/src/test/java/com/dat3m/dartagnan/spirv/header/BadIndexTest.java +++ b/dartagnan/src/test/java/com/dat3m/dartagnan/spirv/header/BadIndexTest.java @@ -33,7 +33,7 @@ public static Iterable data() { {"; @Output: forall %v8[0][1]==0", "Index is out of bounds for variable '%v8[0][1]'"}, {"; @Input: %v8={{{0}}}", - "Mismatching value type for variable '%v8[0][0]', expected 'bv64' but received '{ bv64 }'"}, + "Mismatching value type for variable '%v8[0][0]', expected 'bv64' but received '{ 0: bv64 }'"}, {"; @Input: %v8={0}", "Mismatching value type for variable '%v8[0]', expected '[1 x bv64]' but received 'bv64'"}, {"; @Input: %v8={{0}, {0}}", diff --git a/dartagnan/src/test/resources/spirv/basic/array-of-vector1.spv.dis b/dartagnan/src/test/resources/spirv/basic/array-of-vector1.spv.dis index 703a730b8c..10ea4e9e2e 100644 --- a/dartagnan/src/test/resources/spirv/basic/array-of-vector1.spv.dis +++ b/dartagnan/src/test/resources/spirv/basic/array-of-vector1.spv.dis @@ -8,6 +8,8 @@ OpMemoryModel Logical Vulkan OpEntryPoint GLCompute %main "main" OpSource GLSL 450 + OpMemberDecorate %struct 0 Offset 0 + OpMemberDecorate %struct 1 Offset 8 %void = OpTypeVoid %func = OpTypeFunction %void %uint = OpTypeInt 64 0 diff --git a/dartagnan/src/test/resources/spirv/basic/mixed-size.spv.dis b/dartagnan/src/test/resources/spirv/basic/mixed-size.spv.dis index 0fa6b56f68..8fc59795ce 100644 --- a/dartagnan/src/test/resources/spirv/basic/mixed-size.spv.dis +++ b/dartagnan/src/test/resources/spirv/basic/mixed-size.spv.dis @@ -12,6 +12,17 @@ OpMemoryModel Logical Vulkan OpEntryPoint GLCompute %main "main" OpSource GLSL 450 + OpMemberDecorate %struct1 0 Offset 0 + OpMemberDecorate %struct1 1 Offset 2 + OpMemberDecorate %struct1 2 Offset 6 + OpMemberDecorate %struct2 0 Offset 0 + OpMemberDecorate %struct2 1 Offset 1 + OpMemberDecorate %struct2 2 Offset 3 + OpMemberDecorate %struct2 3 Offset 6 + OpMemberDecorate %struct2 4 Offset 10 + OpMemberDecorate %struct2 5 Offset 15 + OpMemberDecorate %struct2 6 Offset 21 + OpMemberDecorate %struct2 7 Offset 28 %void = OpTypeVoid %uint = OpTypeInt 64 0 %ptr_uint = OpTypePointer Uniform %uint