import model |
import tetris |
import sys |
import representation |
import random |
from pathlib import Path |
script_dir = Path(__file__).parent.resolve() |
checkpoints_dir = script_dir / "checkpoints" |
checkpoints_dir.mkdir(exist_ok=True) |
log_file_path = checkpoints_dir / "log.txt" |
model_save_path = r"" |
gamma:float = 0.5 |
epsilon:float = 0.2 |
batch_size:int = 100 |
save_model_every_experiences:int = 5000 |
tmodel:model.TetrisAI = None |
if model_save_path != None and model_save_path != "": |
print("Loading model checkpoint at '" + model_save_path + "'...") |
tmodel = model.TetrisAI(model_save_path) |
print("Model loaded!") |
else: |
print("Constructing new model...") |
tmodel = model.TetrisAI() |
experiences_trained:int = 0 |
model_last_saved_at_experiences_trained:int = 0 |
on_checkpoint:int = 0 |
def log(path:str, content:str) -> None: |
if path != None and path != "": |
f = open(path, "a") |
f.write(content + "\n") |
f.close() |
while True: |
gs:tetris.GameState = tetris.GameState() |
experiences:list[model.Experience] = [] |
for ei in range(0, batch_size): |
sys.stdout.write("\r" + "Collecting experience " + str(ei+1) + " / " + str(batch_size) + "... ") |
sys.stdout.flush() |
state_board:list[int] = representation.BoardState(gs) |
move:int |
if random.random() < epsilon: |
move = random.randint(0, 3) |
else: |
predictions:list[float] = tmodel.predict(state_board) |
move = predictions.index(max(predictions)) |
IllegalMovePlayed:bool = False |
MoveReward:float |
try: |
MoveReward = gs.drop(move) |
except tetris.InvalidDropException as ex: |
IllegalMovePlayed = True |
MoveReward = -3.0 |
except Exception as ex: |
print("Unhandled exception in move execution: " + str(ex)) |
input("Press enter key to continue, if you want to.") |
exp:model.Experience = model.Experience() |
exp.state = state_board |
exp.action = move |
exp.reward = MoveReward |
exp.next_state = representation.BoardState(gs) |
exp.done = gs.over() or IllegalMovePlayed |
experiences.append(exp) |
if gs.over() or IllegalMovePlayed: |
gs = tetris.GameState() |
print() |
rewards:float = 0.0 |
for exp in experiences: |
rewards = rewards + exp.reward |
status:str = "Average reward over those " + str(len(experiences)) + " experiences on model w/ " + str(experiences_trained) + " trained experiences: " + str(round(rewards / len(experiences), 2)) |
log(log_file_path, status) |
print(status) |
for ei in range(0, len(experiences)): |
exp = experiences[ei] |
sys.stdout.write("\r" + "Training on experience " + str(ei+1) + " / " + str(len(experiences)) + "... ") |
sys.stdout.flush() |
new_target:float |
if exp.done: |
new_target = exp.reward |
else: |
max_q_of_next_state:float = max(tmodel.predict(exp.next_state)) |
new_target = exp.reward + (gamma * max_q_of_next_state) |
qvalues:list[float] = tmodel.predict(exp.state) |
qvalues[exp.action] = new_target |
tmodel.train(exp.state, qvalues) |
experiences_trained = experiences_trained + 1 |
print("Training complete!") |
if (experiences_trained - model_last_saved_at_experiences_trained) >= save_model_every_experiences: |
print("Time to save model!") |
path = checkpoints_dir / f"checkpoint{on_checkpoint}.pth" |
tmodel.save(path) |
print("Checkpoint # " + str(on_checkpoint) + " saved to " + str(path) + "!") |
on_checkpoint = on_checkpoint + 1 |
model_last_saved_at_experiences_trained = experiences_trained |
print("Model saved to " + str(path) + "!") |