diff --git a/Cargo.toml b/Cargo.toml index bb76cb1..2c1bd9f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,7 @@ path = "src/main.rs" [dependencies] arrayvec = "=0.7.4" -bytemuck = { version = "=1.17.1", features = ["derive", "extern_crate_alloc", "min_const_generics"] } +bytemuck = { version = "=1.17.1", features = ["derive", "extern_crate_alloc", "min_const_generics", "must_cast"] } fastapprox = "=0.3.1" goober = { git = "https://github.com/jw1912/goober/", rev = "32b9b52" } nohash-hasher = "=0.2.0" diff --git a/bin/first_moves.epd b/bin/first_moves.epd deleted file mode 100644 index d03ba32..0000000 --- a/bin/first_moves.epd +++ /dev/null @@ -1,10 +0,0 @@ -rnbqkbnr/pp1ppppp/8/2p5/4P3/8/PPPP1PPP/RNBQKBNR w KQkq - 0 2 -rnbqkbnr/pppp1ppp/8/4p3/4P3/8/PPPP1PPP/RNBQKBNR w KQkq - 0 2 -rnbqkbnr/pppp1ppp/4p3/8/4P3/8/PPPP1PPP/RNBQKBNR w KQkq - 0 2 -rnbqkbnr/pp1ppppp/2p5/8/4P3/8/PPPP1PPP/RNBQKBNR w KQkq - 0 2 -rnbqkbnr/ppp1pppp/3p4/8/4P3/8/PPPP1PPP/RNBQKBNR w KQkq - 0 2 -rnbqkb1r/pppppppp/5n2/8/3P4/8/PPP1PPPP/RNBQKBNR w KQkq - 1 2 -rnbqkbnr/ppp1pppp/8/3p4/3P4/8/PPP1PPPP/RNBQKBNR w KQkq - 0 2 -rnbqkb1r/pppppppp/5n2/8/8/5N2/PPPPPPPP/RNBQKB1R w KQkq - 2 2 -rnbqkbnr/ppp1pppp/8/3p4/8/5N2/PPPPPPPP/RNBQKB1R w KQkq - 0 2 -rnbqkb1r/pppppppp/5n2/8/2P5/8/PP1PPPPP/RNBQKBNR w KQkq - 1 2 diff --git a/bin/gm2001.bin b/bin/gm2001.bin deleted file mode 100644 index d1626aa..0000000 Binary files a/bin/gm2001.bin and /dev/null differ diff --git a/bin/tuning_config.json b/bin/tuning_config.json index 1fdc1c6..6d23fb6 100644 --- a/bin/tuning_config.json +++ b/bin/tuning_config.json @@ -5,6 +5,7 @@ "fixed_parameters": { "Threads": 1, "Hash": 128, + "PolicyTemperatureRoot": 100, "SyzygyPath": "/syzygy" } }, @@ -18,9 +19,8 @@ } ], "parameter_ranges": { - "CPuct": "Real(0.0, 3.0)", - "CPuctTau": "Real(0.5, 1.0)", - "PolicyTemperatureRoot": "Real(1.0, 20.0)" + "CPuct": "Integer(1, 100)", + "CPuctTau": "Integer(50, 100)" }, "rounds": 15, "engine1_npm": "25000", diff --git a/src/bin/data-gen.rs b/src/bin/data-gen.rs index aacab60..93c4910 100644 --- a/src/bin/data-gen.rs +++ b/src/bin/data-gen.rs @@ -1,41 +1,67 @@ use princhess::chess::{Board, Move}; +use princhess::evaluation; use princhess::math::Rng; -use princhess::options::SearchOptions; +use princhess::options::{MctsOptions, SearchOptions}; use princhess::search::Search; use princhess::state::State; use princhess::tablebase::{self, Wdl}; use princhess::train::TrainingPosition; use princhess::transposition_table::LRTable; +use bytemuck::allocation; +use std::collections::HashSet; use std::fmt::{Display, Formatter}; use std::fs::File; use std::io::{self, BufWriter, Write}; -use std::sync::atomic::{AtomicU64, AtomicUsize}; +use std::ops::Neg; +use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering}; use std::sync::Arc; use std::thread; use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; const HASH_SIZE_MB: usize = 128; -const THREADS: usize = 6; -const DATA_WRITE_RATE: usize = 16384; +const VISITS_PER_POSITION: u64 = 3000; +const THREADS: u64 = 5; const DFRC_PCT: u64 = 10; +const MAX_POSITIONS_PER_FILE: u64 = 20_000_000; +const MAX_POSITIONS_TOTAL: u64 = MAX_POSITIONS_PER_FILE * THREADS; + +const CPUCT: f32 = 2.82; +const CPUCT_TAU: f32 = 0.5; +const POLICY_TEMPERATURE: f32 = 1.0; +const POLICY_TEMPERATURE_ROOT: f32 = 1.1; +const MAX_VARIATIONS: usize = 16; + struct Stats { start: Instant, games: AtomicU64, positions: AtomicU64, skipped: AtomicU64, + aborted: AtomicU64, white_wins: AtomicU64, black_wins: AtomicU64, draws: AtomicU64, blunders: AtomicU64, + variations: AtomicUsize, nodes: AtomicUsize, playouts: AtomicUsize, depth: AtomicUsize, seldepth: AtomicUsize, } +struct GameStats { + pub positions: u64, + pub skipped: u64, + pub blunders: u64, + pub variations: usize, + pub nodes: usize, + pub playouts: usize, + pub depth: usize, + pub seldepth: usize, +} + impl Stats { pub fn zero() -> Self { Self { @@ -43,10 +69,12 @@ impl Stats { games: AtomicU64::new(0), positions: AtomicU64::new(0), skipped: AtomicU64::new(0), + aborted: AtomicU64::new(0), white_wins: AtomicU64::new(0), black_wins: AtomicU64::new(0), draws: AtomicU64::new(0), blunders: AtomicU64::new(0), + variations: AtomicUsize::new(0), nodes: AtomicUsize::new(0), playouts: AtomicUsize::new(0), depth: AtomicUsize::new(0), @@ -54,241 +82,308 @@ impl Stats { } } - pub fn inc_games(&self) { - self.games - .fetch_add(1, std::sync::atomic::Ordering::Relaxed); - } - - pub fn inc_positions(&self) { - self.positions - .fetch_add(1, std::sync::atomic::Ordering::Relaxed); - } - - pub fn inc_skipped(&self) { - self.skipped - .fetch_add(1, std::sync::atomic::Ordering::Relaxed); - } - - pub fn inc_white_wins(&self) { - self.white_wins - .fetch_add(1, std::sync::atomic::Ordering::Relaxed); - } - - pub fn inc_black_wins(&self) { - self.black_wins - .fetch_add(1, std::sync::atomic::Ordering::Relaxed); - } - - pub fn inc_draws(&self) { - self.draws - .fetch_add(1, std::sync::atomic::Ordering::Relaxed); - } - - pub fn inc_blunders(&self) { - self.blunders - .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + fn add_game(&self, game: &GameStats) { + self.games.fetch_add(1, Ordering::Relaxed); + self.positions.fetch_add(game.positions, Ordering::Relaxed); + self.skipped.fetch_add(game.skipped, Ordering::Relaxed); + self.blunders.fetch_add(game.blunders, Ordering::Relaxed); + self.variations + .fetch_add(game.variations, Ordering::Relaxed); + self.nodes.fetch_add(game.nodes, Ordering::Relaxed); + self.playouts.fetch_add(game.playouts, Ordering::Relaxed); + self.depth.fetch_add(game.depth, Ordering::Relaxed); + self.seldepth.fetch_add(game.seldepth, Ordering::Relaxed); } - pub fn plus_nodes(&self, nodes: usize) { - self.nodes - .fetch_add(nodes, std::sync::atomic::Ordering::Relaxed); + pub fn add_white_win(&self, game: &GameStats) { + self.add_game(game); + self.white_wins.fetch_add(1, Ordering::Relaxed); } - pub fn plus_playouts(&self, playouts: usize) { - self.playouts - .fetch_add(playouts, std::sync::atomic::Ordering::Relaxed); + pub fn add_black_win(&self, game: &GameStats) { + self.add_game(game); + self.black_wins.fetch_add(1, Ordering::Relaxed); } - pub fn plus_depth(&self, depth: usize) { - self.depth - .fetch_add(depth, std::sync::atomic::Ordering::Relaxed); + pub fn add_draw(&self, game: &GameStats) { + self.add_game(game); + self.draws.fetch_add(1, Ordering::Relaxed); } - pub fn plus_seldepth(&self, seldepth: usize) { - self.seldepth - .fetch_add(seldepth, std::sync::atomic::Ordering::Relaxed); + pub fn add_aborted(&self) { + self.aborted.fetch_add(1, Ordering::Relaxed); } } impl Display for Stats { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - let games = self.games.load(std::sync::atomic::Ordering::Relaxed); - let white_wins = self.white_wins.load(std::sync::atomic::Ordering::Relaxed); - let draws = self.draws.load(std::sync::atomic::Ordering::Relaxed); - let black_wins = self.black_wins.load(std::sync::atomic::Ordering::Relaxed); - let positions = self.positions.load(std::sync::atomic::Ordering::Relaxed); - let skipped = self.skipped.load(std::sync::atomic::Ordering::Relaxed); - let blunders = self.blunders.load(std::sync::atomic::Ordering::Relaxed); - let nodes = self.nodes.load(std::sync::atomic::Ordering::Relaxed); - let playouts = self.playouts.load(std::sync::atomic::Ordering::Relaxed); - let depth = self.depth.load(std::sync::atomic::Ordering::Relaxed); - let seldepth = self.seldepth.load(std::sync::atomic::Ordering::Relaxed); + let games = self.games.load(Ordering::Relaxed); + let white_wins = self.white_wins.load(Ordering::Relaxed); + let draws = self.draws.load(Ordering::Relaxed); + let black_wins = self.black_wins.load(Ordering::Relaxed); + let positions = self.positions.load(Ordering::Relaxed); + let skipped = self.skipped.load(Ordering::Relaxed); + let aborted = self.aborted.load(Ordering::Relaxed); + let blunders = self.blunders.load(Ordering::Relaxed); + let variations = self.variations.load(Ordering::Relaxed); + let nodes = self.nodes.load(Ordering::Relaxed); + let playouts = self.playouts.load(Ordering::Relaxed); + let depth = self.depth.load(Ordering::Relaxed); + let seldepth = self.seldepth.load(Ordering::Relaxed); let seconds = self.start.elapsed().as_secs().max(1); write!( f, - "G {:>7} | +{:>2} ={:>2} -{:>2} % | B {:>3.1}% | S {:>3.1}% | N {:>5} P {:>5} | D {:>2}/{:>2} | P {:>5.1}m ({:>3}/g, {:>4}/s, {:>3.1}m/h)", + "G {:>7} | +{:>2}={:>2}-{:>2} | B {:>4.1} V {:>4.1} S {:>3.1} X {:>4.1} | N {:>4} P {:>4} | D {:>2}/{:>2} | P {:>5.1}m ({:>2}/g, {:>3.1}m/h)", games, white_wins * 100 / games, draws * 100 / games, black_wins * 100 / games, blunders as f32 * 100. / games as f32, + variations as f32 * 100. / positions as f32, skipped as f32 * 100. / positions as f32, + aborted as f32 * 100. / games as f32, nodes / positions as usize, playouts / positions as usize, depth / positions as usize, seldepth / positions as usize, positions as f32 / 1000000.0, positions / games, - positions / seconds, (positions * 3600 / seconds) as f32 / 1000000.0 ) } } +impl GameStats { + pub fn zero() -> Self { + Self { + positions: 0, + skipped: 0, + blunders: 0, + variations: 0, + nodes: 0, + playouts: 0, + depth: 0, + seldepth: 0, + } + } +} + +#[derive(Clone, Copy, Debug)] +enum GameResult { + WhiteWin, + Draw, + BlackWin, + Aborted, +} + +impl From for i8 { + fn from(result: GameResult) -> Self { + match result { + GameResult::WhiteWin => 1, + GameResult::Draw => 0, + GameResult::BlackWin => -1, + GameResult::Aborted => unreachable!(), + } + } +} + +impl Neg for GameResult { + type Output = Self; + + fn neg(self) -> Self::Output { + match self { + GameResult::WhiteWin => GameResult::BlackWin, + GameResult::BlackWin => GameResult::WhiteWin, + GameResult::Draw => GameResult::Draw, + GameResult::Aborted => GameResult::Aborted, + } + } +} + fn run_game(stats: &Stats, positions: &mut Vec, rng: &mut Rng) { - let startpos = if rng.next_u64() % 100 < DFRC_PCT { - Board::dfrc(rng.next_usize() % 960, rng.next_usize() % 960) - } else { - Board::startpos() - }; + let mut variations = Vec::with_capacity(MAX_VARIATIONS + 1); + let mut variations_count = 0; + let mut seen_positions = HashSet::new(); - let mut state = State::from_board(startpos); - let mut table = LRTable::empty(HASH_SIZE_MB); + variations.push(random_start(rng)); - let search_options = SearchOptions::default(); + let mcts_options = MctsOptions { + cpuct: CPUCT, + cpuct_tau: CPUCT_TAU, + policy_temperature: POLICY_TEMPERATURE, + policy_temperature_root: POLICY_TEMPERATURE_ROOT, + }; - let mut game_positions = Vec::with_capacity(256); + let search_options = SearchOptions { + mcts_options, + ..SearchOptions::default() + }; - let mut prev_moves = [Move::NONE; 4]; + while let Some(mut state) = variations.pop() { + let mut game_stats = GameStats::zero(); + let mut table = LRTable::empty(HASH_SIZE_MB); - for _ in 0..(8 + rng.next_u64() % 2) { - let moves = state.available_moves(); + if !state.is_available_move() { + continue; + } - if moves.is_empty() { - return; + if state.drawn_by_fifty_move_rule() + || state.is_repetition() + || state.board().is_insufficient_material() + { + continue; } - let index = rng.next_usize() % moves.len(); - let best_move = moves[index]; + if tablebase::probe_wdl(state.board()).is_some() { + continue; + } - state.make_move(best_move); - } + let mut game_positions = Vec::with_capacity(256); - if !state.is_available_move() { - return; - } + let result = loop { + let search = Search::new(state.clone(), table, search_options); + let legal_moves = search.root_node().hots().len(); - let result = loop { - let search = Search::new(state.clone(), table, search_options); - let legal_moves = search.root_node().hots().len(); - - let mut max_visits = search - .root_node() - .hots() - .iter() - .map(|hot| hot.visits()) - .max() - .unwrap_or(0); - - if legal_moves > 1 { - while max_visits < TrainingPosition::MAX_VISITS { - search.playout_sync((TrainingPosition::MAX_VISITS - max_visits) as usize); - - max_visits = search - .root_node() - .hots() - .iter() - .map(|hot| hot.visits()) - .max() - .unwrap_or(0); + if legal_moves > 1 { + let visits = search.root_node().visits(); + search.playout_sync(VISITS_PER_POSITION.saturating_sub(visits)); } - } - let best_move = search.best_move(); + let best_move = search.best_move(); - if legal_moves <= 1 || legal_moves > TrainingPosition::MAX_MOVES { - stats.inc_skipped(); - } else { - let mut position = TrainingPosition::from(search.tree()); + if variations_count < MAX_VARIATIONS { + let varation = search.most_visited_move(); - if position.evaluation() > 0.95 { - break 1; - } else if position.evaluation() < -0.95 { - break -1; + if varation != best_move && state.phase() > 18 { + game_stats.variations += 1; + let mut state = state.clone(); + state.make_move(varation); + variations.push(state); + variations_count += 1; + } } - position.set_previous_moves(prev_moves); - game_positions.push(position); - stats.inc_positions(); - stats.plus_nodes(search.tree().num_nodes()); - stats.plus_playouts(search.tree().playouts()); - stats.plus_depth(search.tree().depth()); - stats.plus_seldepth(search.tree().max_depth()); - } + if legal_moves == 1 || legal_moves > TrainingPosition::MAX_MOVES { + game_stats.skipped += 1; + } else { + let position = TrainingPosition::from(search.tree()); - state.make_move(best_move); + if position.evaluation() > 0.95 { + break GameResult::WhiteWin; + } else if position.evaluation() < -0.95 { + break GameResult::BlackWin; + } - prev_moves.rotate_right(1); - prev_moves[0] = best_move; + game_positions.push(position); - if !state.is_available_move() { - break if state.is_check() { - // The stm has been checkmated. Convert to white relative result - state.side_to_move().fold(-1, 1) - } else { - 0 - }; - } + game_stats.positions += 1; + game_stats.nodes += search.tree().num_nodes(); + game_stats.playouts += search.tree().playouts(); + game_stats.depth += search.tree().depth(); + game_stats.seldepth += search.tree().max_depth(); + } - if state.drawn_by_fifty_move_rule() - || state.is_repetition() - || state.board().is_insufficient_material() - { - break 0; + state.make_move(best_move); + + if !state.is_available_move() { + break if state.is_check() { + // The stm has been checkmated. Convert to white relative result + state + .side_to_move() + .fold(GameResult::BlackWin, GameResult::WhiteWin) + } else { + GameResult::Draw + }; + } + + if state.drawn_by_fifty_move_rule() + || state.is_repetition() + || state.board().is_insufficient_material() + { + break GameResult::Draw; + } + + if let Some(wdl) = tablebase::probe_wdl(state.board()) { + let result = match wdl { + Wdl::Win => GameResult::WhiteWin, + Wdl::Draw => GameResult::Draw, + Wdl::Loss => GameResult::BlackWin, + }; + + break state.side_to_move().fold(result, -result); + } + + if !seen_positions.insert(state.hash()) { + game_positions.clear(); + break GameResult::Aborted; + } + + table = search.table(); + }; + + let mut blunder = false; + + for position in game_positions.iter_mut() { + position.set_result(i8::from(result)); + + blunder |= match result { + GameResult::WhiteWin => position.evaluation() < -0.5, + GameResult::BlackWin => position.evaluation() > 0.5, + GameResult::Draw => position.evaluation().abs() > 0.75, + GameResult::Aborted => false, + } } - if let Some(wdl) = tablebase::probe_wdl(state.board()) { - let result = match wdl { - Wdl::Win => 1, - Wdl::Draw => 0, - Wdl::Loss => -1, - }; + positions.append(&mut game_positions); - break state.side_to_move().fold(result, -result); + if blunder { + game_stats.blunders += 1; } - table = search.table(); + match result { + GameResult::WhiteWin => stats.add_white_win(&game_stats), + GameResult::Draw => stats.add_draw(&game_stats), + GameResult::BlackWin => stats.add_black_win(&game_stats), + GameResult::Aborted => stats.add_aborted(), + } + } +} + +fn random_start(rng: &mut Rng) -> State { + let startpos = if rng.next_u64() % 100 < DFRC_PCT { + Board::dfrc(rng.next_usize() % 960, rng.next_usize() % 960) + } else { + Board::startpos() }; - let mut blunder = false; + let mut state = State::from_board(startpos); + + for p in 0..16 { + let t = 1. + ((p as f32) / 8.).powi(2); - for position in game_positions.iter_mut() { - position.set_result(result); + let best_move = select_weighted_random_move(&state, t, rng); - blunder |= match result { - 1 => position.evaluation() < -0.5, - -1 => position.evaluation() > 0.5, - 0 => position.evaluation().abs() > 0.75, - _ => unreachable!(), + if best_move == Move::NONE { + return state; } + + state.make_move(best_move); } - positions.append(&mut game_positions); + state +} + +fn select_weighted_random_move(state: &State, t: f32, rng: &mut Rng) -> Move { + let moves = state.available_moves(); - if blunder { - stats.inc_blunders(); + if moves.is_empty() { + return Move::NONE; } - stats.inc_games(); + let policy = evaluation::policy(state, &moves, t); - if result == 1 { - stats.inc_white_wins(); - } else if result == -1 { - stats.inc_black_wins(); - } else { - stats.inc_draws(); - } + moves[rng.weighted(&policy)] } fn main() { @@ -303,31 +398,48 @@ fn main() { thread::scope(|s| { for t in 0..THREADS { - thread::sleep(Duration::from_millis(10 * t as u64)); + thread::sleep(Duration::from_millis(10 * t)); let mut writer = BufWriter::new( File::create(format!("data/princhess-{timestamp}-{t}.data")).unwrap(), ); - let stats = stats.clone(); - let mut rng = Rng::default(); + { + let stats = stats.clone(); + let mut rng = Rng::default(); - s.spawn(move || loop { - let mut positions = Vec::new(); + s.spawn(move || { + let mut positions = Vec::new(); - while positions.len() < DATA_WRITE_RATE { - run_game(&stats, &mut positions, &mut rng); - } + let mut buffer: Box<[TrainingPosition; TrainingPosition::BUFFER_COUNT]> = + allocation::zeroed_box(); - TrainingPosition::write_batch(&mut writer, &positions).unwrap(); + while stats.positions.load(Ordering::Relaxed) < MAX_POSITIONS_TOTAL { + while positions.len() < TrainingPosition::BUFFER_COUNT { + run_game(&stats, &mut positions, &mut rng); + } - if t == 0 { - print!("{}\r", stats); - io::stdout().flush().unwrap(); - } - }); + buffer.copy_from_slice( + positions.drain(..TrainingPosition::BUFFER_COUNT).as_slice(), + ); + + TrainingPosition::write_buffer(&mut writer, &buffer); + } + }); + } } + + let stats = stats.clone(); + + s.spawn(move || { + while stats.positions.load(Ordering::Relaxed) < MAX_POSITIONS_TOTAL { + thread::sleep(Duration::from_secs(1)); + print!("{}\r", stats); + io::stdout().flush().unwrap(); + } + println!("\nStopping..."); + }); }); - println!(); + println!("{}", stats); } diff --git a/src/bin/data-shuffle.rs b/src/bin/data-shuffle.rs index 421af0a..c79b327 100644 --- a/src/bin/data-shuffle.rs +++ b/src/bin/data-shuffle.rs @@ -1,9 +1,10 @@ use princhess::math::Rng; use princhess::train::TrainingPosition; +use bytemuck::allocation; use std::env; use std::fs::File; -use std::io::{self, BufWriter, Write}; +use std::io::{self, BufRead, BufReader, BufWriter, Write}; use std::time::Instant; fn main() { @@ -14,12 +15,30 @@ fn main() { let mut rng = Rng::default(); - for file in files { - let mut bytes = std::fs::read(file.clone()).unwrap(); - let positions = TrainingPosition::read_batch_mut(&mut bytes); + for input in files { + let file = File::open(input.clone()).unwrap(); + let mut positions = + Vec::with_capacity(file.metadata().unwrap().len() as usize / TrainingPosition::SIZE); + + let mut buffer = BufReader::with_capacity(TrainingPosition::BUFFER_SIZE, file); let start = Instant::now(); - println!("Shuffling {} positions from {}...", positions.len(), file); + + println!("Loading {}...", input); + + while let Ok(bytes) = buffer.fill_buf() { + if bytes.is_empty() { + break; + } + + let data = TrainingPosition::read_buffer(bytes); + positions.extend_from_slice(data); + + let consumed = bytes.len(); + buffer.consume(consumed); + } + + println!("Shuffling {} positions from {}...", positions.len(), input); for i in 0..positions.len() - 1 { let j = rng.next_usize() % (positions.len() - i); @@ -35,9 +54,17 @@ fn main() { } } - println!("\nDone ({}ms).", start.elapsed().as_millis()); + println!("Saving {}.shuffled...", input); + + let mut writer = BufWriter::new(File::create(format!("{}.shuffled", input)).unwrap()); + let mut buffer: Box<[TrainingPosition; TrainingPosition::BUFFER_COUNT]> = + allocation::zeroed_box(); + + while !positions.is_empty() { + buffer.copy_from_slice(positions.drain(..TrainingPosition::BUFFER_COUNT).as_slice()); + TrainingPosition::write_buffer(&mut writer, &buffer); + } - let mut writer = BufWriter::new(File::create(format!("{}.shuffled", file)).unwrap()); - TrainingPosition::write_batch(&mut writer, positions).unwrap(); + println!("Done ({}ms).", start.elapsed().as_millis()); } } diff --git a/src/bin/data-summary.rs b/src/bin/data-summary.rs index bb0d526..e1e0c5e 100644 --- a/src/bin/data-summary.rs +++ b/src/bin/data-summary.rs @@ -1,4 +1,4 @@ -use princhess::chess::MoveIndex; +use princhess::policy::MoveIndex; use princhess::state::{self, State}; use princhess::train::TrainingPosition; @@ -15,8 +15,11 @@ fn main() { let file = File::open(path).expect("could not open file"); let records = file.metadata().unwrap().len() as usize / TrainingPosition::SIZE; - let capacity = 16 * TrainingPosition::SIZE; - let mut buffer = BufReader::with_capacity(capacity, file); + let mut buffer = BufReader::with_capacity(TrainingPosition::BUFFER_SIZE, file); + + let mut phase_win: [u64; 25] = [0; 25]; + let mut phase_draw: [u64; 25] = [0; 25]; + let mut phase_loss: [u64; 25] = [0; 25]; let mut policy_inputs: [u64; state::POLICY_NUMBER_FEATURES] = [0; state::POLICY_NUMBER_FEATURES]; @@ -25,18 +28,32 @@ fn main() { let mut count = 0; + let mut first = true; + while let Ok(buf) = buffer.fill_buf() { if buf.is_empty() { break; } - let positions = TrainingPosition::read_batch(buf); + let positions = TrainingPosition::read_buffer(buf); + + if first { + first = false; + println!("samples: {:?}", &positions[..10]); + } for position in positions.iter() { let features = position.get_policy_features(); let moves = position.moves().iter().map(|(mv, _)| *mv).collect(); let state = State::from(position); + match position.stm_relative_result() { + 1 => phase_win[state.phase()] += 1, + 0 => phase_draw[state.phase()] += 1, + -1 => phase_loss[state.phase()] += 1, + _ => (), + } + for feature in features.iter() { policy_inputs[*feature] += 1; } @@ -63,6 +80,22 @@ fn main() { println!("records: {}", records); + println!("phase:"); + for idx in 0..25 { + let (w, d, l) = (phase_win[idx], phase_draw[idx], phase_loss[idx]); + let total = w + d + l; + + println!( + "{:>2}: {:>15}/{:>5.2}% +{:>2} ={:>2} -{:>2} %", + idx, + total, + total as f32 / records as f32 * 100.0, + w * 100 / total, + d * 100 / total, + l * 100 / total + ); + } + println!("policy inputs:"); for (idx, input) in policy_inputs.iter().enumerate() { print!( diff --git a/src/bin/data-truncate.rs b/src/bin/data-truncate.rs index c0ca9ec..9ea5f22 100644 --- a/src/bin/data-truncate.rs +++ b/src/bin/data-truncate.rs @@ -5,8 +5,6 @@ use std::fs::File; use std::io::{self, BufRead, BufReader, BufWriter, Write}; use std::time::Instant; -const BUFFER_COUNT: usize = 1 << 16; - fn main() { let mut args = env::args(); args.next(); @@ -18,8 +16,7 @@ fn main() { let file = File::open(input.clone()).unwrap(); let positions = file.metadata().unwrap().len() as usize / TrainingPosition::SIZE; - let buffer_size = BUFFER_COUNT * TrainingPosition::SIZE; - let mut buffer = BufReader::with_capacity(buffer_size, file); + let mut buffer = BufReader::with_capacity(TrainingPosition::BUFFER_SIZE, file); let out_file = format!("{}.{}m.truncated", input, truncate_to); let mut writer = BufWriter::new(File::create(out_file).unwrap()); @@ -37,8 +34,8 @@ fn main() { break; } - let data = TrainingPosition::read_batch(bytes); - TrainingPosition::write_batch(&mut writer, data).unwrap(); + let data = TrainingPosition::read_buffer(bytes); + TrainingPosition::write_buffer(&mut writer, data); processed += data.len(); diff --git a/src/bin/train-policy.rs b/src/bin/train-policy.rs index 97c3f1f..383b5af 100644 --- a/src/bin/train-policy.rs +++ b/src/bin/train-policy.rs @@ -6,20 +6,19 @@ use std::thread; use std::time::{Instant, SystemTime, UNIX_EPOCH}; use princhess::math; -use princhess::policy::PolicyNetwork; +use princhess::policy::{PolicyCount, PolicyNetwork}; use princhess::state::State; use princhess::train::TrainingPosition; -const EPOCHS: usize = 10; +const TARGET_BATCH_COUNT: usize = 150_000; const BATCH_SIZE: usize = 16384; const THREADS: usize = 6; -const BUFFER_SIZE: usize = 1 << 16; const LR: f32 = 0.001; -const LR_DROP_AT: usize = EPOCHS * 2 / 3; +const LR_DROP_AT: f32 = 0.7; const LR_DROP_FACTOR: f32 = 0.5; -const _BUFFER_SIZE_CHECK: () = assert!(BUFFER_SIZE % BATCH_SIZE == 0); +const _BUFFER_SIZE_CHECK: () = assert!(TrainingPosition::BUFFER_SIZE % BATCH_SIZE == 0); fn main() { println!("Running..."); @@ -47,8 +46,12 @@ fn main() { println!("Positions: {}", count); println!("File: {}", input); - for epoch in 1..=EPOCHS { - println!("\nEpoch {}/{} (LR: {})...", epoch, EPOCHS, lr); + let epochs = TARGET_BATCH_COUNT.div_ceil(count / BATCH_SIZE); + let lr_drop_at = (epochs as f32 * LR_DROP_AT) as usize; + let net_save_period = epochs.div_ceil(10); + + for epoch in 1..=epochs { + println!("\nEpoch {}/{} (LR: {})...", epoch, epochs, lr); let start = Instant::now(); train( @@ -68,17 +71,19 @@ fn main() { (seconds % 60) ); - let dir_name = format!("nets/policy-{}-e{:03}", timestamp, epoch); + if epoch % net_save_period == 0 || epoch == epochs { + let dir_name = format!("nets/policy-{}-e{:03}", timestamp, epoch); - fs::create_dir(&dir_name).unwrap(); + fs::create_dir(&dir_name).unwrap(); - let dir = Path::new(&dir_name); + let dir = Path::new(&dir_name); - network.to_boxed_and_quantized().save_to_bin(dir); + network.to_boxed_and_quantized().save_to_bin(dir); - println!("Saved to {}", dir_name); + println!("Saved to {}", dir_name); + } - if epoch % LR_DROP_AT == 0 { + if epoch % lr_drop_at == 0 { lr *= LR_DROP_FACTOR; } } @@ -94,8 +99,7 @@ fn train( let file = File::open(input).unwrap(); let positions = file.metadata().unwrap().len() as usize / TrainingPosition::SIZE; - let buffer_size = BUFFER_SIZE * TrainingPosition::SIZE; - let mut buffer = BufReader::with_capacity(buffer_size, file); + let mut buffer = BufReader::with_capacity(TrainingPosition::BUFFER_SIZE, file); let mut running_loss = 0.0; let mut running_acc = 0.; @@ -107,18 +111,17 @@ fn train( break; } - let data = TrainingPosition::read_batch(bytes); + let data = TrainingPosition::read_buffer(bytes); for batch in data.chunks(BATCH_SIZE) { let mut gradients = PolicyNetwork::zeroed(); + let mut count = PolicyCount::zeroed(); - let (loss, acc) = gradients_batch(network, &mut gradients, batch); + let (loss, acc) = gradients_batch(network, &mut gradients, &mut count, batch); running_loss += loss; running_acc += acc; - let adj = 2.0 / batch.len() as f32; - - network.adam(&gradients, momentum, velocity, adj, lr); + network.adam(&gradients, momentum, velocity, &count, lr); batch_n += 1; print!("Batch {}/{}\r", batch_n, batches,); @@ -136,6 +139,7 @@ fn train( fn gradients_batch( network: &PolicyNetwork, gradients: &mut PolicyNetwork, + count: &mut PolicyCount, batch: &[TrainingPosition], ) -> (f32, f32) { let size = (batch.len() / THREADS) + 1; @@ -149,16 +153,28 @@ fn gradients_batch( .map(|(chunk, (loss, acc))| { s.spawn(move || { let mut inner_gradients = PolicyNetwork::zeroed(); + let mut inner_count = PolicyCount::zeroed(); + for position in chunk { - update_gradient(position, network, &mut inner_gradients, loss, acc); + update_gradient( + position, + network, + &mut inner_gradients, + &mut inner_count, + loss, + acc, + ); } - inner_gradients + (inner_gradients, inner_count) }) }) .collect::>() .into_iter() .map(|handle| handle.join().unwrap()) - .for_each(|inner_gradients| *gradients += &inner_gradients); + .for_each(|(inner_gradients, inner_count)| { + *gradients += &inner_gradients; + *count += &inner_count; + }); }); (loss.iter().sum::(), acc.iter().sum::()) @@ -168,6 +184,7 @@ fn update_gradient( position: &TrainingPosition, network: &PolicyNetwork, gradients: &mut PolicyNetwork, + count: &mut PolicyCount, loss: &mut f32, acc: &mut f32, ) { @@ -201,6 +218,7 @@ fn update_gradient( *loss -= expected * actual.ln(); network.backprop(&features, gradients, move_idx, error); + count.increment(move_idx); } if argmax(&expected) == argmax(&actual) { diff --git a/src/bin/train-value.rs b/src/bin/train-value.rs index 6a2377b..4ef6535 100644 --- a/src/bin/train-value.rs +++ b/src/bin/train-value.rs @@ -9,20 +9,19 @@ use std::path::Path; use std::thread; use std::time::{Instant, SystemTime, UNIX_EPOCH}; -const EPOCHS: usize = 20; +const TARGET_BATCH_COUNT: usize = 250_000; const BATCH_SIZE: usize = 16384; const THREADS: usize = 6; -const BUFFER_SIZE: usize = 1 << 24; const LR: f32 = 0.001; -const LR_DROP_AT: usize = EPOCHS * 2 / 3; +const LR_DROP_AT: f32 = 0.7; const LR_DROP_FACTOR: f32 = 0.1; const WEIGHT_DECAY: f32 = 0.01; -const WDL_WEIGHT: f32 = 0.5; +const WDL_WEIGHT: f32 = 0.3; -const _BUFFER_SIZE_CHECK: () = assert!(BUFFER_SIZE % BATCH_SIZE == 0); +const _BUFFER_SIZE_CHECK: () = assert!(TrainingPosition::BUFFER_SIZE % BATCH_SIZE == 0); fn main() { println!("Running..."); @@ -50,8 +49,11 @@ fn main() { println!("Positions: {}", count); println!("File: {}", input); - for epoch in 1..=EPOCHS { - println!("\nEpoch {}/{} (LR: {})...", epoch, EPOCHS, lr); + let epochs = TARGET_BATCH_COUNT.div_ceil(count / BATCH_SIZE); + let lr_drop_at = (epochs as f32 * LR_DROP_AT) as usize; + + for epoch in 1..=epochs { + println!("\nEpoch {}/{} (LR: {})...", epoch, epochs, lr); let start = Instant::now(); train( @@ -81,7 +83,7 @@ fn main() { println!("Saved to {}", dir_name); - if epoch % LR_DROP_AT == 0 { + if epoch % lr_drop_at == 0 { lr *= LR_DROP_FACTOR; } } @@ -97,8 +99,7 @@ fn train( let file = File::open(input).unwrap(); let positions = file.metadata().unwrap().len() as usize / TrainingPosition::SIZE; - let buffer_size = BUFFER_SIZE * TrainingPosition::SIZE; - let mut buffer = BufReader::with_capacity(buffer_size, file); + let mut buffer = BufReader::with_capacity(TrainingPosition::BUFFER_SIZE, file); let mut running_loss = 0.0; let mut batch_n = 0; @@ -109,7 +110,7 @@ fn train( break; } - let data = TrainingPosition::read_batch(bytes); + let data = TrainingPosition::read_buffer(bytes); for batch in data.chunks(BATCH_SIZE) { let mut gradients = ValueNetwork::zeroed(); diff --git a/src/chess/bitboard.rs b/src/chess/bitboard.rs index 09e4892..1c6491a 100644 --- a/src/chess/bitboard.rs +++ b/src/chess/bitboard.rs @@ -1,10 +1,12 @@ +use bytemuck::{Pod, Zeroable}; use std::iter::FusedIterator; use std::ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, Not}; use crate::chess::Square; #[must_use] -#[derive(Copy, Clone, Debug, Eq, PartialEq)] +#[derive(Copy, Clone, Debug, Eq, PartialEq, Pod, Zeroable)] +#[repr(transparent)] pub struct Bitboard(pub u64); impl Bitboard { diff --git a/src/chess/castling.rs b/src/chess/castling.rs index d63e929..4475b3b 100644 --- a/src/chess/castling.rs +++ b/src/chess/castling.rs @@ -87,15 +87,12 @@ impl Castling { } pub fn discard_color(&mut self, color: Color) { - match color { - Color::WHITE => { - self.white_king = Square::NONE; - self.white_queen = Square::NONE; - } - Color::BLACK => { - self.black_king = Square::NONE; - self.black_queen = Square::NONE; - } + if color == Color::WHITE { + self.white_king = Square::NONE; + self.white_queen = Square::NONE; + } else { + self.black_king = Square::NONE; + self.black_queen = Square::NONE; } } diff --git a/src/chess/mod.rs b/src/chess/mod.rs index 10edef9..03a2e4f 100644 --- a/src/chess/mod.rs +++ b/src/chess/mod.rs @@ -16,7 +16,6 @@ pub use crate::chess::board::Board; pub use crate::chess::castling::Castling; pub use crate::chess::color::Color; pub use crate::chess::mv::Move; -pub use crate::chess::mv::MoveIndex; pub use crate::chess::mv::MoveList; pub use crate::chess::piece::Piece; pub use crate::chess::square::File; diff --git a/src/chess/mv.rs b/src/chess/mv.rs index 6e07802..1af18a2 100644 --- a/src/chess/mv.rs +++ b/src/chess/mv.rs @@ -1,9 +1,11 @@ use arrayvec::ArrayVec; +use bytemuck::{Pod, Zeroable}; use crate::chess::{Piece, Square}; #[must_use] -#[derive(Debug, Copy, Clone, Eq, PartialEq)] +#[derive(Debug, Copy, Clone, Eq, PartialEq, Pod, Zeroable)] +#[repr(transparent)] pub struct Move(u16); pub type MoveList = ArrayVec; @@ -91,63 +93,3 @@ impl Move { format!("{from}{to}{promotion}") } } - -#[must_use] -#[derive(Debug, Copy, Clone)] -pub struct MoveIndex { - piece: Piece, - from_sq: Square, - to_sq: Square, - from_threats: u8, - to_threats: u8, -} - -impl MoveIndex { - const FROM_BUCKETS: usize = 4; - const TO_BUCKETS: usize = 10; - - pub const FROM_COUNT: usize = 64 * Self::FROM_BUCKETS; - pub const TO_COUNT: usize = 64 * Self::TO_BUCKETS; - - const THREAT_SHIFT: u8 = 0; - const DEFEND_SHIFT: u8 = 1; - const SEE_SHIFT: u8 = 0; - - pub fn new(piece: Piece, from: Square, to: Square) -> Self { - Self { - piece, - from_sq: from, - to_sq: to, - from_threats: 0, - to_threats: 0, - } - } - - pub fn set_from_threat(&mut self, is_threat: bool) { - self.from_threats |= u8::from(is_threat) << Self::THREAT_SHIFT; - } - - pub fn set_from_defend(&mut self, is_defend: bool) { - self.from_threats |= u8::from(is_defend) << Self::DEFEND_SHIFT; - } - - pub fn set_to_good_see(&mut self, is_good_see: bool) { - self.to_threats |= u8::from(is_good_see) << Self::SEE_SHIFT; - } - - #[must_use] - pub fn from_index(&self) -> usize { - let bucket = usize::from(self.from_threats); - bucket * 64 + self.from_sq.index() - } - - #[must_use] - pub fn to_index(&self) -> usize { - let bucket = match self.piece { - Piece::KING => 0, - Piece::PAWN => 1, - p => 2 + usize::from(self.to_threats) * 4 + p.index() - 1, - }; - bucket * 64 + self.to_sq.index() - } -} diff --git a/src/lib.rs b/src/lib.rs index 99b91d2..c0b95ab 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,6 +6,7 @@ mod arena; mod mem; mod nets; mod search_tree; +mod subnets; mod tree_policy; pub mod chess; diff --git a/src/math.rs b/src/math.rs index 722987d..b7e1453 100644 --- a/src/math.rs +++ b/src/math.rs @@ -60,6 +60,21 @@ impl Rng { pub fn next_f32_range(&mut self, min: f32, max: f32) -> f32 { min + self.next_f32() * (max - min) } + + pub fn weighted(&mut self, weights: &[f32]) -> usize { + let r = self.next_f32(); + + let mut cumulative = 0.; + + for (i, &w) in weights.iter().enumerate() { + cumulative += w; + if r < cumulative { + return i; + } + } + + weights.len() - 1 + } } impl Default for Rng { diff --git a/src/nets.rs b/src/nets.rs index 91e27f0..610aaab 100644 --- a/src/nets.rs +++ b/src/nets.rs @@ -2,8 +2,11 @@ use bytemuck::{self, Pod, Zeroable}; use goober::activation::Activation; use std::fs; use std::io::Write; +use std::ops::AddAssign; use std::path::Path; +use crate::subnets::{QAA, QB}; + // Workaround for error in how goober handles an activation such as SCReLU #[derive(Clone, Copy)] pub struct SCReLU; @@ -25,31 +28,39 @@ impl Activation for SCReLU { #[derive(Clone, Copy, Debug, Zeroable)] #[repr(C)] -pub struct Accumulator { - pub vals: [i16; H], +pub struct Accumulator { + pub vals: [T; H], } -unsafe impl Pod for Accumulator {} +unsafe impl Pod for Accumulator {} -impl Accumulator { - pub fn set(&mut self, weights: &Accumulator) { +impl Accumulator { + pub fn set(&mut self, weights: &Accumulator) + where + T: From, + { for (i, d) in self.vals.iter_mut().zip(&weights.vals) { - *i += *d; + *i += T::from(*d); } } +} - pub fn dot_relu(&self, rhs: &Accumulator) -> i32 { - let mut result = 0; +impl Accumulator { + pub fn dot_relu(&self, rhs: &Accumulator) -> f32 { + let mut result: i32 = 0; for (a, b) in self.vals.iter().zip(&rhs.vals) { - result += relu(*a) * relu(*b); + result += relu(*a) * relu(*b / QB); } - result + result as f32 / QAA as f32 } } -pub fn relu(x: i16) -> i32 { +pub fn relu(x: F) -> i32 +where + i32: From, +{ i32::from(x).max(0) } diff --git a/src/nets/policy.bin b/src/nets/policy.bin index ab2268b..019b4c4 100644 Binary files a/src/nets/policy.bin and b/src/nets/policy.bin differ diff --git a/src/options.rs b/src/options.rs index 415ebb2..b54880b 100644 --- a/src/options.rs +++ b/src/options.rs @@ -210,6 +210,12 @@ pub struct TimeManagementOptions { pub visits_m: f32, } +impl Default for MctsOptions { + fn default() -> Self { + MctsOptions::from(&UciOptionMap::default()) + } +} + impl From<&UciOptionMap> for MctsOptions { fn from(map: &UciOptionMap) -> Self { Self { diff --git a/src/policy.rs b/src/policy.rs index 6d144ce..f83b41a 100644 --- a/src/policy.rs +++ b/src/policy.rs @@ -1,64 +1,76 @@ use bytemuck::{allocation, Pod, Zeroable}; -use goober::activation::ReLU; -use goober::layer::SparseConnected; use goober::{FeedForwardNetwork, OutputLayer, SparseVector}; use std::fmt::{self, Display}; -use std::mem; use std::ops::AddAssign; use std::path::Path; -use crate::chess::{MoveIndex, Square}; -use crate::math::{randomize_sparse, Rng}; -use crate::mem::Align16; -use crate::nets::{q_i16, save_to_bin, Accumulator}; -use crate::state; - -const INPUT_SIZE: usize = state::POLICY_NUMBER_FEATURES; -const ATTENTION_SIZE: usize = 8; - -const QA: i32 = 256; -const QAA: i32 = QA * QA; - -type Output = SparseConnected; - -type QuantizedOutputWeights = [Align16>; INPUT_SIZE]; -type QuantizedOutputBias = Align16>; +use crate::chess::{Piece, Square}; +use crate::nets::save_to_bin; +use crate::subnets::{LayerNetwork, LinearNetwork, QuantizedLayerNetwork, QuantizedLinearNetwork}; + +#[must_use] +#[derive(Debug, Copy, Clone)] +pub struct MoveIndex { + piece: Piece, + from_sq: Square, + to_sq: Square, + from_threats: u8, + to_threats: u8, +} -type RawOutputWeights = [[i16; ATTENTION_SIZE]; INPUT_SIZE]; -type RawOutputBias = [i16; ATTENTION_SIZE]; +impl MoveIndex { + const FROM_BUCKETS: usize = 4; + const TO_BUCKETS: usize = 10; -#[repr(C)] -#[derive(FeedForwardNetwork)] -pub struct FromNetwork { - output: Output, -} + pub const FROM_COUNT: usize = 64 * Self::FROM_BUCKETS; + pub const TO_COUNT: usize = 64 * Self::TO_BUCKETS; -unsafe impl Zeroable for FromNetwork {} + const THREAT_SHIFT: u8 = 0; + const DEFEND_SHIFT: u8 = 1; + const SEE_SHIFT: u8 = 0; -impl FromNetwork { - pub fn randomize(&mut self) { - let mut rng = Rng::default(); + pub fn new(piece: Piece, from: Square, to: Square) -> Self { + Self { + piece, + from_sq: from, + to_sq: to, + from_threats: 0, + to_threats: 0, + } + } - randomize_sparse(&mut self.output, &mut rng); + pub fn set_from_threat(&mut self, is_threat: bool) { + self.from_threats |= u8::from(is_threat) << Self::THREAT_SHIFT; } -} -#[repr(C)] -#[derive(FeedForwardNetwork)] -pub struct ToNetwork { - output: Output, -} + pub fn set_from_defend(&mut self, is_defend: bool) { + self.from_threats |= u8::from(is_defend) << Self::DEFEND_SHIFT; + } -unsafe impl Zeroable for ToNetwork {} + pub fn set_to_good_see(&mut self, is_good_see: bool) { + self.to_threats |= u8::from(is_good_see) << Self::SEE_SHIFT; + } -impl ToNetwork { - pub fn randomize(&mut self) { - let mut rng = Rng::default(); + #[must_use] + pub fn from_index(&self) -> usize { + let bucket = usize::from(self.from_threats); + bucket * 64 + self.from_sq.index() + } - randomize_sparse(&mut self.output, &mut rng); + #[must_use] + pub fn to_index(&self) -> usize { + let bucket = match self.piece { + Piece::KING => 0, + Piece::PAWN => 1, + p => 2 + usize::from(self.to_threats) * 4 + p.index() - 1, + }; + bucket * 64 + self.to_sq.index() } } +type FromNetwork = LinearNetwork; +type ToNetwork = LayerNetwork; + #[allow(clippy::module_name_repetitions)] #[derive(Zeroable)] pub struct PolicyNetwork { @@ -69,22 +81,20 @@ pub struct PolicyNetwork { #[repr(C)] #[derive(Copy, Clone, Pod, Zeroable)] pub struct QuantizedPolicyNetwork { - from_weights: [QuantizedOutputWeights; MoveIndex::FROM_COUNT], - from_bias: [QuantizedOutputBias; MoveIndex::FROM_COUNT], - to_weights: [QuantizedOutputWeights; MoveIndex::TO_COUNT], - to_bias: [QuantizedOutputBias; MoveIndex::TO_COUNT], + from: QuantizedLinearNetwork<{ MoveIndex::FROM_COUNT }>, + to: QuantizedLayerNetwork<{ MoveIndex::TO_COUNT }>, +} + +#[allow(clippy::module_name_repetitions)] +pub struct PolicyCount { + pub from: [u64; MoveIndex::FROM_COUNT], + pub to: [u64; MoveIndex::TO_COUNT], } impl Display for PolicyNetwork { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let from = format!( - "from: [{INPUT_SIZE}->{ATTENTION_SIZE}; {}]", - MoveIndex::FROM_COUNT - ); - let to = format!( - "to: [{INPUT_SIZE}->{ATTENTION_SIZE}; {}]", - MoveIndex::TO_COUNT - ); + let from = format!("from: [{}; {}]", self.from[0], MoveIndex::FROM_COUNT); + let to = format!("to: [{}; {}]", self.to[0], MoveIndex::TO_COUNT); write!(f, "{from} * {to}") } } @@ -150,25 +160,31 @@ impl PolicyNetwork { } } - pub fn adam(&mut self, g: &Self, m: &mut Self, v: &mut Self, adj: f32, lr: f32) { + pub fn adam(&mut self, g: &Self, m: &mut Self, v: &mut Self, count: &PolicyCount, lr: f32) { for subnet_idx in 0..self.from.len() { - self.from[subnet_idx].adam( - &g.from[subnet_idx], - &mut m.from[subnet_idx], - &mut v.from[subnet_idx], - adj, - lr, - ); + match count.from[subnet_idx] { + 0 => continue, + n => self.from[subnet_idx].adam( + &g.from[subnet_idx], + &mut m.from[subnet_idx], + &mut v.from[subnet_idx], + 1.0 / n as f32, + lr, + ), + } } for subnet_idx in 0..self.to.len() { - self.to[subnet_idx].adam( - &g.to[subnet_idx], - &mut m.to[subnet_idx], - &mut v.to[subnet_idx], - adj, - lr, - ); + match count.to[subnet_idx] { + 0 => continue, + n => self.to[subnet_idx].adam( + &g.to[subnet_idx], + &mut m.to[subnet_idx], + &mut v.to[subnet_idx], + 1.0 / n as f32, + lr, + ), + } } } @@ -196,43 +212,12 @@ impl PolicyNetwork { #[must_use] pub fn to_boxed_and_quantized(&self) -> Box { - let mut from_weights: Box<[RawOutputWeights; MoveIndex::FROM_COUNT]> = - allocation::zeroed_box(); - let mut from_bias: Box<[RawOutputBias; MoveIndex::FROM_COUNT]> = allocation::zeroed_box(); - let mut to_weights: Box<[RawOutputWeights; MoveIndex::TO_COUNT]> = allocation::zeroed_box(); - let mut to_bias: Box<[RawOutputBias; MoveIndex::TO_COUNT]> = allocation::zeroed_box(); - - for (subnet, raw) in self.from.iter().zip(from_weights.iter_mut()) { - for (row_idx, weights) in raw.iter_mut().enumerate() { - let row = subnet.output.weights_row(row_idx); - for weight_idx in 0..ATTENTION_SIZE { - weights[weight_idx] = q_i16(row[weight_idx], QA); - } - } - } + let mut result: Box = allocation::zeroed_box(); - for (subnet, raw) in self.from.iter().zip(from_bias.iter_mut()) { - for (weight_idx, bias) in raw.iter_mut().enumerate() { - *bias = q_i16(subnet.output.bias()[weight_idx], QA); - } - } + result.from = *QuantizedLinearNetwork::boxed_from(&self.from); + result.to = *QuantizedLayerNetwork::boxed_from(&self.to); - for (subnet, raw) in self.to.iter().zip(to_weights.iter_mut()) { - for (row_idx, weights) in raw.iter_mut().enumerate() { - let row = subnet.output.weights_row(row_idx); - for weight_idx in 0..ATTENTION_SIZE { - weights[weight_idx] = q_i16(row[weight_idx], QA); - } - } - } - - for (subnet, raw) in self.to.iter().zip(to_bias.iter_mut()) { - for (weight_idx, bias) in raw.iter_mut().enumerate() { - *bias = q_i16(subnet.output.bias()[weight_idx], QA); - } - } - - QuantizedPolicyNetwork::from_slices(&from_weights, &from_bias, &to_weights, &to_bias) + result } } @@ -242,74 +227,10 @@ impl QuantizedPolicyNetwork { allocation::zeroed_box() } - #[must_use] - pub fn from_slices( - from_weights: &[RawOutputWeights; MoveIndex::FROM_COUNT], - from_bias: &[RawOutputBias; MoveIndex::FROM_COUNT], - to_weights: &[RawOutputWeights; MoveIndex::TO_COUNT], - to_bias: &[RawOutputBias; MoveIndex::TO_COUNT], - ) -> Box { - let mut network = Self::zeroed(); - - network.from_weights = unsafe { - mem::transmute::< - [RawOutputWeights; MoveIndex::FROM_COUNT], - [QuantizedOutputWeights; MoveIndex::FROM_COUNT], - >(*from_weights) - }; - - network.from_bias = unsafe { - mem::transmute::< - [RawOutputBias; MoveIndex::FROM_COUNT], - [QuantizedOutputBias; MoveIndex::FROM_COUNT], - >(*from_bias) - }; - - network.to_weights = unsafe { - mem::transmute::< - [RawOutputWeights; MoveIndex::TO_COUNT], - [QuantizedOutputWeights; MoveIndex::TO_COUNT], - >(*to_weights) - }; - - network.to_bias = unsafe { - mem::transmute::< - [RawOutputBias; MoveIndex::TO_COUNT], - [QuantizedOutputBias; MoveIndex::TO_COUNT], - >(*to_bias) - }; - - network - } - pub fn save_to_bin(&self, dir: &Path) { save_to_bin(dir, "policy.bin", self); } - fn get_from_bias(&self, from_idx: usize) -> Accumulator { - unsafe { **self.from_bias.get_unchecked(from_idx) } - } - - fn get_from_weights(&self, from_idx: usize, feat_idx: usize) -> &Accumulator { - unsafe { - self.from_weights - .get_unchecked(from_idx) - .get_unchecked(feat_idx) - } - } - - fn get_to_bias(&self, to_idx: usize) -> Accumulator { - unsafe { **self.to_bias.get_unchecked(to_idx) } - } - - fn get_to_weights(&self, to_idx: usize, feat_idx: usize) -> &Accumulator { - unsafe { - self.to_weights - .get_unchecked(to_idx) - .get_unchecked(feat_idx) - } - } - pub fn get_all>( &self, features: &SparseVector, @@ -320,18 +241,44 @@ impl QuantizedPolicyNetwork { let from_idx = move_idx.from_index(); let to_idx = move_idx.to_index(); - let mut from = self.get_from_bias(from_idx); - let mut to = self.get_to_bias(to_idx); + let mut from = self.from.get_bias(from_idx); + let mut to = self.to.get_bias(to_idx); for f in features.iter() { - let from_weight = self.get_from_weights(from_idx, *f); - let to_weight = self.get_to_weights(to_idx, *f); - - from.set(from_weight); - to.set(to_weight); + self.from.set(from_idx, *f, &mut from); + self.to.set(to_idx, *f, &mut to); } - out.push(from.dot_relu(&to) as f32 / QAA as f32); + let to_out = self.to.out(to_idx, &to); + + out.push(from.dot_relu(&to_out)); + } + } +} + +impl PolicyCount { + #[must_use] + pub fn zeroed() -> Self { + Self { + from: [0; MoveIndex::FROM_COUNT], + to: [0; MoveIndex::TO_COUNT], + } + } + + pub fn increment(&mut self, move_idx: MoveIndex) { + self.from[move_idx.from_index()] += 1; + self.to[move_idx.to_index()] += 1; + } +} + +impl AddAssign<&Self> for PolicyCount { + fn add_assign(&mut self, rhs: &Self) { + for (lhs, rhs) in self.from.iter_mut().zip(&rhs.from) { + *lhs += rhs; + } + + for (lhs, rhs) in self.to.iter_mut().zip(&rhs.to) { + *lhs += rhs; } } } diff --git a/src/search.rs b/src/search.rs index af76897..9330502 100644 --- a/src/search.rs +++ b/src/search.rs @@ -304,7 +304,7 @@ impl Search { }); } - pub fn playout_sync(&self, playouts: usize) { + pub fn playout_sync(&self, playouts: u64) { let mut tld = ThreadData::create(&self.search_tree); let cpuct = self.search_options.mcts_options.cpuct; let tm = TimeManagement::infinite(); @@ -363,6 +363,14 @@ impl Search { self.search_tree.best_move() } + pub fn most_visited_move(&self) -> Move { + *self + .search_tree + .root_node() + .select_child_by_visits() + .get_move() + } + fn to_uci(&self, mov: Move) -> String { mov.to_uci(self.search_options.is_chess960) } diff --git a/src/search_tree.rs b/src/search_tree.rs index 9796fea..80153dc 100644 --- a/src/search_tree.rs +++ b/src/search_tree.rs @@ -113,6 +113,10 @@ impl PositionNode { .max_by_key(|x| x.reward().average) .unwrap() } + + pub fn select_child_by_visits(&self) -> &MoveEdge { + self.hots().iter().max_by_key(|x| x.visits()).unwrap() + } } impl MoveEdge { diff --git a/src/state.rs b/src/state.rs index 07d3db9..e0c937d 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,6 +1,7 @@ use arrayvec::ArrayVec; -use crate::chess::{Board, Color, File, Move, MoveIndex, MoveList, Piece, Square}; +use crate::chess::{Board, Color, File, Move, MoveList, Piece, Rank, Square}; +use crate::policy::MoveIndex; use crate::uci::Tokens; const NUMBER_KING_BUCKETS: usize = 3; @@ -90,6 +91,14 @@ impl State { self.board.is_legal_move() } + #[must_use] + pub fn phase(&self) -> usize { + let b = self.board; + + (4 * b.queens().count() + 2 * b.rooks().count() + b.bishops().count() + b.knights().count()) + .clamp(0, 24) + } + #[must_use] pub fn available_moves(&self) -> MoveList { self.board.legal_moves() @@ -253,7 +262,13 @@ impl State { let flip_from = flip_square(from_sq); let flip_to = flip_square(to_sq); - let mut mi = MoveIndex::new(piece, flip_from, flip_to); + let adj_to = match mv.promotion() { + Piece::KNIGHT => flip_to.with_rank(Rank::_1), + Piece::BISHOP | Piece::ROOK => flip_to.with_rank(Rank::_2), + _ => flip_to, + }; + + let mut mi = MoveIndex::new(piece, flip_from, adj_to); mi.set_from_threat(threats.contains(from_sq)); mi.set_from_defend(defends.contains(from_sq)); diff --git a/src/subnets.rs b/src/subnets.rs new file mode 100644 index 0000000..f0c167a --- /dev/null +++ b/src/subnets.rs @@ -0,0 +1,231 @@ +use bytemuck::{allocation, Pod, Zeroable}; +use goober::activation::ReLU; +use goober::layer::{DenseConnected, SparseConnected}; +use goober::FeedForwardNetwork; +use std::fmt::{self, Display}; + +use crate::math::{randomize_dense, randomize_sparse, Rng}; +use crate::mem::Align16; +use crate::nets::{q_i16, q_i32, relu, Accumulator}; +use crate::state; + +const INPUT_SIZE: usize = state::POLICY_NUMBER_FEATURES; +const ATTENTION_SIZE: usize = 8; + +pub const QA: i32 = 256; +pub const QB: i32 = 256; +pub const QAA: i32 = QA * QA; +pub const QAB: i32 = QA * QB; + +type Linear = SparseConnected; + +type QuantizedLinearWeights = [Align16>; INPUT_SIZE]; +type QuantizedLinearBias = Align16>; + +type RawLinearWeights = Align16<[[i16; ATTENTION_SIZE]; INPUT_SIZE]>; +type RawLinearBias = Align16<[i16; ATTENTION_SIZE]>; + +#[repr(C)] +#[derive(FeedForwardNetwork)] +pub struct LinearNetwork { + output: Linear, +} + +unsafe impl Zeroable for LinearNetwork {} + +impl LinearNetwork { + pub fn randomize(&mut self) { + let mut rng = Rng::default(); + + randomize_sparse(&mut self.output, &mut rng); + } +} + +impl Display for LinearNetwork { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{INPUT_SIZE}->{ATTENTION_SIZE}") + } +} + +#[repr(C)] +#[derive(Copy, Clone, Zeroable)] +pub struct QuantizedLinearNetwork { + weights: [QuantizedLinearWeights; N], + bias: [QuantizedLinearBias; N], +} + +unsafe impl Pod for QuantizedLinearNetwork {} + +impl QuantizedLinearNetwork { + #[must_use] + pub fn boxed_from(subnets: &[LinearNetwork; N]) -> Box { + let mut weights: Box<[RawLinearWeights; N]> = allocation::zeroed_box(); + let mut bias: Box<[RawLinearBias; N]> = allocation::zeroed_box(); + + for (subnet, raw) in subnets.iter().zip(weights.iter_mut()) { + for (row_idx, weights) in raw.iter_mut().enumerate() { + let row = subnet.output.weights_row(row_idx); + for weight_idx in 0..ATTENTION_SIZE { + weights[weight_idx] = q_i16(row[weight_idx], QA); + } + } + } + + for (subnet, raw) in subnets.iter().zip(bias.iter_mut()) { + for (weight_idx, bias) in raw.iter_mut().enumerate() { + *bias = q_i16(subnet.output.bias()[weight_idx], QA); + } + } + + let mut result: Box = allocation::zeroed_box(); + + result.weights = *bytemuck::must_cast_ref(&*weights); + result.bias = *bytemuck::must_cast_ref(&*bias); + + result + } + + pub fn get_bias(&self, idx: usize) -> Accumulator { + unsafe { **self.bias.get_unchecked(idx) } + } + + fn get_weights(&self, idx: usize, feat_idx: usize) -> &Accumulator { + unsafe { self.weights.get_unchecked(idx).get_unchecked(feat_idx) } + } + + pub fn set(&self, idx: usize, feat_idx: usize, acc: &mut Accumulator) { + acc.set(self.get_weights(idx, feat_idx)); + } +} + +const HIDDEN_SIZE: usize = 8; + +type Feature = SparseConnected; +type Output = DenseConnected; + +type QuantizedFeatureWeights = [Align16>; INPUT_SIZE]; +type QuantizedFeatureBias = Align16>; +type QuantizedOutputWeights = [Align16>; HIDDEN_SIZE]; +type QuantizedOutputBias = Align16>; + +type RawFeatureWeights = Align16<[[i16; HIDDEN_SIZE]; INPUT_SIZE]>; +type RawFeatureBias = Align16<[i16; HIDDEN_SIZE]>; +type RawOutputWeights = Align16<[[i16; ATTENTION_SIZE]; HIDDEN_SIZE]>; +type RawOutputBias = Align16<[i32; ATTENTION_SIZE]>; + +#[repr(C)] +#[derive(FeedForwardNetwork)] +pub struct LayerNetwork { + feature: Feature, + output: Output, +} + +unsafe impl Zeroable for LayerNetwork {} + +impl Display for LayerNetwork { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{INPUT_SIZE}->{HIDDEN_SIZE}->{ATTENTION_SIZE}") + } +} + +impl LayerNetwork { + pub fn randomize(&mut self) { + let mut rng = Rng::default(); + + randomize_sparse(&mut self.feature, &mut rng); + randomize_dense(&mut self.output, &mut rng); + } +} + +#[repr(C)] +#[derive(Copy, Clone, Zeroable)] +pub struct QuantizedLayerNetwork { + feature_weights: [QuantizedFeatureWeights; N], + feature_bias: [QuantizedFeatureBias; N], + output_weights: [QuantizedOutputWeights; N], + output_bias: [QuantizedOutputBias; N], +} + +unsafe impl Pod for QuantizedLayerNetwork {} + +impl QuantizedLayerNetwork { + #[must_use] + pub fn boxed_from(subnets: &[LayerNetwork; N]) -> Box { + let mut feature_weights: Box<[RawFeatureWeights; N]> = allocation::zeroed_box(); + let mut feature_bias: Box<[RawFeatureBias; N]> = allocation::zeroed_box(); + let mut output_weights: Box<[RawOutputWeights; N]> = allocation::zeroed_box(); + let mut output_bias: Box<[RawOutputBias; N]> = allocation::zeroed_box(); + + for (subnet, raw) in subnets.iter().zip(feature_weights.iter_mut()) { + for (row_idx, weights) in raw.iter_mut().enumerate() { + let row = subnet.feature.weights_row(row_idx); + for weight_idx in 0..HIDDEN_SIZE { + weights[weight_idx] = q_i16(row[weight_idx], QA); + } + } + } + + for (subnet, raw) in subnets.iter().zip(feature_bias.iter_mut()) { + for (weight_idx, bias) in raw.iter_mut().enumerate() { + *bias = q_i16(subnet.feature.bias()[weight_idx], QA); + } + } + + for (subnet, raw) in subnets.iter().zip(output_weights.iter_mut()) { + for a in 0..ATTENTION_SIZE { + let col = subnet.output.weights_col(a); + for h in 0..HIDDEN_SIZE { + raw[h][a] = q_i16(col[h], QB); + } + } + } + + for (subnet, raw) in subnets.iter().zip(output_bias.iter_mut()) { + for (weight_idx, bias) in raw.iter_mut().enumerate() { + *bias = q_i32(subnet.output.bias()[weight_idx], QAB); + } + } + + let mut result: Box = allocation::zeroed_box(); + + result.feature_weights = *bytemuck::must_cast_ref(&*feature_weights); + result.feature_bias = *bytemuck::must_cast_ref(&*feature_bias); + result.output_weights = *bytemuck::must_cast_ref(&*output_weights); + result.output_bias = *bytemuck::must_cast_ref(&*output_bias); + + result + } + + pub fn get_bias(&self, idx: usize) -> Accumulator { + unsafe { **self.feature_bias.get_unchecked(idx) } + } + + pub fn set(&self, idx: usize, feat_idx: usize, acc: &mut Accumulator) { + let weights = unsafe { + **self + .feature_weights + .get_unchecked(idx) + .get_unchecked(feat_idx) + }; + acc.set(&weights); + } + + pub fn out( + &self, + idx: usize, + acc: &Accumulator, + ) -> Accumulator { + assert!(idx < N); + + let mut outs = *self.output_bias[idx]; + let weights = self.output_weights[idx]; + + for (out, weight) in outs.vals.iter_mut().zip(weights.iter()) { + for (a, b) in acc.vals.iter().zip(weight.vals) { + *out += relu(*a) * i32::from(b); + } + } + + outs + } +} diff --git a/src/train/data.rs b/src/train/data.rs index 046a182..acf668d 100644 --- a/src/train/data.rs +++ b/src/train/data.rs @@ -1,55 +1,50 @@ -use crate::chess::{Bitboard, Board, Castling, Color, Move, Piece, Square}; -use crate::search::SCALE; -use crate::search_tree::SearchTree; -use crate::state::State; - use arrayvec::ArrayVec; +use bytemuck::{self, Pod, Zeroable}; use goober::SparseVector; use std::fs::File; use std::io::{BufWriter, Write}; -use std::{io, mem, slice}; +use std::mem; -#[derive(Clone, Copy, Debug)] +use crate::chess::{Bitboard, Board, Castling, Color, Move, Piece, Square}; +use crate::search::SCALE; +use crate::search_tree::SearchTree; +use crate::state::State; + +#[derive(Clone, Copy, Debug, Pod, Zeroable)] +#[repr(C)] pub struct TrainingPosition { occupied: Bitboard, pieces: [u8; 16], - stm: Color, - result: i8, evaluation: i32, - - previous_moves: [Move; 4], - - #[allow(dead_code)] - best_move: Move, + result: i8, + stm: u8, legal_moves: [Move; TrainingPosition::MAX_MOVES], visits: [u8; TrainingPosition::MAX_MOVES], } -const _SIZE_CHECK: () = assert!(mem::size_of::() == 256); +const _SIZE_CHECK: () = assert!(mem::size_of::() == 192); impl TrainingPosition { - pub const MAX_MOVES: usize = 72; - pub const MAX_VISITS: u32 = 1024; + pub const MAX_MOVES: usize = 54; + pub const SIZE: usize = mem::size_of::(); + pub const BATCH_SIZE: usize = 16384; + pub const BUFFER_COUNT: usize = 1 << 16; + pub const BUFFER_SIZE: usize = Self::BUFFER_COUNT * Self::SIZE; - pub fn write_batch(out: &mut BufWriter, data: &[TrainingPosition]) -> io::Result<()> { - let src_size = mem::size_of_val(data); - let data_slice = unsafe { slice::from_raw_parts(data.as_ptr().cast(), src_size) }; - out.write_all(data_slice)?; - Ok(()) + pub fn write_buffer(out: &mut BufWriter, data: &[TrainingPosition; Self::BUFFER_COUNT]) { + out.write_all(bytemuck::bytes_of(data)).unwrap(); } #[must_use] - pub fn read_batch(buffer: &[u8]) -> &[TrainingPosition] { - let len = buffer.len() / TrainingPosition::SIZE; - unsafe { slice::from_raw_parts(buffer.as_ptr().cast(), len) } + pub fn read_buffer(buffer: &[u8]) -> &[TrainingPosition; Self::BUFFER_COUNT] { + bytemuck::from_bytes(buffer) } - pub fn read_batch_mut(buffer: &mut [u8]) -> &mut [TrainingPosition] { - let len = buffer.len() / TrainingPosition::SIZE; - unsafe { slice::from_raw_parts_mut(buffer.as_mut_ptr().cast(), len) } + pub fn read_buffer_mut(buffer: &mut [u8]) -> &mut [TrainingPosition; Self::BUFFER_COUNT] { + bytemuck::from_bytes_mut(buffer) } #[must_use] @@ -57,15 +52,19 @@ impl TrainingPosition { self.evaluation as f32 / SCALE } + pub fn stm(&self) -> Color { + Color::from(self.stm) + } + #[must_use] pub fn stm_relative_evaluation(&self) -> f32 { let e = self.evaluation(); - self.stm.fold(e, -e) + self.stm().fold(e, -e) } #[must_use] pub fn stm_relative_result(&self) -> i8 { - self.stm.fold(self.result, -self.result) + self.stm().fold(self.result, -self.result) } #[must_use] @@ -78,10 +77,6 @@ impl TrainingPosition { .collect() } - pub fn set_previous_moves(&mut self, moves: [Move; 4]) { - self.previous_moves = moves; - } - pub fn set_result(&mut self, result: i8) { self.result = result; } @@ -126,41 +121,28 @@ impl From<&SearchTree> for TrainingPosition { } let mut nodes = [(Move::NONE, 0); Self::MAX_MOVES]; + let mut max_visits = 0; for (node, hot) in nodes.iter_mut().zip(tree.root_node().hots().iter()) { - *node = (*hot.get_move(), hot.visits()); + let vs = hot.visits(); + max_visits = max_visits.max(vs); + *node = (*hot.get_move(), vs); } let mut legal_moves = [Move::NONE; Self::MAX_MOVES]; let mut visits = [0; Self::MAX_MOVES]; - let mut max_visits = 0; - - for (idx, (mv, vs)) in nodes - .iter() - .take_while(|(m, _)| *m != Move::NONE) - .enumerate() - { - let vs = vs.min(&Self::MAX_VISITS); - - assert!(*vs <= Self::MAX_VISITS); - assert!(u8::try_from(vs * u32::from(u8::MAX) / Self::MAX_VISITS).is_ok()); - - if *vs > max_visits { - max_visits = *vs; - } - + for (idx, (mv, vs)) in nodes.iter().enumerate() { legal_moves[idx] = *mv; - visits[idx] = (*vs * u32::from(u8::MAX) / Self::MAX_VISITS) as u8; - } - assert!(max_visits == Self::MAX_VISITS); + let scaled_visits = (*vs * u32::from(u8::MAX)).div_ceil(max_visits); + assert!(u8::try_from(scaled_visits).is_ok()); + visits[idx] = scaled_visits as u8; + } let pv = tree.best_edge(); let mut evaluation = pv.reward().average; - let best_move = *pv.get_move(); - // white relative evaluation evaluation = stm .fold(evaluation, -evaluation) @@ -168,18 +150,15 @@ impl From<&SearchTree> for TrainingPosition { // zero'd to be filled in later let result = 0; - let previous_moves = [Move::NONE; 4]; TrainingPosition { occupied, pieces, - stm, - result, - evaluation, - previous_moves, - best_move, legal_moves, visits, + evaluation, + result, + stm: u8::from(stm), } } } @@ -198,8 +177,13 @@ impl From<&TrainingPosition> for State { pieces[piece].toggle(sq); } - let board = - Board::from_bitboards(colors, pieces, position.stm, Square::NONE, Castling::none()); + let board = Board::from_bitboards( + colors, + pieces, + position.stm(), + Square::NONE, + Castling::none(), + ); State::from_board(board) } diff --git a/src/value.rs b/src/value.rs index ae8e232..faf77cf 100644 --- a/src/value.rs +++ b/src/value.rs @@ -23,13 +23,13 @@ const QAB: i32 = QA * QB; type Feature = SparseConnected; type Output = DenseConnected; -type QuantizedFeatureWeights = [Align64>; INPUT_SIZE]; -type QuantizedFeatureBias = Align64>; -type QuantizedOutputWeights = [Align64>; 2]; +type QuantizedFeatureWeights = [Align64>; INPUT_SIZE]; +type QuantizedFeatureBias = Align64>; +type QuantizedOutputWeights = [Align64>; 2]; -type RawFeatureWeights = [[i16; HIDDEN_SIZE]; INPUT_SIZE]; -type RawFeatureBias = [i16; HIDDEN_SIZE]; -type RawOutputWeights = [i16; HIDDEN_SIZE * 2]; +type RawFeatureWeights = Align64<[[i16; HIDDEN_SIZE]; INPUT_SIZE]>; +type RawFeatureBias = Align64<[i16; HIDDEN_SIZE]>; +type RawOutputWeights = Align64<[i16; HIDDEN_SIZE * 2]>; #[allow(clippy::module_name_repetitions)] pub struct ValueNetwork { @@ -116,9 +116,11 @@ impl ValueNetwork { #[must_use] pub fn to_boxed_and_quantized(&self) -> Box { let mut stm_weights: Box = allocation::zeroed_box(); - let mut stm_bias = [0; HIDDEN_SIZE]; + let mut stm_bias: Box = allocation::zeroed_box(); + let mut nstm_weights: Box = allocation::zeroed_box(); - let mut nstm_bias = [0; HIDDEN_SIZE]; + let mut nstm_bias: Box = allocation::zeroed_box(); + let mut output_weights: Box = allocation::zeroed_box(); for (row_idx, weights) in stm_weights.iter_mut().enumerate() { @@ -249,21 +251,13 @@ impl QuantizedValueNetwork { ) -> Box { let mut network = Self::zeroed(); - network.stm_weights = unsafe { - std::mem::transmute::(*stm_weights) - }; - network.stm_bias = - unsafe { std::mem::transmute::(*stm_bias) }; - - network.nstm_weights = unsafe { - std::mem::transmute::(*nstm_weights) - }; - network.nstm_bias = - unsafe { std::mem::transmute::(*nstm_bias) }; - - network.output_weights = unsafe { - std::mem::transmute::(*output_weights) - }; + network.stm_weights = *bytemuck::must_cast_ref(stm_weights); + network.stm_bias = *bytemuck::must_cast_ref(stm_bias); + + network.nstm_weights = *bytemuck::must_cast_ref(nstm_weights); + network.nstm_bias = *bytemuck::must_cast_ref(nstm_bias); + + network.output_weights = *bytemuck::must_cast_ref(output_weights); network.output_bias = output_bias; network