Spaces:
Runtime error
Runtime error
using Random | |
using StatsBase | |
""" | |
simluate_rollout(b::Board, policy, side [rng=MersenneTwister(420)) | |
Simulate one rollout of a simulation based on the given `Chess.board` state. | |
Policy is a function, given the board and `MoveList`, returns an `AbstractArray` of probability | |
weights for each `Move` in `Move`List` based on index. | |
""" | |
function simulate_rollout(b::Board, policy, side; rng = MersenneTwister(420))::Tuple{Board, Int64} | |
#pprint(b) # Debugging | |
movelist = MoveList(200) | |
num_sim_moves = 0 | |
while !isterminal(b) # TODO Use `matein1` possibly to trim leaf nodes in sims? | |
moves(b, movelist) | |
policy_weights = ProbabilityWeights(policy(b, movelist)) | |
#pprint(b) | |
#println(movelist, policy_weights) | |
domove!(b, sample(movelist, policy_weights)) | |
recycle!(movelist) | |
num_sim_moves += 1 | |
end | |
return b, num_sim_moves | |
end | |
""" | |
CESPF(b::Board, movelist::MoveList) | |
Utilizes `Chess.jl`'s `see()` function to simulate (C)apture / (E)scape (S)tronger (P)iece | |
(F)irst heuristic in simulation/rollout policy. We use Chess weights set in `see` function to get weight for which | |
move we prefer to take. | |
""" | |
function CESPF(b::Board, movelist::MoveList) | |
unnorm_policy_weights = map(x -> see(b, x), movelist) | |
# Center raw centipawn values to 1 to then normalize | |
centered_policy_weights = (1 + abs(min(unnorm_policy_weights...))) .+ | |
unnorm_policy_weights | |
return centered_policy_weights / sum(centered_policy_weights) | |
end | |
""" | |
CESPF_greedy(b::Board, movelist::MoveList) | |
Utilizes `Chess.jl`'s `see()` function to simulate (C)apture / (E)scape (S)tronger (P)iece | |
(F)irst heuristic in simulation/rollout policy. We use Chess weights set in `see` function to get weight for which | |
move we prefer to take. This is greedy, and will set only the maximal valued policies to a non-zero | |
probability | |
""" | |
function CESPF_greedy(b::Board, movelist::MoveList) | |
unnorm_policy_weights = map(x -> see(b, x), movelist) | |
policy_weights = zeros(length(unnorm_policy_weights)) | |
max_idxs = findall(unnorm_policy_weights .== maximum(unnorm_policy_weights)) | |
for max_idx in max_idxs | |
policy_weights[max_idx] = 1.0 / length(max_idxs) | |
end | |
return policy_weights | |
end | |