Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support explicitly defined offsets in aggregate type #770

Merged
merged 1 commit into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,7 @@ public Expression makeFloatCast(Expression operand, FloatType targetType, boolea
// -----------------------------------------------------------------------------------------------------------------
// Aggregates

public Expression makeConstruct(List<Expression> arguments) {
final AggregateType type = types.getAggregateType(arguments.stream().map(Expression::getType).toList());
public Expression makeConstruct(Type type, List<Expression> arguments) {
return new ConstructExpr(type, arguments);
}

Expand Down Expand Up @@ -302,11 +301,11 @@ public Expression makeGeneralZero(Type type) {
}
return makeArray(arrayType.getElementType(), zeroes, true);
} else if (type instanceof AggregateType structType) {
List<Expression> zeroes = new ArrayList<>(structType.getDirectFields().size());
for (Type fieldType : structType.getDirectFields()) {
zeroes.add(makeGeneralZero(fieldType));
List<Expression> zeroes = new ArrayList<>(structType.getTypeOffsets().size());
for (TypeOffset typeOffset : structType.getTypeOffsets()) {
zeroes.add(makeGeneralZero(typeOffset.type()));
}
return makeConstruct(zeroes);
return makeConstruct(structType, zeroes);
} else if (type instanceof IntegerType intType) {
return makeZero(intType);
} else if (type instanceof BooleanType) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import com.dat3m.dartagnan.expression.base.NaryExpressionBase;
import com.dat3m.dartagnan.expression.type.AggregateType;
import com.dat3m.dartagnan.expression.type.ArrayType;
import com.dat3m.dartagnan.expression.type.TypeOffset;

import java.util.List;
import java.util.stream.Collectors;
Expand All @@ -20,7 +21,8 @@ public ConstructExpr(Type type, List<Expression> arguments) {
checkArgument(type instanceof AggregateType || type instanceof ArrayType,
"Non-constructible type %s.", type);
checkArgument(!(type instanceof AggregateType a) ||
arguments.stream().map(Expression::getType).toList().equals(a.getDirectFields()),
arguments.stream().map(Expression::getType).toList()
.equals(a.getTypeOffsets().stream().map(TypeOffset::type).toList()),
"Arguments do not match the constructor signature.");
checkArgument(!(type instanceof ArrayType a) ||
!a.hasKnownNumElements() ||
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@
import com.dat3m.dartagnan.expression.base.UnaryExpressionBase;
import com.dat3m.dartagnan.expression.type.AggregateType;
import com.dat3m.dartagnan.expression.type.ArrayType;
import com.dat3m.dartagnan.expression.type.TypeOffset;
import com.google.common.base.Preconditions;

import java.util.List;

import static com.google.common.base.Preconditions.checkArgument;

public final class ExtractExpr extends UnaryExpressionBase<Type, ExpressionKind.Other> {
Expand All @@ -25,13 +28,14 @@ private static Type extractType(Expression expr, int index) {
Preconditions.checkArgument(exprType instanceof AggregateType || exprType instanceof ArrayType,
"Cannot extract from a non-aggregate expression: (%s)[%d].", expr, index);
if (exprType instanceof AggregateType aggregateType) {
return aggregateType.getDirectFields().get(index);
} else {
final ArrayType arrayType = (ArrayType) exprType;
checkArgument(0 <= index && (!arrayType.hasKnownNumElements() || index < arrayType.getNumElements()),
"Index %s out of bounds [0,%s].", index, arrayType.getNumElements() - 1);
return arrayType.getElementType();
final List<TypeOffset> typeOffsets = aggregateType.getTypeOffsets();
checkArgument(0 <= index && index < typeOffsets.size());
return typeOffsets.get(index).type();
}
final ArrayType arrayType = (ArrayType) exprType;
checkArgument(0 <= index && (!arrayType.hasKnownNumElements() || index < arrayType.getNumElements()),
"Index %s out of bounds [0,%s].", index, arrayType.getNumElements() - 1);
return arrayType.getElementType();
}

public int getFieldIndex() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ public Expression visitConstructExpression(ConstructExpr construct) {
for (final Expression argument : construct.getOperands()) {
arguments.add(argument.accept(this));
}
return expressions.makeConstruct(arguments);
return expressions.makeConstruct(construct.getType(), arguments);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,32 @@

import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

public final class AggregateType implements Type {
public class AggregateType implements Type {

private final List<Type> fields;
private final List<TypeOffset> directFields;

AggregateType(List<Type> directFields) {
this.fields = List.copyOf(directFields);
AggregateType(List<Type> fields, List<Integer> offsets) {
this.directFields = IntStream.range(0, fields.size()).boxed().map(i -> new TypeOffset(fields.get(i), offsets.get(i))).toList();
}

public List<Type> getDirectFields() {
return fields;
public List<TypeOffset> getTypeOffsets() {
return directFields;
}

@Override
public int hashCode() {
return fields.hashCode();
return directFields.hashCode();
}

@Override
public boolean equals(Object obj) {
return this == obj || obj instanceof AggregateType o && fields.equals(o.fields);
return this == obj || obj instanceof AggregateType o && directFields.equals(o.directFields);
}

@Override
public String toString() {
return fields.stream().map(Type::toString).collect(Collectors.joining(", ", "{ ", " }"));
return directFields.stream().map(f -> f.offset() + ": " + f.type()).collect(Collectors.joining(", ", "{ ", " }"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

import com.dat3m.dartagnan.expression.Type;
import com.dat3m.dartagnan.utils.Normalizer;
import com.google.common.math.IntMath;

import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.math.RoundingMode;
import java.util.*;
import java.util.stream.IntStream;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
Expand Down Expand Up @@ -71,7 +72,38 @@ public FunctionType getFunctionType(Type returnType, List<? extends Type> parame
public AggregateType getAggregateType(List<Type> fields) {
checkNotNull(fields);
checkArgument(fields.stream().noneMatch(t -> t == voidType), "Void fields are not allowed");
return typeNormalizer.normalize(new AggregateType(fields));
return typeNormalizer.normalize(new AggregateType(fields, computeDefaultOffsets(fields)));
}

public AggregateType getAggregateType(List<Type> fields, List<Integer> offsets) {
checkNotNull(fields);
checkNotNull(offsets);
checkArgument(fields.stream().noneMatch(t -> t == voidType), "Void fields are not allowed");
checkArgument(fields.size() == offsets.size(), "Offsets number does not match the fields number");
checkArgument(offsets.stream().noneMatch(o -> o < 0), "Offset cannot be negative");
checkArgument(offsets.isEmpty() || offsets.get(0) == 0, "The first offset must be zero");
checkArgument(IntStream.range(1, offsets.size()).boxed().allMatch(
i -> offsets.get(i) >= offsets.get(i - 1) + Integer.max(0, getMemorySizeInBytes(fields.get(i - 1), false))),
"Offset is too small");
checkArgument(IntStream.range(0, offsets.size() - 1).boxed().allMatch(
i -> getMemorySizeInBytes(fields.get(i)) > 0),
"Non-last element with unknown size");
return typeNormalizer.normalize(new AggregateType(fields, offsets));
}

private List<Integer> computeDefaultOffsets(List<Type> fields) {
List<Integer> offsets = new ArrayList<>();
int offset = 0;
if (!fields.isEmpty()) {
offset = getMemorySizeInBytes(fields.get(0));
offsets.add(0);
}
for (int i = 1; i < fields.size(); i++) {
offset = paddedSize(offset, getAlignment(fields.get(i)));
offsets.add(offset);
offset += getMemorySizeInBytes(fields.get(i));
}
return offsets;
}

public ArrayType getArrayType(Type element) {
Expand All @@ -92,7 +124,63 @@ public IntegerType getByteType() {
}

public int getMemorySizeInBytes(Type type) {
return TypeLayout.of(type).totalSizeInBytes();
return getMemorySizeInBytes(type, true);
}

public int getMemorySizeInBytes(Type type, boolean padded) {
if (type instanceof BooleanType) {
return 1;
}
if (type instanceof IntegerType integerType) {
return IntMath.divide(integerType.getBitWidth(), 8, RoundingMode.CEILING);
}
if (type instanceof FloatType floatType) {
return IntMath.divide(floatType.getBitWidth(), 8, RoundingMode.CEILING);
}
if (type instanceof ArrayType arrayType) {
if (arrayType.hasKnownNumElements()) {
Type elType = arrayType.getElementType();
return getMemorySizeInBytes(elType) * arrayType.getNumElements();
}
return -1;
}
if (type instanceof AggregateType aType) {
List<TypeOffset> typeOffsets = aType.getTypeOffsets();
if (typeOffsets.isEmpty()) {
return 0;
}
if (aType.getTypeOffsets().stream().anyMatch(o -> getMemorySizeInBytes(o.type()) < 0)) {
return -1;
}
hernanponcedeleon marked this conversation as resolved.
Show resolved Hide resolved
TypeOffset lastTypeOffset = typeOffsets.get(typeOffsets.size() - 1);
int baseSize = lastTypeOffset.offset() + getMemorySizeInBytes(lastTypeOffset.type());
if (padded) {
return paddedSize(baseSize, getAlignment(type));
}
return baseSize;
}
throw new UnsupportedOperationException("Cannot compute memory layout of type " + type);
}

public int getAlignment(Type type) {
if (type instanceof BooleanType || type instanceof IntegerType || type instanceof FloatType) {
return getMemorySizeInBytes(type);
}
if (type instanceof ArrayType arrayType) {
return getMemorySizeInBytes(arrayType.getElementType());
}
if (type instanceof AggregateType aType) {
return aType.getTypeOffsets().stream().map(o -> getAlignment(o.type())).max(Integer::compare).orElseThrow();
}
throw new UnsupportedOperationException("Cannot compute memory layout of type " + type);
}

public static int paddedSize(int size, int alignment) {
int mod = size % alignment;
if (mod > 0) {
return size + alignment - mod;
}
return size;
}

public int getMemorySizeInBits(Type type) {
Expand All @@ -119,16 +207,13 @@ public Map<Integer, Type> decomposeIntoPrimitives(Type type) {
}
}
} else if (type instanceof AggregateType aggregateType) {
final List<Type> fields = aggregateType.getDirectFields();
for (int i = 0; i < fields.size(); i++) {
final int offset = getOffsetInBytes(aggregateType, i);
final Map<Integer, Type> innerDecomposition = decomposeIntoPrimitives(fields.get(i));
for (TypeOffset typeOffset : aggregateType.getTypeOffsets()) {
final Map<Integer, Type> innerDecomposition = decomposeIntoPrimitives(typeOffset.type());
if (innerDecomposition == null) {
return null;
}

for (Map.Entry<Integer, Type> entry : innerDecomposition.entrySet()) {
decomposition.put(entry.getKey() + offset, entry.getValue());
decomposition.put(typeOffset.offset() + entry.getKey(), entry.getValue());
}
}
} else {
Expand All @@ -147,12 +232,7 @@ public static boolean isStaticType(Type type) {
return aType.hasKnownNumElements() && isStaticType(aType.getElementType());
}
if (type instanceof AggregateType aType) {
for (Type elType : aType.getDirectFields()) {
if (!isStaticType(elType)) {
return false;
}
}
return true;
return aType.getTypeOffsets().stream().allMatch(o -> isStaticType(o.type()));
}
throw new UnsupportedOperationException("Cannot compute if type '" + type + "' is static");
}
Expand All @@ -162,12 +242,15 @@ public static boolean isStaticTypeOf(Type staticType, Type runtimeType) {
return true;
}
if (staticType instanceof AggregateType aStaticType && runtimeType instanceof AggregateType aRuntimeType) {
int size = aStaticType.getDirectFields().size();
if (size != aRuntimeType.getDirectFields().size()) {
int size = aStaticType.getTypeOffsets().size();
if (size != aRuntimeType.getTypeOffsets().size()) {
return false;
}
for (int i = 0; i < size; i++) {
if (!isStaticTypeOf(aStaticType.getDirectFields().get(i), aRuntimeType.getDirectFields().get(i))) {
TypeOffset staticTypeOffset = aStaticType.getTypeOffsets().get(i);
TypeOffset runtimeTypeOffset = aRuntimeType.getTypeOffsets().get(i);
if (staticTypeOffset.offset() != runtimeTypeOffset.offset()
|| !isStaticTypeOf(staticTypeOffset.type(), runtimeTypeOffset.type())) {
return false;
}
}
Expand Down

This file was deleted.

Loading
Loading