# -*-coding:utf-8-*- import os import sys from logzero import logger # current_path = os.path.dirname(__file__) # for local current_path = "/app" # for docker logger.info(f"{current_path}") sys.path.append(f"{current_path}/text2image/") sys.path.append(f"{current_path}/text2image/BigGAN_utils/") import json import base64 from PIL import Image from flask import Flask, request, make_response import cv2 from io import BytesIO import torch from mmpose.apis import init_pose_model from text2image.run_text2img import text2image from detection.detection import detector from segmentation.segment_pred import run_seg from ocr.ocr import run_tr from backbone.backbone_infer import run_backbone_infer from mnist.mnist_torch import run_mnist_infer, CNN DEVICE = 'cpu' TEXT = "text" BASE64_IMG = "base64_img" model_5x = torch.hub.load(f'{current_path}/detection/yolov5/', 'custom', path=f'{current_path}/detection/models/yolov5x.pt', source='local') model_5s = torch.hub.load(f'{current_path}/detection/yolov5/', 'custom', path=f'{current_path}/detection/models/yolov5s.pt', source='local') model_seg = torch.load(f'{current_path}/segmentation/models/best_model_pvgc.pth', map_location=DEVICE) pose_config_file = f'{current_path}/backbone/associative_embedding_hrnet_w32_coco_512x512.py' pose_ckpt_file = f'{current_path}/backbone/models/hrnet_w32_coco_512x512-bcb8c247_20200816.pth' pose_model = init_pose_model(pose_config_file, pose_ckpt_file, device=DEVICE) mnist_model = torch.load(f'{current_path}/mnist/models/MNIST_torch.pth', map_location=DEVICE) app = Flask(__name__) @app.route('/text2image/', methods=["POST"]) def run_text2img(): resp_info = dict() if request.method == "POST": text = request.form.get('text') logger.info(f"{text}") img = text2image(text) output_buffer = BytesIO() img.save(output_buffer, format='png') byte_data = output_buffer.getvalue() b64_code = base64.b64encode(byte_data).decode('utf-8') resp_info["code"] = 200 resp_info["data"] = b64_code resp_info["dtype"] = BASE64_IMG resp = make_response(resp_info) resp.status_code = 200 return resp @app.route('/detection/', methods=["POST"]) def run_detection(): resp_info = dict() if request.method == "POST": img = request.files.get('image') model_type = request.form.get('model_type') try: img = cv2.imread(img) except Exception as e: resp_info["msg"] = e resp_info["code"] = 406 else: if model_type.lower().strip() == 'yolov5x': rst, _ = detector(img, model_5x) else: rst, _ = detector(img, model_5s) logger.info(rst.shape) img_str = cv2.imencode('.png', rst)[1].tobytes() b64_code = base64.b64encode(img_str).decode('utf-8') resp_info["code"] = 200 resp_info["data"] = b64_code resp_info["dtype"] = BASE64_IMG resp = make_response(resp_info) resp.status_code = 200 return resp @app.route('/ocr/', methods=["POST"]) def run_ocr(): resp_info = dict() if request.method == "POST": img = request.files.get('image') try: img = cv2.imread(img) except Exception as e: resp_info["msg"] = e resp_info["code"] = 406 else: text = run_tr(img) resp_info["code"] = 200 resp_info["data"] = text resp_info["dtype"] = TEXT resp = make_response(resp_info) resp.status_code = 200 return resp @app.route('/segmentation/', methods=["POST"]) def run_segmentation(): resp_info = dict() if request.method == "POST": img_upload = request.files.get('image') try: img = cv2.imread(img_upload) except Exception as e: resp_info["msg"] = e resp_info["code"] = 406 else: result = run_seg(img, model_seg) img_str = cv2.imencode('.png', result)[1].tobytes() b64_code = base64.b64encode(img_str).decode('utf-8') resp_info["code"] = 200 resp_info["data"] = b64_code resp_info["dtype"] = BASE64_IMG resp = make_response(resp_info) resp.status_code = 200 return resp @app.route('/backbone/', methods=["POST"]) def run_backbone(): resp_info = dict() if request.method == "POST": img_upload = request.files.get('image') try: img = cv2.imread(img_upload) except Exception as e: resp_info["msg"] = e resp_info["code"] = 406 else: _, result = run_backbone_infer(img, pose_model) img_str = cv2.imencode('.png', result)[1].tobytes() b64_code = base64.b64encode(img_str).decode('utf-8') resp_info["code"] = 200 resp_info["data"] = b64_code resp_info["dtype"] = BASE64_IMG resp = make_response(resp_info) resp.status_code = 200 return resp @app.route('/mnist/', methods=["POST"]) def run_mnist(): resp_info = dict() if request.method == "POST": img_upload = request.files.get('image') try: img = cv2.imread(img_upload, 1) # 使用全局阈值,降噪 ret, th1 = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY) # 把opencv图像转化为PIL图像 im = Image.fromarray(cv2.cvtColor(th1, cv2.COLOR_BGR2RGB)) # 灰度化 im = im.convert('L') Im = im.resize((28, 28), Image.ANTIALIAS) except Exception as e: resp_info["msg"] = e resp_info["code"] = 406 else: result = run_mnist_infer(Im, mnist_model) resp_info["code"] = 200 resp_info["data"] = str(result) resp_info["dtype"] = TEXT resp = make_response(resp_info) resp.status_code = 200 return resp if __name__ == '__main__': app.run(host='0.0.0.0', port='8902')