Skip to content

Commit

Permalink
Add support to inject lemmas into commutativity checks (#211)
Browse files Browse the repository at this point in the history
  • Loading branch information
bkragl authored Mar 2, 2020
1 parent 282f930 commit 9ae2a77
Show file tree
Hide file tree
Showing 12 changed files with 346 additions and 215 deletions.
32 changes: 13 additions & 19 deletions Source/Concurrency/CivlTypeChecker.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ public class CivlTypeChecker
public Dictionary<Procedure, AtomicAction> procToIsAbstraction;
public Dictionary<Procedure, YieldingProc> procToYieldingProc;
public Dictionary<Procedure, IntroductionProc> procToIntroductionProc;
public Dictionary<Tuple<AtomicAction, AtomicAction>,
List<WitnessFunction>> atomicActionPairToWitnessFunctions;
public CommutativityHints commutativityHints;

public List<InductiveSequentialization> inductiveSequentializations;

Expand Down Expand Up @@ -55,8 +54,6 @@ public CivlTypeChecker(Program program)
this.procToYieldingProc = new Dictionary<Procedure, YieldingProc>();
this.procToIntroductionProc = new Dictionary<Procedure, IntroductionProc>();
this.implToPendingAsyncCollector = new Dictionary<Implementation, Variable>();
this.atomicActionPairToWitnessFunctions = new Dictionary<Tuple<AtomicAction, AtomicAction>,
List<WitnessFunction>>();
this.inductiveSequentializations = new List<InductiveSequentialization>();
}

Expand Down Expand Up @@ -91,7 +88,7 @@ public void TypeCheck()

TypeCheckRefinementLayers();

TypeCheckWitnessFunctions();
TypeCheckCommutativityHints();

AttributeEraser.Erase(this);
}
Expand Down Expand Up @@ -129,21 +126,11 @@ private void TypeCheckRefinementLayers()

}

private void TypeCheckWitnessFunctions()
private void TypeCheckCommutativityHints()
{
WitnessFunctionVisitor wfv = new WitnessFunctionVisitor(this);
wfv.VisitFunctions();
foreach (var witnessFunction in wfv.allWitnessFunctions)
{
var key = Tuple.Create(
witnessFunction.firstAction,
witnessFunction.secondAction);
if (!atomicActionPairToWitnessFunctions.ContainsKey(key))
{
atomicActionPairToWitnessFunctions[key] = new List<WitnessFunction>();
}
atomicActionPairToWitnessFunctions[key].Add(witnessFunction);
}
CommutativityHintVisitor visitor = new CommutativityHintVisitor(this);
visitor.VisitFunctions();
this.commutativityHints = visitor.commutativityHints;
}

private void TypeCheckGlobalVariables()
Expand Down Expand Up @@ -893,6 +880,13 @@ public AtomicAction FindIsAbstraction(string name)
return procToIsAbstraction.Values.FirstOrDefault(a => a.proc.Name == name);
}

public AtomicAction FindAtomicActionOrAbstraction(string name)
{
var action = FindAtomicAction(name);
if (action != null) return action;
return FindIsAbstraction(name);
}

public IEnumerable<AtomicAction> AllActions =>
procToAtomicAction.Union(procToIsInvariant).Union(procToIsAbstraction)
.Select(x => x.Value);
Expand Down
259 changes: 259 additions & 0 deletions Source/Concurrency/CommutativityHints.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;

namespace Microsoft.Boogie
{
public class CommutativityHint
{
public readonly Function function;
public readonly AtomicAction firstAction;
public readonly AtomicAction secondAction;
public List<Expr> args;

public CommutativityHint(Function function,
AtomicAction firstAction, AtomicAction secondAction, List<Expr> args)
{
this.function = function;
this.firstAction = firstAction;
this.secondAction = secondAction;
this.args = args;
}
}

public class CommutativityWitness : CommutativityHint
{
public readonly Variable witnessedVariable;

public CommutativityWitness(Function function, Variable witnessedVariable,
AtomicAction firstAction, AtomicAction secondAction, List<Expr> args)
: base(function, firstAction, secondAction, args)
{
this.witnessedVariable = witnessedVariable;
}

public CommutativityWitness(Variable witnessedVariable, CommutativityHint h)
: this(h.function, witnessedVariable, h.firstAction, h.secondAction, h.args) {}
}

public class CommutativityHints
{
private Dictionary<Tuple<AtomicAction, AtomicAction>, List<CommutativityWitness>> witnesses;
private Dictionary<Tuple<AtomicAction, AtomicAction>, List<CommutativityHint>> lemmas;

public CommutativityHints()
{
witnesses = new Dictionary<Tuple<AtomicAction, AtomicAction>, List<CommutativityWitness>>();
lemmas = new Dictionary<Tuple<AtomicAction, AtomicAction>, List<CommutativityHint>>();
}

private Tuple<AtomicAction, AtomicAction> Key(AtomicAction first, AtomicAction second)
{
return Tuple.Create(first, second);
}

public void AddWitness(CommutativityWitness witness)
{
var key = Key(witness.firstAction, witness.secondAction);
if (!witnesses.ContainsKey(key))
{
witnesses[key] = new List<CommutativityWitness>();
}
witnesses[key].Add(witness);
}

public void AddLemma(CommutativityHint lemma)
{
var key = Key(lemma.firstAction, lemma.secondAction);
if (!lemmas.ContainsKey(key))
{
lemmas[key] = new List<CommutativityHint>();
}
lemmas[key].Add(lemma);
}

public IEnumerable<CommutativityWitness> GetWitnesses(AtomicAction first, AtomicAction second)
{
witnesses.TryGetValue(Key(first, second), out List<CommutativityWitness> list);
if (list == null) return Enumerable.Empty<CommutativityWitness>();
return list;
}

public IEnumerable<CommutativityHint> GetLemmas(AtomicAction first, AtomicAction second)
{
lemmas.TryGetValue(Key(first, second), out List<CommutativityHint> list);
if (list == null) return Enumerable.Empty<CommutativityHint>();
return list;
}
}

class CommutativityHintVisitor
{
private const string FirstProcInputPrefix = "first_";
private const string SecondProcInputPrefix = "second_";
private const string PostStateSuffix = "'";

private readonly CivlTypeChecker ctc;
public CommutativityHints commutativityHints;

private AtomicAction firstAction;
private AtomicAction secondAction;
private List<Expr> args;

public CommutativityHintVisitor(CivlTypeChecker ctc)
{
this.ctc = ctc;
commutativityHints = new CommutativityHints();
}

public void VisitFunctions()
{
foreach (var f in ctc.program.Functions)
{
VisitFunction(f);
}
}

private void VisitFunction(Function function)
{
Debug.Assert(function.OutParams.Count == 1);
List<CommutativityHint> hints = new List<CommutativityHint>();
// First we collect all {:commutativity "first_action", "second_action"} attributes
for (QKeyValue kv = function.Attributes; kv != null; kv = kv.Next)
{
if (kv.Key != CivlAttributes.COMMUTATIVITY)
continue;
if (kv.Params.Count == 2 &&
kv.Params[0] is string firstActionName &&
kv.Params[1] is string secondActionName)
{
firstAction = ctc.FindAtomicActionOrAbstraction(firstActionName);
secondAction = ctc.FindAtomicActionOrAbstraction(secondActionName);
if (firstAction == null)
{
ctc.Error(kv, $"Could not find atomic action {firstActionName}");
}
if (secondAction == null)
{
ctc.Error(kv, $"Could not find atomic action {secondActionName}");
}
if (firstAction != null && secondAction != null)
{
CheckInParams(function.InParams);
}
hints.Add(new CommutativityHint(function, firstAction, secondAction, args));
}
else
{
ctc.Error(kv, "Commutativity attribute expects two action names as parameters");
}
}

// And then we look for either {:witness "globalVar"} or {:lemma}
for (QKeyValue kv = function.Attributes; kv != null; kv = kv.Next)
{
if (kv.Key == CivlAttributes.WITNESS)
{
if (kv.Params.Count == 1 &&
kv.Params[0] is string witnessedVariableName)
{
Variable witnessedVariable = ctc.sharedVariables.Find(v => v.Name == witnessedVariableName);
if (witnessedVariable == null)
{
ctc.Error(kv, $"Could not find shared variable {witnessedVariableName}");
}
else if (!function.OutParams[0].TypedIdent.Type.Equals(witnessedVariable.TypedIdent.Type))
{
ctc.Error(function, "Result type does not match witnessed variable");
}
else
{
hints.ForEach(h => commutativityHints.AddWitness(new CommutativityWitness(witnessedVariable, h)));
}
}
else
{
ctc.Error(kv, "Witness attribute expects the name of a global variable as parameter");
}
break;
}
else if (kv.Key == CivlAttributes.LEMMA)
{
if (kv.Params.Count == 0)
{
if (!function.OutParams[0].TypedIdent.Type.Equals(Type.Bool))
{
ctc.Error(function, "Result type of lemma must be bool");
}
else
{
hints.ForEach(h => commutativityHints.AddLemma(h));
}
}
else
{
ctc.Error(kv, "Lemma attribute does not expect any parameters");
}
break;
}
}
}

private void CheckInParams(List<Variable> inParams)
{
args = new List<Expr>();
foreach (var param in inParams)
{
Expr arg = null;
if (param.Name.StartsWith(FirstProcInputPrefix, StringComparison.Ordinal))
{
arg = CheckLocal(param, firstAction.firstImpl, FirstProcInputPrefix);
}
else if (param.Name.StartsWith(SecondProcInputPrefix, StringComparison.Ordinal))
{
arg = CheckLocal(param, secondAction.secondImpl, SecondProcInputPrefix);
}
else
{
arg = CheckGlobal(param);
}
args.Add(arg);
}
}

private Expr CheckLocal(Variable param, Implementation impl, string prefix)
{
var var = FindVariable(param.Name, param.TypedIdent.Type,
impl.InParams.Union(impl.OutParams));
if (var != null)
return Expr.Ident(var);
var name = param.Name.Remove(0, prefix.Length);
ctc.Error(param, $"Action {impl.Name} does not have parameter {name}:{param.TypedIdent.Type}");
return null;
}

private Expr CheckGlobal(Variable param)
{
bool postState = param.Name.EndsWith(PostStateSuffix, StringComparison.Ordinal);
var name = param.Name;
if (postState)
name = name.Substring(0, name.Length - 1);
var var = FindVariable(name, param.TypedIdent.Type, ctc.sharedVariables);
if (var != null)
{
if (!postState)
return ExprHelper.Old(Expr.Ident(var));
else
return Expr.Ident(var);
}
ctc.Error(param, $"No shared variable {name}:{param.TypedIdent.Type}");
return null;
}

private Variable FindVariable(string name, Type type, IEnumerable<Variable> vars)
{
return vars.FirstOrDefault(i => i.Name == name && i.TypedIdent.Type.Equals(type));
}
}
}
2 changes: 1 addition & 1 deletion Source/Concurrency/Concurrency.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
<Compile Include="YieldingProcDuplicator.cs" />
<Compile Include="YieldTypeChecker.cs" />
<Compile Include="CivlUtil.cs" />
<Compile Include="Witnesses.cs" />
<Compile Include="CommutativityHints.cs" />
<Compile Include="CivlCoreTypes.cs" />
<Compile Include="InductiveSequentialization.cs" />
<Compile Include="PendingAsyncChecker.cs" />
Expand Down
7 changes: 5 additions & 2 deletions Source/Concurrency/MoverCheck.cs
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,7 @@ private void CreateCommutativityChecker(AtomicAction first, AtomicAction second)
foreach (AssertCmd assertCmd in Enumerable.Union(first.firstGate, second.secondGate))
requires.Add(new Requires(false, assertCmd.Expr));

civlTypeChecker.atomicActionPairToWitnessFunctions.TryGetValue(
Tuple.Create(first, second), out List<WitnessFunction> witnesses);
var witnesses = civlTypeChecker.commutativityHints.GetWitnesses(first, second);
var transitionRelation = TransitionRelationComputation.
Commutativity(second, first, frame, witnesses);

Expand All @@ -165,6 +164,10 @@ private void CreateCommutativityChecker(AtomicAction first, AtomicAction second)
second.secondImpl.OutParams.Select(Expr.Ident).ToList()
) { Proc = second.proc }
};
foreach (var lemma in civlTypeChecker.commutativityHints.GetLemmas(first,second))
{
cmds.Add(CmdHelper.AssumeCmd(ExprHelper.FunctionCall(lemma.function, lemma.args.ToArray())));
}
var block = new Block(Token.NoToken, "init", cmds, new ReturnCmd(Token.NoToken));

var secondInParamsFiltered = second.secondImpl.InParams.Where(v => linearTypeChecker.FindLinearKind(v) != LinearKind.LINEAR_IN);
Expand Down
Loading

0 comments on commit 9ae2a77

Please sign in to comment.