diff --git a/src/TreeSearch/explore.jl b/src/TreeSearch/explore.jl index 31e83b7f2..c94f48849 100644 --- a/src/TreeSearch/explore.jl +++ b/src/TreeSearch/explore.jl @@ -1,20 +1,17 @@ +"Generic implementation of the tree search algorithm for a given explore strategy." +@mustimplement "TreeSearch" tree_search(s::AbstractExploreStrategy, space, env, input) = nothing + + +################################################################################ +# Depth First Strategy +################################################################################ + """ Explore the tree search space with a depth-first strategy. The next visited node is the last one pushed in the stack of unexplored nodes. """ struct DepthFirstStrategy <: AbstractExploreStrategy end -abstract type AbstractBestFirstSearch <: AbstractExploreStrategy end - -""" -Explore the tree search space with a best-first strategy. -The next visited node is the one with the highest local dual bound. -""" -struct BestDualBoundStrategy <: AbstractBestFirstSearch end - -"Generic implementation of the tree search algorithm for a given explore strategy." -@mustimplement "TreeSearch" tree_search(s::AbstractExploreStrategy, space, env, input) = nothing - function tree_search(::DepthFirstStrategy, space, env, input) root_node = new_root(space, input) stack = Stack{typeof(root_node)}() @@ -28,6 +25,18 @@ function tree_search(::DepthFirstStrategy, space, env, input) return TreeSearch.tree_search_output(space, stack) end +################################################################################ +# Best First Strategy +################################################################################ + +abstract type AbstractBestFirstSearch <: AbstractExploreStrategy end + +""" +Explore the tree search space with a best-first strategy. +The next visited node is the one with the highest local dual bound. +""" +struct BestDualBoundStrategy <: AbstractBestFirstSearch end + function tree_search(strategy::AbstractBestFirstSearch, space, env, input) root_node = new_root(space, input) pq = PriorityQueue{typeof(root_node), Float64}() @@ -39,4 +48,43 @@ function tree_search(strategy::AbstractBestFirstSearch, space, env, input) end end return TreeSearch.tree_search_output(space, pq) +end + +################################################################################ +# Limited Discrepancy +################################################################################ + +struct LimitedDiscrepancyStrategy <: AbstractExploreStrategy + max_discrepancy::Int +end + +struct LimitedDiscrepancySpace <: AbstractSearchSpace + inner_space::AbstractSearchSpace + max_discrepancy::Int +end +struct LimitedDiscrepancyNode + inner_node::AbstractNode + discrepancy::Int +end + +new_root(space::LimitedDiscrepancySpace, input) = LimitedDiscrepancyNode(new_root(space.inner_space, input), space.max_discrepancy) +stop(space::LimitedDiscrepancySpace, nodes) = stop(space.inner_space, nodes) +tree_search_output(space::LimitedDiscrepancySpace, nodes) = tree_search_output(space.inner_space, nodes) + +function children(space::LimitedDiscrepancySpace, current::LimitedDiscrepancyNode, env, input) + lds_children = LimitedDiscrepancyNode[] + inner_children = children(space.inner_space, current.inner_node, env, input) + for (i, child) in enumerate(inner_children) + discrepancy = current.discrepancy - i + 1 + if discrepancy < 0 + break + end + pushfirst!(lds_children, LimitedDiscrepancyNode(child, discrepancy)) + end + return lds_children +end + +function tree_search(strategy::LimitedDiscrepancyStrategy, space, env, input) + space = LimitedDiscrepancySpace(space, strategy.max_discrepancy) + return tree_search(DepthFirstStrategy(), space, env, input) end \ No newline at end of file diff --git a/test/unit/Algorithm/explore.jl b/test/unit/Algorithm/explore.jl index a74469d36..b6e60bf73 100644 --- a/test/unit/Algorithm/explore.jl +++ b/test/unit/Algorithm/explore.jl @@ -66,3 +66,81 @@ function test_bfs() @test visit_order == [1, 3, 5, 7, 6, 4, 9, 8, 2, 11, 10] end register!(unit_tests, "explore", test_bfs) + +############################################################################################ +# Limited Discrepancy Explore Strategy +############################################################################################ + +struct NodeAe2 <: Coluna.TreeSearch.AbstractNode + id::Int + depth::Int + parent::Union{Nothing, NodeAe2} + + function NodeAe2(id::Int, parent::Union{Nothing, NodeAe2} = nothing) + depth = isnothing(parent) ? 0 : parent.depth + 1 + return new(id, depth, parent) + end +end + +Coluna.TreeSearch.get_root(node::NodeAe2) = isnothing(node.parent) ? node : ClA.root(node.parent) + +mutable struct CustomSearchSpaceAe2 <: Coluna.TreeSearch.AbstractSearchSpace + nb_branches::Int + max_depth::Int + nb_nodes_generated::Int + visit_order::Vector{Int} + + function CustomSearchSpaceAe2(nb_branches::Int, max_depth::Int) + return new(nb_branches, max_depth, 0, Int[]) + end +end + +function Coluna.TreeSearch.new_root(space::CustomSearchSpaceAe2, input) + space.nb_nodes_generated += 1 + return NodeAe2(1) +end + +Coluna.TreeSearch.stop(sp::CustomSearchSpaceAe2, _) = false + +function Coluna.TreeSearch.children(space::CustomSearchSpaceAe2, current, _, _) + children = NodeAe2[] + push!(space.visit_order, current.id) + if current.depth != space.max_depth + for _ in 1:space.nb_branches + space.nb_nodes_generated += 1 + node_id = space.nb_nodes_generated + child = NodeAe2(node_id, current) + push!(children, child) + end + end + return children +end + +Coluna.TreeSearch.tree_search_output(space::CustomSearchSpaceAe2, _) = space.visit_order + + +function test_lds() + # max_depth = 3, max_discrepancy = 2 + # + # 01------------------------------------------ + # 02-------------------- 12---------- 18 + # 03------ 07--- 10 13--- 16 19 + # 04 05 06 08 09 11 14 15 17 20 + + # ============================================= + + # 1* + # -------------------------------------------------------------------------------------- + # 2* 3* 4* + # -------------------------------- ---------------------------------- ------------------------------ + # 5* 6* 7* 17* 18* 19 26* 27 28 + # 8* 9* 10* 11* 12* 13 14* 15 16 20* 21* 22 23* 24 25 29* 30 31 + + + search_space = CustomSearchSpaceAe2(3, 3) + visit_order = Coluna.TreeSearch.tree_search(Coluna.TreeSearch.LimitedDiscrepancyStrategy(2), search_space, nothing, nothing) + + @test visit_order == [1, 2, 5, 8, 9, 10, 6, 11, 12, 7, 14, 3, 17, 20, 21, 18, 23, 4, 26, 29] + @test length(visit_order) == 20 +end +register!(unit_tests, "explore", test_lds)