Skip to content

Commit

Permalink
Merge branch 'release/0.3.0'
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidfrontier45 committed Jun 18, 2024
2 parents 93c0dc8 + 10bab8a commit fae7cdc
Show file tree
Hide file tree
Showing 8 changed files with 86 additions and 17 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ keywords = ["combinatorial", "optimization", "tree", "graph"]
license-file = "LICENSE"
readme = "README.md"
categories = ["algorithms"]
version = "0.2.0"
version = "0.3.0"
edition = "2021"
authors = ["Du Shiqiao <[email protected]>"]

Expand Down
17 changes: 11 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ cargo add tree_traversal
use tree_traversal::bbs::bbs;

type Node = Vec<bool>;

fn main() {
let weights = [4, 2, 6, 3, 4];
let profits = [100, 20, 2, 5, 10];
Expand Down Expand Up @@ -79,8 +78,6 @@ fn main() {
s
};

// tree traversal assumes a minimization problem
// if you want to solve maximization problem, subtract your actual score from the MAX value
let lower_bound_fn = |n: &Node| {
let current_profit = total_profit(n);
let max_remained_profit: u32 = profits[n.len()..].into_iter().sum();
Expand All @@ -90,9 +87,17 @@ fn main() {
let cost_fn = |n: &Node| Some(u32::MAX - total_profit(n));

let leaf_check_fn = |n: &Node| n.len() == total_items;

let (cost, best_node) =
bbs(vec![], successor_fn, lower_bound_fn, cost_fn, leaf_check_fn).unwrap();
let max_ops = usize::MAX;

let (cost, best_node) = bbs(
vec![],
successor_fn,
lower_bound_fn,
cost_fn,
leaf_check_fn,
max_ops,
)
.unwrap();
let cost = u32::MAX - cost;

dbg!((best_node, cost));
Expand Down
12 changes: 10 additions & 2 deletions examples/bbs_knapsack_problem.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,17 @@ fn main() {
let cost_fn = |n: &Node| Some(u32::MAX - total_profit(n));

let leaf_check_fn = |n: &Node| n.len() == total_items;
let max_ops = usize::MAX;

let (cost, best_node) =
bbs(vec![], successor_fn, lower_bound_fn, cost_fn, leaf_check_fn).unwrap();
let (cost, best_node) = bbs(
vec![],
successor_fn,
lower_bound_fn,
cost_fn,
leaf_check_fn,
max_ops,
)
.unwrap();
let cost = u32::MAX - cost;

dbg!((best_node, cost));
Expand Down
24 changes: 21 additions & 3 deletions src/bbs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pub struct BbsReachable<N, FN, FC, C> {
successor_fn: FN,
lower_bound_fn: FC,
current_best_cost: C,
remained_ops: usize,
}

impl<N, FN, IN, FC, C> Iterator for BbsReachable<N, FN, FC, C>
Expand All @@ -22,6 +23,10 @@ where
type Item = N;

fn next(&mut self) -> Option<Self::Item> {
if self.remained_ops == 0 {
return None;
}
self.remained_ops -= 1;
if let Some(n) = self.to_see.pop() {
// get lower bound
if let Some(lb) = (self.lower_bound_fn)(&n) {
Expand Down Expand Up @@ -53,6 +58,7 @@ pub fn bbs_reach<N, FN, IN, FC, C>(
start: N,
successor_fn: FN,
lower_bound_fn: FC,
max_ops: usize,
) -> BbsReachable<N, FN, FC, C>
where
N: Clone,
Expand All @@ -66,6 +72,7 @@ where
successor_fn,
lower_bound_fn,
current_best_cost: C::max_value(),
remained_ops: max_ops,
}
}

Expand All @@ -76,6 +83,7 @@ where
/// - `lower_bound_fn` returns the lower bound of a given node do decide wheather search deeper or not
/// - `cost_fn` returns the final cost of a leaf node
/// - `leaf_check_fn` check if a node is leaf or not
/// - `max_ops` is the maximum number of search operations to perform
///
/// This function returns Some of a tuple of (cost, leaf node) if found, otherwise returns None
pub fn bbs<N, IN, FN, FC1, FC2, C, FR>(
Expand All @@ -84,6 +92,7 @@ pub fn bbs<N, IN, FN, FC1, FC2, C, FR>(
lower_bound_fn: FC1,
cost_fn: FC2,
leaf_check_fn: FR,
max_ops: usize,
) -> Option<(C, N)>
where
N: Clone,
Expand All @@ -94,7 +103,7 @@ where
C: Ord + Copy + Bounded,
FR: Fn(&N) -> bool,
{
let mut res = bbs_reach(start, successor_fn, lower_bound_fn);
let mut res = bbs_reach(start, successor_fn, lower_bound_fn, max_ops);
let mut best_leaf_node = None;
loop {
let op_n = res.next();
Expand Down Expand Up @@ -186,8 +195,17 @@ mod test {

let leaf_check_fn = |n: &Node| n.len() == total_items;

let (cost, best_node) =
bbs(vec![], successor_fn, lower_bound_fn, cost_fn, leaf_check_fn).unwrap();
let max_ops = usize::MAX;

let (cost, best_node) = bbs(
vec![],
successor_fn,
lower_bound_fn,
cost_fn,
leaf_check_fn,
max_ops,
)
.unwrap();
let cost = u32::MAX - cost;

assert_eq!(cost, 120);
Expand Down
6 changes: 5 additions & 1 deletion src/bfs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@ use crate::bms::bms;
/// - `successor_fn` returns a list of successors for a given node.
/// - `cost_fn` returns the final cost of a leaf node
/// - `leaf_check_fn` check if a node is leaf or not
/// - `max_ops` is the maximum number of search operations to perform
///
/// This function returns Some of a tuple of (cost, leaf node) if found, otherwise returns None
pub fn bfs<N, IN, FN, FC, C, FR>(
start: N,
successor_fn: FN,
cost_fn: FC,
leaf_check_fn: FR,
max_ops: usize,
) -> Option<(C, N)>
where
N: Clone,
Expand All @@ -34,6 +36,7 @@ where
usize::MAX,
cost_fn,
leaf_check_fn,
max_ops,
)
}

Expand Down Expand Up @@ -99,8 +102,9 @@ mod test {
};

let leaf_check_fn = |n: &Node| n.len() == total_items;
let max_ops = usize::MAX;

let (cost, best_node) = bfs(vec![], successor_fn, cost_fn, leaf_check_fn).unwrap();
let (cost, best_node) = bfs(vec![], successor_fn, cost_fn, leaf_check_fn, max_ops).unwrap();
let cost = u32::MAX - cost;

assert_eq!(cost, 6);
Expand Down
20 changes: 19 additions & 1 deletion src/bms.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ pub struct BmsReachable<N, FN, FC, C: Ord> {
eval_fn: FC,
branch_factor: usize,
beam_width: usize,
remained_ops: usize,
pool: BinaryHeap<ScoredItem<C, N>>,
}

Expand All @@ -54,6 +55,10 @@ where
type Item = N;

fn next(&mut self) -> Option<Self::Item> {
if self.remained_ops == 0 {
return None;
}
self.remained_ops -= 1;
if self.to_see.is_empty() {
let max_iter = std::cmp::min(self.pool.len(), self.beam_width);
for _ in 0..max_iter {
Expand Down Expand Up @@ -88,6 +93,7 @@ pub fn bms_reach<N, FN, IN, FC, C>(
eval_fn: FC,
branch_factor: usize,
beam_width: usize,
max_ops: usize,
) -> BmsReachable<N, FN, FC, C>
where
N: Clone,
Expand All @@ -102,6 +108,7 @@ where
eval_fn,
branch_factor,
beam_width,
remained_ops: max_ops,
pool: BinaryHeap::new(),
}
}
Expand All @@ -115,6 +122,7 @@ where
/// - `beam_width` decides muximum number of nodes at each depth.
/// - `cost_fn` returns the final cost of a leaf node
/// - `leaf_check_fn` check if a node is leaf or not
/// - `max_ops` is the maximum number of search operations to perform
///
/// This function returns Some of a tuple of (cost, leaf node) if found, otherwise returns None
pub fn bms<N, IN, FN, FC1, FC2, C, FR>(
Expand All @@ -125,6 +133,7 @@ pub fn bms<N, IN, FN, FC1, FC2, C, FR>(
beam_width: usize,
cost_fn: FC2,
leaf_check_fn: FR,
max_ops: usize,
) -> Option<(C, N)>
where
N: Clone,
Expand All @@ -135,7 +144,14 @@ where
C: Ord + Copy + Bounded,
FR: Fn(&N) -> bool,
{
let mut res = bms_reach(start, successor_fn, eval_fn, branch_factor, beam_width);
let mut res = bms_reach(
start,
successor_fn,
eval_fn,
branch_factor,
beam_width,
max_ops,
);
let mut best_leaf_node = None;
let mut current_best_cost = C::max_value();
loop {
Expand Down Expand Up @@ -319,6 +335,7 @@ mod test {

let branch_factor = 10;
let beam_width = 5;
let max_ops = usize::MAX;
let cost_fn = |n: &Node| Some(n.t + time_func(n.city, start));
let leaf_check_fn = |n: &Node| n.is_leaf();

Expand All @@ -330,6 +347,7 @@ mod test {
beam_width,
cost_fn,
leaf_check_fn,
max_ops,
)
.unwrap();

Expand Down
6 changes: 5 additions & 1 deletion src/dfs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@ use crate::bbs::bbs;
/// - `successor_fn` returns a list of successors for a given node.
/// - `cost_fn` returns the final cost of a leaf node
/// - `leaf_check_fn` check if a node is leaf or not
/// - `max_ops` is the maximum number of search operations to perform
///
/// This function returns Some of a tuple of (cost, leaf node) if found, otherwise returns None
pub fn dfs<N, IN, FN, FC, C, FR>(
start: N,
successor_fn: FN,
cost_fn: FC,
leaf_check_fn: FR,
max_ops: usize,
) -> Option<(C, N)>
where
N: Clone,
Expand All @@ -32,6 +34,7 @@ where
|_| Some(C::min_value()),
cost_fn,
leaf_check_fn,
max_ops,
)
}

Expand Down Expand Up @@ -97,8 +100,9 @@ mod test {
};

let leaf_check_fn = |n: &Node| n.len() == total_items;
let max_ops = usize::MAX;

let (cost, best_node) = dfs(vec![], successor_fn, cost_fn, leaf_check_fn).unwrap();
let (cost, best_node) = dfs(vec![], successor_fn, cost_fn, leaf_check_fn, max_ops).unwrap();
let cost = u32::MAX - cost;

assert_eq!(cost, 6);
Expand Down
16 changes: 14 additions & 2 deletions src/gds.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use crate::bms::bms;
/// - `eval_fn` returns the approximated cost of a given node to sort and select k-best
/// - `cost_fn` returns the final cost of a leaf node
/// - `leaf_check_fn` check if a node is leaf or not
/// - `max_ops` is the maximum number of search operations to perform
///
/// This function returns Some of a tuple of (cost, leaf node) if found, otherwise returns None
pub fn gds<N, IN, FN, FC1, FC2, C, FR>(
Expand All @@ -19,6 +20,7 @@ pub fn gds<N, IN, FN, FC1, FC2, C, FR>(
eval_fn: FC1,
cost_fn: FC2,
leaf_check_fn: FR,
max_ops: usize,
) -> Option<(C, N)>
where
N: Clone,
Expand All @@ -37,6 +39,7 @@ where
1,
cost_fn,
leaf_check_fn,
max_ops,
)
}

Expand Down Expand Up @@ -173,8 +176,17 @@ mod test {
let cost_fn = |n: &Node| Some(n.t + time_func(n.city, start));
let leaf_check_fn = |n: &Node| n.is_leaf();

let (cost, best_node) =
gds(root_node, successor_fn, eval_fn, cost_fn, leaf_check_fn).unwrap();
let max_ops = usize::MAX;

let (cost, best_node) = gds(
root_node,
successor_fn,
eval_fn,
cost_fn,
leaf_check_fn,
max_ops,
)
.unwrap();

assert!(cost < 9000);
let mut visited_cities = best_node.parents.clone();
Expand Down

0 comments on commit fae7cdc

Please sign in to comment.