|
--- |
|
language: ja |
|
tags: |
|
- ja |
|
- japanese |
|
- gpt2 |
|
- text-generation |
|
- lm |
|
- nlp |
|
license: mit |
|
widget: |
|
- text: 桜が咲く |
|
datasets: |
|
- skytnt/japanese-lyric |
|
--- |
|
|
|
# Japanese GPT2 Lyric Model |
|
|
|
## Model description |
|
|
|
The model is used to generate Japanese lyrics. |
|
|
|
You can try it on my website [https://lyric.fab.moe/](https://lyric.fab.moe/#/) |
|
|
|
## How to use |
|
|
|
```python |
|
import torch |
|
from transformers import T5Tokenizer, GPT2LMHeadModel |
|
|
|
tokenizer = T5Tokenizer.from_pretrained("skytnt/gpt2-japanese-lyric-small") |
|
model = GPT2LMHeadModel.from_pretrained("skytnt/gpt2-japanese-lyric-small") |
|
|
|
|
|
def gen_lyric(prompt_text: str): |
|
prompt_text = "<s>" + prompt_text.replace("\n", "\\n ") |
|
prompt_tokens = tokenizer.tokenize(prompt_text) |
|
prompt_token_ids = tokenizer.convert_tokens_to_ids(prompt_tokens) |
|
prompt_tensor = torch.LongTensor(prompt_token_ids).to(device) |
|
prompt_tensor = prompt_tensor.view(1, -1) |
|
# model forward |
|
output_sequences = model.generate( |
|
input_ids=prompt_tensor, |
|
max_length=512, |
|
top_p=0.95, |
|
top_k=40, |
|
temperature=1.0, |
|
do_sample=True, |
|
early_stopping=True, |
|
bos_token_id=tokenizer.bos_token_id, |
|
eos_token_id=tokenizer.eos_token_id, |
|
pad_token_id=tokenizer.pad_token_id, |
|
num_return_sequences=1 |
|
) |
|
|
|
# convert model outputs to readable sentence |
|
generated_sequence = output_sequences.tolist()[0] |
|
generated_tokens = tokenizer.convert_ids_to_tokens(generated_sequence) |
|
generated_text = tokenizer.convert_tokens_to_string(generated_tokens) |
|
generated_text = "\n".join([s.strip() for s in generated_text.split('\\n')]).replace(' ', '\u3000').replace('<s>', '').replace('</s>', '\n\n---end---') |
|
return generated_text |
|
|
|
|
|
print(gen_lyric("桜が咲く")) |
|
|
|
``` |
|
|
|
## Training data |
|
|
|
[Training data](https://huggingface.co/datasets/skytnt/japanese-lyric/blob/main/lyric_clean.pkl) contains 143,587 Japanese lyrics which are collected from [uta-net](https://www.uta-net.com/) by [lyric_download](https://github.com/SkyTNT/lyric_downlowd) |