181 lines
7.3 KiB
Python
181 lines
7.3 KiB
Python
import cv2
|
|
import os
|
|
import numpy as np
|
|
from segment_anything import sam_model_registry, SamPredictor
|
|
|
|
input_dir = r'C:\Users\t2581\Desktop\222\images'
|
|
output_dir = r'C:\Users\t2581\Desktop\222\json'
|
|
crop_mode = True
|
|
|
|
print('最好是每加一个点就按w键predict一次')
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
image_files = [f for f in os.listdir(input_dir) if
|
|
f.lower().endswith(('.png', '.jpg', '.jpeg', '.JPG', '.JPEG', '.PNG', '.tiff'))]
|
|
|
|
sam = sam_model_registry["vit_b"](checkpoint=r"D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth")
|
|
_ = sam.to(device="cuda")
|
|
predictor = SamPredictor(sam)
|
|
|
|
WINDOW_WIDTH = 1280
|
|
WINDOW_HEIGHT = 720
|
|
cv2.namedWindow("image", cv2.WINDOW_NORMAL)
|
|
cv2.resizeWindow("image", WINDOW_WIDTH, WINDOW_HEIGHT)
|
|
cv2.moveWindow("image", (1920 - WINDOW_WIDTH) // 2, (1080 - WINDOW_HEIGHT) // 2)
|
|
|
|
def mouse_click(event, x, y, flags, param):
|
|
global input_point, input_label, input_stop
|
|
if not input_stop:
|
|
if event == cv2.EVENT_LBUTTONDOWN:
|
|
input_point.append([x, y])
|
|
input_label.append(1)
|
|
elif event == cv2.EVENT_RBUTTONDOWN:
|
|
input_point.append([x, y])
|
|
input_label.append(0)
|
|
else:
|
|
if event == cv2.EVENT_LBUTTONDOWN or event == cv2.EVENT_RBUTTONDOWN:
|
|
print('此时不能添加点,按w退出mask选择模式')
|
|
|
|
|
|
def apply_mask(image, mask, alpha_channel=True):
|
|
if alpha_channel:
|
|
alpha = np.zeros_like(mask)
|
|
alpha[mask == 1] = 255
|
|
masked_image = image.copy()
|
|
masked_image[mask == 1] = [0, 255, 0]
|
|
return cv2.addWeighted(image, 0.5, masked_image, 0.5, 0)
|
|
else:
|
|
masked_image = image.copy()
|
|
masked_image[mask == 1] = [0, 255, 0]
|
|
return masked_image
|
|
|
|
def apply_color_mask(image, mask, color, color_dark=0.5):
|
|
masked_image = image.copy()
|
|
for c in range(3):
|
|
masked_image[:, :, c] = np.where(mask == 1, image[:, :, c] * (1 - color_dark) + color_dark * color[c], image[:, :, c])
|
|
return masked_image
|
|
|
|
def draw_external_rectangle(image, mask, pv):
|
|
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
for contour in contours:
|
|
x, y, w, h = cv2.boundingRect(contour)
|
|
cv2.rectangle(image, (x, y), (x + w, y + h), (0, 255, 255), 2) # Yellow rectangle
|
|
cv2.putText(image, pv, (x, y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
|
|
|
|
def save_masked_image(image, mask, output_dir, filename, crop_mode_, pv):
|
|
masked_image = apply_mask(image, mask)
|
|
draw_external_rectangle(masked_image, mask, pv)
|
|
filename = filename[:filename.rfind('.')] + '_masked.png'
|
|
cv2.imwrite(os.path.join(output_dir, filename), masked_image)
|
|
print(f"Saved as {filename}")
|
|
|
|
current_index = 0
|
|
|
|
cv2.namedWindow("image")
|
|
cv2.setMouseCallback("image", mouse_click)
|
|
input_point = []
|
|
input_label = []
|
|
input_stop = False
|
|
while True:
|
|
filename = image_files[current_index]
|
|
image_orign = cv2.imread(os.path.join(input_dir, filename))
|
|
image_crop = image_orign.copy()
|
|
image = cv2.cvtColor(image_orign.copy(), cv2.COLOR_BGR2RGB)
|
|
selected_mask = None
|
|
logit_input = None
|
|
while True:
|
|
# print(input_point)
|
|
input_stop = False
|
|
image_display = image_orign.copy()
|
|
display_info = f'{filename} '
|
|
cv2.putText(image_display, display_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2, cv2.LINE_AA)
|
|
for point, label in zip(input_point, input_label):
|
|
color = (0, 255, 0) if label == 1 else (0, 0, 255)
|
|
cv2.circle(image_display, tuple(point), 5, color, -1)
|
|
if selected_mask is not None:
|
|
color = tuple(np.random.randint(0, 256, 3).tolist())
|
|
selected_image = apply_color_mask(image_display, selected_mask, color)
|
|
|
|
cv2.imshow("image", image_display)
|
|
key = cv2.waitKey(1)
|
|
|
|
if key == ord(" "):
|
|
input_point = []
|
|
input_label = []
|
|
selected_mask = None
|
|
logit_input = None
|
|
elif key == ord("w"):
|
|
input_stop = True
|
|
if len(input_point) > 0 and len(input_label) > 0:
|
|
|
|
predictor.set_image(image)
|
|
input_point_np = np.array(input_point)
|
|
input_label_np = np.array(input_label)
|
|
|
|
masks, scores, logits = predictor.predict(
|
|
point_coords=input_point_np,
|
|
point_labels=input_label_np,
|
|
mask_input=logit_input[None, :, :] if logit_input is not None else None,
|
|
multimask_output=True,
|
|
)
|
|
|
|
mask_idx = 0
|
|
num_masks = len(masks) # masks的数量
|
|
while (1):
|
|
color = tuple(np.random.randint(0, 256, 3).tolist()) # 随机列表颜色,就是
|
|
image_select = image_orign.copy()
|
|
selected_mask = masks[mask_idx] # 选择msks也就是,a,d切换
|
|
selected_image = apply_color_mask(image_select, selected_mask, color)
|
|
mask_info = f'Total: {num_masks} | Current: {mask_idx} | Score: {scores[mask_idx]:.2f} '
|
|
cv2.putText(selected_image, mask_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2,
|
|
cv2.LINE_AA)
|
|
# todo 显示在当前的图片,
|
|
cv2.imshow("image", selected_image)
|
|
|
|
key = cv2.waitKey(10)
|
|
if key == ord('q') and len(input_point) > 0:
|
|
input_point.pop(-1)
|
|
input_label.pop(-1)
|
|
elif key == ord('s'):
|
|
save_masked_image(image_crop, selected_mask, output_dir, filename, crop_mode_=crop_mode, pv=f"pv: {scores[mask_idx]:.2f}")
|
|
elif key == ord('a'):
|
|
if mask_idx > 0:
|
|
mask_idx -= 1
|
|
else:
|
|
mask_idx = num_masks - 1
|
|
elif key == ord('d'):
|
|
if mask_idx < num_masks - 1:
|
|
mask_idx += 1
|
|
else:
|
|
mask_idx = 0
|
|
elif key == ord('w'):
|
|
break
|
|
elif key == ord(" "):
|
|
input_point = []
|
|
input_label = []
|
|
selected_mask = None
|
|
logit_input = None
|
|
break
|
|
logit_input = logits[mask_idx, :, :]
|
|
print('max score:', np.argmax(scores), ' select:', mask_idx)
|
|
|
|
elif key == ord('a'):
|
|
current_index = max(0, current_index - 1)
|
|
input_point = []
|
|
input_label = []
|
|
break
|
|
elif key == ord('d'):
|
|
current_index = min(len(image_files) - 1, current_index + 1)
|
|
input_point = []
|
|
input_label = []
|
|
break
|
|
elif key == 27:
|
|
break
|
|
elif key == ord('q') and len(input_point) > 0:
|
|
input_point.pop(-1)
|
|
input_label.pop(-1)
|
|
elif key == ord('s') and selected_mask is not None:
|
|
save_masked_image(image_crop, selected_mask, output_dir, filename, crop_mode_=crop_mode, pv=f"pv: {scores[mask_idx]:.2f}")
|
|
|
|
if key == 27:
|
|
break
|