修改response结构
This commit is contained in:
parent
8c33ecc8ca
commit
eedd13fbea
|
@ -3,6 +3,7 @@
|
|||
# 测试脚本
|
||||
# 测试环境:tensorflow 2.5.3 + keras 2.3.1 + bert4keras 0.11
|
||||
import os
|
||||
|
||||
os.environ['TF_KERAS'] = "1"
|
||||
|
||||
import numpy as np
|
||||
|
@ -15,6 +16,7 @@ from bert4keras.snippets import AutoRegressiveDecoder
|
|||
class ChatBot(AutoRegressiveDecoder):
|
||||
"""基于随机采样对话机器人
|
||||
"""
|
||||
|
||||
def __init__(self, start_id, end_id, maxlen, model, tokenizer):
|
||||
super().__init__(start_id, end_id, maxlen)
|
||||
self.model = model
|
||||
|
@ -53,11 +55,9 @@ def build_chat_model(model_path, tokenizer):
|
|||
chatbot = ChatBot(start_id=None, end_id=tokenizer._token_end_id, maxlen=32, model=model, tokenizer=tokenizer)
|
||||
return chatbot
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tokenizer = Tokenizer("../models/nezha_gpt_dialog/vocab.txt", do_lower_case=True)
|
||||
chatbot = build_chat_model('../models/nezha_gpt_dialog', tokenizer)
|
||||
text_list = ['绿遍山原白满川', '子规声里雨如烟']
|
||||
print(chatbot.response(text_list))
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import hanlp
|
||||
from logzero import logger
|
||||
from hanlp_common.document import Document
|
||||
|
||||
tok = hanlp.load('./.hanlp/tok/coarse_electra_small_20220616_012050/')
|
||||
dep = hanlp.load('./.hanlp/dep/ctb9_dep_electra_small_20220216_100306/')
|
||||
sts = hanlp.load('./.hanlp/sts/sts_electra_base_zh_20210530_200109/')
|
||||
|
@ -10,13 +11,14 @@ def text_analysis(text):
|
|||
segments = tok(text)
|
||||
logger.info(segments)
|
||||
doc = Document(
|
||||
tok=segments,
|
||||
dep=dep(segments, conll=False),
|
||||
tok=segments,
|
||||
dep=dep(segments, conll=False),
|
||||
)
|
||||
rst = doc.to_pretty()
|
||||
logger.info(rst)
|
||||
return rst
|
||||
|
||||
|
||||
def text_simi(src, tgt):
|
||||
score = sts([(src, tgt)])[0]
|
||||
result = ["negative", "positive"][round(score)]
|
||||
|
@ -25,4 +27,3 @@ def text_simi(src, tgt):
|
|||
|
||||
if __name__ == '__main__':
|
||||
print(text_analysis("台湾省是中国不可分割的一部分。"))
|
||||
|
|
@ -3,6 +3,7 @@
|
|||
# (Accuracy=0.7282149325820084 F1=0.8207266829447049 Final=0.7744708077633566)
|
||||
|
||||
import json, os, re
|
||||
|
||||
os.environ['TF_KERAS'] = '1'
|
||||
import numpy as np
|
||||
from bert4keras.backend import keras, K
|
||||
|
@ -16,9 +17,10 @@ from keras.models import Model
|
|||
from tqdm import tqdm
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
config = tf.compat.v1.ConfigProto()
|
||||
config.gpu_options.allow_growth=True # 按需分配显存
|
||||
tf_session = tf.compat.v1.Session(graph=tf.compat.v1.get_default_graph(),config=config)
|
||||
config.gpu_options.allow_growth = True # 按需分配显存
|
||||
tf_session = tf.compat.v1.Session(graph=tf.compat.v1.get_default_graph(), config=config)
|
||||
tf.compat.v1.keras.backend.set_session(tf_session)
|
||||
|
||||
max_p_len = 256
|
||||
|
@ -27,6 +29,7 @@ max_a_len = 32
|
|||
batch_size = 32
|
||||
epochs = 10
|
||||
|
||||
|
||||
# # bert配置
|
||||
# config_path = '../models/nezha_gpt/config.json'
|
||||
# checkpoint_path = '../models/nezha_gpt/gpt.ckpt'
|
||||
|
@ -113,8 +116,7 @@ def masked_cross_entropy(y_true, y_pred):
|
|||
return cross_entropy
|
||||
|
||||
|
||||
|
||||
def build_reading_model(config_path:str, ckpt_path:str, keep_tokens:str, weight_path:str):
|
||||
def build_reading_model(config_path: str, ckpt_path: str, keep_tokens: str, weight_path: str):
|
||||
model = build_transformer_model(
|
||||
config_path,
|
||||
ckpt_path,
|
||||
|
@ -127,6 +129,7 @@ def build_reading_model(config_path:str, ckpt_path:str, keep_tokens:str, weight_
|
|||
model.load_weights(weight_path)
|
||||
return model
|
||||
|
||||
|
||||
def get_ngram_set(x, n):
|
||||
"""生成ngram合集,返回结果格式是:
|
||||
{(n-1)-gram: set([n-gram的第n个字集合])}
|
||||
|
@ -176,7 +179,7 @@ def gen_answer(question, passages, model, tokenizer):
|
|||
if a:
|
||||
results[a] = results.get(a, []) + [score]
|
||||
results = {
|
||||
k: (np.array(v)**2).sum() / (sum(v) + 1)
|
||||
k: (np.array(v) ** 2).sum() / (sum(v) + 1)
|
||||
for k, v in results.items()
|
||||
}
|
||||
return results
|
||||
|
@ -221,8 +224,9 @@ if __name__ == '__main__':
|
|||
model = build_reading_model()
|
||||
model.load_weights('../models/qa/best_model.weights')
|
||||
questions = "嬴政出生在哪里?"
|
||||
passages = ["秦始皇嬴政(前259年—前210年),嬴姓,赵氏 ,名政(一说名“正”),又称赵政 、祖龙 ,也有吕政一说(详见“人物争议-姓名之争”目录)。秦庄襄王和赵姬之子。中国古代杰出的政治家、战略家、改革家,首次完成中国大一统的政治人物,也是中国第一个称皇帝的君主。",
|
||||
"公元前221年,秦统一六国之后,秦王嬴政认为自己“德兼三皇,功过五帝”,遂采用三皇之“皇”、五帝之“帝”构成“皇帝”的称号,是中国历史上第一个使用“皇帝”称号的君主,所以自称“始皇帝”。",
|
||||
"秦始皇有二十余子。长子扶苏,少子胡亥。",
|
||||
"嬴政出生在当时赵国的邯郸廓城(在今城内中街以东,丛台西南的朱家巷一带),是当时的秦国王孙异人之子。"]
|
||||
print(gen_answer(questions, passages))
|
||||
passages = [
|
||||
"秦始皇嬴政(前259年—前210年),嬴姓,赵氏 ,名政(一说名“正”),又称赵政 、祖龙 ,也有吕政一说(详见“人物争议-姓名之争”目录)。秦庄襄王和赵姬之子。中国古代杰出的政治家、战略家、改革家,首次完成中国大一统的政治人物,也是中国第一个称皇帝的君主。",
|
||||
"公元前221年,秦统一六国之后,秦王嬴政认为自己“德兼三皇,功过五帝”,遂采用三皇之“皇”、五帝之“帝”构成“皇帝”的称号,是中国历史上第一个使用“皇帝”称号的君主,所以自称“始皇帝”。",
|
||||
"秦始皇有二十余子。长子扶苏,少子胡亥。",
|
||||
"嬴政出生在当时赵国的邯郸廓城(在今城内中街以东,丛台西南的朱家巷一带),是当时的秦国王孙异人之子。"]
|
||||
print(gen_answer(questions, passages))
|
||||
|
|
|
@ -3,7 +3,7 @@ import os
|
|||
|
||||
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" #(保证程序cuda序号与实际cuda序号对应)
|
||||
# os.environ['CUDA_VISIBLE_DEVICES'] = "0, 1" #(代表仅使用第0,1号GPU)
|
||||
os.environ['TF_KERAS'] ='1'
|
||||
os.environ['TF_KERAS'] = '1'
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
@ -18,20 +18,19 @@ from bert4keras.snippets import open
|
|||
from keras.layers import Lambda, Dense
|
||||
|
||||
config = tf.compat.v1.ConfigProto()
|
||||
config.gpu_options.allow_growth=True # 按需分配显存
|
||||
tf_session = tf.compat.v1.Session(graph=tf.compat.v1.get_default_graph(),config=config)
|
||||
config.gpu_options.allow_growth = True # 按需分配显存
|
||||
tf_session = tf.compat.v1.Session(graph=tf.compat.v1.get_default_graph(), config=config)
|
||||
tf.compat.v1.keras.backend.set_session(tf_session)
|
||||
|
||||
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
|
||||
set_gelu('tanh') # 切换gelu版本
|
||||
|
||||
MAX_LEN = 128
|
||||
BATCH_SIZE = 32
|
||||
AdamLR = extend_with_piecewise_linear_lr(Adam, name='AdamLR')
|
||||
|
||||
|
||||
# config_path = '/home/zhaojh/pretrain_models/roberta_base/bert_config.json'
|
||||
# checkpoint_path = '/home/zhaojh/pretrain_models/roberta_base/bert_model.ckpt'
|
||||
# dict_path = '/home/zhaojh/pretrain_models/roberta_base/vocab.txt'
|
||||
|
@ -161,7 +160,7 @@ def evaluate(data, model):
|
|||
# return model
|
||||
|
||||
|
||||
def run_cls(model, tokenizer, test_text:str, labels=["negative", "positive"]):
|
||||
def run_cls(model, tokenizer, test_text: str, labels=["negative", "positive"]):
|
||||
token_ids, segment_ids = tokenizer.encode(test_text, maxlen=MAX_LEN)
|
||||
tok_ids = sequence_padding([token_ids])
|
||||
seg_ids = sequence_padding([segment_ids])
|
||||
|
@ -176,7 +175,7 @@ def run_cls(model, tokenizer, test_text:str, labels=["negative", "positive"]):
|
|||
# model.save('../models/text_classifier/best_model.h5')
|
||||
# return model, labels
|
||||
|
||||
def load_cls_model(config_path:str, ckpt_path:str, weight_path:str):
|
||||
def load_cls_model(config_path: str, ckpt_path: str, weight_path: str):
|
||||
"""load local classification model
|
||||
|
||||
Args:
|
||||
|
@ -191,10 +190,12 @@ def load_cls_model(config_path:str, ckpt_path:str, weight_path:str):
|
|||
model.load_weights(weight_path)
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from keras.models import load_model
|
||||
|
||||
dict_path = '../models/tokenizer/vocab.txt'
|
||||
tokenizer = Tokenizer(dict_path, do_lower_case=True)
|
||||
model = load_model('../models/text_classifier/best_model.h5')
|
||||
rst = run_cls(model, tokenizer, "这部电影太棒了")
|
||||
print(rst)
|
||||
print(rst)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
#! -*- coding: utf-8 -*-
|
||||
import os
|
||||
|
||||
os.environ["TF_KERAS"] = "1"
|
||||
import numpy as np
|
||||
from bert4keras.models import build_transformer_model
|
||||
|
@ -10,6 +11,7 @@ from bert4keras.snippets import AutoRegressiveDecoder
|
|||
class ArticleCompletion(AutoRegressiveDecoder):
|
||||
"""基于随机采样的文章续写
|
||||
"""
|
||||
|
||||
def __init__(self, start_id, end_id, maxlen, minlen, config_path, ckpt_path):
|
||||
super().__init__(start_id, end_id, maxlen, minlen)
|
||||
self.model = build_transformer_model(
|
||||
|
@ -51,4 +53,4 @@ if __name__ == '__main__':
|
|||
ckpt_path="../models/nezha_gpt/gpt.ckpt"
|
||||
)
|
||||
tokenizer = Tokenizer(f"../models/tokenizer/vocab.txt", do_lower_case=True)
|
||||
print(article_completion.generate(u'中国科学院青岛生物能源与过程研究所泛能源大数据与战略研究中心', tokenizer))
|
||||
print(article_completion.generate(u'中国科学院青岛生物能源与过程研究所泛能源大数据与战略研究中心', tokenizer))
|
||||
|
|
146
run.py
146
run.py
|
@ -4,10 +4,10 @@ import sys
|
|||
import json
|
||||
from flask import Flask, request, make_response
|
||||
from logzero import logger
|
||||
# current_path = os.path.dirname(os.path.abspath(__file__)) # for local
|
||||
current_path = "/app" # for docker
|
||||
logger.info(f"{current_path}")
|
||||
|
||||
# current_path = os.path.dirname(os.path.abspath(__file__)) # for local
|
||||
current_path = "/app" # for docker
|
||||
logger.info(f"{current_path}")
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
|
@ -16,7 +16,6 @@ os.environ["TF_KERAS"] = "1"
|
|||
os.environ["CUDA_DEVICE_ORDER"] = 'PCI_BUS_ID'
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
|
||||
|
||||
|
||||
from bert4keras.tokenizers import Tokenizer, load_vocab
|
||||
from keras.models import load_model
|
||||
|
||||
|
@ -27,9 +26,8 @@ from nlp.text_classification import run_cls as run_class, AdamLR
|
|||
from nlp.hanlp_tools import text_analysis, text_simi
|
||||
from nlp.reading import build_reading_model, gen_answer
|
||||
|
||||
|
||||
general_tokenizer = Tokenizer(f"{current_path}/models/tokenizer/vocab.txt", do_lower_case=True) # 通用分词器
|
||||
dialog_tokenizer = Tokenizer(f"{current_path}/models/nezha_gpt_dialog/vocab.txt", do_lower_case=True) # 对话分词器
|
||||
general_tokenizer = Tokenizer(f"{current_path}/models/tokenizer/vocab.txt", do_lower_case=True) # 通用分词器
|
||||
dialog_tokenizer = Tokenizer(f"{current_path}/models/nezha_gpt_dialog/vocab.txt", do_lower_case=True) # 对话分词器
|
||||
|
||||
token_dict, keep_tokens = load_vocab(
|
||||
# 加载并精简词表,建立阅读理解词表
|
||||
|
@ -37,8 +35,8 @@ token_dict, keep_tokens = load_vocab(
|
|||
simplified=True,
|
||||
startswith=['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]'],
|
||||
)
|
||||
reading_tokenizer = Tokenizer(token_dict, do_lower_case=True) # 阅读理解分词器
|
||||
cls_model = load_model(f"{current_path}/models/text_classifier/best_model.h5") # 加载分类模型
|
||||
reading_tokenizer = Tokenizer(token_dict, do_lower_case=True) # 阅读理解分词器
|
||||
cls_model = load_model(f"{current_path}/models/text_classifier/best_model.h5") # 加载分类模型
|
||||
gen_model = ArticleCompletion(
|
||||
# 加载文本生成模型
|
||||
start_id=None,
|
||||
|
@ -50,7 +48,7 @@ gen_model = ArticleCompletion(
|
|||
)
|
||||
chatbot = build_chat_model(
|
||||
# 加载对话模型
|
||||
f"{current_path}/models/nezha_gpt_dialog/",
|
||||
f"{current_path}/models/nezha_gpt_dialog/",
|
||||
dialog_tokenizer)
|
||||
# translator, trans_data = load_translator(
|
||||
# # 加载翻译模型
|
||||
|
@ -60,55 +58,58 @@ chatbot = build_chat_model(
|
|||
# )
|
||||
reading_model = build_reading_model(
|
||||
# 加载阅读理解模型
|
||||
f"{current_path}/models/nezha_gpt/config.json",
|
||||
f"{current_path}/models/nezha_gpt/config.json",
|
||||
f"{current_path}/models/nezha_gpt/gpt.ckpt",
|
||||
keep_tokens,
|
||||
f"{current_path}/models/qa/best_model.weights"
|
||||
)
|
||||
)
|
||||
|
||||
TEXT = "text"
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
|
||||
@app.route('/text_cls/', methods=["POST"])
|
||||
def run_cls():
|
||||
resp = make_response()
|
||||
resp_info = dict()
|
||||
if request.method == "POST":
|
||||
text = request.form.get('text')
|
||||
if text is not None and text != '':
|
||||
resp.response = run_class(cls_model, general_tokenizer, text)
|
||||
resp.status_code = 200
|
||||
return resp
|
||||
rst = run_class(cls_model, general_tokenizer, text)
|
||||
resp_info["code"] = 200
|
||||
resp_info["data"] = rst
|
||||
resp_info["dtype"] = TEXT
|
||||
else:
|
||||
resp.status_code = 406
|
||||
return resp
|
||||
else:
|
||||
resp.status_code=405
|
||||
return resp
|
||||
resp_info["msg"] = "Input is None, please check !"
|
||||
resp_info["code"] = 406
|
||||
resp = make_response(json.dumps(resp_info))
|
||||
resp.status_code = 200
|
||||
return resp
|
||||
|
||||
|
||||
@app.route('/text_gen/', methods=["POST"])
|
||||
def run_gen():
|
||||
resp = make_response()
|
||||
resp_info = dict()
|
||||
if request.method == "POST":
|
||||
text = request.form.get('text')
|
||||
logger.info(f"将对文本'{text}'进行续写")
|
||||
if text != "":
|
||||
rest = gen_model.generate(text, general_tokenizer)
|
||||
logger.info(rest)
|
||||
resp.response = rest
|
||||
resp.status_code = 200
|
||||
return resp
|
||||
resp_info["code"] = 200
|
||||
resp_info["data"] = rest
|
||||
resp_info["dtype"] = TEXT
|
||||
else:
|
||||
resp.status_code = 406
|
||||
return resp
|
||||
else:
|
||||
resp.status_code=405
|
||||
return resp
|
||||
resp_info["msg"] = "Input is None, please check !"
|
||||
resp_info["code"] = 406
|
||||
resp = make_response(json.dumps(resp_info))
|
||||
resp.status_code = 200
|
||||
return resp
|
||||
|
||||
|
||||
@app.route('/chat/', methods=["POST"])
|
||||
def run_chat():
|
||||
# todo: 这个模块可以用grpc流式服务做。
|
||||
# 如果用flask 就把历史对话都按照list的方式传进来
|
||||
# 历史对话可以用json传
|
||||
resp = make_response()
|
||||
resp_info = dict()
|
||||
if request.method == "POST":
|
||||
dialog_history = request.form.get("dialog_history")
|
||||
dialog_history = dialog_history.split('。')
|
||||
|
@ -116,15 +117,16 @@ def run_chat():
|
|||
if len(dialog_history) > 0:
|
||||
rest = chatbot.response(dialog_history)
|
||||
logger.info(rest)
|
||||
resp.response = rest
|
||||
resp.status_code = 200
|
||||
return resp
|
||||
resp_info["code"] = 200
|
||||
resp_info["data"] = rest
|
||||
resp_info["dtype"] = TEXT
|
||||
else:
|
||||
resp.status_code = 406
|
||||
return resp
|
||||
else:
|
||||
resp.status_code=405
|
||||
return resp
|
||||
resp_info["msg"] = "请将历史对话以中文句号分隔"
|
||||
resp_info["code"] = 406
|
||||
resp = make_response(json.dumps(resp_info))
|
||||
resp.status_code = 200
|
||||
return resp
|
||||
|
||||
|
||||
# @app.route('/translate/', methods=["POST"])
|
||||
# def run_translate():
|
||||
|
@ -144,49 +146,57 @@ def run_chat():
|
|||
|
||||
@app.route('/simi/', methods=["POST"])
|
||||
def run_match():
|
||||
resp = make_response()
|
||||
resp_info = dict()
|
||||
|
||||
if request.method == "POST":
|
||||
src = request.form.get('text_1')
|
||||
tgt = request.form.get('text_2')
|
||||
resp.response = str(text_simi(src, tgt))
|
||||
resp.status_code = 200
|
||||
return resp
|
||||
else:
|
||||
resp.status_code=405
|
||||
return resp
|
||||
resp_info["code"] = 200
|
||||
resp_info["data"] = str(text_simi(src, tgt))
|
||||
resp_info["dtype"] = TEXT
|
||||
resp = make_response(json.dumps(resp_info))
|
||||
resp.status_code = 200
|
||||
return resp
|
||||
|
||||
|
||||
@app.route('/dependency/', methods=["POST"])
|
||||
def run_depend():
|
||||
resp = make_response()
|
||||
resp_info = dict()
|
||||
if request.method == "POST":
|
||||
text = request.form.get('text')
|
||||
if text is None or text.strip() == "":
|
||||
resp.status_code=406
|
||||
return resp
|
||||
resp.response = str(text_analysis(text))
|
||||
resp.status_code = 200
|
||||
return resp
|
||||
else:
|
||||
resp.status_code=405
|
||||
return resp
|
||||
resp_info["msg"] = "Input is None, please check !"
|
||||
resp_info["code"] = 406
|
||||
else:
|
||||
resp_info["code"] = 200
|
||||
resp_info["data"] = str(text_analysis(text))
|
||||
resp_info["dtype"] = TEXT
|
||||
resp = make_response(json.dumps(resp_info))
|
||||
resp.status_code = 200
|
||||
return resp
|
||||
|
||||
|
||||
@app.route('/reading/', methods=["POST"])
|
||||
def run_reading():
|
||||
resp = make_response()
|
||||
resp_info = dict()
|
||||
if request.method == "POST":
|
||||
question = request.form.get("question")
|
||||
passages = request.form.get("passages")
|
||||
passages = [x + '。' for x in passages.split('。')]
|
||||
if len(passages) == 0 or question is None or question.strip() == "":
|
||||
resp.status_code=406
|
||||
return resp
|
||||
resp.response = json.dumps(gen_answer(question, passages, reading_model, reading_tokenizer))
|
||||
resp.status_code = 200
|
||||
return resp
|
||||
else:
|
||||
resp.status_code=405
|
||||
return resp
|
||||
if question is None or question.strip() == "":
|
||||
resp_info["msg"] = "Question is None, please check!"
|
||||
resp_info["code"] = 406
|
||||
elif len(passages) == 0:
|
||||
resp_info["msg"] = "请将文章段落以中文句号分隔"
|
||||
resp_info["code"] = 406
|
||||
else:
|
||||
rest = gen_answer(question, passages, reading_model, reading_tokenizer)
|
||||
resp_info["code"] = 200
|
||||
resp_info["data"] = rest
|
||||
resp_info["dtype"] = TEXT
|
||||
resp = make_response(json.dumps(resp_info))
|
||||
resp.status_code = 200
|
||||
return resp
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Loading…
Reference in New Issue