update function
This commit is contained in:
parent
b8fe3320a9
commit
e872462870
22
Dockerfile
22
Dockerfile
|
@ -1,17 +1,21 @@
|
|||
FROM python:3.7.13-slim
|
||||
FROM pytorch/pytorch:1.12.0-cuda11.3-cudnn8-runtime
|
||||
|
||||
WORKDIR /app
|
||||
ADD . /app/
|
||||
ADD ./requirements.txt /app/
|
||||
|
||||
RUN pip install --upgrade pip -i https://pypi.tuna.tsinghua.edu.cn/simple --no-cache-dir
|
||||
RUN pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu113 --no-cache-dir
|
||||
# RUN pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu113 --no-cache-dir
|
||||
RUN pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple --no-cache-dir
|
||||
RUN mim install mmcv-full
|
||||
RUN pip install -e ./text2image/CLIP/
|
||||
RUN pip install -e ./backbone/mmpose/
|
||||
RUN pip install -e ./ocr/tr/
|
||||
RUN rm -rf ./text2image/CLIP/
|
||||
RUN rm -rf ./ocr/tr/
|
||||
RUN rm -rf ./backbone/mmpose/
|
||||
|
||||
ADD . /app
|
||||
RUN pip install -e ./text2image/CLIP/ --no-cache-dir
|
||||
RUN pip install -e ./backbone/mmpose/ --no-cache-dir
|
||||
RUN pip install -e ./ocr/tr/ --no-cache-dir
|
||||
# RUN rm -rf ./text2image/CLIP/
|
||||
# RUN rm -rf ./ocr/tr/
|
||||
# RUN rm -rf ./backbone/mmpose/
|
||||
RUN pip uninstall opencv-python -y
|
||||
RUN pip install opencv-python-headless --force-reinstall
|
||||
|
||||
CMD ["python3", "run.py"]
|
|
@ -1,6 +1,4 @@
|
|||
from mmpose.apis import inference_bottom_up_pose_model, vis_pose_result
|
||||
import cv2
|
||||
import os
|
||||
|
||||
def run_backbone_infer(img_ndarr, pose_model):
|
||||
# test a single image
|
||||
|
|
|
@ -1,6 +1,4 @@
|
|||
import torch
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# -*-coding:utf-8-*-
|
||||
def detector(img, model):
|
||||
"""_summary_
|
||||
|
||||
|
@ -17,6 +15,5 @@ def detector(img, model):
|
|||
return result.render()[0], result.pred[0].cpu().numpy()
|
||||
|
||||
if __name__ == '__main__':
|
||||
# model = torch.hub.load('/home/zhaojh/workspace/git_space/yolov5/', 'yolov5x', source='local', pretrained=True)
|
||||
pass
|
||||
|
|
@ -1,16 +1,15 @@
|
|||
from asyncio.log import logger
|
||||
from cmath import log
|
||||
# -*-coding:utf-8-*-
|
||||
from logzero import logger
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
from torch.utils.data import DataLoader
|
||||
import torch.optim as optim
|
||||
import tqdm
|
||||
import numpy as np
|
||||
|
||||
batch_size = 256
|
||||
random_seed = 1884
|
||||
DEVICE = torch.device('cuda')
|
||||
DEVICE = torch.device('cpu')
|
||||
torch.manual_seed(random_seed)
|
||||
|
||||
|
||||
|
@ -122,13 +121,4 @@ def run_mnist_infer(img, model:CNN):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# model=load_model('./models/MNIST_torch.pth')
|
||||
# a, b = load_data()
|
||||
# for i, l in a:
|
||||
# print(model(i)[0])
|
||||
# print(np.argmax(model(i).detach().numpy()[0]), l)
|
||||
# break
|
||||
# trl, tel = load_data()
|
||||
# cnn = train(10, 0.01, trl, tel)
|
||||
# torch.save(cnn, './models/MNIST_torch.pth')
|
||||
pass
|
|
@ -10,7 +10,6 @@ Pillow==9.2.0
|
|||
scipy==1.4.1
|
||||
six==1.15.0
|
||||
tqdm==4.64.0
|
||||
opencv-python==4.6.0.66
|
||||
seaborn==0.11.2
|
||||
segmentation-models-pytorch==0.2.1
|
||||
albumentations==1.2.1
|
||||
|
|
220
run.py
220
run.py
|
@ -2,118 +2,124 @@
|
|||
import os
|
||||
import sys
|
||||
from logzero import logger
|
||||
current_path = os.path.dirname(__file__)
|
||||
logger.info(current_path)
|
||||
|
||||
# current_path = os.path.dirname(__file__) # for local
|
||||
current_path = "/app" # for docker
|
||||
logger.info(f"{current_path}")
|
||||
sys.path.append(f"{current_path}/text2image/")
|
||||
sys.path.append(f"{current_path}/text2image/BigGAN_utils/")
|
||||
|
||||
import json
|
||||
import base64
|
||||
from PIL import Image
|
||||
from flask import Flask, request, make_response
|
||||
import cv2
|
||||
# from io import BytesIO
|
||||
# import torch
|
||||
from io import BytesIO
|
||||
import torch
|
||||
from mmpose.apis import init_pose_model
|
||||
# from text2image.run_text2img import text2image
|
||||
# from detection.detection import detector
|
||||
# from segmentation.segment_pred import run_seg
|
||||
# from ocr.ocr import run_tr
|
||||
from text2image.run_text2img import text2image
|
||||
from detection.detection import detector
|
||||
from segmentation.segment_pred import run_seg
|
||||
from ocr.ocr import run_tr
|
||||
from backbone.backbone_infer import run_backbone_infer
|
||||
from mnist.mnist_torch import run_mnist_infer, CNN
|
||||
|
||||
DEVICE = 'cpu'
|
||||
|
||||
# model_5x = torch.hub.load(f'{current_path}/detection/yolov5/','yolov5x', source='local', pretrained=True)
|
||||
# model_5s = torch.hub.load(f'{current_path}/detection/yolov5/','yolov5s', source='local', pretrained=True)
|
||||
# model_seg = torch.load(f'{current_path}/segmentation/models/best_model_pvgc.pth', map_location=DEVICE)
|
||||
model_5x = torch.hub.load(f'{current_path}/detection/yolov5/','custom', path=f'{current_path}/detection/models/yolov5x.pt', source='local')
|
||||
model_5s = torch.hub.load(f'{current_path}/detection/yolov5/','custom', path=f'{current_path}/detection/models/yolov5s.pt', source='local')
|
||||
model_seg = torch.load(f'{current_path}/segmentation/models/best_model_pvgc.pth', map_location=DEVICE)
|
||||
pose_config_file = f'{current_path}/backbone/associative_embedding_hrnet_w32_coco_512x512.py'
|
||||
pose_ckpt_file = f'{current_path}/backbone/models/hrnet_w32_coco_512x512-bcb8c247_20200816.pth'
|
||||
pose_model = init_pose_model(pose_config_file, pose_ckpt_file, device='cpu') # or device='cuda:0'
|
||||
pose_model = init_pose_model(pose_config_file, pose_ckpt_file, device=DEVICE)
|
||||
mnist_model = torch.load(f'{current_path}/mnist/models/MNIST_torch.pth', map_location=DEVICE)
|
||||
|
||||
app=Flask(__name__)
|
||||
|
||||
# @app.route('/text2image/',methods=["POST"])
|
||||
# def run_text2img():
|
||||
# if request.method == "POST":
|
||||
# text = request.form.get('text')
|
||||
# logger.info(f"{text}")
|
||||
# img = text2image(text)
|
||||
# output_buffer = BytesIO()
|
||||
# img.save(output_buffer, format='png')
|
||||
# byte_data = output_buffer.getvalue()
|
||||
# b64_code = base64.b64encode(byte_data).decode('utf-8')
|
||||
# resp = make_response(b64_code)
|
||||
# resp.status_code = 200
|
||||
# return resp
|
||||
# else:
|
||||
# resp = make_response()
|
||||
# resp.status_code=405
|
||||
# return resp
|
||||
@app.route('/text2image/',methods=["POST"])
|
||||
def run_text2img():
|
||||
if request.method == "POST":
|
||||
text = request.form.get('text')
|
||||
logger.info(f"{text}")
|
||||
img = text2image(text)
|
||||
output_buffer = BytesIO()
|
||||
img.save(output_buffer, format='png')
|
||||
byte_data = output_buffer.getvalue()
|
||||
b64_code = base64.b64encode(byte_data).decode('utf-8')
|
||||
resp = make_response(b64_code)
|
||||
resp.status_code = 200
|
||||
return resp
|
||||
else:
|
||||
resp = make_response()
|
||||
resp.status_code=405
|
||||
return resp
|
||||
|
||||
|
||||
# @app.route('/detection/', methods=["POST"])
|
||||
# def run_detection():
|
||||
# if request.method == "POST":
|
||||
# img = request.files.get('image')
|
||||
# model_type = request.form.get('model_type')
|
||||
# try:
|
||||
# img = cv2.imread(img)
|
||||
# except:
|
||||
# resp = make_response()
|
||||
# resp.status_code = 406
|
||||
# return resp
|
||||
# if model_type.lower().strip() == 'yolov5x':
|
||||
# rst, _ = detector(img, model_5x)
|
||||
# else:
|
||||
# rst, _ = detector(img, model_5s)
|
||||
# logger.info(rst.shape)
|
||||
# img_str = cv2.imencode('.png', rst)[1].tobytes()
|
||||
# b64_code = base64.b64encode(img_str).decode('utf-8')
|
||||
# resp = make_response(b64_code)
|
||||
# resp.status_code = 200
|
||||
# return b64_code
|
||||
# else:
|
||||
# resp = make_response()
|
||||
# resp.status_code=405
|
||||
# return resp
|
||||
@app.route('/detection/', methods=["POST"])
|
||||
def run_detection():
|
||||
if request.method == "POST":
|
||||
img = request.files.get('image')
|
||||
model_type = request.form.get('model_type')
|
||||
try:
|
||||
img = cv2.imread(img)
|
||||
except:
|
||||
resp = make_response()
|
||||
resp.status_code = 406
|
||||
return resp
|
||||
if model_type.lower().strip() == 'yolov5x':
|
||||
rst, _ = detector(img, model_5x)
|
||||
else:
|
||||
rst, _ = detector(img, model_5s)
|
||||
logger.info(rst.shape)
|
||||
img_str = cv2.imencode('.png', rst)[1].tobytes()
|
||||
b64_code = base64.b64encode(img_str).decode('utf-8')
|
||||
resp = make_response(b64_code)
|
||||
resp.status_code = 200
|
||||
return b64_code
|
||||
else:
|
||||
resp = make_response()
|
||||
resp.status_code=405
|
||||
return resp
|
||||
|
||||
# @app.route('/ocr/', methods=["POST"])
|
||||
# def run_ocr():
|
||||
# resp = make_response()
|
||||
# if request.method == "POST":
|
||||
# img = request.files.get('image')
|
||||
# try:
|
||||
# img = cv2.imread(img)
|
||||
# except:
|
||||
# resp.status_code = 406
|
||||
# return resp
|
||||
# text = run_tr(img)
|
||||
# resp.status_code = 200
|
||||
# resp.data = json.dumps({'result':text})
|
||||
# return resp
|
||||
# else:
|
||||
# resp.status_code=405
|
||||
# return resp
|
||||
@app.route('/ocr/', methods=["POST"])
|
||||
def run_ocr():
|
||||
resp = make_response()
|
||||
if request.method == "POST":
|
||||
img = request.files.get('image')
|
||||
try:
|
||||
img = cv2.imread(img)
|
||||
except:
|
||||
resp.status_code = 406
|
||||
return resp
|
||||
text = run_tr(img)
|
||||
resp.status_code = 200
|
||||
resp.data = json.dumps({'result':text})
|
||||
return resp
|
||||
else:
|
||||
resp.status_code=405
|
||||
return resp
|
||||
|
||||
|
||||
# @app.route('/segmentation/', methods=["POST"])
|
||||
# def run_segmentation():
|
||||
# if request.method == "POST":
|
||||
# img_upload = request.files.get('image')
|
||||
# try:
|
||||
# img = cv2.imread(img_upload)
|
||||
# except:
|
||||
# resp = make_response()
|
||||
# resp.status_code = 406
|
||||
# return resp
|
||||
# result = run_seg(img, model_seg)
|
||||
# img_str = cv2.imencode('.png', result)[1].tobytes()
|
||||
# b64_code = base64.b64encode(img_str).decode('utf-8')
|
||||
# resp = make_response(b64_code)
|
||||
# resp.status_code = 200
|
||||
# return resp
|
||||
# else:
|
||||
# resp = make_response()
|
||||
# resp.status_code=405
|
||||
# return resp
|
||||
@app.route('/segmentation/', methods=["POST"])
|
||||
def run_segmentation():
|
||||
if request.method == "POST":
|
||||
img_upload = request.files.get('image')
|
||||
try:
|
||||
img = cv2.imread(img_upload)
|
||||
except:
|
||||
resp = make_response()
|
||||
resp.status_code = 406
|
||||
return resp
|
||||
result = run_seg(img, model_seg)
|
||||
img_str = cv2.imencode('.png', result)[1].tobytes()
|
||||
b64_code = base64.b64encode(img_str).decode('utf-8')
|
||||
resp = make_response(b64_code)
|
||||
resp.status_code = 200
|
||||
return resp
|
||||
else:
|
||||
resp = make_response()
|
||||
resp.status_code=405
|
||||
return resp
|
||||
|
||||
@app.route('/backbone/', methods=["POST"])
|
||||
def run_backbone():
|
||||
|
@ -125,7 +131,7 @@ def run_backbone():
|
|||
resp = make_response()
|
||||
resp.status_code = 406
|
||||
return resp
|
||||
pose, result = run_backbone_infer(img, pose_model)
|
||||
_, result = run_backbone_infer(img, pose_model)
|
||||
img_str = cv2.imencode('.png', result)[1].tobytes()
|
||||
b64_code = base64.b64encode(img_str).decode('utf-8')
|
||||
resp = make_response(b64_code)
|
||||
|
@ -136,8 +142,32 @@ def run_backbone():
|
|||
resp.status_code=405
|
||||
return resp
|
||||
|
||||
@app.route('/mnist/', methods=["POST"])
|
||||
def run_mnist():
|
||||
if request.method == "POST":
|
||||
img_upload = request.files.get('image')
|
||||
try:
|
||||
img = cv2.imread(img_upload, 1)
|
||||
# 使用全局阈值,降噪
|
||||
ret,th1 = cv2.threshold(img,127,255,cv2.THRESH_BINARY)
|
||||
# 把opencv图像转化为PIL图像
|
||||
im = Image.fromarray(cv2.cvtColor(th1,cv2.COLOR_BGR2RGB))
|
||||
# 灰度化
|
||||
im = im.convert('L')
|
||||
Im = im.resize((28, 28), Image.ANTIALIAS)
|
||||
except:
|
||||
resp = make_response()
|
||||
resp.status_code = 406
|
||||
return resp
|
||||
result = run_mnist_infer(Im, mnist_model)
|
||||
resp = make_response(str(result))
|
||||
resp.status_code = 200
|
||||
return resp
|
||||
else:
|
||||
resp = make_response()
|
||||
resp.status_code=405
|
||||
return resp
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
img = cv2.imread('./1.jpg')
|
||||
pose, rst = run_backbone_infer(img, pose_model)
|
||||
cv2.imwrite('./1_bb.jpg', rst)
|
||||
app.run(host='0.0.0.0', port='8902')
|
|
@ -1,14 +1,8 @@
|
|||
from asyncio.log import logger
|
||||
import os
|
||||
import numpy as np
|
||||
import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
# -*-coding:utf-8-*-
|
||||
from logzero import logger
|
||||
import albumentations as albu
|
||||
import torch
|
||||
import segmentation_models_pytorch as smp
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data import Dataset as BaseDataset
|
||||
from PIL import Image
|
||||
|
||||
DEVICE = 'cpu'
|
||||
ENCODER = 'se_resnext50_32x4d'
|
||||
|
@ -20,7 +14,7 @@ ENCODER_WEIGHTS = 'imagenet'
|
|||
def get_validation_augmentation():
|
||||
"""调整图像使得图片的分辨率长宽能被32整除"""
|
||||
test_transform = [
|
||||
albu.PadIfNeeded(256, 256)
|
||||
albu.PadIfNeeded(1024, 1024)
|
||||
]
|
||||
return albu.Compose(test_transform)
|
||||
|
||||
|
@ -54,7 +48,7 @@ def run_seg(img, best_model):
|
|||
# ---------------------------------------------------------------
|
||||
img = augmentator(image=img)['image']
|
||||
img = preprocessor(image=img)['image']
|
||||
# 加载最佳模型
|
||||
logger.info(f"img.shape {img.shape}")
|
||||
x_tensor = torch.from_numpy(img).to(DEVICE).unsqueeze(0)
|
||||
pr_mask = best_model.predict(x_tensor)
|
||||
pr_mask = (pr_mask.squeeze().cpu().numpy().round())
|
||||
|
@ -62,5 +56,4 @@ def run_seg(img, best_model):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
best_model = torch.load('/home/zhaojh/workspace/computer_vision/segmentation/models/best_model_pvgc.pth', map_location=DEVICE)
|
||||
input_img = cv2.imread('/home/zhaojh/datasets/photovoltaic/PV03/PV03_Ground_Cropland/test/PV03_316626_1211836.bmp')
|
||||
pass
|
|
@ -7,7 +7,6 @@ import clip
|
|||
import torch.nn.functional as F
|
||||
from DiffAugment_pytorch import DiffAugment
|
||||
import numpy as np
|
||||
import lpips
|
||||
import os
|
||||
current_path = os.path.dirname(__file__)
|
||||
|
||||
|
|
Loading…
Reference in New Issue