Skip to content

Commit

Permalink
Pass in the learning model instance instead of the class
Browse files Browse the repository at this point in the history
  • Loading branch information
shayakbanerjee committed Dec 18, 2017
1 parent c242ced commit e2bc98d
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 47 deletions.
3 changes: 1 addition & 2 deletions learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def loadLearning(self, filename):

class NNUltimateLearning(GenericLearning):
STATE_TO_NUMBER_MAP = {GridStates.EMPTY: 0, GridStates.PLAYER_O: -1, GridStates.PLAYER_X: 1}
TABLE_LEARNING_FILE = 'table_learning.json'

def __init__(self, DecisionClass=TTTBoardDecision):
self.DecisionClass = DecisionClass
Expand All @@ -81,9 +80,9 @@ def initializeModel(self):
self.model.add(Dense(1, activation='linear', kernel_initializer='glorot_uniform'))
self.model.compile(loss='mean_squared_error', optimizer='adam', metrics=['accuracy'])
plot_model(self.model, to_file='model.png')
#self.initialModelTraining(self.TABLE_LEARNING_FILE)

def initialModelTraining(self, jsonFile):
# If the neural network model should be seeded from some known state/value pairs
import os
if os.path.isfile(jsonFile):
self.values = json.load(open(jsonFile, 'r'))
Expand Down
34 changes: 2 additions & 32 deletions plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,40 +15,10 @@ def setFontProperties():
matplotlib.rc('text', usetex=False)
matplotlib.rcParams.update({'font.size': 14})

def visualizeStateData(stateDict, dataBind, dataKey,
thresholds=None,
mapTitle=None,
jsonFile=None,
htmlFile=None):
statePanda = {}
stateCount = 0
for (key, value) in stateDict.items():
statePanda[stateCount] = [key, value]
stateCount += 1
stateData = pd.DataFrame.from_dict(statePanda, orient='index')
stateData.columns = [dataKey, dataBind]
visualizePandasData(stateData, dataBind, dataKey, thresholds, mapTitle, jsonFile, htmlFile)

def visualizePandasData(stateData,dataBind, dataKey,
thresholds=None,
mapTitle=None,
jsonFile=None,
htmlFile=None):
state_topo = 'https://raw.githubusercontent.com/wrobstory/vincent_map_data/master/us_states.topo.json'
geo_data = [{'name': 'states', 'url': state_topo, 'feature': 'us_states.geo'}]
vis = vincent.Map(data=stateData, geo_data=geo_data, scale=1000,
projection='albersUsa', data_bind=dataBind, data_key=dataKey,
map_key={'states':'properties.NAME'}, brew='RdPu')
vis.scales[0].type='threshold'
vis.scales[0].domain = thresholds
vis.legend(title=mapTitle)
if jsonFile is not None and htmlFile is not None:
vis.to_json(jsonFile, html_out=True, html_path=htmlFile)

def drawXYPlotByFactor(dataDict, xlabel='', ylabel='', legend=None,
title=None, logy=False, location=2):
title=None, logy=False, location=5):
# Assuming that the data is in the format { factor: [(x1, y1),(x2,y2),...] }
PLOT_STYLES = ['r^-', 'bo-', 'g^-', 'ks-', 'co-', 'ms-', 'y^-']
PLOT_STYLES = ['r^-', 'bo-', 'g^-', 'ks-', 'ms-', 'co-', 'y^-']
styleCount = 0
displayedPlots = []
pltfn = plt.semilogy if logy else plt.plot
Expand Down
21 changes: 11 additions & 10 deletions test_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,21 @@ def playTTTAndPlotResults():
learningPlayer = RLTTTPlayer()
randomPlayer = RandomTTTPlayer()
results = []
numberOfSetsOfGames = 5
numberOfSetsOfGames = 50
for i in range(numberOfSetsOfGames):
games = GameSequence(1000, learningPlayer, randomPlayer)
games = GameSequence(100, learningPlayer, randomPlayer)
results.append(games.playGamesAndGetWinPercent())
plotValues = {'X Win Fraction': zip(range(numberOfSetsOfGames), map(lambda x: x[0], results)),
'O Win Fraction': zip(range(numberOfSetsOfGames), map(lambda x: x[1], results)),
'Draw Fraction': zip(range(numberOfSetsOfGames), map(lambda x: x[2], results))}
drawXYPlotByFactor(plotValues, 'Set Number', 'Fraction')
drawXYPlotByFactor(plotValues, 'Number of Sets (of 100 Games)', 'Fraction', title='RL Player (X) vs. Random Player (O)')

def playUltimateAndPlotResults():
learningPlayer = RLUTTTPlayer(NNUltimateLearning)
learningModel = NNUltimateLearning(UTTTBoardDecision)
learningPlayer = RLUTTTPlayer(learningModel)
randomPlayer = RandomUTTTPlayer()
results = []
numberOfSetsOfGames = 50
numberOfSetsOfGames = 60
if os.path.isfile(LEARNING_FILE):
learningPlayer.loadLearning(LEARNING_FILE)
for i in range(numberOfSetsOfGames):
Expand All @@ -37,10 +38,10 @@ def playUltimateAndPlotResults():
plotValues = {'X Win Fraction': zip(range(numberOfSetsOfGames), map(lambda x: x[0], results)),
'O Win Fraction': zip(range(numberOfSetsOfGames), map(lambda x: x[1], results)),
'Draw Fraction': zip(range(numberOfSetsOfGames), map(lambda x: x[2], results))}
drawXYPlotByFactor(plotValues, 'Set Number', 'Fraction')
drawXYPlotByFactor(plotValues, 'Number of Sets (of 100 Games)', 'Fraction',title='RL Player (O) vs. Random Player (X)')

def playUltimateForTraining():
learningPlayer = RLUTTTPlayer(TableLearning)
learningPlayer = RLUTTTPlayer(TableLearning())
randomPlayer = RandomUTTTPlayer()
games = GameSequence(4000, learningPlayer, randomPlayer, BoardClass=UTTTBoard, BoardDecisionClass=UTTTBoardDecision)
games.playGamesAndGetWinPercent()
Expand All @@ -60,10 +61,10 @@ def plotResultsFromFile(resultsFile):
plotValues = {'X Win Fraction': zip(range(numberOfSetsOfGames), map(lambda x: x[0], results)),
'O Win Fraction': zip(range(numberOfSetsOfGames), map(lambda x: x[1], results)),
'Draw Fraction': zip(range(numberOfSetsOfGames), map(lambda x: x[2], results))}
drawXYPlotByFactor(plotValues, 'Set Number', 'Fraction')
drawXYPlotByFactor(plotValues, 'Number of Sets (of 100 Games)', 'Fraction', title='RL Player (O) vs. Random Player (X)')

if __name__ == '__main__':
#playTTTAndPlotResults()
#playUltimateForTraining()
#playUltimateAndPlotResults()
plotResultsFromFile('results/ultmate_nn1_results.csv')
playUltimateAndPlotResults()
#plotResultsFromFile('results/ultimate_nn1_results_o.csv')
6 changes: 3 additions & 3 deletions ultimateplayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ def learnFromMove(self, prevBoardState):
pass # Random player does not learn from move

class RLUTTTPlayer(UTTTPlayer):
def __init__(self, Learning=TableLearning):
self.learningAlgo = Learning(UTTTBoardDecision)
def __init__(self, learningModel):
self.learningAlgo = learningModel
super(RLUTTTPlayer, self).__init__()

def printValues(self):
Expand Down Expand Up @@ -96,4 +96,4 @@ def loadLearning(self, filename):
if __name__ == '__main__':
board = UTTTBoard()
player1 = RandomUTTTPlayer()
player2 = RLUTTTPlayer()
player2 = RLUTTTPlayer(TableLearning(UTTTBoardDecision))

0 comments on commit e2bc98d

Please sign in to comment.