Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Phoenix komi+ #40

Open
wants to merge 5 commits into
base: phoenix-komi
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/GTP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,12 @@ int cfg_max_visits;
TimeManagement::enabled_t cfg_timemanage;
int cfg_lagbuffer_cs;
int cfg_resignpct;

float cfg_max_wr;
float cfg_min_wr;
float cfg_mid_wr;
float cfg_adj_playouts;

int cfg_noise;
int cfg_random_cnt;
int cfg_random_min_visits;
Expand Down Expand Up @@ -107,6 +110,7 @@ void GTP::setup_default_parameters() {
cfg_max_wr = 0.12;
cfg_min_wr = 0.05;
cfg_mid_wr = 0.10;
cfg_adj_playouts = 100;
cfg_noise = false;
cfg_random_cnt = 0;
cfg_random_min_visits = 1;
Expand Down
3 changes: 3 additions & 0 deletions src/GTP.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,12 @@ extern int cfg_max_visits;
extern TimeManagement::enabled_t cfg_timemanage;
extern int cfg_lagbuffer_cs;
extern int cfg_resignpct;

extern float cfg_max_wr;
extern float cfg_min_wr;
extern float cfg_mid_wr;
extern float cfg_adj_playouts;

extern int cfg_noise;
extern int cfg_random_cnt;
extern int cfg_random_min_visits;
Expand Down
7 changes: 6 additions & 1 deletion src/Leela.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ static void parse_commandline(int argc, char *argv[]) {
"-m0 -t1 -s1.")
("max-wr", po::value<float>()->default_value(cfg_max_wr), "Maximal white winrate.")
("min-wr", po::value<float>()->default_value(cfg_min_wr), "Minimal white winrate.")
("mid-wr", po::value<float>()->default_value(cfg_mid_wr), "Ideal white winrate.")
("mid-wr", po::value<float>()->default_value(cfg_mid_wr), "Target white winrate.")
("adj-playouts", po::value<int>()->default_value(cfg_adj_playouts), "Number of playouts for komi adjustment.")
;
#ifdef USE_OPENCL
po::options_description gpu_desc("GPU options");
Expand Down Expand Up @@ -195,6 +196,10 @@ static void parse_commandline(int argc, char *argv[]) {
}
}

if (vm.count("adj-playouts")) {
cfg_adj_playouts = vm["adj-playouts"].as<int>();
}

#ifdef USE_TUNER
if (vm.count("puct")) {
cfg_puct = vm["puct"].as<float>();
Expand Down
12 changes: 9 additions & 3 deletions src/UCTNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ SMP::Mutex& UCTNode::get_mutex() {
bool UCTNode::create_children(std::atomic<int>& nodecount,
GameState& state,
float& eval,
float min_psa_ratio) {
float min_psa_ratio,
int symmetry) {
// check whether somebody beat us to it (atomic)
if (!expandable(min_psa_ratio)) {
return false;
Expand All @@ -77,8 +78,13 @@ bool UCTNode::create_children(std::atomic<int>& nodecount,
m_is_expanding = true;
lock.unlock();

const auto raw_netlist = Network::get_scored_moves(
&state, Network::Ensemble::RANDOM_SYMMETRY);
Network::Netresult raw_netlist;
if (symmetry == -1) {
raw_netlist = Network::get_scored_moves(&state, Network::Ensemble::RANDOM_SYMMETRY);
}
else {
raw_netlist = Network::get_scored_moves(&state, Network::Ensemble::DIRECT, symmetry);
}

// DCNN returns winrate as side to move
m_net_eval = raw_netlist.winrate;
Expand Down
7 changes: 5 additions & 2 deletions src/UCTNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
#include "SMP.h"
#include "UCTNodePointer.h"

class UCTSearch;

class UCTNode {
public:
// When we visit a node, add this amount of virtual losses
Expand All @@ -45,7 +47,8 @@ class UCTNode {

bool create_children(std::atomic<int>& nodecount,
GameState& state, float& eval,
float min_psa_ratio = 0.0f);
float min_psa_ratio = 0.0f,
int symmetry = -1);

const std::vector<UCTNodePointer>& get_children() const;
void sort_children(int color);
Expand Down Expand Up @@ -77,7 +80,7 @@ class UCTNode {
void randomize_first_proportionally();
void prepare_root_node(int color,
std::atomic<int>& nodecount,
GameState& state);
GameState& state, UCTSearch * search);

UCTNode* get_first_child() const;
UCTNode* get_nopass_child(FastState& state) const;
Expand Down
81 changes: 52 additions & 29 deletions src/UCTNodeRoot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <vector>

#include "UCTNode.h"
#include "UCTSearch.h"
#include "FastBoard.h"
#include "FastState.h"
#include "KoState.h"
Expand Down Expand Up @@ -183,7 +184,8 @@ const int cfg_steps = 8;
const float target_komi = 7.5f;

float white_net_eval(GameState root_state) {
auto net_eval = Network::get_scored_moves(&root_state, Network::Ensemble::AVERAGE, 8, true).winrate;
//auto net_eval = Network::get_scored_moves(&root_state, Network::Ensemble::AVERAGE, 8, true).winrate;
auto net_eval = Network::get_scored_moves(&root_state, Network::Ensemble::DIRECT, 0, true).winrate;
if (root_state.get_to_move() == FastBoard::WHITE) {
return net_eval;
}
Expand All @@ -192,11 +194,15 @@ float white_net_eval(GameState root_state) {
}
}

void binary_search_komi(GameState& root_state, float factor, float high, float low, int steps) {
float inv_wr(float wr) {
return -log(1.0 / wr - 1.0) / 2.0;
}

void binary_search_komi(GameState& root_state, float shift, float high, float low, int steps) {
while (steps-- > 0) {
root_state.m_komi = (high + low) / 2.0;
auto net_eval = white_net_eval(root_state);
if (net_eval * factor > cfg_mid_wr) {
if (inv_wr(net_eval) + shift > inv_wr(cfg_mid_wr)) {
high = root_state.m_komi;
}
else {
Expand All @@ -206,49 +212,50 @@ void binary_search_komi(GameState& root_state, float factor, float high, float l
root_state.m_komi = low;
}

void adjust_up_komi(GameState& root_state, float factor) {
void adjust_up_komi(GameState& root_state, float shift) {
float net_eval;
do {
root_state.m_komi = 2.0f * root_state.m_komi - (target_komi - 7.5f);
net_eval = white_net_eval(root_state);
} while (net_eval * factor < cfg_mid_wr);
binary_search_komi(root_state, factor, root_state.m_komi, (root_state.m_komi + target_komi - 7.5f) / 2.0f, cfg_steps);
} while (inv_wr(net_eval) + shift < inv_wr(cfg_mid_wr));
binary_search_komi(root_state, shift, root_state.m_komi, (root_state.m_komi + target_komi - 7.5f) / 2.0f, cfg_steps);
}

void adjust_down_komi(GameState& root_state, float factor) {
void adjust_down_komi(GameState& root_state, float shift) {
auto komi = root_state.m_komi;
root_state.m_komi = target_komi;
auto net_eval = white_net_eval(root_state);
if (net_eval * factor < cfg_mid_wr) {
binary_search_komi(root_state, factor, komi, target_komi, cfg_steps);
if (inv_wr(net_eval) + shift < inv_wr(cfg_mid_wr)) {
binary_search_komi(root_state, shift, komi, target_komi, cfg_steps);
}
}

void adjust_komi(GameState& root_state, bool opp) { //, float root_eval) {
auto root_eval = white_net_eval(root_state);
void adjust_komi(GameState& root_state, float root_eval, bool opp) { //, float root_eval) {
auto net_eval = white_net_eval(root_state);
auto shift = inv_wr(root_eval) - inv_wr(net_eval); //0.0f; root_eval = net_eval;
Utils::myprintf("%f, %f, %f\n", root_eval, net_eval, shift);
if (opp) {
if (root_eval < cfg_mid_wr) {
adjust_up_komi(root_state, 1.0f);
adjust_up_komi(root_state, shift);
}
else if (root_eval > cfg_mid_wr) {
adjust_down_komi(root_state, 1.0f);
adjust_down_komi(root_state, shift);
}
}
else {
if (root_eval < cfg_min_wr) {
//auto net_eval = white_net_eval(root_state);
adjust_up_komi(root_state, 1.0f); // root_eval / net_eval);
adjust_up_komi(root_state, shift);
}
else if (root_state.m_komi != target_komi && root_eval > cfg_max_wr) {
//auto net_eval = white_net_eval(root_state);
adjust_down_komi(root_state, 1.0f); // root_eval / net_eval);
adjust_down_komi(root_state, shift);
}
}
}

void UCTNode::prepare_root_node(int color,
std::atomic<int>& nodes,
GameState& root_state) {
GameState& root_state, UCTSearch * search) {
float root_eval;
const auto had_children = has_children();
if (expandable()) {
Expand All @@ -260,17 +267,6 @@ void UCTNode::prepare_root_node(int color,
update(root_eval);
root_eval = (color == FastBoard::BLACK ? root_eval : 1.0f - root_eval);
}
auto komi = root_state.m_komi;
adjust_komi(root_state, false);
if (komi != root_state.m_komi) {
NNCache::get_NNCache().clear_cache();
m_visits = 0;
m_blackevals = 0.0;
m_min_psa_ratio_children = 2.0;
m_children.clear();
create_children(nodes, root_state, root_eval);
root_eval = (color == FastBoard::BLACK ? root_eval : 1.0f - root_eval);
}

// There are a lot of special cases where code assumes
// all children of the root are inflated, so do that.
Expand All @@ -280,15 +276,42 @@ void UCTNode::prepare_root_node(int color,
// This also removes a lot of special cases.
kill_superkos(root_state);

// if playouts not enough or if enough and winrate need to be adjusted
if (m_visits < cfg_adj_playouts) {
search->adjusting = true;
search->sym_states.reserve(cfg_adj_playouts - m_visits);
for (int i = 0; i < cfg_adj_playouts - m_visits; i++) {
auto currstate = std::make_unique<GameState>(root_state);
search->play_simulation(*currstate, this);
}
}
root_eval = get_pure_eval(color);
auto white_root_eval = (color == FastBoard::BLACK ? 1.0 - root_eval : root_eval);
//
auto komi = root_state.m_komi;
adjust_komi(root_state, white_root_eval, false);
search->sym_states.clear();
if (komi != root_state.m_komi) {
NNCache::get_NNCache().clear_cache();
m_visits = 0;
m_blackevals = 0.0;
m_min_psa_ratio_children = 2.0;
m_children.clear();

float tmp_root_eval;
create_children(nodes, root_state, tmp_root_eval);
//Utils::myprintf("tmp root eval=%f\n", tmp_root_eval);
inflate_all_children();
kill_superkos(root_state);

komi = root_state.m_opp_komi;
if (root_state.m_komi == target_komi) {
root_state.m_opp_komi = target_komi;
}
else {
GameState tmpstate = root_state;
tmpstate.play_move(get_first_child()->get_move());
adjust_komi(tmpstate, true);
adjust_komi(tmpstate, white_root_eval, true);
root_state.m_opp_komi = tmpstate.m_komi;
}
if (komi != root_state.m_opp_komi) {
Expand Down
30 changes: 25 additions & 5 deletions src/UCTSearch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include "FullBoard.h"
#include "GTP.h"
#include "GameState.h"
#include "Random.h"
#include "TimeControl.h"
#include "Timing.h"
#include "Training.h"
Expand Down Expand Up @@ -188,6 +189,7 @@ float UCTSearch::get_min_psa_ratio() const {
return 0.0f;
}


SearchResult UCTSearch::play_simulation(GameState & currstate,
UCTNode* const node) {
const auto color = currstate.get_to_move();
Expand All @@ -203,9 +205,27 @@ SearchResult UCTSearch::play_simulation(GameState & currstate,
} else if (m_nodes < MAX_TREE_SIZE) {
float eval;
const auto had_children = node->has_children();
const auto success =
node->create_children(m_nodes, currstate, eval,
get_min_psa_ratio());
///*
bool success;
if (adjusting) {
const auto rand_sym = Random::get_Rng().randfix<8>();
success = node->create_children(m_nodes, currstate, eval, get_min_psa_ratio(), rand_sym);
if (success) {
Sym_State sym_state;
sym_state.symmetry = rand_sym;
sym_state.state = currstate;
//LOCK(get_mutex(), lock);
if (sym_states.size() < cfg_adj_playouts) {
sym_states.emplace_back(sym_state);
}
}
}
else {
success = node->create_children(m_nodes, currstate, eval, get_min_psa_ratio());
}
//*/
//const auto success = node->create_children(m_nodes, currstate, eval, get_min_psa_ratio());

if (!had_children && success) {
result = SearchResult::from_eval(eval);
}
Expand Down Expand Up @@ -683,7 +703,7 @@ int UCTSearch::think(int color, passflag_t passflag) {

// create a sorted list of legal moves (make sure we
// play something legal and decent even in time trouble)
m_root->prepare_root_node(color, m_nodes, m_rootstate);
m_root->prepare_root_node(color, m_nodes, m_rootstate, this);

m_run = true;
int cpus = cfg_num_threads;
Expand Down Expand Up @@ -762,7 +782,7 @@ void UCTSearch::ponder() {
update_root();

m_root->prepare_root_node(m_rootstate.board.get_to_move(),
m_nodes, m_rootstate);
m_nodes, m_rootstate, this);

m_run = true;
ThreadGroup tg(thread_pool);
Expand Down
7 changes: 7 additions & 0 deletions src/UCTSearch.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ namespace TimeManagement {
};
};

struct Sym_State {
int symmetry;
GameState state;
};

class UCTSearch {
public:
/*
Expand Down Expand Up @@ -99,6 +104,8 @@ class UCTSearch {
bool is_running() const;
void increment_playouts();
SearchResult play_simulation(GameState& currstate, UCTNode* const node);
bool adjusting;
std::vector<Sym_State> sym_states;

private:
float get_min_psa_ratio() const;
Expand Down
2 changes: 1 addition & 1 deletion src/config.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/*
/*
This file is part of Leela Zero.
Copyright (C) 2017-2018 Gian-Carlo Pascutto and contributors

Expand Down