Spaces:
Sleeping
Sleeping
Delete MAGInet_final.ipynb
Browse files- MAGInet_final.ipynb +0 -776
MAGInet_final.ipynb
DELETED
@@ -1,776 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [
|
3 |
-
{
|
4 |
-
"cell_type": "code",
|
5 |
-
"execution_count": null,
|
6 |
-
"id": "e2789c26-d501-41f2-98e1-3ef2940e3f80",
|
7 |
-
"metadata": {},
|
8 |
-
"outputs": [],
|
9 |
-
"source": [
|
10 |
-
"# Final version...\n",
|
11 |
-
"import torch\n",
|
12 |
-
"import torch.nn as nn\n",
|
13 |
-
"import gradio as gr\n",
|
14 |
-
"import pandas as pd\n",
|
15 |
-
"import numpy as np\n",
|
16 |
-
"from sklearn.metrics import mean_absolute_error, mean_squared_error\n",
|
17 |
-
"import os\n",
|
18 |
-
"import logging\n",
|
19 |
-
"import joblib\n",
|
20 |
-
"from tqdm import tqdm\n",
|
21 |
-
"import tempfile\n",
|
22 |
-
"import json\n",
|
23 |
-
"from math import radians, cos, sin, asin, sqrt, atan2, degrees\n",
|
24 |
-
"import time\n",
|
25 |
-
"\n",
|
26 |
-
"# ============================\n",
|
27 |
-
"# Configure Logging\n",
|
28 |
-
"# ============================\n",
|
29 |
-
"\n",
|
30 |
-
"logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')\n",
|
31 |
-
"\n",
|
32 |
-
"# ============================\n",
|
33 |
-
"# Helper Functions\n",
|
34 |
-
"# ============================\n",
|
35 |
-
"\n",
|
36 |
-
"def add_time_decimal_feature(df):\n",
|
37 |
-
" \"\"\"\n",
|
38 |
-
" Add 'time_decimal' feature by combining 'hour' and 'minutes'.\n",
|
39 |
-
"\n",
|
40 |
-
" :param df: DataFrame with 'hour' and 'minutes' columns.\n",
|
41 |
-
" :return: DataFrame with 'time_decimal' and without 'hour' and 'minutes'.\n",
|
42 |
-
" \"\"\"\n",
|
43 |
-
" if 'hour' in df.columns and 'minutes' in df.columns:\n",
|
44 |
-
" logging.info(\"Adding 'time_decimal' feature...\")\n",
|
45 |
-
" df['time_decimal'] = df['hour'] + df['minutes'] / 60.0\n",
|
46 |
-
" df = df.drop(columns=['hour', 'minutes']) # Drop 'hour' and 'minutes' after creation\n",
|
47 |
-
" logging.info(\"'time_decimal' feature added.\")\n",
|
48 |
-
" else:\n",
|
49 |
-
" logging.warning(\"'hour' and/or 'minutes' columns not found. Skipping 'time_decimal' feature addition.\")\n",
|
50 |
-
" return df\n",
|
51 |
-
"\n",
|
52 |
-
"def haversine(lon1, lat1, lon2, lat2):\n",
|
53 |
-
" \"\"\"\n",
|
54 |
-
" Calculate the great-circle distance between two points on the Earth.\n",
|
55 |
-
"\n",
|
56 |
-
" :param lon1: Longitude of point 1 (in decimal degrees)\n",
|
57 |
-
" :param lat1: Latitude of point 1 (in decimal degrees)\n",
|
58 |
-
" :param lon2: Longitude of point 2 (in decimal degrees)\n",
|
59 |
-
" :param lat2: Latitude of point 2 (in decimal degrees)\n",
|
60 |
-
" :return: Distance in kilometers\n",
|
61 |
-
" \"\"\"\n",
|
62 |
-
" # Convert decimal degrees to radians\n",
|
63 |
-
" lon1_rad, lat1_rad, lon2_rad, lat2_rad = map(np.radians, [lon1, lat1, lon2, lat2])\n",
|
64 |
-
"\n",
|
65 |
-
" # Haversine formula\n",
|
66 |
-
" dlon = lon2_rad - lon1_rad \n",
|
67 |
-
" dlat = lat2_rad - lat1_rad \n",
|
68 |
-
" a = np.sin(dlat/2)**2 + np.cos(lat1_rad) * np.cos(lat2_rad) * np.sin(dlon/2)**2\n",
|
69 |
-
" c = 2 * np.arcsin(np.sqrt(a)) \n",
|
70 |
-
" r = 6371 # Radius of Earth in kilometers\n",
|
71 |
-
" return c * r\n",
|
72 |
-
"\n",
|
73 |
-
"def calculate_bearing(lon1, lat1, lon2, lat2):\n",
|
74 |
-
" \"\"\"\n",
|
75 |
-
" Calculate the bearing between two points.\n",
|
76 |
-
"\n",
|
77 |
-
" :param lon1: Longitude of point 1 (in decimal degrees)\n",
|
78 |
-
" :param lat1: Latitude of point 1 (in decimal degrees)\n",
|
79 |
-
" :param lon2: Longitude of point 2 (in decimal degrees)\n",
|
80 |
-
" :param lat2: Latitude of point 2 (in decimal degrees)\n",
|
81 |
-
" :return: Bearing in degrees\n",
|
82 |
-
" \"\"\"\n",
|
83 |
-
" # Convert decimal degrees to radians\n",
|
84 |
-
" lon1_rad, lat1_rad, lon2_rad, lat2_rad = map(radians, [lon1, lat1, lon2, lat2])\n",
|
85 |
-
"\n",
|
86 |
-
" dlon = lon2_rad - lon1_rad\n",
|
87 |
-
" x = sin(dlon) * cos(lat2_rad)\n",
|
88 |
-
" y = cos(lat1_rad) * sin(lat2_rad) - (sin(lat1_rad) * cos(lat2_rad) * cos(dlon))\n",
|
89 |
-
"\n",
|
90 |
-
" initial_bearing = atan2(x, y)\n",
|
91 |
-
"\n",
|
92 |
-
" # Convert from radians to degrees and normalize\n",
|
93 |
-
" initial_bearing = degrees(initial_bearing)\n",
|
94 |
-
" compass_bearing = (initial_bearing + 360) % 360\n",
|
95 |
-
"\n",
|
96 |
-
" return compass_bearing\n",
|
97 |
-
"\n",
|
98 |
-
"def angular_divergence(bearing1, bearing2):\n",
|
99 |
-
" \"\"\"\n",
|
100 |
-
" Calculate the smallest angle difference between two bearings.\n",
|
101 |
-
"\n",
|
102 |
-
" :param bearing1: First bearing in degrees\n",
|
103 |
-
" :param bearing2: Second bearing in degrees\n",
|
104 |
-
" :return: Angular divergence in degrees\n",
|
105 |
-
" \"\"\"\n",
|
106 |
-
" diff = abs(bearing1 - bearing2) % 360\n",
|
107 |
-
" return min(diff, 360 - diff)\n",
|
108 |
-
"\n",
|
109 |
-
"def denormalize(scaled_lat, scaled_lon, scaler, lat_idx, lon_idx):\n",
|
110 |
-
" \"\"\"\n",
|
111 |
-
" Denormalize latitude and longitude using the scaler's parameters.\n",
|
112 |
-
"\n",
|
113 |
-
" :param scaled_lat: Scaled latitude values (numpy array).\n",
|
114 |
-
" :param scaled_lon: Scaled longitude values (numpy array).\n",
|
115 |
-
" :param scaler: The scaler object used for normalization.\n",
|
116 |
-
" :param lat_idx: Index of 'latitude_degrees' in the scaler's feature list.\n",
|
117 |
-
" :param lon_idx: Index of 'longitude_degrees' in the scaler's feature list.\n",
|
118 |
-
" :return: Tuple of (denormalized_lat, denormalized_lon).\n",
|
119 |
-
" \"\"\"\n",
|
120 |
-
" lat_min = scaler.data_min_[lat_idx]\n",
|
121 |
-
" lat_max = scaler.data_max_[lat_idx]\n",
|
122 |
-
" lon_min = scaler.data_min_[lon_idx]\n",
|
123 |
-
" lon_max = scaler.data_max_[lon_idx]\n",
|
124 |
-
"\n",
|
125 |
-
" denorm_lat = scaled_lat * (lat_max - lat_min) + lat_min\n",
|
126 |
-
" denorm_lon = scaled_lon * (lon_max - lon_min) + lon_min\n",
|
127 |
-
" return denorm_lat, denorm_lon\n",
|
128 |
-
"\n",
|
129 |
-
"def create_dataset_grouped_by_mmsi(df_scaled, seq_len, forecast_horizon, features_to_scale):\n",
|
130 |
-
" \"\"\"\n",
|
131 |
-
" Create input and output sequences grouped by original MMSI.\n",
|
132 |
-
" Returns scaled last known positions.\n",
|
133 |
-
" \"\"\"\n",
|
134 |
-
" Xs, ys, mmsis = [], [], []\n",
|
135 |
-
" last_known_positions_scaled = []\n",
|
136 |
-
"\n",
|
137 |
-
" grouped = df_scaled.groupby('original_mmsi')\n",
|
138 |
-
"\n",
|
139 |
-
" for mmsi, group in tqdm(grouped, desc=\"Creating sequences\"):\n",
|
140 |
-
" if len(group) >= seq_len + forecast_horizon:\n",
|
141 |
-
" for i in range(len(group) - seq_len - forecast_horizon + 1):\n",
|
142 |
-
" # Select scaled features for the sequence\n",
|
143 |
-
" sequence = group.iloc[i:(i + seq_len)][features_to_scale].to_numpy()\n",
|
144 |
-
"\n",
|
145 |
-
" # Future positions to predict (scaled)\n",
|
146 |
-
" future_positions = group[['latitude_degrees', 'longitude_degrees']].iloc[i + seq_len:i + seq_len + forecast_horizon].to_numpy()\n",
|
147 |
-
"\n",
|
148 |
-
" # Future hour feature\n",
|
149 |
-
" future_hour = group[['time_decimal']].iloc[i + seq_len].values[0]\n",
|
150 |
-
" future_hour_feature = np.full((seq_len, 1), future_hour)\n",
|
151 |
-
"\n",
|
152 |
-
" # Combine sequence with future_hour_feature\n",
|
153 |
-
" sequence_with_future_hour = np.hstack((sequence, future_hour_feature))\n",
|
154 |
-
"\n",
|
155 |
-
" Xs.append(sequence_with_future_hour)\n",
|
156 |
-
" ys.append(future_positions)\n",
|
157 |
-
" mmsis.append(mmsi)\n",
|
158 |
-
"\n",
|
159 |
-
" # Store last known positions (scaled)\n",
|
160 |
-
" last_lat_scaled = group['latitude_degrees'].iloc[i + seq_len - 1]\n",
|
161 |
-
" last_lon_scaled = group['longitude_degrees'].iloc[i + seq_len - 1]\n",
|
162 |
-
" last_known_positions_scaled.append((last_lat_scaled, last_lon_scaled))\n",
|
163 |
-
"\n",
|
164 |
-
" return np.array(Xs, dtype=np.float32), np.array(ys, dtype=np.float32), np.array(mmsis), last_known_positions_scaled\n",
|
165 |
-
"\n",
|
166 |
-
"# ============================\n",
|
167 |
-
"# Model Definitions\n",
|
168 |
-
"# ============================\n",
|
169 |
-
"\n",
|
170 |
-
"class LSTMModelTeacher(nn.Module):\n",
|
171 |
-
" def __init__(self, in_dim, hidden_dim, forecast_horizon, n_layers=7, dropout=0.2):\n",
|
172 |
-
" \"\"\"\n",
|
173 |
-
" Teacher LSTM Model.\n",
|
174 |
-
"\n",
|
175 |
-
" :param in_dim: Number of input features.\n",
|
176 |
-
" :param hidden_dim: Number of hidden units.\n",
|
177 |
-
" :param forecast_horizon: Number of future steps to predict.\n",
|
178 |
-
" :param n_layers: Number of LSTM layers.\n",
|
179 |
-
" :param dropout: Dropout rate.\n",
|
180 |
-
" \"\"\"\n",
|
181 |
-
" super(LSTMModelTeacher, self).__init__()\n",
|
182 |
-
" self.forecast_horizon = forecast_horizon # Store as an instance attribute\n",
|
183 |
-
" self.embedding = nn.Linear(in_dim, hidden_dim)\n",
|
184 |
-
" self.lstm = nn.LSTM(hidden_dim, hidden_dim, num_layers=n_layers, dropout=dropout, batch_first=True)\n",
|
185 |
-
" self.fc = nn.Linear(hidden_dim, forecast_horizon * 2)\n",
|
186 |
-
"\n",
|
187 |
-
" def forward(self, x):\n",
|
188 |
-
" x = self.embedding(x)\n",
|
189 |
-
" x, _ = self.lstm(x)\n",
|
190 |
-
" x = self.fc(x[:, -1, :]) # Use the last timestep for prediction\n",
|
191 |
-
" x = x.view(-1, self.forecast_horizon, 2) # Shape: (batch_size, forecast_horizon, 2)\n",
|
192 |
-
" return x\n",
|
193 |
-
"\n",
|
194 |
-
"class LSTMModelStudent(nn.Module):\n",
|
195 |
-
" def __init__(self, in_dim, hidden_dim, forecast_horizon, n_layers=3, dropout=0.2):\n",
|
196 |
-
" \"\"\"\n",
|
197 |
-
" Student LSTM Model.\n",
|
198 |
-
"\n",
|
199 |
-
" :param in_dim: Number of input features.\n",
|
200 |
-
" :param hidden_dim: Number of hidden units.\n",
|
201 |
-
" :param forecast_horizon: Number of future steps to predict.\n",
|
202 |
-
" :param n_layers: Number of LSTM layers.\n",
|
203 |
-
" :param dropout: Dropout rate.\n",
|
204 |
-
" \"\"\"\n",
|
205 |
-
" super(LSTMModelStudent, self).__init__()\n",
|
206 |
-
" self.forecast_horizon = forecast_horizon # Store as an instance attribute\n",
|
207 |
-
" self.embedding = nn.Linear(in_dim, hidden_dim)\n",
|
208 |
-
" self.lstm = nn.LSTM(hidden_dim, hidden_dim, num_layers=n_layers, dropout=dropout, batch_first=True)\n",
|
209 |
-
" self.fc = nn.Linear(hidden_dim, forecast_horizon * 2)\n",
|
210 |
-
"\n",
|
211 |
-
" def forward(self, x):\n",
|
212 |
-
" x = self.embedding(x)\n",
|
213 |
-
" x, _ = self.lstm(x)\n",
|
214 |
-
" x = self.fc(x[:, -1, :]) # Use the last timestep for prediction\n",
|
215 |
-
" x = x.view(-1, self.forecast_horizon, 2) # Shape: (batch_size, forecast_horizon, 2)\n",
|
216 |
-
" return x\n",
|
217 |
-
"\n",
|
218 |
-
"# ============================\n",
|
219 |
-
"# Model Loading Functions\n",
|
220 |
-
"# ============================\n",
|
221 |
-
"\n",
|
222 |
-
"def load_models(model_paths):\n",
|
223 |
-
" \"\"\"\n",
|
224 |
-
" Load teacher and student models, including submodels for North, Mid, and South areas.\n",
|
225 |
-
"\n",
|
226 |
-
" :param model_paths: Dictionary containing paths to the models.\n",
|
227 |
-
" :return: Dictionary of loaded models.\n",
|
228 |
-
" \"\"\"\n",
|
229 |
-
" models = {}\n",
|
230 |
-
" logging.info(\"Loading Teacher model...\")\n",
|
231 |
-
" # Load Teacher Model (Global)\n",
|
232 |
-
" teacher = LSTMModelTeacher(in_dim=15, hidden_dim=200, forecast_horizon=1, n_layers=7, dropout=0.2) # 15 features including 'future_hour_feature'\n",
|
233 |
-
" teacher.load_state_dict(torch.load(model_paths['teacher'], map_location=torch.device('cpu')))\n",
|
234 |
-
" teacher.eval()\n",
|
235 |
-
" models['Teacher'] = teacher\n",
|
236 |
-
" logging.info(\"Teacher model loaded successfully.\")\n",
|
237 |
-
"\n",
|
238 |
-
" logging.info(\"Loading Student North model...\")\n",
|
239 |
-
" # Load Student Models (Sub-areas)\n",
|
240 |
-
" student_north = LSTMModelStudent(in_dim=15, hidden_dim=200, forecast_horizon=1, n_layers=3, dropout=0.2)\n",
|
241 |
-
" student_north.load_state_dict(torch.load(model_paths['student_north'], map_location=torch.device('cpu')))\n",
|
242 |
-
" student_north.eval()\n",
|
243 |
-
" models['Student_North'] = student_north\n",
|
244 |
-
" logging.info(\"Student North model loaded successfully.\")\n",
|
245 |
-
"\n",
|
246 |
-
" logging.info(\"Loading Student Mid model...\")\n",
|
247 |
-
" student_mid = LSTMModelStudent(in_dim=15, hidden_dim=200, forecast_horizon=1, n_layers=3, dropout=0.2)\n",
|
248 |
-
" student_mid.load_state_dict(torch.load(model_paths['student_mid'], map_location=torch.device('cpu')))\n",
|
249 |
-
" student_mid.eval()\n",
|
250 |
-
" models['Student_Mid'] = student_mid\n",
|
251 |
-
" logging.info(\"Student Mid model loaded successfully.\")\n",
|
252 |
-
"\n",
|
253 |
-
" logging.info(\"Loading Student South model...\")\n",
|
254 |
-
" student_south = LSTMModelStudent(in_dim=15, hidden_dim=200, forecast_horizon=1, n_layers=3, dropout=0.2)\n",
|
255 |
-
" student_south.load_state_dict(torch.load(model_paths['student_south'], map_location=torch.device('cpu')))\n",
|
256 |
-
" student_south.eval()\n",
|
257 |
-
" models['Student_South'] = student_south\n",
|
258 |
-
" logging.info(\"Student South model loaded successfully.\")\n",
|
259 |
-
"\n",
|
260 |
-
" return models\n",
|
261 |
-
"\n",
|
262 |
-
"def load_scalers(scaler_paths):\n",
|
263 |
-
" \"\"\"\n",
|
264 |
-
" Load scalers for each model.\n",
|
265 |
-
"\n",
|
266 |
-
" :param scaler_paths: Dictionary containing paths to the scaler files.\n",
|
267 |
-
" :return: Dictionary of loaded scalers.\n",
|
268 |
-
" \"\"\"\n",
|
269 |
-
" loaded_scalers = {}\n",
|
270 |
-
" for model_name, scaler_path in scaler_paths.items():\n",
|
271 |
-
" if os.path.exists(scaler_path):\n",
|
272 |
-
" loaded_scalers[model_name] = joblib.load(scaler_path)\n",
|
273 |
-
" logging.info(f\"Loaded scaler for {model_name} from '{scaler_path}'.\")\n",
|
274 |
-
" else:\n",
|
275 |
-
" logging.error(f\"Scaler file for {model_name} not found at '{scaler_path}'.\")\n",
|
276 |
-
" raise FileNotFoundError(f\"Scaler file for {model_name} not found at '{scaler_path}'. Please provide the correct path.\")\n",
|
277 |
-
" return loaded_scalers\n",
|
278 |
-
"\n",
|
279 |
-
"# ============================\n",
|
280 |
-
"# Model Selection Logic\n",
|
281 |
-
"# ============================\n",
|
282 |
-
"\n",
|
283 |
-
"def determine_subarea(df):\n",
|
284 |
-
" \"\"\"\n",
|
285 |
-
" Determine the sub-area (North, Mid, South) based on latitude and longitude ranges.\n",
|
286 |
-
"\n",
|
287 |
-
" :param df: DataFrame containing 'latitude_degrees' and 'longitude_degrees'.\n",
|
288 |
-
" :return: String indicating the sub-area.\n",
|
289 |
-
" \"\"\"\n",
|
290 |
-
" # Define sub-area boundaries\n",
|
291 |
-
" subareas = {\n",
|
292 |
-
" 'North': {'lat_min': 30, 'lat_max': 60, 'lon_min': -80, 'lon_max': -10},\n",
|
293 |
-
" 'Mid': {'lat_min': 0, 'lat_max': 30, 'lon_min': -80, 'lon_max': 10},\n",
|
294 |
-
" 'South': {'lat_min': -80, 'lat_max': 0, 'lon_min': -60, 'lon_max': 20}\n",
|
295 |
-
" }\n",
|
296 |
-
"\n",
|
297 |
-
" # Count the number of data points in each sub-area\n",
|
298 |
-
" counts = {}\n",
|
299 |
-
" for area, bounds in subareas.items():\n",
|
300 |
-
" count = df[\n",
|
301 |
-
" (df['latitude_degrees'] >= bounds['lat_min']) & (df['latitude_degrees'] <= bounds['lat_max']) &\n",
|
302 |
-
" (df['longitude_degrees'] >= bounds['lon_min']) & (df['longitude_degrees'] <= bounds['lon_max'])\n",
|
303 |
-
" ].shape[0]\n",
|
304 |
-
" counts[area] = count\n",
|
305 |
-
" logging.info(f\"Sub-area '{area}': {count} records.\")\n",
|
306 |
-
"\n",
|
307 |
-
" # Determine the sub-area with the maximum count\n",
|
308 |
-
" predominant_subarea = max(counts, key=counts.get)\n",
|
309 |
-
" logging.info(f\"Predominant sub-area determined: {predominant_subarea}\")\n",
|
310 |
-
"\n",
|
311 |
-
" # If no data points fall into any sub-area, default to Teacher\n",
|
312 |
-
" if counts[predominant_subarea] == 0:\n",
|
313 |
-
" logging.warning(\"No data points found in any sub-area. Defaulting to Teacher model.\")\n",
|
314 |
-
" return 'Teacher'\n",
|
315 |
-
"\n",
|
316 |
-
" return predominant_subarea\n",
|
317 |
-
"\n",
|
318 |
-
"def select_model(models, subarea):\n",
|
319 |
-
" \"\"\"\n",
|
320 |
-
" Select the appropriate model based on the sub-area.\n",
|
321 |
-
"\n",
|
322 |
-
" :param models: Dictionary of loaded models.\n",
|
323 |
-
" :param subarea: String indicating the sub-area.\n",
|
324 |
-
" :return: Tuple of (selected_model, selected_model_name).\n",
|
325 |
-
" \"\"\"\n",
|
326 |
-
" if subarea in ['North', 'Mid', 'South']:\n",
|
327 |
-
" selected_model = models.get(f'Student_{subarea}')\n",
|
328 |
-
" selected_model_name = f'Student_{subarea}'\n",
|
329 |
-
" logging.info(f\"Selected model: {selected_model_name}\")\n",
|
330 |
-
" return selected_model, selected_model_name\n",
|
331 |
-
" else:\n",
|
332 |
-
" selected_model = models.get('Teacher')\n",
|
333 |
-
" selected_model_name = 'Teacher'\n",
|
334 |
-
" logging.info(f\"Selected model: {selected_model_name}\")\n",
|
335 |
-
" return selected_model, selected_model_name\n",
|
336 |
-
"\n",
|
337 |
-
"# ============================\n",
|
338 |
-
"# Evaluation Metrics Calculation\n",
|
339 |
-
"# ============================\n",
|
340 |
-
"\n",
|
341 |
-
"def calculate_classic_metrics(y_true, y_pred):\n",
|
342 |
-
" \"\"\"\n",
|
343 |
-
" Calculate MAE, MSE, and RMSE directly on latitude/longitude pairs.\n",
|
344 |
-
"\n",
|
345 |
-
" :param y_true: Ground truth positions (numpy array of shape (num_samples, 2)).\n",
|
346 |
-
" :param y_pred: Predicted positions (numpy array of shape (num_samples, 2)).\n",
|
347 |
-
" :return: Dictionary containing the classic metrics.\n",
|
348 |
-
" \"\"\"\n",
|
349 |
-
" # Calculate MAE\n",
|
350 |
-
" mae = mean_absolute_error(y_true, y_pred)\n",
|
351 |
-
"\n",
|
352 |
-
" # Calculate MSE\n",
|
353 |
-
" mse = mean_squared_error(y_true, y_pred)\n",
|
354 |
-
"\n",
|
355 |
-
" # Calculate RMSE\n",
|
356 |
-
" rmse = np.sqrt(mse)\n",
|
357 |
-
"\n",
|
358 |
-
" classic_metrics = {\n",
|
359 |
-
" 'MAE (degrees)': mae,\n",
|
360 |
-
" 'MSE (degrees^2)': mse,\n",
|
361 |
-
" 'RMSE (degrees)': rmse\n",
|
362 |
-
" }\n",
|
363 |
-
"\n",
|
364 |
-
" logging.info(f\"Calculated classic metrics: {classic_metrics}\")\n",
|
365 |
-
" \n",
|
366 |
-
" return classic_metrics\n",
|
367 |
-
"\n",
|
368 |
-
"def calculate_distance_metrics(y_true, y_pred):\n",
|
369 |
-
" \"\"\"\n",
|
370 |
-
" Calculate metrics based on distance (in kilometers).\n",
|
371 |
-
"\n",
|
372 |
-
" :param y_true: Ground truth positions (numpy array of shape (num_samples, 2)).\n",
|
373 |
-
" :param y_pred: Predicted positions (numpy array of shape (num_samples, 2)).\n",
|
374 |
-
" :return: Dictionary containing the distance-based metrics.\n",
|
375 |
-
" \"\"\"\n",
|
376 |
-
" # Calculate haversine distance between predicted and true positions\n",
|
377 |
-
" distances = np.array([\n",
|
378 |
-
" haversine(y_true[i, 1], y_true[i, 0], y_pred[i, 1], y_pred[i, 0]) \n",
|
379 |
-
" for i in range(len(y_true))\n",
|
380 |
-
" ]) # Assuming columns are [latitude, longitude]\n",
|
381 |
-
"\n",
|
382 |
-
" # Calculate MAE\n",
|
383 |
-
" mae = np.mean(np.abs(distances))\n",
|
384 |
-
"\n",
|
385 |
-
" # Calculate MSE\n",
|
386 |
-
" mse = np.mean(np.square(distances))\n",
|
387 |
-
"\n",
|
388 |
-
" # Calculate RMSE\n",
|
389 |
-
" rmse = np.sqrt(mse)\n",
|
390 |
-
"\n",
|
391 |
-
" # Calculate RSE (Relative Squared Error)\n",
|
392 |
-
" variance = np.var(distances)\n",
|
393 |
-
" rse = mse / variance if variance != 0 else float('inf')\n",
|
394 |
-
"\n",
|
395 |
-
" metrics = {\n",
|
396 |
-
" 'MAE (km)': mae,\n",
|
397 |
-
" 'MSE (km^2)': mse,\n",
|
398 |
-
" 'RMSE (km)': rmse,\n",
|
399 |
-
" 'RSE': rse\n",
|
400 |
-
" }\n",
|
401 |
-
"\n",
|
402 |
-
" logging.info(f\"Calculated distance metrics: {metrics}\")\n",
|
403 |
-
" \n",
|
404 |
-
" return metrics\n",
|
405 |
-
"\n",
|
406 |
-
"# ============================\n",
|
407 |
-
"# Classical Metrics Prediction\n",
|
408 |
-
"# ============================\n",
|
409 |
-
"\n",
|
410 |
-
"def classical_prediction(file, model_choice, min_mmsi, max_mmsi, models, loaded_scalers):\n",
|
411 |
-
" \"\"\"\n",
|
412 |
-
" Preprocess the input CSV and make predictions using the selected model.\n",
|
413 |
-
" Calculate classical evaluation metrics and include inference time.\n",
|
414 |
-
" \"\"\"\n",
|
415 |
-
" try:\n",
|
416 |
-
" logging.info(\"Starting classical prediction...\")\n",
|
417 |
-
"\n",
|
418 |
-
" # Load the uploaded CSV file and filter based on MMSI\n",
|
419 |
-
" logging.info(\"Loading uploaded CSV file...\")\n",
|
420 |
-
" df = pd.read_csv(file.name, delimiter=',')\n",
|
421 |
-
" logging.info(f\"Uploaded CSV file loaded with {df.shape[0]} records.\")\n",
|
422 |
-
" \n",
|
423 |
-
" df = df[(df['mmsi'] >= min_mmsi) & (df['mmsi'] <= max_mmsi)]\n",
|
424 |
-
" if df.empty:\n",
|
425 |
-
" error_message = \"No data available after applying MMSI filters.\"\n",
|
426 |
-
" logging.error(error_message)\n",
|
427 |
-
" return {\"error\": error_message}, None, None\n",
|
428 |
-
"\n",
|
429 |
-
" # Check if 'time_decimal' exists\n",
|
430 |
-
" if 'time_decimal' not in df.columns:\n",
|
431 |
-
" df = add_time_decimal_feature(df)\n",
|
432 |
-
" else:\n",
|
433 |
-
" logging.info(\"'time_decimal' feature already exists. Skipping creation.\")\n",
|
434 |
-
"\n",
|
435 |
-
" expected_columns = [\n",
|
436 |
-
" \"mmsi\", \"sog_kt\", \"latitude_degrees\", \"longitude_degrees\", \"cog_degrees\",\n",
|
437 |
-
" \"dimension_a_m\", \"dimension_b_m\", \"dimension_c_m\", \"dimension_d_m\",\n",
|
438 |
-
" \"ship_type\", \"day\", \"month\", \"year\", \"time_decimal\"\n",
|
439 |
-
" ]\n",
|
440 |
-
"\n",
|
441 |
-
" if list(df.columns) != expected_columns:\n",
|
442 |
-
" error_message = (\n",
|
443 |
-
" f\"Input data does not have the correct columns.\\n\"\n",
|
444 |
-
" f\"Expected columns: {expected_columns}\\n\"\n",
|
445 |
-
" f\"Got columns: {list(df.columns)}\"\n",
|
446 |
-
" )\n",
|
447 |
-
" logging.error(error_message)\n",
|
448 |
-
" return {\"error\": error_message}, None, None\n",
|
449 |
-
"\n",
|
450 |
-
" logging.info(\"Input CSV has the correct columns.\")\n",
|
451 |
-
"\n",
|
452 |
-
" # Select the appropriate model and scaler\n",
|
453 |
-
" if model_choice == \"Auto-Select\":\n",
|
454 |
-
" temp_df = df.copy()\n",
|
455 |
-
" subarea = determine_subarea(temp_df)\n",
|
456 |
-
" selected_model, selected_model_name = select_model(models, subarea)\n",
|
457 |
-
" scaler = loaded_scalers[selected_model_name]\n",
|
458 |
-
" else:\n",
|
459 |
-
" if model_choice in models:\n",
|
460 |
-
" selected_model = models[model_choice]\n",
|
461 |
-
" selected_model_name = model_choice\n",
|
462 |
-
" scaler = loaded_scalers[selected_model_name]\n",
|
463 |
-
" else:\n",
|
464 |
-
" error_message = f\"Selected model '{model_choice}' is not available.\"\n",
|
465 |
-
" logging.error(error_message)\n",
|
466 |
-
" return {\"error\": error_message}, None, None\n",
|
467 |
-
"\n",
|
468 |
-
" logging.info(f\"Using scaler for model: {selected_model_name}\")\n",
|
469 |
-
"\n",
|
470 |
-
" # Normalize the data\n",
|
471 |
-
" logging.info(\"Normalizing the data...\")\n",
|
472 |
-
" features_to_scale = [\n",
|
473 |
-
" \"mmsi\", \"sog_kt\", \"latitude_degrees\", \"longitude_degrees\", \"cog_degrees\",\n",
|
474 |
-
" \"dimension_a_m\", \"dimension_b_m\", \"dimension_c_m\", \"dimension_d_m\",\n",
|
475 |
-
" \"ship_type\", \"day\", \"month\", \"year\", \"time_decimal\"\n",
|
476 |
-
" ]\n",
|
477 |
-
" X_new = df[features_to_scale]\n",
|
478 |
-
" X_scaled = scaler.transform(X_new)\n",
|
479 |
-
" df_scaled = pd.DataFrame(X_scaled, columns=features_to_scale, index=df.index)\n",
|
480 |
-
" df_scaled['original_mmsi'] = df['mmsi']\n",
|
481 |
-
"\n",
|
482 |
-
" # Create sequences and get last known positions (scaled)\n",
|
483 |
-
" seq_len = 24\n",
|
484 |
-
" forecast_horizon = 1\n",
|
485 |
-
" X, y, mmsi_seq, last_known_positions_scaled = create_dataset_grouped_by_mmsi(df_scaled, seq_len, forecast_horizon, features_to_scale)\n",
|
486 |
-
"\n",
|
487 |
-
" if X.size == 0:\n",
|
488 |
-
" error_message = \"Not enough data to create sequences.\"\n",
|
489 |
-
" logging.error(error_message)\n",
|
490 |
-
" return {\"error\": error_message}, None, None\n",
|
491 |
-
"\n",
|
492 |
-
" logging.info(f\"Created {X.shape[0]} sequences.\")\n",
|
493 |
-
"\n",
|
494 |
-
" # Inference\n",
|
495 |
-
" logging.info(\"Starting model inference...\")\n",
|
496 |
-
" test_dataset = torch.utils.data.TensorDataset(torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.float32))\n",
|
497 |
-
" test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)\n",
|
498 |
-
" all_predictions = []\n",
|
499 |
-
" all_y_true = []\n",
|
500 |
-
"\n",
|
501 |
-
" start_time = time.time() # Start inference time tracking\n",
|
502 |
-
"\n",
|
503 |
-
" with torch.no_grad():\n",
|
504 |
-
" for batch in test_loader:\n",
|
505 |
-
" X_batch, y_batch = batch\n",
|
506 |
-
" predictions = selected_model(X_batch).cpu().numpy()\n",
|
507 |
-
" all_predictions.append(predictions)\n",
|
508 |
-
" all_y_true.append(y_batch.numpy())\n",
|
509 |
-
"\n",
|
510 |
-
" inference_time = time.time() - start_time # End inference time tracking\n",
|
511 |
-
"\n",
|
512 |
-
" all_predictions = np.concatenate(all_predictions, axis=0)\n",
|
513 |
-
" y_true = np.concatenate(all_y_true, axis=0)\n",
|
514 |
-
" y_pred = all_predictions\n",
|
515 |
-
"\n",
|
516 |
-
" logging.info(f\"Inference completed in {inference_time:.2f} seconds.\")\n",
|
517 |
-
"\n",
|
518 |
-
" # Denormalize predictions and real values\n",
|
519 |
-
" lat_idx = features_to_scale.index(\"latitude_degrees\")\n",
|
520 |
-
" lon_idx = features_to_scale.index(\"longitude_degrees\")\n",
|
521 |
-
" pred_lat, pred_lon = denormalize(y_pred[:, :, 0], y_pred[:, :, 1], scaler, lat_idx, lon_idx)\n",
|
522 |
-
" true_lat, true_lon = denormalize(y_true[:, :, 0], y_true[:, :, 1], scaler, lat_idx, lon_idx)\n",
|
523 |
-
"\n",
|
524 |
-
" # Denormalize last known positions\n",
|
525 |
-
" last_lat_scaled = np.array([pos[0] for pos in last_known_positions_scaled])\n",
|
526 |
-
" last_lon_scaled = np.array([pos[1] for pos in last_known_positions_scaled])\n",
|
527 |
-
"\n",
|
528 |
-
" last_lat_denorm, last_lon_denorm = denormalize(\n",
|
529 |
-
" last_lat_scaled, last_lon_scaled, scaler, lat_idx, lon_idx\n",
|
530 |
-
" )\n",
|
531 |
-
"\n",
|
532 |
-
" # Calculate the classic evaluation metrics\n",
|
533 |
-
" y_true_pairs = np.column_stack((true_lat.flatten(), true_lon.flatten()))\n",
|
534 |
-
" y_pred_pairs = np.column_stack((pred_lat.flatten(), pred_lon.flatten()))\n",
|
535 |
-
" classic_metrics = calculate_classic_metrics(y_true=y_true_pairs, y_pred=y_pred_pairs)\n",
|
536 |
-
" classic_metrics['Inference Time (seconds)'] = inference_time # Include inference time\n",
|
537 |
-
"\n",
|
538 |
-
" # Prepare metrics and output CSV\n",
|
539 |
-
" metrics_df = pd.DataFrame([classic_metrics])\n",
|
540 |
-
" metrics_json = metrics_df.to_json(orient=\"records\")\n",
|
541 |
-
" metrics_json = json.loads(metrics_json)[0]\n",
|
542 |
-
"\n",
|
543 |
-
" # Prepare predicted and real positions DataFrame\n",
|
544 |
-
" predicted_df = pd.DataFrame({\n",
|
545 |
-
" 'MMSI': mmsi_seq[:len(y_pred)].flatten(),\n",
|
546 |
-
" 'Last Known Latitude': last_lat_denorm.flatten(),\n",
|
547 |
-
" 'Last Known Longitude': last_lon_denorm.flatten(),\n",
|
548 |
-
" 'Predicted Latitude': pred_lat.flatten(),\n",
|
549 |
-
" 'Predicted Longitude': pred_lon.flatten(),\n",
|
550 |
-
" 'Real Latitude': true_lat.flatten(),\n",
|
551 |
-
" 'Real Longitude': true_lon.flatten()\n",
|
552 |
-
" })\n",
|
553 |
-
"\n",
|
554 |
-
" # Save predictions as CSV\n",
|
555 |
-
" with tempfile.NamedTemporaryFile(delete=False, suffix='.csv', mode='w', newline='') as tmp_positions_file:\n",
|
556 |
-
" predicted_df.to_csv(tmp_positions_file, index=False)\n",
|
557 |
-
" positions_csv_path = tmp_positions_file.name\n",
|
558 |
-
"\n",
|
559 |
-
" logging.info(\"Classical prediction completed.\")\n",
|
560 |
-
" return metrics_json, positions_csv_path, inference_time\n",
|
561 |
-
" except Exception as e:\n",
|
562 |
-
" logging.error(f\"An error occurred: {str(e)}\")\n",
|
563 |
-
" return {\"error\": str(e)}, None, None\n",
|
564 |
-
"\n",
|
565 |
-
"# ============================\n",
|
566 |
-
"# Abnormal Behavior Detection\n",
|
567 |
-
"# ============================\n",
|
568 |
-
"\n",
|
569 |
-
"def abnormal_behavior_detection(prediction_file, alpha=0.5, threshold=10.0):\n",
|
570 |
-
" \"\"\"\n",
|
571 |
-
" Detect abnormal behavior based on angular divergence and distance difference.\n",
|
572 |
-
" Accepts a CSV file containing real and predicted positions.\n",
|
573 |
-
" \"\"\"\n",
|
574 |
-
" try:\n",
|
575 |
-
" logging.info(\"Starting abnormal behavior detection...\")\n",
|
576 |
-
"\n",
|
577 |
-
" # Load the CSV file containing real and predicted positions\n",
|
578 |
-
" logging.info(\"Loading prediction CSV file...\")\n",
|
579 |
-
" df = pd.read_csv(prediction_file.name)\n",
|
580 |
-
" logging.info(f\"Prediction CSV file loaded with {df.shape[0]} records.\")\n",
|
581 |
-
"\n",
|
582 |
-
" # Check if necessary columns exist\n",
|
583 |
-
" expected_columns = [\n",
|
584 |
-
" 'MMSI', 'Last Known Latitude', 'Last Known Longitude',\n",
|
585 |
-
" 'Predicted Latitude', 'Predicted Longitude',\n",
|
586 |
-
" 'Real Latitude', 'Real Longitude'\n",
|
587 |
-
" ]\n",
|
588 |
-
"\n",
|
589 |
-
" if not all(col in df.columns for col in expected_columns):\n",
|
590 |
-
" error_message = (\n",
|
591 |
-
" f\"Input data does not have the correct columns.\\n\"\n",
|
592 |
-
" f\"Expected columns: {expected_columns}\\n\"\n",
|
593 |
-
" f\"Got columns: {list(df.columns)}\"\n",
|
594 |
-
" )\n",
|
595 |
-
" logging.error(error_message)\n",
|
596 |
-
" return {\"error\": error_message}\n",
|
597 |
-
"\n",
|
598 |
-
" # Extract necessary data\n",
|
599 |
-
" mmsi_seq = df['MMSI'].values\n",
|
600 |
-
" last_lat_flat = df['Last Known Latitude'].values\n",
|
601 |
-
" last_lon_flat = df['Last Known Longitude'].values\n",
|
602 |
-
" pred_lat_flat = df['Predicted Latitude'].values\n",
|
603 |
-
" pred_lon_flat = df['Predicted Longitude'].values\n",
|
604 |
-
" true_lat_flat = df['Real Latitude'].values\n",
|
605 |
-
" true_lon_flat = df['Real Longitude'].values\n",
|
606 |
-
"\n",
|
607 |
-
" # Calculate bearings\n",
|
608 |
-
" logging.info(\"Calculating bearings for predictions and real values...\")\n",
|
609 |
-
" bearings_pred = [\n",
|
610 |
-
" calculate_bearing(last_lon_flat[i], last_lat_flat[i], pred_lon_flat[i], pred_lat_flat[i]) \n",
|
611 |
-
" for i in range(len(pred_lat_flat))\n",
|
612 |
-
" ]\n",
|
613 |
-
" bearings_true = [\n",
|
614 |
-
" calculate_bearing(last_lon_flat[i], last_lat_flat[i], true_lon_flat[i], true_lat_flat[i]) \n",
|
615 |
-
" for i in range(len(true_lat_flat))\n",
|
616 |
-
" ]\n",
|
617 |
-
"\n",
|
618 |
-
" # Calculate angular divergence Δθ\n",
|
619 |
-
" logging.info(\"Calculating angular divergence (Δθ)...\")\n",
|
620 |
-
" delta_theta = [\n",
|
621 |
-
" angular_divergence(bearings_pred[i], bearings_true[i]) \n",
|
622 |
-
" for i in range(len(bearings_pred))\n",
|
623 |
-
" ]\n",
|
624 |
-
"\n",
|
625 |
-
" # Calculate distance difference Δd\n",
|
626 |
-
" logging.info(\"Calculating distance difference (Δd)...\")\n",
|
627 |
-
" delta_d = [\n",
|
628 |
-
" haversine(last_lon_flat[i], last_lat_flat[i], pred_lon_flat[i], pred_lat_flat[i]) - \n",
|
629 |
-
" haversine(last_lon_flat[i], last_lat_flat[i], true_lon_flat[i], true_lat_flat[i])\n",
|
630 |
-
" for i in range(len(pred_lat_flat))\n",
|
631 |
-
" ]\n",
|
632 |
-
"\n",
|
633 |
-
" # Compute the score\n",
|
634 |
-
" logging.info(\"Computing the abnormal behavior score...\")\n",
|
635 |
-
" score = [alpha * abs(dd) + (1 - alpha) * dt for dd, dt in zip(delta_d, delta_theta)]\n",
|
636 |
-
"\n",
|
637 |
-
" # Determine abnormal behavior\n",
|
638 |
-
" logging.info(\"Determining abnormal behavior based on the score...\")\n",
|
639 |
-
" abnormal_behavior = [1 if s >= threshold else 0 for s in score] # 1: Abnormal, 0: Normal\n",
|
640 |
-
"\n",
|
641 |
-
" # Create DataFrame for saving\n",
|
642 |
-
" abnormal_behavior_df = pd.DataFrame({\n",
|
643 |
-
" 'MMSI': mmsi_seq,\n",
|
644 |
-
" 'Last Known Latitude': last_lat_flat,\n",
|
645 |
-
" 'Last Known Longitude': last_lon_flat,\n",
|
646 |
-
" 'Predicted Latitude': pred_lat_flat,\n",
|
647 |
-
" 'Predicted Longitude': pred_lon_flat,\n",
|
648 |
-
" 'Real Latitude': true_lat_flat,\n",
|
649 |
-
" 'Real Longitude': true_lon_flat,\n",
|
650 |
-
" 'Distance Difference (Δd) [km]': delta_d,\n",
|
651 |
-
" 'Angular Divergence (Δθ) [degrees]': delta_theta,\n",
|
652 |
-
" 'Score (αΔd + (1-α)Δθ)': score,\n",
|
653 |
-
" 'Abnormal Behavior (1=Abnormal, 0=Normal)': abnormal_behavior\n",
|
654 |
-
" })\n",
|
655 |
-
"\n",
|
656 |
-
" # Save abnormal behavior dataset as CSV\n",
|
657 |
-
" with tempfile.NamedTemporaryFile(delete=False, suffix='.csv', mode='w', newline='') as tmp_abnormal_file:\n",
|
658 |
-
" abnormal_behavior_df.to_csv(tmp_abnormal_file, index=False)\n",
|
659 |
-
" abnormal_csv_path = tmp_abnormal_file.name\n",
|
660 |
-
"\n",
|
661 |
-
" logging.info(\"Abnormal behavior detection completed.\")\n",
|
662 |
-
" return abnormal_csv_path\n",
|
663 |
-
" except Exception as e:\n",
|
664 |
-
" logging.error(f\"An error occurred: {str(e)}\")\n",
|
665 |
-
" return {\"error\": str(e)}\n",
|
666 |
-
"\n",
|
667 |
-
"# ============================\n",
|
668 |
-
"# Define Gradio Interface\n",
|
669 |
-
"# ============================\n",
|
670 |
-
"\n",
|
671 |
-
"def main():\n",
|
672 |
-
" # ============================\n",
|
673 |
-
" # Define Model and Scaler Paths\n",
|
674 |
-
" # ============================\n",
|
675 |
-
"\n",
|
676 |
-
" model_paths = {\n",
|
677 |
-
" 'teacher': 'LSTM_whole_atlantic_horizon1_with_time_decimal_input_batch256/horizon_data_LSTM_whole_atlantic_horizon1_with_time_decimal_input_batch256_seq_24/run_1/best_model.pth',\n",
|
678 |
-
" 'student_north': 'LSTM_whole_atlantic_horizon1_with_time_decimal_input_batch256_KD_North/horizon1_data_LSTM_whole_atlantic_horizon1_with_time_decimal_input_batch256_KD_North_seq_24/run_1/best_model.pth',\n",
|
679 |
-
" 'student_mid': 'LSTM_whole_atlantic_horizon1_with_time_decimal_input_batch256_KD_Mid/horizon1_data_LSTM_whole_atlantic_horizon1_with_time_decimal_input_batch256_KD_Mid_seq_24/run_1/best_model.pth',\n",
|
680 |
-
" 'student_south': 'LSTM_whole_atlantic_horizon1_with_time_decimal_input_batch256_KD_South/horizon1_data_LSTM_whole_atlantic_horizon1_with_time_decimal_input_batch256_KD_South_seq_24/run_1/best_model.pth'\n",
|
681 |
-
" }\n",
|
682 |
-
"\n",
|
683 |
-
" scaler_paths = {\n",
|
684 |
-
" 'Teacher': 'scaler_train_wholedata.joblib',\n",
|
685 |
-
" 'Student_North': 'scaler_train_North.joblib',\n",
|
686 |
-
" 'Student_Mid': 'scaler_train_Mid.joblib',\n",
|
687 |
-
" 'Student_South': 'scaler_train_South.joblib'\n",
|
688 |
-
" }\n",
|
689 |
-
"\n",
|
690 |
-
" # ============================\n",
|
691 |
-
" # Load Models and Scalers\n",
|
692 |
-
" # ============================\n",
|
693 |
-
"\n",
|
694 |
-
" logging.info(\"Loading models and scalers...\")\n",
|
695 |
-
" models = load_models(model_paths)\n",
|
696 |
-
" loaded_scalers = load_scalers(scaler_paths)\n",
|
697 |
-
" logging.info(\"All models and scalers loaded successfully.\")\n",
|
698 |
-
"\n",
|
699 |
-
" # Define the Gradio components for classical prediction tab\n",
|
700 |
-
" classical_tab = gr.Interface(\n",
|
701 |
-
" fn=lambda file, model_choice, min_mmsi, max_mmsi: classical_prediction(file, model_choice, min_mmsi, max_mmsi, models, loaded_scalers),\n",
|
702 |
-
" inputs=[\n",
|
703 |
-
" gr.File(label=\"Upload CSV File\"),\n",
|
704 |
-
" gr.Dropdown(choices=[\"Auto-Select\", \"Teacher\", \"Student_North\", \"Student_Mid\", \"Student_South\"], value=\"Auto-Select\", label=\"Choose Model\"),\n",
|
705 |
-
" gr.Number(label=\"Min MMSI\", value=0),\n",
|
706 |
-
" gr.Number(label=\"Max MMSI\", value=999999999)\n",
|
707 |
-
" ],\n",
|
708 |
-
" outputs=[\n",
|
709 |
-
" gr.JSON(label=\"Classical Metrics (Degrees)\"),\n",
|
710 |
-
" gr.File(label=\"Download Predicted & Real Positions CSV\"),\n",
|
711 |
-
" gr.Number(label=\"Inference Time (seconds)\")\n",
|
712 |
-
" ],\n",
|
713 |
-
" title=\"Classical Prediction & Metrics\",\n",
|
714 |
-
" description=\"Upload a CSV file and select a model to get classical evaluation metrics such as MAE, MSE, RMSE. The inference time is also provided.\"\n",
|
715 |
-
" )\n",
|
716 |
-
"\n",
|
717 |
-
" # Define the Gradio components for abnormal behavior detection tab\n",
|
718 |
-
" abnormal_tab = gr.Interface(\n",
|
719 |
-
" fn=lambda prediction_file, alpha, threshold: abnormal_behavior_detection(prediction_file, alpha, threshold),\n",
|
720 |
-
" inputs=[\n",
|
721 |
-
" gr.File(label=\"Upload Predicted Positions CSV\"),\n",
|
722 |
-
" gr.Slider(minimum=0, maximum=1, step=0.1, value=0.5, label=\"Alpha (α)\"),\n",
|
723 |
-
" gr.Number(label=\"Threshold\", value=10.0)\n",
|
724 |
-
" ],\n",
|
725 |
-
" outputs=[\n",
|
726 |
-
" gr.File(label=\"Download Abnormal Behavior CSV\")\n",
|
727 |
-
" ],\n",
|
728 |
-
" title=\"Abnormal Behavior Detection\",\n",
|
729 |
-
" description=(\n",
|
730 |
-
" \"Upload the CSV file containing real and predicted positions from the Classical Prediction tab. \"\n",
|
731 |
-
" \"Adjust the Alpha and Threshold parameters to compute abnormal behavior.\"\n",
|
732 |
-
" )\n",
|
733 |
-
" )\n",
|
734 |
-
"\n",
|
735 |
-
" # Combine the two tabs using Gradio Tabs component\n",
|
736 |
-
" with gr.Blocks() as demo:\n",
|
737 |
-
" gr.Markdown(\"# Vessel Trajectory Prediction and Abnormal Behavior Detection\")\n",
|
738 |
-
" with gr.Tabs():\n",
|
739 |
-
" with gr.TabItem(\"Classical Prediction\"):\n",
|
740 |
-
" classical_tab.render()\n",
|
741 |
-
" with gr.TabItem(\"Abnormal Behavior Detection\"):\n",
|
742 |
-
" abnormal_tab.render()\n",
|
743 |
-
"\n",
|
744 |
-
" # Launch the Gradio interface\n",
|
745 |
-
" logging.info(\"Launching Gradio interface...\")\n",
|
746 |
-
" demo.launch(share=True)\n",
|
747 |
-
" logging.info(\"Gradio interface launched successfully.\")\n",
|
748 |
-
"\n",
|
749 |
-
"# Run the app\n",
|
750 |
-
"if __name__ == \"__main__\":\n",
|
751 |
-
" main()\n"
|
752 |
-
]
|
753 |
-
}
|
754 |
-
],
|
755 |
-
"metadata": {
|
756 |
-
"kernelspec": {
|
757 |
-
"display_name": "Python 3 (ipykernel)",
|
758 |
-
"language": "python",
|
759 |
-
"name": "python3"
|
760 |
-
},
|
761 |
-
"language_info": {
|
762 |
-
"codemirror_mode": {
|
763 |
-
"name": "ipython",
|
764 |
-
"version": 3
|
765 |
-
},
|
766 |
-
"file_extension": ".py",
|
767 |
-
"mimetype": "text/x-python",
|
768 |
-
"name": "python",
|
769 |
-
"nbconvert_exporter": "python",
|
770 |
-
"pygments_lexer": "ipython3",
|
771 |
-
"version": "3.11.8"
|
772 |
-
}
|
773 |
-
},
|
774 |
-
"nbformat": 4,
|
775 |
-
"nbformat_minor": 5
|
776 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|