File size: 6,208 Bytes
04b3693
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#! /usr/bin/python3
#pip3 install transformers accelerate deepspeed triton datasets fugashi unidic-lite
tgt="modernbert-base-japanese-wikipedia"
import os,json
os.system("""
if test -d transformers
then :
else git clone --depth=1 https://github.com/huggingface/transformers transformers-all
     ln -s transformers-all/src/transformers transformers
fi
test -d ModernBERT-base || git clone --depth=1 https://huggingface.co/answerdotai/ModernBERT-base
test -f ModernBERT-base/configuration_modernbert.py || sed 's/^from \\.\\.\\./from transformers./' transformers/models/modernbert/configuration_modernbert.py > ModernBERT-base/configuration_modernbert.py
test -f ModernBERT-base/modeling_modernbert.py || sed -e 's/^from \\.\\.\\./from transformers./' -e 's/^from .* import is_triton_available/import importlib\\nis_triton_available = lambda: importlib.util.find_spec("triton") is not None/' transformers/models/modernbert/modeling_modernbert.py > ModernBERT-base/modeling_modernbert.py
""")
with open("ModernBERT-base/config.json","r",encoding="utf-8") as r:
  d=json.load(r)
if not "auto_map" in d:
  d["auto_map"]={
    "AutoConfig":"configuration_modernbert.ModernBertConfig",
    "AutoModel":"modeling_modernbert.ModernBertModel",
    "AutoModelForMaskedLM":"modeling_modernbert.ModernBertForMaskedLM",
    "AutoModelForSequenceClassification":"modeling_modernbert.ModernBertForSequenceClassification",
    "AutoModelForTokenClassification":"modeling_modernbert.ModernBertForTokenClassification"
  }
  with open("ModernBERT-base/config.json","w",encoding="utf-8") as w:
    json.dump(d,w,indent=2)
if not os.path.isfile("train.txt"):
  from datasets import load_dataset
  aug=lambda x:(x.replace("侠","俠").replace("倶","俱").replace("洗","冼").replace("剥","剝").replace("即","卽").replace("呑","吞").replace("呉","吳").replace("填","塡").replace("巣","巢").replace("徴","徵").replace("徳","德").replace("掲","揭").replace("撃","擊").replace("教","敎").replace("晩","晚").replace("横","橫").replace("歩","步").replace("歴","歷").replace("毎","每").replace("冷","泠").replace("渉","涉").replace("涙","淚").replace("清","淸").replace("渇","渴").replace("温","溫").replace("状","狀").replace("産","產").replace("痩","瘦").replace("禰","祢").replace("箪","簞").replace("緑","綠").replace("緒","緖").replace("縁","緣").replace("繋","繫").replace("莱","萊").replace("薫","薰").replace("虚","虛").replace("蝉","蟬").replace("説","說").replace("躯","軀").replace("郎","郞").replace("醤","醬").replace("録","錄").replace("錬","鍊").replace("間","閒").replace("頬","頰").replace("顛","顚").replace("鴎","鷗").replace("麺","麵").replace("黄","黃").replace("黒","黑").replace("叱","𠮟"))
  with open("train.txt","w",encoding="utf-8") as w:
    d,u,v=load_dataset("globis-university/aozorabunko-clean"),"",""
    for t in d["train"]:
      for s in t["text"].replace("。","。\n").replace("\u3000"," ").split("\n"):
        r=aug(s)
        if r!=s:
          if len(r)+len(v)<10000:
            v+=r 
          else:
            print(v,file=w)
            v=r
        if len(s)+len(u)<10000:
          u+=s
        else:
          print(u,file=w)
          u=s
    print(u,v,file=w)
    d,u=load_dataset("wikimedia/wikipedia","20231101.ja"),""
    for t in d["train"]:
      for s in t["text"].replace("。","。\n").replace("\u3000"," ").split("\n"):
        if len(s)+len(u)<10000:
          u+=s
        else:
          print(u,file=w)
          u=s
    print(u,file=w)
os.system("test -s token.txt || fugashi -Owakati < train.txt > token.txt")

from transformers import DebertaV2TokenizerFast
if not os.path.isfile("tokenizer.json"):
  import urllib.request
  from tokenizers import Tokenizer,models,pre_tokenizers,normalizers,processors,decoders,trainers
  with urllib.request.urlopen("https://www.unicode.org/wg2/iso10646/edition6/data/JapaneseCoreKanji.txt") as r:
    joyo=[chr(int(t,16)) for t in r.read().decode().strip().split("\n") if not t.startswith("#")]
  spt=Tokenizer(models.Unigram())
  spt.pre_tokenizer=pre_tokenizers.Sequence([pre_tokenizers.Whitespace(),pre_tokenizers.Punctuation()])
  spt.normalizer=normalizers.Sequence([normalizers.Nmt(),normalizers.NFKC()])
  spt.post_processor=processors.TemplateProcessing(single="[CLS] $A [SEP]",pair="[CLS] $A [SEP] $B:1 [SEP]:1",special_tokens=[("[CLS]",0),("[SEP]",2)])
  spt.decoder=decoders.WordPiece(prefix="",cleanup=True)
  spt.train(trainer=trainers.UnigramTrainer(vocab_size=65000,max_piece_length=4,initial_alphabet=joyo,special_tokens=["[CLS]","[PAD]","[SEP]","[UNK]","[MASK]"],unk_token="[UNK]",n_sub_iterations=2),files=["token.txt"])
  spt.save("tokenizer.json")
tkz=DebertaV2TokenizerFast(tokenizer_file="tokenizer.json",split_by_punct=True,do_lower_case=False,keep_accents=True,vocab_file="/dev/null")
tkz.save_pretrained(tgt)
with open("train.py","w",encoding="utf-8") as w:
  print(f'#! /usr/bin/env deepspeed\ntgt="{tgt}"'+'''
from transformers import DebertaV2TokenizerFast,ModernBertForMaskedLM,AutoConfig,DataCollatorForLanguageModeling,TrainingArguments,Trainer
tkz=DebertaV2TokenizerFast.from_pretrained(tgt)
c={"trust_remote_code":True,"vocab_size":len(tkz),"tokenizer_class":type(tkz).__name__}
for k,v in tkz.special_tokens_map.items():
  c[k+"_id"]=tkz.convert_tokens_to_ids(v)
cfg=AutoConfig.from_pretrained("ModernBERT-base",**c)
arg=TrainingArguments(num_train_epochs=3,per_device_train_batch_size=1,output_dir=tgt,overwrite_output_dir=True,save_total_limit=2,save_safetensors=False)
class ReadLineDS(object):
  def __init__(self,file,tokenizer):
    self.tokenizer=tokenizer
    with open(file,"r",encoding="utf-8") as r:
      self.lines=[s.strip() for s in r if s.strip()>""]
  __len__=lambda self:len(self.lines)
  __getitem__=lambda self,i:self.tokenizer(self.lines[i],truncation=True,add_special_tokens=True,max_length=8190)
trn=Trainer(args=arg,data_collator=DataCollatorForLanguageModeling(tkz),model=ModernBertForMaskedLM(cfg),train_dataset=ReadLineDS("train.txt",tkz))
trn.train()
trn.save_model(tgt)''',file=w)
os.system("chmod 755 train.py ; ./train.py")
os.system(f"cp ModernBERT-base/*.py {tgt}")