SAM/salt/SAM_JSON_多类别.py

239 lines
9.4 KiB
Python
Raw Permalink Normal View History

2024-06-19 08:51:04 +08:00
import cv2
import os
import numpy as np
import json
from segment_anything import sam_model_registry, SamPredictor
input_dir = r'C:\Users\t2581\Desktop\222\images'
output_dir = r'C:\Users\t2581\Desktop\222\2'
crop_mode = True
print('最好是每加一个点就按w键predict一次')
os.makedirs(output_dir, exist_ok=True)
image_files = [f for f in os.listdir(input_dir) if
f.lower().endswith(('.png', '.jpg', '.jpeg', '.JPG', '.JPEG', '.PNG', '.tiff'))]
sam = sam_model_registry["vit_b"](checkpoint=r"D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth")
_ = sam.to(device="cuda")
predictor = SamPredictor(sam)
WINDOW_WIDTH = 1280
WINDOW_HEIGHT = 720
cv2.namedWindow("image", cv2.WINDOW_NORMAL)
cv2.resizeWindow("image", WINDOW_WIDTH, WINDOW_HEIGHT)
cv2.moveWindow("image", (1920 - WINDOW_WIDTH) // 2, (1080 - WINDOW_HEIGHT) // 2)
# 定义类别
categories = {1: "category1", 2: "category2", 3: "category3"}
category_colors = {1: (0, 255, 0), 2: (0, 0, 255), 3: (255, 0, 0)}
current_label = 1 # 默认类别
def mouse_click(event, x, y, flags, param):
global input_points, input_labels, input_stop
if not input_stop:
if event == cv2.EVENT_LBUTTONDOWN:
input_points.append([x, y])
input_labels.append(current_label)
elif event == cv2.EVENT_RBUTTONDOWN:
input_points.append([x, y])
input_labels.append(0)
else:
if event == cv2.EVENT_LBUTTONDOWN or event == cv2.EVENT_RBUTTONDOWN:
print('此时不能添加点,按w退出mask选择模式')
def apply_color_mask(image, mask, color, color_dark=0.5):
masked_image = image.copy()
for c in range(3):
masked_image[:, :, c] = np.where(mask == 1, image[:, :, c] * (1 - color_dark) + color_dark * color[c],
image[:, :, c])
return masked_image
def draw_external_rectangle(image, mask, pv):
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
for contour in contours:
x, y, w, h = cv2.boundingRect(contour)
cv2.rectangle(image, (x, y), (x + w, y + h), (0, 255, 255), 2) # Yellow rectangle
cv2.putText(image, pv, (x, y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
def save_masked_image_and_json(image, masks, output_dir, filename, crop_mode_, pv):
masked_image = image.copy()
json_shapes = []
for mask, label, score in masks:
color = category_colors[int(label[-1])] # 获取类别对应的颜色
masked_image = apply_color_mask(masked_image, mask, color)
draw_external_rectangle(masked_image, mask, f"{label}: {score:.2f}")
# Convert mask to polygons
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
polygons = [contour.reshape(-1, 2).tolist() for contour in contours]
# Append JSON shapes
json_shapes.extend([
{
"label": label,
"points": polygon,
"group_id": None,
"shape_type": "polygon",
"flags": {}
} for polygon in polygons
])
masked_filename = filename[:filename.rfind('.')] + '_masked.png'
cv2.imwrite(os.path.join(output_dir, masked_filename), masked_image)
print(f"Saved image as {masked_filename}")
# Create JSON data
json_data = {
"version": "5.1.1",
"flags": {},
"shapes": json_shapes,
"imagePath": filename,
"imageData": None,
"imageHeight": image.shape[0],
"imageWidth": image.shape[1]
}
# Save JSON file
json_filename = filename[:filename.rfind('.')] + '_masked.json'
with open(os.path.join(output_dir, json_filename), 'w') as json_file:
json.dump(json_data, json_file, indent=4)
print(f"Saved JSON as {json_filename}")
current_index = 0
cv2.namedWindow("image")
cv2.setMouseCallback("image", mouse_click)
input_points = []
input_labels = []
input_stop = False
masks = []
all_masks = [] # 用于保存所有类别的标注
while True:
filename = image_files[current_index]
image_orign = cv2.imread(os.path.join(input_dir, filename))
image_crop = image_orign.copy()
image = cv2.cvtColor(image_orign.copy(), cv2.COLOR_BGR2RGB)
selected_mask = None
logit_input = None
while True:
input_stop = False
image_display = image_orign.copy()
display_info = f'{filename} | Current label: {categories[current_label]}'
cv2.putText(image_display, display_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2, cv2.LINE_AA)
for point, label in zip(input_points, input_labels):
color = (0, 255, 0) if label > 0 else (0, 0, 255)
cv2.circle(image_display, tuple(point), 5, color, -1)
if selected_mask is not None:
color = tuple(np.random.randint(0, 256, 3).tolist())
selected_image = apply_color_mask(image_display, selected_mask, color)
cv2.imshow("image", image_display)
key = cv2.waitKey(1)
if key == ord(" "):
input_points = []
input_labels = []
selected_mask = None
logit_input = None
elif key == ord("w"):
input_stop = True
if len(input_points) > 0 and len(input_labels) > 0:
try:
predictor.set_image(image)
input_point_np = np.array(input_points)
input_label_np = np.array(input_labels)
masks_pred, scores, logits = predictor.predict(
point_coords=input_point_np,
point_labels=input_label_np,
mask_input=logit_input[None, :, :] if logit_input is not None else None,
multimask_output=True,
)
mask_idx = 0
num_masks = len(masks_pred) # masks的数量
while True:
color = tuple(np.random.randint(0, 256, 3).tolist()) # 随机列表颜色
image_select = image_orign.copy()
selected_mask = masks_pred[mask_idx] # 选择msks也就是,a,d切换
selected_image = apply_color_mask(image_select, selected_mask, color)
mask_info = f'Total: {num_masks} | Current: {mask_idx} | Score: {scores[mask_idx]:.2f} '
cv2.putText(selected_image, mask_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255),
2, cv2.LINE_AA)
# todo 显示在当前的图片,
cv2.imshow("image", selected_image)
key = cv2.waitKey(10)
if key == ord('q') and len(input_points) > 0:
input_points.pop(-1)
input_labels.pop(-1)
elif key == ord('s'):
masks.append((selected_mask, categories[current_label], scores[mask_idx]))
elif key == ord('a'):
if mask_idx > 0:
mask_idx -= 1
else:
mask_idx = num_masks - 1
elif key == ord('d'):
if mask_idx < num_masks - 1:
mask_idx += 1
else:
mask_idx = 0
elif key == ord('w'):
break
elif key == ord(" "):
input_points = []
input_labels = []
selected_mask = None
logit_input = None
break
logit_input = logits[mask_idx, :, :]
print('max score:', np.argmax(scores), ' select:', mask_idx)
except Exception as e:
print(f"Error during prediction: {e}")
elif key == ord('a'):
current_index = max(0, current_index - 1)
input_points = []
input_labels = []
break
elif key == ord('d'):
current_index = min(len(image_files) - 1, current_index + 1)
input_points = []
input_labels = []
break
elif key == 27:
break
elif key == ord('q') and len(input_points) > 0:
input_points.pop(-1)
input_labels.pop(-1)
elif key == ord('r'):
if masks:
all_masks.extend(masks) # 保存当前的masks到all_masks
masks = [] # 清空当前的masks
input_points = []
input_labels = []
selected_mask = None
logit_input = None
elif key == ord('s'):
if masks:
all_masks.extend(masks) # 保存当前的masks到all_masks
if all_masks:
save_masked_image_and_json(image_crop, all_masks, output_dir, filename, crop_mode_=crop_mode, pv="")
all_masks = [] # 清空所有保存的masks
masks = [] # 清空当前的masks
elif key in [ord(str(i)) for i in categories.keys()]:
current_label = int(chr(key))
print(f"Switched to label: {categories[current_label]}")
if key == 27:
break
cv2.destroyAllWindows()
2024-06-19 08:57:04 +08:00
#多类别标注
2024-06-19 08:51:04 +08:00