Waterhorse commited on
Commit
3985e26
1 Parent(s): 859d91e

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +92 -0
README.md ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - en
5
+ ---
6
+
7
+ # ChessCLIP
8
+
9
+ ChessCLIP is a CLIP model trained to align (board, action) representation with natural language and calculate the similarity in Chess game.
10
+
11
+ ## Model Details
12
+ - **Language(s)**: English
13
+ - **License**: Apache 2.0
14
+ - **Model Description**: A CLIP model for chess.
15
+
16
+ # Quick Start
17
+
18
+ ```bash
19
+ git clone https://github.com/waterhorse1/ChessGPT
20
+ ```
21
+ Clone our codebase and install all dependencies according to our README.
22
+
23
+ ## Inference
24
+
25
+ ```python
26
+ import sys
27
+ sys.path.append('./chessclip/src')
28
+ import torch
29
+ import io
30
+ import chess.pgn
31
+ import numpy as np
32
+ from data.chessclip_data.feature_converter import get_lc0_input_planes_tf
33
+ from data.chessclip_data.pgn_base import generate_examples_from_game_no_comment
34
+
35
+ from open_clip.factory import get_tokenizer, load_checkpoint
36
+
37
+ # init
38
+ model_name = 'chessclip-quickgelu'
39
+ model = open_clip.create_model(model_name, pretrained='openai')
40
+ tokenizer = get_tokenizer(model_name)
41
+
42
+ # load model
43
+ load_checkpoint(model, './ChessCLIP/epoch_last.pt')
44
+
45
+ # check parameters
46
+ model.eval()
47
+ context_length = model.text.context_length
48
+ vocab_size = model.text.vocab_size
49
+
50
+ print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
51
+ print("Context length:", context_length)
52
+ print("Vocab size:", vocab_size)
53
+
54
+ # generate board/action embedding based on pgn string
55
+ def generate_representation_for_final(pgn):
56
+ game = chess.pgn.read_game(io.StringIO(pgn))
57
+ data = list(generate_examples_from_game_no_comment(game))[-1]
58
+ for key in data.keys():
59
+ data[key] = np.array(data[key])
60
+ board = get_lc0_input_planes_tf(data).numpy()
61
+ action = data['probs']
62
+ return board, action
63
+
64
+ # Prepare input
65
+ prompt = "Black plays Sicilian Defense"
66
+ pgn_str = '1. e4 c5'
67
+ board, action = generate_representation_for_final(pgn_str)
68
+ text_tokens = tokenizer([prompt])
69
+
70
+ image_input = torch.from_numpy(np.stack([board], axis=0))
71
+ action_input = torch.from_numpy(np.stack([action], axis=0))
72
+
73
+ # infer
74
+ with torch.no_grad():
75
+ image_features = model.encode_image((image_input, action_input)).float()
76
+ text_features = model.encode_text(text_tokens).float()
77
+ image_features /= image_features.norm(dim=-1, keepdim=True) # n * dim
78
+ text_features /= text_features.norm(dim=-1, keepdim=True)# m * dim
79
+ similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T # m * n
80
+ print(similarity)
81
+ ```
82
+
83
+ # Uses
84
+
85
+ ## Limitations
86
+ "ChessCLIP," like other CLIP-based models, has certain limitations that need to be taken into consideration. For instance, the model may produce incorrect similarities, especially when faced with complex, ambiguous, or language inputs that fall outside its training data.
87
+
88
+ We highly appreciate contributions from individuals and organizations to enhance the model's performance and stability. Specifically, we welcome annotated data, such as annotated PGN (Portable Game Notation), which can be utilized to train a more robust and reliable CLIP model.
89
+
90
+ ## Benchmark
91
+
92
+ Please refer to our [paper](https://together.xyz) and [code](https://github.com/waterhorse1/ChessGPT)for benchmark results.