Skip to content

Commit

Permalink
Created initial crate and code for training the Hash Chess engine net…
Browse files Browse the repository at this point in the history
…works specified in the `hash-network` crate, by using an AlphaZero like training algorithm
  • Loading branch information
miestrode committed Oct 15, 2023
1 parent f734ca4 commit 26785e0
Show file tree
Hide file tree
Showing 20 changed files with 568 additions and 134 deletions.
10 changes: 9 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ to do well currently, using Rust. Some areas may suffer, or just won't use Rust

## To do

The primary things as of right now to be done, are:
### Move generation (`hash-core`)

- [ ] Make the FEN parser fail when the board it is parsing is illegal, as `Board` should never, whilst only using safe
functions result in an invalid position.
Expand All @@ -23,6 +23,14 @@ The primary things as of right now to be done, are:
- [ ] Refactor the build script, and it's magic bitboards setup (consider using `phf`, and unrelatedly switching to
black
magic bitboards)

### MCTS

- [ ] Create an MCTS searcher using the networks (incorporating parallelism, Murphy Sampling and the like)
- [ ] Consider not tying a board to the tree, saving memory
- [ ] Consider to the contrary tying the relevant move to each child, or at least a move integer.

### Network training

- [ ] Create a network trainer in Rust
- [ ] Create an evaluation framework, similar to FishTest or OpenBench
3 changes: 2 additions & 1 deletion hash-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@ test-case = "3.2.1"
hash-bootstrap = { path = "../hash-bootstrap" }

rustifact = "0.9.0"
arrayvec = "0.7.4"
arrayvec = "0.7.4"

17 changes: 7 additions & 10 deletions hash-core/src/board.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ pub struct Board {
pub playing_color: Color,
pub en_passant_capture_square: Option<Square>,
pub piece_table: PieceTable,
pub min_half_move_clock: u8,
pub min_ply_clock: u8,
pub full_moves: u16,
pub hash: u64,
}
Expand Down Expand Up @@ -257,10 +257,10 @@ impl Board {

self.full_moves += (self.playing_color == Color::Black) as u16;

self.min_half_move_clock = if moved_piece_kind == PieceKind::Pawn || is_capture {
self.min_ply_clock = if moved_piece_kind == PieceKind::Pawn || is_capture {
0
} else {
self.min_half_move_clock.saturating_add(1)
self.min_ply_clock.saturating_add(1)
};

self.playing_color = !self.playing_color;
Expand Down Expand Up @@ -316,9 +316,9 @@ impl FromStr for Board {
square => Some(Square::from_str(square)?),
};

let half_move_clock = parts[4]
let ply_clock = parts[4]
.parse::<u8>()
.map_err(|_| "Input contains invalid number for the half move clock")?;
.map_err(|_| "Input contains invalid number for the half-move clock")?;

let full_moves = parts[5]
.parse::<u16>()
Expand Down Expand Up @@ -390,7 +390,7 @@ impl FromStr for Board {
^ zobrist::castling_rights(&black.castling_rights),
checkers: BitBoard::EMPTY,
pinned: BitBoard::EMPTY,
min_half_move_clock: half_move_clock,
min_ply_clock: ply_clock,
full_moves,
};

Expand Down Expand Up @@ -478,9 +478,6 @@ impl Display for Board {
} else {
'-'.fmt(f)?;
}
f.write_fmt(format_args!(
" {} {}",
self.min_half_move_clock, self.full_moves
))
f.write_fmt(format_args!(" {} {}", self.min_ply_clock, self.full_moves))
}
}
26 changes: 15 additions & 11 deletions hash-core/src/cache.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::{array, marker::PhantomData};

pub trait CacheHash {
fn hash(&self) -> u64;
}
Expand All @@ -11,44 +13,46 @@ struct Entry<T> {
impl<T: Copy> Copy for Entry<T> {}

#[derive(Clone)]
pub struct Cache<T: Clone, const N: usize> {
data: [Option<Entry<T>>; N],
pub struct Cache<K: CacheHash, V, const N: usize> {
data: [Option<Entry<V>>; N],
_marker: PhantomData<K>,
}

impl<T: Clone, const N: usize> Cache<T, N> {
impl<K: CacheHash, V, const N: usize> Cache<K, V, N> {
// TODO: Rework this implementation to be less simple. The replacement strategy shown here
// should be tweaked to be more balanced, and of course, fixed-probing should be explored
// (probing up to some number H of buckets, and then simply replacing)
pub fn insert<K: CacheHash>(&mut self, key: &K, value: T) {
pub fn insert(&mut self, key: &K, value: V) {
let hash = key.hash();

self.data[hash as usize % N] = Some(Entry { value, hash });
}

pub fn get<K: CacheHash>(&self, key: &K) -> Option<T> {
pub fn get(&self, key: &K) -> Option<&V> {
let hash = key.hash();
let entry = &self.data[hash as usize % self.data.len()];

entry.as_ref().and_then(|entry| {
if entry.hash == hash {
Some(entry.value.clone())
Some(&entry.value)
} else {
None
}
})
}
}

impl<T: Copy, const N: usize> Cache<T, N> {
impl<K: CacheHash, V, const N: usize> Cache<K, V, N> {
pub fn new() -> Self {
Self { data: [None; N] }
Self {
data: array::from_fn(|_| None),
_marker: PhantomData,
}
}
}

impl<T: Copy, const N: usize> Default for Cache<T, N> {
impl<K: CacheHash, V, const N: usize> Default for Cache<K, V, N> {
fn default() -> Self {
Self::new()
}
}

impl<T: Copy, const N: usize> Copy for Cache<T, N> {}
87 changes: 87 additions & 0 deletions hash-core/src/game.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
use std::str::FromStr;

use hash_bootstrap::Color;

use crate::{board::Board, cache::Cache, mg, repr::Move};

const REPETITIONS: usize = 1000;

pub enum Outcome {
Win(Color),
Draw,
}

#[derive(PartialEq)]
enum Repetition {
Once,
Never,
}

pub struct Game {
board: Board,
repetitions: Cache<Board, Repetition, REPETITIONS>,
}

impl Game {
pub fn starting_position() -> Self {
Self {
board: Board::starting_position(),
repetitions: Cache::new(),
}
}

fn was_current_board_repeated_thrice(&self) -> bool {
if let Some(repetition) = self.repetitions.get(&self.board) {
*repetition == Repetition::Once
} else {
false
}
}

fn can_either_player_claim_draw(&self) -> bool {
self.board.min_ply_clock >= 100 || self.was_current_board_repeated_thrice()
}

pub fn outcome(&self) -> Option<Outcome> {
if mg::gen_moves(&self.board).is_empty() || self.can_either_player_claim_draw() {
Some(if self.board.in_check() {
Outcome::Win(!self.board.playing_color)
} else {
Outcome::Draw
})
} else {
None
}
}

pub unsafe fn make_move_unchecked(&mut self, chess_move: Move) {
self.repetitions.insert(
&self.board,
if self.repetitions.get(&self.board).is_none() {
Repetition::Never
} else {
Repetition::Once
},
);

// SAFETY: Move is assumed to be legal in this position
unsafe {
self.board.make_move_unchecked(&chess_move);
}
}

pub fn board(&self) -> &Board {
&self.board
}
}

impl FromStr for Game {
type Err = &'static str;

fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(Self {
board: Board::from_str(s)?,
repetitions: Cache::new(),
})
}
}
3 changes: 2 additions & 1 deletion hash-core/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#![feature(test, custom_test_frameworks)]
#![feature(test)]

pub mod board;
mod cache;
pub mod game;
mod index;
pub mod mg;
pub mod repr;
Expand Down
1 change: 0 additions & 1 deletion hash-network/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ edition.workspace = true
[dependencies]
hash-bootstrap = { path = "../hash-bootstrap" }
hash-core = { path = "../hash-core" }

burn = "0.9.0"
serde = "1.0.188"

Expand Down
8 changes: 6 additions & 2 deletions hash-network/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ use hash_core::{board::Board, repr::Player};

pub mod model;

fn stack<B: Backend, const D: usize, const D2: usize>(tensors: Vec<Tensor<B, D>>) -> Tensor<B, D2> {
pub fn stack<B: Backend, const D: usize, const D2: usize>(
tensors: Vec<Tensor<B, D>>,
) -> Tensor<B, D2> {
Tensor::cat(
tensors
.into_iter()
Expand Down Expand Up @@ -55,7 +57,7 @@ pub fn board_to_tensor<B: Backend>(board: Option<&Board>) -> Tensor<B, 3> {
Color::Black => Tensor::ones(Shape::new([8, 8])).neg(),
}
.unsqueeze(),
Tensor::from_floats([board.min_half_move_clock as f32; 64])
Tensor::from_floats([board.min_ply_clock as f32; 64])
.reshape(Shape::new([1, 8, 8])),
boolean_to_tensor(true).unsqueeze(), // This is for the existence layer
],
Expand All @@ -66,6 +68,8 @@ pub fn board_to_tensor<B: Backend>(board: Option<&Board>) -> Tensor<B, 3> {
}
}

// TODO: It might be the best to just fill the rest with zeroes on the tensor level, instead of
// requiring one to pass a bunch of zeros
pub fn boards_to_tensor<B: Backend>(boards: Vec<Option<&Board>>) -> Tensor<B, 3> {
stack(
boards
Expand Down
23 changes: 19 additions & 4 deletions hash-network/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use burn::{
conv::{Conv2d, Conv2dConfig},
BatchNorm, BatchNormConfig, Linear, LinearConfig, PaddingConfig2d, ReLU,
},
tensor::{backend::Backend, Tensor},
tensor::{activation, backend::Backend, Shape, Tensor},
};
use std::iter;

Expand Down Expand Up @@ -86,6 +86,11 @@ impl ConvBlockConfig {
}
}

pub struct BatchOutput<B: Backend> {
pub values: Tensor<B, 1>,
pub probabilities: Tensor<B, 2>,
}

#[derive(Module, Debug)]
pub struct Model<B: Backend> {
move_history: usize,
Expand All @@ -99,15 +104,25 @@ impl<B: Backend> Model<B> {
self.move_history
}

pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 2> {
pub fn forward(&self, input: Tensor<B, 4>) -> BatchOutput<B> {
let x = self
.conv_blocks
.iter()
.fold(input, |x, block| block.forward(x));
let x = x.flatten(1, 3);
let x = self.fc_1.forward(x);
let x = self.output.forward(x);

let shape = x.shape().dims;
let value_index_tensor = Tensor::zeros(Shape::new([shape[0], 1]));

self.output.forward(x)
let values = x.clone().gather(1, value_index_tensor.clone()).squeeze(1);
let probabilities = activation::softmax(x.slice([0..shape[0], 1..shape[1]]), 1);

BatchOutput {
values,
probabilities,
}
}
}

Expand All @@ -121,7 +136,7 @@ pub struct ModelConfig {
move_history: usize,
#[config(default = 3)]
kernel_length: usize,
#[config(default = 256)]
#[config(default = 32)]
filters: usize,
}

Expand Down
8 changes: 3 additions & 5 deletions hash-search/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,10 @@ hash-bootstrap = { path = "../hash-bootstrap" }
hash-network = { path = "../hash-network" }

arrayvec = "0.7.4"
num-traits = "0.2.16"
num-traits = "0.2.17"
rand = { version = "0.8.5", features = ["min_const_gen"] }

burn = "0.9.0"
burn-ndarray = "0.9.0"
serde = "1.0.188"
burn = { version = "0.9.0", features = ["ndarray-blas-openblas-system"] }
serde = "1.0.189"

[lints]
workspace = true
Loading

0 comments on commit 26785e0

Please sign in to comment.