54 lines
2.0 KiB
Python
54 lines
2.0 KiB
Python
|
#! -*- coding: utf-8 -*-
|
||
|
import os
|
||
|
os.environ["TF_KERAS"] = "1"
|
||
|
import numpy as np
|
||
|
from bert4keras.models import build_transformer_model
|
||
|
from bert4keras.tokenizers import Tokenizer
|
||
|
from bert4keras.snippets import AutoRegressiveDecoder
|
||
|
|
||
|
|
||
|
class ArticleCompletion(AutoRegressiveDecoder):
|
||
|
"""基于随机采样的文章续写
|
||
|
"""
|
||
|
def __init__(self, start_id, end_id, maxlen, minlen, config_path, ckpt_path):
|
||
|
super().__init__(start_id, end_id, maxlen, minlen)
|
||
|
self.model = build_transformer_model(
|
||
|
config_path=config_path,
|
||
|
checkpoint_path=ckpt_path,
|
||
|
segment_vocab_size=0,
|
||
|
application='lm',
|
||
|
)
|
||
|
|
||
|
@AutoRegressiveDecoder.wraps(default_rtype='probas')
|
||
|
def predict(self, inputs, output_ids, states):
|
||
|
token_ids = np.concatenate([inputs[0], output_ids], 1)
|
||
|
return self.last_token(self.model).predict(token_ids)
|
||
|
|
||
|
def generate(self, text, tokenizer, n=1, topp=0.95):
|
||
|
"""根据输入文本生成文本
|
||
|
|
||
|
Args:
|
||
|
text (str): 输入文本
|
||
|
tokenizer (Tokenizer): 分词工具
|
||
|
n (int, optional): 文本数. Defaults to 1.
|
||
|
topp (float, optional): 置信度. Defaults to 0.95.
|
||
|
|
||
|
Returns:
|
||
|
str: 生成的文本
|
||
|
"""
|
||
|
token_ids = tokenizer.encode(text)[0][:-1]
|
||
|
results = self.random_sample([token_ids], n, topp=topp) # 基于随机采样
|
||
|
return [text + tokenizer.decode(ids) for ids in results]
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
article_completion = ArticleCompletion(
|
||
|
start_id=None,
|
||
|
end_id=511, # 511是中文句号
|
||
|
maxlen=256,
|
||
|
minlen=128,
|
||
|
config_path='../models/nezha_gpt/config.json',
|
||
|
ckpt_path="../models/nezha_gpt/gpt.ckpt"
|
||
|
)
|
||
|
tokenizer = Tokenizer(f"../models/tokenizer/vocab.txt", do_lower_case=True)
|
||
|
print(article_completion.generate(u'中国科学院青岛生物能源与过程研究所泛能源大数据与战略研究中心', tokenizer))
|