Skip to content

Commit

Permalink
Add support for cttz (count trailing zeros) intrinsic (#689)
Browse files Browse the repository at this point in the history
Signed-off-by: Hernan Ponce de Leon <[email protected]>
  • Loading branch information
hernanponcedeleon authored Jun 7, 2024
1 parent 28885e2 commit 929b15b
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 4 deletions.
22 changes: 22 additions & 0 deletions benchmarks/c/miscellaneous/cttz.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#include <stdint.h>
#include <assert.h>

volatile int32_t x = INT32_MAX+1;
volatile int32_t y = INT32_MAX-1;
volatile int32_t z = 1;
volatile int32_t u;

int main()
{
// x = 1000 0000 0000 0000 0000 0000 0000 0000
u = __builtin_ctz(x);
assert(u == 31);
// y = 1111 1111 1111 1111 1111 1111 1111 1110
u = __builtin_ctz(y);
assert(u == 1);
// z = 0000 0000 0000 0000 0000 0000 0000 0001
u = __builtin_ctz(z);
assert(u == 0);

return 0;
}
Original file line number Diff line number Diff line change
Expand Up @@ -240,18 +240,33 @@ public Formula visitIntUnaryExpression(IntUnaryExpr iUn) {
case CTLZ -> {
if (inner instanceof BitvectorFormula bv) {
BitvectorFormulaManager bvmgr = bitvectorFormulaManager();
// enc = extract(bv, 63, 63) == 1 ? 0 : (extract(bv, 62, 62) == 1 ? 1 : extract ... extract(bv, 0, 0) ? 63 : 64)
// enc = extract(bv, 63, 63) == 1 ? 0 : (extract(bv, 62, 62) == 1 ? 1 : extract ... extract(bv, 0, 0) == 1 ? 63 : 64)
int bvLength = bvmgr.getLength(bv);
BitvectorFormula bv1 = bvmgr.makeBitvector(1, 1);
BitvectorFormula enc = bvmgr.makeBitvector(bvLength, bvLength);
for(int i = bvmgr.getLength(bv) - 1; i >= 0; i--) {
for(int i = bvLength - 1; i >= 0; i--) {
BitvectorFormula bvi = bvmgr.makeBitvector(bvLength, i);
BitvectorFormula bvbit = bvmgr.extract(bv, bvLength - (i + 1), bvLength - (i + 1));
enc = booleanFormulaManager.ifThenElse(bvmgr.equal(bvbit, bv1), bvi, enc);
}
return enc;
}
}
case CTTZ -> {
if (inner instanceof BitvectorFormula bv) {
BitvectorFormulaManager bvmgr = bitvectorFormulaManager();
// enc = extract(bv, 0, 0) == 1 ? 0 : (extract(bv, 1, 1) == 1 ? 1 : extract ... extract(bv, 63, 63) == 1? 63 : 64)
int bvLength = bvmgr.getLength(bv);
BitvectorFormula bv1 = bvmgr.makeBitvector(1, 1);
BitvectorFormula enc = bvmgr.makeBitvector(bvLength, bvLength);
for(int i = bvLength - 1; i >= 0; i--) {
BitvectorFormula bvi = bvmgr.makeBitvector(bvLength, i);
BitvectorFormula bvbit = bvmgr.extract(bv, i, i);
enc = booleanFormulaManager.ifThenElse(bvmgr.equal(bvbit, bv1), bvi, enc);
}
return enc;
}
}
}
throw new UnsupportedOperationException(
String.format("Encoding of (%s) %s %s not supported.", iUn.getType(), iUn.getKind(), inner));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,10 @@ public Expression makeCTLZ(Expression operand) {
return makeIntUnary(IntUnaryOp.CTLZ, operand);
}

public Expression makeCTTZ(Expression operand) {
return makeIntUnary(IntUnaryOp.CTTZ, operand);
}

public Expression makeAdd(Expression leftOperand, Expression rightOperand) {
return makeIntBinary(leftOperand, IntBinaryOp.ADD, rightOperand);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import com.dat3m.dartagnan.expression.ExpressionKind;

public enum IntUnaryOp implements ExpressionKind {
CTLZ, MINUS;
CTLZ, CTTZ, MINUS;

@Override
public String toString() {
Expand All @@ -14,6 +14,7 @@ public String toString() {
public String getSymbol() {
return switch (this) {
case CTLZ -> "ctlz ";
case CTTZ -> "cttz ";
case MINUS -> "-";
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ public Expression visitIntUnaryExpression(IntUnaryExpr expr) {
final BigInteger newValue = switch (expr.getKind()) {
case MINUS -> IntegerHelper.neg(lit.getValue(), bitWidth);
case CTLZ -> IntegerHelper.ctlz(lit.getValue(), bitWidth);
case CTTZ -> IntegerHelper.cttz(lit.getValue(), bitWidth);
};
return expressions.makeValue(newValue, expr.getType());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,4 +205,16 @@ public static BigInteger ctlz(BigInteger x, int bitWidth) {
return BigInteger.valueOf(leadingZeroes);
}

public static BigInteger cttz(BigInteger x, int bitWidth) {
int trailingZeroes = 0;
for (int i = 0; i < bitWidth; i++) {
if (!x.testBit(i)) {
trailingZeroes++;
} else {
break;
}
}
return BigInteger.valueOf(trailingZeroes);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ public enum Info {
LLVM(List.of("llvm.smax", "llvm.umax", "llvm.smin", "llvm.umin",
"llvm.ssub.sat", "llvm.usub.sat", "llvm.sadd.sat", "llvm.uadd.sat", // TODO: saturated shifts
"llvm.sadd.with.overflow", "llvm.ssub.with.overflow", "llvm.smul.with.overflow",
"llvm.ctlz", "llvm.ctpop"),
"llvm.ctlz", "llvm.cttz", "llvm.ctpop"),
false, false, true, true, Intrinsics::handleLLVMIntrinsic),
LLVM_ASSUME("llvm.assume", false, false, true, true, Intrinsics::inlineLLVMAssume),
LLVM_META(List.of("llvm.stacksave", "llvm.stackrestore", "llvm.lifetime"), false, false, true, true, Intrinsics::inlineAsZero),
Expand Down Expand Up @@ -930,6 +930,8 @@ private List<Event> handleLLVMIntrinsic(FunctionCall call) {

if (name.startsWith("llvm.ctlz")) {
return inlineLLVMCtlz(valueCall);
} else if (name.startsWith("llvm.cttz")) {
return inlineLLVMCtlz(valueCall);
} else if (name.startsWith("llvm.ctpop")) {
return inlineLLVMCtpop(valueCall);
} else if (name.contains("add.sat")) {
Expand Down Expand Up @@ -981,6 +983,23 @@ private List<Event> inlineLLVMCtlz(ValueFunctionCall call) {
return List.of(assignment);
}

private List<Event> inlineLLVMCttz(ValueFunctionCall call) {
//see https://llvm.org/docs/LangRef.html#llvm-cttz-intrinsic
checkArgument(call.getArguments().size() == 2,
"Expected 2 parameters for \"llvm.cttz\", got %s.", call.getArguments().size());
final Expression input = call.getArguments().get(0);
// TODO: Handle the second parameter as well
final Register resultReg = call.getResultRegister();
final Type type = resultReg.getType();
checkArgument(resultReg.getType() instanceof IntegerType,
"Non-integer %s type for \"llvm.cttz\".", type);
checkArgument(input.getType().equals(type),
"Return type %s of \"llvm.cttz\" must match argument type %s.", type, input.getType());
final Expression resultExpression = expressions.makeCTTZ(input);
final Event assignment = EventFactory.newLocal(resultReg, resultExpression);
return List.of(assignment);
}

private List<Event> inlineLLVMCtpop(ValueFunctionCall call) {
//see https://llvm.org/docs/LangRef.html#llvm-ctpop-intrinsic
final Expression input = call.getArguments().get(0);
Expand Down

0 comments on commit 929b15b

Please sign in to comment.