22 lines
630 B
Python
22 lines
630 B
Python
import torch
|
||
import matplotlib.pyplot as plt
|
||
|
||
def detector(img, model):
|
||
"""_summary_
|
||
|
||
Args:
|
||
img (str or numpy.ndarray): 图片路径或者像素矩阵
|
||
model (_type_): 预加载的模型
|
||
|
||
Returns:
|
||
rtn(numpy.ndarray): 渲染后的图片像素点
|
||
pred(numpy.ndarray): 检测而出的目标的坐标点、置信度和类别,shape=[n, 6]
|
||
"""
|
||
result = model(img)
|
||
|
||
return result.render()[0], result.pred[0].cpu().numpy()
|
||
|
||
if __name__ == '__main__':
|
||
# model = torch.hub.load('/home/zhaojh/workspace/git_space/yolov5/', 'yolov5x', source='local', pretrained=True)
|
||
pass
|
||
|