22 lines
630 B
Python
22 lines
630 B
Python
|
import torch
|
|||
|
import matplotlib.pyplot as plt
|
|||
|
|
|||
|
def detector(img, model):
|
|||
|
"""_summary_
|
|||
|
|
|||
|
Args:
|
|||
|
img (str or numpy.ndarray): 图片路径或者像素矩阵
|
|||
|
model (_type_): 预加载的模型
|
|||
|
|
|||
|
Returns:
|
|||
|
rtn(numpy.ndarray): 渲染后的图片像素点
|
|||
|
pred(numpy.ndarray): 检测而出的目标的坐标点、置信度和类别,shape=[n, 6]
|
|||
|
"""
|
|||
|
result = model(img)
|
|||
|
|
|||
|
return result.render()[0], result.pred[0].cpu().numpy()
|
|||
|
|
|||
|
if __name__ == '__main__':
|
|||
|
# model = torch.hub.load('/home/zhaojh/workspace/git_space/yolov5/', 'yolov5x', source='local', pretrained=True)
|
|||
|
pass
|
|||
|
|