64 lines
2.2 KiB
Python
64 lines
2.2 KiB
Python
|
#! -*- coding: utf-8 -*-
|
|||
|
# NEZHA模型做闲聊任务
|
|||
|
# 测试脚本
|
|||
|
# 测试环境:tensorflow 2.5.3 + keras 2.3.1 + bert4keras 0.11
|
|||
|
import os
|
|||
|
os.environ['TF_KERAS'] = "1"
|
|||
|
|
|||
|
import numpy as np
|
|||
|
from bert4keras.backend import keras, K
|
|||
|
from bert4keras.models import build_transformer_model
|
|||
|
from bert4keras.tokenizers import Tokenizer
|
|||
|
from bert4keras.snippets import AutoRegressiveDecoder
|
|||
|
|
|||
|
|
|||
|
class ChatBot(AutoRegressiveDecoder):
|
|||
|
"""基于随机采样对话机器人
|
|||
|
"""
|
|||
|
def __init__(self, start_id, end_id, maxlen, model, tokenizer):
|
|||
|
super().__init__(start_id, end_id, maxlen)
|
|||
|
self.model = model
|
|||
|
self.tokenizer = tokenizer
|
|||
|
|
|||
|
@AutoRegressiveDecoder.wraps(default_rtype='probas')
|
|||
|
def predict(self, inputs, output_ids, states):
|
|||
|
token_ids, segment_ids = inputs
|
|||
|
token_ids = np.concatenate([token_ids, output_ids], 1)
|
|||
|
curr_segment_ids = np.ones_like(output_ids) - segment_ids[0, -1]
|
|||
|
segment_ids = np.concatenate([segment_ids, curr_segment_ids], 1)
|
|||
|
return self.model.predict([token_ids, segment_ids])[:, -1]
|
|||
|
|
|||
|
def response(self, texts, topk=5):
|
|||
|
token_ids, segment_ids = [self.tokenizer._token_start_id], [0]
|
|||
|
for i, text in enumerate(texts):
|
|||
|
ids = self.tokenizer.encode(text)[0][1:]
|
|||
|
token_ids.extend(ids)
|
|||
|
segment_ids.extend([i % 2] * len(ids))
|
|||
|
results = self.random_sample([token_ids, segment_ids], 1, topk)
|
|||
|
return self.tokenizer.decode(results[0])
|
|||
|
|
|||
|
|
|||
|
def build_chat_model(model_path, tokenizer):
|
|||
|
# nezha配置
|
|||
|
config_path = f'{model_path}config.json'
|
|||
|
checkpoint_path = f'{model_path}model.ckpt'
|
|||
|
|
|||
|
# 建立并加载模型
|
|||
|
model = build_transformer_model(
|
|||
|
config_path,
|
|||
|
checkpoint_path,
|
|||
|
model='nezha',
|
|||
|
application='lm',
|
|||
|
)
|
|||
|
chatbot = ChatBot(start_id=None, end_id=tokenizer._token_end_id, maxlen=32, model=model, tokenizer=tokenizer)
|
|||
|
return chatbot
|
|||
|
|
|||
|
if __name__ == '__main__':
|
|||
|
tokenizer = Tokenizer("../models/nezha_gpt_dialog/vocab.txt", do_lower_case=True)
|
|||
|
chatbot = build_chat_model('../models/nezha_gpt_dialog', tokenizer)
|
|||
|
text_list = ['绿遍山原白满川', '子规声里雨如烟']
|
|||
|
print(chatbot.response(text_list))
|
|||
|
|
|||
|
|
|||
|
|