Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add caching for GetAllMoves, GetAllAttackableMoves #17

Merged
merged 9 commits into from
Mar 26, 2019
90 changes: 62 additions & 28 deletions pkg/chessai/board/game_board.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"github.com/Vadman97/ChessAI3/pkg/chessai/color"
"github.com/Vadman97/ChessAI3/pkg/chessai/location"
"github.com/Vadman97/ChessAI3/pkg/chessai/piece"
"github.com/Vadman97/ChessAI3/pkg/chessai/util"
"log"
"math/rand"
"time"
Expand Down Expand Up @@ -62,6 +63,11 @@ var StartRow = map[byte]map[string]int8{
},
}

type MoveCacheEntry struct {
// color -> move
moves map[byte]interface{}
}

type Board struct {
// board stores entire layout of pieces on the Width * Height board
// more efficient to use ints - faster to copy int than set of bytes
Expand All @@ -71,18 +77,19 @@ type Board struct {
// max 4 flags if we use byte
flags byte

TestRandGen *rand.Rand
TestRandGen *rand.Rand
MoveCache, AttackableCache *util.ConcurrentBoardMap
}

func (b *Board) Hash() (result [33]byte) {
// TODO(Vadim) evenly distribute output over {1,0}^264 via SHA256?
// TODO(Vadim) really thoroughly test this for correctness
// store into map[uint64]map[uint64]map[uint64]map[uint64]map[byte]uint32
// Want to lookup score for a board using hash value
// Board stored in (8 * 4 + 1) bytes = 33bytes
for i := 0; i < Height; i++ {
for bIdx := 0; bIdx < BytesPerRow; bIdx++ {
result[i*BytesPerRow+bIdx] |= byte(b.board[i] & (PieceMask << byte(bIdx*BytesPerRow)))
p := b.board[i] & (0xFF << byte(bIdx*8)) >> byte(bIdx*8)
result[i*BytesPerRow+bIdx] |= byte(p)
}
}
result[32] = b.flags
Expand Down Expand Up @@ -123,6 +130,8 @@ func (b *Board) Copy() *Board {
newBoard.board[i] = b.board[i]
}
newBoard.flags = b.flags
newBoard.MoveCache = b.MoveCache
newBoard.AttackableCache = b.AttackableCache
return &newBoard
}

Expand All @@ -131,6 +140,8 @@ func (b *Board) ResetDefault() {
b.board[1] = StartingRowHex[1]
b.board[6] = StartingRowHex[6]
b.board[7] = StartingRowHex[7]
b.MoveCache = util.NewConcurrentBoardMap()
b.AttackableCache = util.NewConcurrentBoardMap()
}

func (b *Board) ResetDefaultSlow() {
Expand Down Expand Up @@ -246,56 +257,79 @@ func (b *Board) RandomizeIllegal() {
b.flags = byte(b.TestRandGen.Uint32())
}

func (b *Board) GetAllMoves(c byte) *[]location.Move {
/*
* Only this is cached and not GetAllAttackableMoves for now because this calls GetAllAttackableMoves
* May need to cache that one too when we use it for CheckMate / Tie evaluation
*/
func (b *Board) GetAllMoves(color byte) *[]location.Move {
// TODO(Vadim) when king under attack, moves that block check are the only possible ones
black, white := b.getAllMoves(c == color.Black, c == color.White)
if c == color.Black {
return black
} else if c == color.White {
return white
h := b.Hash()
entry := &MoveCacheEntry{
moves: make(map[byte]interface{}),
}
Vadman97 marked this conversation as resolved.
Show resolved Hide resolved
return nil
if cacheEntry, cacheExists := b.MoveCache.Read(&h); cacheExists {
entry = cacheEntry.(*MoveCacheEntry)
// we've gotten the other color but not the one we want
if entry.moves[color] == nil {
entry.moves[color] = b.getAllMoves(color)
b.MoveCache.Store(&h, entry)
}
} else {
entry.moves[color] = b.getAllMoves(color)
b.MoveCache.Store(&h, entry)
}
return entry.moves[color].(*[]location.Move)
}

func (b *Board) getAllMoves(getBlack, getWhite bool) (black, white *[]location.Move) {
var blackMoves, whiteMoves []location.Move
// TODO(Vadim) think of how to optimize this, profile it and write tests
func (b *Board) getAllMoves(color byte) *[]location.Move {
var moves []location.Move
for r := 0; r < Height; r++ {
// this is just a speedup - if the whole row is empty don't look at pieces
if b.board[r] == 0 {
continue
}
for c := 0; c < Width; c++ {
l := location.Location{int8(r), int8(c)}
if !b.IsEmpty(l) {
p := b.GetPiece(l)
moves := p.GetMoves(b)
if moves != nil {
if getBlack && p.GetColor() == color.Black {
blackMoves = append(blackMoves, *moves...)
} else if getWhite && p.GetColor() == color.White {
whiteMoves = append(whiteMoves, *moves...)
}
if p.GetColor() == color {
moves = append(moves, *p.GetMoves(b)...)
}
}
}
}
if getBlack {
black = &blackMoves
return &moves
}

/*
* Caches getAllAttackableMoves
*/
func (b *Board) GetAllAttackableMoves(color byte) AttackableBoard {
h := b.Hash()
entry := &MoveCacheEntry{
moves: make(map[byte]interface{}),
}
if getWhite {
white = &whiteMoves
if cacheEntry, cacheExists := b.AttackableCache.Read(&h); cacheExists {
entry = cacheEntry.(*MoveCacheEntry)
// we've gotten the other color but not the one we want
if entry.moves[color] == nil {
entry.moves[color] = b.getAllAttackableMoves(color)
b.AttackableCache.Store(&h, entry)
}
} else {
entry.moves[color] = b.getAllAttackableMoves(color)
b.AttackableCache.Store(&h, entry)
}
return
return entry.moves[color].(AttackableBoard)
}

/**
* Returns all attack moves for a specific color.
* TODO We need to cache this!
*/
func (b *Board) GetAllAttackableMoves(color byte) AttackableBoard {
func (b *Board) getAllAttackableMoves(color byte) AttackableBoard {
attackable := CreateEmptyAttackableBoard()
for r := 0; r < Height; r++ {
//TODO (Devan) figure out what this check is for
// this is just a speedup - if the whole row is empty don't look at pieces
if b.board[r] == 0 {
continue
}
Expand Down
2 changes: 2 additions & 0 deletions pkg/chessai/game/game.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ type Game struct {
}

func (g *Game) PlayTurn() {
start := time.Now()
g.Players[g.CurrentTurnColor].MakeMove(g.CurrentBoard)
g.PlayTime[g.CurrentTurnColor] += time.Now().Sub(start)
g.CurrentTurnColor ^= 1
g.MovesPlayed++
}
Expand Down
25 changes: 18 additions & 7 deletions pkg/chessai/player/ai/ai_player.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,17 +110,13 @@ func (p *Player) GetBestMove(b *board.Board) *location.Move {
} else {
var m *ScoredMove
if p.Algorithm == AlgorithmMiniMax {
m = p.MiniMax(b, 4, p.PlayerColor)
m = p.MiniMax(b, 2, p.PlayerColor)
} else if p.Algorithm == AlgorithmAlphaBetaWithMemory {
m = p.AlphaBetaWithMemory(b, 8, NegInf, PosInf, p.PlayerColor)
m = p.AlphaBetaWithMemory(b, 4, NegInf, PosInf, p.PlayerColor)
} else {
panic("invalid ai algorithm")
}
c := "Black"
if p.PlayerColor == color.White {
c = "White"
}
fmt.Printf("AI (%s - %s) best move leads to score %d\n", p.Algorithm, c, m.Score)
fmt.Printf("%s best move leads to score %d\n", p.Repr(), m.Score)
debugBoard := b.Copy()
//for i := 0; i < len(m.MoveSequence); i++ {
for i := len(m.MoveSequence) - 1; i >= 0; i-- {
Expand All @@ -135,8 +131,15 @@ func (p *Player) GetBestMove(b *board.Board) *location.Move {
fmt.Printf("\t\t%s\n", move.Print())
board.MakeMove(&move, debugBoard)
}
fmt.Printf("Board evaluation metrics\n")
p.evaluationMap.PrintMetrics()
fmt.Printf("Transposition table metrics\n")
p.alphaBetaTable.PrintMetrics()
fmt.Printf("Move cache metrics\n")
b.MoveCache.PrintMetrics()
fmt.Printf("Attack Move cache metrics\n")
b.AttackableCache.PrintMetrics()
fmt.Printf("\n\n")
return &m.Move
}
}
Expand Down Expand Up @@ -201,3 +204,11 @@ func (p *Player) EvaluateBoard(b *board.Board) *board.Evaluation {
p.evaluationMap.Store(&hash, int32(eval.TotalScore))
return eval
}

func (p *Player) Repr() string {
c := "Black"
if p.PlayerColor == color.White {
c = "White"
}
return fmt.Sprintf("AI (%s - %s)", p.Algorithm, c)
}
1 change: 1 addition & 0 deletions pkg/chessai/player/ai/minimax.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ func (p *Player) MiniMax(b *board.Board, depth int, currentPlayer byte) *ScoredM
}

var best ScoredMove
// TODO(Vadim) if depth is odd, flip these?
if currentPlayer == p.PlayerColor {
// maximizing player
best.Score = NegInf
Expand Down
25 changes: 14 additions & 11 deletions pkg/chessai/util/concurrent_board_map.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ const (
)

type ConcurrentBoardMap struct {
scoreMap [NumSlices]map[uint64]map[uint64]map[uint64]map[uint64]map[byte]interface{}
locks [NumSlices]sync.RWMutex
lockUsage [NumSlices]uint64
scoreMap [NumSlices]map[uint64]map[uint64]map[uint64]map[uint64]map[byte]interface{}
locks [NumSlices]sync.RWMutex
lockUsage [NumSlices]uint64
entriesWritten [NumSlices]uint64
}

func NewConcurrentBoardMap() *ConcurrentBoardMap {
Expand All @@ -27,7 +28,7 @@ func NewConcurrentBoardMap() *ConcurrentBoardMap {
return &m
}

func getIdx(hash *[33]byte) (idx [4]uint64) {
func HashToMapKey(hash *[33]byte) (idx [4]uint64) {
for x := 0; x < 32; x += 8 {
idx[x/8] = binary.BigEndian.Uint64((*hash)[x : x+8])
}
Expand All @@ -39,13 +40,13 @@ func (m *ConcurrentBoardMap) getLock(hash *[33]byte) (*sync.RWMutex, uint32) {
for i := 0; i < 28; i += 4 {
s += (binary.BigEndian.Uint32(hash[i:i+4]) / NumSlices) % NumSlices
}
s += uint32(hash[32]) % NumSlices
s = (s + uint32(hash[32])) % NumSlices
atomic.AddUint64(&m.lockUsage[s], 1)
return &m.locks[s], s
}

func (m *ConcurrentBoardMap) Store(hash *[33]byte, value interface{}) {
idx := getIdx(hash)
idx := HashToMapKey(hash)

lock, lockIdx := m.getLock(hash)
lock.Lock()
Expand All @@ -67,12 +68,12 @@ func (m *ConcurrentBoardMap) Store(hash *[33]byte, value interface{}) {
if !ok {
m.scoreMap[lockIdx][idx[0]][idx[1]][idx[2]][idx[3]] = make(map[byte]interface{})
}

m.entriesWritten[lockIdx]++
m.scoreMap[lockIdx][idx[0]][idx[1]][idx[2]][idx[3]][(*hash)[32]] = value
}

func (m *ConcurrentBoardMap) Read(hash *[33]byte) (interface{}, bool) {
idx := getIdx(hash)
idx := HashToMapKey(hash)

lock, lockIdx := m.getLock(hash)
lock.Lock()
Expand All @@ -98,10 +99,12 @@ func (m *ConcurrentBoardMap) Read(hash *[33]byte) (interface{}, bool) {

func (m *ConcurrentBoardMap) PrintMetrics() {
//fmt.Printf("Lock Usages: \n")
total := uint64(0)
totalLocks, totalEntries := uint64(0), uint64(0)
for i := 0; i < NumSlices; i++ {
//fmt.Printf("Slice #%d, Used #%d times\n", i, m.lockUsage[i])
total += m.lockUsage[i]
totalLocks += m.lockUsage[i]
totalEntries += m.entriesWritten[i]
}
fmt.Printf("Total entries in map %d\n", total)
fmt.Printf("\tLock usages in map %d\n", totalLocks)
fmt.Printf("\tTotal entries in map %d\n", totalEntries)
}
6 changes: 3 additions & 3 deletions pkg/chessai/util/transposition_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func NewTranspositionTable() *TranspositionTable {
}

func (m *TranspositionTable) Store(hash *[33]byte, entry *TranspositionTableEntry) {
idx := getIdx(hash)
idx := HashToMapKey(hash)

_, ok := m.entryMap[idx[0]]
if !ok {
Expand All @@ -48,7 +48,7 @@ func (m *TranspositionTable) Store(hash *[33]byte, entry *TranspositionTableEntr
}

func (m *TranspositionTable) Read(hash *[33]byte) (*TranspositionTableEntry, bool) {
idx := getIdx(hash)
idx := HashToMapKey(hash)

m1, ok := m.entryMap[idx[0]]
if ok {
Expand All @@ -69,5 +69,5 @@ func (m *TranspositionTable) Read(hash *[33]byte) (*TranspositionTableEntry, boo
}

func (m *TranspositionTable) PrintMetrics() {
fmt.Printf("Total entries in transposition table %d\n", m.numStored)
fmt.Printf("\tTotal entries in transposition table %d\n", m.numStored)
}
8 changes: 4 additions & 4 deletions test/ai_basic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@ import (
)

func TestBoardAI(t *testing.T) {
// TODO(Vadim) skip until it works better
t.Skip()
const MovesToPlay = 100
const TimeToPlay = 10 * time.Second
const TimeToPlay = 60 * time.Second

aiPlayerSmart := ai.NewAIPlayer(color.Black)
aiPlayerSmart.Algorithm = ai.AlgorithmMiniMax
Expand All @@ -34,13 +32,15 @@ func TestBoardAI(t *testing.T) {
g.PlayTurn()
fmt.Printf("Move %d\n", g.MovesPlayed)
fmt.Println(g.CurrentBoard.Print())
fmt.Printf("White %s has thought for %s\n", g.Players[color.White].Repr(), g.PlayTime[color.White])
fmt.Printf("Black %s has thought for %s\n", g.Players[color.Black].Repr(), g.PlayTime[color.Black])
util.PrintMemStats()
}

fmt.Println("After moves:")
fmt.Println(g.CurrentBoard.Print())
// comment out printing inside loop for accurate timing
fmt.Printf("Played %d moves in %d ms.\n", MovesToPlay, time.Now().Sub(start)/time.Millisecond)
fmt.Printf("Played %d moves in %d ms.\n", g.MovesPlayed, time.Now().Sub(start)/time.Millisecond)

smartScore := aiPlayerSmart.EvaluateBoard(g.CurrentBoard).TotalScore
dumbScore := aiPlayerDumb.EvaluateBoard(g.CurrentBoard).TotalScore
Expand Down