import cv2 import os import numpy as np import json from segment_anything import sam_model_registry, SamPredictor input_dir = r'C:\Users\t2581\Desktop\222\images' output_dir = r'C:\Users\t2581\Desktop\222\2' crop_mode = True print('最好是每加一个点就按w键predict一次') os.makedirs(output_dir, exist_ok=True) image_files = [f for f in os.listdir(input_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.JPG', '.JPEG', '.PNG', '.tiff'))] sam = sam_model_registry["vit_b"](checkpoint=r"D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth") _ = sam.to(device="cuda") predictor = SamPredictor(sam) WINDOW_WIDTH = 1280 WINDOW_HEIGHT = 720 cv2.namedWindow("image", cv2.WINDOW_NORMAL) cv2.resizeWindow("image", WINDOW_WIDTH, WINDOW_HEIGHT) cv2.moveWindow("image", (1920 - WINDOW_WIDTH) // 2, (1080 - WINDOW_HEIGHT) // 2) # 定义类别 categories = {1: "category1", 2: "category2", 3: "category3"} category_colors = {1: (0, 255, 0), 2: (0, 0, 255), 3: (255, 0, 0)} current_label = 1 # 默认类别 def mouse_click(event, x, y, flags, param): global input_points, input_labels, input_stop if not input_stop: if event == cv2.EVENT_LBUTTONDOWN: input_points.append([x, y]) input_labels.append(current_label) elif event == cv2.EVENT_RBUTTONDOWN: input_points.append([x, y]) input_labels.append(0) else: if event == cv2.EVENT_LBUTTONDOWN or event == cv2.EVENT_RBUTTONDOWN: print('此时不能添加点,按w退出mask选择模式') def apply_color_mask(image, mask, color, color_dark=0.5): masked_image = image.copy() for c in range(3): masked_image[:, :, c] = np.where(mask == 1, image[:, :, c] * (1 - color_dark) + color_dark * color[c], image[:, :, c]) return masked_image def draw_external_rectangle(image, mask, pv): contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) for contour in contours: x, y, w, h = cv2.boundingRect(contour) cv2.rectangle(image, (x, y), (x + w, y + h), (0, 255, 255), 2) # Yellow rectangle cv2.putText(image, pv, (x, y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2) def save_masked_image_and_json(image, masks, output_dir, filename, crop_mode_, pv): masked_image = image.copy() json_shapes = [] for mask, label, score in masks: color = category_colors[int(label[-1])] # 获取类别对应的颜色 masked_image = apply_color_mask(masked_image, mask, color) draw_external_rectangle(masked_image, mask, f"{label}: {score:.2f}") # Convert mask to polygons contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) polygons = [contour.reshape(-1, 2).tolist() for contour in contours] # Append JSON shapes json_shapes.extend([ { "label": label, "points": polygon, "group_id": None, "shape_type": "polygon", "flags": {} } for polygon in polygons ]) masked_filename = filename[:filename.rfind('.')] + '_masked.png' cv2.imwrite(os.path.join(output_dir, masked_filename), masked_image) print(f"Saved image as {masked_filename}") # Create JSON data json_data = { "version": "5.1.1", "flags": {}, "shapes": json_shapes, "imagePath": filename, "imageData": None, "imageHeight": image.shape[0], "imageWidth": image.shape[1] } # Save JSON file json_filename = filename[:filename.rfind('.')] + '_masked.json' with open(os.path.join(output_dir, json_filename), 'w') as json_file: json.dump(json_data, json_file, indent=4) print(f"Saved JSON as {json_filename}") current_index = 0 cv2.namedWindow("image") cv2.setMouseCallback("image", mouse_click) input_points = [] input_labels = [] input_stop = False masks = [] all_masks = [] # 用于保存所有类别的标注 while True: filename = image_files[current_index] image_orign = cv2.imread(os.path.join(input_dir, filename)) image_crop = image_orign.copy() image = cv2.cvtColor(image_orign.copy(), cv2.COLOR_BGR2RGB) selected_mask = None logit_input = None while True: input_stop = False image_display = image_orign.copy() display_info = f'{filename} | Current label: {categories[current_label]}' cv2.putText(image_display, display_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2, cv2.LINE_AA) for point, label in zip(input_points, input_labels): color = (0, 255, 0) if label > 0 else (0, 0, 255) cv2.circle(image_display, tuple(point), 5, color, -1) if selected_mask is not None: color = tuple(np.random.randint(0, 256, 3).tolist()) selected_image = apply_color_mask(image_display, selected_mask, color) cv2.imshow("image", image_display) key = cv2.waitKey(1) if key == ord(" "): input_points = [] input_labels = [] selected_mask = None logit_input = None elif key == ord("w"): input_stop = True if len(input_points) > 0 and len(input_labels) > 0: try: predictor.set_image(image) input_point_np = np.array(input_points) input_label_np = np.array(input_labels) masks_pred, scores, logits = predictor.predict( point_coords=input_point_np, point_labels=input_label_np, mask_input=logit_input[None, :, :] if logit_input is not None else None, multimask_output=True, ) mask_idx = 0 num_masks = len(masks_pred) # masks的数量 while True: color = tuple(np.random.randint(0, 256, 3).tolist()) # 随机列表颜色 image_select = image_orign.copy() selected_mask = masks_pred[mask_idx] # 选择msks也就是,a,d切换 selected_image = apply_color_mask(image_select, selected_mask, color) mask_info = f'Total: {num_masks} | Current: {mask_idx} | Score: {scores[mask_idx]:.2f} ' cv2.putText(selected_image, mask_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2, cv2.LINE_AA) # todo 显示在当前的图片, cv2.imshow("image", selected_image) key = cv2.waitKey(10) if key == ord('q') and len(input_points) > 0: input_points.pop(-1) input_labels.pop(-1) elif key == ord('s'): masks.append((selected_mask, categories[current_label], scores[mask_idx])) elif key == ord('a'): if mask_idx > 0: mask_idx -= 1 else: mask_idx = num_masks - 1 elif key == ord('d'): if mask_idx < num_masks - 1: mask_idx += 1 else: mask_idx = 0 elif key == ord('w'): break elif key == ord(" "): input_points = [] input_labels = [] selected_mask = None logit_input = None break logit_input = logits[mask_idx, :, :] print('max score:', np.argmax(scores), ' select:', mask_idx) except Exception as e: print(f"Error during prediction: {e}") elif key == ord('a'): current_index = max(0, current_index - 1) input_points = [] input_labels = [] break elif key == ord('d'): current_index = min(len(image_files) - 1, current_index + 1) input_points = [] input_labels = [] break elif key == 27: break elif key == ord('q') and len(input_points) > 0: input_points.pop(-1) input_labels.pop(-1) elif key == ord('r'): if masks: all_masks.extend(masks) # 保存当前的masks到all_masks masks = [] # 清空当前的masks input_points = [] input_labels = [] selected_mask = None logit_input = None elif key == ord('s'): if masks: all_masks.extend(masks) # 保存当前的masks到all_masks if all_masks: save_masked_image_and_json(image_crop, all_masks, output_dir, filename, crop_mode_=crop_mode, pv="") all_masks = [] # 清空所有保存的masks masks = [] # 清空当前的masks elif key in [ord(str(i)) for i in categories.keys()]: current_label = int(chr(key)) print(f"Switched to label: {categories[current_label]}") if key == 27: break cv2.destroyAllWindows() #多类别标注