修改response结构
This commit is contained in:
parent
8c33ecc8ca
commit
eedd13fbea
|
@ -3,6 +3,7 @@
|
|||
# 测试脚本
|
||||
# 测试环境:tensorflow 2.5.3 + keras 2.3.1 + bert4keras 0.11
|
||||
import os
|
||||
|
||||
os.environ['TF_KERAS'] = "1"
|
||||
|
||||
import numpy as np
|
||||
|
@ -15,6 +16,7 @@ from bert4keras.snippets import AutoRegressiveDecoder
|
|||
class ChatBot(AutoRegressiveDecoder):
|
||||
"""基于随机采样对话机器人
|
||||
"""
|
||||
|
||||
def __init__(self, start_id, end_id, maxlen, model, tokenizer):
|
||||
super().__init__(start_id, end_id, maxlen)
|
||||
self.model = model
|
||||
|
@ -53,11 +55,9 @@ def build_chat_model(model_path, tokenizer):
|
|||
chatbot = ChatBot(start_id=None, end_id=tokenizer._token_end_id, maxlen=32, model=model, tokenizer=tokenizer)
|
||||
return chatbot
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tokenizer = Tokenizer("../models/nezha_gpt_dialog/vocab.txt", do_lower_case=True)
|
||||
chatbot = build_chat_model('../models/nezha_gpt_dialog', tokenizer)
|
||||
text_list = ['绿遍山原白满川', '子规声里雨如烟']
|
||||
print(chatbot.response(text_list))
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import hanlp
|
||||
from logzero import logger
|
||||
from hanlp_common.document import Document
|
||||
|
||||
tok = hanlp.load('./.hanlp/tok/coarse_electra_small_20220616_012050/')
|
||||
dep = hanlp.load('./.hanlp/dep/ctb9_dep_electra_small_20220216_100306/')
|
||||
sts = hanlp.load('./.hanlp/sts/sts_electra_base_zh_20210530_200109/')
|
||||
|
@ -17,6 +18,7 @@ def text_analysis(text):
|
|||
logger.info(rst)
|
||||
return rst
|
||||
|
||||
|
||||
def text_simi(src, tgt):
|
||||
score = sts([(src, tgt)])[0]
|
||||
result = ["negative", "positive"][round(score)]
|
||||
|
@ -25,4 +27,3 @@ def text_simi(src, tgt):
|
|||
|
||||
if __name__ == '__main__':
|
||||
print(text_analysis("台湾省是中国不可分割的一部分。"))
|
||||
|
|
@ -3,6 +3,7 @@
|
|||
# (Accuracy=0.7282149325820084 F1=0.8207266829447049 Final=0.7744708077633566)
|
||||
|
||||
import json, os, re
|
||||
|
||||
os.environ['TF_KERAS'] = '1'
|
||||
import numpy as np
|
||||
from bert4keras.backend import keras, K
|
||||
|
@ -16,6 +17,7 @@ from keras.models import Model
|
|||
from tqdm import tqdm
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
config = tf.compat.v1.ConfigProto()
|
||||
config.gpu_options.allow_growth = True # 按需分配显存
|
||||
tf_session = tf.compat.v1.Session(graph=tf.compat.v1.get_default_graph(), config=config)
|
||||
|
@ -27,6 +29,7 @@ max_a_len = 32
|
|||
batch_size = 32
|
||||
epochs = 10
|
||||
|
||||
|
||||
# # bert配置
|
||||
# config_path = '../models/nezha_gpt/config.json'
|
||||
# checkpoint_path = '../models/nezha_gpt/gpt.ckpt'
|
||||
|
@ -113,7 +116,6 @@ def masked_cross_entropy(y_true, y_pred):
|
|||
return cross_entropy
|
||||
|
||||
|
||||
|
||||
def build_reading_model(config_path: str, ckpt_path: str, keep_tokens: str, weight_path: str):
|
||||
model = build_transformer_model(
|
||||
config_path,
|
||||
|
@ -127,6 +129,7 @@ def build_reading_model(config_path:str, ckpt_path:str, keep_tokens:str, weight_
|
|||
model.load_weights(weight_path)
|
||||
return model
|
||||
|
||||
|
||||
def get_ngram_set(x, n):
|
||||
"""生成ngram合集,返回结果格式是:
|
||||
{(n-1)-gram: set([n-gram的第n个字集合])}
|
||||
|
@ -221,7 +224,8 @@ if __name__ == '__main__':
|
|||
model = build_reading_model()
|
||||
model.load_weights('../models/qa/best_model.weights')
|
||||
questions = "嬴政出生在哪里?"
|
||||
passages = ["秦始皇嬴政(前259年—前210年),嬴姓,赵氏 ,名政(一说名“正”),又称赵政 、祖龙 ,也有吕政一说(详见“人物争议-姓名之争”目录)。秦庄襄王和赵姬之子。中国古代杰出的政治家、战略家、改革家,首次完成中国大一统的政治人物,也是中国第一个称皇帝的君主。",
|
||||
passages = [
|
||||
"秦始皇嬴政(前259年—前210年),嬴姓,赵氏 ,名政(一说名“正”),又称赵政 、祖龙 ,也有吕政一说(详见“人物争议-姓名之争”目录)。秦庄襄王和赵姬之子。中国古代杰出的政治家、战略家、改革家,首次完成中国大一统的政治人物,也是中国第一个称皇帝的君主。",
|
||||
"公元前221年,秦统一六国之后,秦王嬴政认为自己“德兼三皇,功过五帝”,遂采用三皇之“皇”、五帝之“帝”构成“皇帝”的称号,是中国历史上第一个使用“皇帝”称号的君主,所以自称“始皇帝”。",
|
||||
"秦始皇有二十余子。长子扶苏,少子胡亥。",
|
||||
"嬴政出生在当时赵国的邯郸廓城(在今城内中街以东,丛台西南的朱家巷一带),是当时的秦国王孙异人之子。"]
|
||||
|
|
|
@ -22,16 +22,15 @@ config.gpu_options.allow_growth=True # 按需分配显存
|
|||
tf_session = tf.compat.v1.Session(graph=tf.compat.v1.get_default_graph(), config=config)
|
||||
tf.compat.v1.keras.backend.set_session(tf_session)
|
||||
|
||||
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
|
||||
set_gelu('tanh') # 切换gelu版本
|
||||
|
||||
MAX_LEN = 128
|
||||
BATCH_SIZE = 32
|
||||
AdamLR = extend_with_piecewise_linear_lr(Adam, name='AdamLR')
|
||||
|
||||
|
||||
# config_path = '/home/zhaojh/pretrain_models/roberta_base/bert_config.json'
|
||||
# checkpoint_path = '/home/zhaojh/pretrain_models/roberta_base/bert_model.ckpt'
|
||||
# dict_path = '/home/zhaojh/pretrain_models/roberta_base/vocab.txt'
|
||||
|
@ -191,8 +190,10 @@ def load_cls_model(config_path:str, ckpt_path:str, weight_path:str):
|
|||
model.load_weights(weight_path)
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from keras.models import load_model
|
||||
|
||||
dict_path = '../models/tokenizer/vocab.txt'
|
||||
tokenizer = Tokenizer(dict_path, do_lower_case=True)
|
||||
model = load_model('../models/text_classifier/best_model.h5')
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
#! -*- coding: utf-8 -*-
|
||||
import os
|
||||
|
||||
os.environ["TF_KERAS"] = "1"
|
||||
import numpy as np
|
||||
from bert4keras.models import build_transformer_model
|
||||
|
@ -10,6 +11,7 @@ from bert4keras.snippets import AutoRegressiveDecoder
|
|||
class ArticleCompletion(AutoRegressiveDecoder):
|
||||
"""基于随机采样的文章续写
|
||||
"""
|
||||
|
||||
def __init__(self, start_id, end_id, maxlen, minlen, config_path, ckpt_path):
|
||||
super().__init__(start_id, end_id, maxlen, minlen)
|
||||
self.model = build_transformer_model(
|
||||
|
|
110
run.py
110
run.py
|
@ -4,11 +4,11 @@ import sys
|
|||
import json
|
||||
from flask import Flask, request, make_response
|
||||
from logzero import logger
|
||||
|
||||
# current_path = os.path.dirname(os.path.abspath(__file__)) # for local
|
||||
current_path = "/app" # for docker
|
||||
logger.info(f"{current_path}")
|
||||
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
sys.path.append(f"{current_path}/nlp/")
|
||||
|
@ -16,7 +16,6 @@ os.environ["TF_KERAS"] = "1"
|
|||
os.environ["CUDA_DEVICE_ORDER"] = 'PCI_BUS_ID'
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
|
||||
|
||||
|
||||
from bert4keras.tokenizers import Tokenizer, load_vocab
|
||||
from keras.models import load_model
|
||||
|
||||
|
@ -27,7 +26,6 @@ from nlp.text_classification import run_cls as run_class, AdamLR
|
|||
from nlp.hanlp_tools import text_analysis, text_simi
|
||||
from nlp.reading import build_reading_model, gen_answer
|
||||
|
||||
|
||||
general_tokenizer = Tokenizer(f"{current_path}/models/tokenizer/vocab.txt", do_lower_case=True) # 通用分词器
|
||||
dialog_tokenizer = Tokenizer(f"{current_path}/models/nezha_gpt_dialog/vocab.txt", do_lower_case=True) # 对话分词器
|
||||
|
||||
|
@ -66,49 +64,52 @@ reading_model = build_reading_model(
|
|||
f"{current_path}/models/qa/best_model.weights"
|
||||
)
|
||||
|
||||
TEXT = "text"
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
|
||||
@app.route('/text_cls/', methods=["POST"])
|
||||
def run_cls():
|
||||
resp = make_response()
|
||||
resp_info = dict()
|
||||
if request.method == "POST":
|
||||
text = request.form.get('text')
|
||||
if text is not None and text != '':
|
||||
resp.response = run_class(cls_model, general_tokenizer, text)
|
||||
rst = run_class(cls_model, general_tokenizer, text)
|
||||
resp_info["code"] = 200
|
||||
resp_info["data"] = rst
|
||||
resp_info["dtype"] = TEXT
|
||||
else:
|
||||
resp_info["msg"] = "Input is None, please check !"
|
||||
resp_info["code"] = 406
|
||||
resp = make_response(json.dumps(resp_info))
|
||||
resp.status_code = 200
|
||||
return resp
|
||||
else:
|
||||
resp.status_code = 406
|
||||
return resp
|
||||
else:
|
||||
resp.status_code=405
|
||||
return resp
|
||||
|
||||
|
||||
@app.route('/text_gen/', methods=["POST"])
|
||||
def run_gen():
|
||||
resp = make_response()
|
||||
resp_info = dict()
|
||||
if request.method == "POST":
|
||||
text = request.form.get('text')
|
||||
logger.info(f"将对文本'{text}'进行续写")
|
||||
if text != "":
|
||||
rest = gen_model.generate(text, general_tokenizer)
|
||||
logger.info(rest)
|
||||
resp.response = rest
|
||||
resp_info["code"] = 200
|
||||
resp_info["data"] = rest
|
||||
resp_info["dtype"] = TEXT
|
||||
else:
|
||||
resp_info["msg"] = "Input is None, please check !"
|
||||
resp_info["code"] = 406
|
||||
resp = make_response(json.dumps(resp_info))
|
||||
resp.status_code = 200
|
||||
return resp
|
||||
else:
|
||||
resp.status_code = 406
|
||||
return resp
|
||||
else:
|
||||
resp.status_code=405
|
||||
return resp
|
||||
|
||||
|
||||
@app.route('/chat/', methods=["POST"])
|
||||
def run_chat():
|
||||
# todo: 这个模块可以用grpc流式服务做。
|
||||
# 如果用flask 就把历史对话都按照list的方式传进来
|
||||
# 历史对话可以用json传
|
||||
resp = make_response()
|
||||
resp_info = dict()
|
||||
if request.method == "POST":
|
||||
dialog_history = request.form.get("dialog_history")
|
||||
dialog_history = dialog_history.split('。')
|
||||
|
@ -116,15 +117,16 @@ def run_chat():
|
|||
if len(dialog_history) > 0:
|
||||
rest = chatbot.response(dialog_history)
|
||||
logger.info(rest)
|
||||
resp.response = rest
|
||||
resp_info["code"] = 200
|
||||
resp_info["data"] = rest
|
||||
resp_info["dtype"] = TEXT
|
||||
else:
|
||||
resp_info["msg"] = "请将历史对话以中文句号分隔"
|
||||
resp_info["code"] = 406
|
||||
resp = make_response(json.dumps(resp_info))
|
||||
resp.status_code = 200
|
||||
return resp
|
||||
else:
|
||||
resp.status_code = 406
|
||||
return resp
|
||||
else:
|
||||
resp.status_code=405
|
||||
return resp
|
||||
|
||||
|
||||
# @app.route('/translate/', methods=["POST"])
|
||||
# def run_translate():
|
||||
|
@ -144,48 +146,56 @@ def run_chat():
|
|||
|
||||
@app.route('/simi/', methods=["POST"])
|
||||
def run_match():
|
||||
resp = make_response()
|
||||
resp_info = dict()
|
||||
|
||||
if request.method == "POST":
|
||||
src = request.form.get('text_1')
|
||||
tgt = request.form.get('text_2')
|
||||
resp.response = str(text_simi(src, tgt))
|
||||
resp_info["code"] = 200
|
||||
resp_info["data"] = str(text_simi(src, tgt))
|
||||
resp_info["dtype"] = TEXT
|
||||
resp = make_response(json.dumps(resp_info))
|
||||
resp.status_code = 200
|
||||
return resp
|
||||
else:
|
||||
resp.status_code=405
|
||||
return resp
|
||||
|
||||
|
||||
@app.route('/dependency/', methods=["POST"])
|
||||
def run_depend():
|
||||
resp = make_response()
|
||||
resp_info = dict()
|
||||
if request.method == "POST":
|
||||
text = request.form.get('text')
|
||||
if text is None or text.strip() == "":
|
||||
resp.status_code=406
|
||||
return resp
|
||||
resp.response = str(text_analysis(text))
|
||||
resp_info["msg"] = "Input is None, please check !"
|
||||
resp_info["code"] = 406
|
||||
else:
|
||||
resp_info["code"] = 200
|
||||
resp_info["data"] = str(text_analysis(text))
|
||||
resp_info["dtype"] = TEXT
|
||||
resp = make_response(json.dumps(resp_info))
|
||||
resp.status_code = 200
|
||||
return resp
|
||||
else:
|
||||
resp.status_code=405
|
||||
return resp
|
||||
|
||||
|
||||
@app.route('/reading/', methods=["POST"])
|
||||
def run_reading():
|
||||
resp = make_response()
|
||||
resp_info = dict()
|
||||
if request.method == "POST":
|
||||
question = request.form.get("question")
|
||||
passages = request.form.get("passages")
|
||||
passages = [x + '。' for x in passages.split('。')]
|
||||
if len(passages) == 0 or question is None or question.strip() == "":
|
||||
resp.status_code=406
|
||||
return resp
|
||||
resp.response = json.dumps(gen_answer(question, passages, reading_model, reading_tokenizer))
|
||||
resp.status_code = 200
|
||||
return resp
|
||||
if question is None or question.strip() == "":
|
||||
resp_info["msg"] = "Question is None, please check!"
|
||||
resp_info["code"] = 406
|
||||
elif len(passages) == 0:
|
||||
resp_info["msg"] = "请将文章段落以中文句号分隔"
|
||||
resp_info["code"] = 406
|
||||
else:
|
||||
resp.status_code=405
|
||||
rest = gen_answer(question, passages, reading_model, reading_tokenizer)
|
||||
resp_info["code"] = 200
|
||||
resp_info["data"] = rest
|
||||
resp_info["dtype"] = TEXT
|
||||
resp = make_response(json.dumps(resp_info))
|
||||
resp.status_code = 200
|
||||
return resp
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue