|
--- |
|
language: ja |
|
tags: |
|
- ja |
|
- japanese |
|
- gpt2 |
|
- text-generation |
|
- lm |
|
- nlp |
|
license: mit |
|
widget: |
|
- text: <s>桜[CLS] |
|
datasets: |
|
- skytnt/japanese-lyric |
|
--- |
|
|
|
# Japanese GPT2 Lyric Model |
|
|
|
## Model description |
|
|
|
The model is used to generate Japanese lyrics. |
|
|
|
## How to use |
|
|
|
```python |
|
import torch |
|
from transformers import T5Tokenizer, GPT2LMHeadModel |
|
|
|
device = torch.device("cpu") |
|
if torch.cuda.is_available(): |
|
device = torch.device("cuda") |
|
|
|
tokenizer = T5Tokenizer.from_pretrained("skytnt/gpt2-japanese-lyric-medium") |
|
model = GPT2LMHeadModel.from_pretrained("skytnt/gpt2-japanese-lyric-medium") |
|
model = model.to(device) |
|
|
|
def gen_lyric(title: str, prompt_text: str): |
|
if len(title)!= 0 or len(prompt_text)!= 0: |
|
prompt_text = "<s>" + title + "[CLS]" + prompt_text |
|
prompt_text = prompt_text.replace("\n", "\\n ") |
|
prompt_tokens = tokenizer.tokenize(prompt_text) |
|
prompt_token_ids = tokenizer.convert_tokens_to_ids(prompt_tokens) |
|
prompt_tensor = torch.LongTensor(prompt_token_ids) |
|
prompt_tensor = prompt_tensor.view(1, -1).to(device) |
|
else: |
|
prompt_tensor = None |
|
# model forward |
|
output_sequences = model.generate( |
|
input_ids=prompt_tensor, |
|
max_length=512, |
|
top_p=0.95, |
|
top_k=40, |
|
temperature=1.0, |
|
do_sample=True, |
|
early_stopping=True, |
|
bos_token_id=tokenizer.bos_token_id, |
|
eos_token_id=tokenizer.eos_token_id, |
|
pad_token_id=tokenizer.pad_token_id, |
|
num_return_sequences=1 |
|
) |
|
|
|
# convert model outputs to readable sentence |
|
generated_sequence = output_sequences.tolist()[0] |
|
generated_tokens = tokenizer.convert_ids_to_tokens(generated_sequence) |
|
generated_text = tokenizer.convert_tokens_to_string(generated_tokens) |
|
generated_text = "\n".join([s.strip() for s in generated_text.split('\\n')]).replace(' ', '\u3000').replace('<s>', '').replace('</s>', '\n\n---end---') |
|
title_and_lyric = generated_text.split("[CLS]",1) |
|
if len(title_and_lyric)==1: |
|
title,lyric = "" , title_and_lyric[0].strip() |
|
else: |
|
title,lyric = title_and_lyric[0].strip(), title_and_lyric[1].strip() |
|
return f"---{title}---\n\n{lyric}" |
|
|
|
|
|
print(gen_lyric("桜","")) |
|
|
|
``` |
|
|
|
## Training data |
|
|
|
[Training data](https://huggingface.co/datasets/skytnt/japanese-lyric/blob/main/lyric_clean.pkl) contains 143,587 Japanese lyrics which are collected from [uta-net](https://www.uta-net.com/) by [lyric_download](https://github.com/SkyTNT/lyric_downlowd) |