Skip to content

Commit

Permalink
📦 NEW: Add self play data gen and new state network
Browse files Browse the repository at this point in the history
  • Loading branch information
ianagbip1oti authored Jan 21, 2024
1 parent 71950f3 commit 06d7955
Show file tree
Hide file tree
Showing 34 changed files with 116,286 additions and 73,871 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,6 @@ jobs:
- uses: actions-rs/toolchain@v1
with:
toolchain: stable
- run: cargo rustc --release -- -C target-feature=+crt-static -C target-cpu=$TARGET_CPU
- run: cargo rustc --release --bin princhess -- -C target-feature=+crt-static -C target-cpu=$TARGET_CPU
env:
TARGET_CPU: ${{ matrix.cpu }}
2 changes: 1 addition & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ jobs:
- name: Set version
run: sed -i "s/^version = .*/version = \"$(git describe --tags --dirty --always)\"/" Cargo.toml
shell: bash
- run: cargo rustc --release -- -C target-feature=+crt-static -C target-cpu=$TARGET_CPU
- run: cargo rustc --release --bin princhess -- -C target-feature=+crt-static -C target-cpu=$TARGET_CPU
env:
TARGET_CPU: ${{ matrix.cpu }}
- run: ls target/release
Expand Down
42 changes: 7 additions & 35 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,38 +1,10 @@
.DS_Store
**/*.rs.bk
*.h5
*.libsvm
*.log
*.pb
*.sqlite
*_features.txt
*debug.txt
.env
.token
/target/
bin/princhess
builds/
engine/*
expanded_feature_list.txt
feature_whitelist.txt
lichess*.pgn
old-versions/*
out/
pgn/
policy_key.txt
run_goose.sh
syzygy
train/*_data*/*
train/*model
train/logs
train/.env
train/__pycache__
train/pgn/*
train/pgn.all
train_data*
train/*_weights
train/*_bias
train/current
train/state_tuning
train/tuning_plots
train/tuning_*/
/target
/builds
/data
/nets
/pgn
/syzygy
/train
34 changes: 34 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ arc-swap = "=1.6.0"
arrayvec = "=0.7.4"
dashmap = "=5.5.3"
fastapprox = "=0.3.1"
goober = { git = "https://github.com/jw1912/goober/", rev = "30ded2d" }
nohash-hasher = "=0.2.0"
memmap = "=0.7.0"
once_cell = "=1.19.0"
Expand Down
270 changes: 270 additions & 0 deletions src/bin/data-gen.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
use princhess::chess::{Board, Move};
use princhess::math::Rng;
use princhess::options::set_hash_size_mb;
use princhess::search::Search;
use princhess::state::State;
use princhess::tablebase::{self, Wdl};
use princhess::train::TrainingPosition;
use princhess::transposition_table::LRTable;

use std::fmt::{Display, Formatter};
use std::fs::File;
use std::io::{self, BufWriter, Write};
use std::sync::atomic::AtomicU64;
use std::sync::Arc;
use std::thread;
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};

const THREADS: usize = 6;
const DATA_WRITE_RATE: usize = 16384;
const PLAYOUTS_PER_MOVE: usize = 2000;
const DFRC_PCT: u64 = 10;

struct Stats {
start: Instant,
games: AtomicU64,
positions: AtomicU64,
skipped: AtomicU64,
white_wins: AtomicU64,
black_wins: AtomicU64,
draws: AtomicU64,
blunders: AtomicU64,
}

impl Stats {
pub fn zero() -> Self {
Self {
start: Instant::now(),
games: AtomicU64::new(0),
positions: AtomicU64::new(0),
skipped: AtomicU64::new(0),
white_wins: AtomicU64::new(0),
black_wins: AtomicU64::new(0),
draws: AtomicU64::new(0),
blunders: AtomicU64::new(0),
}
}

pub fn inc_games(&self) {
self.games
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}

pub fn inc_positions(&self) {
self.positions
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}

pub fn inc_skipped(&self) {
self.skipped
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}

pub fn inc_white_wins(&self) {
self.white_wins
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}

pub fn inc_black_wins(&self) {
self.black_wins
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}

pub fn inc_draws(&self) {
self.draws
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}

pub fn inc_blunders(&self) {
self.blunders
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
}

impl Display for Stats {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let games = self.games.load(std::sync::atomic::Ordering::Relaxed);
let white_wins = self.white_wins.load(std::sync::atomic::Ordering::Relaxed);
let draws = self.draws.load(std::sync::atomic::Ordering::Relaxed);
let black_wins = self.black_wins.load(std::sync::atomic::Ordering::Relaxed);
let positions = self.positions.load(std::sync::atomic::Ordering::Relaxed);
let skipped = self.skipped.load(std::sync::atomic::Ordering::Relaxed);
let blunders = self.blunders.load(std::sync::atomic::Ordering::Relaxed);
let seconds = self.start.elapsed().as_secs().max(1);

write!(
f,
"G: {:>7} | +{:>2} ={:>2} -{:>2} % | Bls: {:>2}% | S: {:>5.3}% | Pos: {:>5.1}m ({:>3}/g) ({:>4}/s)",
games,
white_wins * 100 / games,
draws * 100 / games,
black_wins * 100 / games,
blunders * 100 / games,
skipped as f32 / positions as f32,
positions as f32 / 1000000.0,
positions / games,
positions / seconds
)
}
}

fn run_game(stats: &Stats, positions: &mut Vec<TrainingPosition>, rng: &mut Rng) {
let startpos = if rng.next_u64() % 100 < DFRC_PCT {
Board::dfrc(rng.next_usize() % 960, rng.next_usize() % 960)
} else {
Board::startpos()
};

let mut state = State::from_board(startpos);
let mut table = LRTable::empty();

let mut game_positions = Vec::with_capacity(256);

let mut prev_moves = [Move::NONE; 4];

for _ in 0..(8 + rng.next_u64() % 2) {
let moves = state.available_moves();

if moves.is_empty() {
return;
}

let index = rng.next_usize() % moves.len();
let best_move = moves[index];

state.make_move(best_move);
}

if !state.is_available_move() {
return;
}

let result = loop {
let search = Search::new(state.clone(), table);

search.playout_sync(PLAYOUTS_PER_MOVE);

let best_move = search.best_move();

let legal_moves = search.root_node().hots().len();

if legal_moves > TrainingPosition::MAX_MOVES {
stats.inc_skipped();
} else {
let mut position = TrainingPosition::from(search.tree());

if position.evaluation() > 0.95 {
break 1;
} else if position.evaluation() < -0.95 {
break -1;
} else if let Some(wdl) = tablebase::probe_wdl(state.board()) {
let result = match wdl {
Wdl::Win => 1,
Wdl::Draw => 0,
Wdl::Loss => -1,
};

break state.side_to_move().fold(result, -result);
}

position.set_previous_moves(prev_moves);
game_positions.push(position);
stats.inc_positions();
}

state.make_move(best_move);

prev_moves.rotate_right(1);
prev_moves[0] = best_move;

if !state.is_available_move() {
break if state.is_check() {
// The stm has been checkmated. Convert to white relative result
state.side_to_move().fold(-1, 1)
} else {
0
};
}

if state.drawn_by_fifty_move_rule()
|| state.is_repetition()
|| state.board().is_insufficient_material()
{
break 0;
}

table = search.table();
};

let mut blunder = false;

for position in game_positions.iter_mut() {
position.set_result(result);

blunder |= match result {
1 => position.evaluation() < -0.5,
-1 => position.evaluation() > 0.5,
0 => position.evaluation().abs() > 0.75,
_ => unreachable!(),
}
}

positions.append(&mut game_positions);

if blunder {
stats.inc_blunders();
}

stats.inc_games();

if result == 1 {
stats.inc_white_wins();
} else if result == -1 {
stats.inc_black_wins();
} else {
stats.inc_draws();
}
}

fn main() {
set_hash_size_mb(128);

tablebase::set_tablebase_directory("syzygy");

let stats = Arc::new(Stats::zero());

let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();

thread::scope(|s| {
for t in 0..THREADS {
thread::sleep(Duration::from_millis(10 * t as u64));

let mut writer = BufWriter::new(
File::create(format!("data/princhess-{timestamp}-{t}.data")).unwrap(),
);

let stats = stats.clone();
let mut rng = Rng::default();

s.spawn(move || loop {
let mut positions = Vec::new();

while positions.len() < DATA_WRITE_RATE {
run_game(&stats, &mut positions, &mut rng);
}

TrainingPosition::write_batch(&mut writer, &positions).unwrap();

if t == 0 {
print!("{}\r", stats);
io::stdout().flush().unwrap();
}
});
}
});

println!();
}
Loading

0 comments on commit 06d7955

Please sign in to comment.