DinoLiu commited on
Commit
9bd5d2c
1 Parent(s): 73ff1a7
Files changed (1) hide show
  1. handler.py +75 -26
handler.py CHANGED
@@ -1,15 +1,11 @@
1
  import os
2
- import sys
3
  import json
4
  import torch
5
- from ts.torch_handler.base_handler import BaseHandler
6
  from transformers import BertTokenizer
7
-
8
- # Add the model directory to the Python path
9
- model_dir = os.path.dirname(os.path.abspath(__file__))
10
- sys.path.append(model_dir)
11
-
12
- from model import ImprovedBERTClass # Ensure this import matches your model file name
13
 
14
  class UICardMappingHandler(BaseHandler):
15
  def __init__(self):
@@ -20,43 +16,96 @@ class UICardMappingHandler(BaseHandler):
20
  self.manifest = context.manifest
21
  properties = context.system_properties
22
  model_dir = properties.get("model_dir")
23
- self.device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu")
24
-
25
- self.tokenizer = BertTokenizer.from_pretrained(model_dir)
 
 
 
 
 
 
 
 
 
 
 
26
  self.model = ImprovedBERTClass()
27
  self.model.load_state_dict(torch.load(os.path.join(model_dir, 'model.pth'), map_location=self.device))
28
  self.model.to(self.device)
29
  self.model.eval()
30
-
 
 
 
31
  self.initialized = True
32
 
33
  def preprocess(self, data):
34
- text = data[0].get("data")
35
- if text is None:
36
- text = data[0].get("body")
37
- inputs = self.tokenizer(text, return_tensors="pt", max_length=64, padding='max_length', truncation=True)
38
- return inputs.to(self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- def inference(self, inputs):
41
  with torch.no_grad():
42
- outputs = self.model(**inputs)
43
- return torch.sigmoid(outputs.logits)
 
44
 
45
  def postprocess(self, inference_output):
46
- probabilities = inference_output.cpu().numpy().flatten()
47
- labels = ['Videos', 'Unit Conversion', 'Translation', 'Shopping Product Comparison', 'Restaurants', 'Product', 'Information', 'Images', 'Gift', 'General Comparison', 'Flights', 'Answer', 'Aircraft Seat Map']
48
 
49
- top_k = 3 # You can adjust this value
50
- top_k_indices = probabilities.argsort()[-top_k:][::-1]
51
  top_k_probs = probabilities[top_k_indices]
52
 
53
- top_k_predictions = [{"card": labels[i], "probability": float(p)} for i, p in zip(top_k_indices, top_k_probs)]
 
 
 
54
 
55
- most_likely_card = "Answer" if sum(probabilities > 0.5) == 0 else labels[probabilities.argmax()]
 
56
 
 
 
 
 
 
 
 
 
 
 
 
57
  result = {
58
  "most_likely_card": most_likely_card,
59
  "top_k_predictions": top_k_predictions
60
  }
61
 
62
  return [result]
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
 
2
  import json
3
  import torch
4
+ import numpy as np
5
  from transformers import BertTokenizer
6
+ from ts.torch_handler.base_handler import BaseHandler
7
+ from model import ImprovedBERTClass
8
+ from sklearn.preprocessing import OneHotEncoder
 
 
 
9
 
10
  class UICardMappingHandler(BaseHandler):
11
  def __init__(self):
 
16
  self.manifest = context.manifest
17
  properties = context.system_properties
18
  model_dir = properties.get("model_dir")
19
+
20
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+
22
+ # Load config
23
+ with open(os.path.join(model_dir, 'config.json'), 'r') as f:
24
+ self.config = json.load(f)
25
+
26
+ # Initialize encoder and labels
27
+ self.labels = ['Videos', 'Unit Conversion', 'Translation', 'Shopping Product Comparison', 'Restaurants', 'Product', 'Information', 'Images', 'Gift', 'General Comparison', 'Flights', 'Answer', 'Aircraft Seat Map']
28
+ labels_np = np.array(self.labels).reshape(-1, 1)
29
+ self.encoder = OneHotEncoder(sparse_output=False)
30
+ self.encoder.fit(labels_np)
31
+
32
+ # Load model
33
  self.model = ImprovedBERTClass()
34
  self.model.load_state_dict(torch.load(os.path.join(model_dir, 'model.pth'), map_location=self.device))
35
  self.model.to(self.device)
36
  self.model.eval()
37
+
38
+ # Load tokenizer
39
+ self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
40
+
41
  self.initialized = True
42
 
43
  def preprocess(self, data):
44
+ text = data[0].get("body").get("text", "")
45
+ k = data[0].get("body").get("k", 3)
46
+
47
+ inputs = self.tokenizer.encode_plus(
48
+ text,
49
+ add_special_tokens=True,
50
+ max_length=64,
51
+ padding='max_length',
52
+ return_tensors='pt',
53
+ truncation=True
54
+ )
55
+
56
+ return {
57
+ "ids": inputs['input_ids'].to(self.device, dtype=torch.long),
58
+ "mask": inputs['attention_mask'].to(self.device, dtype=torch.long),
59
+ "token_type_ids": inputs['token_type_ids'].to(self.device, dtype=torch.long),
60
+ "k": k
61
+ }
62
 
63
+ def inference(self, data):
64
  with torch.no_grad():
65
+ outputs = self.model(data["ids"], data["mask"], data["token_type_ids"])
66
+ probabilities = torch.sigmoid(outputs)
67
+ return probabilities.cpu().detach().numpy().flatten(), data["k"]
68
 
69
  def postprocess(self, inference_output):
70
+ probabilities, k = inference_output
 
71
 
72
+ # Get top k predictions
73
+ top_k_indices = np.argsort(probabilities)[-k:][::-1]
74
  top_k_probs = probabilities[top_k_indices]
75
 
76
+ # Create one-hot encodings for top k indices
77
+ top_k_one_hot = np.zeros((k, len(probabilities)))
78
+ for i, idx in enumerate(top_k_indices):
79
+ top_k_one_hot[i, idx] = 1
80
 
81
+ # Decode the top k predictions
82
+ top_k_cards = [self.decode_vector(one_hot.reshape(1, -1)) for one_hot in top_k_one_hot]
83
 
84
+ # Create a list of tuples (card, probability) for top k predictions
85
+ top_k_predictions = list(zip(top_k_cards, top_k_probs.tolist()))
86
+
87
+ # Determine the most likely card
88
+ predicted_labels = (probabilities > 0.5).astype(int)
89
+ if sum(predicted_labels) == 0:
90
+ most_likely_card = "Answer"
91
+ else:
92
+ most_likely_card = self.decode_vector(predicted_labels.reshape(1, -1))
93
+
94
+ # Prepare the response
95
  result = {
96
  "most_likely_card": most_likely_card,
97
  "top_k_predictions": top_k_predictions
98
  }
99
 
100
  return [result]
101
+
102
+ def decode_vector(self, vector):
103
+ original_label = self.encoder.inverse_transform(vector)
104
+ return original_label[0][0] # Returns the label as a string
105
+
106
+ def handle(self, data, context):
107
+ self.context = context
108
+ data = self.preprocess(data)
109
+ data = self.inference(data)
110
+ data = self.postprocess(data)
111
+ return data