Skip to content

Commit

Permalink
Merge pull request #17 from Vadman97/cacheMoves
Browse files Browse the repository at this point in the history
Add caching for GetAllMoves, GetAllAttackableMoves
  • Loading branch information
Devan Adhia authored Mar 26, 2019
2 parents 62c3571 + c6425b6 commit 98c750c
Show file tree
Hide file tree
Showing 7 changed files with 104 additions and 53 deletions.
90 changes: 62 additions & 28 deletions pkg/chessai/board/game_board.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"github.com/Vadman97/ChessAI3/pkg/chessai/color"
"github.com/Vadman97/ChessAI3/pkg/chessai/location"
"github.com/Vadman97/ChessAI3/pkg/chessai/piece"
"github.com/Vadman97/ChessAI3/pkg/chessai/util"
"log"
"math/rand"
"time"
Expand Down Expand Up @@ -62,6 +63,11 @@ var StartRow = map[byte]map[string]int8{
},
}

type MoveCacheEntry struct {
// color -> move
moves map[byte]interface{}
}

type Board struct {
// board stores entire layout of pieces on the Width * Height board
// more efficient to use ints - faster to copy int than set of bytes
Expand All @@ -71,18 +77,19 @@ type Board struct {
// max 4 flags if we use byte
flags byte

TestRandGen *rand.Rand
TestRandGen *rand.Rand
MoveCache, AttackableCache *util.ConcurrentBoardMap
}

func (b *Board) Hash() (result [33]byte) {
// TODO(Vadim) evenly distribute output over {1,0}^264 via SHA256?
// TODO(Vadim) really thoroughly test this for correctness
// store into map[uint64]map[uint64]map[uint64]map[uint64]map[byte]uint32
// Want to lookup score for a board using hash value
// Board stored in (8 * 4 + 1) bytes = 33bytes
for i := 0; i < Height; i++ {
for bIdx := 0; bIdx < BytesPerRow; bIdx++ {
result[i*BytesPerRow+bIdx] |= byte(b.board[i] & (PieceMask << byte(bIdx*BytesPerRow)))
p := b.board[i] & (0xFF << byte(bIdx*8)) >> byte(bIdx*8)
result[i*BytesPerRow+bIdx] |= byte(p)
}
}
result[32] = b.flags
Expand Down Expand Up @@ -123,6 +130,8 @@ func (b *Board) Copy() *Board {
newBoard.board[i] = b.board[i]
}
newBoard.flags = b.flags
newBoard.MoveCache = b.MoveCache
newBoard.AttackableCache = b.AttackableCache
return &newBoard
}

Expand All @@ -131,6 +140,8 @@ func (b *Board) ResetDefault() {
b.board[1] = StartingRowHex[1]
b.board[6] = StartingRowHex[6]
b.board[7] = StartingRowHex[7]
b.MoveCache = util.NewConcurrentBoardMap()
b.AttackableCache = util.NewConcurrentBoardMap()
}

func (b *Board) ResetDefaultSlow() {
Expand Down Expand Up @@ -246,56 +257,79 @@ func (b *Board) RandomizeIllegal() {
b.flags = byte(b.TestRandGen.Uint32())
}

func (b *Board) GetAllMoves(c byte) *[]location.Move {
/*
* Only this is cached and not GetAllAttackableMoves for now because this calls GetAllAttackableMoves
* May need to cache that one too when we use it for CheckMate / Tie evaluation
*/
func (b *Board) GetAllMoves(color byte) *[]location.Move {
// TODO(Vadim) when king under attack, moves that block check are the only possible ones
black, white := b.getAllMoves(c == color.Black, c == color.White)
if c == color.Black {
return black
} else if c == color.White {
return white
h := b.Hash()
entry := &MoveCacheEntry{
moves: make(map[byte]interface{}),
}
return nil
if cacheEntry, cacheExists := b.MoveCache.Read(&h); cacheExists {
entry = cacheEntry.(*MoveCacheEntry)
// we've gotten the other color but not the one we want
if entry.moves[color] == nil {
entry.moves[color] = b.getAllMoves(color)
b.MoveCache.Store(&h, entry)
}
} else {
entry.moves[color] = b.getAllMoves(color)
b.MoveCache.Store(&h, entry)
}
return entry.moves[color].(*[]location.Move)
}

func (b *Board) getAllMoves(getBlack, getWhite bool) (black, white *[]location.Move) {
var blackMoves, whiteMoves []location.Move
// TODO(Vadim) think of how to optimize this, profile it and write tests
func (b *Board) getAllMoves(color byte) *[]location.Move {
var moves []location.Move
for r := 0; r < Height; r++ {
// this is just a speedup - if the whole row is empty don't look at pieces
if b.board[r] == 0 {
continue
}
for c := 0; c < Width; c++ {
l := location.Location{int8(r), int8(c)}
if !b.IsEmpty(l) {
p := b.GetPiece(l)
moves := p.GetMoves(b)
if moves != nil {
if getBlack && p.GetColor() == color.Black {
blackMoves = append(blackMoves, *moves...)
} else if getWhite && p.GetColor() == color.White {
whiteMoves = append(whiteMoves, *moves...)
}
if p.GetColor() == color {
moves = append(moves, *p.GetMoves(b)...)
}
}
}
}
if getBlack {
black = &blackMoves
return &moves
}

/*
* Caches getAllAttackableMoves
*/
func (b *Board) GetAllAttackableMoves(color byte) AttackableBoard {
h := b.Hash()
entry := &MoveCacheEntry{
moves: make(map[byte]interface{}),
}
if getWhite {
white = &whiteMoves
if cacheEntry, cacheExists := b.AttackableCache.Read(&h); cacheExists {
entry = cacheEntry.(*MoveCacheEntry)
// we've gotten the other color but not the one we want
if entry.moves[color] == nil {
entry.moves[color] = b.getAllAttackableMoves(color)
b.AttackableCache.Store(&h, entry)
}
} else {
entry.moves[color] = b.getAllAttackableMoves(color)
b.AttackableCache.Store(&h, entry)
}
return
return entry.moves[color].(AttackableBoard)
}

/**
* Returns all attack moves for a specific color.
* TODO We need to cache this!
*/
func (b *Board) GetAllAttackableMoves(color byte) AttackableBoard {
func (b *Board) getAllAttackableMoves(color byte) AttackableBoard {
attackable := CreateEmptyAttackableBoard()
for r := 0; r < Height; r++ {
//TODO (Devan) figure out what this check is for
// this is just a speedup - if the whole row is empty don't look at pieces
if b.board[r] == 0 {
continue
}
Expand Down
2 changes: 2 additions & 0 deletions pkg/chessai/game/game.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ type Game struct {
}

func (g *Game) PlayTurn() {
start := time.Now()
g.Players[g.CurrentTurnColor].MakeMove(g.CurrentBoard)
g.PlayTime[g.CurrentTurnColor] += time.Now().Sub(start)
g.CurrentTurnColor ^= 1
g.MovesPlayed++
}
Expand Down
25 changes: 18 additions & 7 deletions pkg/chessai/player/ai/ai_player.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,17 +110,13 @@ func (p *Player) GetBestMove(b *board.Board) *location.Move {
} else {
var m *ScoredMove
if p.Algorithm == AlgorithmMiniMax {
m = p.MiniMax(b, 4, p.PlayerColor)
m = p.MiniMax(b, 2, p.PlayerColor)
} else if p.Algorithm == AlgorithmAlphaBetaWithMemory {
m = p.AlphaBetaWithMemory(b, 8, NegInf, PosInf, p.PlayerColor)
m = p.AlphaBetaWithMemory(b, 4, NegInf, PosInf, p.PlayerColor)
} else {
panic("invalid ai algorithm")
}
c := "Black"
if p.PlayerColor == color.White {
c = "White"
}
fmt.Printf("AI (%s - %s) best move leads to score %d\n", p.Algorithm, c, m.Score)
fmt.Printf("%s best move leads to score %d\n", p.Repr(), m.Score)
debugBoard := b.Copy()
//for i := 0; i < len(m.MoveSequence); i++ {
for i := len(m.MoveSequence) - 1; i >= 0; i-- {
Expand All @@ -135,8 +131,15 @@ func (p *Player) GetBestMove(b *board.Board) *location.Move {
fmt.Printf("\t\t%s\n", move.Print())
board.MakeMove(&move, debugBoard)
}
fmt.Printf("Board evaluation metrics\n")
p.evaluationMap.PrintMetrics()
fmt.Printf("Transposition table metrics\n")
p.alphaBetaTable.PrintMetrics()
fmt.Printf("Move cache metrics\n")
b.MoveCache.PrintMetrics()
fmt.Printf("Attack Move cache metrics\n")
b.AttackableCache.PrintMetrics()
fmt.Printf("\n\n")
return &m.Move
}
}
Expand Down Expand Up @@ -201,3 +204,11 @@ func (p *Player) EvaluateBoard(b *board.Board) *board.Evaluation {
p.evaluationMap.Store(&hash, int32(eval.TotalScore))
return eval
}

func (p *Player) Repr() string {
c := "Black"
if p.PlayerColor == color.White {
c = "White"
}
return fmt.Sprintf("AI (%s - %s)", p.Algorithm, c)
}
1 change: 1 addition & 0 deletions pkg/chessai/player/ai/minimax.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ func (p *Player) MiniMax(b *board.Board, depth int, currentPlayer byte) *ScoredM
}

var best ScoredMove
// TODO(Vadim) if depth is odd, flip these?
if currentPlayer == p.PlayerColor {
// maximizing player
best.Score = NegInf
Expand Down
25 changes: 14 additions & 11 deletions pkg/chessai/util/concurrent_board_map.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ const (
)

type ConcurrentBoardMap struct {
scoreMap [NumSlices]map[uint64]map[uint64]map[uint64]map[uint64]map[byte]interface{}
locks [NumSlices]sync.RWMutex
lockUsage [NumSlices]uint64
scoreMap [NumSlices]map[uint64]map[uint64]map[uint64]map[uint64]map[byte]interface{}
locks [NumSlices]sync.RWMutex
lockUsage [NumSlices]uint64
entriesWritten [NumSlices]uint64
}

func NewConcurrentBoardMap() *ConcurrentBoardMap {
Expand All @@ -27,7 +28,7 @@ func NewConcurrentBoardMap() *ConcurrentBoardMap {
return &m
}

func getIdx(hash *[33]byte) (idx [4]uint64) {
func HashToMapKey(hash *[33]byte) (idx [4]uint64) {
for x := 0; x < 32; x += 8 {
idx[x/8] = binary.BigEndian.Uint64((*hash)[x : x+8])
}
Expand All @@ -39,13 +40,13 @@ func (m *ConcurrentBoardMap) getLock(hash *[33]byte) (*sync.RWMutex, uint32) {
for i := 0; i < 28; i += 4 {
s += (binary.BigEndian.Uint32(hash[i:i+4]) / NumSlices) % NumSlices
}
s += uint32(hash[32]) % NumSlices
s = (s + uint32(hash[32])) % NumSlices
atomic.AddUint64(&m.lockUsage[s], 1)
return &m.locks[s], s
}

func (m *ConcurrentBoardMap) Store(hash *[33]byte, value interface{}) {
idx := getIdx(hash)
idx := HashToMapKey(hash)

lock, lockIdx := m.getLock(hash)
lock.Lock()
Expand All @@ -67,12 +68,12 @@ func (m *ConcurrentBoardMap) Store(hash *[33]byte, value interface{}) {
if !ok {
m.scoreMap[lockIdx][idx[0]][idx[1]][idx[2]][idx[3]] = make(map[byte]interface{})
}

m.entriesWritten[lockIdx]++
m.scoreMap[lockIdx][idx[0]][idx[1]][idx[2]][idx[3]][(*hash)[32]] = value
}

func (m *ConcurrentBoardMap) Read(hash *[33]byte) (interface{}, bool) {
idx := getIdx(hash)
idx := HashToMapKey(hash)

lock, lockIdx := m.getLock(hash)
lock.Lock()
Expand All @@ -98,10 +99,12 @@ func (m *ConcurrentBoardMap) Read(hash *[33]byte) (interface{}, bool) {

func (m *ConcurrentBoardMap) PrintMetrics() {
//fmt.Printf("Lock Usages: \n")
total := uint64(0)
totalLocks, totalEntries := uint64(0), uint64(0)
for i := 0; i < NumSlices; i++ {
//fmt.Printf("Slice #%d, Used #%d times\n", i, m.lockUsage[i])
total += m.lockUsage[i]
totalLocks += m.lockUsage[i]
totalEntries += m.entriesWritten[i]
}
fmt.Printf("Total entries in map %d\n", total)
fmt.Printf("\tLock usages in map %d\n", totalLocks)
fmt.Printf("\tTotal entries in map %d\n", totalEntries)
}
6 changes: 3 additions & 3 deletions pkg/chessai/util/transposition_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func NewTranspositionTable() *TranspositionTable {
}

func (m *TranspositionTable) Store(hash *[33]byte, entry *TranspositionTableEntry) {
idx := getIdx(hash)
idx := HashToMapKey(hash)

_, ok := m.entryMap[idx[0]]
if !ok {
Expand All @@ -48,7 +48,7 @@ func (m *TranspositionTable) Store(hash *[33]byte, entry *TranspositionTableEntr
}

func (m *TranspositionTable) Read(hash *[33]byte) (*TranspositionTableEntry, bool) {
idx := getIdx(hash)
idx := HashToMapKey(hash)

m1, ok := m.entryMap[idx[0]]
if ok {
Expand All @@ -69,5 +69,5 @@ func (m *TranspositionTable) Read(hash *[33]byte) (*TranspositionTableEntry, boo
}

func (m *TranspositionTable) PrintMetrics() {
fmt.Printf("Total entries in transposition table %d\n", m.numStored)
fmt.Printf("\tTotal entries in transposition table %d\n", m.numStored)
}
8 changes: 4 additions & 4 deletions test/ai_basic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@ import (
)

func TestBoardAI(t *testing.T) {
// TODO(Vadim) skip until it works better
t.Skip()
const MovesToPlay = 100
const TimeToPlay = 10 * time.Second
const TimeToPlay = 60 * time.Second

aiPlayerSmart := ai.NewAIPlayer(color.Black)
aiPlayerSmart.Algorithm = ai.AlgorithmMiniMax
Expand All @@ -34,13 +32,15 @@ func TestBoardAI(t *testing.T) {
g.PlayTurn()
fmt.Printf("Move %d\n", g.MovesPlayed)
fmt.Println(g.CurrentBoard.Print())
fmt.Printf("White %s has thought for %s\n", g.Players[color.White].Repr(), g.PlayTime[color.White])
fmt.Printf("Black %s has thought for %s\n", g.Players[color.Black].Repr(), g.PlayTime[color.Black])
util.PrintMemStats()
}

fmt.Println("After moves:")
fmt.Println(g.CurrentBoard.Print())
// comment out printing inside loop for accurate timing
fmt.Printf("Played %d moves in %d ms.\n", MovesToPlay, time.Now().Sub(start)/time.Millisecond)
fmt.Printf("Played %d moves in %d ms.\n", g.MovesPlayed, time.Now().Sub(start)/time.Millisecond)

smartScore := aiPlayerSmart.EvaluateBoard(g.CurrentBoard).TotalScore
dumbScore := aiPlayerDumb.EvaluateBoard(g.CurrentBoard).TotalScore
Expand Down

0 comments on commit 98c750c

Please sign in to comment.