File size: 30,875 Bytes
476fdbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3214805
476fdbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94969df
476fdbc
 
 
 
 
 
 
 
 
94969df
476fdbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b07abc
476fdbc
 
9b07abc
476fdbc
 
 
 
 
94969df
476fdbc
 
 
 
 
 
 
 
 
94969df
476fdbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b07abc
476fdbc
 
9b07abc
476fdbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d8345d
 
 
 
476fdbc
 
 
 
 
 
 
 
 
 
 
 
 
3214805
476fdbc
94969df
3214805
 
 
 
 
476fdbc
 
 
 
 
9b07abc
 
 
476fdbc
 
3214805
 
 
 
476fdbc
 
 
 
3214805
476fdbc
94969df
476fdbc
 
 
 
9b07abc
 
476fdbc
 
 
 
 
 
 
 
3214805
476fdbc
 
 
 
 
 
 
 
 
 
 
3214805
476fdbc
 
 
 
51ec1bd
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
# Final version...
import torch
import torch.nn as nn
import gradio as gr
import pandas as pd
import numpy as np
from sklearn.metrics import mean_absolute_error, mean_squared_error
import os
import logging
import joblib
from tqdm import tqdm
import tempfile
import json
from math import radians, cos, sin, asin, sqrt, atan2, degrees
import time
import functools

# ============================
# Configure Logging
# ============================

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# ============================
# Helper Functions
# ============================

def add_time_decimal_feature(df):
    """
    Add 'time_decimal' feature by combining 'hour' and 'minutes'.

    :param df: DataFrame with 'hour' and 'minutes' columns.
    :return: DataFrame with 'time_decimal' and without 'hour' and 'minutes'.
    """
    if 'hour' in df.columns and 'minutes' in df.columns:
        logging.info("Adding 'time_decimal' feature...")
        df['time_decimal'] = df['hour'] + df['minutes'] / 60.0
        df = df.drop(columns=['hour', 'minutes'])  # Drop 'hour' and 'minutes' after creation
        logging.info("'time_decimal' feature added.")
    else:
        logging.warning("'hour' and/or 'minutes' columns not found. Skipping 'time_decimal' feature addition.")
    return df

def haversine(lon1, lat1, lon2, lat2):
    """
    Calculate the great-circle distance between two points on the Earth.

    :param lon1: Longitude of point 1 (in decimal degrees)
    :param lat1: Latitude of point 1 (in decimal degrees)
    :param lon2: Longitude of point 2 (in decimal degrees)
    :param lat2: Latitude of point 2 (in decimal degrees)
    :return: Distance in kilometers
    """
    # Convert decimal degrees to radians
    lon1_rad, lat1_rad, lon2_rad, lat2_rad = map(np.radians, [lon1, lat1, lon2, lat2])

    # Haversine formula
    dlon = lon2_rad - lon1_rad 
    dlat = lat2_rad - lat1_rad 
    a = np.sin(dlat/2)**2 + np.cos(lat1_rad) * np.cos(lat2_rad) * np.sin(dlon/2)**2
    c = 2 * np.arcsin(np.sqrt(a)) 
    r = 6371  # Radius of Earth in kilometers
    return c * r

def calculate_bearing(lon1, lat1, lon2, lat2):
    """
    Calculate the bearing between two points.

    :param lon1: Longitude of point 1 (in decimal degrees)
    :param lat1: Latitude of point 1 (in decimal degrees)
    :param lon2: Longitude of point 2 (in decimal degrees)
    :param lat2: Latitude of point 2 (in decimal degrees)
    :return: Bearing in degrees
    """
    # Convert decimal degrees to radians
    lon1_rad, lat1_rad, lon2_rad, lat2_rad = map(radians, [lon1, lat1, lon2, lat2])

    dlon = lon2_rad - lon1_rad
    x = sin(dlon) * cos(lat2_rad)
    y = cos(lat1_rad) * sin(lat2_rad) - (sin(lat1_rad) * cos(lat2_rad) * cos(dlon))

    initial_bearing = atan2(x, y)

    # Convert from radians to degrees and normalize
    initial_bearing = degrees(initial_bearing)
    compass_bearing = (initial_bearing + 360) % 360

    return compass_bearing

def angular_divergence(bearing1, bearing2):
    """
    Calculate the smallest angle difference between two bearings.

    :param bearing1: First bearing in degrees
    :param bearing2: Second bearing in degrees
    :return: Angular divergence in degrees
    """
    diff = abs(bearing1 - bearing2) % 360
    return min(diff, 360 - diff)

def denormalize(scaled_lat, scaled_lon, scaler, lat_idx, lon_idx):
    """
    Denormalize latitude and longitude using the scaler's parameters.

    :param scaled_lat: Scaled latitude values (numpy array).
    :param scaled_lon: Scaled longitude values (numpy array).
    :param scaler: The scaler object used for normalization.
    :param lat_idx: Index of 'latitude_degrees' in the scaler's feature list.
    :param lon_idx: Index of 'longitude_degrees' in the scaler's feature list.
    :return: Tuple of (denormalized_lat, denormalized_lon).
    """
    lat_min = scaler.data_min_[lat_idx]
    lat_max = scaler.data_max_[lat_idx]
    lon_min = scaler.data_min_[lon_idx]
    lon_max = scaler.data_max_[lon_idx]

    denorm_lat = scaled_lat * (lat_max - lat_min) + lat_min
    denorm_lon = scaled_lon * (lon_max - lon_min) + lon_min
    return denorm_lat, denorm_lon

def create_dataset_grouped_by_mmsi(df_scaled, seq_len, forecast_horizon, features_to_scale):
    """
    Create input and output sequences grouped by original MMSI.
    Returns scaled last known positions.
    """
    Xs, ys, mmsis = [], [], []
    last_known_positions_scaled = []

    grouped = df_scaled.groupby('original_mmsi')

    for mmsi, group in tqdm(grouped, desc="Creating sequences"):
        if len(group) >= seq_len + forecast_horizon:
            for i in range(len(group) - seq_len - forecast_horizon + 1):
                # Select scaled features for the sequence
                sequence = group.iloc[i:(i + seq_len)][features_to_scale].to_numpy()

                # Future positions to predict (scaled)
                future_positions = group[['latitude_degrees', 'longitude_degrees']].iloc[i + seq_len:i + seq_len + forecast_horizon].to_numpy()

                # Future hour feature
                future_hour = group[['time_decimal']].iloc[i + seq_len].values[0]
                future_hour_feature = np.full((seq_len, 1), future_hour)

                # Combine sequence with future_hour_feature
                sequence_with_future_hour = np.hstack((sequence, future_hour_feature))

                Xs.append(sequence_with_future_hour)
                ys.append(future_positions)
                mmsis.append(mmsi)

                # Store last known positions (scaled)
                last_lat_scaled = group['latitude_degrees'].iloc[i + seq_len - 1]
                last_lon_scaled = group['longitude_degrees'].iloc[i + seq_len - 1]
                last_known_positions_scaled.append((last_lat_scaled, last_lon_scaled))

    return np.array(Xs, dtype=np.float32), np.array(ys, dtype=np.float32), np.array(mmsis), last_known_positions_scaled

# ============================
# Model Definitions
# ============================

class LSTMModelTeacher(nn.Module):
    def __init__(self, in_dim, hidden_dim, forecast_horizon, n_layers=7, dropout=0.2):
        """
        Teacher LSTM Model.

        :param in_dim: Number of input features.
        :param hidden_dim: Number of hidden units.
        :param forecast_horizon: Number of future steps to predict.
        :param n_layers: Number of LSTM layers.
        :param dropout: Dropout rate.
        """
        super(LSTMModelTeacher, self).__init__()
        self.forecast_horizon = forecast_horizon  # Store as an instance attribute
        self.embedding = nn.Linear(in_dim, hidden_dim)
        self.lstm = nn.LSTM(hidden_dim, hidden_dim, num_layers=n_layers, dropout=dropout, batch_first=True)
        self.fc = nn.Linear(hidden_dim, forecast_horizon * 2)

    def forward(self, x):
        x = self.embedding(x)
        x, _ = self.lstm(x)
        x = self.fc(x[:, -1, :])  # Use the last timestep for prediction
        x = x.view(-1, self.forecast_horizon, 2)  # Shape: (batch_size, forecast_horizon, 2)
        return x

class LSTMModelStudent(nn.Module):
    def __init__(self, in_dim, hidden_dim, forecast_horizon, n_layers=3, dropout=0.2):
        """
        Student LSTM Model.

        :param in_dim: Number of input features.
        :param hidden_dim: Number of hidden units.
        :param forecast_horizon: Number of future steps to predict.
        :param n_layers: Number of LSTM layers.
        :param dropout: Dropout rate.
        """
        super(LSTMModelStudent, self).__init__()
        self.forecast_horizon = forecast_horizon  # Store as an instance attribute
        self.embedding = nn.Linear(in_dim, hidden_dim)
        self.lstm = nn.LSTM(hidden_dim, hidden_dim, num_layers=n_layers, dropout=dropout, batch_first=True)
        self.fc = nn.Linear(hidden_dim, forecast_horizon * 2)

    def forward(self, x):
        x = self.embedding(x)
        x, _ = self.lstm(x)
        x = self.fc(x[:, -1, :])  # Use the last timestep for prediction
        x = x.view(-1, self.forecast_horizon, 2)  # Shape: (batch_size, forecast_horizon, 2)
        return x

# ============================
# Model Loading Functions
# ============================

def load_models(model_paths):
    """
    Load teacher and student models, including submodels for North, Mid, and South areas.

    :param model_paths: Dictionary containing paths to the models.
    :return: Dictionary of loaded models.
    """
    models = {}
    logging.info("Loading Teacher model...")
    # Load Teacher Model (Global)
    teacher = LSTMModelTeacher(in_dim=15, hidden_dim=200, forecast_horizon=1, n_layers=7, dropout=0.2)  # 15 features including 'future_hour_feature'
    teacher.load_state_dict(torch.load(model_paths['teacher'], map_location=torch.device('cpu')))
    teacher.eval()
    models['Teacher'] = teacher
    logging.info("Teacher model loaded successfully.")

    logging.info("Loading Student North model...")
    # Load Student Models (Sub-areas)
    student_north = LSTMModelStudent(in_dim=15, hidden_dim=200, forecast_horizon=1, n_layers=3, dropout=0.2)
    student_north.load_state_dict(torch.load(model_paths['student_north'], map_location=torch.device('cpu')))
    student_north.eval()
    models['Student_North'] = student_north
    logging.info("Student North model loaded successfully.")

    logging.info("Loading Student Mid model...")
    student_mid = LSTMModelStudent(in_dim=15, hidden_dim=200, forecast_horizon=1, n_layers=3, dropout=0.2)
    student_mid.load_state_dict(torch.load(model_paths['student_mid'], map_location=torch.device('cpu')))
    student_mid.eval()
    models['Student_Mid'] = student_mid
    logging.info("Student Mid model loaded successfully.")

    logging.info("Loading Student South model...")
    student_south = LSTMModelStudent(in_dim=15, hidden_dim=200, forecast_horizon=1, n_layers=3, dropout=0.2)
    student_south.load_state_dict(torch.load(model_paths['student_south'], map_location=torch.device('cpu')))
    student_south.eval()
    models['Student_South'] = student_south
    logging.info("Student South model loaded successfully.")

    return models

def load_scalers(scaler_paths):
    """
    Load scalers for each model.

    :param scaler_paths: Dictionary containing paths to the scaler files.
    :return: Dictionary of loaded scalers.
    """
    loaded_scalers = {}
    for model_name, scaler_path in scaler_paths.items():
        if os.path.exists(scaler_path):
            loaded_scalers[model_name] = joblib.load(scaler_path)
            logging.info(f"Loaded scaler for {model_name} from '{scaler_path}'.")
        else:
            logging.error(f"Scaler file for {model_name} not found at '{scaler_path}'.")
            raise FileNotFoundError(f"Scaler file for {model_name} not found at '{scaler_path}'. Please provide the correct path.")
    return loaded_scalers

# ============================
# Model Selection Logic
# ============================

def determine_subarea(df):
    """
    Determine the sub-area (North, Mid, South) based on latitude and longitude ranges.

    :param df: DataFrame containing 'latitude_degrees' and 'longitude_degrees'.
    :return: String indicating the sub-area.
    """
    # Define sub-area boundaries
    subareas = {
        'North': {'lat_min': 30, 'lat_max': 60, 'lon_min': -80, 'lon_max': -10},
        'Mid': {'lat_min': 0, 'lat_max': 30, 'lon_min': -80, 'lon_max': 10},
        'South': {'lat_min': -80, 'lat_max': 0, 'lon_min': -60, 'lon_max': 20}
    }

    # Count the number of data points in each sub-area
    counts = {}
    for area, bounds in subareas.items():
        count = df[
            (df['latitude_degrees'] >= bounds['lat_min']) & (df['latitude_degrees'] <= bounds['lat_max']) &
            (df['longitude_degrees'] >= bounds['lon_min']) & (df['longitude_degrees'] <= bounds['lon_max'])
        ].shape[0]
        counts[area] = count
        logging.info(f"Sub-area '{area}': {count} records.")

    # Determine the sub-area with the maximum count
    predominant_subarea = max(counts, key=counts.get)
    logging.info(f"Predominant sub-area determined: {predominant_subarea}")

    # If no data points fall into any sub-area, default to Teacher
    if counts[predominant_subarea] == 0:
        logging.warning("No data points found in any sub-area. Defaulting to Teacher model.")
        return 'Teacher'

    return predominant_subarea

def select_model(models, subarea):
    """
    Select the appropriate model based on the sub-area.

    :param models: Dictionary of loaded models.
    :param subarea: String indicating the sub-area.
    :return: Tuple of (selected_model, selected_model_name).
    """
    if subarea in ['North', 'Mid', 'South']:
        selected_model = models.get(f'Student_{subarea}')
        selected_model_name = f'Student_{subarea}'
        logging.info(f"Selected model: {selected_model_name}")
        return selected_model, selected_model_name
    else:
        selected_model = models.get('Teacher')
        selected_model_name = 'Teacher'
        logging.info(f"Selected model: {selected_model_name}")
        return selected_model, selected_model_name

# ============================
# Evaluation Metrics Calculation
# ============================

def calculate_classic_metrics(y_true, y_pred):
    """
    Calculate MAE, MSE, and RMSE directly on latitude/longitude pairs.

    :param y_true: Ground truth positions (numpy array of shape (num_samples, 2)).
    :param y_pred: Predicted positions (numpy array of shape (num_samples, 2)).
    :return: Dictionary containing the classic metrics.
    """
    # Calculate MAE
    mae = mean_absolute_error(y_true, y_pred)

    # Calculate MSE
    mse = mean_squared_error(y_true, y_pred)

    # Calculate RMSE
    rmse = np.sqrt(mse)

    classic_metrics = {
        'MAE (degrees)': mae,
        'MSE (degrees^2)': mse,
        'RMSE (degrees)': rmse
    }

    logging.info(f"Calculated classic metrics: {classic_metrics}")
    
    return classic_metrics

def calculate_distance_metrics(y_true, y_pred):
    """
    Calculate metrics based on distance (in kilometers).

    :param y_true: Ground truth positions (numpy array of shape (num_samples, 2)).
    :param y_pred: Predicted positions (numpy array of shape (num_samples, 2)).
    :return: Dictionary containing the distance-based metrics.
    """
    # Calculate haversine distance between predicted and true positions
    distances = np.array([
        haversine(y_true[i, 1], y_true[i, 0], y_pred[i, 1], y_pred[i, 0]) 
        for i in range(len(y_true))
    ])  # Assuming columns are [latitude, longitude]

    # Calculate MAE
    mae = np.mean(np.abs(distances))

    # Calculate MSE
    mse = np.mean(np.square(distances))

    # Calculate RMSE
    rmse = np.sqrt(mse)

    # Calculate RSE (Relative Squared Error)
    variance = np.var(distances)
    rse = mse / variance if variance != 0 else float('inf')

    metrics = {
        'MAE (km)': mae,
        'MSE (km^2)': mse,
        'RMSE (km)': rmse,
        'RSE': rse
    }

    logging.info(f"Calculated distance metrics: {metrics}")
    
    return metrics

# ============================
# Classical Metrics Prediction
# ============================

def classical_prediction(file_path, model_choice, min_mmsi, max_mmsi, models, loaded_scalers):
    """
    Preprocess the input CSV and make predictions using the selected model.
    Calculate classical evaluation metrics and include inference time.
    """
    try:
        logging.info("Starting classical prediction...")

        # Load the uploaded CSV file and filter based on MMSI
        logging.info("Loading uploaded CSV file...")
        df = pd.read_csv(file_path, delimiter=',')
        logging.info(f"Uploaded CSV file loaded with {df.shape[0]} records.")
        
        df = df[(df['mmsi'] >= min_mmsi) & (df['mmsi'] <= max_mmsi)]
        if df.empty:
            error_message = "No data available after applying MMSI filters."
            logging.error(error_message)
            return {"error": error_message}, None, None

        # Check if 'time_decimal' exists
        if 'time_decimal' not in df.columns:
            df = add_time_decimal_feature(df)
        else:
            logging.info("'time_decimal' feature already exists. Skipping creation.")

        expected_columns = [
            "mmsi", "sog_kt", "latitude_degrees", "longitude_degrees", "cog_degrees",
            "dimension_a_m", "dimension_b_m", "dimension_c_m", "dimension_d_m",
            "ship_type", "day", "month", "year", "time_decimal"
        ]

        if list(df.columns) != expected_columns:
            error_message = (
                f"Input data does not have the correct columns.\n"
                f"Expected columns: {expected_columns}\n"
                f"Got columns: {list(df.columns)}"
            )
            logging.error(error_message)
            return {"error": error_message}, None, None

        logging.info("Input CSV has the correct columns.")

        # Select the appropriate model and scaler
        if model_choice == "Auto-Select":
            temp_df = df.copy()
            subarea = determine_subarea(temp_df)
            selected_model, selected_model_name = select_model(models, subarea)
            scaler = loaded_scalers[selected_model_name]
        else:
            if model_choice in models:
                selected_model = models[model_choice]
                selected_model_name = model_choice
                scaler = loaded_scalers[selected_model_name]
            else:
                error_message = f"Selected model '{model_choice}' is not available."
                logging.error(error_message)
                return {"error": error_message}, None, None

        logging.info(f"Using scaler for model: {selected_model_name}")

        # Normalize the data
        logging.info("Normalizing the data...")
        features_to_scale = [
            "mmsi", "sog_kt", "latitude_degrees", "longitude_degrees", "cog_degrees",
            "dimension_a_m", "dimension_b_m", "dimension_c_m", "dimension_d_m",
            "ship_type", "day", "month", "year", "time_decimal"
        ]
        X_new = df[features_to_scale]
        X_scaled = scaler.transform(X_new)
        df_scaled = pd.DataFrame(X_scaled, columns=features_to_scale, index=df.index)
        df_scaled['original_mmsi'] = df['mmsi']

        # Create sequences and get last known positions (scaled)
        seq_len = 24
        forecast_horizon = 1
        X, y, mmsi_seq, last_known_positions_scaled = create_dataset_grouped_by_mmsi(df_scaled, seq_len, forecast_horizon, features_to_scale)

        if X.size == 0:
            error_message = "Not enough data to create sequences."
            logging.error(error_message)
            return {"error": error_message}, None, None

        logging.info(f"Created {X.shape[0]} sequences.")

        # Inference
        logging.info("Starting model inference...")
        test_dataset = torch.utils.data.TensorDataset(torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.float32))
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)
        all_predictions = []
        all_y_true = []

        start_time = time.time()  # Start inference time tracking

        with torch.no_grad():
            for batch in test_loader:
                X_batch, y_batch = batch
                predictions = selected_model(X_batch).cpu().numpy()
                all_predictions.append(predictions)
                all_y_true.append(y_batch.numpy())

        inference_time = time.time() - start_time  # End inference time tracking

        all_predictions = np.concatenate(all_predictions, axis=0)
        y_true = np.concatenate(all_y_true, axis=0)
        y_pred = all_predictions

        logging.info(f"Inference completed in {inference_time:.2f} seconds.")

        # Denormalize predictions and real values
        lat_idx = features_to_scale.index("latitude_degrees")
        lon_idx = features_to_scale.index("longitude_degrees")
        pred_lat, pred_lon = denormalize(y_pred[:, :, 0], y_pred[:, :, 1], scaler, lat_idx, lon_idx)
        true_lat, true_lon = denormalize(y_true[:, :, 0], y_true[:, :, 1], scaler, lat_idx, lon_idx)

        # Denormalize last known positions
        last_lat_scaled = np.array([pos[0] for pos in last_known_positions_scaled])
        last_lon_scaled = np.array([pos[1] for pos in last_known_positions_scaled])

        last_lat_denorm, last_lon_denorm = denormalize(
            last_lat_scaled, last_lon_scaled, scaler, lat_idx, lon_idx
        )

        # Calculate the classic evaluation metrics
        y_true_pairs = np.column_stack((true_lat.flatten(), true_lon.flatten()))
        y_pred_pairs = np.column_stack((pred_lat.flatten(), pred_lon.flatten()))
        classic_metrics = calculate_classic_metrics(y_true=y_true_pairs, y_pred=y_pred_pairs)
        classic_metrics['Inference Time (seconds)'] = inference_time  # Include inference time

        # Prepare metrics and output CSV
        metrics_df = pd.DataFrame([classic_metrics])
        metrics_json = metrics_df.to_json(orient="records")
        metrics_json = json.loads(metrics_json)[0]

        # Prepare predicted and real positions DataFrame
        predicted_df = pd.DataFrame({
            'MMSI': mmsi_seq[:len(y_pred)].flatten(),
            'Last Known Latitude': last_lat_denorm.flatten(),
            'Last Known Longitude': last_lon_denorm.flatten(),
            'Predicted Latitude': pred_lat.flatten(),
            'Predicted Longitude': pred_lon.flatten(),
            'Real Latitude': true_lat.flatten(),
            'Real Longitude': true_lon.flatten()
        })

        # Save predictions as CSV
        with tempfile.NamedTemporaryFile(delete=False, suffix='.csv', mode='w', newline='') as tmp_positions_file:
            predicted_df.to_csv(tmp_positions_file, index=False)
            positions_csv_path = tmp_positions_file.name

        logging.info("Classical prediction completed.")
        return metrics_json, positions_csv_path, inference_time, None
    except Exception as e:
        logging.error(f"An error occurred: {str(e)}")
        return None, None, None, str(e)

# ============================
# Abnormal Behavior Detection
# ============================

def abnormal_behavior_detection(prediction_file_path, alpha=0.5, threshold=10.0):
    """
    Detect abnormal behavior based on angular divergence and distance difference.
    Accepts a CSV file containing real and predicted positions.
    """
    try:
        logging.info("Starting abnormal behavior detection...")

        # Load the CSV file containing real and predicted positions
        logging.info("Loading prediction CSV file...")
        df = pd.read_csv(prediction_file_path)
        logging.info(f"Prediction CSV file loaded with {df.shape[0]} records.")

        # Check if necessary columns exist
        expected_columns = [
            'MMSI', 'Last Known Latitude', 'Last Known Longitude',
            'Predicted Latitude', 'Predicted Longitude',
            'Real Latitude', 'Real Longitude'
        ]

        if not all(col in df.columns for col in expected_columns):
            error_message = (
                f"Input data does not have the correct columns.\n"
                f"Expected columns: {expected_columns}\n"
                f"Got columns: {list(df.columns)}"
            )
            logging.error(error_message)
            return {"error": error_message}

        # Extract necessary data
        mmsi_seq = df['MMSI'].values
        last_lat_flat = df['Last Known Latitude'].values
        last_lon_flat = df['Last Known Longitude'].values
        pred_lat_flat = df['Predicted Latitude'].values
        pred_lon_flat = df['Predicted Longitude'].values
        true_lat_flat = df['Real Latitude'].values
        true_lon_flat = df['Real Longitude'].values

        # Calculate bearings
        logging.info("Calculating bearings for predictions and real values...")
        bearings_pred = [
            calculate_bearing(last_lon_flat[i], last_lat_flat[i], pred_lon_flat[i], pred_lat_flat[i]) 
            for i in range(len(pred_lat_flat))
        ]
        bearings_true = [
            calculate_bearing(last_lon_flat[i], last_lat_flat[i], true_lon_flat[i], true_lat_flat[i]) 
            for i in range(len(true_lat_flat))
        ]

        # Calculate angular divergence Δθ
        logging.info("Calculating angular divergence (Δθ)...")
        delta_theta = [
            angular_divergence(bearings_pred[i], bearings_true[i]) 
            for i in range(len(bearings_pred))
        ]

        # Calculate distance difference Δd
        logging.info("Calculating distance difference (Δd)...")
        delta_d = [
            haversine(last_lon_flat[i], last_lat_flat[i], pred_lon_flat[i], pred_lat_flat[i]) - 
            haversine(last_lon_flat[i], last_lat_flat[i], true_lon_flat[i], true_lat_flat[i])
            for i in range(len(pred_lat_flat))
        ]

        # Compute the score
        logging.info("Computing the abnormal behavior score...")
        score = [alpha * abs(dd) + (1 - alpha) * dt for dd, dt in zip(delta_d, delta_theta)]

        # Determine abnormal behavior
        logging.info("Determining abnormal behavior based on the score...")
        abnormal_behavior = [1 if s >= threshold else 0 for s in score]  # 1: Abnormal, 0: Normal

        # Create DataFrame for saving
        abnormal_behavior_df = pd.DataFrame({
            'MMSI': mmsi_seq,
            'Last Known Latitude': last_lat_flat,
            'Last Known Longitude': last_lon_flat,
            'Predicted Latitude': pred_lat_flat,
            'Predicted Longitude': pred_lon_flat,
            'Real Latitude': true_lat_flat,
            'Real Longitude': true_lon_flat,
            'Distance Difference (Δd) [km]': delta_d,
            'Angular Divergence (Δθ) [degrees]': delta_theta,
            'Score (αΔd + (1-α)Δθ)': score,
            'Abnormal Behavior (1=Abnormal, 0=Normal)': abnormal_behavior
        })

        # Save abnormal behavior dataset as CSV
        with tempfile.NamedTemporaryFile(delete=False, suffix='.csv', mode='w', newline='') as tmp_abnormal_file:
            abnormal_behavior_df.to_csv(tmp_abnormal_file, index=False)
            abnormal_csv_path = tmp_abnormal_file.name

        logging.info("Abnormal behavior detection completed.")
        return abnormal_csv_path, None
    except Exception as e:
        logging.error(f"An error occurred: {str(e)}")
        return None, str(e)

# ============================
# Define Gradio Interface
# ============================

def main():
    # ============================
    # Define Model and Scaler Paths
    # ============================

    model_paths = {
        'teacher': 'LSTM_whole_atlantic_horizon1_with_time_decimal_input_batch256/horizon_data_LSTM_whole_atlantic_horizon1_with_time_decimal_input_batch256_seq_24/run_1/best_model.pth',
        'student_north': 'LSTM_whole_atlantic_horizon1_with_time_decimal_input_batch256_KD_North/horizon1_data_LSTM_whole_atlantic_horizon1_with_time_decimal_input_batch256_KD_North_seq_24/run_1/best_model.pth',
        'student_mid': 'LSTM_whole_atlantic_horizon1_with_time_decimal_input_batch256_KD_Mid/horizon1_data_LSTM_whole_atlantic_horizon1_with_time_decimal_input_batch256_KD_Mid_seq_24/run_1/best_model.pth',
        'student_south': 'LSTM_whole_atlantic_horizon1_with_time_decimal_input_batch256_KD_South/horizon1_data_LSTM_whole_atlantic_horizon1_with_time_decimal_input_batch256_KD_South_seq_24/run_1/best_model.pth'
    }

    scaler_paths = {
        'Teacher': 'scaler_train_wholedata_up.joblib',
        'Student_North': 'scaler_train_North_up.joblib',
        'Student_Mid': 'scaler_train_Mid_up.joblib',
        'Student_South': 'scaler_train_South_up.joblib'
    }

    # ============================
    # Load Models and Scalers
    # ============================

    logging.info("Loading models and scalers...")
    models = load_models(model_paths)
    loaded_scalers = load_scalers(scaler_paths)
    logging.info("All models and scalers loaded successfully.")

    # Define the Gradio components for classical prediction tab
    classical_tab = gr.Interface(
        fn=functools.partial(classical_prediction, models=models, loaded_scalers=loaded_scalers),
        inputs=[
            gr.File(label="Upload CSV File", type='filepath'),
            gr.Dropdown(
                choices=["Auto-Select", "Teacher", "Student_North", "Student_Mid", "Student_South"],
                value="Auto-Select",
                label="Choose Model"
            ),
            gr.Number(label="Min MMSI", value=0),
            gr.Number(label="Max MMSI", value=999999999)
        ],
        outputs=[
            gr.JSON(label="Classical Metrics (Degrees)"),
            gr.File(label="Download Predicted & Real Positions CSV"),
            gr.Number(label="Inference Time (seconds)"),
            gr.Textbox(label="Error Message", lines=2, visible=False)
        ],
        title="Classical Prediction & Metrics",
        description=(
            "Upload a CSV file and select a model to get classical evaluation metrics such as MAE, MSE, RMSE. "
            "The inference time is also provided."
        )
    )

    # Define the Gradio components for abnormal behavior detection tab
    abnormal_tab = gr.Interface(
        fn=functools.partial(abnormal_behavior_detection),
        inputs=[
            gr.File(label="Upload Predicted Positions CSV", type='filepath'),
            gr.Slider(minimum=0, maximum=1, step=0.1, value=0.5, label="Alpha (α)"),
            gr.Number(label="Threshold", value=10.0)
        ],
        outputs=[
            gr.File(label="Download Abnormal Behavior CSV"),
            gr.Textbox(label="Error Message", lines=2, visible=False)
        ],
        title="Abnormal Behavior Detection",
        description=(
            "Upload the CSV file containing real and predicted positions from the Classical Prediction tab. "
            "Adjust the Alpha and Threshold parameters to compute abnormal behavior."
        )
    )


    # Combine the two tabs using Gradio Tabs component
    with gr.Blocks() as demo:
        gr.Markdown("# Vessel Trajectory Prediction and Abnormal Behavior Detection")
        with gr.Tabs():
            with gr.TabItem("Classical Prediction"):
                classical_tab.render()
            with gr.TabItem("Abnormal Behavior Detection"):
                abnormal_tab.render()

    # Launch the Gradio interface
    logging.info("Launching Gradio interface...")
    demo.launch()
    logging.info("Gradio interface launched successfully.")

# Run the app
if __name__ == "__main__":
    main()