|
import model |
|
import tetris |
|
import sys |
|
import representation |
|
import random |
|
from pathlib import Path |
|
|
|
script_dir = Path(__file__).parent.resolve() |
|
checkpoints_dir = script_dir / "checkpoints" |
|
checkpoints_dir.mkdir(exist_ok=True) |
|
log_file_path = checkpoints_dir / "log.txt" |
|
|
|
|
|
model_save_path = r"" |
|
|
|
|
|
gamma:float = 0.5 |
|
epsilon:float = 0.2 |
|
|
|
|
|
batch_size:int = 100 |
|
save_model_every_experiences:int = 5000 |
|
|
|
|
|
|
|
tmodel:model.TetrisAI = None |
|
if model_save_path != None and model_save_path != "": |
|
print("Loading model checkpoint at '" + model_save_path + "'...") |
|
tmodel = model.TetrisAI(model_save_path) |
|
print("Model loaded!") |
|
else: |
|
print("Constructing new model...") |
|
tmodel = model.TetrisAI() |
|
|
|
|
|
experiences_trained:int = 0 |
|
model_last_saved_at_experiences_trained:int = 0 |
|
on_checkpoint:int = 0 |
|
|
|
def log(path:str, content:str) -> None: |
|
if path != None and path != "": |
|
f = open(path, "a") |
|
f.write(content + "\n") |
|
f.close() |
|
|
|
|
|
while True: |
|
|
|
|
|
gs:tetris.GameState = tetris.GameState() |
|
experiences:list[model.Experience] = [] |
|
for ei in range(0, batch_size): |
|
|
|
|
|
sys.stdout.write("\r" + "Collecting experience " + str(ei+1) + " / " + str(batch_size) + "... ") |
|
sys.stdout.flush() |
|
|
|
|
|
state_board:list[int] = representation.BoardState(gs) |
|
|
|
|
|
move:int |
|
if random.random() < epsilon: |
|
move = random.randint(0, 3) |
|
else: |
|
predictions:list[float] = tmodel.predict(state_board) |
|
move = predictions.index(max(predictions)) |
|
|
|
|
|
IllegalMovePlayed:bool = False |
|
MoveReward:float |
|
try: |
|
MoveReward = gs.drop(move) |
|
except tetris.InvalidDropException as ex: |
|
IllegalMovePlayed = True |
|
MoveReward = -3.0 |
|
except Exception as ex: |
|
print("Unhandled exception in move execution: " + str(ex)) |
|
input("Press enter key to continue, if you want to.") |
|
|
|
|
|
exp:model.Experience = model.Experience() |
|
exp.state = state_board |
|
exp.action = move |
|
exp.reward = MoveReward |
|
exp.next_state = representation.BoardState(gs) |
|
exp.done = gs.over() or IllegalMovePlayed |
|
experiences.append(exp) |
|
|
|
|
|
if gs.over() or IllegalMovePlayed: |
|
gs = tetris.GameState() |
|
|
|
print() |
|
|
|
|
|
rewards:float = 0.0 |
|
for exp in experiences: |
|
rewards = rewards + exp.reward |
|
status:str = "Average reward over those " + str(len(experiences)) + " experiences on model w/ " + str(experiences_trained) + " trained experiences: " + str(round(rewards / len(experiences), 2)) |
|
log(log_file_path, status) |
|
print(status) |
|
|
|
|
|
for ei in range(0, len(experiences)): |
|
exp = experiences[ei] |
|
|
|
|
|
sys.stdout.write("\r" + "Training on experience " + str(ei+1) + " / " + str(len(experiences)) + "... ") |
|
sys.stdout.flush() |
|
|
|
|
|
new_target:float |
|
if exp.done: |
|
new_target = exp.reward |
|
else: |
|
max_q_of_next_state:float = max(tmodel.predict(exp.next_state)) |
|
new_target = exp.reward + (gamma * max_q_of_next_state) |
|
|
|
|
|
qvalues:list[float] = tmodel.predict(exp.state) |
|
|
|
|
|
qvalues[exp.action] = new_target |
|
|
|
|
|
tmodel.train(exp.state, qvalues) |
|
experiences_trained = experiences_trained + 1 |
|
|
|
print("Training complete!") |
|
|
|
|
|
if (experiences_trained - model_last_saved_at_experiences_trained) >= save_model_every_experiences: |
|
print("Time to save model!") |
|
path = checkpoints_dir / f"checkpoint{on_checkpoint}.pth" |
|
|
|
tmodel.save(path) |
|
print("Checkpoint # " + str(on_checkpoint) + " saved to " + str(path) + "!") |
|
on_checkpoint = on_checkpoint + 1 |
|
model_last_saved_at_experiences_trained = experiences_trained |
|
print("Model saved to " + str(path) + "!") |