2022-12-07 10:49:21 +08:00
|
|
|
# -*-coding:utf-8-*-
|
|
|
|
import os
|
|
|
|
import sys
|
|
|
|
import json
|
|
|
|
from flask import Flask, request, make_response
|
|
|
|
from logzero import logger
|
2022-12-08 15:16:57 +08:00
|
|
|
|
2022-12-07 10:49:21 +08:00
|
|
|
# current_path = os.path.dirname(os.path.abspath(__file__)) # for local
|
2022-12-08 15:16:57 +08:00
|
|
|
current_path = "/app" # for docker
|
2022-12-07 10:49:21 +08:00
|
|
|
logger.info(f"{current_path}")
|
|
|
|
|
|
|
|
app = Flask(__name__)
|
|
|
|
|
|
|
|
sys.path.append(f"{current_path}/nlp/")
|
|
|
|
os.environ["TF_KERAS"] = "1"
|
|
|
|
os.environ["CUDA_DEVICE_ORDER"] = 'PCI_BUS_ID'
|
|
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
|
|
|
|
|
|
|
|
from bert4keras.tokenizers import Tokenizer, load_vocab
|
|
|
|
from keras.models import load_model
|
|
|
|
|
|
|
|
from nlp.text_gen import ArticleCompletion
|
|
|
|
from nlp.chat import build_chat_model
|
|
|
|
from nlp.text_classification import run_cls as run_class, AdamLR
|
|
|
|
# from utils.translate import load_model as load_translator, run_test as run_translator
|
|
|
|
from nlp.hanlp_tools import text_analysis, text_simi
|
|
|
|
from nlp.reading import build_reading_model, gen_answer
|
|
|
|
|
2022-12-08 15:16:57 +08:00
|
|
|
general_tokenizer = Tokenizer(f"{current_path}/models/tokenizer/vocab.txt", do_lower_case=True) # 通用分词器
|
|
|
|
dialog_tokenizer = Tokenizer(f"{current_path}/models/nezha_gpt_dialog/vocab.txt", do_lower_case=True) # 对话分词器
|
2022-12-07 10:49:21 +08:00
|
|
|
|
|
|
|
token_dict, keep_tokens = load_vocab(
|
|
|
|
# 加载并精简词表,建立阅读理解词表
|
|
|
|
dict_path=f"{current_path}/models/tokenizer/vocab.txt",
|
|
|
|
simplified=True,
|
|
|
|
startswith=['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]'],
|
|
|
|
)
|
2022-12-08 15:16:57 +08:00
|
|
|
reading_tokenizer = Tokenizer(token_dict, do_lower_case=True) # 阅读理解分词器
|
|
|
|
cls_model = load_model(f"{current_path}/models/text_classifier/best_model.h5") # 加载分类模型
|
2022-12-07 10:49:21 +08:00
|
|
|
gen_model = ArticleCompletion(
|
|
|
|
# 加载文本生成模型
|
|
|
|
start_id=None,
|
|
|
|
end_id=511, # 511是中文句号
|
|
|
|
maxlen=256,
|
|
|
|
minlen=128,
|
|
|
|
config_path=f"{current_path}/models/nezha_gpt/config.json",
|
|
|
|
ckpt_path=f"{current_path}/models/nezha_gpt/gpt.ckpt"
|
|
|
|
)
|
|
|
|
chatbot = build_chat_model(
|
|
|
|
# 加载对话模型
|
2022-12-08 15:16:57 +08:00
|
|
|
f"{current_path}/models/nezha_gpt_dialog/",
|
2022-12-07 10:49:21 +08:00
|
|
|
dialog_tokenizer)
|
|
|
|
# translator, trans_data = load_translator(
|
|
|
|
# # 加载翻译模型
|
|
|
|
# f"{current_path}/models/translator/translation.h5",
|
|
|
|
# f"{current_path}/data/translator/train.txt",
|
|
|
|
# f"{current_path}/data/translator/dev.txt"
|
|
|
|
# )
|
|
|
|
reading_model = build_reading_model(
|
|
|
|
# 加载阅读理解模型
|
2022-12-08 15:16:57 +08:00
|
|
|
f"{current_path}/models/nezha_gpt/config.json",
|
2022-12-07 10:49:21 +08:00
|
|
|
f"{current_path}/models/nezha_gpt/gpt.ckpt",
|
|
|
|
keep_tokens,
|
|
|
|
f"{current_path}/models/qa/best_model.weights"
|
2022-12-08 15:16:57 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
TEXT = "text"
|
2022-12-07 10:49:21 +08:00
|
|
|
|
|
|
|
app = Flask(__name__)
|
|
|
|
|
2022-12-08 15:16:57 +08:00
|
|
|
|
2022-12-07 10:49:21 +08:00
|
|
|
@app.route('/text_cls/', methods=["POST"])
|
|
|
|
def run_cls():
|
2022-12-08 15:16:57 +08:00
|
|
|
resp_info = dict()
|
2022-12-07 10:49:21 +08:00
|
|
|
if request.method == "POST":
|
|
|
|
text = request.form.get('text')
|
|
|
|
if text is not None and text != '':
|
2022-12-08 15:16:57 +08:00
|
|
|
rst = run_class(cls_model, general_tokenizer, text)
|
|
|
|
resp_info["code"] = 200
|
|
|
|
resp_info["data"] = rst
|
|
|
|
resp_info["dtype"] = TEXT
|
2022-12-07 10:49:21 +08:00
|
|
|
else:
|
2022-12-08 15:16:57 +08:00
|
|
|
resp_info["msg"] = "Input is None, please check !"
|
|
|
|
resp_info["code"] = 406
|
|
|
|
resp = make_response(json.dumps(resp_info))
|
|
|
|
resp.status_code = 200
|
|
|
|
return resp
|
|
|
|
|
2022-12-07 10:49:21 +08:00
|
|
|
|
|
|
|
@app.route('/text_gen/', methods=["POST"])
|
|
|
|
def run_gen():
|
2022-12-08 15:16:57 +08:00
|
|
|
resp_info = dict()
|
2022-12-07 10:49:21 +08:00
|
|
|
if request.method == "POST":
|
|
|
|
text = request.form.get('text')
|
|
|
|
logger.info(f"将对文本'{text}'进行续写")
|
|
|
|
if text != "":
|
|
|
|
rest = gen_model.generate(text, general_tokenizer)
|
|
|
|
logger.info(rest)
|
2022-12-08 15:16:57 +08:00
|
|
|
resp_info["code"] = 200
|
|
|
|
resp_info["data"] = rest
|
|
|
|
resp_info["dtype"] = TEXT
|
2022-12-07 10:49:21 +08:00
|
|
|
else:
|
2022-12-08 15:16:57 +08:00
|
|
|
resp_info["msg"] = "Input is None, please check !"
|
|
|
|
resp_info["code"] = 406
|
|
|
|
resp = make_response(json.dumps(resp_info))
|
|
|
|
resp.status_code = 200
|
|
|
|
return resp
|
|
|
|
|
2022-12-07 10:49:21 +08:00
|
|
|
|
|
|
|
@app.route('/chat/', methods=["POST"])
|
|
|
|
def run_chat():
|
2022-12-08 15:16:57 +08:00
|
|
|
resp_info = dict()
|
2022-12-07 10:49:21 +08:00
|
|
|
if request.method == "POST":
|
|
|
|
dialog_history = request.form.get("dialog_history")
|
2022-12-08 15:29:01 +08:00
|
|
|
dialog_history = dialog_history.rstrip('。').split('。')
|
2022-12-07 10:49:21 +08:00
|
|
|
logger.info(f"将对文本'{dialog_history}'进行对话")
|
|
|
|
if len(dialog_history) > 0:
|
|
|
|
rest = chatbot.response(dialog_history)
|
|
|
|
logger.info(rest)
|
2022-12-08 15:16:57 +08:00
|
|
|
resp_info["code"] = 200
|
|
|
|
resp_info["data"] = rest
|
|
|
|
resp_info["dtype"] = TEXT
|
2022-12-07 10:49:21 +08:00
|
|
|
else:
|
2022-12-08 15:16:57 +08:00
|
|
|
resp_info["msg"] = "请将历史对话以中文句号分隔"
|
|
|
|
resp_info["code"] = 406
|
|
|
|
resp = make_response(json.dumps(resp_info))
|
|
|
|
resp.status_code = 200
|
|
|
|
return resp
|
|
|
|
|
2022-12-07 10:49:21 +08:00
|
|
|
|
|
|
|
# @app.route('/translate/', methods=["POST"])
|
|
|
|
# def run_translate():
|
|
|
|
# resp = make_response()
|
|
|
|
# if request.method == "POST":
|
|
|
|
# text = request.json.get('text')
|
|
|
|
# if text is None or text.strip() == "":
|
|
|
|
# resp.status_code = 406
|
|
|
|
# return resp
|
|
|
|
# rest = run_translator(text, translator, trans_data)
|
|
|
|
# resp.status_code = 200
|
|
|
|
# resp.response = rest
|
|
|
|
# return resp
|
|
|
|
# else:
|
|
|
|
# resp.status_code=405
|
|
|
|
# return resp
|
|
|
|
|
|
|
|
@app.route('/simi/', methods=["POST"])
|
|
|
|
def run_match():
|
2022-12-08 15:16:57 +08:00
|
|
|
resp_info = dict()
|
|
|
|
|
2022-12-07 10:49:21 +08:00
|
|
|
if request.method == "POST":
|
|
|
|
src = request.form.get('text_1')
|
|
|
|
tgt = request.form.get('text_2')
|
2022-12-08 15:16:57 +08:00
|
|
|
resp_info["code"] = 200
|
|
|
|
resp_info["data"] = str(text_simi(src, tgt))
|
|
|
|
resp_info["dtype"] = TEXT
|
|
|
|
resp = make_response(json.dumps(resp_info))
|
|
|
|
resp.status_code = 200
|
|
|
|
return resp
|
2022-12-07 10:49:21 +08:00
|
|
|
|
|
|
|
|
|
|
|
@app.route('/dependency/', methods=["POST"])
|
|
|
|
def run_depend():
|
2022-12-08 15:16:57 +08:00
|
|
|
resp_info = dict()
|
2022-12-07 10:49:21 +08:00
|
|
|
if request.method == "POST":
|
|
|
|
text = request.form.get('text')
|
|
|
|
if text is None or text.strip() == "":
|
2022-12-08 15:16:57 +08:00
|
|
|
resp_info["msg"] = "Input is None, please check !"
|
|
|
|
resp_info["code"] = 406
|
|
|
|
else:
|
|
|
|
resp_info["code"] = 200
|
|
|
|
resp_info["data"] = str(text_analysis(text))
|
|
|
|
resp_info["dtype"] = TEXT
|
|
|
|
resp = make_response(json.dumps(resp_info))
|
|
|
|
resp.status_code = 200
|
|
|
|
return resp
|
|
|
|
|
2022-12-07 10:49:21 +08:00
|
|
|
|
|
|
|
@app.route('/reading/', methods=["POST"])
|
|
|
|
def run_reading():
|
2022-12-08 15:16:57 +08:00
|
|
|
resp_info = dict()
|
2022-12-07 10:49:21 +08:00
|
|
|
if request.method == "POST":
|
|
|
|
question = request.form.get("question")
|
|
|
|
passages = request.form.get("passages")
|
2022-12-08 15:29:01 +08:00
|
|
|
passages = [x.strip() + '。' for x in passages.split('。')]
|
2022-12-08 15:16:57 +08:00
|
|
|
if question is None or question.strip() == "":
|
|
|
|
resp_info["msg"] = "Question is None, please check!"
|
|
|
|
resp_info["code"] = 406
|
|
|
|
elif len(passages) == 0:
|
|
|
|
resp_info["msg"] = "请将文章段落以中文句号分隔"
|
|
|
|
resp_info["code"] = 406
|
|
|
|
else:
|
|
|
|
rest = gen_answer(question, passages, reading_model, reading_tokenizer)
|
|
|
|
resp_info["code"] = 200
|
|
|
|
resp_info["data"] = rest
|
|
|
|
resp_info["dtype"] = TEXT
|
|
|
|
resp = make_response(json.dumps(resp_info))
|
|
|
|
resp.status_code = 200
|
|
|
|
return resp
|
2022-12-07 10:49:21 +08:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
2022-12-08 15:29:01 +08:00
|
|
|
app.run(host='0.0.0.0', port=8903, debug=False)
|