SAM/scripts/namg.py

325 lines
10 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import cv2 # type: ignore
import matplotlib.pyplot as plt
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
import argparse
import json
import os
from typing import Any, Dict, List
import numpy as np
parser = argparse.ArgumentParser(
description=(
"Runs automatic mask generation on an input image or directory of images, "
"and outputs masks as either PNGs or COCO-style RLEs. Requires open-cv, "
"as well as pycocotools if saving in RLE format."
)
)
parser.add_argument(
"--input",
type=str,
default=r'D:\Program Files\Pycharm items\segment-anything-model\scripts\input\crops\warning\231314.jpg',
required=False,
help="Path to either a single input image or folder of images.",
)
parser.add_argument(
"--output",
type=str,
required=False,
default=r'D:\Program Files\Pycharm items\segment-anything-model\scripts\output\crops',
help=(
"Path to the directory where masks will be output. Output will be either a folder "
"of PNGs per image or a single json with COCO-style masks."
),
)
parser.add_argument(
"--model-type",
type=str,
required=False,
default='vit_b',
help="The type of model to load, in ['default', 'vit_h', 'vit_l', 'vit_b']",
)
parser.add_argument(
"--checkpoint",
type=str,
required=False,
default=r'D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth',
help="The path to the SAM checkpoint to use for mask generation.",
)
parser.add_argument("--device", type=str, default="cuda", help="The device to run generation on.")
parser.add_argument(
"--convert-to-rle",
action="store_true",
help=(
"Save masks as COCO RLEs in a single json instead of as a folder of PNGs. "
"Requires pycocotools."
),
)
amg_settings = parser.add_argument_group("AMG Settings")
amg_settings.add_argument(
"--points-per-side",
type=int,
default=None,
help="Generate masks by sampling a grid over the image with this many points to a side.",
)
amg_settings.add_argument(
"--points-per-batch",
type=int,
default=None,
help="How many input points to process simultaneously in one batch.",
)
amg_settings.add_argument(
"--pred-iou-thresh",
type=float,
default=None,
help="Exclude masks with a predicted score from the model that is lower than this threshold.",
)
amg_settings.add_argument(
"--stability-score-thresh",
type=float,
default=None,
help="Exclude masks with a stability score lower than this threshold.",
)
amg_settings.add_argument(
"--stability-score-offset",
type=float,
default=None,
help="Larger values perturb the mask more when measuring stability score.",
)
amg_settings.add_argument(
"--box-nms-thresh",
type=float,
default=None,
help="The overlap threshold for excluding a duplicate mask.",
)
amg_settings.add_argument(
"--crop-n-layers",
type=int,
default=None,
help=(
"If >0, mask generation is run on smaller crops of the image to generate more masks. "
"The value sets how many different scales to crop at."
),
)
amg_settings.add_argument(
"--crop-nms-thresh",
type=float,
default=None,
help="The overlap threshold for excluding duplicate masks across different crops.",
)
amg_settings.add_argument(
"--crop-overlap-ratio",
type=int,
default=None,
help="Larger numbers mean image crops will overlap more.",
)
amg_settings.add_argument(
"--crop-n-points-downscale-factor",
type=int,
default=None,
help="The number of points-per-side in each layer of crop is reduced by this factor.",
)
amg_settings.add_argument(
"--min-mask-region-area",
type=int,
default=None,
help=(
"Disconnected mask regions or holes with area smaller than this value "
"in pixels are removed by postprocessing."
),
)
def write_masks_to_folder(masks: List[Dict[str, Any]], path: str) -> None:
header = "id,area,bbox_x0,bbox_y0,bbox_w,bbox_h,point_input_x,point_input_y,predicted_iou,stability_score,crop_box_x0,crop_box_y0,crop_box_w,crop_box_h" # noqa
metadata = [header]
for i, mask_data in enumerate(masks):
mask = mask_data["segmentation"]
filename = f"{i}.png"
cv2.imwrite(os.path.join(path, filename), mask * 255)
mask_metadata = [
str(i),
str(mask_data["area"]),
*[str(x) for x in mask_data["bbox"]],
*[str(x) for x in mask_data["point_coords"][0]],
str(mask_data["predicted_iou"]),
str(mask_data["stability_score"]),
*[str(x) for x in mask_data["crop_box"]],
]
row = ",".join(mask_metadata)
metadata.append(row)
metadata_path = os.path.join(path, "metadata.csv")
with open(metadata_path, "w") as f:
f.write("\n".join(metadata))
return
def get_amg_kwargs(args):
amg_kwargs = {
"points_per_side": args.points_per_side,
"points_per_batch": args.points_per_batch,
"pred_iou_thresh": args.pred_iou_thresh,
"stability_score_thresh": args.stability_score_thresh,
"stability_score_offset": args.stability_score_offset,
"box_nms_thresh": args.box_nms_thresh,
"crop_n_layers": args.crop_n_layers,
"crop_nms_thresh": args.crop_nms_thresh,
"crop_overlap_ratio": args.crop_overlap_ratio,
"crop_n_points_downscale_factor": args.crop_n_points_downscale_factor,
"min_mask_region_area": args.min_mask_region_area,
}
amg_kwargs = {k: v for k, v in amg_kwargs.items() if v is not None}
return amg_kwargs
def show_mask(mask, ax, random_color=True):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
def show_points(coords, labels, ax, marker_size=375):
pos_points = coords[labels == 1]
neg_points = coords[labels == 0]
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white',
linewidth=1.25)
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white',
linewidth=1.25)
def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
# def main(args: argparse.Namespace) -> None:
# print("Loading model...")
# sam = sam_model_registry[args.model_type](checkpoint=args.checkpoint)
# _ = sam.to(device=args.device)
# # output_mode = "coco_rle" if args.convert_to_rle else "binary_mask"
# output_mode = 'binary_mask'
# amg_kwargs = get_amg_kwargs(args)
# generator = SamAutomaticMaskGenerator(sam, output_mode=output_mode, **amg_kwargs)
#
# floader_path = r"D:\Program Files\Pycharm items\segment-anything-model\scripts\input\crops\Guide" # 这里为一批图像所在的文件夹
# file_path = os.listdir(floader_path)
# for im in file_path:
# args.input = os.path.join(floader_path, im)
# if not os.path.isdir(args.input):
# targets = [args.input]
# else:
# targets = [
# f for f in os.listdir(args.input) if not os.path.isdir(os.path.join(args.input, f))
# ]
# targets = [os.path.join(args.input, f) for f in targets]
#
# os.makedirs(args.output, exist_ok=True)
#
# for t in targets:
# print(f"Processing '{t}'...")
# image = cv2.imread(t)
# if image is None:
# print(f"Could not load '{t}' as an image, skipping...")
# continue
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# masks = generator.generate(image)
#
#
# plt.imshow(image) # 这里自行选择加不加
# for mask in masks:
# show_mask(mask['segmentation'], plt.gca())
# # show_box(mask['bbox'],plt.gca())
# # plt.axis('off') 保存轴线或不保存
# plt.savefig(r'D:\Program Files\Pycharm items\segment-anything-model\scripts\output\crops'+im) # 这里要替换为自己的路径
# plt.close()
# print("Done!")
def main(args: argparse.Namespace) -> None:
print("Loading model...")
sam = sam_model_registry[args.model_type](checkpoint=args.checkpoint)
_ = sam.to(device=args.device)
output_mode = 'binary_mask'
amg_kwargs = get_amg_kwargs(args)
generator = SamAutomaticMaskGenerator(sam, output_mode=output_mode, **amg_kwargs)
if not os.path.isdir(args.input):
targets = [args.input]
else:
targets = [
f for f in os.listdir(args.input) if not os.path.isdir(os.path.join(args.input, f))
]
targets = [os.path.join(args.input, f) for f in targets]
os.makedirs(args.output, exist_ok=True)
for t in targets:
print(f"Processing '{t}'...")
image = cv2.imread(t)
if image is None:
print(f"Could not load '{t}' as an image, skipping...")
continue
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
masks = generator.generate(image)
# For visualization (optional)
plt.imshow(image)
for mask in masks:
show_mask(mask['segmentation'], plt.gca())
show_box(mask['bbox'], plt.gca())
plt.show()
# For saving masks
base = os.path.basename(t)
base = os.path.splitext(base)[0]
save_base = os.path.join(args.output, base)
if output_mode == "binary_mask":
os.makedirs(save_base, exist_ok=False)
for idx, mask in enumerate(masks):
mask_image = mask['segmentation'].astype('uint8') * 255
save_path = os.path.join(save_base, f"mask_{idx}.png")
cv2.imwrite(save_path, mask_image)
else:
save_file = save_base + ".json"
with open(save_file, "w") as f:
json.dump(masks, f)
print("Done!")
if __name__ == "__main__":
args = parser.parse_args()
main(args)