ai_platform_nlu/run.py

194 lines
6.1 KiB
Python

# -*-coding:utf-8-*-
import os
import sys
import json
from flask import Flask, request, make_response
from logzero import logger
# current_path = os.path.dirname(os.path.abspath(__file__)) # for local
current_path = "/app" # for docker
logger.info(f"{current_path}")
app = Flask(__name__)
sys.path.append(f"{current_path}/nlp/")
os.environ["TF_KERAS"] = "1"
os.environ["CUDA_DEVICE_ORDER"] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
from bert4keras.tokenizers import Tokenizer, load_vocab
from keras.models import load_model
from nlp.text_gen import ArticleCompletion
from nlp.chat import build_chat_model
from nlp.text_classification import run_cls as run_class, AdamLR
# from utils.translate import load_model as load_translator, run_test as run_translator
from nlp.hanlp_tools import text_analysis, text_simi
from nlp.reading import build_reading_model, gen_answer
general_tokenizer = Tokenizer(f"{current_path}/models/tokenizer/vocab.txt", do_lower_case=True) # 通用分词器
dialog_tokenizer = Tokenizer(f"{current_path}/models/nezha_gpt_dialog/vocab.txt", do_lower_case=True) # 对话分词器
token_dict, keep_tokens = load_vocab(
# 加载并精简词表,建立阅读理解词表
dict_path=f"{current_path}/models/tokenizer/vocab.txt",
simplified=True,
startswith=['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]'],
)
reading_tokenizer = Tokenizer(token_dict, do_lower_case=True) # 阅读理解分词器
cls_model = load_model(f"{current_path}/models/text_classifier/best_model.h5") # 加载分类模型
gen_model = ArticleCompletion(
# 加载文本生成模型
start_id=None,
end_id=511, # 511是中文句号
maxlen=256,
minlen=128,
config_path=f"{current_path}/models/nezha_gpt/config.json",
ckpt_path=f"{current_path}/models/nezha_gpt/gpt.ckpt"
)
chatbot = build_chat_model(
# 加载对话模型
f"{current_path}/models/nezha_gpt_dialog/",
dialog_tokenizer)
# translator, trans_data = load_translator(
# # 加载翻译模型
# f"{current_path}/models/translator/translation.h5",
# f"{current_path}/data/translator/train.txt",
# f"{current_path}/data/translator/dev.txt"
# )
reading_model = build_reading_model(
# 加载阅读理解模型
f"{current_path}/models/nezha_gpt/config.json",
f"{current_path}/models/nezha_gpt/gpt.ckpt",
keep_tokens,
f"{current_path}/models/qa/best_model.weights"
)
app = Flask(__name__)
@app.route('/text_cls/', methods=["POST"])
def run_cls():
resp = make_response()
if request.method == "POST":
text = request.form.get('text')
if text is not None and text != '':
resp.response = run_class(cls_model, general_tokenizer, text)
resp.status_code = 200
return resp
else:
resp.status_code = 406
return resp
else:
resp.status_code=405
return resp
@app.route('/text_gen/', methods=["POST"])
def run_gen():
resp = make_response()
if request.method == "POST":
text = request.form.get('text')
logger.info(f"将对文本'{text}'进行续写")
if text != "":
rest = gen_model.generate(text, general_tokenizer)
logger.info(rest)
resp.response = rest
resp.status_code = 200
return resp
else:
resp.status_code = 406
return resp
else:
resp.status_code=405
return resp
@app.route('/chat/', methods=["POST"])
def run_chat():
# todo: 这个模块可以用grpc流式服务做。
# 如果用flask 就把历史对话都按照list的方式传进来
# 历史对话可以用json传
resp = make_response()
if request.method == "POST":
dialog_history = request.form.get("dialog_history")
dialog_history = dialog_history.split('')
logger.info(f"将对文本'{dialog_history}'进行对话")
if len(dialog_history) > 0:
rest = chatbot.response(dialog_history)
logger.info(rest)
resp.response = rest
resp.status_code = 200
return resp
else:
resp.status_code = 406
return resp
else:
resp.status_code=405
return resp
# @app.route('/translate/', methods=["POST"])
# def run_translate():
# resp = make_response()
# if request.method == "POST":
# text = request.json.get('text')
# if text is None or text.strip() == "":
# resp.status_code = 406
# return resp
# rest = run_translator(text, translator, trans_data)
# resp.status_code = 200
# resp.response = rest
# return resp
# else:
# resp.status_code=405
# return resp
@app.route('/simi/', methods=["POST"])
def run_match():
resp = make_response()
if request.method == "POST":
src = request.form.get('text_1')
tgt = request.form.get('text_2')
resp.response = str(text_simi(src, tgt))
resp.status_code = 200
return resp
else:
resp.status_code=405
return resp
@app.route('/dependency/', methods=["POST"])
def run_depend():
resp = make_response()
if request.method == "POST":
text = request.form.get('text')
if text is None or text.strip() == "":
resp.status_code=406
return resp
resp.response = str(text_analysis(text))
resp.status_code = 200
return resp
else:
resp.status_code=405
return resp
@app.route('/reading/', methods=["POST"])
def run_reading():
resp = make_response()
if request.method == "POST":
question = request.form.get("question")
passages = request.form.get("passages")
passages = [x + '' for x in passages.split('')]
if len(passages) == 0 or question is None or question.strip() == "":
resp.status_code=406
return resp
resp.response = json.dumps(gen_answer(question, passages, reading_model, reading_tokenizer))
resp.status_code = 200
return resp
else:
resp.status_code=405
return resp
if __name__ == '__main__':
app.run(host='0.0.0.0', port=8903, debug=True)