import torch from transformers import T5Tokenizer, GPT2LMHeadModel from flask import Flask, request, jsonify import cutlet convertors = {} for romaji_sys in ["hepburn", "kunrei", "nippon"]: convertors[romaji_sys] = cutlet.Cutlet(romaji_sys) device = torch.device("cpu") if torch.cuda.is_available(): device = torch.device("cuda") tokenizer = T5Tokenizer.from_pretrained("skytnt/gpt2-japanese-lyric-medium") model = GPT2LMHeadModel.from_pretrained("skytnt/gpt2-japanese-lyric-medium") model = model.to(device) def gen_lyric(title: str, prompt_text: str): if len(title) != 0 or len(prompt_text) != 0: prompt_text = "" + title + "[CLS]" + prompt_text prompt_text = prompt_text.replace("\n", "\\n ") prompt_tokens = tokenizer.tokenize(prompt_text) prompt_token_ids = tokenizer.convert_tokens_to_ids(prompt_tokens) prompt_tensor = torch.LongTensor(prompt_token_ids) prompt_tensor = prompt_tensor.view(1, -1).to(device) else: prompt_tensor = None # model forward output_sequences = model.generate( input_ids=prompt_tensor, max_length=512, top_p=0.95, top_k=40, temperature=1.0, do_sample=True, early_stopping=True, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id, num_return_sequences=1 ) # convert model outputs to readable sentence generated_sequence = output_sequences.tolist()[0] generated_tokens = tokenizer.convert_ids_to_tokens(generated_sequence) generated_text = tokenizer.convert_tokens_to_string(generated_tokens) generated_text = "\n".join([s.strip() for s in generated_text.split('\\n')]).replace(' ', '\u3000').replace('', '').replace( '', '\n\n---end---') title_and_lyric = generated_text.split("[CLS]", 1) if len(title_and_lyric) == 1: title, lyric = "", title_and_lyric[0].strip() else: title, lyric = title_and_lyric[0].strip(), title_and_lyric[1].strip() return title, lyric app = Flask(__name__, static_url_path="", static_folder="frontend/dist") @app.route('/') def index_page(): return app.send_static_file("index.html") @app.route('/gen', methods=["POST"]) def generate(): if request.method == "POST": try: data = request.get_json() title = data['title'] text = data['text'] title, lyric = gen_lyric(title, text) result = { "state": 200, "title": title, "lyric": lyric } except Exception as e: result = { "state": 400, "msg": f"{e}" } return jsonify(result), result["state"] @app.route('/romaji', methods=["POST"]) def romaji(): if request.method == "POST": try: data = request.get_json() text = data['text'] system = data['system'] lines = [] # 不支持带换行符的直接转换 for line in text.split("\n"): lines.append(convertors[system].romaji(line)) result = { "state": 200, "romaji": "\n".join(lines), } except Exception as e: result = { "state": 400, "msg": f"{e}" } return jsonify(result), result["state"] if __name__ == '__main__': app.run(host="0.0.0.0", port=7860, debug=False, use_reloader=False)