File size: 5,689 Bytes
02c198f
 
 
 
 
 
5aa95d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02c198f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a94233d
02c198f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import os
import json
import torch
import numpy as np
from transformers import BertTokenizer
from sklearn.preprocessing import OneHotEncoder
import transformers
import torch
import torch.nn as nn
import torch.nn.functional as F

class AttentionPool(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.attention = nn.Linear(hidden_size, 1)
    
    def forward(self, last_hidden_state):
        attention_scores = self.attention(last_hidden_state).squeeze(-1)
        attention_weights = F.softmax(attention_scores, dim=1)
        pooled_output = torch.bmm(attention_weights.unsqueeze(1), last_hidden_state).squeeze(1)
        return pooled_output

class MultiSampleDropout(nn.Module):
    def __init__(self, dropout=0.5, num_samples=5):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.num_samples = num_samples
    
    def forward(self, x):
        return torch.mean(torch.stack([self.dropout(x) for _ in range(self.num_samples)]), dim=0)


class ImprovedBERTClass(nn.Module):
    def __init__(self, num_classes=13):
        super().__init__()
        self.bert = transformers.BertModel.from_pretrained('bert-base-uncased')
        self.attention_pool = AttentionPool(768)
        self.dropout = MultiSampleDropout()
        self.norm = nn.LayerNorm(768)
        self.classifier = nn.Linear(768, num_classes)
        
    def forward(self, input_ids, attention_mask, token_type_ids):
        bert_output = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        pooled_output = self.attention_pool(bert_output.last_hidden_state)
        pooled_output = self.dropout(pooled_output)
        pooled_output = self.norm(pooled_output)
        logits = self.classifier(pooled_output)
        return logits

def handler(data, context):
    """Handle incoming requests to the SageMaker endpoint."""
    
    if context.request_content_type != 'application/json':
        raise ValueError("This model only supports application/json input")

    # Set up device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Load model and tokenizer (consider caching these for better performance)
    model, tokenizer = load_model_and_tokenizer(context)
    
    # Process the input data
    input_data = json.loads(data.read().decode('utf-8'))
    query = input_data.get('text', '')
    k = input_data.get('k', 3)  # Default to top 3 if not specified
    
    # Tokenize and prepare the input
    inputs = tokenizer.encode_plus(
        query,
        add_special_tokens=True,
        max_length=64,
        padding='max_length',
        return_tensors='pt',
        truncation=True
    )
    ids = inputs['input_ids'].to(device, dtype=torch.long)
    mask = inputs['attention_mask'].to(device, dtype=torch.long)
    token_type_ids = inputs['token_type_ids'].to(device, dtype=torch.long)
    
    # Make the prediction
    model.eval()
    with torch.no_grad():
        outputs = model(ids, mask, token_type_ids)
    
    # Apply sigmoid for multi-label classification
    probabilities = torch.sigmoid(outputs)
    
    # Convert to numpy array
    probabilities = probabilities.cpu().detach().numpy().flatten()
    
    # Get top k predictions
    top_k_indices = np.argsort(probabilities)[-k:][::-1]
    top_k_probs = probabilities[top_k_indices]
    
    # Create one-hot encodings for top k indices
    top_k_one_hot = np.zeros((k, len(probabilities)))
    for i, idx in enumerate(top_k_indices):
        top_k_one_hot[i, idx] = 1
    
    # Decode the top k predictions
    top_k_cards = [decode_vector(one_hot.reshape(1, -1)) for one_hot in top_k_one_hot]
    
    # Create a list of tuples (card, probability) for top k predictions
    top_k_predictions = list(zip(top_k_cards, top_k_probs.tolist()))
    
    # Determine the most likely card
    predicted_labels = (probabilities > 0.5).astype(int)
    if sum(predicted_labels) == 0:
        most_likely_card = "Answer"
    else:
        most_likely_card = decode_vector(predicted_labels.reshape(1, -1))
    
    # Prepare the response
    result = {
        "most_likely_card": most_likely_card,
        "top_k_predictions": top_k_predictions
    }
    
    return json.dumps(result), 'application/json'


def load_model_and_tokenizer(context):
    """Load the PyTorch model and tokenizer."""
    global global_encoder
    labels = ['Videos', 'Unit Conversion', 'Translation', 'Shopping Product Comparison', 'Restaurants', 'Product', 'Information', 'Images', 'Gift', 'General Comparison', 'Flights', 'Answer', 'Aircraft Seat Map']
    
    model_dir = context.model_dir if hasattr(context, 'model_dir') else os.environ.get('SM_MODEL_DIR', '/opt/ml/model')
    
    # Load config and model
    config_path = os.path.join(model_dir, 'config.json')
    model_path = os.path.join(model_dir, 'model.pth')
    
    with open(config_path, 'r') as f:
        config = json.load(f)
    
    # Initialize the encoder and labels
    global_labels = labels
    labels_np = np.array(global_labels).reshape(-1, 1)
    global_encoder = OneHotEncoder(sparse_output=False)
    global_encoder.fit(labels_np)

    model = ImprovedBERTClass()
    model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
    model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
    model.eval()
    
    # Load tokenizer
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    
    return model, tokenizer


def decode_vector(vector):
    global global_encoder
    original_label = global_encoder.inverse_transform(vector)
    return original_label[0][0]  # Returns the label as a string