diff --git a/src/HerbConstraints.jl b/src/HerbConstraints.jl index 6509f62..3e9e473 100644 --- a/src/HerbConstraints.jl +++ b/src/HerbConstraints.jl @@ -7,20 +7,29 @@ abstract type PropagatorConstraint <: Constraint end abstract type LocalConstraint <: Constraint end +@enum PropagateFailureReason unchanged_domain=1 +PropagatedDomain = Union{PropagateFailureReason, Vector{Int}} + include("matchfail.jl") include("matchnode.jl") include("context.jl") include("patternmatch.jl") include("rulenodematch.jl") +include("csg_annotated/csg_annotated.jl") + include("propagatorconstraints/comesafter.jl") include("propagatorconstraints/forbidden_path.jl") include("propagatorconstraints/require_on_left.jl") include("propagatorconstraints/forbidden.jl") include("propagatorconstraints/ordered.jl") +include("propagatorconstraints/condition.jl") +include("propagatorconstraints/one_of.jl") include("localconstraints/local_forbidden.jl") include("localconstraints/local_ordered.jl") +include("localconstraints/local_condition.jl") +include("localconstraints/local_one_of.jl") export AbstractMatchNode, @@ -36,17 +45,26 @@ export PropagatorConstraint, LocalConstraint, + PropagateFailureReason, + PropagatedDomain, propagate, check_tree, + generateconstraints!, + ComesAfter, ForbiddenPath, RequireOnLeft, Forbidden, Ordered, + Condition, + OneOf, LocalForbidden, - LocalOrdered + LocalOrdered, + LocalCondition + LocalOrdered, + LocalOneOf end # module HerbConstraints diff --git a/src/context.jl b/src/context.jl index 2eaeac1..1782a33 100644 --- a/src/context.jl +++ b/src/context.jl @@ -8,8 +8,8 @@ Contains: """ mutable struct GrammarContext originalExpr::AbstractRuleNode # original expression being modified - nodeLocation::Vector{Int} # path to he current node in the expression, - constraints::Vector{LocalConstraint} # local constraints that should be propagated + nodeLocation::Vector{Int} # path to the current node in the expression, + constraints::Set{LocalConstraint} # local constraints that should be propagated end GrammarContext(originalExpr::AbstractRuleNode) = GrammarContext(originalExpr, [], []) diff --git a/src/csg_annotated/csg_annotated.jl b/src/csg_annotated/csg_annotated.jl new file mode 100644 index 0000000..b2e3012 --- /dev/null +++ b/src/csg_annotated/csg_annotated.jl @@ -0,0 +1,194 @@ + +""" +@csgrammar_annotated +Define an annotated grammar and return it as a ContextSensitiveGrammar. +Allows for adding optional annotations per rule. +As well as that, allows for adding optional labels per rule, which can be referenced in annotations. +Syntax is backwards-compatible with @csgrammar. +Examples: +```julia-repl +g₁ = @csgrammar_annotated begin + Element = 1 + Element = x + Element = Element + Element := commutative + Element = Element * Element := (commutative, transitive) +end +``` + +```julia-repl +g₁ = @csgrammar_annotated begin + Element = 1 + Element = x + Element = Element + Element := forbidden_path([3, 1]) + Element = Element * Element := (commutative, transitive) +end +``` + +```julia-repl +g₁ = @csgrammar_annotated begin + one:: Element = 1 + variable:: Element = x + addition:: Element = Element + Element := ( + commutative, + transitive, + forbidden_path([:addition, :one]) || forbidden_path([:one, :variable]) + ) + multiplication:: Element = Element * Element := (commutative, transitive) +end +``` +""" +macro csgrammar_annotated(expression) + # collect and remove labels + labels = _get_labels!(expression) + + # parse rules, get constraints from annotations + rules = Any[] + types = Symbol[] + bytype = Dict{Symbol,Vector{Int}}() + constraints = Vector{Constraint}() + + rule_index = 1 + + for (e, label) in zip(expression.args, labels) + # only consider if e is of type ... = ... + if !(e isa Expr && e.head == :(=)) continue end + + # get the left and right hand side of a rule + lhs = e.args[1] + rhs = e.args[2] + + # parse annotations if present + if rhs isa Expr && rhs.head == :(:=) + # get new annotations as a list + annotations = rhs.args[2] + if annotations isa Expr && annotations.head == :tuple + annotations = annotations.args + else + annotations = [annotations] + end + + # convert annotations, append to constraints + append!(constraints, annotation2constraint(a, rule_index, labels) for a ∈ annotations) + + # discard annotation + rhs = rhs.args[1] + end + + # parse rules + new_rules = Any[] + parse_rule!(new_rules, rhs) + + @assert (length(new_rules) == 1 || label == "") "Cannot give rule name $(label) to multiple rules!" + + # add new rules to data + for new_rule ∈ new_rules + push!(rules, new_rule) + push!(types, lhs) + bytype[lhs] = push!(get(bytype, lhs, Int[]), rule_index) + + rule_index += 1 + end + end + + # determine parameters + alltypes = collect(keys(bytype)) + is_terminal = [isterminal(rule, alltypes) for rule ∈ rules] + is_eval = [iseval(rule) for rule ∈ rules] + childtypes = [get_childtypes(rule, alltypes) for rule ∈ rules] + domains = Dict(type => BitArray(r ∈ bytype[type] for r ∈ 1:length(rules)) for type ∈ alltypes) + + return ContextSensitiveGrammar( + rules, + types, + is_terminal, + is_eval, + bytype, + domains, + childtypes, + nothing, + constraints + ) +end + + +# gets the labels from an expression +function _get_labels!(expression::Expr)::Vector{String} + labels = Vector{String}() + + for e in expression.args + # only consider if e is of type ... = ... + if !(e isa Expr && e.head == :(=)) continue end + + lhs = e.args[1] + + label = "" + if lhs isa Expr && lhs.head == :(::) + label = string(lhs.args[1]) + + # discard rule name + e.args[1] = lhs.args[2] + end + + push!(labels, label) + end + + # flatten linenums into expression + Base.remove_linenums!(expression) + + return labels +end + + +""" +Converts an annotation to a constraint. +commutative: creates an Ordered constraint +transitive: creates an (incorrect) Forbidden constraint +forbidden_path(path::Vector{Union{Symbol, Int}}): creates a ForbiddenPath constraint with the original rule included +... || ...: creates a OneOf constraint (also works with ... || ... || ... et cetera, though not very performant) +""" +function annotation2constraint(annotation::Any, rule_index::Int, labels::Vector{String})::Constraint + if annotation isa Expr + # function-like annotations + if annotation.head == :call + func_name = annotation.args[1] + func_args = annotation.args[2:end] + + if func_name == :forbidden_path + string_args = eval(func_args[1]) + index_args = [arg isa Symbol ? _get_rule_index(labels, string(arg)) : arg for arg in string_args] + + return ForbiddenPath( + [rule_index; index_args] + ) + end + end + + # disjunctive annotations + if annotation.head == :|| + return OneOf( + @show [annotation2constraint(a, rule_index, labels) for a in annotation.args] + ) + end + end + + # commutative annotations + if annotation == :commutative + return Ordered( + MatchNode(rule_index, [MatchVar(:x), MatchVar(:y)]), + [:x, :y] + ) + end + + if annotation == :transitive + return Forbidden( + MatchNode(rule_index, [MatchVar(:x), MatchNode(rule_index, [MatchVar(:y), MatchVar(:z)])]) + ) + end + + # unknown constraint + throw(ArgumentError("Annotation $(annotation) at rule $(rule_index) not found!")) +end + + +# helper function for label lookup +_get_rule_index(labels::Vector{String}, label::String)::Int = findfirst(isequal(label), labels) \ No newline at end of file diff --git a/src/localconstraints/local_condition.jl b/src/localconstraints/local_condition.jl new file mode 100644 index 0000000..99af0a0 --- /dev/null +++ b/src/localconstraints/local_condition.jl @@ -0,0 +1,48 @@ +mutable struct LocalCondition <: LocalConstraint + path::Vector{Int} + tree::AbstractMatchNode + condition::Function +end + +function propagate( + c::LocalCondition, + ::Grammar, + context::GrammarContext, + domain::Vector{Int}, + filled_hole::Union{HoleReference, Nothing} +)::Tuple{PropagatedDomain, Set{LocalConstraint}} + # Skip the propagator if a node is being propagated that it isn't targeting + if length(c.path) > length(context.nodeLocation) || c.path ≠ context.nodeLocation[1:length(c.path)] + return domain, Set([c]) + end + + # Skip the propagator if the filled hole wasn't part of the path + if !isnothing(filled_hole) && (length(c.path) > length(filled_hole.path) || c.path ≠ filled_hole.path[1:length(c.path)]) + return domain, Set([c]) + end + + n = get_node_at_location(context.originalExpr, c.path) + + hole_location = context.nodeLocation[length(c.path)+1:end] + + vars = Dict{Symbol, AbstractRuleNode}() + + match = _pattern_match_with_hole(n, c.tree, hole_location, vars) + if match ≡ hardfail + # Match attempt failed due to mismatched rulenode indices. + # This means that we can remove the current constraint. + return domain, Set() + elseif match ≡ softfail + # Match attempt failed because we had to compare with a hole. + # If the hole would've been filled it might have succeeded, so we cannot yet remove the constraint. + return domain, Set([c]) + end + + function is_in_domain(rule) + vars_copy = copy(vars) + vars_copy[match[1]] = RuleNode(rule) + return c.condition(vars_copy) + end + + return filter(is_in_domain, domain), Set() +end diff --git a/src/localconstraints/local_forbidden.jl b/src/localconstraints/local_forbidden.jl index ce6a17a..732b743 100644 --- a/src/localconstraints/local_forbidden.jl +++ b/src/localconstraints/local_forbidden.jl @@ -14,11 +14,23 @@ Propagates the LocalForbidden constraint. It removes rules from the domain that would make the RuleNode at the given path match the pattern defined by the MatchNode. """ -function propagate(c::LocalForbidden, ::Grammar, context::GrammarContext, domain::Vector{Int})::Tuple{Vector{Int}, Vector{LocalConstraint}} +function propagate( + c::LocalForbidden, + ::Grammar, + context::GrammarContext, + domain::Vector{Int}, + filled_hole::Union{HoleReference, Nothing} +)::Tuple{PropagatedDomain, Set{LocalConstraint}} + # Skip the propagator if a node is being propagated that it isn't targeting if length(c.path) > length(context.nodeLocation) || c.path ≠ context.nodeLocation[1:length(c.path)] - return domain, [c] + return domain, Set([c]) end + # Skip the propagator if the filled hole wasn't part of the path + if !isnothing(filled_hole) && (length(c.path) > length(filled_hole.path) || c.path ≠ filled_hole.path[1:length(c.path)]) + return domain, Set([c]) + end + n = get_node_at_location(context.originalExpr, c.path) hole_location = context.nodeLocation[length(c.path)+1:end] @@ -29,11 +41,11 @@ function propagate(c::LocalForbidden, ::Grammar, context::GrammarContext, domain if match ≡ hardfail # Match attempt failed due to mismatched rulenode indices. # This means that we can remove the current constraint. - return domain, [] + return domain, Set() elseif match ≡ softfail # Match attempt failed because we had to compare with a hole. # If the hole would've been filled it might have succeeded, so we cannot yet remove the constraint. - return domain, [c] + return domain, Set([c]) end remove_from_domain::Int = 0 @@ -42,7 +54,7 @@ function propagate(c::LocalForbidden, ::Grammar, context::GrammarContext, domain remove_from_domain = match elseif match isa Tuple{Symbol, Vector{Int}} # The hole is matched with an otherwise unassigned variable (wildcard). - return [], [] + return Vector{Int}(), Set() end # Remove the rule that would complete the forbidden tree from the domain @@ -52,5 +64,5 @@ function propagate(c::LocalForbidden, ::Grammar, context::GrammarContext, domain end # If the domain is pruned, we do not need this constraint anymore after expansion, # since no equality is possible with the new domain. - return domain, [] -end \ No newline at end of file + return domain, Set() +end diff --git a/src/localconstraints/local_one_of.jl b/src/localconstraints/local_one_of.jl new file mode 100644 index 0000000..cb6624f --- /dev/null +++ b/src/localconstraints/local_one_of.jl @@ -0,0 +1,52 @@ +""" +Meta-constraint that enforces the disjunction of its given constraints. +""" +mutable struct LocalOneOf <: LocalConstraint + global_constraints::Vector{PropagatorConstraint} + local_constraints::Set{LocalConstraint} +end + + +""" +Propagates the LocalOneOf constraint. +It enforces that at least one of its given constraints hold. +""" +function propagate( + c::LocalOneOf, + g::Grammar, + context::GrammarContext, + domain::Vector{Int}, + filled_hole::Union{HoleReference, Nothing} +)::Tuple{PropagatedDomain, Set{LocalConstraint}} + if length(c.global_constraints) == 0 + return domain, Set() + end + + # Copy the context to add the local constraints belonging to this one of constraint as well. + # This way, we don't unnecessarily keep creating new local constraints. + new_context = deepcopy(context) + union!(new_context.constraints, c.local_constraints) + + new_local_constraints::Set{LocalConstraint} = Set() + new_domain = BitVector(undef, length(g.rules)) + + # Iterate over the global constraints & local constraints + any_domain_updated = false + for constraint ∈ Iterators.flatten((c.global_constraints, c.local_constraints)) + curr_domain, curr_local_constraints = propagate(constraint, g, new_context, copy(domain), filled_hole) + + # If we are actually intending to update the domain, OR it and set the domain updated flag to true. + if !isa(curr_domain, PropagateFailureReason) + new_domain .|= get_domain(g, curr_domain) + any_domain_updated = true + end + + union!(new_local_constraints, curr_local_constraints) + end + + # If we have updated the domain, use that domain. Otherwise, simply return the original domain. + returned_domain = any_domain_updated ? findall(new_domain) : domain + + # Make a copy of the one of constraint. Otherwise, every tree will have the same reference to it (as we only create 1). + return returned_domain, Set([LocalOneOf(c.global_constraints, new_local_constraints)]) +end diff --git a/src/localconstraints/local_ordered.jl b/src/localconstraints/local_ordered.jl index 880271b..39d1aac 100644 --- a/src/localconstraints/local_ordered.jl +++ b/src/localconstraints/local_ordered.jl @@ -14,9 +14,21 @@ Propagates the LocalOrdered constraint. It removes rules from the domain that would violate the order of variables as defined in the constraint. """ -function propagate(c::LocalOrdered, ::Grammar, context::GrammarContext, domain::Vector{Int})::Tuple{Vector{Int}, Vector{LocalConstraint}} +function propagate( + c::LocalOrdered, + ::Grammar, + context::GrammarContext, + domain::Vector{Int}, + filled_hole::Union{HoleReference, Nothing} +)::Tuple{PropagatedDomain, Set{LocalConstraint}} + # Skip the propagator if a node is being propagated that it isn't targeting if length(c.path) > length(context.nodeLocation) || c.path ≠ context.nodeLocation[1:length(c.path)] - return domain, [c] + return domain, Set([c]) + end + + # Skip the propagator if the filled hole wasn't part of the path + if !isnothing(filled_hole) && (length(c.path) > length(filled_hole.path) || c.path ≠ filled_hole.path[1:length(c.path)]) + return domain, Set([c]) end n = get_node_at_location(context.originalExpr, c.path) @@ -29,13 +41,14 @@ function propagate(c::LocalOrdered, ::Grammar, context::GrammarContext, domain:: if match ≡ hardfail # Match attempt failed due to mismatched rulenode indices. # This means that we can remove the current constraint. - return domain, [] + return domain, Set() elseif match ≡ softfail # Match attempt failed because we had to compare with a hole. # If the hole would've been filled it might have succeeded, so we cannot yet remove the constraint. - return domain, [c] + return domain, Set([c]) elseif match isa Tuple{Symbol, Vector{Int}} hole_var, hole_path = match + @assert hole_var ∈ keys(vars) @assert hole_var ∈ c.order @@ -57,10 +70,10 @@ function propagate(c::LocalOrdered, ::Grammar, context::GrammarContext, domain:: domain = new_domain end - return domain, can_be_deleted ? [] : [c] + return domain, can_be_deleted ? Set() : Set([c]) else @error("Unexpected result from pattern match, not propagating constraint $c") - return domain, [c] + return domain, Set([c]) end end @@ -209,4 +222,4 @@ end # TODO: Can we analyze the hole domains? _rulenode_compare(::Hole, ::RuleNode) = softfail _rulenode_compare(::RuleNode, ::Hole) = softfail -_rulenode_compare(::Hole, ::Hole) = softfail \ No newline at end of file +_rulenode_compare(::Hole, ::Hole) = softfail diff --git a/src/propagatorconstraints/comesafter.jl b/src/propagatorconstraints/comesafter.jl index c84b90a..8d32ca7 100644 --- a/src/propagatorconstraints/comesafter.jl +++ b/src/propagatorconstraints/comesafter.jl @@ -13,15 +13,36 @@ ComesAfter(rule::Int, predecessor::Int) = ComesAfter(rule, [predecessor]) Propagates the ComesAfter constraint. It removes the rule from the domain if the predecessors sequence is in the ancestors. """ -function propagate(c::ComesAfter, ::Grammar, context::GrammarContext, domain::Vector{Int})::Tuple{Vector{Int}, Vector{LocalConstraint}} - ancestors = get_rulesequence(context.originalExpr, context.nodeLocation[begin:end-1]) # remove the current node from the node sequence +function propagate( + c::ComesAfter, + ::Grammar, + context::GrammarContext, + domain::Vector{Int}, + filled_hole::Union{HoleReference, Nothing} +)::Tuple{PropagatedDomain, Set{LocalConstraint}} + # Skip the propagator if the hole that was filled isn't a parent of the current hole + if !isnothing(filled_hole) && filled_hole.path != context.nodeLocation[begin:end-1] + return domain, Set() + end + if c.rule in domain # if rule is in domain, check the ancestors + ancestors = get_rulesequence(context.originalExpr, context.nodeLocation[begin:end-1]) # remove the current node from the node sequence if containedin(c.predecessors, ancestors) - return domain, [] + return domain, Set() else - return filter(e -> e != c.rule, domain), [] + return filter(e -> e != c.rule, domain), Set() end else # if it is not in the domain, just return domain - return domain, [] + return domain, Set() end end + + +""" +Checks if the given tree abides the constraint. +""" +function check_tree(c::ComesAfter, g::Grammar, tree::AbstractRuleNode)::Bool + @warn "ComesAfter.check_tree not implemented!" + + return true +end diff --git a/src/propagatorconstraints/condition.jl b/src/propagatorconstraints/condition.jl new file mode 100644 index 0000000..0c314ed --- /dev/null +++ b/src/propagatorconstraints/condition.jl @@ -0,0 +1,43 @@ +struct Condition <: PropagatorConstraint + tree::AbstractMatchNode + condition::Function +end + + +function propagate( + c::Condition, + g::Grammar, + context::GrammarContext, + domain::Vector{Int}, + filled_hole::Union{HoleReference, Nothing} +)::Tuple{PropagatedDomain, Set{LocalConstraint}} + # Skip the propagator if the hole that was filled isn't a parent of the current hole + if !isnothing(filled_hole) && filled_hole.path != context.nodeLocation[begin:end-1] + return domain, Set() + end + + _condition_constraint = LocalCondition(context.nodeLocation, c.tree, c.condition) + if in(_condition_constraint, context.constraints) return domain, Set() end + + new_domain, new_constraints = propagate(_condition_constraint, g, context, domain, filled_hole) + return new_domain, new_constraints +end + + +""" +Checks if the given tree abides the constraint. +""" +function check_tree(c::Condition, g::Grammar, tree::RuleNode)::Bool + vars = Dict{Symbol, AbstractRuleNode}() + + # Return false if the node fits the pattern, but not the condition + if _pattern_match(tree, c.tree, vars) ≡ nothing && !c.condition(vars) + return false + end + + return all(check_tree(c, g, child) for child ∈ tree.children) +end + +function check_tree(::Condition, ::Grammar, ::Hole)::Bool + return false +end diff --git a/src/propagatorconstraints/forbidden.jl b/src/propagatorconstraints/forbidden.jl index 1d439ba..5bf990f 100644 --- a/src/propagatorconstraints/forbidden.jl +++ b/src/propagatorconstraints/forbidden.jl @@ -14,9 +14,22 @@ end Propagates the Forbidden constraint. It removes the elements from the domain that would complete the forbidden tree. """ -function propagate(c::Forbidden, g::Grammar, context::GrammarContext, domain::Vector{Int})::Tuple{Vector{Int}, Vector{LocalConstraint}} +function propagate( + c::Forbidden, + g::Grammar, + context::GrammarContext, + domain::Vector{Int}, + filled_hole::Union{HoleReference, Nothing} +)::Tuple{PropagatedDomain, Set{LocalConstraint}} + # Skip the propagator if the hole that was filled isn't a parent of the current hole + if !isnothing(filled_hole) && filled_hole.path != context.nodeLocation[begin:end-1] + return domain, Set() + end + notequals_constraint = LocalForbidden(context.nodeLocation, c.tree) - new_domain, new_constraints = propagate(notequals_constraint, g, context, domain) + if in(notequals_constraint, context.constraints) return domain, Set() end + + new_domain, new_constraints = propagate(notequals_constraint, g, context, domain, filled_hole) return new_domain, new_constraints end @@ -34,4 +47,4 @@ end function check_tree(c::Forbidden, ::Grammar, tree::Hole)::Bool vars = Dict{Symbol, AbstractRuleNode}() return _pattern_match(tree, c.tree, vars) !== nothing -end \ No newline at end of file +end diff --git a/src/propagatorconstraints/forbidden_path.jl b/src/propagatorconstraints/forbidden_path.jl index cd8c6d8..7225509 100644 --- a/src/propagatorconstraints/forbidden_path.jl +++ b/src/propagatorconstraints/forbidden_path.jl @@ -11,13 +11,34 @@ end Propagates the ForbiddenPath constraint. It removes the elements from the domain that would complete the forbidden sequence. """ -function propagate(c::ForbiddenPath, ::Grammar, context::GrammarContext, domain::Vector{Int})::Tuple{Vector{Int}, Vector{LocalConstraint}} +function propagate( + c::ForbiddenPath, + ::Grammar, + context::GrammarContext, + domain::Vector{Int}, + filled_hole::Union{HoleReference, Nothing} +)::Tuple{PropagatedDomain, Set{LocalConstraint}} + # Skip the propagator if the hole that was filled isn't a parent of the current hole + if !isnothing(filled_hole) && filled_hole.path != context.nodeLocation[begin:end-1] + return domain, Set() + end + ancestors = get_rulesequence(context.originalExpr, context.nodeLocation[begin:end-1]) if subsequenceof(c.sequence[begin:end-1], ancestors) last_in_seq = c.sequence[end] - return filter(x -> !(x == last_in_seq), domain), [] + return filter(x -> !(x == last_in_seq), domain), Set() end - return domain, [] + return domain, Set() +end + + +""" +Checks if the given tree abides the constraint. +""" +function check_tree(c::ForbiddenPath, g::Grammar, tree::AbstractRuleNode)::Bool + @warn "ForbiddenPath.check_tree not implemented!" + + return true end diff --git a/src/propagatorconstraints/one_of.jl b/src/propagatorconstraints/one_of.jl new file mode 100644 index 0000000..0fb6c5a --- /dev/null +++ b/src/propagatorconstraints/one_of.jl @@ -0,0 +1,37 @@ +""" +Meta-constraint that enforces the disjunction of its given constraints. +""" +struct OneOf <: PropagatorConstraint + constraints::Vector{PropagatorConstraint} +end + +function OneOf(constraint::PropagatorConstraint) return OneOf([constraint]) end +function OneOf(constraints...) return OneOf([constraints...]) end + + +""" +Propagates the OneOf constraint. +It enforces that at least one of its given constraints hold. +""" +function propagate( + c::OneOf, + g::Grammar, + context::GrammarContext, + domain::Vector{Int}, + filled_hole::Union{HoleReference, Nothing} +)::Tuple{PropagatedDomain, Set{LocalConstraint}} + # Only ever create 1 instance mounted at the root. We do require a local constraint to have multiple instances (one for every PQ node). + if context.nodeLocation != [] return domain, Set() end + + _one_of_constraint = LocalOneOf(c.constraints, Set()) + new_domain, new_constraints = propagate(_one_of_constraint, g, context, domain, filled_hole) + return new_domain, new_constraints +end + + +""" +Checks if the given tree abides the constraint. +""" +function check_tree(c::OneOf, g::Grammar, tree::AbstractRuleNode)::Bool + return any(check_tree(cons, g, tree) for cons in c.constraints) +end diff --git a/src/propagatorconstraints/ordered.jl b/src/propagatorconstraints/ordered.jl index ef7e0ab..18a1eb0 100644 --- a/src/propagatorconstraints/ordered.jl +++ b/src/propagatorconstraints/ordered.jl @@ -17,9 +17,22 @@ end """ Propagates the Ordered constraint. """ -function propagate(c::Ordered, g::Grammar, context::GrammarContext, domain::Vector{Int})::Tuple{Vector{Int}, Vector{LocalConstraint}} +function propagate( + c::Ordered, + g::Grammar, + context::GrammarContext, + domain::Vector{Int}, + filled_hole::Union{HoleReference, Nothing} +)::Tuple{PropagatedDomain, Set{LocalConstraint}} + # Skip the propagator if the hole that was filled isn't a parent of the current hole + if !isnothing(filled_hole) && filled_hole.path != context.nodeLocation[begin:end-1] + return domain, Set() + end + ordered_constraint = LocalOrdered(context.nodeLocation, c.tree, c.order) - new_domain, new_constraints = propagate(ordered_constraint, g, context, domain) + if in(ordered_constraint, context.constraints) return domain, Set() end + + new_domain, new_constraints = propagate(ordered_constraint, g, context, domain, filled_hole) return new_domain, new_constraints end diff --git a/src/propagatorconstraints/require_on_left.jl b/src/propagatorconstraints/require_on_left.jl index bdb3346..8875d3a 100644 --- a/src/propagatorconstraints/require_on_left.jl +++ b/src/propagatorconstraints/require_on_left.jl @@ -12,15 +12,39 @@ Propagates the RequireOnLeft constraint. It removes every element from the domain that does not have a necessary predecessor in the left subtree. """ -function propagate(c::RequireOnLeft, ::Grammar, context::GrammarContext, domain::Vector{Int})::Tuple{Vector{Int}, Vector{LocalConstraint}} - rules_on_left = rulesonleft(context.originalExpr, context.nodeLocation) - +function propagate( + c::RequireOnLeft, + ::Grammar, + context::GrammarContext, + domain::Vector{Int}, + filled_hole::Union{HoleReference, Nothing} +)::Tuple{PropagatedDomain, Set{LocalConstraint}} + # Skip the propagator if the hole that was filled isn't a parent of the current hole + if !isnothing(filled_hole) && filled_hole.path != context.nodeLocation[begin:end-1] + return domain, Set() + end + + + if context.nodeLocation == [] + rules_on_left = Set{Int}() + else + rules_on_left = rulesonleft(context.originalExpr, context.nodeLocation) + end + last_rule_index = 0 for (i, r) ∈ enumerate(c.order) - r in rules_on_left ? last_rule_index = i : break + r ∈ rules_on_left ? last_rule_index = i : break end - rules_to_remove = Set(c.order[last_rule_index+2:end]) # +2 because the one after the last index can be used + return filter((x) -> !(x in rules_to_remove), domain), Set() +end + + +""" +Checks if the given tree abides the constraint. +""" +function check_tree(c::RequireOnLeft, g::Grammar, tree::AbstractRuleNode)::Bool + @warn "RequireOnLeft.check_tree not implemented!" - return filter((x) -> !(x in rules_to_remove), domain), [] + return true end diff --git a/test/test_propagators.jl b/test/test_propagators.jl index 8f3fc28..b3ab3c0 100644 --- a/test/test_propagators.jl +++ b/test/test_propagators.jl @@ -9,22 +9,29 @@ @testset "Propagating comesafter" begin constraint = ComesAfter(1, [9]) - context = GrammarContext(RuleNode(10, [Hole(get_domain(g₁, :Real)), Hole(get_domain(g₁, :Real))]), [1], []) - domain, _ = propagate(constraint, g₁, context, Vector(1:9)) + context = GrammarContext(RuleNode(10, [Hole(get_domain(g₁, :Real)), Hole(get_domain(g₁, :Real))]), [1], Set{Int}()) + domain, _ = propagate(constraint, g₁, context, Vector(1:9), nothing) @test domain == Vector(2:9) end @testset "Propagating require on left" begin constraint = RequireOnLeft([2, 1]) - context = GrammarContext(RuleNode(10, [RuleNode(3), Hole(get_domain(g₁, :Real))]), [2], []) - domain, _ = propagate(constraint, g₁, context, Vector(1:9)) + context = GrammarContext(RuleNode(10, [RuleNode(3), Hole(get_domain(g₁, :Real))]), [2], Set{Int}()) + domain, _ = propagate(constraint, g₁, context, Vector(1:9), nothing) @test domain == Vector(2:9) end + @testset "Propagating require on left 2" begin + constraint = RequireOnLeft([2, 1]) + context = GrammarContext(RuleNode(10, [RuleNode(2), Hole(get_domain(g₁, :Real))]), [2], Set{Int}()) + domain, _ = propagate(constraint, g₁, context, Vector(1:9), nothing) + @test domain == Vector(1:9) + end + @testset "Propagating forbidden path" begin constraint = ForbiddenPath([10, 1]) - context = GrammarContext(RuleNode(10, [RuleNode(3), Hole(get_domain(g₁, :Real))]), [2], []) - domain, _ = propagate(constraint, g₁, context, Vector(1:9)) + context = GrammarContext(RuleNode(10, [RuleNode(3), Hole(get_domain(g₁, :Real))]), [2], Set{Int}()) + domain, _ = propagate(constraint, g₁, context, Vector(1:9), nothing) @test domain == Vector(2:9) end @@ -33,8 +40,8 @@ [], MatchNode(10, [MatchNode(1), MatchNode(1)]) ) - context = GrammarContext(RuleNode(10, [RuleNode(1), Hole(get_domain(g₁, :Real))]), [2], []) - domain, _ = propagate(constraint, g₁, context, Vector(1:9)) + context = GrammarContext(RuleNode(10, [RuleNode(1), Hole(get_domain(g₁, :Real))]), [2], Set{Int}()) + domain, _ = propagate(constraint, g₁, context, Vector(1:9), nothing) @test domain == Vector(2:9) end @@ -43,8 +50,8 @@ [], MatchNode(10, [MatchNode(1), MatchVar(:x)]) ) - context = GrammarContext(RuleNode(10, [RuleNode(1), Hole(get_domain(g₁, :Real))]), [2], []) - domain, _ = propagate(constraint, g₁, context, Vector(1:9)) + context = GrammarContext(RuleNode(10, [RuleNode(1), Hole(get_domain(g₁, :Real))]), [2], Set{Int}()) + domain, _ = propagate(constraint, g₁, context, Vector(1:9), nothing) @test domain == [] end @@ -53,12 +60,12 @@ [], MatchNode(10, [MatchVar(:x), MatchVar(:x)]) ) - context = GrammarContext(RuleNode(10, [RuleNode(1), Hole(get_domain(g₁, :Real))]), [2], []) - domain, _ = propagate(constraint, g₁, context, Vector(1:9)) + context = GrammarContext(RuleNode(10, [RuleNode(1), Hole(get_domain(g₁, :Real))]), [2], Set{Int}()) + domain, _ = propagate(constraint, g₁, context, Vector(1:9), nothing) @test domain == Vector(2:9) - context = GrammarContext(RuleNode(10, [RuleNode(5), Hole(get_domain(g₁, :Real))]), [2], []) - domain, _ = propagate(constraint, g₁, context, Vector(1:9)) + context = GrammarContext(RuleNode(10, [RuleNode(5), Hole(get_domain(g₁, :Real))]), [2], Set{Int}()) + domain, _ = propagate(constraint, g₁, context, Vector(1:9), nothing) @test domain == append!(Vector(1:4), Vector(6:9)) end @@ -72,9 +79,9 @@ MatchNode(10, [MatchVar(:x), MatchVar(:x)]) ) expr = RuleNode(10, [RuleNode(10, [RuleNode(2), RuleNode(1)]), RuleNode(10, [RuleNode(2), Hole(Herb.HerbGrammar.get_domain(g₁, :Real))])]) - context = GrammarContext(expr, [2, 2], []) - domain, _ = propagate(constraint₁, g₁, context, [1,2,3]) - domain, _ = propagate(constraint₂, g₁, context, domain) + context = GrammarContext(expr, [2, 2], Set{Int}()) + domain, _ = propagate(constraint₁, g₁, context, [1,2,3], nothing) + domain, _ = propagate(constraint₂, g₁, context, domain, nothing) @test domain == [3] end @@ -86,8 +93,8 @@ ) expr = RuleNode(10, [RuleNode(8), Hole(get_domain(g₁, :Real))]) - context = GrammarContext(expr, [2], []) - domain, _ = propagate(constraint₁, g₁, context, collect(1:9)) + context = GrammarContext(expr, [2], Set{Int}()) + domain, _ = propagate(constraint₁, g₁, context, collect(1:9), nothing) @test domain == [8, 9] end @@ -102,7 +109,7 @@ Hole(get_domain(g₁, :Real)) ]) ]) - context = GrammarContext(expr, [2, 2], []) + context = GrammarContext(expr, [2, 2], Set{Int}()) constraint = LocalOrdered( [], @@ -110,7 +117,7 @@ [:x₂, :x₁] ) - domain, _ = propagate(constraint, g₁, context, [1,2,3]) + domain, _ = propagate(constraint, g₁, context, [1,2,3], nothing) @test domain == [1] end