Skip to content

Commit

Permalink
Merge pull request #102 from Herb-AI/dev
Browse files Browse the repository at this point in the history
v0.3
  • Loading branch information
THinnerichs authored May 15, 2024
2 parents 5b315ce + a43eb05 commit 24c3600
Show file tree
Hide file tree
Showing 34 changed files with 1,634 additions and 738 deletions.
10 changes: 5 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "HerbSearch"
uuid = "3008d8e8-f9aa-438a-92ed-26e9c7b4829f"
authors = ["Sebastijan Dumancic <[email protected]>", "Jaap de Jong <[email protected]>", "Nicolae Filat <[email protected]>", "Piotr Cichoń <[email protected]>", "Tilman Hinnerichs <[email protected]>"]
version = "0.2.0"
version = "0.3.0"

[deps]
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Expand All @@ -17,10 +17,10 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[compat]
DataStructures = "0.17,0.18"
HerbConstraints = "^0.1.0"
HerbCore = "^0.2.0"
HerbGrammar = "^0.2.0"
HerbInterpret = "^0.1.1"
HerbConstraints = "^0.2.0"
HerbCore = "^0.3.0"
HerbGrammar = "^0.3.0"
HerbInterpret = "^0.1.3"
HerbSpecification = "^0.1.0"
MLStyle = "^0.4.17"
StatsBase = "^0.34"
Expand Down
13 changes: 11 additions & 2 deletions src/HerbSearch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@ using MLStyle
include("sampling_grammar.jl")

include("program_iterator.jl")
include("count_expressions.jl")
include("uniform_iterator.jl")

include("heuristics.jl")

include("fixed_shaped_iterator.jl")
include("top_down_iterator.jl")

include("evaluate.jl")
Expand All @@ -35,8 +36,9 @@ include("genetic_functions/crossover.jl")
include("genetic_functions/select_parents.jl")
include("genetic_search_iterator.jl")

include("random_iterator.jl")

export
count_expressions,
ProgramIterator,
@programiterator,

Expand All @@ -47,12 +49,19 @@ export
heuristic_random,
heuristic_smallest_domain,

derivation_heuristic,

synth,
SynthResult,
optimal_program,
suboptimal_program,

FixedShapedIterator,
UniformIterator,
next_solution!,

TopDownIterator,
RandomIterator,
BFSIterator,
DFSIterator,
MLFSIterator,
Expand Down
20 changes: 0 additions & 20 deletions src/count_expressions.jl

This file was deleted.

107 changes: 107 additions & 0 deletions src/fixed_shaped_iterator.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
Base.@doc """
@programiterator FixedShapedIterator()
Enumerates all programs that extend from the provided fixed shaped tree.
The [Solver](@ref) is required to be in a state without any [Hole](@ref)s.
!!! warning: this iterator is used as a baseline for the constraint propagation thesis. After the thesis, this iterator can (and should) be deleted.
""" FixedShapedIterator
@programiterator FixedShapedIterator()

"""
priority_function(::FixedShapedIterator, g::AbstractGrammar, tree::AbstractRuleNode, parent_value::Union{Real, Tuple{Vararg{Real}}})
Assigns a priority value to a `tree` that needs to be considered later in the search. Trees with the lowest priority value are considered first.
"""
function priority_function(
::FixedShapedIterator,
g::AbstractGrammar,
tree::AbstractRuleNode,
parent_value::Union{Real, Tuple{Vararg{Real}}}
)
parent_value + 1;
end


"""
hole_heuristic(::FixedShapedIterator, node::AbstractRuleNode, max_depth::Int)::Union{ExpandFailureReason, HoleReference}
Defines a heuristic over fixed shaped holes. Returns a [`HoleReference`](@ref) once a hole is found.
"""
function hole_heuristic(::FixedShapedIterator, node::AbstractRuleNode, max_depth::Int)::Union{ExpandFailureReason, HoleReference}
return heuristic_leftmost_fixed_shaped_hole(node, max_depth);
end

"""
Base.iterate(iter::FixedShapedIterator)
Describes the iteration for a given [`TopDownIterator`](@ref) over the grammar. The iteration constructs a [`PriorityQueue`](@ref) first and then prunes it propagating the active constraints. Recursively returns the result for the priority queue.
"""
function Base.iterate(iter::FixedShapedIterator)
# Priority queue with number of nodes in the program
pq :: PriorityQueue{SolverState, Union{Real, Tuple{Vararg{Real}}}} = PriorityQueue()

solver = iter.solver
@assert !contains_nonuniform_hole(get_tree(iter.solver)) "A FixedShapedIterator cannot iterate partial programs with Holes"

if isfeasible(solver)
enqueue!(pq, get_state(solver), priority_function(iter, get_grammar(solver), get_tree(solver), 0))
end
return _find_next_complete_tree(solver, pq, iter)
end


"""
Base.iterate(iter::FixedShapedIterator, pq::DataStructures.PriorityQueue)
Describes the iteration for a given [`TopDownIterator`](@ref) and a [`PriorityQueue`](@ref) over the grammar without enqueueing new items to the priority queue. Recursively returns the result for the priority queue.
"""
function Base.iterate(iter::FixedShapedIterator, pq::DataStructures.PriorityQueue)
return _find_next_complete_tree(iter.solver, pq, iter)
end

"""
_find_next_complete_tree(solver::Solver, pq::PriorityQueue, iter::FixedShapedIterator)::Union{Tuple{RuleNode, PriorityQueue}, Nothing}
Takes a priority queue and returns the smallest AST from the grammar it can obtain from the queue or by (repeatedly) expanding trees that are in the queue.
Returns `nothing` if there are no trees left within the depth limit.
"""
function _find_next_complete_tree(
solver::Solver,
pq::PriorityQueue,
iter::FixedShapedIterator
)::Union{Tuple{RuleNode, PriorityQueue}, Nothing}
while length(pq) 0
(state, priority_value) = dequeue_pair!(pq)
load_state!(solver, state)

hole_res = hole_heuristic(iter, get_tree(solver), typemax(Int))
if hole_res already_complete
#the tree is complete
return (get_tree(solver), pq)
elseif hole_res limit_reached
# The maximum depth is reached
continue
elseif hole_res isa HoleReference
# UniformHole was found
(; hole, path) = hole_res

rules = findall(hole.domain)
number_of_rules = length(rules)
for (i, rule_index) enumerate(findall(hole.domain))
if i < number_of_rules
state = save_state!(solver)
end
@assert isfeasible(solver) "Attempting to expand an infeasible tree: $(get_tree(solver))"
remove_all_but!(solver, path, rule_index)
if isfeasible(solver)
enqueue!(pq, get_state(solver), priority_function(iter, get_grammar(solver), get_tree(solver), priority_value))
end
if i < number_of_rules
load_state!(solver, state)
end
end
end
end
return nothing
end
12 changes: 7 additions & 5 deletions src/genetic_search_iterator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,10 @@ Returns the best program within the population with respect to the fitness funct
function get_best_program(population::Array{RuleNode}, iter::GeneticSearchIterator)::RuleNode
best_program = nothing
best_fitness = 0
grammar = get_grammar(iter.solver)
for index eachindex(population)
chromosome = population[index]
zipped_outputs = zip([example.out for example in iter.spec], execute_on_input(iter.grammar, chromosome, [example.in for example in iter.spec]))
zipped_outputs = zip([example.out for example in iter.spec], execute_on_input(grammar, chromosome, [example.in for example in iter.spec]))
fitness_value = fitness(iter, chromosome, collect(zipped_outputs))
if isnothing(best_program)
best_fitness = fitness_value
Expand All @@ -137,13 +138,14 @@ Iterates the search space using a genetic algorithm. First generates a populatio
"""
function Base.iterate(iter::GeneticSearchIterator)
validate_iterator(iter)
grammar = iter.grammar
grammar = get_grammar(iter.solver)

population = Vector{RuleNode}(undef,iter.population_size)

start_symbol = get_starting_symbol(iter.solver)
for i in 1:iter.population_size
# sample a random nodes using start symbol and grammar
population[i] = rand(RuleNode, grammar, iter.sym, iter.maximum_initial_population_depth)
population[i] = rand(RuleNode, grammar, start_symbol, iter.maximum_initial_population_depth)
end
best_program = get_best_program(population, iter)
return (best_program, GeneticIteratorState(population))
Expand All @@ -160,7 +162,7 @@ function Base.iterate(iter::GeneticSearchIterator, current_state::GeneticIterato
current_population = current_state.population

# Calculate fitness
zipped_outputs(chromosome) = zip([example.out for example in iter.spec], execute_on_input(iter.grammar, chromosome, [example.in for example in iter.spec]))
zipped_outputs(chromosome) = zip([example.out for example in iter.spec], execute_on_input(get_grammar(iter.solver), chromosome, [example.in for example in iter.spec]))
fitness_array = [fitness(iter, chromosome, collect(zipped_outputs(chromosome))) for chromosome in current_population]

new_population = Vector{RuleNode}(undef,iter.population_size)
Expand All @@ -187,7 +189,7 @@ function Base.iterate(iter::GeneticSearchIterator, current_state::GeneticIterato
for chromosome in new_population
random_number = rand()
if random_number < iter.mutation_probability
mutate!(iter, chromosome, iter.grammar)
mutate!(iter, chromosome, get_grammar(iter.solver))
end
end

Expand Down
36 changes: 32 additions & 4 deletions src/heuristics.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,41 @@
using Random

"""
heuristic_leftmost_fixed_shaped_hole(node::AbstractRuleNode, max_depth::Int)::Union{ExpandFailureReason, HoleReference}
Defines a heuristic over [FixedShapeHole](@ref)s, where the left-most hole always gets considered first. Returns a [`HoleReference`](@ref) once a hole is found. This is the default option for enumerators.
"""
function heuristic_leftmost_fixed_shaped_hole(node::AbstractRuleNode, max_depth::Int)::Union{ExpandFailureReason, HoleReference}
function leftmost(node::AbstractRuleNode, max_depth::Int, path::Vector{Int})::Union{ExpandFailureReason, HoleReference}
if max_depth == 0 return limit_reached end

for (i, child) in enumerate(node.children)
new_path = push!(copy(path), i)
hole_res = leftmost(child, max_depth-1, new_path)
if (hole_res == limit_reached) || (hole_res isa HoleReference)
return hole_res
end
end

return already_complete
end

function leftmost(hole::UniformHole, max_depth::Int, path::Vector{Int})::Union{ExpandFailureReason, HoleReference}
if max_depth == 0 return limit_reached end
return HoleReference(hole, path)
end

return leftmost(node, max_depth, Vector{Int}())
end


"""
heuristic_leftmost(node::AbstractRuleNode, max_depth::Int)::Union{ExpandFailureReason, HoleReference}
Defines a heuristic over holes, where the left-most hole always gets considered first. Returns a [`HoleReference`](@ref) once a hole is found. This is the default option for enumerators.
"""
function heuristic_leftmost(node::AbstractRuleNode, max_depth::Int)::Union{ExpandFailureReason, HoleReference}
function leftmost(node::RuleNode, max_depth::Int, path::Vector{Int})::Union{ExpandFailureReason, HoleReference}
function leftmost(node::AbstractRuleNode, max_depth::Int, path::Vector{Int})::Union{ExpandFailureReason, HoleReference}
if max_depth == 0 return limit_reached end

for (i, child) in enumerate(node.children)
Expand Down Expand Up @@ -35,7 +63,7 @@ end
Defines a heuristic over holes, where the right-most hole always gets considered first. Returns a [`HoleReference`](@ref) once a hole is found.
"""
function heuristic_rightmost(node::AbstractRuleNode, max_depth::Int)::Union{ExpandFailureReason, HoleReference}
function rightmost(node::RuleNode, max_depth::Int, path::Vector{Int})::Union{ExpandFailureReason, HoleReference}
function rightmost(node::AbstractRuleNode, max_depth::Int, path::Vector{Int})::Union{ExpandFailureReason, HoleReference}
if max_depth == 0 return limit_reached end

for (i, child) in Iterators.reverse(enumerate(node.children))
Expand Down Expand Up @@ -64,7 +92,7 @@ end
Defines a heuristic over holes, where random holes get chosen randomly using random exploration. Returns a [`HoleReference`](@ref) once a hole is found.
"""
function heuristic_random(node::AbstractRuleNode, max_depth::Int)::Union{ExpandFailureReason, HoleReference}
function random(node::RuleNode, max_depth::Int, path::Vector{Int})::Union{ExpandFailureReason, HoleReference}
function random(node::AbstractRuleNode, max_depth::Int, path::Vector{Int})::Union{ExpandFailureReason, HoleReference}
if max_depth == 0 return limit_reached end

for (i, child) in shuffle(collect(enumerate(node.children)))
Expand Down Expand Up @@ -92,7 +120,7 @@ end
Defines a heuristic over all available holes in the unfinished AST, by considering the size of their respective domains. A domain here describes the number of possible derivations with respect to the constraints. Returns a [`HoleReference`](@ref) once a hole is found.
"""
function heuristic_smallest_domain(node::AbstractRuleNode, max_depth::Int)::Union{ExpandFailureReason, HoleReference}
function smallest_domain(node::RuleNode, max_depth::Int, path::Vector{Int})::Union{ExpandFailureReason, HoleReference}
function smallest_domain(node::AbstractRuleNode, max_depth::Int, path::Vector{Int})::Union{ExpandFailureReason, HoleReference}
if max_depth == 0 return limit_reached end

smallest_size::Int = typemax(Int)
Expand Down
Loading

2 comments on commit 24c3600

@ReubenJ
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/106867

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.3.0 -m "<description of version>" 24c3600a0237c7c2021f73518c0940f730fd5ee1
git push origin v0.3.0

Please sign in to comment.