#! -*- coding: utf-8 -*- # NEZHA模型做闲聊任务 # 测试脚本 # 测试环境:tensorflow 2.5.3 + keras 2.3.1 + bert4keras 0.11 import os os.environ['TF_KERAS'] = "1" import numpy as np from bert4keras.backend import keras, K from bert4keras.models import build_transformer_model from bert4keras.tokenizers import Tokenizer from bert4keras.snippets import AutoRegressiveDecoder class ChatBot(AutoRegressiveDecoder): """基于随机采样对话机器人 """ def __init__(self, start_id, end_id, maxlen, model, tokenizer): super().__init__(start_id, end_id, maxlen) self.model = model self.tokenizer = tokenizer @AutoRegressiveDecoder.wraps(default_rtype='probas') def predict(self, inputs, output_ids, states): token_ids, segment_ids = inputs token_ids = np.concatenate([token_ids, output_ids], 1) curr_segment_ids = np.ones_like(output_ids) - segment_ids[0, -1] segment_ids = np.concatenate([segment_ids, curr_segment_ids], 1) return self.model.predict([token_ids, segment_ids])[:, -1] def response(self, texts, topk=5): token_ids, segment_ids = [self.tokenizer._token_start_id], [0] for i, text in enumerate(texts): ids = self.tokenizer.encode(text)[0][1:] token_ids.extend(ids) segment_ids.extend([i % 2] * len(ids)) results = self.random_sample([token_ids, segment_ids], 1, topk) return self.tokenizer.decode(results[0]) def build_chat_model(model_path, tokenizer): # nezha配置 config_path = f'{model_path}config.json' checkpoint_path = f'{model_path}model.ckpt' # 建立并加载模型 model = build_transformer_model( config_path, checkpoint_path, model='nezha', application='lm', ) chatbot = ChatBot(start_id=None, end_id=tokenizer._token_end_id, maxlen=32, model=model, tokenizer=tokenizer) return chatbot if __name__ == '__main__': tokenizer = Tokenizer("../models/nezha_gpt_dialog/vocab.txt", do_lower_case=True) chatbot = build_chat_model('../models/nezha_gpt_dialog', tokenizer) text_list = ['绿遍山原白满川', '子规声里雨如烟'] print(chatbot.response(text_list))