update function

This commit is contained in:
赵敬皓 2022-12-07 10:46:43 +08:00
parent b8fe3320a9
commit e872462870
8 changed files with 148 additions and 138 deletions

View File

@ -1,17 +1,21 @@
FROM python:3.7.13-slim
FROM pytorch/pytorch:1.12.0-cuda11.3-cudnn8-runtime
WORKDIR /app
ADD . /app/
ADD ./requirements.txt /app/
RUN pip install --upgrade pip -i https://pypi.tuna.tsinghua.edu.cn/simple --no-cache-dir
RUN pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu113 --no-cache-dir
# RUN pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu113 --no-cache-dir
RUN pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple --no-cache-dir
RUN mim install mmcv-full
RUN pip install -e ./text2image/CLIP/
RUN pip install -e ./backbone/mmpose/
RUN pip install -e ./ocr/tr/
RUN rm -rf ./text2image/CLIP/
RUN rm -rf ./ocr/tr/
RUN rm -rf ./backbone/mmpose/
ADD . /app
RUN pip install -e ./text2image/CLIP/ --no-cache-dir
RUN pip install -e ./backbone/mmpose/ --no-cache-dir
RUN pip install -e ./ocr/tr/ --no-cache-dir
# RUN rm -rf ./text2image/CLIP/
# RUN rm -rf ./ocr/tr/
# RUN rm -rf ./backbone/mmpose/
RUN pip uninstall opencv-python -y
RUN pip install opencv-python-headless --force-reinstall
CMD ["python3", "run.py"]

View File

@ -1,6 +1,4 @@
from mmpose.apis import inference_bottom_up_pose_model, vis_pose_result
import cv2
import os
def run_backbone_infer(img_ndarr, pose_model):
# test a single image

View File

@ -1,6 +1,4 @@
import torch
import matplotlib.pyplot as plt
# -*-coding:utf-8-*-
def detector(img, model):
"""_summary_
@ -17,6 +15,5 @@ def detector(img, model):
return result.render()[0], result.pred[0].cpu().numpy()
if __name__ == '__main__':
# model = torch.hub.load('/home/zhaojh/workspace/git_space/yolov5/', 'yolov5x', source='local', pretrained=True)
pass

View File

@ -1,16 +1,15 @@
from asyncio.log import logger
from cmath import log
# -*-coding:utf-8-*-
from logzero import logger
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
import torch.optim as optim
import tqdm
import numpy as np
batch_size = 256
random_seed = 1884
DEVICE = torch.device('cuda')
DEVICE = torch.device('cpu')
torch.manual_seed(random_seed)
@ -122,13 +121,4 @@ def run_mnist_infer(img, model:CNN):
if __name__ == '__main__':
# model=load_model('./models/MNIST_torch.pth')
# a, b = load_data()
# for i, l in a:
# print(model(i)[0])
# print(np.argmax(model(i).detach().numpy()[0]), l)
# break
# trl, tel = load_data()
# cnn = train(10, 0.01, trl, tel)
# torch.save(cnn, './models/MNIST_torch.pth')
pass

View File

@ -10,7 +10,6 @@ Pillow==9.2.0
scipy==1.4.1
six==1.15.0
tqdm==4.64.0
opencv-python==4.6.0.66
seaborn==0.11.2
segmentation-models-pytorch==0.2.1
albumentations==1.2.1

220
run.py
View File

@ -2,118 +2,124 @@
import os
import sys
from logzero import logger
current_path = os.path.dirname(__file__)
logger.info(current_path)
# current_path = os.path.dirname(__file__) # for local
current_path = "/app" # for docker
logger.info(f"{current_path}")
sys.path.append(f"{current_path}/text2image/")
sys.path.append(f"{current_path}/text2image/BigGAN_utils/")
import json
import base64
from PIL import Image
from flask import Flask, request, make_response
import cv2
# from io import BytesIO
# import torch
from io import BytesIO
import torch
from mmpose.apis import init_pose_model
# from text2image.run_text2img import text2image
# from detection.detection import detector
# from segmentation.segment_pred import run_seg
# from ocr.ocr import run_tr
from text2image.run_text2img import text2image
from detection.detection import detector
from segmentation.segment_pred import run_seg
from ocr.ocr import run_tr
from backbone.backbone_infer import run_backbone_infer
from mnist.mnist_torch import run_mnist_infer, CNN
DEVICE = 'cpu'
# model_5x = torch.hub.load(f'{current_path}/detection/yolov5/','yolov5x', source='local', pretrained=True)
# model_5s = torch.hub.load(f'{current_path}/detection/yolov5/','yolov5s', source='local', pretrained=True)
# model_seg = torch.load(f'{current_path}/segmentation/models/best_model_pvgc.pth', map_location=DEVICE)
model_5x = torch.hub.load(f'{current_path}/detection/yolov5/','custom', path=f'{current_path}/detection/models/yolov5x.pt', source='local')
model_5s = torch.hub.load(f'{current_path}/detection/yolov5/','custom', path=f'{current_path}/detection/models/yolov5s.pt', source='local')
model_seg = torch.load(f'{current_path}/segmentation/models/best_model_pvgc.pth', map_location=DEVICE)
pose_config_file = f'{current_path}/backbone/associative_embedding_hrnet_w32_coco_512x512.py'
pose_ckpt_file = f'{current_path}/backbone/models/hrnet_w32_coco_512x512-bcb8c247_20200816.pth'
pose_model = init_pose_model(pose_config_file, pose_ckpt_file, device='cpu') # or device='cuda:0'
pose_model = init_pose_model(pose_config_file, pose_ckpt_file, device=DEVICE)
mnist_model = torch.load(f'{current_path}/mnist/models/MNIST_torch.pth', map_location=DEVICE)
app=Flask(__name__)
# @app.route('/text2image/',methods=["POST"])
# def run_text2img():
# if request.method == "POST":
# text = request.form.get('text')
# logger.info(f"{text}")
# img = text2image(text)
# output_buffer = BytesIO()
# img.save(output_buffer, format='png')
# byte_data = output_buffer.getvalue()
# b64_code = base64.b64encode(byte_data).decode('utf-8')
# resp = make_response(b64_code)
# resp.status_code = 200
# return resp
# else:
# resp = make_response()
# resp.status_code=405
# return resp
@app.route('/text2image/',methods=["POST"])
def run_text2img():
if request.method == "POST":
text = request.form.get('text')
logger.info(f"{text}")
img = text2image(text)
output_buffer = BytesIO()
img.save(output_buffer, format='png')
byte_data = output_buffer.getvalue()
b64_code = base64.b64encode(byte_data).decode('utf-8')
resp = make_response(b64_code)
resp.status_code = 200
return resp
else:
resp = make_response()
resp.status_code=405
return resp
# @app.route('/detection/', methods=["POST"])
# def run_detection():
# if request.method == "POST":
# img = request.files.get('image')
# model_type = request.form.get('model_type')
# try:
# img = cv2.imread(img)
# except:
# resp = make_response()
# resp.status_code = 406
# return resp
# if model_type.lower().strip() == 'yolov5x':
# rst, _ = detector(img, model_5x)
# else:
# rst, _ = detector(img, model_5s)
# logger.info(rst.shape)
# img_str = cv2.imencode('.png', rst)[1].tobytes()
# b64_code = base64.b64encode(img_str).decode('utf-8')
# resp = make_response(b64_code)
# resp.status_code = 200
# return b64_code
# else:
# resp = make_response()
# resp.status_code=405
# return resp
@app.route('/detection/', methods=["POST"])
def run_detection():
if request.method == "POST":
img = request.files.get('image')
model_type = request.form.get('model_type')
try:
img = cv2.imread(img)
except:
resp = make_response()
resp.status_code = 406
return resp
if model_type.lower().strip() == 'yolov5x':
rst, _ = detector(img, model_5x)
else:
rst, _ = detector(img, model_5s)
logger.info(rst.shape)
img_str = cv2.imencode('.png', rst)[1].tobytes()
b64_code = base64.b64encode(img_str).decode('utf-8')
resp = make_response(b64_code)
resp.status_code = 200
return b64_code
else:
resp = make_response()
resp.status_code=405
return resp
# @app.route('/ocr/', methods=["POST"])
# def run_ocr():
# resp = make_response()
# if request.method == "POST":
# img = request.files.get('image')
# try:
# img = cv2.imread(img)
# except:
# resp.status_code = 406
# return resp
# text = run_tr(img)
# resp.status_code = 200
# resp.data = json.dumps({'result':text})
# return resp
# else:
# resp.status_code=405
# return resp
@app.route('/ocr/', methods=["POST"])
def run_ocr():
resp = make_response()
if request.method == "POST":
img = request.files.get('image')
try:
img = cv2.imread(img)
except:
resp.status_code = 406
return resp
text = run_tr(img)
resp.status_code = 200
resp.data = json.dumps({'result':text})
return resp
else:
resp.status_code=405
return resp
# @app.route('/segmentation/', methods=["POST"])
# def run_segmentation():
# if request.method == "POST":
# img_upload = request.files.get('image')
# try:
# img = cv2.imread(img_upload)
# except:
# resp = make_response()
# resp.status_code = 406
# return resp
# result = run_seg(img, model_seg)
# img_str = cv2.imencode('.png', result)[1].tobytes()
# b64_code = base64.b64encode(img_str).decode('utf-8')
# resp = make_response(b64_code)
# resp.status_code = 200
# return resp
# else:
# resp = make_response()
# resp.status_code=405
# return resp
@app.route('/segmentation/', methods=["POST"])
def run_segmentation():
if request.method == "POST":
img_upload = request.files.get('image')
try:
img = cv2.imread(img_upload)
except:
resp = make_response()
resp.status_code = 406
return resp
result = run_seg(img, model_seg)
img_str = cv2.imencode('.png', result)[1].tobytes()
b64_code = base64.b64encode(img_str).decode('utf-8')
resp = make_response(b64_code)
resp.status_code = 200
return resp
else:
resp = make_response()
resp.status_code=405
return resp
@app.route('/backbone/', methods=["POST"])
def run_backbone():
@ -125,7 +131,7 @@ def run_backbone():
resp = make_response()
resp.status_code = 406
return resp
pose, result = run_backbone_infer(img, pose_model)
_, result = run_backbone_infer(img, pose_model)
img_str = cv2.imencode('.png', result)[1].tobytes()
b64_code = base64.b64encode(img_str).decode('utf-8')
resp = make_response(b64_code)
@ -136,8 +142,32 @@ def run_backbone():
resp.status_code=405
return resp
@app.route('/mnist/', methods=["POST"])
def run_mnist():
if request.method == "POST":
img_upload = request.files.get('image')
try:
img = cv2.imread(img_upload, 1)
# 使用全局阈值,降噪
ret,th1 = cv2.threshold(img,127,255,cv2.THRESH_BINARY)
# 把opencv图像转化为PIL图像
im = Image.fromarray(cv2.cvtColor(th1,cv2.COLOR_BGR2RGB))
# 灰度化
im = im.convert('L')
Im = im.resize((28, 28), Image.ANTIALIAS)
except:
resp = make_response()
resp.status_code = 406
return resp
result = run_mnist_infer(Im, mnist_model)
resp = make_response(str(result))
resp.status_code = 200
return resp
else:
resp = make_response()
resp.status_code=405
return resp
if __name__ == '__main__':
img = cv2.imread('./1.jpg')
pose, rst = run_backbone_infer(img, pose_model)
cv2.imwrite('./1_bb.jpg', rst)
app.run(host='0.0.0.0', port='8902')

View File

@ -1,14 +1,8 @@
from asyncio.log import logger
import os
import numpy as np
import cv2
import matplotlib.pyplot as plt
# -*-coding:utf-8-*-
from logzero import logger
import albumentations as albu
import torch
import segmentation_models_pytorch as smp
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as BaseDataset
from PIL import Image
DEVICE = 'cpu'
ENCODER = 'se_resnext50_32x4d'
@ -20,7 +14,7 @@ ENCODER_WEIGHTS = 'imagenet'
def get_validation_augmentation():
"""调整图像使得图片的分辨率长宽能被32整除"""
test_transform = [
albu.PadIfNeeded(256, 256)
albu.PadIfNeeded(1024, 1024)
]
return albu.Compose(test_transform)
@ -54,13 +48,12 @@ def run_seg(img, best_model):
# ---------------------------------------------------------------
img = augmentator(image=img)['image']
img = preprocessor(image=img)['image']
# 加载最佳模型
logger.info(f"img.shape {img.shape}")
x_tensor = torch.from_numpy(img).to(DEVICE).unsqueeze(0)
pr_mask = best_model.predict(x_tensor)
pr_mask = (pr_mask.squeeze().cpu().numpy().round())
return (pr_mask - 1) * (-220)
if __name__ == '__main__':
best_model = torch.load('/home/zhaojh/workspace/computer_vision/segmentation/models/best_model_pvgc.pth', map_location=DEVICE)
input_img = cv2.imread('/home/zhaojh/datasets/photovoltaic/PV03/PV03_Ground_Cropland/test/PV03_316626_1211836.bmp')
if __name__ == '__main__':
pass

View File

@ -7,7 +7,6 @@ import clip
import torch.nn.functional as F
from DiffAugment_pytorch import DiffAugment
import numpy as np
import lpips
import os
current_path = os.path.dirname(__file__)