Spaces:
No application file
No application file
import os | |
#import utils.save as saver | |
#import models | |
#from models.transformer import Transformer | |
#from modules.config import find_all_config | |
class TransformerHandlerClass: | |
def __init__(self): | |
self.model = None | |
self.device = None | |
self.initialized = False | |
def _find_checkpoint(self, model_dir, best_model_prefix="best_model", model_prefix="model", validate=True): | |
"""Attempt to retrieve the best model checkpoint from model_dir. Failing that, the model of the latest iteration. | |
Args: | |
model_dir: location to search for checkpoint. str | |
Returns: | |
single str denoting the checkpoint path """ | |
score_file_path = os.path.join(model_dir, saver.BEST_MODEL_FILE) | |
if(os.path.isfile(score_file_path)): # score exist -> best model | |
best_model_path = os.path.join(model_dir, saver.MODEL_FILE_FORMAT.format(best_model_prefix, 0, saver.MODEL_EXTENSION)) | |
if(validate): | |
assert os.path.isfile(best_model_path), "Score file is available, but file {:s} is missing.".format(best_model_path) | |
return best_model_path | |
else: # score not exist -> latest model | |
last_checkpoint_idx = saver.check_model_in_path(name_prefix=model_prefix) | |
if(last_checkpoint_idx == 0): | |
raise ValueError("No checkpoint found in folder {:s} with prefix {:s}.".format(model_dir, model_prefix)) | |
else: | |
return os.path.join(model_dir, saver.MODEL_FILE_FORMAT.format(model_prefix, last_checkpoint_idx, saver.MODEL_EXTENSION)) | |
def initialize(self, ctx): | |
manifest = ctx.manifest | |
properties = ctx.system_properties | |
self.device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu") | |
self.model_dir = model_dir = properties.get("model_dir") | |
# extract checkpoint location, config & model name | |
model_serve_file = os.path.join(model_dir, saver.MODEL_SERVE_FILE) | |
with io.open(model_serve_file, "r") as serve_config: | |
model_name = serve_config.read().strip() | |
# model_cls = models.AvailableModels[model_name] | |
model_cls = Transformer # can't select due to nature of model file | |
checkpoint_path = manifest['model'].get('serializedFile', self._find_checkpoint(model_dir)) # attempt to use the checkpoint fed from archiver; else use the best checkpoint found | |
config_path = find_all_config(model_dir) | |
# load model with inbuilt config + vocab & without pretraining data | |
self.model = model = model_cls(config=config_path, model_dir=model_dir, mode="infer") | |
model.load_checkpoint(args.model_dir, checkpoint=checkpoint_path) # TODO find_checkpoint might do some redundant thing here since load_checkpoint had already done searching for latest | |
print("Model {:s} loaded successfully at location {:s}.".format(model_name, model_dir)) | |
self.initialized = True | |
def handle(self, data): | |
"""The main bulk of handling. Process a batch of data received from client. | |
Args: | |
data: the object received from client. Should contain something in [batch_size] of str | |
Returns: | |
the expected translation, [batch_size] of str | |
""" | |
batch_sentences = data[0].get("data") | |
# assert batch_sentences is not None, "data is {}".format(data) | |
# make sure that sentences are detokenized before returning | |
translated_sentences = self.model.translate_batch(batch_sentences, output_tokens=False) | |
return translated_sentences | |
class BeamSearchHandlerClass: | |
def __init__(self): | |
self.model = None | |
self.inferrer = None | |
self.initialized = False | |
def initialize(self, ctx): | |
manifest = ctx.manifest | |
properties = ctx.system_properties | |
model_dir = properties['model_dir'] | |
ts_modelpath = manifest['model']['serializedFile'] | |
self.model = ts_model = torch.jit.load(os.path.join(model_dir, ts_modelpath)) | |
from modules.inference.beam_search import BeamSearch | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.inferrer = BeamSearch(model, 160, device, beam_size=5) | |
self.initialized = True | |
def handle(self, data): | |
batch_sentences = data[0].get("data") | |
# assert batch_sentences is not None, "data is {}".format(data) | |
translated_sentences = self.inferrer.translate_batch_sentence(data, output_tokens=False) | |
return translated_sentences | |
RUNNING_MODEL = BeamSearchHandlerClass() | |
def handle(data, context): | |
if(not RUNNING_MODEL.initialized): # Lazy init | |
RUNNING_MODEL.initialize(context) | |
if(data is None): | |
return None | |
return RUNNING_MODEL.handle(data) | |