64 lines
2.2 KiB
Python
64 lines
2.2 KiB
Python
#! -*- coding: utf-8 -*-
|
||
# NEZHA模型做闲聊任务
|
||
# 测试脚本
|
||
# 测试环境:tensorflow 2.5.3 + keras 2.3.1 + bert4keras 0.11
|
||
import os
|
||
os.environ['TF_KERAS'] = "1"
|
||
|
||
import numpy as np
|
||
from bert4keras.backend import keras, K
|
||
from bert4keras.models import build_transformer_model
|
||
from bert4keras.tokenizers import Tokenizer
|
||
from bert4keras.snippets import AutoRegressiveDecoder
|
||
|
||
|
||
class ChatBot(AutoRegressiveDecoder):
|
||
"""基于随机采样对话机器人
|
||
"""
|
||
def __init__(self, start_id, end_id, maxlen, model, tokenizer):
|
||
super().__init__(start_id, end_id, maxlen)
|
||
self.model = model
|
||
self.tokenizer = tokenizer
|
||
|
||
@AutoRegressiveDecoder.wraps(default_rtype='probas')
|
||
def predict(self, inputs, output_ids, states):
|
||
token_ids, segment_ids = inputs
|
||
token_ids = np.concatenate([token_ids, output_ids], 1)
|
||
curr_segment_ids = np.ones_like(output_ids) - segment_ids[0, -1]
|
||
segment_ids = np.concatenate([segment_ids, curr_segment_ids], 1)
|
||
return self.model.predict([token_ids, segment_ids])[:, -1]
|
||
|
||
def response(self, texts, topk=5):
|
||
token_ids, segment_ids = [self.tokenizer._token_start_id], [0]
|
||
for i, text in enumerate(texts):
|
||
ids = self.tokenizer.encode(text)[0][1:]
|
||
token_ids.extend(ids)
|
||
segment_ids.extend([i % 2] * len(ids))
|
||
results = self.random_sample([token_ids, segment_ids], 1, topk)
|
||
return self.tokenizer.decode(results[0])
|
||
|
||
|
||
def build_chat_model(model_path, tokenizer):
|
||
# nezha配置
|
||
config_path = f'{model_path}config.json'
|
||
checkpoint_path = f'{model_path}model.ckpt'
|
||
|
||
# 建立并加载模型
|
||
model = build_transformer_model(
|
||
config_path,
|
||
checkpoint_path,
|
||
model='nezha',
|
||
application='lm',
|
||
)
|
||
chatbot = ChatBot(start_id=None, end_id=tokenizer._token_end_id, maxlen=32, model=model, tokenizer=tokenizer)
|
||
return chatbot
|
||
|
||
if __name__ == '__main__':
|
||
tokenizer = Tokenizer("../models/nezha_gpt_dialog/vocab.txt", do_lower_case=True)
|
||
chatbot = build_chat_model('../models/nezha_gpt_dialog', tokenizer)
|
||
text_list = ['绿遍山原白满川', '子规声里雨如烟']
|
||
print(chatbot.response(text_list))
|
||
|
||
|
||
|