import gradio as gr import pandas as pd import random import os import torch import torch.nn.functional as F from mobilenet import MobileNetLarge3D, MobileNetSmall3D from torchvision.io import read_video import time def classify_pitch(confidence_scores): if torch.argmax(confidence_scores) == 0: call = 'Ball' elif torch.argmax(confidence_scores) == 1: call = 'Strike' else: print("that's odd, something is wrong") pass return call def get_demo_call(video_name): video_name = os.path.basename(video_name).lower() if video_name.startswith("strike"): ump_out = "Strike" elif video_name.startswith("ball"): ump_out = "Ball" else: ump_out = None return ump_out def call_pitch(pitch): start = time.time() std = (0.2104, 0.1986, 0.1829) mean = (0.3939, 0.3817, 0.3314) # Convert the mean and std to tensors mean = torch.tensor(mean).view(1, 3, 1, 1, 1) std = torch.tensor(std).view(1, 3, 1, 1, 1) ump_out = get_demo_call(pitch) pitch_tensor = (read_video(pitch,pts_unit='sec')[0].permute(-1,0,1,2)).unsqueeze(0)/255 pitch_tensor = (pitch_tensor-mean)/std #normalize the pitch tensor video_length = pitch_tensor.shape[2] #get video length to wait for the video to finish model = MobileNetSmall3D() model.load_state_dict(torch.load('weights/MobileNetSmall.pth',map_location=torch.device('cpu'))) model.eval() #run the model with torch.no_grad(): output = model(pitch_tensor) output = F.softmax(output,dim=1) final_call = classify_pitch(output) # time.sleep(video_length-1.5) #wait until the video is done to return the call while time.time() - start < (video_length-5)/15: #wait until the video is done to return the call and go 2 frames early to seem speedy time.sleep(0.1) if ump_out is not None: return final_call,ump_out else: return final_call def generate_random_pitch(): random_number = random.randint(1,2645342) #random number in our range to select a random pitch df = pd.read_csv('picklebot_2m.csv',skiprows=random_number,nrows=1,header=None) #just load one row video_link = df.values[0][0] label = df.values[0][1] if label == 0: ump_out = "Ball" elif label == 1: ump_out = "Strike" else: #error if it's not a 0 or 1 ump_out = "Error" random_pitch = download_pitch(video_link,ump_out) return random_pitch def download_pitch(video_link,ump_out=None): #download and process the video link using yt-dlp and ffmpeg if ump_out is not None: video_out = f"demo_files/downloaded_videos/{ump_out}_demo_video.mp4" os.system(f"yt-dlp -q --no-warnings --force-overwrites -f mp4 {video_link} -o '{video_out}'") os.system(f'ffmpeg -y -nostats -loglevel 0 -i {video_out} -vf "crop=700:700:in_w/2-350:in_h/2-350,scale=224:224" -r 15 -c:v libx264 -crf 23 -c:a aac -an -strict experimental {video_out}_cropped_video.mp4') else: video_out = f"demo_files/downloaded_videos/own_demo_video.mp4" os.system(f"yt-dlp -q --no-warnings --force-overwrites -f mp4 {video_link} -o '{video_out}'") os.system(f'ffmpeg -y -nostats -loglevel 0 -i {video_out} -vf "crop=700:700:in_w/2-350:in_h/2-350,scale=224:224" -r 15 -c:v libx264 -crf 23 -c:a aac -an -strict experimental {video_out}_cropped_video.mp4') return f"{video_out}_cropped_video.mp4" demo_files = os.listdir("demo_files/") demo_files = [os.path.join("demo_files/", file) for file in demo_files if file.endswith(".mp4")] with gr.Blocks(title="Picklebot") as demo: with gr.Tab("Picklebot"): gr.Markdown(value=""" To load a video, click the random button or choose from the examples. Play the video by clicking anywhere on the video, and watch Picklebot's call! To use your own video, click the "Use Your Own Video!" tab and follow the instructions.""") with gr.Row(): with gr.Column(scale=3): inp = gr.Video(interactive=False, label="Pitch Video") with gr.Column(scale=2): ump_out = gr.Label(label="Umpire's Original Call") pb_out = gr.Label(label="Picklebot's Call") random_button = gr.Button("🔀 Load a Random Pitch") random_button.click(fn=generate_random_pitch,outputs=[inp]) inp.play(fn=call_pitch,inputs=inp,outputs=[pb_out,ump_out]) gr.ClearButton([inp,pb_out,ump_out]) gr.Examples(demo_files,inputs=inp,outputs=[inp]) with gr.Tab("Use Your Own Video!"): with gr.Row(): gr.Markdown(value= """ # Here\'s how to use your own video: 1. Navigate to [Baseball Savant's statcast search](https://baseballsavant.mlb.com/statcast_search) 2. Choose the situation and pitch you want (just keep in mind that the network was only trained on called balls and called strikes) 3. Click on the pitcher or batter's name to get to list of pitches 4. Right click on the camera icon to the right, and copy the link 5. Paste the video url in below, and press go to download the pitch 6. Watch the video and see what Picklebot thinks! """) with gr.Row(): with gr.Column(scale=3): vid_inp = gr.Video(interactive=False, label="Pitch Video") with gr.Column(scale=2): #make a textbox to take in the user's video link input_txt = gr.Textbox(placeholder="Paste your video link here",label="Video Link") input_txt.submit(fn=download_pitch,inputs=input_txt,outputs=vid_inp) submit_button = gr.Button("Go!") submit_button.click(fn=download_pitch,inputs=input_txt,outputs=vid_inp) pb_out = gr.Label(label="Picklebot's Call") vid_inp.play(fn=call_pitch,inputs=vid_inp,outputs=pb_out) gr.ClearButton([vid_inp,pb_out,input_txt]) with gr.Tab("About"): gr.Markdown(value= """Picklebot is a 3D Convolutional Neural Network based on MobileNetV3 that classifies baseball pitches as balls or strikes. The network was trained on the [Picklebot-50K Dataset](https://huggingface.co/datasets/hbfreed/Picklebot-50K), comprised of over fifty thousand pitches to achieve ~80% accuracy. Here's the [GitHub](https://github.com/hbfreed/Picklebot) for the project. Here's my [LinkedIn](https://www.linkedin.com/in/hbfreed/) if you want to connect.""") if __name__ == "__main__": demo.launch()