File size: 4,512 Bytes
8cb4f3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os

#import utils.save as saver
#import models
#from models.transformer import Transformer
#from modules.config import find_all_config

class TransformerHandlerClass:
  def __init__(self):
    self.model = None
    self.device = None
    self.initialized = False

  def _find_checkpoint(self, model_dir, best_model_prefix="best_model", model_prefix="model", validate=True):
    """Attempt to retrieve the best model checkpoint from model_dir. Failing that, the model of the latest iteration.
    Args:
      model_dir: location to search for checkpoint. str
    Returns:
      single str denoting the checkpoint path """
    score_file_path = os.path.join(model_dir, saver.BEST_MODEL_FILE)
    if(os.path.isfile(score_file_path)): # score exist -> best model
      best_model_path = os.path.join(model_dir, saver.MODEL_FILE_FORMAT.format(best_model_prefix, 0, saver.MODEL_EXTENSION))
      if(validate):
        assert os.path.isfile(best_model_path), "Score file is available, but file {:s} is missing.".format(best_model_path)
      return best_model_path
    else: # score not exist -> latest model
      last_checkpoint_idx = saver.check_model_in_path(name_prefix=model_prefix)
      if(last_checkpoint_idx == 0):
        raise ValueError("No checkpoint found in folder {:s} with prefix {:s}.".format(model_dir, model_prefix))
      else:
        return os.path.join(model_dir, saver.MODEL_FILE_FORMAT.format(model_prefix, last_checkpoint_idx, saver.MODEL_EXTENSION))


  def initialize(self, ctx):
    manifest = ctx.manifest
    properties = ctx.system_properties

    self.device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu")
    self.model_dir = model_dir = properties.get("model_dir")

    # extract checkpoint location, config & model name
    model_serve_file = os.path.join(model_dir, saver.MODEL_SERVE_FILE)
    with io.open(model_serve_file, "r") as serve_config:
      model_name = serve_config.read().strip()
#    model_cls = models.AvailableModels[model_name]
    model_cls = Transformer # can't select due to nature of model file
    checkpoint_path = manifest['model'].get('serializedFile', self._find_checkpoint(model_dir)) # attempt to use the checkpoint fed from archiver; else use the best checkpoint found
    config_path = find_all_config(model_dir)

    # load model with inbuilt config + vocab & without pretraining data
    self.model = model = model_cls(config=config_path, model_dir=model_dir, mode="infer")
    model.load_checkpoint(args.model_dir, checkpoint=checkpoint_path) # TODO find_checkpoint might do some redundant thing here since load_checkpoint had already done searching for latest
    
    print("Model {:s} loaded successfully at location {:s}.".format(model_name, model_dir))
    self.initialized = True

  def handle(self, data):
    """The main bulk of handling. Process a batch of data received from client.
    Args: 
      data: the object received from client. Should contain something in [batch_size] of str
    Returns:
      the expected translation, [batch_size] of str
    """
    batch_sentences = data[0].get("data")
#    assert batch_sentences is not None, "data is {}".format(data)
    
    # make sure that sentences are detokenized before returning
    translated_sentences = self.model.translate_batch(batch_sentences, output_tokens=False)

    return translated_sentences

class BeamSearchHandlerClass:
  def __init__(self):
    self.model = None
    self.inferrer = None
    self.initialized = False

  def initialize(self, ctx):
    manifest = ctx.manifest
    properties = ctx.system_properties

    model_dir = properties['model_dir']
    ts_modelpath = manifest['model']['serializedFile'] 
    self.model = ts_model = torch.jit.load(os.path.join(model_dir, ts_modelpath))

    from modules.inference.beam_search import BeamSearch
    device = "cuda" if torch.cuda.is_available() else "cpu"
    self.inferrer = BeamSearch(model, 160, device, beam_size=5)
    
    self.initialized = True

  def handle(self, data):
    batch_sentences = data[0].get("data")
#    assert batch_sentences is not None, "data is {}".format(data)
    
    translated_sentences = self.inferrer.translate_batch_sentence(data, output_tokens=False)
    return translated_sentences

RUNNING_MODEL = BeamSearchHandlerClass()

def handle(data, context):
  if(not RUNNING_MODEL.initialized): # Lazy init
    RUNNING_MODEL.initialize(context)

  if(data is None):
    return None

  return RUNNING_MODEL.handle(data)