|
import io |
|
import traceback |
|
from typing import List |
|
|
|
import chess |
|
import chess.pgn |
|
import chess.svg |
|
import gradio as gr |
|
import numpy as np |
|
import tokenizers |
|
import torch |
|
from tokenizers import models, pre_tokenizers, processors |
|
from torch import Tensor as TT |
|
from transformers import (AutoModelForCausalLM, GPT2LMHeadModel, |
|
PreTrainedTokenizerFast) |
|
|
|
checkpoint_name = "austindavis/chess-gpt2-uci-8x8x512" |
|
|
|
|
|
class UciTokenizer(PreTrainedTokenizerFast): |
|
_PAD_TOKEN: str |
|
_UNK_TOKEN: str |
|
_EOS_TOKEN: str |
|
_BOS_TOKEN: str |
|
|
|
stoi: dict[str, int] |
|
"""Integer to String mapping""" |
|
|
|
itos: dict[int, str] |
|
"""String to Integer Mapping. This is the vocab""" |
|
|
|
def __init__( |
|
self, |
|
stoi, |
|
itos, |
|
pad_token, |
|
unk_token, |
|
bos_token, |
|
eos_token, |
|
name_or_path, |
|
): |
|
self.stoi = stoi |
|
self.itos = itos |
|
|
|
self._PAD_TOKEN = pad_token |
|
self._UNK_TOKEN = unk_token |
|
self._EOS_TOKEN = eos_token |
|
self._BOS_TOKEN = bos_token |
|
|
|
|
|
tok_model = models.WordLevel(vocab=self.stoi, |
|
unk_token=self._UNK_TOKEN) |
|
|
|
slow_tokenizer = tokenizers.Tokenizer(tok_model) |
|
slow_tokenizer.pre_tokenizer = self._init_pretokenizer() |
|
|
|
|
|
post_proc = processors.TemplateProcessing( |
|
single=f"{bos_token} $0", |
|
pair=None, |
|
special_tokens=[(bos_token, 1)], |
|
) |
|
slow_tokenizer.post_processor = post_proc |
|
|
|
super().__init__( |
|
tokenizer_object=slow_tokenizer, |
|
unk_token=self._UNK_TOKEN, |
|
bos_token=self._BOS_TOKEN, |
|
eos_token=self._EOS_TOKEN, |
|
pad_token=self._PAD_TOKEN, |
|
name_or_path=name_or_path, |
|
) |
|
|
|
|
|
def _decode( |
|
token_ids: int | List[int], |
|
skip_special_tokens=False, |
|
clean_up_tokenization_spaces=False, |
|
) -> int | List[int]: |
|
|
|
if isinstance(token_ids, int): |
|
return self.itos.get(token_ids, self._UNK_TOKEN) |
|
|
|
if isinstance(token_ids, dict): |
|
token_ids = token_ids["input_ids"] |
|
|
|
if isinstance(token_ids, TT): |
|
token_ids = token_ids.tolist() |
|
|
|
if isinstance(token_ids, list): |
|
tokens_str = [self.itos.get(xi, self._UNK_TOKEN) |
|
for xi in token_ids] |
|
moves = self._process_str_tokens(tokens_str) |
|
|
|
return " ".join(moves) |
|
|
|
self._decode = _decode |
|
|
|
def _init_pretokenizer(self) -> pre_tokenizers.PreTokenizer: |
|
raise NotImplementedError |
|
|
|
def _process_str_tokens(self, tokens_str: list[str]) -> list[str]: |
|
raise NotImplementedError |
|
|
|
def get_id2square_list() -> list[int]: |
|
raise NotImplementedError |
|
|
|
|
|
class UciTileTokenizer(UciTokenizer): |
|
"""Uci tokenizer converting start/end tiles and promotion types each |
|
into individual tokens""" |
|
|
|
stoi = { |
|
tok: idx |
|
for tok, idx in list( |
|
zip( |
|
["<pad>", "<s>", "</s>", "<unk>"] + |
|
chess.SQUARE_NAMES + |
|
list("qrbn"), |
|
range(72), |
|
) |
|
) |
|
} |
|
|
|
itos = { |
|
idx: tok |
|
for tok, idx in list( |
|
zip( |
|
["<pad>", "<s>", "</s>", "<unk>"] + |
|
chess.SQUARE_NAMES + list("qrbn"), |
|
range(72), |
|
) |
|
) |
|
} |
|
|
|
id2square: List[int] = [None] * 4 + list(range(64)) + [None] * 4 |
|
""" |
|
List mapping token IDs to squares on the chess board. |
|
Order is file then row, i.e.: |
|
`A1, B1, C1, ..., F8, G8, H8` |
|
""" |
|
|
|
def get_id2square_list(self) -> List[int]: |
|
return self.id2square |
|
|
|
def __init__(self): |
|
|
|
super().__init__( |
|
self.stoi, |
|
self.itos, |
|
pad_token="<pad>", |
|
unk_token="<unk>", |
|
bos_token="<s>", |
|
eos_token="</s>", |
|
name_or_path="austindavis/uci_tile_tokenizer", |
|
) |
|
|
|
def _init_pretokenizer(self): |
|
|
|
pattern = tokenizers.Regex(r"\d") |
|
pre_tokenizer = pre_tokenizers.Sequence( |
|
[ |
|
pre_tokenizers.Whitespace(), |
|
pre_tokenizers.Split(pattern=pattern, |
|
behavior="merged_with_previous"), |
|
] |
|
) |
|
return pre_tokenizer |
|
|
|
def _process_str_tokens(self, token_str): |
|
moves = [] |
|
next_move = "" |
|
for token in token_str: |
|
|
|
|
|
if token in self.all_special_tokens: |
|
continue |
|
|
|
|
|
if len(token) == 1: |
|
moves.append(next_move + token) |
|
continue |
|
|
|
|
|
if len(next_move) == 4: |
|
moves.append(next_move) |
|
next_move = token |
|
else: |
|
next_move += token |
|
|
|
moves.append(next_move) |
|
return moves |
|
|
|
|
|
def setup_app(model: GPT2LMHeadModel): |
|
""" |
|
Configures a Gradio App to use the GPT model for move generation. |
|
The model must be compatible with a UciTileTokenizer. |
|
""" |
|
tokenizer = UciTileTokenizer() |
|
|
|
|
|
board = chess.Board() |
|
game: chess.pgn.GameNode = chess.pgn.Game() |
|
|
|
game.headers["Event"] = "Example" |
|
|
|
generate_kwargs = { |
|
"max_new_tokens": 3, |
|
"num_return_sequences": 10, |
|
"temperature": 0.5, |
|
"output_scores": True, |
|
"output_logits": True, |
|
"return_dict_in_generate": True, |
|
} |
|
|
|
def make_move(input: str, node=game, board=board): |
|
|
|
if input.lower() == "reset": |
|
board.reset() |
|
node.root().variations.clear() |
|
return chess.svg.board(board=board), "New game!" |
|
|
|
|
|
if input[0] == "[" or input[:3] == "1. ": |
|
pgn = io.StringIO(input) |
|
game = chess.pgn.read_game(pgn) |
|
board.reset() |
|
node.root().variations.clear() |
|
|
|
for move in game.mainline_moves(): |
|
board.push(move) |
|
node.add_variation(move) |
|
|
|
return ( |
|
chess.svg.board(board=board, lastmove=move), |
|
"", |
|
) |
|
|
|
try: |
|
move = chess.Move.from_uci(input) |
|
if move in board.legal_moves: |
|
board.push(move) |
|
|
|
while node.next() is not None: |
|
node = node.next() |
|
node = node.add_variation(move) |
|
|
|
|
|
|
|
prefix = " ".join([x.uci() for x in board.move_stack]) |
|
encoding = tokenizer( |
|
text=prefix, |
|
return_tensors="pt", |
|
)["input_ids"] |
|
|
|
output = model.generate(encoding, **generate_kwargs) |
|
new_tokens = tokenizer.batch_decode(output.sequences[:, -3:]) |
|
unique_moves, unique_indices = np.unique( |
|
[x[:4] if " " in x else x for x in new_tokens], |
|
return_index=True |
|
) |
|
unique_indices = ( |
|
torch.Tensor(list(unique_indices)) |
|
.to(dtype=torch.int) |
|
) |
|
logits = torch.stack(output.logits) |
|
logits = logits[:, unique_indices] |
|
|
|
|
|
logit_priority_order = ( |
|
logits.max(dim=-1) |
|
.values.T[:, :2] |
|
.mean(-1) |
|
.topk(len(unique_indices)) |
|
.indices |
|
) |
|
priority_ordered_moves = unique_moves[logit_priority_order] |
|
|
|
|
|
if isinstance(priority_ordered_moves, str): |
|
priority_ordered_moves = [priority_ordered_moves] |
|
|
|
|
|
for uci in priority_ordered_moves: |
|
move = chess.Move.from_uci(uci) |
|
if move in board.legal_moves: |
|
board.push(move) |
|
while node.next() is not None: |
|
node = node.next() |
|
node = node.add_variation(move) |
|
return ( |
|
chess.svg.board(board=board, lastmove=move), |
|
"".join(str(node.root()).split("]")[-1]).strip(), |
|
) |
|
|
|
|
|
bad_from_tiles = [ |
|
chess.parse_square(x) for x in [x[:2] |
|
for x in unique_moves] |
|
] |
|
bad_to_tiles = [ |
|
chess.parse_square(x) for x in [x[2:] |
|
for x in unique_moves] |
|
] |
|
arrows = [ |
|
chess.svg.Arrow(tail, head, color="red") |
|
for (tail, head) in zip(bad_from_tiles, bad_to_tiles) |
|
] |
|
checks = None |
|
if board.is_check(): |
|
checks = (board |
|
.pieces(chess.PIECE_TYPES[-1], board.turn) |
|
.pop() |
|
) |
|
|
|
return chess.svg.board( |
|
board=board, arrows=arrows, check=checks |
|
), "|".join(unique_moves) |
|
else: |
|
return ( |
|
chess.svg.board(board=board, lastmove=move), |
|
f"Illegal move: {input}", |
|
) |
|
|
|
except chess.InvalidMoveError: |
|
return (chess.svg.board(board=board), |
|
f"Invalid UCI format: {input}") |
|
except Exception: |
|
return chess.svg.board(board=board), traceback.format_exc() |
|
|
|
input_box = gr.Textbox(None, placeholder="Enter your move in UCI format") |
|
|
|
|
|
iface = gr.Interface( |
|
fn=make_move, |
|
inputs=input_box, |
|
outputs=["html", "text"], |
|
examples=[["e2e4"], ["d2d4"], ["Reset"]], |
|
title="Play Versus ChessGPT", |
|
description="Enter moves in UCI notation (e.g., e2e4 for pawn from e2 \ |
|
to e4). Enter 'reset' to restart the game.", |
|
allow_flagging="never", |
|
submit_btn="Move", |
|
stop_btn="Stop", |
|
clear_btn="Clear w/o reset", |
|
) |
|
|
|
iface.output_components[0].label = "Board" |
|
iface.output_components[0].show_label = True |
|
iface.output_components[1].label = "Move Sequence" |
|
|
|
return iface |
|
|
|
|
|
model: GPT2LMHeadModel = AutoModelForCausalLM.from_pretrained(checkpoint_name) |
|
model.requires_grad_(False) |
|
|
|
iface = setup_app(model) |
|
iface.launch(share=True) |
|
|