|
|
|
|
|
tgt="modernbert-base-japanese-char" |
|
import os |
|
os.system(""" |
|
if test -d transformers |
|
then : |
|
else git clone --depth=1 https://github.com/huggingface/transformers transformers-all |
|
ln -s transformers-all/src/transformers transformers |
|
fi""") |
|
if not os.path.isfile("train.txt"): |
|
from datasets import load_dataset |
|
aug=lambda x:(x.replace("侠","俠").replace("倶","俱").replace("洗","冼").replace("剥","剝").replace("即","卽").replace("呑","吞").replace("呉","吳").replace("填","塡").replace("巣","巢").replace("徴","徵").replace("徳","德").replace("掲","揭").replace("撃","擊").replace("教","敎").replace("晩","晚").replace("横","橫").replace("歩","步").replace("歴","歷").replace("毎","每").replace("冷","泠").replace("渉","涉").replace("涙","淚").replace("清","淸").replace("渇","渴").replace("温","溫").replace("状","狀").replace("産","產").replace("痩","瘦").replace("禰","祢").replace("箪","簞").replace("緑","綠").replace("緒","緖").replace("縁","緣").replace("繋","繫").replace("莱","萊").replace("薫","薰").replace("虚","虛").replace("蝉","蟬").replace("説","說").replace("躯","軀").replace("郎","郞").replace("醤","醬").replace("録","錄").replace("錬","鍊").replace("間","閒").replace("頬","頰").replace("顛","顚").replace("鴎","鷗").replace("麺","麵").replace("黄","黃").replace("黒","黑").replace("叱","𠮟")) |
|
with open("train.txt","w",encoding="utf-8") as w: |
|
d,u,v=load_dataset("globis-university/aozorabunko-clean"),"","" |
|
for t in d["train"]: |
|
for s in t["text"].replace("。","。\n").replace("\u3000"," ").split("\n"): |
|
r=aug(s) |
|
if r!=s: |
|
if len(r)+len(v)<8200: |
|
v+=r |
|
else: |
|
print(v,file=w) |
|
v=r |
|
if len(s)+len(u)<8200: |
|
u+=s |
|
else: |
|
print(u,file=w) |
|
u=s |
|
print(u,v,file=w) |
|
d,u=load_dataset("wikimedia/wikipedia","20231101.ja"),"" |
|
for t in d["train"]: |
|
for s in t["text"].replace("。","。\n").replace("\u3000"," ").split("\n"): |
|
if len(s)+len(u)<8200: |
|
u+=s |
|
else: |
|
print(u,file=w) |
|
u=s |
|
print(u,file=w) |
|
|
|
from transformers import BertTokenizerFast |
|
from tokenizers.pre_tokenizers import Sequence,Whitespace,Split |
|
from tokenizers import Regex |
|
s=["[CLS]","[PAD]","[SEP]","[UNK]","[MASK]"] |
|
if not os.path.isfile("vocab.txt"): |
|
import urllib.request |
|
with open("train.txt","r",encoding="utf-8") as r: |
|
v=set(c for c in r.read() if not c.isspace()) |
|
with urllib.request.urlopen("https://www.unicode.org/wg2/iso10646/edition6/data/JapaneseCoreKanji.txt") as r: |
|
v=v.union([chr(int(t,16)) for t in r.read().decode().strip().split("\n") if not t.startswith("#")]) |
|
with open("vocab.txt","w",encoding="utf-8") as w: |
|
print("\n".join(s+sorted(v)),file=w) |
|
tkz=BertTokenizerFast(vocab_file="vocab.txt",never_split=s,do_lower_case=False,strip_accents=False,tokenize_chinese_chars=True) |
|
tkz.backend_tokenizer.pre_tokenizer=Sequence([Whitespace(),Split(Regex("."),"isolated")]) |
|
tkz.backend_tokenizer.decoder.prefix=tkz.backend_tokenizer.model.continuing_subword_prefix="" |
|
tkz.save_pretrained(tgt) |
|
|
|
with open("train.py","w",encoding="utf-8") as w: |
|
print(f'#! /usr/bin/env deepspeed\ntgt="{tgt}"'+''' |
|
from transformers import BertTokenizerFast,ModernBertForMaskedLM,AutoConfig,DataCollatorForLanguageModeling,TrainingArguments,Trainer |
|
tkz=BertTokenizerFast.from_pretrained(tgt) |
|
c={"trust_remote_code":True,"vocab_size":len(tkz),"tokenizer_class":type(tkz).__name__} |
|
for k,v in tkz.special_tokens_map.items(): |
|
c[k+"_id"]=tkz.convert_tokens_to_ids(v) |
|
cfg=AutoConfig.from_pretrained("KoichiYasuoka/modernbert-base-japanese-wikipedia",**c) |
|
arg=TrainingArguments(num_train_epochs=3,per_device_train_batch_size=1,output_dir=tgt,overwrite_output_dir=True,save_total_limit=2,save_safetensors=False) |
|
class ReadLineDS(object): |
|
def __init__(self,file,tokenizer): |
|
self.tokenizer=tokenizer |
|
with open(file,"r",encoding="utf-8") as r: |
|
self.lines=[s.strip() for s in r if s.strip()>""] |
|
__len__=lambda self:len(self.lines) |
|
__getitem__=lambda self,i:self.tokenizer(self.lines[i],truncation=True,add_special_tokens=True,max_length=8190) |
|
trn=Trainer(args=arg,data_collator=DataCollatorForLanguageModeling(tkz),model=ModernBertForMaskedLM(cfg),train_dataset=ReadLineDS("train.txt",tkz)) |
|
trn.train() |
|
trn.save_model(tgt)''',file=w) |
|
os.system("chmod 755 train.py ; ./train.py") |
|
|