Skip to content

Commit

Permalink
Merge pull request #13 from Herb-AI/idm_project_2705
Browse files Browse the repository at this point in the history
Idm project 2705
  • Loading branch information
jaapdejong15 authored Jun 30, 2023
2 parents 189e7f1 + adfdf3f commit 543cdad
Show file tree
Hide file tree
Showing 15 changed files with 573 additions and 57 deletions.
20 changes: 19 additions & 1 deletion src/HerbConstraints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,29 @@ abstract type PropagatorConstraint <: Constraint end

abstract type LocalConstraint <: Constraint end

@enum PropagateFailureReason unchanged_domain=1
PropagatedDomain = Union{PropagateFailureReason, Vector{Int}}

include("matchfail.jl")
include("matchnode.jl")
include("context.jl")
include("patternmatch.jl")
include("rulenodematch.jl")

include("csg_annotated/csg_annotated.jl")

include("propagatorconstraints/comesafter.jl")
include("propagatorconstraints/forbidden_path.jl")
include("propagatorconstraints/require_on_left.jl")
include("propagatorconstraints/forbidden.jl")
include("propagatorconstraints/ordered.jl")
include("propagatorconstraints/condition.jl")
include("propagatorconstraints/one_of.jl")

include("localconstraints/local_forbidden.jl")
include("localconstraints/local_ordered.jl")
include("localconstraints/local_condition.jl")
include("localconstraints/local_one_of.jl")

export
AbstractMatchNode,
Expand All @@ -36,17 +45,26 @@ export

PropagatorConstraint,
LocalConstraint,
PropagateFailureReason,
PropagatedDomain,

propagate,
check_tree,

generateconstraints!,

ComesAfter,
ForbiddenPath,
RequireOnLeft,
Forbidden,
Ordered,
Condition,
OneOf,

LocalForbidden,
LocalOrdered
LocalOrdered,
LocalCondition
LocalOrdered,
LocalOneOf

end # module HerbConstraints
4 changes: 2 additions & 2 deletions src/context.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ Contains:
"""
mutable struct GrammarContext
originalExpr::AbstractRuleNode # original expression being modified
nodeLocation::Vector{Int} # path to he current node in the expression,
constraints::Vector{LocalConstraint} # local constraints that should be propagated
nodeLocation::Vector{Int} # path to the current node in the expression,
constraints::Set{LocalConstraint} # local constraints that should be propagated
end

GrammarContext(originalExpr::AbstractRuleNode) = GrammarContext(originalExpr, [], [])
Expand Down
194 changes: 194 additions & 0 deletions src/csg_annotated/csg_annotated.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@

"""
@csgrammar_annotated
Define an annotated grammar and return it as a ContextSensitiveGrammar.
Allows for adding optional annotations per rule.
As well as that, allows for adding optional labels per rule, which can be referenced in annotations.
Syntax is backwards-compatible with @csgrammar.
Examples:
```julia-repl
g₁ = @csgrammar_annotated begin
Element = 1
Element = x
Element = Element + Element := commutative
Element = Element * Element := (commutative, transitive)
end
```
```julia-repl
g₁ = @csgrammar_annotated begin
Element = 1
Element = x
Element = Element + Element := forbidden_path([3, 1])
Element = Element * Element := (commutative, transitive)
end
```
```julia-repl
g₁ = @csgrammar_annotated begin
one:: Element = 1
variable:: Element = x
addition:: Element = Element + Element := (
commutative,
transitive,
forbidden_path([:addition, :one]) || forbidden_path([:one, :variable])
)
multiplication:: Element = Element * Element := (commutative, transitive)
end
```
"""
macro csgrammar_annotated(expression)
# collect and remove labels
labels = _get_labels!(expression)

# parse rules, get constraints from annotations
rules = Any[]
types = Symbol[]
bytype = Dict{Symbol,Vector{Int}}()
constraints = Vector{Constraint}()

rule_index = 1

for (e, label) in zip(expression.args, labels)
# only consider if e is of type ... = ...
if !(e isa Expr && e.head == :(=)) continue end

# get the left and right hand side of a rule
lhs = e.args[1]
rhs = e.args[2]

# parse annotations if present
if rhs isa Expr && rhs.head == :(:=)
# get new annotations as a list
annotations = rhs.args[2]
if annotations isa Expr && annotations.head == :tuple
annotations = annotations.args
else
annotations = [annotations]
end

# convert annotations, append to constraints
append!(constraints, annotation2constraint(a, rule_index, labels) for a annotations)

# discard annotation
rhs = rhs.args[1]
end

# parse rules
new_rules = Any[]
parse_rule!(new_rules, rhs)

@assert (length(new_rules) == 1 || label == "") "Cannot give rule name $(label) to multiple rules!"

# add new rules to data
for new_rule new_rules
push!(rules, new_rule)
push!(types, lhs)
bytype[lhs] = push!(get(bytype, lhs, Int[]), rule_index)

rule_index += 1
end
end

# determine parameters
alltypes = collect(keys(bytype))
is_terminal = [isterminal(rule, alltypes) for rule rules]
is_eval = [iseval(rule) for rule rules]
childtypes = [get_childtypes(rule, alltypes) for rule rules]
domains = Dict(type => BitArray(r bytype[type] for r 1:length(rules)) for type alltypes)

return ContextSensitiveGrammar(
rules,
types,
is_terminal,
is_eval,
bytype,
domains,
childtypes,
nothing,
constraints
)
end


# gets the labels from an expression
function _get_labels!(expression::Expr)::Vector{String}
labels = Vector{String}()

for e in expression.args
# only consider if e is of type ... = ...
if !(e isa Expr && e.head == :(=)) continue end

lhs = e.args[1]

label = ""
if lhs isa Expr && lhs.head == :(::)
label = string(lhs.args[1])

# discard rule name
e.args[1] = lhs.args[2]
end

push!(labels, label)
end

# flatten linenums into expression
Base.remove_linenums!(expression)

return labels
end


"""
Converts an annotation to a constraint.
commutative: creates an Ordered constraint
transitive: creates an (incorrect) Forbidden constraint
forbidden_path(path::Vector{Union{Symbol, Int}}): creates a ForbiddenPath constraint with the original rule included
... || ...: creates a OneOf constraint (also works with ... || ... || ... et cetera, though not very performant)
"""
function annotation2constraint(annotation::Any, rule_index::Int, labels::Vector{String})::Constraint
if annotation isa Expr
# function-like annotations
if annotation.head == :call
func_name = annotation.args[1]
func_args = annotation.args[2:end]

if func_name == :forbidden_path
string_args = eval(func_args[1])
index_args = [arg isa Symbol ? _get_rule_index(labels, string(arg)) : arg for arg in string_args]

return ForbiddenPath(
[rule_index; index_args]
)
end
end

# disjunctive annotations
if annotation.head == :||
return OneOf(
@show [annotation2constraint(a, rule_index, labels) for a in annotation.args]
)
end
end

# commutative annotations
if annotation == :commutative
return Ordered(
MatchNode(rule_index, [MatchVar(:x), MatchVar(:y)]),
[:x, :y]
)
end

if annotation == :transitive
return Forbidden(
MatchNode(rule_index, [MatchVar(:x), MatchNode(rule_index, [MatchVar(:y), MatchVar(:z)])])
)
end

# unknown constraint
throw(ArgumentError("Annotation $(annotation) at rule $(rule_index) not found!"))
end


# helper function for label lookup
_get_rule_index(labels::Vector{String}, label::String)::Int = findfirst(isequal(label), labels)
48 changes: 48 additions & 0 deletions src/localconstraints/local_condition.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
mutable struct LocalCondition <: LocalConstraint
path::Vector{Int}
tree::AbstractMatchNode
condition::Function
end

function propagate(
c::LocalCondition,
::Grammar,
context::GrammarContext,
domain::Vector{Int},
filled_hole::Union{HoleReference, Nothing}
)::Tuple{PropagatedDomain, Set{LocalConstraint}}
# Skip the propagator if a node is being propagated that it isn't targeting
if length(c.path) > length(context.nodeLocation) || c.path context.nodeLocation[1:length(c.path)]
return domain, Set([c])
end

# Skip the propagator if the filled hole wasn't part of the path
if !isnothing(filled_hole) && (length(c.path) > length(filled_hole.path) || c.path filled_hole.path[1:length(c.path)])
return domain, Set([c])
end

n = get_node_at_location(context.originalExpr, c.path)

hole_location = context.nodeLocation[length(c.path)+1:end]

vars = Dict{Symbol, AbstractRuleNode}()

match = _pattern_match_with_hole(n, c.tree, hole_location, vars)
if match hardfail
# Match attempt failed due to mismatched rulenode indices.
# This means that we can remove the current constraint.
return domain, Set()
elseif match softfail
# Match attempt failed because we had to compare with a hole.
# If the hole would've been filled it might have succeeded, so we cannot yet remove the constraint.
return domain, Set([c])
end

function is_in_domain(rule)
vars_copy = copy(vars)
vars_copy[match[1]] = RuleNode(rule)
return c.condition(vars_copy)
end

return filter(is_in_domain, domain), Set()
end
26 changes: 19 additions & 7 deletions src/localconstraints/local_forbidden.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,23 @@ Propagates the LocalForbidden constraint.
It removes rules from the domain that would make
the RuleNode at the given path match the pattern defined by the MatchNode.
"""
function propagate(c::LocalForbidden, ::Grammar, context::GrammarContext, domain::Vector{Int})::Tuple{Vector{Int}, Vector{LocalConstraint}}
function propagate(
c::LocalForbidden,
::Grammar,
context::GrammarContext,
domain::Vector{Int},
filled_hole::Union{HoleReference, Nothing}
)::Tuple{PropagatedDomain, Set{LocalConstraint}}
# Skip the propagator if a node is being propagated that it isn't targeting
if length(c.path) > length(context.nodeLocation) || c.path context.nodeLocation[1:length(c.path)]
return domain, [c]
return domain, Set([c])
end

# Skip the propagator if the filled hole wasn't part of the path
if !isnothing(filled_hole) && (length(c.path) > length(filled_hole.path) || c.path filled_hole.path[1:length(c.path)])
return domain, Set([c])
end

n = get_node_at_location(context.originalExpr, c.path)

hole_location = context.nodeLocation[length(c.path)+1:end]
Expand All @@ -29,11 +41,11 @@ function propagate(c::LocalForbidden, ::Grammar, context::GrammarContext, domain
if match hardfail
# Match attempt failed due to mismatched rulenode indices.
# This means that we can remove the current constraint.
return domain, []
return domain, Set()
elseif match softfail
# Match attempt failed because we had to compare with a hole.
# If the hole would've been filled it might have succeeded, so we cannot yet remove the constraint.
return domain, [c]
return domain, Set([c])
end

remove_from_domain::Int = 0
Expand All @@ -42,7 +54,7 @@ function propagate(c::LocalForbidden, ::Grammar, context::GrammarContext, domain
remove_from_domain = match
elseif match isa Tuple{Symbol, Vector{Int}}
# The hole is matched with an otherwise unassigned variable (wildcard).
return [], []
return Vector{Int}(), Set()
end

# Remove the rule that would complete the forbidden tree from the domain
Expand All @@ -52,5 +64,5 @@ function propagate(c::LocalForbidden, ::Grammar, context::GrammarContext, domain
end
# If the domain is pruned, we do not need this constraint anymore after expansion,
# since no equality is possible with the new domain.
return domain, []
end
return domain, Set()
end
Loading

0 comments on commit 543cdad

Please sign in to comment.