Skip to content

Commit

Permalink
Merge pull request #11 from Herb-AI/fix-local-ordered
Browse files Browse the repository at this point in the history
Local ordered bug
  • Loading branch information
sebdumancic authored Jun 29, 2023
2 parents b1a4c58 + 17616da commit 189e7f1
Showing 1 changed file with 60 additions and 56 deletions.
116 changes: 60 additions & 56 deletions src/localconstraints/local_ordered.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,91 +34,97 @@ function propagate(c::LocalOrdered, ::Grammar, context::GrammarContext, domain::
# Match attempt failed because we had to compare with a hole.
# If the hole would've been filled it might have succeeded, so we cannot yet remove the constraint.
return domain, [c]
else
hole_var = nothing
hole_path::Vector{Int} = []
if match isa Tuple{Symbol, Vector{Int}}
hole_var, hole_path = match
end
elseif match isa Tuple{Symbol, Vector{Int}}
hole_var, hole_path = match
@assert hole_var keys(vars)
@assert hole_var c.order

hole_index = findfirst(isequal(hole_var), c.order)

can_be_deleted = true
for var c.order[1:hole_index-1]
new_domain = make_greater_or_equal(vars[hole_var], vars[var], domain, hole_path)
new_domain softfail && continue
new_domain, can_be_deletedᵢ = make_greater_or_equal(vars[hole_var], vars[var], domain, hole_path)
if !can_be_deletedᵢ
can_be_deleted = false
end
domain = new_domain
end

for var c.order[hole_index+1:end]
new_domain = make_smaller_or_equal(vars[hole_var], vars[var], domain, hole_path)
new_domain softfail && continue
new_domain, can_be_deletedᵢ = make_smaller_or_equal(vars[hole_var], vars[var], domain, hole_path)
if !can_be_deletedᵢ
can_be_deleted = false
end
domain = new_domain
end
end

return domain, []
return domain, can_be_deleted ? [] : [c]
else
@error("Unexpected result from pattern match, not propagating constraint $c")
return domain, [c]
end
end

"""
Filters the `domain` of the hole at `hole_location` in `rn₁` to make `rn₁` be ordered before `rn₂`.
Returns the filtered domain, and a boolean indicating if this constraint can be deleted.
"""
function make_smaller_or_equal(
rn₁::RuleNode,
rn₂::RuleNode,
domain::Vector{Int},
hole_location::Vector{Int}
)::Union{Vector{Int}, MatchFail}
)::Tuple{Vector{Int}, Bool}

if rn₁.ind < rn₂.ind
return domain
return domain, true
elseif rn₁.ind > rn₂.ind
return Int[]
return Int[], true
else
# rn₁.ind == rn₂.ind
for (i, (c₁, c₂)) enumerate(zip(rn₁.children, rn₂.children))
if i == hole_location[1]
return make_smaller_or_equal(c₁, c₂, domain, hole_location[2:end])
else
comparison_value = _rulenode_compare(c₁, c₂)
comparison_value softfail && return softfail
comparison_value softfail && return domain, false
if comparison_value == -1 # c₁ < c₂
return domain
return domain, true
elseif comparison_value == 1 # c₁ > c₂
return Int[]
return Int[], true
end
end
end
end
end

function make_smaller_or_equal(
::Hole,
::RuleNode,
::Vector{Int},
::Vector{Int}
)::Union{Vector{Int}, MatchFail}
return softfail
end

function make_smaller_or_equal(
h::Hole,
rn::RuleNode,
domain::Vector{Int},
hole_location::Vector{Int}
)::Union{Vector{Int}, MatchFail}
)::Tuple{Vector{Int}, Bool}
@assert hole_location == []
return filter(x -> x rn.ind, domain)
return filter(x -> x rn.ind, domain), false
end

function make_smaller_or_equal(
::RuleNode,
::Hole,
domain::Vector{Int},
::Vector{Int}
)::Tuple{Vector{Int}, Bool}
return domain, false
end

function make_greater_or_equal(
h₁::Hole,
h₂::Hole,
function make_smaller_or_equal(
::Hole,
::Hole,
domain::Vector{Int},
hole_location::Vector{Int}
)::Union{Vector{Int}, MatchFail}
)::Tuple{Vector{Int}, Bool}
@assert hole_location == []
m = maximum(findall(h₂.domain))
return filter(x x m, domain)
return domain, false
end


Expand All @@ -127,59 +133,57 @@ function make_greater_or_equal(
rn₂::RuleNode,
domain::Vector{Int},
hole_location::Vector{Int}
)::Union{Vector{Int}, MatchFail}
)::Tuple{Vector{Int}, Bool}

if rn₁.ind > rn₂.ind
return domain
return domain, true
elseif rn₁.ind < rn₂.ind
return Int[]
return Int[], true
else
# rn₁.ind == rn₂.ind
for (i, (c₁, c₂)) enumerate(zip(rn₁.children, rn₂.children))
if i == hole_location[1]
return make_greater_or_equal(c₁, c₂, domain, hole_location[2:end])
else
comparison_value = _rulenode_compare(c₁, c₂)
comparison_value softfail && return softfail
comparison_value softfail && return domain, false
if comparison_value == 1 # c₁ > c₂
return domain
return domain, true
elseif comparison_value == -1 # c₁ < c₂
return Int[]
return Int[], true
end
end
end
end
end

function make_greater_or_equal(
::Hole,
::RuleNode,
::Vector{Int},
::Vector{Int}
)::Union{Vector{Int}, MatchFail}
return softfail
end

function make_greater_or_equal(
h::Hole,
rn::RuleNode,
domain::Vector{Int},
hole_location::Vector{Int}
)::Union{Vector{Int}, MatchFail}
)::Tuple{Vector{Int}, Bool}
@assert hole_location == []
return filter(x -> x rn.ind, domain), false
end

return filter(x -> x rn.ind, domain)
function make_greater_or_equal(
::RuleNode,
::Hole,
::Vector{Int},
::Vector{Int}
)::Tuple{Vector{Int}, Bool}
return domain, false
end

function make_larger_or_equal(
function make_greater_or_equal(
h₁::Hole,
h₂::Hole,
domain::Vector{Int},
hole_location::Vector{Int}
)::Union{Vector{Int}, MatchFail}
)::Tuple{Vector{Int}, Bool}
@assert hole_location == []
m = maximum(findall(h₂.domain))
return filter(x x m, domain)
return domain, false
end

"""
Expand Down

0 comments on commit 189e7f1

Please sign in to comment.