Spaces:
Configuration error
Configuration error
# -*- coding: utf-8 -*- | |
""" | |
@author:XuMing([email protected]) | |
@description: | |
""" | |
import os | |
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" | |
from transformers import LlamaTokenizer | |
from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model | |
import sentencepiece as spm | |
import argparse | |
def is_chinese(uchar): | |
"""判断一个unicode是否是汉字""" | |
return '\u4e00' <= uchar <= '\u9fa5' | |
def is_chinese_string(string): | |
"""判断是否全为汉字""" | |
return all(is_chinese(c) for c in string) | |
def load_baichuan_vocab(vocab_file): | |
words = set() | |
with open(vocab_file, "r", encoding="utf-8") as f: | |
for line in f: | |
if line.strip(): | |
words.add(line.strip().split()[0]) | |
return words | |
def load_jieba_vocab(jieba_vocab_file): | |
# Read jieba vocab and sort by freq | |
with open(jieba_vocab_file, "r", encoding="utf-8") as f: | |
lines = f.readlines() | |
word_freqs = [line.strip().split() for line in lines] | |
word_freqs.sort(key=lambda x: int(x[1]), reverse=True) | |
return word_freqs | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--base_tokenizer_dir', default=None, type=str, required=True) | |
parser.add_argument('--domain_sp_model_file', default='./domain_sp.model', type=str) | |
parser.add_argument('--baichuan_vocab_file', default="data/vocab/baichuan_vocab.txt", type=str) | |
parser.add_argument('--add_jieba', action='store_true', help='Whether to add jieba vocab.') | |
parser.add_argument('--jieba_word_freq_file', default='data/vocab/word_freq.txt', type=str) | |
parser.add_argument('--jieba_word_size', default=20000, type=int) | |
args = parser.parse_args() | |
print(args) | |
# load | |
llama_tokenizer = LlamaTokenizer.from_pretrained(args.base_tokenizer_dir) | |
chinese_sp_model = spm.SentencePieceProcessor() | |
chinese_sp_model.Load(args.domain_sp_model_file) | |
llama_spm = sp_pb2_model.ModelProto() | |
llama_spm.ParseFromString(llama_tokenizer.sp_model.serialized_model_proto()) | |
chinese_spm = sp_pb2_model.ModelProto() | |
chinese_spm.ParseFromString(chinese_sp_model.serialized_model_proto()) | |
# print number of tokens | |
print(len(llama_tokenizer), len(chinese_sp_model)) | |
print(llama_tokenizer.all_special_tokens) | |
print(llama_tokenizer.all_special_ids) | |
print(llama_tokenizer.special_tokens_map) | |
# Add Chinese tokens to LLaMA tokenizer | |
llama_spm_tokens_set = set(p.piece for p in llama_spm.pieces) | |
print(len(llama_spm_tokens_set)) | |
print(f"Before:{len(llama_spm_tokens_set)}") | |
added_set = set() | |
for p in chinese_spm.pieces: | |
piece = p.piece | |
if piece not in llama_spm_tokens_set: | |
# print('picec', piece) | |
new_p = sp_pb2_model.ModelProto().SentencePiece() | |
new_p.piece = piece | |
new_p.score = 0 | |
llama_spm.pieces.append(new_p) | |
added_set.add(piece) | |
print(f"[add domain tokens]New model pieces: {len(llama_spm.pieces)}") | |
vocab = load_baichuan_vocab(args.baichuan_vocab_file) | |
print('baichuan vocab len:', len(vocab)) | |
baichuan_vocab_set = set([i for i in vocab if is_chinese_string(i)]) | |
print('baichuan chinese vocab size:', len(baichuan_vocab_set)) | |
print('baichuan vocab head:', list(baichuan_vocab_set)[:10]) | |
for p in baichuan_vocab_set: | |
piece = p | |
if piece not in llama_spm_tokens_set and piece not in added_set: | |
# print('baichuan picec', piece) | |
new_p = sp_pb2_model.ModelProto().SentencePiece() | |
new_p.piece = piece | |
new_p.score = 0 | |
llama_spm.pieces.append(new_p) | |
added_set.add(piece) | |
print(f"[add baichuan tokens]New model pieces: {len(llama_spm.pieces)}") | |
if args.add_jieba: | |
word_freqs = load_jieba_vocab(args.jieba_word_freq_file) | |
top_words = word_freqs[:args.jieba_word_size] | |
print('jieba top10 freq words:', top_words[:10]) | |
jieba_vocab_set = set([i[0] for i in top_words if i]) | |
print('jieba_vocab_set size:', len(jieba_vocab_set)) | |
print('jieba_vocab head:', list(jieba_vocab_set)[:3]) | |
for p in jieba_vocab_set: | |
piece = p | |
if piece not in llama_spm_tokens_set and piece not in added_set: | |
# print('jieba picec', piece) | |
new_p = sp_pb2_model.ModelProto().SentencePiece() | |
new_p.piece = piece | |
new_p.score = 0 | |
llama_spm.pieces.append(new_p) | |
print(f"[add jieba tokens]New model pieces: {len(llama_spm.pieces)}") | |
# Save | |
output_sp_dir = 'merged_tokenizer_sp' | |
output_hf_dir = 'merged_tokenizer_hf' # the path to save Chinese-LLaMA tokenizer | |
os.makedirs(output_sp_dir, exist_ok=True) | |
with open(output_sp_dir + '/chinese_llama.model', 'wb') as f: | |
f.write(llama_spm.SerializeToString()) | |
tokenizer = LlamaTokenizer(vocab_file=output_sp_dir + '/chinese_llama.model') | |
tokenizer.save_pretrained(output_hf_dir) | |
print(f"Chinese-LLaMA tokenizer has been saved to {output_hf_dir}") | |
# Test | |
llama_tokenizer = LlamaTokenizer.from_pretrained(args.base_tokenizer_dir) | |
chinese_llama_tokenizer = LlamaTokenizer.from_pretrained(output_hf_dir) | |
print(chinese_llama_tokenizer.all_special_tokens) | |
print(chinese_llama_tokenizer.all_special_ids) | |
print(chinese_llama_tokenizer.special_tokens_map) | |
print('old len:', len(llama_tokenizer), ' new len:', len(chinese_llama_tokenizer)) | |
text = '''this is a test, hello world. thisisatesthelloworld, | |
慕容复来到河边,姑苏慕容氏在外面丢了人。 | |
1号店一周岁了,我们一古脑儿买了10斤零食。 | |
巴塞罗那足球俱乐部简称巴萨(Barça),是一家位于西班牙加泰罗尼亚巴塞罗那的足球俱乐部,于1899年由瑞士企业家胡安·甘伯所创立,世界球坛顶级足球俱乐部之一。俱乐部主场可容纳接近十万名观众,是全欧洲最大及世界第二大的足球场。 | |
白日依山尽,黄河入海流。欲穷千里目,更上一层楼。''' | |
print("Test text:\n", text) | |
print(f"Tokenized by LLaMA tokenizer:{llama_tokenizer.tokenize(text)}") | |
print(f"Tokenized by Chinese-LLaMA tokenizer:{chinese_llama_tokenizer.tokenize(text)}") | |
if __name__ == '__main__': | |
main() | |