commit 62bad68b870af0b8f2bf45d01dbef2055b55a5bf Author: tanzk Date: Wed Jun 19 08:51:04 2024 +0800 SAM_Project diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..6b07595 --- /dev/null +++ b/.flake8 @@ -0,0 +1,7 @@ +[flake8] +ignore = W503, E203, E221, C901, C408, E741, C407, B017, F811, C101, EXE001, EXE002 +max-line-length = 100 +max-complexity = 18 +select = B,C,E,F,W,T4,B9 +per-file-ignores = + **/__init__.py:F401,F403,E402 diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..35410ca --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# 默认忽略的文件 +/shelf/ +/workspace.xml +# 基于编辑器的 HTTP 客户端请求 +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/deployment.xml b/.idea/deployment.xml new file mode 100644 index 0000000..f637c16 --- /dev/null +++ b/.idea/deployment.xml @@ -0,0 +1,98 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..f1146ff --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,26 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..4bdd18a --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..39fb201 --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/segment-anything-model.iml b/.idea/segment-anything-model.iml new file mode 100644 index 0000000..95337de --- /dev/null +++ b/.idea/segment-anything-model.iml @@ -0,0 +1,14 @@ + + + + + + + + + + + + \ No newline at end of file diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000..08b500a --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,80 @@ +# Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to make participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, sex characteristics, gender identity and expression, +level of experience, education, socio-economic status, nationality, personal +appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or + advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic + address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies within all project spaces, and it also applies when +an individual is representing the project or its community in public spaces. +Examples of representing a project or community include using an official +project e-mail address, posting via an official social media account, or acting +as an appointed representative at an online or offline event. Representation of +a project may be further defined and clarified by project maintainers. + +This Code of Conduct also applies outside the project spaces when there is a +reasonable belief that an individual's behavior may have a negative impact on +the project or its community. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported by contacting the project team at . All +complaints will be reviewed and investigated and will result in a response that +is deemed necessary and appropriate to the circumstances. The project team is +obligated to maintain confidentiality with regard to the reporter of an incident. +Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see +https://www.contributor-covenant.org/faq diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..263991c --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,31 @@ +# Contributing to segment-anything +We want to make contributing to this project as easy and transparent as +possible. + +## Pull Requests +We actively welcome your pull requests. + +1. Fork the repo and create your branch from `main`. +2. If you've added code that should be tested, add tests. +3. If you've changed APIs, update the documentation. +4. Ensure the test suite passes. +5. Make sure your code lints, using the `linter.sh` script in the project's root directory. Linting requires `black==23.*`, `isort==5.12.0`, `flake8`, and `mypy`. +6. If you haven't already, complete the Contributor License Agreement ("CLA"). + +## Contributor License Agreement ("CLA") +In order to accept your pull request, we need you to submit a CLA. You only need +to do this once to work on any of Facebook's open source projects. + +Complete your CLA here: + +## Issues +We use GitHub issues to track public bugs. Please ensure your description is +clear and has sufficient instructions to be able to reproduce the issue. + +Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe +disclosure of security bugs. In those cases, please go through the process +outlined on that page and do not file a public issue. + +## License +By contributing to segment-anything, you agree that your contributions will be licensed +under the LICENSE file in the root directory of this source tree. diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..261eeb9 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000..4f5efb9 --- /dev/null +++ b/README.md @@ -0,0 +1,171 @@ +# Segment Anything + +**[Meta AI Research, FAIR](https://ai.facebook.com/research/)** + +[Alexander Kirillov](https://alexander-kirillov.github.io/), [Eric Mintun](https://ericmintun.github.io/), [Nikhila Ravi](https://nikhilaravi.com/), [Hanzi Mao](https://hanzimao.me/), Chloe Rolland, Laura Gustafson, [Tete Xiao](https://tetexiao.com), [Spencer Whitehead](https://www.spencerwhitehead.com/), Alex Berg, Wan-Yen Lo, [Piotr Dollar](https://pdollar.github.io/), [Ross Girshick](https://www.rossgirshick.info/) + +[[`Paper`](https://ai.facebook.com/research/publications/segment-anything/)] [[`Project`](https://segment-anything.com/)] [[`Demo`](https://segment-anything.com/demo)] [[`Dataset`](https://segment-anything.com/dataset/index.html)] [[`Blog`](https://ai.facebook.com/blog/segment-anything-foundation-model-image-segmentation/)] [[`BibTeX`](#citing-segment-anything)] + +![SAM design](assets/model_diagram.png?raw=true) + +The **Segment Anything Model (SAM)** produces high quality object masks from input prompts such as points or boxes, and it can be used to generate masks for all objects in an image. It has been trained on a [dataset](https://segment-anything.com/dataset/index.html) of 11 million images and 1.1 billion masks, and has strong zero-shot performance on a variety of segmentation tasks. + +

+ + +

+ +## Installation + +The code requires `python>=3.8`, as well as `pytorch>=1.7` and `torchvision>=0.8`. Please follow the instructions [here](https://pytorch.org/get-started/locally/) to install both PyTorch and TorchVision dependencies. Installing both PyTorch and TorchVision with CUDA support is strongly recommended. + +Install Segment Anything: + +``` +pip install git+https://github.com/facebookresearch/segment-anything.git +``` + +or clone the repository locally and install with + +``` +git clone git@github.com:facebookresearch/segment-anything.git +cd segment-anything; pip install -e . +``` + +The following optional dependencies are necessary for mask post-processing, saving masks in COCO format, the example notebooks, and exporting the model in ONNX format. `jupyter` is also required to run the example notebooks. + +``` +pip install opencv-python pycocotools matplotlib onnxruntime onnx +``` + +## Getting Started + +First download a [model checkpoint](#model-checkpoints). Then the model can be used in just a few lines to get masks from a given prompt: + +``` +from segment_anything import SamPredictor, sam_model_registry +sam = sam_model_registry[""](checkpoint="") +predictor = SamPredictor(sam) +predictor.set_image() +masks, _, _ = predictor.predict() +``` + +or generate masks for an entire image: + +``` +from segment_anything import SamAutomaticMaskGenerator, sam_model_registry +sam = sam_model_registry[""](checkpoint="") +mask_generator = SamAutomaticMaskGenerator(sam) +masks = mask_generator.generate() +``` + +Additionally, masks can be generated for images from the command line: + +``` +python scripts/amg.py --checkpoint --model-type --input --output +``` + +See the examples notebooks on [using SAM with prompts](/notebooks/predictor_example.ipynb) and [automatically generating masks](/notebooks/automatic_mask_generator_example.ipynb) for more details. + +

+ + +

+ +## ONNX Export + +SAM's lightweight mask decoder can be exported to ONNX format so that it can be run in any environment that supports ONNX runtime, such as in-browser as showcased in the [demo](https://segment-anything.com/demo). Export the model with + +``` +python scripts/export_onnx_model.py --checkpoint --model-type --output +``` + +See the [example notebook](https://github.com/facebookresearch/segment-anything/blob/main/notebooks/onnx_model_example.ipynb) for details on how to combine image preprocessing via SAM's backbone with mask prediction using the ONNX model. It is recommended to use the latest stable version of PyTorch for ONNX export. + +### Web demo + +The `demo/` folder has a simple one page React app which shows how to run mask prediction with the exported ONNX model in a web browser with multithreading. Please see [`demo/README.md`](https://github.com/facebookresearch/segment-anything/blob/main/demo/README.md) for more details. + +## Model Checkpoints + +Three model versions of the model are available with different backbone sizes. These models can be instantiated by running + +``` +from segment_anything import sam_model_registry +sam = sam_model_registry[""](checkpoint="") +``` + +Click the links below to download the checkpoint for the corresponding model type. + +- **`default` or `vit_h`: [ViT-H SAM model.](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth)** +- `vit_l`: [ViT-L SAM model.](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth) +- `vit_b`: [ViT-B SAM model.](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth) + +## Dataset + +See [here](https://ai.facebook.com/datasets/segment-anything/) for an overview of the datastet. The dataset can be downloaded [here](https://ai.facebook.com/datasets/segment-anything-downloads/). By downloading the datasets you agree that you have read and accepted the terms of the SA-1B Dataset Research License. + +We save masks per image as a json file. It can be loaded as a dictionary in python in the below format. + +```python +{ + "image" : image_info, + "annotations" : [annotation], +} + +image_info { + "image_id" : int, # Image id + "width" : int, # Image width + "height" : int, # Image height + "file_name" : str, # Image filename +} + +annotation { + "id" : int, # Annotation id + "segmentation" : dict, # Mask saved in COCO RLE format. + "bbox" : [x, y, w, h], # The box around the mask, in XYWH format + "area" : int, # The area in pixels of the mask + "predicted_iou" : float, # The model's own prediction of the mask's quality + "stability_score" : float, # A measure of the mask's quality + "crop_box" : [x, y, w, h], # The crop of the image used to generate the mask, in XYWH format + "point_coords" : [[x, y]], # The point coordinates input to the model to generate the mask +} +``` + +Image ids can be found in sa_images_ids.txt which can be downloaded using the above [link](https://ai.facebook.com/datasets/segment-anything-downloads/) as well. + +To decode a mask in COCO RLE format into binary: + +``` +from pycocotools import mask as mask_utils +mask = mask_utils.decode(annotation["segmentation"]) +``` + +See [here](https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/mask.py) for more instructions to manipulate masks stored in RLE format. + +## License + +The model is licensed under the [Apache 2.0 license](LICENSE). + +## Contributing + +See [contributing](CONTRIBUTING.md) and the [code of conduct](CODE_OF_CONDUCT.md). + +## Contributors + +The Segment Anything project was made possible with the help of many contributors (alphabetical): + +Aaron Adcock, Vaibhav Aggarwal, Morteza Behrooz, Cheng-Yang Fu, Ashley Gabriel, Ahuva Goldstand, Allen Goodman, Sumanth Gurram, Jiabo Hu, Somya Jain, Devansh Kukreja, Robert Kuo, Joshua Lane, Yanghao Li, Lilian Luong, Jitendra Malik, Mallika Malhotra, William Ngan, Omkar Parkhi, Nikhil Raina, Dirk Rowe, Neil Sejoor, Vanessa Stark, Bala Varadarajan, Bram Wasti, Zachary Winstrom + +## Citing Segment Anything + +If you use SAM or SA-1B in your research, please use the following BibTeX entry. + +``` +@article{kirillov2023segany, + title={Segment Anything}, + author={Kirillov, Alexander and Mintun, Eric and Ravi, Nikhila and Mao, Hanzi and Rolland, Chloe and Gustafson, Laura and Xiao, Tete and Whitehead, Spencer and Berg, Alexander C. and Lo, Wan-Yen and Doll{\'a}r, Piotr and Girshick, Ross}, + journal={arXiv:2304.02643}, + year={2023} +} +``` diff --git a/SAM_Mask.py b/SAM_Mask.py new file mode 100644 index 0000000..db4f84e --- /dev/null +++ b/SAM_Mask.py @@ -0,0 +1,175 @@ +import cv2 +import os +import numpy as np +from segment_anything import sam_model_registry, SamPredictor + +input_dir = 'scripts/pv1/0.8/wuding/jpg' +output_dir = 'scripts/pv1/0.8/wuding/mask' +crop_mode = True + +print('最好是每加一个点就按w键predict一次') +os.makedirs(output_dir, exist_ok=True) +image_files = [f for f in os.listdir(input_dir) if + f.lower().endswith(('.png', '.jpg', '.jpeg', '.JPG', '.JPEG', '.PNG', '.tif'))] + +sam = sam_model_registry["vit_b"](checkpoint=r"D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth") +_ = sam.to(device="cuda") +predictor = SamPredictor(sam) + +WINDOW_WIDTH = 1280 +WINDOW_HEIGHT = 720 +cv2.namedWindow("image", cv2.WINDOW_NORMAL) +cv2.resizeWindow("image", WINDOW_WIDTH, WINDOW_HEIGHT) +cv2.moveWindow("image", (1920 - WINDOW_WIDTH) // 2, (1080 - WINDOW_HEIGHT) // 2) + +def mouse_click(event, x, y, flags, param): # 鼠标点击事件 + global input_point, input_label, input_stop # 全局变量,输入点, + if not input_stop: # 判定标志是否停止输入响应了! + if event == cv2.EVENT_LBUTTONDOWN: # 鼠标左键 + input_point.append([x, y]) + input_label.append(1) # 1表示前景点 + elif event == cv2.EVENT_RBUTTONDOWN: # 鼠标右键 + input_point.append([x, y]) + input_label.append(0) # 0表示背景点 + else: + if event == cv2.EVENT_LBUTTONDOWN or event == cv2.EVENT_RBUTTONDOWN: # 提示添加不了 + print('此时不能添加点,按w退出mask选择模式') + + +def apply_mask(image, mask, alpha_channel=True): + if alpha_channel: + alpha = np.zeros_like(mask) + alpha[mask == 1] = 255 + masked_image = image.copy() + masked_image[mask == 1] = [0, 255, 0] # 这里用绿色表示mask的区域,可以根据需求修改 + return cv2.addWeighted(image, 0.5, masked_image, 0.5, 0) + else: + masked_image = image.copy() + masked_image[mask == 1] = [0, 255, 0] # 这里用绿色表示mask的区域,可以根据需求修改 + return masked_image + +def apply_color_mask(image, mask, color, color_dark=0.5): + masked_image = image.copy() + for c in range(3): + masked_image[:, :, c] = np.where(mask == 1, image[:, :, c] * (1 - color_dark) + color_dark * color[c], image[:, :, c]) + return masked_image + +def save_masked_image(image, mask, output_dir, filename, crop_mode_): + masked_image = apply_mask(image, mask) + filename = filename[:filename.rfind('.')] + '_masked.png' # 修改保存的文件名 + cv2.imwrite(os.path.join(output_dir, filename), masked_image) + print(f"Saved as {filename}") + + + +current_index = 0 + +cv2.namedWindow("image") +cv2.setMouseCallback("image", mouse_click) +input_point = [] +input_label = [] +input_stop = False +while True: + filename = image_files[current_index] + image_orign = cv2.imread(os.path.join(input_dir, filename)) + image_crop = image_orign.copy() # 原图裁剪 + image = cv2.cvtColor(image_orign.copy(), cv2.COLOR_BGR2RGB) # 原图色彩转变 + selected_mask = None + logit_input = None + while True: + # print(input_point) + input_stop = False + image_display = image_orign.copy() + display_info = f'{filename} | Press s to save | Press w to predict | Press d to next image | Press a to previous image | Press space to clear | Press q to remove last point ' + cv2.putText(image_display, display_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2, cv2.LINE_AA) + for point, label in zip(input_point, input_label): # 输入点和输入类型 + color = (0, 255, 0) if label == 1 else (0, 0, 255) + cv2.circle(image_display, tuple(point), 5, color, -1) + if selected_mask is not None: + color = tuple(np.random.randint(0, 256, 3).tolist()) + selected_image = apply_color_mask(image_display, selected_mask, color) +# S保存,w预测,ad切换,esc退出 + cv2.imshow("image", image_display) + key = cv2.waitKey(1) + + if key == ord(" "): + input_point = [] + input_label = [] + selected_mask = None + logit_input = None + elif key == ord("w"): + input_stop = True + if len(input_point) > 0 and len(input_label) > 0: + + predictor.set_image(image) # 设置输入图像 + input_point_np = np.array(input_point) # 输入暗示点,需要转变array类型才可以输入 + input_label_np = np.array(input_label) # 输入暗示点的类型 + + masks, scores, logits = predictor.predict( + point_coords=input_point_np, + point_labels=input_label_np, + mask_input=logit_input[None, :, :] if logit_input is not None else None, + multimask_output=True, + ) + + mask_idx = 0 + num_masks = len(masks) # masks的数量 + while (1): + color = tuple(np.random.randint(0, 256, 3).tolist()) # 随机列表颜色,就是 + image_select = image_orign.copy() + selected_mask = masks[mask_idx] # 选择msks也就是,a,d切换 + selected_image = apply_color_mask(image_select, selected_mask, color) + mask_info = f'Total: {num_masks} | Current: {mask_idx} | Score: {scores[mask_idx]:.2f} | Press w to confirm | Press d to next mask | Press a to previous mask | Press q to remove last point | Press s to save' + cv2.putText(selected_image, mask_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2, + cv2.LINE_AA) + # todo 显示在当前的图片, + cv2.imshow("image", selected_image) + + key = cv2.waitKey(10) + if key == ord('q') and len(input_point) > 0: + input_point.pop(-1) + input_label.pop(-1) + elif key == ord('s'): + save_masked_image(image_crop, selected_mask, output_dir, filename, crop_mode_=crop_mode) + elif key == ord('a'): + if mask_idx > 0: + mask_idx -= 1 + else: + mask_idx = num_masks - 1 + elif key == ord('d'): + if mask_idx < num_masks - 1: + mask_idx += 1 + else: + mask_idx = 0 + elif key == ord('w'): + break + elif key == ord(" "): + input_point = [] + input_label = [] + selected_mask = None + logit_input = None + break + logit_input = logits[mask_idx, :, :] + print('max score:', np.argmax(scores), ' select:', mask_idx) + + elif key == ord('a'): + current_index = max(0, current_index - 1) + input_point = [] + input_label = [] + break + elif key == ord('d'): + current_index = min(len(image_files) - 1, current_index + 1) + input_point = [] + input_label = [] + break + elif key == 27: + break + elif key == ord('q') and len(input_point) > 0: + input_point.pop(-1) + input_label.pop(-1) + elif key == ord('s') and selected_mask is not None: + save_masked_image(image_crop, selected_mask, output_dir, filename, crop_mode_=crop_mode) + + if key == 27: + break + diff --git a/SAM_YY.py b/SAM_YY.py new file mode 100644 index 0000000..dd16d46 --- /dev/null +++ b/SAM_YY.py @@ -0,0 +1,180 @@ +import cv2 +import os +import numpy as np +from segment_anything import sam_model_registry, SamPredictor + +input_dir = r'C:\Users\t2581\Desktop\222\images' +output_dir = r'C:\Users\t2581\Desktop\222\json' +crop_mode = True + +print('最好是每加一个点就按w键predict一次') +os.makedirs(output_dir, exist_ok=True) +image_files = [f for f in os.listdir(input_dir) if + f.lower().endswith(('.png', '.jpg', '.jpeg', '.JPG', '.JPEG', '.PNG', '.tiff'))] + +sam = sam_model_registry["vit_b"](checkpoint=r"D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth") +_ = sam.to(device="cuda") +predictor = SamPredictor(sam) + +WINDOW_WIDTH = 1280 +WINDOW_HEIGHT = 720 +cv2.namedWindow("image", cv2.WINDOW_NORMAL) +cv2.resizeWindow("image", WINDOW_WIDTH, WINDOW_HEIGHT) +cv2.moveWindow("image", (1920 - WINDOW_WIDTH) // 2, (1080 - WINDOW_HEIGHT) // 2) + +def mouse_click(event, x, y, flags, param): + global input_point, input_label, input_stop + if not input_stop: + if event == cv2.EVENT_LBUTTONDOWN: + input_point.append([x, y]) + input_label.append(1) + elif event == cv2.EVENT_RBUTTONDOWN: + input_point.append([x, y]) + input_label.append(0) + else: + if event == cv2.EVENT_LBUTTONDOWN or event == cv2.EVENT_RBUTTONDOWN: + print('此时不能添加点,按w退出mask选择模式') + + +def apply_mask(image, mask, alpha_channel=True): + if alpha_channel: + alpha = np.zeros_like(mask) + alpha[mask == 1] = 255 + masked_image = image.copy() + masked_image[mask == 1] = [0, 255, 0] + return cv2.addWeighted(image, 0.5, masked_image, 0.5, 0) + else: + masked_image = image.copy() + masked_image[mask == 1] = [0, 255, 0] + return masked_image + +def apply_color_mask(image, mask, color, color_dark=0.5): + masked_image = image.copy() + for c in range(3): + masked_image[:, :, c] = np.where(mask == 1, image[:, :, c] * (1 - color_dark) + color_dark * color[c], image[:, :, c]) + return masked_image + +def draw_external_rectangle(image, mask, pv): + contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + for contour in contours: + x, y, w, h = cv2.boundingRect(contour) + cv2.rectangle(image, (x, y), (x + w, y + h), (0, 255, 255), 2) # Yellow rectangle + cv2.putText(image, pv, (x, y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2) + +def save_masked_image(image, mask, output_dir, filename, crop_mode_, pv): + masked_image = apply_mask(image, mask) + draw_external_rectangle(masked_image, mask, pv) + filename = filename[:filename.rfind('.')] + '_masked.png' + cv2.imwrite(os.path.join(output_dir, filename), masked_image) + print(f"Saved as {filename}") + +current_index = 0 + +cv2.namedWindow("image") +cv2.setMouseCallback("image", mouse_click) +input_point = [] +input_label = [] +input_stop = False +while True: + filename = image_files[current_index] + image_orign = cv2.imread(os.path.join(input_dir, filename)) + image_crop = image_orign.copy() + image = cv2.cvtColor(image_orign.copy(), cv2.COLOR_BGR2RGB) + selected_mask = None + logit_input = None + while True: + # print(input_point) + input_stop = False + image_display = image_orign.copy() + display_info = f'{filename} ' + cv2.putText(image_display, display_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2, cv2.LINE_AA) + for point, label in zip(input_point, input_label): + color = (0, 255, 0) if label == 1 else (0, 0, 255) + cv2.circle(image_display, tuple(point), 5, color, -1) + if selected_mask is not None: + color = tuple(np.random.randint(0, 256, 3).tolist()) + selected_image = apply_color_mask(image_display, selected_mask, color) + + cv2.imshow("image", image_display) + key = cv2.waitKey(1) + + if key == ord(" "): + input_point = [] + input_label = [] + selected_mask = None + logit_input = None + elif key == ord("w"): + input_stop = True + if len(input_point) > 0 and len(input_label) > 0: + + predictor.set_image(image) + input_point_np = np.array(input_point) + input_label_np = np.array(input_label) + + masks, scores, logits = predictor.predict( + point_coords=input_point_np, + point_labels=input_label_np, + mask_input=logit_input[None, :, :] if logit_input is not None else None, + multimask_output=True, + ) + + mask_idx = 0 + num_masks = len(masks) # masks的数量 + while (1): + color = tuple(np.random.randint(0, 256, 3).tolist()) # 随机列表颜色,就是 + image_select = image_orign.copy() + selected_mask = masks[mask_idx] # 选择msks也就是,a,d切换 + selected_image = apply_color_mask(image_select, selected_mask, color) + mask_info = f'Total: {num_masks} | Current: {mask_idx} | Score: {scores[mask_idx]:.2f} ' + cv2.putText(selected_image, mask_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2, + cv2.LINE_AA) + # todo 显示在当前的图片, + cv2.imshow("image", selected_image) + + key = cv2.waitKey(10) + if key == ord('q') and len(input_point) > 0: + input_point.pop(-1) + input_label.pop(-1) + elif key == ord('s'): + save_masked_image(image_crop, selected_mask, output_dir, filename, crop_mode_=crop_mode, pv=f"pv: {scores[mask_idx]:.2f}") + elif key == ord('a'): + if mask_idx > 0: + mask_idx -= 1 + else: + mask_idx = num_masks - 1 + elif key == ord('d'): + if mask_idx < num_masks - 1: + mask_idx += 1 + else: + mask_idx = 0 + elif key == ord('w'): + break + elif key == ord(" "): + input_point = [] + input_label = [] + selected_mask = None + logit_input = None + break + logit_input = logits[mask_idx, :, :] + print('max score:', np.argmax(scores), ' select:', mask_idx) + + elif key == ord('a'): + current_index = max(0, current_index - 1) + input_point = [] + input_label = [] + break + elif key == ord('d'): + current_index = min(len(image_files) - 1, current_index + 1) + input_point = [] + input_label = [] + break + elif key == 27: + break + elif key == ord('q') and len(input_point) > 0: + input_point.pop(-1) + input_label.pop(-1) + elif key == ord('s') and selected_mask is not None: + save_masked_image(image_crop, selected_mask, output_dir, filename, crop_mode_=crop_mode, pv=f"pv: {scores[mask_idx]:.2f}") + + if key == 27: + break diff --git a/SAM_YY_JSON.py b/SAM_YY_JSON.py new file mode 100644 index 0000000..860b4ae --- /dev/null +++ b/SAM_YY_JSON.py @@ -0,0 +1,209 @@ +import cv2 +import os +import numpy as np +import json +from segment_anything import sam_model_registry, SamPredictor + +input_dir = r'C:\Users\t2581\Desktop\222\images' +output_dir = r'C:\Users\t2581\Desktop\222\2' +crop_mode = True + +print('最好是每加一个点就按w键predict一次') +os.makedirs(output_dir, exist_ok=True) +image_files = [f for f in os.listdir(input_dir) if + f.lower().endswith(('.png', '.jpg', '.jpeg', '.JPG', '.JPEG', '.PNG', '.tiff'))] + +sam = sam_model_registry["vit_b"](checkpoint=r"D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth") +_ = sam.to(device="cuda") +predictor = SamPredictor(sam) + +WINDOW_WIDTH = 1280 +WINDOW_HEIGHT = 720 +cv2.namedWindow("image", cv2.WINDOW_NORMAL) +cv2.resizeWindow("image", WINDOW_WIDTH, WINDOW_HEIGHT) +cv2.moveWindow("image", (1920 - WINDOW_WIDTH) // 2, (1080 - WINDOW_HEIGHT) // 2) + +def mouse_click(event, x, y, flags, param): + global input_point, input_label, input_stop + if not input_stop: + if event == cv2.EVENT_LBUTTONDOWN: + input_point.append([x, y]) + input_label.append(1) + elif event == cv2.EVENT_RBUTTONDOWN: + input_point.append([x, y]) + input_label.append(0) + else: + if event == cv2.EVENT_LBUTTONDOWN or event == cv2.EVENT_RBUTTONDOWN: + print('此时不能添加点,按w退出mask选择模式') + + +def apply_mask(image, mask, alpha_channel=True): + if alpha_channel: + alpha = np.zeros_like(mask) + alpha[mask == 1] = 255 + masked_image = image.copy() + masked_image[mask == 1] = [0, 255, 0] + return cv2.addWeighted(image, 0.5, masked_image, 0.5, 0) + else: + masked_image = image.copy() + masked_image[mask == 1] = [0, 255, 0] + return masked_image + +def apply_color_mask(image, mask, color, color_dark=0.5): + masked_image = image.copy() + for c in range(3): + masked_image[:, :, c] = np.where(mask == 1, image[:, :, c] * (1 - color_dark) + color_dark * color[c], image[:, :, c]) + return masked_image + +def draw_external_rectangle(image, mask, pv): + contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + for contour in contours: + x, y, w, h = cv2.boundingRect(contour) + cv2.rectangle(image, (x, y), (x + w, y + h), (0, 255, 255), 2) # Yellow rectangle + cv2.putText(image, pv, (x, y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2) + +def save_masked_image_and_json(image, mask, output_dir, filename, crop_mode_, pv): + masked_image = apply_mask(image, mask) + draw_external_rectangle(masked_image, mask, pv) + masked_filename = filename[:filename.rfind('.')] + '_masked.png' + cv2.imwrite(os.path.join(output_dir, masked_filename), masked_image) + print(f"Saved image as {masked_filename}") + + # Convert mask to polygons + contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + polygons = [contour.reshape(-1, 2).tolist() for contour in contours] + + # Create JSON data + json_data = { + "version": "5.1.1", + "flags": {}, + "shapes": [ + { + "label": "pv", + "points": polygon, + "group_id": None, + "shape_type": "polygon", + "flags": {} + } for polygon in polygons + ], + "imagePath": filename, + "imageData": None, + "imageHeight": mask.shape[0], + "imageWidth": mask.shape[1] + } + + # Save JSON file + json_filename = filename[:filename.rfind('.')] + '_masked.json' + with open(os.path.join(output_dir, json_filename), 'w') as json_file: + json.dump(json_data, json_file, indent=4) + print(f"Saved JSON as {json_filename}") + +current_index = 0 + +cv2.namedWindow("image") +cv2.setMouseCallback("image", mouse_click) +input_point = [] +input_label = [] +input_stop = False +while True: + filename = image_files[current_index] + image_orign = cv2.imread(os.path.join(input_dir, filename)) + image_crop = image_orign.copy() + image = cv2.cvtColor(image_orign.copy(), cv2.COLOR_BGR2RGB) + selected_mask = None + logit_input = None + while True: + input_stop = False + image_display = image_orign.copy() + display_info = f'{filename} ' + cv2.putText(image_display, display_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2, cv2.LINE_AA) + for point, label in zip(input_point, input_label): + color = (0, 255, 0) if label == 1 else (0, 0, 255) + cv2.circle(image_display, tuple(point), 5, color, -1) + if selected_mask is not None: + color = tuple(np.random.randint(0, 256, 3).tolist()) + selected_image = apply_color_mask(image_display, selected_mask, color) + + cv2.imshow("image", image_display) + key = cv2.waitKey(1) + + if key == ord(" "): + input_point = [] + input_label = [] + selected_mask = None + logit_input = None + elif key == ord("w"): + input_stop = True + if len(input_point) > 0 and len(input_label) > 0: + + predictor.set_image(image) + input_point_np = np.array(input_point) + input_label_np = np.array(input_label) + + masks, scores, logits = predictor.predict( + point_coords=input_point_np, + point_labels=input_label_np, + mask_input=logit_input[None, :, :] if logit_input is not None else None, + multimask_output=True, + ) + + mask_idx = 0 + num_masks = len(masks) # masks的数量 + while (1): + color = tuple(np.random.randint(0, 256, 3).tolist()) # 随机列表颜色,就是 + image_select = image_orign.copy() + selected_mask = masks[mask_idx] # 选择msks也就是,a,d切换 + selected_image = apply_color_mask(image_select, selected_mask, color) + mask_info = f'Total: {num_masks} | Current: {mask_idx} | Score: {scores[mask_idx]:.2f} ' + cv2.putText(selected_image, mask_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2, + cv2.LINE_AA) + # todo 显示在当前的图片, + cv2.imshow("image", selected_image) + + key = cv2.waitKey(10) + if key == ord('q') and len(input_point) > 0: + input_point.pop(-1) + input_label.pop(-1) + elif key == ord('s'): + save_masked_image_and_json(image_crop, selected_mask, output_dir, filename, crop_mode_=crop_mode, pv=f"pv: {scores[mask_idx]:.2f}") + elif key == ord('a'): + if mask_idx > 0: + mask_idx -= 1 + else: + mask_idx = num_masks - 1 + elif key == ord('d'): + if mask_idx < num_masks - 1: + mask_idx += 1 + else: + mask_idx = 0 + elif key == ord('w'): + break + elif key == ord(" "): + input_point = [] + input_label = [] + selected_mask = None + logit_input = None + break + logit_input = logits[mask_idx, :, :] + print('max score:', np.argmax(scores), ' select:', mask_idx) + + elif key == ord('a'): + current_index = max(0, current_index - 1) + input_point = [] + input_label = [] + break + elif key == ord('d'): + current_index = min(len(image_files) - 1, current_index + 1) + input_point = [] + input_label = [] + break + elif key == 27: + break + elif key == ord('q') and len(input_point) > 0: + input_point.pop(-1) + input_label.pop(-1) + elif key == ord('s') and selected_mask is not None: + save_masked_image_and_json(image_crop, selected_mask, output_dir, filename, crop_mode_=crop_mode, pv=f"pv: {scores[mask_idx]:.2f}") + + if key == 27: + break diff --git a/UI.py b/UI.py new file mode 100644 index 0000000..a7c5636 --- /dev/null +++ b/UI.py @@ -0,0 +1,161 @@ +import sys +import os +from PyQt5 import QtCore, QtGui, QtWidgets + +class Ui_MainWindow(object): + def setupUi(self, MainWindow): + MainWindow.setObjectName("MainWindow") + MainWindow.resize(1140, 450) + MainWindow.setMinimumSize(QtCore.QSize(1140, 450)) + MainWindow.setMaximumSize(QtCore.QSize(1140, 450)) + self.centralwidget = QtWidgets.QWidget(MainWindow) + self.centralwidget.setObjectName("centralwidget") + self.pushButton_w = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_w.setGeometry(QtCore.QRect(10, 90, 151, 51)) + self.pushButton_w.setObjectName("pushButton_w") + self.pushButton_a = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_a.setGeometry(QtCore.QRect(10, 160, 71, 51)) + self.pushButton_a.setObjectName("pushButton_a") + self.pushButton_d = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_d.setGeometry(QtCore.QRect(90, 160, 71, 51)) + self.pushButton_d.setObjectName("pushButton_d") + self.pushButton_s = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_s.setGeometry(QtCore.QRect(10, 360, 151, 51)) + self.pushButton_s.setObjectName("pushButton_s") + self.pushButton_5 = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_5.setGeometry(QtCore.QRect(10, 230, 151, 51)) + self.pushButton_5.setObjectName("pushButton_5") + self.label_orign = QtWidgets.QLabel(self.centralwidget) + self.label_orign.setGeometry(QtCore.QRect(180, 20, 471, 401)) + self.label_orign.setStyleSheet("background-color: rgb(255, 255, 255);") + self.label_orign.setObjectName("label_orign") + self.label_2 = QtWidgets.QLabel(self.centralwidget) + self.label_2.setGeometry(QtCore.QRect(660, 20, 471, 401)) + self.label_2.setStyleSheet("background-color: rgb(255, 255, 255);") + self.label_2.setObjectName("label_2") + self.pushButton_w_2 = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_w_2.setGeometry(QtCore.QRect(10, 20, 151, 51)) + self.pushButton_w_2.setObjectName("pushButton_w_2") + self.lineEdit = QtWidgets.QLineEdit(self.centralwidget) + self.lineEdit.setGeometry(QtCore.QRect(50, 290, 81, 21)) + self.lineEdit.setObjectName("lineEdit") + self.horizontalSlider = QtWidgets.QSlider(self.centralwidget) + self.horizontalSlider.setGeometry(QtCore.QRect(10, 320, 141, 22)) + self.horizontalSlider.setSliderPosition(50) + self.horizontalSlider.setOrientation(QtCore.Qt.Horizontal) + self.horizontalSlider.setTickInterval(0) + self.horizontalSlider.setObjectName("horizontalSlider") + MainWindow.setCentralWidget(self.centralwidget) + self.menubar = QtWidgets.QMenuBar(MainWindow) + self.menubar.setGeometry(QtCore.QRect(0, 0, 1140, 23)) + self.menubar.setObjectName("menubar") + MainWindow.setMenuBar(self.menubar) + self.statusbar = QtWidgets.QStatusBar(MainWindow) + self.statusbar.setObjectName("statusbar") + MainWindow.setStatusBar(self.statusbar) + + self.retranslateUi(MainWindow) + QtCore.QMetaObject.connectSlotsByName(MainWindow) + + def retranslateUi(self, MainWindow): + _translate = QtCore.QCoreApplication.translate + MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow")) + self.pushButton_w.setText(_translate("MainWindow", "Predict")) + self.pushButton_a.setText(_translate("MainWindow", "Pre")) + self.pushButton_d.setText(_translate("MainWindow", "Next")) + self.pushButton_s.setText(_translate("MainWindow", "Save")) + self.pushButton_5.setText(_translate("MainWindow", "背景图")) + self.label_orign.setText(_translate("MainWindow", "

原始图像

")) + self.label_2.setText(_translate("MainWindow", "

预测图像

")) + self.pushButton_w_2.setText(_translate("MainWindow", "Openimg")) + self.lineEdit.setText(_translate("MainWindow", "改变mask大小")) + + +class MyMainWindow(QtWidgets.QMainWindow, Ui_MainWindow): + def __init__(self): + super().__init__() + self.setupUi(self) + + self.image_path = "" + self.image_folder = "" + self.image_files = [] + self.current_index = 0 + + self.pushButton_w_2.clicked.connect(self.open_image_folder) + self.pushButton_a.clicked.connect(self.load_previous_image) + self.pushButton_d.clicked.connect(self.load_next_image) + + def open_image_folder(self): + folder_dialog = QtWidgets.QFileDialog() + folder_path = folder_dialog.getExistingDirectory(self, 'Open Image Folder', '') + if folder_path: + self.image_folder = folder_path + self.image_files = self.get_image_files(self.image_folder) + if self.image_files: + self.show_image_selection_dialog() + + def get_image_files(self, folder_path): + image_files = [file for file in os.listdir(folder_path) if file.endswith(('png', 'jpg', 'jpeg', 'bmp'))] + return image_files + + def show_image_selection_dialog(self): + dialog = QtWidgets.QDialog(self) + dialog.setWindowTitle("Select Image") + layout = QtWidgets.QVBoxLayout() + + self.listWidget = QtWidgets.QListWidget() + for image_file in self.image_files: + item = QtWidgets.QListWidgetItem(image_file) + pixmap = QtGui.QPixmap(os.path.join(self.image_folder, image_file)).scaledToWidth(100) + item.setIcon(QtGui.QIcon(pixmap)) + self.listWidget.addItem(item) + self.listWidget.itemDoubleClicked.connect(self.image_selected) + layout.addWidget(self.listWidget) + + buttonBox = QtWidgets.QDialogButtonBox(QtWidgets.QDialogButtonBox.Ok | QtWidgets.QDialogButtonBox.Cancel) + buttonBox.accepted.connect(self.image_selected) + buttonBox.rejected.connect(dialog.reject) + layout.addWidget(buttonBox) + + dialog.setLayout(layout) + + dialog.exec_() + + def image_selected(self): + selected_item = self.listWidget.currentItem() + if selected_item: + selected_index = self.listWidget.currentRow() + if selected_index >= 0: + self.current_index = selected_index + self.show_image() + + def show_image(self): + file_path = os.path.join(self.image_folder, self.image_files[self.current_index]) + pixmap = QtGui.QPixmap(file_path) + self.label_orign.setPixmap(pixmap.scaled(self.label_orign.size(), QtCore.Qt.KeepAspectRatio)) + + def load_previous_image(self): + if self.image_files: + if self.current_index > 0: + self.current_index -= 1 + else: + self.current_index = len(self.image_files) - 1 + self.show_image() + + def load_next_image(self): + if self.image_files: + if self.current_index < len(self.image_files) - 1: + self.current_index += 1 + else: + self.current_index = 0 + self.show_image() + +if __name__ == "__main__": + app = QtWidgets.QApplication(sys.argv) + mainWindow = MyMainWindow() + mainWindow.show() + sys.exit(app.exec_()) + + + + diff --git a/UI.ui b/UI.ui new file mode 100644 index 0000000..f3ed17d --- /dev/null +++ b/UI.ui @@ -0,0 +1,186 @@ + + + MainWindow + + + + 0 + 0 + 1140 + 450 + + + + + 1140 + 450 + + + + + 1140 + 450 + + + + MainWindow + + + + + + 10 + 90 + 151 + 51 + + + + Predict + + + + + + 10 + 160 + 71 + 51 + + + + Pre + + + + + + 90 + 160 + 71 + 51 + + + + Next + + + + + + 10 + 360 + 151 + 51 + + + + Save + + + + + + 10 + 230 + 151 + 51 + + + + 背景图 + + + + + + 180 + 20 + 471 + 401 + + + + background-color: rgb(255, 255, 255); + + + <html><head/><body><p align="center">原始图像</p></body></html> + + + + + + 660 + 20 + 471 + 401 + + + + background-color: rgb(255, 255, 255); + + + <html><head/><body><p align="center">预测图像</p></body></html> + + + + + + 10 + 20 + 151 + 51 + + + + Openimg + + + + + + 50 + 290 + 81 + 21 + + + + 改变mask大小 + + + + + + 10 + 320 + 141 + 22 + + + + 50 + + + Qt::Horizontal + + + 0 + + + + + + + 0 + 0 + 1140 + 23 + + + + + + + + diff --git a/__pycache__/predict_mask.cpython-39.pyc b/__pycache__/predict_mask.cpython-39.pyc new file mode 100644 index 0000000..138c5af Binary files /dev/null and b/__pycache__/predict_mask.cpython-39.pyc differ diff --git a/biao.py b/biao.py new file mode 100644 index 0000000..4ce89d8 --- /dev/null +++ b/biao.py @@ -0,0 +1,734 @@ +import math +import os +import cv2 + +""" +标注关键点只能存在一个框和多个点,并且不能删除点和删除框,读取本地文件的关键点要保证其中的关键点 +数和key_point_num的值是一样的,本地标签中如果只存在框的信息就不要使用该脚本标注,不然会出错, +本地文件夹中可以有标签,如果有会优先加载本地标签,没有才会创建一个 +关键点标注 +这个是默认的标注就是普通标注 +按Q切换下一张,R清除干扰再标注(把之前的框屏蔽) +按Y从本地把图像和标签一切删掉 +按T 退出 +鼠标双击可以删除框 +单机就是拖拽 + +""" +draw_line_circle = True # True/None 是否在框上绘制点(8个点) +key_point_is = None # 是否标记关键点 设置为None标注普通yolo标签 +# 可以自定义得参数 +image_path = R'C:\Users\lengdan\Desktop\data1\images\images' # 标注完成保存到的文件夹 +label_path = R'C:\Users\lengdan\Desktop\data1\images\labels' # 要标注的图像所在文件夹 +circle_distance = 10 # 半径范围:鼠标进入点的半径范围内会出现光圈 +key_point_num = 5 # 关键点个数 +box_thickness = 2 # 框的粗细 +small_box_thickness = 1 # 框的8个点的粗细 +label_thickness = 1 # 框上面的类别字体的粗细 +label_fontScale = 0.4 # 框上面的类别字体的倍数 +key_thick = -1 # 关键点的粗细 +key_text_thick = 2 # 关键点上文字粗细 +key_text_scale = 0.6 # 关键点上文字的放大倍数 +key_radius = 4 # 关键点绘制半径 +dot = 6 # 选择保留几位小数 + +key_color = { + 0: (0, 0, 200), + 1: (255, 0, 0), + 2: (0, 222, 0) +} # 关键点的颜色 +key_text_color = { + 0: (0, 100, 200), + 1: (255, 0, 0), + 2: (0, 255, 125) +} # 关键点上文本的颜色 +box_color = { + 0: (125, 125, 125), + 1: (0, 255, 0), + 2: (0, 255, 0), + 3: (255, 0, 0), + 4: (0, 255, 255), + 5: (255, 255, 0), + 6: (255, 0, 255), + 7: (0, 125, 125), + 8: (125, 125, 125), + 9: (125, 125, 125), + 10: (125, 0, 125), + 11: (125, 0, 125), + 12: (125, 0, 125), + 13: (125, 0, 125), + 14: (125, 0, 125), + 15: (125, 0, 125) +} # 每个不同类别框的颜色 +my_cls = { + 0: '0', + 1: '1', + 2: '10', + 3: '11', + 4: '12', + 5: '13', + 6: '14', + 7: '15', + 8: '2', + 9: '3', + 10: '4', + 11: '5', + 12: '6', + 13: '7', + 14: '8', + 15: '9' +} # 添加自己的框的标签,如果没有就用i:'i'替代 +final_class = { + i: my_cls[i] if i in my_cls else str(i) for i in range(16) +} # 框的默认名字 + +# 不要修改的参数 +position = None # 这里判断鼠标放到了哪个点上,方便后面移动的时候做计算 +label = None # 操作图像对应的标签 +img = None # 操作的图像 +Mouse_move = None # 选择移动框的标志位 +label_index = None # 鼠标选中的框在标签中的位置 +label_index_pos = None # 记录选中了框的8个点位的哪一个 +Mouse_insert = None # 用来记录是否进入删除状态 +draw_rectangle = None # 用来记录开始添加新框 +end_draw_rectangle = None # 用来记录结束绘制新框 +append_str_temp = None # 用来保存新增加的框的信息 +empty_label = None # 本地是否存在标签文件标志 +# 关键点相关的参数 +key_points = None +key_points_move = None +key_x = None # 移动关键点的时候记录其每个关键点的x +key_y = None # 移动关键点的时候记录其每个关键点的y +key_v = None # 移动关键点的时候记录其每个关键点的状态 +key_box = None +box_move = None # 移动的是框的时候的标志位 +key_insert = None # 对某个关键点双击,切换其状态 +move_key_point = None # 把其他位置的关键点移动到这个地方 +la_path = None +key_point_one = None # 使用双击移动关键点的时候,记录第一个按下的键 +key_point_two = None # 使用双击移动关键点的时候,记录第二个按下的键 +append_new_key_point = None # 增加第二个关键点 +append_new_key_point_index = 0 # 增加第二个关键点 +window_w = None # 获取创建窗口的宽度 +window_h = None # 获取创建窗口的高度 + + +def flag_init(): + # 初始化下参数 + global position, label, img, Mouse_insert, Mouse_move, label_index, draw_rectangle, end_draw_rectangle, append_str_temp, empty_label, \ + label_index_pos, key_v, key_x, key_y, key_x, key_box, key_points_move, key_points, box_move, move_key_point, window_w, window_h + position = None + Mouse_move = None + label_index = None + label_index_pos = None + Mouse_insert = None + draw_rectangle = None + end_draw_rectangle = None + append_str_temp = None + empty_label = None + key_points = None + key_points_move = None + key_x = None + key_y = None + key_v = None + key_box = None + box_move = None + move_key_point = None + window_w = None + window_h = None + + +# 用来绘制小的填充矩形框 +def draw_rect_box(img, center, length_1, color=(0, 0, 255)): + x1, y1 = center[0] - length_1, center[1] - length_1 + x2, y2 = center[0] + length_1, center[1] + length_1 + cv2.rectangle(img, (x1, y1), (x2, y2), color, thickness=-1) + + +# 用来读取本地图像 +def img_read(img_path, scale_): + global window_w, window_h + # scale_填写屏幕的最小尺寸 + image = cv2.imread(img_path) + scale_x, scale_y, _ = image.shape + if max(scale_x, scale_y) > scale_ and window_w is None: + scale = max(scale_x, scale_y) / scale_ + image = cv2.resize(image, (int(image.shape[1] / scale), int(image.shape[0] / scale))) + if window_w is not None: + image = cv2.resize(image, (window_w, window_h)) + return image + + +# 判断两点的间距,用来判断鼠标所在位置是否进入了8个点所在的区域 +def distance(p1, p2): + global circle_distance + if math.sqrt((p2[0] - p1[0]) ** 2 + (p2[1] - p1[1]) ** 2) < circle_distance: + return True + else: + return False + + +# 绘制虚线矩形框,当切换到删除时,由实线框转为虚线框 +def draw_dotted_rectangle(img, pt1, pt2, length_1=5, gap=6, thick=2, color=(100, 254, 100)): + (x1, y1), (x2, y2) = pt1, pt2 + temp1, temp2 = x1, y1 + while x1 + length_1 < x2: + cv2.line(img, (x1, y1), (x1 + length_1, y1), color, thickness=thick) + cv2.line(img, (x1, y2), (x1 + length_1, y2), color, thickness=thick) + x1 += length_1 + gap + while y1 + length_1 < y2: + cv2.line(img, (temp1, y1), (temp1, y1 + length_1), color, thickness=thick) + cv2.line(img, (x1, y1), (x1, y1 + length_1), color, thickness=thick) + y1 += length_1 + gap + + +# 把本地标签展示到图像中 +def label_show(img1, label_path, index): + global small_box_thickness, box_thickness, label_fontScale, label_thickness, key_point_is, key_points, \ + key_radius, key_color, key_thick, key_text_scale, key_text_thick, key_text_color, label, draw_line_circle + with open(la_path) as f: + label = f.readlines() + if len(label) == 0: + return + for i, points in enumerate(label): + if key_point_is: + # 获取关键点参数 + key_points = points.split(' ')[5:] + points = points.split(' ')[0:5] + classify = int(float(points[0])) + points.pop(0) + point = [float(s.strip('\n')) for s in points] + # point = list(map(float, points)) + scale_y, scale_x, _ = img1.shape + x, y, w, h = int((point[0] - point[2] / 2) * scale_x), int( + (point[1] - point[3] / 2) * scale_y), int( + point[2] * scale_x), int(point[3] * scale_y) + if i == index: + draw_dotted_rectangle(img1, (x, y), (x + w, y + h), box_thickness) + else: + cv2.rectangle(img1, (x, y), (x + w, y + h), box_color[classify], thickness=box_thickness) + if draw_line_circle: + # 绘制边上中心点,与四个顶点,矩形框中心点 + draw_rect_box(img1, (x, int(0.5 * (y + y + h))), length_1=small_box_thickness) + draw_rect_box(img1, (x + w - 1, int(0.5 * (y + y + h))), length_1=small_box_thickness) + draw_rect_box(img1, (int(0.5 * (x + x + w)), y), length_1=small_box_thickness) + draw_rect_box(img1, (int(0.5 * (x + x + w)), y + h), length_1=small_box_thickness) + draw_rect_box(img1, (x, y), length_1=small_box_thickness) + draw_rect_box(img1, (x + w, y), length_1=small_box_thickness) + draw_rect_box(img1, (x + w, y + h), length_1=small_box_thickness) + draw_rect_box(img1, (x, y + h), length_1=small_box_thickness) + draw_rect_box(img1, (int(x + 0.5 * w), int(y + 0.5 * h)), length_1=small_box_thickness) + cv2.putText(img1, str(final_class[classify]), (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, label_fontScale, + (255, 0, 255), label_thickness) + if key_point_is: + # 依次获取每个关键点 + key_x = [float(i) for i in key_points[::3]] + key_y = [float(i) for i in key_points[1::3]] + key_v = [int(float(i)) for i in key_points[2::3]] + index = 0 + key_point = zip(key_x, key_y) + for p in key_point: + cv2.circle(img, (int(p[0] * scale_x), int(p[1] * scale_y)), key_radius, key_color[key_v[index]], + thickness=key_thick, + lineType=cv2.LINE_AA) + cv2.putText(img, str(index), (int(p[0] * scale_x - 5), int(p[1] * scale_y - 10)), + cv2.FONT_HERSHEY_SIMPLEX, + key_text_scale, key_text_color[0], key_text_thick) + index += 1 + key_points = None + + +# 回调函数,用于记录鼠标操作 +def mouse_event(event, x, y, flag, param): + global label, img, position, Mouse_move, label_index, label_index_pos, dot, Mouse_insert, draw_rectangle, \ + end_draw_rectangle, key_points, key_v, key_x, key_y, key_x, key_box, key_points_move, box_move, \ + key_insert, label_path, move_key_point + scale_y, scale_x, _ = img.shape + # 鼠标如果位于8个点左右,即通过position记录当前位置,通过主函数在鼠标附近绘制空心圈 + # 通过label_index记录鼠标选择了第几个框,通过label_index_pos记录该框第几个点被选中了 + with open(la_path) as f: + label = f.readlines() + if move_key_point is None and key_insert is None and Mouse_insert is None and empty_label is None and event == cv2.EVENT_MOUSEMOVE and img is not None and label is not None and \ + Mouse_move is None: + for i, la in enumerate(label): + la = la.strip('\n').split(' ') + if key_point_is: + key_points = list(map(float, la))[5:] + la = list(map(float, la))[0:5] + x1, y1 = int((la[1] - la[3] / 2) * scale_x), int((la[2] - la[4] / 2) * scale_y) + x2, y2 = x1 + int(la[3] * scale_x), y1 + int(la[4] * scale_y) + # 这里判断鼠标放到了哪个点上,方便后面移动的时候做计算 + if distance((x, y), (x1, y1)): + label_index_pos = 0 + position = (x, y) + label_index = i + box_move = True + key_points_move = None + break + elif distance((x, y), (x2, y2)): + label_index_pos = 1 + position = (x, y) + label_index = i + box_move = True + key_points_move = None + break + elif distance((x, y), (x1, int(0.5 * y1 + 0.5 * y2))): + label_index_pos = 2 + position = (x, y) + label_index = i + box_move = True + key_points_move = None + break + elif distance((x, y), (int((x1 + x2) / 2), y2)): + label_index_pos = 3 + position = (x, y) + label_index = i + box_move = True + key_points_move = None + break + elif distance((x, y), (int((x1 + x2) / 2), y1)): + label_index_pos = 4 + position = (x, y) + label_index = i + box_move = True + key_points_move = None + break + elif distance((x, y), (x2, int(0.5 * y1 + 0.5 * y2))): + label_index_pos = 5 + position = (x, y) + label_index = i + box_move = True + key_points_move = None + break + elif distance((x, y), (x1, y2)): + label_index_pos = 6 + position = (x, y) + label_index = i + box_move = True + key_points_move = None + break + elif distance((x, y), (x2, y1)): + label_index_pos = 7 + position = (x, y) + label_index = i + box_move = True + key_points_move = None + break + elif distance((x, y), ((x1 + x2) / 2, (y1 + y2) / 2)): + # 框中心 + label_index_pos = 8 + position = (x, y) + label_index = i + box_move = True + key_points_move = None + break + else: + label_index_pos = None + position = None + label_index = None + if key_point_is: + # 判断鼠标是不是放到了关键点上 + key_x = [float(i) for i in key_points[::3]] + key_y = [float(i) for i in key_points[1::3]] + key_v = [float(i) for i in key_points[2::3]] # 能见度 + if len(key_x) == len(key_v) and len(key_x) == len(key_y): + for index, key_ in enumerate(key_x): + if distance((x, y), (int(key_ * scale_x), int(key_y[index] * scale_y))): + position = (x, y) + label_index, label_index_pos = i, index + key_box = la + key_points_move = True + box_move = None + break + + # 这里到下一个注释都是为了移动已有的框做准备 + if position is not None and event == cv2.EVENT_LBUTTONDOWN: + Mouse_move = True + position = None + + # 首先判断鼠标选择了该框的第几个点,然后移动鼠标的时候只负责移动该点 + if Mouse_move and box_move: + # 先把要移动的框的标签记录下来,然后删除,添加到末尾,不断修改末尾标签来达到移动框的目的 + # temp_label用来记录标签 + temp_label = label[label_index] + label.pop(label_index) + temp_label = temp_label.strip('\n').split(' ') + temp_label = [float(i) for i in temp_label] + x_1, y_1 = (temp_label[1] - 0.5 * temp_label[3]), (temp_label[2] - 0.5 * temp_label[4]) + x_2, y_2 = x_1 + temp_label[3], y_1 + temp_label[4] + # 判断移动的是8个点中的哪个 + if label_index_pos == 0: + x_1, y_1 = x / scale_x, y / scale_y + elif label_index_pos == 1: + x_2, y_2 = x / scale_x, y / scale_y + elif label_index_pos == 2: + x_1 = x / scale_x + elif label_index_pos == 3: + y_2 = y / scale_y + elif label_index_pos == 4: + y_1 = y / scale_y + elif label_index_pos == 5: + x_2 = x / scale_x + elif label_index_pos == 6: + x_1, y_2 = x / scale_x, y / scale_y + elif label_index_pos == 7: + y_1, x_2 = y / scale_y, x / scale_x + elif label_index_pos == 8: + x_1, y_1 = x / scale_x - (abs(temp_label[3]) / 2), y / scale_y - (abs(temp_label[4]) / 2) + x_2, y_2 = x / scale_x + (abs(temp_label[3]) / 2), y / scale_y + (abs(temp_label[4]) / 2) + # 把移动后的点信息保存下来添加到标签中,以此形成动态绘制一个框的效果 + temp_label[0], temp_label[1], temp_label[2], temp_label[3], temp_label[4] = str( + round((int(temp_label[0])), dot)), \ + str(round(((x_1 + x_2) * 0.5), dot)), str(round(((y_1 + y_2) * 0.5), dot)), str( + round((abs(x_1 - x_2)), dot)), str(round((abs(y_1 - y_2)), dot)) + temp_label = [str(i) for i in temp_label] + str_temp = ' '.join(temp_label) + '\n' + label.append(str_temp) + label_index = len(label) - 1 + elif Mouse_move and key_points_move: + label.pop(label_index) + key_x[label_index_pos] = round(x / scale_x, dot) + key_y[label_index_pos] = round(y / scale_y, dot) + key_box[0] = int(key_box[0]) + str_temp = ' '.join([str(j) for j in key_box]) + for index, kx in enumerate(key_x): + str_temp += ' ' + str(kx) + ' ' + str(key_y[index]) + ' ' + str(int(key_v[index])) + label.append(str_temp) + label_index = len(label) - 1 + + if Mouse_move and event == cv2.EVENT_LBUTTONUP: + flag_init() + + # 这里是为了删除框 + if key_point_is is None and Mouse_insert is None and position is not None and event == cv2.EVENT_LBUTTONDBLCLK and Mouse_move is None: + Mouse_insert = label_index + + if key_point_is and event == cv2.EVENT_LBUTTONDBLCLK and Mouse_move is None and key_points_move and box_move is None: + key_insert = label_index_pos + + if key_point_is and event == cv2.EVENT_LBUTTONDBLCLK and Mouse_insert is None and key_insert is None and position is None: + move_key_point = (x, y) + + # 这里是为了增加新的框 + if key_point_is is None and Mouse_insert is None and position is None and Mouse_move is None and event == cv2.EVENT_LBUTTONDOWN and end_draw_rectangle is None: + draw_rectangle = [(x, y), (x, y)] + + # 如果鼠标左键一直没有松开,则不断更新第二个点的位置 + elif Mouse_insert is None and draw_rectangle is not None and event == cv2.EVENT_MOUSEMOVE and end_draw_rectangle is None: + draw_rectangle[1] = (x, y) + + # 鼠标松开了,最后记录松开时鼠标的位置,现在则记录了开始和松开鼠标的两个位置 + # 如果两个位置太近,则不添加 + elif Mouse_insert is None and draw_rectangle is not None and event == cv2.EVENT_LBUTTONUP: + if end_draw_rectangle is None: + draw_rectangle[1] = (x, y) + if not distance(draw_rectangle[0], draw_rectangle[1]): + end_draw_rectangle = True + else: + draw_rectangle = None + + +def create_file_key(img_path, label_path): + empty_la = None + if not os.path.exists(label_path): + with open(label_path, 'w') as f: + pass + empty_la = True + with open(label_path) as f: + label_ = f.readlines() + if len(label_) == 0 or label_[0] == '\n': + empty_la = True + img_s = img_read(img_path, 950) # 950调整图像的大小 + if key_point_is and empty_la: + box_create = '0 0.5 0.5 0.3 0.3 ' + len_t = img_s.shape[1] // key_point_num + key_num_x = [str(round((i * len_t + 20) / img_s.shape[1], dot)) + ' ' + str(0.5) + ' ' + '2' for i in + range(key_point_num)] + with open(label_path, 'w') as f: + f.write(box_create + ' '.join(key_num_x)) + + +def main(img_path, label_path): + global img, position, label, Mouse_insert, draw_rectangle, end_draw_rectangle, append_str_temp, empty_label, \ + Mouse_move, dot, box_move, key_insert, key_point_one, key_point_two, key_x, key_y, key_v, \ + move_key_point, append_new_key_point, append_new_key_point_index, window_w, window_h + # 判断本地是否存在文件,或者文件中是否为空或者存在一个换行符,就先把标签删除,添加'0 0 0 0 0\n' + # 如果不预先添加一个处理起来有点麻烦,这里就先加一个,然后后面删掉就行了 + if not os.path.exists(label_path): + empty_label = True + with open(label_path, 'w') as f: + pass + with open(label_path) as f: + label = f.readlines() + if len(label) == 0 or label[0] == '\n': + empty_label = True + # 这里的2是将原图缩小为2分之一 + print(img_path) + img_s = img_read(img_path, 900) + if key_point_is and empty_label: + box_create = '0 0.5 0.5 0.3 0.3 ' + len_t = img_s.shape[1] // key_point_num + key_num_x = [str(round((i * len_t + 20) / img_s.shape[1], dot)) + ' ' + str(0.5) + ' ' + '2' for i in + range(key_point_num)] + with open(label_path, 'w') as f: + f.write(box_create + ' '.join(key_num_x)) + label = box_create + ' '.join(key_num_x) + # 创建回调函数,绑定窗口 + cv2.namedWindow('image', cv2.WINDOW_NORMAL) + _, _, window_w, window_h = cv2.getWindowImageRect('image') + cv2.resizeWindow('image', img_s.shape[1], img_s.shape[0]) + cv2.setMouseCallback('image', mouse_event) + # 刷新图像的地方 + while True: + # 首先读取下标签,用来初始化显示 + with open(label_path, 'w') as f: + for i in label: + f.write(i) + # 如果鼠标选中了框的8个点之一,就在鼠标周围绘制空心圈 + if Mouse_insert is None and draw_rectangle is None and position is not None and key_insert is None: + img = img_s.copy() + label_show(img, label_path, Mouse_insert) + cv2.circle(img, position, 10, (0, 255, 100), 2) + # 如果选择开始增加新的框,则不断绘制鼠标起始点和移动过程之间形成的框 + elif draw_rectangle is not None and end_draw_rectangle is None: + img = img_s.copy() + label_show(img, label_path, Mouse_insert) + cv2.rectangle(img, draw_rectangle[0], draw_rectangle[1], color=box_color[1], thickness=2) + # 当松开鼠标后,记录两点位置,并提示选择类别 + elif draw_rectangle is not None and end_draw_rectangle: + scale_y, scale_x, _ = img.shape + x1, y1 = draw_rectangle[0] + x2, y2 = draw_rectangle[1] + w1, h1 = abs(x2 - x1), abs(y2 - y1) + append_str_temp = str(round((x1 + x2) / 2 / scale_x, dot)) + ' ' + str( + round((y1 + y2) / 2 / scale_y, dot)) + ' ' + \ + str(round((w1 / scale_x), dot)) + ' ' + str(round((h1 / scale_y), dot)) + '\n' + cv2.putText(img, 'choose your classify', (0, img.shape[0] // 2 - 30), + cv2.FONT_HERSHEY_SIMPLEX, + 0.7, (255, 0, 255), 2) + cv2.putText(img, ' '.join([str(i) + ':' + my_cls[i] for i in my_cls]), (0, img.shape[0] // 2 + 30), + cv2.FONT_HERSHEY_SIMPLEX, + 0.7, (100, 255, 255), 2) + elif key_insert is not None: + position, Mouse_move, box_move = None, None, None # 禁用其他操作 + cv2.putText(img, 'Switching visibility: 0 1 2', (0, img.shape[0] // 2 - 30), + cv2.FONT_HERSHEY_SIMPLEX, + 1, (100, 255, 255), 2, + lineType=cv2.LINE_AA) + elif move_key_point is not None: + position, Mouse_move, box_move = None, None, None # 禁用其他操作 + cv2.putText(img, 'choose point: 0 - {}'.format(key_point_num - 1), (0, img.shape[0] // 2 - 30), + cv2.FONT_HERSHEY_SIMPLEX, + 1, (100, 255, 255), 2, + lineType=cv2.LINE_AA) + # 如果什么标志都没有,就正常显示一个图 + else: + img = img_s.copy() + if Mouse_insert is not None: + position, Mouse_move = None, None + cv2.putText(img, 'delete: W, exit: E', (0, img.shape[0] // 2 - 30), + cv2.FONT_HERSHEY_SIMPLEX, + 1, (100, 255, 255), 2, + lineType=cv2.LINE_AA) + label_show(img, label_path, Mouse_insert) + cv2.imshow('image', img) + + # key用来获取键盘输入 + key = cv2.waitKey(10) + # 输入为Q则退出 + if key == ord('Q'): + append_new_key_point = None + # 退出按键 + break + if move_key_point is not None and key_point_one is None and 48 <= key <= 57: + key_point_one = int(chr(key)) + key = 0 + if move_key_point is not None and key_point_two is None and 48 <= key <= 57: + key_point_two = int(chr(key)) + key = 0 + if (move_key_point is not None) and (key_point_one is not None) and (key_point_two is not None): + with open(la_path) as f: + label = f.readlines() + for i, la in enumerate(label): + la = la.strip('\n').split(' ') + key_points_ = list(map(float, la))[5:] + key_box_ = list(map(float, la))[0:5] + key_x_ = [float(i) for i in key_points_[::3]] + key_y_ = [float(i) for i in key_points_[1::3]] + key_v_ = [float(i) for i in key_points_[2::3]] # 能见度 + key_box_[0] = int(key_box_[0]) + index_ = key_point_one * 10 + key_point_two + if index_ >= key_point_num: + break + key_x_[index_] = round(move_key_point[0] / img.shape[1], dot) + key_y_[index_] = round(move_key_point[1] / img.shape[0], dot) + str_temp = ' '.join([str(j) for j in key_box_]) + for index, kx in enumerate(key_x_): + str_temp += ' ' + str(kx) + ' ' + str(key_y_[index]) + ' ' + str(int(key_v_[index])) + label = str_temp + with open(la_path, 'w') as f: + f.write(str_temp) + move_key_point, key_point_one, key_point_two = None, None, None + break + move_key_point, key_point_one, key_point_two = None, None, None + # 如果按键输入为W则删除选中的框 + if Mouse_insert is not None and key == ord('W'): + label.pop(Mouse_insert) + Mouse_insert = None + elif key_insert is not None and key == ord('0'): + with open(label_path, 'r') as f: + label_temp = f.read() + str_temp = label_temp.split(' ') + str_temp[3 * int(key_insert) + 7] = '0' + str_temp[3 * int(key_insert) + 7 - 1] = '0' + str_temp[3 * int(key_insert) + 7 - 2] = '0' + with open(label_path, 'w') as f: + f.write(' '.join(str_temp)) + label = ' '.join(str_temp) + key_insert = None + elif key_insert is not None and key == ord('1'): + with open(label_path, 'r') as f: + label_temp = f.read() + str_temp = label_temp.split(' ') + str_temp[3 * int(key_insert) + 7] = '1' + with open(label_path, 'w') as f: + f.write(' '.join(str_temp)) + label = ' '.join(str_temp) + key_insert = None + elif key_insert is not None and key == ord('2'): + with open(label_path, 'r') as f: + label_temp = f.read() + str_temp = label_temp.split(' ') + str_temp[3 * int(key_insert) + 7] = '2' + with open(label_path, 'w') as f: + f.write(' '.join(str_temp)) + label = ' '.join(str_temp) + key_insert = None + # 如果输入为E则从选中框的状态退出 + elif key == ord('E'): + Mouse_insert = None + # 通过键盘获取输入的类别 + elif Mouse_move is None and Mouse_insert is None and draw_rectangle is not None and end_draw_rectangle is not None \ + and (48 <= key <= 57 or key == ord('Z') or key == ord('X') or key == ord('C') or key == ord('V') or key==ord('B') or key==ord('N')): + if 48 <= key <= 57: + str_temp = str(chr(key)) + ' ' + append_str_temp + elif key == ord('Z'): + str_temp = str(10) + ' ' + append_str_temp + elif key == ord('X'): + str_temp = str(11) + ' ' + append_str_temp + elif key == ord('C'): + str_temp = str(12) + ' ' + append_str_temp + elif key == ord('V'): + str_temp = str(13) + ' ' + append_str_temp + elif key == ord('B'): + str_temp = str(14) + ' ' + append_str_temp + elif key == ord('N'): + str_temp = str(15) + ' ' + append_str_temp + label.append(str_temp) + append_str_temp, draw_rectangle, end_draw_rectangle, empty_label = None, None, None, None + elif key == ord('R'): + flag_init() + append_new_key_point = True + break + elif key == ord('T'): + exit(0) + elif key == ord('Y'): + os.remove(img_path) + os.remove(label_path) + break + + +def delete_line_feed(label_path): + # 去掉最后一行的换行符'\n',保存的时候需要 + if os.path.exists(label_path): + with open(label_path) as f: + label_ = f.read() + label_ = label_.rstrip('\n') + with open(label_path, 'w') as f: + f.write(label_) + + +def append__line_feed(label_path): + # 加上最后一行的换行符'\n',标注的时候增加新的框的时候需要 + with open(label_path) as f: + label_ = f.read() + if len(label_) < 4: + with open(label_path, 'w') as f: + pass + return + label_ = label_.rstrip('\n') + '\n' + with open(label_path, 'w') as f: + f.write(label_) + + +def key_check(label_path): + # 检查开启关键点之后本地标签是否满足要求, 如果本地标签中和预设关键点数不等以及关键点数量不是3的倍数都会将原有标签重置 + if os.path.exists(label_path): + with open(label_path) as f: + label_ = f.readlines() + for label_ in label_: + label_ = label_.strip('\n').split(' ') + if ((len(label_) - 5) % 3) or ((len(label_) - 5) // 3 - key_point_num): + with open(label_path, 'w') as f: + pass + + +def label_check(label_path): + # 检查普通标签,判断每行是包含5个数值 + if os.path.exists(label_path): + with open(label_path) as f: + label_ = f.readlines() + for i in label_: + i = i.strip('\n').split(' ') + if len(i) - 5 != 0: + with open(label_path, 'w'): + pass + + +def merge_file_key(la_path, index): + with open(la_path) as f: + text = f.read().strip('\n') + for i in range(index): + with open(la_path.split('.')[0] + str(i) + '.txt') as f: + text += '\n' + f.read().strip('\n') + os.remove(la_path.split('.')[0] + str(i) + '.txt') + with open(la_path, 'w') as f: + f.write(text) + + +if __name__ == '__main__': + image_ = os.listdir(image_path) + for im in image_: + flag_init() + im_path = os.path.join(image_path, im) + la_path = os.path.join(label_path, im.split('.')[0] + '.txt') + if key_point_is: + key_check(la_path) # 检查本地标签的关键点数量是否和预设的关键点数量相等,以及去除框的5点后点数是否满足为3的倍数 + create_file_key(im_path, la_path) + else: + delete_line_feed(la_path) + label_check(la_path) + if os.path.exists(la_path): + # 先增加一个换行符为了后面的增加框的操作 + append__line_feed(la_path) + while True: + main(im_path, la_path) + if append_new_key_point is None: + break + else: + la_path = os.path.join(label_path, im.split('.')[0] + str(append_new_key_point_index) + '.txt') + with open(la_path, 'w') as f: + pass + if key_point_is: + key_check(la_path) # 检查本地标签的关键点数量是否和预设的关键点数量相等,以及去除框的5点后点数是否满足为3的倍数 + create_file_key(im_path, la_path) + else: + delete_line_feed(la_path) + label_check(la_path) + append_new_key_point_index += 1 + if append_new_key_point_index != 0: + merge_file_key(os.path.join(label_path, im.split('.')[0] + '.txt'), append_new_key_point_index) + append_new_key_point_index = 0 + if os.path.exists(la_path): + # 去掉最后一行的换行符 + delete_line_feed(la_path) diff --git a/cut.py b/cut.py new file mode 100644 index 0000000..c1cd88a --- /dev/null +++ b/cut.py @@ -0,0 +1,40 @@ +from PIL import Image + + +def crop_image_into_nine_parts(image_path): + # 打开图片 + image = Image.open(image_path) + + # 获取图片的宽度和高度 + width, height = image.size + + # 计算裁剪后小图片的宽度和高度 + part_width = width // 3 + part_height = height // 3 + + parts = [] + + # 循环裁剪九等分的小图片 + for i in range(3): + for j in range(3): + left = j * part_width + top = i * part_height + right = (j + 1) * part_width + bottom = (i + 1) * part_height + + # 裁剪小图片 + part = image.crop((left, top, right, bottom)) + parts.append(part) + + return parts + + +# 图片路径 +image_path = "scripts/input/images/725.jpg" + +# 调用函数裁剪图片 +nine_parts = crop_image_into_nine_parts(image_path) + +# 保存裁剪后的小图片 +for idx, part in enumerate(nine_parts): + part.save(f"part_{idx + 1}.jpg") diff --git a/detect_c.py b/detect_c.py new file mode 100644 index 0000000..816197d --- /dev/null +++ b/detect_c.py @@ -0,0 +1,91 @@ +import time +from pathlib import Path +import cv2 +import torch +from PIL import ImageGrab +import numpy as np +import os +import glob +from models.common import DetectMultiBackend +from utils.general import non_max_suppression, scale_boxes +from utils.plots import Annotator, colors +from utils.augmentations import letterbox +from utils.torch_utils import select_device + +FILE = Path(__file__).resolve() +folder_path = R'C:\Users\lengdan\Desktop\yolov5-master\data\images' # 本地文件夹 +camera = 2 # 0调用本地相机, 1检测文件夹中的图像, 2检测屏幕上内容 + + +class mydataload: + def __init__(self): + self.count = 0 + if camera == 0: + self.cap = cv2.VideoCapture(0) + + def __iter__(self): + return self + + def __next__(self): + if camera == 0: + _, im0 = self.cap.read() + elif camera == 1: + file_list = glob.glob(os.path.join(folder_path, '*.jpg')) + # 对文件列表按时间戳排序,以确保最新添加的图像排在最后面 + file_list.sort(key=os.path.getmtime) + im0 = cv2.imread(file_list[self.count]) + self.count = self.count if self.count == len(file_list) - 1 else self.count + 1 + else: # camera 检测屏幕上的内容 + # 指定截图区域的左上角和右下角坐标 + x1, y1 = 1000, 100 # 左上角 + x2, y2 = 1900, 1000 # 右下角 + # 截取屏幕区域 + img = ImageGrab.grab(bbox=(x1, y1, x2, y2)) + im0 = np.array(img) + im = letterbox(im0, 640, auto=True)[0] # padded resize + im = im.transpose((2, 0, 1)) # HWC to CHW, BGR to RGB + if camera != 2: + im = im[::-1] + im = np.ascontiguousarray(im) # contiguous + return im, im0 + + +def get_image(model, im, im0s, conf_thres=0.5, iou_thres=0.5, line_thickness=3): + # temp_list = [] + pred = model(im, visualize=False) + pred = non_max_suppression(pred, conf_thres, iou_thres, None, False, max_det=1000) + for i, det in enumerate(pred): + im0, names = im0s.copy(), model.names + annotator = Annotator(im0, line_width=line_thickness, example=str(names)) + if len(det): + det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round() + for *xyxy, conf, cls in reversed(det): + c = int(cls) + # temp_list.append((xyxy[0].item(), xyxy[1].item(), xyxy[2].item(), xyxy[3].item())) + label = f'{names[c]} {conf:.2f}' + annotator.box_label(xyxy, label, color=colors(c, True)) + return annotator.result() + + +if __name__ == "__main__": + device = select_device('0') + model = DetectMultiBackend('7_29_last.pt', device=device, dnn=False, data='', fp16=False) + dataset = mydataload() + model.warmup(imgsz=(1, 3, 640, 640)) # warmup + for im, im0s in dataset: + t0 = time.time() + im = torch.from_numpy(im).to(model.device) + im = im.half() if model.fp16 else im.float() # uint8 to fp16/32 + im /= 255 # 0 - 255 to 0.0 - 1.0 + if len(im.shape) == 3: + im = im[None] # expand for batch dim + im0 = get_image(model, im, im0s) + + if camera == 2: + im0 = cv2.cvtColor(im0, cv2.COLOR_BGR2RGB) + cv2.namedWindow('1', cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux) + cv2.resizeWindow('1', im0.shape[1] // 2, im0.shape[0] // 2) + cv2.imshow('1', im0) + if cv2.waitKey(1) == ord('Q'): + exit(0) + print(time.time() - t0) \ No newline at end of file diff --git a/display.py b/display.py new file mode 100644 index 0000000..829d458 --- /dev/null +++ b/display.py @@ -0,0 +1,169 @@ +import sys +import os +import cv2 +import numpy as np +from PyQt5 import QtCore, QtGui, QtWidgets +from segment_anything import sam_model_registry, SamPredictor + + +class Ui_MainWindow(object): + def setupUi(self, MainWindow): + MainWindow.setObjectName("MainWindow") + MainWindow.resize(1333, 657) + self.centralwidget = QtWidgets.QWidget(MainWindow) + self.centralwidget.setObjectName("centralwidget") + self.pushButton_init = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_init.setGeometry(QtCore.QRect(10, 30, 141, 41)) + self.pushButton_init.setObjectName("pushButton_init") + self.pushButton_openimg = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_openimg.setGeometry(QtCore.QRect(10, 90, 141, 41)) + self.pushButton_openimg.setObjectName("pushButton_openimg") + self.pushButton_Fusionimg = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_Fusionimg.setGeometry(QtCore.QRect(10, 270, 141, 41)) + self.pushButton_Fusionimg.setObjectName("pushButton_Fusionimg") + self.pushButton_exit = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_exit.setGeometry(QtCore.QRect(10, 570, 141, 41)) + self.pushButton_exit.setObjectName("pushButton_exit") + self.pushButton_Transparency = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_Transparency.setGeometry(QtCore.QRect(10, 380, 141, 41)) + self.pushButton_Transparency.setObjectName("pushButton_Transparency") + self.pushButton_copymask = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_copymask.setGeometry(QtCore.QRect(10, 450, 141, 41)) + self.pushButton_copymask.setObjectName("pushButton_copymask") + self.pushButton_saveimg = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_saveimg.setGeometry(QtCore.QRect(10, 510, 141, 41)) + self.pushButton_saveimg.setObjectName("pushButton_saveimg") + self.horizontalSlider = QtWidgets.QSlider(self.centralwidget) + self.horizontalSlider.setGeometry(QtCore.QRect(10, 330, 141, 22)) + self.horizontalSlider.setOrientation(QtCore.Qt.Horizontal) + self.horizontalSlider.setObjectName("horizontalSlider") + self.horizontalSlider.setValue(50) + self.label_Originalimg = QtWidgets.QLabel(self.centralwidget) + self.label_Originalimg.setGeometry(QtCore.QRect(160, 30, 571, 581)) + self.label_Originalimg.setStyleSheet("background-color: rgb(255, 255, 255);") + self.label_Originalimg.setObjectName("label_Originalimg") + self.label_Maskimg = QtWidgets.QLabel(self.centralwidget) + self.label_Maskimg.setGeometry(QtCore.QRect(740, 30, 581, 581)) + self.label_Maskimg.setStyleSheet("background-color: rgb(255, 255, 255);") + self.label_Maskimg.setObjectName("label_Maskimg") + self.pushButton_shang = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_shang.setGeometry(QtCore.QRect(10, 150, 141, 41)) + self.pushButton_shang.setObjectName("pushButton_shang") + self.pushButton_xia = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_xia.setGeometry(QtCore.QRect(10, 210, 141, 41)) + self.pushButton_xia.setObjectName("pushButton_openimg_3") + MainWindow.setCentralWidget(self.centralwidget) + self.menubar = QtWidgets.QMenuBar(MainWindow) + self.menubar.setGeometry(QtCore.QRect(0, 0, 1333, 26)) + self.menubar.setObjectName("menubar") + MainWindow.setMenuBar(self.menubar) + self.statusbar = QtWidgets.QStatusBar(MainWindow) + self.statusbar.setObjectName("statusbar") + MainWindow.setStatusBar(self.statusbar) + + self.retranslateUi(MainWindow) + QtCore.QMetaObject.connectSlotsByName(MainWindow) + + def retranslateUi(self, MainWindow): + + _translate = QtCore.QCoreApplication.translate + MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow")) + self.pushButton_init.setText(_translate("MainWindow", "重置选择")) + self.pushButton_openimg.setText(_translate("MainWindow", "打开图片")) + self.pushButton_shang.setText(_translate("MainWindow", "上一张")) + self.pushButton_xia.setText(_translate("MainWindow", "下一张")) + self.pushButton_Fusionimg.setText(_translate("MainWindow", "融合背景图片")) + self.pushButton_exit.setText(_translate("MainWindow", "退出")) + self.pushButton_Transparency.setText(_translate("MainWindow", "调整透明度")) + self.pushButton_copymask.setText(_translate("MainWindow", "复制掩码")) + self.pushButton_saveimg.setText(_translate("MainWindow", "保存图片")) + self.label_Originalimg.setText( + _translate("MainWindow", "

原始图像

")) + self.label_Maskimg.setText( + _translate("MainWindow", "

掩码图像

")) + + + def init_slots(self): + self.pushButton_openimg.clicked.connect(self.button_image_open) + self.pushButton_init.clicked.connect(self.button_image_init) + self.pushButton_shang.clicked.connect(self.button_image_shang) + self.pushButton_xia.clicked.connect(self.button_image_xia) + self.pushButton_Fusionimg.clicked.connect(self.button_image_Fusionimg) + self.pushButton_copymask.clicked.connect(self.button_image_copymask) + self.pushButton_saveimg.clicked.connect(self.button_image_saveimg) + self.pushButton_exit.clicked.connect(self.button_image_exit) + self.horizontalSlider.valueChanged.connect(self.slider_value_changed) + self.pushButton_openimg.clicked.connect(self.button_image_open) + # 连接label_Originalimg点击事件 + self.label_Originalimg.mousePressEvent = self.label_Originalimg_click_event + + def label_Originalimg_click_event(self, event): + # 捕获label_Originalimg中的鼠标点击事件 + point = event.pos() + x, y = point.x(), point.y() + print("Clicked at:", x, y) # 您可以使用这些坐标进行预测 + def button_image_open(self): + choice = QtWidgets.QMessageBox.question(None, "选择", "您想要打开文件夹还是选择一个图片文件?", + QtWidgets.QMessageBox.Open | QtWidgets.QMessageBox.Cancel) + if choice == QtWidgets.QMessageBox.Open: + folder_path = QtWidgets.QFileDialog.getExistingDirectory(None, "选择文件夹", "") + if folder_path: + image_files = [os.path.join(folder_path, file) for file in os.listdir(folder_path) + if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp'))] + if image_files: + self.image_files = image_files + self.current_index = 0 + self.display_image() + elif choice == QtWidgets.QMessageBox.Cancel: + selected_image, _ = QtWidgets.QFileDialog.getOpenFileName(None, "选择图片", "", + "Image files (*.png *.jpg *.jpeg *.bmp)") + if selected_image: + self.image_files = [selected_image] + self.current_index = 0 + self.display_image() + + def display_image(self): + if hasattr(self, 'image_files') and self.image_files: + pixmap = QtGui.QPixmap(self.image_files[self.current_index]) + self.label_Originalimg.setPixmap(pixmap) + self.label_Originalimg.setScaledContents(True) + + def button_image_shang(self): + if hasattr(self, 'image_files') and self.image_files: + self.current_index = (self.current_index - 1) % len(self.image_files) + self.display_image() + + def button_image_xia(self): + if hasattr(self, 'image_files') and self.image_files: + self.current_index = (self.current_index + 1) % len(self.image_files) + self.display_image() + + def button_image_exit(self): + sys.exit() + + def slider_value_changed(self, value): + print("Slider value changed:", value) + + def button_image_saveimg(self): + pass + + def button_image_Fusionimg(self): + pass + + def button_image_copymask(self): + pass + + def button_image_init(self): + pass + + + +if __name__ == "__main__": + app = QtWidgets.QApplication(sys.argv) + MainWindow = QtWidgets.QMainWindow() + ui = Ui_MainWindow() + ui.setupUi(MainWindow) + ui.init_slots() # 调用init_slots以连接信号和槽 + MainWindow.show() + sys.exit(app.exec_()) + diff --git a/linter.sh b/linter.sh new file mode 100644 index 0000000..df2e174 --- /dev/null +++ b/linter.sh @@ -0,0 +1,32 @@ +#!/bin/bash -e +# Copyright (c) Facebook, Inc. and its affiliates. + +{ + black --version | grep -E "23\." > /dev/null +} || { + echo "Linter requires 'black==23.*' !" + exit 1 +} + +ISORT_VERSION=$(isort --version-number) +if [[ "$ISORT_VERSION" != 5.12* ]]; then + echo "Linter requires isort==5.12.0 !" + exit 1 +fi + +echo "Running isort ..." +isort . --atomic + +echo "Running black ..." +black -l 100 . + +echo "Running flake8 ..." +if [ -x "$(command -v flake8)" ]; then + flake8 . +else + python3 -m flake8 . +fi + +echo "Running mypy..." + +mypy --exclude 'setup.py|notebooks' . diff --git a/mask6.ui b/mask6.ui new file mode 100644 index 0000000..a806444 --- /dev/null +++ b/mask6.ui @@ -0,0 +1,113 @@ + + + MainWindow + + + + 0 + 0 + 889 + 600 + + + + MainWindow + + + + + + 10 + 120 + 421 + 331 + + + + background-color: rgb(255, 255, 255); + + + 原始图像 + + + Qt::AlignCenter + + + + + + 440 + 120 + 421 + 331 + + + + background-color: rgb(255, 255, 255); + + + 预测图像 + + + Qt::AlignCenter + + + + + + 150 + 470 + 131 + 51 + + + + 打开图像 + + + + + + 570 + 470 + 131 + 51 + + + + 预测图像 + + + + + + 320 + 20 + 221 + 41 + + + + <!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 4.0//EN" "http://www.w3.org/TR/REC-html40/strict.dtd"> +<html><head><meta name="qrichtext" content="1" /><style type="text/css"> +p, li { white-space: pre-wrap; } +</style></head><body style=" font-family:'SimSun'; font-size:9pt; font-weight:400; font-style:normal;"> +<p style=" margin-top:0px; margin-bottom:0px; margin-left:0px; margin-right:0px; -qt-block-indent:0; text-indent:0px;"><span style=" font-size:16pt;">分割图像GUI界面</span></p></body></html> + + + + + + + 0 + 0 + 889 + 26 + + + + + + + + diff --git a/maskui.py b/maskui.py new file mode 100644 index 0000000..3cb14b8 --- /dev/null +++ b/maskui.py @@ -0,0 +1,72 @@ +# -*- coding: utf-8 -*- + +# Form implementation generated from reading ui file 'mask6.ui' +# +# Created by: PyQt5 UI code generator 5.15.9 +# +# WARNING: Any manual changes made to this file will be lost when pyuic5 is +# run again. Do not edit this file unless you know what you are doing. + + +from PyQt5 import QtCore, QtGui, QtWidgets + + +class Ui_MainWindow(object): + def setupUi(self, MainWindow): + MainWindow.setObjectName("MainWindow") + MainWindow.resize(889, 600) + self.centralwidget = QtWidgets.QWidget(MainWindow) + self.centralwidget.setObjectName("centralwidget") + self.label = QtWidgets.QLabel(self.centralwidget) + self.label.setGeometry(QtCore.QRect(10, 120, 421, 331)) + self.label.setStyleSheet("background-color: rgb(255, 255, 255);") + self.label.setAlignment(QtCore.Qt.AlignCenter) + self.label.setObjectName("label") + self.label_2 = QtWidgets.QLabel(self.centralwidget) + self.label_2.setGeometry(QtCore.QRect(440, 120, 421, 331)) + self.label_2.setStyleSheet("background-color: rgb(255, 255, 255);") + self.label_2.setAlignment(QtCore.Qt.AlignCenter) + self.label_2.setObjectName("label_2") + self.pushButton = QtWidgets.QPushButton(self.centralwidget) + self.pushButton.setGeometry(QtCore.QRect(150, 470, 131, 51)) + self.pushButton.setObjectName("pushButton") + self.pushButton_2 = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_2.setGeometry(QtCore.QRect(570, 470, 131, 51)) + self.pushButton_2.setObjectName("pushButton_2") + self.textEdit = QtWidgets.QTextEdit(self.centralwidget) + self.textEdit.setGeometry(QtCore.QRect(320, 20, 221, 41)) + self.textEdit.setObjectName("textEdit") + MainWindow.setCentralWidget(self.centralwidget) + self.menubar = QtWidgets.QMenuBar(MainWindow) + self.menubar.setGeometry(QtCore.QRect(0, 0, 889, 26)) + self.menubar.setObjectName("menubar") + MainWindow.setMenuBar(self.menubar) + self.statusbar = QtWidgets.QStatusBar(MainWindow) + self.statusbar.setObjectName("statusbar") + MainWindow.setStatusBar(self.statusbar) + + self.retranslateUi(MainWindow) + QtCore.QMetaObject.connectSlotsByName(MainWindow) + + def retranslateUi(self, MainWindow): + _translate = QtCore.QCoreApplication.translate + MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow")) + self.label.setText(_translate("MainWindow", "原始图像")) + self.label_2.setText(_translate("MainWindow", "预测图像")) + self.pushButton.setText(_translate("MainWindow", "打开图像")) + self.pushButton_2.setText(_translate("MainWindow", "预测图像")) + self.textEdit.setHtml(_translate("MainWindow", "\n" +"\n" +"

分割图像GUI界面

")) + + +if __name__ == "__main__": + import sys + app = QtWidgets.QApplication(sys.argv) + MainWindow = QtWidgets.QMainWindow() + ui = Ui_MainWindow() + ui.setupUi(MainWindow) + MainWindow.show() + sys.exit(app.exec_()) diff --git a/modeltest.py b/modeltest.py new file mode 100644 index 0000000..a8ae979 --- /dev/null +++ b/modeltest.py @@ -0,0 +1,426 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from typing import Optional, Tuple, Type + +from .common import LayerNorm2d, MLPBlock + +# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa +# 这个代码定义了 SAM 的图像编码器 ImageEncoderViT。它包含以下主要部分: +# 1. patch_embed: 这是 ViT 的 patch embedding 层,用于将输入图像划分为 patch,并获得 patch 的 embedding。 +# 2. pos_embed: 这是 ViT的绝对位置 embedding,用于为每个patch提供位置信息。 +# 3. blocks: 这是 ViT 的 transformer encoder 块的列表,每个块包含多头自注意力层和前馈神经网络。 +# 4. neck: 这是图像编码器的“颈部”,包含几个卷积层和 LayerNorm 层,用于从 transformer encoder 块的输出中提取特征。 +# 5. forward(): 这是图像编码器的前向传播过程。首先通过 patch_embed 层获得 patch embedding, 然后加上 pos_embed。 +# 接着,patch embedding通过transformer encoder块。最后, neck 层从 transformer encoder 块的输出中提取特征。 +# 所以,这个 ImageEncoderViT 类定义了 SAM 的图像编码器,它基于 ViT,包含 patch embedding、位置 embedding、 +# transformer encoder块以及 neck, 可以从输入图像中提取特征。 +# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa +class ImageEncoderViT(nn.Module): + def __init__( + self, + img_size: int = 1024, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + out_chans: int = 256, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_abs_pos: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + global_attn_indexes: Tuple[int, ...] = (), + ) -> None: + + super().__init__() + self.img_size = img_size + + self.patch_embed = PatchEmbed( + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size), + in_chans=in_chans, + embed_dim=embed_dim, + ) + + self.pos_embed: Optional[nn.Parameter] = None + if use_abs_pos: + # Initialize absolute positional embedding with pretrain image size. + self.pos_embed = nn.Parameter( + torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) + ) + + self.blocks = nn.ModuleList() + for i in range(depth): + block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + act_layer=act_layer, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + window_size=window_size if i not in global_attn_indexes else 0, + input_size=(img_size // patch_size, img_size // patch_size), + ) + self.blocks.append(block) + + self.neck = nn.Sequential( + nn.Conv2d( + embed_dim, + out_chans, + kernel_size=1, + bias=False, + ), + LayerNorm2d(out_chans), + nn.Conv2d( + out_chans, + out_chans, + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(out_chans), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.patch_embed(x) + if self.pos_embed is not None: + x = x + self.pos_embed + + for blk in self.blocks: + x = blk(x) + + x = self.neck(x.permute(0, 3, 1, 2)) + + return x +# 这个 Block 类实现了 transformer block, 可以选择使用全局注意力或局部窗口注意力,同时包含残差连接。它包含: +# __init__方法: +# 1. 输入参数: +# - dim: 输入通道数 +# - num_heads: 注意力头数 +# - mlp_ratio: mlp 隐藏层与输入 embedding 维度的比例 +# - qkv_bias: 是否为 query、key、value 添加偏置 +# - norm_layer: 归一化层 +# - act_layer: 激活层 +# - use_rel_pos: 是否使用相对位置 embedding +# - rel_pos_zero_init: 是否将相对位置 embedding 初始化为 0 +# - window_size: 窗口注意力的窗口大小,如果为 0 则使用全局注意力 +# - input_size: 计算相对位置 embedding 大小所需的输入分辨率 +# 2. 实例化第 1 次和第 2 次归一化层 norm1 和 norm2。 +# 3. 实例化 Attention 层和 MLPBlock 层。Attention 层的输入大小根据是否使用窗口注意力进行了调整。 +# 4. 记录窗口注意力的窗口大小 window_size。 +# forward方法: +# 1. 提取 shortcut 并对 x 进行第 1 次归一化。 +# 2. 如果使用窗口注意力, 则调用 window_partition 对 x 进行窗口划分。 +# 3. 将 x 输入 Attention 层。 +# 4. 如果使用窗口注意力,则调用 window_unpartition 对 x 进行窗口反划分。 +# 5. x = shortcut + x,实现第 1 次残差连接。 +# 6. x = x + mlp(norm2(x)),实现第 2 次残差连接和 MLPBlock。 +# 7. 返回最终的 x。 +# 所以,这个 Block 类实现了带有可选的窗口注意力和双残差连接的transformer block。 +# 窗口注意力可以更好地建模局部结构,双残差连接可以提高梯度流动,都是transformer结构的重要改进。 +# 这个 Block 类实现了 transformer 的关键组成部分,同时提供了窗口注意力和残差连接等重要变体,可以显著提高其表现力和泛化能力。 + +class Block(nn.Module): + """Transformer blocks with support of window attention and residual propagation blocks""" + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + input_size=input_size if window_size == 0 else (window_size, window_size), + ) + + self.norm2 = norm_layer(dim) + self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) + + self.window_size = window_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x + x = self.norm1(x) + # Window partition + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + + x = self.attn(x) + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + + x = shortcut + x + x = x + self.mlp(self.norm2(x)) + + return x + +# 这个Attention类实现了多头注意力机制,可以加入相对位置 embedding。它包含: +# __init__方法: +# 1. 输入参数: +# - dim: 输入通道数 +# - num_heads: 注意力头数 +# - qkv_bias: 是否为查询、键、值添加偏置 +# - use_rel_pos: 是否使用相对位置 embedding +# - rel_pos_zero_init: 是否将相对位置 embedding 初始化为0 +# - input_size: 计算相对位置 embedding 大小所需的输入分辨率 +# 2. 计算每个注意力头的维度 head_dim。 +# 3. 实例化 self.qkv和 输出投影 self.proj。 +# 4. 如果使用相对位置 embedding, 则初始化 rel_pos_h 和 rel_pos_w。 +# forward方法: +# 1. 从输入 x 中提取批次大小 B、高度 H、宽度 W 和通道数 C。 +# 2. 计算 qkv,形状为 (3, B, nHead, H * W, C), 包含 query、key 和 value。 +# 3. 提取 q、 k 和 v, 形状为 (B * nHead, H * W, C)。 +# 4. 计算注意力图 attn,形状为 (B * nHead, H * W, H * W)。 +# 5. 如果使用相对位置 embedding, 则调用 add_decomposed_rel_pos 函数将其加入 attn。 +# 6. 对 attn 进行 softmax 归一化。 +# 7. 计算输出 x , (attn @ v), 形状为 (B, nHead, H, W, C), 然后合并注意力头, 形状为(B, H, W, C)。 +# 8. 对 x 进行投影, 返回最终的输出。 +# 所以,这个 Attention 类实现了带有相对位置 embedding 的多头注意力机制。 +# 它可以高效地建模图像和视频等二维结构数据,是 transformer 在这些领域得到广泛应用的关键。 +# 这个 Attention 类提供了相对位置 embedding 和多头注意力机制的实现, +# 是理解 transformer 在图像和视频建模中的重要组成部分。 +class Attention(nn.Module): + """Multi-head Attention block with relative position embeddings.""" + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.use_rel_pos = use_rel_pos + if self.use_rel_pos: + assert ( + input_size is not None + ), "Input size must be provided if using relative positional encoding." + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, H, W, _ = x.shape + # qkv with shape (3, B, nHead, H * W, C) + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + # q, k, v with shape (B * nHead, H * W, C) + q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) + + attn = (q * self.scale) @ k.transpose(-2, -1) + + if self.use_rel_pos: + attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) + + attn = attn.softmax(dim=-1) + x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) + x = self.proj(x) + + return x +# 这个 window_partition 函数的作用是将输入张量划分为非重叠的窗口。它包含: +# 1. 输入参数: + # - x: 输入的张量,形状为 [B, H, W, C] + # - window_size: 窗口大小 +# 2. 首先计算输入需要 padding 的高度和宽度,将x进行padding。 +# 3. 然后将 x 的形状变化为 [B, Hp//window_size, window_size, Wp//window_size, window_size, C], +# 表示将图像划分为 Hp//window_size * Wp//window_size 个 window_size * window_size 的 patch。 +# 4. 最后,通过 permute 和 view 操作,得到 windows 的形状为 [B * num_windows, window_size, window_size, C], +# 表示将所有 patch 打平, num_windows 是 patch 的总数 +# 5. 返回windows和原来的高度和宽度(包含padding)Hp和Wp。 +# 所以,这个 window_partition 函数的作用是,将输入的图像划分为 window_size * window_size 的 patch, +# 并将所有的 patch 打平, 输出可以输入到 transformer encoder 中的 token 序列。 +# 这个函数实现了将二维图像转化为一维 token 序列的过程,是 transformer 用于处理图像的一个关键步骤。 +# 通过这个函数,图像可以被 transformer encoder 所处理,就像处理文本序列一样。 + +def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: + + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows, (Hp, Wp) + +# 这个 window_unpartition 函数的作用是将 window_partition 函数的输出进行反划分, 恢复成原始的图像形状。它包含: +# 1. 输入参数: + # - windows: window_partition的输出,形状为 [B * num_windows, window_size, window_size, C] + # - window_size: 窗口大小 + # - pad_hw: padding后的高度和宽度 (Hp, Wp) + # - hw: padding前的原始高度和宽度 (H, W) +# 2. 首先根据窗口大小和 padding 后的 hw 计算原始的 batch_size B。 +# 3. 然后将 windows 的形状变回 [B, Hp//window_size, Wp//window_size, window_size, window_size, C], 表示每个patch的位置。 +# 4. 接着通过permute和view操作,得到x的形状为 [B, Hp, Wp, C], 恢复成图像的形状。 +# 5. 最后,如果进行了padding,则截取x到原始的高度H和宽度W。 +# 6. 返回恢复后的图像x。 +# 所以,这个 window_unpartition 函数的作用是将通过 window_partition 函数得到的 patch 序列恢复成原始的图像。 +# 它实现了从一维 patch token 序列到二维图像的反过程。 +# 这个函数与 window_partition 函数相反,使得 transformer 能够最终从 patch token 序列恢复成图像,完成对图像的建模。 +# 总的来说,这个 window_unpartition 函数实现了从 patch token 序列恢复成原始图像的过程,与 window_partition 函数相对应, +# 是使得 transformer 可以处理图像的另一个关键步骤 +def window_unpartition( + windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] +) -> torch.Tensor: + + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + +# 这个 get_rel_pos 函数的作用是根据 query 和 key 的相对位置获取相对位置 embedding。它包含: +# 1. 输入参数: +# - q_size: query 的大小 +# - k_size: key 的大小 +# - rel_pos: 相对位置 embedding, 形状为[L, C] +# 2. 首先计算最大的相对距离 max_rel_dist, 它等于 query 和 key 大小的 2 倍减 1。 +# 3. 如果相对位置 embedding 的长度小于 max_rel_dist, 则通过线性插值将其调整到 max_rel_dist 的长度。 +# 4. 如果 q_size 和 k_size 不同, 则将 q_size 和 k_size 的坐标按比例缩放,使它们之间的相对距离保持不变。 +# 5. 根据调整后的 q_size 和 k_size 坐标计算相对坐标 relative_coords。 +# 6. 根据 relative_coords 从 rel_pos_resized 中提取相对位置 embedding。 +# 7. 返回提取出的相对位置 embedding。 +# 所以,这个 get_rel_pos 函数的主要作用是,当 query 和 key 的大小不同时,根据它们的相对位置关系提取相应的相对位置 embedding。 +# 它实现了相对位置 embedding 的可变长度和可缩放性。 +# 这个函数使得相对位置 embedding 可以用于 query 和 key 大小不同的 attention 中,是相对位置表示的一个关键步骤。 +# 总的来说,这个 get_rel_pos 函数实现了根据 query 和 key 的相对位置关系提取相应相对位置 embedding 的过程。 +# 它提供了相对位置 embedding 的可变长度和可缩放性,使其可以支持不同的 query 和 key 大小,从而应用到更加灵活的 attention 机制中。 +def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + +# 这个 add_decomposed_rel_pos 函数的作用是根据 query q 和 key k 的空间尺寸, 添加分解的相对位置 embedding 到注意力图 attn 中。它包含: +# 1. 输入参数: +# - attn: 注意力图,形状为 [B, q_h * q_w, k_h * k_w] +# - q: 查询 q,形状为 [B, q_h * q_w, C] +# - rel_pos_h: 高度轴的相对位置 embedding, 形状为[Lh, C] +# - rel_pos_w: 宽度轴的相对位置 embedding, 形状为[Lw, C] +# - q_size: 查询 q的空间尺寸 (q_h, q_w) +# - k_size: 键 k的空间尺寸 (k_h, k_w) +# 2. 从 q_size 和 k_size 中提取高度 q_h、宽度 q_w 以及高度 k_h、宽度 k_w。 +# 3. 调用 get_rel_pos 函数获取高度轴 Rh 和宽度轴 Rw 的相对位置 embedding。 +# 4. 重塑 q 为 [B, q_h, q_w, C]。 +# 5. 计算高度轴 rel_h 和宽度轴 rel_w 的相对位置图, 形状为 [B, q_h, q_w, k_h] 和 [B, q_h, q_w, k_w]。 +# 6. 将 attn 的形状变为 [B, q_h, q_w, k_h, k_w], 并加上 rel_h 和 rel_w。 +# 7. 将 attn 的形状变回 [B, q_h * q_w, k_h * k_w]。 +# 8. 返回加了相对位置 embedding 的 attn。 +def add_decomposed_rel_pos( + attn: torch.Tensor, + q: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], +) -> torch.Tensor: + + q_h, q_w = q_size + k_h, k_w = k_size + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + + B, _, dim = q.shape + r_q = q.reshape(B, q_h, q_w, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) + + attn = ( + attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] + ).view(B, q_h * q_w, k_h * k_w) + + return attn + +# 这个 PatchEmbed 类定义了 ViT 的 patch embedding 层。它包含: +# 1. __init__: 初始化,设置卷积层的 kernel size、stride、padding以 及输入通道数和 embedding 维度。 +# 2. proj: 这是一个卷积层,用于将输入图像划分为 patch, 并获得每个 patch 的 embedding。 +# 3. forward: 前向传播过程。首先通过 proj 卷积层获得 patch embedding ,然后将维度从 [B, C, H, W] 转置成 [B, H, W, C]。 + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__( + self, + kernel_size: Tuple[int, int] = (16, 16), + stride: Tuple[int, int] = (16, 16), + padding: Tuple[int, int] = (0, 0), + in_chans: int = 3, + embed_dim: int = 768, + ) -> None: + + super().__init__() + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + return x diff --git a/predict_mask.py b/predict_mask.py new file mode 100644 index 0000000..d359183 --- /dev/null +++ b/predict_mask.py @@ -0,0 +1,197 @@ +import cv2 +import os +import numpy as np +from segment_anything import sam_model_registry, SamPredictor + +input_dir = 'scripts/input/images' +output_dir = 'scripts/output/mask' +crop_mode = True + +print('最好是每加一个点就按w键predict一次') +os.makedirs(output_dir, exist_ok=True) +image_files = [f for f in os.listdir(input_dir) if + f.lower().endswith(('.png', '.jpg', '.jpeg', '.JPG', '.JPEG', '.PNG', '.tiff'))] + +sam = sam_model_registry["vit_b"](checkpoint=r"D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth") +_ = sam.to(device="cuda") +predictor = SamPredictor(sam) + +WINDOW_WIDTH = 1280 +WINDOW_HEIGHT = 720 +cv2.namedWindow("image", cv2.WINDOW_NORMAL) +cv2.resizeWindow("image", WINDOW_WIDTH, WINDOW_HEIGHT) +cv2.moveWindow("image", (1920 - WINDOW_WIDTH) // 2, (1080 - WINDOW_HEIGHT) // 2) + + +def mouse_click(event, x, y, flags, param): + global input_point, input_label, input_stop + if not input_stop: + if event == cv2.EVENT_LBUTTONDOWN: + input_point.append([x, y]) + input_label.append(1) + elif event == cv2.EVENT_RBUTTONDOWN: + input_point.append([x, y]) + input_label.append(0) + else: + if event == cv2.EVENT_LBUTTONDOWN or event == cv2.EVENT_RBUTTONDOWN: + print('此时不能添加点,按w退出mask选择模式') + + +def apply_mask(image, mask, alpha_channel=True): + if alpha_channel: + alpha = np.zeros_like(image[..., 0]) + alpha[mask == 1] = 255 + image = cv2.merge((image[..., 0], image[..., 1], image[..., 2], alpha)) + else: + image = np.where(mask[..., None] == 1, image, 0) + return image + + +def apply_color_mask(image, mask, color, color_dark=0.5): + for c in range(3): + image[:, :, c] = np.where(mask == 1, image[:, :, c] * (1 - color_dark) + color_dark * color[c], image[:, :, c]) + return image + + +def get_next_filename(base_path, filename): + name, ext = os.path.splitext(filename) + for i in range(1, 101): + new_name = f"{name}_{i}{ext}" + if not os.path.exists(os.path.join(base_path, new_name)): + return new_name + return None + + +def save_masked_image(image, mask, output_dir, filename, crop_mode_): + if crop_mode_: + y, x = np.where(mask) + y_min, y_max, x_min, x_max = y.min(), y.max(), x.min(), x.max() + cropped_mask = mask[y_min:y_max + 1, x_min:x_max + 1] + cropped_image = image[y_min:y_max + 1, x_min:x_max + 1] + masked_image = apply_mask(cropped_image, cropped_mask) + else: + masked_image = apply_mask(image, mask) + filename = filename[:filename.rfind('.')] + '.png' + new_filename = get_next_filename(output_dir, filename) + + if new_filename: + if masked_image.shape[-1] == 4: + cv2.imwrite(os.path.join(output_dir, new_filename), masked_image, [cv2.IMWRITE_PNG_COMPRESSION, 9]) + else: + cv2.imwrite(os.path.join(output_dir, new_filename), masked_image) + print(f"Saved as {new_filename}") + else: + print("Could not save the image. Too many variations exist.") + + +current_index = 0 +cv2.namedWindow("image") +cv2.setMouseCallback("image", mouse_click) +input_point = [] +input_label = [] +input_stop = False + +while True: + filename = image_files[current_index] + image_orign = cv2.imread(os.path.join(input_dir, filename)) + image_crop = image_orign.copy() + image = cv2.cvtColor(image_orign.copy(), cv2.COLOR_BGR2RGB) + selected_mask = None + logit_input = None + + while True: + input_stop = False + image_display = image_orign.copy() + display_info = f'{filename} ' + cv2.putText(image_display, display_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2, cv2.LINE_AA) + + for point, label in zip(input_point, input_label): + color = (0, 255, 0) if label == 1 else (0, 0, 255) + cv2.circle(image_display, tuple(point), 5, color, -1) + + if selected_mask is not None: + color = tuple(np.random.randint(0, 256, 3).tolist()) + selected_image = apply_color_mask(image_display, selected_mask, color) + + cv2.imshow("image", image_display) + key = cv2.waitKey(1) + + if key == ord(" "): + input_point = [] + input_label = [] + selected_mask = None + logit_input = None + + elif key == ord("w"): + input_stop = True + if len(input_point) > 0 and len(input_label) > 0: + predictor.set_image(image) + input_point_np = np.array(input_point) + input_label_np = np.array(input_label) + + masks, scores, logits = predictor.predict( + point_coords=input_point_np, + point_labels=input_label_np, + mask_input=logit_input[None, :, :] if logit_input is not None else None, + multimask_output=True, + ) + + mask_idx = 0 + num_masks = len(masks) + while (1): + color = tuple(np.random.randint(0, 256, 3).tolist()) + image_select = image_orign.copy() + selected_mask = masks[mask_idx] + selected_image = apply_color_mask(image_select, selected_mask, color) + mask_info = f'Total: {num_masks} | Current: {mask_idx} | Score: {scores[mask_idx]:.2f} | w 预测 | d 切换下一个 | a 切换上一个 | q 移除最后一个 | s 保存' + cv2.putText(selected_image, mask_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2, + cv2.LINE_AA) + cv2.imshow("image", selected_image) + + key = cv2.waitKey(10) + if key == ord('q') and len(input_point) > 0: + input_point.pop(-1) + input_label.pop(-1) + elif key == ord('s'): + save_masked_image(image_crop, selected_mask, output_dir, filename, crop_mode_=crop_mode) + elif key == ord('a'): + if mask_idx > 0: + mask_idx -= 1 + else: + mask_idx = num_masks - 1 + elif key == ord('d'): + if mask_idx < num_masks - 1: + mask_idx += 1 + else: + mask_idx = 0 + elif key == ord('w'): + break + elif key == ord(" "): + input_point = [] + input_label = [] + selected_mask = None + logit_input = None + break + logit_input = logits[mask_idx, :, :] + print('max score:', np.argmax(scores), ' select:', mask_idx) + + elif key == ord('a'): + current_index = max(0, current_index - 1) + input_point = [] + input_label = [] + break + elif key == ord('d'): + current_index = min(len(image_files) - 1, current_index + 1) + input_point = [] + input_label = [] + break + elif key == 27: + break + elif key == ord('q') and len(input_point) > 0: + input_point.pop(-1) + input_label.pop(-1) + elif key == ord('s') and selected_mask is not None: + save_masked_image(image_crop, selected_mask, output_dir, filename, crop_mode_=crop_mode) + + if key == 27: + break diff --git a/salt/GUI.py b/salt/GUI.py new file mode 100644 index 0000000..9638687 --- /dev/null +++ b/salt/GUI.py @@ -0,0 +1,200 @@ +import sys +from PyQt5 import QtCore, QtGui, QtWidgets +from PyQt5.QtGui import QPixmap, QImage +from PyQt5.QtCore import QTimer +import cv2 +import os +import numpy as np +from segment_anything import sam_model_registry, SamPredictor + +input_dir = r'D:\Program Files\Pycharm items\segment-anything-model\scripts\input\images' +output_dir = r'D:\Program Files\Pycharm items\segment-anything-model\scripts\output\maskt' +crop_mode = True + +sam = sam_model_registry["vit_b"](checkpoint=r"D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth") +_ = sam.to(device="cuda") +predictor = SamPredictor(sam) + +WINDOW_WIDTH = 1280 +WINDOW_HEIGHT = 720 + +class Ui_MainWindow(object): + def setupUi(self, MainWindow): + MainWindow.setObjectName("MainWindow") + MainWindow.resize(1170, 486) + self.centralwidget = QtWidgets.QWidget(MainWindow) + self.centralwidget.setObjectName("centralwidget") + self.pushButton_w = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_w.setGeometry(QtCore.QRect(10, 90, 151, 51)) + self.pushButton_w.setObjectName("pushButton_w") + self.pushButton_a = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_a.setGeometry(QtCore.QRect(10, 160, 151, 51)) + self.pushButton_a.setObjectName("pushButton_a") + self.pushButton_d = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_d.setGeometry(QtCore.QRect(10, 230, 151, 51)) + self.pushButton_d.setObjectName("pushButton_d") + self.pushButton_s = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_s.setGeometry(QtCore.QRect(10, 300, 151, 51)) + self.pushButton_s.setObjectName("pushButton_s") + self.pushButton_q = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_q.setGeometry(QtCore.QRect(10, 370, 151, 51)) + self.pushButton_q.setObjectName("pushButton_q") + self.label_orign = QtWidgets.QLabel(self.centralwidget) + self.label_orign.setGeometry(QtCore.QRect(180, 20, 471, 401)) + self.label_orign.setStyleSheet("background-color: rgb(255, 255, 255);") + self.label_orign.setObjectName("label_orign") + self.label_pre = QtWidgets.QLabel(self.centralwidget) + self.label_pre.setGeometry(QtCore.QRect(660, 20, 471, 401)) + self.label_pre.setStyleSheet("background-color: rgb(255, 255, 255);") + self.label_pre.setObjectName("label_pre") + self.pushButton_opimg = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_opimg.setGeometry(QtCore.QRect(10, 20, 151, 51)) + self.pushButton_opimg.setObjectName("pushButton_opimg") + MainWindow.setCentralWidget(self.centralwidget) + self.menubar = QtWidgets.QMenuBar(MainWindow) + self.menubar.setGeometry(QtCore.QRect(0, 0, 1170, 26)) + self.menubar.setObjectName("menubar") + MainWindow.setMenuBar(self.menubar) + self.statusbar = QtWidgets.QStatusBar(MainWindow) + self.statusbar.setObjectName("statusbar") + MainWindow.setStatusBar(self.statusbar) + + self.retranslateUi(MainWindow) + QtCore.QMetaObject.connectSlotsByName(MainWindow) + + self.pushButton_opimg.clicked.connect(self.open_image) + self.pushButton_w.clicked.connect(self.predict_image) + + self.timer = QTimer() + self.timer.timeout.connect(self.update_original_image) + self.timer.start(100) # Update every 100 milliseconds + + self.image_files = [f for f in os.listdir(input_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.JPG', '.JPEG', '.PNG', '.tiff'))] + self.current_index = 0 + self.input_point = [] + self.input_label = [] + self.input_stop = False + + cv2.namedWindow("image") + cv2.setMouseCallback("image", self.mouse_click) + + def retranslateUi(self, MainWindow): + _translate = QtCore.QCoreApplication.translate + MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow")) + self.pushButton_w.setText(_translate("MainWindow", "w")) + self.pushButton_a.setText(_translate("MainWindow", "a")) + self.pushButton_d.setText(_translate("MainWindow", "d")) + self.pushButton_s.setText(_translate("MainWindow", "s")) + self.pushButton_q.setText(_translate("MainWindow", "q")) + self.label_orign.setText(_translate("MainWindow", "

原始图像

")) + self.label_pre.setText(_translate("MainWindow", "

预测图像

")) + self.pushButton_opimg.setText(_translate("MainWindow", "打开图像")) + + def open_image(self): + filename, _ = QtWidgets.QFileDialog.getOpenFileName(None, "Open Image File", "", "Image files (*.jpg *.png)") + if filename: + self.current_index = 0 + pixmap = QPixmap(filename) + pixmap = pixmap.scaled(self.label_orign.size(), QtCore.Qt.KeepAspectRatio) + self.label_orign.setPixmap(pixmap) + self.label_orign.setAlignment(QtCore.Qt.AlignCenter) + self.input_point = [] + self.input_label = [] + + def predict_image(self): + if self.current_index < len(self.image_files): + filename = self.image_files[self.current_index] + image_orign = cv2.imread(os.path.join(input_dir, filename)) + image_display = cv2.cvtColor(image_orign.copy(), cv2.COLOR_BGR2RGB) + + for point, label in zip(self.input_point, self.input_label): + color = (0, 255, 0) if label == 1 else (0, 0, 255) + cv2.circle(image_display, tuple(point), 5, color, -1) + + cv2.imshow("image", image_display) + + def update_original_image(self): + if self.current_index < len(self.image_files): + filename = self.image_files[self.current_index] + image_orign = cv2.imread(os.path.join(input_dir, filename)) + image_display = cv2.cvtColor(image_orign.copy(), cv2.COLOR_BGR2RGB) + + for point, label in zip(self.input_point, self.input_label): + color = (0, 255, 0) if label == 1 else (0, 0, 255) + cv2.circle(image_display, tuple(point), 5, color, -1) + + height, width, channel = image_display.shape + bytesPerLine = 3 * width + qImg = QImage(image_display.data, width, height, bytesPerLine, QImage.Format_RGB888) + pixmap = QPixmap.fromImage(qImg) + pixmap = pixmap.scaled(self.label_orign.size(), QtCore.Qt.KeepAspectRatio) + self.label_orign.setPixmap(pixmap) + self.label_orign.setAlignment(QtCore.Qt.AlignCenter) + + def mouse_click(self, event, x, y, flags, param): + if not self.input_stop: + if event == cv2.EVENT_LBUTTONDOWN: + self.input_point.append([x, y]) + self.input_label.append(1) + elif event == cv2.EVENT_RBUTTONDOWN: + self.input_point.append([x, y]) + self.input_label.append(0) + else: + if event == cv2.EVENT_LBUTTONDOWN or event == cv2.EVENT_RBUTTONDOWN: + print('此时不能添加点,按w退出mask选择模式') + +def apply_mask(image, mask, alpha_channel=True): + if alpha_channel: + alpha = np.zeros_like(image[..., 0]) + alpha[mask == 1] = 255 + image = cv2.merge((image[..., 0], image[..., 1], image[..., 2], alpha)) + else: + image = np.where(mask[..., None] == 1, image, 0) + return image + + +def apply_color_mask(image, mask, color, color_dark=0.5): + for c in range(3): + image[:, :, c] = np.where(mask == 1, image[:, :, c] * (1 - color_dark) + color_dark * color[c], image[:, :, c]) + return image + + +def get_next_filename(base_path, filename): + name, ext = os.path.splitext(filename) + for i in range(1, 101): + new_name = f"{name}_{i}{ext}" + if not os.path.exists(os.path.join(base_path, new_name)): + return new_name + return None + + +def save_masked_image(image, mask, output_dir, filename, crop_mode_): + if crop_mode_: + y, x = np.where(mask) + y_min, y_max, x_min, x_max = y.min(), y.max(), x.min(), x.max() + cropped_mask = mask[y_min:y_max + 1, x_min:x_max + 1] + cropped_image = image[y_min:y_max + 1, x_min:x_max + 1] + masked_image = apply_mask(cropped_image, cropped_mask) + else: + masked_image = apply_mask(image, mask) + filename = filename[:filename.rfind('.')] + '.png' + new_filename = get_next_filename(output_dir, filename) + + if new_filename: + if masked_image.shape[-1] == 4: + cv2.imwrite(os.path.join(output_dir, new_filename), masked_image, [cv2.IMWRITE_PNG_COMPRESSION, 9]) + else: + cv2.imwrite(os.path.join(output_dir, new_filename), masked_image) + print(f"Saved as {new_filename}") + else: + print("Could not save the image. Too many variations exist.") + + +if __name__ == "__main__": + app = QtWidgets.QApplication(sys.argv) + MainWindow = QtWidgets.QMainWindow() + ui = Ui_MainWindow() + ui.setupUi(MainWindow) + MainWindow.show() + sys.exit(app.exec_()) + diff --git a/salt/SAM_JSON_多类别.py b/salt/SAM_JSON_多类别.py new file mode 100644 index 0000000..ce31298 --- /dev/null +++ b/salt/SAM_JSON_多类别.py @@ -0,0 +1,236 @@ +import cv2 +import os +import numpy as np +import json +from segment_anything import sam_model_registry, SamPredictor + +input_dir = r'C:\Users\t2581\Desktop\222\images' +output_dir = r'C:\Users\t2581\Desktop\222\2' +crop_mode = True + +print('最好是每加一个点就按w键predict一次') +os.makedirs(output_dir, exist_ok=True) +image_files = [f for f in os.listdir(input_dir) if + f.lower().endswith(('.png', '.jpg', '.jpeg', '.JPG', '.JPEG', '.PNG', '.tiff'))] + +sam = sam_model_registry["vit_b"](checkpoint=r"D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth") +_ = sam.to(device="cuda") +predictor = SamPredictor(sam) + +WINDOW_WIDTH = 1280 +WINDOW_HEIGHT = 720 +cv2.namedWindow("image", cv2.WINDOW_NORMAL) +cv2.resizeWindow("image", WINDOW_WIDTH, WINDOW_HEIGHT) +cv2.moveWindow("image", (1920 - WINDOW_WIDTH) // 2, (1080 - WINDOW_HEIGHT) // 2) + +# 定义类别 +categories = {1: "category1", 2: "category2", 3: "category3"} +category_colors = {1: (0, 255, 0), 2: (0, 0, 255), 3: (255, 0, 0)} +current_label = 1 # 默认类别 + +def mouse_click(event, x, y, flags, param): + global input_points, input_labels, input_stop + if not input_stop: + if event == cv2.EVENT_LBUTTONDOWN: + input_points.append([x, y]) + input_labels.append(current_label) + elif event == cv2.EVENT_RBUTTONDOWN: + input_points.append([x, y]) + input_labels.append(0) + else: + if event == cv2.EVENT_LBUTTONDOWN or event == cv2.EVENT_RBUTTONDOWN: + print('此时不能添加点,按w退出mask选择模式') + +def apply_color_mask(image, mask, color, color_dark=0.5): + masked_image = image.copy() + for c in range(3): + masked_image[:, :, c] = np.where(mask == 1, image[:, :, c] * (1 - color_dark) + color_dark * color[c], + image[:, :, c]) + return masked_image + +def draw_external_rectangle(image, mask, pv): + contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + for contour in contours: + x, y, w, h = cv2.boundingRect(contour) + cv2.rectangle(image, (x, y), (x + w, y + h), (0, 255, 255), 2) # Yellow rectangle + cv2.putText(image, pv, (x, y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2) + +def save_masked_image_and_json(image, masks, output_dir, filename, crop_mode_, pv): + masked_image = image.copy() + json_shapes = [] + + for mask, label, score in masks: + color = category_colors[int(label[-1])] # 获取类别对应的颜色 + masked_image = apply_color_mask(masked_image, mask, color) + draw_external_rectangle(masked_image, mask, f"{label}: {score:.2f}") + + # Convert mask to polygons + contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + polygons = [contour.reshape(-1, 2).tolist() for contour in contours] + + # Append JSON shapes + json_shapes.extend([ + { + "label": label, + "points": polygon, + "group_id": None, + "shape_type": "polygon", + "flags": {} + } for polygon in polygons + ]) + + masked_filename = filename[:filename.rfind('.')] + '_masked.png' + cv2.imwrite(os.path.join(output_dir, masked_filename), masked_image) + print(f"Saved image as {masked_filename}") + + # Create JSON data + json_data = { + "version": "5.1.1", + "flags": {}, + "shapes": json_shapes, + "imagePath": filename, + "imageData": None, + "imageHeight": image.shape[0], + "imageWidth": image.shape[1] + } + + # Save JSON file + json_filename = filename[:filename.rfind('.')] + '_masked.json' + with open(os.path.join(output_dir, json_filename), 'w') as json_file: + json.dump(json_data, json_file, indent=4) + print(f"Saved JSON as {json_filename}") + +current_index = 0 + +cv2.namedWindow("image") +cv2.setMouseCallback("image", mouse_click) +input_points = [] +input_labels = [] +input_stop = False +masks = [] +all_masks = [] # 用于保存所有类别的标注 + +while True: + filename = image_files[current_index] + image_orign = cv2.imread(os.path.join(input_dir, filename)) + image_crop = image_orign.copy() + image = cv2.cvtColor(image_orign.copy(), cv2.COLOR_BGR2RGB) + selected_mask = None + logit_input = None + while True: + input_stop = False + image_display = image_orign.copy() + display_info = f'{filename} | Current label: {categories[current_label]}' + cv2.putText(image_display, display_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2, cv2.LINE_AA) + for point, label in zip(input_points, input_labels): + color = (0, 255, 0) if label > 0 else (0, 0, 255) + cv2.circle(image_display, tuple(point), 5, color, -1) + if selected_mask is not None: + color = tuple(np.random.randint(0, 256, 3).tolist()) + selected_image = apply_color_mask(image_display, selected_mask, color) + + cv2.imshow("image", image_display) + key = cv2.waitKey(1) + + if key == ord(" "): + input_points = [] + input_labels = [] + selected_mask = None + logit_input = None + elif key == ord("w"): + input_stop = True + if len(input_points) > 0 and len(input_labels) > 0: + try: + predictor.set_image(image) + input_point_np = np.array(input_points) + input_label_np = np.array(input_labels) + + masks_pred, scores, logits = predictor.predict( + point_coords=input_point_np, + point_labels=input_label_np, + mask_input=logit_input[None, :, :] if logit_input is not None else None, + multimask_output=True, + ) + + mask_idx = 0 + num_masks = len(masks_pred) # masks的数量 + while True: + color = tuple(np.random.randint(0, 256, 3).tolist()) # 随机列表颜色 + image_select = image_orign.copy() + selected_mask = masks_pred[mask_idx] # 选择msks也就是,a,d切换 + selected_image = apply_color_mask(image_select, selected_mask, color) + mask_info = f'Total: {num_masks} | Current: {mask_idx} | Score: {scores[mask_idx]:.2f} ' + cv2.putText(selected_image, mask_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), + 2, cv2.LINE_AA) + # todo 显示在当前的图片, + cv2.imshow("image", selected_image) + + key = cv2.waitKey(10) + if key == ord('q') and len(input_points) > 0: + input_points.pop(-1) + input_labels.pop(-1) + elif key == ord('s'): + masks.append((selected_mask, categories[current_label], scores[mask_idx])) + elif key == ord('a'): + if mask_idx > 0: + mask_idx -= 1 + else: + mask_idx = num_masks - 1 + elif key == ord('d'): + if mask_idx < num_masks - 1: + mask_idx += 1 + else: + mask_idx = 0 + elif key == ord('w'): + break + elif key == ord(" "): + input_points = [] + input_labels = [] + selected_mask = None + logit_input = None + break + logit_input = logits[mask_idx, :, :] + print('max score:', np.argmax(scores), ' select:', mask_idx) + except Exception as e: + print(f"Error during prediction: {e}") + + elif key == ord('a'): + current_index = max(0, current_index - 1) + input_points = [] + input_labels = [] + break + elif key == ord('d'): + current_index = min(len(image_files) - 1, current_index + 1) + input_points = [] + input_labels = [] + break + elif key == 27: + break + elif key == ord('q') and len(input_points) > 0: + input_points.pop(-1) + input_labels.pop(-1) + elif key == ord('r'): + if masks: + all_masks.extend(masks) # 保存当前的masks到all_masks + masks = [] # 清空当前的masks + input_points = [] + input_labels = [] + selected_mask = None + logit_input = None + elif key == ord('s'): + if masks: + all_masks.extend(masks) # 保存当前的masks到all_masks + if all_masks: + save_masked_image_and_json(image_crop, all_masks, output_dir, filename, crop_mode_=crop_mode, pv="") + all_masks = [] # 清空所有保存的masks + masks = [] # 清空当前的masks + + elif key in [ord(str(i)) for i in categories.keys()]: + current_label = int(chr(key)) + print(f"Switched to label: {categories[current_label]}") + + if key == 27: + break + +cv2.destroyAllWindows() + diff --git a/salt/__pycache__/segment.cpython-39.pyc b/salt/__pycache__/segment.cpython-39.pyc new file mode 100644 index 0000000..691ee93 Binary files /dev/null and b/salt/__pycache__/segment.cpython-39.pyc differ diff --git a/salt/banben1.py b/salt/banben1.py new file mode 100644 index 0000000..e376c0a --- /dev/null +++ b/salt/banben1.py @@ -0,0 +1,202 @@ +import cv2 +import os +import numpy as np +from segment_anything import sam_model_registry, SamPredictor + +input_dir = r'D:\Program Files\Pycharm items\segment-anything-model\scripts\input\images' +output_dir = r'D:\Program Files\Pycharm items\segment-anything-model\scripts\output\maskt' +crop_mode = True + +print('最好是每加一个点就按w键predict一次') +os.makedirs(output_dir, exist_ok=True) +image_files = [f for f in os.listdir(input_dir) if + f.lower().endswith(('.png', '.jpg', '.jpeg', '.JPG', '.JPEG', '.PNG', '.tiff'))] + +sam = sam_model_registry["vit_b"](checkpoint=r"D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth") +_ = sam.to(device="cuda") +predictor = SamPredictor(sam) + +WINDOW_WIDTH = 1280 +WINDOW_HEIGHT = 720 +cv2.namedWindow("image", cv2.WINDOW_NORMAL) +cv2.resizeWindow("image", WINDOW_WIDTH, WINDOW_HEIGHT) +cv2.moveWindow("image", (1920 - WINDOW_WIDTH) // 2, (1080 - WINDOW_HEIGHT) // 2) + +def mouse_click(event, x, y, flags, param): + global input_point, input_label, input_stop + if not input_stop: + if event == cv2.EVENT_LBUTTONDOWN: + input_point.append([x, y]) + input_label.append(1) + elif event == cv2.EVENT_RBUTTONDOWN: + input_point.append([x, y]) + input_label.append(0) + else: + if event == cv2.EVENT_LBUTTONDOWN or event == cv2.EVENT_RBUTTONDOWN: + print('此时不能添加点,按w退出mask选择模式') + + +def apply_mask(image, mask, alpha_channel=True): + if alpha_channel: + alpha = np.zeros_like(image[..., 0]) + alpha[mask == 1] = 255 + image = cv2.merge((image[..., 0], image[..., 1], image[..., 2], alpha)) + else: + image = np.where(mask[..., None] == 1, image, 0) + return image + + +def apply_color_mask(image, mask, color, color_dark=0.5): + for c in range(3): + image[:, :, c] = np.where(mask == 1, image[:, :, c] * (1 - color_dark) + color_dark * color[c], image[:, :, c]) + return image + + +def get_next_filename(base_path, filename): + name, ext = os.path.splitext(filename) + for i in range(1, 101): + new_name = f"{name}_{i}{ext}" + if not os.path.exists(os.path.join(base_path, new_name)): + return new_name + return None + + +def save_masked_image(image, mask, output_dir, filename, crop_mode_): + if crop_mode_: + y, x = np.where(mask) + y_min, y_max, x_min, x_max = y.min(), y.max(), x.min(), x.max() + cropped_mask = mask[y_min:y_max + 1, x_min:x_max + 1] + cropped_image = image[y_min:y_max + 1, x_min:x_max + 1] + masked_image = apply_mask(cropped_image, cropped_mask) + else: + masked_image = apply_mask(image, mask) + filename = filename[:filename.rfind('.')] + '.png' + new_filename = get_next_filename(output_dir, filename) + + if new_filename: + if masked_image.shape[-1] == 4: + cv2.imwrite(os.path.join(output_dir, new_filename), masked_image, [cv2.IMWRITE_PNG_COMPRESSION, 9]) + else: + cv2.imwrite(os.path.join(output_dir, new_filename), masked_image) + print(f"Saved as {new_filename}") + else: + print("Could not save the image. Too many variations exist.") + + +current_index = 0 + +cv2.namedWindow("image") +cv2.setMouseCallback("image", mouse_click) +input_point = [] +input_label = [] +input_stop = False +while True: + filename = image_files[current_index] + image_orign = cv2.imread(os.path.join(input_dir, filename)) + image_crop = image_orign.copy() + image = cv2.cvtColor(image_orign.copy(), cv2.COLOR_BGR2RGB) + selected_mask = None + logit_input = None + while True: + image_display = image_orign.copy() + display_info = f'{filename} ' + cv2.putText(image_display, display_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2, cv2.LINE_AA) + for point, label in zip(input_point, input_label): + color = (0, 255, 0) if label == 1 else (0, 0, 255) + cv2.circle(image_display, tuple(point), 5, color, -1) + if selected_mask is not None: + color = tuple(np.random.randint(0, 256, 3).tolist()) + selected_image = apply_color_mask(image_display, selected_mask, color) + + cv2.imshow("image", image_display) + key = cv2.waitKey(1) + + if key == ord(" "): + input_point = [] + input_label = [] + selected_mask = None + logit_input = None + elif key == ord("w"): + input_stop = True + if len(input_point) > 0 and len(input_label) > 0: + + predictor.set_image(image) + input_point_np = np.array(input_point) + input_label_np = np.array(input_label) + + masks, scores, logits = predictor.predict( + point_coords=input_point_np, + point_labels=input_label_np, + mask_input=logit_input[None, :, :] if logit_input is not None else None, + multimask_output=True, + ) + + mask_idx = 0 + num_masks = len(masks) + + prediction_window_name = "Prediction" + cv2.namedWindow(prediction_window_name, cv2.WINDOW_NORMAL) + cv2.resizeWindow(prediction_window_name, WINDOW_WIDTH, WINDOW_HEIGHT) + cv2.moveWindow(prediction_window_name, (1920 - WINDOW_WIDTH) // 2, (1080 - WINDOW_HEIGHT) // 2) + + while True: + color = tuple(np.random.randint(0, 256, 3).tolist()) + image_select = image_orign.copy() + selected_mask = masks[mask_idx] + selected_image = apply_color_mask(image_select, selected_mask, color) + mask_info = f'Total: {num_masks} | Current: {mask_idx} | Score: {scores[mask_idx]:.2f} | w 预测 | d 切换下一个 | a 切换上一个 |q 移除最后一个 | s 保存' + cv2.putText(selected_image, mask_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2, cv2.LINE_AA) + + cv2.imshow(prediction_window_name, selected_image) + + key = cv2.waitKey(10) + if key == ord('q') and len(input_point) > 0: + input_point.pop(-1) + input_label.pop(-1) + elif key == ord('s'): + save_masked_image(image_crop, selected_mask, output_dir, filename, crop_mode_=crop_mode) + elif key == ord('a'): + if mask_idx > 0: + mask_idx -= 1 + else: + mask_idx = num_masks - 1 + elif key == ord('d'): + if mask_idx < num_masks - 1: + mask_idx += 1 + else: + mask_idx = 0 + elif key == ord('w'): + input_stop = False # Allow adding points again + break + elif key == ord(" "): + input_point = [] + input_label = [] + selected_mask = None + logit_input = None + break + logit_input = logits[mask_idx, :, :] + print('max score:', np.argmax(scores), ' select:', mask_idx) + + elif key == ord('a'): + current_index = max(0, current_index - 1) + input_point = [] + input_label = [] + break + elif key == ord('d'): + current_index = min(len(image_files) - 1, current_index + 1) + input_point = [] + input_label = [] + break + elif key == 27: + break + elif key == ord('q') and len(input_point) > 0: + input_point.pop(-1) + input_label.pop(-1) + elif key == ord('s') and selected_mask is not None: + save_masked_image(image_crop, selected_mask, output_dir, filename, crop_mode_=crop_mode) + + if key == 27: + break + +cv2.destroyAllWindows() # Close all windows before exiting + diff --git a/salt/banben2.py b/salt/banben2.py new file mode 100644 index 0000000..bc96e4e --- /dev/null +++ b/salt/banben2.py @@ -0,0 +1,300 @@ +import cv2 +import os +import numpy as np +from PyQt5 import QtCore, QtGui, QtWidgets +from segment_anything import sam_model_registry, SamPredictor + + +class Ui_MainWindow(object): + def setupUi(self, MainWindow): + MainWindow.setObjectName("MainWindow") + MainWindow.resize(1170, 486) + self.centralwidget = QtWidgets.QWidget(MainWindow) + self.centralwidget.setObjectName("centralwidget") + self.pushButton_w = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_w.setGeometry(QtCore.QRect(10, 90, 151, 51)) + self.pushButton_w.setObjectName("pushButton_w") + self.pushButton_a = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_a.setGeometry(QtCore.QRect(10, 160, 151, 51)) + self.pushButton_a.setObjectName("pushButton_a") + self.pushButton_d = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_d.setGeometry(QtCore.QRect(10, 230, 151, 51)) + self.pushButton_d.setObjectName("pushButton_d") + self.pushButton_s = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_s.setGeometry(QtCore.QRect(10, 300, 151, 51)) + self.pushButton_s.setObjectName("pushButton_s") + self.pushButton_q = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_q.setGeometry(QtCore.QRect(10, 370, 151, 51)) + self.pushButton_q.setObjectName("pushButton_q") + self.label_orign = QtWidgets.QLabel(self.centralwidget) + self.label_orign.setGeometry(QtCore.QRect(180, 20, 450, 450)) + self.label_orign.setStyleSheet("background-color: rgb(255, 255, 255);") + self.label_orign.setObjectName("label_orign") + self.label_pre = QtWidgets.QLabel(self.centralwidget) + self.label_pre.setGeometry(QtCore.QRect(660, 20, 450, 450)) + self.label_pre.setStyleSheet("background-color: rgb(255, 255, 255);") + self.label_pre.setObjectName("label_pre") + self.pushButton_opimg = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_opimg.setGeometry(QtCore.QRect(10, 20, 151, 51)) + self.pushButton_opimg.setObjectName("pushButton_opimg") + MainWindow.setCentralWidget(self.centralwidget) + self.menubar = QtWidgets.QMenuBar(MainWindow) + self.menubar.setGeometry(QtCore.QRect(0, 0, 1170, 26)) + self.menubar.setObjectName("menubar") + MainWindow.setMenuBar(self.menubar) + self.statusbar = QtWidgets.QStatusBar(MainWindow) + self.statusbar.setObjectName("statusbar") + MainWindow.setStatusBar(self.statusbar) + + self.retranslateUi(MainWindow) + QtCore.QMetaObject.connectSlotsByName(MainWindow) + + def retranslateUi(self, MainWindow): + _translate = QtCore.QCoreApplication.translate + MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow")) + self.pushButton_w.setText(_translate("MainWindow", "w")) + self.pushButton_a.setText(_translate("MainWindow", "a")) + self.pushButton_d.setText(_translate("MainWindow", "d")) + self.pushButton_s.setText(_translate("MainWindow", "s")) + self.pushButton_q.setText(_translate("MainWindow", "q")) + self.label_orign.setText( + _translate("MainWindow", "

Original Image

")) + self.label_pre.setText( + _translate("MainWindow", "

Predicted Image

")) + self.pushButton_opimg.setText(_translate("MainWindow", "Open Image")) + + +class MainWindow(QtWidgets.QMainWindow, Ui_MainWindow): + def __init__(self): + super().__init__() + self.setupUi(self) + + self.pushButton_opimg.clicked.connect(self.open_image) + self.pushButton_w.clicked.connect(self.predict_and_interact) + + self.image_files = [] + self.current_index = 0 + self.input_point = [] + self.input_label = [] + self.input_stop = False + self.interaction_count = 0 # 记录交互次数 + self.sam = sam_model_registry["vit_b"]( + checkpoint=r"D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth") + _ = self.sam.to(device="cuda") + self.predictor = SamPredictor(self.sam) + + # Calculate coordinate scaling factors + self.scale_x = 1.0 + self.scale_y = 1.0 + self.label_pre_width = self.label_pre.width() + self.label_pre_height = self.label_pre.height() + + # Set mouse click event for original image label + self.set_mouse_click_event() + + def open_image(self): + options = QtWidgets.QFileDialog.Options() + filename, _ = QtWidgets.QFileDialog.getOpenFileName(self, "Open Image File", "", + "Image Files (*.png *.jpg *.jpeg *.JPG *.JPEG *.PNG *.tiff)", + options=options) + if filename: + image = cv2.imread(filename) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + self.image_files.append(image) + self.display_original_image() + + def display_original_image(self): + if self.image_files: + image = self.image_files[self.current_index] + height, width, channel = image.shape + bytesPerLine = 3 * width + qImg = QtGui.QImage(image.data, width, height, bytesPerLine, QtGui.QImage.Format_RGB888) + pixmap = QtGui.QPixmap.fromImage(qImg) + self.label_orign.setPixmap(pixmap.scaled(self.label_orign.size(), QtCore.Qt.KeepAspectRatio)) + + # Add mouse click event + self.label_orign.mousePressEvent = self.mouse_click + + # Draw marked points on the original image + painter = QtGui.QPainter(self.label_orign.pixmap()) # Use label_orign for drawing points + pen_foreground = QtGui.QPen(QtGui.QColor(0, 255, 0)) # Red color for foreground points + pen_foreground.setWidth(5) # Set the width of the pen to 5 (adjust as needed) + pen_background = QtGui.QPen(QtGui.QColor(255, 0, 0)) # Green color for background points + pen_background.setWidth(5) # Set the width of the pen to 5 (adjust as needed) + painter.setPen(pen_foreground) + for point, label in zip(self.input_point, self.input_label): + x, y = self.convert_to_label_coords(point) + if label == 1: # Foreground point + painter.drawPoint(QtCore.QPoint(x, y)) + painter.setPen(pen_background) + for point, label in zip(self.input_point, self.input_label): + x, y = self.convert_to_label_coords(point) + if label == 0: # Background point + painter.drawPoint(QtCore.QPoint(x, y)) + painter.end() + + # Calculate coordinate scaling factors + self.scale_x = width / self.label_orign.width() + self.scale_y = height / self.label_orign.height() + + def convert_to_label_coords(self, point): + x = point[0] / self.scale_x + y = point[1] / self.scale_y + return x, y + + def mouse_click(self, event): + if not self.input_stop: + x = int(event.pos().x() * self.scale_x) + y = int(event.pos().y() * self.scale_y) + if event.button() == QtCore.Qt.LeftButton: # If left-clicked, mark as foreground + self.input_label.append(1) # Foreground label is 1 + elif event.button() == QtCore.Qt.RightButton: # If right-clicked, mark as background + self.input_label.append(0) # Background label is 0 + + self.input_point.append([x, y]) + + # Update the original image with marked points + self.display_original_image() + + def predict_and_interact(self): + if not self.image_files: + return + + image = self.image_files[self.current_index].copy() + filename = f"image_{self.current_index}.png" + image_crop = image.copy() + + while True: # Outer loop for prediction + # Prediction logic + if not self.input_stop: # If not in interaction mode + if len(self.input_point) > 0 and len(self.input_label) > 0: + self.predictor.set_image(image) + input_point_np = np.array(self.input_point) + input_label_np = np.array(self.input_label) + + masks, scores, logits = self.predictor.predict( + point_coords=input_point_np, + point_labels=input_label_np, + multimask_output=True, + ) + + mask_idx = 0 + num_masks = len(masks) + + while True: # Inner loop for interaction + color = tuple(np.random.randint(0, 256, 3).tolist()) + image_select = image.copy() + selected_mask = masks[mask_idx] + selected_image = self.apply_color_mask(image_select, selected_mask, color) + mask_info = f'Total: {num_masks} | Current: {mask_idx} | Score: {scores[mask_idx]:.2f} | w Predict | d Next | a Previous | q Remove Last | s Save' + cv2.putText(selected_image, mask_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), + 2, cv2.LINE_AA) + + # Display the predicted result in label_pre area + self.display_prediction_image(selected_image) + + key = cv2.waitKey(10) + + # Handle key press events + if key == ord('q') and len(self.input_point) > 0: + self.input_point.pop(-1) + self.input_label.pop(-1) + self.display_original_image() + elif key == ord('s'): + self.save_masked_image(image_crop, selected_mask, filename) + elif key == ord('a'): + if mask_idx > 0: + mask_idx -= 1 + else: + mask_idx = num_masks - 1 + elif key == ord('d'): + if mask_idx < num_masks - 1: + mask_idx += 1 + else: + mask_idx = 0 + elif key == ord(" "): + break + + if cv2.getWindowProperty("Prediction", cv2.WND_PROP_VISIBLE) < 1: + break + + # If 'w' is pressed, toggle interaction mode + if key == ord('w'): + self.input_stop = not self.input_stop # Toggle interaction mode + if not self.input_stop: # If entering interaction mode + self.interaction_count += 1 + if self.interaction_count % 2 == 0: # If even number of interactions, call the interaction function + self.input_point = [] # Reset input points for the next interaction + self.input_label = [] # Reset input labels for the next interaction + self.display_original_image() # Display original image + self.set_mouse_click_event() # Set mouse click event + break # Exit outer loop + else: + continue # Continue prediction + + # Exit the outer loop if not in interaction mode + if not self.input_stop: + break + + def set_mouse_click_event(self): + self.label_orign.mousePressEvent = self.mouse_click + + def display_prediction_image(self, image): + height, width, channel = image.shape + bytesPerLine = 3 * width + qImg = QtGui.QImage(image.data, width, height, bytesPerLine, QtGui.QImage.Format_RGB888) + pixmap = QtGui.QPixmap.fromImage(qImg) + self.label_pre.setPixmap(pixmap.scaled(self.label_pre.size(), QtCore.Qt.KeepAspectRatio)) + + # Draw marked points on the predicted image + painter = QtGui.QPainter(self.label_pre.pixmap()) + pen_foreground = QtGui.QPen(QtGui.QColor(0, 255, 0)) # Red color for foreground points + pen_foreground.setWidth(5) # Set the width of the pen to 5 (adjust as needed) + pen_background = QtGui.QPen(QtGui.QColor(255, 0, 0)) # Green color for background points + pen_background.setWidth(5) # Set the width of the pen to 5 (adjust as needed) + painter.setPen(pen_foreground) + for point, label in zip(self.input_point, self.input_label): + x, y = self.convert_to_label_coords(point) + if label == 1: # Foreground point + painter.drawPoint(QtCore.QPoint(x, y)) + painter.setPen(pen_background) + for point, label in zip(self.input_point, self.input_label): + x, y = self.convert_to_label_coords(point) + if label == 0: # Background point + painter.drawPoint(QtCore.QPoint(x, y)) + painter.end() + + def apply_color_mask(self, image, mask, color=(0, 255, 0), color_dark=0.5): + for c in range(3): + image[:, :, c] = np.where(mask == 1, image[:, :, c] * (1 - color_dark) + color_dark * color[c], + image[:, :, c]) + return image + + def save_masked_image(self, image, mask, filename): + output_dir = os.path.dirname(filename) + filename = os.path.basename(filename) + filename = filename[:filename.rfind('.')] + '_masked.png' + new_filename = os.path.join(output_dir, filename) + + masked_image = self.apply_color_mask(image, mask) + cv2.imwrite(new_filename, cv2.cvtColor(masked_image, cv2.COLOR_RGB2BGR)) + print(f"Saved as {new_filename}") + + def previous_image(self): + if self.current_index > 0: + self.current_index -= 1 + self.display_original_image() + + def next_image(self): + if self.current_index < len(self.image_files) - 1: + self.current_index += 1 + self.display_original_image() + +if __name__ == "__main__": + import sys + app = QtWidgets.QApplication(sys.argv) + window = MainWindow() + window.show() + sys.exit(app.exec_()) + + diff --git a/salt/banben3.py b/salt/banben3.py new file mode 100644 index 0000000..532fd8a --- /dev/null +++ b/salt/banben3.py @@ -0,0 +1,378 @@ +import os +import sys +import cv2 +import numpy as np +from PyQt5 import QtCore, QtGui, QtWidgets +from segment_anything import sam_model_registry, SamPredictor + +class Ui_MainWindow(object): + def setupUi(self, MainWindow): + MainWindow.setObjectName("MainWindow") + MainWindow.resize(1140, 450) + MainWindow.setMinimumSize(QtCore.QSize(1140, 450)) + MainWindow.setMaximumSize(QtCore.QSize(1140, 450)) + self.centralwidget = QtWidgets.QWidget(MainWindow) + self.centralwidget.setObjectName("centralwidget") + self.pushButton_w = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_w.setGeometry(QtCore.QRect(10, 90, 151, 51)) + self.pushButton_w.setObjectName("pushButton_w") + self.pushButton_a = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_a.setGeometry(QtCore.QRect(10, 160, 71, 51)) + self.pushButton_a.setObjectName("pushButton_a") + self.pushButton_d = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_d.setGeometry(QtCore.QRect(90, 160, 71, 51)) + self.pushButton_d.setObjectName("pushButton_d") + self.pushButton_s = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_s.setGeometry(QtCore.QRect(10, 360, 151, 51)) + self.pushButton_s.setObjectName("pushButton_s") + self.pushButton_5 = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_5.setGeometry(QtCore.QRect(10, 230, 151, 51)) + self.pushButton_5.setObjectName("pushButton_5") + self.label_orign = QtWidgets.QLabel(self.centralwidget) + self.label_orign.setGeometry(QtCore.QRect(180, 20, 471, 401)) + self.label_orign.setStyleSheet("background-color: rgb(255, 255, 255);") + self.label_orign.setObjectName("label_orign") + self.label_2 = QtWidgets.QLabel(self.centralwidget) + self.label_2.setGeometry(QtCore.QRect(660, 20, 471, 401)) + self.label_2.setStyleSheet("background-color: rgb(255, 255, 255);") + self.label_2.setObjectName("label_2") + self.pushButton_w_2 = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_w_2.setGeometry(QtCore.QRect(10, 20, 151, 51)) + self.pushButton_w_2.setObjectName("pushButton_w_2") + self.lineEdit = QtWidgets.QLineEdit(self.centralwidget) + self.lineEdit.setGeometry(QtCore.QRect(50, 290, 81, 21)) + self.lineEdit.setObjectName("lineEdit") + self.horizontalSlider = QtWidgets.QSlider(self.centralwidget) + self.horizontalSlider.setGeometry(QtCore.QRect(10, 320, 141, 22)) + self.horizontalSlider.setSliderPosition(50) + self.horizontalSlider.setOrientation(QtCore.Qt.Horizontal) + self.horizontalSlider.setTickInterval(0) + self.horizontalSlider.setObjectName("horizontalSlider") + MainWindow.setCentralWidget(self.centralwidget) + self.menubar = QtWidgets.QMenuBar(MainWindow) + self.menubar.setGeometry(QtCore.QRect(0, 0, 1140, 23)) + self.menubar.setObjectName("menubar") + MainWindow.setMenuBar(self.menubar) + self.statusbar = QtWidgets.QStatusBar(MainWindow) + self.statusbar.setObjectName("statusbar") + MainWindow.setStatusBar(self.statusbar) + + self.retranslateUi(MainWindow) + QtCore.QMetaObject.connectSlotsByName(MainWindow) + + def retranslateUi(self, MainWindow): + _translate = QtCore.QCoreApplication.translate + MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow")) + self.pushButton_w.setText(_translate("MainWindow", "Predict")) + self.pushButton_a.setText(_translate("MainWindow", "Pre")) + self.pushButton_d.setText(_translate("MainWindow", "Next")) + self.pushButton_s.setText(_translate("MainWindow", "Save")) + self.pushButton_5.setText(_translate("MainWindow", "背景图")) + self.label_orign.setText(_translate("MainWindow", "

原始图像

")) + self.label_2.setText(_translate("MainWindow", "

预测图像

")) + self.pushButton_w_2.setText(_translate("MainWindow", "Openimg")) + self.lineEdit.setText(_translate("MainWindow", "改变mask大小")) + +class MyMainWindow(QtWidgets.QMainWindow, Ui_MainWindow): + def __init__(self): + super().__init__() + self.setupUi(self) + + self.image_path = "" + self.image_folder = "" + self.image_files = [] + self.current_index = 0 + self.input_stop = False # 在这里初始化 input_stop + + self.pushButton_w_2.clicked.connect(self.open_image_folder) + self.pushButton_a.clicked.connect(self.load_previous_image) + self.pushButton_d.clicked.connect(self.load_next_image) + + + def open_image_folder(self): + folder_dialog = QtWidgets.QFileDialog() + folder_path = folder_dialog.getExistingDirectory(self, 'Open Image Folder', '') + if folder_path: + self.image_folder = folder_path + self.image_files = self.get_image_files(self.image_folder) + if self.image_files: + self.show_image_selection_dialog() + + def load_previous_image(self): + if self.image_files: + if self.current_index > 0: + self.current_index -= 1 + else: + self.current_index = len(self.image_files) - 1 + self.show_image() + + def load_next_image(self): + if self.image_files: + if self.current_index < len(self.image_files) - 1: + self.current_index += 1 + else: + self.current_index = 0 + self.show_image() + + def get_image_files(self, folder_path): + image_files = [file for file in os.listdir(folder_path) if file.endswith(('png', 'jpg', 'jpeg', 'bmp'))] + return image_files + + def show_image_selection_dialog(self): + dialog = QtWidgets.QDialog(self) + dialog.setWindowTitle("Select Image") + layout = QtWidgets.QVBoxLayout() + + self.listWidget = QtWidgets.QListWidget() + for image_file in self.image_files: + item = QtWidgets.QListWidgetItem(image_file) + pixmap = QtGui.QPixmap(os.path.join(self.image_folder, image_file)).scaledToWidth(100) + item.setIcon(QtGui.QIcon(pixmap)) + self.listWidget.addItem(item) + self.listWidget.itemDoubleClicked.connect(self.image_selected) + layout.addWidget(self.listWidget) + + buttonBox = QtWidgets.QDialogButtonBox(QtWidgets.QDialogButtonBox.Ok | QtWidgets.QDialogButtonBox.Cancel) + buttonBox.accepted.connect(self.image_selected) + buttonBox.rejected.connect(dialog.reject) + layout.addWidget(buttonBox) + + dialog.setLayout(layout) + + dialog.exec_() + + def image_selected(self): + selected_item = self.listWidget.currentItem() + if selected_item: + selected_index = self.listWidget.currentRow() + if selected_index >= 0 and selected_index < len(self.image_files): # 检查索引是否在有效范围内 + self.current_index = selected_index + self.show_image() # 显示所选图像 + # 调用OpenCV窗口显示 + self.call_opencv_interaction(os.path.join(self.image_folder, self.image_files[self.current_index])) + + def show_image(self): + if self.image_files and self.current_index < len(self.image_files): + file_path = os.path.join(self.image_folder, self.image_files[self.current_index]) + pixmap = QtGui.QPixmap(file_path) + self.label_orign.setPixmap(pixmap.scaled(self.label_orign.size(), QtCore.Qt.KeepAspectRatio)) + + def call_opencv_interaction(self, image_path): + input_dir = os.path.dirname(image_path) + image_orign = cv2.imread(image_path) + output_dir = r'D:\Program Files\Pycharm items\segment-anything-model\scripts\output\maskt' + crop_mode = True + + print('最好是每加一个点就按w键predict一次') + os.makedirs(output_dir, exist_ok=True) + image_files = [self.image_files[self.current_index]] + + sam = sam_model_registry["vit_b"]( + checkpoint=r"D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth") + _ = sam.to(device="cuda") + predictor = SamPredictor(sam) + + WINDOW_WIDTH = 1280 + WINDOW_HEIGHT = 720 + + def apply_mask(image, mask, alpha_channel=True): + if alpha_channel: + alpha = np.zeros_like(image[..., 0]) + alpha[mask == 1] = 255 + image = cv2.merge((image[..., 0], image[..., 1], image[..., 2], alpha)) + else: + image = np.where(mask[..., None] == 1, image, 0) + return image + + def apply_color_mask(image, mask, color, color_dark=0.5): + for c in range(3): + image[:, :, c] = np.where(mask == 1, image[:, :, c] * (1 - color_dark) + color_dark * color[c], + image[:, :, c]) + return image + + def get_next_filename(base_path, filename): + name, ext = os.path.splitext(filename) + for i in range(1, 101): + new_name = f"{name}_{i}{ext}" + if not os.path.exists(os.path.join(base_path, new_name)): + return new_name + return None + + def save_masked_image(image, mask, output_dir, filename, crop_mode_): + # 保存图像到指定路径 + if crop_mode_: + # 如果采用了裁剪模式,则裁剪图像 + y, x = np.where(mask) + y_min, y_max, x_min, x_max = y.min(), y.max(), x.min(), x.max() + cropped_mask = mask[y_min:y_max + 1, x_min:x_max + 1] + cropped_image = image[y_min:y_max + 1, x_min:x_max + 1] + masked_image = apply_mask(cropped_image, cropped_mask) + else: + masked_image = apply_mask(image, mask) + filename = filename[:filename.rfind('.')] + '.png' + new_filename = get_next_filename(output_dir, filename) + + if new_filename: + if masked_image.shape[-1] == 4: + cv2.imwrite(os.path.join(output_dir, new_filename), masked_image, + [cv2.IMWRITE_PNG_COMPRESSION, 9]) + else: + cv2.imwrite(os.path.join(output_dir, new_filename), masked_image) + print(f"Saved as {new_filename}") + + # 读取保存的图像文件 + saved_image_path = os.path.join(output_dir, new_filename) + saved_image_pixmap = QtGui.QPixmap(saved_image_path) + + # 将保存的图像显示在预测图像区域 + mainWindow.label_2.setPixmap( + saved_image_pixmap.scaled(mainWindow.label_2.size(), QtCore.Qt.KeepAspectRatio)) + else: + print("Could not save the image. Too many variations exist.") + + current_index = 0 + + cv2.namedWindow("image", cv2.WINDOW_NORMAL) + cv2.resizeWindow("image", WINDOW_WIDTH, WINDOW_HEIGHT) + cv2.moveWindow("image", (1920 - WINDOW_WIDTH) // 2, (1080 - WINDOW_HEIGHT) // 2) + + def mouse_click(event, x, y, flags, param): + if not self.input_stop: + if event == cv2.EVENT_LBUTTONDOWN: + input_point.append([x, y]) + input_label.append(1) + elif event == cv2.EVENT_RBUTTONDOWN: + input_point.append([x, y]) + input_label.append(0) + else: + if event == cv2.EVENT_LBUTTONDOWN or event == cv2.EVENT_RBUTTONDOWN: + print('此时不能添加点,按w退出mask选择模式') + + cv2.setMouseCallback("image", mouse_click) + input_point = [] + input_label = [] + input_stop = False + while True: + filename = self.image_files[self.current_index] + image_orign = cv2.imread(os.path.join(input_dir, filename)) + image_crop = image_orign.copy() + image = cv2.cvtColor(image_orign.copy(), cv2.COLOR_BGR2RGB) + selected_mask = None + logit_input = None + while True: + image_display = image_orign.copy() + display_info = f'{filename} ' + cv2.putText(image_display, display_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), + 2, + cv2.LINE_AA) + for point, label in zip(input_point, input_label): + color = (0, 255, 0) if label == 1 else (0, 0, 255) + cv2.circle(image_display, tuple(point), 5, color, -1) + if selected_mask is not None: + color = tuple(np.random.randint(0, 256, 3).tolist()) + selected_image = apply_color_mask(image_display, selected_mask, color) + + cv2.imshow("image", image_display) + + key = cv2.waitKey(1) + + if key == ord(" "): + input_point = [] + input_label = [] + selected_mask = None + logit_input = None + elif key == ord("w"): + input_stop = True + if len(input_point) > 0 and len(input_label) > 0: + + predictor.set_image(image) + input_point_np = np.array(input_point) + input_label_np = np.array(input_label) + + masks, scores, logits = predictor.predict( + point_coords=input_point_np, + point_labels=input_label_np, + mask_input=logit_input[None, :, :] if logit_input is not None else None, + multimask_output=True, + ) + + mask_idx = 0 + num_masks = len(masks) + + prediction_window_name = "Prediction" + cv2.namedWindow(prediction_window_name, cv2.WINDOW_NORMAL) + cv2.resizeWindow(prediction_window_name, WINDOW_WIDTH, WINDOW_HEIGHT) + cv2.moveWindow(prediction_window_name, (1920 - WINDOW_WIDTH) // 2, + (1080 - WINDOW_HEIGHT) // 2) + + while True: + color = tuple(np.random.randint(0, 256, 3).tolist()) + image_select = image_orign.copy() + selected_mask = masks[mask_idx] + selected_image = apply_color_mask(image_select, selected_mask, color) + mask_info = f'Total: {num_masks} | Current: {mask_idx} | Score: {scores[mask_idx]:.2f} | w 预测 | d 切换下一个 | a 切换上一个 |q 移除最后一个' + cv2.putText(selected_image, mask_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, + (0, 255, 255), 2, cv2.LINE_AA) + + cv2.imshow(prediction_window_name, selected_image) + + key = cv2.waitKey(10) + if key == ord('q') and len(input_point) > 0: + input_point.pop(-1) + elif key == ord('s'): + save_masked_image(image_crop, selected_mask, output_dir, filename, + crop_mode_=crop_mode) + elif key == ord('a'): + if mask_idx > 0: + mask_idx -= 1 + else: + mask_idx = num_masks - 1 + elif key == ord('d'): + if mask_idx < num_masks - 1: + mask_idx += 1 + else: + mask_idx = 0 + elif key == ord('w'): + input_stop = False # Allow adding points again + break + elif key == ord(" "): + input_point = [] + input_label = [] + selected_mask = None + logit_input = None + break + logit_input = logits[mask_idx, :, :] + print('max score:', np.argmax(scores), ' select:', mask_idx) + + elif key == ord('a'): + current_index = max(0, current_index - 1) + input_point = [] + input_label = [] + break + elif key == ord('d'): + current_index = min(len(image_files) - 1, current_index + 1) + input_point = [] + input_label = [] + break + elif key == 27: + break + elif key == ord('q') and len(input_point) > 0: + input_point.pop(-1) + input_label.pop(-1) + elif key == ord('s') and selected_mask is not None: + save_masked_image(image_crop, selected_mask, output_dir, filename, crop_mode_=crop_mode) + + if key == 27: + break + + cv2.destroyAllWindows() # Close all windows before exiting + if key == 27: + break + +if __name__ == "__main__": + app = QtWidgets.QApplication(sys.argv) + mainWindow = MyMainWindow() + mainWindow.show() + sys.exit(app.exec_()) + + diff --git a/salt/banben4.py b/salt/banben4.py new file mode 100644 index 0000000..1783a05 --- /dev/null +++ b/salt/banben4.py @@ -0,0 +1,453 @@ +import os +import sys +import cv2 +import numpy as np +from PyQt5 import QtCore, QtGui, QtWidgets +from segment_anything import sam_model_registry, SamPredictor + +class Ui_MainWindow(object): + def setupUi(self, MainWindow): + MainWindow.setObjectName("MainWindow") + MainWindow.resize(1140, 450) + MainWindow.setMinimumSize(QtCore.QSize(1140, 450)) + MainWindow.setMaximumSize(QtCore.QSize(1140, 450)) + self.centralwidget = QtWidgets.QWidget(MainWindow) + self.centralwidget.setObjectName("centralwidget") + self.pushButton_w = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_w.setGeometry(QtCore.QRect(10, 90, 151, 51)) + self.pushButton_w.setObjectName("pushButton_w") + self.pushButton_a = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_a.setGeometry(QtCore.QRect(10, 160, 71, 51)) + self.pushButton_a.setObjectName("pushButton_a") + self.pushButton_d = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_d.setGeometry(QtCore.QRect(90, 160, 71, 51)) + self.pushButton_d.setObjectName("pushButton_d") + self.pushButton_s = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_s.setGeometry(QtCore.QRect(10, 360, 151, 51)) + self.pushButton_s.setObjectName("pushButton_s") + self.pushButton_5 = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_5.setGeometry(QtCore.QRect(10, 230, 151, 51)) + self.pushButton_5.setObjectName("pushButton_5") + self.label_orign = QtWidgets.QLabel(self.centralwidget) + self.label_orign.setGeometry(QtCore.QRect(180, 20, 471, 401)) + self.label_orign.setStyleSheet("background-color: rgb(255, 255, 255);") + self.label_orign.setObjectName("label_orign") + self.label_2 = QtWidgets.QLabel(self.centralwidget) + self.label_2.setGeometry(QtCore.QRect(660, 20, 471, 401)) + self.label_2.setStyleSheet("background-color: rgb(255, 255, 255);") + self.label_2.setObjectName("label_2") + self.pushButton_w_2 = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_w_2.setGeometry(QtCore.QRect(10, 20, 151, 51)) + self.pushButton_w_2.setObjectName("pushButton_w_2") + self.lineEdit = QtWidgets.QLineEdit(self.centralwidget) + self.lineEdit.setGeometry(QtCore.QRect(50, 290, 81, 21)) + self.lineEdit.setObjectName("lineEdit") + self.horizontalSlider = QtWidgets.QSlider(self.centralwidget) + self.horizontalSlider.setGeometry(QtCore.QRect(10, 320, 141, 22)) + self.horizontalSlider.setRange(0, 10) # 将范围设置为从0到最大值 + self.horizontalSlider.setSingleStep(1) + self.horizontalSlider.setValue(0) # 初始值设为0 + self.horizontalSlider.setOrientation(QtCore.Qt.Horizontal) + self.horizontalSlider.setTickInterval(0) + self.horizontalSlider.setObjectName("horizontalSlider") + MainWindow.setCentralWidget(self.centralwidget) + self.menubar = QtWidgets.QMenuBar(MainWindow) + self.menubar.setGeometry(QtCore.QRect(0, 0, 1140, 23)) + self.menubar.setObjectName("menubar") + MainWindow.setMenuBar(self.menubar) + self.statusbar = QtWidgets.QStatusBar(MainWindow) + self.statusbar.setObjectName("statusbar") + MainWindow.setStatusBar(self.statusbar) + + self.retranslateUi(MainWindow) + QtCore.QMetaObject.connectSlotsByName(MainWindow) + + def retranslateUi(self, MainWindow): + _translate = QtCore.QCoreApplication.translate + MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow")) + self.pushButton_w.setText(_translate("MainWindow", "获取json文件")) + self.pushButton_a.setText(_translate("MainWindow", "Pre")) + self.pushButton_d.setText(_translate("MainWindow", "Next")) + self.pushButton_s.setText(_translate("MainWindow", "Save")) + self.pushButton_5.setText(_translate("MainWindow", "背景图")) + self.label_orign.setText(_translate("MainWindow", "

原始图像

")) + self.label_2.setText(_translate("MainWindow", "

预测图像

")) + self.pushButton_w_2.setText(_translate("MainWindow", "Openimg")) + self.lineEdit.setText(_translate("MainWindow", "改变mask大小")) + +class MyMainWindow(QtWidgets.QMainWindow, Ui_MainWindow): + def __init__(self): + super().__init__() + self.setupUi(self) + + self.k = 0 + self.last_value = 0 # 保存上一次滑块值 + + self.image_path = "" + self.image_folder = "" + self.image_files = [] + self.current_index = 0 + self.input_stop = False # 在这里初始化 input_stop + + self.pushButton_w_2.clicked.connect(self.open_image_folder) + self.pushButton_a.clicked.connect(self.load_previous_image) + self.pushButton_d.clicked.connect(self.load_next_image) + self.pushButton_s.clicked.connect(self.save_prediction) + self.pushButton_5.clicked.connect(self.select_background_image) + self.horizontalSlider.valueChanged.connect(self.adjust_prediction_size) # 连接水平滑块的值改变信号 + + def adjust_pixmap_size(self, pixmap, scale_factor): + scaled_size = QtCore.QSize(pixmap.size().width() * scale_factor / 100, + pixmap.size().height() * scale_factor / 100) + return pixmap.scaled(scaled_size, QtCore.Qt.KeepAspectRatio) + + + def open_image_folder(self): + folder_dialog = QtWidgets.QFileDialog() + folder_path = folder_dialog.getExistingDirectory(self, 'Open Image Folder', '') + if folder_path: + self.image_folder = folder_path + self.image_files = self.get_image_files(self.image_folder) + if self.image_files: + self.show_image_selection_dialog() + + def load_previous_image(self): + if self.image_files: + if self.current_index > 0: + self.current_index -= 1 + else: + self.current_index = len(self.image_files) - 1 + self.show_image() + + def load_next_image(self): + if self.image_files: + if self.current_index < len(self.image_files) - 1: + self.current_index += 1 + else: + self.current_index = 0 + self.show_image() + + def get_image_files(self, folder_path): + image_files = [file for file in os.listdir(folder_path) if file.endswith(('png', 'jpg', 'jpeg', 'bmp'))] + return image_files + + def show_image_selection_dialog(self): + dialog = QtWidgets.QDialog(self) + dialog.setWindowTitle("Select Image") + layout = QtWidgets.QVBoxLayout() + + self.listWidget = QtWidgets.QListWidget() + for image_file in self.image_files: + item = QtWidgets.QListWidgetItem(image_file) + pixmap = QtGui.QPixmap(os.path.join(self.image_folder, image_file)).scaledToWidth(100) + item.setIcon(QtGui.QIcon(pixmap)) + self.listWidget.addItem(item) + self.listWidget.itemDoubleClicked.connect(self.image_selected) + layout.addWidget(self.listWidget) + + buttonBox = QtWidgets.QDialogButtonBox(QtWidgets.QDialogButtonBox.Ok | QtWidgets.QDialogButtonBox.Cancel) + buttonBox.accepted.connect(self.image_selected) + buttonBox.rejected.connect(dialog.reject) + layout.addWidget(buttonBox) + + dialog.setLayout(layout) + + dialog.exec_() + + def image_selected(self): + selected_item = self.listWidget.currentItem() + if selected_item: + selected_index = self.listWidget.currentRow() + if selected_index >= 0 and selected_index < len(self.image_files): # 检查索引是否在有效范围内 + self.current_index = selected_index + self.show_image() # 显示所选图像 + # 调用OpenCV窗口显示 + self.call_opencv_interaction(os.path.join(self.image_folder, self.image_files[self.current_index])) + + def select_background_image(self): + file_dialog = QtWidgets.QFileDialog() + image_path, _ = file_dialog.getOpenFileName(self, 'Select Background Image', '', + 'Image Files (*.png *.jpg *.jpeg *.bmp)') + if image_path: + self.show_background_image(image_path) + + def show_background_image(self, image_path): + pixmap = QtGui.QPixmap(image_path) + current_pixmap = self.label_2.pixmap() + if current_pixmap: + current_pixmap = QtGui.QPixmap(current_pixmap) + scene = QtWidgets.QGraphicsScene() + scene.addPixmap(pixmap) + scene.addPixmap(current_pixmap) + merged_pixmap = QtGui.QPixmap(scene.sceneRect().size().toSize()) + merged_pixmap.fill(QtCore.Qt.transparent) + painter = QtGui.QPainter(merged_pixmap) + scene.render(painter) + painter.end() + self.label_2.setPixmap(merged_pixmap) + else: + self.label_2.setPixmap(pixmap.scaled(self.label_2.size(), QtCore.Qt.KeepAspectRatio)) + + def show_image(self): + if self.image_files and self.current_index < len(self.image_files): + file_path = os.path.join(self.image_folder, self.image_files[self.current_index]) + pixmap = QtGui.QPixmap(file_path) + self.label_orign.setPixmap(pixmap.scaled(self.label_orign.size(), QtCore.Qt.KeepAspectRatio)) + + def call_opencv_interaction(self, image_path): + input_dir = os.path.dirname(image_path) + image_orign = cv2.imread(image_path) + output_dir = r'D:\Program Files\Pycharm items\segment-anything-model\scripts\output\maskt' + crop_mode = True + + print('最好是每加一个点就按w键predict一次') + os.makedirs(output_dir, exist_ok=True) + image_files = [self.image_files[self.current_index]] + + sam = sam_model_registry["vit_b"]( + checkpoint=r"D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth") + _ = sam.to(device="cuda") + predictor = SamPredictor(sam) + + WINDOW_WIDTH = 1280 + WINDOW_HEIGHT = 720 + + def apply_mask(image, mask, alpha_channel=True): + if alpha_channel: + alpha = np.zeros_like(image[..., 0]) + alpha[mask == 1] = 255 + image = cv2.merge((image[..., 0], image[..., 1], image[..., 2], alpha)) + else: + image = np.where(mask[..., None] == 1, image, 0) + return image + + def apply_color_mask(image, mask, color, color_dark=0.5): + for c in range(3): + image[:, :, c] = np.where(mask == 1, image[:, :, c] * (1 - color_dark) + color_dark * color[c], + image[:, :, c]) + return image + + def get_next_filename(base_path, filename): + name, ext = os.path.splitext(filename) + for i in range(1, 101): + new_name = f"{name}_{i}{ext}" + if not os.path.exists(os.path.join(base_path, new_name)): + return new_name + return None + + def save_masked_image(image, mask, output_dir, filename, crop_mode_): + # 保存图像到指定路径 + if crop_mode_: + # 如果采用了裁剪模式,则裁剪图像 + y, x = np.where(mask) + y_min, y_max, x_min, x_max = y.min(), y.max(), x.min(), x.max() + cropped_mask = mask[y_min:y_max + 1, x_min:x_max + 1] + cropped_image = image[y_min:y_max + 1, x_min:x_max + 1] + masked_image = apply_mask(cropped_image, cropped_mask) + else: + masked_image = apply_mask(image, mask) + filename = filename[:filename.rfind('.')] + '.png' + new_filename = get_next_filename(output_dir, filename) + + if new_filename: + if masked_image.shape[-1] == 4: + cv2.imwrite(os.path.join(output_dir, new_filename), masked_image, + [cv2.IMWRITE_PNG_COMPRESSION, 9]) + else: + cv2.imwrite(os.path.join(output_dir, new_filename), masked_image) + print(f"Saved as {new_filename}") + + # 读取保存的图像文件 + saved_image_path = os.path.join(output_dir, new_filename) + saved_image_pixmap = QtGui.QPixmap(saved_image_path) + + # 将保存的图像显示在预测图像区域 + mainWindow.label_2.setPixmap( + saved_image_pixmap.scaled(mainWindow.label_2.size(), QtCore.Qt.KeepAspectRatio)) + else: + print("Could not save the image. Too many variations exist.") + + current_index = 0 + + cv2.namedWindow("image", cv2.WINDOW_NORMAL) + cv2.resizeWindow("image", WINDOW_WIDTH, WINDOW_HEIGHT) + cv2.moveWindow("image", (1920 - WINDOW_WIDTH) // 2, (1080 - WINDOW_HEIGHT) // 2) + + def mouse_click(event, x, y, flags, param): + if not self.input_stop: + if event == cv2.EVENT_LBUTTONDOWN: + input_point.append([x, y]) + input_label.append(1) + elif event == cv2.EVENT_RBUTTONDOWN: + input_point.append([x, y]) + input_label.append(0) + else: + if event == cv2.EVENT_LBUTTONDOWN or event == cv2.EVENT_RBUTTONDOWN: + print('此时不能添加点,按w退出mask选择模式') + + cv2.setMouseCallback("image", mouse_click) + input_point = [] + input_label = [] + input_stop = False + while True: + filename = self.image_files[self.current_index] + image_orign = cv2.imread(os.path.join(input_dir, filename)) + image_crop = image_orign.copy() + image = cv2.cvtColor(image_orign.copy(), cv2.COLOR_BGR2RGB) + selected_mask = None + logit_input = None + while True: + image_display = image_orign.copy() + display_info = f'{filename} ' + cv2.putText(image_display, display_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), + 2, + cv2.LINE_AA) + for point, label in zip(input_point, input_label): + color = (0, 255, 0) if label == 1 else (0, 0, 255) + cv2.circle(image_display, tuple(point), 5, color, -1) + if selected_mask is not None: + color = tuple(np.random.randint(0, 256, 3).tolist()) + selected_image = apply_color_mask(image_display, selected_mask, color) + + cv2.imshow("image", image_display) + + key = cv2.waitKey(1) + + if key == ord(" "): + input_point = [] + input_label = [] + selected_mask = None + logit_input = None + elif key == ord("w"): + input_stop = True + if len(input_point) > 0 and len(input_label) > 0: + + predictor.set_image(image) + input_point_np = np.array(input_point) + input_label_np = np.array(input_label) + + masks, scores, logits = predictor.predict( + point_coords=input_point_np, + point_labels=input_label_np, + mask_input=logit_input[None, :, :] if logit_input is not None else None, + multimask_output=True, + ) + + mask_idx = 0 + num_masks = len(masks) + + prediction_window_name = "Prediction" + cv2.namedWindow(prediction_window_name, cv2.WINDOW_NORMAL) + cv2.resizeWindow(prediction_window_name, WINDOW_WIDTH, WINDOW_HEIGHT) + cv2.moveWindow(prediction_window_name, (1920 - WINDOW_WIDTH) // 2, + (1080 - WINDOW_HEIGHT) // 2) + + while True: + color = tuple(np.random.randint(0, 256, 3).tolist()) + image_select = image_orign.copy() + selected_mask = masks[mask_idx] + selected_image = apply_color_mask(image_select, selected_mask, color) + mask_info = f'Total: {num_masks} | Current: {mask_idx} | Score: {scores[mask_idx]:.2f} | w 预测 | d 切换下一个 | a 切换上一个 |q 移除最后一个' + cv2.putText(selected_image, mask_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, + (0, 255, 255), 2, cv2.LINE_AA) + + cv2.imshow(prediction_window_name, selected_image) + + key = cv2.waitKey(10) + if key == ord('q') and len(input_point) > 0: + input_point.pop(-1) + elif key == ord('s'): + save_masked_image(image_crop, selected_mask, output_dir, filename, + crop_mode_=crop_mode) + elif key == ord('a'): + if mask_idx > 0: + mask_idx -= 1 + else: + mask_idx = num_masks - 1 + elif key == ord('d'): + if mask_idx < num_masks - 1: + mask_idx += 1 + else: + mask_idx = 0 + elif key == ord('w'): + input_stop = False # Allow adding points again + break + elif key == ord(" "): + input_point = [] + input_label = [] + selected_mask = None + logit_input = None + break + logit_input = logits[mask_idx, :, :] + print('max score:', np.argmax(scores), ' select:', mask_idx) + + elif key == ord('a'): + current_index = max(0, current_index - 1) + input_point = [] + input_label = [] + break + elif key == ord('d'): + current_index = min(len(image_files) - 1, current_index + 1) + input_point = [] + input_label = [] + break + elif key == 27: + break + elif key == ord('q') and len(input_point) > 0: + input_point.pop(-1) + input_label.pop(-1) + elif key == ord('s') and selected_mask is not None: + save_masked_image(image_crop, selected_mask, output_dir, filename, crop_mode_=crop_mode) + + if key == 27: + break + + cv2.destroyAllWindows() # Close all windows before exiting + if key == 27: + break + + def save_prediction(self): + if self.label_2.pixmap(): # 检查预测图像区域是否有图像 + # 保存预测结果的部分,这里假设你已经有了保存预测结果的代码,我用 placeholer 代替 + # placeholder: 这里假设 save_prediction_result 是一个保存预测结果的函数,它接受预测结果的图像数据以及保存路径作为参数 + # 这里假设预测结果图像数据为 prediction_image,保存路径为 save_path + prediction_image = self.label_2.pixmap().toImage() + save_path = "prediction_result.png" + prediction_image.save(save_path) + + # 调用 adjust_prediction_size 方法来根据 horizontalSlider 的值调整预测区域的大小 + self.adjust_prediction_size(self.horizontalSlider.value()) + + def adjust_prediction_size(self, value): + if self.image_files and self.current_index < len(self.image_files): + # 获取预测图像区域的原始大小 + pixmap = self.label_2.pixmap() + if pixmap.isNull(): + return + + original_size = pixmap.size() + + # 判断是缩小还是还原图像 + if value < self.last_value: + # 缩小掩码 + scale_factor = 1.0 + (self.last_value - value) * 0.1 + else: + # 放大掩码 + scale_factor = 1.0 - (value - self.last_value) * 0.1 + + self.last_value = value # 更新上一次的滑块值 + + # 根据缩放比例调整预测图像区域的大小,并保持纵横比例 + scaled_size = QtCore.QSize(original_size.width() * scale_factor, original_size.height() * scale_factor) + scaled_pixmap = pixmap.scaled(scaled_size, QtCore.Qt.KeepAspectRatio) + + # 更新预测图像区域的大小并显示 + self.label_2.setPixmap(scaled_pixmap) + + +if __name__ == "__main__": + app = QtWidgets.QApplication(sys.argv) + mainWindow = MyMainWindow() + mainWindow.show() + sys.exit(app.exec_()) + diff --git a/salt/display1.py b/salt/display1.py new file mode 100644 index 0000000..76f904f --- /dev/null +++ b/salt/display1.py @@ -0,0 +1,146 @@ +import sys +import os +import cv2 +import numpy as np +from PyQt5 import QtCore, QtGui, QtWidgets +from segment_anything import sam_model_registry, SamPredictor + +class ImageLabel(QtWidgets.QLabel): + clicked = QtCore.pyqtSignal(QtCore.QPoint) + + def __init__(self, *args, **kwargs): + super(ImageLabel, self).__init__(*args, **kwargs) + self.foreground_points = [] + self.background_points = [] + + def mousePressEvent(self, event): + self.clicked.emit(event.pos()) + + if event.button() == QtCore.Qt.LeftButton: + self.foreground_points.append(event.pos()) + elif event.button() == QtCore.Qt.RightButton: + self.background_points.append(event.pos()) + + self.update() + + def paintEvent(self, event): + super().paintEvent(event) + + painter = QtGui.QPainter(self) + painter.setPen(QtGui.QPen(QtGui.QColor("green"), 5)) + + for point in self.foreground_points: + painter.drawPoint(point) + + painter.setPen(QtGui.QPen(QtGui.QColor("red"), 5)) + + for point in self.background_points: + painter.drawPoint(point) + + +class MainWindow(QtWidgets.QMainWindow): + def __init__(self, predictor): + super().__init__() + self.predictor = predictor + self.current_points = [] + self.current_image = None + self.setupUi() + + def setupUi(self): + self.setObjectName("MainWindow") + self.resize(1333, 657) + self.centralwidget = QtWidgets.QWidget(self) + self.centralwidget.setObjectName("centralwidget") + self.pushButton_init = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_init.setGeometry(QtCore.QRect(10, 30, 141, 41)) + self.pushButton_init.setObjectName("pushButton_init") + self.pushButton_openimg = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_openimg.setGeometry(QtCore.QRect(10, 90, 141, 41)) + self.pushButton_openimg.setObjectName("pushButton_openimg") + self.pushButton_save_mask = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_save_mask.setGeometry(QtCore.QRect(10, 600, 141, 41)) + self.pushButton_save_mask.setObjectName("pushButton_save_mask") + self.label_Originalimg = ImageLabel(self.centralwidget) + self.label_Originalimg.setGeometry(QtCore.QRect(160, 30, 571, 581)) + self.label_Originalimg.setStyleSheet("background-color: rgb(255, 255, 255);") + self.label_Originalimg.setObjectName("label_Originalimg") + self.label_Maskimg = QtWidgets.QLabel(self.centralwidget) + self.label_Maskimg.setGeometry(QtCore.QRect(740, 30, 581, 581)) + self.label_Maskimg.setStyleSheet("background-color: rgb(255, 255, 255);") + self.label_Maskimg.setObjectName("label_Maskimg") + self.setCentralWidget(self.centralwidget) + self.menubar = QtWidgets.QMenuBar(self) + self.menubar.setGeometry(QtCore.QRect(0, 0, 1333, 26)) + self.menubar.setObjectName("menubar") + self.setMenuBar(self.menubar) + self.statusbar = QtWidgets.QStatusBar(self) + self.statusbar.setObjectName("statusbar") + self.setStatusBar(self.statusbar) + self.retranslateUi() + QtCore.QMetaObject.connectSlotsByName(self) + + def retranslateUi(self): + _translate = QtCore.QCoreApplication.translate + self.setWindowTitle(_translate("MainWindow", "MainWindow")) + self.pushButton_init.setText(_translate("MainWindow", "重置选择")) + self.pushButton_openimg.setText(_translate("MainWindow", "打开图片")) + self.pushButton_save_mask.setText(_translate("MainWindow", "保存掩码")) + + self.pushButton_openimg.clicked.connect(self.button_image_open) + self.pushButton_save_mask.clicked.connect(self.button_save_mask) + self.label_Originalimg.clicked.connect(self.mouse_click) + + def button_image_open(self): + choice = QtWidgets.QMessageBox.question(None, "选择", "您想要打开文件夹还是选择一个图片文件?", + QtWidgets.QMessageBox.Open | QtWidgets.QMessageBox.Cancel) + if choice == QtWidgets.QMessageBox.Open: + folder_path = QtWidgets.QFileDialog.getExistingDirectory(None, "选择文件夹", "") + if folder_path: + image_files = [os.path.join(folder_path, file) for file in os.listdir(folder_path) + if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp'))] + if image_files: + self.image_files = image_files + self.current_index = 0 + self.display_image() + elif choice == QtWidgets.QMessageBox.Cancel: + selected_image, _ = QtWidgets.QFileDialog.getOpenFileName(None, "选择图片", "", + "Image files (*.png *.jpg *.jpeg *.bmp)") + if selected_image: + self.image_files = [selected_image] + self.current_index = 0 + self.display_image() + + def display_image(self): + if hasattr(self, 'image_files') and self.image_files: + pixmap = QtGui.QPixmap(self.image_files[self.current_index]) + self.label_Originalimg.setPixmap(pixmap) + self.label_Originalimg.setScaledContents(True) + + def mouse_click(self, pos): + x, y = pos.x(), pos.y() + print("Mouse clicked at position:", x, y) + self.current_points.append([x, y]) + + def button_save_mask(self): + if self.current_image is not None and len(self.current_points) > 0: + masks, _, _ = self.predictor.predict(point_coords=np.array(self.current_points), + point_labels=np.ones(len(self.current_points), dtype=np.uint8), + multimask_output=True) + if masks: + mask_image = masks[0] + mask_image = cv2.cvtColor(mask_image, cv2.COLOR_BGR2RGB) + q_image = QtGui.QImage(mask_image.data, mask_image.shape[1], mask_image.shape[0], mask_image.strides[0], + QtGui.QImage.Format_RGB888) + pixmap = QtGui.QPixmap.fromImage(q_image) + self.label_Maskimg.setPixmap(pixmap) + self.label_Maskimg.setScaledContents(True) + + +if __name__ == "__main__": + app = QtWidgets.QApplication(sys.argv) + sam = sam_model_registry["vit_b"](checkpoint=r"D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth") + sam = sam.to(device="cuda") + predictor = SamPredictor(sam) + window = MainWindow(predictor) + window.show() + sys.exit(app.exec_()) diff --git a/salt/interface.py b/salt/interface.py new file mode 100644 index 0000000..662c8ff --- /dev/null +++ b/salt/interface.py @@ -0,0 +1,144 @@ +import sys +import os +import cv2 +import numpy as np +from PyQt5 import QtCore, QtGui, QtWidgets +from segment_anything import sam_model_registry, SamPredictor + + +class Ui_MainWindow(object): + def setupUi(self, MainWindow): + MainWindow.setObjectName("MainWindow") + MainWindow.resize(1333, 657) + self.centralwidget = QtWidgets.QWidget(MainWindow) + self.centralwidget.setObjectName("centralwidget") + self.pushButton_init = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_init.setGeometry(QtCore.QRect(10, 30, 141, 41)) + self.pushButton_init.setObjectName("pushButton_init") + self.pushButton_openimg = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_openimg.setGeometry(QtCore.QRect(10, 90, 141, 41)) + self.pushButton_openimg.setObjectName("pushButton_openimg") + self.pushButton_Fusionimg = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_Fusionimg.setGeometry(QtCore.QRect(10, 270, 141, 41)) + self.pushButton_Fusionimg.setObjectName("pushButton_Fusionimg") + self.pushButton_exit = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_exit.setGeometry(QtCore.QRect(10, 570, 141, 41)) + self.pushButton_exit.setObjectName("pushButton_exit") + self.pushButton_Transparency = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_Transparency.setGeometry(QtCore.QRect(10, 380, 141, 41)) + self.pushButton_Transparency.setObjectName("pushButton_Transparency") + self.pushButton_copymask = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_copymask.setGeometry(QtCore.QRect(10, 450, 141, 41)) + self.pushButton_copymask.setObjectName("pushButton_copymask") + self.pushButton_saveimg = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_saveimg.setGeometry(QtCore.QRect(10, 510, 141, 41)) + self.pushButton_saveimg.setObjectName("pushButton_saveimg") + self.horizontalSlider = QtWidgets.QSlider(self.centralwidget) + self.horizontalSlider.setGeometry(QtCore.QRect(10, 330, 141, 22)) + self.horizontalSlider.setOrientation(QtCore.Qt.Horizontal) + self.horizontalSlider.setObjectName("horizontalSlider") + self.horizontalSlider.setValue(50) + self.label_Originalimg = QtWidgets.QLabel(self.centralwidget) + self.label_Originalimg.setGeometry(QtCore.QRect(160, 30, 571, 581)) + self.label_Originalimg.setStyleSheet("background-color: rgb(255, 255, 255);") + self.label_Originalimg.setObjectName("label_Originalimg") + self.label_Maskimg = QtWidgets.QLabel(self.centralwidget) + self.label_Maskimg.setGeometry(QtCore.QRect(740, 30, 581, 581)) + self.label_Maskimg.setStyleSheet("background-color: rgb(255, 255, 255);") + self.label_Maskimg.setObjectName("label_Maskimg") + self.pushButton_shang = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_shang.setGeometry(QtCore.QRect(10, 150, 141, 41)) + self.pushButton_shang.setObjectName("pushButton_shang") + self.pushButton_xia = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_xia.setGeometry(QtCore.QRect(10, 210, 141, 41)) + self.pushButton_xia.setObjectName("pushButton_openimg_3") + MainWindow.setCentralWidget(self.centralwidget) + self.menubar = QtWidgets.QMenuBar(MainWindow) + self.menubar.setGeometry(QtCore.QRect(0, 0, 1333, 26)) + self.menubar.setObjectName("menubar") + MainWindow.setMenuBar(self.menubar) + self.statusbar = QtWidgets.QStatusBar(MainWindow) + self.statusbar.setObjectName("statusbar") + MainWindow.setStatusBar(self.statusbar) + + self.retranslateUi(MainWindow) + QtCore.QMetaObject.connectSlotsByName(MainWindow) + + def retranslateUi(self, MainWindow): + + _translate = QtCore.QCoreApplication.translate + MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow")) + self.pushButton_init.setText(_translate("MainWindow", "重置选择")) + self.pushButton_openimg.setText(_translate("MainWindow", "打开图片")) + self.pushButton_shang.setText(_translate("MainWindow", "上一张")) + self.pushButton_xia.setText(_translate("MainWindow", "下一张")) + self.pushButton_Fusionimg.setText(_translate("MainWindow", "融合背景图片")) + self.pushButton_exit.setText(_translate("MainWindow", "退出")) + self.pushButton_Transparency.setText(_translate("MainWindow", "调整透明度")) + self.pushButton_copymask.setText(_translate("MainWindow", "复制掩码")) + self.pushButton_saveimg.setText(_translate("MainWindow", "保存图片")) + self.label_Originalimg.setText( + _translate("MainWindow", "

原始图像

")) + self.label_Maskimg.setText( + _translate("MainWindow", "

掩码图像

")) + + +class MyMainWindow(QtWidgets.QMainWindow, Ui_MainWindow): + def __init__(self): + super(MyMainWindow, self).__init__() + self.setupUi(self) + self.init_slots() + self.image_files = [] + self.current_index = 0 + + def init_slots(self): + self.pushButton_openimg.clicked.connect(self.button_image_open) + self.pushButton_shang.clicked.connect(self.button_image_shang) + self.pushButton_xia.clicked.connect(self.button_image_xia) + self.pushButton_exit.clicked.connect(self.button_image_exit) + + def button_image_open(self): + choice = QtWidgets.QMessageBox.question(None, "选择", "您想要打开文件夹还是选择一个图片文件?", + QtWidgets.QMessageBox.Open | QtWidgets.QMessageBox.Cancel) + if choice == QtWidgets.QMessageBox.Open: + folder_path = QtWidgets.QFileDialog.getExistingDirectory(None, "选择文件夹", "") + if folder_path: + self.image_files = [os.path.join(folder_path, file) for file in os.listdir(folder_path) + if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp'))] + if self.image_files: + self.current_index = 0 + self.display_image() + elif choice == QtWidgets.QMessageBox.Cancel: + selected_image, _ = QtWidgets.QFileDialog.getOpenFileName(None, "选择图片", "", + "Image files (*.png *.jpg *.jpeg *.bmp)") + if selected_image: + self.image_files = [selected_image] + self.current_index = 0 + self.display_image() + + def display_image(self): + if self.image_files: + pixmap = QtGui.QPixmap(self.image_files[self.current_index]) + self.label_Originalimg.setPixmap(pixmap) + self.label_Originalimg.setScaledContents(True) + + def button_image_shang(self): + if self.image_files: + self.current_index = (self.current_index - 1) % len(self.image_files) + self.display_image() + + def button_image_xia(self): + if self.image_files: + self.current_index = (self.current_index + 1) % len(self.image_files) + self.display_image() + + def button_image_exit(self): + sys.exit() + + +if __name__ == "__main__": + app = QtWidgets.QApplication(sys.argv) + my_main_window = MyMainWindow() + my_main_window.show() + sys.exit(app.exec_()) + diff --git a/salt/prediction_result.png b/salt/prediction_result.png new file mode 100644 index 0000000..9662d41 Binary files /dev/null and b/salt/prediction_result.png differ diff --git a/salt/segment1.py b/salt/segment1.py new file mode 100644 index 0000000..d3f51e1 --- /dev/null +++ b/salt/segment1.py @@ -0,0 +1,299 @@ +import cv2 +import os +import numpy as np +from PyQt5 import QtCore, QtGui, QtWidgets +from segment_anything import sam_model_registry, SamPredictor + + +class Ui_MainWindow(object): + def setupUi(self, MainWindow): + MainWindow.setObjectName("MainWindow") + MainWindow.resize(1170, 486) + self.centralwidget = QtWidgets.QWidget(MainWindow) + self.centralwidget.setObjectName("centralwidget") + self.pushButton_w = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_w.setGeometry(QtCore.QRect(10, 90, 151, 51)) + self.pushButton_w.setObjectName("pushButton_w") + self.pushButton_a = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_a.setGeometry(QtCore.QRect(10, 160, 151, 51)) + self.pushButton_a.setObjectName("pushButton_a") + self.pushButton_d = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_d.setGeometry(QtCore.QRect(10, 230, 151, 51)) + self.pushButton_d.setObjectName("pushButton_d") + self.pushButton_s = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_s.setGeometry(QtCore.QRect(10, 300, 151, 51)) + self.pushButton_s.setObjectName("pushButton_s") + self.pushButton_q = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_q.setGeometry(QtCore.QRect(10, 370, 151, 51)) + self.pushButton_q.setObjectName("pushButton_q") + self.label_orign = QtWidgets.QLabel(self.centralwidget) + self.label_orign.setGeometry(QtCore.QRect(180, 20, 450, 450)) + self.label_orign.setStyleSheet("background-color: rgb(255, 255, 255);") + self.label_orign.setObjectName("label_orign") + self.label_pre = QtWidgets.QLabel(self.centralwidget) + self.label_pre.setGeometry(QtCore.QRect(660, 20, 450, 450)) + self.label_pre.setStyleSheet("background-color: rgb(255, 255, 255);") + self.label_pre.setObjectName("label_pre") + self.pushButton_opimg = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_opimg.setGeometry(QtCore.QRect(10, 20, 151, 51)) + self.pushButton_opimg.setObjectName("pushButton_opimg") + MainWindow.setCentralWidget(self.centralwidget) + self.menubar = QtWidgets.QMenuBar(MainWindow) + self.menubar.setGeometry(QtCore.QRect(0, 0, 1170, 26)) + self.menubar.setObjectName("menubar") + MainWindow.setMenuBar(self.menubar) + self.statusbar = QtWidgets.QStatusBar(MainWindow) + self.statusbar.setObjectName("statusbar") + MainWindow.setStatusBar(self.statusbar) + + self.retranslateUi(MainWindow) + QtCore.QMetaObject.connectSlotsByName(MainWindow) + + def retranslateUi(self, MainWindow): + _translate = QtCore.QCoreApplication.translate + MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow")) + self.pushButton_w.setText(_translate("MainWindow", "w")) + self.pushButton_a.setText(_translate("MainWindow", "a")) + self.pushButton_d.setText(_translate("MainWindow", "d")) + self.pushButton_s.setText(_translate("MainWindow", "s")) + self.pushButton_q.setText(_translate("MainWindow", "q")) + self.label_orign.setText( + _translate("MainWindow", "

Original Image

")) + self.label_pre.setText( + _translate("MainWindow", "

Predicted Image

")) + self.pushButton_opimg.setText(_translate("MainWindow", "Open Image")) + + +class MainWindow(QtWidgets.QMainWindow, Ui_MainWindow): + def __init__(self): + super().__init__() + self.setupUi(self) + + self.pushButton_opimg.clicked.connect(self.open_image) + self.pushButton_w.clicked.connect(self.predict_and_interact) + + self.image_files = [] + self.current_index = 0 + self.input_point = [] + self.input_label = [] + self.input_stop = False + self.interaction_count = 0 # 记录交互次数 + self.sam = sam_model_registry["vit_b"]( + checkpoint=r"D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth") + _ = self.sam.to(device="cuda") + self.predictor = SamPredictor(self.sam) + + # Calculate coordinate scaling factors + self.scale_x = 1.0 + self.scale_y = 1.0 + self.label_pre_width = self.label_pre.width() + self.label_pre_height = self.label_pre.height() + + # Set mouse click event for original image label + self.set_mouse_click_event() + + def open_image(self): + options = QtWidgets.QFileDialog.Options() + filename, _ = QtWidgets.QFileDialog.getOpenFileName(self, "Open Image File", "", + "Image Files (*.png *.jpg *.jpeg *.JPG *.JPEG *.PNG *.tiff)", + options=options) + if filename: + image = cv2.imread(filename) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + self.image_files.append(image) + self.display_original_image() + + def display_original_image(self): + if self.image_files: + image = self.image_files[self.current_index] + height, width, channel = image.shape + bytesPerLine = 3 * width + qImg = QtGui.QImage(image.data, width, height, bytesPerLine, QtGui.QImage.Format_RGB888) + pixmap = QtGui.QPixmap.fromImage(qImg) + self.label_orign.setPixmap(pixmap.scaled(self.label_orign.size(), QtCore.Qt.KeepAspectRatio)) + + # Add mouse click event + self.label_orign.mousePressEvent = self.mouse_click + + # Draw marked points on the original image + painter = QtGui.QPainter(self.label_orign.pixmap()) # Use label_orign for drawing points + pen_foreground = QtGui.QPen(QtGui.QColor(0, 255, 0)) # Red color for foreground points + pen_foreground.setWidth(5) # Set the width of the pen to 5 (adjust as needed) + pen_background = QtGui.QPen(QtGui.QColor(255, 0, 0)) # Green color for background points + pen_background.setWidth(5) # Set the width of the pen to 5 (adjust as needed) + painter.setPen(pen_foreground) + for point, label in zip(self.input_point, self.input_label): + x, y = self.convert_to_label_coords(point) + if label == 1: # Foreground point + painter.drawPoint(QtCore.QPoint(x, y)) + painter.setPen(pen_background) + for point, label in zip(self.input_point, self.input_label): + x, y = self.convert_to_label_coords(point) + if label == 0: # Background point + painter.drawPoint(QtCore.QPoint(x, y)) + painter.end() + + # Calculate coordinate scaling factors + self.scale_x = width / self.label_orign.width() + self.scale_y = height / self .label_orign.height() + + def convert_to_label_coords(self, point): + x = point[0] / self.scale_x + y = point[1] / self.scale_y + return x, y + + def mouse_click(self, event): + if not self.input_stop: + x = int(event.pos().x() * self.scale_x) + y = int(event.pos().y() * self.scale_y) + if event.button() == QtCore.Qt.LeftButton: # If left-clicked, mark as foreground + self.input_label.append(1) # Foreground label is 1 + elif event.button() == QtCore.Qt.RightButton: # If right-clicked, mark as background + self.input_label.append(0) # Background label is 0 + + self.input_point.append([x, y]) + + # Update the original image with marked points + self.display_original_image() + + def predict_and_interact(self): + if not self.image_files: + return + + image = self.image_files[self.current_index].copy() + filename = f"image_{self.current_index}.png" + image_crop = image.copy() + + while True: # Outer loop for prediction + # Prediction logic + if not self.input_stop: # If not in interaction mode + if len(self.input_point) > 0 and len(self.input_label) > 0: + self.predictor.set_image(image) + input_point_np = np.array(self.input_point) + input_label_np = np.array(self.input_label) + + masks, scores, logits = self.predictor.predict( + point_coords=input_point_np, + point_labels=input_label_np, + multimask_output=True, + ) + + mask_idx = 0 + num_masks = len(masks) + + while True: # Inner loop for interaction + color = tuple(np.random.randint(0, 256, 3).tolist()) + image_select = image.copy() + selected_mask = masks[mask_idx] + selected_image = self.apply_color_mask(image_select, selected_mask, color) + mask_info = f'Total: {num_masks} | Current: {mask_idx} | Score: {scores[mask_idx]:.2f} | w Predict | d Next | a Previous | q Remove Last | s Save' + cv2.putText(selected_image, mask_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), + 2, cv2.LINE_AA) + + # Display the predicted result in label_pre area + self.display_prediction_image(selected_image) + + key = cv2.waitKey(10) + + # Handle key press events + if key == ord('q') and len(self.input_point) > 0: + self.input_point.pop(-1) + self.input_label.pop(-1) + self.display_original_image() + elif key == ord('s'): + self.save_masked_image(image_crop, selected_mask, filename) + elif key == ord('a'): + if mask_idx > 0: + mask_idx -= 1 + else: + mask_idx = num_masks - 1 + elif key == ord('d'): + if mask_idx < num_masks - 1: + mask_idx += 1 + else: + mask_idx = 0 + elif key == ord(" "): + break + + if cv2.getWindowProperty("Prediction", cv2.WND_PROP_VISIBLE) < 1: + break + + # If 'w' is pressed, toggle interaction mode + if key == ord('w'): + self.input_stop = not self.input_stop # Toggle interaction mode + if not self.input_stop: # If entering interaction mode + self.interaction_count += 1 + if self.interaction_count % 2 == 0: # If even number of interactions, call the interaction function + self.input_point = [] # Reset input points for the next interaction + self.input_label = [] # Reset input labels for the next interaction + self.display_original_image() # Display original image + self.set_mouse_click_event() # Set mouse click event + break # Exit outer loop + else: + continue # Continue prediction + + # Exit the outer loop if not in interaction mode + if not self.input_stop: + break + + def set_mouse_click_event(self): + self.label_orign.mousePressEvent = self.mouse_click + + def display_prediction_image(self, image): + height, width, channel = image.shape + bytesPerLine = 3 * width + qImg = QtGui.QImage(image.data, width, height, bytesPerLine, QtGui.QImage.Format_RGB888) + pixmap = QtGui.QPixmap.fromImage(qImg) + self.label_pre.setPixmap(pixmap.scaled(self.label_pre.size(), QtCore.Qt.KeepAspectRatio)) + + # Draw marked points on the predicted image + painter = QtGui.QPainter(self.label_pre.pixmap()) + pen_foreground = QtGui.QPen(QtGui.QColor(0, 255, 0)) # Red color for foreground points + pen_foreground.setWidth(5) # Set the width of the pen to 5 (adjust as needed) + pen_background = QtGui.QPen(QtGui.QColor(255, 0, 0)) # Green color for background points + pen_background.setWidth(5) # Set the width of the pen to 5 (adjust as needed) + painter.setPen(pen_foreground) + for point, label in zip(self.input_point, self.input_label): + x, y = self.convert_to_label_coords(point) + if label == 1: # Foreground point + painter.drawPoint(QtCore.QPoint(x, y)) + painter.setPen(pen_background) + for point, label in zip(self.input_point, self.input_label): + x, y = self.convert_to_label_coords(point) + if label == 0: # Background point + painter.drawPoint(QtCore.QPoint(x, y)) + painter.end() + + def apply_color_mask(self, image, mask, color=(0, 255, 0), color_dark=0.5): + for c in range(3): + image[:, :, c] = np.where(mask == 1, image[:, :, c] * (1 - color_dark) + color_dark * color[c], + image[:, :, c]) + return image + + def save_masked_image(self, image, mask, filename): + output_dir = os.path.dirname(filename) + filename = os.path.basename(filename) + filename = filename[:filename.rfind('.')] + '_masked.png' + new_filename = os.path.join(output_dir, filename) + + masked_image = self.apply_color_mask(image, mask) + cv2.imwrite(new_filename, cv2.cvtColor(masked_image, cv2.COLOR_RGB2BGR)) + print(f"Saved as { new_filename}") + + def previous_image(self): + if self.current_index > 0: + self.current_index -= 1 + self.display_original_image() + + def next_image(self): + if self.current_index < len(self.image_files) - 1: + self.current_index += 1 + self.display_original_image() + +if __name__ == "__main__": + import sys + app = QtWidgets.QApplication(sys.argv) + window = MainWindow() + window.show() + sys.exit(app.exec_()) + diff --git a/salt/suibian.py b/salt/suibian.py new file mode 100644 index 0000000..a015f05 --- /dev/null +++ b/salt/suibian.py @@ -0,0 +1,453 @@ +import os +import sys +import cv2 +import numpy as np +from PyQt5 import QtCore, QtGui, QtWidgets +from segment_anything import sam_model_registry, SamPredictor + +class Ui_MainWindow(object): + def setupUi(self, MainWindow): + MainWindow.setObjectName("MainWindow") + MainWindow.resize(1140, 450) + MainWindow.setMinimumSize(QtCore.QSize(1140, 450)) + MainWindow.setMaximumSize(QtCore.QSize(1140, 450)) + self.centralwidget = QtWidgets.QWidget(MainWindow) + self.centralwidget.setObjectName("centralwidget") + self.pushButton_w = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_w.setGeometry(QtCore.QRect(10, 90, 151, 51)) + self.pushButton_w.setObjectName("pushButton_w") + self.pushButton_a = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_a.setGeometry(QtCore.QRect(10, 160, 71, 51)) + self.pushButton_a.setObjectName("pushButton_a") + self.pushButton_d = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_d.setGeometry(QtCore.QRect(90, 160, 71, 51)) + self.pushButton_d.setObjectName("pushButton_d") + self.pushButton_s = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_s.setGeometry(QtCore.QRect(10, 360, 151, 51)) + self.pushButton_s.setObjectName("pushButton_s") + self.pushButton_5 = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_5.setGeometry(QtCore.QRect(10, 230, 151, 51)) + self.pushButton_5.setObjectName("pushButton_5") + self.label_orign = QtWidgets.QLabel(self.centralwidget) + self.label_orign.setGeometry(QtCore.QRect(180, 20, 471, 401)) + self.label_orign.setStyleSheet("background-color: rgb(255, 255, 255);") + self.label_orign.setObjectName("label_orign") + self.label_2 = QtWidgets.QLabel(self.centralwidget) + self.label_2.setGeometry(QtCore.QRect(660, 20, 471, 401)) + self.label_2.setStyleSheet("background-color: rgb(255, 255, 255);") + self.label_2.setObjectName("label_2") + self.pushButton_w_2 = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_w_2.setGeometry(QtCore.QRect(10, 20, 151, 51)) + self.pushButton_w_2.setObjectName("pushButton_w_2") + self.lineEdit = QtWidgets.QLineEdit(self.centralwidget) + self.lineEdit.setGeometry(QtCore.QRect(50, 290, 81, 21)) + self.lineEdit.setObjectName("lineEdit") + self.horizontalSlider = QtWidgets.QSlider(self.centralwidget) + self.horizontalSlider.setGeometry(QtCore.QRect(10, 320, 141, 22)) + self.horizontalSlider.setRange(0, 10) # 将范围设置为从0到最大值 + self.horizontalSlider.setSingleStep(1) + self.horizontalSlider.setValue(0) # 初始值设为0 + self.horizontalSlider.setOrientation(QtCore.Qt.Horizontal) + self.horizontalSlider.setTickInterval(0) + self.horizontalSlider.setObjectName("horizontalSlider") + MainWindow.setCentralWidget(self.centralwidget) + self.menubar = QtWidgets.QMenuBar(MainWindow) + self.menubar.setGeometry(QtCore.QRect(0, 0, 1140, 23)) + self.menubar.setObjectName("menubar") + MainWindow.setMenuBar(self.menubar) + self.statusbar = QtWidgets.QStatusBar(MainWindow) + self.statusbar.setObjectName("statusbar") + MainWindow.setStatusBar(self.statusbar) + + self.retranslateUi(MainWindow) + QtCore.QMetaObject.connectSlotsByName(MainWindow) + + def retranslateUi(self, MainWindow): + _translate = QtCore.QCoreApplication.translate + MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow")) + self.pushButton_w.setText(_translate("MainWindow", "Predict")) + self.pushButton_a.setText(_translate("MainWindow", "Pre")) + self.pushButton_d.setText(_translate("MainWindow", "Next")) + self.pushButton_s.setText(_translate("MainWindow", "Save")) + self.pushButton_5.setText(_translate("MainWindow", "背景图")) + self.label_orign.setText(_translate("MainWindow", "

原始图像

")) + self.label_2.setText(_translate("MainWindow", "

预测图像

")) + self.pushButton_w_2.setText(_translate("MainWindow", "Openimg")) + self.lineEdit.setText(_translate("MainWindow", "改变mask大小")) + +class MyMainWindow(QtWidgets.QMainWindow, Ui_MainWindow): + def __init__(self): + super().__init__() + self.setupUi(self) + + self.k = 0 + self.last_value = 0 # 保存上一次滑块值 + + self.image_path = "" + self.image_folder = "" + self.image_files = [] + self.current_index = 0 + self.input_stop = False # 在这里初始化 input_stop + + self.pushButton_w_2.clicked.connect(self.open_image_folder) + self.pushButton_a.clicked.connect(self.load_previous_image) + self.pushButton_d.clicked.connect(self.load_next_image) + self.pushButton_s.clicked.connect(self.save_prediction) + self.pushButton_5.clicked.connect(self.select_background_image) + self.horizontalSlider.valueChanged.connect(self.adjust_prediction_size) # 连接水平滑块的值改变信号 + + def adjust_pixmap_size(self, pixmap, scale_factor): + scaled_size = QtCore.QSize(pixmap.size().width() * scale_factor / 100, + pixmap.size().height() * scale_factor / 100) + return pixmap.scaled(scaled_size, QtCore.Qt.KeepAspectRatio) + + + def open_image_folder(self): + folder_dialog = QtWidgets.QFileDialog() + folder_path = folder_dialog.getExistingDirectory(self, 'Open Image Folder', '') + if folder_path: + self.image_folder = folder_path + self.image_files = self.get_image_files(self.image_folder) + if self.image_files: + self.show_image_selection_dialog() + + def load_previous_image(self): + if self.image_files: + if self.current_index > 0: + self.current_index -= 1 + else: + self.current_index = len(self.image_files) - 1 + self.show_image() + + def load_next_image(self): + if self.image_files: + if self.current_index < len(self.image_files) - 1: + self.current_index += 1 + else: + self.current_index = 0 + self.show_image() + + def get_image_files(self, folder_path): + image_files = [file for file in os.listdir(folder_path) if file.endswith(('png', 'jpg', 'jpeg', 'bmp'))] + return image_files + + def show_image_selection_dialog(self): + dialog = QtWidgets.QDialog(self) + dialog.setWindowTitle("Select Image") + layout = QtWidgets.QVBoxLayout() + + self.listWidget = QtWidgets.QListWidget() + for image_file in self.image_files: + item = QtWidgets.QListWidgetItem(image_file) + pixmap = QtGui.QPixmap(os.path.join(self.image_folder, image_file)).scaledToWidth(100) + item.setIcon(QtGui.QIcon(pixmap)) + self.listWidget.addItem(item) + self.listWidget.itemDoubleClicked.connect(self.image_selected) + layout.addWidget(self.listWidget) + + buttonBox = QtWidgets.QDialogButtonBox(QtWidgets.QDialogButtonBox.Ok | QtWidgets.QDialogButtonBox.Cancel) + buttonBox.accepted.connect(self.image_selected) + buttonBox.rejected.connect(dialog.reject) + layout.addWidget(buttonBox) + + dialog.setLayout(layout) + + dialog.exec_() + + def image_selected(self): + selected_item = self.listWidget.currentItem() + if selected_item: + selected_index = self.listWidget.currentRow() + if selected_index >= 0 and selected_index < len(self.image_files): # 检查索引是否在有效范围内 + self.current_index = selected_index + self.show_image() # 显示所选图像 + # 调用OpenCV窗口显示 + self.call_opencv_interaction(os.path.join(self.image_folder, self.image_files[self.current_index])) + + def select_background_image(self): + file_dialog = QtWidgets.QFileDialog() + image_path, _ = file_dialog.getOpenFileName(self, 'Select Background Image', '', + 'Image Files (*.png *.jpg *.jpeg *.bmp)') + if image_path: + self.show_background_image(image_path) + + def show_background_image(self, image_path): + pixmap = QtGui.QPixmap(image_path) + current_pixmap = self.label_2.pixmap() + if current_pixmap: + current_pixmap = QtGui.QPixmap(current_pixmap) + scene = QtWidgets.QGraphicsScene() + scene.addPixmap(pixmap) + scene.addPixmap(current_pixmap) + merged_pixmap = QtGui.QPixmap(scene.sceneRect().size().toSize()) + merged_pixmap.fill(QtCore.Qt.transparent) + painter = QtGui.QPainter(merged_pixmap) + scene.render(painter) + painter.end() + self.label_2.setPixmap(merged_pixmap) + else: + self.label_2.setPixmap(pixmap.scaled(self.label_2.size(), QtCore.Qt.KeepAspectRatio)) + + def show_image(self): + if self.image_files and self.current_index < len(self.image_files): + file_path = os.path.join(self.image_folder, self.image_files[self.current_index]) + pixmap = QtGui.QPixmap(file_path) + self.label_orign.setPixmap(pixmap.scaled(self.label_orign.size(), QtCore.Qt.KeepAspectRatio)) + + def call_opencv_interaction(self, image_path): + input_dir = os.path.dirname(image_path) + image_orign = cv2.imread(image_path) + output_dir = r'D:\Program Files\Pycharm items\segment-anything-model\scripts\output\maskt' + crop_mode = True + + print('最好是每加一个点就按w键predict一次') + os.makedirs(output_dir, exist_ok=True) + image_files = [self.image_files[self.current_index]] + + sam = sam_model_registry["vit_b"]( + checkpoint=r"D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth") + _ = sam.to(device="cuda") + predictor = SamPredictor(sam) + + WINDOW_WIDTH = 1280 + WINDOW_HEIGHT = 720 + + def apply_mask(image, mask, alpha_channel=True): + if alpha_channel: + alpha = np.zeros_like(image[..., 0]) + alpha[mask == 1] = 255 + image = cv2.merge((image[..., 0], image[..., 1], image[..., 2], alpha)) + else: + image = np.where(mask[..., None] == 1, image, 0) + return image + + def apply_color_mask(image, mask, color, color_dark=0.5): + for c in range(3): + image[:, :, c] = np.where(mask == 1, image[:, :, c] * (1 - color_dark) + color_dark * color[c], + image[:, :, c]) + return image + + def get_next_filename(base_path, filename): + name, ext = os.path.splitext(filename) + for i in range(1, 101): + new_name = f"{name}_{i}{ext}" + if not os.path.exists(os.path.join(base_path, new_name)): + return new_name + return None + + def save_masked_image(image, mask, output_dir, filename, crop_mode_): + # 保存图像到指定路径 + if crop_mode_: + # 如果采用了裁剪模式,则裁剪图像 + y, x = np.where(mask) + y_min, y_max, x_min, x_max = y.min(), y.max(), x.min(), x.max() + cropped_mask = mask[y_min:y_max + 1, x_min:x_max + 1] + cropped_image = image[y_min:y_max + 1, x_min:x_max + 1] + masked_image = apply_mask(cropped_image, cropped_mask) + else: + masked_image = apply_mask(image, mask) + filename = filename[:filename.rfind('.')] + '.png' + new_filename = get_next_filename(output_dir, filename) + + if new_filename: + if masked_image.shape[-1] == 4: + cv2.imwrite(os.path.join(output_dir, new_filename), masked_image, + [cv2.IMWRITE_PNG_COMPRESSION, 9]) + else: + cv2.imwrite(os.path.join(output_dir, new_filename), masked_image) + print(f"Saved as {new_filename}") + + # 读取保存的图像文件 + saved_image_path = os.path.join(output_dir, new_filename) + saved_image_pixmap = QtGui.QPixmap(saved_image_path) + + # 将保存的图像显示在预测图像区域 + mainWindow.label_2.setPixmap( + saved_image_pixmap.scaled(mainWindow.label_2.size(), QtCore.Qt.KeepAspectRatio)) + else: + print("Could not save the image. Too many variations exist.") + + current_index = 0 + + cv2.namedWindow("image", cv2.WINDOW_NORMAL) + cv2.resizeWindow("image", WINDOW_WIDTH, WINDOW_HEIGHT) + cv2.moveWindow("image", (1920 - WINDOW_WIDTH) // 2, (1080 - WINDOW_HEIGHT) // 2) + + def mouse_click(event, x, y, flags, param): + if not self.input_stop: + if event == cv2.EVENT_LBUTTONDOWN: + input_point.append([x, y]) + input_label.append(1) + elif event == cv2.EVENT_RBUTTONDOWN: + input_point.append([x, y]) + input_label.append(0) + else: + if event == cv2.EVENT_LBUTTONDOWN or event == cv2.EVENT_RBUTTONDOWN: + print('此时不能添加点,按w退出mask选择模式') + + cv2.setMouseCallback("image", mouse_click) + input_point = [] + input_label = [] + input_stop = False + while True: + filename = self.image_files[self.current_index] + image_orign = cv2.imread(os.path.join(input_dir, filename)) + image_crop = image_orign.copy() + image = cv2.cvtColor(image_orign.copy(), cv2.COLOR_BGR2RGB) + selected_mask = None + logit_input = None + while True: + image_display = image_orign.copy() + display_info = f'{filename} ' + cv2.putText(image_display, display_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), + 2, + cv2.LINE_AA) + for point, label in zip(input_point, input_label): + color = (0, 255, 0) if label == 1 else (0, 0, 255) + cv2.circle(image_display, tuple(point), 5, color, -1) + if selected_mask is not None: + color = tuple(np.random.randint(0, 256, 3).tolist()) + selected_image = apply_color_mask(image_display, selected_mask, color) + + cv2.imshow("image", image_display) + + key = cv2.waitKey(1) + + if key == ord(" "): + input_point = [] + input_label = [] + selected_mask = None + logit_input = None + elif key == ord("w"): + input_stop = True + if len(input_point) > 0 and len(input_label) > 0: + + predictor.set_image(image) + input_point_np = np.array(input_point) + input_label_np = np.array(input_label) + + masks, scores, logits = predictor.predict( + point_coords=input_point_np, + point_labels=input_label_np, + mask_input=logit_input[None, :, :] if logit_input is not None else None, + multimask_output=True, + ) + + mask_idx = 0 + num_masks = len(masks) + + prediction_window_name = "Prediction" + cv2.namedWindow(prediction_window_name, cv2.WINDOW_NORMAL) + cv2.resizeWindow(prediction_window_name, WINDOW_WIDTH, WINDOW_HEIGHT) + cv2.moveWindow(prediction_window_name, (1920 - WINDOW_WIDTH) // 2, + (1080 - WINDOW_HEIGHT) // 2) + + while True: + color = tuple(np.random.randint(0, 256, 3).tolist()) + image_select = image_orign.copy() + selected_mask = masks[mask_idx] + selected_image = apply_color_mask(image_select, selected_mask, color) + mask_info = f'Total: {num_masks} | Current: {mask_idx} | Score: {scores[mask_idx]:.2f} | w 预测 | d 切换下一个 | a 切换上一个 |q 移除最后一个' + cv2.putText(selected_image, mask_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, + (0, 255, 255), 2, cv2.LINE_AA) + + cv2.imshow(prediction_window_name, selected_image) + + key = cv2.waitKey(10) + if key == ord('q') and len(input_point) > 0: + input_point.pop(-1) + elif key == ord('s'): + save_masked_image(image_crop, selected_mask, output_dir, filename, + crop_mode_=crop_mode) + elif key == ord('a'): + if mask_idx > 0: + mask_idx -= 1 + else: + mask_idx = num_masks - 1 + elif key == ord('d'): + if mask_idx < num_masks - 1: + mask_idx += 1 + else: + mask_idx = 0 + elif key == ord('w'): + input_stop = False # Allow adding points again + break + elif key == ord(" "): + input_point = [] + input_label = [] + selected_mask = None + logit_input = None + break + logit_input = logits[mask_idx, :, :] + print('max score:', np.argmax(scores), ' select:', mask_idx) + + elif key == ord('a'): + current_index = max(0, current_index - 1) + input_point = [] + input_label = [] + break + elif key == ord('d'): + current_index = min(len(image_files) - 1, current_index + 1) + input_point = [] + input_label = [] + break + elif key == 27: + break + elif key == ord('q') and len(input_point) > 0: + input_point.pop(-1) + input_label.pop(-1) + elif key == ord('s') and selected_mask is not None: + save_masked_image(image_crop, selected_mask, output_dir, filename, crop_mode_=crop_mode) + + if key == 27: + break + + cv2.destroyAllWindows() # Close all windows before exiting + if key == 27: + break + + def save_prediction(self): + if self.label_2.pixmap(): # 检查预测图像区域是否有图像 + # 保存预测结果的部分,这里假设你已经有了保存预测结果的代码,我用 placeholer 代替 + # placeholder: 这里假设 save_prediction_result 是一个保存预测结果的函数,它接受预测结果的图像数据以及保存路径作为参数 + # 这里假设预测结果图像数据为 prediction_image,保存路径为 save_path + prediction_image = self.label_2.pixmap().toImage() + save_path = "prediction_result.png" + prediction_image.save(save_path) + + # 调用 adjust_prediction_size 方法来根据 horizontalSlider 的值调整预测区域的大小 + self.adjust_prediction_size(self.horizontalSlider.value()) + + def adjust_prediction_size(self, value): + if self.image_files and self.current_index < len(self.image_files): + # 获取预测图像区域的原始大小 + pixmap = self.label_2.pixmap() + if pixmap.isNull(): + return + + original_size = pixmap.size() + + # 判断是缩小还是还原图像 + if value < self.last_value: + # 缩小掩码 + scale_factor = 1.0 + (self.last_value - value) * 0.1 + else: + # 放大掩码 + scale_factor = 1.0 - (value - self.last_value) * 0.1 + + self.last_value = value # 更新上一次的滑块值 + + # 根据缩放比例调整预测图像区域的大小,并保持纵横比例 + scaled_size = QtCore.QSize(original_size.width() * scale_factor, original_size.height() * scale_factor) + scaled_pixmap = pixmap.scaled(scaled_size, QtCore.Qt.KeepAspectRatio) + + # 更新预测图像区域的大小并显示 + self.label_2.setPixmap(scaled_pixmap) + + +if __name__ == "__main__": + app = QtWidgets.QApplication(sys.argv) + mainWindow = MyMainWindow() + mainWindow.show() + sys.exit(app.exec_()) + diff --git a/scripts/amg.py b/scripts/amg.py new file mode 100644 index 0000000..75693a6 --- /dev/null +++ b/scripts/amg.py @@ -0,0 +1,248 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import cv2 # type: ignore +import matplotlib.pyplot as plt + +from segment_anything import SamAutomaticMaskGenerator, sam_model_registry + +import argparse +import json +import os +from typing import Any, Dict, List +import numpy as np + +parser = argparse.ArgumentParser( + description=( + "Runs automatic mask generation on an input image or directory of images, " + "and outputs masks as either PNGs or COCO-style RLEs. Requires open-cv, " + "as well as pycocotools if saving in RLE format." + ) +) +# --input scripts/input/crops/Guide/23140.jpg --output scripts/output/crops --model-type vit_b --checkpoint D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth +parser.add_argument( + "--input", + type=str, + default=r'scripts/input/crops/Guide/23140.jpg', + required=False, + help="Path to either a single input image or folder of images.", +) + +parser.add_argument( + "--output", + type=str, + required=False, + default=r'scripts/output/crops/Guide/', + help=( + "Path to the directory where masks will be output. Output will be either a folder " + "of PNGs per image or a single json with COCO-style masks." + ), +) + +parser.add_argument( + "--model-type", + type=str, + required=False, + default='vit_b', + help="The type of model to load, in ['default', 'vit_h', 'vit_l', 'vit_b']", +) + +parser.add_argument( + "--checkpoint", + type=str, + required=False, + default=r'D:/Program Files/Pycharm items/segment-anything-model/weights/vit_b.pth', + help="The path to the SAM checkpoint to use for mask generation.", +) + +parser.add_argument("--device", type=str, default="cuda", help="The device to run generation on.") + +parser.add_argument( + "--convert-to-rle", + action="store_true", + help=( + "Save masks as COCO RLEs in a single json instead of as a folder of PNGs. " + "Requires pycocotools." + ), +) + +amg_settings = parser.add_argument_group("AMG Settings") + +amg_settings.add_argument( + "--points-per-side", + type=int, + default=None, + help="Generate masks by sampling a grid over the image with this many points to a side.", +) + +amg_settings.add_argument( + "--points-per-batch", + type=int, + default=None, + help="How many input points to process simultaneously in one batch.", +) + +amg_settings.add_argument( + "--pred-iou-thresh", + type=float, + default=None, + help="Exclude masks with a predicted score from the model that is lower than this threshold.", +) + +amg_settings.add_argument( + "--stability-score-thresh", + type=float, + default=None, + help="Exclude masks with a stability score lower than this threshold.", +) + +amg_settings.add_argument( + "--stability-score-offset", + type=float, + default=None, + help="Larger values perturb the mask more when measuring stability score.", +) + +amg_settings.add_argument( + "--box-nms-thresh", + type=float, + default=None, + help="The overlap threshold for excluding a duplicate mask.", +) + +amg_settings.add_argument( + "--crop-n-layers", + type=int, + default=None, + help=( + "If >0, mask generation is run on smaller crops of the image to generate more masks. " + "The value sets how many different scales to crop at." + ), +) + +amg_settings.add_argument( + "--crop-nms-thresh", + type=float, + default=None, + help="The overlap threshold for excluding duplicate masks across different crops.", +) + +amg_settings.add_argument( + "--crop-overlap-ratio", + type=int, + default=None, + help="Larger numbers mean image crops will overlap more.", +) + +amg_settings.add_argument( + "--crop-n-points-downscale-factor", + type=int, + default=None, + help="The number of points-per-side in each layer of crop is reduced by this factor.", +) + +amg_settings.add_argument( + "--min-mask-region-area", + type=int, + default=None, + help=( + "Disconnected mask regions or holes with area smaller than this value " + "in pixels are removed by postprocessing." + ), +) + + +def write_masks_to_folder(masks: List[Dict[str, Any]], path: str) -> None: + header = "id,area,bbox_x0,bbox_y0,bbox_w,bbox_h,point_input_x,point_input_y,predicted_iou,stability_score,crop_box_x0,crop_box_y0,crop_box_w,crop_box_h" # noqa + metadata = [header] + for i, mask_data in enumerate(masks): + mask = mask_data["segmentation"] + filename = f"{i}.png" + cv2.imwrite(os.path.join(path, filename), mask * 255) + mask_metadata = [ + str(i), + str(mask_data["area"]), + *[str(x) for x in mask_data["bbox"]], + *[str(x) for x in mask_data["point_coords"][0]], + str(mask_data["predicted_iou"]), + str(mask_data["stability_score"]), + *[str(x) for x in mask_data["crop_box"]], + ] + row = ",".join(mask_metadata) + metadata.append(row) + metadata_path = os.path.join(path, "metadata.csv") + with open(metadata_path, "w") as f: + f.write("\n".join(metadata)) + + return + + +def get_amg_kwargs(args): + amg_kwargs = { + "points_per_side": args.points_per_side, + "points_per_batch": args.points_per_batch, + "pred_iou_thresh": args.pred_iou_thresh, + "stability_score_thresh": args.stability_score_thresh, + "stability_score_offset": args.stability_score_offset, + "box_nms_thresh": args.box_nms_thresh, + "crop_n_layers": args.crop_n_layers, + "crop_nms_thresh": args.crop_nms_thresh, + "crop_overlap_ratio": args.crop_overlap_ratio, + "crop_n_points_downscale_factor": args.crop_n_points_downscale_factor, + "min_mask_region_area": args.min_mask_region_area, + } + amg_kwargs = {k: v for k, v in amg_kwargs.items() if v is not None} + return amg_kwargs + + +def main(args: argparse.Namespace) -> None: + print("Loading model...") + sam = sam_model_registry[args.model_type](checkpoint=args.checkpoint) + _ = sam.to(device=args.device) + output_mode = "coco_rle" if args.convert_to_rle else "binary_mask" + amg_kwargs = get_amg_kwargs(args) + generator = SamAutomaticMaskGenerator(sam, output_mode=output_mode, **amg_kwargs) + + if not os.path.isdir(args.input): + targets = [args.input] + else: + targets = [ + f for f in os.listdir(args.input) if not os.path.isdir(os.path.join(args.input, f)) + ] + targets = [os.path.join(args.input, f) for f in targets] + + os.makedirs(args.output, exist_ok=True) + + for t in targets: + print(f"Processing '{t}'...") + image = cv2.imread(t) + if image is None: + print(f"Could not load '{t}' as an image, skipping...") + continue + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + masks = generator.generate(image) + + base = os.path.basename(t) + base = os.path.splitext(base)[0] + save_base = os.path.join(args.output, base) + if output_mode == "binary_mask": + os.makedirs(save_base, exist_ok=False) + write_masks_to_folder(masks, save_base) + else: + save_file = save_base + ".json" + with open(save_file, "w") as f: + json.dump(masks, f) + print("Done!") + + +if __name__ == "__main__": + args = parser.parse_args() + main(args) +""" +D:\anaconda3\envs\pytorch\python.exe "D:/Program Files/Pycharm items/segment-anything-model/scripts/amg.py" --input "scripts/input/crops/Guide/23140.jpg" --output "scripts/output/crops" --model-type vit_b --checkpoint "D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth" + +""" \ No newline at end of file diff --git a/scripts/amg1.py b/scripts/amg1.py new file mode 100644 index 0000000..5a5a4ee --- /dev/null +++ b/scripts/amg1.py @@ -0,0 +1,272 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import cv2 # type: ignore + +from segment_anything import SamAutomaticMaskGenerator, sam_model_registry + +import argparse +import json +import os +from typing import Any, Dict, List + +parser = argparse.ArgumentParser( + description=( + "Runs automatic mask generation on an input image or directory of images, " + "and outputs masks as either PNGs or COCO-style RLEs. Requires open-cv, " + "as well as pycocotools if saving in RLE format." + ) +) +parser.add_argument( + "--input", + type=str, + required=False, + default=r'D:\Program Files\Pycharm items\segment-anything-model\scripts\input\images', + help="Path to either a single input image or folder of images.", +) + +parser.add_argument( + "--output", + type=str, + required=False, + default=r'output/mask', + help=( + "Path to the directory where masks will be output. Output will be either a folder " + "of PNGs per image or a single json with COCO-style masks." + ), +) + +parser.add_argument( + "--model-type", + type=str, + required=False, + default='vit_b', + help="The type of model to load, in ['default', 'vit_h', 'vit_l', 'vit_b']", +) + +parser.add_argument( + "--checkpoint", + type=str, + required=False, + default=r'D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth', + help="The path to the SAM checkpoint to use for mask generation.", +) + +# parser.add_argument( +# "--input", +# type=str, +# required=True, +# help="Path to either a single input image or folder of images.", +# ) +# +# parser.add_argument( +# "--output", +# type=str, +# required=True, +# help=( +# "Path to the directory where masks will be output. Output will be either a folder " +# "of PNGs per image or a single json with COCO-style masks." +# ), +# ) +# +# parser.add_argument( +# "--model-type", +# type=str, +# required=True, +# help="The type of model to load, in ['default', 'vit_h', 'vit_l', 'vit_b']", +# ) +# +# parser.add_argument( +# "--checkpoint", +# type=str, +# required=True, +# help="The path to the SAM checkpoint to use for mask generation.", +# ) + +parser.add_argument("--device", type=str, default="cuda", help="The device to run generation on.") + +parser.add_argument( + "--convert-to-rle", + action="store_true", + help=( + "Save masks as COCO RLEs in a single json instead of as a folder of PNGs. " + "Requires pycocotools." + ), +) + +amg_settings = parser.add_argument_group("AMG Settings") + +amg_settings.add_argument( + "--points-per-side", + type=int, + default=None, + help="Generate masks by sampling a grid over the image with this many points to a side.", +) + +amg_settings.add_argument( + "--points-per-batch", + type=int, + default=None, + help="How many input points to process simultaneously in one batch.", +) + +amg_settings.add_argument( + "--pred-iou-thresh", + type=float, + default=None, + help="Exclude masks with a predicted score from the model that is lower than this threshold.", +) + +amg_settings.add_argument( + "--stability-score-thresh", + type=float, + default=None, + help="Exclude masks with a stability score lower than this threshold.", +) + +amg_settings.add_argument( + "--stability-score-offset", + type=float, + default=None, + help="Larger values perturb the mask more when measuring stability score.", +) + +amg_settings.add_argument( + "--box-nms-thresh", + type=float, + default=None, + help="The overlap threshold for excluding a duplicate mask.", +) + +amg_settings.add_argument( + "--crop-n-layers", + type=int, + default=None, + help=( + "If >0, mask generation is run on smaller crops of the image to generate more masks. " + "The value sets how many different scales to crop at." + ), +) + +amg_settings.add_argument( + "--crop-nms-thresh", + type=float, + default=None, + help="The overlap threshold for excluding duplicate masks across different crops.", +) + +amg_settings.add_argument( + "--crop-overlap-ratio", + type=int, + default=None, + help="Larger numbers mean image crops will overlap more.", +) + +amg_settings.add_argument( + "--crop-n-points-downscale-factor", + type=int, + default=None, + help="The number of points-per-side in each layer of crop is reduced by this factor.", +) + +amg_settings.add_argument( + "--min-mask-region-area", + type=int, + default=None, + help=( + "Disconnected mask regions or holes with area smaller than this value " + "in pixels are removed by postprocessing." + ), +) + + +def write_masks_to_folder(masks: List[Dict[str, Any]], path: str) -> None: + header = "id,area,bbox_x0,bbox_y0,bbox_w,bbox_h,point_input_x,point_input_y,predicted_iou,stability_score,crop_box_x0,crop_box_y0,crop_box_w,crop_box_h" # noqa + metadata = [header] + for i, mask_data in enumerate(masks): + mask = mask_data["segmentation"] + filename = f"{i}.png" + cv2.imwrite(os.path.join(path, filename), mask * 255) + mask_metadata = [ + str(i), + str(mask_data["area"]), + *[str(x) for x in mask_data["bbox"]], + *[str(x) for x in mask_data["point_coords"][0]], + str(mask_data["predicted_iou"]), + str(mask_data["stability_score"]), + *[str(x) for x in mask_data["crop_box"]], + ] + row = ",".join(mask_metadata) + metadata.append(row) + metadata_path = os.path.join(path, "metadata.csv") + with open(metadata_path, "w") as f: + f.write("\n".join(metadata)) + + return + + +def get_amg_kwargs(args): + amg_kwargs = { + "points_per_side": args.points_per_side, + "points_per_batch": args.points_per_batch, + "pred_iou_thresh": args.pred_iou_thresh, + "stability_score_thresh": args.stability_score_thresh, + "stability_score_offset": args.stability_score_offset, + "box_nms_thresh": args.box_nms_thresh, + "crop_n_layers": args.crop_n_layers, + "crop_nms_thresh": args.crop_nms_thresh, + "crop_overlap_ratio": args.crop_overlap_ratio, + "crop_n_points_downscale_factor": args.crop_n_points_downscale_factor, + "min_mask_region_area": args.min_mask_region_area, + } + amg_kwargs = {k: v for k, v in amg_kwargs.items() if v is not None} + return amg_kwargs + + +def main(args: argparse.Namespace) -> None: + print("Loading model...") + sam = sam_model_registry[args.model_type](checkpoint=args.checkpoint) + _ = sam.to(device=args.device) + output_mode = "coco_rle" if args.convert_to_rle else "binary_mask" + amg_kwargs = get_amg_kwargs(args) + generator = SamAutomaticMaskGenerator(sam, output_mode=output_mode, **amg_kwargs) + + if not os.path.isdir(args.input): + targets = [args.input] + else: + targets = [ + f for f in os.listdir(args.input) if not os.path.isdir(os.path.join(args.input, f)) + ] + targets = [os.path.join(args.input, f) for f in targets] + + os.makedirs(args.output, exist_ok=True) + + for t in targets: + print(f"Processing '{t}'...") + image = cv2.imread(t) + if image is None: + print(f"Could not load '{t}' as an image, skipping...") + continue + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + masks = generator.generate(image) + + base = os.path.basename(t) + base = os.path.splitext(base)[0] + save_base = os.path.join(args.output, base) + if output_mode == "binary_mask": + os.makedirs(save_base, exist_ok=False) + write_masks_to_folder(masks, save_base) + else: + save_file = save_base + ".json" + with open(save_file, "w") as f: + json.dump(masks, f) + print("Done!") + + +if __name__ == "__main__": + args = parser.parse_args() + main(args) diff --git a/scripts/export_onnx_model.py b/scripts/export_onnx_model.py new file mode 100644 index 0000000..5c6f838 --- /dev/null +++ b/scripts/export_onnx_model.py @@ -0,0 +1,201 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from segment_anything import sam_model_registry +from segment_anything.utils.onnx import SamOnnxModel + +import argparse +import warnings + +try: + import onnxruntime # type: ignore + + onnxruntime_exists = True +except ImportError: + onnxruntime_exists = False + +parser = argparse.ArgumentParser( + description="Export the SAM prompt encoder and mask decoder to an ONNX model." +) + +parser.add_argument( + "--checkpoint", type=str, required=True, help="The path to the SAM model checkpoint." +) + +parser.add_argument( + "--output", type=str, required=True, help="The filename to save the ONNX model to." +) + +parser.add_argument( + "--model-type", + type=str, + required=True, + help="In ['default', 'vit_h', 'vit_l', 'vit_b']. Which type of SAM model to export.", +) + +parser.add_argument( + "--return-single-mask", + action="store_true", + help=( + "If true, the exported ONNX model will only return the best mask, " + "instead of returning multiple masks. For high resolution images " + "this can improve runtime when upscaling masks is expensive." + ), +) + +parser.add_argument( + "--opset", + type=int, + default=17, + help="The ONNX opset version to use. Must be >=11", +) + +parser.add_argument( + "--quantize-out", + type=str, + default=None, + help=( + "If set, will quantize the model and save it with this name. " + "Quantization is performed with quantize_dynamic from onnxruntime.quantization.quantize." + ), +) + +parser.add_argument( + "--gelu-approximate", + action="store_true", + help=( + "Replace GELU operations with approximations using tanh. Useful " + "for some runtimes that have slow or unimplemented erf ops, used in GELU." + ), +) + +parser.add_argument( + "--use-stability-score", + action="store_true", + help=( + "Replaces the model's predicted mask quality score with the stability " + "score calculated on the low resolution masks using an offset of 1.0. " + ), +) + +parser.add_argument( + "--return-extra-metrics", + action="store_true", + help=( + "The model will return five results: (masks, scores, stability_scores, " + "areas, low_res_logits) instead of the usual three. This can be " + "significantly slower for high resolution outputs." + ), +) + + +def run_export( + model_type: str, + checkpoint: str, + output: str, + opset: int, + return_single_mask: bool, + gelu_approximate: bool = False, + use_stability_score: bool = False, + return_extra_metrics=False, +): + print("Loading model...") + sam = sam_model_registry[model_type](checkpoint=checkpoint) + + onnx_model = SamOnnxModel( + model=sam, + return_single_mask=return_single_mask, + use_stability_score=use_stability_score, + return_extra_metrics=return_extra_metrics, + ) + + if gelu_approximate: + for n, m in onnx_model.named_modules(): + if isinstance(m, torch.nn.GELU): + m.approximate = "tanh" + + dynamic_axes = { + "point_coords": {1: "num_points"}, + "point_labels": {1: "num_points"}, + } + + embed_dim = sam.prompt_encoder.embed_dim + embed_size = sam.prompt_encoder.image_embedding_size + mask_input_size = [4 * x for x in embed_size] + dummy_inputs = { + "image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float), + "point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float), + "point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float), + "mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float), + "has_mask_input": torch.tensor([1], dtype=torch.float), + "orig_im_size": torch.tensor([1500, 2250], dtype=torch.float), + } + + _ = onnx_model(**dummy_inputs) + + output_names = ["masks", "iou_predictions", "low_res_masks"] + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) + warnings.filterwarnings("ignore", category=UserWarning) + with open(output, "wb") as f: + print(f"Exporting onnx model to {output}...") + torch.onnx.export( + onnx_model, + tuple(dummy_inputs.values()), + f, + export_params=True, + verbose=False, + opset_version=opset, + do_constant_folding=True, + input_names=list(dummy_inputs.keys()), + output_names=output_names, + dynamic_axes=dynamic_axes, + ) + + if onnxruntime_exists: + ort_inputs = {k: to_numpy(v) for k, v in dummy_inputs.items()} + # set cpu provider default + providers = ["CPUExecutionProvider"] + ort_session = onnxruntime.InferenceSession(output, providers=providers) + _ = ort_session.run(None, ort_inputs) + print("Model has successfully been run with ONNXRuntime.") + + +def to_numpy(tensor): + return tensor.cpu().numpy() + + +if __name__ == "__main__": + args = parser.parse_args() + run_export( + model_type=args.model_type, + checkpoint=args.checkpoint, + output=args.output, + opset=args.opset, + return_single_mask=args.return_single_mask, + gelu_approximate=args.gelu_approximate, + use_stability_score=args.use_stability_score, + return_extra_metrics=args.return_extra_metrics, + ) + + if args.quantize_out is not None: + assert onnxruntime_exists, "onnxruntime is required to quantize the model." + from onnxruntime.quantization import QuantType # type: ignore + from onnxruntime.quantization.quantize import quantize_dynamic # type: ignore + + print(f"Quantizing model and writing to {args.quantize_out}...") + quantize_dynamic( + model_input=args.output, + model_output=args.quantize_out, + optimize_model=True, + per_channel=False, + reduce_range=False, + weight_type=QuantType.QUInt8, + ) + print("Done!") diff --git a/scripts/input/crops/Guide/20194.jpg b/scripts/input/crops/Guide/20194.jpg new file mode 100644 index 0000000..dcf81fa Binary files /dev/null and b/scripts/input/crops/Guide/20194.jpg differ diff --git a/scripts/input/crops/Guide/201942.jpg b/scripts/input/crops/Guide/201942.jpg new file mode 100644 index 0000000..0a1b3d1 Binary files /dev/null and b/scripts/input/crops/Guide/201942.jpg differ diff --git a/scripts/input/crops/Guide/23120.jpg b/scripts/input/crops/Guide/23120.jpg new file mode 100644 index 0000000..1e09791 Binary files /dev/null and b/scripts/input/crops/Guide/23120.jpg differ diff --git a/scripts/input/crops/Guide/23140.jpg b/scripts/input/crops/Guide/23140.jpg new file mode 100644 index 0000000..31d7b2c Binary files /dev/null and b/scripts/input/crops/Guide/23140.jpg differ diff --git a/scripts/input/crops/Guide/23149.jpg b/scripts/input/crops/Guide/23149.jpg new file mode 100644 index 0000000..0fdef7b Binary files /dev/null and b/scripts/input/crops/Guide/23149.jpg differ diff --git a/scripts/input/crops/Guide/23192.jpg b/scripts/input/crops/Guide/23192.jpg new file mode 100644 index 0000000..26279af Binary files /dev/null and b/scripts/input/crops/Guide/23192.jpg differ diff --git a/scripts/input/crops/Guide/23196.jpg b/scripts/input/crops/Guide/23196.jpg new file mode 100644 index 0000000..914b0d7 Binary files /dev/null and b/scripts/input/crops/Guide/23196.jpg differ diff --git a/scripts/input/crops/Guide/231962.jpg b/scripts/input/crops/Guide/231962.jpg new file mode 100644 index 0000000..91e8c05 Binary files /dev/null and b/scripts/input/crops/Guide/231962.jpg differ diff --git a/scripts/input/crops/auxiliary/17359.jpg b/scripts/input/crops/auxiliary/17359.jpg new file mode 100644 index 0000000..13caa58 Binary files /dev/null and b/scripts/input/crops/auxiliary/17359.jpg differ diff --git a/scripts/input/crops/auxiliary/23140.jpg b/scripts/input/crops/auxiliary/23140.jpg new file mode 100644 index 0000000..a12bbe6 Binary files /dev/null and b/scripts/input/crops/auxiliary/23140.jpg differ diff --git a/scripts/input/crops/auxiliary/231402.jpg b/scripts/input/crops/auxiliary/231402.jpg new file mode 100644 index 0000000..1dc6b20 Binary files /dev/null and b/scripts/input/crops/auxiliary/231402.jpg differ diff --git a/scripts/input/crops/auxiliary/231403.jpg b/scripts/input/crops/auxiliary/231403.jpg new file mode 100644 index 0000000..0cd440d Binary files /dev/null and b/scripts/input/crops/auxiliary/231403.jpg differ diff --git a/scripts/input/crops/auxiliary/231404.jpg b/scripts/input/crops/auxiliary/231404.jpg new file mode 100644 index 0000000..e0b2f15 Binary files /dev/null and b/scripts/input/crops/auxiliary/231404.jpg differ diff --git a/scripts/input/crops/auxiliary/231405.jpg b/scripts/input/crops/auxiliary/231405.jpg new file mode 100644 index 0000000..188a75e Binary files /dev/null and b/scripts/input/crops/auxiliary/231405.jpg differ diff --git a/scripts/input/crops/indicative/11536.jpg b/scripts/input/crops/indicative/11536.jpg new file mode 100644 index 0000000..49094ee Binary files /dev/null and b/scripts/input/crops/indicative/11536.jpg differ diff --git a/scripts/input/crops/indicative/115362.jpg b/scripts/input/crops/indicative/115362.jpg new file mode 100644 index 0000000..6a99d63 Binary files /dev/null and b/scripts/input/crops/indicative/115362.jpg differ diff --git a/scripts/input/crops/indicative/115363.jpg b/scripts/input/crops/indicative/115363.jpg new file mode 100644 index 0000000..4cd2f97 Binary files /dev/null and b/scripts/input/crops/indicative/115363.jpg differ diff --git a/scripts/input/crops/indicative/17359.jpg b/scripts/input/crops/indicative/17359.jpg new file mode 100644 index 0000000..0cf1e77 Binary files /dev/null and b/scripts/input/crops/indicative/17359.jpg differ diff --git a/scripts/input/crops/indicative/17589.jpg b/scripts/input/crops/indicative/17589.jpg new file mode 100644 index 0000000..775b550 Binary files /dev/null and b/scripts/input/crops/indicative/17589.jpg differ diff --git a/scripts/input/crops/indicative/20670.jpg b/scripts/input/crops/indicative/20670.jpg new file mode 100644 index 0000000..e6d005b Binary files /dev/null and b/scripts/input/crops/indicative/20670.jpg differ diff --git a/scripts/input/crops/indicative/23109.jpg b/scripts/input/crops/indicative/23109.jpg new file mode 100644 index 0000000..4404664 Binary files /dev/null and b/scripts/input/crops/indicative/23109.jpg differ diff --git a/scripts/input/crops/indicative/231092.jpg b/scripts/input/crops/indicative/231092.jpg new file mode 100644 index 0000000..ca8636c Binary files /dev/null and b/scripts/input/crops/indicative/231092.jpg differ diff --git a/scripts/input/crops/indicative/231093.jpg b/scripts/input/crops/indicative/231093.jpg new file mode 100644 index 0000000..284089f Binary files /dev/null and b/scripts/input/crops/indicative/231093.jpg differ diff --git a/scripts/input/crops/indicative/23131.jpg b/scripts/input/crops/indicative/23131.jpg new file mode 100644 index 0000000..90671f9 Binary files /dev/null and b/scripts/input/crops/indicative/23131.jpg differ diff --git a/scripts/input/crops/indicative/231312.jpg b/scripts/input/crops/indicative/231312.jpg new file mode 100644 index 0000000..803f992 Binary files /dev/null and b/scripts/input/crops/indicative/231312.jpg differ diff --git a/scripts/input/crops/indicative/23136.jpg b/scripts/input/crops/indicative/23136.jpg new file mode 100644 index 0000000..6e971c7 Binary files /dev/null and b/scripts/input/crops/indicative/23136.jpg differ diff --git a/scripts/input/crops/indicative/23170.jpg b/scripts/input/crops/indicative/23170.jpg new file mode 100644 index 0000000..4e7394d Binary files /dev/null and b/scripts/input/crops/indicative/23170.jpg differ diff --git a/scripts/input/crops/indicative/23192.jpg b/scripts/input/crops/indicative/23192.jpg new file mode 100644 index 0000000..8096a32 Binary files /dev/null and b/scripts/input/crops/indicative/23192.jpg differ diff --git a/scripts/input/crops/indicative/231922.jpg b/scripts/input/crops/indicative/231922.jpg new file mode 100644 index 0000000..c94dfde Binary files /dev/null and b/scripts/input/crops/indicative/231922.jpg differ diff --git a/scripts/input/crops/indicative/23196.jpg b/scripts/input/crops/indicative/23196.jpg new file mode 100644 index 0000000..451d685 Binary files /dev/null and b/scripts/input/crops/indicative/23196.jpg differ diff --git a/scripts/input/crops/notice/23131.jpg b/scripts/input/crops/notice/23131.jpg new file mode 100644 index 0000000..0ad98a0 Binary files /dev/null and b/scripts/input/crops/notice/23131.jpg differ diff --git a/scripts/input/crops/notice/231312.jpg b/scripts/input/crops/notice/231312.jpg new file mode 100644 index 0000000..90671f9 Binary files /dev/null and b/scripts/input/crops/notice/231312.jpg differ diff --git a/scripts/input/crops/notice/23136.jpg b/scripts/input/crops/notice/23136.jpg new file mode 100644 index 0000000..ba6c644 Binary files /dev/null and b/scripts/input/crops/notice/23136.jpg differ diff --git a/scripts/input/crops/prohibit/15577.jpg b/scripts/input/crops/prohibit/15577.jpg new file mode 100644 index 0000000..a4b24c2 Binary files /dev/null and b/scripts/input/crops/prohibit/15577.jpg differ diff --git a/scripts/input/crops/prohibit/15974.jpg b/scripts/input/crops/prohibit/15974.jpg new file mode 100644 index 0000000..e44779b Binary files /dev/null and b/scripts/input/crops/prohibit/15974.jpg differ diff --git a/scripts/input/crops/prohibit/17359.jpg b/scripts/input/crops/prohibit/17359.jpg new file mode 100644 index 0000000..b3033fb Binary files /dev/null and b/scripts/input/crops/prohibit/17359.jpg differ diff --git a/scripts/input/crops/prohibit/173592.jpg b/scripts/input/crops/prohibit/173592.jpg new file mode 100644 index 0000000..e6bc7b9 Binary files /dev/null and b/scripts/input/crops/prohibit/173592.jpg differ diff --git a/scripts/input/crops/prohibit/17589.jpg b/scripts/input/crops/prohibit/17589.jpg new file mode 100644 index 0000000..dc16878 Binary files /dev/null and b/scripts/input/crops/prohibit/17589.jpg differ diff --git a/scripts/input/crops/prohibit/20194.jpg b/scripts/input/crops/prohibit/20194.jpg new file mode 100644 index 0000000..0a8a081 Binary files /dev/null and b/scripts/input/crops/prohibit/20194.jpg differ diff --git a/scripts/input/crops/prohibit/22430.jpg b/scripts/input/crops/prohibit/22430.jpg new file mode 100644 index 0000000..72f8547 Binary files /dev/null and b/scripts/input/crops/prohibit/22430.jpg differ diff --git a/scripts/input/crops/prohibit/23109.jpg b/scripts/input/crops/prohibit/23109.jpg new file mode 100644 index 0000000..25ec71b Binary files /dev/null and b/scripts/input/crops/prohibit/23109.jpg differ diff --git a/scripts/input/crops/prohibit/23120.jpg b/scripts/input/crops/prohibit/23120.jpg new file mode 100644 index 0000000..261939d Binary files /dev/null and b/scripts/input/crops/prohibit/23120.jpg differ diff --git a/scripts/input/crops/prohibit/23131.jpg b/scripts/input/crops/prohibit/23131.jpg new file mode 100644 index 0000000..65f01ec Binary files /dev/null and b/scripts/input/crops/prohibit/23131.jpg differ diff --git a/scripts/input/crops/prohibit/231312.jpg b/scripts/input/crops/prohibit/231312.jpg new file mode 100644 index 0000000..fdc84ac Binary files /dev/null and b/scripts/input/crops/prohibit/231312.jpg differ diff --git a/scripts/input/crops/prohibit/23136.jpg b/scripts/input/crops/prohibit/23136.jpg new file mode 100644 index 0000000..b5bcd97 Binary files /dev/null and b/scripts/input/crops/prohibit/23136.jpg differ diff --git a/scripts/input/crops/prohibit/231362.jpg b/scripts/input/crops/prohibit/231362.jpg new file mode 100644 index 0000000..1155e2a Binary files /dev/null and b/scripts/input/crops/prohibit/231362.jpg differ diff --git a/scripts/input/crops/prohibit/231363.jpg b/scripts/input/crops/prohibit/231363.jpg new file mode 100644 index 0000000..b430ae4 Binary files /dev/null and b/scripts/input/crops/prohibit/231363.jpg differ diff --git a/scripts/input/crops/prohibit/231364.jpg b/scripts/input/crops/prohibit/231364.jpg new file mode 100644 index 0000000..687298e Binary files /dev/null and b/scripts/input/crops/prohibit/231364.jpg differ diff --git a/scripts/input/crops/prohibit/23140.jpg b/scripts/input/crops/prohibit/23140.jpg new file mode 100644 index 0000000..a12bbe6 Binary files /dev/null and b/scripts/input/crops/prohibit/23140.jpg differ diff --git a/scripts/input/crops/prohibit/231402.jpg b/scripts/input/crops/prohibit/231402.jpg new file mode 100644 index 0000000..bfba263 Binary files /dev/null and b/scripts/input/crops/prohibit/231402.jpg differ diff --git a/scripts/input/crops/prohibit/231403.jpg b/scripts/input/crops/prohibit/231403.jpg new file mode 100644 index 0000000..1c0a3c9 Binary files /dev/null and b/scripts/input/crops/prohibit/231403.jpg differ diff --git a/scripts/input/crops/prohibit/231404.jpg b/scripts/input/crops/prohibit/231404.jpg new file mode 100644 index 0000000..a3eae6a Binary files /dev/null and b/scripts/input/crops/prohibit/231404.jpg differ diff --git a/scripts/input/crops/prohibit/231405.jpg b/scripts/input/crops/prohibit/231405.jpg new file mode 100644 index 0000000..42a064d Binary files /dev/null and b/scripts/input/crops/prohibit/231405.jpg differ diff --git a/scripts/input/crops/prohibit/23149.jpg b/scripts/input/crops/prohibit/23149.jpg new file mode 100644 index 0000000..bfee290 Binary files /dev/null and b/scripts/input/crops/prohibit/23149.jpg differ diff --git a/scripts/input/crops/prohibit/23170.jpg b/scripts/input/crops/prohibit/23170.jpg new file mode 100644 index 0000000..c3481d7 Binary files /dev/null and b/scripts/input/crops/prohibit/23170.jpg differ diff --git a/scripts/input/crops/prohibit/23192.jpg b/scripts/input/crops/prohibit/23192.jpg new file mode 100644 index 0000000..aeaf3fb Binary files /dev/null and b/scripts/input/crops/prohibit/23192.jpg differ diff --git a/scripts/input/crops/prohibit/231922.jpg b/scripts/input/crops/prohibit/231922.jpg new file mode 100644 index 0000000..d8c3267 Binary files /dev/null and b/scripts/input/crops/prohibit/231922.jpg differ diff --git a/scripts/input/crops/prohibit/231923.jpg b/scripts/input/crops/prohibit/231923.jpg new file mode 100644 index 0000000..76f8936 Binary files /dev/null and b/scripts/input/crops/prohibit/231923.jpg differ diff --git a/scripts/input/crops/prohibit/231924.jpg b/scripts/input/crops/prohibit/231924.jpg new file mode 100644 index 0000000..6d3e9e0 Binary files /dev/null and b/scripts/input/crops/prohibit/231924.jpg differ diff --git a/scripts/input/crops/prohibit/231925.jpg b/scripts/input/crops/prohibit/231925.jpg new file mode 100644 index 0000000..7bd852b Binary files /dev/null and b/scripts/input/crops/prohibit/231925.jpg differ diff --git a/scripts/input/crops/prohibit/23196.jpg b/scripts/input/crops/prohibit/23196.jpg new file mode 100644 index 0000000..7348910 Binary files /dev/null and b/scripts/input/crops/prohibit/23196.jpg differ diff --git a/scripts/input/crops/warning/23131.jpg b/scripts/input/crops/warning/23131.jpg new file mode 100644 index 0000000..be348a3 Binary files /dev/null and b/scripts/input/crops/warning/23131.jpg differ diff --git a/scripts/input/crops/warning/231312.jpg b/scripts/input/crops/warning/231312.jpg new file mode 100644 index 0000000..732cee7 Binary files /dev/null and b/scripts/input/crops/warning/231312.jpg differ diff --git a/scripts/input/crops/warning/231313.jpg b/scripts/input/crops/warning/231313.jpg new file mode 100644 index 0000000..ff1d931 Binary files /dev/null and b/scripts/input/crops/warning/231313.jpg differ diff --git a/scripts/input/crops/warning/231314.jpg b/scripts/input/crops/warning/231314.jpg new file mode 100644 index 0000000..56f2640 Binary files /dev/null and b/scripts/input/crops/warning/231314.jpg differ diff --git a/scripts/input/images/1.png b/scripts/input/images/1.png new file mode 100644 index 0000000..5f79e27 Binary files /dev/null and b/scripts/input/images/1.png differ diff --git a/scripts/input/images/1048.jpg b/scripts/input/images/1048.jpg new file mode 100644 index 0000000..b999327 Binary files /dev/null and b/scripts/input/images/1048.jpg differ diff --git a/scripts/input/images/1204.jpg b/scripts/input/images/1204.jpg new file mode 100644 index 0000000..116c5ae Binary files /dev/null and b/scripts/input/images/1204.jpg differ diff --git a/scripts/input/images/1220.jpg b/scripts/input/images/1220.jpg new file mode 100644 index 0000000..410d96f Binary files /dev/null and b/scripts/input/images/1220.jpg differ diff --git a/scripts/input/images/2.jpg b/scripts/input/images/2.jpg new file mode 100644 index 0000000..576993d Binary files /dev/null and b/scripts/input/images/2.jpg differ diff --git a/scripts/input/images/285.jpg b/scripts/input/images/285.jpg new file mode 100644 index 0000000..6949fd0 Binary files /dev/null and b/scripts/input/images/285.jpg differ diff --git a/scripts/input/images/3700.jpg b/scripts/input/images/3700.jpg new file mode 100644 index 0000000..239fa91 Binary files /dev/null and b/scripts/input/images/3700.jpg differ diff --git a/scripts/input/images/725.jpg b/scripts/input/images/725.jpg new file mode 100644 index 0000000..f18fcb6 Binary files /dev/null and b/scripts/input/images/725.jpg differ diff --git a/scripts/input/img/725.jpg b/scripts/input/img/725.jpg new file mode 100644 index 0000000..f18fcb6 Binary files /dev/null and b/scripts/input/img/725.jpg differ diff --git a/scripts/namg.py b/scripts/namg.py new file mode 100644 index 0000000..89c0c11 --- /dev/null +++ b/scripts/namg.py @@ -0,0 +1,324 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import cv2 # type: ignore +import matplotlib.pyplot as plt + +from segment_anything import SamAutomaticMaskGenerator, sam_model_registry + +import argparse +import json +import os +from typing import Any, Dict, List +import numpy as np + +parser = argparse.ArgumentParser( + description=( + "Runs automatic mask generation on an input image or directory of images, " + "and outputs masks as either PNGs or COCO-style RLEs. Requires open-cv, " + "as well as pycocotools if saving in RLE format." + ) +) + +parser.add_argument( + "--input", + type=str, + default=r'D:\Program Files\Pycharm items\segment-anything-model\scripts\input\crops\warning\231314.jpg', + required=False, + help="Path to either a single input image or folder of images.", +) + +parser.add_argument( + "--output", + type=str, + required=False, + default=r'D:\Program Files\Pycharm items\segment-anything-model\scripts\output\crops', + help=( + "Path to the directory where masks will be output. Output will be either a folder " + "of PNGs per image or a single json with COCO-style masks." + ), +) + +parser.add_argument( + "--model-type", + type=str, + required=False, + default='vit_b', + help="The type of model to load, in ['default', 'vit_h', 'vit_l', 'vit_b']", +) + +parser.add_argument( + "--checkpoint", + type=str, + required=False, + default=r'D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth', + help="The path to the SAM checkpoint to use for mask generation.", +) + +parser.add_argument("--device", type=str, default="cuda", help="The device to run generation on.") + +parser.add_argument( + "--convert-to-rle", + action="store_true", + help=( + "Save masks as COCO RLEs in a single json instead of as a folder of PNGs. " + "Requires pycocotools." + ), +) + +amg_settings = parser.add_argument_group("AMG Settings") + +amg_settings.add_argument( + "--points-per-side", + type=int, + default=None, + help="Generate masks by sampling a grid over the image with this many points to a side.", +) + +amg_settings.add_argument( + "--points-per-batch", + type=int, + default=None, + help="How many input points to process simultaneously in one batch.", +) + +amg_settings.add_argument( + "--pred-iou-thresh", + type=float, + default=None, + help="Exclude masks with a predicted score from the model that is lower than this threshold.", +) + +amg_settings.add_argument( + "--stability-score-thresh", + type=float, + default=None, + help="Exclude masks with a stability score lower than this threshold.", +) + +amg_settings.add_argument( + "--stability-score-offset", + type=float, + default=None, + help="Larger values perturb the mask more when measuring stability score.", +) + +amg_settings.add_argument( + "--box-nms-thresh", + type=float, + default=None, + help="The overlap threshold for excluding a duplicate mask.", +) + +amg_settings.add_argument( + "--crop-n-layers", + type=int, + default=None, + help=( + "If >0, mask generation is run on smaller crops of the image to generate more masks. " + "The value sets how many different scales to crop at." + ), +) + +amg_settings.add_argument( + "--crop-nms-thresh", + type=float, + default=None, + help="The overlap threshold for excluding duplicate masks across different crops.", +) + +amg_settings.add_argument( + "--crop-overlap-ratio", + type=int, + default=None, + help="Larger numbers mean image crops will overlap more.", +) + +amg_settings.add_argument( + "--crop-n-points-downscale-factor", + type=int, + default=None, + help="The number of points-per-side in each layer of crop is reduced by this factor.", +) + +amg_settings.add_argument( + "--min-mask-region-area", + type=int, + default=None, + help=( + "Disconnected mask regions or holes with area smaller than this value " + "in pixels are removed by postprocessing." + ), +) + + +def write_masks_to_folder(masks: List[Dict[str, Any]], path: str) -> None: + header = "id,area,bbox_x0,bbox_y0,bbox_w,bbox_h,point_input_x,point_input_y,predicted_iou,stability_score,crop_box_x0,crop_box_y0,crop_box_w,crop_box_h" # noqa + metadata = [header] + for i, mask_data in enumerate(masks): + mask = mask_data["segmentation"] + filename = f"{i}.png" + cv2.imwrite(os.path.join(path, filename), mask * 255) + mask_metadata = [ + str(i), + str(mask_data["area"]), + *[str(x) for x in mask_data["bbox"]], + *[str(x) for x in mask_data["point_coords"][0]], + str(mask_data["predicted_iou"]), + str(mask_data["stability_score"]), + *[str(x) for x in mask_data["crop_box"]], + ] + row = ",".join(mask_metadata) + metadata.append(row) + metadata_path = os.path.join(path, "metadata.csv") + with open(metadata_path, "w") as f: + f.write("\n".join(metadata)) + + return + + +def get_amg_kwargs(args): + amg_kwargs = { + "points_per_side": args.points_per_side, + "points_per_batch": args.points_per_batch, + "pred_iou_thresh": args.pred_iou_thresh, + "stability_score_thresh": args.stability_score_thresh, + "stability_score_offset": args.stability_score_offset, + "box_nms_thresh": args.box_nms_thresh, + "crop_n_layers": args.crop_n_layers, + "crop_nms_thresh": args.crop_nms_thresh, + "crop_overlap_ratio": args.crop_overlap_ratio, + "crop_n_points_downscale_factor": args.crop_n_points_downscale_factor, + "min_mask_region_area": args.min_mask_region_area, + } + amg_kwargs = {k: v for k, v in amg_kwargs.items() if v is not None} + return amg_kwargs + + +def show_mask(mask, ax, random_color=True): + if random_color: + color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) + else: + color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6]) + h, w = mask.shape[-2:] + mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) + ax.imshow(mask_image) + + +def show_points(coords, labels, ax, marker_size=375): + pos_points = coords[labels == 1] + neg_points = coords[labels == 0] + ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', + linewidth=1.25) + ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', + linewidth=1.25) + + +def show_box(box, ax): + x0, y0 = box[0], box[1] + w, h = box[2] - box[0], box[3] - box[1] + ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2)) + + +# def main(args: argparse.Namespace) -> None: +# print("Loading model...") +# sam = sam_model_registry[args.model_type](checkpoint=args.checkpoint) +# _ = sam.to(device=args.device) +# # output_mode = "coco_rle" if args.convert_to_rle else "binary_mask" +# output_mode = 'binary_mask' +# amg_kwargs = get_amg_kwargs(args) +# generator = SamAutomaticMaskGenerator(sam, output_mode=output_mode, **amg_kwargs) +# +# floader_path = r"D:\Program Files\Pycharm items\segment-anything-model\scripts\input\crops\Guide" # 这里为一批图像所在的文件夹 +# file_path = os.listdir(floader_path) +# for im in file_path: +# args.input = os.path.join(floader_path, im) +# if not os.path.isdir(args.input): +# targets = [args.input] +# else: +# targets = [ +# f for f in os.listdir(args.input) if not os.path.isdir(os.path.join(args.input, f)) +# ] +# targets = [os.path.join(args.input, f) for f in targets] +# +# os.makedirs(args.output, exist_ok=True) +# +# for t in targets: +# print(f"Processing '{t}'...") +# image = cv2.imread(t) +# if image is None: +# print(f"Could not load '{t}' as an image, skipping...") +# continue +# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) +# masks = generator.generate(image) +# +# +# plt.imshow(image) # 这里自行选择加不加 +# for mask in masks: +# show_mask(mask['segmentation'], plt.gca()) +# # show_box(mask['bbox'],plt.gca()) +# # plt.axis('off') 保存轴线或不保存 +# plt.savefig(r'D:\Program Files\Pycharm items\segment-anything-model\scripts\output\crops'+im) # 这里要替换为自己的路径 +# plt.close() +# print("Done!") + + + +def main(args: argparse.Namespace) -> None: + print("Loading model...") + sam = sam_model_registry[args.model_type](checkpoint=args.checkpoint) + _ = sam.to(device=args.device) + output_mode = 'binary_mask' + amg_kwargs = get_amg_kwargs(args) + generator = SamAutomaticMaskGenerator(sam, output_mode=output_mode, **amg_kwargs) + + if not os.path.isdir(args.input): + targets = [args.input] + else: + targets = [ + f for f in os.listdir(args.input) if not os.path.isdir(os.path.join(args.input, f)) + ] + targets = [os.path.join(args.input, f) for f in targets] + + os.makedirs(args.output, exist_ok=True) + + for t in targets: + print(f"Processing '{t}'...") + image = cv2.imread(t) + if image is None: + print(f"Could not load '{t}' as an image, skipping...") + continue + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + masks = generator.generate(image) + + # For visualization (optional) + plt.imshow(image) + for mask in masks: + show_mask(mask['segmentation'], plt.gca()) + show_box(mask['bbox'], plt.gca()) + plt.show() + + # For saving masks + base = os.path.basename(t) + base = os.path.splitext(base)[0] + save_base = os.path.join(args.output, base) + if output_mode == "binary_mask": + os.makedirs(save_base, exist_ok=False) + for idx, mask in enumerate(masks): + mask_image = mask['segmentation'].astype('uint8') * 255 + save_path = os.path.join(save_base, f"mask_{idx}.png") + cv2.imwrite(save_path, mask_image) + else: + save_file = save_base + ".json" + with open(save_file, "w") as f: + json.dump(masks, f) + + print("Done!") + + +if __name__ == "__main__": + args = parser.parse_args() + main(args) diff --git a/scripts/output/crops/231314/mask_0.png b/scripts/output/crops/231314/mask_0.png new file mode 100644 index 0000000..82024ca Binary files /dev/null and b/scripts/output/crops/231314/mask_0.png differ diff --git a/scripts/output/crops/231314/mask_1.png b/scripts/output/crops/231314/mask_1.png new file mode 100644 index 0000000..b0cbee3 Binary files /dev/null and b/scripts/output/crops/231314/mask_1.png differ diff --git a/scripts/output/crops/231314/mask_2.png b/scripts/output/crops/231314/mask_2.png new file mode 100644 index 0000000..9892cb6 Binary files /dev/null and b/scripts/output/crops/231314/mask_2.png differ diff --git a/scripts/output/crops/231314/mask_3.png b/scripts/output/crops/231314/mask_3.png new file mode 100644 index 0000000..384031e Binary files /dev/null and b/scripts/output/crops/231314/mask_3.png differ diff --git a/scripts/output/crops/231314/mask_4.png b/scripts/output/crops/231314/mask_4.png new file mode 100644 index 0000000..613592b Binary files /dev/null and b/scripts/output/crops/231314/mask_4.png differ diff --git a/scripts/output/crops/231314/mask_5.png b/scripts/output/crops/231314/mask_5.png new file mode 100644 index 0000000..a67bbbb Binary files /dev/null and b/scripts/output/crops/231314/mask_5.png differ diff --git a/scripts/output/crops/warning/23131_1.png b/scripts/output/crops/warning/23131_1.png new file mode 100644 index 0000000..a27d1f6 Binary files /dev/null and b/scripts/output/crops/warning/23131_1.png differ diff --git a/scripts/output/crops/warning/23131_2.png b/scripts/output/crops/warning/23131_2.png new file mode 100644 index 0000000..a554042 Binary files /dev/null and b/scripts/output/crops/warning/23131_2.png differ diff --git a/scripts/output/crops/warning/23131_3.png b/scripts/output/crops/warning/23131_3.png new file mode 100644 index 0000000..590ff72 Binary files /dev/null and b/scripts/output/crops/warning/23131_3.png differ diff --git a/scripts/output/crops/warning/23131_4.png b/scripts/output/crops/warning/23131_4.png new file mode 100644 index 0000000..53ec622 Binary files /dev/null and b/scripts/output/crops/warning/23131_4.png differ diff --git a/scripts/output/crops20194.jpg b/scripts/output/crops20194.jpg new file mode 100644 index 0000000..28d8501 Binary files /dev/null and b/scripts/output/crops20194.jpg differ diff --git a/scripts/output/crops201942.jpg b/scripts/output/crops201942.jpg new file mode 100644 index 0000000..8ee504e Binary files /dev/null and b/scripts/output/crops201942.jpg differ diff --git a/scripts/output/crops23120.jpg b/scripts/output/crops23120.jpg new file mode 100644 index 0000000..60fb6e2 Binary files /dev/null and b/scripts/output/crops23120.jpg differ diff --git a/scripts/output/crops23140.jpg b/scripts/output/crops23140.jpg new file mode 100644 index 0000000..ec2ed8d Binary files /dev/null and b/scripts/output/crops23140.jpg differ diff --git a/scripts/output/crops23149.jpg b/scripts/output/crops23149.jpg new file mode 100644 index 0000000..33bbea8 Binary files /dev/null and b/scripts/output/crops23149.jpg differ diff --git a/scripts/output/crops23192.jpg b/scripts/output/crops23192.jpg new file mode 100644 index 0000000..0d461d2 Binary files /dev/null and b/scripts/output/crops23192.jpg differ diff --git a/scripts/output/crops23196.jpg b/scripts/output/crops23196.jpg new file mode 100644 index 0000000..383c32c Binary files /dev/null and b/scripts/output/crops23196.jpg differ diff --git a/scripts/output/crops231962.jpg b/scripts/output/crops231962.jpg new file mode 100644 index 0000000..bb3c9a5 Binary files /dev/null and b/scripts/output/crops231962.jpg differ diff --git a/scripts/output/mask/1048_masked.png b/scripts/output/mask/1048_masked.png new file mode 100644 index 0000000..dcc3fa9 Binary files /dev/null and b/scripts/output/mask/1048_masked.png differ diff --git a/scripts/output/mask/1204_masked.png b/scripts/output/mask/1204_masked.png new file mode 100644 index 0000000..6775033 Binary files /dev/null and b/scripts/output/mask/1204_masked.png differ diff --git a/scripts/output/mask/285_masked.png b/scripts/output/mask/285_masked.png new file mode 100644 index 0000000..d50b4a7 Binary files /dev/null and b/scripts/output/mask/285_masked.png differ diff --git a/scripts/output/mask/3700_1.png b/scripts/output/mask/3700_1.png new file mode 100644 index 0000000..14b6bf3 Binary files /dev/null and b/scripts/output/mask/3700_1.png differ diff --git a/scripts/output/mask/3700_2.png b/scripts/output/mask/3700_2.png new file mode 100644 index 0000000..7acc840 Binary files /dev/null and b/scripts/output/mask/3700_2.png differ diff --git a/scripts/output/mask/3700_masked.png b/scripts/output/mask/3700_masked.png new file mode 100644 index 0000000..62075f9 Binary files /dev/null and b/scripts/output/mask/3700_masked.png differ diff --git a/scripts/output/mask/725_masked.png b/scripts/output/mask/725_masked.png new file mode 100644 index 0000000..cb7e1ff Binary files /dev/null and b/scripts/output/mask/725_masked.png differ diff --git a/scripts/output/mask/725_masked_with_semantic_info.png b/scripts/output/mask/725_masked_with_semantic_info.png new file mode 100644 index 0000000..1950e36 Binary files /dev/null and b/scripts/output/mask/725_masked_with_semantic_info.png differ diff --git a/scripts/output/maskt/1_1.png b/scripts/output/maskt/1_1.png new file mode 100644 index 0000000..1608551 Binary files /dev/null and b/scripts/output/maskt/1_1.png differ diff --git a/scripts/output/maskt/1_2.png b/scripts/output/maskt/1_2.png new file mode 100644 index 0000000..7a4adef Binary files /dev/null and b/scripts/output/maskt/1_2.png differ diff --git a/scripts/output/maskt/1_3.png b/scripts/output/maskt/1_3.png new file mode 100644 index 0000000..3ae1479 Binary files /dev/null and b/scripts/output/maskt/1_3.png differ diff --git a/scripts/output/maskt/2_1.png b/scripts/output/maskt/2_1.png new file mode 100644 index 0000000..8f421fc Binary files /dev/null and b/scripts/output/maskt/2_1.png differ diff --git a/scripts/output/maskt/725_1.png b/scripts/output/maskt/725_1.png new file mode 100644 index 0000000..4c4171f Binary files /dev/null and b/scripts/output/maskt/725_1.png differ diff --git a/scripts/output/maskt/725_10.png b/scripts/output/maskt/725_10.png new file mode 100644 index 0000000..c8230d9 Binary files /dev/null and b/scripts/output/maskt/725_10.png differ diff --git a/scripts/output/maskt/725_11.png b/scripts/output/maskt/725_11.png new file mode 100644 index 0000000..4a6be62 Binary files /dev/null and b/scripts/output/maskt/725_11.png differ diff --git a/scripts/output/maskt/725_12.png b/scripts/output/maskt/725_12.png new file mode 100644 index 0000000..cfbf217 Binary files /dev/null and b/scripts/output/maskt/725_12.png differ diff --git a/scripts/output/maskt/725_13.png b/scripts/output/maskt/725_13.png new file mode 100644 index 0000000..eac527e Binary files /dev/null and b/scripts/output/maskt/725_13.png differ diff --git a/scripts/output/maskt/725_14.png b/scripts/output/maskt/725_14.png new file mode 100644 index 0000000..eb7f297 Binary files /dev/null and b/scripts/output/maskt/725_14.png differ diff --git a/scripts/output/maskt/725_15.png b/scripts/output/maskt/725_15.png new file mode 100644 index 0000000..3b4c74c Binary files /dev/null and b/scripts/output/maskt/725_15.png differ diff --git a/scripts/output/maskt/725_16.png b/scripts/output/maskt/725_16.png new file mode 100644 index 0000000..b589405 Binary files /dev/null and b/scripts/output/maskt/725_16.png differ diff --git a/scripts/output/maskt/725_17.png b/scripts/output/maskt/725_17.png new file mode 100644 index 0000000..23a83a2 Binary files /dev/null and b/scripts/output/maskt/725_17.png differ diff --git a/scripts/output/maskt/725_18.png b/scripts/output/maskt/725_18.png new file mode 100644 index 0000000..d9ce234 Binary files /dev/null and b/scripts/output/maskt/725_18.png differ diff --git a/scripts/output/maskt/725_19.png b/scripts/output/maskt/725_19.png new file mode 100644 index 0000000..f901410 Binary files /dev/null and b/scripts/output/maskt/725_19.png differ diff --git a/scripts/output/maskt/725_2.png b/scripts/output/maskt/725_2.png new file mode 100644 index 0000000..1177316 Binary files /dev/null and b/scripts/output/maskt/725_2.png differ diff --git a/scripts/output/maskt/725_20.png b/scripts/output/maskt/725_20.png new file mode 100644 index 0000000..104244a Binary files /dev/null and b/scripts/output/maskt/725_20.png differ diff --git a/scripts/output/maskt/725_21.png b/scripts/output/maskt/725_21.png new file mode 100644 index 0000000..ede0da5 Binary files /dev/null and b/scripts/output/maskt/725_21.png differ diff --git a/scripts/output/maskt/725_22.png b/scripts/output/maskt/725_22.png new file mode 100644 index 0000000..5bcc88f Binary files /dev/null and b/scripts/output/maskt/725_22.png differ diff --git a/scripts/output/maskt/725_23.png b/scripts/output/maskt/725_23.png new file mode 100644 index 0000000..7fc01d9 Binary files /dev/null and b/scripts/output/maskt/725_23.png differ diff --git a/scripts/output/maskt/725_24.png b/scripts/output/maskt/725_24.png new file mode 100644 index 0000000..421194b Binary files /dev/null and b/scripts/output/maskt/725_24.png differ diff --git a/scripts/output/maskt/725_25.png b/scripts/output/maskt/725_25.png new file mode 100644 index 0000000..f5a9da7 Binary files /dev/null and b/scripts/output/maskt/725_25.png differ diff --git a/scripts/output/maskt/725_26.png b/scripts/output/maskt/725_26.png new file mode 100644 index 0000000..c97f3c9 Binary files /dev/null and b/scripts/output/maskt/725_26.png differ diff --git a/scripts/output/maskt/725_27.png b/scripts/output/maskt/725_27.png new file mode 100644 index 0000000..cab8c9e Binary files /dev/null and b/scripts/output/maskt/725_27.png differ diff --git a/scripts/output/maskt/725_28.png b/scripts/output/maskt/725_28.png new file mode 100644 index 0000000..2a89a19 Binary files /dev/null and b/scripts/output/maskt/725_28.png differ diff --git a/scripts/output/maskt/725_29.png b/scripts/output/maskt/725_29.png new file mode 100644 index 0000000..6d7e1c4 Binary files /dev/null and b/scripts/output/maskt/725_29.png differ diff --git a/scripts/output/maskt/725_3.png b/scripts/output/maskt/725_3.png new file mode 100644 index 0000000..c09f373 Binary files /dev/null and b/scripts/output/maskt/725_3.png differ diff --git a/scripts/output/maskt/725_30.png b/scripts/output/maskt/725_30.png new file mode 100644 index 0000000..6654e60 Binary files /dev/null and b/scripts/output/maskt/725_30.png differ diff --git a/scripts/output/maskt/725_31.png b/scripts/output/maskt/725_31.png new file mode 100644 index 0000000..c3e8915 Binary files /dev/null and b/scripts/output/maskt/725_31.png differ diff --git a/scripts/output/maskt/725_4.png b/scripts/output/maskt/725_4.png new file mode 100644 index 0000000..9bf3a44 Binary files /dev/null and b/scripts/output/maskt/725_4.png differ diff --git a/scripts/output/maskt/725_5.png b/scripts/output/maskt/725_5.png new file mode 100644 index 0000000..3d6f872 Binary files /dev/null and b/scripts/output/maskt/725_5.png differ diff --git a/scripts/output/maskt/725_6.png b/scripts/output/maskt/725_6.png new file mode 100644 index 0000000..c54adef Binary files /dev/null and b/scripts/output/maskt/725_6.png differ diff --git a/scripts/output/maskt/725_7.png b/scripts/output/maskt/725_7.png new file mode 100644 index 0000000..37e198f Binary files /dev/null and b/scripts/output/maskt/725_7.png differ diff --git a/scripts/output/maskt/725_8.png b/scripts/output/maskt/725_8.png new file mode 100644 index 0000000..b7a787d Binary files /dev/null and b/scripts/output/maskt/725_8.png differ diff --git a/scripts/output/maskt/725_9.png b/scripts/output/maskt/725_9.png new file mode 100644 index 0000000..854224f Binary files /dev/null and b/scripts/output/maskt/725_9.png differ diff --git a/scripts/pv/0.1/gangwa/image/1.jpg b/scripts/pv/0.1/gangwa/image/1.jpg new file mode 100644 index 0000000..a2d80c1 Binary files /dev/null and b/scripts/pv/0.1/gangwa/image/1.jpg differ diff --git a/scripts/pv/0.1/gangwa/image/2.jpg b/scripts/pv/0.1/gangwa/image/2.jpg new file mode 100644 index 0000000..4949295 Binary files /dev/null and b/scripts/pv/0.1/gangwa/image/2.jpg differ diff --git a/scripts/pv/0.1/gangwa/image/3.jpg b/scripts/pv/0.1/gangwa/image/3.jpg new file mode 100644 index 0000000..82dc892 Binary files /dev/null and b/scripts/pv/0.1/gangwa/image/3.jpg differ diff --git a/scripts/pv/0.1/gangwa/mask/1_masked.png b/scripts/pv/0.1/gangwa/mask/1_masked.png new file mode 100644 index 0000000..39df3e6 Binary files /dev/null and b/scripts/pv/0.1/gangwa/mask/1_masked.png differ diff --git a/scripts/pv/0.1/gangwa/mask/2_masked - 副本.png b/scripts/pv/0.1/gangwa/mask/2_masked - 副本.png new file mode 100644 index 0000000..b5d1e19 Binary files /dev/null and b/scripts/pv/0.1/gangwa/mask/2_masked - 副本.png differ diff --git a/scripts/pv/0.1/gangwa/mask/2_masked.png b/scripts/pv/0.1/gangwa/mask/2_masked.png new file mode 100644 index 0000000..b5d1e19 Binary files /dev/null and b/scripts/pv/0.1/gangwa/mask/2_masked.png differ diff --git a/scripts/pv/0.1/gangwa/mask/3_masked.png b/scripts/pv/0.1/gangwa/mask/3_masked.png new file mode 100644 index 0000000..83b3adb Binary files /dev/null and b/scripts/pv/0.1/gangwa/mask/3_masked.png differ diff --git a/scripts/pv/0.1/hunningtu/image/1.jpg b/scripts/pv/0.1/hunningtu/image/1.jpg new file mode 100644 index 0000000..cd62388 Binary files /dev/null and b/scripts/pv/0.1/hunningtu/image/1.jpg differ diff --git a/scripts/pv/0.1/hunningtu/image/2.jpg b/scripts/pv/0.1/hunningtu/image/2.jpg new file mode 100644 index 0000000..08d2ae2 Binary files /dev/null and b/scripts/pv/0.1/hunningtu/image/2.jpg differ diff --git a/scripts/pv/0.1/hunningtu/image/3.jpg b/scripts/pv/0.1/hunningtu/image/3.jpg new file mode 100644 index 0000000..18147af Binary files /dev/null and b/scripts/pv/0.1/hunningtu/image/3.jpg differ diff --git a/scripts/pv/0.1/hunningtu/image/4.jpg b/scripts/pv/0.1/hunningtu/image/4.jpg new file mode 100644 index 0000000..3bb2755 Binary files /dev/null and b/scripts/pv/0.1/hunningtu/image/4.jpg differ diff --git a/scripts/pv/0.1/hunningtu/image/5.jpg b/scripts/pv/0.1/hunningtu/image/5.jpg new file mode 100644 index 0000000..ad879e0 Binary files /dev/null and b/scripts/pv/0.1/hunningtu/image/5.jpg differ diff --git a/scripts/pv/0.1/hunningtu/image/6.jpg b/scripts/pv/0.1/hunningtu/image/6.jpg new file mode 100644 index 0000000..1c506f2 Binary files /dev/null and b/scripts/pv/0.1/hunningtu/image/6.jpg differ diff --git a/scripts/pv/0.1/hunningtu/mask/1_masked.png b/scripts/pv/0.1/hunningtu/mask/1_masked.png new file mode 100644 index 0000000..57b6853 Binary files /dev/null and b/scripts/pv/0.1/hunningtu/mask/1_masked.png differ diff --git a/scripts/pv/0.1/hunningtu/mask/2_masked.png b/scripts/pv/0.1/hunningtu/mask/2_masked.png new file mode 100644 index 0000000..38cd1a8 Binary files /dev/null and b/scripts/pv/0.1/hunningtu/mask/2_masked.png differ diff --git a/scripts/pv/0.1/hunningtu/mask/3_masked.png b/scripts/pv/0.1/hunningtu/mask/3_masked.png new file mode 100644 index 0000000..e0a1fc9 Binary files /dev/null and b/scripts/pv/0.1/hunningtu/mask/3_masked.png differ diff --git a/scripts/pv/0.1/hunningtu/mask/4_masked.png b/scripts/pv/0.1/hunningtu/mask/4_masked.png new file mode 100644 index 0000000..36da5e6 Binary files /dev/null and b/scripts/pv/0.1/hunningtu/mask/4_masked.png differ diff --git a/scripts/pv/0.1/hunningtu/mask/5_masked.png b/scripts/pv/0.1/hunningtu/mask/5_masked.png new file mode 100644 index 0000000..4ba1a14 Binary files /dev/null and b/scripts/pv/0.1/hunningtu/mask/5_masked.png differ diff --git a/scripts/pv/0.1/hunningtu/mask/6_masked.png b/scripts/pv/0.1/hunningtu/mask/6_masked.png new file mode 100644 index 0000000..9531cfb Binary files /dev/null and b/scripts/pv/0.1/hunningtu/mask/6_masked.png differ diff --git a/scripts/pv/0.1/zhuanwa/image/1.jpg b/scripts/pv/0.1/zhuanwa/image/1.jpg new file mode 100644 index 0000000..968741a Binary files /dev/null and b/scripts/pv/0.1/zhuanwa/image/1.jpg differ diff --git a/scripts/pv/0.1/zhuanwa/image/2.jpg b/scripts/pv/0.1/zhuanwa/image/2.jpg new file mode 100644 index 0000000..1275471 Binary files /dev/null and b/scripts/pv/0.1/zhuanwa/image/2.jpg differ diff --git a/scripts/pv/0.1/zhuanwa/image/3.jpg b/scripts/pv/0.1/zhuanwa/image/3.jpg new file mode 100644 index 0000000..07c051f Binary files /dev/null and b/scripts/pv/0.1/zhuanwa/image/3.jpg differ diff --git a/scripts/pv/0.1/zhuanwa/image/4.jpg b/scripts/pv/0.1/zhuanwa/image/4.jpg new file mode 100644 index 0000000..d922434 Binary files /dev/null and b/scripts/pv/0.1/zhuanwa/image/4.jpg differ diff --git a/scripts/pv/0.1/zhuanwa/mask/1_masked.png b/scripts/pv/0.1/zhuanwa/mask/1_masked.png new file mode 100644 index 0000000..d6d034f Binary files /dev/null and b/scripts/pv/0.1/zhuanwa/mask/1_masked.png differ diff --git a/scripts/pv/0.1/zhuanwa/mask/2_masked.png b/scripts/pv/0.1/zhuanwa/mask/2_masked.png new file mode 100644 index 0000000..12074a5 Binary files /dev/null and b/scripts/pv/0.1/zhuanwa/mask/2_masked.png differ diff --git a/scripts/pv/0.1/zhuanwa/mask/3_masked.png b/scripts/pv/0.1/zhuanwa/mask/3_masked.png new file mode 100644 index 0000000..4044265 Binary files /dev/null and b/scripts/pv/0.1/zhuanwa/mask/3_masked.png differ diff --git a/scripts/pv/0.1/zhuanwa/mask/4_masked.png b/scripts/pv/0.1/zhuanwa/mask/4_masked.png new file mode 100644 index 0000000..5c7d2ee Binary files /dev/null and b/scripts/pv/0.1/zhuanwa/mask/4_masked.png differ diff --git a/scripts/pv1/0.1/gangwa/image/1.jpg b/scripts/pv1/0.1/gangwa/image/1.jpg new file mode 100644 index 0000000..a2d80c1 Binary files /dev/null and b/scripts/pv1/0.1/gangwa/image/1.jpg differ diff --git a/scripts/pv1/0.1/gangwa/image/2.jpg b/scripts/pv1/0.1/gangwa/image/2.jpg new file mode 100644 index 0000000..4949295 Binary files /dev/null and b/scripts/pv1/0.1/gangwa/image/2.jpg differ diff --git a/scripts/pv1/0.1/gangwa/image/3.jpg b/scripts/pv1/0.1/gangwa/image/3.jpg new file mode 100644 index 0000000..82dc892 Binary files /dev/null and b/scripts/pv1/0.1/gangwa/image/3.jpg differ diff --git a/scripts/pv1/0.1/gangwa/image/4.jpg b/scripts/pv1/0.1/gangwa/image/4.jpg new file mode 100644 index 0000000..2a39320 Binary files /dev/null and b/scripts/pv1/0.1/gangwa/image/4.jpg differ diff --git a/scripts/pv1/0.1/gangwa/mask/1_masked.png b/scripts/pv1/0.1/gangwa/mask/1_masked.png new file mode 100644 index 0000000..a542967 Binary files /dev/null and b/scripts/pv1/0.1/gangwa/mask/1_masked.png differ diff --git a/scripts/pv1/0.1/gangwa/mask/2_masked.png b/scripts/pv1/0.1/gangwa/mask/2_masked.png new file mode 100644 index 0000000..8508fbe Binary files /dev/null and b/scripts/pv1/0.1/gangwa/mask/2_masked.png differ diff --git a/scripts/pv1/0.1/gangwa/mask/3_masked.png b/scripts/pv1/0.1/gangwa/mask/3_masked.png new file mode 100644 index 0000000..434b066 Binary files /dev/null and b/scripts/pv1/0.1/gangwa/mask/3_masked.png differ diff --git a/scripts/pv1/0.1/gangwa/mask/4_masked.png b/scripts/pv1/0.1/gangwa/mask/4_masked.png new file mode 100644 index 0000000..f5e55fb Binary files /dev/null and b/scripts/pv1/0.1/gangwa/mask/4_masked.png differ diff --git a/scripts/pv1/0.1/hunningtu/image/1.jpg b/scripts/pv1/0.1/hunningtu/image/1.jpg new file mode 100644 index 0000000..cd62388 Binary files /dev/null and b/scripts/pv1/0.1/hunningtu/image/1.jpg differ diff --git a/scripts/pv1/0.1/hunningtu/image/2.jpg b/scripts/pv1/0.1/hunningtu/image/2.jpg new file mode 100644 index 0000000..08d2ae2 Binary files /dev/null and b/scripts/pv1/0.1/hunningtu/image/2.jpg differ diff --git a/scripts/pv1/0.1/hunningtu/image/3.jpg b/scripts/pv1/0.1/hunningtu/image/3.jpg new file mode 100644 index 0000000..18147af Binary files /dev/null and b/scripts/pv1/0.1/hunningtu/image/3.jpg differ diff --git a/scripts/pv1/0.1/hunningtu/image/4.jpg b/scripts/pv1/0.1/hunningtu/image/4.jpg new file mode 100644 index 0000000..3bb2755 Binary files /dev/null and b/scripts/pv1/0.1/hunningtu/image/4.jpg differ diff --git a/scripts/pv1/0.1/hunningtu/image/5.jpg b/scripts/pv1/0.1/hunningtu/image/5.jpg new file mode 100644 index 0000000..ad879e0 Binary files /dev/null and b/scripts/pv1/0.1/hunningtu/image/5.jpg differ diff --git a/scripts/pv1/0.1/hunningtu/image/6.jpg b/scripts/pv1/0.1/hunningtu/image/6.jpg new file mode 100644 index 0000000..1c506f2 Binary files /dev/null and b/scripts/pv1/0.1/hunningtu/image/6.jpg differ diff --git a/scripts/pv1/0.1/hunningtu/mask/1_masked.png b/scripts/pv1/0.1/hunningtu/mask/1_masked.png new file mode 100644 index 0000000..7476c7b Binary files /dev/null and b/scripts/pv1/0.1/hunningtu/mask/1_masked.png differ diff --git a/scripts/pv1/0.1/hunningtu/mask/2_masked.png b/scripts/pv1/0.1/hunningtu/mask/2_masked.png new file mode 100644 index 0000000..8b14b13 Binary files /dev/null and b/scripts/pv1/0.1/hunningtu/mask/2_masked.png differ diff --git a/scripts/pv1/0.1/hunningtu/mask/3_masked.png b/scripts/pv1/0.1/hunningtu/mask/3_masked.png new file mode 100644 index 0000000..722c780 Binary files /dev/null and b/scripts/pv1/0.1/hunningtu/mask/3_masked.png differ diff --git a/scripts/pv1/0.1/hunningtu/mask/4_masked.png b/scripts/pv1/0.1/hunningtu/mask/4_masked.png new file mode 100644 index 0000000..9567788 Binary files /dev/null and b/scripts/pv1/0.1/hunningtu/mask/4_masked.png differ diff --git a/scripts/pv1/0.1/hunningtu/mask/5_masked.png b/scripts/pv1/0.1/hunningtu/mask/5_masked.png new file mode 100644 index 0000000..8326a6d Binary files /dev/null and b/scripts/pv1/0.1/hunningtu/mask/5_masked.png differ diff --git a/scripts/pv1/0.1/hunningtu/mask/6_masked.png b/scripts/pv1/0.1/hunningtu/mask/6_masked.png new file mode 100644 index 0000000..5ff02bc Binary files /dev/null and b/scripts/pv1/0.1/hunningtu/mask/6_masked.png differ diff --git a/scripts/pv1/0.1/zhuanwa/image/1.jpg b/scripts/pv1/0.1/zhuanwa/image/1.jpg new file mode 100644 index 0000000..968741a Binary files /dev/null and b/scripts/pv1/0.1/zhuanwa/image/1.jpg differ diff --git a/scripts/pv1/0.1/zhuanwa/image/2.jpg b/scripts/pv1/0.1/zhuanwa/image/2.jpg new file mode 100644 index 0000000..1275471 Binary files /dev/null and b/scripts/pv1/0.1/zhuanwa/image/2.jpg differ diff --git a/scripts/pv1/0.1/zhuanwa/image/3.jpg b/scripts/pv1/0.1/zhuanwa/image/3.jpg new file mode 100644 index 0000000..07c051f Binary files /dev/null and b/scripts/pv1/0.1/zhuanwa/image/3.jpg differ diff --git a/scripts/pv1/0.1/zhuanwa/image/4.jpg b/scripts/pv1/0.1/zhuanwa/image/4.jpg new file mode 100644 index 0000000..d922434 Binary files /dev/null and b/scripts/pv1/0.1/zhuanwa/image/4.jpg differ diff --git a/scripts/pv1/0.1/zhuanwa/mask/1_masked.png b/scripts/pv1/0.1/zhuanwa/mask/1_masked.png new file mode 100644 index 0000000..00c7c2e Binary files /dev/null and b/scripts/pv1/0.1/zhuanwa/mask/1_masked.png differ diff --git a/scripts/pv1/0.1/zhuanwa/mask/2_masked.png b/scripts/pv1/0.1/zhuanwa/mask/2_masked.png new file mode 100644 index 0000000..8c06b7d Binary files /dev/null and b/scripts/pv1/0.1/zhuanwa/mask/2_masked.png differ diff --git a/scripts/pv1/0.1/zhuanwa/mask/3_masked.png b/scripts/pv1/0.1/zhuanwa/mask/3_masked.png new file mode 100644 index 0000000..ab752e6 Binary files /dev/null and b/scripts/pv1/0.1/zhuanwa/mask/3_masked.png differ diff --git a/scripts/pv1/0.1/zhuanwa/mask/4_masked.png b/scripts/pv1/0.1/zhuanwa/mask/4_masked.png new file mode 100644 index 0000000..ec5b0bd Binary files /dev/null and b/scripts/pv1/0.1/zhuanwa/mask/4_masked.png differ diff --git a/scripts/pv1/0.3/caodi/jpg/1.jpg b/scripts/pv1/0.3/caodi/jpg/1.jpg new file mode 100644 index 0000000..de7e0e2 Binary files /dev/null and b/scripts/pv1/0.3/caodi/jpg/1.jpg differ diff --git a/scripts/pv1/0.3/caodi/jpg/2.jpg b/scripts/pv1/0.3/caodi/jpg/2.jpg new file mode 100644 index 0000000..cdfe273 Binary files /dev/null and b/scripts/pv1/0.3/caodi/jpg/2.jpg differ diff --git a/scripts/pv1/0.3/caodi/mask/1447_masked.png b/scripts/pv1/0.3/caodi/mask/1447_masked.png new file mode 100644 index 0000000..187978a Binary files /dev/null and b/scripts/pv1/0.3/caodi/mask/1447_masked.png differ diff --git a/scripts/pv1/0.3/caodi/mask/1854_masked.png b/scripts/pv1/0.3/caodi/mask/1854_masked.png new file mode 100644 index 0000000..ec0a4ba Binary files /dev/null and b/scripts/pv1/0.3/caodi/mask/1854_masked.png differ diff --git a/scripts/pv1/0.3/caodi/mask/1_masked.png b/scripts/pv1/0.3/caodi/mask/1_masked.png new file mode 100644 index 0000000..c39e48a Binary files /dev/null and b/scripts/pv1/0.3/caodi/mask/1_masked.png differ diff --git a/scripts/pv1/0.3/caodi/mask/2_masked.png b/scripts/pv1/0.3/caodi/mask/2_masked.png new file mode 100644 index 0000000..b83bcd6 Binary files /dev/null and b/scripts/pv1/0.3/caodi/mask/2_masked.png differ diff --git a/scripts/pv1/0.3/guanmudi/jpg/1.jpg b/scripts/pv1/0.3/guanmudi/jpg/1.jpg new file mode 100644 index 0000000..2c60edd Binary files /dev/null and b/scripts/pv1/0.3/guanmudi/jpg/1.jpg differ diff --git a/scripts/pv1/0.3/guanmudi/jpg/2.jpg b/scripts/pv1/0.3/guanmudi/jpg/2.jpg new file mode 100644 index 0000000..36061e5 Binary files /dev/null and b/scripts/pv1/0.3/guanmudi/jpg/2.jpg differ diff --git a/scripts/pv1/0.3/guanmudi/jpg/3.jpg b/scripts/pv1/0.3/guanmudi/jpg/3.jpg new file mode 100644 index 0000000..06c4d93 Binary files /dev/null and b/scripts/pv1/0.3/guanmudi/jpg/3.jpg differ diff --git a/scripts/pv1/0.3/guanmudi/mask/1_masked.png b/scripts/pv1/0.3/guanmudi/mask/1_masked.png new file mode 100644 index 0000000..bc3b6af Binary files /dev/null and b/scripts/pv1/0.3/guanmudi/mask/1_masked.png differ diff --git a/scripts/pv1/0.3/guanmudi/mask/2_masked.png b/scripts/pv1/0.3/guanmudi/mask/2_masked.png new file mode 100644 index 0000000..e5d2d9e Binary files /dev/null and b/scripts/pv1/0.3/guanmudi/mask/2_masked.png differ diff --git a/scripts/pv1/0.3/nongtian/jpg/1.jpg b/scripts/pv1/0.3/nongtian/jpg/1.jpg new file mode 100644 index 0000000..38c9a08 Binary files /dev/null and b/scripts/pv1/0.3/nongtian/jpg/1.jpg differ diff --git a/scripts/pv1/0.3/nongtian/jpg/2.jpg b/scripts/pv1/0.3/nongtian/jpg/2.jpg new file mode 100644 index 0000000..16849f4 Binary files /dev/null and b/scripts/pv1/0.3/nongtian/jpg/2.jpg differ diff --git a/scripts/pv1/0.3/nongtian/jpg/3.jpg b/scripts/pv1/0.3/nongtian/jpg/3.jpg new file mode 100644 index 0000000..699a09e Binary files /dev/null and b/scripts/pv1/0.3/nongtian/jpg/3.jpg differ diff --git a/scripts/pv1/0.3/nongtian/mask/1_masked.png b/scripts/pv1/0.3/nongtian/mask/1_masked.png new file mode 100644 index 0000000..41c21f3 Binary files /dev/null and b/scripts/pv1/0.3/nongtian/mask/1_masked.png differ diff --git a/scripts/pv1/0.3/nongtian/mask/2_masked.png b/scripts/pv1/0.3/nongtian/mask/2_masked.png new file mode 100644 index 0000000..da7985c Binary files /dev/null and b/scripts/pv1/0.3/nongtian/mask/2_masked.png differ diff --git a/scripts/pv1/0.3/shuimian/jpg/1.jpg b/scripts/pv1/0.3/shuimian/jpg/1.jpg new file mode 100644 index 0000000..f08d00e Binary files /dev/null and b/scripts/pv1/0.3/shuimian/jpg/1.jpg differ diff --git a/scripts/pv1/0.3/shuimian/jpg/2.jpg b/scripts/pv1/0.3/shuimian/jpg/2.jpg new file mode 100644 index 0000000..23de944 Binary files /dev/null and b/scripts/pv1/0.3/shuimian/jpg/2.jpg differ diff --git a/scripts/pv1/0.3/shuimian/jpg/3.jpg b/scripts/pv1/0.3/shuimian/jpg/3.jpg new file mode 100644 index 0000000..7e062b0 Binary files /dev/null and b/scripts/pv1/0.3/shuimian/jpg/3.jpg differ diff --git a/scripts/pv1/0.3/shuimian/mask/1_masked.png b/scripts/pv1/0.3/shuimian/mask/1_masked.png new file mode 100644 index 0000000..0c65c90 Binary files /dev/null and b/scripts/pv1/0.3/shuimian/mask/1_masked.png differ diff --git a/scripts/pv1/0.3/shuimian/mask/3_masked.png b/scripts/pv1/0.3/shuimian/mask/3_masked.png new file mode 100644 index 0000000..f2c999b Binary files /dev/null and b/scripts/pv1/0.3/shuimian/mask/3_masked.png differ diff --git a/scripts/pv1/0.3/uding/jpg/1.jpg b/scripts/pv1/0.3/uding/jpg/1.jpg new file mode 100644 index 0000000..51c1a13 Binary files /dev/null and b/scripts/pv1/0.3/uding/jpg/1.jpg differ diff --git a/scripts/pv1/0.3/uding/jpg/2.jpg b/scripts/pv1/0.3/uding/jpg/2.jpg new file mode 100644 index 0000000..cf0682d Binary files /dev/null and b/scripts/pv1/0.3/uding/jpg/2.jpg differ diff --git a/scripts/pv1/0.3/uding/jpg/3.jpg b/scripts/pv1/0.3/uding/jpg/3.jpg new file mode 100644 index 0000000..30edbaf Binary files /dev/null and b/scripts/pv1/0.3/uding/jpg/3.jpg differ diff --git a/scripts/pv1/0.3/uding/mask/1_masked.png b/scripts/pv1/0.3/uding/mask/1_masked.png new file mode 100644 index 0000000..b95bd0a Binary files /dev/null and b/scripts/pv1/0.3/uding/mask/1_masked.png differ diff --git a/scripts/pv1/0.3/uding/mask/2_masked.png b/scripts/pv1/0.3/uding/mask/2_masked.png new file mode 100644 index 0000000..854df9d Binary files /dev/null and b/scripts/pv1/0.3/uding/mask/2_masked.png differ diff --git a/scripts/pv1/0.3/uding/mask/3_masked.png b/scripts/pv1/0.3/uding/mask/3_masked.png new file mode 100644 index 0000000..0407ec9 Binary files /dev/null and b/scripts/pv1/0.3/uding/mask/3_masked.png differ diff --git a/scripts/pv1/0.3/yanjiandi/jpg/1.jpg b/scripts/pv1/0.3/yanjiandi/jpg/1.jpg new file mode 100644 index 0000000..45179b4 Binary files /dev/null and b/scripts/pv1/0.3/yanjiandi/jpg/1.jpg differ diff --git a/scripts/pv1/0.3/yanjiandi/jpg/2.jpg b/scripts/pv1/0.3/yanjiandi/jpg/2.jpg new file mode 100644 index 0000000..47dfbeb Binary files /dev/null and b/scripts/pv1/0.3/yanjiandi/jpg/2.jpg differ diff --git a/scripts/pv1/0.3/yanjiandi/jpg/3.jpg b/scripts/pv1/0.3/yanjiandi/jpg/3.jpg new file mode 100644 index 0000000..7081b49 Binary files /dev/null and b/scripts/pv1/0.3/yanjiandi/jpg/3.jpg differ diff --git a/scripts/pv1/0.3/yanjiandi/mask/1_masked.png b/scripts/pv1/0.3/yanjiandi/mask/1_masked.png new file mode 100644 index 0000000..dd95867 Binary files /dev/null and b/scripts/pv1/0.3/yanjiandi/mask/1_masked.png differ diff --git a/scripts/pv1/0.3/yanjiandi/mask/2_masked.png b/scripts/pv1/0.3/yanjiandi/mask/2_masked.png new file mode 100644 index 0000000..01eab19 Binary files /dev/null and b/scripts/pv1/0.3/yanjiandi/mask/2_masked.png differ diff --git a/scripts/pv1/0.3/yanjiandi/mask/3_masked.png b/scripts/pv1/0.3/yanjiandi/mask/3_masked.png new file mode 100644 index 0000000..85ac387 Binary files /dev/null and b/scripts/pv1/0.3/yanjiandi/mask/3_masked.png differ diff --git a/scripts/pv1/0.6/caodi/jpg/1447.jpg b/scripts/pv1/0.6/caodi/jpg/1447.jpg new file mode 100644 index 0000000..31f980e Binary files /dev/null and b/scripts/pv1/0.6/caodi/jpg/1447.jpg differ diff --git a/scripts/pv1/0.6/caodi/jpg/1854.jpg b/scripts/pv1/0.6/caodi/jpg/1854.jpg new file mode 100644 index 0000000..dc8ecad Binary files /dev/null and b/scripts/pv1/0.6/caodi/jpg/1854.jpg differ diff --git a/scripts/pv1/0.6/gengdi/jpg/2483.jpg b/scripts/pv1/0.6/gengdi/jpg/2483.jpg new file mode 100644 index 0000000..5efcac5 Binary files /dev/null and b/scripts/pv1/0.6/gengdi/jpg/2483.jpg differ diff --git a/scripts/pv1/0.6/gengdi/jpg/33.jpg b/scripts/pv1/0.6/gengdi/jpg/33.jpg new file mode 100644 index 0000000..dffeb60 Binary files /dev/null and b/scripts/pv1/0.6/gengdi/jpg/33.jpg differ diff --git a/scripts/pv1/0.6/gengdi/mask/2483_masked.png b/scripts/pv1/0.6/gengdi/mask/2483_masked.png new file mode 100644 index 0000000..89e807b Binary files /dev/null and b/scripts/pv1/0.6/gengdi/mask/2483_masked.png differ diff --git a/scripts/pv1/0.6/gengdi/mask/33_masked.png b/scripts/pv1/0.6/gengdi/mask/33_masked.png new file mode 100644 index 0000000..8b2819d Binary files /dev/null and b/scripts/pv1/0.6/gengdi/mask/33_masked.png differ diff --git a/scripts/pv1/0.6/humian/jpg/3499.jpg b/scripts/pv1/0.6/humian/jpg/3499.jpg new file mode 100644 index 0000000..872b4f2 Binary files /dev/null and b/scripts/pv1/0.6/humian/jpg/3499.jpg differ diff --git a/scripts/pv1/0.6/humian/jpg/3504.jpg b/scripts/pv1/0.6/humian/jpg/3504.jpg new file mode 100644 index 0000000..bc00168 Binary files /dev/null and b/scripts/pv1/0.6/humian/jpg/3504.jpg differ diff --git a/scripts/pv1/0.6/humian/mask/3499_masked.png b/scripts/pv1/0.6/humian/mask/3499_masked.png new file mode 100644 index 0000000..932ebb2 Binary files /dev/null and b/scripts/pv1/0.6/humian/mask/3499_masked.png differ diff --git a/scripts/pv1/0.6/humian/mask/3504_masked.png b/scripts/pv1/0.6/humian/mask/3504_masked.png new file mode 100644 index 0000000..1cb441e Binary files /dev/null and b/scripts/pv1/0.6/humian/mask/3504_masked.png differ diff --git a/scripts/pv1/0.6/shamo/jpg/3399.jpg b/scripts/pv1/0.6/shamo/jpg/3399.jpg new file mode 100644 index 0000000..d96586c Binary files /dev/null and b/scripts/pv1/0.6/shamo/jpg/3399.jpg differ diff --git a/scripts/pv1/0.6/shamo/jpg/3400.jpg b/scripts/pv1/0.6/shamo/jpg/3400.jpg new file mode 100644 index 0000000..7c17493 Binary files /dev/null and b/scripts/pv1/0.6/shamo/jpg/3400.jpg differ diff --git a/scripts/pv1/0.6/shamo/mask/3399_masked.png b/scripts/pv1/0.6/shamo/mask/3399_masked.png new file mode 100644 index 0000000..b603c20 Binary files /dev/null and b/scripts/pv1/0.6/shamo/mask/3399_masked.png differ diff --git a/scripts/pv1/0.6/shamo/mask/3400_masked.png b/scripts/pv1/0.6/shamo/mask/3400_masked.png new file mode 100644 index 0000000..b21b771 Binary files /dev/null and b/scripts/pv1/0.6/shamo/mask/3400_masked.png differ diff --git a/scripts/pv1/0.6/shandi/jpg/6.jpg b/scripts/pv1/0.6/shandi/jpg/6.jpg new file mode 100644 index 0000000..61ec3b1 Binary files /dev/null and b/scripts/pv1/0.6/shandi/jpg/6.jpg differ diff --git a/scripts/pv1/0.6/shandi/jpg/7.jpg b/scripts/pv1/0.6/shandi/jpg/7.jpg new file mode 100644 index 0000000..67447fd Binary files /dev/null and b/scripts/pv1/0.6/shandi/jpg/7.jpg differ diff --git a/scripts/pv1/0.6/shandi/mask/6_masked.png b/scripts/pv1/0.6/shandi/mask/6_masked.png new file mode 100644 index 0000000..e271950 Binary files /dev/null and b/scripts/pv1/0.6/shandi/mask/6_masked.png differ diff --git a/scripts/pv1/0.6/shandi/mask/7_masked.png b/scripts/pv1/0.6/shandi/mask/7_masked.png new file mode 100644 index 0000000..2b3a272 Binary files /dev/null and b/scripts/pv1/0.6/shandi/mask/7_masked.png differ diff --git a/scripts/pv1/0.6/wuding/jpg/279.jpg b/scripts/pv1/0.6/wuding/jpg/279.jpg new file mode 100644 index 0000000..59094e2 Binary files /dev/null and b/scripts/pv1/0.6/wuding/jpg/279.jpg differ diff --git a/scripts/pv1/0.6/wuding/jpg/458.jpg b/scripts/pv1/0.6/wuding/jpg/458.jpg new file mode 100644 index 0000000..7993197 Binary files /dev/null and b/scripts/pv1/0.6/wuding/jpg/458.jpg differ diff --git a/scripts/pv1/0.6/wuding/mask/279_masked.png b/scripts/pv1/0.6/wuding/mask/279_masked.png new file mode 100644 index 0000000..20ac4c2 Binary files /dev/null and b/scripts/pv1/0.6/wuding/mask/279_masked.png differ diff --git a/scripts/pv1/0.6/wuding/mask/458_masked.png b/scripts/pv1/0.6/wuding/mask/458_masked.png new file mode 100644 index 0000000..ba7574f Binary files /dev/null and b/scripts/pv1/0.6/wuding/mask/458_masked.png differ diff --git a/scripts/pv1/0.8/dimian/jpg/1.jpg b/scripts/pv1/0.8/dimian/jpg/1.jpg new file mode 100644 index 0000000..b50fdd7 Binary files /dev/null and b/scripts/pv1/0.8/dimian/jpg/1.jpg differ diff --git a/scripts/pv1/0.8/dimian/jpg/2.jpg b/scripts/pv1/0.8/dimian/jpg/2.jpg new file mode 100644 index 0000000..e9a8d38 Binary files /dev/null and b/scripts/pv1/0.8/dimian/jpg/2.jpg differ diff --git a/scripts/pv1/0.8/dimian/jpg/3.jpg b/scripts/pv1/0.8/dimian/jpg/3.jpg new file mode 100644 index 0000000..fc47b5c Binary files /dev/null and b/scripts/pv1/0.8/dimian/jpg/3.jpg differ diff --git a/scripts/pv1/0.8/dimian/jpg/4.jpg b/scripts/pv1/0.8/dimian/jpg/4.jpg new file mode 100644 index 0000000..bc599d6 Binary files /dev/null and b/scripts/pv1/0.8/dimian/jpg/4.jpg differ diff --git a/scripts/pv1/0.8/dimian/jpg/5.jpg b/scripts/pv1/0.8/dimian/jpg/5.jpg new file mode 100644 index 0000000..adbf8c8 Binary files /dev/null and b/scripts/pv1/0.8/dimian/jpg/5.jpg differ diff --git a/scripts/pv1/0.8/dimian/jpg/6.jpg b/scripts/pv1/0.8/dimian/jpg/6.jpg new file mode 100644 index 0000000..ef41cb1 Binary files /dev/null and b/scripts/pv1/0.8/dimian/jpg/6.jpg differ diff --git a/scripts/pv1/0.8/dimian/mask/1_masked.png b/scripts/pv1/0.8/dimian/mask/1_masked.png new file mode 100644 index 0000000..b702258 Binary files /dev/null and b/scripts/pv1/0.8/dimian/mask/1_masked.png differ diff --git a/scripts/pv1/0.8/dimian/mask/2_masked.png b/scripts/pv1/0.8/dimian/mask/2_masked.png new file mode 100644 index 0000000..4f20a59 Binary files /dev/null and b/scripts/pv1/0.8/dimian/mask/2_masked.png differ diff --git a/scripts/pv1/0.8/dimian/mask/3_masked.png b/scripts/pv1/0.8/dimian/mask/3_masked.png new file mode 100644 index 0000000..21ec285 Binary files /dev/null and b/scripts/pv1/0.8/dimian/mask/3_masked.png differ diff --git a/scripts/pv1/0.8/dimian/mask/4_masked.png b/scripts/pv1/0.8/dimian/mask/4_masked.png new file mode 100644 index 0000000..e690246 Binary files /dev/null and b/scripts/pv1/0.8/dimian/mask/4_masked.png differ diff --git a/scripts/pv1/0.8/dimian/mask/5_masked.png b/scripts/pv1/0.8/dimian/mask/5_masked.png new file mode 100644 index 0000000..e4b8d50 Binary files /dev/null and b/scripts/pv1/0.8/dimian/mask/5_masked.png differ diff --git a/scripts/pv1/0.8/dimian/mask/6_masked.png b/scripts/pv1/0.8/dimian/mask/6_masked.png new file mode 100644 index 0000000..cab77a3 Binary files /dev/null and b/scripts/pv1/0.8/dimian/mask/6_masked.png differ diff --git a/scripts/pv1/0.8/wuding/jpg/1.jpg b/scripts/pv1/0.8/wuding/jpg/1.jpg new file mode 100644 index 0000000..223aee1 Binary files /dev/null and b/scripts/pv1/0.8/wuding/jpg/1.jpg differ diff --git a/scripts/pv1/0.8/wuding/jpg/2.jpg b/scripts/pv1/0.8/wuding/jpg/2.jpg new file mode 100644 index 0000000..ff89f98 Binary files /dev/null and b/scripts/pv1/0.8/wuding/jpg/2.jpg differ diff --git a/scripts/pv1/0.8/wuding/jpg/3.jpg b/scripts/pv1/0.8/wuding/jpg/3.jpg new file mode 100644 index 0000000..693a368 Binary files /dev/null and b/scripts/pv1/0.8/wuding/jpg/3.jpg differ diff --git a/scripts/pv1/0.8/wuding/jpg/4.jpg b/scripts/pv1/0.8/wuding/jpg/4.jpg new file mode 100644 index 0000000..a5230c1 Binary files /dev/null and b/scripts/pv1/0.8/wuding/jpg/4.jpg differ diff --git a/scripts/pv1/0.8/wuding/jpg/5..jpg b/scripts/pv1/0.8/wuding/jpg/5..jpg new file mode 100644 index 0000000..1f09dd2 Binary files /dev/null and b/scripts/pv1/0.8/wuding/jpg/5..jpg differ diff --git a/scripts/pv1/0.8/wuding/jpg/6.jpg b/scripts/pv1/0.8/wuding/jpg/6.jpg new file mode 100644 index 0000000..daa7fd8 Binary files /dev/null and b/scripts/pv1/0.8/wuding/jpg/6.jpg differ diff --git a/scripts/pv1/0.8/wuding/mask/1_masked.png b/scripts/pv1/0.8/wuding/mask/1_masked.png new file mode 100644 index 0000000..897fb23 Binary files /dev/null and b/scripts/pv1/0.8/wuding/mask/1_masked.png differ diff --git a/scripts/pv1/0.8/wuding/mask/2_masked.png b/scripts/pv1/0.8/wuding/mask/2_masked.png new file mode 100644 index 0000000..9f7bd4b Binary files /dev/null and b/scripts/pv1/0.8/wuding/mask/2_masked.png differ diff --git a/scripts/pv1/0.8/wuding/mask/3_masked.png b/scripts/pv1/0.8/wuding/mask/3_masked.png new file mode 100644 index 0000000..9c08b01 Binary files /dev/null and b/scripts/pv1/0.8/wuding/mask/3_masked.png differ diff --git a/scripts/pv1/0.8/wuding/mask/4_masked.png b/scripts/pv1/0.8/wuding/mask/4_masked.png new file mode 100644 index 0000000..cd3caf0 Binary files /dev/null and b/scripts/pv1/0.8/wuding/mask/4_masked.png differ diff --git a/scripts/pv1/0.8/wuding/mask/5._masked.png b/scripts/pv1/0.8/wuding/mask/5._masked.png new file mode 100644 index 0000000..19fa95a Binary files /dev/null and b/scripts/pv1/0.8/wuding/mask/5._masked.png differ diff --git a/scripts/pv1/0.8/wuding/mask/6_masked.png b/scripts/pv1/0.8/wuding/mask/6_masked.png new file mode 100644 index 0000000..344fe02 Binary files /dev/null and b/scripts/pv1/0.8/wuding/mask/6_masked.png differ diff --git a/segment.py b/segment.py new file mode 100644 index 0000000..71f6e0b --- /dev/null +++ b/segment.py @@ -0,0 +1,59 @@ +import cv2 +import os +import numpy as np +from segment_anything import sam_model_registry, SamPredictor + +input_dir = 'scripts/input/images' +output_dir = 'scripts/output/mask' +crop_mode = True + +print('最好是每加一个点就按w键predict一次') +os.makedirs(output_dir, exist_ok=True) +image_files = [f for f in os.listdir(input_dir) if + f.lower().endswith(('.png', '.jpg', '.jpeg', '.JPG', '.JPEG', '.PNG', '.tiff'))] + +sam = sam_model_registry["vit_b"](checkpoint=r"D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth") +_ = sam.to(device="cuda") +predictor = SamPredictor(sam) + +WINDOW_WIDTH = 1280 +WINDOW_HEIGHT = 720 +cv2.namedWindow("image", cv2.WINDOW_NORMAL) +cv2.resizeWindow("image", WINDOW_WIDTH, WINDOW_HEIGHT) +cv2.moveWindow("image", (1920 - WINDOW_WIDTH) // 2, (1080 - WINDOW_HEIGHT) // 2) + +def button_image_open(): + # 实现打开图片功能 + pass + +def button_image_init(): + # 实现重置选择功能 + pass + +def button_image_shang(): + # 实现切换至上一张图像功能 + pass + +def button_image_xia(): + # 实现切换至下一张图像功能 + pass + +def button_image_exit(): + # 实现退出程序功能 + pass + +def button_image_Transparency(): + # 实现调整透明度功能 + pass + +def button_image_copymask(): + # 实现复制掩码功能 + pass + +def button_image_saveimg(): + # 实现保存图片功能 + pass + +def button_image_Fusionimg(): + # 实现融合背景图片功能 + pass diff --git a/segment_anything/__init__.py b/segment_anything/__init__.py new file mode 100644 index 0000000..34383d8 --- /dev/null +++ b/segment_anything/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .build_sam import ( + build_sam, + build_sam_vit_h, + build_sam_vit_l, + build_sam_vit_b, + sam_model_registry, +) +from .predictor import SamPredictor +from .automatic_mask_generator import SamAutomaticMaskGenerator diff --git a/segment_anything/__pycache__/__init__.cpython-39.pyc b/segment_anything/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000..c3507b2 Binary files /dev/null and b/segment_anything/__pycache__/__init__.cpython-39.pyc differ diff --git a/segment_anything/__pycache__/automatic_mask_generator.cpython-39.pyc b/segment_anything/__pycache__/automatic_mask_generator.cpython-39.pyc new file mode 100644 index 0000000..bf2c0bc Binary files /dev/null and b/segment_anything/__pycache__/automatic_mask_generator.cpython-39.pyc differ diff --git a/segment_anything/__pycache__/build_sam.cpython-39.pyc b/segment_anything/__pycache__/build_sam.cpython-39.pyc new file mode 100644 index 0000000..88b4910 Binary files /dev/null and b/segment_anything/__pycache__/build_sam.cpython-39.pyc differ diff --git a/segment_anything/__pycache__/predictor.cpython-39.pyc b/segment_anything/__pycache__/predictor.cpython-39.pyc new file mode 100644 index 0000000..5dedb32 Binary files /dev/null and b/segment_anything/__pycache__/predictor.cpython-39.pyc differ diff --git a/segment_anything/automatic_mask_generator.py b/segment_anything/automatic_mask_generator.py new file mode 100644 index 0000000..d5a8c96 --- /dev/null +++ b/segment_anything/automatic_mask_generator.py @@ -0,0 +1,372 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from torchvision.ops.boxes import batched_nms, box_area # type: ignore + +from typing import Any, Dict, List, Optional, Tuple + +from .modeling import Sam +from .predictor import SamPredictor +from .utils.amg import ( + MaskData, + area_from_rle, + batch_iterator, + batched_mask_to_box, + box_xyxy_to_xywh, + build_all_layer_point_grids, + calculate_stability_score, + coco_encode_rle, + generate_crop_boxes, + is_box_near_crop_edge, + mask_to_rle_pytorch, + remove_small_regions, + rle_to_mask, + uncrop_boxes_xyxy, + uncrop_masks, + uncrop_points, +) + + +class SamAutomaticMaskGenerator: + def __init__( + self, + model: Sam, + points_per_side: Optional[int] = 32, + points_per_batch: int = 64, + pred_iou_thresh: float = 0.88, + stability_score_thresh: float = 0.95, + stability_score_offset: float = 1.0, + box_nms_thresh: float = 0.7, + crop_n_layers: int = 0, + crop_nms_thresh: float = 0.7, + crop_overlap_ratio: float = 512 / 1500, + crop_n_points_downscale_factor: int = 1, + point_grids: Optional[List[np.ndarray]] = None, + min_mask_region_area: int = 0, + output_mode: str = "binary_mask", + ) -> None: + """ + Using a SAM model, generates masks for the entire image. + Generates a grid of point prompts over the image, then filters + low quality and duplicate masks. The default settings are chosen + for SAM with a ViT-H backbone. + + Arguments: + model (Sam): The SAM model to use for mask prediction. + points_per_side (int or None): The number of points to be sampled + along one side of the image. The total number of points is + points_per_side**2. If None, 'point_grids' must provide explicit + point sampling. + points_per_batch (int): Sets the number of points run simultaneously + by the model. Higher numbers may be faster but use more GPU memory. + pred_iou_thresh (float): A filtering threshold in [0,1], using the + model's predicted mask quality. + stability_score_thresh (float): A filtering threshold in [0,1], using + the stability of the mask under changes to the cutoff used to binarize + the model's mask predictions. + stability_score_offset (float): The amount to shift the cutoff when + calculated the stability score. + box_nms_thresh (float): The box IoU cutoff used by non-maximal + suppression to filter duplicate masks. + crop_n_layers (int): If >0, mask prediction will be run again on + crops of the image. Sets the number of layers to run, where each + layer has 2**i_layer number of image crops. + crop_nms_thresh (float): The box IoU cutoff used by non-maximal + suppression to filter duplicate masks between different crops. + crop_overlap_ratio (float): Sets the degree to which crops overlap. + In the first crop layer, crops will overlap by this fraction of + the image length. Later layers with more crops scale down this overlap. + crop_n_points_downscale_factor (int): The number of points-per-side + sampled in layer n is scaled down by crop_n_points_downscale_factor**n. + point_grids (list(np.ndarray) or None): A list over explicit grids + of points used for sampling, normalized to [0,1]. The nth grid in the + list is used in the nth crop layer. Exclusive with points_per_side. + min_mask_region_area (int): If >0, postprocessing will be applied + to remove disconnected regions and holes in masks with area smaller + than min_mask_region_area. Requires opencv. + output_mode (str): The form masks are returned in. Can be 'binary_mask', + 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools. + For large resolutions, 'binary_mask' may consume large amounts of + memory. + """ + + assert (points_per_side is None) != ( + point_grids is None + ), "Exactly one of points_per_side or point_grid must be provided." + if points_per_side is not None: + self.point_grids = build_all_layer_point_grids( + points_per_side, + crop_n_layers, + crop_n_points_downscale_factor, + ) + elif point_grids is not None: + self.point_grids = point_grids + else: + raise ValueError("Can't have both points_per_side and point_grid be None.") + + assert output_mode in [ + "binary_mask", + "uncompressed_rle", + "coco_rle", + ], f"Unknown output_mode {output_mode}." + if output_mode == "coco_rle": + from pycocotools import mask as mask_utils # type: ignore # noqa: F401 + + if min_mask_region_area > 0: + import cv2 # type: ignore # noqa: F401 + + self.predictor = SamPredictor(model) + self.points_per_batch = points_per_batch + self.pred_iou_thresh = pred_iou_thresh + self.stability_score_thresh = stability_score_thresh + self.stability_score_offset = stability_score_offset + self.box_nms_thresh = box_nms_thresh + self.crop_n_layers = crop_n_layers + self.crop_nms_thresh = crop_nms_thresh + self.crop_overlap_ratio = crop_overlap_ratio + self.crop_n_points_downscale_factor = crop_n_points_downscale_factor + self.min_mask_region_area = min_mask_region_area + self.output_mode = output_mode + + @torch.no_grad() + def generate(self, image: np.ndarray) -> List[Dict[str, Any]]: + """ + Generates masks for the given image. + + Arguments: + image (np.ndarray): The image to generate masks for, in HWC uint8 format. + + Returns: + list(dict(str, any)): A list over records for masks. Each record is + a dict containing the following keys: + segmentation (dict(str, any) or np.ndarray): The mask. If + output_mode='binary_mask', is an array of shape HW. Otherwise, + is a dictionary containing the RLE. + bbox (list(float)): The box around the mask, in XYWH format. + area (int): The area in pixels of the mask. + predicted_iou (float): The model's own prediction of the mask's + quality. This is filtered by the pred_iou_thresh parameter. + point_coords (list(list(float))): The point coordinates input + to the model to generate this mask. + stability_score (float): A measure of the mask's quality. This + is filtered on using the stability_score_thresh parameter. + crop_box (list(float)): The crop of the image used to generate + the mask, given in XYWH format. + """ + + # Generate masks + mask_data = self._generate_masks(image) + + # Filter small disconnected regions and holes in masks + if self.min_mask_region_area > 0: + mask_data = self.postprocess_small_regions( + mask_data, + self.min_mask_region_area, + max(self.box_nms_thresh, self.crop_nms_thresh), + ) + + # Encode masks + if self.output_mode == "coco_rle": + mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]] + elif self.output_mode == "binary_mask": + mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]] + else: + mask_data["segmentations"] = mask_data["rles"] + + # Write mask records + curr_anns = [] + for idx in range(len(mask_data["segmentations"])): + ann = { + "segmentation": mask_data["segmentations"][idx], + "area": area_from_rle(mask_data["rles"][idx]), + "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), + "predicted_iou": mask_data["iou_preds"][idx].item(), + "point_coords": [mask_data["points"][idx].tolist()], + "stability_score": mask_data["stability_score"][idx].item(), + "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), + } + curr_anns.append(ann) + + return curr_anns + + def _generate_masks(self, image: np.ndarray) -> MaskData: + orig_size = image.shape[:2] + crop_boxes, layer_idxs = generate_crop_boxes( + orig_size, self.crop_n_layers, self.crop_overlap_ratio + ) + + # Iterate over image crops + data = MaskData() + for crop_box, layer_idx in zip(crop_boxes, layer_idxs): + crop_data = self._process_crop(image, crop_box, layer_idx, orig_size) + data.cat(crop_data) + + # Remove duplicate masks between crops + if len(crop_boxes) > 1: + # Prefer masks from smaller crops + scores = 1 / box_area(data["crop_boxes"]) + scores = scores.to(data["boxes"].device) + keep_by_nms = batched_nms( + data["boxes"].float(), + scores, + torch.zeros_like(data["boxes"][:, 0]), # categories + iou_threshold=self.crop_nms_thresh, + ) + data.filter(keep_by_nms) + + data.to_numpy() + return data + + def _process_crop( + self, + image: np.ndarray, + crop_box: List[int], + crop_layer_idx: int, + orig_size: Tuple[int, ...], + ) -> MaskData: + # Crop the image and calculate embeddings + x0, y0, x1, y1 = crop_box + cropped_im = image[y0:y1, x0:x1, :] + cropped_im_size = cropped_im.shape[:2] + self.predictor.set_image(cropped_im) + + # Get points for this crop + points_scale = np.array(cropped_im_size)[None, ::-1] + points_for_image = self.point_grids[crop_layer_idx] * points_scale + + # Generate masks for this crop in batches + data = MaskData() + for (points,) in batch_iterator(self.points_per_batch, points_for_image): + batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size) + data.cat(batch_data) + del batch_data + self.predictor.reset_image() + + # Remove duplicates within this crop. + keep_by_nms = batched_nms( + data["boxes"].float(), + data["iou_preds"], + torch.zeros_like(data["boxes"][:, 0]), # categories + iou_threshold=self.box_nms_thresh, + ) + data.filter(keep_by_nms) + + # Return to the original image frame + data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box) + data["points"] = uncrop_points(data["points"], crop_box) + data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) + + return data + + def _process_batch( + self, + points: np.ndarray, + im_size: Tuple[int, ...], + crop_box: List[int], + orig_size: Tuple[int, ...], + ) -> MaskData: + orig_h, orig_w = orig_size + + # Run model on this batch + transformed_points = self.predictor.transform.apply_coords(points, im_size) + in_points = torch.as_tensor(transformed_points, device=self.predictor.device) + in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) + masks, iou_preds, _ = self.predictor.predict_torch( + in_points[:, None, :], + in_labels[:, None], + multimask_output=True, + return_logits=True, + ) + + # Serialize predictions and store in MaskData + data = MaskData( + masks=masks.flatten(0, 1), + iou_preds=iou_preds.flatten(0, 1), + points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)), + ) + del masks + + # Filter by predicted IoU + if self.pred_iou_thresh > 0.0: + keep_mask = data["iou_preds"] > self.pred_iou_thresh + data.filter(keep_mask) + + # Calculate stability score + data["stability_score"] = calculate_stability_score( + data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset + ) + if self.stability_score_thresh > 0.0: + keep_mask = data["stability_score"] >= self.stability_score_thresh + data.filter(keep_mask) + + # Threshold masks and calculate boxes + data["masks"] = data["masks"] > self.predictor.model.mask_threshold + data["boxes"] = batched_mask_to_box(data["masks"]) + + # Filter boxes that touch crop boundaries + keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h]) + if not torch.all(keep_mask): + data.filter(keep_mask) + + # Compress to RLE + data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w) + data["rles"] = mask_to_rle_pytorch(data["masks"]) + del data["masks"] + + return data + + @staticmethod + def postprocess_small_regions( + mask_data: MaskData, min_area: int, nms_thresh: float + ) -> MaskData: + """ + Removes small disconnected regions and holes in masks, then reruns + box NMS to remove any new duplicates. + + Edits mask_data in place. + + Requires open-cv as a dependency. + """ + if len(mask_data["rles"]) == 0: + return mask_data + + # Filter small disconnected regions and holes + new_masks = [] + scores = [] + for rle in mask_data["rles"]: + mask = rle_to_mask(rle) + + mask, changed = remove_small_regions(mask, min_area, mode="holes") + unchanged = not changed + mask, changed = remove_small_regions(mask, min_area, mode="islands") + unchanged = unchanged and not changed + + new_masks.append(torch.as_tensor(mask).unsqueeze(0)) + # Give score=0 to changed masks and score=1 to unchanged masks + # so NMS will prefer ones that didn't need postprocessing + scores.append(float(unchanged)) + + # Recalculate boxes and remove any new duplicates + masks = torch.cat(new_masks, dim=0) + boxes = batched_mask_to_box(masks) + keep_by_nms = batched_nms( + boxes.float(), + torch.as_tensor(scores), + torch.zeros_like(boxes[:, 0]), # categories + iou_threshold=nms_thresh, + ) + + # Only recalculate RLEs for masks that have changed + for i_mask in keep_by_nms: + if scores[i_mask] == 0.0: + mask_torch = masks[i_mask].unsqueeze(0) + mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] + mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly + mask_data.filter(keep_by_nms) + + return mask_data diff --git a/segment_anything/build_sam.py b/segment_anything/build_sam.py new file mode 100644 index 0000000..37cd245 --- /dev/null +++ b/segment_anything/build_sam.py @@ -0,0 +1,107 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from functools import partial + +from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer + + +def build_sam_vit_h(checkpoint=None): + return _build_sam( + encoder_embed_dim=1280, + encoder_depth=32, + encoder_num_heads=16, + encoder_global_attn_indexes=[7, 15, 23, 31], + checkpoint=checkpoint, + ) + + +build_sam = build_sam_vit_h + + +def build_sam_vit_l(checkpoint=None): + return _build_sam( + encoder_embed_dim=1024, + encoder_depth=24, + encoder_num_heads=16, + encoder_global_attn_indexes=[5, 11, 17, 23], + checkpoint=checkpoint, + ) + + +def build_sam_vit_b(checkpoint=None): + return _build_sam( + encoder_embed_dim=768, + encoder_depth=12, + encoder_num_heads=12, + encoder_global_attn_indexes=[2, 5, 8, 11], + checkpoint=checkpoint, + ) + + +sam_model_registry = { + "default": build_sam_vit_h, + "vit_h": build_sam_vit_h, + "vit_l": build_sam_vit_l, + "vit_b": build_sam_vit_b, +} + + +def _build_sam( + encoder_embed_dim, + encoder_depth, + encoder_num_heads, + encoder_global_attn_indexes, + checkpoint=None, +): + prompt_embed_dim = 256 + image_size = 1024 + vit_patch_size = 16 + image_embedding_size = image_size // vit_patch_size + sam = Sam( + image_encoder=ImageEncoderViT( + depth=encoder_depth, + embed_dim=encoder_embed_dim, + img_size=image_size, + mlp_ratio=4, + norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), + num_heads=encoder_num_heads, + patch_size=vit_patch_size, + qkv_bias=True, + use_rel_pos=True, + global_attn_indexes=encoder_global_attn_indexes, + window_size=14, + out_chans=prompt_embed_dim, + ), + prompt_encoder=PromptEncoder( + embed_dim=prompt_embed_dim, + image_embedding_size=(image_embedding_size, image_embedding_size), + input_image_size=(image_size, image_size), + mask_in_chans=16, + ), + mask_decoder=MaskDecoder( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + ), + pixel_mean=[123.675, 116.28, 103.53], + pixel_std=[58.395, 57.12, 57.375], + ) + sam.eval() + if checkpoint is not None: + with open(checkpoint, "rb") as f: + state_dict = torch.load(f) + sam.load_state_dict(state_dict) + return sam diff --git a/segment_anything/modeling/__init__.py b/segment_anything/modeling/__init__.py new file mode 100644 index 0000000..38e9062 --- /dev/null +++ b/segment_anything/modeling/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .sam import Sam +from .image_encoder import ImageEncoderViT +from .mask_decoder import MaskDecoder +from .prompt_encoder import PromptEncoder +from .transformer import TwoWayTransformer diff --git a/segment_anything/modeling/__pycache__/__init__.cpython-39.pyc b/segment_anything/modeling/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000..bcd7ebd Binary files /dev/null and b/segment_anything/modeling/__pycache__/__init__.cpython-39.pyc differ diff --git a/segment_anything/modeling/__pycache__/common.cpython-39.pyc b/segment_anything/modeling/__pycache__/common.cpython-39.pyc new file mode 100644 index 0000000..231e51c Binary files /dev/null and b/segment_anything/modeling/__pycache__/common.cpython-39.pyc differ diff --git a/segment_anything/modeling/__pycache__/image_encoder.cpython-39.pyc b/segment_anything/modeling/__pycache__/image_encoder.cpython-39.pyc new file mode 100644 index 0000000..384362a Binary files /dev/null and b/segment_anything/modeling/__pycache__/image_encoder.cpython-39.pyc differ diff --git a/segment_anything/modeling/__pycache__/mask_decoder.cpython-39.pyc b/segment_anything/modeling/__pycache__/mask_decoder.cpython-39.pyc new file mode 100644 index 0000000..2b30e9b Binary files /dev/null and b/segment_anything/modeling/__pycache__/mask_decoder.cpython-39.pyc differ diff --git a/segment_anything/modeling/__pycache__/prompt_encoder.cpython-39.pyc b/segment_anything/modeling/__pycache__/prompt_encoder.cpython-39.pyc new file mode 100644 index 0000000..d813e32 Binary files /dev/null and b/segment_anything/modeling/__pycache__/prompt_encoder.cpython-39.pyc differ diff --git a/segment_anything/modeling/__pycache__/sam.cpython-39.pyc b/segment_anything/modeling/__pycache__/sam.cpython-39.pyc new file mode 100644 index 0000000..25c726e Binary files /dev/null and b/segment_anything/modeling/__pycache__/sam.cpython-39.pyc differ diff --git a/segment_anything/modeling/__pycache__/transformer.cpython-39.pyc b/segment_anything/modeling/__pycache__/transformer.cpython-39.pyc new file mode 100644 index 0000000..6b30407 Binary files /dev/null and b/segment_anything/modeling/__pycache__/transformer.cpython-39.pyc differ diff --git a/segment_anything/modeling/common.py b/segment_anything/modeling/common.py new file mode 100644 index 0000000..10e5e15 --- /dev/null +++ b/segment_anything/modeling/common.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + +from typing import Type + + +class MLPBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + mlp_dim: int, + act: Type[nn.Module] = nn.GELU, + ) -> None: + super().__init__() + self.lin1 = nn.Linear(embedding_dim, mlp_dim) + self.lin2 = nn.Linear(mlp_dim, embedding_dim) + self.act = act() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.lin2(self.act(self.lin1(x))) + + +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x diff --git a/segment_anything/modeling/image_encoder.py b/segment_anything/modeling/image_encoder.py new file mode 100644 index 0000000..a8ae979 --- /dev/null +++ b/segment_anything/modeling/image_encoder.py @@ -0,0 +1,426 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from typing import Optional, Tuple, Type + +from .common import LayerNorm2d, MLPBlock + +# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa +# 这个代码定义了 SAM 的图像编码器 ImageEncoderViT。它包含以下主要部分: +# 1. patch_embed: 这是 ViT 的 patch embedding 层,用于将输入图像划分为 patch,并获得 patch 的 embedding。 +# 2. pos_embed: 这是 ViT的绝对位置 embedding,用于为每个patch提供位置信息。 +# 3. blocks: 这是 ViT 的 transformer encoder 块的列表,每个块包含多头自注意力层和前馈神经网络。 +# 4. neck: 这是图像编码器的“颈部”,包含几个卷积层和 LayerNorm 层,用于从 transformer encoder 块的输出中提取特征。 +# 5. forward(): 这是图像编码器的前向传播过程。首先通过 patch_embed 层获得 patch embedding, 然后加上 pos_embed。 +# 接着,patch embedding通过transformer encoder块。最后, neck 层从 transformer encoder 块的输出中提取特征。 +# 所以,这个 ImageEncoderViT 类定义了 SAM 的图像编码器,它基于 ViT,包含 patch embedding、位置 embedding、 +# transformer encoder块以及 neck, 可以从输入图像中提取特征。 +# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa +class ImageEncoderViT(nn.Module): + def __init__( + self, + img_size: int = 1024, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + out_chans: int = 256, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_abs_pos: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + global_attn_indexes: Tuple[int, ...] = (), + ) -> None: + + super().__init__() + self.img_size = img_size + + self.patch_embed = PatchEmbed( + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size), + in_chans=in_chans, + embed_dim=embed_dim, + ) + + self.pos_embed: Optional[nn.Parameter] = None + if use_abs_pos: + # Initialize absolute positional embedding with pretrain image size. + self.pos_embed = nn.Parameter( + torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) + ) + + self.blocks = nn.ModuleList() + for i in range(depth): + block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + act_layer=act_layer, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + window_size=window_size if i not in global_attn_indexes else 0, + input_size=(img_size // patch_size, img_size // patch_size), + ) + self.blocks.append(block) + + self.neck = nn.Sequential( + nn.Conv2d( + embed_dim, + out_chans, + kernel_size=1, + bias=False, + ), + LayerNorm2d(out_chans), + nn.Conv2d( + out_chans, + out_chans, + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(out_chans), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.patch_embed(x) + if self.pos_embed is not None: + x = x + self.pos_embed + + for blk in self.blocks: + x = blk(x) + + x = self.neck(x.permute(0, 3, 1, 2)) + + return x +# 这个 Block 类实现了 transformer block, 可以选择使用全局注意力或局部窗口注意力,同时包含残差连接。它包含: +# __init__方法: +# 1. 输入参数: +# - dim: 输入通道数 +# - num_heads: 注意力头数 +# - mlp_ratio: mlp 隐藏层与输入 embedding 维度的比例 +# - qkv_bias: 是否为 query、key、value 添加偏置 +# - norm_layer: 归一化层 +# - act_layer: 激活层 +# - use_rel_pos: 是否使用相对位置 embedding +# - rel_pos_zero_init: 是否将相对位置 embedding 初始化为 0 +# - window_size: 窗口注意力的窗口大小,如果为 0 则使用全局注意力 +# - input_size: 计算相对位置 embedding 大小所需的输入分辨率 +# 2. 实例化第 1 次和第 2 次归一化层 norm1 和 norm2。 +# 3. 实例化 Attention 层和 MLPBlock 层。Attention 层的输入大小根据是否使用窗口注意力进行了调整。 +# 4. 记录窗口注意力的窗口大小 window_size。 +# forward方法: +# 1. 提取 shortcut 并对 x 进行第 1 次归一化。 +# 2. 如果使用窗口注意力, 则调用 window_partition 对 x 进行窗口划分。 +# 3. 将 x 输入 Attention 层。 +# 4. 如果使用窗口注意力,则调用 window_unpartition 对 x 进行窗口反划分。 +# 5. x = shortcut + x,实现第 1 次残差连接。 +# 6. x = x + mlp(norm2(x)),实现第 2 次残差连接和 MLPBlock。 +# 7. 返回最终的 x。 +# 所以,这个 Block 类实现了带有可选的窗口注意力和双残差连接的transformer block。 +# 窗口注意力可以更好地建模局部结构,双残差连接可以提高梯度流动,都是transformer结构的重要改进。 +# 这个 Block 类实现了 transformer 的关键组成部分,同时提供了窗口注意力和残差连接等重要变体,可以显著提高其表现力和泛化能力。 + +class Block(nn.Module): + """Transformer blocks with support of window attention and residual propagation blocks""" + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + input_size=input_size if window_size == 0 else (window_size, window_size), + ) + + self.norm2 = norm_layer(dim) + self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) + + self.window_size = window_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x + x = self.norm1(x) + # Window partition + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + + x = self.attn(x) + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + + x = shortcut + x + x = x + self.mlp(self.norm2(x)) + + return x + +# 这个Attention类实现了多头注意力机制,可以加入相对位置 embedding。它包含: +# __init__方法: +# 1. 输入参数: +# - dim: 输入通道数 +# - num_heads: 注意力头数 +# - qkv_bias: 是否为查询、键、值添加偏置 +# - use_rel_pos: 是否使用相对位置 embedding +# - rel_pos_zero_init: 是否将相对位置 embedding 初始化为0 +# - input_size: 计算相对位置 embedding 大小所需的输入分辨率 +# 2. 计算每个注意力头的维度 head_dim。 +# 3. 实例化 self.qkv和 输出投影 self.proj。 +# 4. 如果使用相对位置 embedding, 则初始化 rel_pos_h 和 rel_pos_w。 +# forward方法: +# 1. 从输入 x 中提取批次大小 B、高度 H、宽度 W 和通道数 C。 +# 2. 计算 qkv,形状为 (3, B, nHead, H * W, C), 包含 query、key 和 value。 +# 3. 提取 q、 k 和 v, 形状为 (B * nHead, H * W, C)。 +# 4. 计算注意力图 attn,形状为 (B * nHead, H * W, H * W)。 +# 5. 如果使用相对位置 embedding, 则调用 add_decomposed_rel_pos 函数将其加入 attn。 +# 6. 对 attn 进行 softmax 归一化。 +# 7. 计算输出 x , (attn @ v), 形状为 (B, nHead, H, W, C), 然后合并注意力头, 形状为(B, H, W, C)。 +# 8. 对 x 进行投影, 返回最终的输出。 +# 所以,这个 Attention 类实现了带有相对位置 embedding 的多头注意力机制。 +# 它可以高效地建模图像和视频等二维结构数据,是 transformer 在这些领域得到广泛应用的关键。 +# 这个 Attention 类提供了相对位置 embedding 和多头注意力机制的实现, +# 是理解 transformer 在图像和视频建模中的重要组成部分。 +class Attention(nn.Module): + """Multi-head Attention block with relative position embeddings.""" + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.use_rel_pos = use_rel_pos + if self.use_rel_pos: + assert ( + input_size is not None + ), "Input size must be provided if using relative positional encoding." + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, H, W, _ = x.shape + # qkv with shape (3, B, nHead, H * W, C) + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + # q, k, v with shape (B * nHead, H * W, C) + q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) + + attn = (q * self.scale) @ k.transpose(-2, -1) + + if self.use_rel_pos: + attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) + + attn = attn.softmax(dim=-1) + x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) + x = self.proj(x) + + return x +# 这个 window_partition 函数的作用是将输入张量划分为非重叠的窗口。它包含: +# 1. 输入参数: + # - x: 输入的张量,形状为 [B, H, W, C] + # - window_size: 窗口大小 +# 2. 首先计算输入需要 padding 的高度和宽度,将x进行padding。 +# 3. 然后将 x 的形状变化为 [B, Hp//window_size, window_size, Wp//window_size, window_size, C], +# 表示将图像划分为 Hp//window_size * Wp//window_size 个 window_size * window_size 的 patch。 +# 4. 最后,通过 permute 和 view 操作,得到 windows 的形状为 [B * num_windows, window_size, window_size, C], +# 表示将所有 patch 打平, num_windows 是 patch 的总数 +# 5. 返回windows和原来的高度和宽度(包含padding)Hp和Wp。 +# 所以,这个 window_partition 函数的作用是,将输入的图像划分为 window_size * window_size 的 patch, +# 并将所有的 patch 打平, 输出可以输入到 transformer encoder 中的 token 序列。 +# 这个函数实现了将二维图像转化为一维 token 序列的过程,是 transformer 用于处理图像的一个关键步骤。 +# 通过这个函数,图像可以被 transformer encoder 所处理,就像处理文本序列一样。 + +def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: + + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows, (Hp, Wp) + +# 这个 window_unpartition 函数的作用是将 window_partition 函数的输出进行反划分, 恢复成原始的图像形状。它包含: +# 1. 输入参数: + # - windows: window_partition的输出,形状为 [B * num_windows, window_size, window_size, C] + # - window_size: 窗口大小 + # - pad_hw: padding后的高度和宽度 (Hp, Wp) + # - hw: padding前的原始高度和宽度 (H, W) +# 2. 首先根据窗口大小和 padding 后的 hw 计算原始的 batch_size B。 +# 3. 然后将 windows 的形状变回 [B, Hp//window_size, Wp//window_size, window_size, window_size, C], 表示每个patch的位置。 +# 4. 接着通过permute和view操作,得到x的形状为 [B, Hp, Wp, C], 恢复成图像的形状。 +# 5. 最后,如果进行了padding,则截取x到原始的高度H和宽度W。 +# 6. 返回恢复后的图像x。 +# 所以,这个 window_unpartition 函数的作用是将通过 window_partition 函数得到的 patch 序列恢复成原始的图像。 +# 它实现了从一维 patch token 序列到二维图像的反过程。 +# 这个函数与 window_partition 函数相反,使得 transformer 能够最终从 patch token 序列恢复成图像,完成对图像的建模。 +# 总的来说,这个 window_unpartition 函数实现了从 patch token 序列恢复成原始图像的过程,与 window_partition 函数相对应, +# 是使得 transformer 可以处理图像的另一个关键步骤 +def window_unpartition( + windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] +) -> torch.Tensor: + + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + +# 这个 get_rel_pos 函数的作用是根据 query 和 key 的相对位置获取相对位置 embedding。它包含: +# 1. 输入参数: +# - q_size: query 的大小 +# - k_size: key 的大小 +# - rel_pos: 相对位置 embedding, 形状为[L, C] +# 2. 首先计算最大的相对距离 max_rel_dist, 它等于 query 和 key 大小的 2 倍减 1。 +# 3. 如果相对位置 embedding 的长度小于 max_rel_dist, 则通过线性插值将其调整到 max_rel_dist 的长度。 +# 4. 如果 q_size 和 k_size 不同, 则将 q_size 和 k_size 的坐标按比例缩放,使它们之间的相对距离保持不变。 +# 5. 根据调整后的 q_size 和 k_size 坐标计算相对坐标 relative_coords。 +# 6. 根据 relative_coords 从 rel_pos_resized 中提取相对位置 embedding。 +# 7. 返回提取出的相对位置 embedding。 +# 所以,这个 get_rel_pos 函数的主要作用是,当 query 和 key 的大小不同时,根据它们的相对位置关系提取相应的相对位置 embedding。 +# 它实现了相对位置 embedding 的可变长度和可缩放性。 +# 这个函数使得相对位置 embedding 可以用于 query 和 key 大小不同的 attention 中,是相对位置表示的一个关键步骤。 +# 总的来说,这个 get_rel_pos 函数实现了根据 query 和 key 的相对位置关系提取相应相对位置 embedding 的过程。 +# 它提供了相对位置 embedding 的可变长度和可缩放性,使其可以支持不同的 query 和 key 大小,从而应用到更加灵活的 attention 机制中。 +def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + +# 这个 add_decomposed_rel_pos 函数的作用是根据 query q 和 key k 的空间尺寸, 添加分解的相对位置 embedding 到注意力图 attn 中。它包含: +# 1. 输入参数: +# - attn: 注意力图,形状为 [B, q_h * q_w, k_h * k_w] +# - q: 查询 q,形状为 [B, q_h * q_w, C] +# - rel_pos_h: 高度轴的相对位置 embedding, 形状为[Lh, C] +# - rel_pos_w: 宽度轴的相对位置 embedding, 形状为[Lw, C] +# - q_size: 查询 q的空间尺寸 (q_h, q_w) +# - k_size: 键 k的空间尺寸 (k_h, k_w) +# 2. 从 q_size 和 k_size 中提取高度 q_h、宽度 q_w 以及高度 k_h、宽度 k_w。 +# 3. 调用 get_rel_pos 函数获取高度轴 Rh 和宽度轴 Rw 的相对位置 embedding。 +# 4. 重塑 q 为 [B, q_h, q_w, C]。 +# 5. 计算高度轴 rel_h 和宽度轴 rel_w 的相对位置图, 形状为 [B, q_h, q_w, k_h] 和 [B, q_h, q_w, k_w]。 +# 6. 将 attn 的形状变为 [B, q_h, q_w, k_h, k_w], 并加上 rel_h 和 rel_w。 +# 7. 将 attn 的形状变回 [B, q_h * q_w, k_h * k_w]。 +# 8. 返回加了相对位置 embedding 的 attn。 +def add_decomposed_rel_pos( + attn: torch.Tensor, + q: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], +) -> torch.Tensor: + + q_h, q_w = q_size + k_h, k_w = k_size + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + + B, _, dim = q.shape + r_q = q.reshape(B, q_h, q_w, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) + + attn = ( + attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] + ).view(B, q_h * q_w, k_h * k_w) + + return attn + +# 这个 PatchEmbed 类定义了 ViT 的 patch embedding 层。它包含: +# 1. __init__: 初始化,设置卷积层的 kernel size、stride、padding以 及输入通道数和 embedding 维度。 +# 2. proj: 这是一个卷积层,用于将输入图像划分为 patch, 并获得每个 patch 的 embedding。 +# 3. forward: 前向传播过程。首先通过 proj 卷积层获得 patch embedding ,然后将维度从 [B, C, H, W] 转置成 [B, H, W, C]。 + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__( + self, + kernel_size: Tuple[int, int] = (16, 16), + stride: Tuple[int, int] = (16, 16), + padding: Tuple[int, int] = (0, 0), + in_chans: int = 3, + embed_dim: int = 768, + ) -> None: + + super().__init__() + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + return x diff --git a/segment_anything/modeling/mask_decoder.py b/segment_anything/modeling/mask_decoder.py new file mode 100644 index 0000000..fbeb0f6 --- /dev/null +++ b/segment_anything/modeling/mask_decoder.py @@ -0,0 +1,209 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn +from torch.nn import functional as F + +from typing import List, Tuple, Type + +from .common import LayerNorm2d + + +class MaskDecoder(nn.Module): + # __init__方法: + # 1. 输入参数: + # - transformer_dim: transformer 的通道维度 + # - transformer: 使用的 transformer + # - num_multimask_outputs: 在消除掩码歧义时预测的掩码数量。 + # - activation: 上采样掩码时使用的激活函数类型 + # - iou_head_depth: 用于预测掩码质量的 MLP 的深度 + # - iou_head_hidden_dim: 用于预测掩码质量的 MLP 的隐藏维度 + # 2. 记录 transformer_dim 和 transformer。 + # 3. 记录 num_multimask_outputs。 + # 4. 嵌入 iou_token 和 mask_tokens。 + # 5. 定义 output_upscaling 为上采样器,用于上采样 transformer 的输出以得到掩码。 + # 6. 定义 output_hypernetworks_mlps 为 MLP 列表,个数为 num_mask_tokens, 用于从 transformer 的输出生成掩码通道。 + # 7. 定义 iou_prediction_head 为 MLP,用于从 transformer 的输出预测掩码的 IOU。 + def __init__( + self, + *, + transformer_dim: int, + transformer: nn.Module, + num_multimask_outputs: int = 3, + activation: Type[nn.Module] = nn.GELU, + iou_head_depth: int = 3, + iou_head_hidden_dim: int = 256, + ) -> None: + + super().__init__() + self.transformer_dim = transformer_dim + self.transformer = transformer + + self.num_multimask_outputs = num_multimask_outputs + + self.iou_token = nn.Embedding(1, transformer_dim) + self.num_mask_tokens = num_multimask_outputs + 1 + self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) + + self.output_upscaling = nn.Sequential( + nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), + LayerNorm2d(transformer_dim // 4), + activation(), + nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), + activation(), + ) + self.output_hypernetworks_mlps = nn.ModuleList( + [ + MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) + for i in range(self.num_mask_tokens) + ] + ) + + self.iou_prediction_head = MLP( + transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth + ) + + # 这个 forward 方法的作用是根据图像和 prompt 的 embedding 预测掩码。它包含: + # 1. 输入参数: + # - image_embeddings: 图像编码器的输出 + # - image_pe: 与 image_embeddings 形状相同的位置编码 + # - sparse_prompt_embeddings: 点和框的 embedding + # - dense_prompt_embeddings: 掩码输入的 embedding + # - multimask_output: 是否返回多个掩码或单个掩码 + # 2. 调用 predict_masks 根据图像和 prompt 的 embedding 预测掩码 masks 和掩码质量 iou_pred。 + # 3. 如果 multimask_output 为 True,则选择 masks 的第 1 个维度后的全部切片。否则选择第一个切片。 + # 4. 相应地选择 iou_pred 的切片。 + # 5. 准备输出,返回 masks 和 iou_pred。 + # 所以,这个 forward 方法实现了根据图像和 prompt 的 embedding 预测掩码的功能。 + # 它可以根据输入的 prompt 学习掩码生成的高度非线性映射,为 prompt 驱动生成模型提供掩码预测的关键能力。 + # 这个 forward 方法提供了根据 prompt 预测掩码的具体实现。它发挥了 MaskDecoder 类的强大功能, + # 可以解码出复杂的定制化掩码,为实现高质量的 prompt 驱动生成模型提供强有力的支持。 + def forward( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + + masks, iou_pred = self.predict_masks( + image_embeddings=image_embeddings, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_prompt_embeddings, + dense_prompt_embeddings=dense_prompt_embeddings, + ) + + # Select the correct mask or masks for output + if multimask_output: + mask_slice = slice(1, None) + else: + mask_slice = slice(0, 1) + masks = masks[:, mask_slice, :, :] + iou_pred = iou_pred[:, mask_slice] + + # Prepare output + return masks, iou_pred + + # 这个 predict_masks 方法的作用是预测掩码。它包含: + # 1. 输入参数: + # - image_embeddings: 图像编码器的输出 + # - image_pe: 与 image_embeddings 形状相同的位置编码 + # - sparse_prompt_embeddings: 点和框的 embedding + # - dense_prompt_embeddings: 掩码输入的 embedding + # 2. 拼接 iou_token 和 mask_tokens 作为输出 tokens, 扩展至 batch 大小, 与 sparse_prompt_embeddings 拼接作为 tokens。 + # 3. 通过 torch.repeat_interleave 扩展 src 和 pos_src 至与 tokens 相同的 batch 大小。 + # 4. 将 src 和 pos_src 以及 tokens 输入 transformer, 获得 hs 和 src。 + # 5. 获得 iou_token_out 和 mask_tokens_out 作为 transformer 的输出。 + # 6. 上采样 src 得到 upscaled_embedding。 + # 7. 对 mask_tokens_out 中的每个 token, 使用对应 MLP 得到 hyper_in_list 中的 tensor。 + # 8. 使用 torch.stack 将 hyper_in_list 拼接为 hyper_in。 + # 9. 计算 masks=(hyper_in @ upscaled_embedding.view(b, c, h * w)), 形状为 (b, num_mask_tokens, h, w)。 + # 10. 使用 iou_prediction_head 从 iou_token_out 预测 iou_pred。 + # 11. 返回 masks 和 iou_pred。 + # 所以,这个 predict_masks 方法实现了根据prompt预测掩码的功能。 + # 它发挥 transformer 和上采样器的功能,可以从 prompt 学习生成模型的参数 + # 这个 predict_masks 方法提供了根据 prompt 预测掩码的具体实现。 + # 它利用 MaskDecoder 的强大功能,可以解码出复杂的定制化掩码,为实现高质量的 prompt 驱动生成模型提供关键支持。 + def predict_masks( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Predicts masks. See 'forward' for more details.""" + # Concatenate output tokens + output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) + output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) + + # Expand per-image data in batch direction to be per-mask + src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) + src = src + dense_prompt_embeddings + pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) + b, c, h, w = src.shape + + # Run the transformer + hs, src = self.transformer(src, pos_src, tokens) + iou_token_out = hs[:, 0, :] + mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + src = src.transpose(1, 2).view(b, c, h, w) + upscaled_embedding = self.output_upscaling(src) + hyper_in_list: List[torch.Tensor] = [] + for i in range(self.num_mask_tokens): + hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) + hyper_in = torch.stack(hyper_in_list, dim=1) + b, c, h, w = upscaled_embedding.shape + masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + + return masks, iou_pred + + +class MLP(nn.Module): + # __init__方法: + # 1. 输入参数: + # - input_dim: 输入维度 + # - hidden_dim: 隐藏层维度 + # - output_dim: 输出维度 + # - num_layers: 隐藏层数 + # - sigmoid_output: 是否使用 sigmoid 激活函数 + # 2. 记录 num_layers 和 h 为 num_layers-1 个隐藏层维度。 + # 3. 实例化 nn.ModuleList 由 nn.Linear 组成的列表,用于实现 MLP 的线性变换。 + # 4. 记录 sigmoid_output 以决定是否使用 sigmoid 激活函数。 + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + sigmoid_output: bool = False, + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + self.sigmoid_output = sigmoid_output + + # forward 方法: + # 1. 对输入 x 重复 num_layers 次线性变换和激活。 + # 2. 最后一层只使用线性变换,不使用激活函数。 + # 3. 如果 sigmoid_output 为 True, 使用 sigmoid 激活函数。 + # 4. 返回 MLP 的输出。 + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + if self.sigmoid_output: + x = F.sigmoid(x) + return x diff --git a/segment_anything/modeling/prompt_encoder.py b/segment_anything/modeling/prompt_encoder.py new file mode 100644 index 0000000..7370a5b --- /dev/null +++ b/segment_anything/modeling/prompt_encoder.py @@ -0,0 +1,282 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from torch import nn + +from typing import Any, Optional, Tuple, Type + +from .common import LayerNorm2d + +# 这个 PromptEncoder 类实现了 prompt 的编码,为 mask解码器 提供prompt输入。它包含: +# __init__ 方法: +# 1. 输入参数: +# - embed_dim: prompt 的 embedding 维度 +# - image_embedding_size: 图像 embedding 的空间大小,表示为(H, W) +# - input_image_size: 输入到图像编码器的填充后的图像尺寸,表示为 (H, W)。 +# - mask_in_chans: 用于编码输入掩码的隐藏通道数 +# - activation: 用于编码输入掩码的激活函数 +# 2. 记录 embed_dim、image_size 和 image_embedding_size。 +# 3. 实例化 PositionEmbeddingRandom 作为位置 embedding 层 pe_layer。 +# 4. 实例化 4个 Embedding 层作为点 prompt 的 embedding,以及 not_a_point_embed 用于非点 prompt。 +# 5. 计算掩码输入大小 mask_input_size 为 (4 * image_embedding_size[0], 4 * image_embedding_size[1])。 +# 6. 实例化 mask_downscaling 为多个 Conv2d 和 LayerNorm2d 层,用于下采样和编码输入掩码。 +# 7. 实例化 no_mask_embed 用于无掩码 prompt 的 embedding。 +class PromptEncoder(nn.Module): + def __init__( + self, + embed_dim: int, + image_embedding_size: Tuple[int, int], + input_image_size: Tuple[int, int], + mask_in_chans: int, + activation: Type[nn.Module] = nn.GELU, + ) -> None: + + super().__init__() + self.embed_dim = embed_dim + self.input_image_size = input_image_size + self.image_embedding_size = image_embedding_size + self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) + + self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners + point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] + self.point_embeddings = nn.ModuleList(point_embeddings) + self.not_a_point_embed = nn.Embedding(1, embed_dim) + + self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) + self.mask_downscaling = nn.Sequential( + nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans // 4), + activation(), + nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans), + activation(), + nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), + ) + self.no_mask_embed = nn.Embedding(1, embed_dim) + + # 这个 get_dense_pe 方法的作用是返回用于对点 prompt 进行编码的密集位置编码。它包含: + # 1. 调用 pe_layer(image_embedding_size) 得到形状为 (embed_dim)x(embedding_h)x(embedding_w) 的位置编码, + # image_embedding_size 是图像 embedding 的空间大小。 + # 2. 使用 unsqueeze(0) 增加 batch 维度,得到形状为 1x(embed_dim)x(embedding_h)x(embedding_w) 的位置编码。 + # 3. 返回该位置编码用于对点 prompt 进行编码。 + # 所以,这个 get_dense_pe 方法的作用就是返回一个密集的位置编码,该位置编码具有和图像 embedding 相同的空间尺寸, + # 用于对点 prompt 进行位置编码,从而得到丰富的 prompt 表达。 + def get_dense_pe(self) -> torch.Tensor: + + return self.pe_layer(self.image_embedding_size).unsqueeze(0) + + # 这个 _embed_points 方法的作用是对点 prompt 进行 embedding。它包含: + # 1. 将 points 中的坐标增加 0.5,将其移至像素中心。 + # 2. 如果 pad 为 True,则会在 points 上追加一个坐标为 [0,0] 和 label 为 -1 的额外点, 并相应地扩充labels。这是用于当未提供 bbox 时的补齐。 + # 3. 调用 pe_layer.forward_with_coords 对 points 进行位置编码,得到 point_embedding。 + # 4. 将 point_embedding 中 label 为 -1 的点 embedding 设置为0。 + # 5. 将 point_embedding 中label为 -1 的点 embedding 增加 not_a_point_embed 的权重。 + # 6. 根据 label 为 0 或 1, 将相应的 point_embedding 增加 point_embeddings[0] 或 point_embeddings[1] 的权重。 + # 7. 返回 point_embedding 作为点 prompt 的 embedding。 + # 所以, 这个 _embed_points 方法实现了对点 prompt 的完整 embedding 过程。 + # 它包含位置编码、分类 Embedding 和类别偏置, 可以得到表达丰富的点 prompt embedding。 + def _embed_points( + self, + points: torch.Tensor, + labels: torch.Tensor, + pad: bool, + ) -> torch.Tensor: + """Embeds point prompts.""" + points = points + 0.5 # Shift to center of pixel + if pad: + padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) + padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) + points = torch.cat([points, padding_point], dim=1) + labels = torch.cat([labels, padding_label], dim=1) + point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) + point_embedding[labels == -1] = 0.0 + point_embedding[labels == -1] += self.not_a_point_embed.weight + point_embedding[labels == 0] += self.point_embeddings[0].weight + point_embedding[labels == 1] += self.point_embeddings[1].weight + return point_embedding + + # 这个 _embed_boxes 方法的作用是对框 prompt 进行 embedding。它包含: + # 1. 将 boxes 中的坐标增加 0.5, 将其移至像素中心。 + # 2. 将 boxes reshape 为形状为 (-1, 2, 2) 的张量 coords, 包含框的左上角和右下角坐标。 + # 3. 调用 pe_layer.forward_with_coords 对 coords 进行位置编码,得到 corner_embedding。 + # 4. 将 corner_embedding 中的第 0 维(左上角)增加 point_embeddings[2] 的权重。 + # 5. 将 corner_embedding 中的第 1 维(右下角)增加 point_embeddings[3] 的权重。 + # 6. 返回 corner_embedding 作为框 prompt 的 embedding。 + # 所以, 这个 _embed_boxes 方法实现了对框 prompt 的 embedding。它对框的左上角和右下角坐标进行了位置编码, + # 并增加相应的角点 Embedding, 可以得到表达丰富的框 prompt embedding。 + # 这个 _embed_boxes 方法提供了框 prompt embedding 的详细实现, + # 包含位置编码和框角点 Embedding, 是理解框 prompt 表达的基础。 + # 总的来说,这个 _embed_boxes 方法实现了框 prompt 的 EMBEDDING 过程, + # 可以获取表达丰富的框 prompt embedding, 为掩码解码器提供有效的 prompt 输入。 + # 这个 _embed_boxes 方法与 _embed_points 方法一起, + # 实现了对点 prompt 和框 prompt 的完整 embedding 流程, + # 可以为掩码解码器提供丰富多样的 prompt 表达, generate高质量的掩码输出。 + def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: + """Embeds box prompts.""" + boxes = boxes + 0.5 # Shift to center of pixel + coords = boxes.reshape(-1, 2, 2) + corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) + corner_embedding[:, 0, :] += self.point_embeddings[2].weight + corner_embedding[:, 1, :] += self.point_embeddings[3].weight + return corner_embedding + + # 这个 _embed_masks 方法的作用是对掩码输入进行 embedding 。它包含: + # 1. 将 masks 输入 mask_downscaling, 得到 mask_embedding。 + # 2. 返回 mask_embedding 作为掩码输入的 embedding。 + # 这个 _embed_masks 方法与 _embed_points 和 _embed_boxes 方法一起, + # 实现了对点 prompt、框 prompt 和掩码输入的完整 embedding, + # 可以为掩码解码器提供丰富的多模态 prompt 和上下文表达,推动生成高质量的掩码输出。 + def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: + """Embeds mask inputs.""" + mask_embedding = self.mask_downscaling(masks) + return mask_embedding + + # 这个 _get_batch_size 方法的作用是根据 prompt 输入计算输出的 batch size。它包含: + # 1. 如果 points 不为 None,则返回 points[0] 的第 0 维作为 batch size。points[0] 中包含 prompt 坐标。 + # 2. 如果 boxes 不为 None,则返回 boxes 的第 0 维作为 batch size。boxes 中包含 prompt 框选坐标。 + # 3. 如果 masks 不为 None,则返回 masks 的第 0 维作为 batch size。masks 中包含 prompt 掩码输入。 + # 4. 否则返回 1 作为 batch size。 + # 所以,这个 _get_batch_size 方法根据是否输入了点 prompt、框 prompt或掩码 prompt,返回相应的batch size。 + # 如果未输入任何 prompt,则返回 1 作为 batch size。 + # 这个 _get_batch_size 方法提供了根据 prompt 输入推断输出 batch size的 简单实现,是设计基于 prompt 的生成模型的常用技巧。 + def _get_batch_size( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> int: + """ + Gets the batch size of the output given the batch size of the input prompts. + """ + if points is not None: + return points[0].shape[0] + elif boxes is not None: + return boxes.shape[0] + elif masks is not None: + return masks.shape[0] + else: + return 1 + + # 这个 _get_device 方法的作用很简单,就是返回点 prompt 的第一个 Embedding 层 point_embeddings[0] + # 的权重参数 weight 所在的设备,作为 PromptEncoder 的设备。 + def _get_device(self) -> torch.device: + return self.point_embeddings[0].weight.device + + # 这个 forward 方法的作用是对各种 prompt 进行 embedding, 并返回稀疏 embedding 和密集 embedding。它包含: + # 1. 调用 _get_batch_siz e根据点 prompt、框 prompt和掩码 prompt计算输出的 batch size bs。 + # 2. 初始化稀疏 embedding为形状为 (bs, 0, self.embed_dim) 的空张量, 设备为 _get_device() 的返回设备。 + # 3. 如果 points 不为 None,则调用 _embed_points 对点 prompt 进行 embedding, + # 得到 point_embeddings, 并将其拼接到 sparse_embeddings。 + # 4. 如果 boxes 不为 None,则调用 _embed_boxes 对框 prompt 进行 embedding, + # 得到 box_embeddings, 并将其拼接到 sparse_embeddings。 + # 5. 如果 masks 不为 None, 则调用 _embed_masks 对掩码 prompt 进行 embedding, 得到 dense_embeddings。 + # 6. 否则, 将 no_mask_embed 的权重 reshape 并扩展为形状为 (bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]) + # 的张量作为 dense_embeddings。 + # 7. 返回 sparse_embeddings 和 dense_embeddings 作为稀疏 embedding 和密集 embedding。 + # 所以,这个 forward 方法实现了对点 prompt、框 prompt 和掩码 prompt 的 embedding, + # 可以得到表达丰富的稀疏 embedding 和密集 embedding, 为下游的解码器提供复杂的 prompt 表达。 + # 这个 forward 方法提供了 prompt 的完整 embedding 流程,包含对三种 prompt 的处理, + # 可以获得多模态的 prompt 表达,为实现高质量的 prompt 驱动生成模型打下了基础。 + # 总的来说,这个 forward 方法实现了 prompt 的 ENCODING 过程, 可以获取稀疏 embedding 和密集 embedding + # 两种 prompt 表达,为实现高质量的多模态 prompt 驱动生成模型提供支持。 + def forward( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + + bs = self._get_batch_size(points, boxes, masks) + sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) + if points is not None: + coords, labels = points + point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) + sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) + if boxes is not None: + box_embeddings = self._embed_boxes(boxes) + sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) + + if masks is not None: + dense_embeddings = self._embed_masks(masks) + else: + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) + + return sparse_embeddings, dense_embeddings +# 这个 PositionEmbeddingRandom 类实现了随机空间频率的位置编码。它包含: + +# 所以,这个PositionEmbeddingRandom类实现了对图像坐标点的随机位置编码。它可以对归一化坐标和非归一化坐标进行编码,为PromptEncoder提供位置编码能力,显著丰富prompt的表达。 +# 这个PositionEmbeddingRandom类提供了随机位置编码的实现,为PromptEncoder类带来位置表达的能力,可以丰富prompt的表达,提高prompt驱动生成的质量。 +# 总的来说,这个PositionEmbeddingRandom类实现了用于prompt位置编码的随机位置编码器,为PromptEncoder类提供位置编码能力,可以丰富prompt的表达,显著提高prompt驱动生成的质量。 + +class PositionEmbeddingRandom(nn.Module): + """ + Positional encoding using random spatial frequencies. + """ + + # __init__ 方法: + # 1. 输入参数: + # - num_pos_feats: 位置编码的特征数 + # - scale: 位置编码的 scale, 默认为 1.0 + def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: + super().__init__() + if scale is None or scale <= 0.0: + scale = 1.0 + self.register_buffer( + "positional_encoding_gaussian_matrix", + scale * torch.randn((2, num_pos_feats)), + ) + + # _pe_encoding方法: + # 1. 输入参数 coords 为归一化到 [0,1] 的坐标点。 + # 2. 将 coords 映射到 [-1,1] 区间。 + # 3. 将 coords 与 positional_encoding_gaussian_matrix 相乘。 + # 4. 将结果乘以 2*π。 + # 5. 拼接 sin 和 cos 作为位置编码,返回形状为 (d_1, ..., d_n, C) 的张量。 + def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: + """Positionally encode points that are normalized to [0,1].""" + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coords = 2 * coords - 1 + coords = coords @ self.positional_encoding_gaussian_matrix + coords = 2 * np.pi * coords + # outputs d_1 x ... x d_n x C shape + return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) + + # forward 方法: + # 1. 输入参数 size 为 (H, W) 的网格大小。 + # 2. 生成一个形状为 (H, W) 的网格,并获得 y 和 x 轴的顺序编码。 + # 3. 归一化 y_embed 和 x_embed 到 [0,1] 区间。 + # 4. 调用 _pe_encoding 对 x_embed 和 y_embed 进行位置编码。 + # 5. 返回位置编码,形状为 (C, H, W) 。 + def forward(self, size: Tuple[int, int]) -> torch.Tensor: + """Generate positional encoding for a grid of the specified size.""" + h, w = size + device: Any = self.positional_encoding_gaussian_matrix.device + grid = torch.ones((h, w), device=device, dtype=torch.float32) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / h + x_embed = x_embed / w + + pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) + return pe.permute(2, 0, 1) # C x H x W + + # forward_with_coords方法: + # 1. 输入参数 coords_input 为未归一化到 [0,1] 的坐标, image_size 为 (H, W) 的图像大小。 + # 2. 归一化 coords_input 到 [0,1] 区间。 + # 3. 调用 _pe_encoding 对 coords 进行位置编码。 + # 4. 返回位置编码,形状为 (B, N, C)。 + def forward_with_coords( + self, coords_input: torch.Tensor, image_size: Tuple[int, int] + ) -> torch.Tensor: + """Positionally encode points that are not normalized to [0,1].""" + coords = coords_input.clone() + coords[:, :, 0] = coords[:, :, 0] / image_size[1] + coords[:, :, 1] = coords[:, :, 1] / image_size[0] + return self._pe_encoding(coords.to(torch.float)) # B x N x C diff --git a/segment_anything/modeling/sam.py b/segment_anything/modeling/sam.py new file mode 100644 index 0000000..6d6d960 --- /dev/null +++ b/segment_anything/modeling/sam.py @@ -0,0 +1,111 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn +from torch.nn import functional as F + +from typing import Any, Dict, List, Tuple + +from .image_encoder import ImageEncoderViT +from .mask_decoder import MaskDecoder +from .prompt_encoder import PromptEncoder + + +class Sam(nn.Module): + mask_threshold: float = 0.0 + image_format: str = "RGB" + + def __init__( + self, + image_encoder: ImageEncoderViT, + prompt_encoder: PromptEncoder, + mask_decoder: MaskDecoder, + pixel_mean: List[float] = [123.675, 116.28, 103.53], + pixel_std: List[float] = [58.395, 57.12, 57.375], + ) -> None: + + super().__init__() + self.image_encoder = image_encoder + self.prompt_encoder = prompt_encoder + self.mask_decoder = mask_decoder + self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) + self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) + + @property + def device(self) -> Any: + return self.pixel_mean.device + + @torch.no_grad() + def forward( + self, + batched_input: List[Dict[str, Any]], + multimask_output: bool, + ) -> List[Dict[str, torch.Tensor]]: + input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) + image_embeddings = self.image_encoder(input_images) + + outputs = [] + for image_record, curr_embedding in zip(batched_input, image_embeddings): + if "point_coords" in image_record: + points = (image_record["point_coords"], image_record["point_labels"]) + else: + points = None + sparse_embeddings, dense_embeddings = self.prompt_encoder( + points=points, + boxes=image_record.get("boxes", None), + masks=image_record.get("mask_inputs", None), + ) + low_res_masks, iou_predictions = self.mask_decoder( + image_embeddings=curr_embedding.unsqueeze(0), + image_pe=self.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + masks = self.postprocess_masks( + low_res_masks, + input_size=image_record["image"].shape[-2:], + original_size=image_record["original_size"], + ) + masks = masks > self.mask_threshold + outputs.append( + { + "masks": masks, + "iou_predictions": iou_predictions, + "low_res_logits": low_res_masks, + } + ) + return outputs + + def postprocess_masks( + self, + masks: torch.Tensor, + input_size: Tuple[int, ...], + original_size: Tuple[int, ...], + ) -> torch.Tensor: + + masks = F.interpolate( + masks, + (self.image_encoder.img_size, self.image_encoder.img_size), + mode="bilinear", + align_corners=False, + ) + masks = masks[..., : input_size[0], : input_size[1]] + masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) + return masks + + def preprocess(self, x: torch.Tensor) -> torch.Tensor: + """Normalize pixel values and pad to a square input.""" + # Normalize colors + x = (x - self.pixel_mean) / self.pixel_std + + # Pad + h, w = x.shape[-2:] + padh = self.image_encoder.img_size - h + padw = self.image_encoder.img_size - w + x = F.pad(x, (0, padw, 0, padh)) + return x diff --git a/segment_anything/modeling/transformer.py b/segment_anything/modeling/transformer.py new file mode 100644 index 0000000..28fafea --- /dev/null +++ b/segment_anything/modeling/transformer.py @@ -0,0 +1,240 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import Tensor, nn + +import math +from typing import Tuple, Type + +from .common import MLPBlock + + +class TwoWayTransformer(nn.Module): + def __init__( + self, + depth: int, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + ) -> None: + """ + A transformer decoder that attends to an input image using + queries whose positional embedding is supplied. + + Args: + depth (int): number of layers in the transformer + embedding_dim (int): the channel dimension for the input embeddings + num_heads (int): the number of heads for multihead attention. Must + divide embedding_dim + mlp_dim (int): the channel dimension internal to the MLP block + activation (nn.Module): the activation to use in the MLP block + """ + super().__init__() + self.depth = depth + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.mlp_dim = mlp_dim + self.layers = nn.ModuleList() + + for i in range(depth): + self.layers.append( + TwoWayAttentionBlock( + embedding_dim=embedding_dim, + num_heads=num_heads, + mlp_dim=mlp_dim, + activation=activation, + attention_downsample_rate=attention_downsample_rate, + skip_first_layer_pe=(i == 0), + ) + ) + + self.final_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm_final_attn = nn.LayerNorm(embedding_dim) + + def forward( + self, + image_embedding: Tensor, + image_pe: Tensor, + point_embedding: Tensor, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + image_embedding (torch.Tensor): image to attend to. Should be shape + B x embedding_dim x h x w for any h and w. + image_pe (torch.Tensor): the positional encoding to add to the image. Must + have the same shape as image_embedding. + point_embedding (torch.Tensor): the embedding to add to the query points. + Must have shape B x N_points x embedding_dim for any N_points. + + Returns: + torch.Tensor: the processed point_embedding + torch.Tensor: the processed image_embedding + """ + # BxCxHxW -> BxHWxC == B x N_image_tokens x C + bs, c, h, w = image_embedding.shape + image_embedding = image_embedding.flatten(2).permute(0, 2, 1) + image_pe = image_pe.flatten(2).permute(0, 2, 1) + + # Prepare queries + queries = point_embedding + keys = image_embedding + + # Apply transformer blocks and final layernorm + for layer in self.layers: + queries, keys = layer( + queries=queries, + keys=keys, + query_pe=point_embedding, + key_pe=image_pe, + ) + + # Apply the final attention layer from the points to the image + q = queries + point_embedding + k = keys + image_pe + attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm_final_attn(queries) + + return queries, keys + + +class TwoWayAttentionBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + num_heads: int, + mlp_dim: int = 2048, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + skip_first_layer_pe: bool = False, + ) -> None: + """ + A transformer block with four layers: (1) self-attention of sparse + inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp + block on sparse inputs, and (4) cross attention of dense inputs to sparse + inputs. + + Arguments: + embedding_dim (int): the channel dimension of the embeddings + num_heads (int): the number of heads in the attention layers + mlp_dim (int): the hidden dimension of the mlp block + activation (nn.Module): the activation of the mlp block + skip_first_layer_pe (bool): skip the PE on the first layer + """ + super().__init__() + self.self_attn = Attention(embedding_dim, num_heads) + self.norm1 = nn.LayerNorm(embedding_dim) + + self.cross_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm2 = nn.LayerNorm(embedding_dim) + + self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) + self.norm3 = nn.LayerNorm(embedding_dim) + + self.norm4 = nn.LayerNorm(embedding_dim) + self.cross_attn_image_to_token = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + + self.skip_first_layer_pe = skip_first_layer_pe + + def forward( + self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor + ) -> Tuple[Tensor, Tensor]: + # Self attention block + if self.skip_first_layer_pe: + queries = self.self_attn(q=queries, k=queries, v=queries) + else: + q = queries + query_pe + attn_out = self.self_attn(q=q, k=q, v=queries) + queries = queries + attn_out + queries = self.norm1(queries) + + # Cross attention block, tokens attending to image embedding + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.norm3(queries) + + # Cross attention block, image embedding attending to tokens + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) + keys = keys + attn_out + keys = self.norm4(keys) + + return queries, keys + + +class Attention(nn.Module): + """ + An attention layer that allows for downscaling the size of the embedding + after projection to queries, keys, and values. + """ + + def __init__( + self, + embedding_dim: int, + num_heads: int, + downsample_rate: int = 1, + ) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.internal_dim = embedding_dim // downsample_rate + self.num_heads = num_heads + assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." + + self.q_proj = nn.Linear(embedding_dim, self.internal_dim) + self.k_proj = nn.Linear(embedding_dim, self.internal_dim) + self.v_proj = nn.Linear(embedding_dim, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, embedding_dim) + + def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: + b, n, c = x.shape + x = x.reshape(b, n, num_heads, c // num_heads) + return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head + + def _recombine_heads(self, x: Tensor) -> Tensor: + b, n_heads, n_tokens, c_per_head = x.shape + x = x.transpose(1, 2) + return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + # Attention + _, _, _, c_per_head = q.shape + attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens + attn = attn / math.sqrt(c_per_head) + attn = torch.softmax(attn, dim=-1) + + # Get output + out = attn @ v + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out diff --git a/segment_anything/predictor.py b/segment_anything/predictor.py new file mode 100644 index 0000000..8a6e6d8 --- /dev/null +++ b/segment_anything/predictor.py @@ -0,0 +1,269 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch + +from segment_anything.modeling import Sam + +from typing import Optional, Tuple + +from .utils.transforms import ResizeLongestSide + + +class SamPredictor: + def __init__( + self, + sam_model: Sam, + ) -> None: + """ + Uses SAM to calculate the image embedding for an image, and then + allow repeated, efficient mask prediction given prompts. + + Arguments: + sam_model (Sam): The model to use for mask prediction. + """ + super().__init__() + self.model = sam_model + self.transform = ResizeLongestSide(sam_model.image_encoder.img_size) + self.reset_image() + + def set_image( + self, + image: np.ndarray, + image_format: str = "RGB", + ) -> None: + """ + Calculates the image embeddings for the provided image, allowing + masks to be predicted with the 'predict' method. + + Arguments: + image (np.ndarray): The image for calculating masks. Expects an + image in HWC uint8 format, with pixel values in [0, 255]. + image_format (str): The color format of the image, in ['RGB', 'BGR']. + """ + assert image_format in [ + "RGB", + "BGR", + ], f"image_format must be in ['RGB', 'BGR'], is {image_format}." + if image_format != self.model.image_format: + image = image[..., ::-1] + + # Transform the image to the form expected by the model + input_image = self.transform.apply_image(image) + input_image_torch = torch.as_tensor(input_image, device=self.device) + input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] + + self.set_torch_image(input_image_torch, image.shape[:2]) + + @torch.no_grad() + def set_torch_image( + self, + transformed_image: torch.Tensor, + original_image_size: Tuple[int, ...], + ) -> None: + """ + Calculates the image embeddings for the provided image, allowing + masks to be predicted with the 'predict' method. Expects the input + image to be already transformed to the format expected by the model. + + Arguments: + transformed_image (torch.Tensor): The input image, with shape + 1x3xHxW, which has been transformed with ResizeLongestSide. + original_image_size (tuple(int, int)): The size of the image + before transformation, in (H, W) format. + """ + assert ( + len(transformed_image.shape) == 4 + and transformed_image.shape[1] == 3 + and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size + ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}." + self.reset_image() + + self.original_size = original_image_size + self.input_size = tuple(transformed_image.shape[-2:]) + input_image = self.model.preprocess(transformed_image) + self.features = self.model.image_encoder(input_image) + self.is_image_set = True + + def predict( + self, + point_coords: Optional[np.ndarray] = None, + point_labels: Optional[np.ndarray] = None, + box: Optional[np.ndarray] = None, + mask_input: Optional[np.ndarray] = None, + multimask_output: bool = True, + return_logits: bool = False, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Predict masks for the given input prompts, using the currently set image. + + Arguments: + point_coords (np.ndarray or None): A Nx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (np.ndarray or None): A length N array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + box (np.ndarray or None): A length 4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form 1xHxW, where + for SAM, H=W=256. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (np.ndarray): The output masks in CxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (np.ndarray): An array of length C containing the model's + predictions for the quality of each mask. + (np.ndarray): An array of shape CxHxW, where C is the number + of masks and H=W=256. These low resolution logits can be passed to + a subsequent iteration as mask input. + """ + if not self.is_image_set: + raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") + + # Transform input prompts + coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None + if point_coords is not None: + assert ( + point_labels is not None + ), "point_labels must be supplied if point_coords is supplied." + point_coords = self.transform.apply_coords(point_coords, self.original_size) + coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device) + labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) + coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] + if box is not None: + box = self.transform.apply_boxes(box, self.original_size) + box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) + box_torch = box_torch[None, :] + if mask_input is not None: + mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device) + mask_input_torch = mask_input_torch[None, :, :, :] + + masks, iou_predictions, low_res_masks = self.predict_torch( + coords_torch, + labels_torch, + box_torch, + mask_input_torch, + multimask_output, + return_logits=return_logits, + ) + + masks_np = masks[0].detach().cpu().numpy() + iou_predictions_np = iou_predictions[0].detach().cpu().numpy() + low_res_masks_np = low_res_masks[0].detach().cpu().numpy() + return masks_np, iou_predictions_np, low_res_masks_np + + @torch.no_grad() + def predict_torch( + self, + point_coords: Optional[torch.Tensor], + point_labels: Optional[torch.Tensor], + boxes: Optional[torch.Tensor] = None, + mask_input: Optional[torch.Tensor] = None, + multimask_output: bool = True, + return_logits: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Predict masks for the given input prompts, using the currently set image. + Input prompts are batched torch tensors and are expected to already be + transformed to the input frame using ResizeLongestSide. + + Arguments: + point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (torch.Tensor or None): A BxN array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + boxes (np.ndarray or None): A Bx4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form Bx1xHxW, where + for SAM, H=W=256. Masks returned by a previous iteration of the + predict method do not need further transformation. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (torch.Tensor): The output masks in BxCxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (torch.Tensor): An array of shape BxC containing the model's + predictions for the quality of each mask. + (torch.Tensor): An array of shape BxCxHxW, where C is the number + of masks and H=W=256. These low res logits can be passed to + a subsequent iteration as mask input. + """ + if not self.is_image_set: + raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") + + if point_coords is not None: + points = (point_coords, point_labels) + else: + points = None + + # Embed prompts + sparse_embeddings, dense_embeddings = self.model.prompt_encoder( + points=points, + boxes=boxes, + masks=mask_input, + ) + + # Predict masks + low_res_masks, iou_predictions = self.model.mask_decoder( + image_embeddings=self.features, + image_pe=self.model.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + + # Upscale the masks to the original image resolution + masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size) + + if not return_logits: + masks = masks > self.model.mask_threshold + + return masks, iou_predictions, low_res_masks + + def get_image_embedding(self) -> torch.Tensor: + """ + Returns the image embeddings for the currently set image, with + shape 1xCxHxW, where C is the embedding dimension and (H,W) are + the embedding spatial dimension of SAM (typically C=256, H=W=64). + """ + if not self.is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) to generate an embedding." + ) + assert self.features is not None, "Features must exist if an image has been set." + return self.features + + @property + def device(self) -> torch.device: + return self.model.device + + def reset_image(self) -> None: + """Resets the currently set image.""" + self.is_image_set = False + self.features = None + self.orig_h = None + self.orig_w = None + self.input_h = None + self.input_w = None diff --git a/segment_anything/utils/__init__.py b/segment_anything/utils/__init__.py new file mode 100644 index 0000000..5277f46 --- /dev/null +++ b/segment_anything/utils/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/segment_anything/utils/__pycache__/__init__.cpython-39.pyc b/segment_anything/utils/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000..1432c0a Binary files /dev/null and b/segment_anything/utils/__pycache__/__init__.cpython-39.pyc differ diff --git a/segment_anything/utils/__pycache__/amg.cpython-39.pyc b/segment_anything/utils/__pycache__/amg.cpython-39.pyc new file mode 100644 index 0000000..2c620fe Binary files /dev/null and b/segment_anything/utils/__pycache__/amg.cpython-39.pyc differ diff --git a/segment_anything/utils/__pycache__/transforms.cpython-39.pyc b/segment_anything/utils/__pycache__/transforms.cpython-39.pyc new file mode 100644 index 0000000..6215b55 Binary files /dev/null and b/segment_anything/utils/__pycache__/transforms.cpython-39.pyc differ diff --git a/segment_anything/utils/amg.py b/segment_anything/utils/amg.py new file mode 100644 index 0000000..be06407 --- /dev/null +++ b/segment_anything/utils/amg.py @@ -0,0 +1,346 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch + +import math +from copy import deepcopy +from itertools import product +from typing import Any, Dict, Generator, ItemsView, List, Tuple + + +class MaskData: + """ + A structure for storing masks and their related data in batched format. + Implements basic filtering and concatenation. + """ + + def __init__(self, **kwargs) -> None: + for v in kwargs.values(): + assert isinstance( + v, (list, np.ndarray, torch.Tensor) + ), "MaskData only supports list, numpy arrays, and torch tensors." + self._stats = dict(**kwargs) + + def __setitem__(self, key: str, item: Any) -> None: + assert isinstance( + item, (list, np.ndarray, torch.Tensor) + ), "MaskData only supports list, numpy arrays, and torch tensors." + self._stats[key] = item + + def __delitem__(self, key: str) -> None: + del self._stats[key] + + def __getitem__(self, key: str) -> Any: + return self._stats[key] + + def items(self) -> ItemsView[str, Any]: + return self._stats.items() + + def filter(self, keep: torch.Tensor) -> None: + for k, v in self._stats.items(): + if v is None: + self._stats[k] = None + elif isinstance(v, torch.Tensor): + self._stats[k] = v[torch.as_tensor(keep, device=v.device)] + elif isinstance(v, np.ndarray): + self._stats[k] = v[keep.detach().cpu().numpy()] + elif isinstance(v, list) and keep.dtype == torch.bool: + self._stats[k] = [a for i, a in enumerate(v) if keep[i]] + elif isinstance(v, list): + self._stats[k] = [v[i] for i in keep] + else: + raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") + + def cat(self, new_stats: "MaskData") -> None: + for k, v in new_stats.items(): + if k not in self._stats or self._stats[k] is None: + self._stats[k] = deepcopy(v) + elif isinstance(v, torch.Tensor): + self._stats[k] = torch.cat([self._stats[k], v], dim=0) + elif isinstance(v, np.ndarray): + self._stats[k] = np.concatenate([self._stats[k], v], axis=0) + elif isinstance(v, list): + self._stats[k] = self._stats[k] + deepcopy(v) + else: + raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") + + def to_numpy(self) -> None: + for k, v in self._stats.items(): + if isinstance(v, torch.Tensor): + self._stats[k] = v.detach().cpu().numpy() + + +def is_box_near_crop_edge( + boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0 +) -> torch.Tensor: + """Filter masks at the edge of a crop, but not at the edge of the original image.""" + crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) + orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) + boxes = uncrop_boxes_xyxy(boxes, crop_box).float() + near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) + near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) + near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) + return torch.any(near_crop_edge, dim=1) + + +def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor: + box_xywh = deepcopy(box_xyxy) + box_xywh[2] = box_xywh[2] - box_xywh[0] + box_xywh[3] = box_xywh[3] - box_xywh[1] + return box_xywh + + +def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: + assert len(args) > 0 and all( + len(a) == len(args[0]) for a in args + ), "Batched iteration must have inputs of all the same size." + n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) + for b in range(n_batches): + yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] + + +def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: + """ + Encodes masks to an uncompressed RLE, in the format expected by + pycoco tools. + """ + # Put in fortran order and flatten h,w + b, h, w = tensor.shape + tensor = tensor.permute(0, 2, 1).flatten(1) + + # Compute change indices + diff = tensor[:, 1:] ^ tensor[:, :-1] + change_indices = diff.nonzero() + + # Encode run length + out = [] + for i in range(b): + cur_idxs = change_indices[change_indices[:, 0] == i, 1] + cur_idxs = torch.cat( + [ + torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device), + cur_idxs + 1, + torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device), + ] + ) + btw_idxs = cur_idxs[1:] - cur_idxs[:-1] + counts = [] if tensor[i, 0] == 0 else [0] + counts.extend(btw_idxs.detach().cpu().tolist()) + out.append({"size": [h, w], "counts": counts}) + return out + + +def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: + """Compute a binary mask from an uncompressed RLE.""" + h, w = rle["size"] + mask = np.empty(h * w, dtype=bool) + idx = 0 + parity = False + for count in rle["counts"]: + mask[idx : idx + count] = parity + idx += count + parity ^= True + mask = mask.reshape(w, h) + return mask.transpose() # Put in C order + + +def area_from_rle(rle: Dict[str, Any]) -> int: + return sum(rle["counts"][1::2]) + + +def calculate_stability_score( + masks: torch.Tensor, mask_threshold: float, threshold_offset: float +) -> torch.Tensor: + """ + Computes the stability score for a batch of masks. The stability + score is the IoU between the binary masks obtained by thresholding + the predicted mask logits at high and low values. + """ + # One mask is always contained inside the other. + # Save memory by preventing unnecessary cast to torch.int64 + intersections = ( + (masks > (mask_threshold + threshold_offset)) + .sum(-1, dtype=torch.int16) + .sum(-1, dtype=torch.int32) + ) + unions = ( + (masks > (mask_threshold - threshold_offset)) + .sum(-1, dtype=torch.int16) + .sum(-1, dtype=torch.int32) + ) + return intersections / unions + + +def build_point_grid(n_per_side: int) -> np.ndarray: + """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" + offset = 1 / (2 * n_per_side) + points_one_side = np.linspace(offset, 1 - offset, n_per_side) + points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) + points_y = np.tile(points_one_side[:, None], (1, n_per_side)) + points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) + return points + + +def build_all_layer_point_grids( + n_per_side: int, n_layers: int, scale_per_layer: int +) -> List[np.ndarray]: + """Generates point grids for all crop layers.""" + points_by_layer = [] + for i in range(n_layers + 1): + n_points = int(n_per_side / (scale_per_layer**i)) + points_by_layer.append(build_point_grid(n_points)) + return points_by_layer + + +def generate_crop_boxes( + im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float +) -> Tuple[List[List[int]], List[int]]: + """ + Generates a list of crop boxes of different sizes. Each layer + has (2**i)**2 boxes for the ith layer. + """ + crop_boxes, layer_idxs = [], [] + im_h, im_w = im_size + short_side = min(im_h, im_w) + + # Original image + crop_boxes.append([0, 0, im_w, im_h]) + layer_idxs.append(0) + + def crop_len(orig_len, n_crops, overlap): + return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)) + + for i_layer in range(n_layers): + n_crops_per_side = 2 ** (i_layer + 1) + overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) + + crop_w = crop_len(im_w, n_crops_per_side, overlap) + crop_h = crop_len(im_h, n_crops_per_side, overlap) + + crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)] + crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)] + + # Crops in XYWH format + for x0, y0 in product(crop_box_x0, crop_box_y0): + box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)] + crop_boxes.append(box) + layer_idxs.append(i_layer + 1) + + return crop_boxes, layer_idxs + + +def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor: + x0, y0, _, _ = crop_box + offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device) + # Check if boxes has a channel dimension + if len(boxes.shape) == 3: + offset = offset.unsqueeze(1) + return boxes + offset + + +def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor: + x0, y0, _, _ = crop_box + offset = torch.tensor([[x0, y0]], device=points.device) + # Check if points has a channel dimension + if len(points.shape) == 3: + offset = offset.unsqueeze(1) + return points + offset + + +def uncrop_masks( + masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int +) -> torch.Tensor: + x0, y0, x1, y1 = crop_box + if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h: + return masks + # Coordinate transform masks + pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0) + pad = (x0, pad_x - x0, y0, pad_y - y0) + return torch.nn.functional.pad(masks, pad, value=0) + + +def remove_small_regions( + mask: np.ndarray, area_thresh: float, mode: str +) -> Tuple[np.ndarray, bool]: + """ + Removes small disconnected regions and holes in a mask. Returns the + mask and an indicator of if the mask has been modified. + """ + import cv2 # type: ignore + + assert mode in ["holes", "islands"] + correct_holes = mode == "holes" + working_mask = (correct_holes ^ mask).astype(np.uint8) + n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) + sizes = stats[:, -1][1:] # Row 0 is background label + small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] + if len(small_regions) == 0: + return mask, False + fill_labels = [0] + small_regions + if not correct_holes: + fill_labels = [i for i in range(n_labels) if i not in fill_labels] + # If every region is below threshold, keep largest + if len(fill_labels) == 0: + fill_labels = [int(np.argmax(sizes)) + 1] + mask = np.isin(regions, fill_labels) + return mask, True + + +def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]: + from pycocotools import mask as mask_utils # type: ignore + + h, w = uncompressed_rle["size"] + rle = mask_utils.frPyObjects(uncompressed_rle, h, w) + rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json + return rle + + +def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: + """ + Calculates boxes in XYXY format around masks. Return [0,0,0,0] for + an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. + """ + # torch.max below raises an error on empty inputs, just skip in this case + if torch.numel(masks) == 0: + return torch.zeros(*masks.shape[:-2], 4, device=masks.device) + + # Normalize shape to CxHxW + shape = masks.shape + h, w = shape[-2:] + if len(shape) > 2: + masks = masks.flatten(0, -3) + else: + masks = masks.unsqueeze(0) + + # Get top and bottom edges + in_height, _ = torch.max(masks, dim=-1) + in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :] + bottom_edges, _ = torch.max(in_height_coords, dim=-1) + in_height_coords = in_height_coords + h * (~in_height) + top_edges, _ = torch.min(in_height_coords, dim=-1) + + # Get left and right edges + in_width, _ = torch.max(masks, dim=-2) + in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :] + right_edges, _ = torch.max(in_width_coords, dim=-1) + in_width_coords = in_width_coords + w * (~in_width) + left_edges, _ = torch.min(in_width_coords, dim=-1) + + # If the mask is empty the right edge will be to the left of the left edge. + # Replace these boxes with [0, 0, 0, 0] + empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) + out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) + out = out * (~empty_filter).unsqueeze(-1) + + # Return to original shape + if len(shape) > 2: + out = out.reshape(*shape[:-2], 4) + else: + out = out[0] + + return out diff --git a/segment_anything/utils/onnx.py b/segment_anything/utils/onnx.py new file mode 100644 index 0000000..3196bdf --- /dev/null +++ b/segment_anything/utils/onnx.py @@ -0,0 +1,144 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.nn import functional as F + +from typing import Tuple + +from ..modeling import Sam +from .amg import calculate_stability_score + + +class SamOnnxModel(nn.Module): + """ + This model should not be called directly, but is used in ONNX export. + It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, + with some functions modified to enable model tracing. Also supports extra + options controlling what information. See the ONNX export script for details. + """ + + def __init__( + self, + model: Sam, + return_single_mask: bool, + use_stability_score: bool = False, + return_extra_metrics: bool = False, + ) -> None: + super().__init__() + self.mask_decoder = model.mask_decoder + self.model = model + self.img_size = model.image_encoder.img_size + self.return_single_mask = return_single_mask + self.use_stability_score = use_stability_score + self.stability_score_offset = 1.0 + self.return_extra_metrics = return_extra_metrics + + @staticmethod + def resize_longest_image_size( + input_image_size: torch.Tensor, longest_side: int + ) -> torch.Tensor: + input_image_size = input_image_size.to(torch.float32) + scale = longest_side / torch.max(input_image_size) + transformed_size = scale * input_image_size + transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) + return transformed_size + + def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: + point_coords = point_coords + 0.5 + point_coords = point_coords / self.img_size + point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) + point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) + + point_embedding = point_embedding * (point_labels != -1) + point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * ( + point_labels == -1 + ) + + for i in range(self.model.prompt_encoder.num_point_embeddings): + point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ + i + ].weight * (point_labels == i) + + return point_embedding + + def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor: + mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask) + mask_embedding = mask_embedding + ( + 1 - has_mask_input + ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) + return mask_embedding + + def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor: + masks = F.interpolate( + masks, + size=(self.img_size, self.img_size), + mode="bilinear", + align_corners=False, + ) + + prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64) + masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore + + orig_im_size = orig_im_size.to(torch.int64) + h, w = orig_im_size[0], orig_im_size[1] + masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) + return masks + + def select_masks( + self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Determine if we should return the multiclick mask or not from the number of points. + # The reweighting is used to avoid control flow. + score_reweight = torch.tensor( + [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] + ).to(iou_preds.device) + score = iou_preds + (num_points - 2.5) * score_reweight + best_idx = torch.argmax(score, dim=1) + masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) + iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) + + return masks, iou_preds + + @torch.no_grad() + def forward( + self, + image_embeddings: torch.Tensor, + point_coords: torch.Tensor, + point_labels: torch.Tensor, + mask_input: torch.Tensor, + has_mask_input: torch.Tensor, + orig_im_size: torch.Tensor, + ): + sparse_embedding = self._embed_points(point_coords, point_labels) + dense_embedding = self._embed_masks(mask_input, has_mask_input) + + masks, scores = self.model.mask_decoder.predict_masks( + image_embeddings=image_embeddings, + image_pe=self.model.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embedding, + dense_prompt_embeddings=dense_embedding, + ) + + if self.use_stability_score: + scores = calculate_stability_score( + masks, self.model.mask_threshold, self.stability_score_offset + ) + + if self.return_single_mask: + masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) + + upscaled_masks = self.mask_postprocessing(masks, orig_im_size) + + if self.return_extra_metrics: + stability_scores = calculate_stability_score( + upscaled_masks, self.model.mask_threshold, self.stability_score_offset + ) + areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) + return upscaled_masks, scores, stability_scores, areas, masks + + return upscaled_masks, scores, masks diff --git a/segment_anything/utils/transforms.py b/segment_anything/utils/transforms.py new file mode 100644 index 0000000..c08ba1e --- /dev/null +++ b/segment_anything/utils/transforms.py @@ -0,0 +1,102 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from torch.nn import functional as F +from torchvision.transforms.functional import resize, to_pil_image # type: ignore + +from copy import deepcopy +from typing import Tuple + + +class ResizeLongestSide: + """ + Resizes images to the longest side 'target_length', as well as provides + methods for resizing coordinates and boxes. Provides methods for + transforming both numpy array and batched torch tensors. + """ + + def __init__(self, target_length: int) -> None: + self.target_length = target_length + + def apply_image(self, image: np.ndarray) -> np.ndarray: + """ + Expects a numpy array with shape HxWxC in uint8 format. + """ + target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) + return np.array(resize(to_pil_image(image), target_size)) + + def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: + """ + Expects a numpy array of length 2 in the final dimension. Requires the + original image size in (H, W) format. + """ + old_h, old_w = original_size + new_h, new_w = self.get_preprocess_shape( + original_size[0], original_size[1], self.target_length + ) + coords = deepcopy(coords).astype(float) + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + return coords + + def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: + """ + Expects a numpy array shape Bx4. Requires the original image size + in (H, W) format. + """ + boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) + return boxes.reshape(-1, 4) + + def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: + """ + Expects batched images with shape BxCxHxW and float format. This + transformation may not exactly match apply_image. apply_image is + the transformation expected by the model. + """ + # Expects an image in BCHW format. May not exactly match apply_image. + target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length) + return F.interpolate( + image, target_size, mode="bilinear", align_corners=False, antialias=True + ) + + def apply_coords_torch( + self, coords: torch.Tensor, original_size: Tuple[int, ...] + ) -> torch.Tensor: + """ + Expects a torch tensor with length 2 in the last dimension. Requires the + original image size in (H, W) format. + """ + old_h, old_w = original_size + new_h, new_w = self.get_preprocess_shape( + original_size[0], original_size[1], self.target_length + ) + coords = deepcopy(coords).to(torch.float) + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + return coords + + def apply_boxes_torch( + self, boxes: torch.Tensor, original_size: Tuple[int, ...] + ) -> torch.Tensor: + """ + Expects a torch tensor with shape Bx4. Requires the original image + size in (H, W) format. + """ + boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) + return boxes.reshape(-1, 4) + + @staticmethod + def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: + """ + Compute the output size given input size and target long side length. + """ + scale = long_side_length * 1.0 / max(oldh, oldw) + newh, neww = oldh * scale, oldw * scale + neww = int(neww + 0.5) + newh = int(newh + 0.5) + return (newh, neww) diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..0eee130 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,11 @@ +[isort] +line_length=100 +multi_line_output=3 +include_trailing_comma=True +known_standard_library=numpy,setuptools +skip_glob=*/__init__.py +known_myself=segment_anything +known_third_party=matplotlib,cv2,torch,torchvision,pycocotools,onnx,black,isort +no_lines_before=STDLIB,THIRDPARTY +sections=FUTURE,STDLIB,THIRDPARTY,MYSELF,FIRSTPARTY,LOCALFOLDER +default_section=FIRSTPARTY diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..2c09863 --- /dev/null +++ b/setup.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from setuptools import find_packages, setup + +setup( + name="segment_anything", + version="1.0", + install_requires=[], + packages=find_packages(exclude="notebooks"), + extras_require={ + "all": ["matplotlib", "pycocotools", "opencv-python", "onnx", "onnxruntime"], + "dev": ["flake8", "isort", "black", "mypy"], + }, +) diff --git a/test.py b/test.py new file mode 100644 index 0000000..a072e0e --- /dev/null +++ b/test.py @@ -0,0 +1,196 @@ +import cv2 +import os +import numpy as np +from segment_anything import sam_model_registry, SamPredictor + +input_dir = 'scripts/input/images' +output_dir = 'scripts/output/mask' +crop_mode = True + +print('最好是每加一个点就按w键predict一次') +os.makedirs(output_dir, exist_ok=True) +image_files = [f for f in os.listdir(input_dir) if + f.lower().endswith(('.png', '.jpg', '.jpeg', '.JPG', '.JPEG', '.PNG', '.tiff'))] + +sam = sam_model_registry["vit_b"](checkpoint=r"D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth") +_ = sam.to(device="cuda") +predictor = SamPredictor(sam) + +WINDOW_WIDTH = 1280 +WINDOW_HEIGHT = 720 +cv2.namedWindow("image", cv2.WINDOW_NORMAL) +cv2.resizeWindow("image", WINDOW_WIDTH, WINDOW_HEIGHT) +cv2.moveWindow("image", (1920 - WINDOW_WIDTH) // 2, (1080 - WINDOW_HEIGHT) // 2) + +def mouse_click(event, x, y, flags, param): # 鼠标点击事件 + global input_point, input_label, input_stop # 全局变量,输入点, + if not input_stop: # 判定标志是否停止输入响应了! + if event == cv2.EVENT_LBUTTONDOWN: # 鼠标左键 + input_point.append([x, y]) + input_label.append(1) # 1表示前景点 + elif event == cv2.EVENT_RBUTTONDOWN: # 鼠标右键 + input_point.append([x, y]) + input_label.append(0) # 0表示背景点 + else: + if event == cv2.EVENT_LBUTTONDOWN or event == cv2.EVENT_RBUTTONDOWN: # 提示添加不了 + print('此时不能添加点,按w退出mask选择模式') + + +def apply_mask(image, mask, alpha_channel=True): # 应用并且响应mask + if alpha_channel: + alpha = np.zeros_like(image[..., 0]) # 制作掩体 + alpha[mask == 1] = 255 # 兴趣地方标记为1,且为白色 + image = cv2.merge((image[..., 0], image[..., 1], image[..., 2], alpha)) # 融合图像 + else: + image = np.where(mask[..., None] == 1, image, 0) + return image + + +def apply_color_mask(image, mask, color, color_dark=0.5): # 对掩体进行赋予颜色 + for c in range(3): + image[:, :, c] = np.where(mask == 1, image[:, :, c] * (1 - color_dark) + color_dark * color[c], image[:, :, c]) + return image + + +def get_next_filename(base_path, filename): # 进行下一个图像 + name, ext = os.path.splitext(filename) + for i in range(1, 101): + new_name = f"{name}_{i}{ext}" + if not os.path.exists(os.path.join(base_path, new_name)): + return new_name + return None + + +def save_masked_image(image, mask, output_dir, filename, crop_mode_): # 保存掩盖部分的图像(感兴趣的图像) + if crop_mode_: + y, x = np.where(mask) + y_min, y_max, x_min, x_max = y.min(), y.max(), x.min(), x.max() + cropped_mask = mask[y_min:y_max + 1, x_min:x_max + 1] + cropped_image = image[y_min:y_max + 1, x_min:x_max + 1] + masked_image = apply_mask(cropped_image, cropped_mask) + else: + masked_image = apply_mask(image, mask) + filename = filename[:filename.rfind('.')] + '.png' + new_filename = get_next_filename(output_dir, filename) + + if new_filename: + if masked_image.shape[-1] == 4: + cv2.imwrite(os.path.join(output_dir, new_filename), masked_image, [cv2.IMWRITE_PNG_COMPRESSION, 9]) + else: + cv2.imwrite(os.path.join(output_dir, new_filename), masked_image) + print(f"Saved as {new_filename}") + else: + print("Could not save the image. Too many variations exist.") + + +current_index = 0 + +cv2.namedWindow("image") +cv2.setMouseCallback("image", mouse_click) +input_point = [] +input_label = [] +input_stop = False +while True: + filename = image_files[current_index] + image_orign = cv2.imread(os.path.join(input_dir, filename)) + image_crop = image_orign.copy() # 原图裁剪 + image = cv2.cvtColor(image_orign.copy(), cv2.COLOR_BGR2RGB) # 原图色彩转变 + selected_mask = None + logit_input = None + while True: + # print(input_point) + input_stop = False + image_display = image_orign.copy() + display_info = f'{filename} ' + cv2.putText(image_display, display_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2, cv2.LINE_AA) + for point, label in zip(input_point, input_label): # 输入点和输入类型 + color = (0, 255, 0) if label == 1 else (0, 0, 255) + cv2.circle(image_display, tuple(point), 5, color, -1) + if selected_mask is not None: + color = tuple(np.random.randint(0, 256, 3).tolist()) + selected_image = apply_color_mask(image_display, selected_mask, color) +# S保存,w预测,ad切换,esc退出 + cv2.imshow("image", image_display) + key = cv2.waitKey(1) + + if key == ord(" "): + input_point = [] + input_label = [] + selected_mask = None + logit_input = None + elif key == ord("w"): + input_stop = True + if len(input_point) > 0 and len(input_label) > 0: + + predictor.set_image(image) # 设置输入图像 + input_point_np = np.array(input_point) # 输入暗示点,需要转变array类型才可以输入 + input_label_np = np.array(input_label) # 输入暗示点的类型 + + masks, scores, logits = predictor.predict( + point_coords=input_point_np, + point_labels=input_label_np, + mask_input=logit_input[None, :, :] if logit_input is not None else None, + multimask_output=True, + ) + + mask_idx = 0 + num_masks = len(masks) # masks的数量 + while (1): + color = tuple(np.random.randint(0, 256, 3).tolist()) # 随机列表颜色,就是 + image_select = image_orign.copy() + selected_mask = masks[mask_idx] # 选择msks也就是,a,d切换 + selected_image = apply_color_mask(image_select, selected_mask, color) + mask_info = f'Total: {num_masks} | Current: {mask_idx} | Score: {scores[mask_idx]:.2f} | w 预测 | d 切换下一个 | a 切换上一个 | q 移除最后一个 | s 保存' + cv2.putText(selected_image, mask_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2, + cv2.LINE_AA) + # todo 显示在当前的图片, + cv2.imshow("image", selected_image) + + key = cv2.waitKey(10) + if key == ord('q') and len(input_point) > 0: + input_point.pop(-1) + input_label.pop(-1) + elif key == ord('s'): + save_masked_image(image_crop, selected_mask, output_dir, filename, crop_mode_=crop_mode) + elif key == ord('a'): + if mask_idx > 0: + mask_idx -= 1 + else: + mask_idx = num_masks - 1 + elif key == ord('d'): + if mask_idx < num_masks - 1: + mask_idx += 1 + else: + mask_idx = 0 + elif key == ord('w'): + break + elif key == ord(" "): + input_point = [] + input_label = [] + selected_mask = None + logit_input = None + break + logit_input = logits[mask_idx, :, :] + print('max score:', np.argmax(scores), ' select:', mask_idx) + + elif key == ord('a'): + current_index = max(0, current_index - 1) + input_point = [] + input_label = [] + break + elif key == ord('d'): + current_index = min(len(image_files) - 1, current_index + 1) + input_point = [] + input_label = [] + break + elif key == 27: + break + elif key == ord('q') and len(input_point) > 0: + input_point.pop(-1) + input_label.pop(-1) + elif key == ord('s') and selected_mask is not None: + save_masked_image(image_crop, selected_mask, output_dir, filename, crop_mode_=crop_mode) + + if key == 27: + break + diff --git a/untitled.py b/untitled.py new file mode 100644 index 0000000..a671d73 --- /dev/null +++ b/untitled.py @@ -0,0 +1,94 @@ +# -*- coding: utf-8 -*- + +# Form implementation generated from reading ui file 'untitled.ui' +# +# Created by: PyQt5 UI code generator 5.15.9 +# +# WARNING: Any manual changes made to this file will be lost when pyuic5 is +# run again. Do not edit this file unless you know what you are doing. + + +from PyQt5 import QtCore, QtGui, QtWidgets + + +class Ui_MainWindow(object): + def setupUi(self, MainWindow): + MainWindow.setObjectName("MainWindow") + MainWindow.resize(1333, 657) + self.centralwidget = QtWidgets.QWidget(MainWindow) + self.centralwidget.setObjectName("centralwidget") + self.pushButton_init = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_init.setGeometry(QtCore.QRect(10, 30, 141, 41)) + self.pushButton_init.setObjectName("pushButton_init") + self.pushButton_openimg = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_openimg.setGeometry(QtCore.QRect(10, 90, 141, 41)) + self.pushButton_openimg.setObjectName("pushButton_openimg") + self.pushButton_Fusionimg = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_Fusionimg.setGeometry(QtCore.QRect(10, 270, 141, 41)) + self.pushButton_Fusionimg.setObjectName("pushButton_Fusionimg") + self.pushButton_exit = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_exit.setGeometry(QtCore.QRect(10, 570, 141, 41)) + self.pushButton_exit.setObjectName("pushButton_exit") + self.pushButton_Transparency = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_Transparency.setGeometry(QtCore.QRect(10, 380, 141, 41)) + self.pushButton_Transparency.setObjectName("pushButton_Transparency") + self.pushButton_copymask = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_copymask.setGeometry(QtCore.QRect(10, 450, 141, 41)) + self.pushButton_copymask.setObjectName("pushButton_copymask") + self.pushButton_saveimg = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_saveimg.setGeometry(QtCore.QRect(10, 510, 141, 41)) + self.pushButton_saveimg.setObjectName("pushButton_saveimg") + self.horizontalSlider = QtWidgets.QSlider(self.centralwidget) + self.horizontalSlider.setGeometry(QtCore.QRect(10, 330, 141, 22)) + self.horizontalSlider.setOrientation(QtCore.Qt.Horizontal) + self.horizontalSlider.setObjectName("horizontalSlider") + self.horizontalSlider.setValue(50) + self.label_Originalimg = QtWidgets.QLabel(self.centralwidget) + self.label_Originalimg.setGeometry(QtCore.QRect(160, 30, 571, 581)) + self.label_Originalimg.setStyleSheet("background-color: rgb(255, 255, 255);") + self.label_Originalimg.setObjectName("label_Originalimg") + self.label_Maskimg = QtWidgets.QLabel(self.centralwidget) + self.label_Maskimg.setGeometry(QtCore.QRect(740, 30, 581, 581)) + self.label_Maskimg.setStyleSheet("background-color: rgb(255, 255, 255);") + self.label_Maskimg.setObjectName("label_Maskimg") + self.pushButton_shang = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_shang.setGeometry(QtCore.QRect(10, 150, 141, 41)) + self.pushButton_shang.setObjectName("pushButton_shang") + self.pushButton_openimg_3 = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_openimg_3.setGeometry(QtCore.QRect(10, 210, 141, 41)) + self.pushButton_openimg_3.setObjectName("pushButton_openimg_3") + MainWindow.setCentralWidget(self.centralwidget) + self.menubar = QtWidgets.QMenuBar(MainWindow) + self.menubar.setGeometry(QtCore.QRect(0, 0, 1333, 26)) + self.menubar.setObjectName("menubar") + MainWindow.setMenuBar(self.menubar) + self.statusbar = QtWidgets.QStatusBar(MainWindow) + self.statusbar.setObjectName("statusbar") + MainWindow.setStatusBar(self.statusbar) + + self.retranslateUi(MainWindow) + QtCore.QMetaObject.connectSlotsByName(MainWindow) + + def retranslateUi(self, MainWindow): + _translate = QtCore.QCoreApplication.translate + MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow")) + self.pushButton_init.setText(_translate("MainWindow", "重置选择")) + self.pushButton_openimg.setText(_translate("MainWindow", "打开图片")) + self.pushButton_Fusionimg.setText(_translate("MainWindow", "融合背景图片")) + self.pushButton_exit.setText(_translate("MainWindow", "退出")) + self.pushButton_Transparency.setText(_translate("MainWindow", "调整透明度")) + self.pushButton_copymask.setText(_translate("MainWindow", "复制掩码")) + self.pushButton_saveimg.setText(_translate("MainWindow", "保存图片")) + self.label_Originalimg.setText(_translate("MainWindow", "

原始图像

")) + self.label_Maskimg.setText(_translate("MainWindow", "

掩码图像

")) + self.pushButton_shang.setText(_translate("MainWindow", "上一张")) + self.pushButton_openimg_3.setText(_translate("MainWindow", "下一张")) + +if __name__ == "__main__": + import sys + app = QtWidgets.QApplication(sys.argv) + MainWindow = QtWidgets.QMainWindow() + ui = Ui_MainWindow() + ui.setupUi(MainWindow) + MainWindow.show() + sys.exit(app.exec_()) diff --git a/untitled.ui b/untitled.ui new file mode 100644 index 0000000..3470d73 --- /dev/null +++ b/untitled.ui @@ -0,0 +1,194 @@ + + + MainWindow + + + + 0 + 0 + 1333 + 657 + + + + MainWindow + + + + + + 10 + 30 + 141 + 41 + + + + 重置选择 + + + + + + 10 + 90 + 141 + 41 + + + + 打开图片 + + + + + + 10 + 270 + 141 + 41 + + + + 融合背景图片 + + + + + + 10 + 570 + 141 + 41 + + + + 退出 + + + + + + 10 + 380 + 141 + 41 + + + + 调整透明度 + + + + + + 10 + 450 + 141 + 41 + + + + 复制掩码 + + + + + + 10 + 510 + 141 + 41 + + + + 保存图片 + + + + + + 10 + 330 + 141 + 22 + + + + Qt::Horizontal + + + + + + 160 + 30 + 571 + 581 + + + + background-color: rgb(255, 255, 255); + + + <html><head/><body><p align="center">原始图像</p></body></html> + + + + + + 740 + 30 + 581 + 581 + + + + background-color: rgb(255, 255, 255); + + + <html><head/><body><p align="center">掩码图像</p></body></html> + + + + + + 10 + 150 + 141 + 41 + + + + 上一张 + + + + + + 10 + 210 + 141 + 41 + + + + 下一张 + + + + + + + 0 + 0 + 1333 + 26 + + + + + + + + diff --git a/weights/vit_b.pth b/weights/vit_b.pth new file mode 100644 index 0000000..538f0f2 Binary files /dev/null and b/weights/vit_b.pth differ diff --git a/yourfile.py b/yourfile.py new file mode 100644 index 0000000..95bbc71 --- /dev/null +++ b/yourfile.py @@ -0,0 +1,10 @@ +# -*- coding: utf-8 -*- + +# Form implementation generated from reading ui file 'yourfile.ui' +# +# Created by: PyQt5 UI code generator 5.15.9 +# +# WARNING: Any manual changes made to this file will be lost when pyuic5 is +# run again. Do not edit this file unless you know what you are doing. + + diff --git a/命令.txt b/命令.txt new file mode 100644 index 0000000..8a6f724 --- /dev/null +++ b/命令.txt @@ -0,0 +1,18 @@ +# 接下来我们开始运行开源的demo,有两种方式: + +# cmd命令:注意notebooks/images/是指你的输入图片路径,output是指的输出mask的路径,后面的--device cpu如果加了,就会采用cpu跑,不然会默认GPU。 +# python scripts/amg.py --checkpoint sam_vit_b_01ec64.pth --model-type vit_b --input data/img/ --output data/mask --device cuda:0 +# python scripts/amg.py --checkpoint sam_vit_b_01ec64.pth --model-type vit_b --input data/img/ --output data/mask --device cpu +# python scripts/export_onnx_model.py --checkpoint --model-type --output + + +# python scripts/amg.py --checkpoint sam_vit_b_01ec64.pth --model-type vit_b --input data/img/ --output data/mask --device cuda:0 +# python scripts/export_onnx_model.py --checkpoint --model-type --output + + +# D:\anaconda3\envs\pytorch\Scripts\pyuic5.exe -x untitled.ui -o display1.py + +# D:\anaconda3\envs\pytorch\Scripts\pyuic5.exe -x untitled.ui -o display1.py +# 我现在设置的滑块范围是0-10 起始位置是0,步长是1,每一个步长缩放比例是20%, +# 往右拖动滑块1、2、3、4、5、6、7、8、9、10 这是缩小 +# 往回拖动滑块10、9、8、7、6、5、4、3、2、1、0这个是还原