Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

📦 NEW: Add MultiPV support #333

Merged
merged 1 commit into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::sync::RwLock;

static NUM_THREADS: AtomicUsize = AtomicUsize::new(1);
static HASH_SIZE_MB: AtomicUsize = AtomicUsize::new(16);
static MULTI_PV: AtomicUsize = AtomicUsize::new(1);

static CPUCT: Lazy<RwLock<f32>> = Lazy::new(|| RwLock::new(0.16));
static CPUCT_TAU: Lazy<RwLock<f32>> = Lazy::new(|| RwLock::new(0.84));
Expand All @@ -31,6 +32,14 @@ pub fn get_hash_size_mb() -> usize {
max(1, HASH_SIZE_MB.load(Ordering::Relaxed))
}

pub fn set_multi_pv(pv: usize) {
MULTI_PV.store(pv, Ordering::Relaxed);
}

pub fn get_multi_pv() -> usize {
max(1, MULTI_PV.load(Ordering::Relaxed))
}

pub fn set_cpuct(c: f32) {
let mut cp = CPUCT.write().unwrap();
*cp = c;
Expand Down
26 changes: 1 addition & 25 deletions src/search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -327,32 +327,8 @@ impl Search {
);
}

pub fn principal_variation(&self, num_moves: usize) -> Vec<Move> {
self.search_tree
.principal_variation(num_moves)
.into_iter()
.map(MoveEdge::get_move)
.copied()
.collect()
}

pub fn best_move(&self) -> Move {
*self.principal_variation(1).first().unwrap()
}

pub fn best_move_by_visits(&self) -> Move {
let root_node = self.search_tree.root_node();
let root_moves = root_node.hots();

let mut best = &root_moves[0];

for mov in root_moves.iter().skip(1) {
if mov.visits() > best.visits() {
best = mov;
}
}

*best.get_move()
self.search_tree.best_move()
}
}

Expand Down
127 changes: 76 additions & 51 deletions src/search_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::arena::Error as ArenaError;
use crate::chess;
use crate::evaluation::{self, Flag};
use crate::options::{
get_cpuct, get_cpuct_tau, get_cvisits_selection, get_policy_temperature,
get_cpuct, get_cpuct_tau, get_cvisits_selection, get_multi_pv, get_policy_temperature,
get_policy_temperature_root,
};
use crate::search::{eval_in_cp, ThreadData};
Expand Down Expand Up @@ -99,6 +99,24 @@ impl PositionNode {
h.child.store(null_mut(), Ordering::SeqCst);
}
}

pub fn select_child_by_rewards(&self) -> &MoveEdge {
let children = self.hots();

let mut best = &children[0];
let mut best_reward = best.average_reward().unwrap_or(-SCALE);

for child in children.iter().skip(1) {
let reward = child.average_reward().unwrap_or(-SCALE);

if reward > best_reward {
best = child;
best_reward = reward;
}
}

best
}
}

impl MoveEdge {
Expand Down Expand Up @@ -452,21 +470,12 @@ impl SearchTree {
&self.root_node
}

pub fn principal_variation(&self, num_moves: usize) -> Vec<&MoveEdge> {
let mut result = Vec::new();
let mut crnt = &self.root_node;
while !crnt.hots().is_empty() && result.len() < num_moves {
let choice = select_child_after_search(crnt.hots());
result.push(choice);
let child = choice.child.load(Ordering::SeqCst).cast_const();
if child.is_null() {
break;
}
unsafe {
crnt = &*child;
}
}
result
pub fn best_move(&self) -> chess::Move {
*self.best_edge().get_move()
}

pub fn best_edge(&self) -> &MoveEdge {
sort_moves(self.root_node.hots())[0]
}

pub fn print_info(&self, time_management: &TimeManagement) {
Expand All @@ -475,44 +484,62 @@ impl SearchTree {
let nodes = self.num_nodes();
let depth = self.depth();
let sel_depth = self.max_depth();
let pv = self.principal_variation(depth.max(1));
let pv_string: String = pv.into_iter().fold(String::new(), |mut out, x| {
write!(out, " {}", x.get_move().to_uci()).unwrap();
out
});

let nps = if search_time_ms == 0 {
nodes
} else {
nodes * 1000 / search_time_ms as usize
};

let info_str = format!(
"info depth {} seldepth {} nodes {} nps {} tbhits {} score {} time {} pv{}",
depth.max(1),
sel_depth.max(1),
nodes,
nps,
self.tb_hits(),
self.eval_in_cp(),
search_time_ms,
pv_string,
);
println!("{info_str}");
}
let moves = sort_moves(self.root_node.hots());

pub fn eval(&self) -> f32 {
self.principal_variation(1)
.first()
.map_or(0., |x| x.average_reward().unwrap_or(-SCALE) / SCALE)
for (idx, edge) in moves.iter().enumerate().take(get_multi_pv()) {
let pv = match edge.child() {
Some(child) => principal_variation(child, depth.max(1) - 1),
None => vec![],
};

let pv_string: String = pv.into_iter().fold(edge.get_move().to_uci(), |mut out, x| {
write!(out, " {}", x.get_move().to_uci()).unwrap();
out
});

let eval = eval_in_cp(edge.average_reward().unwrap_or(-SCALE) / SCALE);

let info_str = format!(
"info depth {} seldepth {} nodes {} nps {} tbhits {} score {} time {} multipv {} pv {}",
depth.max(1),
sel_depth.max(1),
nodes,
nps,
self.tb_hits(),
eval,
search_time_ms,
idx + 1,
pv_string,
);
println!("{info_str}");
}
}
}

fn eval_in_cp(&self) -> String {
eval_in_cp(self.eval())
fn principal_variation(from: &PositionNode, num_moves: usize) -> Vec<&MoveEdge> {
let mut result = Vec::with_capacity(num_moves);
let mut crnt = from;

while !crnt.hots().is_empty() && result.len() < num_moves {
let choice = crnt.select_child_by_rewards();
result.push(choice);

match choice.child() {
Some(child) => crnt = child,
None => break,
}
}

result
}

fn select_child_after_search(children: &[MoveEdge]) -> &MoveEdge {
fn sort_moves(children: &[MoveEdge]) -> Vec<&MoveEdge> {
let k = get_cvisits_selection();

let reward = |child: &MoveEdge| {
Expand All @@ -527,18 +554,16 @@ fn select_child_after_search(children: &[MoveEdge]) -> &MoveEdge {
sum_rewards as f32 / visits as f32 - (k * 2. * SCALE) / (visits as f32).sqrt()
};

let mut best = &children[0];
let mut best_reward = reward(best);
let mut result = Vec::with_capacity(children.len());

for child in children.iter().skip(1) {
let reward = reward(child);
if reward > best_reward {
best = child;
best_reward = reward;
}
for child in children {
result.push((child, reward(child)));
}

best
result.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
result.reverse();

result.into_iter().map(|x| x.0).collect()
}

pub fn print_size_list() {
Expand Down
3 changes: 1 addition & 2 deletions src/train/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,7 @@ impl From<&SearchTree> for TrainingPosition {

assert!(max_visits == Self::MAX_VISITS);

let pv = tree.principal_variation(1);
let pv = pv.first().unwrap();
let pv = tree.best_edge();
let mut evaluation = match pv.visits() {
0 => 0,
v => pv.sum_rewards() / i64::from(v),
Expand Down
4 changes: 3 additions & 1 deletion src/uci.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::io::stdin;
use std::str::{FromStr, SplitWhitespace};

use crate::options::{
set_chess960, set_cpuct, set_cpuct_tau, set_cvisits_selection, set_hash_size_mb,
set_chess960, set_cpuct, set_cpuct_tau, set_cvisits_selection, set_hash_size_mb, set_multi_pv,
set_num_threads, set_policy_only, set_policy_temperature, set_policy_temperature_root,
};
use crate::search::Search;
Expand Down Expand Up @@ -75,6 +75,7 @@ pub fn uci() {
println!("id author {ENGINE_AUTHOR}");
println!("option name Hash type spin min 8 max 65536 default 16");
println!("option name Threads type spin min 1 max 255 default 1");
println!("option name MultiPV type spin min 1 max 255 default 1");
println!("option name SyzygyPath type string default <empty>");
println!("option name CPuct type string default 0.16");
println!("option name CPuctTau type string default 0.84");
Expand Down Expand Up @@ -143,6 +144,7 @@ impl UciOption {
self.set_option(set_hash_size_mb);
search.reset_table();
}
"multipv" => self.set_option(set_multi_pv),
"cpuct" => self.set_option(set_cpuct),
"cpucttau" => self.set_option(set_cpuct_tau),
"cvisitsselection" => self.set_option(set_cvisits_selection),
Expand Down
Loading