Samuel CHAINEAU
commited on
Commit
•
e8a8dd9
1
Parent(s):
ace9d63
QBGPT V1
Browse files- .DS_Store +0 -0
- .gitattributes +1 -0
- __pycache__/app_tools.cpython-311.pyc +0 -0
- __pycache__/pages.cpython-311.pyc +0 -0
- __pycache__/tools.cpython-311.pyc +0 -0
- app.py +79 -0
- assets/__pycache__/models.cpython-311.pyc +0 -0
- assets/index.parquet +3 -0
- assets/model_mediumv2/QBGPT.data-00000-of-00001 +3 -0
- assets/model_mediumv2/QBGPT.index +0 -0
- assets/model_mediumv2/checkpoint +2 -0
- assets/models.py +377 -0
- assets/moves_index.parquet +3 -0
- assets/photo_cv.jpg +0 -0
- assets/plays_index.parquet +3 -0
- assets/positions_index.parquet +3 -0
- assets/ref.json +0 -0
- assets/ref_df.json +1 -0
- assets/scrimmage_index.parquet +3 -0
- assets/starts_index.parquet +3 -0
- assets/time_index.parquet +3 -0
- pages.py +185 -0
- tools.py +375 -0
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
|
__pycache__/app_tools.cpython-311.pyc
ADDED
Binary file (24.8 kB). View file
|
|
__pycache__/pages.cpython-311.pyc
ADDED
Binary file (13.8 kB). View file
|
|
__pycache__/tools.cpython-311.pyc
ADDED
Binary file (30.6 kB). View file
|
|
app.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from pages import set_app_title_and_logo, qb_gpt_page, contacts_and_disclaimers
|
3 |
+
import json
|
4 |
+
import pandas as pd
|
5 |
+
import numpy as np
|
6 |
+
import os
|
7 |
+
from tools import tokenizer
|
8 |
+
|
9 |
+
from assets.models import QBGPT
|
10 |
+
|
11 |
+
moves_to_pred = 11170
|
12 |
+
input_size = 11172
|
13 |
+
starts_size = 1954
|
14 |
+
scrimmage_size = 100
|
15 |
+
positions_id = 29
|
16 |
+
|
17 |
+
temp_ids = 52
|
18 |
+
off_def_size = 2
|
19 |
+
token_type_size = 3
|
20 |
+
play_type_size = 9
|
21 |
+
|
22 |
+
qbgpt = QBGPT(input_vocab_size = input_size,
|
23 |
+
positional_vocab_size = temp_ids,
|
24 |
+
position_vocab_size=positions_id,
|
25 |
+
start_vocab_size=starts_size,
|
26 |
+
scrimmage_vocab_size=scrimmage_size,
|
27 |
+
offdef_vocab_size = off_def_size,
|
28 |
+
type_vocab_size = token_type_size,
|
29 |
+
playtype_vocab_size = play_type_size,
|
30 |
+
embedding_dim = 256,
|
31 |
+
hidden_dim = 256,
|
32 |
+
num_heads = 3,
|
33 |
+
diag_masks = False,
|
34 |
+
to_pred_size = moves_to_pred)
|
35 |
+
|
36 |
+
qbgpt.load_weights("app/assets/model_mediumv2/QBGPT")
|
37 |
+
|
38 |
+
|
39 |
+
qb_tok = tokenizer(moves_index="./app/assets/moves_index.parquet",
|
40 |
+
play_index="./app/assets/plays_index.parquet",
|
41 |
+
positions_index="./app/assets/positions_index.parquet",
|
42 |
+
scrimmage_index="./app/assets/scrimmage_index.parquet",
|
43 |
+
starts_index="./app/assets/starts_index.parquet",
|
44 |
+
time_index="./app/assets/time_index.parquet",
|
45 |
+
window_size=20)
|
46 |
+
|
47 |
+
print(os.listdir("app"))
|
48 |
+
|
49 |
+
with open('./app/assets/ref.json', 'r') as fp:
|
50 |
+
ref_json = json.load(fp)
|
51 |
+
|
52 |
+
def convert_numpy(d):
|
53 |
+
return {k:np.array(v) for k,v in d.items()}
|
54 |
+
|
55 |
+
ref_json = {int(k):convert_numpy(v) for k,v in ref_json.items()}
|
56 |
+
|
57 |
+
ref_df = pd.read_json("./app/assets/ref_df.json")
|
58 |
+
|
59 |
+
|
60 |
+
|
61 |
+
# Define the main function to run the app
|
62 |
+
def main():
|
63 |
+
set_app_title_and_logo()
|
64 |
+
|
65 |
+
# Create a sidebar for navigation
|
66 |
+
st.sidebar.title("Navigation")
|
67 |
+
page = st.sidebar.radio("Go to:", ("QB-GPT", "Contacts and Disclaimers"))
|
68 |
+
|
69 |
+
if page == "QB-GPT":
|
70 |
+
# Page 2: QB-GPT
|
71 |
+
st.title("QB-GPT")
|
72 |
+
qb_gpt_page(ref_df, ref_json, qb_tok, qbgpt)
|
73 |
+
|
74 |
+
if page == "Contacts and Disclaimers":
|
75 |
+
contacts_and_disclaimers()
|
76 |
+
|
77 |
+
|
78 |
+
if __name__ == "__main__":
|
79 |
+
main()
|
assets/__pycache__/models.cpython-311.pyc
ADDED
Binary file (22.4 kB). View file
|
|
assets/index.parquet
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ce0a2888c66b821c990837d1ea76061923d9ab85bde410aa2291f426c63317eb
|
3 |
+
size 29240
|
assets/model_mediumv2/QBGPT.data-00000-of-00001
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7a725719eb71e137bfc5f126af82012085830f95a2992374e196cf3d7a054fc6
|
3 |
+
size 86410198
|
assets/model_mediumv2/QBGPT.index
ADDED
Binary file (4.55 kB). View file
|
|
assets/model_mediumv2/checkpoint
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
model_checkpoint_path: "QBGPT"
|
2 |
+
all_model_checkpoint_paths: "QBGPT"
|
assets/models.py
ADDED
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
from typing import List, Optional, Union
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
def shape_list(tensor: Union[tf.Tensor, np.ndarray]) -> List[int]:
|
6 |
+
"""
|
7 |
+
Deal with dynamic shape in tensorflow cleanly.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
tensor (`tf.Tensor` or `np.ndarray`): The tensor we want the shape of.
|
11 |
+
|
12 |
+
Returns:
|
13 |
+
`List[int]`: The shape of the tensor as a list.
|
14 |
+
"""
|
15 |
+
if isinstance(tensor, np.ndarray):
|
16 |
+
return list(tensor.shape)
|
17 |
+
|
18 |
+
dynamic = tf.shape(tensor)
|
19 |
+
|
20 |
+
if tensor.shape == tf.TensorShape(None):
|
21 |
+
return dynamic
|
22 |
+
|
23 |
+
static = tensor.shape.as_list()
|
24 |
+
|
25 |
+
return [dynamic[i] if s is None else s for i, s in enumerate(static)]
|
26 |
+
|
27 |
+
class PlayTypeEncoder(tf.keras.Model):
|
28 |
+
def __init__(self, vocab_size : int, embedding_dim : int):
|
29 |
+
super(PlayTypeEncoder, self).__init__()
|
30 |
+
|
31 |
+
self.Embedding = tf.keras.layers.Embedding(input_dim = vocab_size,
|
32 |
+
output_dim = embedding_dim)
|
33 |
+
|
34 |
+
def call(self, x):
|
35 |
+
embed = self.Embedding(x["PlayType"])
|
36 |
+
return embed
|
37 |
+
|
38 |
+
class PositionEncoder(tf.keras.Model):
|
39 |
+
def __init__(self, vocab_size : int, embedding_dim : int):
|
40 |
+
super(PositionEncoder, self).__init__()
|
41 |
+
|
42 |
+
self.Embedding = tf.keras.layers.Embedding(input_dim = vocab_size,
|
43 |
+
output_dim = embedding_dim)
|
44 |
+
|
45 |
+
def call(self, x):
|
46 |
+
embed = self.Embedding(x["position_ids"])
|
47 |
+
return embed
|
48 |
+
|
49 |
+
class ScrimmageEncoder(tf.keras.Model):
|
50 |
+
def __init__(self, vocab_size : int, embedding_dim : int):
|
51 |
+
super(ScrimmageEncoder, self).__init__()
|
52 |
+
|
53 |
+
self.Embedding = tf.keras.layers.Embedding(input_dim = vocab_size,
|
54 |
+
output_dim = embedding_dim)
|
55 |
+
|
56 |
+
def call(self, x):
|
57 |
+
embed = self.Embedding(x["scrim_ids"])
|
58 |
+
return embed
|
59 |
+
|
60 |
+
class StartEncoder(tf.keras.Model):
|
61 |
+
def __init__(self, vocab_size : int, embedding_dim : int):
|
62 |
+
super(StartEncoder, self).__init__()
|
63 |
+
|
64 |
+
self.Embedding = tf.keras.layers.Embedding(input_dim = vocab_size,
|
65 |
+
output_dim = embedding_dim)
|
66 |
+
|
67 |
+
def call(self, x):
|
68 |
+
embed = self.Embedding(x["start_ids"])
|
69 |
+
return embed
|
70 |
+
|
71 |
+
class OffDefEncoder(tf.keras.Model):
|
72 |
+
def __init__(self, vocab_size : int, embedding_dim : int):
|
73 |
+
super(OffDefEncoder, self).__init__()
|
74 |
+
|
75 |
+
self.Embedding = tf.keras.layers.Embedding(input_dim = vocab_size,
|
76 |
+
output_dim = embedding_dim)
|
77 |
+
|
78 |
+
def call(self, x):
|
79 |
+
embed = self.Embedding(x["OffDef"])
|
80 |
+
return embed
|
81 |
+
|
82 |
+
class TypeEncoder(tf.keras.Model):
|
83 |
+
def __init__(self, vocab_size : int, embedding_dim : int):
|
84 |
+
super(TypeEncoder, self).__init__()
|
85 |
+
|
86 |
+
self.Embedding = tf.keras.layers.Embedding(input_dim = vocab_size,
|
87 |
+
output_dim = embedding_dim)
|
88 |
+
|
89 |
+
def call(self, x):
|
90 |
+
embed = self.Embedding(x["token_type_ids"])
|
91 |
+
return embed
|
92 |
+
|
93 |
+
class PositionalEncoder(tf.keras.Model):
|
94 |
+
def __init__(self, vocab_size : int, embedding_dim : int):
|
95 |
+
super(PositionalEncoder, self).__init__()
|
96 |
+
|
97 |
+
self.Embedding = tf.keras.layers.Embedding(input_dim = vocab_size,
|
98 |
+
output_dim = embedding_dim)
|
99 |
+
|
100 |
+
def call(self, x):
|
101 |
+
embed = self.Embedding(x["pos_ids"])
|
102 |
+
return embed
|
103 |
+
|
104 |
+
class InputEncoder(tf.keras.Model):
|
105 |
+
def __init__(self, vocab_size : int, embedding_dim : int):
|
106 |
+
super(InputEncoder, self).__init__()
|
107 |
+
|
108 |
+
self.Embedding = tf.keras.layers.Embedding(input_dim = vocab_size,
|
109 |
+
output_dim = embedding_dim)
|
110 |
+
|
111 |
+
def call(self, x):
|
112 |
+
embed = self.Embedding(x["input_ids"])
|
113 |
+
return embed
|
114 |
+
|
115 |
+
class Embedding(tf.keras.Model):
|
116 |
+
def __init__(self,
|
117 |
+
input_vocab_size : int,
|
118 |
+
positional_vocab_size : int,
|
119 |
+
position_vocab_size : int,
|
120 |
+
scrimmage_vocab_size : int,
|
121 |
+
start_vocab_size: int,
|
122 |
+
offdef_vocab_size : int,
|
123 |
+
type_vocab_size : int,
|
124 |
+
playtype_vocab_size : int,
|
125 |
+
embedding_dim : int):
|
126 |
+
super(Embedding, self).__init__()
|
127 |
+
|
128 |
+
self.InputEmbedding = InputEncoder(vocab_size=input_vocab_size,
|
129 |
+
embedding_dim=embedding_dim)
|
130 |
+
self.PositionalEmbedding = PositionalEncoder(vocab_size=positional_vocab_size,
|
131 |
+
embedding_dim=embedding_dim)
|
132 |
+
self.PositionEmbedding = PositionEncoder(vocab_size=position_vocab_size,
|
133 |
+
embedding_dim=embedding_dim)
|
134 |
+
self.ScrimEmbedding = ScrimmageEncoder(vocab_size=scrimmage_vocab_size,
|
135 |
+
embedding_dim=embedding_dim)
|
136 |
+
self.StartEmbedding = StartEncoder(vocab_size=start_vocab_size,
|
137 |
+
embedding_dim=embedding_dim)
|
138 |
+
self.OffDefEmbedding = OffDefEncoder(vocab_size=offdef_vocab_size,
|
139 |
+
embedding_dim=embedding_dim)
|
140 |
+
self.TypeEmbedding = TypeEncoder(vocab_size=type_vocab_size,
|
141 |
+
embedding_dim=embedding_dim)
|
142 |
+
self.PlayTypeEmbedding = PlayTypeEncoder(vocab_size=playtype_vocab_size,
|
143 |
+
embedding_dim=embedding_dim)
|
144 |
+
self.Add = tf.keras.layers.Add()
|
145 |
+
|
146 |
+
def call(self, x):
|
147 |
+
input_embed = self.InputEmbedding(x)
|
148 |
+
positional_embed = self.PositionalEmbedding(x)
|
149 |
+
position_embed = self.PositionEmbedding(x)
|
150 |
+
scrim_embed = self.ScrimEmbedding(x)
|
151 |
+
start_embed = self.StartEmbedding(x)
|
152 |
+
type_embed = self.TypeEmbedding(x)
|
153 |
+
offdef_embed = self.OffDefEmbedding(x)
|
154 |
+
playtype_embed = self.PlayTypeEmbedding(x)
|
155 |
+
|
156 |
+
embed = self.Add([input_embed,
|
157 |
+
positional_embed,
|
158 |
+
position_embed,
|
159 |
+
scrim_embed,
|
160 |
+
start_embed,
|
161 |
+
type_embed,
|
162 |
+
offdef_embed,
|
163 |
+
playtype_embed])
|
164 |
+
|
165 |
+
return embed
|
166 |
+
|
167 |
+
class Transformers(tf.keras.Model):
|
168 |
+
def __init__(self,
|
169 |
+
num_heads : int,
|
170 |
+
hidden_dim : int,
|
171 |
+
output_dim : int,
|
172 |
+
diag_masks : bool):
|
173 |
+
super(Transformers, self).__init__()
|
174 |
+
|
175 |
+
self.diag_masks = diag_masks
|
176 |
+
self.num_attention_heads = num_heads
|
177 |
+
self.attention_head_size = hidden_dim
|
178 |
+
self.total_dim = num_heads * hidden_dim
|
179 |
+
self.output_dim = output_dim
|
180 |
+
|
181 |
+
self.NormIn = tf.keras.layers.LayerNormalization(name = "Norm_in")
|
182 |
+
self.Query = tf.keras.layers.Dense(self.total_dim, name = "Query", use_bias = False)
|
183 |
+
self.Key = tf.keras.layers.Dense(self.total_dim, name = "Key", use_bias = False)
|
184 |
+
self.Value = tf.keras.layers.Dense(self.total_dim, name = "Value", use_bias = False)
|
185 |
+
|
186 |
+
self.DenseAtt = tf.keras.layers.Dense(hidden_dim, activation = "relu", use_bias = False)
|
187 |
+
|
188 |
+
self.Add = tf.keras.layers.Add(name = "Add")
|
189 |
+
self.Drop = tf.keras.layers.Dropout(rate = 0.1)
|
190 |
+
|
191 |
+
self.DenseOut = tf.keras.layers.Dense(output_dim, name = "Dense", activation = "relu")
|
192 |
+
self.NormOut = tf.keras.layers.LayerNormalization(name = "Norm_out")
|
193 |
+
|
194 |
+
def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
|
195 |
+
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
|
196 |
+
tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
|
197 |
+
|
198 |
+
# Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size]
|
199 |
+
return tf.transpose(tensor, perm=[0, 2, 1, 3])
|
200 |
+
|
201 |
+
def create_causal_masks(self, temp_ids):
|
202 |
+
# Use broadcasting to create the 2D comparison tensor
|
203 |
+
causal_mask = temp_ids[:, :, tf.newaxis] >= temp_ids[:, tf.newaxis, :]
|
204 |
+
causal_mask = (tf.cast(causal_mask, dtype=tf.float32) - 1) * 1000000
|
205 |
+
reshaped_tensor = tf.expand_dims(causal_mask, axis=1)
|
206 |
+
duplicated_tensor = tf.tile(reshaped_tensor, multiples=[1, self.num_attention_heads, 1, 1])
|
207 |
+
return duplicated_tensor
|
208 |
+
|
209 |
+
def create_diag_masks(self, hidden_state):
|
210 |
+
dims = shape_list(hidden_state)
|
211 |
+
matrix = tf.linalg.diag(tf.ones((dims[0], dims[1], dims[2]), dtype=tf.float32))
|
212 |
+
return matrix*-1000000
|
213 |
+
|
214 |
+
def create_attention_mask(self, attn_mask):
|
215 |
+
attn_mask = (tf.cast(attn_mask, dtype=tf.float32) -1) * 1000000
|
216 |
+
reshaped_tensor = tf.expand_dims(attn_mask, axis=1)
|
217 |
+
reshaped_tensor = tf.expand_dims(reshaped_tensor, axis=1)
|
218 |
+
duplicated_tensor = tf.tile(reshaped_tensor, multiples=[1, self.num_attention_heads, 1, 1])
|
219 |
+
return duplicated_tensor
|
220 |
+
|
221 |
+
def compute_scaled_attn_scores(self, query, key):
|
222 |
+
attention_scores = tf.matmul(query, key, transpose_b=True) # Transpose the second sequence
|
223 |
+
|
224 |
+
# If you want scaled dot-product attention, divide by the square root of the embedding dimension
|
225 |
+
embedding_dim = query.shape[-1]
|
226 |
+
scaled_attention_scores = attention_scores / tf.math.sqrt(tf.cast(embedding_dim, dtype=tf.float32))
|
227 |
+
|
228 |
+
return scaled_attention_scores
|
229 |
+
|
230 |
+
def compute_attention_weigths(self, query, key, temp_ids, masks):
|
231 |
+
|
232 |
+
attn_masks = self.create_attention_mask(masks)
|
233 |
+
causal_masks = self.create_causal_masks(temp_ids)
|
234 |
+
scaled_attn_scores = self.compute_scaled_attn_scores(query, key)
|
235 |
+
if self.diag_masks == True:
|
236 |
+
diag_masks = self.create_diag_masks(query)
|
237 |
+
attn_scores = scaled_attn_scores + attn_masks + causal_masks + diag_masks
|
238 |
+
else:
|
239 |
+
attn_scores = scaled_attn_scores + attn_masks + causal_masks
|
240 |
+
|
241 |
+
return tf.nn.softmax(attn_scores, axis = -1)
|
242 |
+
|
243 |
+
def get_preds_and_attention(self,
|
244 |
+
embeddings,
|
245 |
+
temporal_ids,
|
246 |
+
attention_masks):
|
247 |
+
|
248 |
+
query = self.Query(embeddings)
|
249 |
+
key = self.Key(embeddings)
|
250 |
+
value = self.Value(embeddings)
|
251 |
+
|
252 |
+
attention_weights = self.compute_attention_weigths(query, key, temporal_ids, attention_masks)
|
253 |
+
|
254 |
+
attention_scores = tf.matmul(attention_weights, value)
|
255 |
+
attention_scores = self.Dense(attention_scores)
|
256 |
+
|
257 |
+
output = self.Add([attention_scores, embeddings])
|
258 |
+
output = self.Drop(output)
|
259 |
+
output = self.Norm(output)
|
260 |
+
return output, attention_weights
|
261 |
+
|
262 |
+
def call(self,
|
263 |
+
hidden_states : tf.Tensor,
|
264 |
+
temporal_ids,
|
265 |
+
attention_masks):
|
266 |
+
|
267 |
+
batch_size = shape_list(hidden_states)[0]
|
268 |
+
|
269 |
+
norm_hidden_states = self.NormIn(hidden_states)
|
270 |
+
|
271 |
+
query = self.Query(norm_hidden_states)
|
272 |
+
queries = self.transpose_for_scores(query, batch_size)
|
273 |
+
|
274 |
+
key = self.Key(norm_hidden_states)
|
275 |
+
keys = self.transpose_for_scores(key, batch_size)
|
276 |
+
|
277 |
+
value = self.Value(norm_hidden_states)
|
278 |
+
values = self.transpose_for_scores(value, batch_size)
|
279 |
+
|
280 |
+
attention_weights = self.compute_attention_weigths(queries, keys, temporal_ids, attention_masks)
|
281 |
+
attention_scores = tf.matmul(attention_weights, values)
|
282 |
+
attention_scores = tf.transpose(attention_scores, perm=[0, 2, 1, 3])
|
283 |
+
attention_scores = tf.reshape(tensor=attention_scores, shape=(batch_size, -1, self.total_dim))
|
284 |
+
attention_scores = self.DenseAtt(attention_scores)
|
285 |
+
|
286 |
+
output = self.Add([attention_scores, hidden_states])
|
287 |
+
norm_output = self.NormOut(output)
|
288 |
+
|
289 |
+
densed_output = self.DenseOut(norm_output)
|
290 |
+
output = self.Add([densed_output, output])
|
291 |
+
output = self.Drop(output)
|
292 |
+
return output
|
293 |
+
|
294 |
+
class Encoder(tf.keras.Model):
|
295 |
+
def __init__(self,
|
296 |
+
input_vocab_size : int,
|
297 |
+
positional_vocab_size : int,
|
298 |
+
position_vocab_size : int,
|
299 |
+
scrimmage_vocab_size : int,
|
300 |
+
start_vocab_size: int,
|
301 |
+
offdef_vocab_size : int,
|
302 |
+
type_vocab_size : int,
|
303 |
+
playtype_vocab_size : int,
|
304 |
+
embedding_dim : int,
|
305 |
+
hidden_dim : int,
|
306 |
+
num_heads : int,
|
307 |
+
diag_masks : bool):
|
308 |
+
super(Encoder, self).__init__()
|
309 |
+
|
310 |
+
self.num_heads = num_heads
|
311 |
+
self.diag_masks = diag_masks
|
312 |
+
self.Embedding = Embedding(input_vocab_size = input_vocab_size,
|
313 |
+
positional_vocab_size = positional_vocab_size,
|
314 |
+
position_vocab_size = position_vocab_size,
|
315 |
+
scrimmage_vocab_size = scrimmage_vocab_size,
|
316 |
+
start_vocab_size = start_vocab_size,
|
317 |
+
type_vocab_size = type_vocab_size,
|
318 |
+
offdef_vocab_size = offdef_vocab_size,
|
319 |
+
playtype_vocab_size = playtype_vocab_size,
|
320 |
+
embedding_dim = embedding_dim)
|
321 |
+
|
322 |
+
self.Attention1 = Transformers(num_heads = self.num_heads,
|
323 |
+
hidden_dim = hidden_dim,
|
324 |
+
output_dim = embedding_dim,
|
325 |
+
diag_masks = self.diag_masks)
|
326 |
+
|
327 |
+
self.DenseHead = tf.keras.layers.Dense(embedding_dim, activation = "relu")
|
328 |
+
|
329 |
+
def call(self,
|
330 |
+
x):
|
331 |
+
|
332 |
+
embed = self.Embedding(x)
|
333 |
+
h1 = self.Attention1(embed, x["pos_ids"], x["attention_mask"])
|
334 |
+
|
335 |
+
encoded = self.DenseHead(h1)
|
336 |
+
|
337 |
+
return encoded
|
338 |
+
|
339 |
+
|
340 |
+
class QBGPT(tf.keras.Model):
|
341 |
+
def __init__(self,
|
342 |
+
input_vocab_size : int,
|
343 |
+
positional_vocab_size : int,
|
344 |
+
position_vocab_size : int,
|
345 |
+
scrimmage_vocab_size : int,
|
346 |
+
start_vocab_size: int,
|
347 |
+
offdef_vocab_size : int,
|
348 |
+
type_vocab_size : int,
|
349 |
+
playtype_vocab_size : int,
|
350 |
+
embedding_dim : int,
|
351 |
+
hidden_dim : int,
|
352 |
+
num_heads : int,
|
353 |
+
diag_masks : bool,
|
354 |
+
to_pred_size : int):
|
355 |
+
super(QBGPT, self).__init__()
|
356 |
+
|
357 |
+
self.Encoder = Encoder(input_vocab_size = input_vocab_size,
|
358 |
+
positional_vocab_size = positional_vocab_size,
|
359 |
+
position_vocab_size = position_vocab_size,
|
360 |
+
scrimmage_vocab_size = scrimmage_vocab_size,
|
361 |
+
start_vocab_size = start_vocab_size,
|
362 |
+
type_vocab_size = type_vocab_size,
|
363 |
+
offdef_vocab_size = offdef_vocab_size,
|
364 |
+
playtype_vocab_size = playtype_vocab_size,
|
365 |
+
embedding_dim = embedding_dim,
|
366 |
+
hidden_dim = hidden_dim,
|
367 |
+
num_heads = num_heads,
|
368 |
+
diag_masks = diag_masks)
|
369 |
+
|
370 |
+
self.Logits = tf.keras.layers.Dense(to_pred_size)
|
371 |
+
|
372 |
+
def call(self, x):
|
373 |
+
|
374 |
+
encoded = self.Encoder(x)
|
375 |
+
logits = self.Logits(encoded)
|
376 |
+
|
377 |
+
return logits
|
assets/moves_index.parquet
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1a42fbda4f9377e74f2d79024f8b8628a9cc462b21a18526b7e7342188a6245e
|
3 |
+
size 32456
|
assets/photo_cv.jpg
ADDED
assets/plays_index.parquet
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e9a9697f12ba01faa7a5a395cd6eafe435868b34b5c634322ec67424e24a6fcc
|
3 |
+
size 760
|
assets/positions_index.parquet
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7abdabf96a14a3457287d6147280ab2051d8c65b2875d71a6402f4b880756bfb
|
3 |
+
size 1052
|
assets/ref.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
assets/ref_df.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"gameId":{"0":2019100600,"1":2018122201,"2":2019092600,"3":2019101300,"4":2020092011,"5":2019112801,"6":2018102802,"7":2018120912,"8":2019120109,"9":2019092900,"10":2018122312,"11":2019101311,"12":2019102702,"13":2019120102,"14":2019121507,"15":2019091507,"16":2019112404,"17":2018121605,"18":2019102000,"19":2019122901,"20":2019111002,"21":2018092303,"22":2019120109,"23":2018110404,"24":2018121609,"25":2020101801,"26":2021010302,"27":2020112200,"28":2018102809,"29":2019122909,"30":2019102710,"31":2018110500,"32":2018122314,"33":2020122002,"34":2019122901,"35":2020100401,"36":2019122205,"37":2020102513,"38":2018100706,"39":2018120211,"40":2018112600,"41":2018120200,"42":2019122912,"43":2019102007,"44":2020110500,"45":2019110308,"46":2018112504,"47":2018090903,"48":2018101408,"49":2020092013,"50":2018102110,"51":2019110310,"52":2019111100,"53":2018090900,"54":2019111001,"55":2020121306,"56":2019122102,"57":2021010310,"58":2021010311,"59":2018093009,"60":2020111503,"61":2019111000,"62":2018101404,"63":2019101309,"64":2019120500,"65":2020122012,"66":2018121604,"67":2019092207,"68":2019100600,"69":2019122913,"70":2018112508,"71":2018091601,"72":2020120702,"73":2019100608,"74":2018102808,"75":2019092209,"76":2018112505,"77":2018120909,"78":2018091613,"79":2019110303,"80":2018120204,"81":2019120802,"82":2018102111,"83":2018120902,"84":2018120907,"85":2019100602,"86":2020100501,"87":2020122004,"88":2018101401,"89":2019092211,"90":2018112201,"91":2020121312,"92":2020101200,"93":2020121700,"94":2021010304,"95":2020112204,"96":2018123011,"97":2019111703,"98":2020110900,"99":2020110808,"100":2018102805,"101":2020110802,"102":2019092210,"103":2018120912,"104":2018112510,"105":2019122905,"106":2018100100,"107":2020112202,"108":2020122002,"109":2021010305,"110":2018100706,"111":2019092903,"112":2019122209,"113":2018120300,"114":2019110303,"115":2018101402,"116":2019120106,"117":2020110108,"118":2020101801,"119":2020101812,"120":2018090600,"121":2018123014,"122":2018093004,"123":2021102401,"124":2018122313,"125":2018091607,"126":2018102111,"127":2018121700,"128":2018110405,"129":2018090901,"130":2018121609,"131":2018102900,"132":2018101400,"133":2018110409,"134":2018093009,"135":2018111103,"136":2018111802,"137":2018111809,"138":2021103102,"139":2018121501,"140":2018111101,"141":2021101700,"142":2018102107,"143":2021101707,"144":2021101001,"145":2018123012,"146":2018110404,"147":2018092000,"148":2021091908,"149":2018122309,"150":2018102104,"151":2018110406,"152":2018102109,"153":2021092607,"154":2018111108,"155":2021102100,"156":2018120909,"157":2021091910,"158":2018122313,"159":2018111108,"160":2018091001,"161":2018092000,"162":2018111802,"163":2021103106,"164":2018102801,"165":2018100711,"166":2018091604,"167":2021091203,"168":2018090600,"169":2021092603,"170":2018111111,"171":2018112507,"172":2018101401,"173":2018123013,"174":2021101800,"175":2021101703,"176":2018111111,"177":2021091204,"178":2021102406,"179":2018101100,"180":2021091208,"181":2018101100,"182":2019111001,"183":2021101011,"184":2021103109,"185":2021101009,"186":2021102800,"187":2021101800,"188":2019122210,"189":2021092600,"190":2021101008,"191":2021091209,"192":2018120600,"193":2018112507,"194":2021091910,"195":2021101701,"196":2019122900,"197":2018122313,"198":2021091211,"199":2021091903,"200":2021103104,"201":2021091300,"202":2021091600,"203":2021092700,"204":2021102406,"205":2021091600,"206":2021091300,"207":2021102401,"208":2020111505,"209":2021102401,"210":2021102500,"211":2021091900,"212":2021103100,"213":2021101704,"214":2021102405,"215":2021100308,"216":2021091901,"217":2019111708,"218":2019120108,"219":2021101001,"220":2018123001,"221":2020112601,"222":2021102800,"223":2021102400,"224":2020010400,"225":2019100612,"226":2021092600,"227":2019120200,"228":2021092600,"229":2020010401,"230":2021102405,"231":2021102407,"232":2018111102,"233":2021101703,"234":2021101701,"235":2019111001,"236":2021100400,"237":2021100307,"238":2021092607,"239":2021102100},"playId":{"0":2982,"1":1468,"2":1354,"3":2856,"4":2465,"5":2414,"6":2963,"7":1968,"8":3361,"9":2619,"10":1476,"11":1994,"12":2907,"13":1121,"14":3588,"15":2316,"16":2334,"17":2244,"18":587,"19":4339,"20":4405,"21":2001,"22":3361,"23":1637,"24":2901,"25":3860,"26":2517,"27":752,"28":1593,"29":2913,"30":4343,"31":299,"32":2594,"33":1831,"34":2158,"35":1446,"36":1330,"37":338,"38":3240,"39":3633,"40":2075,"41":3719,"42":2100,"43":366,"44":2860,"45":3363,"46":2578,"47":3690,"48":3602,"49":1482,"50":2823,"51":2104,"52":4387,"53":1226,"54":880,"55":3189,"56":402,"57":812,"58":1148,"59":1087,"60":3389,"61":2781,"62":2709,"63":2091,"64":1389,"65":1290,"66":3032,"67":581,"68":2542,"69":2693,"70":614,"71":618,"72":1864,"73":1899,"74":1114,"75":825,"76":1804,"77":4008,"78":2135,"79":2561,"80":1839,"81":36,"82":1113,"83":1020,"84":3589,"85":441,"86":3336,"87":3817,"88":2558,"89":260,"90":2939,"91":939,"92":4450,"93":2100,"94":3139,"95":2052,"96":3847,"97":1214,"98":1436,"99":1942,"100":4119,"101":770,"102":3615,"103":4107,"104":771,"105":36,"106":3169,"107":960,"108":40,"109":2171,"110":2799,"111":324,"112":3592,"113":1812,"114":36,"115":36,"116":3006,"117":4314,"118":2249,"119":41,"120":1085,"121":555,"122":2577,"123":3528,"124":1547,"125":1021,"126":3193,"127":852,"128":1283,"129":5511,"130":4242,"131":954,"132":2207,"133":1797,"134":3845,"135":1515,"136":4328,"137":325,"138":62,"139":4297,"140":1803,"141":2085,"142":295,"143":2535,"144":2135,"145":1359,"146":3051,"147":3860,"148":3110,"149":400,"150":3904,"151":1592,"152":463,"153":2325,"154":275,"155":1587,"156":566,"157":756,"158":2419,"159":1061,"160":3049,"161":3947,"162":2406,"163":153,"164":3963,"165":3218,"166":508,"167":620,"168":3678,"169":2221,"170":685,"171":2004,"172":3333,"173":2329,"174":3505,"175":3218,"176":1248,"177":1579,"178":1355,"179":2765,"180":3874,"181":2858,"182":1000,"183":54,"184":3119,"185":2668,"186":3007,"187":3998,"188":2421,"189":3942,"190":1905,"191":1295,"192":1238,"193":153,"194":2359,"195":1326,"196":3123,"197":3997,"198":142,"199":1170,"200":3699,"201":148,"202":3392,"203":2565,"204":896,"205":3392,"206":356,"207":893,"208":4156,"209":2098,"210":3030,"211":2784,"212":1624,"213":340,"214":3280,"215":1258,"216":988,"217":3909,"218":3886,"219":3479,"220":3552,"221":3187,"222":3670,"223":3028,"224":122,"225":4026,"226":3544,"227":3214,"228":3639,"229":698,"230":2124,"231":2169,"232":2155,"233":3602,"234":1326,"235":2257,"236":3637,"237":3395,"238":1227,"239":655},"Traj":{"0":0,"1":0,"2":0,"3":0,"4":0,"5":0,"6":0,"7":0,"8":0,"9":0,"10":0,"11":0,"12":0,"13":0,"14":0,"15":0,"16":0,"17":0,"18":0,"19":0,"20":0,"21":0,"22":0,"23":0,"24":0,"25":0,"26":0,"27":0,"28":0,"29":0,"30":0,"31":0,"32":0,"33":0,"34":0,"35":0,"36":0,"37":0,"38":0,"39":0,"40":0,"41":0,"42":0,"43":0,"44":0,"45":0,"46":0,"47":0,"48":0,"49":0,"50":0,"51":0,"52":0,"53":0,"54":0,"55":0,"56":0,"57":0,"58":0,"59":0,"60":0,"61":0,"62":0,"63":0,"64":0,"65":0,"66":0,"67":0,"68":0,"69":0,"70":0,"71":0,"72":0,"73":0,"74":0,"75":0,"76":0,"77":0,"78":0,"79":0,"80":0,"81":0,"82":0,"83":0,"84":0,"85":0,"86":0,"87":0,"88":0,"89":0,"90":0,"91":0,"92":0,"93":0,"94":0,"95":0,"96":0,"97":0,"98":0,"99":0,"100":0,"101":0,"102":0,"103":0,"104":0,"105":0,"106":0,"107":0,"108":0,"109":0,"110":0,"111":0,"112":0,"113":0,"114":0,"115":0,"116":0,"117":0,"118":0,"119":0,"120":0,"121":0,"122":0,"123":0,"124":0,"125":0,"126":0,"127":0,"128":0,"129":0,"130":0,"131":0,"132":0,"133":0,"134":0,"135":0,"136":0,"137":0,"138":0,"139":0,"140":0,"141":0,"142":0,"143":0,"144":0,"145":0,"146":0,"147":0,"148":0,"149":0,"150":0,"151":0,"152":0,"153":0,"154":0,"155":0,"156":0,"157":0,"158":0,"159":0,"160":0,"161":0,"162":0,"163":0,"164":0,"165":0,"166":0,"167":0,"168":0,"169":0,"170":0,"171":0,"172":0,"173":0,"174":0,"175":0,"176":0,"177":0,"178":0,"179":0,"180":0,"181":0,"182":0,"183":0,"184":0,"185":0,"186":0,"187":0,"188":0,"189":0,"190":0,"191":0,"192":0,"193":0,"194":0,"195":0,"196":0,"197":0,"198":0,"199":0,"200":0,"201":0,"202":0,"203":0,"204":0,"205":0,"206":0,"207":0,"208":0,"209":0,"210":0,"211":0,"212":0,"213":0,"214":0,"215":0,"216":0,"217":0,"218":0,"219":0,"220":0,"221":0,"222":0,"223":0,"224":0,"225":0,"226":0,"227":0,"228":0,"229":0,"230":0,"231":0,"232":0,"233":0,"234":0,"235":0,"236":0,"237":0,"238":0,"239":0},"PlayType":{"0":1,"1":1,"2":1,"3":1,"4":1,"5":1,"6":1,"7":1,"8":1,"9":1,"10":1,"11":1,"12":1,"13":1,"14":1,"15":1,"16":1,"17":1,"18":1,"19":1,"20":1,"21":1,"22":1,"23":1,"24":1,"25":1,"26":1,"27":1,"28":1,"29":1,"30":1,"31":1,"32":1,"33":1,"34":1,"35":1,"36":1,"37":1,"38":1,"39":1,"40":1,"41":1,"42":1,"43":1,"44":1,"45":1,"46":1,"47":1,"48":1,"49":1,"50":1,"51":1,"52":1,"53":1,"54":1,"55":1,"56":1,"57":1,"58":1,"59":1,"60":2,"61":2,"62":2,"63":2,"64":2,"65":2,"66":2,"67":2,"68":2,"69":2,"70":2,"71":2,"72":2,"73":2,"74":2,"75":2,"76":2,"77":2,"78":2,"79":2,"80":2,"81":2,"82":2,"83":2,"84":2,"85":2,"86":2,"87":2,"88":2,"89":2,"90":2,"91":2,"92":2,"93":2,"94":2,"95":2,"96":2,"97":2,"98":2,"99":2,"100":2,"101":2,"102":2,"103":2,"104":2,"105":2,"106":2,"107":2,"108":2,"109":2,"110":2,"111":2,"112":2,"113":2,"114":2,"115":2,"116":2,"117":2,"118":2,"119":2,"120":4,"121":4,"122":4,"123":4,"124":4,"125":4,"126":4,"127":4,"128":4,"129":4,"130":4,"131":4,"132":4,"133":4,"134":4,"135":4,"136":4,"137":4,"138":4,"139":4,"140":4,"141":4,"142":4,"143":4,"144":4,"145":4,"146":4,"147":4,"148":4,"149":4,"150":4,"151":4,"152":4,"153":4,"154":4,"155":4,"156":4,"157":4,"158":4,"159":4,"160":4,"161":4,"162":4,"163":4,"164":4,"165":4,"166":4,"167":4,"168":4,"169":4,"170":4,"171":4,"172":4,"173":4,"174":4,"175":4,"176":4,"177":4,"178":4,"179":4,"180":7,"181":7,"182":7,"183":7,"184":7,"185":7,"186":7,"187":7,"188":7,"189":7,"190":7,"191":7,"192":7,"193":7,"194":7,"195":7,"196":7,"197":7,"198":7,"199":7,"200":7,"201":7,"202":7,"203":7,"204":7,"205":7,"206":7,"207":7,"208":7,"209":7,"210":7,"211":7,"212":7,"213":7,"214":7,"215":7,"216":7,"217":7,"218":7,"219":7,"220":7,"221":7,"222":7,"223":7,"224":7,"225":7,"226":7,"227":7,"228":7,"229":7,"230":7,"231":7,"232":7,"233":7,"234":7,"235":7,"236":7,"237":7,"238":7,"239":7},"index":{"0":0,"1":1,"2":2,"3":3,"4":4,"5":5,"6":6,"7":7,"8":8,"9":9,"10":10,"11":11,"12":12,"13":13,"14":14,"15":15,"16":16,"17":17,"18":18,"19":19,"20":20,"21":21,"22":22,"23":23,"24":24,"25":25,"26":26,"27":27,"28":28,"29":29,"30":30,"31":31,"32":32,"33":33,"34":34,"35":35,"36":36,"37":37,"38":38,"39":39,"40":40,"41":41,"42":42,"43":43,"44":44,"45":45,"46":46,"47":47,"48":48,"49":49,"50":50,"51":51,"52":52,"53":53,"54":54,"55":55,"56":56,"57":57,"58":58,"59":59,"60":60,"61":61,"62":62,"63":63,"64":64,"65":65,"66":66,"67":67,"68":68,"69":69,"70":70,"71":71,"72":72,"73":73,"74":74,"75":75,"76":76,"77":77,"78":78,"79":79,"80":80,"81":81,"82":82,"83":83,"84":84,"85":85,"86":86,"87":87,"88":88,"89":89,"90":90,"91":91,"92":92,"93":93,"94":94,"95":95,"96":96,"97":97,"98":98,"99":99,"100":100,"101":101,"102":102,"103":103,"104":104,"105":105,"106":106,"107":107,"108":108,"109":109,"110":110,"111":111,"112":112,"113":113,"114":114,"115":115,"116":116,"117":117,"118":118,"119":119,"120":120,"121":121,"122":122,"123":123,"124":124,"125":125,"126":126,"127":127,"128":128,"129":129,"130":130,"131":131,"132":132,"133":133,"134":134,"135":135,"136":136,"137":137,"138":138,"139":139,"140":140,"141":141,"142":142,"143":143,"144":144,"145":145,"146":146,"147":147,"148":148,"149":149,"150":150,"151":151,"152":152,"153":153,"154":154,"155":155,"156":156,"157":157,"158":158,"159":159,"160":160,"161":161,"162":162,"163":163,"164":164,"165":165,"166":166,"167":167,"168":168,"169":169,"170":170,"171":171,"172":172,"173":173,"174":174,"175":175,"176":176,"177":177,"178":178,"179":179,"180":180,"181":181,"182":182,"183":183,"184":184,"185":185,"186":186,"187":187,"188":188,"189":189,"190":190,"191":191,"192":192,"193":193,"194":194,"195":195,"196":196,"197":197,"198":198,"199":199,"200":200,"201":201,"202":202,"203":203,"204":204,"205":205,"206":206,"207":207,"208":208,"209":209,"210":210,"211":211,"212":212,"213":213,"214":214,"215":215,"216":216,"217":217,"218":218,"219":219,"220":220,"221":221,"222":222,"223":223,"224":224,"225":225,"226":226,"227":227,"228":228,"229":229,"230":230,"231":231,"232":232,"233":233,"234":234,"235":235,"236":236,"237":237,"238":238,"239":239}}
|
assets/scrimmage_index.parquet
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a46d1fbccfcc76f8adab0a3653ee2eed8c4ebbad5b47348e62f309c0c7235981
|
3 |
+
size 1291
|
assets/starts_index.parquet
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7fa4326a5b21d554d21611679dd8eeb69000a8be2258a789617237f49ea447b4
|
3 |
+
size 10308
|
assets/time_index.parquet
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:990ef64cc53d1aefcd5916daeb91c090372833f2ac9422f60fef0264fd91a8a7
|
3 |
+
size 825
|
pages.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import plotly.express as px
|
3 |
+
import pandas as pd
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
from tools import generator
|
7 |
+
|
8 |
+
def set_app_title_and_logo():
|
9 |
+
st.set_page_config(
|
10 |
+
page_title="QB-GPT",
|
11 |
+
page_icon=":rocket:",
|
12 |
+
layout="wide",
|
13 |
+
)
|
14 |
+
|
15 |
+
def qb_gpt_page(ref_df, ref, tokenizer, model):
|
16 |
+
|
17 |
+
with st.container():
|
18 |
+
cola, colb = st.columns(2)
|
19 |
+
with cola:
|
20 |
+
selected_gameId = st.selectbox("Select Game ID", ref_df['gameId'].unique())
|
21 |
+
filtered_df1 = ref_df[(ref_df['gameId'] == selected_gameId)]
|
22 |
+
with colb:
|
23 |
+
selected_Play= st.selectbox("Select Play ", filtered_df1['playId'].unique())
|
24 |
+
filtered_df = filtered_df1[(filtered_df1['playId'] == selected_Play)].reset_index(drop ="True")
|
25 |
+
|
26 |
+
# Display the filtered DataFrame
|
27 |
+
st.write("Filtered Data:")
|
28 |
+
st.dataframe(filtered_df)
|
29 |
+
|
30 |
+
|
31 |
+
with st.container():
|
32 |
+
col1, col2 = st.columns(2)
|
33 |
+
with col1:
|
34 |
+
temperature = st.slider("Temperature", 1.0, 10.0, 1.5, step = 0.5)
|
35 |
+
with col2:
|
36 |
+
n_select = st.slider("N movements to shortlist", 2, 100, 10, step = 1)
|
37 |
+
|
38 |
+
QB_gen = generator(model=model,
|
39 |
+
tokenizer=tokenizer,
|
40 |
+
temp = temperature,
|
41 |
+
n_select = n_select)
|
42 |
+
|
43 |
+
|
44 |
+
selected = filtered_df["index"][0]
|
45 |
+
selection = ref[selected]
|
46 |
+
|
47 |
+
colc, cold = st.columns(2)
|
48 |
+
|
49 |
+
with colc:
|
50 |
+
starts = st.slider("Temperature", 1, 21, 1, step = 1)
|
51 |
+
with cold:
|
52 |
+
frames = st.slider("n select", 1, 50, 20, step = 1)
|
53 |
+
|
54 |
+
if st.button("Generate"):
|
55 |
+
trial_d = QB_gen.tokenizer.truncate_to_time_t(selection, starts)
|
56 |
+
generated = QB_gen.generate_sequence(trial_d, frames)
|
57 |
+
decoded = QB_gen.tokenizer.decode_sequence(generated)
|
58 |
+
|
59 |
+
step1 = QB_gen.prepare_for_plot(decoded)
|
60 |
+
plot = pd.DataFrame(step1)
|
61 |
+
|
62 |
+
decoded_true = QB_gen.tokenizer.decode_sequence(selection)
|
63 |
+
step1_true = QB_gen.prepare_for_plot(decoded_true)
|
64 |
+
plot_true = pd.DataFrame(step1_true)
|
65 |
+
|
66 |
+
fig_gen = px.line(plot, x="input_ids_x", y="input_ids_y", animation_frame="pos_ids", color="OffDef", symbol="ids",
|
67 |
+
text="position_ids", title="Player Trajectories Over Time", line_shape="linear",
|
68 |
+
range_x=[0, 140], range_y=[0, 60], # Set X and Y axis ranges
|
69 |
+
render_mode="svg") # Render mode for smoother lines
|
70 |
+
|
71 |
+
# Customize the appearance of the plot
|
72 |
+
fig_gen.update_traces(marker=dict(size=10), selector=dict(mode='lines'))
|
73 |
+
fig_gen.update_layout(width=800, height=600)
|
74 |
+
st.plotly_chart(fig_gen)
|
75 |
+
|
76 |
+
fig_true = px.line(plot_true, x="input_ids_x", y="input_ids_y", animation_frame="pos_ids", color="OffDef", symbol="ids",
|
77 |
+
text="position_ids", title="Player Trajectories Over Time",
|
78 |
+
range_x=[0, 140], range_y=[0, 60], # Set X and Y axis ranges
|
79 |
+
line_shape="linear", # Draw lines connecting points
|
80 |
+
render_mode="svg") # Render mode for smoother lines
|
81 |
+
|
82 |
+
# Customize the appearance of the plot
|
83 |
+
fig_true.update_traces(marker=dict(size=10), selector=dict(mode='lines'))
|
84 |
+
fig_true.update_layout(width=800, height=600)
|
85 |
+
st.plotly_chart(fig_true)
|
86 |
+
|
87 |
+
|
88 |
+
def contacts_and_disclaimers():
|
89 |
+
|
90 |
+
|
91 |
+
st.title("QB-GPT - Your Football Playbook Powerhouse!")
|
92 |
+
|
93 |
+
qb_gpt_text_intro = """
|
94 |
+
Are you a data scientist, a machine learning enthusiast, or simply a die-hard NFL fan looking to explore the power of Transformers in the world of American football? Look no further!
|
95 |
+
"""
|
96 |
+
st.markdown(qb_gpt_text_intro)
|
97 |
+
|
98 |
+
with st.expander("***What is QB-GPT?***"):
|
99 |
+
|
100 |
+
qb_gpt_what = """
|
101 |
+
QuarterBack-GPT (QB-GPT) is a companion in the world of football strategy and analytics. It's an innovative application with a model relying on the remarkable capabilities of Transformers to generate football plays that are not only strategic but also incredibly realistic. Imagine having an AI-powered coach in your corner, designing plays that could turn the tide of any game.
|
102 |
+
"""
|
103 |
+
st.markdown(qb_gpt_what)
|
104 |
+
|
105 |
+
with st.expander("***What's inside QB-GPT***"):
|
106 |
+
|
107 |
+
qb_gpt_transf = """
|
108 |
+
At the heart of QB-GPT lies the cutting-edge Transformer model, a deep learning architecture known for its prowess in understanding sequential data. It doesn't just create plays; it understands the game at a granular level, taking into account player positions, game situations, and historical data. It relies on the same conceptual approach behind the now famous "GPT" model of OpenAI. It's the playbook of the future, driven by the technology of tomorrow.
|
109 |
+
|
110 |
+
A more detailed blogpost about the model QB-GPT can be found [here](link)
|
111 |
+
"""
|
112 |
+
st.markdown(qb_gpt_transf)
|
113 |
+
|
114 |
+
with st.expander("***QB-GPT in Action***"):
|
115 |
+
|
116 |
+
qb_gpt_act = """
|
117 |
+
With QB-GPT, you can explore a wide range of football scenarios. Design plays that are tailored to your team's strengths, simulate game situations, and experiment with different strategies—all at your fingertips. Whether you're a coach looking to refine your playbook or an NFL enthusiast seeking the thrill of strategic gameplay, QB-GPT has something for everyone.
|
118 |
+
"""
|
119 |
+
st.markdown(qb_gpt_act)
|
120 |
+
|
121 |
+
|
122 |
+
|
123 |
+
with st.expander("***Author***"):
|
124 |
+
col1, col2 = st.columns([4, 1])
|
125 |
+
with col1:
|
126 |
+
author_text = """
|
127 |
+
My name is Samuel Chaineau, I am 26 and live in Paris. This is my second work related to Deep-Learning applied to NFL NGS data. This work is released as an app in order to facilitate the interaction and feedbacks from user. I also hope having constructive academic/scientific feedbacks to improve the model and bring it to the next level.
|
128 |
+
|
129 |
+
I have a background in economics, management, statistics and computer sciences. I am currently the CTO of a french healthcare start-up called Nuvocare. Prior to that I worked 2 and a half years as a Data Scientist at Ekimetrics.
|
130 |
+
|
131 |
+
***Contacts***
|
132 |
+
|
133 |
+
If interested by the project, the app or wishing to discuss any related topics, feel free to contact me on :
|
134 |
+
- My email : [email protected]
|
135 |
+
|
136 |
+
- Linkedin : [My profile](https://www.linkedin.com/in/samuel-chaineau-734b13122/)
|
137 |
+
|
138 |
+
- X (Twitter, nobody says X) : [My profile](https://twitter.com/samboucon)
|
139 |
+
|
140 |
+
- Or you can follow me on Medium : [My blog](https://medium.com/@sam.chaineau)
|
141 |
+
"""
|
142 |
+
st.markdown(author_text)
|
143 |
+
|
144 |
+
with col2:
|
145 |
+
image = Image.open('app/assets/photo_cv.jpg')
|
146 |
+
st.image(image)
|
147 |
+
|
148 |
+
with st.expander("***Disclaimers***"):
|
149 |
+
disclaimer_text = """
|
150 |
+
This work is at a very early stage and while I think it shows promising results, I acknowledge that the model may yield disturbing results and potentially wrong.
|
151 |
+
Maintaining and improving QB-GPT will be a long run.
|
152 |
+
|
153 |
+
I used data only found publicly on the internet (GitHub and Kaggle). I don't hold any relationship with NFL officials or any NFL teams.
|
154 |
+
I do not intend to have any payments or commercial activities via this app. It is a POC showing abilities, advantages and flaws of current SotA technologies applied to sports analytics.
|
155 |
+
|
156 |
+
"""
|
157 |
+
st.markdown(disclaimer_text)
|
158 |
+
|
159 |
+
with st.expander("***License***"):
|
160 |
+
license_text = """
|
161 |
+
**License**
|
162 |
+
|
163 |
+
This application and its associated work are released under the **Creative Commons Attribution-NonCommercial 4.0 International License**.
|
164 |
+
|
165 |
+
**Creative Commons Attribution-NonCommercial 4.0 International License (CC BY-NC 4.0)**
|
166 |
+
|
167 |
+
You are free to:
|
168 |
+
|
169 |
+
- **Share** - copy and redistribute the material in any medium or format.
|
170 |
+
|
171 |
+
- **Adapt** - remix, transform, and build upon the material.
|
172 |
+
|
173 |
+
Under the following terms:
|
174 |
+
|
175 |
+
- **Attribution (BY)**: You must give appropriate credits and indicate if changes were made. You may do so in any reasonable manner, but not in any way that suggests the licensor endorses you or your use.
|
176 |
+
|
177 |
+
- **Non-Commercial (NC)**: You may not use the material for commercial purposes.
|
178 |
+
|
179 |
+
- **No Derivatives (ND)**: If you remix, transform, or build upon the material, you may not distribute the modified material.
|
180 |
+
|
181 |
+
For a full description of the license terms, please visit the [Creative Commons Attribution-NonCommercial 4.0 International License](https://creativecommons.org/licenses/by-nc/4.0/).
|
182 |
+
|
183 |
+
This license is designed to allow others to use, remix, and build upon your work, but not for commercial purposes. It requires proper attribution and restricts commercial use and the creation of derivative works.
|
184 |
+
"""
|
185 |
+
st.markdown(license_text)
|
tools.py
ADDED
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import polars as pl
|
2 |
+
import numpy as np
|
3 |
+
import tensorflow as tf
|
4 |
+
import pandas as pd
|
5 |
+
|
6 |
+
class tokenizer:
|
7 |
+
def __init__(self,
|
8 |
+
moves_index : str,
|
9 |
+
play_index : str,
|
10 |
+
positions_index : str,
|
11 |
+
scrimmage_index : str,
|
12 |
+
starts_index : str,
|
13 |
+
time_index : str,
|
14 |
+
window_size : int):
|
15 |
+
self.window = window_size
|
16 |
+
|
17 |
+
moves_index = pl.read_parquet(moves_index)
|
18 |
+
self.moves_index = self.convert_index_to_dict(moves_index)
|
19 |
+
|
20 |
+
play_index = pl.read_parquet(play_index)
|
21 |
+
self.play_index= self.convert_index_to_dict(play_index)
|
22 |
+
|
23 |
+
positions_index = pl.read_parquet(positions_index)
|
24 |
+
self.positions_index = self.convert_index_to_dict(positions_index)
|
25 |
+
|
26 |
+
scrimmage_index = pl.read_parquet(scrimmage_index)
|
27 |
+
self.scrimmage_index = self.convert_index_to_dict(scrimmage_index)
|
28 |
+
|
29 |
+
starts_index = pl.read_parquet(starts_index)
|
30 |
+
self.starts_index = self.convert_index_to_dict(starts_index)
|
31 |
+
|
32 |
+
time_index = pl.read_parquet(time_index)
|
33 |
+
self.time_index = self.convert_index_to_dict(time_index)
|
34 |
+
|
35 |
+
self.offdef_index = {0 : ["Def"],
|
36 |
+
1 : ["Off"]}
|
37 |
+
|
38 |
+
self.index = {"input_ids" : self.moves_index,
|
39 |
+
"PlayType" : self.play_index,
|
40 |
+
"position_ids" : self.positions_index,
|
41 |
+
"scrim_ids" : self.scrimmage_index,
|
42 |
+
"start_ids" : self.starts_index,
|
43 |
+
"pos_ids" : self.time_index,
|
44 |
+
"OffDef" : self.offdef_index}
|
45 |
+
|
46 |
+
def convert_index_to_dict(self, df : pl.DataFrame):
|
47 |
+
|
48 |
+
ID_col = [v for v in df.columns if "ID" in v]
|
49 |
+
assert len(ID_col) == 1
|
50 |
+
new_id_name = ["ID"]
|
51 |
+
|
52 |
+
val_cols = [v for v in df.columns if v not in ID_col+["Cat"]]
|
53 |
+
new_val_name = ["Val_"+str(i) for i in range(1, len(val_cols)+1)]
|
54 |
+
|
55 |
+
past_names = ID_col + val_cols
|
56 |
+
new_names = new_id_name+new_val_name
|
57 |
+
|
58 |
+
renaming = {past_names[i]: new_names[i] for i in range(len(new_names))}
|
59 |
+
|
60 |
+
d = (df.
|
61 |
+
drop("Cat").
|
62 |
+
rename(renaming).
|
63 |
+
select(new_names).
|
64 |
+
to_dict(as_series=False))
|
65 |
+
|
66 |
+
final_d = {d["ID"][i] : [d[k][i] for k in new_val_name] for i in range(len(d["ID"]))}
|
67 |
+
|
68 |
+
return final_d
|
69 |
+
|
70 |
+
def base_decode(self,
|
71 |
+
pad_element,
|
72 |
+
inputs : list,
|
73 |
+
index : dict,
|
74 |
+
first : bool):
|
75 |
+
if first == True:
|
76 |
+
return [index[v][0] if v in index.keys() else pad_element for v in inputs]
|
77 |
+
else:
|
78 |
+
return [index[v] if v in index.keys() else pad_element for v in inputs]
|
79 |
+
|
80 |
+
def decode(self,
|
81 |
+
inputs : list,
|
82 |
+
type : str):
|
83 |
+
if type in ["input_ids", "start_ids"]:
|
84 |
+
padding = [-1000, -1000]
|
85 |
+
elif type in ["scrim_ids", "pos_ids"]:
|
86 |
+
padding = -1000
|
87 |
+
else:
|
88 |
+
padding = "[PAD]"
|
89 |
+
|
90 |
+
if type in ["input_ids", "start_ids"]:
|
91 |
+
return self.base_decode(padding, inputs, index = self.index[type], first=False)
|
92 |
+
else:
|
93 |
+
return self.base_decode(padding, inputs, index = self.index[type], first=True)
|
94 |
+
|
95 |
+
def find_id_by_values(self,
|
96 |
+
input_dict : dict,
|
97 |
+
target_list : list):
|
98 |
+
|
99 |
+
for key, values in input_dict.items():
|
100 |
+
if set(target_list) == set(values):
|
101 |
+
return key
|
102 |
+
|
103 |
+
def base_encode(self,
|
104 |
+
inputs : list,
|
105 |
+
index : dict):
|
106 |
+
return [self.find_id_by_values(index, [v]) for v in inputs]
|
107 |
+
|
108 |
+
def encode(self,
|
109 |
+
inputs : list,
|
110 |
+
type : str):
|
111 |
+
return self.base_encode(inputs, index = self.index[type])
|
112 |
+
|
113 |
+
def decode_sequence(self,
|
114 |
+
input : dict):
|
115 |
+
return {k : self.decode(v, k) if k not in ["side_ids", "token_type_ids", "labels", "attention_mask", "ids"] else v for k,v in input.items()}
|
116 |
+
|
117 |
+
def encode_sequence(self,
|
118 |
+
input : dict):
|
119 |
+
return {k : self.encode(v, k) if k not in ["side_ids", "token_type_ids", "labels", "attention_mask", "ids"] else v for k,v in input.items()}
|
120 |
+
|
121 |
+
def truncate_to_time_t(self,
|
122 |
+
input : dict,
|
123 |
+
t : int):
|
124 |
+
to_keep = [i < t for i in input["pos_ids"]]
|
125 |
+
return {k: [v[i] for i in range(len(v)) if to_keep[i] == True] for k,v in input.items()}
|
126 |
+
|
127 |
+
def resize_window(self,
|
128 |
+
input : dict,
|
129 |
+
pos_id):
|
130 |
+
out = input.copy()
|
131 |
+
out["attention_mask"] = [0 if out["pos_ids"][p] <pos_id else 1 for p in range(len(out["pos_ids"]))]
|
132 |
+
return out
|
133 |
+
|
134 |
+
def prepare_for_call(self,
|
135 |
+
input : dict):
|
136 |
+
resize_limit = max([v for v in np.array(input["pos_ids"]).flatten() if v != 51]) - self.window
|
137 |
+
if resize_limit > 0:
|
138 |
+
input = self.resize_window(input, resize_limit)
|
139 |
+
|
140 |
+
done = {k : tf.constant(v) for k,v in input.items()}
|
141 |
+
if len(done["pos_ids"].shape) == 1:
|
142 |
+
done = {k : tf.expand_dims(v, axis=0) for k,v in input.items()}
|
143 |
+
return done
|
144 |
+
|
145 |
+
class generator:
|
146 |
+
def __init__(self,
|
147 |
+
model,
|
148 |
+
tokenizer,
|
149 |
+
temp,
|
150 |
+
n_select):
|
151 |
+
|
152 |
+
self.QBGPT = model
|
153 |
+
self.tokenizer = tokenizer
|
154 |
+
|
155 |
+
self.temperature = temp
|
156 |
+
self.n_select = n_select
|
157 |
+
|
158 |
+
def get_unique_lists(self,
|
159 |
+
l_of_ls : list):
|
160 |
+
list_of_tuples = [tuple(inner_list) for inner_list in l_of_ls]
|
161 |
+
|
162 |
+
# Create a set to eliminate duplicate
|
163 |
+
unique_tuples = set(list_of_tuples)
|
164 |
+
|
165 |
+
# Convert unique tuples back to lists
|
166 |
+
unique_lists = [list(unique_tuple) for unique_tuple in unique_tuples]
|
167 |
+
|
168 |
+
return unique_lists
|
169 |
+
|
170 |
+
def cut(self, l, ref):
|
171 |
+
splitted = []
|
172 |
+
cutted = []
|
173 |
+
for i in range(len(l)):
|
174 |
+
if ref[i] == True:
|
175 |
+
cutted.append(l[i])
|
176 |
+
else:
|
177 |
+
splitted.append(cutted)
|
178 |
+
cutted = []
|
179 |
+
cutted.append(l[i])
|
180 |
+
if i == len(l)-1:
|
181 |
+
splitted.append(cutted)
|
182 |
+
return splitted
|
183 |
+
|
184 |
+
def get_last_preds(self,
|
185 |
+
logits,
|
186 |
+
input : dict):
|
187 |
+
|
188 |
+
to_keep = [i == max(input["pos_ids"]) for i in input["pos_ids"]]
|
189 |
+
return np.array([logits[i] for i in range(len(logits)) if to_keep[i] == True])
|
190 |
+
|
191 |
+
def get_logits(self,
|
192 |
+
input : dict):
|
193 |
+
x = self.tokenizer.prepare_for_call(input)
|
194 |
+
return self.QBGPT(x)
|
195 |
+
|
196 |
+
def convert_to_preds(self,
|
197 |
+
logits):
|
198 |
+
preds = tf.squeeze(logits, axis=0)
|
199 |
+
return preds
|
200 |
+
|
201 |
+
def set_temperature(self,
|
202 |
+
x):
|
203 |
+
if x < 5:
|
204 |
+
return self.temperature
|
205 |
+
elif x < 10 and x >= 5:
|
206 |
+
return self.temperature/2
|
207 |
+
elif x <20 and x >= 10:
|
208 |
+
return self.temperature/5
|
209 |
+
else:
|
210 |
+
return 1.0
|
211 |
+
|
212 |
+
def select_and_temp(self,
|
213 |
+
tensor,
|
214 |
+
n,
|
215 |
+
temp):
|
216 |
+
probas = tf.nn.softmax(tf.sort(tensor/temp, axis = -1)[:,:,-n:], axis = 2)
|
217 |
+
indices = tf.argsort(tensor, axis = -1)[:,:,-n:]
|
218 |
+
return probas, indices
|
219 |
+
|
220 |
+
def draw_random(self,
|
221 |
+
probas):
|
222 |
+
drawn = np.vstack([np.random.multinomial(1, p.numpy(), 1) for p in probas[0]])
|
223 |
+
drawn = tf.expand_dims(drawn, axis = 0)
|
224 |
+
return tf.cast(drawn, dtype="int32")
|
225 |
+
|
226 |
+
def get_indices(self,
|
227 |
+
drawn,
|
228 |
+
ind):
|
229 |
+
return tf.reduce_sum(drawn*ind, axis = 2)
|
230 |
+
|
231 |
+
def process_logits(self,
|
232 |
+
logits,
|
233 |
+
temp,
|
234 |
+
n):
|
235 |
+
probas, indices = self.select_and_temp(logits, n, temp)
|
236 |
+
drawn = self.draw_random(probas)
|
237 |
+
results = self.get_indices(drawn, indices)
|
238 |
+
return results
|
239 |
+
|
240 |
+
def generate(self,
|
241 |
+
input : dict):
|
242 |
+
logits = self.get_logits(input)
|
243 |
+
temperature_parameter = self.set_temperature(max(input["pos_ids"]))
|
244 |
+
processed_logits = self.process_logits(logits, n=self.n_select, temp=temperature_parameter)
|
245 |
+
preds = self.convert_to_preds(processed_logits)
|
246 |
+
return self.get_last_preds(preds, input)
|
247 |
+
|
248 |
+
def slice_inputs(self,
|
249 |
+
input : dict):
|
250 |
+
flags = [True] + [input["pos_ids"][i+1] > input["pos_ids"][i] for i in range(len(input["pos_ids"])-1)]
|
251 |
+
cutted_inputs = {k : self.cut(v, flags) for k,v in input.items()}
|
252 |
+
return cutted_inputs
|
253 |
+
|
254 |
+
def continue_by_token(self,
|
255 |
+
arr,
|
256 |
+
token :str):
|
257 |
+
if token == "input_ids":
|
258 |
+
return arr
|
259 |
+
if token == "pos_ids":
|
260 |
+
insert = max(arr)+1
|
261 |
+
return np.concatenate([arr, np.array([insert])])
|
262 |
+
elif token == "token_type_ids":
|
263 |
+
return np.concatenate([arr, np.array([1])])
|
264 |
+
else:
|
265 |
+
return np.concatenate([arr, [arr[-1]]])
|
266 |
+
|
267 |
+
|
268 |
+
def append_prediction(self,
|
269 |
+
arr,
|
270 |
+
pred):
|
271 |
+
return np.concatenate([arr, [pred]])
|
272 |
+
|
273 |
+
def append_predictions(self,
|
274 |
+
d : dict,
|
275 |
+
preds):
|
276 |
+
new = d.copy()
|
277 |
+
new["input_ids"] = [self.append_prediction(new["input_ids"][i], preds[i]) for i in range(len(preds))]
|
278 |
+
return new
|
279 |
+
|
280 |
+
def merge_cuts(self,
|
281 |
+
input : dict):
|
282 |
+
return {k : np.concatenate(v) for k,v in input.items()}
|
283 |
+
|
284 |
+
def update_inputs(self,
|
285 |
+
input,
|
286 |
+
preds):
|
287 |
+
sliced = self.slice_inputs(input)
|
288 |
+
appended = self.append_predictions(sliced, preds)
|
289 |
+
continued = {k : [self.continue_by_token(e, k) for e in v] for k,v in appended.items()}
|
290 |
+
merged = self.merge_cuts(continued)
|
291 |
+
return merged
|
292 |
+
|
293 |
+
def generate_sequence(self,
|
294 |
+
input,
|
295 |
+
t):
|
296 |
+
new_input = input.copy()
|
297 |
+
for i in range(t):
|
298 |
+
generated = self.generate(new_input)
|
299 |
+
new_input = self.update_inputs(new_input, generated)
|
300 |
+
return new_input
|
301 |
+
|
302 |
+
def convert_list(self,
|
303 |
+
d,
|
304 |
+
keep_original):
|
305 |
+
new_df = d.copy()
|
306 |
+
new_df["start_ids_x"] = [v[0] for v in new_df["start_ids"]]
|
307 |
+
new_df["start_ids_y"] = [v[1] for v in new_df["start_ids"]]
|
308 |
+
new_df["input_ids_x"] = [v[0] for v in new_df["input_ids"]]
|
309 |
+
new_df["input_ids_y"] = [v[1] for v in new_df["input_ids"]]
|
310 |
+
if keep_original == True:
|
311 |
+
return new_df
|
312 |
+
else:
|
313 |
+
return {k : v for k,v in new_df.items() if k not in ["start_ids", "input_ids"]}
|
314 |
+
|
315 |
+
def remove_pad(self,
|
316 |
+
seq):
|
317 |
+
df = pd.DataFrame(seq)
|
318 |
+
filtered = df[df["start_ids_x"] != -1000].reset_index(drop=True)
|
319 |
+
filtered = df[df["input_ids_x"] != -1000].reset_index(drop=True)
|
320 |
+
return filtered.to_dict(orient = "list")
|
321 |
+
|
322 |
+
def _compute_true_sequence(self,
|
323 |
+
scrimmage_line,
|
324 |
+
start : list,
|
325 |
+
moves : list):
|
326 |
+
scrimmage = np.array([scrimmage_line, 26.5])
|
327 |
+
|
328 |
+
updated_moves = np.array([np.array(start) + np.array(v) for v in moves])
|
329 |
+
|
330 |
+
appended = np.concatenate([np.expand_dims(start, axis = 0), updated_moves])
|
331 |
+
|
332 |
+
final = appended + scrimmage
|
333 |
+
return final
|
334 |
+
|
335 |
+
def compute_true_sequence(self,
|
336 |
+
scrims,
|
337 |
+
starts,
|
338 |
+
moves):
|
339 |
+
return self._compute_true_sequence(np.unique(scrims)[0], self.get_unique_lists(starts)[0], moves)
|
340 |
+
|
341 |
+
def _resize_variable(self,
|
342 |
+
x,
|
343 |
+
ref: str):
|
344 |
+
|
345 |
+
if ref in ["pos_ids", "token_type_ids"]:
|
346 |
+
return np.concatenate([[0], x])
|
347 |
+
|
348 |
+
elif ref in ["input_ids", "start_ids"]:
|
349 |
+
return np.vstack([self.get_unique_lists(x)[0], x])
|
350 |
+
|
351 |
+
else:
|
352 |
+
return np.concatenate([np.unique(x), x])
|
353 |
+
|
354 |
+
def prepare_for_plot(self,
|
355 |
+
seq):
|
356 |
+
sequence = seq.copy()
|
357 |
+
sequence = self.convert_list(sequence, keep_original = True)
|
358 |
+
sequence = self.remove_pad(sequence)
|
359 |
+
cutted = self.slice_inputs(sequence)
|
360 |
+
moves_updated = [self.compute_true_sequence(cutted["scrim_ids"][i], cutted["start_ids"][i], cutted["input_ids"][i]) for i in range(len(cutted["input_ids"]))]
|
361 |
+
cutted["input_ids"] = moves_updated
|
362 |
+
cutted = {k : [self._resize_variable(e, k) if k != "input_ids" else e for e in v] for k,v in cutted.items()}
|
363 |
+
cutted["ids"] = [[i for e in range(len(cutted["input_ids"][i]))] for i in range(len(cutted["input_ids"]))]
|
364 |
+
merged = self.merge_cuts(cutted)
|
365 |
+
converted = self.convert_list(merged, keep_original = False)
|
366 |
+
structured = {k:v for k,v in converted.items() if k != "labels"}
|
367 |
+
return structured
|
368 |
+
|
369 |
+
def insert_ids(self,
|
370 |
+
input):
|
371 |
+
|
372 |
+
cutted = self.slice_inputs(input)
|
373 |
+
cutted["ids"] = [[i for e in range(len(cutted["input_ids"][i]))] for i in range(len(cutted["input_ids"]))]
|
374 |
+
merged = self.merge_cuts(cutted)
|
375 |
+
return merged
|