SAM/SAM_Mask.py

176 lines
7.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import cv2
import os
import numpy as np
from segment_anything import sam_model_registry, SamPredictor
input_dir = 'scripts/pv1/0.8/wuding/jpg'
output_dir = 'scripts/pv1/0.8/wuding/mask'
crop_mode = True
print('最好是每加一个点就按w键predict一次')
os.makedirs(output_dir, exist_ok=True)
image_files = [f for f in os.listdir(input_dir) if
f.lower().endswith(('.png', '.jpg', '.jpeg', '.JPG', '.JPEG', '.PNG', '.tif'))]
sam = sam_model_registry["vit_b"](checkpoint=r"D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth")
_ = sam.to(device="cuda")
predictor = SamPredictor(sam)
WINDOW_WIDTH = 1280
WINDOW_HEIGHT = 720
cv2.namedWindow("image", cv2.WINDOW_NORMAL)
cv2.resizeWindow("image", WINDOW_WIDTH, WINDOW_HEIGHT)
cv2.moveWindow("image", (1920 - WINDOW_WIDTH) // 2, (1080 - WINDOW_HEIGHT) // 2)
def mouse_click(event, x, y, flags, param): # 鼠标点击事件
global input_point, input_label, input_stop # 全局变量,输入点,
if not input_stop: # 判定标志是否停止输入响应了!
if event == cv2.EVENT_LBUTTONDOWN: # 鼠标左键
input_point.append([x, y])
input_label.append(1) # 1表示前景点
elif event == cv2.EVENT_RBUTTONDOWN: # 鼠标右键
input_point.append([x, y])
input_label.append(0) # 0表示背景点
else:
if event == cv2.EVENT_LBUTTONDOWN or event == cv2.EVENT_RBUTTONDOWN: # 提示添加不了
print('此时不能添加点,按w退出mask选择模式')
def apply_mask(image, mask, alpha_channel=True):
if alpha_channel:
alpha = np.zeros_like(mask)
alpha[mask == 1] = 255
masked_image = image.copy()
masked_image[mask == 1] = [0, 255, 0] # 这里用绿色表示mask的区域可以根据需求修改
return cv2.addWeighted(image, 0.5, masked_image, 0.5, 0)
else:
masked_image = image.copy()
masked_image[mask == 1] = [0, 255, 0] # 这里用绿色表示mask的区域可以根据需求修改
return masked_image
def apply_color_mask(image, mask, color, color_dark=0.5):
masked_image = image.copy()
for c in range(3):
masked_image[:, :, c] = np.where(mask == 1, image[:, :, c] * (1 - color_dark) + color_dark * color[c], image[:, :, c])
return masked_image
def save_masked_image(image, mask, output_dir, filename, crop_mode_):
masked_image = apply_mask(image, mask)
filename = filename[:filename.rfind('.')] + '_masked.png' # 修改保存的文件名
cv2.imwrite(os.path.join(output_dir, filename), masked_image)
print(f"Saved as {filename}")
current_index = 0
cv2.namedWindow("image")
cv2.setMouseCallback("image", mouse_click)
input_point = []
input_label = []
input_stop = False
while True:
filename = image_files[current_index]
image_orign = cv2.imread(os.path.join(input_dir, filename))
image_crop = image_orign.copy() # 原图裁剪
image = cv2.cvtColor(image_orign.copy(), cv2.COLOR_BGR2RGB) # 原图色彩转变
selected_mask = None
logit_input = None
while True:
# print(input_point)
input_stop = False
image_display = image_orign.copy()
display_info = f'{filename} | Press s to save | Press w to predict | Press d to next image | Press a to previous image | Press space to clear | Press q to remove last point '
cv2.putText(image_display, display_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2, cv2.LINE_AA)
for point, label in zip(input_point, input_label): # 输入点和输入类型
color = (0, 255, 0) if label == 1 else (0, 0, 255)
cv2.circle(image_display, tuple(point), 5, color, -1)
if selected_mask is not None:
color = tuple(np.random.randint(0, 256, 3).tolist())
selected_image = apply_color_mask(image_display, selected_mask, color)
# S保存w预测ad切换esc退出
cv2.imshow("image", image_display)
key = cv2.waitKey(1)
if key == ord(" "):
input_point = []
input_label = []
selected_mask = None
logit_input = None
elif key == ord("w"):
input_stop = True
if len(input_point) > 0 and len(input_label) > 0:
predictor.set_image(image) # 设置输入图像
input_point_np = np.array(input_point) # 输入暗示点,需要转变array类型才可以输入
input_label_np = np.array(input_label) # 输入暗示点的类型
masks, scores, logits = predictor.predict(
point_coords=input_point_np,
point_labels=input_label_np,
mask_input=logit_input[None, :, :] if logit_input is not None else None,
multimask_output=True,
)
mask_idx = 0
num_masks = len(masks) # masks的数量
while (1):
color = tuple(np.random.randint(0, 256, 3).tolist()) # 随机列表颜色,就是
image_select = image_orign.copy()
selected_mask = masks[mask_idx] # 选择msks也就是,a,d切换
selected_image = apply_color_mask(image_select, selected_mask, color)
mask_info = f'Total: {num_masks} | Current: {mask_idx} | Score: {scores[mask_idx]:.2f} | Press w to confirm | Press d to next mask | Press a to previous mask | Press q to remove last point | Press s to save'
cv2.putText(selected_image, mask_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2,
cv2.LINE_AA)
# todo 显示在当前的图片,
cv2.imshow("image", selected_image)
key = cv2.waitKey(10)
if key == ord('q') and len(input_point) > 0:
input_point.pop(-1)
input_label.pop(-1)
elif key == ord('s'):
save_masked_image(image_crop, selected_mask, output_dir, filename, crop_mode_=crop_mode)
elif key == ord('a'):
if mask_idx > 0:
mask_idx -= 1
else:
mask_idx = num_masks - 1
elif key == ord('d'):
if mask_idx < num_masks - 1:
mask_idx += 1
else:
mask_idx = 0
elif key == ord('w'):
break
elif key == ord(" "):
input_point = []
input_label = []
selected_mask = None
logit_input = None
break
logit_input = logits[mask_idx, :, :]
print('max score:', np.argmax(scores), ' select:', mask_idx)
elif key == ord('a'):
current_index = max(0, current_index - 1)
input_point = []
input_label = []
break
elif key == ord('d'):
current_index = min(len(image_files) - 1, current_index + 1)
input_point = []
input_label = []
break
elif key == 27:
break
elif key == ord('q') and len(input_point) > 0:
input_point.pop(-1)
input_label.pop(-1)
elif key == ord('s') and selected_mask is not None:
save_masked_image(image_crop, selected_mask, output_dir, filename, crop_mode_=crop_mode)
if key == 27:
break