File size: 5,657 Bytes
cc56315
 
 
 
 
 
 
 
6db7e59
cc56315
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e47836e
cc56315
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import torch
import torch.nn as nn
import pandas as pd
from sklearn.preprocessing import LabelEncoder
import gradio as gr
import cv2
import mediapipe as mp
import numpy as np
import spaces

# Define the ASLClassifier model
class ASLClassifier(nn.Module):
    def __init__(self, input_size=63, hidden_size=256, num_classes=28):
        super(ASLClassifier, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.bn1 = nn.BatchNorm1d(hidden_size)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(0.3)
        self.fc2 = nn.Linear(hidden_size, hidden_size * 2)
        self.bn2 = nn.BatchNorm1d(hidden_size * 2)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(0.3)
        self.fc3 = nn.Linear(hidden_size * 2, hidden_size)
        self.bn3 = nn.BatchNorm1d(hidden_size)
        self.relu3 = nn.ReLU()
        self.dropout3 = nn.Dropout(0.3)
        self.fc4 = nn.Linear(hidden_size, hidden_size // 2)
        self.bn4 = nn.BatchNorm1d(hidden_size // 2)
        self.relu4 = nn.ReLU()
        self.dropout4 = nn.Dropout(0.3)
        self.fc5 = nn.Linear(hidden_size // 2, num_classes)

    def forward(self, x):
        x = self.fc1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.dropout1(x)
        x = self.fc2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        x = self.dropout2(x)
        x = self.fc3(x)
        x = self.bn3(x)
        x = self.relu3(x)
        x = self.dropout3(x)
        x = self.fc4(x)
        x = self.bn4(x)
        x = self.relu4(x)
        x = self.dropout4(x)
        x = self.fc5(x)
        return x

# Load the model and label encoder (CPU initially, GPU handled by decorator)
device = torch.device('cpu')  # Default to CPU; GPU inference handled by @spaces.GPU
model = ASLClassifier().to(device)
model.load_state_dict(torch.load('data/asl_classifier.pth', map_location=device))
model.eval()

df = pd.read_csv('data/asl_landmarks_final.csv')
label_encoder = LabelEncoder()
label_encoder.fit(df['label'].values)

# Initialize MediaPipe (runs on CPU)
mp_hands = mp.solutions.hands
hands = mp_hands.Hands(static_image_mode=False, max_num_hands=1, min_detection_confidence=0.5)
mp_drawing = mp.solutions.drawing_utils

# Prediction function with GPU offloading
@spaces.GPU
def predict_letter(landmarks, model, label_encoder):
    with torch.no_grad():
        # Move to GPU for inference (handled by decorator)
        landmarks = torch.tensor(landmarks, dtype=torch.float32).unsqueeze(0).to('cuda')
        model = model.to('cuda')
        output = model(landmarks)
        _, predicted_idx = torch.max(output, 1)
        letter = label_encoder.inverse_transform([predicted_idx.item()])[0]
        # Move model back to CPU to free GPU memory
        model = model.to('cpu')
    return letter

# Video processing function (CPU for video processing, GPU for prediction)
def process_video(video_path):
    # Open video file
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        return None, "Error: Could not open video."

    # Variables to store output
    text_output = ""
    out_frames = []

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        # Process frame with MediaPipe (CPU)
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        results = hands.process(frame_rgb)

        if results.multi_hand_landmarks:
            for hand_landmarks in results.multi_hand_landmarks:
                # Draw landmarks
                mp_drawing.draw_landmarks(frame, hand_landmarks, mp_hands.HAND_CONNECTIONS)

                # Extract landmarks and predict (GPU via decorator)
                landmarks = []
                for lm in hand_landmarks.landmark:
                    landmarks.extend([lm.x, lm.y, lm.z])
                landmarks = np.array(landmarks, dtype=np.float32)
                predicted_letter = predict_letter(landmarks, model, label_encoder)

                # Add letter to text (avoid duplicates if same as last)
                if not text_output or predicted_letter != text_output[-1]:
                    text_output += predicted_letter

                # Overlay predicted letter on frame
                cv2.putText(frame, f"Letter: {predicted_letter}", (10, 30),
                            cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2, cv2.LINE_AA)

        # Store processed frame
        out_frames.append(frame)

    cap.release()

    # Write processed video to a temporary file
    out_path = "processed_video.mp4"
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(out_path, fourcc, 20.0, (out_frames[0].shape[1], out_frames[0].shape[0]))
    for frame in out_frames:
        out.write(frame)
    out.release()

    return out_path, text_output

# Create Gradio interface with sample input
with gr.Blocks(title="Sign Language Translation") as demo:
    gr.Markdown("## Sign Language Translation")
    video_input = gr.Video(label="Input Video", sources=["upload", "webcam"])
    video_output = gr.Video(label="Processed Video with Landmarks")
    text_output = gr.Textbox(label="Predicted Text", interactive=False)

    # Button to process video
    btn = gr.Button("Translate")
    btn.click(
        fn=process_video,
        inputs=video_input,
        outputs=[video_output, text_output]
    )

    # Add sample input video
    gr.Examples(
        examples=[["data/letters_seq.mp4"]],
        inputs=[video_input],
        outputs=[video_output, text_output],
        fn=process_video,
        cache_examples=True  # Cache the output for faster loading
    )

# Launch the app
demo.launch()