TaherFattahi's picture
init: tetris neural network model with q learning
03b0d13
import model
import tetris
import sys
import representation
import random
from pathlib import Path
script_dir = Path(__file__).parent.resolve()
checkpoints_dir = script_dir / "checkpoints"
checkpoints_dir.mkdir(exist_ok=True)
log_file_path = checkpoints_dir / "log.txt"
# if you want to start from a checkpoint, fill this in with the path to the .pth file. If wanting to start from a new NN, leave blank!
model_save_path = r""
# training settings
gamma:float = 0.5
epsilon:float = 0.2
# training config
batch_size:int = 100 # the number of experiences that will be collected and trained on
save_model_every_experiences:int = 5000
################
# construct/load model
tmodel:model.TetrisAI = None
if model_save_path != None and model_save_path != "":
print("Loading model checkpoint at '" + model_save_path + "'...")
tmodel = model.TetrisAI(model_save_path)
print("Model loaded!")
else:
print("Constructing new model...")
tmodel = model.TetrisAI()
# variables to track
experiences_trained:int = 0 # the number of experiences the model has been trained on
model_last_saved_at_experiences_trained:int = 0 # the last number of experiences that the model was trained on
on_checkpoint:int = 0
def log(path:str, content:str) -> None:
if path != None and path != "":
f = open(path, "a")
f.write(content + "\n")
f.close()
# training loop
while True:
# collect X number of experiences
gs:tetris.GameState = tetris.GameState()
experiences:list[model.Experience] = []
for ei in range(0, batch_size):
# print!
sys.stdout.write("\r" + "Collecting experience " + str(ei+1) + " / " + str(batch_size) + "... ")
sys.stdout.flush()
# get board representation
state_board:list[int] = representation.BoardState(gs)
# select move to play
move:int
if random.random() < epsilon: # if by chance we should select a random move
move = random.randint(0, 3) # choose move at random
else:
predictions:list[float] = tmodel.predict(state_board) # predict Q-Values
move = predictions.index(max(predictions)) # select the move (index) with the highest Q-Value
# play the move
IllegalMovePlayed:bool = False
MoveReward:float
try:
MoveReward = gs.drop(move)
except tetris.InvalidDropException as ex: # the model (or at random) tried to play an illegal move
IllegalMovePlayed = True
MoveReward = -3.0 # small penalty for illegal moves
except Exception as ex:
print("Unhandled exception in move execution: " + str(ex))
input("Press enter key to continue, if you want to.")
# store this experience
exp:model.Experience = model.Experience()
exp.state = state_board
exp.action = move
exp.reward = MoveReward
exp.next_state = representation.BoardState(gs) # the state we find ourselves in now.
exp.done = gs.over() or IllegalMovePlayed # it is over if the game is completed OR an illegal move was played
experiences.append(exp)
# if game is over or they played an illegal move, reset the game!
if gs.over() or IllegalMovePlayed:
gs = tetris.GameState()
print()
# print avg rewards
rewards:float = 0.0
for exp in experiences:
rewards = rewards + exp.reward
status:str = "Average reward over those " + str(len(experiences)) + " experiences on model w/ " + str(experiences_trained) + " trained experiences: " + str(round(rewards / len(experiences), 2))
log(log_file_path, status)
print(status)
# train!
for ei in range(0, len(experiences)):
exp = experiences[ei]
# print training number
sys.stdout.write("\r" + "Training on experience " + str(ei+1) + " / " + str(len(experiences)) + "... ")
sys.stdout.flush()
# determine new target based on the game ending or not (maybe we should factor in future rewards, maybe we shouldnt)
new_target:float
if exp.done:
new_target = exp.reward
else:
max_q_of_next_state:float = max(tmodel.predict(exp.next_state))
new_target = exp.reward + (gamma * max_q_of_next_state) # blend immediate vs. future rewards
# ask the model to predict again for this experiences state
qvalues:list[float] = tmodel.predict(exp.state)
# plug in the new target where it belongs
qvalues[exp.action] = new_target
# now train on the updated qvalues (with 1 changed)
tmodel.train(exp.state, qvalues)
experiences_trained = experiences_trained + 1
print("Training complete!")
# save model!
if (experiences_trained - model_last_saved_at_experiences_trained) >= save_model_every_experiences:
print("Time to save model!")
path = checkpoints_dir / f"checkpoint{on_checkpoint}.pth"
tmodel.save(path)
print("Checkpoint # " + str(on_checkpoint) + " saved to " + str(path) + "!")
on_checkpoint = on_checkpoint + 1
model_last_saved_at_experiences_trained = experiences_trained
print("Model saved to " + str(path) + "!")