update function
This commit is contained in:
parent
b8fe3320a9
commit
e872462870
22
Dockerfile
22
Dockerfile
|
@ -1,17 +1,21 @@
|
||||||
FROM python:3.7.13-slim
|
FROM pytorch/pytorch:1.12.0-cuda11.3-cudnn8-runtime
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
ADD . /app/
|
ADD ./requirements.txt /app/
|
||||||
|
|
||||||
RUN pip install --upgrade pip -i https://pypi.tuna.tsinghua.edu.cn/simple --no-cache-dir
|
RUN pip install --upgrade pip -i https://pypi.tuna.tsinghua.edu.cn/simple --no-cache-dir
|
||||||
RUN pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu113 --no-cache-dir
|
# RUN pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu113 --no-cache-dir
|
||||||
RUN pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple --no-cache-dir
|
RUN pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple --no-cache-dir
|
||||||
RUN mim install mmcv-full
|
RUN mim install mmcv-full
|
||||||
RUN pip install -e ./text2image/CLIP/
|
|
||||||
RUN pip install -e ./backbone/mmpose/
|
ADD . /app
|
||||||
RUN pip install -e ./ocr/tr/
|
RUN pip install -e ./text2image/CLIP/ --no-cache-dir
|
||||||
RUN rm -rf ./text2image/CLIP/
|
RUN pip install -e ./backbone/mmpose/ --no-cache-dir
|
||||||
RUN rm -rf ./ocr/tr/
|
RUN pip install -e ./ocr/tr/ --no-cache-dir
|
||||||
RUN rm -rf ./backbone/mmpose/
|
# RUN rm -rf ./text2image/CLIP/
|
||||||
|
# RUN rm -rf ./ocr/tr/
|
||||||
|
# RUN rm -rf ./backbone/mmpose/
|
||||||
|
RUN pip uninstall opencv-python -y
|
||||||
|
RUN pip install opencv-python-headless --force-reinstall
|
||||||
|
|
||||||
CMD ["python3", "run.py"]
|
CMD ["python3", "run.py"]
|
|
@ -1,6 +1,4 @@
|
||||||
from mmpose.apis import inference_bottom_up_pose_model, vis_pose_result
|
from mmpose.apis import inference_bottom_up_pose_model, vis_pose_result
|
||||||
import cv2
|
|
||||||
import os
|
|
||||||
|
|
||||||
def run_backbone_infer(img_ndarr, pose_model):
|
def run_backbone_infer(img_ndarr, pose_model):
|
||||||
# test a single image
|
# test a single image
|
||||||
|
|
|
@ -1,6 +1,4 @@
|
||||||
import torch
|
# -*-coding:utf-8-*-
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
def detector(img, model):
|
def detector(img, model):
|
||||||
"""_summary_
|
"""_summary_
|
||||||
|
|
||||||
|
@ -17,6 +15,5 @@ def detector(img, model):
|
||||||
return result.render()[0], result.pred[0].cpu().numpy()
|
return result.render()[0], result.pred[0].cpu().numpy()
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# model = torch.hub.load('/home/zhaojh/workspace/git_space/yolov5/', 'yolov5x', source='local', pretrained=True)
|
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -1,16 +1,15 @@
|
||||||
from asyncio.log import logger
|
# -*-coding:utf-8-*-
|
||||||
from cmath import log
|
from logzero import logger
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torchvision
|
import torchvision
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
import tqdm
|
import tqdm
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
batch_size = 256
|
batch_size = 256
|
||||||
random_seed = 1884
|
random_seed = 1884
|
||||||
DEVICE = torch.device('cuda')
|
DEVICE = torch.device('cpu')
|
||||||
torch.manual_seed(random_seed)
|
torch.manual_seed(random_seed)
|
||||||
|
|
||||||
|
|
||||||
|
@ -122,13 +121,4 @@ def run_mnist_infer(img, model:CNN):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# model=load_model('./models/MNIST_torch.pth')
|
|
||||||
# a, b = load_data()
|
|
||||||
# for i, l in a:
|
|
||||||
# print(model(i)[0])
|
|
||||||
# print(np.argmax(model(i).detach().numpy()[0]), l)
|
|
||||||
# break
|
|
||||||
# trl, tel = load_data()
|
|
||||||
# cnn = train(10, 0.01, trl, tel)
|
|
||||||
# torch.save(cnn, './models/MNIST_torch.pth')
|
|
||||||
pass
|
pass
|
|
@ -10,7 +10,6 @@ Pillow==9.2.0
|
||||||
scipy==1.4.1
|
scipy==1.4.1
|
||||||
six==1.15.0
|
six==1.15.0
|
||||||
tqdm==4.64.0
|
tqdm==4.64.0
|
||||||
opencv-python==4.6.0.66
|
|
||||||
seaborn==0.11.2
|
seaborn==0.11.2
|
||||||
segmentation-models-pytorch==0.2.1
|
segmentation-models-pytorch==0.2.1
|
||||||
albumentations==1.2.1
|
albumentations==1.2.1
|
||||||
|
|
220
run.py
220
run.py
|
@ -2,118 +2,124 @@
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from logzero import logger
|
from logzero import logger
|
||||||
current_path = os.path.dirname(__file__)
|
|
||||||
logger.info(current_path)
|
# current_path = os.path.dirname(__file__) # for local
|
||||||
|
current_path = "/app" # for docker
|
||||||
|
logger.info(f"{current_path}")
|
||||||
sys.path.append(f"{current_path}/text2image/")
|
sys.path.append(f"{current_path}/text2image/")
|
||||||
sys.path.append(f"{current_path}/text2image/BigGAN_utils/")
|
sys.path.append(f"{current_path}/text2image/BigGAN_utils/")
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import base64
|
import base64
|
||||||
|
from PIL import Image
|
||||||
from flask import Flask, request, make_response
|
from flask import Flask, request, make_response
|
||||||
import cv2
|
import cv2
|
||||||
# from io import BytesIO
|
from io import BytesIO
|
||||||
# import torch
|
import torch
|
||||||
from mmpose.apis import init_pose_model
|
from mmpose.apis import init_pose_model
|
||||||
# from text2image.run_text2img import text2image
|
from text2image.run_text2img import text2image
|
||||||
# from detection.detection import detector
|
from detection.detection import detector
|
||||||
# from segmentation.segment_pred import run_seg
|
from segmentation.segment_pred import run_seg
|
||||||
# from ocr.ocr import run_tr
|
from ocr.ocr import run_tr
|
||||||
from backbone.backbone_infer import run_backbone_infer
|
from backbone.backbone_infer import run_backbone_infer
|
||||||
|
from mnist.mnist_torch import run_mnist_infer, CNN
|
||||||
|
|
||||||
DEVICE = 'cpu'
|
DEVICE = 'cpu'
|
||||||
|
|
||||||
# model_5x = torch.hub.load(f'{current_path}/detection/yolov5/','yolov5x', source='local', pretrained=True)
|
model_5x = torch.hub.load(f'{current_path}/detection/yolov5/','custom', path=f'{current_path}/detection/models/yolov5x.pt', source='local')
|
||||||
# model_5s = torch.hub.load(f'{current_path}/detection/yolov5/','yolov5s', source='local', pretrained=True)
|
model_5s = torch.hub.load(f'{current_path}/detection/yolov5/','custom', path=f'{current_path}/detection/models/yolov5s.pt', source='local')
|
||||||
# model_seg = torch.load(f'{current_path}/segmentation/models/best_model_pvgc.pth', map_location=DEVICE)
|
model_seg = torch.load(f'{current_path}/segmentation/models/best_model_pvgc.pth', map_location=DEVICE)
|
||||||
pose_config_file = f'{current_path}/backbone/associative_embedding_hrnet_w32_coco_512x512.py'
|
pose_config_file = f'{current_path}/backbone/associative_embedding_hrnet_w32_coco_512x512.py'
|
||||||
pose_ckpt_file = f'{current_path}/backbone/models/hrnet_w32_coco_512x512-bcb8c247_20200816.pth'
|
pose_ckpt_file = f'{current_path}/backbone/models/hrnet_w32_coco_512x512-bcb8c247_20200816.pth'
|
||||||
pose_model = init_pose_model(pose_config_file, pose_ckpt_file, device='cpu') # or device='cuda:0'
|
pose_model = init_pose_model(pose_config_file, pose_ckpt_file, device=DEVICE)
|
||||||
|
mnist_model = torch.load(f'{current_path}/mnist/models/MNIST_torch.pth', map_location=DEVICE)
|
||||||
|
|
||||||
app=Flask(__name__)
|
app=Flask(__name__)
|
||||||
|
|
||||||
# @app.route('/text2image/',methods=["POST"])
|
@app.route('/text2image/',methods=["POST"])
|
||||||
# def run_text2img():
|
def run_text2img():
|
||||||
# if request.method == "POST":
|
if request.method == "POST":
|
||||||
# text = request.form.get('text')
|
text = request.form.get('text')
|
||||||
# logger.info(f"{text}")
|
logger.info(f"{text}")
|
||||||
# img = text2image(text)
|
img = text2image(text)
|
||||||
# output_buffer = BytesIO()
|
output_buffer = BytesIO()
|
||||||
# img.save(output_buffer, format='png')
|
img.save(output_buffer, format='png')
|
||||||
# byte_data = output_buffer.getvalue()
|
byte_data = output_buffer.getvalue()
|
||||||
# b64_code = base64.b64encode(byte_data).decode('utf-8')
|
b64_code = base64.b64encode(byte_data).decode('utf-8')
|
||||||
# resp = make_response(b64_code)
|
resp = make_response(b64_code)
|
||||||
# resp.status_code = 200
|
resp.status_code = 200
|
||||||
# return resp
|
return resp
|
||||||
# else:
|
else:
|
||||||
# resp = make_response()
|
resp = make_response()
|
||||||
# resp.status_code=405
|
resp.status_code=405
|
||||||
# return resp
|
return resp
|
||||||
|
|
||||||
|
|
||||||
# @app.route('/detection/', methods=["POST"])
|
@app.route('/detection/', methods=["POST"])
|
||||||
# def run_detection():
|
def run_detection():
|
||||||
# if request.method == "POST":
|
if request.method == "POST":
|
||||||
# img = request.files.get('image')
|
img = request.files.get('image')
|
||||||
# model_type = request.form.get('model_type')
|
model_type = request.form.get('model_type')
|
||||||
# try:
|
try:
|
||||||
# img = cv2.imread(img)
|
img = cv2.imread(img)
|
||||||
# except:
|
except:
|
||||||
# resp = make_response()
|
resp = make_response()
|
||||||
# resp.status_code = 406
|
resp.status_code = 406
|
||||||
# return resp
|
return resp
|
||||||
# if model_type.lower().strip() == 'yolov5x':
|
if model_type.lower().strip() == 'yolov5x':
|
||||||
# rst, _ = detector(img, model_5x)
|
rst, _ = detector(img, model_5x)
|
||||||
# else:
|
else:
|
||||||
# rst, _ = detector(img, model_5s)
|
rst, _ = detector(img, model_5s)
|
||||||
# logger.info(rst.shape)
|
logger.info(rst.shape)
|
||||||
# img_str = cv2.imencode('.png', rst)[1].tobytes()
|
img_str = cv2.imencode('.png', rst)[1].tobytes()
|
||||||
# b64_code = base64.b64encode(img_str).decode('utf-8')
|
b64_code = base64.b64encode(img_str).decode('utf-8')
|
||||||
# resp = make_response(b64_code)
|
resp = make_response(b64_code)
|
||||||
# resp.status_code = 200
|
resp.status_code = 200
|
||||||
# return b64_code
|
return b64_code
|
||||||
# else:
|
else:
|
||||||
# resp = make_response()
|
resp = make_response()
|
||||||
# resp.status_code=405
|
resp.status_code=405
|
||||||
# return resp
|
return resp
|
||||||
|
|
||||||
# @app.route('/ocr/', methods=["POST"])
|
@app.route('/ocr/', methods=["POST"])
|
||||||
# def run_ocr():
|
def run_ocr():
|
||||||
# resp = make_response()
|
resp = make_response()
|
||||||
# if request.method == "POST":
|
if request.method == "POST":
|
||||||
# img = request.files.get('image')
|
img = request.files.get('image')
|
||||||
# try:
|
try:
|
||||||
# img = cv2.imread(img)
|
img = cv2.imread(img)
|
||||||
# except:
|
except:
|
||||||
# resp.status_code = 406
|
resp.status_code = 406
|
||||||
# return resp
|
return resp
|
||||||
# text = run_tr(img)
|
text = run_tr(img)
|
||||||
# resp.status_code = 200
|
resp.status_code = 200
|
||||||
# resp.data = json.dumps({'result':text})
|
resp.data = json.dumps({'result':text})
|
||||||
# return resp
|
return resp
|
||||||
# else:
|
else:
|
||||||
# resp.status_code=405
|
resp.status_code=405
|
||||||
# return resp
|
return resp
|
||||||
|
|
||||||
|
|
||||||
# @app.route('/segmentation/', methods=["POST"])
|
@app.route('/segmentation/', methods=["POST"])
|
||||||
# def run_segmentation():
|
def run_segmentation():
|
||||||
# if request.method == "POST":
|
if request.method == "POST":
|
||||||
# img_upload = request.files.get('image')
|
img_upload = request.files.get('image')
|
||||||
# try:
|
try:
|
||||||
# img = cv2.imread(img_upload)
|
img = cv2.imread(img_upload)
|
||||||
# except:
|
except:
|
||||||
# resp = make_response()
|
resp = make_response()
|
||||||
# resp.status_code = 406
|
resp.status_code = 406
|
||||||
# return resp
|
return resp
|
||||||
# result = run_seg(img, model_seg)
|
result = run_seg(img, model_seg)
|
||||||
# img_str = cv2.imencode('.png', result)[1].tobytes()
|
img_str = cv2.imencode('.png', result)[1].tobytes()
|
||||||
# b64_code = base64.b64encode(img_str).decode('utf-8')
|
b64_code = base64.b64encode(img_str).decode('utf-8')
|
||||||
# resp = make_response(b64_code)
|
resp = make_response(b64_code)
|
||||||
# resp.status_code = 200
|
resp.status_code = 200
|
||||||
# return resp
|
return resp
|
||||||
# else:
|
else:
|
||||||
# resp = make_response()
|
resp = make_response()
|
||||||
# resp.status_code=405
|
resp.status_code=405
|
||||||
# return resp
|
return resp
|
||||||
|
|
||||||
@app.route('/backbone/', methods=["POST"])
|
@app.route('/backbone/', methods=["POST"])
|
||||||
def run_backbone():
|
def run_backbone():
|
||||||
|
@ -125,7 +131,7 @@ def run_backbone():
|
||||||
resp = make_response()
|
resp = make_response()
|
||||||
resp.status_code = 406
|
resp.status_code = 406
|
||||||
return resp
|
return resp
|
||||||
pose, result = run_backbone_infer(img, pose_model)
|
_, result = run_backbone_infer(img, pose_model)
|
||||||
img_str = cv2.imencode('.png', result)[1].tobytes()
|
img_str = cv2.imencode('.png', result)[1].tobytes()
|
||||||
b64_code = base64.b64encode(img_str).decode('utf-8')
|
b64_code = base64.b64encode(img_str).decode('utf-8')
|
||||||
resp = make_response(b64_code)
|
resp = make_response(b64_code)
|
||||||
|
@ -136,8 +142,32 @@ def run_backbone():
|
||||||
resp.status_code=405
|
resp.status_code=405
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
@app.route('/mnist/', methods=["POST"])
|
||||||
|
def run_mnist():
|
||||||
|
if request.method == "POST":
|
||||||
|
img_upload = request.files.get('image')
|
||||||
|
try:
|
||||||
|
img = cv2.imread(img_upload, 1)
|
||||||
|
# 使用全局阈值,降噪
|
||||||
|
ret,th1 = cv2.threshold(img,127,255,cv2.THRESH_BINARY)
|
||||||
|
# 把opencv图像转化为PIL图像
|
||||||
|
im = Image.fromarray(cv2.cvtColor(th1,cv2.COLOR_BGR2RGB))
|
||||||
|
# 灰度化
|
||||||
|
im = im.convert('L')
|
||||||
|
Im = im.resize((28, 28), Image.ANTIALIAS)
|
||||||
|
except:
|
||||||
|
resp = make_response()
|
||||||
|
resp.status_code = 406
|
||||||
|
return resp
|
||||||
|
result = run_mnist_infer(Im, mnist_model)
|
||||||
|
resp = make_response(str(result))
|
||||||
|
resp.status_code = 200
|
||||||
|
return resp
|
||||||
|
else:
|
||||||
|
resp = make_response()
|
||||||
|
resp.status_code=405
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
img = cv2.imread('./1.jpg')
|
app.run(host='0.0.0.0', port='8902')
|
||||||
pose, rst = run_backbone_infer(img, pose_model)
|
|
||||||
cv2.imwrite('./1_bb.jpg', rst)
|
|
|
@ -1,14 +1,8 @@
|
||||||
from asyncio.log import logger
|
# -*-coding:utf-8-*-
|
||||||
import os
|
from logzero import logger
|
||||||
import numpy as np
|
|
||||||
import cv2
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import albumentations as albu
|
import albumentations as albu
|
||||||
import torch
|
import torch
|
||||||
import segmentation_models_pytorch as smp
|
import segmentation_models_pytorch as smp
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
from torch.utils.data import Dataset as BaseDataset
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
DEVICE = 'cpu'
|
DEVICE = 'cpu'
|
||||||
ENCODER = 'se_resnext50_32x4d'
|
ENCODER = 'se_resnext50_32x4d'
|
||||||
|
@ -20,7 +14,7 @@ ENCODER_WEIGHTS = 'imagenet'
|
||||||
def get_validation_augmentation():
|
def get_validation_augmentation():
|
||||||
"""调整图像使得图片的分辨率长宽能被32整除"""
|
"""调整图像使得图片的分辨率长宽能被32整除"""
|
||||||
test_transform = [
|
test_transform = [
|
||||||
albu.PadIfNeeded(256, 256)
|
albu.PadIfNeeded(1024, 1024)
|
||||||
]
|
]
|
||||||
return albu.Compose(test_transform)
|
return albu.Compose(test_transform)
|
||||||
|
|
||||||
|
@ -54,13 +48,12 @@ def run_seg(img, best_model):
|
||||||
# ---------------------------------------------------------------
|
# ---------------------------------------------------------------
|
||||||
img = augmentator(image=img)['image']
|
img = augmentator(image=img)['image']
|
||||||
img = preprocessor(image=img)['image']
|
img = preprocessor(image=img)['image']
|
||||||
# 加载最佳模型
|
logger.info(f"img.shape {img.shape}")
|
||||||
x_tensor = torch.from_numpy(img).to(DEVICE).unsqueeze(0)
|
x_tensor = torch.from_numpy(img).to(DEVICE).unsqueeze(0)
|
||||||
pr_mask = best_model.predict(x_tensor)
|
pr_mask = best_model.predict(x_tensor)
|
||||||
pr_mask = (pr_mask.squeeze().cpu().numpy().round())
|
pr_mask = (pr_mask.squeeze().cpu().numpy().round())
|
||||||
return (pr_mask - 1) * (-220)
|
return (pr_mask - 1) * (-220)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
best_model = torch.load('/home/zhaojh/workspace/computer_vision/segmentation/models/best_model_pvgc.pth', map_location=DEVICE)
|
pass
|
||||||
input_img = cv2.imread('/home/zhaojh/datasets/photovoltaic/PV03/PV03_Ground_Cropland/test/PV03_316626_1211836.bmp')
|
|
|
@ -7,7 +7,6 @@ import clip
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from DiffAugment_pytorch import DiffAugment
|
from DiffAugment_pytorch import DiffAugment
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import lpips
|
|
||||||
import os
|
import os
|
||||||
current_path = os.path.dirname(__file__)
|
current_path = os.path.dirname(__file__)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue