ai_platform_cv/detection/detection.py

22 lines
630 B
Python
Raw Normal View History

2022-08-03 10:16:48 +08:00
import torch
import matplotlib.pyplot as plt
def detector(img, model):
"""_summary_
Args:
img (str or numpy.ndarray): 图片路径或者像素矩阵
model (_type_): 预加载的模型
Returns:
rtn(numpy.ndarray): 渲染后的图片像素点
pred(numpy.ndarray): 检测而出的目标的坐标点置信度和类别shape=[n, 6]
"""
result = model(img)
return result.render()[0], result.pred[0].cpu().numpy()
if __name__ == '__main__':
# model = torch.hub.load('/home/zhaojh/workspace/git_space/yolov5/', 'yolov5x', source='local', pretrained=True)
pass