-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #13 from Herb-AI/idm_project_2705
Idm project 2705
- Loading branch information
Showing
15 changed files
with
573 additions
and
57 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,194 @@ | ||
|
||
""" | ||
@csgrammar_annotated | ||
Define an annotated grammar and return it as a ContextSensitiveGrammar. | ||
Allows for adding optional annotations per rule. | ||
As well as that, allows for adding optional labels per rule, which can be referenced in annotations. | ||
Syntax is backwards-compatible with @csgrammar. | ||
Examples: | ||
```julia-repl | ||
g₁ = @csgrammar_annotated begin | ||
Element = 1 | ||
Element = x | ||
Element = Element + Element := commutative | ||
Element = Element * Element := (commutative, transitive) | ||
end | ||
``` | ||
```julia-repl | ||
g₁ = @csgrammar_annotated begin | ||
Element = 1 | ||
Element = x | ||
Element = Element + Element := forbidden_path([3, 1]) | ||
Element = Element * Element := (commutative, transitive) | ||
end | ||
``` | ||
```julia-repl | ||
g₁ = @csgrammar_annotated begin | ||
one:: Element = 1 | ||
variable:: Element = x | ||
addition:: Element = Element + Element := ( | ||
commutative, | ||
transitive, | ||
forbidden_path([:addition, :one]) || forbidden_path([:one, :variable]) | ||
) | ||
multiplication:: Element = Element * Element := (commutative, transitive) | ||
end | ||
``` | ||
""" | ||
macro csgrammar_annotated(expression) | ||
# collect and remove labels | ||
labels = _get_labels!(expression) | ||
|
||
# parse rules, get constraints from annotations | ||
rules = Any[] | ||
types = Symbol[] | ||
bytype = Dict{Symbol,Vector{Int}}() | ||
constraints = Vector{Constraint}() | ||
|
||
rule_index = 1 | ||
|
||
for (e, label) in zip(expression.args, labels) | ||
# only consider if e is of type ... = ... | ||
if !(e isa Expr && e.head == :(=)) continue end | ||
|
||
# get the left and right hand side of a rule | ||
lhs = e.args[1] | ||
rhs = e.args[2] | ||
|
||
# parse annotations if present | ||
if rhs isa Expr && rhs.head == :(:=) | ||
# get new annotations as a list | ||
annotations = rhs.args[2] | ||
if annotations isa Expr && annotations.head == :tuple | ||
annotations = annotations.args | ||
else | ||
annotations = [annotations] | ||
end | ||
|
||
# convert annotations, append to constraints | ||
append!(constraints, annotation2constraint(a, rule_index, labels) for a ∈ annotations) | ||
|
||
# discard annotation | ||
rhs = rhs.args[1] | ||
end | ||
|
||
# parse rules | ||
new_rules = Any[] | ||
parse_rule!(new_rules, rhs) | ||
|
||
@assert (length(new_rules) == 1 || label == "") "Cannot give rule name $(label) to multiple rules!" | ||
|
||
# add new rules to data | ||
for new_rule ∈ new_rules | ||
push!(rules, new_rule) | ||
push!(types, lhs) | ||
bytype[lhs] = push!(get(bytype, lhs, Int[]), rule_index) | ||
|
||
rule_index += 1 | ||
end | ||
end | ||
|
||
# determine parameters | ||
alltypes = collect(keys(bytype)) | ||
is_terminal = [isterminal(rule, alltypes) for rule ∈ rules] | ||
is_eval = [iseval(rule) for rule ∈ rules] | ||
childtypes = [get_childtypes(rule, alltypes) for rule ∈ rules] | ||
domains = Dict(type => BitArray(r ∈ bytype[type] for r ∈ 1:length(rules)) for type ∈ alltypes) | ||
|
||
return ContextSensitiveGrammar( | ||
rules, | ||
types, | ||
is_terminal, | ||
is_eval, | ||
bytype, | ||
domains, | ||
childtypes, | ||
nothing, | ||
constraints | ||
) | ||
end | ||
|
||
|
||
# gets the labels from an expression | ||
function _get_labels!(expression::Expr)::Vector{String} | ||
labels = Vector{String}() | ||
|
||
for e in expression.args | ||
# only consider if e is of type ... = ... | ||
if !(e isa Expr && e.head == :(=)) continue end | ||
|
||
lhs = e.args[1] | ||
|
||
label = "" | ||
if lhs isa Expr && lhs.head == :(::) | ||
label = string(lhs.args[1]) | ||
|
||
# discard rule name | ||
e.args[1] = lhs.args[2] | ||
end | ||
|
||
push!(labels, label) | ||
end | ||
|
||
# flatten linenums into expression | ||
Base.remove_linenums!(expression) | ||
|
||
return labels | ||
end | ||
|
||
|
||
""" | ||
Converts an annotation to a constraint. | ||
commutative: creates an Ordered constraint | ||
transitive: creates an (incorrect) Forbidden constraint | ||
forbidden_path(path::Vector{Union{Symbol, Int}}): creates a ForbiddenPath constraint with the original rule included | ||
... || ...: creates a OneOf constraint (also works with ... || ... || ... et cetera, though not very performant) | ||
""" | ||
function annotation2constraint(annotation::Any, rule_index::Int, labels::Vector{String})::Constraint | ||
if annotation isa Expr | ||
# function-like annotations | ||
if annotation.head == :call | ||
func_name = annotation.args[1] | ||
func_args = annotation.args[2:end] | ||
|
||
if func_name == :forbidden_path | ||
string_args = eval(func_args[1]) | ||
index_args = [arg isa Symbol ? _get_rule_index(labels, string(arg)) : arg for arg in string_args] | ||
|
||
return ForbiddenPath( | ||
[rule_index; index_args] | ||
) | ||
end | ||
end | ||
|
||
# disjunctive annotations | ||
if annotation.head == :|| | ||
return OneOf( | ||
@show [annotation2constraint(a, rule_index, labels) for a in annotation.args] | ||
) | ||
end | ||
end | ||
|
||
# commutative annotations | ||
if annotation == :commutative | ||
return Ordered( | ||
MatchNode(rule_index, [MatchVar(:x), MatchVar(:y)]), | ||
[:x, :y] | ||
) | ||
end | ||
|
||
if annotation == :transitive | ||
return Forbidden( | ||
MatchNode(rule_index, [MatchVar(:x), MatchNode(rule_index, [MatchVar(:y), MatchVar(:z)])]) | ||
) | ||
end | ||
|
||
# unknown constraint | ||
throw(ArgumentError("Annotation $(annotation) at rule $(rule_index) not found!")) | ||
end | ||
|
||
|
||
# helper function for label lookup | ||
_get_rule_index(labels::Vector{String}, label::String)::Int = findfirst(isequal(label), labels) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
mutable struct LocalCondition <: LocalConstraint | ||
path::Vector{Int} | ||
tree::AbstractMatchNode | ||
condition::Function | ||
end | ||
|
||
function propagate( | ||
c::LocalCondition, | ||
::Grammar, | ||
context::GrammarContext, | ||
domain::Vector{Int}, | ||
filled_hole::Union{HoleReference, Nothing} | ||
)::Tuple{PropagatedDomain, Set{LocalConstraint}} | ||
# Skip the propagator if a node is being propagated that it isn't targeting | ||
if length(c.path) > length(context.nodeLocation) || c.path ≠ context.nodeLocation[1:length(c.path)] | ||
return domain, Set([c]) | ||
end | ||
|
||
# Skip the propagator if the filled hole wasn't part of the path | ||
if !isnothing(filled_hole) && (length(c.path) > length(filled_hole.path) || c.path ≠ filled_hole.path[1:length(c.path)]) | ||
return domain, Set([c]) | ||
end | ||
|
||
n = get_node_at_location(context.originalExpr, c.path) | ||
|
||
hole_location = context.nodeLocation[length(c.path)+1:end] | ||
|
||
vars = Dict{Symbol, AbstractRuleNode}() | ||
|
||
match = _pattern_match_with_hole(n, c.tree, hole_location, vars) | ||
if match ≡ hardfail | ||
# Match attempt failed due to mismatched rulenode indices. | ||
# This means that we can remove the current constraint. | ||
return domain, Set() | ||
elseif match ≡ softfail | ||
# Match attempt failed because we had to compare with a hole. | ||
# If the hole would've been filled it might have succeeded, so we cannot yet remove the constraint. | ||
return domain, Set([c]) | ||
end | ||
|
||
function is_in_domain(rule) | ||
vars_copy = copy(vars) | ||
vars_copy[match[1]] = RuleNode(rule) | ||
return c.condition(vars_copy) | ||
end | ||
|
||
return filter(is_in_domain, domain), Set() | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.