# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import cv2 # type: ignore import matplotlib.pyplot as plt from segment_anything import SamAutomaticMaskGenerator, sam_model_registry import argparse import json import os from typing import Any, Dict, List import numpy as np parser = argparse.ArgumentParser( description=( "Runs automatic mask generation on an input image or directory of images, " "and outputs masks as either PNGs or COCO-style RLEs. Requires open-cv, " "as well as pycocotools if saving in RLE format." ) ) parser.add_argument( "--input", type=str, default=r'D:\Program Files\Pycharm items\segment-anything-model\scripts\input\crops\warning\231314.jpg', required=False, help="Path to either a single input image or folder of images.", ) parser.add_argument( "--output", type=str, required=False, default=r'D:\Program Files\Pycharm items\segment-anything-model\scripts\output\crops', help=( "Path to the directory where masks will be output. Output will be either a folder " "of PNGs per image or a single json with COCO-style masks." ), ) parser.add_argument( "--model-type", type=str, required=False, default='vit_b', help="The type of model to load, in ['default', 'vit_h', 'vit_l', 'vit_b']", ) parser.add_argument( "--checkpoint", type=str, required=False, default=r'D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth', help="The path to the SAM checkpoint to use for mask generation.", ) parser.add_argument("--device", type=str, default="cuda", help="The device to run generation on.") parser.add_argument( "--convert-to-rle", action="store_true", help=( "Save masks as COCO RLEs in a single json instead of as a folder of PNGs. " "Requires pycocotools." ), ) amg_settings = parser.add_argument_group("AMG Settings") amg_settings.add_argument( "--points-per-side", type=int, default=None, help="Generate masks by sampling a grid over the image with this many points to a side.", ) amg_settings.add_argument( "--points-per-batch", type=int, default=None, help="How many input points to process simultaneously in one batch.", ) amg_settings.add_argument( "--pred-iou-thresh", type=float, default=None, help="Exclude masks with a predicted score from the model that is lower than this threshold.", ) amg_settings.add_argument( "--stability-score-thresh", type=float, default=None, help="Exclude masks with a stability score lower than this threshold.", ) amg_settings.add_argument( "--stability-score-offset", type=float, default=None, help="Larger values perturb the mask more when measuring stability score.", ) amg_settings.add_argument( "--box-nms-thresh", type=float, default=None, help="The overlap threshold for excluding a duplicate mask.", ) amg_settings.add_argument( "--crop-n-layers", type=int, default=None, help=( "If >0, mask generation is run on smaller crops of the image to generate more masks. " "The value sets how many different scales to crop at." ), ) amg_settings.add_argument( "--crop-nms-thresh", type=float, default=None, help="The overlap threshold for excluding duplicate masks across different crops.", ) amg_settings.add_argument( "--crop-overlap-ratio", type=int, default=None, help="Larger numbers mean image crops will overlap more.", ) amg_settings.add_argument( "--crop-n-points-downscale-factor", type=int, default=None, help="The number of points-per-side in each layer of crop is reduced by this factor.", ) amg_settings.add_argument( "--min-mask-region-area", type=int, default=None, help=( "Disconnected mask regions or holes with area smaller than this value " "in pixels are removed by postprocessing." ), ) def write_masks_to_folder(masks: List[Dict[str, Any]], path: str) -> None: header = "id,area,bbox_x0,bbox_y0,bbox_w,bbox_h,point_input_x,point_input_y,predicted_iou,stability_score,crop_box_x0,crop_box_y0,crop_box_w,crop_box_h" # noqa metadata = [header] for i, mask_data in enumerate(masks): mask = mask_data["segmentation"] filename = f"{i}.png" cv2.imwrite(os.path.join(path, filename), mask * 255) mask_metadata = [ str(i), str(mask_data["area"]), *[str(x) for x in mask_data["bbox"]], *[str(x) for x in mask_data["point_coords"][0]], str(mask_data["predicted_iou"]), str(mask_data["stability_score"]), *[str(x) for x in mask_data["crop_box"]], ] row = ",".join(mask_metadata) metadata.append(row) metadata_path = os.path.join(path, "metadata.csv") with open(metadata_path, "w") as f: f.write("\n".join(metadata)) return def get_amg_kwargs(args): amg_kwargs = { "points_per_side": args.points_per_side, "points_per_batch": args.points_per_batch, "pred_iou_thresh": args.pred_iou_thresh, "stability_score_thresh": args.stability_score_thresh, "stability_score_offset": args.stability_score_offset, "box_nms_thresh": args.box_nms_thresh, "crop_n_layers": args.crop_n_layers, "crop_nms_thresh": args.crop_nms_thresh, "crop_overlap_ratio": args.crop_overlap_ratio, "crop_n_points_downscale_factor": args.crop_n_points_downscale_factor, "min_mask_region_area": args.min_mask_region_area, } amg_kwargs = {k: v for k, v in amg_kwargs.items() if v is not None} return amg_kwargs def show_mask(mask, ax, random_color=True): if random_color: color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) else: color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6]) h, w = mask.shape[-2:] mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) ax.imshow(mask_image) def show_points(coords, labels, ax, marker_size=375): pos_points = coords[labels == 1] neg_points = coords[labels == 0] ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) def show_box(box, ax): x0, y0 = box[0], box[1] w, h = box[2] - box[0], box[3] - box[1] ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2)) # def main(args: argparse.Namespace) -> None: # print("Loading model...") # sam = sam_model_registry[args.model_type](checkpoint=args.checkpoint) # _ = sam.to(device=args.device) # # output_mode = "coco_rle" if args.convert_to_rle else "binary_mask" # output_mode = 'binary_mask' # amg_kwargs = get_amg_kwargs(args) # generator = SamAutomaticMaskGenerator(sam, output_mode=output_mode, **amg_kwargs) # # floader_path = r"D:\Program Files\Pycharm items\segment-anything-model\scripts\input\crops\Guide" # 这里为一批图像所在的文件夹 # file_path = os.listdir(floader_path) # for im in file_path: # args.input = os.path.join(floader_path, im) # if not os.path.isdir(args.input): # targets = [args.input] # else: # targets = [ # f for f in os.listdir(args.input) if not os.path.isdir(os.path.join(args.input, f)) # ] # targets = [os.path.join(args.input, f) for f in targets] # # os.makedirs(args.output, exist_ok=True) # # for t in targets: # print(f"Processing '{t}'...") # image = cv2.imread(t) # if image is None: # print(f"Could not load '{t}' as an image, skipping...") # continue # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # masks = generator.generate(image) # # # plt.imshow(image) # 这里自行选择加不加 # for mask in masks: # show_mask(mask['segmentation'], plt.gca()) # # show_box(mask['bbox'],plt.gca()) # # plt.axis('off') 保存轴线或不保存 # plt.savefig(r'D:\Program Files\Pycharm items\segment-anything-model\scripts\output\crops'+im) # 这里要替换为自己的路径 # plt.close() # print("Done!") def main(args: argparse.Namespace) -> None: print("Loading model...") sam = sam_model_registry[args.model_type](checkpoint=args.checkpoint) _ = sam.to(device=args.device) output_mode = 'binary_mask' amg_kwargs = get_amg_kwargs(args) generator = SamAutomaticMaskGenerator(sam, output_mode=output_mode, **amg_kwargs) if not os.path.isdir(args.input): targets = [args.input] else: targets = [ f for f in os.listdir(args.input) if not os.path.isdir(os.path.join(args.input, f)) ] targets = [os.path.join(args.input, f) for f in targets] os.makedirs(args.output, exist_ok=True) for t in targets: print(f"Processing '{t}'...") image = cv2.imread(t) if image is None: print(f"Could not load '{t}' as an image, skipping...") continue image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) masks = generator.generate(image) # For visualization (optional) plt.imshow(image) for mask in masks: show_mask(mask['segmentation'], plt.gca()) show_box(mask['bbox'], plt.gca()) plt.show() # For saving masks base = os.path.basename(t) base = os.path.splitext(base)[0] save_base = os.path.join(args.output, base) if output_mode == "binary_mask": os.makedirs(save_base, exist_ok=False) for idx, mask in enumerate(masks): mask_image = mask['segmentation'].astype('uint8') * 255 save_path = os.path.join(save_base, f"mask_{idx}.png") cv2.imwrite(save_path, mask_image) else: save_file = save_base + ".json" with open(save_file, "w") as f: json.dump(masks, f) print("Done!") if __name__ == "__main__": args = parser.parse_args() main(args)