Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import pandas as pd | |
import pickle | |
import xgboost as xgb | |
from catboost import CatBoostRegressor | |
def predict(team, inning, venue, hits, errors, lob, runs, opp_team, opp_runs, opp_hits): | |
data = [team, opp_team, inning, venue, hits, opp_hits, errors, runs, opp_runs, lob] | |
df_main = pd.read_csv("Score_prediction_dataset_11th_July.csv") | |
df_main = df_main.drop(columns=['Final_Score', 'Opp_LOB']) | |
df_main = pd.get_dummies(df_main, columns=['Team_Name', 'Opposition_Team']) | |
df = pd.DataFrame([data], columns=["Team_Name", "Opposition_Team", "Inning", "Home/Away", "Hits", "Opp_Hits", "Errors", "Runs", "Opp_Runs", "LOB"]) | |
df = pd.get_dummies(df, columns=['Team_Name', 'Opposition_Team']) | |
df = df.reindex(columns=df_main.columns, fill_value=0) | |
# print(df.columns) | |
# print(len(df.columns)) | |
xgb_model = xgb.XGBRegressor() | |
xgb_model.load_model('xgbr1_exp10_model.json') | |
with open('pca_model7.pkl', 'rb') as f: | |
pca = pickle.load(f) | |
# with open('label_encoder_teams_xgbr1_exp3.pkl', 'rb') as f: | |
# label_encoder = pickle.load(f) | |
home_away_status = {'Home': 0, 'Away': 1} | |
df['Home/Away'] = df['Home/Away'].map(home_away_status) | |
df = df.astype(int) | |
df = pca.transform(df) | |
score = xgb_model.predict(df) | |
if score[0] < 0: | |
score = np.clip(score[0], a_min=0, a_max=None) | |
return np.round(score,1) | |
if score[0] < runs: | |
score = runs | |
return score | |
return np.round(score[0],1) | |
def predict_2(team, inning, venue, hits, errors, lob, runs, opp_team, opp_runs, opp_hits): | |
data = [team, opp_team, inning, venue, hits, opp_hits, errors, runs, opp_runs, lob] | |
df_main = pd.read_csv("Score_prediction_dataset_11th_July.csv") | |
df_main = df_main.drop(columns=['Final_Score', 'Opp_LOB']) | |
df_main = pd.get_dummies(df_main, columns=['Team_Name', 'Opposition_Team']) | |
df = pd.DataFrame([data], columns=["Team_Name", "Opposition_Team", "Inning", "Home/Away", "Hits", "Opp_Hits", "Errors", "Runs", "Opp_Runs", "LOB"]) | |
df = pd.get_dummies(df, columns=['Team_Name', 'Opposition_Team']) | |
df = df.reindex(columns=df_main.columns, fill_value=0) | |
cat_model = CatBoostRegressor() | |
cat_model.load_model('catbr1_exp11_model.json') | |
# with open('label_encoder_teams_catbr1_exp1.pkl', 'rb') as f: | |
# label_encoder = pickle.load(f) | |
# df['Team_Name'] = label_encoder.transform(df['Team_Name']) | |
# df['Opposition_Team'] = label_encoder.transform(df['Opposition_Team']) | |
home_away_status = {'Home': 0, 'Away': 1} | |
df['Home/Away'] = df['Home/Away'].map(home_away_status) | |
df = df.astype(int) | |
# print(df) | |
with open('pca_model7.pkl', 'rb') as f: | |
pca = pickle.load(f) | |
df = pca.transform(df) | |
score = cat_model.predict(df) | |
if score[0] < 0: | |
score = np.clip(score[0], a_min=0, a_max=None) | |
return np.round(score,1) | |
if score[0] < runs: | |
score = runs | |
return score | |
return np.round(score[0],1) | |
team_names = ["Arizona Diamondbacks", | |
"Atlanta Braves", | |
"Baltimore Orioles", | |
"Boston Red Sox", | |
"Chicago Cubs", | |
"Chicago White Sox", | |
"Cincinnati Reds", | |
"Cleveland Guardians", | |
"Colorado Rockies", | |
"Detroit Tigers", | |
"Houston Astros", | |
"Kansas City Royals", | |
"Los Angeles Angels", | |
"Los Angeles Dodgers", | |
"Miami Marlins", | |
"Milwaukee Brewers", | |
"Minnesota Twins", | |
"New York Mets", | |
"New York Yankees", | |
"Oakland Athletics", | |
"Philadelphia Phillies", | |
"Pittsburgh Pirates", | |
"San Diego Padres", | |
"San Francisco Giants", | |
"Seattle Mariners", | |
"St. Louis Cardinals", | |
"Tampa Bay Rays", | |
"Texas Rangers", | |
"Toronto Blue Jays", | |
"Washington Nationals"] | |
with gr.Blocks() as demo: | |
# gr.Image("../Documentation/Context Diagram.png", scale=2) | |
# gr(title="Your Interface Title") | |
gr.Markdown(""" | |
<center> | |
<span style='font-size: 50px; font-weight: Bold; font-family: "Graduate", serif'> | |
MLB Score Predictor | |
</span> | |
</center> | |
""") | |
# gr.Markdown(""" | |
# <center> | |
# <span style='font-size: 30px; line-height: 0.1; font-weight: Bold; font-family: "Graduate", serif'> | |
# Admin Dashboard | |
# </span> | |
# </center> | |
# """) | |
with gr.Row(): | |
inning = gr.Number(None, label="Inning", minimum = 1, maximum = 8, scale=1) | |
with gr.Row(): | |
with gr.Column(): | |
venue = gr.Dropdown(choices = ["Home", "Away"], value="Away", max_choices = 1, label="Home/Away Status", scale=1) | |
with gr.Column(): | |
opp_venue = gr.Dropdown(choices = ["Home", "Away"], value="Home", max_choices = 1, label="Opposition Home/Away Status", scale=1) | |
with gr.Row(): | |
with gr.Column(): | |
team = gr.Dropdown(choices = team_names, max_choices = 1, label="Team", scale=1) | |
with gr.Column(): | |
opp_team = gr.Dropdown(choices = team_names, max_choices = 1, label="Opposition Team", scale=1) | |
with gr.Row(): | |
with gr.Column(): | |
hits = gr.Number(None, minimum=0, label="Hits - (H)", scale=1) | |
with gr.Column(): | |
opp_hits = gr.Number(None, minimum=0, label="Opposition Hits - (H)", scale=1) | |
# summarize_btn = gr.Button(value="Summarize Text", size = 'sm') | |
with gr.Row(): | |
with gr.Column(): | |
errors = gr.Number(None, minimum=0, label="Errors - (E)", scale=2) | |
with gr.Column(): | |
opp_errors = gr.Number(None, minimum=0, label="Opposition Errors - (E)", scale=2) | |
# runs = gr.Number(None, minimum=0, label="Runs - (R)", scale=1) | |
with gr.Row(): | |
with gr.Column(): | |
lob = gr.Number(None, minimum=0, label="Left on Base - (LOB)", scale=1) | |
with gr.Column(): | |
opp_lob = gr.Number(None, minimum=0, label="Opposition Left on Base - (LOB)", scale=1) | |
with gr.Row(): | |
with gr.Column(): | |
runs = gr.Number(None, minimum=0, label="Runs - (R)", scale=1) | |
with gr.Column(): | |
opp_runs = gr.Number(None, minimum=0, label="Opposition Runs - (R)", scale=1) | |
with gr.Row(): | |
predict_btn = gr.Button(value="Predict", size = 'sm') | |
with gr.Row(): | |
with gr.Column(): | |
final_score_away1 = gr.Textbox(label="Predicted Score Model XGB", scale=1) | |
with gr.Column(): | |
final_score_home1 = gr.Textbox(label="Opposition Predicted Score Model XGB", scale=1) | |
with gr.Row(): | |
with gr.Column(): | |
final_score_away2 = gr.Textbox(label="Predicted Score Model CATB", scale=1) | |
with gr.Column(): | |
final_score_home2 = gr.Textbox(label="Opposition Predicted Score Model CATB", scale=1) | |
# patent_doc.upload(document_to_text, inputs = [patent_doc, slider, select_model], outputs=summary_doc) | |
predict_btn.click(predict, inputs=[team, inning, venue, hits, errors, lob, runs, opp_team, opp_runs, opp_hits], outputs=final_score_away1) | |
predict_btn.click(predict, inputs=[opp_team, inning, opp_venue, opp_hits, opp_errors, opp_lob, opp_runs, team, runs, hits], outputs=final_score_home1) | |
predict_btn.click(predict_2, inputs=[team, inning, venue, hits, errors, lob, runs, opp_team, opp_runs, opp_hits], outputs=final_score_away2) | |
predict_btn.click(predict_2, inputs=[opp_team, inning, opp_venue, opp_hits, opp_errors, opp_lob, opp_runs, team, runs, hits], outputs=final_score_home2) | |
demo.launch(inline=False) |