Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add method to_axiom to class Rule #4701

Merged
merged 1 commit into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 155 additions & 47 deletions pyk/src/pyk/kore/rule.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations

import logging
from abc import ABC
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Generic, TypeVar, final
from functools import reduce
from typing import TYPE_CHECKING, Generic, TypeVar, cast, final

from .prelude import inj
from .prelude import BOOL, SORT_GENERATED_TOP_CELL, TRUE, inj
from .syntax import (
DV,
And,
Expand All @@ -28,7 +29,7 @@
if TYPE_CHECKING:
from typing import Final

from .syntax import Definition
from .syntax import Definition, Sort

Attrs = dict[str, tuple[Pattern, ...]]

Expand Down Expand Up @@ -68,8 +69,12 @@ class Rule(ABC):
rhs: Pattern
req: Pattern | None
ens: Pattern | None
sort: Sort
priority: int

@abstractmethod
def to_axiom(self) -> Axiom: ...

@staticmethod
def from_axiom(axiom: Axiom) -> Rule:
if isinstance(axiom.pattern, Rewrites):
Expand All @@ -89,22 +94,25 @@ def from_axiom(axiom: Axiom) -> Rule:
raise ValueError(f'Cannot parse simplification rule: {axiom.text}')

@staticmethod
def extract_all(defn: Definition) -> list[Rule]:
def is_rule(axiom: Axiom) -> bool:
if axiom == _INJ_AXIOM:
return False
def is_rule(axiom: Axiom) -> bool:
if axiom == _INJ_AXIOM:
return False
tothtamas28 marked this conversation as resolved.
Show resolved Hide resolved

if any(attr in axiom.attrs_by_key for attr in _SKIPPED_ATTRS):
return False
if any(attr in axiom.attrs_by_key for attr in _SKIPPED_ATTRS):
return False

return True
return True

return [Rule.from_axiom(axiom) for axiom in defn.axioms if is_rule(axiom)]
@staticmethod
def extract_all(defn: Definition) -> list[Rule]:
return [Rule.from_axiom(axiom) for axiom in defn.axioms if Rule.is_rule(axiom)]


@final
@dataclass(frozen=True)
class RewriteRule(Rule):
sort = SORT_GENERATED_TOP_CELL

lhs: App
rhs: App
req: Pattern | None
Expand All @@ -114,6 +122,19 @@ class RewriteRule(Rule):
uid: str
label: str | None

def to_axiom(self) -> Axiom:
lhs = self.lhs if self.ctx is None else And(self.sort, (self.lhs, self.ctx))
req = _to_ml_pred(self.req, self.sort)
ens = _to_ml_pred(self.ens, self.sort)
return Axiom(
(),
Rewrites(
self.sort,
And(self.sort, (lhs, req)),
And(self.sort, (self.rhs, ens)),
),
)

@staticmethod
def from_axiom(axiom: Axiom) -> RewriteRule:
lhs, rhs, req, ens, ctx = RewriteRule._extract(axiom)
Expand Down Expand Up @@ -166,60 +187,125 @@ class FunctionRule(Rule):
rhs: Pattern
req: Pattern | None
ens: Pattern | None
sort: Sort
arg_sorts: tuple[Sort, ...]
anti_left: Pattern | None
priority: int

def to_axiom(self) -> Axiom:
R = SortVar('R') # noqa N806

def arg_list(rest: Pattern, arg_pair: tuple[EVar, Pattern]) -> Pattern:
var, arg = arg_pair
return And(R, (In(var.sort, R, var, arg), rest))

vars = tuple(EVar(f'X{i}', sort) for i, sort in enumerate(self.arg_sorts))

# \and{R}(\in{S1, R}(X1 : S1, Arg1), \and{R}(\in{S2, R}(X2 : S2, Arg2), \top{R}())) etc.
_args = reduce(
arg_list,
reversed(tuple(zip(vars, self.lhs.args, strict=True))),
cast('Pattern', Top(R)),
)

_req = _to_ml_pred(self.req, R)
req = And(R, (_req, _args))
if self.anti_left:
req = And(R, (Not(R, self.anti_left), req))

app = self.lhs.let(args=vars)
ens = _to_ml_pred(self.ens, self.sort)

return Axiom(
(R,),
Implies(
R,
req,
Equals(self.sort, R, app, And(self.sort, (self.rhs, ens))),
),
)

@staticmethod
def from_axiom(axiom: Axiom) -> FunctionRule:
lhs, rhs, req, ens = FunctionRule._extract(axiom)
anti_left: Pattern | None = None
match axiom.pattern:
case Implies(
left=And(ops=(Not(pattern=anti_left), And(ops=(_req, _args)))),
right=Equals(op_sort=sort, left=App() as app, right=_rhs),
):
pass
case Implies(
left=And(ops=(_req, _args)),
right=Equals(op_sort=sort, left=App() as app, right=_rhs),
):
pass
case _:
raise ValueError(f'Cannot extract function rule from axiom: {axiom.text}')

arg_sorts, args = FunctionRule._extract_args(_args)
lhs = app.let(args=args)
req = _extract_condition(_req)
rhs, ens = _extract_rhs(_rhs)

priority = _extract_priority(axiom)
return FunctionRule(
lhs=lhs,
rhs=rhs,
req=req,
ens=ens,
sort=sort,
arg_sorts=arg_sorts,
anti_left=anti_left,
priority=priority,
)

@staticmethod
def _extract(axiom: Axiom) -> tuple[App, Pattern, Pattern | None, Pattern | None]:
match axiom.pattern:
case Implies(
left=And(
ops=(Not(), And(ops=(_req, _args))) | (_req, _args),
),
right=Equals(left=App() as app, right=_rhs),
):
args = FunctionRule._extract_args(_args)
lhs = app.let(args=args)
req = _extract_condition(_req)
rhs, ens = _extract_rhs(_rhs)
return lhs, rhs, req, ens
case _:
raise ValueError(f'Cannot extract function rule from axiom: {axiom.text}')

@staticmethod
def _extract_args(pattern: Pattern) -> tuple[Pattern, ...]:
def _extract_args(pattern: Pattern) -> tuple[tuple[Sort, ...], tuple[Pattern, ...]]:
match pattern:
case Top():
return ()
case And(ops=(In(left=EVar(), right=arg), rest)):
return (arg,) + FunctionRule._extract_args(rest)
return (), ()
case And(ops=(In(left=EVar(sort=sort), right=arg), rest)):
sorts, args = FunctionRule._extract_args(rest)
return (sort,) + sorts, (arg,) + args
case _:
raise ValueError(f'Cannot extract argument list from pattern: {pattern.text}')


class SimpliRule(Rule, Generic[P], ABC):
lhs: P
sort: Sort

def to_axiom(self) -> Axiom:
R = SortVar('R') # noqa N806

vars = (R, self.sort) if isinstance(self.sort, SortVar) else (R,)
req = _to_ml_pred(self.req, R)
ens = _to_ml_pred(self.ens, self.sort)

return Axiom(
vars,
Implies(
R,
req,
Equals(self.sort, R, self.lhs, And(self.sort, (self.rhs, ens))),
),
attrs=(
App(
'simplification',
args=() if self.priority == 50 else (String(str(self.priority)),),
),
),
)

@staticmethod
def _extract(axiom: Axiom, lhs_type: type[P]) -> tuple[P, Pattern, Pattern | None, Pattern | None]:
def _extract(axiom: Axiom, lhs_type: type[P]) -> tuple[P, Pattern, Pattern | None, Pattern | None, Sort]:
match axiom.pattern:
case Implies(left=_req, right=Equals(left=lhs, right=_rhs)):
case Implies(left=_req, right=Equals(op_sort=sort, left=lhs, right=_rhs)):
req = _extract_condition(_req)
rhs, ens = _extract_rhs(_rhs)
if not isinstance(lhs, lhs_type):
raise ValueError(f'Invalid LHS type from simplification axiom: {axiom.text}')
return lhs, rhs, req, ens
return lhs, rhs, req, ens, sort
case _:
raise ValueError(f'Cannot extract simplification rule from axiom: {axiom.text}')

Expand All @@ -231,63 +317,67 @@ class AppRule(SimpliRule[App]):
rhs: Pattern
req: Pattern | None
ens: Pattern | None
sort: Sort
priority: int

@staticmethod
def from_axiom(axiom: Axiom) -> AppRule:
lhs, rhs, req, ens = SimpliRule._extract(axiom, App)
priority = _extract_priority(axiom)
lhs, rhs, req, ens, sort = SimpliRule._extract(axiom, App)
priority = _extract_simpl_priority(axiom)
return AppRule(
lhs=lhs,
rhs=rhs,
req=req,
ens=ens,
sort=sort,
priority=priority,
)


@final
@dataclass(frozen=True)
class CeilRule(SimpliRule):
class CeilRule(SimpliRule[Ceil]):
lhs: Ceil
rhs: Pattern
req: Pattern | None
ens: Pattern | None
sort: Sort
priority: int

@staticmethod
def from_axiom(axiom: Axiom) -> CeilRule:
lhs, rhs, req, ens = SimpliRule._extract(axiom, Ceil)
priority = _extract_priority(axiom)
lhs, rhs, req, ens, sort = SimpliRule._extract(axiom, Ceil)
priority = _extract_simpl_priority(axiom)
return CeilRule(
lhs=lhs,
rhs=rhs,
req=req,
ens=ens,
sort=sort,
priority=priority,
)


@final
@dataclass(frozen=True)
class EqualsRule(SimpliRule):
class EqualsRule(SimpliRule[Equals]):
tothtamas28 marked this conversation as resolved.
Show resolved Hide resolved
lhs: Equals
rhs: Pattern
req: Pattern | None
ens: Pattern | None
sort: Sort
priority: int

@staticmethod
def from_axiom(axiom: Axiom) -> EqualsRule:
lhs, rhs, req, ens = SimpliRule._extract(axiom, Equals)
if not isinstance(lhs, Equals):
raise ValueError(f'Cannot extract LHS as Equals from axiom: {axiom.text}')
priority = _extract_priority(axiom)
lhs, rhs, req, ens, sort = SimpliRule._extract(axiom, Equals)
priority = _extract_simpl_priority(axiom)
return EqualsRule(
lhs=lhs,
rhs=rhs,
req=req,
ens=ens,
sort=sort,
priority=priority,
)

Expand Down Expand Up @@ -340,3 +430,21 @@ def _extract_priority(axiom: Axiom) -> int:
return 200 if 'owise' in attrs else 50
case _:
raise ValueError(f'Cannot extract priority from axiom: {axiom.text}')


def _extract_simpl_priority(axiom: Axiom) -> int:
attrs = axiom.attrs_by_key
match attrs['simplification']:
case App(args=() | (String(''),)):
return 50
case App(args=(String(p),)):
return int(p)
case _:
raise ValueError(f'Cannot extract simplification priority from axiom: {axiom.text}')


def _to_ml_pred(pattern: Pattern | None, sort: Sort) -> Pattern:
if pattern is None:
return Top(sort)

return Equals(BOOL, sort, pattern, TRUE)
28 changes: 27 additions & 1 deletion pyk/src/tests/integration/kore/test_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@

from pyk.kore.parser import KoreParser
from pyk.kore.rule import Rule
from pyk.kore.syntax import App, String

from ..utils import K_FILES

if TYPE_CHECKING:
from pyk.kore.syntax import Definition
from pyk.kore.syntax import Axiom, Definition
from pyk.testing import Kompiler


Expand All @@ -36,3 +37,28 @@ def test_extract_all(definition: Definition) -> None:
assert cnt['AppRule']
assert cnt['CeilRule']
assert cnt['EqualsRule']


def test_to_axiom(definition: Definition) -> None:
def adjust_atts(axiom: Axiom) -> Axiom:
match axiom.attrs_by_key.get('simplification'):
case None:
return axiom.let(attrs=())
case App(args=(String('' | '50'),)):
return axiom.let(attrs=(App('simplification'),))
case attr:
return axiom.let(attrs=(attr,))

for axiom in definition.axioms:
if not Rule.is_rule(axiom):
continue

# Given
expected = adjust_atts(axiom)

# When
rule = Rule.from_axiom(axiom)
actual = rule.to_axiom()

# Then
assert expected == actual
Loading