-
Notifications
You must be signed in to change notification settings - Fork 507
/
merge_tokenizers.py
150 lines (126 loc) · 6.18 KB
/
merge_tokenizers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# -*- coding: utf-8 -*-
"""
@author:XuMing([email protected])
@description:
"""
import os
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
from transformers import LlamaTokenizer
from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model
import sentencepiece as spm
import argparse
def is_chinese(uchar):
"""判断一个unicode是否是汉字"""
return '\u4e00' <= uchar <= '\u9fa5'
def is_chinese_string(string):
"""判断是否全为汉字"""
return all(is_chinese(c) for c in string)
def load_baichuan_vocab(vocab_file):
words = set()
with open(vocab_file, "r", encoding="utf-8") as f:
for line in f:
if line.strip():
words.add(line.strip().split()[0])
return words
def load_jieba_vocab(jieba_vocab_file):
# Read jieba vocab and sort by freq
with open(jieba_vocab_file, "r", encoding="utf-8") as f:
lines = f.readlines()
word_freqs = [line.strip().split() for line in lines]
word_freqs.sort(key=lambda x: int(x[1]), reverse=True)
return word_freqs
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--base_tokenizer_dir', default=None, type=str, required=True)
parser.add_argument('--domain_sp_model_file', default='./domain_sp.model', type=str)
parser.add_argument('--baichuan_vocab_file', default="data/vocab/baichuan_vocab.txt", type=str)
parser.add_argument('--add_jieba', action='store_true', help='Whether to add jieba vocab.')
parser.add_argument('--jieba_word_freq_file', default='data/vocab/word_freq.txt', type=str)
parser.add_argument('--jieba_word_size', default=20000, type=int)
args = parser.parse_args()
print(args)
# load
llama_tokenizer = LlamaTokenizer.from_pretrained(args.base_tokenizer_dir)
chinese_sp_model = spm.SentencePieceProcessor()
chinese_sp_model.Load(args.domain_sp_model_file)
llama_spm = sp_pb2_model.ModelProto()
llama_spm.ParseFromString(llama_tokenizer.sp_model.serialized_model_proto())
chinese_spm = sp_pb2_model.ModelProto()
chinese_spm.ParseFromString(chinese_sp_model.serialized_model_proto())
# print number of tokens
print(len(llama_tokenizer), len(chinese_sp_model))
print(llama_tokenizer.all_special_tokens)
print(llama_tokenizer.all_special_ids)
print(llama_tokenizer.special_tokens_map)
# Add Chinese tokens to LLaMA tokenizer
llama_spm_tokens_set = set(p.piece for p in llama_spm.pieces)
print(len(llama_spm_tokens_set))
print(f"Before:{len(llama_spm_tokens_set)}")
added_set = set()
for p in chinese_spm.pieces:
piece = p.piece
if piece not in llama_spm_tokens_set:
# print('picec', piece)
new_p = sp_pb2_model.ModelProto().SentencePiece()
new_p.piece = piece
new_p.score = 0
llama_spm.pieces.append(new_p)
added_set.add(piece)
print(f"[add domain tokens]New model pieces: {len(llama_spm.pieces)}")
vocab = load_baichuan_vocab(args.baichuan_vocab_file)
print('baichuan vocab len:', len(vocab))
baichuan_vocab_set = set([i for i in vocab if is_chinese_string(i)])
print('baichuan chinese vocab size:', len(baichuan_vocab_set))
print('baichuan vocab head:', list(baichuan_vocab_set)[:10])
for p in baichuan_vocab_set:
piece = p
if piece not in llama_spm_tokens_set and piece not in added_set:
# print('baichuan picec', piece)
new_p = sp_pb2_model.ModelProto().SentencePiece()
new_p.piece = piece
new_p.score = 0
llama_spm.pieces.append(new_p)
added_set.add(piece)
print(f"[add baichuan tokens]New model pieces: {len(llama_spm.pieces)}")
if args.add_jieba:
word_freqs = load_jieba_vocab(args.jieba_word_freq_file)
top_words = word_freqs[:args.jieba_word_size]
print('jieba top10 freq words:', top_words[:10])
jieba_vocab_set = set([i[0] for i in top_words if i])
print('jieba_vocab_set size:', len(jieba_vocab_set))
print('jieba_vocab head:', list(jieba_vocab_set)[:3])
for p in jieba_vocab_set:
piece = p
if piece not in llama_spm_tokens_set and piece not in added_set:
# print('jieba picec', piece)
new_p = sp_pb2_model.ModelProto().SentencePiece()
new_p.piece = piece
new_p.score = 0
llama_spm.pieces.append(new_p)
print(f"[add jieba tokens]New model pieces: {len(llama_spm.pieces)}")
# Save
output_sp_dir = 'merged_tokenizer_sp'
output_hf_dir = 'merged_tokenizer_hf' # the path to save Chinese-LLaMA tokenizer
os.makedirs(output_sp_dir, exist_ok=True)
with open(output_sp_dir + '/chinese_llama.model', 'wb') as f:
f.write(llama_spm.SerializeToString())
tokenizer = LlamaTokenizer(vocab_file=output_sp_dir + '/chinese_llama.model')
tokenizer.save_pretrained(output_hf_dir)
print(f"Chinese-LLaMA tokenizer has been saved to {output_hf_dir}")
# Test
llama_tokenizer = LlamaTokenizer.from_pretrained(args.base_tokenizer_dir)
chinese_llama_tokenizer = LlamaTokenizer.from_pretrained(output_hf_dir)
print(chinese_llama_tokenizer.all_special_tokens)
print(chinese_llama_tokenizer.all_special_ids)
print(chinese_llama_tokenizer.special_tokens_map)
print('old len:', len(llama_tokenizer), ' new len:', len(chinese_llama_tokenizer))
text = '''this is a test, hello world. thisisatesthelloworld,
慕容复来到河边,姑苏慕容氏在外面丢了人。
1号店一周岁了,我们一古脑儿买了10斤零食。
巴塞罗那足球俱乐部简称巴萨(Barça),是一家位于西班牙加泰罗尼亚巴塞罗那的足球俱乐部,于1899年由瑞士企业家胡安·甘伯所创立,世界球坛顶级足球俱乐部之一。俱乐部主场可容纳接近十万名观众,是全欧洲最大及世界第二大的足球场。
白日依山尽,黄河入海流。欲穷千里目,更上一层楼。'''
print("Test text:\n", text)
print(f"Tokenized by LLaMA tokenizer:{llama_tokenizer.tokenize(text)}")
print(f"Tokenized by Chinese-LLaMA tokenizer:{chinese_llama_tokenizer.tokenize(text)}")
if __name__ == '__main__':
main()