Skip to content

Commit

Permalink
feat: Make True and False branches unconditional (#740)
Browse files Browse the repository at this point in the history
For example, when building a CFG for `if True: ...`, build an
unconditional branch. This improves linearity checking for functions
involving `while True` loops, since dataflow analysis now understands
that such loops can only exit on an explicit `break` or `return`.

It also means that Guppy is now better at detecting unreachable code.
~For now, we error out in this case~. In the future it would be nice to
emit a warning when we detect this (see #739).

One complication introduced by this is that Guppy can now generate CFGs
where the exit block is unreachable. This lead to some problems with
borrowed variables. This edge case now needs to be specially handled.

I recommend reviewing each of the three commits separately
  • Loading branch information
mark-koch authored Jan 15, 2025
1 parent 2f5eed3 commit 748ea95
Show file tree
Hide file tree
Showing 26 changed files with 437 additions and 118 deletions.
40 changes: 36 additions & 4 deletions guppylang/cfg/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ def eq(self, t1: T, t2: T, /) -> bool:
"""Equality on lattice values"""
return t1 == t2

@abstractmethod
def include_unreachable(self) -> bool:
"""Whether unreachable BBs and jumps should be taken into account for the
analysis."""

@abstractmethod
def initial(self) -> T:
"""Initial lattice value"""
Expand Down Expand Up @@ -46,12 +51,19 @@ def run(self, bbs: Iterable[BB]) -> Result[T]:
Returns a mapping from basic blocks to lattice values at the start of each BB.
"""
if not self.include_unreachable():
bbs = [bb for bb in bbs if bb.reachable]
vals_before = {bb: self.initial() for bb in bbs} # return value
vals_after = {bb: self.apply_bb(vals_before[bb], bb) for bb in bbs} # cache
queue = set(bbs)
while len(queue) > 0:
bb = queue.pop()
vals_before[bb] = self.join(*(vals_after[pred] for pred in bb.predecessors))
preds = (
bb.predecessors + bb.dummy_predecessors
if self.include_unreachable()
else bb.predecessors
)
vals_before[bb] = self.join(*(vals_after[pred] for pred in preds))
val_after = self.apply_bb(vals_before[bb], bb)
if not self.eq(val_after, vals_after[bb]):
vals_after[bb] = val_after
Expand All @@ -75,7 +87,12 @@ def run(self, bbs: Iterable[BB]) -> Result[T]:
queue = set(bbs)
while len(queue) > 0:
bb = queue.pop()
val_after = self.join(*(vals_before[succ] for succ in bb.successors))
succs = (
bb.successors + bb.dummy_successors
if self.include_unreachable()
else bb.successors
)
val_after = self.join(*(vals_before[succ] for succ in succs))
val_before = self.apply_bb(val_after, bb)
if not self.eq(vals_before[bb], val_before):
vals_before[bb] = val_before
Expand All @@ -97,16 +114,26 @@ class LivenessAnalysis(Generic[VId], BackwardAnalysis[LivenessDomain[VId]]):

stats: dict[BB, VariableStats[VId]]

def __init__(self, stats: dict[BB, VariableStats[VId]]) -> None:
def __init__(
self,
stats: dict[BB, VariableStats[VId]],
initial: LivenessDomain[VId] | None = None,
include_unreachable: bool = False,
) -> None:
self.stats = stats
self._initial = initial or {}
self._include_unreachable = include_unreachable

def eq(self, live1: LivenessDomain[VId], live2: LivenessDomain[VId]) -> bool:
# Only check that both contain the same variables. We don't care about the BB
# in which the use occurs, we just need any one, to report to the user.
return live1.keys() == live2.keys()

def initial(self) -> LivenessDomain[VId]:
return {}
return self._initial

def include_unreachable(self) -> bool:
return self._include_unreachable

def join(self, *ts: LivenessDomain[VId]) -> LivenessDomain[VId]:
res: LivenessDomain[VId] = {}
Expand Down Expand Up @@ -150,6 +177,7 @@ def __init__(
stats: dict[BB, VariableStats[VId]],
ass_before_entry: set[VId],
maybe_ass_before_entry: set[VId],
include_unreachable: bool = False,
) -> None:
"""Constructs an `AssignmentAnalysis` pass for a CFG.
Expand All @@ -164,12 +192,16 @@ def __init__(
set.union(*(set(stat.assigned.keys()) for stat in stats.values()))
| ass_before_entry
)
self._include_unreachable = include_unreachable

def initial(self) -> AssignmentDomain[VId]:
# Note that definite assignment must start with `all_vars` instead of only
# `ass_before_entry` since we want to compute the *greatest* fixpoint.
return self.all_vars, self.maybe_ass_before_entry

def include_unreachable(self) -> bool:
return self._include_unreachable

def join(self, *ts: AssignmentDomain[VId]) -> AssignmentDomain[VId]:
# We always include the variables that are definitely assigned before the entry,
# even if the join is empty
Expand Down
13 changes: 10 additions & 3 deletions guppylang/cfg/bb.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,15 @@ class BB(ABC):
predecessors: list[Self] = field(default_factory=list)
successors: list[Self] = field(default_factory=list)

# Whether this BB is reachable from the entry
reachable: bool = False

# Dummy predecessors and successors that correspond to branches that are provably
# never taken. For example, `if False: ...` statements emit only dummy control-flow
# links.
dummy_predecessors: list[Self] = field(default_factory=list)
dummy_successors: list[Self] = field(default_factory=list)

# If the BB has multiple successors, we need a predicate to decide to which one to
# jump to
branch_pred: ast.expr | None = None
Expand Down Expand Up @@ -93,9 +102,7 @@ def compute_variable_stats(self) -> VariableStats[str]:
@property
def is_exit(self) -> bool:
"""Whether this is the exit BB."""
# The exit BB is the only one without successors (otherwise we would have gotten
# an unreachable code error during CFG building)
return len(self.successors) == 0
return self == self.containing_cfg.exit_bb


class VariableVisitor(ast.NodeVisitor):
Expand Down
45 changes: 41 additions & 4 deletions guppylang/cfg/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,21 +77,48 @@ def build(self, nodes: list[ast.stmt], returns_none: bool, globals: Globals) ->
nodes, self.cfg.entry_bb, Jumps(self.cfg.exit_bb, None, None)
)

# Compute reachable BBs
self.cfg.update_reachable()

# If we're still in a basic block after compiling the whole body, we have to add
# an implicit void return
if final_bb is not None:
if not returns_none:
raise GuppyError(ExpectedError(nodes[-1], "return statement"))
self.cfg.link(final_bb, self.cfg.exit_bb)
if final_bb.reachable:
self.cfg.exit_bb.reachable = True
if not returns_none:
raise GuppyError(ExpectedError(nodes[-1], "return statement"))

# Prune the CFG such that there are no jumps from unreachable code back into
# reachable code. Otherwise, unreachable code could lead to unnecessary type
# checking errors, e.g. if unreachable code changes the type of a variable.
for bb in self.cfg.bbs:
if not bb.reachable:
for succ in list(bb.successors):
if succ.reachable:
bb.successors.remove(succ)
succ.predecessors.remove(bb)
# Similarly, if a BB is reachable, then there is no need to hold on to dummy
# jumps into it. Dummy jumps are only needed to propagate type information
# into and between unreachable BBs
else:
for pred in bb.dummy_predecessors:
pred.dummy_successors.remove(bb)
bb.dummy_predecessors = []

return self.cfg

def visit_stmts(self, nodes: list[ast.stmt], bb: BB, jumps: Jumps) -> BB | None:
prev_bb = bb
bb_opt: BB | None = bb
next_functional = False
for node in nodes:
# If the previous statement jumped, then all following statements are
# unreachable. Just create a new dummy BB and keep going so we can still
# check the unreachable code.
if bb_opt is None:
raise GuppyError(UnreachableError(node))
bb_opt = self.cfg.new_bb()
self.cfg.dummy_link(prev_bb, bb_opt)
if is_functional_annotation(node):
next_functional = True
continue
Expand All @@ -101,7 +128,7 @@ def visit_stmts(self, nodes: list[ast.stmt], bb: BB, jumps: Jumps) -> BB | None:
raise NotImplementedError
next_functional = False
else:
bb_opt = self.visit(node, bb_opt, jumps)
prev_bb, bb_opt = bb_opt, self.visit(node, bb_opt, jumps)
return bb_opt

def _build_node_value(self, node: BBStatement, bb: BB) -> BB:
Expand Down Expand Up @@ -368,6 +395,16 @@ def add_branch(node: ast.expr, cfg: CFG, bb: BB, true_bb: BB, false_bb: BB) -> N
builder = BranchBuilder(cfg)
builder.visit(node, bb, true_bb, false_bb)

def visit_Constant(
self, node: ast.Constant, bb: BB, true_bb: BB, false_bb: BB
) -> None:
# Branching on `True` or `False` constant should be unconditional
if isinstance(node.value, bool):
self.cfg.link(bb, true_bb if node.value else false_bb)
self.cfg.dummy_link(bb, false_bb if node.value else true_bb)
else:
self.generic_visit(node, bb, true_bb, false_bb)

def visit_BoolOp(self, node: ast.BoolOp, bb: BB, true_bb: BB, false_bb: BB) -> None:
# Add short-circuit evaluation of boolean expression. If there are more than 2
# operators, we turn the flat operator list into a right-nested tree to allow
Expand Down
30 changes: 28 additions & 2 deletions guppylang/cfg/cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,16 @@ def ancestors(self, *bbs: T) -> Iterator[T]:
yield bb
queue += bb.predecessors

def update_reachable(self) -> None:
"""Sets the reachability flags on the BBs in this CFG."""
queue = {self.entry_bb}
while queue:
bb = queue.pop()
if not bb.reachable:
bb.reachable = True
for succ in bb.successors:
queue.add(succ)


class CFG(BaseCFG[BB]):
"""A control-flow graph of unchecked basic blocks."""
Expand All @@ -75,6 +85,15 @@ def link(self, src_bb: BB, tgt_bb: BB) -> None:
src_bb.successors.append(tgt_bb)
tgt_bb.predecessors.append(src_bb)

def dummy_link(self, src_bb: BB, tgt_bb: BB) -> None:
"""Adds a dummy control-flow edge between two basic blocks that is provably
never taken.
For example, a `if False: ...` statement emits such a dummy link.
"""
src_bb.dummy_successors.append(tgt_bb)
tgt_bb.dummy_predecessors.append(src_bb)

def analyze(
self,
def_ass_before: set[str],
Expand All @@ -84,8 +103,15 @@ def analyze(
stats = {bb: bb.compute_variable_stats() for bb in self.bbs}
# Mark all borrowed variables as implicitly used in the exit BB
stats[self.exit_bb].used |= {x: InoutReturnSentinel(var=x) for x in inout_vars}
self.live_before = LivenessAnalysis(stats).run(self.bbs)
# This also means borrowed variables are always live, so we can use them as the
# initial value in the liveness analysis. This solves the edge case that
# borrowed variables should be considered live, even if the exit is actually
# unreachable (to avoid linearity violations later).
inout_live = {x: self.exit_bb for x in inout_vars}
self.live_before = LivenessAnalysis(
stats, initial=inout_live, include_unreachable=True
).run(self.bbs)
self.ass_before, self.maybe_ass_before = AssignmentAnalysis(
stats, def_ass_before, maybe_ass_before
stats, def_ass_before, maybe_ass_before, include_unreachable=True
).run_unpacked(self.bbs)
return stats
55 changes: 39 additions & 16 deletions guppylang/checker/cfg_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@ class Signature(Generic[V]):
input_row: Row[V]
output_rows: Sequence[Row[V]] # One for each successor

dummy_output_rows: Sequence[Row[V]] = field(default_factory=list)

@staticmethod
def empty() -> "Signature[V]":
return Signature([], [])
return Signature([], [], [])


@dataclass(eq=False) # Disable equality to recover hash from `object`
Expand Down Expand Up @@ -76,8 +78,8 @@ def check_cfg(
"""
# First, we need to run program analysis
ass_before = {v.name for v in inputs}
inout_vars = [v.name for v in inputs if InputFlags.Inout in v.flags]
cfg.analyze(ass_before, ass_before, inout_vars)
inout_vars = [v for v in inputs if InputFlags.Inout in v.flags]
cfg.analyze(ass_before, ass_before, [v.name for v in inout_vars])

# We start by compiling the entry BB
checked_cfg: CheckedCFG[Variable] = CheckedCFG([v.ty for v in inputs], return_ty)
Expand All @@ -93,13 +95,16 @@ def check_cfg(
(checked_cfg.entry_bb, i, succ)
# We enumerate the successor starting from the back, so we start with the `True`
# branch. This way, we find errors in a more natural order
for i, succ in reverse_enumerate(cfg.entry_bb.successors)
for i, succ in reverse_enumerate(
cfg.entry_bb.successors + cfg.entry_bb.dummy_successors
)
)
while len(queue) > 0:
pred, num_output, bb = queue.popleft()
pred_outputs = [*pred.sig.output_rows, *pred.sig.dummy_output_rows]
input_row = [
Variable(v.name, v.ty, v.defined_at, v.flags)
for v in pred.sig.output_rows[num_output]
for v in pred_outputs[num_output]
]

if bb in compiled:
Expand All @@ -119,16 +124,26 @@ def check_cfg(
]
compiled[bb] = checked_bb

# Link up BBs in the checked CFG
compiled[bb].predecessors.append(pred)
pred.successors[num_output] = compiled[bb]

checked_cfg.bbs = list(compiled.values())
checked_cfg.exit_bb = compiled[cfg.exit_bb] # TODO: Fails if exit is unreachable
checked_cfg.live_before = {compiled[bb]: cfg.live_before[bb] for bb in cfg.bbs}
checked_cfg.ass_before = {compiled[bb]: cfg.ass_before[bb] for bb in cfg.bbs}
# Link up BBs in the checked CFG, excluding the unreachable ones
if bb.reachable:
compiled[bb].predecessors.append(pred)
pred.successors[num_output] = compiled[bb]

# The exit BB might be unreachable. In that case it won't be visited above and we
# have to handle it here
if cfg.exit_bb not in compiled:
assert not cfg.exit_bb.reachable
compiled[cfg.exit_bb] = CheckedBB(
cfg.exit_bb.idx, checked_cfg, reachable=False, sig=Signature(inout_vars, [])
)

required_bbs = [bb for bb in cfg.bbs if bb.reachable or bb.is_exit]
checked_cfg.bbs = [compiled[bb] for bb in required_bbs]
checked_cfg.exit_bb = compiled[cfg.exit_bb]
checked_cfg.live_before = {compiled[bb]: cfg.live_before[bb] for bb in required_bbs}
checked_cfg.ass_before = {compiled[bb]: cfg.ass_before[bb] for bb in required_bbs}
checked_cfg.maybe_ass_before = {
compiled[bb]: cfg.maybe_ass_before[bb] for bb in cfg.bbs
compiled[bb]: cfg.maybe_ass_before[bb] for bb in required_bbs
}

# Finally, run the linearity check
Expand Down Expand Up @@ -205,7 +220,7 @@ def check_bb(
bb.branch_pred, ty = ExprSynthesizer(ctx).synthesize(bb.branch_pred)
bb.branch_pred, _ = to_bool(bb.branch_pred, ty, ctx)

for succ in bb.successors:
for succ in bb.successors + bb.dummy_successors:
for x, use_bb in cfg.live_before[succ].items():
# Check that the variables requested by the successor are defined
if x not in ctx.locals and x not in ctx.globals:
Expand All @@ -227,10 +242,18 @@ def check_bb(
[ctx.locals[x] for x in cfg.live_before[succ] if x in ctx.locals]
for succ in bb.successors
]
dummy_outputs = [
[ctx.locals[x] for x in cfg.live_before[succ] if x in ctx.locals]
for succ in bb.dummy_successors
]

# Also prepare the successor list so we can fill it in later
checked_bb = CheckedBB(
bb.idx, checked_cfg, checked_stmts, sig=Signature(inputs, outputs)
bb.idx,
checked_cfg,
checked_stmts,
reachable=bb.reachable,
sig=Signature(inputs, outputs, dummy_outputs),
)
checked_bb.successors = [None] * len(bb.successors) # type: ignore[list-item]
checked_bb.branch_pred = bb.branch_pred
Expand Down
Loading

0 comments on commit 748ea95

Please sign in to comment.