Spaces:
Sleeping
Sleeping
import gradio as gr | |
import pandas as pd | |
import numpy as np | |
from prophet import Prophet | |
import yfinance as yf | |
from sklearn.metrics import mean_absolute_error, mean_squared_error | |
import matplotlib.pyplot as plt | |
from prophet.plot import plot_plotly, plot_components_plotly | |
# Function to fetch stock data from Yahoo Finance | |
def fetch_stock_data(ticker_symbol, start_date, end_date): | |
stock_data = yf.download(ticker_symbol, start=start_date, end=end_date) | |
df = stock_data[['Adj Close']].reset_index() | |
df = df.rename(columns={'Date': 'ds', 'Adj Close': 'y'}) | |
return df | |
# Function to train the Prophet model | |
def train_prophet_model(df): | |
model = Prophet() | |
model.fit(df) | |
return model | |
# Function to make the forecast | |
def make_forecast(model, periods): | |
future = model.make_future_dataframe(periods=periods) | |
forecast = model.predict(future) | |
return forecast | |
# Function to calculate performance metrics | |
def calculate_performance_metrics(actual, predicted): | |
mae = mean_absolute_error(actual, predicted) | |
mse = mean_squared_error(actual, predicted) | |
rmse = np.sqrt(mse) | |
return {'MAE': mae, 'MSE': mse, 'RMSE': rmse} | |
# Function to handle the complete process and return results | |
def forecast_stock(ticker_symbol, start_date, end_date, forecast_horizon): | |
# Fetch stock data | |
df = fetch_stock_data(ticker_symbol, start_date, end_date) | |
# Train the model | |
model = train_prophet_model(df) | |
# Convert forecast horizon to days | |
horizon_mapping = {'1 Month': 30, '6 months': (365/2), '1 year': 365, '2 years': 730, '3 years': 1095, '5 years': 1825} | |
forecast_days = horizon_mapping[forecast_horizon] | |
# Make forecast | |
forecast = make_forecast(model, forecast_days) | |
# Plot the forecast results using matplotlib | |
plt.figure(figsize=(10, 6)) | |
plt.plot(df['ds'], df['y'], label='Actual Data') | |
plt.plot(forecast['ds'], forecast['yhat'], label='Forecast', color='orange') | |
plt.fill_between(forecast['ds'], forecast['yhat_lower'], forecast['yhat_upper'], color='orange', alpha=0.2) | |
plt.xlabel('Date') | |
plt.ylabel('Price') | |
plt.legend() | |
plt.title('Stock Price Forecast') | |
plt.savefig('forecast_plot.png') | |
plt.close() | |
# Plot the forecast components | |
model.plot_components(forecast) | |
plt.savefig('forecast_components.png') | |
plt.close() | |
return 'forecast_plot.png', 'forecast_components.png' | |
# Gradio Interface | |
def main(): | |
with gr.Blocks() as demo: | |
gr.Markdown("# Stock Forecasting") | |
ticker_symbol = gr.Textbox(label="Enter Ticker Symbol", value="RACE") | |
start_date = gr.Textbox(label="Start Date (YYYY-MM-DD) of Data", value="2015-01-01") | |
end_date = gr.Textbox(label="End Date (YYYY-MM-DD) of Data", value=str(pd.to_datetime('today').date())) | |
forecast_horizon = gr.Dropdown( | |
label="Forecast Horizon", | |
choices=['1 Month','6 months','1 year', '2 years', '3 years', '5 years'], | |
value='1 year' | |
) | |
forecast_button = gr.Button("Forecast Stock Prices") | |
plot_output1 = gr.Image(label="Forecast Plot") | |
plot_output2 = gr.Image(label="Forecast Components") | |
forecast_button.click(forecast_stock, | |
inputs=[ticker_symbol, start_date, end_date, forecast_horizon], | |
outputs=[plot_output1, plot_output2]) | |
demo.launch() | |
# Run the Gradio app | |
if __name__ == "__main__": | |
main() | |