ai_platform_cv/detection/detection.py

19 lines
488 B
Python
Raw Normal View History

2022-12-07 10:46:43 +08:00
# -*-coding:utf-8-*-
2022-08-03 10:16:48 +08:00
def detector(img, model):
"""_summary_
Args:
img (str or numpy.ndarray): 图片路径或者像素矩阵
model (_type_): 预加载的模型
Returns:
rtn(numpy.ndarray): 渲染后的图片像素点
pred(numpy.ndarray): 检测而出的目标的坐标点置信度和类别shape=[n, 6]
"""
result = model(img)
return result.render()[0], result.pred[0].cpu().numpy()
if __name__ == '__main__':
pass