Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,742 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Final version...
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import gradio as gr
|
5 |
+
import pandas as pd
|
6 |
+
import numpy as np
|
7 |
+
from sklearn.metrics import mean_absolute_error, mean_squared_error
|
8 |
+
import os
|
9 |
+
import logging
|
10 |
+
import joblib
|
11 |
+
from tqdm import tqdm
|
12 |
+
import tempfile
|
13 |
+
import json
|
14 |
+
from math import radians, cos, sin, asin, sqrt, atan2, degrees
|
15 |
+
import time
|
16 |
+
|
17 |
+
# ============================
|
18 |
+
# Configure Logging
|
19 |
+
# ============================
|
20 |
+
|
21 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
22 |
+
|
23 |
+
# ============================
|
24 |
+
# Helper Functions
|
25 |
+
# ============================
|
26 |
+
|
27 |
+
def add_time_decimal_feature(df):
|
28 |
+
"""
|
29 |
+
Add 'time_decimal' feature by combining 'hour' and 'minutes'.
|
30 |
+
|
31 |
+
:param df: DataFrame with 'hour' and 'minutes' columns.
|
32 |
+
:return: DataFrame with 'time_decimal' and without 'hour' and 'minutes'.
|
33 |
+
"""
|
34 |
+
if 'hour' in df.columns and 'minutes' in df.columns:
|
35 |
+
logging.info("Adding 'time_decimal' feature...")
|
36 |
+
df['time_decimal'] = df['hour'] + df['minutes'] / 60.0
|
37 |
+
df = df.drop(columns=['hour', 'minutes']) # Drop 'hour' and 'minutes' after creation
|
38 |
+
logging.info("'time_decimal' feature added.")
|
39 |
+
else:
|
40 |
+
logging.warning("'hour' and/or 'minutes' columns not found. Skipping 'time_decimal' feature addition.")
|
41 |
+
return df
|
42 |
+
|
43 |
+
def haversine(lon1, lat1, lon2, lat2):
|
44 |
+
"""
|
45 |
+
Calculate the great-circle distance between two points on the Earth.
|
46 |
+
|
47 |
+
:param lon1: Longitude of point 1 (in decimal degrees)
|
48 |
+
:param lat1: Latitude of point 1 (in decimal degrees)
|
49 |
+
:param lon2: Longitude of point 2 (in decimal degrees)
|
50 |
+
:param lat2: Latitude of point 2 (in decimal degrees)
|
51 |
+
:return: Distance in kilometers
|
52 |
+
"""
|
53 |
+
# Convert decimal degrees to radians
|
54 |
+
lon1_rad, lat1_rad, lon2_rad, lat2_rad = map(np.radians, [lon1, lat1, lon2, lat2])
|
55 |
+
|
56 |
+
# Haversine formula
|
57 |
+
dlon = lon2_rad - lon1_rad
|
58 |
+
dlat = lat2_rad - lat1_rad
|
59 |
+
a = np.sin(dlat/2)**2 + np.cos(lat1_rad) * np.cos(lat2_rad) * np.sin(dlon/2)**2
|
60 |
+
c = 2 * np.arcsin(np.sqrt(a))
|
61 |
+
r = 6371 # Radius of Earth in kilometers
|
62 |
+
return c * r
|
63 |
+
|
64 |
+
def calculate_bearing(lon1, lat1, lon2, lat2):
|
65 |
+
"""
|
66 |
+
Calculate the bearing between two points.
|
67 |
+
|
68 |
+
:param lon1: Longitude of point 1 (in decimal degrees)
|
69 |
+
:param lat1: Latitude of point 1 (in decimal degrees)
|
70 |
+
:param lon2: Longitude of point 2 (in decimal degrees)
|
71 |
+
:param lat2: Latitude of point 2 (in decimal degrees)
|
72 |
+
:return: Bearing in degrees
|
73 |
+
"""
|
74 |
+
# Convert decimal degrees to radians
|
75 |
+
lon1_rad, lat1_rad, lon2_rad, lat2_rad = map(radians, [lon1, lat1, lon2, lat2])
|
76 |
+
|
77 |
+
dlon = lon2_rad - lon1_rad
|
78 |
+
x = sin(dlon) * cos(lat2_rad)
|
79 |
+
y = cos(lat1_rad) * sin(lat2_rad) - (sin(lat1_rad) * cos(lat2_rad) * cos(dlon))
|
80 |
+
|
81 |
+
initial_bearing = atan2(x, y)
|
82 |
+
|
83 |
+
# Convert from radians to degrees and normalize
|
84 |
+
initial_bearing = degrees(initial_bearing)
|
85 |
+
compass_bearing = (initial_bearing + 360) % 360
|
86 |
+
|
87 |
+
return compass_bearing
|
88 |
+
|
89 |
+
def angular_divergence(bearing1, bearing2):
|
90 |
+
"""
|
91 |
+
Calculate the smallest angle difference between two bearings.
|
92 |
+
|
93 |
+
:param bearing1: First bearing in degrees
|
94 |
+
:param bearing2: Second bearing in degrees
|
95 |
+
:return: Angular divergence in degrees
|
96 |
+
"""
|
97 |
+
diff = abs(bearing1 - bearing2) % 360
|
98 |
+
return min(diff, 360 - diff)
|
99 |
+
|
100 |
+
def denormalize(scaled_lat, scaled_lon, scaler, lat_idx, lon_idx):
|
101 |
+
"""
|
102 |
+
Denormalize latitude and longitude using the scaler's parameters.
|
103 |
+
|
104 |
+
:param scaled_lat: Scaled latitude values (numpy array).
|
105 |
+
:param scaled_lon: Scaled longitude values (numpy array).
|
106 |
+
:param scaler: The scaler object used for normalization.
|
107 |
+
:param lat_idx: Index of 'latitude_degrees' in the scaler's feature list.
|
108 |
+
:param lon_idx: Index of 'longitude_degrees' in the scaler's feature list.
|
109 |
+
:return: Tuple of (denormalized_lat, denormalized_lon).
|
110 |
+
"""
|
111 |
+
lat_min = scaler.data_min_[lat_idx]
|
112 |
+
lat_max = scaler.data_max_[lat_idx]
|
113 |
+
lon_min = scaler.data_min_[lon_idx]
|
114 |
+
lon_max = scaler.data_max_[lon_idx]
|
115 |
+
|
116 |
+
denorm_lat = scaled_lat * (lat_max - lat_min) + lat_min
|
117 |
+
denorm_lon = scaled_lon * (lon_max - lon_min) + lon_min
|
118 |
+
return denorm_lat, denorm_lon
|
119 |
+
|
120 |
+
def create_dataset_grouped_by_mmsi(df_scaled, seq_len, forecast_horizon, features_to_scale):
|
121 |
+
"""
|
122 |
+
Create input and output sequences grouped by original MMSI.
|
123 |
+
Returns scaled last known positions.
|
124 |
+
"""
|
125 |
+
Xs, ys, mmsis = [], [], []
|
126 |
+
last_known_positions_scaled = []
|
127 |
+
|
128 |
+
grouped = df_scaled.groupby('original_mmsi')
|
129 |
+
|
130 |
+
for mmsi, group in tqdm(grouped, desc="Creating sequences"):
|
131 |
+
if len(group) >= seq_len + forecast_horizon:
|
132 |
+
for i in range(len(group) - seq_len - forecast_horizon + 1):
|
133 |
+
# Select scaled features for the sequence
|
134 |
+
sequence = group.iloc[i:(i + seq_len)][features_to_scale].to_numpy()
|
135 |
+
|
136 |
+
# Future positions to predict (scaled)
|
137 |
+
future_positions = group[['latitude_degrees', 'longitude_degrees']].iloc[i + seq_len:i + seq_len + forecast_horizon].to_numpy()
|
138 |
+
|
139 |
+
# Future hour feature
|
140 |
+
future_hour = group[['time_decimal']].iloc[i + seq_len].values[0]
|
141 |
+
future_hour_feature = np.full((seq_len, 1), future_hour)
|
142 |
+
|
143 |
+
# Combine sequence with future_hour_feature
|
144 |
+
sequence_with_future_hour = np.hstack((sequence, future_hour_feature))
|
145 |
+
|
146 |
+
Xs.append(sequence_with_future_hour)
|
147 |
+
ys.append(future_positions)
|
148 |
+
mmsis.append(mmsi)
|
149 |
+
|
150 |
+
# Store last known positions (scaled)
|
151 |
+
last_lat_scaled = group['latitude_degrees'].iloc[i + seq_len - 1]
|
152 |
+
last_lon_scaled = group['longitude_degrees'].iloc[i + seq_len - 1]
|
153 |
+
last_known_positions_scaled.append((last_lat_scaled, last_lon_scaled))
|
154 |
+
|
155 |
+
return np.array(Xs, dtype=np.float32), np.array(ys, dtype=np.float32), np.array(mmsis), last_known_positions_scaled
|
156 |
+
|
157 |
+
# ============================
|
158 |
+
# Model Definitions
|
159 |
+
# ============================
|
160 |
+
|
161 |
+
class LSTMModelTeacher(nn.Module):
|
162 |
+
def __init__(self, in_dim, hidden_dim, forecast_horizon, n_layers=7, dropout=0.2):
|
163 |
+
"""
|
164 |
+
Teacher LSTM Model.
|
165 |
+
|
166 |
+
:param in_dim: Number of input features.
|
167 |
+
:param hidden_dim: Number of hidden units.
|
168 |
+
:param forecast_horizon: Number of future steps to predict.
|
169 |
+
:param n_layers: Number of LSTM layers.
|
170 |
+
:param dropout: Dropout rate.
|
171 |
+
"""
|
172 |
+
super(LSTMModelTeacher, self).__init__()
|
173 |
+
self.forecast_horizon = forecast_horizon # Store as an instance attribute
|
174 |
+
self.embedding = nn.Linear(in_dim, hidden_dim)
|
175 |
+
self.lstm = nn.LSTM(hidden_dim, hidden_dim, num_layers=n_layers, dropout=dropout, batch_first=True)
|
176 |
+
self.fc = nn.Linear(hidden_dim, forecast_horizon * 2)
|
177 |
+
|
178 |
+
def forward(self, x):
|
179 |
+
x = self.embedding(x)
|
180 |
+
x, _ = self.lstm(x)
|
181 |
+
x = self.fc(x[:, -1, :]) # Use the last timestep for prediction
|
182 |
+
x = x.view(-1, self.forecast_horizon, 2) # Shape: (batch_size, forecast_horizon, 2)
|
183 |
+
return x
|
184 |
+
|
185 |
+
class LSTMModelStudent(nn.Module):
|
186 |
+
def __init__(self, in_dim, hidden_dim, forecast_horizon, n_layers=3, dropout=0.2):
|
187 |
+
"""
|
188 |
+
Student LSTM Model.
|
189 |
+
|
190 |
+
:param in_dim: Number of input features.
|
191 |
+
:param hidden_dim: Number of hidden units.
|
192 |
+
:param forecast_horizon: Number of future steps to predict.
|
193 |
+
:param n_layers: Number of LSTM layers.
|
194 |
+
:param dropout: Dropout rate.
|
195 |
+
"""
|
196 |
+
super(LSTMModelStudent, self).__init__()
|
197 |
+
self.forecast_horizon = forecast_horizon # Store as an instance attribute
|
198 |
+
self.embedding = nn.Linear(in_dim, hidden_dim)
|
199 |
+
self.lstm = nn.LSTM(hidden_dim, hidden_dim, num_layers=n_layers, dropout=dropout, batch_first=True)
|
200 |
+
self.fc = nn.Linear(hidden_dim, forecast_horizon * 2)
|
201 |
+
|
202 |
+
def forward(self, x):
|
203 |
+
x = self.embedding(x)
|
204 |
+
x, _ = self.lstm(x)
|
205 |
+
x = self.fc(x[:, -1, :]) # Use the last timestep for prediction
|
206 |
+
x = x.view(-1, self.forecast_horizon, 2) # Shape: (batch_size, forecast_horizon, 2)
|
207 |
+
return x
|
208 |
+
|
209 |
+
# ============================
|
210 |
+
# Model Loading Functions
|
211 |
+
# ============================
|
212 |
+
|
213 |
+
def load_models(model_paths):
|
214 |
+
"""
|
215 |
+
Load teacher and student models, including submodels for North, Mid, and South areas.
|
216 |
+
|
217 |
+
:param model_paths: Dictionary containing paths to the models.
|
218 |
+
:return: Dictionary of loaded models.
|
219 |
+
"""
|
220 |
+
models = {}
|
221 |
+
logging.info("Loading Teacher model...")
|
222 |
+
# Load Teacher Model (Global)
|
223 |
+
teacher = LSTMModelTeacher(in_dim=15, hidden_dim=200, forecast_horizon=1, n_layers=7, dropout=0.2) # 15 features including 'future_hour_feature'
|
224 |
+
teacher.load_state_dict(torch.load(model_paths['teacher'], map_location=torch.device('cpu')))
|
225 |
+
teacher.eval()
|
226 |
+
models['Teacher'] = teacher
|
227 |
+
logging.info("Teacher model loaded successfully.")
|
228 |
+
|
229 |
+
logging.info("Loading Student North model...")
|
230 |
+
# Load Student Models (Sub-areas)
|
231 |
+
student_north = LSTMModelStudent(in_dim=15, hidden_dim=200, forecast_horizon=1, n_layers=3, dropout=0.2)
|
232 |
+
student_north.load_state_dict(torch.load(model_paths['student_north'], map_location=torch.device('cpu')))
|
233 |
+
student_north.eval()
|
234 |
+
models['Student_North'] = student_north
|
235 |
+
logging.info("Student North model loaded successfully.")
|
236 |
+
|
237 |
+
logging.info("Loading Student Mid model...")
|
238 |
+
student_mid = LSTMModelStudent(in_dim=15, hidden_dim=200, forecast_horizon=1, n_layers=3, dropout=0.2)
|
239 |
+
student_mid.load_state_dict(torch.load(model_paths['student_mid'], map_location=torch.device('cpu')))
|
240 |
+
student_mid.eval()
|
241 |
+
models['Student_Mid'] = student_mid
|
242 |
+
logging.info("Student Mid model loaded successfully.")
|
243 |
+
|
244 |
+
logging.info("Loading Student South model...")
|
245 |
+
student_south = LSTMModelStudent(in_dim=15, hidden_dim=200, forecast_horizon=1, n_layers=3, dropout=0.2)
|
246 |
+
student_south.load_state_dict(torch.load(model_paths['student_south'], map_location=torch.device('cpu')))
|
247 |
+
student_south.eval()
|
248 |
+
models['Student_South'] = student_south
|
249 |
+
logging.info("Student South model loaded successfully.")
|
250 |
+
|
251 |
+
return models
|
252 |
+
|
253 |
+
def load_scalers(scaler_paths):
|
254 |
+
"""
|
255 |
+
Load scalers for each model.
|
256 |
+
|
257 |
+
:param scaler_paths: Dictionary containing paths to the scaler files.
|
258 |
+
:return: Dictionary of loaded scalers.
|
259 |
+
"""
|
260 |
+
loaded_scalers = {}
|
261 |
+
for model_name, scaler_path in scaler_paths.items():
|
262 |
+
if os.path.exists(scaler_path):
|
263 |
+
loaded_scalers[model_name] = joblib.load(scaler_path)
|
264 |
+
logging.info(f"Loaded scaler for {model_name} from '{scaler_path}'.")
|
265 |
+
else:
|
266 |
+
logging.error(f"Scaler file for {model_name} not found at '{scaler_path}'.")
|
267 |
+
raise FileNotFoundError(f"Scaler file for {model_name} not found at '{scaler_path}'. Please provide the correct path.")
|
268 |
+
return loaded_scalers
|
269 |
+
|
270 |
+
# ============================
|
271 |
+
# Model Selection Logic
|
272 |
+
# ============================
|
273 |
+
|
274 |
+
def determine_subarea(df):
|
275 |
+
"""
|
276 |
+
Determine the sub-area (North, Mid, South) based on latitude and longitude ranges.
|
277 |
+
|
278 |
+
:param df: DataFrame containing 'latitude_degrees' and 'longitude_degrees'.
|
279 |
+
:return: String indicating the sub-area.
|
280 |
+
"""
|
281 |
+
# Define sub-area boundaries
|
282 |
+
subareas = {
|
283 |
+
'North': {'lat_min': 30, 'lat_max': 60, 'lon_min': -80, 'lon_max': -10},
|
284 |
+
'Mid': {'lat_min': 0, 'lat_max': 30, 'lon_min': -80, 'lon_max': 10},
|
285 |
+
'South': {'lat_min': -80, 'lat_max': 0, 'lon_min': -60, 'lon_max': 20}
|
286 |
+
}
|
287 |
+
|
288 |
+
# Count the number of data points in each sub-area
|
289 |
+
counts = {}
|
290 |
+
for area, bounds in subareas.items():
|
291 |
+
count = df[
|
292 |
+
(df['latitude_degrees'] >= bounds['lat_min']) & (df['latitude_degrees'] <= bounds['lat_max']) &
|
293 |
+
(df['longitude_degrees'] >= bounds['lon_min']) & (df['longitude_degrees'] <= bounds['lon_max'])
|
294 |
+
].shape[0]
|
295 |
+
counts[area] = count
|
296 |
+
logging.info(f"Sub-area '{area}': {count} records.")
|
297 |
+
|
298 |
+
# Determine the sub-area with the maximum count
|
299 |
+
predominant_subarea = max(counts, key=counts.get)
|
300 |
+
logging.info(f"Predominant sub-area determined: {predominant_subarea}")
|
301 |
+
|
302 |
+
# If no data points fall into any sub-area, default to Teacher
|
303 |
+
if counts[predominant_subarea] == 0:
|
304 |
+
logging.warning("No data points found in any sub-area. Defaulting to Teacher model.")
|
305 |
+
return 'Teacher'
|
306 |
+
|
307 |
+
return predominant_subarea
|
308 |
+
|
309 |
+
def select_model(models, subarea):
|
310 |
+
"""
|
311 |
+
Select the appropriate model based on the sub-area.
|
312 |
+
|
313 |
+
:param models: Dictionary of loaded models.
|
314 |
+
:param subarea: String indicating the sub-area.
|
315 |
+
:return: Tuple of (selected_model, selected_model_name).
|
316 |
+
"""
|
317 |
+
if subarea in ['North', 'Mid', 'South']:
|
318 |
+
selected_model = models.get(f'Student_{subarea}')
|
319 |
+
selected_model_name = f'Student_{subarea}'
|
320 |
+
logging.info(f"Selected model: {selected_model_name}")
|
321 |
+
return selected_model, selected_model_name
|
322 |
+
else:
|
323 |
+
selected_model = models.get('Teacher')
|
324 |
+
selected_model_name = 'Teacher'
|
325 |
+
logging.info(f"Selected model: {selected_model_name}")
|
326 |
+
return selected_model, selected_model_name
|
327 |
+
|
328 |
+
# ============================
|
329 |
+
# Evaluation Metrics Calculation
|
330 |
+
# ============================
|
331 |
+
|
332 |
+
def calculate_classic_metrics(y_true, y_pred):
|
333 |
+
"""
|
334 |
+
Calculate MAE, MSE, and RMSE directly on latitude/longitude pairs.
|
335 |
+
|
336 |
+
:param y_true: Ground truth positions (numpy array of shape (num_samples, 2)).
|
337 |
+
:param y_pred: Predicted positions (numpy array of shape (num_samples, 2)).
|
338 |
+
:return: Dictionary containing the classic metrics.
|
339 |
+
"""
|
340 |
+
# Calculate MAE
|
341 |
+
mae = mean_absolute_error(y_true, y_pred)
|
342 |
+
|
343 |
+
# Calculate MSE
|
344 |
+
mse = mean_squared_error(y_true, y_pred)
|
345 |
+
|
346 |
+
# Calculate RMSE
|
347 |
+
rmse = np.sqrt(mse)
|
348 |
+
|
349 |
+
classic_metrics = {
|
350 |
+
'MAE (degrees)': mae,
|
351 |
+
'MSE (degrees^2)': mse,
|
352 |
+
'RMSE (degrees)': rmse
|
353 |
+
}
|
354 |
+
|
355 |
+
logging.info(f"Calculated classic metrics: {classic_metrics}")
|
356 |
+
|
357 |
+
return classic_metrics
|
358 |
+
|
359 |
+
def calculate_distance_metrics(y_true, y_pred):
|
360 |
+
"""
|
361 |
+
Calculate metrics based on distance (in kilometers).
|
362 |
+
|
363 |
+
:param y_true: Ground truth positions (numpy array of shape (num_samples, 2)).
|
364 |
+
:param y_pred: Predicted positions (numpy array of shape (num_samples, 2)).
|
365 |
+
:return: Dictionary containing the distance-based metrics.
|
366 |
+
"""
|
367 |
+
# Calculate haversine distance between predicted and true positions
|
368 |
+
distances = np.array([
|
369 |
+
haversine(y_true[i, 1], y_true[i, 0], y_pred[i, 1], y_pred[i, 0])
|
370 |
+
for i in range(len(y_true))
|
371 |
+
]) # Assuming columns are [latitude, longitude]
|
372 |
+
|
373 |
+
# Calculate MAE
|
374 |
+
mae = np.mean(np.abs(distances))
|
375 |
+
|
376 |
+
# Calculate MSE
|
377 |
+
mse = np.mean(np.square(distances))
|
378 |
+
|
379 |
+
# Calculate RMSE
|
380 |
+
rmse = np.sqrt(mse)
|
381 |
+
|
382 |
+
# Calculate RSE (Relative Squared Error)
|
383 |
+
variance = np.var(distances)
|
384 |
+
rse = mse / variance if variance != 0 else float('inf')
|
385 |
+
|
386 |
+
metrics = {
|
387 |
+
'MAE (km)': mae,
|
388 |
+
'MSE (km^2)': mse,
|
389 |
+
'RMSE (km)': rmse,
|
390 |
+
'RSE': rse
|
391 |
+
}
|
392 |
+
|
393 |
+
logging.info(f"Calculated distance metrics: {metrics}")
|
394 |
+
|
395 |
+
return metrics
|
396 |
+
|
397 |
+
# ============================
|
398 |
+
# Classical Metrics Prediction
|
399 |
+
# ============================
|
400 |
+
|
401 |
+
def classical_prediction(file, model_choice, min_mmsi, max_mmsi, models, loaded_scalers):
|
402 |
+
"""
|
403 |
+
Preprocess the input CSV and make predictions using the selected model.
|
404 |
+
Calculate classical evaluation metrics and include inference time.
|
405 |
+
"""
|
406 |
+
try:
|
407 |
+
logging.info("Starting classical prediction...")
|
408 |
+
|
409 |
+
# Load the uploaded CSV file and filter based on MMSI
|
410 |
+
logging.info("Loading uploaded CSV file...")
|
411 |
+
df = pd.read_csv(file.name, delimiter=',')
|
412 |
+
logging.info(f"Uploaded CSV file loaded with {df.shape[0]} records.")
|
413 |
+
|
414 |
+
df = df[(df['mmsi'] >= min_mmsi) & (df['mmsi'] <= max_mmsi)]
|
415 |
+
if df.empty:
|
416 |
+
error_message = "No data available after applying MMSI filters."
|
417 |
+
logging.error(error_message)
|
418 |
+
return {"error": error_message}, None, None
|
419 |
+
|
420 |
+
# Check if 'time_decimal' exists
|
421 |
+
if 'time_decimal' not in df.columns:
|
422 |
+
df = add_time_decimal_feature(df)
|
423 |
+
else:
|
424 |
+
logging.info("'time_decimal' feature already exists. Skipping creation.")
|
425 |
+
|
426 |
+
expected_columns = [
|
427 |
+
"mmsi", "sog_kt", "latitude_degrees", "longitude_degrees", "cog_degrees",
|
428 |
+
"dimension_a_m", "dimension_b_m", "dimension_c_m", "dimension_d_m",
|
429 |
+
"ship_type", "day", "month", "year", "time_decimal"
|
430 |
+
]
|
431 |
+
|
432 |
+
if list(df.columns) != expected_columns:
|
433 |
+
error_message = (
|
434 |
+
f"Input data does not have the correct columns.\n"
|
435 |
+
f"Expected columns: {expected_columns}\n"
|
436 |
+
f"Got columns: {list(df.columns)}"
|
437 |
+
)
|
438 |
+
logging.error(error_message)
|
439 |
+
return {"error": error_message}, None, None
|
440 |
+
|
441 |
+
logging.info("Input CSV has the correct columns.")
|
442 |
+
|
443 |
+
# Select the appropriate model and scaler
|
444 |
+
if model_choice == "Auto-Select":
|
445 |
+
temp_df = df.copy()
|
446 |
+
subarea = determine_subarea(temp_df)
|
447 |
+
selected_model, selected_model_name = select_model(models, subarea)
|
448 |
+
scaler = loaded_scalers[selected_model_name]
|
449 |
+
else:
|
450 |
+
if model_choice in models:
|
451 |
+
selected_model = models[model_choice]
|
452 |
+
selected_model_name = model_choice
|
453 |
+
scaler = loaded_scalers[selected_model_name]
|
454 |
+
else:
|
455 |
+
error_message = f"Selected model '{model_choice}' is not available."
|
456 |
+
logging.error(error_message)
|
457 |
+
return {"error": error_message}, None, None
|
458 |
+
|
459 |
+
logging.info(f"Using scaler for model: {selected_model_name}")
|
460 |
+
|
461 |
+
# Normalize the data
|
462 |
+
logging.info("Normalizing the data...")
|
463 |
+
features_to_scale = [
|
464 |
+
"mmsi", "sog_kt", "latitude_degrees", "longitude_degrees", "cog_degrees",
|
465 |
+
"dimension_a_m", "dimension_b_m", "dimension_c_m", "dimension_d_m",
|
466 |
+
"ship_type", "day", "month", "year", "time_decimal"
|
467 |
+
]
|
468 |
+
X_new = df[features_to_scale]
|
469 |
+
X_scaled = scaler.transform(X_new)
|
470 |
+
df_scaled = pd.DataFrame(X_scaled, columns=features_to_scale, index=df.index)
|
471 |
+
df_scaled['original_mmsi'] = df['mmsi']
|
472 |
+
|
473 |
+
# Create sequences and get last known positions (scaled)
|
474 |
+
seq_len = 24
|
475 |
+
forecast_horizon = 1
|
476 |
+
X, y, mmsi_seq, last_known_positions_scaled = create_dataset_grouped_by_mmsi(df_scaled, seq_len, forecast_horizon, features_to_scale)
|
477 |
+
|
478 |
+
if X.size == 0:
|
479 |
+
error_message = "Not enough data to create sequences."
|
480 |
+
logging.error(error_message)
|
481 |
+
return {"error": error_message}, None, None
|
482 |
+
|
483 |
+
logging.info(f"Created {X.shape[0]} sequences.")
|
484 |
+
|
485 |
+
# Inference
|
486 |
+
logging.info("Starting model inference...")
|
487 |
+
test_dataset = torch.utils.data.TensorDataset(torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.float32))
|
488 |
+
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)
|
489 |
+
all_predictions = []
|
490 |
+
all_y_true = []
|
491 |
+
|
492 |
+
start_time = time.time() # Start inference time tracking
|
493 |
+
|
494 |
+
with torch.no_grad():
|
495 |
+
for batch in test_loader:
|
496 |
+
X_batch, y_batch = batch
|
497 |
+
predictions = selected_model(X_batch).cpu().numpy()
|
498 |
+
all_predictions.append(predictions)
|
499 |
+
all_y_true.append(y_batch.numpy())
|
500 |
+
|
501 |
+
inference_time = time.time() - start_time # End inference time tracking
|
502 |
+
|
503 |
+
all_predictions = np.concatenate(all_predictions, axis=0)
|
504 |
+
y_true = np.concatenate(all_y_true, axis=0)
|
505 |
+
y_pred = all_predictions
|
506 |
+
|
507 |
+
logging.info(f"Inference completed in {inference_time:.2f} seconds.")
|
508 |
+
|
509 |
+
# Denormalize predictions and real values
|
510 |
+
lat_idx = features_to_scale.index("latitude_degrees")
|
511 |
+
lon_idx = features_to_scale.index("longitude_degrees")
|
512 |
+
pred_lat, pred_lon = denormalize(y_pred[:, :, 0], y_pred[:, :, 1], scaler, lat_idx, lon_idx)
|
513 |
+
true_lat, true_lon = denormalize(y_true[:, :, 0], y_true[:, :, 1], scaler, lat_idx, lon_idx)
|
514 |
+
|
515 |
+
# Denormalize last known positions
|
516 |
+
last_lat_scaled = np.array([pos[0] for pos in last_known_positions_scaled])
|
517 |
+
last_lon_scaled = np.array([pos[1] for pos in last_known_positions_scaled])
|
518 |
+
|
519 |
+
last_lat_denorm, last_lon_denorm = denormalize(
|
520 |
+
last_lat_scaled, last_lon_scaled, scaler, lat_idx, lon_idx
|
521 |
+
)
|
522 |
+
|
523 |
+
# Calculate the classic evaluation metrics
|
524 |
+
y_true_pairs = np.column_stack((true_lat.flatten(), true_lon.flatten()))
|
525 |
+
y_pred_pairs = np.column_stack((pred_lat.flatten(), pred_lon.flatten()))
|
526 |
+
classic_metrics = calculate_classic_metrics(y_true=y_true_pairs, y_pred=y_pred_pairs)
|
527 |
+
classic_metrics['Inference Time (seconds)'] = inference_time # Include inference time
|
528 |
+
|
529 |
+
# Prepare metrics and output CSV
|
530 |
+
metrics_df = pd.DataFrame([classic_metrics])
|
531 |
+
metrics_json = metrics_df.to_json(orient="records")
|
532 |
+
metrics_json = json.loads(metrics_json)[0]
|
533 |
+
|
534 |
+
# Prepare predicted and real positions DataFrame
|
535 |
+
predicted_df = pd.DataFrame({
|
536 |
+
'MMSI': mmsi_seq[:len(y_pred)].flatten(),
|
537 |
+
'Last Known Latitude': last_lat_denorm.flatten(),
|
538 |
+
'Last Known Longitude': last_lon_denorm.flatten(),
|
539 |
+
'Predicted Latitude': pred_lat.flatten(),
|
540 |
+
'Predicted Longitude': pred_lon.flatten(),
|
541 |
+
'Real Latitude': true_lat.flatten(),
|
542 |
+
'Real Longitude': true_lon.flatten()
|
543 |
+
})
|
544 |
+
|
545 |
+
# Save predictions as CSV
|
546 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix='.csv', mode='w', newline='') as tmp_positions_file:
|
547 |
+
predicted_df.to_csv(tmp_positions_file, index=False)
|
548 |
+
positions_csv_path = tmp_positions_file.name
|
549 |
+
|
550 |
+
logging.info("Classical prediction completed.")
|
551 |
+
return metrics_json, positions_csv_path, inference_time
|
552 |
+
except Exception as e:
|
553 |
+
logging.error(f"An error occurred: {str(e)}")
|
554 |
+
return {"error": str(e)}, None, None
|
555 |
+
|
556 |
+
# ============================
|
557 |
+
# Abnormal Behavior Detection
|
558 |
+
# ============================
|
559 |
+
|
560 |
+
def abnormal_behavior_detection(prediction_file, alpha=0.5, threshold=10.0):
|
561 |
+
"""
|
562 |
+
Detect abnormal behavior based on angular divergence and distance difference.
|
563 |
+
Accepts a CSV file containing real and predicted positions.
|
564 |
+
"""
|
565 |
+
try:
|
566 |
+
logging.info("Starting abnormal behavior detection...")
|
567 |
+
|
568 |
+
# Load the CSV file containing real and predicted positions
|
569 |
+
logging.info("Loading prediction CSV file...")
|
570 |
+
df = pd.read_csv(prediction_file.name)
|
571 |
+
logging.info(f"Prediction CSV file loaded with {df.shape[0]} records.")
|
572 |
+
|
573 |
+
# Check if necessary columns exist
|
574 |
+
expected_columns = [
|
575 |
+
'MMSI', 'Last Known Latitude', 'Last Known Longitude',
|
576 |
+
'Predicted Latitude', 'Predicted Longitude',
|
577 |
+
'Real Latitude', 'Real Longitude'
|
578 |
+
]
|
579 |
+
|
580 |
+
if not all(col in df.columns for col in expected_columns):
|
581 |
+
error_message = (
|
582 |
+
f"Input data does not have the correct columns.\n"
|
583 |
+
f"Expected columns: {expected_columns}\n"
|
584 |
+
f"Got columns: {list(df.columns)}"
|
585 |
+
)
|
586 |
+
logging.error(error_message)
|
587 |
+
return {"error": error_message}
|
588 |
+
|
589 |
+
# Extract necessary data
|
590 |
+
mmsi_seq = df['MMSI'].values
|
591 |
+
last_lat_flat = df['Last Known Latitude'].values
|
592 |
+
last_lon_flat = df['Last Known Longitude'].values
|
593 |
+
pred_lat_flat = df['Predicted Latitude'].values
|
594 |
+
pred_lon_flat = df['Predicted Longitude'].values
|
595 |
+
true_lat_flat = df['Real Latitude'].values
|
596 |
+
true_lon_flat = df['Real Longitude'].values
|
597 |
+
|
598 |
+
# Calculate bearings
|
599 |
+
logging.info("Calculating bearings for predictions and real values...")
|
600 |
+
bearings_pred = [
|
601 |
+
calculate_bearing(last_lon_flat[i], last_lat_flat[i], pred_lon_flat[i], pred_lat_flat[i])
|
602 |
+
for i in range(len(pred_lat_flat))
|
603 |
+
]
|
604 |
+
bearings_true = [
|
605 |
+
calculate_bearing(last_lon_flat[i], last_lat_flat[i], true_lon_flat[i], true_lat_flat[i])
|
606 |
+
for i in range(len(true_lat_flat))
|
607 |
+
]
|
608 |
+
|
609 |
+
# Calculate angular divergence Δθ
|
610 |
+
logging.info("Calculating angular divergence (Δθ)...")
|
611 |
+
delta_theta = [
|
612 |
+
angular_divergence(bearings_pred[i], bearings_true[i])
|
613 |
+
for i in range(len(bearings_pred))
|
614 |
+
]
|
615 |
+
|
616 |
+
# Calculate distance difference Δd
|
617 |
+
logging.info("Calculating distance difference (Δd)...")
|
618 |
+
delta_d = [
|
619 |
+
haversine(last_lon_flat[i], last_lat_flat[i], pred_lon_flat[i], pred_lat_flat[i]) -
|
620 |
+
haversine(last_lon_flat[i], last_lat_flat[i], true_lon_flat[i], true_lat_flat[i])
|
621 |
+
for i in range(len(pred_lat_flat))
|
622 |
+
]
|
623 |
+
|
624 |
+
# Compute the score
|
625 |
+
logging.info("Computing the abnormal behavior score...")
|
626 |
+
score = [alpha * abs(dd) + (1 - alpha) * dt for dd, dt in zip(delta_d, delta_theta)]
|
627 |
+
|
628 |
+
# Determine abnormal behavior
|
629 |
+
logging.info("Determining abnormal behavior based on the score...")
|
630 |
+
abnormal_behavior = [1 if s >= threshold else 0 for s in score] # 1: Abnormal, 0: Normal
|
631 |
+
|
632 |
+
# Create DataFrame for saving
|
633 |
+
abnormal_behavior_df = pd.DataFrame({
|
634 |
+
'MMSI': mmsi_seq,
|
635 |
+
'Last Known Latitude': last_lat_flat,
|
636 |
+
'Last Known Longitude': last_lon_flat,
|
637 |
+
'Predicted Latitude': pred_lat_flat,
|
638 |
+
'Predicted Longitude': pred_lon_flat,
|
639 |
+
'Real Latitude': true_lat_flat,
|
640 |
+
'Real Longitude': true_lon_flat,
|
641 |
+
'Distance Difference (Δd) [km]': delta_d,
|
642 |
+
'Angular Divergence (Δθ) [degrees]': delta_theta,
|
643 |
+
'Score (αΔd + (1-α)Δθ)': score,
|
644 |
+
'Abnormal Behavior (1=Abnormal, 0=Normal)': abnormal_behavior
|
645 |
+
})
|
646 |
+
|
647 |
+
# Save abnormal behavior dataset as CSV
|
648 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix='.csv', mode='w', newline='') as tmp_abnormal_file:
|
649 |
+
abnormal_behavior_df.to_csv(tmp_abnormal_file, index=False)
|
650 |
+
abnormal_csv_path = tmp_abnormal_file.name
|
651 |
+
|
652 |
+
logging.info("Abnormal behavior detection completed.")
|
653 |
+
return abnormal_csv_path
|
654 |
+
except Exception as e:
|
655 |
+
logging.error(f"An error occurred: {str(e)}")
|
656 |
+
return {"error": str(e)}
|
657 |
+
|
658 |
+
# ============================
|
659 |
+
# Define Gradio Interface
|
660 |
+
# ============================
|
661 |
+
|
662 |
+
def main():
|
663 |
+
# ============================
|
664 |
+
# Define Model and Scaler Paths
|
665 |
+
# ============================
|
666 |
+
|
667 |
+
model_paths = {
|
668 |
+
'teacher': 'LSTM_whole_atlantic_horizon1_with_time_decimal_input_batch256/horizon_data_LSTM_whole_atlantic_horizon1_with_time_decimal_input_batch256_seq_24/run_1/best_model.pth',
|
669 |
+
'student_north': 'LSTM_whole_atlantic_horizon1_with_time_decimal_input_batch256_KD_North/horizon1_data_LSTM_whole_atlantic_horizon1_with_time_decimal_input_batch256_KD_North_seq_24/run_1/best_model.pth',
|
670 |
+
'student_mid': 'LSTM_whole_atlantic_horizon1_with_time_decimal_input_batch256_KD_Mid/horizon1_data_LSTM_whole_atlantic_horizon1_with_time_decimal_input_batch256_KD_Mid_seq_24/run_1/best_model.pth',
|
671 |
+
'student_south': 'LSTM_whole_atlantic_horizon1_with_time_decimal_input_batch256_KD_South/horizon1_data_LSTM_whole_atlantic_horizon1_with_time_decimal_input_batch256_KD_South_seq_24/run_1/best_model.pth'
|
672 |
+
}
|
673 |
+
|
674 |
+
scaler_paths = {
|
675 |
+
'Teacher': 'scaler_train_wholedata.joblib',
|
676 |
+
'Student_North': 'scaler_train_North.joblib',
|
677 |
+
'Student_Mid': 'scaler_train_Mid.joblib',
|
678 |
+
'Student_South': 'scaler_train_South.joblib'
|
679 |
+
}
|
680 |
+
|
681 |
+
# ============================
|
682 |
+
# Load Models and Scalers
|
683 |
+
# ============================
|
684 |
+
|
685 |
+
logging.info("Loading models and scalers...")
|
686 |
+
models = load_models(model_paths)
|
687 |
+
loaded_scalers = load_scalers(scaler_paths)
|
688 |
+
logging.info("All models and scalers loaded successfully.")
|
689 |
+
|
690 |
+
# Define the Gradio components for classical prediction tab
|
691 |
+
classical_tab = gr.Interface(
|
692 |
+
fn=lambda file, model_choice, min_mmsi, max_mmsi: classical_prediction(file, model_choice, min_mmsi, max_mmsi, models, loaded_scalers),
|
693 |
+
inputs=[
|
694 |
+
gr.File(label="Upload CSV File"),
|
695 |
+
gr.Dropdown(choices=["Auto-Select", "Teacher", "Student_North", "Student_Mid", "Student_South"], value="Auto-Select", label="Choose Model"),
|
696 |
+
gr.Number(label="Min MMSI", value=0),
|
697 |
+
gr.Number(label="Max MMSI", value=999999999)
|
698 |
+
],
|
699 |
+
outputs=[
|
700 |
+
gr.JSON(label="Classical Metrics (Degrees)"),
|
701 |
+
gr.File(label="Download Predicted & Real Positions CSV"),
|
702 |
+
gr.Number(label="Inference Time (seconds)")
|
703 |
+
],
|
704 |
+
title="Classical Prediction & Metrics",
|
705 |
+
description="Upload a CSV file and select a model to get classical evaluation metrics such as MAE, MSE, RMSE. The inference time is also provided."
|
706 |
+
)
|
707 |
+
|
708 |
+
# Define the Gradio components for abnormal behavior detection tab
|
709 |
+
abnormal_tab = gr.Interface(
|
710 |
+
fn=lambda prediction_file, alpha, threshold: abnormal_behavior_detection(prediction_file, alpha, threshold),
|
711 |
+
inputs=[
|
712 |
+
gr.File(label="Upload Predicted Positions CSV"),
|
713 |
+
gr.Slider(minimum=0, maximum=1, step=0.1, value=0.5, label="Alpha (α)"),
|
714 |
+
gr.Number(label="Threshold", value=10.0)
|
715 |
+
],
|
716 |
+
outputs=[
|
717 |
+
gr.File(label="Download Abnormal Behavior CSV")
|
718 |
+
],
|
719 |
+
title="Abnormal Behavior Detection",
|
720 |
+
description=(
|
721 |
+
"Upload the CSV file containing real and predicted positions from the Classical Prediction tab. "
|
722 |
+
"Adjust the Alpha and Threshold parameters to compute abnormal behavior."
|
723 |
+
)
|
724 |
+
)
|
725 |
+
|
726 |
+
# Combine the two tabs using Gradio Tabs component
|
727 |
+
with gr.Blocks() as demo:
|
728 |
+
gr.Markdown("# Vessel Trajectory Prediction and Abnormal Behavior Detection")
|
729 |
+
with gr.Tabs():
|
730 |
+
with gr.TabItem("Classical Prediction"):
|
731 |
+
classical_tab.render()
|
732 |
+
with gr.TabItem("Abnormal Behavior Detection"):
|
733 |
+
abnormal_tab.render()
|
734 |
+
|
735 |
+
# Launch the Gradio interface
|
736 |
+
logging.info("Launching Gradio interface...")
|
737 |
+
demo.launch(share=True)
|
738 |
+
logging.info("Gradio interface launched successfully.")
|
739 |
+
|
740 |
+
# Run the app
|
741 |
+
if __name__ == "__main__":
|
742 |
+
main()
|