ai_platform_nlu/run.py

204 lines
6.7 KiB
Python
Raw Permalink Normal View History

2022-12-07 10:49:21 +08:00
# -*-coding:utf-8-*-
import os
import sys
import json
from flask import Flask, request, make_response
from logzero import logger
2022-12-08 15:16:57 +08:00
2022-12-07 10:49:21 +08:00
# current_path = os.path.dirname(os.path.abspath(__file__)) # for local
2022-12-08 15:16:57 +08:00
current_path = "/app" # for docker
2022-12-07 10:49:21 +08:00
logger.info(f"{current_path}")
app = Flask(__name__)
sys.path.append(f"{current_path}/nlp/")
os.environ["TF_KERAS"] = "1"
os.environ["CUDA_DEVICE_ORDER"] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
from bert4keras.tokenizers import Tokenizer, load_vocab
from keras.models import load_model
from nlp.text_gen import ArticleCompletion
from nlp.chat import build_chat_model
from nlp.text_classification import run_cls as run_class, AdamLR
# from utils.translate import load_model as load_translator, run_test as run_translator
from nlp.hanlp_tools import text_analysis, text_simi
from nlp.reading import build_reading_model, gen_answer
2022-12-08 15:16:57 +08:00
general_tokenizer = Tokenizer(f"{current_path}/models/tokenizer/vocab.txt", do_lower_case=True) # 通用分词器
dialog_tokenizer = Tokenizer(f"{current_path}/models/nezha_gpt_dialog/vocab.txt", do_lower_case=True) # 对话分词器
2022-12-07 10:49:21 +08:00
token_dict, keep_tokens = load_vocab(
# 加载并精简词表,建立阅读理解词表
dict_path=f"{current_path}/models/tokenizer/vocab.txt",
simplified=True,
startswith=['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]'],
)
2022-12-08 15:16:57 +08:00
reading_tokenizer = Tokenizer(token_dict, do_lower_case=True) # 阅读理解分词器
cls_model = load_model(f"{current_path}/models/text_classifier/best_model.h5") # 加载分类模型
2022-12-07 10:49:21 +08:00
gen_model = ArticleCompletion(
# 加载文本生成模型
start_id=None,
end_id=511, # 511是中文句号
maxlen=256,
minlen=128,
config_path=f"{current_path}/models/nezha_gpt/config.json",
ckpt_path=f"{current_path}/models/nezha_gpt/gpt.ckpt"
)
chatbot = build_chat_model(
# 加载对话模型
2022-12-08 15:16:57 +08:00
f"{current_path}/models/nezha_gpt_dialog/",
2022-12-07 10:49:21 +08:00
dialog_tokenizer)
# translator, trans_data = load_translator(
# # 加载翻译模型
# f"{current_path}/models/translator/translation.h5",
# f"{current_path}/data/translator/train.txt",
# f"{current_path}/data/translator/dev.txt"
# )
reading_model = build_reading_model(
# 加载阅读理解模型
2022-12-08 15:16:57 +08:00
f"{current_path}/models/nezha_gpt/config.json",
2022-12-07 10:49:21 +08:00
f"{current_path}/models/nezha_gpt/gpt.ckpt",
keep_tokens,
f"{current_path}/models/qa/best_model.weights"
2022-12-08 15:16:57 +08:00
)
TEXT = "text"
2022-12-07 10:49:21 +08:00
app = Flask(__name__)
2022-12-08 15:16:57 +08:00
2022-12-07 10:49:21 +08:00
@app.route('/text_cls/', methods=["POST"])
def run_cls():
2022-12-08 15:16:57 +08:00
resp_info = dict()
2022-12-07 10:49:21 +08:00
if request.method == "POST":
text = request.form.get('text')
if text is not None and text != '':
2022-12-08 15:16:57 +08:00
rst = run_class(cls_model, general_tokenizer, text)
resp_info["code"] = 200
resp_info["data"] = rst
resp_info["dtype"] = TEXT
2022-12-07 10:49:21 +08:00
else:
2022-12-08 15:16:57 +08:00
resp_info["msg"] = "Input is None, please check !"
resp_info["code"] = 406
resp = make_response(json.dumps(resp_info))
resp.status_code = 200
return resp
2022-12-07 10:49:21 +08:00
@app.route('/text_gen/', methods=["POST"])
def run_gen():
2022-12-08 15:16:57 +08:00
resp_info = dict()
2022-12-07 10:49:21 +08:00
if request.method == "POST":
text = request.form.get('text')
logger.info(f"将对文本'{text}'进行续写")
if text != "":
rest = gen_model.generate(text, general_tokenizer)
logger.info(rest)
2022-12-08 15:16:57 +08:00
resp_info["code"] = 200
resp_info["data"] = rest
resp_info["dtype"] = TEXT
2022-12-07 10:49:21 +08:00
else:
2022-12-08 15:16:57 +08:00
resp_info["msg"] = "Input is None, please check !"
resp_info["code"] = 406
resp = make_response(json.dumps(resp_info))
resp.status_code = 200
return resp
2022-12-07 10:49:21 +08:00
@app.route('/chat/', methods=["POST"])
def run_chat():
2022-12-08 15:16:57 +08:00
resp_info = dict()
2022-12-07 10:49:21 +08:00
if request.method == "POST":
dialog_history = request.form.get("dialog_history")
2022-12-08 15:29:01 +08:00
dialog_history = dialog_history.rstrip('').split('')
2022-12-07 10:49:21 +08:00
logger.info(f"将对文本'{dialog_history}'进行对话")
if len(dialog_history) > 0:
rest = chatbot.response(dialog_history)
logger.info(rest)
2022-12-08 15:16:57 +08:00
resp_info["code"] = 200
resp_info["data"] = rest
resp_info["dtype"] = TEXT
2022-12-07 10:49:21 +08:00
else:
2022-12-08 15:16:57 +08:00
resp_info["msg"] = "请将历史对话以中文句号分隔"
resp_info["code"] = 406
resp = make_response(json.dumps(resp_info))
resp.status_code = 200
return resp
2022-12-07 10:49:21 +08:00
# @app.route('/translate/', methods=["POST"])
# def run_translate():
# resp = make_response()
# if request.method == "POST":
# text = request.json.get('text')
# if text is None or text.strip() == "":
# resp.status_code = 406
# return resp
# rest = run_translator(text, translator, trans_data)
# resp.status_code = 200
# resp.response = rest
# return resp
# else:
# resp.status_code=405
# return resp
@app.route('/simi/', methods=["POST"])
def run_match():
2022-12-08 15:16:57 +08:00
resp_info = dict()
2022-12-07 10:49:21 +08:00
if request.method == "POST":
src = request.form.get('text_1')
tgt = request.form.get('text_2')
2022-12-08 15:16:57 +08:00
resp_info["code"] = 200
resp_info["data"] = str(text_simi(src, tgt))
resp_info["dtype"] = TEXT
resp = make_response(json.dumps(resp_info))
resp.status_code = 200
return resp
2022-12-07 10:49:21 +08:00
@app.route('/dependency/', methods=["POST"])
def run_depend():
2022-12-08 15:16:57 +08:00
resp_info = dict()
2022-12-07 10:49:21 +08:00
if request.method == "POST":
text = request.form.get('text')
if text is None or text.strip() == "":
2022-12-08 15:16:57 +08:00
resp_info["msg"] = "Input is None, please check !"
resp_info["code"] = 406
else:
resp_info["code"] = 200
resp_info["data"] = str(text_analysis(text))
resp_info["dtype"] = TEXT
resp = make_response(json.dumps(resp_info))
resp.status_code = 200
return resp
2022-12-07 10:49:21 +08:00
@app.route('/reading/', methods=["POST"])
def run_reading():
2022-12-08 15:16:57 +08:00
resp_info = dict()
2022-12-07 10:49:21 +08:00
if request.method == "POST":
question = request.form.get("question")
passages = request.form.get("passages")
2022-12-08 15:29:01 +08:00
passages = [x.strip() + '' for x in passages.split('')]
2022-12-08 15:16:57 +08:00
if question is None or question.strip() == "":
resp_info["msg"] = "Question is None, please check!"
resp_info["code"] = 406
elif len(passages) == 0:
resp_info["msg"] = "请将文章段落以中文句号分隔"
resp_info["code"] = 406
else:
rest = gen_answer(question, passages, reading_model, reading_tokenizer)
resp_info["code"] = 200
resp_info["data"] = rest
resp_info["dtype"] = TEXT
resp = make_response(json.dumps(resp_info))
resp.status_code = 200
return resp
2022-12-07 10:49:21 +08:00
if __name__ == '__main__':
2022-12-08 15:29:01 +08:00
app.run(host='0.0.0.0', port=8903, debug=False)