Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import pandas as pd
|
3 |
+
import random
|
4 |
+
import os
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from mobilenet import MobileNetLarge3D
|
8 |
+
from torchvision.io import read_video
|
9 |
+
import time
|
10 |
+
|
11 |
+
def classify_pitch(confidence_scores):
|
12 |
+
if torch.argmax(confidence_scores) == 0:
|
13 |
+
call = 'Ball'
|
14 |
+
elif torch.argmax(confidence_scores) == 1:
|
15 |
+
call = 'Strike'
|
16 |
+
else:
|
17 |
+
print("that's odd, something is wrong")
|
18 |
+
pass
|
19 |
+
return call
|
20 |
+
|
21 |
+
def get_demo_call(video_name):
|
22 |
+
video_name = os.path.basename(video_name).lower()
|
23 |
+
if video_name.startswith("strike"):
|
24 |
+
ump_out = "Strike"
|
25 |
+
elif video_name.startswith("ball"):
|
26 |
+
ump_out = "Ball"
|
27 |
+
else:
|
28 |
+
ump_out = "Error"
|
29 |
+
return ump_out
|
30 |
+
|
31 |
+
def call_pitch(pitch):
|
32 |
+
std = (0.2104, 0.1986, 0.1829)
|
33 |
+
mean = (0.3939, 0.3817, 0.3314)
|
34 |
+
# Convert the mean and std to tensors
|
35 |
+
mean = torch.tensor(mean).view(1, 3, 1, 1, 1)
|
36 |
+
std = torch.tensor(std).view(1, 3, 1, 1, 1)
|
37 |
+
ump_out = get_demo_call(pitch)
|
38 |
+
pitch_tensor = (read_video(pitch,pts_unit='sec')[0].permute(-1,0,1,2)).unsqueeze(0)/255
|
39 |
+
pitch_tensor = (pitch_tensor-mean)/std #normalize the pitch tensor
|
40 |
+
video_length = pitch_tensor.shape[2]/15
|
41 |
+
model = MobileNetLarge3D()
|
42 |
+
model.load_state_dict(torch.load('weights/MobileNetLarge.pth',map_location=torch.device('cpu')))
|
43 |
+
model.eval()
|
44 |
+
|
45 |
+
#run the model
|
46 |
+
with torch.no_grad():
|
47 |
+
output = model(pitch_tensor)
|
48 |
+
output = F.softmax(output,dim=1)
|
49 |
+
final_call = classify_pitch(output)
|
50 |
+
time.sleep(video_length-1.5) #wait until the video is done to return the call
|
51 |
+
|
52 |
+
return final_call,ump_out
|
53 |
+
|
54 |
+
def generate_random_pitch():
|
55 |
+
random_number = random.randint(1,2645342) #random number in our range to select a random pitch
|
56 |
+
|
57 |
+
df = pd.read_csv('picklebot_2m.csv',skiprows=random_number,nrows=1,header=None) #just load one row
|
58 |
+
video_link = df.values[0][0]
|
59 |
+
label = df.values[0][1]
|
60 |
+
|
61 |
+
if label == 0:
|
62 |
+
ump_out = "Ball"
|
63 |
+
elif label == 1:
|
64 |
+
ump_out = "Strike"
|
65 |
+
else:
|
66 |
+
#error if it's not a 0 or 1
|
67 |
+
ump_out = "Error"
|
68 |
+
|
69 |
+
random_pitch = download_pitch(video_link,ump_out)
|
70 |
+
|
71 |
+
return random_pitch
|
72 |
+
def download_pitch(video_link,ump_out):
|
73 |
+
#download and process the video link using yt-dlp and ffmpeg
|
74 |
+
os.system(f"yt-dlp -q --no-warnings --force-overwrites -f mp4 {video_link} -o 'demo_files/downloaded_videos/{ump_out}_demo_video.mp4'")
|
75 |
+
os.system(f'ffmpeg -y -nostats -loglevel 0 -i demo_files/downloaded_videos/{ump_out}_demo_video.mp4 -vf "crop=700:700:in_w/2-350:in_h/2-350,scale=224:224" -r 15 -c:v libx264 -crf 23 -c:a aac -strict experimental demo_files/downloaded_videos/{ump_out}_cropped_video.mp4')
|
76 |
+
return f"demo_files/downloaded_videos/{ump_out}_cropped_video.mp4"
|
77 |
+
|
78 |
+
|
79 |
+
demo_files = os.listdir("demo_files/")
|
80 |
+
demo_files = [os.path.join("demo_files/", file) for file in demo_files if file.endswith(".mp4")]
|
81 |
+
with gr.Blocks(title="Picklebot") as demo:
|
82 |
+
with gr.Row():
|
83 |
+
with gr.Column(scale=3):
|
84 |
+
inp = gr.Video(interactive=False, label="Pitch Video")
|
85 |
+
with gr.Column(scale=2):
|
86 |
+
ump_out = gr.Label(label="Umpire's Original Call")
|
87 |
+
pb_out = gr.Label(label="Picklebot's Call")
|
88 |
+
random_button = gr.Button("🔀 Load a Random Pitch")
|
89 |
+
ball = 0
|
90 |
+
random_button.click(fn=generate_random_pitch,outputs=[inp])
|
91 |
+
inp.play(fn=call_pitch,inputs=inp,outputs=[pb_out,ump_out])
|
92 |
+
gr.ClearButton([inp,pb_out,ump_out])
|
93 |
+
gr.Examples(demo_files,inputs=inp,outputs=[inp])
|
94 |
+
|
95 |
+
if __name__ == "__main__":
|
96 |
+
demo.launch()
|