From 48f5734df312b56307ea65499ea9a3b1408c17a1 Mon Sep 17 00:00:00 2001 From: ianagbip1oti Date: Sat, 27 Jan 2024 21:18:59 +0000 Subject: [PATCH] =?UTF-8?q?=F0=9F=93=A6=20NEW:=20Include=20best=20move=20i?= =?UTF-8?q?n=20data=20gen?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/bin/train-value.rs | 2 +- src/train/data.rs | 9 +++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/bin/train-value.rs b/src/bin/train-value.rs index 4c2c85b..1e9391e 100644 --- a/src/bin/train-value.rs +++ b/src/bin/train-value.rs @@ -1,4 +1,4 @@ -use princhess::train::{ValueNetwork, TrainingPosition}; +use princhess::train::{TrainingPosition, ValueNetwork}; use goober::{FeedForwardNetwork, OutputLayer, Vector}; use std::env; diff --git a/src/train/data.rs b/src/train/data.rs index 5930343..d1b495b 100644 --- a/src/train/data.rs +++ b/src/train/data.rs @@ -13,13 +13,15 @@ pub struct TrainingPosition { occupied: Bitboard, pieces: [u8; 16], stm: Color, - result: i8, - #[allow(dead_code)] + result: i8, evaluation: i32, previous_moves: [Move; 4], + #[allow(dead_code)] + best_move: Move, + #[allow(dead_code)] legal_moves: [Move; TrainingPosition::MAX_MOVES], @@ -131,6 +133,8 @@ impl From<&SearchTree> for TrainingPosition { v => pv.sum_rewards() / i64::from(v), } as i32; + let best_move = *pv.get_move(); + // white relative evaluation evaluation = stm .fold(evaluation, -evaluation) @@ -147,6 +151,7 @@ impl From<&SearchTree> for TrainingPosition { result, evaluation, previous_moves, + best_move, legal_moves, visits, }