From 7b4e850f9f7b115e681f27df7386ad2bbbca10f1 Mon Sep 17 00:00:00 2001 From: zhaojinghao Date: Thu, 8 Dec 2022 14:39:23 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9response=E7=BB=93=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- run.py | 179 ++++++++++++++++++++++++++++++--------------------------- 1 file changed, 95 insertions(+), 84 deletions(-) diff --git a/run.py b/run.py index 8685636..9ac667a 100644 --- a/run.py +++ b/run.py @@ -3,8 +3,9 @@ import os import sys from logzero import logger + # current_path = os.path.dirname(__file__) # for local -current_path = "/app" # for docker +current_path = "/app" # for docker logger.info(f"{current_path}") sys.path.append(f"{current_path}/text2image/") sys.path.append(f"{current_path}/text2image/BigGAN_utils/") @@ -24,20 +25,27 @@ from ocr.ocr import run_tr from backbone.backbone_infer import run_backbone_infer from mnist.mnist_torch import run_mnist_infer, CNN -DEVICE = 'cpu' -model_5x = torch.hub.load(f'{current_path}/detection/yolov5/','custom', path=f'{current_path}/detection/models/yolov5x.pt', source='local') -model_5s = torch.hub.load(f'{current_path}/detection/yolov5/','custom', path=f'{current_path}/detection/models/yolov5s.pt', source='local') +DEVICE = 'cpu' +TEXT = "text" +BASE64_IMG = "base64_img" + +model_5x = torch.hub.load(f'{current_path}/detection/yolov5/', 'custom', + path=f'{current_path}/detection/models/yolov5x.pt', source='local') +model_5s = torch.hub.load(f'{current_path}/detection/yolov5/', 'custom', + path=f'{current_path}/detection/models/yolov5s.pt', source='local') model_seg = torch.load(f'{current_path}/segmentation/models/best_model_pvgc.pth', map_location=DEVICE) pose_config_file = f'{current_path}/backbone/associative_embedding_hrnet_w32_coco_512x512.py' pose_ckpt_file = f'{current_path}/backbone/models/hrnet_w32_coco_512x512-bcb8c247_20200816.pth' pose_model = init_pose_model(pose_config_file, pose_ckpt_file, device=DEVICE) mnist_model = torch.load(f'{current_path}/mnist/models/MNIST_torch.pth', map_location=DEVICE) -app=Flask(__name__) +app = Flask(__name__) -@app.route('/text2image/',methods=["POST"]) + +@app.route('/text2image/', methods=["POST"]) def run_text2img(): + resp_info = dict() if request.method == "POST": text = request.form.get('text') logger.info(f"{text}") @@ -46,128 +54,131 @@ def run_text2img(): img.save(output_buffer, format='png') byte_data = output_buffer.getvalue() b64_code = base64.b64encode(byte_data).decode('utf-8') - resp = make_response(b64_code) - resp.status_code = 200 - return resp - else: - resp = make_response() - resp.status_code=405 - return resp + resp_info["code"] = 200 + resp_info["data"] = b64_code + resp_info["dtype"] = BASE64_IMG + resp = make_response(resp_info) + resp.status_code = 200 + return resp @app.route('/detection/', methods=["POST"]) def run_detection(): + resp_info = dict() if request.method == "POST": img = request.files.get('image') model_type = request.form.get('model_type') try: img = cv2.imread(img) - except: - resp = make_response() - resp.status_code = 406 - return resp - if model_type.lower().strip() == 'yolov5x': - rst, _ = detector(img, model_5x) + except Exception as e: + resp_info["msg"] = e + resp_info["code"] = 406 else: - rst, _ = detector(img, model_5s) - logger.info(rst.shape) - img_str = cv2.imencode('.png', rst)[1].tobytes() - b64_code = base64.b64encode(img_str).decode('utf-8') - resp = make_response(b64_code) - resp.status_code = 200 - return b64_code - else: - resp = make_response() - resp.status_code=405 - return resp + if model_type.lower().strip() == 'yolov5x': + rst, _ = detector(img, model_5x) + else: + rst, _ = detector(img, model_5s) + logger.info(rst.shape) + img_str = cv2.imencode('.png', rst)[1].tobytes() + b64_code = base64.b64encode(img_str).decode('utf-8') + resp_info["code"] = 200 + resp_info["data"] = b64_code + resp_info["dtype"] = BASE64_IMG + resp = make_response(resp_info) + resp.status_code = 200 + return resp + @app.route('/ocr/', methods=["POST"]) def run_ocr(): - resp = make_response() + resp_info = dict() if request.method == "POST": img = request.files.get('image') try: img = cv2.imread(img) - except: - resp.status_code = 406 - return resp - text = run_tr(img) - resp.status_code = 200 - resp.data = json.dumps({'result':text}) - return resp - else: - resp.status_code=405 - return resp + except Exception as e: + resp_info["msg"] = e + resp_info["code"] = 406 + else: + text = run_tr(img) + resp_info["code"] = 200 + resp_info["data"] = text + resp_info["dtype"] = TEXT + resp = make_response(resp_info) + resp.status_code = 200 + return resp @app.route('/segmentation/', methods=["POST"]) def run_segmentation(): + resp_info = dict() if request.method == "POST": img_upload = request.files.get('image') try: img = cv2.imread(img_upload) - except: - resp = make_response() - resp.status_code = 406 - return resp - result = run_seg(img, model_seg) - img_str = cv2.imencode('.png', result)[1].tobytes() - b64_code = base64.b64encode(img_str).decode('utf-8') - resp = make_response(b64_code) - resp.status_code = 200 - return resp - else: - resp = make_response() - resp.status_code=405 - return resp + except Exception as e: + resp_info["msg"] = e + resp_info["code"] = 406 + else: + result = run_seg(img, model_seg) + img_str = cv2.imencode('.png', result)[1].tobytes() + b64_code = base64.b64encode(img_str).decode('utf-8') + resp_info["code"] = 200 + resp_info["data"] = b64_code + resp_info["dtype"] = BASE64_IMG + resp = make_response(resp_info) + resp.status_code = 200 + return resp + @app.route('/backbone/', methods=["POST"]) def run_backbone(): + resp_info = dict() if request.method == "POST": img_upload = request.files.get('image') try: img = cv2.imread(img_upload) - except: - resp = make_response() - resp.status_code = 406 - return resp - _, result = run_backbone_infer(img, pose_model) - img_str = cv2.imencode('.png', result)[1].tobytes() - b64_code = base64.b64encode(img_str).decode('utf-8') - resp = make_response(b64_code) - resp.status_code = 200 - return resp - else: - resp = make_response() - resp.status_code=405 - return resp + except Exception as e: + resp_info["msg"] = e + resp_info["code"] = 406 + else: + _, result = run_backbone_infer(img, pose_model) + img_str = cv2.imencode('.png', result)[1].tobytes() + b64_code = base64.b64encode(img_str).decode('utf-8') + resp_info["code"] = 200 + resp_info["data"] = b64_code + resp_info["dtype"] = BASE64_IMG + resp = make_response(resp_info) + resp.status_code = 200 + return resp + @app.route('/mnist/', methods=["POST"]) def run_mnist(): + resp_info = dict() if request.method == "POST": img_upload = request.files.get('image') try: img = cv2.imread(img_upload, 1) # 使用全局阈值,降噪 - ret,th1 = cv2.threshold(img,127,255,cv2.THRESH_BINARY) + ret, th1 = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY) # 把opencv图像转化为PIL图像 - im = Image.fromarray(cv2.cvtColor(th1,cv2.COLOR_BGR2RGB)) + im = Image.fromarray(cv2.cvtColor(th1, cv2.COLOR_BGR2RGB)) # 灰度化 im = im.convert('L') Im = im.resize((28, 28), Image.ANTIALIAS) - except: - resp = make_response() - resp.status_code = 406 - return resp - result = run_mnist_infer(Im, mnist_model) - resp = make_response(str(result)) - resp.status_code = 200 - return resp - else: - resp = make_response() - resp.status_code=405 - return resp + except Exception as e: + resp_info["msg"] = e + resp_info["code"] = 406 + else: + result = run_mnist_infer(Im, mnist_model) + resp_info["code"] = 200 + resp_info["data"] = str(result) + resp_info["dtype"] = TEXT + resp = make_response(resp_info) + resp.status_code = 200 + return resp if __name__ == '__main__': - app.run(host='0.0.0.0', port='8902') \ No newline at end of file + app.run(host='0.0.0.0', port='8902')