ai_platform_nlu/nlp/chat.py

64 lines
2.2 KiB
Python
Raw Permalink Normal View History

2022-12-07 10:49:21 +08:00
#! -*- coding: utf-8 -*-
# NEZHA模型做闲聊任务
# 测试脚本
# 测试环境tensorflow 2.5.3 + keras 2.3.1 + bert4keras 0.11
import os
2022-12-08 15:16:57 +08:00
2022-12-07 10:49:21 +08:00
os.environ['TF_KERAS'] = "1"
import numpy as np
from bert4keras.backend import keras, K
from bert4keras.models import build_transformer_model
from bert4keras.tokenizers import Tokenizer
from bert4keras.snippets import AutoRegressiveDecoder
class ChatBot(AutoRegressiveDecoder):
"""基于随机采样对话机器人
"""
2022-12-08 15:16:57 +08:00
2022-12-07 10:49:21 +08:00
def __init__(self, start_id, end_id, maxlen, model, tokenizer):
super().__init__(start_id, end_id, maxlen)
self.model = model
self.tokenizer = tokenizer
@AutoRegressiveDecoder.wraps(default_rtype='probas')
def predict(self, inputs, output_ids, states):
token_ids, segment_ids = inputs
token_ids = np.concatenate([token_ids, output_ids], 1)
curr_segment_ids = np.ones_like(output_ids) - segment_ids[0, -1]
segment_ids = np.concatenate([segment_ids, curr_segment_ids], 1)
return self.model.predict([token_ids, segment_ids])[:, -1]
def response(self, texts, topk=5):
token_ids, segment_ids = [self.tokenizer._token_start_id], [0]
for i, text in enumerate(texts):
ids = self.tokenizer.encode(text)[0][1:]
token_ids.extend(ids)
segment_ids.extend([i % 2] * len(ids))
results = self.random_sample([token_ids, segment_ids], 1, topk)
return self.tokenizer.decode(results[0])
def build_chat_model(model_path, tokenizer):
# nezha配置
config_path = f'{model_path}config.json'
checkpoint_path = f'{model_path}model.ckpt'
# 建立并加载模型
model = build_transformer_model(
config_path,
checkpoint_path,
model='nezha',
application='lm',
)
chatbot = ChatBot(start_id=None, end_id=tokenizer._token_end_id, maxlen=32, model=model, tokenizer=tokenizer)
return chatbot
2022-12-08 15:16:57 +08:00
2022-12-07 10:49:21 +08:00
if __name__ == '__main__':
tokenizer = Tokenizer("../models/nezha_gpt_dialog/vocab.txt", do_lower_case=True)
chatbot = build_chat_model('../models/nezha_gpt_dialog', tokenizer)
text_list = ['绿遍山原白满川', '子规声里雨如烟']
print(chatbot.response(text_list))