ai_platform_nlu/nlp/text_gen.py

54 lines
2.0 KiB
Python

#! -*- coding: utf-8 -*-
import os
os.environ["TF_KERAS"] = "1"
import numpy as np
from bert4keras.models import build_transformer_model
from bert4keras.tokenizers import Tokenizer
from bert4keras.snippets import AutoRegressiveDecoder
class ArticleCompletion(AutoRegressiveDecoder):
"""基于随机采样的文章续写
"""
def __init__(self, start_id, end_id, maxlen, minlen, config_path, ckpt_path):
super().__init__(start_id, end_id, maxlen, minlen)
self.model = build_transformer_model(
config_path=config_path,
checkpoint_path=ckpt_path,
segment_vocab_size=0,
application='lm',
)
@AutoRegressiveDecoder.wraps(default_rtype='probas')
def predict(self, inputs, output_ids, states):
token_ids = np.concatenate([inputs[0], output_ids], 1)
return self.last_token(self.model).predict(token_ids)
def generate(self, text, tokenizer, n=1, topp=0.95):
"""根据输入文本生成文本
Args:
text (str): 输入文本
tokenizer (Tokenizer): 分词工具
n (int, optional): 文本数. Defaults to 1.
topp (float, optional): 置信度. Defaults to 0.95.
Returns:
str: 生成的文本
"""
token_ids = tokenizer.encode(text)[0][:-1]
results = self.random_sample([token_ids], n, topp=topp) # 基于随机采样
return [text + tokenizer.decode(ids) for ids in results]
if __name__ == '__main__':
article_completion = ArticleCompletion(
start_id=None,
end_id=511, # 511是中文句号
maxlen=256,
minlen=128,
config_path='../models/nezha_gpt/config.json',
ckpt_path="../models/nezha_gpt/gpt.ckpt"
)
tokenizer = Tokenizer(f"../models/tokenizer/vocab.txt", do_lower_case=True)
print(article_completion.generate(u'中国科学院青岛生物能源与过程研究所泛能源大数据与战略研究中心', tokenizer))