ai_platform_cv/detection/detection.py

22 lines
630 B
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import matplotlib.pyplot as plt
def detector(img, model):
"""_summary_
Args:
img (str or numpy.ndarray): 图片路径或者像素矩阵
model (_type_): 预加载的模型
Returns:
rtn(numpy.ndarray): 渲染后的图片像素点
pred(numpy.ndarray): 检测而出的目标的坐标点、置信度和类别shape=[n, 6]
"""
result = model(img)
return result.render()[0], result.pred[0].cpu().numpy()
if __name__ == '__main__':
# model = torch.hub.load('/home/zhaojh/workspace/git_space/yolov5/', 'yolov5x', source='local', pretrained=True)
pass