Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate structure-s where possible #4731

Merged
merged 2 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 35 additions & 4 deletions pyk/src/pyk/k2lean4/k2lean4.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
Mutual,
Signature,
SimpleFieldVal,
StructCtor,
Structure,
StructVal,
Term,
)
Expand Down Expand Up @@ -69,14 +71,17 @@ def _sort_block(self, sorts: list[str]) -> Command | None:
def _transform_sort(self, sort: str) -> Declaration:
def is_inductive(sort: str) -> bool:
decl = self.defn.sorts[sort]
return not decl.hooked and 'hasDomainValues' not in decl.attrs_by_key
return not decl.hooked and 'hasDomainValues' not in decl.attrs_by_key and not self._is_cell(sort)

def is_collection(sort: str) -> bool:
return sort in self.defn.collections

if is_inductive(sort):
return self._inductive(sort)

if self._is_cell(sort):
return self._cell(sort)

if is_collection(sort):
return self._collection(sort)

Expand Down Expand Up @@ -109,7 +114,33 @@ def _symbol_ident(symbol: str) -> str:
symbol = f'«{symbol}»'
return symbol

def _collection(self, sort: str) -> Inductive:
@staticmethod
def _is_cell(sort: str) -> bool:
return sort.endswith('Cell')

def _cell(self, sort: str) -> Structure:
(cell_ctor,) = self.defn.constructors[sort]
decl = self.defn.symbols[cell_ctor]
param_sorts = _param_sorts(decl)

param_names: list[str]

if all(self._is_cell(sort) for sort in param_sorts):
param_names = []
for param_sort in param_sorts:
assert param_sort.startswith('Sort')
assert param_sort.endswith('Cell')
name = param_sort[4:-4]
name = name[0].lower() + name[1:]
param_names.append(name)
else:
assert len(param_sorts) == 1
param_names = ['val']

fields = tuple(ExplBinder((name,), Term(sort)) for name, sort in zip(param_names, param_sorts, strict=True))
return Structure(sort, Signature((), Term('Type')), ctor=StructCtor(fields))

def _collection(self, sort: str) -> Structure:
coll = self.defn.collections[sort]
elem = self.defn.symbols[coll.element]
sorts = _param_sorts(elem)
Expand All @@ -124,8 +155,8 @@ def _collection(self, sort: str) -> Inductive:
case CollectionKind.MAP:
key, value = sorts
val = Term(f'List ({key} × {value})')
ctor = Ctor('mk', Signature((ExplBinder(('coll',), val),), Term(sort)))
return Inductive(sort, Signature((), Term('Type')), ctors=(ctor,))
field = ExplBinder(('coll',), val)
return Structure(sort, Signature((), Term('Type')), ctor=StructCtor((field,)))

def inj_module(self) -> Module:
return Module(commands=self._inj_commands())
Expand Down
95 changes: 95 additions & 0 deletions pyk/src/pyk/k2lean4/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,101 @@ def __str__(self) -> str:
return f'| {patterns} => {self.rhs}'


@final
@dataclass(frozen=True)
class Structure(Declaration):
ident: DeclId
signature: Signature | None
extends: tuple[Term, ...]
ctor: StructCtor | None
deriving: tuple[str, ...]
modifiers: Modifiers | None

def __init__(
self,
ident: str | DeclId,
signature: Signature | None = None,
extends: Iterable[Term] | None = None,
ctor: StructCtor | None = None,
deriving: Iterable[str] | None = None,
modifiers: Modifiers | None = None,
):
ident = DeclId(ident) if isinstance(ident, str) else ident
extends = tuple(extends) if extends is not None else ()
deriving = tuple(deriving) if deriving is not None else ()
object.__setattr__(self, 'ident', ident)
object.__setattr__(self, 'signature', signature)
object.__setattr__(self, 'extends', extends)
object.__setattr__(self, 'ctor', ctor)
object.__setattr__(self, 'deriving', deriving)
object.__setattr__(self, 'modifiers', modifiers)

def __str__(self) -> str:
lines = []

modifiers = f'{self.modifiers} ' if self.modifiers else ''
binders = (
' '.join(str(binder) for binder in self.signature.binders)
if self.signature and self.signature.binders
else ''
)
binders = f' {binders}' if binders else ''
extends = ', '.join(str(extend) for extend in self.extends)
extends = f' extends {extends}' if extends else ''
ty = f' : {self.signature.ty}' if self.signature and self.signature.ty else ''
where = ' where' if self.ctor else ''
lines.append(f'{modifiers}structure {self.ident}{binders}{extends}{ty}{where}')

if self.deriving:
lines.append(f' deriving {self.deriving}')

if self.ctor:
lines.extend(f' {line}' for line in str(self.ctor).splitlines())

return '\n'.join(lines)


@final
@dataclass(frozen=True)
class StructCtor:
fields: tuple[Binder, ...] # TODO implement StructField
ident: StructIdent | None

def __init__(
self,
fields: Iterable[Binder],
ident: str | StructIdent | None = None,
):
fields = tuple(fields)
ident = StructIdent(ident) if isinstance(ident, str) else ident
object.__setattr__(self, 'fields', fields)
object.__setattr__(self, 'ident', ident)

def __str__(self) -> str:
lines = []
if self.ident:
lines.append(f'{self.ident} ::')
for field in self.fields:
if isinstance(field, ExplBinder) and len(field.idents) == 1:
(ident,) = field.idents
ty = '' if field.ty is None else f' : {field.ty}'
lines.append(f'{ident}{ty}')
else:
lines.append(str(field))
return '\n'.join(lines)


@final
@dataclass(frozen=True)
class StructIdent:
ident: str
modifiers: Modifiers | None = None

def __str__(self) -> str:
modifiers = f'{self.modifiers} ' if self.modifiers else ''
return f'{modifiers}{ self.ident}'


@final
@dataclass(frozen=True)
class DeclId:
Expand Down
Loading