File size: 3,094 Bytes
3985e26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
---
license: apache-2.0
language:
- en
---

# ChessCLIP

ChessCLIP is a CLIP model trained to align (board, action) representation with natural language and calculate the similarity in Chess game.

## Model Details
- **Language(s)**: English
- **License**: Apache 2.0
- **Model Description**: A CLIP model for chess.

# Quick Start

```bash
git clone https://github.com/waterhorse1/ChessGPT
```
Clone our codebase and install all dependencies according to our README.

## Inference

```python
import sys
sys.path.append('./chessclip/src')
import torch
import io
import chess.pgn
import numpy as np
from data.chessclip_data.feature_converter import get_lc0_input_planes_tf
from data.chessclip_data.pgn_base import generate_examples_from_game_no_comment

from open_clip.factory import get_tokenizer, load_checkpoint

# init
model_name = 'chessclip-quickgelu'
model = open_clip.create_model(model_name, pretrained='openai')
tokenizer = get_tokenizer(model_name)

# load model
load_checkpoint(model, './ChessCLIP/epoch_last.pt')

# check parameters
model.eval()
context_length = model.text.context_length
vocab_size = model.text.vocab_size

print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("Context length:", context_length)
print("Vocab size:", vocab_size)

# generate board/action embedding based on pgn string
def generate_representation_for_final(pgn):
    game = chess.pgn.read_game(io.StringIO(pgn))
    data = list(generate_examples_from_game_no_comment(game))[-1]
    for key in data.keys():
        data[key] = np.array(data[key])
    board = get_lc0_input_planes_tf(data).numpy()
    action = data['probs']
    return board, action

# Prepare input
prompt = "Black plays Sicilian Defense"
pgn_str = '1. e4 c5'
board, action = generate_representation_for_final(pgn_str)
text_tokens = tokenizer([prompt])

image_input = torch.from_numpy(np.stack([board], axis=0))
action_input = torch.from_numpy(np.stack([action], axis=0))

# infer
with torch.no_grad():
    image_features = model.encode_image((image_input, action_input)).float()
    text_features = model.encode_text(text_tokens).float()
image_features /= image_features.norm(dim=-1, keepdim=True) # n * dim
text_features /= text_features.norm(dim=-1, keepdim=True)# m * dim
similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T # m * n
print(similarity)
```

## Limitations
"ChessCLIP," like other CLIP-based models, has certain limitations that need to be taken into consideration. For instance, the model may produce incorrect similarities, especially when faced with complex, ambiguous, or language inputs that fall outside its training data.

We highly appreciate contributions from individuals and organizations to enhance the model's performance and stability. Specifically, we welcome annotated data, such as annotated PGN (Portable Game Notation), which can be utilized to train a more robust and reliable CLIP model.

## Benchmark

Please refer to our [paper](https://together.xyz) and [code](https://github.com/waterhorse1/ChessGPT)for benchmark results.