Skip to content

Commit

Permalink
Devirtualise pthread_create (#591)
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasHaas authored Dec 18, 2023
1 parent f924058 commit 340fb3b
Show file tree
Hide file tree
Showing 7 changed files with 378 additions and 65 deletions.
38 changes: 38 additions & 0 deletions benchmarks/c/miscellaneous/funcPtrInStaticMemory.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#include <stdio.h>
#include <assert.h>
#include <pthread.h>

/*
The test checks if function pointers in static memory are handled correctly.
Expected result: PASS
*/

typedef struct {
void* (*funcPtrOne) (void*);
void* (*funcPtrTwo) (void*);
} MyPtrStruct;

int callCounter = 0;

void *myFunc1(void* arg) {
assert (arg == 1);
return NULL;
}

void *myFunc2(void* arg) {
assert (arg == 42 || arg == 123);
callCounter++;
return arg;
}

MyPtrStruct myStruct = { myFunc1, myFunc2 };

int main () {
assert(myStruct.funcPtrOne(1) == NULL);
assert(myStruct.funcPtrTwo(42) == 42);

pthread_t t;
pthread_create(&t, NULL, myStruct.funcPtrTwo, (void*)123);
pthread_join(t, NULL);
assert (callCounter == 2);
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
public abstract class FunctionCall extends AbstractEvent implements RegReader {

protected FunctionType funcType;
protected Expression callTarget; // TODO: Generalize to function pointer expressions
protected Expression callTarget;
protected List<Expression> arguments;

protected FunctionCall(FunctionType funcType, Expression funcPtr, List<Expression> arguments) {
Expand Down Expand Up @@ -54,6 +54,21 @@ protected FunctionCall(FunctionCall other) {
public Expression getCallTarget() { return callTarget; }
public List<Expression> getArguments() { return arguments; }

public void setArgument(int index, Expression argument) {
arguments.set(index, argument);
}

public void setCallTarget(Expression callTarget) {
if (callTarget instanceof Function func) {
Preconditions.checkArgument(func.getFunctionType() == funcType,
"Call target %s has mismatching function type: expected %s", callTarget, funcType);
}
this.callTarget = callTarget;
}

@Override
public abstract FunctionCall getCopy();

@Override
public Set<Register.Read> getRegisterReads() {
final Set<Register.Read> regReads = new HashSet<>();
Expand All @@ -70,5 +85,4 @@ public void transformExpressions(ExpressionVisitor<? extends Expression> exprTra
callTarget = callTarget.accept(exprTransformer);
arguments.replaceAll(expression -> expression.accept(exprTransformer));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,8 @@ public void setResultRegister(Register reg) {

@Override
protected String defaultString() {
if (isDirectCall()) {
return String.format("%s <- call %s(%s)", resultRegister, ((Function)callTarget).getName(), super.argumentsToString());
} else {
return String.format("%s <- call %s(%s)", resultRegister, callTarget, super.argumentsToString());
}
final Object target = isDirectCall() ? ((Function)callTarget).getName() : callTarget;
return String.format("%s <- call %s(%s)", resultRegister, target, super.argumentsToString());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,8 @@ protected VoidFunctionCall(VoidFunctionCall other) {

@Override
protected String defaultString() {
if (isDirectCall()) {
return String.format("call %s(%s)", ((Function)callTarget).getName(), super.argumentsToString());
} else {
return String.format("call %s(%s)", callTarget, super.argumentsToString());
}
final Object target = isDirectCall() ? ((Function)callTarget).getName() : callTarget;
return String.format("call %s(%s)", target, super.argumentsToString());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import com.dat3m.dartagnan.expression.processing.ExpressionInspector;
import com.dat3m.dartagnan.expression.processing.ExpressionVisitor;
import com.dat3m.dartagnan.expression.type.IntegerType;
import com.dat3m.dartagnan.expression.type.Type;
import com.dat3m.dartagnan.expression.type.TypeFactory;
import com.dat3m.dartagnan.program.Function;
import com.dat3m.dartagnan.program.Program;
Expand All @@ -16,9 +17,9 @@
import com.dat3m.dartagnan.program.event.core.Label;
import com.dat3m.dartagnan.program.event.core.utils.RegReader;
import com.dat3m.dartagnan.program.event.functions.FunctionCall;
import com.dat3m.dartagnan.program.event.functions.ValueFunctionCall;
import com.dat3m.dartagnan.program.memory.Memory;
import com.dat3m.dartagnan.program.memory.MemoryObject;
import com.google.common.base.Verify;
import com.google.common.collect.Iterables;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
Expand All @@ -35,6 +36,11 @@
- All non-standard uses are replaced by their address values.
- Every indirect call is replaced by a switch statement over all registered functions that have a matching type.
Each case of the switch statement contains a direct call to the corresponding function.
TODO: We also need to devirtualize intrinsic functions that expect function pointers.
For now, we only devirtualize pthread_create based on its third parameter.
More generally, we could extend IntrinsicInfo to tell for each intrinsic which parameters are function pointers and
need devirtualization.
*/
public class NaiveDevirtualisation implements ProgramProcessor {

Expand Down Expand Up @@ -111,8 +117,7 @@ private boolean assignAddressToFunction(Function func, Map<Function, IValue> fun

private void applyTransformerToEvent(Event e, ExpressionVisitor<Expression> transformer) {
if (e instanceof FunctionCall call) {
if (call.isDirectCall() && call.getCalledFunction().isIntrinsic()
&& call.getCalledFunction().getName().contains("pthread_create")) {
if (call.isDirectCall() && call.getCalledFunction().getIntrinsicInfo() == Intrinsics.Info.P_THREAD_CREATE) {
// We avoid transforming functions passed as call target to pthread_create
// However, we still collect the last argument of the call, because it
// is the argument passed to the created thread (which might be a pointer to a function).
Expand All @@ -131,61 +136,112 @@ private void devirtualise(Function function, Map<Function, IValue> func2AddressM

int devirtCounter = 0;
for (FunctionCall call : function.getEvents(FunctionCall.class)) {
if (!call.isDirectCall()) {
final List<Function> possibleTargets = func2AddressMap.keySet().stream()
.filter(f -> f.getFunctionType() == call.getCallType()).collect(Collectors.toList());

// FIXME: Here we remove the calling function itself so as to avoid trivial recursion.
// However, indirect/mutual recursion is not prevented by this!
if (possibleTargets.removeIf(f -> f == function)) {
logger.warn("Found potentially recursive dynamic call \"{}\". " +
"Dartagnan (unsoundly) assumes the recursive call cannot happen.", call);
}

if (possibleTargets.isEmpty()) {
logger.warn("Cannot resolve dynamic call \"{}\", no matching functions found.", call);
}
if (!needsDevirtualization(call)) {
continue;
}

logger.trace("Devirtualizing call \"{}\" with possible targets: {}", call, possibleTargets);

final List<Label> caseLabels = new ArrayList<>(possibleTargets.size());
final List<CondJump> caseJumps = new ArrayList<>(possibleTargets.size());
final Expression funcPtr = call.getCallTarget();
// Construct call table
for (Function possibleTarget : possibleTargets) {
final IValue targetAddress = func2AddressMap.get(possibleTarget);
final Label caseLabel = EventFactory.newLabel(String.format("__Ldevirt_%s#%s", targetAddress.getValue(), devirtCounter));
final CondJump caseJump = EventFactory.newJump(expressions.makeEQ(funcPtr, targetAddress), caseLabel);
caseLabels.add(caseLabel);
caseJumps.add(caseJump);
}
final List<Function> possibleTargets = getPossibleTargets(call, func2AddressMap);
// FIXME: Here we remove the calling function itself so as to avoid trivial recursion.
// However, indirect/mutual recursion is not prevented by this!
if (possibleTargets.removeIf(f -> f == function)) {
logger.warn("Found potentially recursive dynamic call \"{}\". " +
"Dartagnan (unsoundly) assumes the recursive call cannot happen.", call);
}

final Event noMatch = EventFactory.newAssert(expressions.makeFalse(), "Invalid function pointer");
final Label endLabel = EventFactory.newLabel(String.format("__Ldevirt_end#%s", devirtCounter));

final List<Event> callReplacement = new ArrayList<>();
callReplacement.add(EventFactory.newStringAnnotation("=== Devirtualized call ==="));
callReplacement.addAll(caseJumps);
callReplacement.add(noMatch);
for (int i = 0; i < caseLabels.size(); i++) {
callReplacement.add(caseLabels.get(i));
if (call instanceof ValueFunctionCall valueCall) {
callReplacement.add(EventFactory.newValueFunctionCall(valueCall.getResultRegister(),
possibleTargets.get(i), call.getArguments()));
} else {
callReplacement.add(EventFactory.newVoidFunctionCall(possibleTargets.get(i), call.getArguments()));
}
callReplacement.add(EventFactory.newGoto(endLabel));
}
callReplacement.add(endLabel);
callReplacement.add(EventFactory.newStringAnnotation("=== End of devirtualized call ==="));
if (possibleTargets.isEmpty()) {
logger.warn("Cannot resolve dynamic call \"{}\", no matching functions found.", call);
}

call.replaceBy(callReplacement);
callReplacement.forEach(e -> e.copyAllMetadataFrom(call));
logger.trace("Devirtualizing call \"{}\" with possible targets: {}", call, possibleTargets);

final List<Label> caseLabels = new ArrayList<>(possibleTargets.size());
final List<CondJump> caseJumps = new ArrayList<>(possibleTargets.size());
final Expression funcPtr = getFunctionPointer(call);
// Construct call table
for (Function possibleTarget : possibleTargets) {
final IValue targetAddress = func2AddressMap.get(possibleTarget);
final Label caseLabel = EventFactory.newLabel(String.format("__Ldevirt_%s#%s", targetAddress.getValue(), devirtCounter));
final CondJump caseJump = EventFactory.newJump(expressions.makeEQ(funcPtr, targetAddress), caseLabel);
caseLabels.add(caseLabel);
caseJumps.add(caseJump);
}

devirtCounter++;
final Event noMatch = EventFactory.newAssert(expressions.makeFalse(), "Invalid function pointer");
final Label endLabel = EventFactory.newLabel(String.format("__Ldevirt_end#%s", devirtCounter));

final List<Event> callReplacement = new ArrayList<>();
callReplacement.add(EventFactory.newStringAnnotation("=== Devirtualized call ==="));
callReplacement.addAll(caseJumps);
callReplacement.add(noMatch);
for (int i = 0; i < caseLabels.size(); i++) {
callReplacement.add(caseLabels.get(i));
callReplacement.add(devirtualiseCall(call, possibleTargets.get(i)));
callReplacement.add(EventFactory.newGoto(endLabel));
}
callReplacement.add(endLabel);
callReplacement.add(EventFactory.newStringAnnotation("=== End of devirtualized call ==="));

call.replaceBy(callReplacement);
callReplacement.forEach(e -> e.copyAllMetadataFrom(call));

devirtCounter++;
}
}

private boolean needsDevirtualization(FunctionCall call) {
return !call.isDirectCall() ||
(call.getCalledFunction().getIntrinsicInfo() == Intrinsics.Info.P_THREAD_CREATE
&& !(call.getArguments().get(2) instanceof Function));
}

private List<Function> getPossibleTargets(FunctionCall call, Map<Function, IValue> func2AddressMap) {
final List<Function> possibleTargets;
if (!call.isDirectCall()) {
possibleTargets = func2AddressMap.keySet().stream()
.filter(f -> f.getFunctionType() == call.getCallType()).collect(Collectors.toList());
} else if (call.getCalledFunction().getIntrinsicInfo() == Intrinsics.Info.P_THREAD_CREATE) {
final TypeFactory types = TypeFactory.getInstance();
final Type ptrType = types.getPointerType();
final Type threadType = types.getFunctionType(ptrType, List.of(ptrType));
possibleTargets = func2AddressMap.keySet().stream()
.filter(f -> f.getFunctionType() == threadType).collect(Collectors.toList());
} else {
possibleTargets = List.of();
throwInternalError(call);
}

return possibleTargets;
}

private FunctionCall devirtualiseCall(FunctionCall virtCall, Function devirtCallTarget) {
final FunctionCall devirtCall = virtCall.getCopy();
setFunctionPointer(devirtCall, devirtCallTarget);
return devirtCall;
}

private Expression getFunctionPointer(FunctionCall call) {
if (!call.isDirectCall()) {
return call.getCallTarget();
} else if (call.getCalledFunction().getIntrinsicInfo() == Intrinsics.Info.P_THREAD_CREATE) {
return call.getArguments().get(2);
}
throwInternalError(call);
return null;
}

private void setFunctionPointer(FunctionCall call, Expression functionPtr) {
if (!call.isDirectCall()) {
call.setCallTarget(functionPtr);
} else if (call.getCalledFunction().getIntrinsicInfo() == Intrinsics.Info.P_THREAD_CREATE) {
call.setArgument(2, functionPtr);
} else {
throwInternalError(call);
}
}

@SuppressWarnings("all")
private void throwInternalError(FunctionCall virtCall) {
Verify.verify(false, "Encountered unexpected virtual function call: " + virtCall);
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ public static Iterable<Object[]> data() throws IOException {
{"thread_inlining_complex_2", IMM, PASS, 1},
{"thread_local", IMM, PASS, 1},
{"thread_loop", IMM, FAIL, 1},
{"thread_id", IMM, PASS, 1}
{"thread_id", IMM, PASS, 1},
{"funcPtrInStaticMemory", IMM, PASS, 1},
});
}

Expand Down
Loading

0 comments on commit 340fb3b

Please sign in to comment.