hbfreed commited on
Commit
6d5c70d
1 Parent(s): 799380d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -0
app.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import random
4
+ import os
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from mobilenet import MobileNetLarge3D
8
+ from torchvision.io import read_video
9
+ import time
10
+
11
+ def classify_pitch(confidence_scores):
12
+ if torch.argmax(confidence_scores) == 0:
13
+ call = 'Ball'
14
+ elif torch.argmax(confidence_scores) == 1:
15
+ call = 'Strike'
16
+ else:
17
+ print("that's odd, something is wrong")
18
+ pass
19
+ return call
20
+
21
+ def get_demo_call(video_name):
22
+ video_name = os.path.basename(video_name).lower()
23
+ if video_name.startswith("strike"):
24
+ ump_out = "Strike"
25
+ elif video_name.startswith("ball"):
26
+ ump_out = "Ball"
27
+ else:
28
+ ump_out = "Error"
29
+ return ump_out
30
+
31
+ def call_pitch(pitch):
32
+ std = (0.2104, 0.1986, 0.1829)
33
+ mean = (0.3939, 0.3817, 0.3314)
34
+ # Convert the mean and std to tensors
35
+ mean = torch.tensor(mean).view(1, 3, 1, 1, 1)
36
+ std = torch.tensor(std).view(1, 3, 1, 1, 1)
37
+ ump_out = get_demo_call(pitch)
38
+ pitch_tensor = (read_video(pitch,pts_unit='sec')[0].permute(-1,0,1,2)).unsqueeze(0)/255
39
+ pitch_tensor = (pitch_tensor-mean)/std #normalize the pitch tensor
40
+ video_length = pitch_tensor.shape[2]/15
41
+ model = MobileNetLarge3D()
42
+ model.load_state_dict(torch.load('weights/MobileNetLarge.pth',map_location=torch.device('cpu')))
43
+ model.eval()
44
+
45
+ #run the model
46
+ with torch.no_grad():
47
+ output = model(pitch_tensor)
48
+ output = F.softmax(output,dim=1)
49
+ final_call = classify_pitch(output)
50
+ time.sleep(video_length-1.5) #wait until the video is done to return the call
51
+
52
+ return final_call,ump_out
53
+
54
+ def generate_random_pitch():
55
+ random_number = random.randint(1,2645342) #random number in our range to select a random pitch
56
+
57
+ df = pd.read_csv('picklebot_2m.csv',skiprows=random_number,nrows=1,header=None) #just load one row
58
+ video_link = df.values[0][0]
59
+ label = df.values[0][1]
60
+
61
+ if label == 0:
62
+ ump_out = "Ball"
63
+ elif label == 1:
64
+ ump_out = "Strike"
65
+ else:
66
+ #error if it's not a 0 or 1
67
+ ump_out = "Error"
68
+
69
+ random_pitch = download_pitch(video_link,ump_out)
70
+
71
+ return random_pitch
72
+ def download_pitch(video_link,ump_out):
73
+ #download and process the video link using yt-dlp and ffmpeg
74
+ os.system(f"yt-dlp -q --no-warnings --force-overwrites -f mp4 {video_link} -o 'demo_files/downloaded_videos/{ump_out}_demo_video.mp4'")
75
+ os.system(f'ffmpeg -y -nostats -loglevel 0 -i demo_files/downloaded_videos/{ump_out}_demo_video.mp4 -vf "crop=700:700:in_w/2-350:in_h/2-350,scale=224:224" -r 15 -c:v libx264 -crf 23 -c:a aac -strict experimental demo_files/downloaded_videos/{ump_out}_cropped_video.mp4')
76
+ return f"demo_files/downloaded_videos/{ump_out}_cropped_video.mp4"
77
+
78
+
79
+ demo_files = os.listdir("demo_files/")
80
+ demo_files = [os.path.join("demo_files/", file) for file in demo_files if file.endswith(".mp4")]
81
+ with gr.Blocks(title="Picklebot") as demo:
82
+ with gr.Row():
83
+ with gr.Column(scale=3):
84
+ inp = gr.Video(interactive=False, label="Pitch Video")
85
+ with gr.Column(scale=2):
86
+ ump_out = gr.Label(label="Umpire's Original Call")
87
+ pb_out = gr.Label(label="Picklebot's Call")
88
+ random_button = gr.Button("🔀 Load a Random Pitch")
89
+ ball = 0
90
+ random_button.click(fn=generate_random_pitch,outputs=[inp])
91
+ inp.play(fn=call_pitch,inputs=inp,outputs=[pb_out,ump_out])
92
+ gr.ClearButton([inp,pb_out,ump_out])
93
+ gr.Examples(demo_files,inputs=inp,outputs=[inp])
94
+
95
+ if __name__ == "__main__":
96
+ demo.launch()