ai_platform_nlu/nlp/chat.py

64 lines
2.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#! -*- coding: utf-8 -*-
# NEZHA模型做闲聊任务
# 测试脚本
# 测试环境tensorflow 2.5.3 + keras 2.3.1 + bert4keras 0.11
import os
os.environ['TF_KERAS'] = "1"
import numpy as np
from bert4keras.backend import keras, K
from bert4keras.models import build_transformer_model
from bert4keras.tokenizers import Tokenizer
from bert4keras.snippets import AutoRegressiveDecoder
class ChatBot(AutoRegressiveDecoder):
"""基于随机采样对话机器人
"""
def __init__(self, start_id, end_id, maxlen, model, tokenizer):
super().__init__(start_id, end_id, maxlen)
self.model = model
self.tokenizer = tokenizer
@AutoRegressiveDecoder.wraps(default_rtype='probas')
def predict(self, inputs, output_ids, states):
token_ids, segment_ids = inputs
token_ids = np.concatenate([token_ids, output_ids], 1)
curr_segment_ids = np.ones_like(output_ids) - segment_ids[0, -1]
segment_ids = np.concatenate([segment_ids, curr_segment_ids], 1)
return self.model.predict([token_ids, segment_ids])[:, -1]
def response(self, texts, topk=5):
token_ids, segment_ids = [self.tokenizer._token_start_id], [0]
for i, text in enumerate(texts):
ids = self.tokenizer.encode(text)[0][1:]
token_ids.extend(ids)
segment_ids.extend([i % 2] * len(ids))
results = self.random_sample([token_ids, segment_ids], 1, topk)
return self.tokenizer.decode(results[0])
def build_chat_model(model_path, tokenizer):
# nezha配置
config_path = f'{model_path}config.json'
checkpoint_path = f'{model_path}model.ckpt'
# 建立并加载模型
model = build_transformer_model(
config_path,
checkpoint_path,
model='nezha',
application='lm',
)
chatbot = ChatBot(start_id=None, end_id=tokenizer._token_end_id, maxlen=32, model=model, tokenizer=tokenizer)
return chatbot
if __name__ == '__main__':
tokenizer = Tokenizer("../models/nezha_gpt_dialog/vocab.txt", do_lower_case=True)
chatbot = build_chat_model('../models/nezha_gpt_dialog', tokenizer)
text_list = ['绿遍山原白满川', '子规声里雨如烟']
print(chatbot.response(text_list))