Skip to content

Commit

Permalink
Support explicitly defined offsets in aggregate type
Browse files Browse the repository at this point in the history
  • Loading branch information
natgavrilenko committed Nov 8, 2024
1 parent 4d03388 commit 4af61ca
Show file tree
Hide file tree
Showing 31 changed files with 579 additions and 296 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,7 @@ public Expression makeFloatCast(Expression operand, FloatType targetType, boolea
// -----------------------------------------------------------------------------------------------------------------
// Aggregates

public Expression makeConstruct(List<Expression> arguments) {
final AggregateType type = types.getAggregateType(arguments.stream().map(Expression::getType).toList());
public Expression makeConstruct(Type type, List<Expression> arguments) {
return new ConstructExpr(type, arguments);
}

Expand Down Expand Up @@ -302,11 +301,11 @@ public Expression makeGeneralZero(Type type) {
}
return makeArray(arrayType.getElementType(), zeroes, true);
} else if (type instanceof AggregateType structType) {
List<Expression> zeroes = new ArrayList<>(structType.getDirectFields().size());
for (Type fieldType : structType.getDirectFields()) {
zeroes.add(makeGeneralZero(fieldType));
List<Expression> zeroes = new ArrayList<>(structType.getTypeOffsets().size());
for (TypeOffset typeOffset : structType.getTypeOffsets()) {
zeroes.add(makeGeneralZero(typeOffset.type()));
}
return makeConstruct(zeroes);
return makeConstruct(structType, zeroes);
} else if (type instanceof IntegerType intType) {
return makeZero(intType);
} else if (type instanceof BooleanType) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import com.dat3m.dartagnan.expression.base.NaryExpressionBase;
import com.dat3m.dartagnan.expression.type.AggregateType;
import com.dat3m.dartagnan.expression.type.ArrayType;
import com.dat3m.dartagnan.expression.type.TypeOffset;

import java.util.List;
import java.util.stream.Collectors;
Expand All @@ -20,7 +21,8 @@ public ConstructExpr(Type type, List<Expression> arguments) {
checkArgument(type instanceof AggregateType || type instanceof ArrayType,
"Non-constructible type %s.", type);
checkArgument(!(type instanceof AggregateType a) ||
arguments.stream().map(Expression::getType).toList().equals(a.getDirectFields()),
arguments.stream().map(Expression::getType).toList()
.equals(a.getTypeOffsets().stream().map(TypeOffset::type).toList()),
"Arguments do not match the constructor signature.");
checkArgument(!(type instanceof ArrayType a) ||
!a.hasKnownNumElements() ||
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@
import com.dat3m.dartagnan.expression.base.UnaryExpressionBase;
import com.dat3m.dartagnan.expression.type.AggregateType;
import com.dat3m.dartagnan.expression.type.ArrayType;
import com.dat3m.dartagnan.expression.type.TypeOffset;
import com.google.common.base.Preconditions;

import java.util.List;

import static com.google.common.base.Preconditions.checkArgument;

public final class ExtractExpr extends UnaryExpressionBase<Type, ExpressionKind.Other> {
Expand All @@ -25,13 +28,14 @@ private static Type extractType(Expression expr, int index) {
Preconditions.checkArgument(exprType instanceof AggregateType || exprType instanceof ArrayType,
"Cannot extract from a non-aggregate expression: (%s)[%d].", expr, index);
if (exprType instanceof AggregateType aggregateType) {
return aggregateType.getDirectFields().get(index);
} else {
final ArrayType arrayType = (ArrayType) exprType;
checkArgument(0 <= index && (!arrayType.hasKnownNumElements() || index < arrayType.getNumElements()),
"Index %s out of bounds [0,%s].", index, arrayType.getNumElements() - 1);
return arrayType.getElementType();
final List<TypeOffset> typeOffsets = aggregateType.getTypeOffsets();
checkArgument(0 <= index && index < typeOffsets.size());
return typeOffsets.get(index).type();
}
final ArrayType arrayType = (ArrayType) exprType;
checkArgument(0 <= index && (!arrayType.hasKnownNumElements() || index < arrayType.getNumElements()),
"Index %s out of bounds [0,%s].", index, arrayType.getNumElements() - 1);
return arrayType.getElementType();
}

public int getFieldIndex() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ public Expression visitConstructExpression(ConstructExpr construct) {
for (final Expression argument : construct.getOperands()) {
arguments.add(argument.accept(this));
}
return expressions.makeConstruct(arguments);
return expressions.makeConstruct(construct.getType(), arguments);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,32 @@

import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

public final class AggregateType implements Type {
public class AggregateType implements Type {

private final List<Type> fields;
private final List<TypeOffset> directFields;

AggregateType(List<Type> directFields) {
this.fields = List.copyOf(directFields);
AggregateType(List<Type> fields, List<Integer> offsets) {
this.directFields = IntStream.range(0, fields.size()).boxed().map(i -> new TypeOffset(fields.get(i), offsets.get(i))).toList();
}

public List<Type> getDirectFields() {
return fields;
public List<TypeOffset> getTypeOffsets() {
return directFields;
}

@Override
public int hashCode() {
return fields.hashCode();
return directFields.hashCode();
}

@Override
public boolean equals(Object obj) {
return this == obj || obj instanceof AggregateType o && fields.equals(o.fields);
return this == obj || obj instanceof AggregateType o && directFields.equals(o.directFields);
}

@Override
public String toString() {
return fields.stream().map(Type::toString).collect(Collectors.joining(", ", "{ ", " }"));
return directFields.stream().map(f -> f.offset() + ": " + f.type()).collect(Collectors.joining(", ", "{ ", " }"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

import com.dat3m.dartagnan.expression.Type;
import com.dat3m.dartagnan.utils.Normalizer;
import com.google.common.math.IntMath;

import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.math.RoundingMode;
import java.util.*;
import java.util.stream.IntStream;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
Expand Down Expand Up @@ -71,7 +72,38 @@ public FunctionType getFunctionType(Type returnType, List<? extends Type> parame
public AggregateType getAggregateType(List<Type> fields) {
checkNotNull(fields);
checkArgument(fields.stream().noneMatch(t -> t == voidType), "Void fields are not allowed");
return typeNormalizer.normalize(new AggregateType(fields));
return typeNormalizer.normalize(new AggregateType(fields, computeDefaultOffsets(fields)));
}

public AggregateType getAggregateType(List<Type> fields, List<Integer> offsets) {
checkNotNull(fields);
checkNotNull(offsets);
checkArgument(fields.stream().noneMatch(t -> t == voidType), "Void fields are not allowed");
checkArgument(fields.size() == offsets.size(), "Offsets number does not match the fields number");
checkArgument(offsets.stream().noneMatch(o -> o < 0), "Offset cannot be negative");
checkArgument(offsets.isEmpty() || offsets.get(0) == 0, "The first offset must be zero");
checkArgument(IntStream.range(1, offsets.size()).boxed().allMatch(
i -> offsets.get(i) >= offsets.get(i - 1) + Integer.max(0, getMemorySizeInBytes(fields.get(i - 1), false))),
"Offset is too small");
checkArgument(IntStream.range(0, offsets.size() - 1).boxed().allMatch(
i -> getMemorySizeInBytes(fields.get(i)) > 0),
"Non-last element with unknown size");
return typeNormalizer.normalize(new AggregateType(fields, offsets));
}

private List<Integer> computeDefaultOffsets(List<Type> fields) {
List<Integer> offsets = new ArrayList<>();
int offset = 0;
if (!fields.isEmpty()) {
offset = getMemorySizeInBytes(fields.get(0));
offsets.add(0);
}
for (int i = 1; i < fields.size(); i++) {
offset = paddedSize(offset, getAlignment(fields.get(i)));
offsets.add(offset);
offset += getMemorySizeInBytes(fields.get(i));
}
return offsets;
}

public ArrayType getArrayType(Type element) {
Expand All @@ -92,7 +124,63 @@ public IntegerType getByteType() {
}

public int getMemorySizeInBytes(Type type) {
return TypeLayout.of(type).totalSizeInBytes();
return getMemorySizeInBytes(type, true);
}

public int getMemorySizeInBytes(Type type, boolean padded) {
if (type instanceof BooleanType) {
return 1;
}
if (type instanceof IntegerType integerType) {
return IntMath.divide(integerType.getBitWidth(), 8, RoundingMode.CEILING);
}
if (type instanceof FloatType floatType) {
return IntMath.divide(floatType.getBitWidth(), 8, RoundingMode.CEILING);
}
if (type instanceof ArrayType arrayType) {
if (arrayType.hasKnownNumElements()) {
Type elType = arrayType.getElementType();
return getMemorySizeInBytes(elType) * arrayType.getNumElements();
}
return -1;
}
if (type instanceof AggregateType aType) {
List<TypeOffset> typeOffsets = aType.getTypeOffsets();
if (typeOffsets.isEmpty()) {
return 0;
}
if (aType.getTypeOffsets().stream().anyMatch(o -> getMemorySizeInBytes(o.type()) < 0)) {
return -1;
}
TypeOffset lastTypeOffset = typeOffsets.get(typeOffsets.size() - 1);
int baseSize = lastTypeOffset.offset() + getMemorySizeInBytes(lastTypeOffset.type());
if (padded) {
return paddedSize(baseSize, getAlignment(type));
}
return baseSize;
}
throw new UnsupportedOperationException("Cannot compute memory layout of type " + type);
}

public int getAlignment(Type type) {
if (type instanceof BooleanType || type instanceof IntegerType || type instanceof FloatType) {
return getMemorySizeInBytes(type);
}
if (type instanceof ArrayType arrayType) {
return getMemorySizeInBytes(arrayType.getElementType());
}
if (type instanceof AggregateType aType) {
return aType.getTypeOffsets().stream().map(o -> getAlignment(o.type())).max(Integer::compare).orElseThrow();
}
throw new UnsupportedOperationException("Cannot compute memory layout of type " + type);
}

public static int paddedSize(int size, int alignment) {
int mod = size % alignment;
if (mod > 0) {
return size + alignment - mod;
}
return size;
}

public int getMemorySizeInBits(Type type) {
Expand All @@ -119,16 +207,13 @@ public Map<Integer, Type> decomposeIntoPrimitives(Type type) {
}
}
} else if (type instanceof AggregateType aggregateType) {
final List<Type> fields = aggregateType.getDirectFields();
for (int i = 0; i < fields.size(); i++) {
final int offset = getOffsetInBytes(aggregateType, i);
final Map<Integer, Type> innerDecomposition = decomposeIntoPrimitives(fields.get(i));
for (TypeOffset typeOffset : aggregateType.getTypeOffsets()) {
final Map<Integer, Type> innerDecomposition = decomposeIntoPrimitives(typeOffset.type());
if (innerDecomposition == null) {
return null;
}

for (Map.Entry<Integer, Type> entry : innerDecomposition.entrySet()) {
decomposition.put(entry.getKey() + offset, entry.getValue());
decomposition.put(typeOffset.offset() + entry.getKey(), entry.getValue());
}
}
} else {
Expand All @@ -147,12 +232,7 @@ public static boolean isStaticType(Type type) {
return aType.hasKnownNumElements() && isStaticType(aType.getElementType());
}
if (type instanceof AggregateType aType) {
for (Type elType : aType.getDirectFields()) {
if (!isStaticType(elType)) {
return false;
}
}
return true;
return aType.getTypeOffsets().stream().allMatch(o -> isStaticType(o.type()));
}
throw new UnsupportedOperationException("Cannot compute if type '" + type + "' is static");
}
Expand All @@ -162,12 +242,15 @@ public static boolean isStaticTypeOf(Type staticType, Type runtimeType) {
return true;
}
if (staticType instanceof AggregateType aStaticType && runtimeType instanceof AggregateType aRuntimeType) {
int size = aStaticType.getDirectFields().size();
if (size != aRuntimeType.getDirectFields().size()) {
int size = aStaticType.getTypeOffsets().size();
if (size != aRuntimeType.getTypeOffsets().size()) {
return false;
}
for (int i = 0; i < size; i++) {
if (!isStaticTypeOf(aStaticType.getDirectFields().get(i), aRuntimeType.getDirectFields().get(i))) {
TypeOffset staticTypeOffset = aStaticType.getTypeOffsets().get(i);
TypeOffset runtimeTypeOffset = aRuntimeType.getTypeOffsets().get(i);
if (staticTypeOffset.offset() != runtimeTypeOffset.offset()
|| !isStaticTypeOf(staticTypeOffset.type(), runtimeTypeOffset.type())) {
return false;
}
}
Expand Down

This file was deleted.

Loading

0 comments on commit 4af61ca

Please sign in to comment.