diff --git a/src/localconstraints/local_ordered.jl b/src/localconstraints/local_ordered.jl index c3c3db4..880271b 100644 --- a/src/localconstraints/local_ordered.jl +++ b/src/localconstraints/local_ordered.jl @@ -34,44 +34,51 @@ function propagate(c::LocalOrdered, ::Grammar, context::GrammarContext, domain:: # Match attempt failed because we had to compare with a hole. # If the hole would've been filled it might have succeeded, so we cannot yet remove the constraint. return domain, [c] - else - hole_var = nothing - hole_path::Vector{Int} = [] - if match isa Tuple{Symbol, Vector{Int}} - hole_var, hole_path = match - end + elseif match isa Tuple{Symbol, Vector{Int}} + hole_var, hole_path = match @assert hole_var ∈ keys(vars) @assert hole_var ∈ c.order hole_index = findfirst(isequal(hole_var), c.order) - + can_be_deleted = true for var ∈ c.order[1:hole_index-1] - new_domain = make_greater_or_equal(vars[hole_var], vars[var], domain, hole_path) - new_domain ≡ softfail && continue + new_domain, can_be_deletedᵢ = make_greater_or_equal(vars[hole_var], vars[var], domain, hole_path) + if !can_be_deletedᵢ + can_be_deleted = false + end domain = new_domain end for var ∈ c.order[hole_index+1:end] - new_domain = make_smaller_or_equal(vars[hole_var], vars[var], domain, hole_path) - new_domain ≡ softfail && continue + new_domain, can_be_deletedᵢ = make_smaller_or_equal(vars[hole_var], vars[var], domain, hole_path) + if !can_be_deletedᵢ + can_be_deleted = false + end domain = new_domain end - end - return domain, [] + return domain, can_be_deleted ? [] : [c] + else + @error("Unexpected result from pattern match, not propagating constraint $c") + return domain, [c] + end end +""" +Filters the `domain` of the hole at `hole_location` in `rn₁` to make `rn₁` be ordered before `rn₂`. +Returns the filtered domain, and a boolean indicating if this constraint can be deleted. +""" function make_smaller_or_equal( rn₁::RuleNode, rn₂::RuleNode, domain::Vector{Int}, hole_location::Vector{Int} -)::Union{Vector{Int}, MatchFail} +)::Tuple{Vector{Int}, Bool} if rn₁.ind < rn₂.ind - return domain + return domain, true elseif rn₁.ind > rn₂.ind - return Int[] + return Int[], true else # rn₁.ind == rn₂.ind for (i, (c₁, c₂)) ∈ enumerate(zip(rn₁.children, rn₂.children)) @@ -79,46 +86,45 @@ function make_smaller_or_equal( return make_smaller_or_equal(c₁, c₂, domain, hole_location[2:end]) else comparison_value = _rulenode_compare(c₁, c₂) - comparison_value ≡ softfail && return softfail + comparison_value ≡ softfail && return domain, false if comparison_value == -1 # c₁ < c₂ - return domain + return domain, true elseif comparison_value == 1 # c₁ > c₂ - return Int[] + return Int[], true end end end end end -function make_smaller_or_equal( - ::Hole, - ::RuleNode, - ::Vector{Int}, - ::Vector{Int} -)::Union{Vector{Int}, MatchFail} - return softfail -end function make_smaller_or_equal( h::Hole, rn::RuleNode, domain::Vector{Int}, hole_location::Vector{Int} -)::Union{Vector{Int}, MatchFail} +)::Tuple{Vector{Int}, Bool} @assert hole_location == [] - return filter(x -> x ≤ rn.ind, domain) + return filter(x -> x ≤ rn.ind, domain), false end +function make_smaller_or_equal( + ::RuleNode, + ::Hole, + domain::Vector{Int}, + ::Vector{Int} +)::Tuple{Vector{Int}, Bool} + return domain, false +end -function make_greater_or_equal( - h₁::Hole, - h₂::Hole, +function make_smaller_or_equal( + ::Hole, + ::Hole, domain::Vector{Int}, hole_location::Vector{Int} -)::Union{Vector{Int}, MatchFail} +)::Tuple{Vector{Int}, Bool} @assert hole_location == [] - m = maximum(findall(h₂.domain)) - return filter(x → x ≤ m, domain) + return domain, false end @@ -127,12 +133,12 @@ function make_greater_or_equal( rn₂::RuleNode, domain::Vector{Int}, hole_location::Vector{Int} -)::Union{Vector{Int}, MatchFail} +)::Tuple{Vector{Int}, Bool} if rn₁.ind > rn₂.ind - return domain + return domain, true elseif rn₁.ind < rn₂.ind - return Int[] + return Int[], true else # rn₁.ind == rn₂.ind for (i, (c₁, c₂)) ∈ enumerate(zip(rn₁.children, rn₂.children)) @@ -140,46 +146,44 @@ function make_greater_or_equal( return make_greater_or_equal(c₁, c₂, domain, hole_location[2:end]) else comparison_value = _rulenode_compare(c₁, c₂) - comparison_value ≡ softfail && return softfail + comparison_value ≡ softfail && return domain, false if comparison_value == 1 # c₁ > c₂ - return domain + return domain, true elseif comparison_value == -1 # c₁ < c₂ - return Int[] + return Int[], true end end end end end -function make_greater_or_equal( - ::Hole, - ::RuleNode, - ::Vector{Int}, - ::Vector{Int} -)::Union{Vector{Int}, MatchFail} - return softfail -end - function make_greater_or_equal( h::Hole, rn::RuleNode, domain::Vector{Int}, hole_location::Vector{Int} -)::Union{Vector{Int}, MatchFail} +)::Tuple{Vector{Int}, Bool} @assert hole_location == [] + return filter(x -> x ≥ rn.ind, domain), false +end - return filter(x -> x ≥ rn.ind, domain) +function make_greater_or_equal( + ::RuleNode, + ::Hole, + ::Vector{Int}, + ::Vector{Int} +)::Tuple{Vector{Int}, Bool} + return domain, false end -function make_larger_or_equal( +function make_greater_or_equal( h₁::Hole, h₂::Hole, domain::Vector{Int}, hole_location::Vector{Int} -)::Union{Vector{Int}, MatchFail} +)::Tuple{Vector{Int}, Bool} @assert hole_location == [] - m = maximum(findall(h₂.domain)) - return filter(x → x ≤ m, domain) + return domain, false end """