update function

This commit is contained in:
赵敬皓 2022-12-07 10:46:43 +08:00
parent b8fe3320a9
commit e872462870
8 changed files with 148 additions and 138 deletions

View File

@ -1,17 +1,21 @@
FROM python:3.7.13-slim FROM pytorch/pytorch:1.12.0-cuda11.3-cudnn8-runtime
WORKDIR /app WORKDIR /app
ADD . /app/ ADD ./requirements.txt /app/
RUN pip install --upgrade pip -i https://pypi.tuna.tsinghua.edu.cn/simple --no-cache-dir RUN pip install --upgrade pip -i https://pypi.tuna.tsinghua.edu.cn/simple --no-cache-dir
RUN pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu113 --no-cache-dir # RUN pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu113 --no-cache-dir
RUN pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple --no-cache-dir RUN pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple --no-cache-dir
RUN mim install mmcv-full RUN mim install mmcv-full
RUN pip install -e ./text2image/CLIP/
RUN pip install -e ./backbone/mmpose/ ADD . /app
RUN pip install -e ./ocr/tr/ RUN pip install -e ./text2image/CLIP/ --no-cache-dir
RUN rm -rf ./text2image/CLIP/ RUN pip install -e ./backbone/mmpose/ --no-cache-dir
RUN rm -rf ./ocr/tr/ RUN pip install -e ./ocr/tr/ --no-cache-dir
RUN rm -rf ./backbone/mmpose/ # RUN rm -rf ./text2image/CLIP/
# RUN rm -rf ./ocr/tr/
# RUN rm -rf ./backbone/mmpose/
RUN pip uninstall opencv-python -y
RUN pip install opencv-python-headless --force-reinstall
CMD ["python3", "run.py"] CMD ["python3", "run.py"]

View File

@ -1,6 +1,4 @@
from mmpose.apis import inference_bottom_up_pose_model, vis_pose_result from mmpose.apis import inference_bottom_up_pose_model, vis_pose_result
import cv2
import os
def run_backbone_infer(img_ndarr, pose_model): def run_backbone_infer(img_ndarr, pose_model):
# test a single image # test a single image

View File

@ -1,6 +1,4 @@
import torch # -*-coding:utf-8-*-
import matplotlib.pyplot as plt
def detector(img, model): def detector(img, model):
"""_summary_ """_summary_
@ -17,6 +15,5 @@ def detector(img, model):
return result.render()[0], result.pred[0].cpu().numpy() return result.render()[0], result.pred[0].cpu().numpy()
if __name__ == '__main__': if __name__ == '__main__':
# model = torch.hub.load('/home/zhaojh/workspace/git_space/yolov5/', 'yolov5x', source='local', pretrained=True)
pass pass

View File

@ -1,16 +1,15 @@
from asyncio.log import logger # -*-coding:utf-8-*-
from cmath import log from logzero import logger
import torch import torch
import torch.nn as nn import torch.nn as nn
import torchvision import torchvision
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import torch.optim as optim import torch.optim as optim
import tqdm import tqdm
import numpy as np
batch_size = 256 batch_size = 256
random_seed = 1884 random_seed = 1884
DEVICE = torch.device('cuda') DEVICE = torch.device('cpu')
torch.manual_seed(random_seed) torch.manual_seed(random_seed)
@ -122,13 +121,4 @@ def run_mnist_infer(img, model:CNN):
if __name__ == '__main__': if __name__ == '__main__':
# model=load_model('./models/MNIST_torch.pth')
# a, b = load_data()
# for i, l in a:
# print(model(i)[0])
# print(np.argmax(model(i).detach().numpy()[0]), l)
# break
# trl, tel = load_data()
# cnn = train(10, 0.01, trl, tel)
# torch.save(cnn, './models/MNIST_torch.pth')
pass pass

View File

@ -10,7 +10,6 @@ Pillow==9.2.0
scipy==1.4.1 scipy==1.4.1
six==1.15.0 six==1.15.0
tqdm==4.64.0 tqdm==4.64.0
opencv-python==4.6.0.66
seaborn==0.11.2 seaborn==0.11.2
segmentation-models-pytorch==0.2.1 segmentation-models-pytorch==0.2.1
albumentations==1.2.1 albumentations==1.2.1

220
run.py
View File

@ -2,118 +2,124 @@
import os import os
import sys import sys
from logzero import logger from logzero import logger
current_path = os.path.dirname(__file__)
logger.info(current_path) # current_path = os.path.dirname(__file__) # for local
current_path = "/app" # for docker
logger.info(f"{current_path}")
sys.path.append(f"{current_path}/text2image/") sys.path.append(f"{current_path}/text2image/")
sys.path.append(f"{current_path}/text2image/BigGAN_utils/") sys.path.append(f"{current_path}/text2image/BigGAN_utils/")
import json import json
import base64 import base64
from PIL import Image
from flask import Flask, request, make_response from flask import Flask, request, make_response
import cv2 import cv2
# from io import BytesIO from io import BytesIO
# import torch import torch
from mmpose.apis import init_pose_model from mmpose.apis import init_pose_model
# from text2image.run_text2img import text2image from text2image.run_text2img import text2image
# from detection.detection import detector from detection.detection import detector
# from segmentation.segment_pred import run_seg from segmentation.segment_pred import run_seg
# from ocr.ocr import run_tr from ocr.ocr import run_tr
from backbone.backbone_infer import run_backbone_infer from backbone.backbone_infer import run_backbone_infer
from mnist.mnist_torch import run_mnist_infer, CNN
DEVICE = 'cpu' DEVICE = 'cpu'
# model_5x = torch.hub.load(f'{current_path}/detection/yolov5/','yolov5x', source='local', pretrained=True) model_5x = torch.hub.load(f'{current_path}/detection/yolov5/','custom', path=f'{current_path}/detection/models/yolov5x.pt', source='local')
# model_5s = torch.hub.load(f'{current_path}/detection/yolov5/','yolov5s', source='local', pretrained=True) model_5s = torch.hub.load(f'{current_path}/detection/yolov5/','custom', path=f'{current_path}/detection/models/yolov5s.pt', source='local')
# model_seg = torch.load(f'{current_path}/segmentation/models/best_model_pvgc.pth', map_location=DEVICE) model_seg = torch.load(f'{current_path}/segmentation/models/best_model_pvgc.pth', map_location=DEVICE)
pose_config_file = f'{current_path}/backbone/associative_embedding_hrnet_w32_coco_512x512.py' pose_config_file = f'{current_path}/backbone/associative_embedding_hrnet_w32_coco_512x512.py'
pose_ckpt_file = f'{current_path}/backbone/models/hrnet_w32_coco_512x512-bcb8c247_20200816.pth' pose_ckpt_file = f'{current_path}/backbone/models/hrnet_w32_coco_512x512-bcb8c247_20200816.pth'
pose_model = init_pose_model(pose_config_file, pose_ckpt_file, device='cpu') # or device='cuda:0' pose_model = init_pose_model(pose_config_file, pose_ckpt_file, device=DEVICE)
mnist_model = torch.load(f'{current_path}/mnist/models/MNIST_torch.pth', map_location=DEVICE)
app=Flask(__name__) app=Flask(__name__)
# @app.route('/text2image/',methods=["POST"]) @app.route('/text2image/',methods=["POST"])
# def run_text2img(): def run_text2img():
# if request.method == "POST": if request.method == "POST":
# text = request.form.get('text') text = request.form.get('text')
# logger.info(f"{text}") logger.info(f"{text}")
# img = text2image(text) img = text2image(text)
# output_buffer = BytesIO() output_buffer = BytesIO()
# img.save(output_buffer, format='png') img.save(output_buffer, format='png')
# byte_data = output_buffer.getvalue() byte_data = output_buffer.getvalue()
# b64_code = base64.b64encode(byte_data).decode('utf-8') b64_code = base64.b64encode(byte_data).decode('utf-8')
# resp = make_response(b64_code) resp = make_response(b64_code)
# resp.status_code = 200 resp.status_code = 200
# return resp return resp
# else: else:
# resp = make_response() resp = make_response()
# resp.status_code=405 resp.status_code=405
# return resp return resp
# @app.route('/detection/', methods=["POST"]) @app.route('/detection/', methods=["POST"])
# def run_detection(): def run_detection():
# if request.method == "POST": if request.method == "POST":
# img = request.files.get('image') img = request.files.get('image')
# model_type = request.form.get('model_type') model_type = request.form.get('model_type')
# try: try:
# img = cv2.imread(img) img = cv2.imread(img)
# except: except:
# resp = make_response() resp = make_response()
# resp.status_code = 406 resp.status_code = 406
# return resp return resp
# if model_type.lower().strip() == 'yolov5x': if model_type.lower().strip() == 'yolov5x':
# rst, _ = detector(img, model_5x) rst, _ = detector(img, model_5x)
# else: else:
# rst, _ = detector(img, model_5s) rst, _ = detector(img, model_5s)
# logger.info(rst.shape) logger.info(rst.shape)
# img_str = cv2.imencode('.png', rst)[1].tobytes() img_str = cv2.imencode('.png', rst)[1].tobytes()
# b64_code = base64.b64encode(img_str).decode('utf-8') b64_code = base64.b64encode(img_str).decode('utf-8')
# resp = make_response(b64_code) resp = make_response(b64_code)
# resp.status_code = 200 resp.status_code = 200
# return b64_code return b64_code
# else: else:
# resp = make_response() resp = make_response()
# resp.status_code=405 resp.status_code=405
# return resp return resp
# @app.route('/ocr/', methods=["POST"]) @app.route('/ocr/', methods=["POST"])
# def run_ocr(): def run_ocr():
# resp = make_response() resp = make_response()
# if request.method == "POST": if request.method == "POST":
# img = request.files.get('image') img = request.files.get('image')
# try: try:
# img = cv2.imread(img) img = cv2.imread(img)
# except: except:
# resp.status_code = 406 resp.status_code = 406
# return resp return resp
# text = run_tr(img) text = run_tr(img)
# resp.status_code = 200 resp.status_code = 200
# resp.data = json.dumps({'result':text}) resp.data = json.dumps({'result':text})
# return resp return resp
# else: else:
# resp.status_code=405 resp.status_code=405
# return resp return resp
# @app.route('/segmentation/', methods=["POST"]) @app.route('/segmentation/', methods=["POST"])
# def run_segmentation(): def run_segmentation():
# if request.method == "POST": if request.method == "POST":
# img_upload = request.files.get('image') img_upload = request.files.get('image')
# try: try:
# img = cv2.imread(img_upload) img = cv2.imread(img_upload)
# except: except:
# resp = make_response() resp = make_response()
# resp.status_code = 406 resp.status_code = 406
# return resp return resp
# result = run_seg(img, model_seg) result = run_seg(img, model_seg)
# img_str = cv2.imencode('.png', result)[1].tobytes() img_str = cv2.imencode('.png', result)[1].tobytes()
# b64_code = base64.b64encode(img_str).decode('utf-8') b64_code = base64.b64encode(img_str).decode('utf-8')
# resp = make_response(b64_code) resp = make_response(b64_code)
# resp.status_code = 200 resp.status_code = 200
# return resp return resp
# else: else:
# resp = make_response() resp = make_response()
# resp.status_code=405 resp.status_code=405
# return resp return resp
@app.route('/backbone/', methods=["POST"]) @app.route('/backbone/', methods=["POST"])
def run_backbone(): def run_backbone():
@ -125,7 +131,7 @@ def run_backbone():
resp = make_response() resp = make_response()
resp.status_code = 406 resp.status_code = 406
return resp return resp
pose, result = run_backbone_infer(img, pose_model) _, result = run_backbone_infer(img, pose_model)
img_str = cv2.imencode('.png', result)[1].tobytes() img_str = cv2.imencode('.png', result)[1].tobytes()
b64_code = base64.b64encode(img_str).decode('utf-8') b64_code = base64.b64encode(img_str).decode('utf-8')
resp = make_response(b64_code) resp = make_response(b64_code)
@ -136,8 +142,32 @@ def run_backbone():
resp.status_code=405 resp.status_code=405
return resp return resp
@app.route('/mnist/', methods=["POST"])
def run_mnist():
if request.method == "POST":
img_upload = request.files.get('image')
try:
img = cv2.imread(img_upload, 1)
# 使用全局阈值,降噪
ret,th1 = cv2.threshold(img,127,255,cv2.THRESH_BINARY)
# 把opencv图像转化为PIL图像
im = Image.fromarray(cv2.cvtColor(th1,cv2.COLOR_BGR2RGB))
# 灰度化
im = im.convert('L')
Im = im.resize((28, 28), Image.ANTIALIAS)
except:
resp = make_response()
resp.status_code = 406
return resp
result = run_mnist_infer(Im, mnist_model)
resp = make_response(str(result))
resp.status_code = 200
return resp
else:
resp = make_response()
resp.status_code=405
return resp
if __name__ == '__main__': if __name__ == '__main__':
img = cv2.imread('./1.jpg') app.run(host='0.0.0.0', port='8902')
pose, rst = run_backbone_infer(img, pose_model)
cv2.imwrite('./1_bb.jpg', rst)

View File

@ -1,14 +1,8 @@
from asyncio.log import logger # -*-coding:utf-8-*-
import os from logzero import logger
import numpy as np
import cv2
import matplotlib.pyplot as plt
import albumentations as albu import albumentations as albu
import torch import torch
import segmentation_models_pytorch as smp import segmentation_models_pytorch as smp
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as BaseDataset
from PIL import Image
DEVICE = 'cpu' DEVICE = 'cpu'
ENCODER = 'se_resnext50_32x4d' ENCODER = 'se_resnext50_32x4d'
@ -20,7 +14,7 @@ ENCODER_WEIGHTS = 'imagenet'
def get_validation_augmentation(): def get_validation_augmentation():
"""调整图像使得图片的分辨率长宽能被32整除""" """调整图像使得图片的分辨率长宽能被32整除"""
test_transform = [ test_transform = [
albu.PadIfNeeded(256, 256) albu.PadIfNeeded(1024, 1024)
] ]
return albu.Compose(test_transform) return albu.Compose(test_transform)
@ -54,13 +48,12 @@ def run_seg(img, best_model):
# --------------------------------------------------------------- # ---------------------------------------------------------------
img = augmentator(image=img)['image'] img = augmentator(image=img)['image']
img = preprocessor(image=img)['image'] img = preprocessor(image=img)['image']
# 加载最佳模型 logger.info(f"img.shape {img.shape}")
x_tensor = torch.from_numpy(img).to(DEVICE).unsqueeze(0) x_tensor = torch.from_numpy(img).to(DEVICE).unsqueeze(0)
pr_mask = best_model.predict(x_tensor) pr_mask = best_model.predict(x_tensor)
pr_mask = (pr_mask.squeeze().cpu().numpy().round()) pr_mask = (pr_mask.squeeze().cpu().numpy().round())
return (pr_mask - 1) * (-220) return (pr_mask - 1) * (-220)
if __name__ == '__main__': if __name__ == '__main__':
best_model = torch.load('/home/zhaojh/workspace/computer_vision/segmentation/models/best_model_pvgc.pth', map_location=DEVICE) pass
input_img = cv2.imread('/home/zhaojh/datasets/photovoltaic/PV03/PV03_Ground_Cropland/test/PV03_316626_1211836.bmp')

View File

@ -7,7 +7,6 @@ import clip
import torch.nn.functional as F import torch.nn.functional as F
from DiffAugment_pytorch import DiffAugment from DiffAugment_pytorch import DiffAugment
import numpy as np import numpy as np
import lpips
import os import os
current_path = os.path.dirname(__file__) current_path = os.path.dirname(__file__)