SAM_Project

This commit is contained in:
tanzk 2024-06-19 08:51:04 +08:00
commit 62bad68b87
350 changed files with 10575 additions and 0 deletions

7
.flake8 Normal file
View File

@ -0,0 +1,7 @@
[flake8]
ignore = W503, E203, E221, C901, C408, E741, C407, B017, F811, C101, EXE001, EXE002
max-line-length = 100
max-complexity = 18
select = B,C,E,F,W,T4,B9
per-file-ignores =
**/__init__.py:F401,F403,E402

8
.idea/.gitignore vendored Normal file
View File

@ -0,0 +1,8 @@
# 默认忽略的文件
/shelf/
/workspace.xml
# 基于编辑器的 HTTP 客户端请求
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml

98
.idea/deployment.xml Normal file
View File

@ -0,0 +1,98 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="PublishConfigData" remoteFilesAllowedToDisappearOnAutoupload="false">
<serverData>
<paths name="root@123.125.240.150:45809">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="root@222.187.226.110:28961">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="root@connect.east.seetacloud.com:15907">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="root@connect.east.seetacloud.com:26749">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="root@connect.east.seetacloud.com:26749 (2)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="root@connect.east.seetacloud.com:26749 (3)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="root@connect.east.seetacloud.com:26749 (4)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="root@connect.east.seetacloud.com:26749 (5)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="root@connect.east.seetacloud.com:26749 password">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="root@region-42.seetacloud.com:14975">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="root@region-42.seetacloud.com:34252">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="root@region-8.seetacloud.com:35693">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="root@region-8.seetacloud.com:35693 (2)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
</serverData>
</component>
</project>

View File

@ -0,0 +1,26 @@
<component name="InspectionProjectProfileManager">
<profile version="1.0">
<option name="myName" value="Project Default" />
<inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
<option name="ignoredPackages">
<value>
<list size="11">
<item index="0" class="java.lang.String" itemvalue="sklearn" />
<item index="1" class="java.lang.String" itemvalue="tqdm" />
<item index="2" class="java.lang.String" itemvalue="scipy" />
<item index="3" class="java.lang.String" itemvalue="h5py" />
<item index="4" class="java.lang.String" itemvalue="matplotlib" />
<item index="5" class="java.lang.String" itemvalue="torch" />
<item index="6" class="java.lang.String" itemvalue="numpy" />
<item index="7" class="java.lang.String" itemvalue="torchvision" />
<item index="8" class="java.lang.String" itemvalue="opencv_python" />
<item index="9" class="java.lang.String" itemvalue="Pillow" />
<item index="10" class="java.lang.String" itemvalue="lxml" />
</list>
</value>
</option>
</inspection_tool>
<inspection_tool class="PyPep8NamingInspection" enabled="false" level="WEAK WARNING" enabled_by_default="false" />
</profile>
</component>

View File

@ -0,0 +1,6 @@
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>

4
.idea/misc.xml Normal file
View File

@ -0,0 +1,4 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.9 (pytorch)" project-jdk-type="Python SDK" />
</project>

8
.idea/modules.xml Normal file
View File

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/segment-anything-model.iml" filepath="$PROJECT_DIR$/.idea/segment-anything-model.iml" />
</modules>
</component>
</project>

View File

@ -0,0 +1,14 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$">
<sourceFolder url="file://$MODULE_DIR$" isTestSource="false" />
</content>
<orderEntry type="jdk" jdkName="Python 3.9 (pytorch)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="PyDocumentationSettings">
<option name="format" value="PLAIN" />
<option name="myDocStringFormat" value="Plain" />
</component>
</module>

80
CODE_OF_CONDUCT.md Normal file
View File

@ -0,0 +1,80 @@
# Code of Conduct
## Our Pledge
In the interest of fostering an open and welcoming environment, we as
contributors and maintainers pledge to make participation in our project and
our community a harassment-free experience for everyone, regardless of age, body
size, disability, ethnicity, sex characteristics, gender identity and expression,
level of experience, education, socio-economic status, nationality, personal
appearance, race, religion, or sexual identity and orientation.
## Our Standards
Examples of behavior that contributes to creating a positive environment
include:
* Using welcoming and inclusive language
* Being respectful of differing viewpoints and experiences
* Gracefully accepting constructive criticism
* Focusing on what is best for the community
* Showing empathy towards other community members
Examples of unacceptable behavior by participants include:
* The use of sexualized language or imagery and unwelcome sexual attention or
advances
* Trolling, insulting/derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or electronic
address, without explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Our Responsibilities
Project maintainers are responsible for clarifying the standards of acceptable
behavior and are expected to take appropriate and fair corrective action in
response to any instances of unacceptable behavior.
Project maintainers have the right and responsibility to remove, edit, or
reject comments, commits, code, wiki edits, issues, and other contributions
that are not aligned to this Code of Conduct, or to ban temporarily or
permanently any contributor for other behaviors that they deem inappropriate,
threatening, offensive, or harmful.
## Scope
This Code of Conduct applies within all project spaces, and it also applies when
an individual is representing the project or its community in public spaces.
Examples of representing a project or community include using an official
project e-mail address, posting via an official social media account, or acting
as an appointed representative at an online or offline event. Representation of
a project may be further defined and clarified by project maintainers.
This Code of Conduct also applies outside the project spaces when there is a
reasonable belief that an individual's behavior may have a negative impact on
the project or its community.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported by contacting the project team at <opensource-conduct@fb.com>. All
complaints will be reviewed and investigated and will result in a response that
is deemed necessary and appropriate to the circumstances. The project team is
obligated to maintain confidentiality with regard to the reporter of an incident.
Further details of specific enforcement policies may be posted separately.
Project maintainers who do not follow or enforce the Code of Conduct in good
faith may face temporary or permanent repercussions as determined by other
members of the project's leadership.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
[homepage]: https://www.contributor-covenant.org
For answers to common questions about this code of conduct, see
https://www.contributor-covenant.org/faq

31
CONTRIBUTING.md Normal file
View File

@ -0,0 +1,31 @@
# Contributing to segment-anything
We want to make contributing to this project as easy and transparent as
possible.
## Pull Requests
We actively welcome your pull requests.
1. Fork the repo and create your branch from `main`.
2. If you've added code that should be tested, add tests.
3. If you've changed APIs, update the documentation.
4. Ensure the test suite passes.
5. Make sure your code lints, using the `linter.sh` script in the project's root directory. Linting requires `black==23.*`, `isort==5.12.0`, `flake8`, and `mypy`.
6. If you haven't already, complete the Contributor License Agreement ("CLA").
## Contributor License Agreement ("CLA")
In order to accept your pull request, we need you to submit a CLA. You only need
to do this once to work on any of Facebook's open source projects.
Complete your CLA here: <https://code.facebook.com/cla>
## Issues
We use GitHub issues to track public bugs. Please ensure your description is
clear and has sufficient instructions to be able to reproduce the issue.
Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
disclosure of security bugs. In those cases, please go through the process
outlined on that page and do not file a public issue.
## License
By contributing to segment-anything, you agree that your contributions will be licensed
under the LICENSE file in the root directory of this source tree.

201
LICENSE Normal file
View File

@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

171
README.md Normal file
View File

@ -0,0 +1,171 @@
# Segment Anything
**[Meta AI Research, FAIR](https://ai.facebook.com/research/)**
[Alexander Kirillov](https://alexander-kirillov.github.io/), [Eric Mintun](https://ericmintun.github.io/), [Nikhila Ravi](https://nikhilaravi.com/), [Hanzi Mao](https://hanzimao.me/), Chloe Rolland, Laura Gustafson, [Tete Xiao](https://tetexiao.com), [Spencer Whitehead](https://www.spencerwhitehead.com/), Alex Berg, Wan-Yen Lo, [Piotr Dollar](https://pdollar.github.io/), [Ross Girshick](https://www.rossgirshick.info/)
[[`Paper`](https://ai.facebook.com/research/publications/segment-anything/)] [[`Project`](https://segment-anything.com/)] [[`Demo`](https://segment-anything.com/demo)] [[`Dataset`](https://segment-anything.com/dataset/index.html)] [[`Blog`](https://ai.facebook.com/blog/segment-anything-foundation-model-image-segmentation/)] [[`BibTeX`](#citing-segment-anything)]
![SAM design](assets/model_diagram.png?raw=true)
The **Segment Anything Model (SAM)** produces high quality object masks from input prompts such as points or boxes, and it can be used to generate masks for all objects in an image. It has been trained on a [dataset](https://segment-anything.com/dataset/index.html) of 11 million images and 1.1 billion masks, and has strong zero-shot performance on a variety of segmentation tasks.
<p float="left">
<img src="assets/masks1.png?raw=true" width="37.25%" />
<img src="assets/masks2.jpg?raw=true" width="61.5%" />
</p>
## Installation
The code requires `python>=3.8`, as well as `pytorch>=1.7` and `torchvision>=0.8`. Please follow the instructions [here](https://pytorch.org/get-started/locally/) to install both PyTorch and TorchVision dependencies. Installing both PyTorch and TorchVision with CUDA support is strongly recommended.
Install Segment Anything:
```
pip install git+https://github.com/facebookresearch/segment-anything.git
```
or clone the repository locally and install with
```
git clone git@github.com:facebookresearch/segment-anything.git
cd segment-anything; pip install -e .
```
The following optional dependencies are necessary for mask post-processing, saving masks in COCO format, the example notebooks, and exporting the model in ONNX format. `jupyter` is also required to run the example notebooks.
```
pip install opencv-python pycocotools matplotlib onnxruntime onnx
```
## <a name="GettingStarted"></a>Getting Started
First download a [model checkpoint](#model-checkpoints). Then the model can be used in just a few lines to get masks from a given prompt:
```
from segment_anything import SamPredictor, sam_model_registry
sam = sam_model_registry["<model_type>"](checkpoint="<path/to/checkpoint>")
predictor = SamPredictor(sam)
predictor.set_image(<your_image>)
masks, _, _ = predictor.predict(<input_prompts>)
```
or generate masks for an entire image:
```
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
sam = sam_model_registry["<model_type>"](checkpoint="<path/to/checkpoint>")
mask_generator = SamAutomaticMaskGenerator(sam)
masks = mask_generator.generate(<your_image>)
```
Additionally, masks can be generated for images from the command line:
```
python scripts/amg.py --checkpoint <path/to/checkpoint> --model-type <model_type> --input <image_or_folder> --output <path/to/output>
```
See the examples notebooks on [using SAM with prompts](/notebooks/predictor_example.ipynb) and [automatically generating masks](/notebooks/automatic_mask_generator_example.ipynb) for more details.
<p float="left">
<img src="assets/notebook1.png?raw=true" width="49.1%" />
<img src="assets/notebook2.png?raw=true" width="48.9%" />
</p>
## ONNX Export
SAM's lightweight mask decoder can be exported to ONNX format so that it can be run in any environment that supports ONNX runtime, such as in-browser as showcased in the [demo](https://segment-anything.com/demo). Export the model with
```
python scripts/export_onnx_model.py --checkpoint <path/to/checkpoint> --model-type <model_type> --output <path/to/output>
```
See the [example notebook](https://github.com/facebookresearch/segment-anything/blob/main/notebooks/onnx_model_example.ipynb) for details on how to combine image preprocessing via SAM's backbone with mask prediction using the ONNX model. It is recommended to use the latest stable version of PyTorch for ONNX export.
### Web demo
The `demo/` folder has a simple one page React app which shows how to run mask prediction with the exported ONNX model in a web browser with multithreading. Please see [`demo/README.md`](https://github.com/facebookresearch/segment-anything/blob/main/demo/README.md) for more details.
## <a name="Models"></a>Model Checkpoints
Three model versions of the model are available with different backbone sizes. These models can be instantiated by running
```
from segment_anything import sam_model_registry
sam = sam_model_registry["<model_type>"](checkpoint="<path/to/checkpoint>")
```
Click the links below to download the checkpoint for the corresponding model type.
- **`default` or `vit_h`: [ViT-H SAM model.](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth)**
- `vit_l`: [ViT-L SAM model.](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth)
- `vit_b`: [ViT-B SAM model.](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth)
## Dataset
See [here](https://ai.facebook.com/datasets/segment-anything/) for an overview of the datastet. The dataset can be downloaded [here](https://ai.facebook.com/datasets/segment-anything-downloads/). By downloading the datasets you agree that you have read and accepted the terms of the SA-1B Dataset Research License.
We save masks per image as a json file. It can be loaded as a dictionary in python in the below format.
```python
{
"image" : image_info,
"annotations" : [annotation],
}
image_info {
"image_id" : int, # Image id
"width" : int, # Image width
"height" : int, # Image height
"file_name" : str, # Image filename
}
annotation {
"id" : int, # Annotation id
"segmentation" : dict, # Mask saved in COCO RLE format.
"bbox" : [x, y, w, h], # The box around the mask, in XYWH format
"area" : int, # The area in pixels of the mask
"predicted_iou" : float, # The model's own prediction of the mask's quality
"stability_score" : float, # A measure of the mask's quality
"crop_box" : [x, y, w, h], # The crop of the image used to generate the mask, in XYWH format
"point_coords" : [[x, y]], # The point coordinates input to the model to generate the mask
}
```
Image ids can be found in sa_images_ids.txt which can be downloaded using the above [link](https://ai.facebook.com/datasets/segment-anything-downloads/) as well.
To decode a mask in COCO RLE format into binary:
```
from pycocotools import mask as mask_utils
mask = mask_utils.decode(annotation["segmentation"])
```
See [here](https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/mask.py) for more instructions to manipulate masks stored in RLE format.
## License
The model is licensed under the [Apache 2.0 license](LICENSE).
## Contributing
See [contributing](CONTRIBUTING.md) and the [code of conduct](CODE_OF_CONDUCT.md).
## Contributors
The Segment Anything project was made possible with the help of many contributors (alphabetical):
Aaron Adcock, Vaibhav Aggarwal, Morteza Behrooz, Cheng-Yang Fu, Ashley Gabriel, Ahuva Goldstand, Allen Goodman, Sumanth Gurram, Jiabo Hu, Somya Jain, Devansh Kukreja, Robert Kuo, Joshua Lane, Yanghao Li, Lilian Luong, Jitendra Malik, Mallika Malhotra, William Ngan, Omkar Parkhi, Nikhil Raina, Dirk Rowe, Neil Sejoor, Vanessa Stark, Bala Varadarajan, Bram Wasti, Zachary Winstrom
## Citing Segment Anything
If you use SAM or SA-1B in your research, please use the following BibTeX entry.
```
@article{kirillov2023segany,
title={Segment Anything},
author={Kirillov, Alexander and Mintun, Eric and Ravi, Nikhila and Mao, Hanzi and Rolland, Chloe and Gustafson, Laura and Xiao, Tete and Whitehead, Spencer and Berg, Alexander C. and Lo, Wan-Yen and Doll{\'a}r, Piotr and Girshick, Ross},
journal={arXiv:2304.02643},
year={2023}
}
```

175
SAM_Mask.py Normal file
View File

@ -0,0 +1,175 @@
import cv2
import os
import numpy as np
from segment_anything import sam_model_registry, SamPredictor
input_dir = 'scripts/pv1/0.8/wuding/jpg'
output_dir = 'scripts/pv1/0.8/wuding/mask'
crop_mode = True
print('最好是每加一个点就按w键predict一次')
os.makedirs(output_dir, exist_ok=True)
image_files = [f for f in os.listdir(input_dir) if
f.lower().endswith(('.png', '.jpg', '.jpeg', '.JPG', '.JPEG', '.PNG', '.tif'))]
sam = sam_model_registry["vit_b"](checkpoint=r"D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth")
_ = sam.to(device="cuda")
predictor = SamPredictor(sam)
WINDOW_WIDTH = 1280
WINDOW_HEIGHT = 720
cv2.namedWindow("image", cv2.WINDOW_NORMAL)
cv2.resizeWindow("image", WINDOW_WIDTH, WINDOW_HEIGHT)
cv2.moveWindow("image", (1920 - WINDOW_WIDTH) // 2, (1080 - WINDOW_HEIGHT) // 2)
def mouse_click(event, x, y, flags, param): # 鼠标点击事件
global input_point, input_label, input_stop # 全局变量,输入点,
if not input_stop: # 判定标志是否停止输入响应了!
if event == cv2.EVENT_LBUTTONDOWN: # 鼠标左键
input_point.append([x, y])
input_label.append(1) # 1表示前景点
elif event == cv2.EVENT_RBUTTONDOWN: # 鼠标右键
input_point.append([x, y])
input_label.append(0) # 0表示背景点
else:
if event == cv2.EVENT_LBUTTONDOWN or event == cv2.EVENT_RBUTTONDOWN: # 提示添加不了
print('此时不能添加点,按w退出mask选择模式')
def apply_mask(image, mask, alpha_channel=True):
if alpha_channel:
alpha = np.zeros_like(mask)
alpha[mask == 1] = 255
masked_image = image.copy()
masked_image[mask == 1] = [0, 255, 0] # 这里用绿色表示mask的区域可以根据需求修改
return cv2.addWeighted(image, 0.5, masked_image, 0.5, 0)
else:
masked_image = image.copy()
masked_image[mask == 1] = [0, 255, 0] # 这里用绿色表示mask的区域可以根据需求修改
return masked_image
def apply_color_mask(image, mask, color, color_dark=0.5):
masked_image = image.copy()
for c in range(3):
masked_image[:, :, c] = np.where(mask == 1, image[:, :, c] * (1 - color_dark) + color_dark * color[c], image[:, :, c])
return masked_image
def save_masked_image(image, mask, output_dir, filename, crop_mode_):
masked_image = apply_mask(image, mask)
filename = filename[:filename.rfind('.')] + '_masked.png' # 修改保存的文件名
cv2.imwrite(os.path.join(output_dir, filename), masked_image)
print(f"Saved as {filename}")
current_index = 0
cv2.namedWindow("image")
cv2.setMouseCallback("image", mouse_click)
input_point = []
input_label = []
input_stop = False
while True:
filename = image_files[current_index]
image_orign = cv2.imread(os.path.join(input_dir, filename))
image_crop = image_orign.copy() # 原图裁剪
image = cv2.cvtColor(image_orign.copy(), cv2.COLOR_BGR2RGB) # 原图色彩转变
selected_mask = None
logit_input = None
while True:
# print(input_point)
input_stop = False
image_display = image_orign.copy()
display_info = f'{filename} | Press s to save | Press w to predict | Press d to next image | Press a to previous image | Press space to clear | Press q to remove last point '
cv2.putText(image_display, display_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2, cv2.LINE_AA)
for point, label in zip(input_point, input_label): # 输入点和输入类型
color = (0, 255, 0) if label == 1 else (0, 0, 255)
cv2.circle(image_display, tuple(point), 5, color, -1)
if selected_mask is not None:
color = tuple(np.random.randint(0, 256, 3).tolist())
selected_image = apply_color_mask(image_display, selected_mask, color)
# S保存w预测ad切换esc退出
cv2.imshow("image", image_display)
key = cv2.waitKey(1)
if key == ord(" "):
input_point = []
input_label = []
selected_mask = None
logit_input = None
elif key == ord("w"):
input_stop = True
if len(input_point) > 0 and len(input_label) > 0:
predictor.set_image(image) # 设置输入图像
input_point_np = np.array(input_point) # 输入暗示点,需要转变array类型才可以输入
input_label_np = np.array(input_label) # 输入暗示点的类型
masks, scores, logits = predictor.predict(
point_coords=input_point_np,
point_labels=input_label_np,
mask_input=logit_input[None, :, :] if logit_input is not None else None,
multimask_output=True,
)
mask_idx = 0
num_masks = len(masks) # masks的数量
while (1):
color = tuple(np.random.randint(0, 256, 3).tolist()) # 随机列表颜色,就是
image_select = image_orign.copy()
selected_mask = masks[mask_idx] # 选择msks也就是,a,d切换
selected_image = apply_color_mask(image_select, selected_mask, color)
mask_info = f'Total: {num_masks} | Current: {mask_idx} | Score: {scores[mask_idx]:.2f} | Press w to confirm | Press d to next mask | Press a to previous mask | Press q to remove last point | Press s to save'
cv2.putText(selected_image, mask_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2,
cv2.LINE_AA)
# todo 显示在当前的图片,
cv2.imshow("image", selected_image)
key = cv2.waitKey(10)
if key == ord('q') and len(input_point) > 0:
input_point.pop(-1)
input_label.pop(-1)
elif key == ord('s'):
save_masked_image(image_crop, selected_mask, output_dir, filename, crop_mode_=crop_mode)
elif key == ord('a'):
if mask_idx > 0:
mask_idx -= 1
else:
mask_idx = num_masks - 1
elif key == ord('d'):
if mask_idx < num_masks - 1:
mask_idx += 1
else:
mask_idx = 0
elif key == ord('w'):
break
elif key == ord(" "):
input_point = []
input_label = []
selected_mask = None
logit_input = None
break
logit_input = logits[mask_idx, :, :]
print('max score:', np.argmax(scores), ' select:', mask_idx)
elif key == ord('a'):
current_index = max(0, current_index - 1)
input_point = []
input_label = []
break
elif key == ord('d'):
current_index = min(len(image_files) - 1, current_index + 1)
input_point = []
input_label = []
break
elif key == 27:
break
elif key == ord('q') and len(input_point) > 0:
input_point.pop(-1)
input_label.pop(-1)
elif key == ord('s') and selected_mask is not None:
save_masked_image(image_crop, selected_mask, output_dir, filename, crop_mode_=crop_mode)
if key == 27:
break

180
SAM_YY.py Normal file
View File

@ -0,0 +1,180 @@
import cv2
import os
import numpy as np
from segment_anything import sam_model_registry, SamPredictor
input_dir = r'C:\Users\t2581\Desktop\222\images'
output_dir = r'C:\Users\t2581\Desktop\222\json'
crop_mode = True
print('最好是每加一个点就按w键predict一次')
os.makedirs(output_dir, exist_ok=True)
image_files = [f for f in os.listdir(input_dir) if
f.lower().endswith(('.png', '.jpg', '.jpeg', '.JPG', '.JPEG', '.PNG', '.tiff'))]
sam = sam_model_registry["vit_b"](checkpoint=r"D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth")
_ = sam.to(device="cuda")
predictor = SamPredictor(sam)
WINDOW_WIDTH = 1280
WINDOW_HEIGHT = 720
cv2.namedWindow("image", cv2.WINDOW_NORMAL)
cv2.resizeWindow("image", WINDOW_WIDTH, WINDOW_HEIGHT)
cv2.moveWindow("image", (1920 - WINDOW_WIDTH) // 2, (1080 - WINDOW_HEIGHT) // 2)
def mouse_click(event, x, y, flags, param):
global input_point, input_label, input_stop
if not input_stop:
if event == cv2.EVENT_LBUTTONDOWN:
input_point.append([x, y])
input_label.append(1)
elif event == cv2.EVENT_RBUTTONDOWN:
input_point.append([x, y])
input_label.append(0)
else:
if event == cv2.EVENT_LBUTTONDOWN or event == cv2.EVENT_RBUTTONDOWN:
print('此时不能添加点,按w退出mask选择模式')
def apply_mask(image, mask, alpha_channel=True):
if alpha_channel:
alpha = np.zeros_like(mask)
alpha[mask == 1] = 255
masked_image = image.copy()
masked_image[mask == 1] = [0, 255, 0]
return cv2.addWeighted(image, 0.5, masked_image, 0.5, 0)
else:
masked_image = image.copy()
masked_image[mask == 1] = [0, 255, 0]
return masked_image
def apply_color_mask(image, mask, color, color_dark=0.5):
masked_image = image.copy()
for c in range(3):
masked_image[:, :, c] = np.where(mask == 1, image[:, :, c] * (1 - color_dark) + color_dark * color[c], image[:, :, c])
return masked_image
def draw_external_rectangle(image, mask, pv):
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
for contour in contours:
x, y, w, h = cv2.boundingRect(contour)
cv2.rectangle(image, (x, y), (x + w, y + h), (0, 255, 255), 2) # Yellow rectangle
cv2.putText(image, pv, (x, y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
def save_masked_image(image, mask, output_dir, filename, crop_mode_, pv):
masked_image = apply_mask(image, mask)
draw_external_rectangle(masked_image, mask, pv)
filename = filename[:filename.rfind('.')] + '_masked.png'
cv2.imwrite(os.path.join(output_dir, filename), masked_image)
print(f"Saved as {filename}")
current_index = 0
cv2.namedWindow("image")
cv2.setMouseCallback("image", mouse_click)
input_point = []
input_label = []
input_stop = False
while True:
filename = image_files[current_index]
image_orign = cv2.imread(os.path.join(input_dir, filename))
image_crop = image_orign.copy()
image = cv2.cvtColor(image_orign.copy(), cv2.COLOR_BGR2RGB)
selected_mask = None
logit_input = None
while True:
# print(input_point)
input_stop = False
image_display = image_orign.copy()
display_info = f'{filename} '
cv2.putText(image_display, display_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2, cv2.LINE_AA)
for point, label in zip(input_point, input_label):
color = (0, 255, 0) if label == 1 else (0, 0, 255)
cv2.circle(image_display, tuple(point), 5, color, -1)
if selected_mask is not None:
color = tuple(np.random.randint(0, 256, 3).tolist())
selected_image = apply_color_mask(image_display, selected_mask, color)
cv2.imshow("image", image_display)
key = cv2.waitKey(1)
if key == ord(" "):
input_point = []
input_label = []
selected_mask = None
logit_input = None
elif key == ord("w"):
input_stop = True
if len(input_point) > 0 and len(input_label) > 0:
predictor.set_image(image)
input_point_np = np.array(input_point)
input_label_np = np.array(input_label)
masks, scores, logits = predictor.predict(
point_coords=input_point_np,
point_labels=input_label_np,
mask_input=logit_input[None, :, :] if logit_input is not None else None,
multimask_output=True,
)
mask_idx = 0
num_masks = len(masks) # masks的数量
while (1):
color = tuple(np.random.randint(0, 256, 3).tolist()) # 随机列表颜色,就是
image_select = image_orign.copy()
selected_mask = masks[mask_idx] # 选择msks也就是,a,d切换
selected_image = apply_color_mask(image_select, selected_mask, color)
mask_info = f'Total: {num_masks} | Current: {mask_idx} | Score: {scores[mask_idx]:.2f} '
cv2.putText(selected_image, mask_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2,
cv2.LINE_AA)
# todo 显示在当前的图片,
cv2.imshow("image", selected_image)
key = cv2.waitKey(10)
if key == ord('q') and len(input_point) > 0:
input_point.pop(-1)
input_label.pop(-1)
elif key == ord('s'):
save_masked_image(image_crop, selected_mask, output_dir, filename, crop_mode_=crop_mode, pv=f"pv: {scores[mask_idx]:.2f}")
elif key == ord('a'):
if mask_idx > 0:
mask_idx -= 1
else:
mask_idx = num_masks - 1
elif key == ord('d'):
if mask_idx < num_masks - 1:
mask_idx += 1
else:
mask_idx = 0
elif key == ord('w'):
break
elif key == ord(" "):
input_point = []
input_label = []
selected_mask = None
logit_input = None
break
logit_input = logits[mask_idx, :, :]
print('max score:', np.argmax(scores), ' select:', mask_idx)
elif key == ord('a'):
current_index = max(0, current_index - 1)
input_point = []
input_label = []
break
elif key == ord('d'):
current_index = min(len(image_files) - 1, current_index + 1)
input_point = []
input_label = []
break
elif key == 27:
break
elif key == ord('q') and len(input_point) > 0:
input_point.pop(-1)
input_label.pop(-1)
elif key == ord('s') and selected_mask is not None:
save_masked_image(image_crop, selected_mask, output_dir, filename, crop_mode_=crop_mode, pv=f"pv: {scores[mask_idx]:.2f}")
if key == 27:
break

209
SAM_YY_JSON.py Normal file
View File

@ -0,0 +1,209 @@
import cv2
import os
import numpy as np
import json
from segment_anything import sam_model_registry, SamPredictor
input_dir = r'C:\Users\t2581\Desktop\222\images'
output_dir = r'C:\Users\t2581\Desktop\222\2'
crop_mode = True
print('最好是每加一个点就按w键predict一次')
os.makedirs(output_dir, exist_ok=True)
image_files = [f for f in os.listdir(input_dir) if
f.lower().endswith(('.png', '.jpg', '.jpeg', '.JPG', '.JPEG', '.PNG', '.tiff'))]
sam = sam_model_registry["vit_b"](checkpoint=r"D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth")
_ = sam.to(device="cuda")
predictor = SamPredictor(sam)
WINDOW_WIDTH = 1280
WINDOW_HEIGHT = 720
cv2.namedWindow("image", cv2.WINDOW_NORMAL)
cv2.resizeWindow("image", WINDOW_WIDTH, WINDOW_HEIGHT)
cv2.moveWindow("image", (1920 - WINDOW_WIDTH) // 2, (1080 - WINDOW_HEIGHT) // 2)
def mouse_click(event, x, y, flags, param):
global input_point, input_label, input_stop
if not input_stop:
if event == cv2.EVENT_LBUTTONDOWN:
input_point.append([x, y])
input_label.append(1)
elif event == cv2.EVENT_RBUTTONDOWN:
input_point.append([x, y])
input_label.append(0)
else:
if event == cv2.EVENT_LBUTTONDOWN or event == cv2.EVENT_RBUTTONDOWN:
print('此时不能添加点,按w退出mask选择模式')
def apply_mask(image, mask, alpha_channel=True):
if alpha_channel:
alpha = np.zeros_like(mask)
alpha[mask == 1] = 255
masked_image = image.copy()
masked_image[mask == 1] = [0, 255, 0]
return cv2.addWeighted(image, 0.5, masked_image, 0.5, 0)
else:
masked_image = image.copy()
masked_image[mask == 1] = [0, 255, 0]
return masked_image
def apply_color_mask(image, mask, color, color_dark=0.5):
masked_image = image.copy()
for c in range(3):
masked_image[:, :, c] = np.where(mask == 1, image[:, :, c] * (1 - color_dark) + color_dark * color[c], image[:, :, c])
return masked_image
def draw_external_rectangle(image, mask, pv):
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
for contour in contours:
x, y, w, h = cv2.boundingRect(contour)
cv2.rectangle(image, (x, y), (x + w, y + h), (0, 255, 255), 2) # Yellow rectangle
cv2.putText(image, pv, (x, y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
def save_masked_image_and_json(image, mask, output_dir, filename, crop_mode_, pv):
masked_image = apply_mask(image, mask)
draw_external_rectangle(masked_image, mask, pv)
masked_filename = filename[:filename.rfind('.')] + '_masked.png'
cv2.imwrite(os.path.join(output_dir, masked_filename), masked_image)
print(f"Saved image as {masked_filename}")
# Convert mask to polygons
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
polygons = [contour.reshape(-1, 2).tolist() for contour in contours]
# Create JSON data
json_data = {
"version": "5.1.1",
"flags": {},
"shapes": [
{
"label": "pv",
"points": polygon,
"group_id": None,
"shape_type": "polygon",
"flags": {}
} for polygon in polygons
],
"imagePath": filename,
"imageData": None,
"imageHeight": mask.shape[0],
"imageWidth": mask.shape[1]
}
# Save JSON file
json_filename = filename[:filename.rfind('.')] + '_masked.json'
with open(os.path.join(output_dir, json_filename), 'w') as json_file:
json.dump(json_data, json_file, indent=4)
print(f"Saved JSON as {json_filename}")
current_index = 0
cv2.namedWindow("image")
cv2.setMouseCallback("image", mouse_click)
input_point = []
input_label = []
input_stop = False
while True:
filename = image_files[current_index]
image_orign = cv2.imread(os.path.join(input_dir, filename))
image_crop = image_orign.copy()
image = cv2.cvtColor(image_orign.copy(), cv2.COLOR_BGR2RGB)
selected_mask = None
logit_input = None
while True:
input_stop = False
image_display = image_orign.copy()
display_info = f'{filename} '
cv2.putText(image_display, display_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2, cv2.LINE_AA)
for point, label in zip(input_point, input_label):
color = (0, 255, 0) if label == 1 else (0, 0, 255)
cv2.circle(image_display, tuple(point), 5, color, -1)
if selected_mask is not None:
color = tuple(np.random.randint(0, 256, 3).tolist())
selected_image = apply_color_mask(image_display, selected_mask, color)
cv2.imshow("image", image_display)
key = cv2.waitKey(1)
if key == ord(" "):
input_point = []
input_label = []
selected_mask = None
logit_input = None
elif key == ord("w"):
input_stop = True
if len(input_point) > 0 and len(input_label) > 0:
predictor.set_image(image)
input_point_np = np.array(input_point)
input_label_np = np.array(input_label)
masks, scores, logits = predictor.predict(
point_coords=input_point_np,
point_labels=input_label_np,
mask_input=logit_input[None, :, :] if logit_input is not None else None,
multimask_output=True,
)
mask_idx = 0
num_masks = len(masks) # masks的数量
while (1):
color = tuple(np.random.randint(0, 256, 3).tolist()) # 随机列表颜色,就是
image_select = image_orign.copy()
selected_mask = masks[mask_idx] # 选择msks也就是,a,d切换
selected_image = apply_color_mask(image_select, selected_mask, color)
mask_info = f'Total: {num_masks} | Current: {mask_idx} | Score: {scores[mask_idx]:.2f} '
cv2.putText(selected_image, mask_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2,
cv2.LINE_AA)
# todo 显示在当前的图片,
cv2.imshow("image", selected_image)
key = cv2.waitKey(10)
if key == ord('q') and len(input_point) > 0:
input_point.pop(-1)
input_label.pop(-1)
elif key == ord('s'):
save_masked_image_and_json(image_crop, selected_mask, output_dir, filename, crop_mode_=crop_mode, pv=f"pv: {scores[mask_idx]:.2f}")
elif key == ord('a'):
if mask_idx > 0:
mask_idx -= 1
else:
mask_idx = num_masks - 1
elif key == ord('d'):
if mask_idx < num_masks - 1:
mask_idx += 1
else:
mask_idx = 0
elif key == ord('w'):
break
elif key == ord(" "):
input_point = []
input_label = []
selected_mask = None
logit_input = None
break
logit_input = logits[mask_idx, :, :]
print('max score:', np.argmax(scores), ' select:', mask_idx)
elif key == ord('a'):
current_index = max(0, current_index - 1)
input_point = []
input_label = []
break
elif key == ord('d'):
current_index = min(len(image_files) - 1, current_index + 1)
input_point = []
input_label = []
break
elif key == 27:
break
elif key == ord('q') and len(input_point) > 0:
input_point.pop(-1)
input_label.pop(-1)
elif key == ord('s') and selected_mask is not None:
save_masked_image_and_json(image_crop, selected_mask, output_dir, filename, crop_mode_=crop_mode, pv=f"pv: {scores[mask_idx]:.2f}")
if key == 27:
break

161
UI.py Normal file
View File

@ -0,0 +1,161 @@
import sys
import os
from PyQt5 import QtCore, QtGui, QtWidgets
class Ui_MainWindow(object):
def setupUi(self, MainWindow):
MainWindow.setObjectName("MainWindow")
MainWindow.resize(1140, 450)
MainWindow.setMinimumSize(QtCore.QSize(1140, 450))
MainWindow.setMaximumSize(QtCore.QSize(1140, 450))
self.centralwidget = QtWidgets.QWidget(MainWindow)
self.centralwidget.setObjectName("centralwidget")
self.pushButton_w = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_w.setGeometry(QtCore.QRect(10, 90, 151, 51))
self.pushButton_w.setObjectName("pushButton_w")
self.pushButton_a = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_a.setGeometry(QtCore.QRect(10, 160, 71, 51))
self.pushButton_a.setObjectName("pushButton_a")
self.pushButton_d = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_d.setGeometry(QtCore.QRect(90, 160, 71, 51))
self.pushButton_d.setObjectName("pushButton_d")
self.pushButton_s = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_s.setGeometry(QtCore.QRect(10, 360, 151, 51))
self.pushButton_s.setObjectName("pushButton_s")
self.pushButton_5 = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_5.setGeometry(QtCore.QRect(10, 230, 151, 51))
self.pushButton_5.setObjectName("pushButton_5")
self.label_orign = QtWidgets.QLabel(self.centralwidget)
self.label_orign.setGeometry(QtCore.QRect(180, 20, 471, 401))
self.label_orign.setStyleSheet("background-color: rgb(255, 255, 255);")
self.label_orign.setObjectName("label_orign")
self.label_2 = QtWidgets.QLabel(self.centralwidget)
self.label_2.setGeometry(QtCore.QRect(660, 20, 471, 401))
self.label_2.setStyleSheet("background-color: rgb(255, 255, 255);")
self.label_2.setObjectName("label_2")
self.pushButton_w_2 = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_w_2.setGeometry(QtCore.QRect(10, 20, 151, 51))
self.pushButton_w_2.setObjectName("pushButton_w_2")
self.lineEdit = QtWidgets.QLineEdit(self.centralwidget)
self.lineEdit.setGeometry(QtCore.QRect(50, 290, 81, 21))
self.lineEdit.setObjectName("lineEdit")
self.horizontalSlider = QtWidgets.QSlider(self.centralwidget)
self.horizontalSlider.setGeometry(QtCore.QRect(10, 320, 141, 22))
self.horizontalSlider.setSliderPosition(50)
self.horizontalSlider.setOrientation(QtCore.Qt.Horizontal)
self.horizontalSlider.setTickInterval(0)
self.horizontalSlider.setObjectName("horizontalSlider")
MainWindow.setCentralWidget(self.centralwidget)
self.menubar = QtWidgets.QMenuBar(MainWindow)
self.menubar.setGeometry(QtCore.QRect(0, 0, 1140, 23))
self.menubar.setObjectName("menubar")
MainWindow.setMenuBar(self.menubar)
self.statusbar = QtWidgets.QStatusBar(MainWindow)
self.statusbar.setObjectName("statusbar")
MainWindow.setStatusBar(self.statusbar)
self.retranslateUi(MainWindow)
QtCore.QMetaObject.connectSlotsByName(MainWindow)
def retranslateUi(self, MainWindow):
_translate = QtCore.QCoreApplication.translate
MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow"))
self.pushButton_w.setText(_translate("MainWindow", "Predict"))
self.pushButton_a.setText(_translate("MainWindow", "Pre"))
self.pushButton_d.setText(_translate("MainWindow", "Next"))
self.pushButton_s.setText(_translate("MainWindow", "Save"))
self.pushButton_5.setText(_translate("MainWindow", "背景图"))
self.label_orign.setText(_translate("MainWindow", "<html><head/><body><p align=\"center\">原始图像</p></body></html>"))
self.label_2.setText(_translate("MainWindow", "<html><head/><body><p align=\"center\">预测图像</p></body></html>"))
self.pushButton_w_2.setText(_translate("MainWindow", "Openimg"))
self.lineEdit.setText(_translate("MainWindow", "改变mask大小"))
class MyMainWindow(QtWidgets.QMainWindow, Ui_MainWindow):
def __init__(self):
super().__init__()
self.setupUi(self)
self.image_path = ""
self.image_folder = ""
self.image_files = []
self.current_index = 0
self.pushButton_w_2.clicked.connect(self.open_image_folder)
self.pushButton_a.clicked.connect(self.load_previous_image)
self.pushButton_d.clicked.connect(self.load_next_image)
def open_image_folder(self):
folder_dialog = QtWidgets.QFileDialog()
folder_path = folder_dialog.getExistingDirectory(self, 'Open Image Folder', '')
if folder_path:
self.image_folder = folder_path
self.image_files = self.get_image_files(self.image_folder)
if self.image_files:
self.show_image_selection_dialog()
def get_image_files(self, folder_path):
image_files = [file for file in os.listdir(folder_path) if file.endswith(('png', 'jpg', 'jpeg', 'bmp'))]
return image_files
def show_image_selection_dialog(self):
dialog = QtWidgets.QDialog(self)
dialog.setWindowTitle("Select Image")
layout = QtWidgets.QVBoxLayout()
self.listWidget = QtWidgets.QListWidget()
for image_file in self.image_files:
item = QtWidgets.QListWidgetItem(image_file)
pixmap = QtGui.QPixmap(os.path.join(self.image_folder, image_file)).scaledToWidth(100)
item.setIcon(QtGui.QIcon(pixmap))
self.listWidget.addItem(item)
self.listWidget.itemDoubleClicked.connect(self.image_selected)
layout.addWidget(self.listWidget)
buttonBox = QtWidgets.QDialogButtonBox(QtWidgets.QDialogButtonBox.Ok | QtWidgets.QDialogButtonBox.Cancel)
buttonBox.accepted.connect(self.image_selected)
buttonBox.rejected.connect(dialog.reject)
layout.addWidget(buttonBox)
dialog.setLayout(layout)
dialog.exec_()
def image_selected(self):
selected_item = self.listWidget.currentItem()
if selected_item:
selected_index = self.listWidget.currentRow()
if selected_index >= 0:
self.current_index = selected_index
self.show_image()
def show_image(self):
file_path = os.path.join(self.image_folder, self.image_files[self.current_index])
pixmap = QtGui.QPixmap(file_path)
self.label_orign.setPixmap(pixmap.scaled(self.label_orign.size(), QtCore.Qt.KeepAspectRatio))
def load_previous_image(self):
if self.image_files:
if self.current_index > 0:
self.current_index -= 1
else:
self.current_index = len(self.image_files) - 1
self.show_image()
def load_next_image(self):
if self.image_files:
if self.current_index < len(self.image_files) - 1:
self.current_index += 1
else:
self.current_index = 0
self.show_image()
if __name__ == "__main__":
app = QtWidgets.QApplication(sys.argv)
mainWindow = MyMainWindow()
mainWindow.show()
sys.exit(app.exec_())

186
UI.ui Normal file
View File

@ -0,0 +1,186 @@
<?xml version="1.0" encoding="UTF-8"?>
<ui version="4.0">
<class>MainWindow</class>
<widget class="QMainWindow" name="MainWindow">
<property name="geometry">
<rect>
<x>0</x>
<y>0</y>
<width>1140</width>
<height>450</height>
</rect>
</property>
<property name="minimumSize">
<size>
<width>1140</width>
<height>450</height>
</size>
</property>
<property name="maximumSize">
<size>
<width>1140</width>
<height>450</height>
</size>
</property>
<property name="windowTitle">
<string>MainWindow</string>
</property>
<widget class="QWidget" name="centralwidget">
<widget class="QPushButton" name="pushButton_w">
<property name="geometry">
<rect>
<x>10</x>
<y>90</y>
<width>151</width>
<height>51</height>
</rect>
</property>
<property name="text">
<string>Predict</string>
</property>
</widget>
<widget class="QPushButton" name="pushButton_a">
<property name="geometry">
<rect>
<x>10</x>
<y>160</y>
<width>71</width>
<height>51</height>
</rect>
</property>
<property name="text">
<string>Pre</string>
</property>
</widget>
<widget class="QPushButton" name="pushButton_d">
<property name="geometry">
<rect>
<x>90</x>
<y>160</y>
<width>71</width>
<height>51</height>
</rect>
</property>
<property name="text">
<string>Next</string>
</property>
</widget>
<widget class="QPushButton" name="pushButton_s">
<property name="geometry">
<rect>
<x>10</x>
<y>360</y>
<width>151</width>
<height>51</height>
</rect>
</property>
<property name="text">
<string>Save</string>
</property>
</widget>
<widget class="QPushButton" name="pushButton_5">
<property name="geometry">
<rect>
<x>10</x>
<y>230</y>
<width>151</width>
<height>51</height>
</rect>
</property>
<property name="text">
<string>背景图</string>
</property>
</widget>
<widget class="QLabel" name="label_orign">
<property name="geometry">
<rect>
<x>180</x>
<y>20</y>
<width>471</width>
<height>401</height>
</rect>
</property>
<property name="styleSheet">
<string notr="true">background-color: rgb(255, 255, 255);</string>
</property>
<property name="text">
<string>&lt;html&gt;&lt;head/&gt;&lt;body&gt;&lt;p align=&quot;center&quot;&gt;原始图像&lt;/p&gt;&lt;/body&gt;&lt;/html&gt;</string>
</property>
</widget>
<widget class="QLabel" name="label_2">
<property name="geometry">
<rect>
<x>660</x>
<y>20</y>
<width>471</width>
<height>401</height>
</rect>
</property>
<property name="styleSheet">
<string notr="true">background-color: rgb(255, 255, 255);</string>
</property>
<property name="text">
<string>&lt;html&gt;&lt;head/&gt;&lt;body&gt;&lt;p align=&quot;center&quot;&gt;预测图像&lt;/p&gt;&lt;/body&gt;&lt;/html&gt;</string>
</property>
</widget>
<widget class="QPushButton" name="pushButton_w_2">
<property name="geometry">
<rect>
<x>10</x>
<y>20</y>
<width>151</width>
<height>51</height>
</rect>
</property>
<property name="text">
<string>Openimg</string>
</property>
</widget>
<widget class="QLineEdit" name="lineEdit">
<property name="geometry">
<rect>
<x>50</x>
<y>290</y>
<width>81</width>
<height>21</height>
</rect>
</property>
<property name="text">
<string>改变mask大小</string>
</property>
</widget>
<widget class="QSlider" name="horizontalSlider">
<property name="geometry">
<rect>
<x>10</x>
<y>320</y>
<width>141</width>
<height>22</height>
</rect>
</property>
<property name="sliderPosition">
<number>50</number>
</property>
<property name="orientation">
<enum>Qt::Horizontal</enum>
</property>
<property name="tickInterval">
<number>0</number>
</property>
</widget>
</widget>
<widget class="QMenuBar" name="menubar">
<property name="geometry">
<rect>
<x>0</x>
<y>0</y>
<width>1140</width>
<height>23</height>
</rect>
</property>
</widget>
<widget class="QStatusBar" name="statusbar"/>
</widget>
<resources/>
<connections/>
</ui>

Binary file not shown.

734
biao.py Normal file
View File

@ -0,0 +1,734 @@
import math
import os
import cv2
"""
标注关键点只能存在一个框和多个点并且不能删除点和删除框读取本地文件的关键点要保证其中的关键点
数和key_point_num的值是一样的本地标签中如果只存在框的信息就不要使用该脚本标注不然会出错
本地文件夹中可以有标签如果有会优先加载本地标签没有才会创建一个
关键点标注
这个是默认的标注就是普通标注
按Q切换下一张,R清除干扰再标注把之前的框屏蔽
按Y从本地把图像和标签一切删掉
按T 退出
鼠标双击可以删除框
单机就是拖拽
"""
draw_line_circle = True # True/None 是否在框上绘制点8个点
key_point_is = None # 是否标记关键点 设置为None标注普通yolo标签
# 可以自定义得参数
image_path = R'C:\Users\lengdan\Desktop\data1\images\images' # 标注完成保存到的文件夹
label_path = R'C:\Users\lengdan\Desktop\data1\images\labels' # 要标注的图像所在文件夹
circle_distance = 10 # 半径范围:鼠标进入点的半径范围内会出现光圈
key_point_num = 5 # 关键点个数
box_thickness = 2 # 框的粗细
small_box_thickness = 1 # 框的8个点的粗细
label_thickness = 1 # 框上面的类别字体的粗细
label_fontScale = 0.4 # 框上面的类别字体的倍数
key_thick = -1 # 关键点的粗细
key_text_thick = 2 # 关键点上文字粗细
key_text_scale = 0.6 # 关键点上文字的放大倍数
key_radius = 4 # 关键点绘制半径
dot = 6 # 选择保留几位小数
key_color = {
0: (0, 0, 200),
1: (255, 0, 0),
2: (0, 222, 0)
} # 关键点的颜色
key_text_color = {
0: (0, 100, 200),
1: (255, 0, 0),
2: (0, 255, 125)
} # 关键点上文本的颜色
box_color = {
0: (125, 125, 125),
1: (0, 255, 0),
2: (0, 255, 0),
3: (255, 0, 0),
4: (0, 255, 255),
5: (255, 255, 0),
6: (255, 0, 255),
7: (0, 125, 125),
8: (125, 125, 125),
9: (125, 125, 125),
10: (125, 0, 125),
11: (125, 0, 125),
12: (125, 0, 125),
13: (125, 0, 125),
14: (125, 0, 125),
15: (125, 0, 125)
} # 每个不同类别框的颜色
my_cls = {
0: '0',
1: '1',
2: '10',
3: '11',
4: '12',
5: '13',
6: '14',
7: '15',
8: '2',
9: '3',
10: '4',
11: '5',
12: '6',
13: '7',
14: '8',
15: '9'
} # 添加自己的框的标签如果没有就用i:'i'替代
final_class = {
i: my_cls[i] if i in my_cls else str(i) for i in range(16)
} # 框的默认名字
# 不要修改的参数
position = None # 这里判断鼠标放到了哪个点上,方便后面移动的时候做计算
label = None # 操作图像对应的标签
img = None # 操作的图像
Mouse_move = None # 选择移动框的标志位
label_index = None # 鼠标选中的框在标签中的位置
label_index_pos = None # 记录选中了框的8个点位的哪一个
Mouse_insert = None # 用来记录是否进入删除状态
draw_rectangle = None # 用来记录开始添加新框
end_draw_rectangle = None # 用来记录结束绘制新框
append_str_temp = None # 用来保存新增加的框的信息
empty_label = None # 本地是否存在标签文件标志
# 关键点相关的参数
key_points = None
key_points_move = None
key_x = None # 移动关键点的时候记录其每个关键点的x
key_y = None # 移动关键点的时候记录其每个关键点的y
key_v = None # 移动关键点的时候记录其每个关键点的状态
key_box = None
box_move = None # 移动的是框的时候的标志位
key_insert = None # 对某个关键点双击,切换其状态
move_key_point = None # 把其他位置的关键点移动到这个地方
la_path = None
key_point_one = None # 使用双击移动关键点的时候,记录第一个按下的键
key_point_two = None # 使用双击移动关键点的时候,记录第二个按下的键
append_new_key_point = None # 增加第二个关键点
append_new_key_point_index = 0 # 增加第二个关键点
window_w = None # 获取创建窗口的宽度
window_h = None # 获取创建窗口的高度
def flag_init():
# 初始化下参数
global position, label, img, Mouse_insert, Mouse_move, label_index, draw_rectangle, end_draw_rectangle, append_str_temp, empty_label, \
label_index_pos, key_v, key_x, key_y, key_x, key_box, key_points_move, key_points, box_move, move_key_point, window_w, window_h
position = None
Mouse_move = None
label_index = None
label_index_pos = None
Mouse_insert = None
draw_rectangle = None
end_draw_rectangle = None
append_str_temp = None
empty_label = None
key_points = None
key_points_move = None
key_x = None
key_y = None
key_v = None
key_box = None
box_move = None
move_key_point = None
window_w = None
window_h = None
# 用来绘制小的填充矩形框
def draw_rect_box(img, center, length_1, color=(0, 0, 255)):
x1, y1 = center[0] - length_1, center[1] - length_1
x2, y2 = center[0] + length_1, center[1] + length_1
cv2.rectangle(img, (x1, y1), (x2, y2), color, thickness=-1)
# 用来读取本地图像
def img_read(img_path, scale_):
global window_w, window_h
# scale_填写屏幕的最小尺寸
image = cv2.imread(img_path)
scale_x, scale_y, _ = image.shape
if max(scale_x, scale_y) > scale_ and window_w is None:
scale = max(scale_x, scale_y) / scale_
image = cv2.resize(image, (int(image.shape[1] / scale), int(image.shape[0] / scale)))
if window_w is not None:
image = cv2.resize(image, (window_w, window_h))
return image
# 判断两点的间距用来判断鼠标所在位置是否进入了8个点所在的区域
def distance(p1, p2):
global circle_distance
if math.sqrt((p2[0] - p1[0]) ** 2 + (p2[1] - p1[1]) ** 2) < circle_distance:
return True
else:
return False
# 绘制虚线矩形框,当切换到删除时,由实线框转为虚线框
def draw_dotted_rectangle(img, pt1, pt2, length_1=5, gap=6, thick=2, color=(100, 254, 100)):
(x1, y1), (x2, y2) = pt1, pt2
temp1, temp2 = x1, y1
while x1 + length_1 < x2:
cv2.line(img, (x1, y1), (x1 + length_1, y1), color, thickness=thick)
cv2.line(img, (x1, y2), (x1 + length_1, y2), color, thickness=thick)
x1 += length_1 + gap
while y1 + length_1 < y2:
cv2.line(img, (temp1, y1), (temp1, y1 + length_1), color, thickness=thick)
cv2.line(img, (x1, y1), (x1, y1 + length_1), color, thickness=thick)
y1 += length_1 + gap
# 把本地标签展示到图像中
def label_show(img1, label_path, index):
global small_box_thickness, box_thickness, label_fontScale, label_thickness, key_point_is, key_points, \
key_radius, key_color, key_thick, key_text_scale, key_text_thick, key_text_color, label, draw_line_circle
with open(la_path) as f:
label = f.readlines()
if len(label) == 0:
return
for i, points in enumerate(label):
if key_point_is:
# 获取关键点参数
key_points = points.split(' ')[5:]
points = points.split(' ')[0:5]
classify = int(float(points[0]))
points.pop(0)
point = [float(s.strip('\n')) for s in points]
# point = list(map(float, points))
scale_y, scale_x, _ = img1.shape
x, y, w, h = int((point[0] - point[2] / 2) * scale_x), int(
(point[1] - point[3] / 2) * scale_y), int(
point[2] * scale_x), int(point[3] * scale_y)
if i == index:
draw_dotted_rectangle(img1, (x, y), (x + w, y + h), box_thickness)
else:
cv2.rectangle(img1, (x, y), (x + w, y + h), box_color[classify], thickness=box_thickness)
if draw_line_circle:
# 绘制边上中心点,与四个顶点,矩形框中心点
draw_rect_box(img1, (x, int(0.5 * (y + y + h))), length_1=small_box_thickness)
draw_rect_box(img1, (x + w - 1, int(0.5 * (y + y + h))), length_1=small_box_thickness)
draw_rect_box(img1, (int(0.5 * (x + x + w)), y), length_1=small_box_thickness)
draw_rect_box(img1, (int(0.5 * (x + x + w)), y + h), length_1=small_box_thickness)
draw_rect_box(img1, (x, y), length_1=small_box_thickness)
draw_rect_box(img1, (x + w, y), length_1=small_box_thickness)
draw_rect_box(img1, (x + w, y + h), length_1=small_box_thickness)
draw_rect_box(img1, (x, y + h), length_1=small_box_thickness)
draw_rect_box(img1, (int(x + 0.5 * w), int(y + 0.5 * h)), length_1=small_box_thickness)
cv2.putText(img1, str(final_class[classify]), (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, label_fontScale,
(255, 0, 255), label_thickness)
if key_point_is:
# 依次获取每个关键点
key_x = [float(i) for i in key_points[::3]]
key_y = [float(i) for i in key_points[1::3]]
key_v = [int(float(i)) for i in key_points[2::3]]
index = 0
key_point = zip(key_x, key_y)
for p in key_point:
cv2.circle(img, (int(p[0] * scale_x), int(p[1] * scale_y)), key_radius, key_color[key_v[index]],
thickness=key_thick,
lineType=cv2.LINE_AA)
cv2.putText(img, str(index), (int(p[0] * scale_x - 5), int(p[1] * scale_y - 10)),
cv2.FONT_HERSHEY_SIMPLEX,
key_text_scale, key_text_color[0], key_text_thick)
index += 1
key_points = None
# 回调函数,用于记录鼠标操作
def mouse_event(event, x, y, flag, param):
global label, img, position, Mouse_move, label_index, label_index_pos, dot, Mouse_insert, draw_rectangle, \
end_draw_rectangle, key_points, key_v, key_x, key_y, key_x, key_box, key_points_move, box_move, \
key_insert, label_path, move_key_point
scale_y, scale_x, _ = img.shape
# 鼠标如果位于8个点左右即通过position记录当前位置通过主函数在鼠标附近绘制空心圈
# 通过label_index记录鼠标选择了第几个框通过label_index_pos记录该框第几个点被选中了
with open(la_path) as f:
label = f.readlines()
if move_key_point is None and key_insert is None and Mouse_insert is None and empty_label is None and event == cv2.EVENT_MOUSEMOVE and img is not None and label is not None and \
Mouse_move is None:
for i, la in enumerate(label):
la = la.strip('\n').split(' ')
if key_point_is:
key_points = list(map(float, la))[5:]
la = list(map(float, la))[0:5]
x1, y1 = int((la[1] - la[3] / 2) * scale_x), int((la[2] - la[4] / 2) * scale_y)
x2, y2 = x1 + int(la[3] * scale_x), y1 + int(la[4] * scale_y)
# 这里判断鼠标放到了哪个点上,方便后面移动的时候做计算
if distance((x, y), (x1, y1)):
label_index_pos = 0
position = (x, y)
label_index = i
box_move = True
key_points_move = None
break
elif distance((x, y), (x2, y2)):
label_index_pos = 1
position = (x, y)
label_index = i
box_move = True
key_points_move = None
break
elif distance((x, y), (x1, int(0.5 * y1 + 0.5 * y2))):
label_index_pos = 2
position = (x, y)
label_index = i
box_move = True
key_points_move = None
break
elif distance((x, y), (int((x1 + x2) / 2), y2)):
label_index_pos = 3
position = (x, y)
label_index = i
box_move = True
key_points_move = None
break
elif distance((x, y), (int((x1 + x2) / 2), y1)):
label_index_pos = 4
position = (x, y)
label_index = i
box_move = True
key_points_move = None
break
elif distance((x, y), (x2, int(0.5 * y1 + 0.5 * y2))):
label_index_pos = 5
position = (x, y)
label_index = i
box_move = True
key_points_move = None
break
elif distance((x, y), (x1, y2)):
label_index_pos = 6
position = (x, y)
label_index = i
box_move = True
key_points_move = None
break
elif distance((x, y), (x2, y1)):
label_index_pos = 7
position = (x, y)
label_index = i
box_move = True
key_points_move = None
break
elif distance((x, y), ((x1 + x2) / 2, (y1 + y2) / 2)):
# 框中心
label_index_pos = 8
position = (x, y)
label_index = i
box_move = True
key_points_move = None
break
else:
label_index_pos = None
position = None
label_index = None
if key_point_is:
# 判断鼠标是不是放到了关键点上
key_x = [float(i) for i in key_points[::3]]
key_y = [float(i) for i in key_points[1::3]]
key_v = [float(i) for i in key_points[2::3]] # 能见度
if len(key_x) == len(key_v) and len(key_x) == len(key_y):
for index, key_ in enumerate(key_x):
if distance((x, y), (int(key_ * scale_x), int(key_y[index] * scale_y))):
position = (x, y)
label_index, label_index_pos = i, index
key_box = la
key_points_move = True
box_move = None
break
# 这里到下一个注释都是为了移动已有的框做准备
if position is not None and event == cv2.EVENT_LBUTTONDOWN:
Mouse_move = True
position = None
# 首先判断鼠标选择了该框的第几个点,然后移动鼠标的时候只负责移动该点
if Mouse_move and box_move:
# 先把要移动的框的标签记录下来,然后删除,添加到末尾,不断修改末尾标签来达到移动框的目的
# temp_label用来记录标签
temp_label = label[label_index]
label.pop(label_index)
temp_label = temp_label.strip('\n').split(' ')
temp_label = [float(i) for i in temp_label]
x_1, y_1 = (temp_label[1] - 0.5 * temp_label[3]), (temp_label[2] - 0.5 * temp_label[4])
x_2, y_2 = x_1 + temp_label[3], y_1 + temp_label[4]
# 判断移动的是8个点中的哪个
if label_index_pos == 0:
x_1, y_1 = x / scale_x, y / scale_y
elif label_index_pos == 1:
x_2, y_2 = x / scale_x, y / scale_y
elif label_index_pos == 2:
x_1 = x / scale_x
elif label_index_pos == 3:
y_2 = y / scale_y
elif label_index_pos == 4:
y_1 = y / scale_y
elif label_index_pos == 5:
x_2 = x / scale_x
elif label_index_pos == 6:
x_1, y_2 = x / scale_x, y / scale_y
elif label_index_pos == 7:
y_1, x_2 = y / scale_y, x / scale_x
elif label_index_pos == 8:
x_1, y_1 = x / scale_x - (abs(temp_label[3]) / 2), y / scale_y - (abs(temp_label[4]) / 2)
x_2, y_2 = x / scale_x + (abs(temp_label[3]) / 2), y / scale_y + (abs(temp_label[4]) / 2)
# 把移动后的点信息保存下来添加到标签中,以此形成动态绘制一个框的效果
temp_label[0], temp_label[1], temp_label[2], temp_label[3], temp_label[4] = str(
round((int(temp_label[0])), dot)), \
str(round(((x_1 + x_2) * 0.5), dot)), str(round(((y_1 + y_2) * 0.5), dot)), str(
round((abs(x_1 - x_2)), dot)), str(round((abs(y_1 - y_2)), dot))
temp_label = [str(i) for i in temp_label]
str_temp = ' '.join(temp_label) + '\n'
label.append(str_temp)
label_index = len(label) - 1
elif Mouse_move and key_points_move:
label.pop(label_index)
key_x[label_index_pos] = round(x / scale_x, dot)
key_y[label_index_pos] = round(y / scale_y, dot)
key_box[0] = int(key_box[0])
str_temp = ' '.join([str(j) for j in key_box])
for index, kx in enumerate(key_x):
str_temp += ' ' + str(kx) + ' ' + str(key_y[index]) + ' ' + str(int(key_v[index]))
label.append(str_temp)
label_index = len(label) - 1
if Mouse_move and event == cv2.EVENT_LBUTTONUP:
flag_init()
# 这里是为了删除框
if key_point_is is None and Mouse_insert is None and position is not None and event == cv2.EVENT_LBUTTONDBLCLK and Mouse_move is None:
Mouse_insert = label_index
if key_point_is and event == cv2.EVENT_LBUTTONDBLCLK and Mouse_move is None and key_points_move and box_move is None:
key_insert = label_index_pos
if key_point_is and event == cv2.EVENT_LBUTTONDBLCLK and Mouse_insert is None and key_insert is None and position is None:
move_key_point = (x, y)
# 这里是为了增加新的框
if key_point_is is None and Mouse_insert is None and position is None and Mouse_move is None and event == cv2.EVENT_LBUTTONDOWN and end_draw_rectangle is None:
draw_rectangle = [(x, y), (x, y)]
# 如果鼠标左键一直没有松开,则不断更新第二个点的位置
elif Mouse_insert is None and draw_rectangle is not None and event == cv2.EVENT_MOUSEMOVE and end_draw_rectangle is None:
draw_rectangle[1] = (x, y)
# 鼠标松开了,最后记录松开时鼠标的位置,现在则记录了开始和松开鼠标的两个位置
# 如果两个位置太近,则不添加
elif Mouse_insert is None and draw_rectangle is not None and event == cv2.EVENT_LBUTTONUP:
if end_draw_rectangle is None:
draw_rectangle[1] = (x, y)
if not distance(draw_rectangle[0], draw_rectangle[1]):
end_draw_rectangle = True
else:
draw_rectangle = None
def create_file_key(img_path, label_path):
empty_la = None
if not os.path.exists(label_path):
with open(label_path, 'w') as f:
pass
empty_la = True
with open(label_path) as f:
label_ = f.readlines()
if len(label_) == 0 or label_[0] == '\n':
empty_la = True
img_s = img_read(img_path, 950) # 950调整图像的大小
if key_point_is and empty_la:
box_create = '0 0.5 0.5 0.3 0.3 '
len_t = img_s.shape[1] // key_point_num
key_num_x = [str(round((i * len_t + 20) / img_s.shape[1], dot)) + ' ' + str(0.5) + ' ' + '2' for i in
range(key_point_num)]
with open(label_path, 'w') as f:
f.write(box_create + ' '.join(key_num_x))
def main(img_path, label_path):
global img, position, label, Mouse_insert, draw_rectangle, end_draw_rectangle, append_str_temp, empty_label, \
Mouse_move, dot, box_move, key_insert, key_point_one, key_point_two, key_x, key_y, key_v, \
move_key_point, append_new_key_point, append_new_key_point_index, window_w, window_h
# 判断本地是否存在文件,或者文件中是否为空或者存在一个换行符,就先把标签删除,添加'0 0 0 0 0\n'
# 如果不预先添加一个处理起来有点麻烦,这里就先加一个,然后后面删掉就行了
if not os.path.exists(label_path):
empty_label = True
with open(label_path, 'w') as f:
pass
with open(label_path) as f:
label = f.readlines()
if len(label) == 0 or label[0] == '\n':
empty_label = True
# 这里的2是将原图缩小为2分之一
print(img_path)
img_s = img_read(img_path, 900)
if key_point_is and empty_label:
box_create = '0 0.5 0.5 0.3 0.3 '
len_t = img_s.shape[1] // key_point_num
key_num_x = [str(round((i * len_t + 20) / img_s.shape[1], dot)) + ' ' + str(0.5) + ' ' + '2' for i in
range(key_point_num)]
with open(label_path, 'w') as f:
f.write(box_create + ' '.join(key_num_x))
label = box_create + ' '.join(key_num_x)
# 创建回调函数,绑定窗口
cv2.namedWindow('image', cv2.WINDOW_NORMAL)
_, _, window_w, window_h = cv2.getWindowImageRect('image')
cv2.resizeWindow('image', img_s.shape[1], img_s.shape[0])
cv2.setMouseCallback('image', mouse_event)
# 刷新图像的地方
while True:
# 首先读取下标签,用来初始化显示
with open(label_path, 'w') as f:
for i in label:
f.write(i)
# 如果鼠标选中了框的8个点之一就在鼠标周围绘制空心圈
if Mouse_insert is None and draw_rectangle is None and position is not None and key_insert is None:
img = img_s.copy()
label_show(img, label_path, Mouse_insert)
cv2.circle(img, position, 10, (0, 255, 100), 2)
# 如果选择开始增加新的框,则不断绘制鼠标起始点和移动过程之间形成的框
elif draw_rectangle is not None and end_draw_rectangle is None:
img = img_s.copy()
label_show(img, label_path, Mouse_insert)
cv2.rectangle(img, draw_rectangle[0], draw_rectangle[1], color=box_color[1], thickness=2)
# 当松开鼠标后,记录两点位置,并提示选择类别
elif draw_rectangle is not None and end_draw_rectangle:
scale_y, scale_x, _ = img.shape
x1, y1 = draw_rectangle[0]
x2, y2 = draw_rectangle[1]
w1, h1 = abs(x2 - x1), abs(y2 - y1)
append_str_temp = str(round((x1 + x2) / 2 / scale_x, dot)) + ' ' + str(
round((y1 + y2) / 2 / scale_y, dot)) + ' ' + \
str(round((w1 / scale_x), dot)) + ' ' + str(round((h1 / scale_y), dot)) + '\n'
cv2.putText(img, 'choose your classify', (0, img.shape[0] // 2 - 30),
cv2.FONT_HERSHEY_SIMPLEX,
0.7, (255, 0, 255), 2)
cv2.putText(img, ' '.join([str(i) + ':' + my_cls[i] for i in my_cls]), (0, img.shape[0] // 2 + 30),
cv2.FONT_HERSHEY_SIMPLEX,
0.7, (100, 255, 255), 2)
elif key_insert is not None:
position, Mouse_move, box_move = None, None, None # 禁用其他操作
cv2.putText(img, 'Switching visibility: 0 1 2', (0, img.shape[0] // 2 - 30),
cv2.FONT_HERSHEY_SIMPLEX,
1, (100, 255, 255), 2,
lineType=cv2.LINE_AA)
elif move_key_point is not None:
position, Mouse_move, box_move = None, None, None # 禁用其他操作
cv2.putText(img, 'choose point: 0 - {}'.format(key_point_num - 1), (0, img.shape[0] // 2 - 30),
cv2.FONT_HERSHEY_SIMPLEX,
1, (100, 255, 255), 2,
lineType=cv2.LINE_AA)
# 如果什么标志都没有,就正常显示一个图
else:
img = img_s.copy()
if Mouse_insert is not None:
position, Mouse_move = None, None
cv2.putText(img, 'delete: W, exit: E', (0, img.shape[0] // 2 - 30),
cv2.FONT_HERSHEY_SIMPLEX,
1, (100, 255, 255), 2,
lineType=cv2.LINE_AA)
label_show(img, label_path, Mouse_insert)
cv2.imshow('image', img)
# key用来获取键盘输入
key = cv2.waitKey(10)
# 输入为Q则退出
if key == ord('Q'):
append_new_key_point = None
# 退出按键
break
if move_key_point is not None and key_point_one is None and 48 <= key <= 57:
key_point_one = int(chr(key))
key = 0
if move_key_point is not None and key_point_two is None and 48 <= key <= 57:
key_point_two = int(chr(key))
key = 0
if (move_key_point is not None) and (key_point_one is not None) and (key_point_two is not None):
with open(la_path) as f:
label = f.readlines()
for i, la in enumerate(label):
la = la.strip('\n').split(' ')
key_points_ = list(map(float, la))[5:]
key_box_ = list(map(float, la))[0:5]
key_x_ = [float(i) for i in key_points_[::3]]
key_y_ = [float(i) for i in key_points_[1::3]]
key_v_ = [float(i) for i in key_points_[2::3]] # 能见度
key_box_[0] = int(key_box_[0])
index_ = key_point_one * 10 + key_point_two
if index_ >= key_point_num:
break
key_x_[index_] = round(move_key_point[0] / img.shape[1], dot)
key_y_[index_] = round(move_key_point[1] / img.shape[0], dot)
str_temp = ' '.join([str(j) for j in key_box_])
for index, kx in enumerate(key_x_):
str_temp += ' ' + str(kx) + ' ' + str(key_y_[index]) + ' ' + str(int(key_v_[index]))
label = str_temp
with open(la_path, 'w') as f:
f.write(str_temp)
move_key_point, key_point_one, key_point_two = None, None, None
break
move_key_point, key_point_one, key_point_two = None, None, None
# 如果按键输入为W则删除选中的框
if Mouse_insert is not None and key == ord('W'):
label.pop(Mouse_insert)
Mouse_insert = None
elif key_insert is not None and key == ord('0'):
with open(label_path, 'r') as f:
label_temp = f.read()
str_temp = label_temp.split(' ')
str_temp[3 * int(key_insert) + 7] = '0'
str_temp[3 * int(key_insert) + 7 - 1] = '0'
str_temp[3 * int(key_insert) + 7 - 2] = '0'
with open(label_path, 'w') as f:
f.write(' '.join(str_temp))
label = ' '.join(str_temp)
key_insert = None
elif key_insert is not None and key == ord('1'):
with open(label_path, 'r') as f:
label_temp = f.read()
str_temp = label_temp.split(' ')
str_temp[3 * int(key_insert) + 7] = '1'
with open(label_path, 'w') as f:
f.write(' '.join(str_temp))
label = ' '.join(str_temp)
key_insert = None
elif key_insert is not None and key == ord('2'):
with open(label_path, 'r') as f:
label_temp = f.read()
str_temp = label_temp.split(' ')
str_temp[3 * int(key_insert) + 7] = '2'
with open(label_path, 'w') as f:
f.write(' '.join(str_temp))
label = ' '.join(str_temp)
key_insert = None
# 如果输入为E则从选中框的状态退出
elif key == ord('E'):
Mouse_insert = None
# 通过键盘获取输入的类别
elif Mouse_move is None and Mouse_insert is None and draw_rectangle is not None and end_draw_rectangle is not None \
and (48 <= key <= 57 or key == ord('Z') or key == ord('X') or key == ord('C') or key == ord('V') or key==ord('B') or key==ord('N')):
if 48 <= key <= 57:
str_temp = str(chr(key)) + ' ' + append_str_temp
elif key == ord('Z'):
str_temp = str(10) + ' ' + append_str_temp
elif key == ord('X'):
str_temp = str(11) + ' ' + append_str_temp
elif key == ord('C'):
str_temp = str(12) + ' ' + append_str_temp
elif key == ord('V'):
str_temp = str(13) + ' ' + append_str_temp
elif key == ord('B'):
str_temp = str(14) + ' ' + append_str_temp
elif key == ord('N'):
str_temp = str(15) + ' ' + append_str_temp
label.append(str_temp)
append_str_temp, draw_rectangle, end_draw_rectangle, empty_label = None, None, None, None
elif key == ord('R'):
flag_init()
append_new_key_point = True
break
elif key == ord('T'):
exit(0)
elif key == ord('Y'):
os.remove(img_path)
os.remove(label_path)
break
def delete_line_feed(label_path):
# 去掉最后一行的换行符'\n',保存的时候需要
if os.path.exists(label_path):
with open(label_path) as f:
label_ = f.read()
label_ = label_.rstrip('\n')
with open(label_path, 'w') as f:
f.write(label_)
def append__line_feed(label_path):
# 加上最后一行的换行符'\n',标注的时候增加新的框的时候需要
with open(label_path) as f:
label_ = f.read()
if len(label_) < 4:
with open(label_path, 'w') as f:
pass
return
label_ = label_.rstrip('\n') + '\n'
with open(label_path, 'w') as f:
f.write(label_)
def key_check(label_path):
# 检查开启关键点之后本地标签是否满足要求, 如果本地标签中和预设关键点数不等以及关键点数量不是3的倍数都会将原有标签重置
if os.path.exists(label_path):
with open(label_path) as f:
label_ = f.readlines()
for label_ in label_:
label_ = label_.strip('\n').split(' ')
if ((len(label_) - 5) % 3) or ((len(label_) - 5) // 3 - key_point_num):
with open(label_path, 'w') as f:
pass
def label_check(label_path):
# 检查普通标签,判断每行是包含5个数值
if os.path.exists(label_path):
with open(label_path) as f:
label_ = f.readlines()
for i in label_:
i = i.strip('\n').split(' ')
if len(i) - 5 != 0:
with open(label_path, 'w'):
pass
def merge_file_key(la_path, index):
with open(la_path) as f:
text = f.read().strip('\n')
for i in range(index):
with open(la_path.split('.')[0] + str(i) + '.txt') as f:
text += '\n' + f.read().strip('\n')
os.remove(la_path.split('.')[0] + str(i) + '.txt')
with open(la_path, 'w') as f:
f.write(text)
if __name__ == '__main__':
image_ = os.listdir(image_path)
for im in image_:
flag_init()
im_path = os.path.join(image_path, im)
la_path = os.path.join(label_path, im.split('.')[0] + '.txt')
if key_point_is:
key_check(la_path) # 检查本地标签的关键点数量是否和预设的关键点数量相等以及去除框的5点后点数是否满足为3的倍数
create_file_key(im_path, la_path)
else:
delete_line_feed(la_path)
label_check(la_path)
if os.path.exists(la_path):
# 先增加一个换行符为了后面的增加框的操作
append__line_feed(la_path)
while True:
main(im_path, la_path)
if append_new_key_point is None:
break
else:
la_path = os.path.join(label_path, im.split('.')[0] + str(append_new_key_point_index) + '.txt')
with open(la_path, 'w') as f:
pass
if key_point_is:
key_check(la_path) # 检查本地标签的关键点数量是否和预设的关键点数量相等以及去除框的5点后点数是否满足为3的倍数
create_file_key(im_path, la_path)
else:
delete_line_feed(la_path)
label_check(la_path)
append_new_key_point_index += 1
if append_new_key_point_index != 0:
merge_file_key(os.path.join(label_path, im.split('.')[0] + '.txt'), append_new_key_point_index)
append_new_key_point_index = 0
if os.path.exists(la_path):
# 去掉最后一行的换行符
delete_line_feed(la_path)

40
cut.py Normal file
View File

@ -0,0 +1,40 @@
from PIL import Image
def crop_image_into_nine_parts(image_path):
# 打开图片
image = Image.open(image_path)
# 获取图片的宽度和高度
width, height = image.size
# 计算裁剪后小图片的宽度和高度
part_width = width // 3
part_height = height // 3
parts = []
# 循环裁剪九等分的小图片
for i in range(3):
for j in range(3):
left = j * part_width
top = i * part_height
right = (j + 1) * part_width
bottom = (i + 1) * part_height
# 裁剪小图片
part = image.crop((left, top, right, bottom))
parts.append(part)
return parts
# 图片路径
image_path = "scripts/input/images/725.jpg"
# 调用函数裁剪图片
nine_parts = crop_image_into_nine_parts(image_path)
# 保存裁剪后的小图片
for idx, part in enumerate(nine_parts):
part.save(f"part_{idx + 1}.jpg")

91
detect_c.py Normal file
View File

@ -0,0 +1,91 @@
import time
from pathlib import Path
import cv2
import torch
from PIL import ImageGrab
import numpy as np
import os
import glob
from models.common import DetectMultiBackend
from utils.general import non_max_suppression, scale_boxes
from utils.plots import Annotator, colors
from utils.augmentations import letterbox
from utils.torch_utils import select_device
FILE = Path(__file__).resolve()
folder_path = R'C:\Users\lengdan\Desktop\yolov5-master\data\images' # 本地文件夹
camera = 2 # 0调用本地相机, 1检测文件夹中的图像, 2检测屏幕上内容
class mydataload:
def __init__(self):
self.count = 0
if camera == 0:
self.cap = cv2.VideoCapture(0)
def __iter__(self):
return self
def __next__(self):
if camera == 0:
_, im0 = self.cap.read()
elif camera == 1:
file_list = glob.glob(os.path.join(folder_path, '*.jpg'))
# 对文件列表按时间戳排序,以确保最新添加的图像排在最后面
file_list.sort(key=os.path.getmtime)
im0 = cv2.imread(file_list[self.count])
self.count = self.count if self.count == len(file_list) - 1 else self.count + 1
else: # camera 检测屏幕上的内容
# 指定截图区域的左上角和右下角坐标
x1, y1 = 1000, 100 # 左上角
x2, y2 = 1900, 1000 # 右下角
# 截取屏幕区域
img = ImageGrab.grab(bbox=(x1, y1, x2, y2))
im0 = np.array(img)
im = letterbox(im0, 640, auto=True)[0] # padded resize
im = im.transpose((2, 0, 1)) # HWC to CHW, BGR to RGB
if camera != 2:
im = im[::-1]
im = np.ascontiguousarray(im) # contiguous
return im, im0
def get_image(model, im, im0s, conf_thres=0.5, iou_thres=0.5, line_thickness=3):
# temp_list = []
pred = model(im, visualize=False)
pred = non_max_suppression(pred, conf_thres, iou_thres, None, False, max_det=1000)
for i, det in enumerate(pred):
im0, names = im0s.copy(), model.names
annotator = Annotator(im0, line_width=line_thickness, example=str(names))
if len(det):
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()
for *xyxy, conf, cls in reversed(det):
c = int(cls)
# temp_list.append((xyxy[0].item(), xyxy[1].item(), xyxy[2].item(), xyxy[3].item()))
label = f'{names[c]} {conf:.2f}'
annotator.box_label(xyxy, label, color=colors(c, True))
return annotator.result()
if __name__ == "__main__":
device = select_device('0')
model = DetectMultiBackend('7_29_last.pt', device=device, dnn=False, data='', fp16=False)
dataset = mydataload()
model.warmup(imgsz=(1, 3, 640, 640)) # warmup
for im, im0s in dataset:
t0 = time.time()
im = torch.from_numpy(im).to(model.device)
im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
im /= 255 # 0 - 255 to 0.0 - 1.0
if len(im.shape) == 3:
im = im[None] # expand for batch dim
im0 = get_image(model, im, im0s)
if camera == 2:
im0 = cv2.cvtColor(im0, cv2.COLOR_BGR2RGB)
cv2.namedWindow('1', cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux)
cv2.resizeWindow('1', im0.shape[1] // 2, im0.shape[0] // 2)
cv2.imshow('1', im0)
if cv2.waitKey(1) == ord('Q'):
exit(0)
print(time.time() - t0)

169
display.py Normal file
View File

@ -0,0 +1,169 @@
import sys
import os
import cv2
import numpy as np
from PyQt5 import QtCore, QtGui, QtWidgets
from segment_anything import sam_model_registry, SamPredictor
class Ui_MainWindow(object):
def setupUi(self, MainWindow):
MainWindow.setObjectName("MainWindow")
MainWindow.resize(1333, 657)
self.centralwidget = QtWidgets.QWidget(MainWindow)
self.centralwidget.setObjectName("centralwidget")
self.pushButton_init = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_init.setGeometry(QtCore.QRect(10, 30, 141, 41))
self.pushButton_init.setObjectName("pushButton_init")
self.pushButton_openimg = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_openimg.setGeometry(QtCore.QRect(10, 90, 141, 41))
self.pushButton_openimg.setObjectName("pushButton_openimg")
self.pushButton_Fusionimg = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_Fusionimg.setGeometry(QtCore.QRect(10, 270, 141, 41))
self.pushButton_Fusionimg.setObjectName("pushButton_Fusionimg")
self.pushButton_exit = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_exit.setGeometry(QtCore.QRect(10, 570, 141, 41))
self.pushButton_exit.setObjectName("pushButton_exit")
self.pushButton_Transparency = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_Transparency.setGeometry(QtCore.QRect(10, 380, 141, 41))
self.pushButton_Transparency.setObjectName("pushButton_Transparency")
self.pushButton_copymask = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_copymask.setGeometry(QtCore.QRect(10, 450, 141, 41))
self.pushButton_copymask.setObjectName("pushButton_copymask")
self.pushButton_saveimg = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_saveimg.setGeometry(QtCore.QRect(10, 510, 141, 41))
self.pushButton_saveimg.setObjectName("pushButton_saveimg")
self.horizontalSlider = QtWidgets.QSlider(self.centralwidget)
self.horizontalSlider.setGeometry(QtCore.QRect(10, 330, 141, 22))
self.horizontalSlider.setOrientation(QtCore.Qt.Horizontal)
self.horizontalSlider.setObjectName("horizontalSlider")
self.horizontalSlider.setValue(50)
self.label_Originalimg = QtWidgets.QLabel(self.centralwidget)
self.label_Originalimg.setGeometry(QtCore.QRect(160, 30, 571, 581))
self.label_Originalimg.setStyleSheet("background-color: rgb(255, 255, 255);")
self.label_Originalimg.setObjectName("label_Originalimg")
self.label_Maskimg = QtWidgets.QLabel(self.centralwidget)
self.label_Maskimg.setGeometry(QtCore.QRect(740, 30, 581, 581))
self.label_Maskimg.setStyleSheet("background-color: rgb(255, 255, 255);")
self.label_Maskimg.setObjectName("label_Maskimg")
self.pushButton_shang = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_shang.setGeometry(QtCore.QRect(10, 150, 141, 41))
self.pushButton_shang.setObjectName("pushButton_shang")
self.pushButton_xia = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_xia.setGeometry(QtCore.QRect(10, 210, 141, 41))
self.pushButton_xia.setObjectName("pushButton_openimg_3")
MainWindow.setCentralWidget(self.centralwidget)
self.menubar = QtWidgets.QMenuBar(MainWindow)
self.menubar.setGeometry(QtCore.QRect(0, 0, 1333, 26))
self.menubar.setObjectName("menubar")
MainWindow.setMenuBar(self.menubar)
self.statusbar = QtWidgets.QStatusBar(MainWindow)
self.statusbar.setObjectName("statusbar")
MainWindow.setStatusBar(self.statusbar)
self.retranslateUi(MainWindow)
QtCore.QMetaObject.connectSlotsByName(MainWindow)
def retranslateUi(self, MainWindow):
_translate = QtCore.QCoreApplication.translate
MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow"))
self.pushButton_init.setText(_translate("MainWindow", "重置选择"))
self.pushButton_openimg.setText(_translate("MainWindow", "打开图片"))
self.pushButton_shang.setText(_translate("MainWindow", "上一张"))
self.pushButton_xia.setText(_translate("MainWindow", "下一张"))
self.pushButton_Fusionimg.setText(_translate("MainWindow", "融合背景图片"))
self.pushButton_exit.setText(_translate("MainWindow", "退出"))
self.pushButton_Transparency.setText(_translate("MainWindow", "调整透明度"))
self.pushButton_copymask.setText(_translate("MainWindow", "复制掩码"))
self.pushButton_saveimg.setText(_translate("MainWindow", "保存图片"))
self.label_Originalimg.setText(
_translate("MainWindow", "<html><head/><body><p align=\"center\">原始图像</p></body></html>"))
self.label_Maskimg.setText(
_translate("MainWindow", "<html><head/><body><p align=\"center\">掩码图像</p></body></html>"))
def init_slots(self):
self.pushButton_openimg.clicked.connect(self.button_image_open)
self.pushButton_init.clicked.connect(self.button_image_init)
self.pushButton_shang.clicked.connect(self.button_image_shang)
self.pushButton_xia.clicked.connect(self.button_image_xia)
self.pushButton_Fusionimg.clicked.connect(self.button_image_Fusionimg)
self.pushButton_copymask.clicked.connect(self.button_image_copymask)
self.pushButton_saveimg.clicked.connect(self.button_image_saveimg)
self.pushButton_exit.clicked.connect(self.button_image_exit)
self.horizontalSlider.valueChanged.connect(self.slider_value_changed)
self.pushButton_openimg.clicked.connect(self.button_image_open)
# 连接label_Originalimg点击事件
self.label_Originalimg.mousePressEvent = self.label_Originalimg_click_event
def label_Originalimg_click_event(self, event):
# 捕获label_Originalimg中的鼠标点击事件
point = event.pos()
x, y = point.x(), point.y()
print("Clicked at:", x, y) # 您可以使用这些坐标进行预测
def button_image_open(self):
choice = QtWidgets.QMessageBox.question(None, "选择", "您想要打开文件夹还是选择一个图片文件?",
QtWidgets.QMessageBox.Open | QtWidgets.QMessageBox.Cancel)
if choice == QtWidgets.QMessageBox.Open:
folder_path = QtWidgets.QFileDialog.getExistingDirectory(None, "选择文件夹", "")
if folder_path:
image_files = [os.path.join(folder_path, file) for file in os.listdir(folder_path)
if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp'))]
if image_files:
self.image_files = image_files
self.current_index = 0
self.display_image()
elif choice == QtWidgets.QMessageBox.Cancel:
selected_image, _ = QtWidgets.QFileDialog.getOpenFileName(None, "选择图片", "",
"Image files (*.png *.jpg *.jpeg *.bmp)")
if selected_image:
self.image_files = [selected_image]
self.current_index = 0
self.display_image()
def display_image(self):
if hasattr(self, 'image_files') and self.image_files:
pixmap = QtGui.QPixmap(self.image_files[self.current_index])
self.label_Originalimg.setPixmap(pixmap)
self.label_Originalimg.setScaledContents(True)
def button_image_shang(self):
if hasattr(self, 'image_files') and self.image_files:
self.current_index = (self.current_index - 1) % len(self.image_files)
self.display_image()
def button_image_xia(self):
if hasattr(self, 'image_files') and self.image_files:
self.current_index = (self.current_index + 1) % len(self.image_files)
self.display_image()
def button_image_exit(self):
sys.exit()
def slider_value_changed(self, value):
print("Slider value changed:", value)
def button_image_saveimg(self):
pass
def button_image_Fusionimg(self):
pass
def button_image_copymask(self):
pass
def button_image_init(self):
pass
if __name__ == "__main__":
app = QtWidgets.QApplication(sys.argv)
MainWindow = QtWidgets.QMainWindow()
ui = Ui_MainWindow()
ui.setupUi(MainWindow)
ui.init_slots() # 调用init_slots以连接信号和槽
MainWindow.show()
sys.exit(app.exec_())

32
linter.sh Normal file
View File

@ -0,0 +1,32 @@
#!/bin/bash -e
# Copyright (c) Facebook, Inc. and its affiliates.
{
black --version | grep -E "23\." > /dev/null
} || {
echo "Linter requires 'black==23.*' !"
exit 1
}
ISORT_VERSION=$(isort --version-number)
if [[ "$ISORT_VERSION" != 5.12* ]]; then
echo "Linter requires isort==5.12.0 !"
exit 1
fi
echo "Running isort ..."
isort . --atomic
echo "Running black ..."
black -l 100 .
echo "Running flake8 ..."
if [ -x "$(command -v flake8)" ]; then
flake8 .
else
python3 -m flake8 .
fi
echo "Running mypy..."
mypy --exclude 'setup.py|notebooks' .

113
mask6.ui Normal file
View File

@ -0,0 +1,113 @@
<?xml version="1.0" encoding="UTF-8"?>
<ui version="4.0">
<class>MainWindow</class>
<widget class="QMainWindow" name="MainWindow">
<property name="geometry">
<rect>
<x>0</x>
<y>0</y>
<width>889</width>
<height>600</height>
</rect>
</property>
<property name="windowTitle">
<string>MainWindow</string>
</property>
<widget class="QWidget" name="centralwidget">
<widget class="QLabel" name="label">
<property name="geometry">
<rect>
<x>10</x>
<y>120</y>
<width>421</width>
<height>331</height>
</rect>
</property>
<property name="styleSheet">
<string notr="true">background-color: rgb(255, 255, 255);</string>
</property>
<property name="text">
<string>原始图像</string>
</property>
<property name="alignment">
<set>Qt::AlignCenter</set>
</property>
</widget>
<widget class="QLabel" name="label_2">
<property name="geometry">
<rect>
<x>440</x>
<y>120</y>
<width>421</width>
<height>331</height>
</rect>
</property>
<property name="styleSheet">
<string notr="true">background-color: rgb(255, 255, 255);</string>
</property>
<property name="text">
<string>预测图像</string>
</property>
<property name="alignment">
<set>Qt::AlignCenter</set>
</property>
</widget>
<widget class="QPushButton" name="pushButton">
<property name="geometry">
<rect>
<x>150</x>
<y>470</y>
<width>131</width>
<height>51</height>
</rect>
</property>
<property name="text">
<string>打开图像</string>
</property>
</widget>
<widget class="QPushButton" name="pushButton_2">
<property name="geometry">
<rect>
<x>570</x>
<y>470</y>
<width>131</width>
<height>51</height>
</rect>
</property>
<property name="text">
<string>预测图像</string>
</property>
</widget>
<widget class="QTextEdit" name="textEdit">
<property name="geometry">
<rect>
<x>320</x>
<y>20</y>
<width>221</width>
<height>41</height>
</rect>
</property>
<property name="html">
<string>&lt;!DOCTYPE HTML PUBLIC &quot;-//W3C//DTD HTML 4.0//EN&quot; &quot;http://www.w3.org/TR/REC-html40/strict.dtd&quot;&gt;
&lt;html&gt;&lt;head&gt;&lt;meta name=&quot;qrichtext&quot; content=&quot;1&quot; /&gt;&lt;style type=&quot;text/css&quot;&gt;
p, li { white-space: pre-wrap; }
&lt;/style&gt;&lt;/head&gt;&lt;body style=&quot; font-family:'SimSun'; font-size:9pt; font-weight:400; font-style:normal;&quot;&gt;
&lt;p style=&quot; margin-top:0px; margin-bottom:0px; margin-left:0px; margin-right:0px; -qt-block-indent:0; text-indent:0px;&quot;&gt;&lt;span style=&quot; font-size:16pt;&quot;&gt;分割图像GUI界面&lt;/span&gt;&lt;/p&gt;&lt;/body&gt;&lt;/html&gt;</string>
</property>
</widget>
</widget>
<widget class="QMenuBar" name="menubar">
<property name="geometry">
<rect>
<x>0</x>
<y>0</y>
<width>889</width>
<height>26</height>
</rect>
</property>
</widget>
<widget class="QStatusBar" name="statusbar"/>
</widget>
<resources/>
<connections/>
</ui>

72
maskui.py Normal file
View File

@ -0,0 +1,72 @@
# -*- coding: utf-8 -*-
# Form implementation generated from reading ui file 'mask6.ui'
#
# Created by: PyQt5 UI code generator 5.15.9
#
# WARNING: Any manual changes made to this file will be lost when pyuic5 is
# run again. Do not edit this file unless you know what you are doing.
from PyQt5 import QtCore, QtGui, QtWidgets
class Ui_MainWindow(object):
def setupUi(self, MainWindow):
MainWindow.setObjectName("MainWindow")
MainWindow.resize(889, 600)
self.centralwidget = QtWidgets.QWidget(MainWindow)
self.centralwidget.setObjectName("centralwidget")
self.label = QtWidgets.QLabel(self.centralwidget)
self.label.setGeometry(QtCore.QRect(10, 120, 421, 331))
self.label.setStyleSheet("background-color: rgb(255, 255, 255);")
self.label.setAlignment(QtCore.Qt.AlignCenter)
self.label.setObjectName("label")
self.label_2 = QtWidgets.QLabel(self.centralwidget)
self.label_2.setGeometry(QtCore.QRect(440, 120, 421, 331))
self.label_2.setStyleSheet("background-color: rgb(255, 255, 255);")
self.label_2.setAlignment(QtCore.Qt.AlignCenter)
self.label_2.setObjectName("label_2")
self.pushButton = QtWidgets.QPushButton(self.centralwidget)
self.pushButton.setGeometry(QtCore.QRect(150, 470, 131, 51))
self.pushButton.setObjectName("pushButton")
self.pushButton_2 = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_2.setGeometry(QtCore.QRect(570, 470, 131, 51))
self.pushButton_2.setObjectName("pushButton_2")
self.textEdit = QtWidgets.QTextEdit(self.centralwidget)
self.textEdit.setGeometry(QtCore.QRect(320, 20, 221, 41))
self.textEdit.setObjectName("textEdit")
MainWindow.setCentralWidget(self.centralwidget)
self.menubar = QtWidgets.QMenuBar(MainWindow)
self.menubar.setGeometry(QtCore.QRect(0, 0, 889, 26))
self.menubar.setObjectName("menubar")
MainWindow.setMenuBar(self.menubar)
self.statusbar = QtWidgets.QStatusBar(MainWindow)
self.statusbar.setObjectName("statusbar")
MainWindow.setStatusBar(self.statusbar)
self.retranslateUi(MainWindow)
QtCore.QMetaObject.connectSlotsByName(MainWindow)
def retranslateUi(self, MainWindow):
_translate = QtCore.QCoreApplication.translate
MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow"))
self.label.setText(_translate("MainWindow", "原始图像"))
self.label_2.setText(_translate("MainWindow", "预测图像"))
self.pushButton.setText(_translate("MainWindow", "打开图像"))
self.pushButton_2.setText(_translate("MainWindow", "预测图像"))
self.textEdit.setHtml(_translate("MainWindow", "<!DOCTYPE HTML PUBLIC \"-//W3C//DTD HTML 4.0//EN\" \"http://www.w3.org/TR/REC-html40/strict.dtd\">\n"
"<html><head><meta name=\"qrichtext\" content=\"1\" /><style type=\"text/css\">\n"
"p, li { white-space: pre-wrap; }\n"
"</style></head><body style=\" font-family:\'SimSun\'; font-size:9pt; font-weight:400; font-style:normal;\">\n"
"<p style=\" margin-top:0px; margin-bottom:0px; margin-left:0px; margin-right:0px; -qt-block-indent:0; text-indent:0px;\"><span style=\" font-size:16pt;\">分割图像GUI界面</span></p></body></html>"))
if __name__ == "__main__":
import sys
app = QtWidgets.QApplication(sys.argv)
MainWindow = QtWidgets.QMainWindow()
ui = Ui_MainWindow()
ui.setupUi(MainWindow)
MainWindow.show()
sys.exit(app.exec_())

426
modeltest.py Normal file
View File

@ -0,0 +1,426 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, Type
from .common import LayerNorm2d, MLPBlock
# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
# 这个代码定义了 SAM 的图像编码器 ImageEncoderViT。它包含以下主要部分:
# 1. patch_embed: 这是 ViT 的 patch embedding 层,用于将输入图像划分为 patch,并获得 patch 的 embedding。
# 2. pos_embed: 这是 ViT的绝对位置 embedding,用于为每个patch提供位置信息。
# 3. blocks: 这是 ViT 的 transformer encoder 块的列表,每个块包含多头自注意力层和前馈神经网络。
# 4. neck: 这是图像编码器的“颈部”,包含几个卷积层和 LayerNorm 层,用于从 transformer encoder 块的输出中提取特征。
# 5. forward(): 这是图像编码器的前向传播过程。首先通过 patch_embed 层获得 patch embedding, 然后加上 pos_embed。
# 接着,patch embedding通过transformer encoder块。最后, neck 层从 transformer encoder 块的输出中提取特征。
# 所以,这个 ImageEncoderViT 类定义了 SAM 的图像编码器,它基于 ViT,包含 patch embedding、位置 embedding、
# transformer encoder块以及 neck, 可以从输入图像中提取特征。
# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
class ImageEncoderViT(nn.Module):
def __init__(
self,
img_size: int = 1024,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.0,
out_chans: int = 256,
qkv_bias: bool = True,
norm_layer: Type[nn.Module] = nn.LayerNorm,
act_layer: Type[nn.Module] = nn.GELU,
use_abs_pos: bool = True,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
window_size: int = 0,
global_attn_indexes: Tuple[int, ...] = (),
) -> None:
super().__init__()
self.img_size = img_size
self.patch_embed = PatchEmbed(
kernel_size=(patch_size, patch_size),
stride=(patch_size, patch_size),
in_chans=in_chans,
embed_dim=embed_dim,
)
self.pos_embed: Optional[nn.Parameter] = None
if use_abs_pos:
# Initialize absolute positional embedding with pretrain image size.
self.pos_embed = nn.Parameter(
torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
)
self.blocks = nn.ModuleList()
for i in range(depth):
block = Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
norm_layer=norm_layer,
act_layer=act_layer,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
window_size=window_size if i not in global_attn_indexes else 0,
input_size=(img_size // patch_size, img_size // patch_size),
)
self.blocks.append(block)
self.neck = nn.Sequential(
nn.Conv2d(
embed_dim,
out_chans,
kernel_size=1,
bias=False,
),
LayerNorm2d(out_chans),
nn.Conv2d(
out_chans,
out_chans,
kernel_size=3,
padding=1,
bias=False,
),
LayerNorm2d(out_chans),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.patch_embed(x)
if self.pos_embed is not None:
x = x + self.pos_embed
for blk in self.blocks:
x = blk(x)
x = self.neck(x.permute(0, 3, 1, 2))
return x
# 这个 Block 类实现了 transformer block, 可以选择使用全局注意力或局部窗口注意力,同时包含残差连接。它包含:
# __init__方法:
# 1. 输入参数:
# - dim: 输入通道数
# - num_heads: 注意力头数
# - mlp_ratio: mlp 隐藏层与输入 embedding 维度的比例
# - qkv_bias: 是否为 query、key、value 添加偏置
# - norm_layer: 归一化层
# - act_layer: 激活层
# - use_rel_pos: 是否使用相对位置 embedding
# - rel_pos_zero_init: 是否将相对位置 embedding 初始化为 0
# - window_size: 窗口注意力的窗口大小,如果为 0 则使用全局注意力
# - input_size: 计算相对位置 embedding 大小所需的输入分辨率
# 2. 实例化第 1 次和第 2 次归一化层 norm1 和 norm2。
# 3. 实例化 Attention 层和 MLPBlock 层。Attention 层的输入大小根据是否使用窗口注意力进行了调整。
# 4. 记录窗口注意力的窗口大小 window_size。
# forward方法:
# 1. 提取 shortcut 并对 x 进行第 1 次归一化。
# 2. 如果使用窗口注意力, 则调用 window_partition 对 x 进行窗口划分。
# 3. 将 x 输入 Attention 层。
# 4. 如果使用窗口注意力,则调用 window_unpartition 对 x 进行窗口反划分。
# 5. x = shortcut + x,实现第 1 次残差连接。
# 6. x = x + mlp(norm2(x)),实现第 2 次残差连接和 MLPBlock。
# 7. 返回最终的 x。
# 所以,这个 Block 类实现了带有可选的窗口注意力和双残差连接的transformer block。
# 窗口注意力可以更好地建模局部结构,双残差连接可以提高梯度流动,都是transformer结构的重要改进。
# 这个 Block 类实现了 transformer 的关键组成部分,同时提供了窗口注意力和残差连接等重要变体,可以显著提高其表现力和泛化能力。
class Block(nn.Module):
"""Transformer blocks with support of window attention and residual propagation blocks"""
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
norm_layer: Type[nn.Module] = nn.LayerNorm,
act_layer: Type[nn.Module] = nn.GELU,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
window_size: int = 0,
input_size: Optional[Tuple[int, int]] = None,
) -> None:
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
input_size=input_size if window_size == 0 else (window_size, window_size),
)
self.norm2 = norm_layer(dim)
self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
self.window_size = window_size
def forward(self, x: torch.Tensor) -> torch.Tensor:
shortcut = x
x = self.norm1(x)
# Window partition
if self.window_size > 0:
H, W = x.shape[1], x.shape[2]
x, pad_hw = window_partition(x, self.window_size)
x = self.attn(x)
# Reverse window partition
if self.window_size > 0:
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
x = shortcut + x
x = x + self.mlp(self.norm2(x))
return x
# 这个Attention类实现了多头注意力机制,可以加入相对位置 embedding。它包含:
# __init__方法:
# 1. 输入参数:
# - dim: 输入通道数
# - num_heads: 注意力头数
# - qkv_bias: 是否为查询、键、值添加偏置
# - use_rel_pos: 是否使用相对位置 embedding
# - rel_pos_zero_init: 是否将相对位置 embedding 初始化为0
# - input_size: 计算相对位置 embedding 大小所需的输入分辨率
# 2. 计算每个注意力头的维度 head_dim。
# 3. 实例化 self.qkv和 输出投影 self.proj。
# 4. 如果使用相对位置 embedding, 则初始化 rel_pos_h 和 rel_pos_w。
# forward方法:
# 1. 从输入 x 中提取批次大小 B、高度 H、宽度 W 和通道数 C。
# 2. 计算 qkv,形状为 (3, B, nHead, H * W, C), 包含 query、key 和 value。
# 3. 提取 q、 k 和 v, 形状为 (B * nHead, H * W, C)。
# 4. 计算注意力图 attn,形状为 (B * nHead, H * W, H * W)。
# 5. 如果使用相对位置 embedding, 则调用 add_decomposed_rel_pos 函数将其加入 attn。
# 6. 对 attn 进行 softmax 归一化。
# 7. 计算输出 x , (attn @ v), 形状为 (B, nHead, H, W, C), 然后合并注意力头, 形状为(B, H, W, C)。
# 8. 对 x 进行投影, 返回最终的输出。
# 所以,这个 Attention 类实现了带有相对位置 embedding 的多头注意力机制。
# 它可以高效地建模图像和视频等二维结构数据,是 transformer 在这些领域得到广泛应用的关键。
# 这个 Attention 类提供了相对位置 embedding 和多头注意力机制的实现,
# 是理解 transformer 在图像和视频建模中的重要组成部分。
class Attention(nn.Module):
"""Multi-head Attention block with relative position embeddings."""
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
input_size: Optional[Tuple[int, int]] = None,
) -> None:
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
self.use_rel_pos = use_rel_pos
if self.use_rel_pos:
assert (
input_size is not None
), "Input size must be provided if using relative positional encoding."
# initialize relative positional embeddings
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, H, W, _ = x.shape
# qkv with shape (3, B, nHead, H * W, C)
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
# q, k, v with shape (B * nHead, H * W, C)
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
attn = (q * self.scale) @ k.transpose(-2, -1)
if self.use_rel_pos:
attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
attn = attn.softmax(dim=-1)
x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
x = self.proj(x)
return x
# 这个 window_partition 函数的作用是将输入张量划分为非重叠的窗口。它包含:
# 1. 输入参数:
# - x: 输入的张量,形状为 [B, H, W, C]
# - window_size: 窗口大小
# 2. 首先计算输入需要 padding 的高度和宽度,将x进行padding。
# 3. 然后将 x 的形状变化为 [B, Hp//window_size, window_size, Wp//window_size, window_size, C],
# 表示将图像划分为 Hp//window_size * Wp//window_size 个 window_size * window_size 的 patch。
# 4. 最后,通过 permute 和 view 操作,得到 windows 的形状为 [B * num_windows, window_size, window_size, C],
# 表示将所有 patch 打平, num_windows 是 patch 的总数
# 5. 返回windows和原来的高度和宽度(包含padding)Hp和Wp。
# 所以,这个 window_partition 函数的作用是,将输入的图像划分为 window_size * window_size 的 patch,
# 并将所有的 patch 打平, 输出可以输入到 transformer encoder 中的 token 序列。
# 这个函数实现了将二维图像转化为一维 token 序列的过程,是 transformer 用于处理图像的一个关键步骤。
# 通过这个函数,图像可以被 transformer encoder 所处理,就像处理文本序列一样。
def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
B, H, W, C = x.shape
pad_h = (window_size - H % window_size) % window_size
pad_w = (window_size - W % window_size) % window_size
if pad_h > 0 or pad_w > 0:
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
Hp, Wp = H + pad_h, W + pad_w
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows, (Hp, Wp)
# 这个 window_unpartition 函数的作用是将 window_partition 函数的输出进行反划分, 恢复成原始的图像形状。它包含:
# 1. 输入参数:
# - windows: window_partition的输出,形状为 [B * num_windows, window_size, window_size, C]
# - window_size: 窗口大小
# - pad_hw: padding后的高度和宽度 (Hp, Wp)
# - hw: padding前的原始高度和宽度 (H, W)
# 2. 首先根据窗口大小和 padding 后的 hw 计算原始的 batch_size B。
# 3. 然后将 windows 的形状变回 [B, Hp//window_size, Wp//window_size, window_size, window_size, C], 表示每个patch的位置。
# 4. 接着通过permute和view操作,得到x的形状为 [B, Hp, Wp, C], 恢复成图像的形状。
# 5. 最后,如果进行了padding,则截取x到原始的高度H和宽度W。
# 6. 返回恢复后的图像x。
# 所以,这个 window_unpartition 函数的作用是将通过 window_partition 函数得到的 patch 序列恢复成原始的图像。
# 它实现了从一维 patch token 序列到二维图像的反过程。
# 这个函数与 window_partition 函数相反,使得 transformer 能够最终从 patch token 序列恢复成图像,完成对图像的建模。
# 总的来说,这个 window_unpartition 函数实现了从 patch token 序列恢复成原始图像的过程,与 window_partition 函数相对应,
# 是使得 transformer 可以处理图像的另一个关键步骤
def window_unpartition(
windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
) -> torch.Tensor:
Hp, Wp = pad_hw
H, W = hw
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
if Hp > H or Wp > W:
x = x[:, :H, :W, :].contiguous()
return x
# 这个 get_rel_pos 函数的作用是根据 query 和 key 的相对位置获取相对位置 embedding。它包含:
# 1. 输入参数:
# - q_size: query 的大小
# - k_size: key 的大小
# - rel_pos: 相对位置 embedding, 形状为[L, C]
# 2. 首先计算最大的相对距离 max_rel_dist, 它等于 query 和 key 大小的 2 倍减 1。
# 3. 如果相对位置 embedding 的长度小于 max_rel_dist, 则通过线性插值将其调整到 max_rel_dist 的长度。
# 4. 如果 q_size 和 k_size 不同, 则将 q_size 和 k_size 的坐标按比例缩放,使它们之间的相对距离保持不变。
# 5. 根据调整后的 q_size 和 k_size 坐标计算相对坐标 relative_coords。
# 6. 根据 relative_coords 从 rel_pos_resized 中提取相对位置 embedding。
# 7. 返回提取出的相对位置 embedding。
# 所以,这个 get_rel_pos 函数的主要作用是,当 query 和 key 的大小不同时,根据它们的相对位置关系提取相应的相对位置 embedding。
# 它实现了相对位置 embedding 的可变长度和可缩放性。
# 这个函数使得相对位置 embedding 可以用于 query 和 key 大小不同的 attention 中,是相对位置表示的一个关键步骤。
# 总的来说,这个 get_rel_pos 函数实现了根据 query 和 key 的相对位置关系提取相应相对位置 embedding 的过程。
# 它提供了相对位置 embedding 的可变长度和可缩放性,使其可以支持不同的 query 和 key 大小,从而应用到更加灵活的 attention 机制中。
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
max_rel_dist = int(2 * max(q_size, k_size) - 1)
# Interpolate rel pos if needed.
if rel_pos.shape[0] != max_rel_dist:
# Interpolate rel pos.
rel_pos_resized = F.interpolate(
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
size=max_rel_dist,
mode="linear",
)
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
else:
rel_pos_resized = rel_pos
# Scale the coords with short length if shapes for q and k are different.
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
return rel_pos_resized[relative_coords.long()]
# 这个 add_decomposed_rel_pos 函数的作用是根据 query q 和 key k 的空间尺寸, 添加分解的相对位置 embedding 到注意力图 attn 中。它包含:
# 1. 输入参数:
# - attn: 注意力图,形状为 [B, q_h * q_w, k_h * k_w]
# - q: 查询 q,形状为 [B, q_h * q_w, C]
# - rel_pos_h: 高度轴的相对位置 embedding, 形状为[Lh, C]
# - rel_pos_w: 宽度轴的相对位置 embedding, 形状为[Lw, C]
# - q_size: 查询 q的空间尺寸 (q_h, q_w)
# - k_size: 键 k的空间尺寸 (k_h, k_w)
# 2. 从 q_size 和 k_size 中提取高度 q_h、宽度 q_w 以及高度 k_h、宽度 k_w。
# 3. 调用 get_rel_pos 函数获取高度轴 Rh 和宽度轴 Rw 的相对位置 embedding。
# 4. 重塑 q 为 [B, q_h, q_w, C]。
# 5. 计算高度轴 rel_h 和宽度轴 rel_w 的相对位置图, 形状为 [B, q_h, q_w, k_h] 和 [B, q_h, q_w, k_w]。
# 6. 将 attn 的形状变为 [B, q_h, q_w, k_h, k_w], 并加上 rel_h 和 rel_w。
# 7. 将 attn 的形状变回 [B, q_h * q_w, k_h * k_w]。
# 8. 返回加了相对位置 embedding 的 attn。
def add_decomposed_rel_pos(
attn: torch.Tensor,
q: torch.Tensor,
rel_pos_h: torch.Tensor,
rel_pos_w: torch.Tensor,
q_size: Tuple[int, int],
k_size: Tuple[int, int],
) -> torch.Tensor:
q_h, q_w = q_size
k_h, k_w = k_size
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
B, _, dim = q.shape
r_q = q.reshape(B, q_h, q_w, dim)
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
attn = (
attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
).view(B, q_h * q_w, k_h * k_w)
return attn
# 这个 PatchEmbed 类定义了 ViT 的 patch embedding 层。它包含:
# 1. __init__: 初始化,设置卷积层的 kernel size、stride、padding以 及输入通道数和 embedding 维度。
# 2. proj: 这是一个卷积层,用于将输入图像划分为 patch, 并获得每个 patch 的 embedding。
# 3. forward: 前向传播过程。首先通过 proj 卷积层获得 patch embedding ,然后将维度从 [B, C, H, W] 转置成 [B, H, W, C]。
class PatchEmbed(nn.Module):
"""
Image to Patch Embedding.
"""
def __init__(
self,
kernel_size: Tuple[int, int] = (16, 16),
stride: Tuple[int, int] = (16, 16),
padding: Tuple[int, int] = (0, 0),
in_chans: int = 3,
embed_dim: int = 768,
) -> None:
super().__init__()
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
# B C H W -> B H W C
x = x.permute(0, 2, 3, 1)
return x

197
predict_mask.py Normal file
View File

@ -0,0 +1,197 @@
import cv2
import os
import numpy as np
from segment_anything import sam_model_registry, SamPredictor
input_dir = 'scripts/input/images'
output_dir = 'scripts/output/mask'
crop_mode = True
print('最好是每加一个点就按w键predict一次')
os.makedirs(output_dir, exist_ok=True)
image_files = [f for f in os.listdir(input_dir) if
f.lower().endswith(('.png', '.jpg', '.jpeg', '.JPG', '.JPEG', '.PNG', '.tiff'))]
sam = sam_model_registry["vit_b"](checkpoint=r"D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth")
_ = sam.to(device="cuda")
predictor = SamPredictor(sam)
WINDOW_WIDTH = 1280
WINDOW_HEIGHT = 720
cv2.namedWindow("image", cv2.WINDOW_NORMAL)
cv2.resizeWindow("image", WINDOW_WIDTH, WINDOW_HEIGHT)
cv2.moveWindow("image", (1920 - WINDOW_WIDTH) // 2, (1080 - WINDOW_HEIGHT) // 2)
def mouse_click(event, x, y, flags, param):
global input_point, input_label, input_stop
if not input_stop:
if event == cv2.EVENT_LBUTTONDOWN:
input_point.append([x, y])
input_label.append(1)
elif event == cv2.EVENT_RBUTTONDOWN:
input_point.append([x, y])
input_label.append(0)
else:
if event == cv2.EVENT_LBUTTONDOWN or event == cv2.EVENT_RBUTTONDOWN:
print('此时不能添加点,按w退出mask选择模式')
def apply_mask(image, mask, alpha_channel=True):
if alpha_channel:
alpha = np.zeros_like(image[..., 0])
alpha[mask == 1] = 255
image = cv2.merge((image[..., 0], image[..., 1], image[..., 2], alpha))
else:
image = np.where(mask[..., None] == 1, image, 0)
return image
def apply_color_mask(image, mask, color, color_dark=0.5):
for c in range(3):
image[:, :, c] = np.where(mask == 1, image[:, :, c] * (1 - color_dark) + color_dark * color[c], image[:, :, c])
return image
def get_next_filename(base_path, filename):
name, ext = os.path.splitext(filename)
for i in range(1, 101):
new_name = f"{name}_{i}{ext}"
if not os.path.exists(os.path.join(base_path, new_name)):
return new_name
return None
def save_masked_image(image, mask, output_dir, filename, crop_mode_):
if crop_mode_:
y, x = np.where(mask)
y_min, y_max, x_min, x_max = y.min(), y.max(), x.min(), x.max()
cropped_mask = mask[y_min:y_max + 1, x_min:x_max + 1]
cropped_image = image[y_min:y_max + 1, x_min:x_max + 1]
masked_image = apply_mask(cropped_image, cropped_mask)
else:
masked_image = apply_mask(image, mask)
filename = filename[:filename.rfind('.')] + '.png'
new_filename = get_next_filename(output_dir, filename)
if new_filename:
if masked_image.shape[-1] == 4:
cv2.imwrite(os.path.join(output_dir, new_filename), masked_image, [cv2.IMWRITE_PNG_COMPRESSION, 9])
else:
cv2.imwrite(os.path.join(output_dir, new_filename), masked_image)
print(f"Saved as {new_filename}")
else:
print("Could not save the image. Too many variations exist.")
current_index = 0
cv2.namedWindow("image")
cv2.setMouseCallback("image", mouse_click)
input_point = []
input_label = []
input_stop = False
while True:
filename = image_files[current_index]
image_orign = cv2.imread(os.path.join(input_dir, filename))
image_crop = image_orign.copy()
image = cv2.cvtColor(image_orign.copy(), cv2.COLOR_BGR2RGB)
selected_mask = None
logit_input = None
while True:
input_stop = False
image_display = image_orign.copy()
display_info = f'{filename} '
cv2.putText(image_display, display_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2, cv2.LINE_AA)
for point, label in zip(input_point, input_label):
color = (0, 255, 0) if label == 1 else (0, 0, 255)
cv2.circle(image_display, tuple(point), 5, color, -1)
if selected_mask is not None:
color = tuple(np.random.randint(0, 256, 3).tolist())
selected_image = apply_color_mask(image_display, selected_mask, color)
cv2.imshow("image", image_display)
key = cv2.waitKey(1)
if key == ord(" "):
input_point = []
input_label = []
selected_mask = None
logit_input = None
elif key == ord("w"):
input_stop = True
if len(input_point) > 0 and len(input_label) > 0:
predictor.set_image(image)
input_point_np = np.array(input_point)
input_label_np = np.array(input_label)
masks, scores, logits = predictor.predict(
point_coords=input_point_np,
point_labels=input_label_np,
mask_input=logit_input[None, :, :] if logit_input is not None else None,
multimask_output=True,
)
mask_idx = 0
num_masks = len(masks)
while (1):
color = tuple(np.random.randint(0, 256, 3).tolist())
image_select = image_orign.copy()
selected_mask = masks[mask_idx]
selected_image = apply_color_mask(image_select, selected_mask, color)
mask_info = f'Total: {num_masks} | Current: {mask_idx} | Score: {scores[mask_idx]:.2f} | w 预测 | d 切换下一个 | a 切换上一个 | q 移除最后一个 | s 保存'
cv2.putText(selected_image, mask_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2,
cv2.LINE_AA)
cv2.imshow("image", selected_image)
key = cv2.waitKey(10)
if key == ord('q') and len(input_point) > 0:
input_point.pop(-1)
input_label.pop(-1)
elif key == ord('s'):
save_masked_image(image_crop, selected_mask, output_dir, filename, crop_mode_=crop_mode)
elif key == ord('a'):
if mask_idx > 0:
mask_idx -= 1
else:
mask_idx = num_masks - 1
elif key == ord('d'):
if mask_idx < num_masks - 1:
mask_idx += 1
else:
mask_idx = 0
elif key == ord('w'):
break
elif key == ord(" "):
input_point = []
input_label = []
selected_mask = None
logit_input = None
break
logit_input = logits[mask_idx, :, :]
print('max score:', np.argmax(scores), ' select:', mask_idx)
elif key == ord('a'):
current_index = max(0, current_index - 1)
input_point = []
input_label = []
break
elif key == ord('d'):
current_index = min(len(image_files) - 1, current_index + 1)
input_point = []
input_label = []
break
elif key == 27:
break
elif key == ord('q') and len(input_point) > 0:
input_point.pop(-1)
input_label.pop(-1)
elif key == ord('s') and selected_mask is not None:
save_masked_image(image_crop, selected_mask, output_dir, filename, crop_mode_=crop_mode)
if key == 27:
break

200
salt/GUI.py Normal file
View File

@ -0,0 +1,200 @@
import sys
from PyQt5 import QtCore, QtGui, QtWidgets
from PyQt5.QtGui import QPixmap, QImage
from PyQt5.QtCore import QTimer
import cv2
import os
import numpy as np
from segment_anything import sam_model_registry, SamPredictor
input_dir = r'D:\Program Files\Pycharm items\segment-anything-model\scripts\input\images'
output_dir = r'D:\Program Files\Pycharm items\segment-anything-model\scripts\output\maskt'
crop_mode = True
sam = sam_model_registry["vit_b"](checkpoint=r"D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth")
_ = sam.to(device="cuda")
predictor = SamPredictor(sam)
WINDOW_WIDTH = 1280
WINDOW_HEIGHT = 720
class Ui_MainWindow(object):
def setupUi(self, MainWindow):
MainWindow.setObjectName("MainWindow")
MainWindow.resize(1170, 486)
self.centralwidget = QtWidgets.QWidget(MainWindow)
self.centralwidget.setObjectName("centralwidget")
self.pushButton_w = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_w.setGeometry(QtCore.QRect(10, 90, 151, 51))
self.pushButton_w.setObjectName("pushButton_w")
self.pushButton_a = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_a.setGeometry(QtCore.QRect(10, 160, 151, 51))
self.pushButton_a.setObjectName("pushButton_a")
self.pushButton_d = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_d.setGeometry(QtCore.QRect(10, 230, 151, 51))
self.pushButton_d.setObjectName("pushButton_d")
self.pushButton_s = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_s.setGeometry(QtCore.QRect(10, 300, 151, 51))
self.pushButton_s.setObjectName("pushButton_s")
self.pushButton_q = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_q.setGeometry(QtCore.QRect(10, 370, 151, 51))
self.pushButton_q.setObjectName("pushButton_q")
self.label_orign = QtWidgets.QLabel(self.centralwidget)
self.label_orign.setGeometry(QtCore.QRect(180, 20, 471, 401))
self.label_orign.setStyleSheet("background-color: rgb(255, 255, 255);")
self.label_orign.setObjectName("label_orign")
self.label_pre = QtWidgets.QLabel(self.centralwidget)
self.label_pre.setGeometry(QtCore.QRect(660, 20, 471, 401))
self.label_pre.setStyleSheet("background-color: rgb(255, 255, 255);")
self.label_pre.setObjectName("label_pre")
self.pushButton_opimg = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_opimg.setGeometry(QtCore.QRect(10, 20, 151, 51))
self.pushButton_opimg.setObjectName("pushButton_opimg")
MainWindow.setCentralWidget(self.centralwidget)
self.menubar = QtWidgets.QMenuBar(MainWindow)
self.menubar.setGeometry(QtCore.QRect(0, 0, 1170, 26))
self.menubar.setObjectName("menubar")
MainWindow.setMenuBar(self.menubar)
self.statusbar = QtWidgets.QStatusBar(MainWindow)
self.statusbar.setObjectName("statusbar")
MainWindow.setStatusBar(self.statusbar)
self.retranslateUi(MainWindow)
QtCore.QMetaObject.connectSlotsByName(MainWindow)
self.pushButton_opimg.clicked.connect(self.open_image)
self.pushButton_w.clicked.connect(self.predict_image)
self.timer = QTimer()
self.timer.timeout.connect(self.update_original_image)
self.timer.start(100) # Update every 100 milliseconds
self.image_files = [f for f in os.listdir(input_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.JPG', '.JPEG', '.PNG', '.tiff'))]
self.current_index = 0
self.input_point = []
self.input_label = []
self.input_stop = False
cv2.namedWindow("image")
cv2.setMouseCallback("image", self.mouse_click)
def retranslateUi(self, MainWindow):
_translate = QtCore.QCoreApplication.translate
MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow"))
self.pushButton_w.setText(_translate("MainWindow", "w"))
self.pushButton_a.setText(_translate("MainWindow", "a"))
self.pushButton_d.setText(_translate("MainWindow", "d"))
self.pushButton_s.setText(_translate("MainWindow", "s"))
self.pushButton_q.setText(_translate("MainWindow", "q"))
self.label_orign.setText(_translate("MainWindow", "<html><head/><body><p align=\"center\">原始图像</p></body></html>"))
self.label_pre.setText(_translate("MainWindow", "<html><head/><body><p align=\"center\">预测图像</p></body></html>"))
self.pushButton_opimg.setText(_translate("MainWindow", "打开图像"))
def open_image(self):
filename, _ = QtWidgets.QFileDialog.getOpenFileName(None, "Open Image File", "", "Image files (*.jpg *.png)")
if filename:
self.current_index = 0
pixmap = QPixmap(filename)
pixmap = pixmap.scaled(self.label_orign.size(), QtCore.Qt.KeepAspectRatio)
self.label_orign.setPixmap(pixmap)
self.label_orign.setAlignment(QtCore.Qt.AlignCenter)
self.input_point = []
self.input_label = []
def predict_image(self):
if self.current_index < len(self.image_files):
filename = self.image_files[self.current_index]
image_orign = cv2.imread(os.path.join(input_dir, filename))
image_display = cv2.cvtColor(image_orign.copy(), cv2.COLOR_BGR2RGB)
for point, label in zip(self.input_point, self.input_label):
color = (0, 255, 0) if label == 1 else (0, 0, 255)
cv2.circle(image_display, tuple(point), 5, color, -1)
cv2.imshow("image", image_display)
def update_original_image(self):
if self.current_index < len(self.image_files):
filename = self.image_files[self.current_index]
image_orign = cv2.imread(os.path.join(input_dir, filename))
image_display = cv2.cvtColor(image_orign.copy(), cv2.COLOR_BGR2RGB)
for point, label in zip(self.input_point, self.input_label):
color = (0, 255, 0) if label == 1 else (0, 0, 255)
cv2.circle(image_display, tuple(point), 5, color, -1)
height, width, channel = image_display.shape
bytesPerLine = 3 * width
qImg = QImage(image_display.data, width, height, bytesPerLine, QImage.Format_RGB888)
pixmap = QPixmap.fromImage(qImg)
pixmap = pixmap.scaled(self.label_orign.size(), QtCore.Qt.KeepAspectRatio)
self.label_orign.setPixmap(pixmap)
self.label_orign.setAlignment(QtCore.Qt.AlignCenter)
def mouse_click(self, event, x, y, flags, param):
if not self.input_stop:
if event == cv2.EVENT_LBUTTONDOWN:
self.input_point.append([x, y])
self.input_label.append(1)
elif event == cv2.EVENT_RBUTTONDOWN:
self.input_point.append([x, y])
self.input_label.append(0)
else:
if event == cv2.EVENT_LBUTTONDOWN or event == cv2.EVENT_RBUTTONDOWN:
print('此时不能添加点,按w退出mask选择模式')
def apply_mask(image, mask, alpha_channel=True):
if alpha_channel:
alpha = np.zeros_like(image[..., 0])
alpha[mask == 1] = 255
image = cv2.merge((image[..., 0], image[..., 1], image[..., 2], alpha))
else:
image = np.where(mask[..., None] == 1, image, 0)
return image
def apply_color_mask(image, mask, color, color_dark=0.5):
for c in range(3):
image[:, :, c] = np.where(mask == 1, image[:, :, c] * (1 - color_dark) + color_dark * color[c], image[:, :, c])
return image
def get_next_filename(base_path, filename):
name, ext = os.path.splitext(filename)
for i in range(1, 101):
new_name = f"{name}_{i}{ext}"
if not os.path.exists(os.path.join(base_path, new_name)):
return new_name
return None
def save_masked_image(image, mask, output_dir, filename, crop_mode_):
if crop_mode_:
y, x = np.where(mask)
y_min, y_max, x_min, x_max = y.min(), y.max(), x.min(), x.max()
cropped_mask = mask[y_min:y_max + 1, x_min:x_max + 1]
cropped_image = image[y_min:y_max + 1, x_min:x_max + 1]
masked_image = apply_mask(cropped_image, cropped_mask)
else:
masked_image = apply_mask(image, mask)
filename = filename[:filename.rfind('.')] + '.png'
new_filename = get_next_filename(output_dir, filename)
if new_filename:
if masked_image.shape[-1] == 4:
cv2.imwrite(os.path.join(output_dir, new_filename), masked_image, [cv2.IMWRITE_PNG_COMPRESSION, 9])
else:
cv2.imwrite(os.path.join(output_dir, new_filename), masked_image)
print(f"Saved as {new_filename}")
else:
print("Could not save the image. Too many variations exist.")
if __name__ == "__main__":
app = QtWidgets.QApplication(sys.argv)
MainWindow = QtWidgets.QMainWindow()
ui = Ui_MainWindow()
ui.setupUi(MainWindow)
MainWindow.show()
sys.exit(app.exec_())

236
salt/SAM_JSON_多类别.py Normal file
View File

@ -0,0 +1,236 @@
import cv2
import os
import numpy as np
import json
from segment_anything import sam_model_registry, SamPredictor
input_dir = r'C:\Users\t2581\Desktop\222\images'
output_dir = r'C:\Users\t2581\Desktop\222\2'
crop_mode = True
print('最好是每加一个点就按w键predict一次')
os.makedirs(output_dir, exist_ok=True)
image_files = [f for f in os.listdir(input_dir) if
f.lower().endswith(('.png', '.jpg', '.jpeg', '.JPG', '.JPEG', '.PNG', '.tiff'))]
sam = sam_model_registry["vit_b"](checkpoint=r"D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth")
_ = sam.to(device="cuda")
predictor = SamPredictor(sam)
WINDOW_WIDTH = 1280
WINDOW_HEIGHT = 720
cv2.namedWindow("image", cv2.WINDOW_NORMAL)
cv2.resizeWindow("image", WINDOW_WIDTH, WINDOW_HEIGHT)
cv2.moveWindow("image", (1920 - WINDOW_WIDTH) // 2, (1080 - WINDOW_HEIGHT) // 2)
# 定义类别
categories = {1: "category1", 2: "category2", 3: "category3"}
category_colors = {1: (0, 255, 0), 2: (0, 0, 255), 3: (255, 0, 0)}
current_label = 1 # 默认类别
def mouse_click(event, x, y, flags, param):
global input_points, input_labels, input_stop
if not input_stop:
if event == cv2.EVENT_LBUTTONDOWN:
input_points.append([x, y])
input_labels.append(current_label)
elif event == cv2.EVENT_RBUTTONDOWN:
input_points.append([x, y])
input_labels.append(0)
else:
if event == cv2.EVENT_LBUTTONDOWN or event == cv2.EVENT_RBUTTONDOWN:
print('此时不能添加点,按w退出mask选择模式')
def apply_color_mask(image, mask, color, color_dark=0.5):
masked_image = image.copy()
for c in range(3):
masked_image[:, :, c] = np.where(mask == 1, image[:, :, c] * (1 - color_dark) + color_dark * color[c],
image[:, :, c])
return masked_image
def draw_external_rectangle(image, mask, pv):
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
for contour in contours:
x, y, w, h = cv2.boundingRect(contour)
cv2.rectangle(image, (x, y), (x + w, y + h), (0, 255, 255), 2) # Yellow rectangle
cv2.putText(image, pv, (x, y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
def save_masked_image_and_json(image, masks, output_dir, filename, crop_mode_, pv):
masked_image = image.copy()
json_shapes = []
for mask, label, score in masks:
color = category_colors[int(label[-1])] # 获取类别对应的颜色
masked_image = apply_color_mask(masked_image, mask, color)
draw_external_rectangle(masked_image, mask, f"{label}: {score:.2f}")
# Convert mask to polygons
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
polygons = [contour.reshape(-1, 2).tolist() for contour in contours]
# Append JSON shapes
json_shapes.extend([
{
"label": label,
"points": polygon,
"group_id": None,
"shape_type": "polygon",
"flags": {}
} for polygon in polygons
])
masked_filename = filename[:filename.rfind('.')] + '_masked.png'
cv2.imwrite(os.path.join(output_dir, masked_filename), masked_image)
print(f"Saved image as {masked_filename}")
# Create JSON data
json_data = {
"version": "5.1.1",
"flags": {},
"shapes": json_shapes,
"imagePath": filename,
"imageData": None,
"imageHeight": image.shape[0],
"imageWidth": image.shape[1]
}
# Save JSON file
json_filename = filename[:filename.rfind('.')] + '_masked.json'
with open(os.path.join(output_dir, json_filename), 'w') as json_file:
json.dump(json_data, json_file, indent=4)
print(f"Saved JSON as {json_filename}")
current_index = 0
cv2.namedWindow("image")
cv2.setMouseCallback("image", mouse_click)
input_points = []
input_labels = []
input_stop = False
masks = []
all_masks = [] # 用于保存所有类别的标注
while True:
filename = image_files[current_index]
image_orign = cv2.imread(os.path.join(input_dir, filename))
image_crop = image_orign.copy()
image = cv2.cvtColor(image_orign.copy(), cv2.COLOR_BGR2RGB)
selected_mask = None
logit_input = None
while True:
input_stop = False
image_display = image_orign.copy()
display_info = f'{filename} | Current label: {categories[current_label]}'
cv2.putText(image_display, display_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2, cv2.LINE_AA)
for point, label in zip(input_points, input_labels):
color = (0, 255, 0) if label > 0 else (0, 0, 255)
cv2.circle(image_display, tuple(point), 5, color, -1)
if selected_mask is not None:
color = tuple(np.random.randint(0, 256, 3).tolist())
selected_image = apply_color_mask(image_display, selected_mask, color)
cv2.imshow("image", image_display)
key = cv2.waitKey(1)
if key == ord(" "):
input_points = []
input_labels = []
selected_mask = None
logit_input = None
elif key == ord("w"):
input_stop = True
if len(input_points) > 0 and len(input_labels) > 0:
try:
predictor.set_image(image)
input_point_np = np.array(input_points)
input_label_np = np.array(input_labels)
masks_pred, scores, logits = predictor.predict(
point_coords=input_point_np,
point_labels=input_label_np,
mask_input=logit_input[None, :, :] if logit_input is not None else None,
multimask_output=True,
)
mask_idx = 0
num_masks = len(masks_pred) # masks的数量
while True:
color = tuple(np.random.randint(0, 256, 3).tolist()) # 随机列表颜色
image_select = image_orign.copy()
selected_mask = masks_pred[mask_idx] # 选择msks也就是,a,d切换
selected_image = apply_color_mask(image_select, selected_mask, color)
mask_info = f'Total: {num_masks} | Current: {mask_idx} | Score: {scores[mask_idx]:.2f} '
cv2.putText(selected_image, mask_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255),
2, cv2.LINE_AA)
# todo 显示在当前的图片,
cv2.imshow("image", selected_image)
key = cv2.waitKey(10)
if key == ord('q') and len(input_points) > 0:
input_points.pop(-1)
input_labels.pop(-1)
elif key == ord('s'):
masks.append((selected_mask, categories[current_label], scores[mask_idx]))
elif key == ord('a'):
if mask_idx > 0:
mask_idx -= 1
else:
mask_idx = num_masks - 1
elif key == ord('d'):
if mask_idx < num_masks - 1:
mask_idx += 1
else:
mask_idx = 0
elif key == ord('w'):
break
elif key == ord(" "):
input_points = []
input_labels = []
selected_mask = None
logit_input = None
break
logit_input = logits[mask_idx, :, :]
print('max score:', np.argmax(scores), ' select:', mask_idx)
except Exception as e:
print(f"Error during prediction: {e}")
elif key == ord('a'):
current_index = max(0, current_index - 1)
input_points = []
input_labels = []
break
elif key == ord('d'):
current_index = min(len(image_files) - 1, current_index + 1)
input_points = []
input_labels = []
break
elif key == 27:
break
elif key == ord('q') and len(input_points) > 0:
input_points.pop(-1)
input_labels.pop(-1)
elif key == ord('r'):
if masks:
all_masks.extend(masks) # 保存当前的masks到all_masks
masks = [] # 清空当前的masks
input_points = []
input_labels = []
selected_mask = None
logit_input = None
elif key == ord('s'):
if masks:
all_masks.extend(masks) # 保存当前的masks到all_masks
if all_masks:
save_masked_image_and_json(image_crop, all_masks, output_dir, filename, crop_mode_=crop_mode, pv="")
all_masks = [] # 清空所有保存的masks
masks = [] # 清空当前的masks
elif key in [ord(str(i)) for i in categories.keys()]:
current_label = int(chr(key))
print(f"Switched to label: {categories[current_label]}")
if key == 27:
break
cv2.destroyAllWindows()

Binary file not shown.

202
salt/banben1.py Normal file
View File

@ -0,0 +1,202 @@
import cv2
import os
import numpy as np
from segment_anything import sam_model_registry, SamPredictor
input_dir = r'D:\Program Files\Pycharm items\segment-anything-model\scripts\input\images'
output_dir = r'D:\Program Files\Pycharm items\segment-anything-model\scripts\output\maskt'
crop_mode = True
print('最好是每加一个点就按w键predict一次')
os.makedirs(output_dir, exist_ok=True)
image_files = [f for f in os.listdir(input_dir) if
f.lower().endswith(('.png', '.jpg', '.jpeg', '.JPG', '.JPEG', '.PNG', '.tiff'))]
sam = sam_model_registry["vit_b"](checkpoint=r"D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth")
_ = sam.to(device="cuda")
predictor = SamPredictor(sam)
WINDOW_WIDTH = 1280
WINDOW_HEIGHT = 720
cv2.namedWindow("image", cv2.WINDOW_NORMAL)
cv2.resizeWindow("image", WINDOW_WIDTH, WINDOW_HEIGHT)
cv2.moveWindow("image", (1920 - WINDOW_WIDTH) // 2, (1080 - WINDOW_HEIGHT) // 2)
def mouse_click(event, x, y, flags, param):
global input_point, input_label, input_stop
if not input_stop:
if event == cv2.EVENT_LBUTTONDOWN:
input_point.append([x, y])
input_label.append(1)
elif event == cv2.EVENT_RBUTTONDOWN:
input_point.append([x, y])
input_label.append(0)
else:
if event == cv2.EVENT_LBUTTONDOWN or event == cv2.EVENT_RBUTTONDOWN:
print('此时不能添加点,按w退出mask选择模式')
def apply_mask(image, mask, alpha_channel=True):
if alpha_channel:
alpha = np.zeros_like(image[..., 0])
alpha[mask == 1] = 255
image = cv2.merge((image[..., 0], image[..., 1], image[..., 2], alpha))
else:
image = np.where(mask[..., None] == 1, image, 0)
return image
def apply_color_mask(image, mask, color, color_dark=0.5):
for c in range(3):
image[:, :, c] = np.where(mask == 1, image[:, :, c] * (1 - color_dark) + color_dark * color[c], image[:, :, c])
return image
def get_next_filename(base_path, filename):
name, ext = os.path.splitext(filename)
for i in range(1, 101):
new_name = f"{name}_{i}{ext}"
if not os.path.exists(os.path.join(base_path, new_name)):
return new_name
return None
def save_masked_image(image, mask, output_dir, filename, crop_mode_):
if crop_mode_:
y, x = np.where(mask)
y_min, y_max, x_min, x_max = y.min(), y.max(), x.min(), x.max()
cropped_mask = mask[y_min:y_max + 1, x_min:x_max + 1]
cropped_image = image[y_min:y_max + 1, x_min:x_max + 1]
masked_image = apply_mask(cropped_image, cropped_mask)
else:
masked_image = apply_mask(image, mask)
filename = filename[:filename.rfind('.')] + '.png'
new_filename = get_next_filename(output_dir, filename)
if new_filename:
if masked_image.shape[-1] == 4:
cv2.imwrite(os.path.join(output_dir, new_filename), masked_image, [cv2.IMWRITE_PNG_COMPRESSION, 9])
else:
cv2.imwrite(os.path.join(output_dir, new_filename), masked_image)
print(f"Saved as {new_filename}")
else:
print("Could not save the image. Too many variations exist.")
current_index = 0
cv2.namedWindow("image")
cv2.setMouseCallback("image", mouse_click)
input_point = []
input_label = []
input_stop = False
while True:
filename = image_files[current_index]
image_orign = cv2.imread(os.path.join(input_dir, filename))
image_crop = image_orign.copy()
image = cv2.cvtColor(image_orign.copy(), cv2.COLOR_BGR2RGB)
selected_mask = None
logit_input = None
while True:
image_display = image_orign.copy()
display_info = f'{filename} '
cv2.putText(image_display, display_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2, cv2.LINE_AA)
for point, label in zip(input_point, input_label):
color = (0, 255, 0) if label == 1 else (0, 0, 255)
cv2.circle(image_display, tuple(point), 5, color, -1)
if selected_mask is not None:
color = tuple(np.random.randint(0, 256, 3).tolist())
selected_image = apply_color_mask(image_display, selected_mask, color)
cv2.imshow("image", image_display)
key = cv2.waitKey(1)
if key == ord(" "):
input_point = []
input_label = []
selected_mask = None
logit_input = None
elif key == ord("w"):
input_stop = True
if len(input_point) > 0 and len(input_label) > 0:
predictor.set_image(image)
input_point_np = np.array(input_point)
input_label_np = np.array(input_label)
masks, scores, logits = predictor.predict(
point_coords=input_point_np,
point_labels=input_label_np,
mask_input=logit_input[None, :, :] if logit_input is not None else None,
multimask_output=True,
)
mask_idx = 0
num_masks = len(masks)
prediction_window_name = "Prediction"
cv2.namedWindow(prediction_window_name, cv2.WINDOW_NORMAL)
cv2.resizeWindow(prediction_window_name, WINDOW_WIDTH, WINDOW_HEIGHT)
cv2.moveWindow(prediction_window_name, (1920 - WINDOW_WIDTH) // 2, (1080 - WINDOW_HEIGHT) // 2)
while True:
color = tuple(np.random.randint(0, 256, 3).tolist())
image_select = image_orign.copy()
selected_mask = masks[mask_idx]
selected_image = apply_color_mask(image_select, selected_mask, color)
mask_info = f'Total: {num_masks} | Current: {mask_idx} | Score: {scores[mask_idx]:.2f} | w 预测 | d 切换下一个 | a 切换上一个 |q 移除最后一个 | s 保存'
cv2.putText(selected_image, mask_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2, cv2.LINE_AA)
cv2.imshow(prediction_window_name, selected_image)
key = cv2.waitKey(10)
if key == ord('q') and len(input_point) > 0:
input_point.pop(-1)
input_label.pop(-1)
elif key == ord('s'):
save_masked_image(image_crop, selected_mask, output_dir, filename, crop_mode_=crop_mode)
elif key == ord('a'):
if mask_idx > 0:
mask_idx -= 1
else:
mask_idx = num_masks - 1
elif key == ord('d'):
if mask_idx < num_masks - 1:
mask_idx += 1
else:
mask_idx = 0
elif key == ord('w'):
input_stop = False # Allow adding points again
break
elif key == ord(" "):
input_point = []
input_label = []
selected_mask = None
logit_input = None
break
logit_input = logits[mask_idx, :, :]
print('max score:', np.argmax(scores), ' select:', mask_idx)
elif key == ord('a'):
current_index = max(0, current_index - 1)
input_point = []
input_label = []
break
elif key == ord('d'):
current_index = min(len(image_files) - 1, current_index + 1)
input_point = []
input_label = []
break
elif key == 27:
break
elif key == ord('q') and len(input_point) > 0:
input_point.pop(-1)
input_label.pop(-1)
elif key == ord('s') and selected_mask is not None:
save_masked_image(image_crop, selected_mask, output_dir, filename, crop_mode_=crop_mode)
if key == 27:
break
cv2.destroyAllWindows() # Close all windows before exiting

300
salt/banben2.py Normal file
View File

@ -0,0 +1,300 @@
import cv2
import os
import numpy as np
from PyQt5 import QtCore, QtGui, QtWidgets
from segment_anything import sam_model_registry, SamPredictor
class Ui_MainWindow(object):
def setupUi(self, MainWindow):
MainWindow.setObjectName("MainWindow")
MainWindow.resize(1170, 486)
self.centralwidget = QtWidgets.QWidget(MainWindow)
self.centralwidget.setObjectName("centralwidget")
self.pushButton_w = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_w.setGeometry(QtCore.QRect(10, 90, 151, 51))
self.pushButton_w.setObjectName("pushButton_w")
self.pushButton_a = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_a.setGeometry(QtCore.QRect(10, 160, 151, 51))
self.pushButton_a.setObjectName("pushButton_a")
self.pushButton_d = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_d.setGeometry(QtCore.QRect(10, 230, 151, 51))
self.pushButton_d.setObjectName("pushButton_d")
self.pushButton_s = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_s.setGeometry(QtCore.QRect(10, 300, 151, 51))
self.pushButton_s.setObjectName("pushButton_s")
self.pushButton_q = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_q.setGeometry(QtCore.QRect(10, 370, 151, 51))
self.pushButton_q.setObjectName("pushButton_q")
self.label_orign = QtWidgets.QLabel(self.centralwidget)
self.label_orign.setGeometry(QtCore.QRect(180, 20, 450, 450))
self.label_orign.setStyleSheet("background-color: rgb(255, 255, 255);")
self.label_orign.setObjectName("label_orign")
self.label_pre = QtWidgets.QLabel(self.centralwidget)
self.label_pre.setGeometry(QtCore.QRect(660, 20, 450, 450))
self.label_pre.setStyleSheet("background-color: rgb(255, 255, 255);")
self.label_pre.setObjectName("label_pre")
self.pushButton_opimg = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_opimg.setGeometry(QtCore.QRect(10, 20, 151, 51))
self.pushButton_opimg.setObjectName("pushButton_opimg")
MainWindow.setCentralWidget(self.centralwidget)
self.menubar = QtWidgets.QMenuBar(MainWindow)
self.menubar.setGeometry(QtCore.QRect(0, 0, 1170, 26))
self.menubar.setObjectName("menubar")
MainWindow.setMenuBar(self.menubar)
self.statusbar = QtWidgets.QStatusBar(MainWindow)
self.statusbar.setObjectName("statusbar")
MainWindow.setStatusBar(self.statusbar)
self.retranslateUi(MainWindow)
QtCore.QMetaObject.connectSlotsByName(MainWindow)
def retranslateUi(self, MainWindow):
_translate = QtCore.QCoreApplication.translate
MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow"))
self.pushButton_w.setText(_translate("MainWindow", "w"))
self.pushButton_a.setText(_translate("MainWindow", "a"))
self.pushButton_d.setText(_translate("MainWindow", "d"))
self.pushButton_s.setText(_translate("MainWindow", "s"))
self.pushButton_q.setText(_translate("MainWindow", "q"))
self.label_orign.setText(
_translate("MainWindow", "<html><head/><body><p align=\"center\">Original Image</p></body></html>"))
self.label_pre.setText(
_translate("MainWindow", "<html><head/><body><p align=\"center\">Predicted Image</p></body></html>"))
self.pushButton_opimg.setText(_translate("MainWindow", "Open Image"))
class MainWindow(QtWidgets.QMainWindow, Ui_MainWindow):
def __init__(self):
super().__init__()
self.setupUi(self)
self.pushButton_opimg.clicked.connect(self.open_image)
self.pushButton_w.clicked.connect(self.predict_and_interact)
self.image_files = []
self.current_index = 0
self.input_point = []
self.input_label = []
self.input_stop = False
self.interaction_count = 0 # 记录交互次数
self.sam = sam_model_registry["vit_b"](
checkpoint=r"D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth")
_ = self.sam.to(device="cuda")
self.predictor = SamPredictor(self.sam)
# Calculate coordinate scaling factors
self.scale_x = 1.0
self.scale_y = 1.0
self.label_pre_width = self.label_pre.width()
self.label_pre_height = self.label_pre.height()
# Set mouse click event for original image label
self.set_mouse_click_event()
def open_image(self):
options = QtWidgets.QFileDialog.Options()
filename, _ = QtWidgets.QFileDialog.getOpenFileName(self, "Open Image File", "",
"Image Files (*.png *.jpg *.jpeg *.JPG *.JPEG *.PNG *.tiff)",
options=options)
if filename:
image = cv2.imread(filename)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
self.image_files.append(image)
self.display_original_image()
def display_original_image(self):
if self.image_files:
image = self.image_files[self.current_index]
height, width, channel = image.shape
bytesPerLine = 3 * width
qImg = QtGui.QImage(image.data, width, height, bytesPerLine, QtGui.QImage.Format_RGB888)
pixmap = QtGui.QPixmap.fromImage(qImg)
self.label_orign.setPixmap(pixmap.scaled(self.label_orign.size(), QtCore.Qt.KeepAspectRatio))
# Add mouse click event
self.label_orign.mousePressEvent = self.mouse_click
# Draw marked points on the original image
painter = QtGui.QPainter(self.label_orign.pixmap()) # Use label_orign for drawing points
pen_foreground = QtGui.QPen(QtGui.QColor(0, 255, 0)) # Red color for foreground points
pen_foreground.setWidth(5) # Set the width of the pen to 5 (adjust as needed)
pen_background = QtGui.QPen(QtGui.QColor(255, 0, 0)) # Green color for background points
pen_background.setWidth(5) # Set the width of the pen to 5 (adjust as needed)
painter.setPen(pen_foreground)
for point, label in zip(self.input_point, self.input_label):
x, y = self.convert_to_label_coords(point)
if label == 1: # Foreground point
painter.drawPoint(QtCore.QPoint(x, y))
painter.setPen(pen_background)
for point, label in zip(self.input_point, self.input_label):
x, y = self.convert_to_label_coords(point)
if label == 0: # Background point
painter.drawPoint(QtCore.QPoint(x, y))
painter.end()
# Calculate coordinate scaling factors
self.scale_x = width / self.label_orign.width()
self.scale_y = height / self.label_orign.height()
def convert_to_label_coords(self, point):
x = point[0] / self.scale_x
y = point[1] / self.scale_y
return x, y
def mouse_click(self, event):
if not self.input_stop:
x = int(event.pos().x() * self.scale_x)
y = int(event.pos().y() * self.scale_y)
if event.button() == QtCore.Qt.LeftButton: # If left-clicked, mark as foreground
self.input_label.append(1) # Foreground label is 1
elif event.button() == QtCore.Qt.RightButton: # If right-clicked, mark as background
self.input_label.append(0) # Background label is 0
self.input_point.append([x, y])
# Update the original image with marked points
self.display_original_image()
def predict_and_interact(self):
if not self.image_files:
return
image = self.image_files[self.current_index].copy()
filename = f"image_{self.current_index}.png"
image_crop = image.copy()
while True: # Outer loop for prediction
# Prediction logic
if not self.input_stop: # If not in interaction mode
if len(self.input_point) > 0 and len(self.input_label) > 0:
self.predictor.set_image(image)
input_point_np = np.array(self.input_point)
input_label_np = np.array(self.input_label)
masks, scores, logits = self.predictor.predict(
point_coords=input_point_np,
point_labels=input_label_np,
multimask_output=True,
)
mask_idx = 0
num_masks = len(masks)
while True: # Inner loop for interaction
color = tuple(np.random.randint(0, 256, 3).tolist())
image_select = image.copy()
selected_mask = masks[mask_idx]
selected_image = self.apply_color_mask(image_select, selected_mask, color)
mask_info = f'Total: {num_masks} | Current: {mask_idx} | Score: {scores[mask_idx]:.2f} | w Predict | d Next | a Previous | q Remove Last | s Save'
cv2.putText(selected_image, mask_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255),
2, cv2.LINE_AA)
# Display the predicted result in label_pre area
self.display_prediction_image(selected_image)
key = cv2.waitKey(10)
# Handle key press events
if key == ord('q') and len(self.input_point) > 0:
self.input_point.pop(-1)
self.input_label.pop(-1)
self.display_original_image()
elif key == ord('s'):
self.save_masked_image(image_crop, selected_mask, filename)
elif key == ord('a'):
if mask_idx > 0:
mask_idx -= 1
else:
mask_idx = num_masks - 1
elif key == ord('d'):
if mask_idx < num_masks - 1:
mask_idx += 1
else:
mask_idx = 0
elif key == ord(" "):
break
if cv2.getWindowProperty("Prediction", cv2.WND_PROP_VISIBLE) < 1:
break
# If 'w' is pressed, toggle interaction mode
if key == ord('w'):
self.input_stop = not self.input_stop # Toggle interaction mode
if not self.input_stop: # If entering interaction mode
self.interaction_count += 1
if self.interaction_count % 2 == 0: # If even number of interactions, call the interaction function
self.input_point = [] # Reset input points for the next interaction
self.input_label = [] # Reset input labels for the next interaction
self.display_original_image() # Display original image
self.set_mouse_click_event() # Set mouse click event
break # Exit outer loop
else:
continue # Continue prediction
# Exit the outer loop if not in interaction mode
if not self.input_stop:
break
def set_mouse_click_event(self):
self.label_orign.mousePressEvent = self.mouse_click
def display_prediction_image(self, image):
height, width, channel = image.shape
bytesPerLine = 3 * width
qImg = QtGui.QImage(image.data, width, height, bytesPerLine, QtGui.QImage.Format_RGB888)
pixmap = QtGui.QPixmap.fromImage(qImg)
self.label_pre.setPixmap(pixmap.scaled(self.label_pre.size(), QtCore.Qt.KeepAspectRatio))
# Draw marked points on the predicted image
painter = QtGui.QPainter(self.label_pre.pixmap())
pen_foreground = QtGui.QPen(QtGui.QColor(0, 255, 0)) # Red color for foreground points
pen_foreground.setWidth(5) # Set the width of the pen to 5 (adjust as needed)
pen_background = QtGui.QPen(QtGui.QColor(255, 0, 0)) # Green color for background points
pen_background.setWidth(5) # Set the width of the pen to 5 (adjust as needed)
painter.setPen(pen_foreground)
for point, label in zip(self.input_point, self.input_label):
x, y = self.convert_to_label_coords(point)
if label == 1: # Foreground point
painter.drawPoint(QtCore.QPoint(x, y))
painter.setPen(pen_background)
for point, label in zip(self.input_point, self.input_label):
x, y = self.convert_to_label_coords(point)
if label == 0: # Background point
painter.drawPoint(QtCore.QPoint(x, y))
painter.end()
def apply_color_mask(self, image, mask, color=(0, 255, 0), color_dark=0.5):
for c in range(3):
image[:, :, c] = np.where(mask == 1, image[:, :, c] * (1 - color_dark) + color_dark * color[c],
image[:, :, c])
return image
def save_masked_image(self, image, mask, filename):
output_dir = os.path.dirname(filename)
filename = os.path.basename(filename)
filename = filename[:filename.rfind('.')] + '_masked.png'
new_filename = os.path.join(output_dir, filename)
masked_image = self.apply_color_mask(image, mask)
cv2.imwrite(new_filename, cv2.cvtColor(masked_image, cv2.COLOR_RGB2BGR))
print(f"Saved as {new_filename}")
def previous_image(self):
if self.current_index > 0:
self.current_index -= 1
self.display_original_image()
def next_image(self):
if self.current_index < len(self.image_files) - 1:
self.current_index += 1
self.display_original_image()
if __name__ == "__main__":
import sys
app = QtWidgets.QApplication(sys.argv)
window = MainWindow()
window.show()
sys.exit(app.exec_())

378
salt/banben3.py Normal file
View File

@ -0,0 +1,378 @@
import os
import sys
import cv2
import numpy as np
from PyQt5 import QtCore, QtGui, QtWidgets
from segment_anything import sam_model_registry, SamPredictor
class Ui_MainWindow(object):
def setupUi(self, MainWindow):
MainWindow.setObjectName("MainWindow")
MainWindow.resize(1140, 450)
MainWindow.setMinimumSize(QtCore.QSize(1140, 450))
MainWindow.setMaximumSize(QtCore.QSize(1140, 450))
self.centralwidget = QtWidgets.QWidget(MainWindow)
self.centralwidget.setObjectName("centralwidget")
self.pushButton_w = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_w.setGeometry(QtCore.QRect(10, 90, 151, 51))
self.pushButton_w.setObjectName("pushButton_w")
self.pushButton_a = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_a.setGeometry(QtCore.QRect(10, 160, 71, 51))
self.pushButton_a.setObjectName("pushButton_a")
self.pushButton_d = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_d.setGeometry(QtCore.QRect(90, 160, 71, 51))
self.pushButton_d.setObjectName("pushButton_d")
self.pushButton_s = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_s.setGeometry(QtCore.QRect(10, 360, 151, 51))
self.pushButton_s.setObjectName("pushButton_s")
self.pushButton_5 = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_5.setGeometry(QtCore.QRect(10, 230, 151, 51))
self.pushButton_5.setObjectName("pushButton_5")
self.label_orign = QtWidgets.QLabel(self.centralwidget)
self.label_orign.setGeometry(QtCore.QRect(180, 20, 471, 401))
self.label_orign.setStyleSheet("background-color: rgb(255, 255, 255);")
self.label_orign.setObjectName("label_orign")
self.label_2 = QtWidgets.QLabel(self.centralwidget)
self.label_2.setGeometry(QtCore.QRect(660, 20, 471, 401))
self.label_2.setStyleSheet("background-color: rgb(255, 255, 255);")
self.label_2.setObjectName("label_2")
self.pushButton_w_2 = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_w_2.setGeometry(QtCore.QRect(10, 20, 151, 51))
self.pushButton_w_2.setObjectName("pushButton_w_2")
self.lineEdit = QtWidgets.QLineEdit(self.centralwidget)
self.lineEdit.setGeometry(QtCore.QRect(50, 290, 81, 21))
self.lineEdit.setObjectName("lineEdit")
self.horizontalSlider = QtWidgets.QSlider(self.centralwidget)
self.horizontalSlider.setGeometry(QtCore.QRect(10, 320, 141, 22))
self.horizontalSlider.setSliderPosition(50)
self.horizontalSlider.setOrientation(QtCore.Qt.Horizontal)
self.horizontalSlider.setTickInterval(0)
self.horizontalSlider.setObjectName("horizontalSlider")
MainWindow.setCentralWidget(self.centralwidget)
self.menubar = QtWidgets.QMenuBar(MainWindow)
self.menubar.setGeometry(QtCore.QRect(0, 0, 1140, 23))
self.menubar.setObjectName("menubar")
MainWindow.setMenuBar(self.menubar)
self.statusbar = QtWidgets.QStatusBar(MainWindow)
self.statusbar.setObjectName("statusbar")
MainWindow.setStatusBar(self.statusbar)
self.retranslateUi(MainWindow)
QtCore.QMetaObject.connectSlotsByName(MainWindow)
def retranslateUi(self, MainWindow):
_translate = QtCore.QCoreApplication.translate
MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow"))
self.pushButton_w.setText(_translate("MainWindow", "Predict"))
self.pushButton_a.setText(_translate("MainWindow", "Pre"))
self.pushButton_d.setText(_translate("MainWindow", "Next"))
self.pushButton_s.setText(_translate("MainWindow", "Save"))
self.pushButton_5.setText(_translate("MainWindow", "背景图"))
self.label_orign.setText(_translate("MainWindow", "<html><head/><body><p align=\"center\">原始图像</p></body></html>"))
self.label_2.setText(_translate("MainWindow", "<html><head/><body><p align=\"center\">预测图像</p></body></html>"))
self.pushButton_w_2.setText(_translate("MainWindow", "Openimg"))
self.lineEdit.setText(_translate("MainWindow", "改变mask大小"))
class MyMainWindow(QtWidgets.QMainWindow, Ui_MainWindow):
def __init__(self):
super().__init__()
self.setupUi(self)
self.image_path = ""
self.image_folder = ""
self.image_files = []
self.current_index = 0
self.input_stop = False # 在这里初始化 input_stop
self.pushButton_w_2.clicked.connect(self.open_image_folder)
self.pushButton_a.clicked.connect(self.load_previous_image)
self.pushButton_d.clicked.connect(self.load_next_image)
def open_image_folder(self):
folder_dialog = QtWidgets.QFileDialog()
folder_path = folder_dialog.getExistingDirectory(self, 'Open Image Folder', '')
if folder_path:
self.image_folder = folder_path
self.image_files = self.get_image_files(self.image_folder)
if self.image_files:
self.show_image_selection_dialog()
def load_previous_image(self):
if self.image_files:
if self.current_index > 0:
self.current_index -= 1
else:
self.current_index = len(self.image_files) - 1
self.show_image()
def load_next_image(self):
if self.image_files:
if self.current_index < len(self.image_files) - 1:
self.current_index += 1
else:
self.current_index = 0
self.show_image()
def get_image_files(self, folder_path):
image_files = [file for file in os.listdir(folder_path) if file.endswith(('png', 'jpg', 'jpeg', 'bmp'))]
return image_files
def show_image_selection_dialog(self):
dialog = QtWidgets.QDialog(self)
dialog.setWindowTitle("Select Image")
layout = QtWidgets.QVBoxLayout()
self.listWidget = QtWidgets.QListWidget()
for image_file in self.image_files:
item = QtWidgets.QListWidgetItem(image_file)
pixmap = QtGui.QPixmap(os.path.join(self.image_folder, image_file)).scaledToWidth(100)
item.setIcon(QtGui.QIcon(pixmap))
self.listWidget.addItem(item)
self.listWidget.itemDoubleClicked.connect(self.image_selected)
layout.addWidget(self.listWidget)
buttonBox = QtWidgets.QDialogButtonBox(QtWidgets.QDialogButtonBox.Ok | QtWidgets.QDialogButtonBox.Cancel)
buttonBox.accepted.connect(self.image_selected)
buttonBox.rejected.connect(dialog.reject)
layout.addWidget(buttonBox)
dialog.setLayout(layout)
dialog.exec_()
def image_selected(self):
selected_item = self.listWidget.currentItem()
if selected_item:
selected_index = self.listWidget.currentRow()
if selected_index >= 0 and selected_index < len(self.image_files): # 检查索引是否在有效范围内
self.current_index = selected_index
self.show_image() # 显示所选图像
# 调用OpenCV窗口显示
self.call_opencv_interaction(os.path.join(self.image_folder, self.image_files[self.current_index]))
def show_image(self):
if self.image_files and self.current_index < len(self.image_files):
file_path = os.path.join(self.image_folder, self.image_files[self.current_index])
pixmap = QtGui.QPixmap(file_path)
self.label_orign.setPixmap(pixmap.scaled(self.label_orign.size(), QtCore.Qt.KeepAspectRatio))
def call_opencv_interaction(self, image_path):
input_dir = os.path.dirname(image_path)
image_orign = cv2.imread(image_path)
output_dir = r'D:\Program Files\Pycharm items\segment-anything-model\scripts\output\maskt'
crop_mode = True
print('最好是每加一个点就按w键predict一次')
os.makedirs(output_dir, exist_ok=True)
image_files = [self.image_files[self.current_index]]
sam = sam_model_registry["vit_b"](
checkpoint=r"D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth")
_ = sam.to(device="cuda")
predictor = SamPredictor(sam)
WINDOW_WIDTH = 1280
WINDOW_HEIGHT = 720
def apply_mask(image, mask, alpha_channel=True):
if alpha_channel:
alpha = np.zeros_like(image[..., 0])
alpha[mask == 1] = 255
image = cv2.merge((image[..., 0], image[..., 1], image[..., 2], alpha))
else:
image = np.where(mask[..., None] == 1, image, 0)
return image
def apply_color_mask(image, mask, color, color_dark=0.5):
for c in range(3):
image[:, :, c] = np.where(mask == 1, image[:, :, c] * (1 - color_dark) + color_dark * color[c],
image[:, :, c])
return image
def get_next_filename(base_path, filename):
name, ext = os.path.splitext(filename)
for i in range(1, 101):
new_name = f"{name}_{i}{ext}"
if not os.path.exists(os.path.join(base_path, new_name)):
return new_name
return None
def save_masked_image(image, mask, output_dir, filename, crop_mode_):
# 保存图像到指定路径
if crop_mode_:
# 如果采用了裁剪模式,则裁剪图像
y, x = np.where(mask)
y_min, y_max, x_min, x_max = y.min(), y.max(), x.min(), x.max()
cropped_mask = mask[y_min:y_max + 1, x_min:x_max + 1]
cropped_image = image[y_min:y_max + 1, x_min:x_max + 1]
masked_image = apply_mask(cropped_image, cropped_mask)
else:
masked_image = apply_mask(image, mask)
filename = filename[:filename.rfind('.')] + '.png'
new_filename = get_next_filename(output_dir, filename)
if new_filename:
if masked_image.shape[-1] == 4:
cv2.imwrite(os.path.join(output_dir, new_filename), masked_image,
[cv2.IMWRITE_PNG_COMPRESSION, 9])
else:
cv2.imwrite(os.path.join(output_dir, new_filename), masked_image)
print(f"Saved as {new_filename}")
# 读取保存的图像文件
saved_image_path = os.path.join(output_dir, new_filename)
saved_image_pixmap = QtGui.QPixmap(saved_image_path)
# 将保存的图像显示在预测图像区域
mainWindow.label_2.setPixmap(
saved_image_pixmap.scaled(mainWindow.label_2.size(), QtCore.Qt.KeepAspectRatio))
else:
print("Could not save the image. Too many variations exist.")
current_index = 0
cv2.namedWindow("image", cv2.WINDOW_NORMAL)
cv2.resizeWindow("image", WINDOW_WIDTH, WINDOW_HEIGHT)
cv2.moveWindow("image", (1920 - WINDOW_WIDTH) // 2, (1080 - WINDOW_HEIGHT) // 2)
def mouse_click(event, x, y, flags, param):
if not self.input_stop:
if event == cv2.EVENT_LBUTTONDOWN:
input_point.append([x, y])
input_label.append(1)
elif event == cv2.EVENT_RBUTTONDOWN:
input_point.append([x, y])
input_label.append(0)
else:
if event == cv2.EVENT_LBUTTONDOWN or event == cv2.EVENT_RBUTTONDOWN:
print('此时不能添加点,按w退出mask选择模式')
cv2.setMouseCallback("image", mouse_click)
input_point = []
input_label = []
input_stop = False
while True:
filename = self.image_files[self.current_index]
image_orign = cv2.imread(os.path.join(input_dir, filename))
image_crop = image_orign.copy()
image = cv2.cvtColor(image_orign.copy(), cv2.COLOR_BGR2RGB)
selected_mask = None
logit_input = None
while True:
image_display = image_orign.copy()
display_info = f'{filename} '
cv2.putText(image_display, display_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255),
2,
cv2.LINE_AA)
for point, label in zip(input_point, input_label):
color = (0, 255, 0) if label == 1 else (0, 0, 255)
cv2.circle(image_display, tuple(point), 5, color, -1)
if selected_mask is not None:
color = tuple(np.random.randint(0, 256, 3).tolist())
selected_image = apply_color_mask(image_display, selected_mask, color)
cv2.imshow("image", image_display)
key = cv2.waitKey(1)
if key == ord(" "):
input_point = []
input_label = []
selected_mask = None
logit_input = None
elif key == ord("w"):
input_stop = True
if len(input_point) > 0 and len(input_label) > 0:
predictor.set_image(image)
input_point_np = np.array(input_point)
input_label_np = np.array(input_label)
masks, scores, logits = predictor.predict(
point_coords=input_point_np,
point_labels=input_label_np,
mask_input=logit_input[None, :, :] if logit_input is not None else None,
multimask_output=True,
)
mask_idx = 0
num_masks = len(masks)
prediction_window_name = "Prediction"
cv2.namedWindow(prediction_window_name, cv2.WINDOW_NORMAL)
cv2.resizeWindow(prediction_window_name, WINDOW_WIDTH, WINDOW_HEIGHT)
cv2.moveWindow(prediction_window_name, (1920 - WINDOW_WIDTH) // 2,
(1080 - WINDOW_HEIGHT) // 2)
while True:
color = tuple(np.random.randint(0, 256, 3).tolist())
image_select = image_orign.copy()
selected_mask = masks[mask_idx]
selected_image = apply_color_mask(image_select, selected_mask, color)
mask_info = f'Total: {num_masks} | Current: {mask_idx} | Score: {scores[mask_idx]:.2f} | w 预测 | d 切换下一个 | a 切换上一个 |q 移除最后一个'
cv2.putText(selected_image, mask_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7,
(0, 255, 255), 2, cv2.LINE_AA)
cv2.imshow(prediction_window_name, selected_image)
key = cv2.waitKey(10)
if key == ord('q') and len(input_point) > 0:
input_point.pop(-1)
elif key == ord('s'):
save_masked_image(image_crop, selected_mask, output_dir, filename,
crop_mode_=crop_mode)
elif key == ord('a'):
if mask_idx > 0:
mask_idx -= 1
else:
mask_idx = num_masks - 1
elif key == ord('d'):
if mask_idx < num_masks - 1:
mask_idx += 1
else:
mask_idx = 0
elif key == ord('w'):
input_stop = False # Allow adding points again
break
elif key == ord(" "):
input_point = []
input_label = []
selected_mask = None
logit_input = None
break
logit_input = logits[mask_idx, :, :]
print('max score:', np.argmax(scores), ' select:', mask_idx)
elif key == ord('a'):
current_index = max(0, current_index - 1)
input_point = []
input_label = []
break
elif key == ord('d'):
current_index = min(len(image_files) - 1, current_index + 1)
input_point = []
input_label = []
break
elif key == 27:
break
elif key == ord('q') and len(input_point) > 0:
input_point.pop(-1)
input_label.pop(-1)
elif key == ord('s') and selected_mask is not None:
save_masked_image(image_crop, selected_mask, output_dir, filename, crop_mode_=crop_mode)
if key == 27:
break
cv2.destroyAllWindows() # Close all windows before exiting
if key == 27:
break
if __name__ == "__main__":
app = QtWidgets.QApplication(sys.argv)
mainWindow = MyMainWindow()
mainWindow.show()
sys.exit(app.exec_())

453
salt/banben4.py Normal file
View File

@ -0,0 +1,453 @@
import os
import sys
import cv2
import numpy as np
from PyQt5 import QtCore, QtGui, QtWidgets
from segment_anything import sam_model_registry, SamPredictor
class Ui_MainWindow(object):
def setupUi(self, MainWindow):
MainWindow.setObjectName("MainWindow")
MainWindow.resize(1140, 450)
MainWindow.setMinimumSize(QtCore.QSize(1140, 450))
MainWindow.setMaximumSize(QtCore.QSize(1140, 450))
self.centralwidget = QtWidgets.QWidget(MainWindow)
self.centralwidget.setObjectName("centralwidget")
self.pushButton_w = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_w.setGeometry(QtCore.QRect(10, 90, 151, 51))
self.pushButton_w.setObjectName("pushButton_w")
self.pushButton_a = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_a.setGeometry(QtCore.QRect(10, 160, 71, 51))
self.pushButton_a.setObjectName("pushButton_a")
self.pushButton_d = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_d.setGeometry(QtCore.QRect(90, 160, 71, 51))
self.pushButton_d.setObjectName("pushButton_d")
self.pushButton_s = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_s.setGeometry(QtCore.QRect(10, 360, 151, 51))
self.pushButton_s.setObjectName("pushButton_s")
self.pushButton_5 = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_5.setGeometry(QtCore.QRect(10, 230, 151, 51))
self.pushButton_5.setObjectName("pushButton_5")
self.label_orign = QtWidgets.QLabel(self.centralwidget)
self.label_orign.setGeometry(QtCore.QRect(180, 20, 471, 401))
self.label_orign.setStyleSheet("background-color: rgb(255, 255, 255);")
self.label_orign.setObjectName("label_orign")
self.label_2 = QtWidgets.QLabel(self.centralwidget)
self.label_2.setGeometry(QtCore.QRect(660, 20, 471, 401))
self.label_2.setStyleSheet("background-color: rgb(255, 255, 255);")
self.label_2.setObjectName("label_2")
self.pushButton_w_2 = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_w_2.setGeometry(QtCore.QRect(10, 20, 151, 51))
self.pushButton_w_2.setObjectName("pushButton_w_2")
self.lineEdit = QtWidgets.QLineEdit(self.centralwidget)
self.lineEdit.setGeometry(QtCore.QRect(50, 290, 81, 21))
self.lineEdit.setObjectName("lineEdit")
self.horizontalSlider = QtWidgets.QSlider(self.centralwidget)
self.horizontalSlider.setGeometry(QtCore.QRect(10, 320, 141, 22))
self.horizontalSlider.setRange(0, 10) # 将范围设置为从0到最大值
self.horizontalSlider.setSingleStep(1)
self.horizontalSlider.setValue(0) # 初始值设为0
self.horizontalSlider.setOrientation(QtCore.Qt.Horizontal)
self.horizontalSlider.setTickInterval(0)
self.horizontalSlider.setObjectName("horizontalSlider")
MainWindow.setCentralWidget(self.centralwidget)
self.menubar = QtWidgets.QMenuBar(MainWindow)
self.menubar.setGeometry(QtCore.QRect(0, 0, 1140, 23))
self.menubar.setObjectName("menubar")
MainWindow.setMenuBar(self.menubar)
self.statusbar = QtWidgets.QStatusBar(MainWindow)
self.statusbar.setObjectName("statusbar")
MainWindow.setStatusBar(self.statusbar)
self.retranslateUi(MainWindow)
QtCore.QMetaObject.connectSlotsByName(MainWindow)
def retranslateUi(self, MainWindow):
_translate = QtCore.QCoreApplication.translate
MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow"))
self.pushButton_w.setText(_translate("MainWindow", "获取json文件"))
self.pushButton_a.setText(_translate("MainWindow", "Pre"))
self.pushButton_d.setText(_translate("MainWindow", "Next"))
self.pushButton_s.setText(_translate("MainWindow", "Save"))
self.pushButton_5.setText(_translate("MainWindow", "背景图"))
self.label_orign.setText(_translate("MainWindow", "<html><head/><body><p align=\"center\">原始图像</p></body></html>"))
self.label_2.setText(_translate("MainWindow", "<html><head/><body><p align=\"center\">预测图像</p></body></html>"))
self.pushButton_w_2.setText(_translate("MainWindow", "Openimg"))
self.lineEdit.setText(_translate("MainWindow", "改变mask大小"))
class MyMainWindow(QtWidgets.QMainWindow, Ui_MainWindow):
def __init__(self):
super().__init__()
self.setupUi(self)
self.k = 0
self.last_value = 0 # 保存上一次滑块值
self.image_path = ""
self.image_folder = ""
self.image_files = []
self.current_index = 0
self.input_stop = False # 在这里初始化 input_stop
self.pushButton_w_2.clicked.connect(self.open_image_folder)
self.pushButton_a.clicked.connect(self.load_previous_image)
self.pushButton_d.clicked.connect(self.load_next_image)
self.pushButton_s.clicked.connect(self.save_prediction)
self.pushButton_5.clicked.connect(self.select_background_image)
self.horizontalSlider.valueChanged.connect(self.adjust_prediction_size) # 连接水平滑块的值改变信号
def adjust_pixmap_size(self, pixmap, scale_factor):
scaled_size = QtCore.QSize(pixmap.size().width() * scale_factor / 100,
pixmap.size().height() * scale_factor / 100)
return pixmap.scaled(scaled_size, QtCore.Qt.KeepAspectRatio)
def open_image_folder(self):
folder_dialog = QtWidgets.QFileDialog()
folder_path = folder_dialog.getExistingDirectory(self, 'Open Image Folder', '')
if folder_path:
self.image_folder = folder_path
self.image_files = self.get_image_files(self.image_folder)
if self.image_files:
self.show_image_selection_dialog()
def load_previous_image(self):
if self.image_files:
if self.current_index > 0:
self.current_index -= 1
else:
self.current_index = len(self.image_files) - 1
self.show_image()
def load_next_image(self):
if self.image_files:
if self.current_index < len(self.image_files) - 1:
self.current_index += 1
else:
self.current_index = 0
self.show_image()
def get_image_files(self, folder_path):
image_files = [file for file in os.listdir(folder_path) if file.endswith(('png', 'jpg', 'jpeg', 'bmp'))]
return image_files
def show_image_selection_dialog(self):
dialog = QtWidgets.QDialog(self)
dialog.setWindowTitle("Select Image")
layout = QtWidgets.QVBoxLayout()
self.listWidget = QtWidgets.QListWidget()
for image_file in self.image_files:
item = QtWidgets.QListWidgetItem(image_file)
pixmap = QtGui.QPixmap(os.path.join(self.image_folder, image_file)).scaledToWidth(100)
item.setIcon(QtGui.QIcon(pixmap))
self.listWidget.addItem(item)
self.listWidget.itemDoubleClicked.connect(self.image_selected)
layout.addWidget(self.listWidget)
buttonBox = QtWidgets.QDialogButtonBox(QtWidgets.QDialogButtonBox.Ok | QtWidgets.QDialogButtonBox.Cancel)
buttonBox.accepted.connect(self.image_selected)
buttonBox.rejected.connect(dialog.reject)
layout.addWidget(buttonBox)
dialog.setLayout(layout)
dialog.exec_()
def image_selected(self):
selected_item = self.listWidget.currentItem()
if selected_item:
selected_index = self.listWidget.currentRow()
if selected_index >= 0 and selected_index < len(self.image_files): # 检查索引是否在有效范围内
self.current_index = selected_index
self.show_image() # 显示所选图像
# 调用OpenCV窗口显示
self.call_opencv_interaction(os.path.join(self.image_folder, self.image_files[self.current_index]))
def select_background_image(self):
file_dialog = QtWidgets.QFileDialog()
image_path, _ = file_dialog.getOpenFileName(self, 'Select Background Image', '',
'Image Files (*.png *.jpg *.jpeg *.bmp)')
if image_path:
self.show_background_image(image_path)
def show_background_image(self, image_path):
pixmap = QtGui.QPixmap(image_path)
current_pixmap = self.label_2.pixmap()
if current_pixmap:
current_pixmap = QtGui.QPixmap(current_pixmap)
scene = QtWidgets.QGraphicsScene()
scene.addPixmap(pixmap)
scene.addPixmap(current_pixmap)
merged_pixmap = QtGui.QPixmap(scene.sceneRect().size().toSize())
merged_pixmap.fill(QtCore.Qt.transparent)
painter = QtGui.QPainter(merged_pixmap)
scene.render(painter)
painter.end()
self.label_2.setPixmap(merged_pixmap)
else:
self.label_2.setPixmap(pixmap.scaled(self.label_2.size(), QtCore.Qt.KeepAspectRatio))
def show_image(self):
if self.image_files and self.current_index < len(self.image_files):
file_path = os.path.join(self.image_folder, self.image_files[self.current_index])
pixmap = QtGui.QPixmap(file_path)
self.label_orign.setPixmap(pixmap.scaled(self.label_orign.size(), QtCore.Qt.KeepAspectRatio))
def call_opencv_interaction(self, image_path):
input_dir = os.path.dirname(image_path)
image_orign = cv2.imread(image_path)
output_dir = r'D:\Program Files\Pycharm items\segment-anything-model\scripts\output\maskt'
crop_mode = True
print('最好是每加一个点就按w键predict一次')
os.makedirs(output_dir, exist_ok=True)
image_files = [self.image_files[self.current_index]]
sam = sam_model_registry["vit_b"](
checkpoint=r"D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth")
_ = sam.to(device="cuda")
predictor = SamPredictor(sam)
WINDOW_WIDTH = 1280
WINDOW_HEIGHT = 720
def apply_mask(image, mask, alpha_channel=True):
if alpha_channel:
alpha = np.zeros_like(image[..., 0])
alpha[mask == 1] = 255
image = cv2.merge((image[..., 0], image[..., 1], image[..., 2], alpha))
else:
image = np.where(mask[..., None] == 1, image, 0)
return image
def apply_color_mask(image, mask, color, color_dark=0.5):
for c in range(3):
image[:, :, c] = np.where(mask == 1, image[:, :, c] * (1 - color_dark) + color_dark * color[c],
image[:, :, c])
return image
def get_next_filename(base_path, filename):
name, ext = os.path.splitext(filename)
for i in range(1, 101):
new_name = f"{name}_{i}{ext}"
if not os.path.exists(os.path.join(base_path, new_name)):
return new_name
return None
def save_masked_image(image, mask, output_dir, filename, crop_mode_):
# 保存图像到指定路径
if crop_mode_:
# 如果采用了裁剪模式,则裁剪图像
y, x = np.where(mask)
y_min, y_max, x_min, x_max = y.min(), y.max(), x.min(), x.max()
cropped_mask = mask[y_min:y_max + 1, x_min:x_max + 1]
cropped_image = image[y_min:y_max + 1, x_min:x_max + 1]
masked_image = apply_mask(cropped_image, cropped_mask)
else:
masked_image = apply_mask(image, mask)
filename = filename[:filename.rfind('.')] + '.png'
new_filename = get_next_filename(output_dir, filename)
if new_filename:
if masked_image.shape[-1] == 4:
cv2.imwrite(os.path.join(output_dir, new_filename), masked_image,
[cv2.IMWRITE_PNG_COMPRESSION, 9])
else:
cv2.imwrite(os.path.join(output_dir, new_filename), masked_image)
print(f"Saved as {new_filename}")
# 读取保存的图像文件
saved_image_path = os.path.join(output_dir, new_filename)
saved_image_pixmap = QtGui.QPixmap(saved_image_path)
# 将保存的图像显示在预测图像区域
mainWindow.label_2.setPixmap(
saved_image_pixmap.scaled(mainWindow.label_2.size(), QtCore.Qt.KeepAspectRatio))
else:
print("Could not save the image. Too many variations exist.")
current_index = 0
cv2.namedWindow("image", cv2.WINDOW_NORMAL)
cv2.resizeWindow("image", WINDOW_WIDTH, WINDOW_HEIGHT)
cv2.moveWindow("image", (1920 - WINDOW_WIDTH) // 2, (1080 - WINDOW_HEIGHT) // 2)
def mouse_click(event, x, y, flags, param):
if not self.input_stop:
if event == cv2.EVENT_LBUTTONDOWN:
input_point.append([x, y])
input_label.append(1)
elif event == cv2.EVENT_RBUTTONDOWN:
input_point.append([x, y])
input_label.append(0)
else:
if event == cv2.EVENT_LBUTTONDOWN or event == cv2.EVENT_RBUTTONDOWN:
print('此时不能添加点,按w退出mask选择模式')
cv2.setMouseCallback("image", mouse_click)
input_point = []
input_label = []
input_stop = False
while True:
filename = self.image_files[self.current_index]
image_orign = cv2.imread(os.path.join(input_dir, filename))
image_crop = image_orign.copy()
image = cv2.cvtColor(image_orign.copy(), cv2.COLOR_BGR2RGB)
selected_mask = None
logit_input = None
while True:
image_display = image_orign.copy()
display_info = f'{filename} '
cv2.putText(image_display, display_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255),
2,
cv2.LINE_AA)
for point, label in zip(input_point, input_label):
color = (0, 255, 0) if label == 1 else (0, 0, 255)
cv2.circle(image_display, tuple(point), 5, color, -1)
if selected_mask is not None:
color = tuple(np.random.randint(0, 256, 3).tolist())
selected_image = apply_color_mask(image_display, selected_mask, color)
cv2.imshow("image", image_display)
key = cv2.waitKey(1)
if key == ord(" "):
input_point = []
input_label = []
selected_mask = None
logit_input = None
elif key == ord("w"):
input_stop = True
if len(input_point) > 0 and len(input_label) > 0:
predictor.set_image(image)
input_point_np = np.array(input_point)
input_label_np = np.array(input_label)
masks, scores, logits = predictor.predict(
point_coords=input_point_np,
point_labels=input_label_np,
mask_input=logit_input[None, :, :] if logit_input is not None else None,
multimask_output=True,
)
mask_idx = 0
num_masks = len(masks)
prediction_window_name = "Prediction"
cv2.namedWindow(prediction_window_name, cv2.WINDOW_NORMAL)
cv2.resizeWindow(prediction_window_name, WINDOW_WIDTH, WINDOW_HEIGHT)
cv2.moveWindow(prediction_window_name, (1920 - WINDOW_WIDTH) // 2,
(1080 - WINDOW_HEIGHT) // 2)
while True:
color = tuple(np.random.randint(0, 256, 3).tolist())
image_select = image_orign.copy()
selected_mask = masks[mask_idx]
selected_image = apply_color_mask(image_select, selected_mask, color)
mask_info = f'Total: {num_masks} | Current: {mask_idx} | Score: {scores[mask_idx]:.2f} | w 预测 | d 切换下一个 | a 切换上一个 |q 移除最后一个'
cv2.putText(selected_image, mask_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7,
(0, 255, 255), 2, cv2.LINE_AA)
cv2.imshow(prediction_window_name, selected_image)
key = cv2.waitKey(10)
if key == ord('q') and len(input_point) > 0:
input_point.pop(-1)
elif key == ord('s'):
save_masked_image(image_crop, selected_mask, output_dir, filename,
crop_mode_=crop_mode)
elif key == ord('a'):
if mask_idx > 0:
mask_idx -= 1
else:
mask_idx = num_masks - 1
elif key == ord('d'):
if mask_idx < num_masks - 1:
mask_idx += 1
else:
mask_idx = 0
elif key == ord('w'):
input_stop = False # Allow adding points again
break
elif key == ord(" "):
input_point = []
input_label = []
selected_mask = None
logit_input = None
break
logit_input = logits[mask_idx, :, :]
print('max score:', np.argmax(scores), ' select:', mask_idx)
elif key == ord('a'):
current_index = max(0, current_index - 1)
input_point = []
input_label = []
break
elif key == ord('d'):
current_index = min(len(image_files) - 1, current_index + 1)
input_point = []
input_label = []
break
elif key == 27:
break
elif key == ord('q') and len(input_point) > 0:
input_point.pop(-1)
input_label.pop(-1)
elif key == ord('s') and selected_mask is not None:
save_masked_image(image_crop, selected_mask, output_dir, filename, crop_mode_=crop_mode)
if key == 27:
break
cv2.destroyAllWindows() # Close all windows before exiting
if key == 27:
break
def save_prediction(self):
if self.label_2.pixmap(): # 检查预测图像区域是否有图像
# 保存预测结果的部分,这里假设你已经有了保存预测结果的代码,我用 placeholer 代替
# placeholder: 这里假设 save_prediction_result 是一个保存预测结果的函数,它接受预测结果的图像数据以及保存路径作为参数
# 这里假设预测结果图像数据为 prediction_image保存路径为 save_path
prediction_image = self.label_2.pixmap().toImage()
save_path = "prediction_result.png"
prediction_image.save(save_path)
# 调用 adjust_prediction_size 方法来根据 horizontalSlider 的值调整预测区域的大小
self.adjust_prediction_size(self.horizontalSlider.value())
def adjust_prediction_size(self, value):
if self.image_files and self.current_index < len(self.image_files):
# 获取预测图像区域的原始大小
pixmap = self.label_2.pixmap()
if pixmap.isNull():
return
original_size = pixmap.size()
# 判断是缩小还是还原图像
if value < self.last_value:
# 缩小掩码
scale_factor = 1.0 + (self.last_value - value) * 0.1
else:
# 放大掩码
scale_factor = 1.0 - (value - self.last_value) * 0.1
self.last_value = value # 更新上一次的滑块值
# 根据缩放比例调整预测图像区域的大小,并保持纵横比例
scaled_size = QtCore.QSize(original_size.width() * scale_factor, original_size.height() * scale_factor)
scaled_pixmap = pixmap.scaled(scaled_size, QtCore.Qt.KeepAspectRatio)
# 更新预测图像区域的大小并显示
self.label_2.setPixmap(scaled_pixmap)
if __name__ == "__main__":
app = QtWidgets.QApplication(sys.argv)
mainWindow = MyMainWindow()
mainWindow.show()
sys.exit(app.exec_())

146
salt/display1.py Normal file
View File

@ -0,0 +1,146 @@
import sys
import os
import cv2
import numpy as np
from PyQt5 import QtCore, QtGui, QtWidgets
from segment_anything import sam_model_registry, SamPredictor
class ImageLabel(QtWidgets.QLabel):
clicked = QtCore.pyqtSignal(QtCore.QPoint)
def __init__(self, *args, **kwargs):
super(ImageLabel, self).__init__(*args, **kwargs)
self.foreground_points = []
self.background_points = []
def mousePressEvent(self, event):
self.clicked.emit(event.pos())
if event.button() == QtCore.Qt.LeftButton:
self.foreground_points.append(event.pos())
elif event.button() == QtCore.Qt.RightButton:
self.background_points.append(event.pos())
self.update()
def paintEvent(self, event):
super().paintEvent(event)
painter = QtGui.QPainter(self)
painter.setPen(QtGui.QPen(QtGui.QColor("green"), 5))
for point in self.foreground_points:
painter.drawPoint(point)
painter.setPen(QtGui.QPen(QtGui.QColor("red"), 5))
for point in self.background_points:
painter.drawPoint(point)
class MainWindow(QtWidgets.QMainWindow):
def __init__(self, predictor):
super().__init__()
self.predictor = predictor
self.current_points = []
self.current_image = None
self.setupUi()
def setupUi(self):
self.setObjectName("MainWindow")
self.resize(1333, 657)
self.centralwidget = QtWidgets.QWidget(self)
self.centralwidget.setObjectName("centralwidget")
self.pushButton_init = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_init.setGeometry(QtCore.QRect(10, 30, 141, 41))
self.pushButton_init.setObjectName("pushButton_init")
self.pushButton_openimg = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_openimg.setGeometry(QtCore.QRect(10, 90, 141, 41))
self.pushButton_openimg.setObjectName("pushButton_openimg")
self.pushButton_save_mask = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_save_mask.setGeometry(QtCore.QRect(10, 600, 141, 41))
self.pushButton_save_mask.setObjectName("pushButton_save_mask")
self.label_Originalimg = ImageLabel(self.centralwidget)
self.label_Originalimg.setGeometry(QtCore.QRect(160, 30, 571, 581))
self.label_Originalimg.setStyleSheet("background-color: rgb(255, 255, 255);")
self.label_Originalimg.setObjectName("label_Originalimg")
self.label_Maskimg = QtWidgets.QLabel(self.centralwidget)
self.label_Maskimg.setGeometry(QtCore.QRect(740, 30, 581, 581))
self.label_Maskimg.setStyleSheet("background-color: rgb(255, 255, 255);")
self.label_Maskimg.setObjectName("label_Maskimg")
self.setCentralWidget(self.centralwidget)
self.menubar = QtWidgets.QMenuBar(self)
self.menubar.setGeometry(QtCore.QRect(0, 0, 1333, 26))
self.menubar.setObjectName("menubar")
self.setMenuBar(self.menubar)
self.statusbar = QtWidgets.QStatusBar(self)
self.statusbar.setObjectName("statusbar")
self.setStatusBar(self.statusbar)
self.retranslateUi()
QtCore.QMetaObject.connectSlotsByName(self)
def retranslateUi(self):
_translate = QtCore.QCoreApplication.translate
self.setWindowTitle(_translate("MainWindow", "MainWindow"))
self.pushButton_init.setText(_translate("MainWindow", "重置选择"))
self.pushButton_openimg.setText(_translate("MainWindow", "打开图片"))
self.pushButton_save_mask.setText(_translate("MainWindow", "保存掩码"))
self.pushButton_openimg.clicked.connect(self.button_image_open)
self.pushButton_save_mask.clicked.connect(self.button_save_mask)
self.label_Originalimg.clicked.connect(self.mouse_click)
def button_image_open(self):
choice = QtWidgets.QMessageBox.question(None, "选择", "您想要打开文件夹还是选择一个图片文件?",
QtWidgets.QMessageBox.Open | QtWidgets.QMessageBox.Cancel)
if choice == QtWidgets.QMessageBox.Open:
folder_path = QtWidgets.QFileDialog.getExistingDirectory(None, "选择文件夹", "")
if folder_path:
image_files = [os.path.join(folder_path, file) for file in os.listdir(folder_path)
if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp'))]
if image_files:
self.image_files = image_files
self.current_index = 0
self.display_image()
elif choice == QtWidgets.QMessageBox.Cancel:
selected_image, _ = QtWidgets.QFileDialog.getOpenFileName(None, "选择图片", "",
"Image files (*.png *.jpg *.jpeg *.bmp)")
if selected_image:
self.image_files = [selected_image]
self.current_index = 0
self.display_image()
def display_image(self):
if hasattr(self, 'image_files') and self.image_files:
pixmap = QtGui.QPixmap(self.image_files[self.current_index])
self.label_Originalimg.setPixmap(pixmap)
self.label_Originalimg.setScaledContents(True)
def mouse_click(self, pos):
x, y = pos.x(), pos.y()
print("Mouse clicked at position:", x, y)
self.current_points.append([x, y])
def button_save_mask(self):
if self.current_image is not None and len(self.current_points) > 0:
masks, _, _ = self.predictor.predict(point_coords=np.array(self.current_points),
point_labels=np.ones(len(self.current_points), dtype=np.uint8),
multimask_output=True)
if masks:
mask_image = masks[0]
mask_image = cv2.cvtColor(mask_image, cv2.COLOR_BGR2RGB)
q_image = QtGui.QImage(mask_image.data, mask_image.shape[1], mask_image.shape[0], mask_image.strides[0],
QtGui.QImage.Format_RGB888)
pixmap = QtGui.QPixmap.fromImage(q_image)
self.label_Maskimg.setPixmap(pixmap)
self.label_Maskimg.setScaledContents(True)
if __name__ == "__main__":
app = QtWidgets.QApplication(sys.argv)
sam = sam_model_registry["vit_b"](checkpoint=r"D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth")
sam = sam.to(device="cuda")
predictor = SamPredictor(sam)
window = MainWindow(predictor)
window.show()
sys.exit(app.exec_())

144
salt/interface.py Normal file
View File

@ -0,0 +1,144 @@
import sys
import os
import cv2
import numpy as np
from PyQt5 import QtCore, QtGui, QtWidgets
from segment_anything import sam_model_registry, SamPredictor
class Ui_MainWindow(object):
def setupUi(self, MainWindow):
MainWindow.setObjectName("MainWindow")
MainWindow.resize(1333, 657)
self.centralwidget = QtWidgets.QWidget(MainWindow)
self.centralwidget.setObjectName("centralwidget")
self.pushButton_init = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_init.setGeometry(QtCore.QRect(10, 30, 141, 41))
self.pushButton_init.setObjectName("pushButton_init")
self.pushButton_openimg = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_openimg.setGeometry(QtCore.QRect(10, 90, 141, 41))
self.pushButton_openimg.setObjectName("pushButton_openimg")
self.pushButton_Fusionimg = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_Fusionimg.setGeometry(QtCore.QRect(10, 270, 141, 41))
self.pushButton_Fusionimg.setObjectName("pushButton_Fusionimg")
self.pushButton_exit = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_exit.setGeometry(QtCore.QRect(10, 570, 141, 41))
self.pushButton_exit.setObjectName("pushButton_exit")
self.pushButton_Transparency = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_Transparency.setGeometry(QtCore.QRect(10, 380, 141, 41))
self.pushButton_Transparency.setObjectName("pushButton_Transparency")
self.pushButton_copymask = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_copymask.setGeometry(QtCore.QRect(10, 450, 141, 41))
self.pushButton_copymask.setObjectName("pushButton_copymask")
self.pushButton_saveimg = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_saveimg.setGeometry(QtCore.QRect(10, 510, 141, 41))
self.pushButton_saveimg.setObjectName("pushButton_saveimg")
self.horizontalSlider = QtWidgets.QSlider(self.centralwidget)
self.horizontalSlider.setGeometry(QtCore.QRect(10, 330, 141, 22))
self.horizontalSlider.setOrientation(QtCore.Qt.Horizontal)
self.horizontalSlider.setObjectName("horizontalSlider")
self.horizontalSlider.setValue(50)
self.label_Originalimg = QtWidgets.QLabel(self.centralwidget)
self.label_Originalimg.setGeometry(QtCore.QRect(160, 30, 571, 581))
self.label_Originalimg.setStyleSheet("background-color: rgb(255, 255, 255);")
self.label_Originalimg.setObjectName("label_Originalimg")
self.label_Maskimg = QtWidgets.QLabel(self.centralwidget)
self.label_Maskimg.setGeometry(QtCore.QRect(740, 30, 581, 581))
self.label_Maskimg.setStyleSheet("background-color: rgb(255, 255, 255);")
self.label_Maskimg.setObjectName("label_Maskimg")
self.pushButton_shang = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_shang.setGeometry(QtCore.QRect(10, 150, 141, 41))
self.pushButton_shang.setObjectName("pushButton_shang")
self.pushButton_xia = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_xia.setGeometry(QtCore.QRect(10, 210, 141, 41))
self.pushButton_xia.setObjectName("pushButton_openimg_3")
MainWindow.setCentralWidget(self.centralwidget)
self.menubar = QtWidgets.QMenuBar(MainWindow)
self.menubar.setGeometry(QtCore.QRect(0, 0, 1333, 26))
self.menubar.setObjectName("menubar")
MainWindow.setMenuBar(self.menubar)
self.statusbar = QtWidgets.QStatusBar(MainWindow)
self.statusbar.setObjectName("statusbar")
MainWindow.setStatusBar(self.statusbar)
self.retranslateUi(MainWindow)
QtCore.QMetaObject.connectSlotsByName(MainWindow)
def retranslateUi(self, MainWindow):
_translate = QtCore.QCoreApplication.translate
MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow"))
self.pushButton_init.setText(_translate("MainWindow", "重置选择"))
self.pushButton_openimg.setText(_translate("MainWindow", "打开图片"))
self.pushButton_shang.setText(_translate("MainWindow", "上一张"))
self.pushButton_xia.setText(_translate("MainWindow", "下一张"))
self.pushButton_Fusionimg.setText(_translate("MainWindow", "融合背景图片"))
self.pushButton_exit.setText(_translate("MainWindow", "退出"))
self.pushButton_Transparency.setText(_translate("MainWindow", "调整透明度"))
self.pushButton_copymask.setText(_translate("MainWindow", "复制掩码"))
self.pushButton_saveimg.setText(_translate("MainWindow", "保存图片"))
self.label_Originalimg.setText(
_translate("MainWindow", "<html><head/><body><p align=\"center\">原始图像</p></body></html>"))
self.label_Maskimg.setText(
_translate("MainWindow", "<html><head/><body><p align=\"center\">掩码图像</p></body></html>"))
class MyMainWindow(QtWidgets.QMainWindow, Ui_MainWindow):
def __init__(self):
super(MyMainWindow, self).__init__()
self.setupUi(self)
self.init_slots()
self.image_files = []
self.current_index = 0
def init_slots(self):
self.pushButton_openimg.clicked.connect(self.button_image_open)
self.pushButton_shang.clicked.connect(self.button_image_shang)
self.pushButton_xia.clicked.connect(self.button_image_xia)
self.pushButton_exit.clicked.connect(self.button_image_exit)
def button_image_open(self):
choice = QtWidgets.QMessageBox.question(None, "选择", "您想要打开文件夹还是选择一个图片文件?",
QtWidgets.QMessageBox.Open | QtWidgets.QMessageBox.Cancel)
if choice == QtWidgets.QMessageBox.Open:
folder_path = QtWidgets.QFileDialog.getExistingDirectory(None, "选择文件夹", "")
if folder_path:
self.image_files = [os.path.join(folder_path, file) for file in os.listdir(folder_path)
if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp'))]
if self.image_files:
self.current_index = 0
self.display_image()
elif choice == QtWidgets.QMessageBox.Cancel:
selected_image, _ = QtWidgets.QFileDialog.getOpenFileName(None, "选择图片", "",
"Image files (*.png *.jpg *.jpeg *.bmp)")
if selected_image:
self.image_files = [selected_image]
self.current_index = 0
self.display_image()
def display_image(self):
if self.image_files:
pixmap = QtGui.QPixmap(self.image_files[self.current_index])
self.label_Originalimg.setPixmap(pixmap)
self.label_Originalimg.setScaledContents(True)
def button_image_shang(self):
if self.image_files:
self.current_index = (self.current_index - 1) % len(self.image_files)
self.display_image()
def button_image_xia(self):
if self.image_files:
self.current_index = (self.current_index + 1) % len(self.image_files)
self.display_image()
def button_image_exit(self):
sys.exit()
if __name__ == "__main__":
app = QtWidgets.QApplication(sys.argv)
my_main_window = MyMainWindow()
my_main_window.show()
sys.exit(app.exec_())

BIN
salt/prediction_result.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 650 KiB

299
salt/segment1.py Normal file
View File

@ -0,0 +1,299 @@
import cv2
import os
import numpy as np
from PyQt5 import QtCore, QtGui, QtWidgets
from segment_anything import sam_model_registry, SamPredictor
class Ui_MainWindow(object):
def setupUi(self, MainWindow):
MainWindow.setObjectName("MainWindow")
MainWindow.resize(1170, 486)
self.centralwidget = QtWidgets.QWidget(MainWindow)
self.centralwidget.setObjectName("centralwidget")
self.pushButton_w = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_w.setGeometry(QtCore.QRect(10, 90, 151, 51))
self.pushButton_w.setObjectName("pushButton_w")
self.pushButton_a = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_a.setGeometry(QtCore.QRect(10, 160, 151, 51))
self.pushButton_a.setObjectName("pushButton_a")
self.pushButton_d = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_d.setGeometry(QtCore.QRect(10, 230, 151, 51))
self.pushButton_d.setObjectName("pushButton_d")
self.pushButton_s = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_s.setGeometry(QtCore.QRect(10, 300, 151, 51))
self.pushButton_s.setObjectName("pushButton_s")
self.pushButton_q = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_q.setGeometry(QtCore.QRect(10, 370, 151, 51))
self.pushButton_q.setObjectName("pushButton_q")
self.label_orign = QtWidgets.QLabel(self.centralwidget)
self.label_orign.setGeometry(QtCore.QRect(180, 20, 450, 450))
self.label_orign.setStyleSheet("background-color: rgb(255, 255, 255);")
self.label_orign.setObjectName("label_orign")
self.label_pre = QtWidgets.QLabel(self.centralwidget)
self.label_pre.setGeometry(QtCore.QRect(660, 20, 450, 450))
self.label_pre.setStyleSheet("background-color: rgb(255, 255, 255);")
self.label_pre.setObjectName("label_pre")
self.pushButton_opimg = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_opimg.setGeometry(QtCore.QRect(10, 20, 151, 51))
self.pushButton_opimg.setObjectName("pushButton_opimg")
MainWindow.setCentralWidget(self.centralwidget)
self.menubar = QtWidgets.QMenuBar(MainWindow)
self.menubar.setGeometry(QtCore.QRect(0, 0, 1170, 26))
self.menubar.setObjectName("menubar")
MainWindow.setMenuBar(self.menubar)
self.statusbar = QtWidgets.QStatusBar(MainWindow)
self.statusbar.setObjectName("statusbar")
MainWindow.setStatusBar(self.statusbar)
self.retranslateUi(MainWindow)
QtCore.QMetaObject.connectSlotsByName(MainWindow)
def retranslateUi(self, MainWindow):
_translate = QtCore.QCoreApplication.translate
MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow"))
self.pushButton_w.setText(_translate("MainWindow", "w"))
self.pushButton_a.setText(_translate("MainWindow", "a"))
self.pushButton_d.setText(_translate("MainWindow", "d"))
self.pushButton_s.setText(_translate("MainWindow", "s"))
self.pushButton_q.setText(_translate("MainWindow", "q"))
self.label_orign.setText(
_translate("MainWindow", "<html><head/><body><p align=\"center\">Original Image</p></body></html>"))
self.label_pre.setText(
_translate("MainWindow", "<html><head/><body><p align=\"center\">Predicted Image</p></body></html>"))
self.pushButton_opimg.setText(_translate("MainWindow", "Open Image"))
class MainWindow(QtWidgets.QMainWindow, Ui_MainWindow):
def __init__(self):
super().__init__()
self.setupUi(self)
self.pushButton_opimg.clicked.connect(self.open_image)
self.pushButton_w.clicked.connect(self.predict_and_interact)
self.image_files = []
self.current_index = 0
self.input_point = []
self.input_label = []
self.input_stop = False
self.interaction_count = 0 # 记录交互次数
self.sam = sam_model_registry["vit_b"](
checkpoint=r"D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth")
_ = self.sam.to(device="cuda")
self.predictor = SamPredictor(self.sam)
# Calculate coordinate scaling factors
self.scale_x = 1.0
self.scale_y = 1.0
self.label_pre_width = self.label_pre.width()
self.label_pre_height = self.label_pre.height()
# Set mouse click event for original image label
self.set_mouse_click_event()
def open_image(self):
options = QtWidgets.QFileDialog.Options()
filename, _ = QtWidgets.QFileDialog.getOpenFileName(self, "Open Image File", "",
"Image Files (*.png *.jpg *.jpeg *.JPG *.JPEG *.PNG *.tiff)",
options=options)
if filename:
image = cv2.imread(filename)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
self.image_files.append(image)
self.display_original_image()
def display_original_image(self):
if self.image_files:
image = self.image_files[self.current_index]
height, width, channel = image.shape
bytesPerLine = 3 * width
qImg = QtGui.QImage(image.data, width, height, bytesPerLine, QtGui.QImage.Format_RGB888)
pixmap = QtGui.QPixmap.fromImage(qImg)
self.label_orign.setPixmap(pixmap.scaled(self.label_orign.size(), QtCore.Qt.KeepAspectRatio))
# Add mouse click event
self.label_orign.mousePressEvent = self.mouse_click
# Draw marked points on the original image
painter = QtGui.QPainter(self.label_orign.pixmap()) # Use label_orign for drawing points
pen_foreground = QtGui.QPen(QtGui.QColor(0, 255, 0)) # Red color for foreground points
pen_foreground.setWidth(5) # Set the width of the pen to 5 (adjust as needed)
pen_background = QtGui.QPen(QtGui.QColor(255, 0, 0)) # Green color for background points
pen_background.setWidth(5) # Set the width of the pen to 5 (adjust as needed)
painter.setPen(pen_foreground)
for point, label in zip(self.input_point, self.input_label):
x, y = self.convert_to_label_coords(point)
if label == 1: # Foreground point
painter.drawPoint(QtCore.QPoint(x, y))
painter.setPen(pen_background)
for point, label in zip(self.input_point, self.input_label):
x, y = self.convert_to_label_coords(point)
if label == 0: # Background point
painter.drawPoint(QtCore.QPoint(x, y))
painter.end()
# Calculate coordinate scaling factors
self.scale_x = width / self.label_orign.width()
self.scale_y = height / self .label_orign.height()
def convert_to_label_coords(self, point):
x = point[0] / self.scale_x
y = point[1] / self.scale_y
return x, y
def mouse_click(self, event):
if not self.input_stop:
x = int(event.pos().x() * self.scale_x)
y = int(event.pos().y() * self.scale_y)
if event.button() == QtCore.Qt.LeftButton: # If left-clicked, mark as foreground
self.input_label.append(1) # Foreground label is 1
elif event.button() == QtCore.Qt.RightButton: # If right-clicked, mark as background
self.input_label.append(0) # Background label is 0
self.input_point.append([x, y])
# Update the original image with marked points
self.display_original_image()
def predict_and_interact(self):
if not self.image_files:
return
image = self.image_files[self.current_index].copy()
filename = f"image_{self.current_index}.png"
image_crop = image.copy()
while True: # Outer loop for prediction
# Prediction logic
if not self.input_stop: # If not in interaction mode
if len(self.input_point) > 0 and len(self.input_label) > 0:
self.predictor.set_image(image)
input_point_np = np.array(self.input_point)
input_label_np = np.array(self.input_label)
masks, scores, logits = self.predictor.predict(
point_coords=input_point_np,
point_labels=input_label_np,
multimask_output=True,
)
mask_idx = 0
num_masks = len(masks)
while True: # Inner loop for interaction
color = tuple(np.random.randint(0, 256, 3).tolist())
image_select = image.copy()
selected_mask = masks[mask_idx]
selected_image = self.apply_color_mask(image_select, selected_mask, color)
mask_info = f'Total: {num_masks} | Current: {mask_idx} | Score: {scores[mask_idx]:.2f} | w Predict | d Next | a Previous | q Remove Last | s Save'
cv2.putText(selected_image, mask_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255),
2, cv2.LINE_AA)
# Display the predicted result in label_pre area
self.display_prediction_image(selected_image)
key = cv2.waitKey(10)
# Handle key press events
if key == ord('q') and len(self.input_point) > 0:
self.input_point.pop(-1)
self.input_label.pop(-1)
self.display_original_image()
elif key == ord('s'):
self.save_masked_image(image_crop, selected_mask, filename)
elif key == ord('a'):
if mask_idx > 0:
mask_idx -= 1
else:
mask_idx = num_masks - 1
elif key == ord('d'):
if mask_idx < num_masks - 1:
mask_idx += 1
else:
mask_idx = 0
elif key == ord(" "):
break
if cv2.getWindowProperty("Prediction", cv2.WND_PROP_VISIBLE) < 1:
break
# If 'w' is pressed, toggle interaction mode
if key == ord('w'):
self.input_stop = not self.input_stop # Toggle interaction mode
if not self.input_stop: # If entering interaction mode
self.interaction_count += 1
if self.interaction_count % 2 == 0: # If even number of interactions, call the interaction function
self.input_point = [] # Reset input points for the next interaction
self.input_label = [] # Reset input labels for the next interaction
self.display_original_image() # Display original image
self.set_mouse_click_event() # Set mouse click event
break # Exit outer loop
else:
continue # Continue prediction
# Exit the outer loop if not in interaction mode
if not self.input_stop:
break
def set_mouse_click_event(self):
self.label_orign.mousePressEvent = self.mouse_click
def display_prediction_image(self, image):
height, width, channel = image.shape
bytesPerLine = 3 * width
qImg = QtGui.QImage(image.data, width, height, bytesPerLine, QtGui.QImage.Format_RGB888)
pixmap = QtGui.QPixmap.fromImage(qImg)
self.label_pre.setPixmap(pixmap.scaled(self.label_pre.size(), QtCore.Qt.KeepAspectRatio))
# Draw marked points on the predicted image
painter = QtGui.QPainter(self.label_pre.pixmap())
pen_foreground = QtGui.QPen(QtGui.QColor(0, 255, 0)) # Red color for foreground points
pen_foreground.setWidth(5) # Set the width of the pen to 5 (adjust as needed)
pen_background = QtGui.QPen(QtGui.QColor(255, 0, 0)) # Green color for background points
pen_background.setWidth(5) # Set the width of the pen to 5 (adjust as needed)
painter.setPen(pen_foreground)
for point, label in zip(self.input_point, self.input_label):
x, y = self.convert_to_label_coords(point)
if label == 1: # Foreground point
painter.drawPoint(QtCore.QPoint(x, y))
painter.setPen(pen_background)
for point, label in zip(self.input_point, self.input_label):
x, y = self.convert_to_label_coords(point)
if label == 0: # Background point
painter.drawPoint(QtCore.QPoint(x, y))
painter.end()
def apply_color_mask(self, image, mask, color=(0, 255, 0), color_dark=0.5):
for c in range(3):
image[:, :, c] = np.where(mask == 1, image[:, :, c] * (1 - color_dark) + color_dark * color[c],
image[:, :, c])
return image
def save_masked_image(self, image, mask, filename):
output_dir = os.path.dirname(filename)
filename = os.path.basename(filename)
filename = filename[:filename.rfind('.')] + '_masked.png'
new_filename = os.path.join(output_dir, filename)
masked_image = self.apply_color_mask(image, mask)
cv2.imwrite(new_filename, cv2.cvtColor(masked_image, cv2.COLOR_RGB2BGR))
print(f"Saved as { new_filename}")
def previous_image(self):
if self.current_index > 0:
self.current_index -= 1
self.display_original_image()
def next_image(self):
if self.current_index < len(self.image_files) - 1:
self.current_index += 1
self.display_original_image()
if __name__ == "__main__":
import sys
app = QtWidgets.QApplication(sys.argv)
window = MainWindow()
window.show()
sys.exit(app.exec_())

453
salt/suibian.py Normal file
View File

@ -0,0 +1,453 @@
import os
import sys
import cv2
import numpy as np
from PyQt5 import QtCore, QtGui, QtWidgets
from segment_anything import sam_model_registry, SamPredictor
class Ui_MainWindow(object):
def setupUi(self, MainWindow):
MainWindow.setObjectName("MainWindow")
MainWindow.resize(1140, 450)
MainWindow.setMinimumSize(QtCore.QSize(1140, 450))
MainWindow.setMaximumSize(QtCore.QSize(1140, 450))
self.centralwidget = QtWidgets.QWidget(MainWindow)
self.centralwidget.setObjectName("centralwidget")
self.pushButton_w = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_w.setGeometry(QtCore.QRect(10, 90, 151, 51))
self.pushButton_w.setObjectName("pushButton_w")
self.pushButton_a = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_a.setGeometry(QtCore.QRect(10, 160, 71, 51))
self.pushButton_a.setObjectName("pushButton_a")
self.pushButton_d = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_d.setGeometry(QtCore.QRect(90, 160, 71, 51))
self.pushButton_d.setObjectName("pushButton_d")
self.pushButton_s = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_s.setGeometry(QtCore.QRect(10, 360, 151, 51))
self.pushButton_s.setObjectName("pushButton_s")
self.pushButton_5 = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_5.setGeometry(QtCore.QRect(10, 230, 151, 51))
self.pushButton_5.setObjectName("pushButton_5")
self.label_orign = QtWidgets.QLabel(self.centralwidget)
self.label_orign.setGeometry(QtCore.QRect(180, 20, 471, 401))
self.label_orign.setStyleSheet("background-color: rgb(255, 255, 255);")
self.label_orign.setObjectName("label_orign")
self.label_2 = QtWidgets.QLabel(self.centralwidget)
self.label_2.setGeometry(QtCore.QRect(660, 20, 471, 401))
self.label_2.setStyleSheet("background-color: rgb(255, 255, 255);")
self.label_2.setObjectName("label_2")
self.pushButton_w_2 = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_w_2.setGeometry(QtCore.QRect(10, 20, 151, 51))
self.pushButton_w_2.setObjectName("pushButton_w_2")
self.lineEdit = QtWidgets.QLineEdit(self.centralwidget)
self.lineEdit.setGeometry(QtCore.QRect(50, 290, 81, 21))
self.lineEdit.setObjectName("lineEdit")
self.horizontalSlider = QtWidgets.QSlider(self.centralwidget)
self.horizontalSlider.setGeometry(QtCore.QRect(10, 320, 141, 22))
self.horizontalSlider.setRange(0, 10) # 将范围设置为从0到最大值
self.horizontalSlider.setSingleStep(1)
self.horizontalSlider.setValue(0) # 初始值设为0
self.horizontalSlider.setOrientation(QtCore.Qt.Horizontal)
self.horizontalSlider.setTickInterval(0)
self.horizontalSlider.setObjectName("horizontalSlider")
MainWindow.setCentralWidget(self.centralwidget)
self.menubar = QtWidgets.QMenuBar(MainWindow)
self.menubar.setGeometry(QtCore.QRect(0, 0, 1140, 23))
self.menubar.setObjectName("menubar")
MainWindow.setMenuBar(self.menubar)
self.statusbar = QtWidgets.QStatusBar(MainWindow)
self.statusbar.setObjectName("statusbar")
MainWindow.setStatusBar(self.statusbar)
self.retranslateUi(MainWindow)
QtCore.QMetaObject.connectSlotsByName(MainWindow)
def retranslateUi(self, MainWindow):
_translate = QtCore.QCoreApplication.translate
MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow"))
self.pushButton_w.setText(_translate("MainWindow", "Predict"))
self.pushButton_a.setText(_translate("MainWindow", "Pre"))
self.pushButton_d.setText(_translate("MainWindow", "Next"))
self.pushButton_s.setText(_translate("MainWindow", "Save"))
self.pushButton_5.setText(_translate("MainWindow", "背景图"))
self.label_orign.setText(_translate("MainWindow", "<html><head/><body><p align=\"center\">原始图像</p></body></html>"))
self.label_2.setText(_translate("MainWindow", "<html><head/><body><p align=\"center\">预测图像</p></body></html>"))
self.pushButton_w_2.setText(_translate("MainWindow", "Openimg"))
self.lineEdit.setText(_translate("MainWindow", "改变mask大小"))
class MyMainWindow(QtWidgets.QMainWindow, Ui_MainWindow):
def __init__(self):
super().__init__()
self.setupUi(self)
self.k = 0
self.last_value = 0 # 保存上一次滑块值
self.image_path = ""
self.image_folder = ""
self.image_files = []
self.current_index = 0
self.input_stop = False # 在这里初始化 input_stop
self.pushButton_w_2.clicked.connect(self.open_image_folder)
self.pushButton_a.clicked.connect(self.load_previous_image)
self.pushButton_d.clicked.connect(self.load_next_image)
self.pushButton_s.clicked.connect(self.save_prediction)
self.pushButton_5.clicked.connect(self.select_background_image)
self.horizontalSlider.valueChanged.connect(self.adjust_prediction_size) # 连接水平滑块的值改变信号
def adjust_pixmap_size(self, pixmap, scale_factor):
scaled_size = QtCore.QSize(pixmap.size().width() * scale_factor / 100,
pixmap.size().height() * scale_factor / 100)
return pixmap.scaled(scaled_size, QtCore.Qt.KeepAspectRatio)
def open_image_folder(self):
folder_dialog = QtWidgets.QFileDialog()
folder_path = folder_dialog.getExistingDirectory(self, 'Open Image Folder', '')
if folder_path:
self.image_folder = folder_path
self.image_files = self.get_image_files(self.image_folder)
if self.image_files:
self.show_image_selection_dialog()
def load_previous_image(self):
if self.image_files:
if self.current_index > 0:
self.current_index -= 1
else:
self.current_index = len(self.image_files) - 1
self.show_image()
def load_next_image(self):
if self.image_files:
if self.current_index < len(self.image_files) - 1:
self.current_index += 1
else:
self.current_index = 0
self.show_image()
def get_image_files(self, folder_path):
image_files = [file for file in os.listdir(folder_path) if file.endswith(('png', 'jpg', 'jpeg', 'bmp'))]
return image_files
def show_image_selection_dialog(self):
dialog = QtWidgets.QDialog(self)
dialog.setWindowTitle("Select Image")
layout = QtWidgets.QVBoxLayout()
self.listWidget = QtWidgets.QListWidget()
for image_file in self.image_files:
item = QtWidgets.QListWidgetItem(image_file)
pixmap = QtGui.QPixmap(os.path.join(self.image_folder, image_file)).scaledToWidth(100)
item.setIcon(QtGui.QIcon(pixmap))
self.listWidget.addItem(item)
self.listWidget.itemDoubleClicked.connect(self.image_selected)
layout.addWidget(self.listWidget)
buttonBox = QtWidgets.QDialogButtonBox(QtWidgets.QDialogButtonBox.Ok | QtWidgets.QDialogButtonBox.Cancel)
buttonBox.accepted.connect(self.image_selected)
buttonBox.rejected.connect(dialog.reject)
layout.addWidget(buttonBox)
dialog.setLayout(layout)
dialog.exec_()
def image_selected(self):
selected_item = self.listWidget.currentItem()
if selected_item:
selected_index = self.listWidget.currentRow()
if selected_index >= 0 and selected_index < len(self.image_files): # 检查索引是否在有效范围内
self.current_index = selected_index
self.show_image() # 显示所选图像
# 调用OpenCV窗口显示
self.call_opencv_interaction(os.path.join(self.image_folder, self.image_files[self.current_index]))
def select_background_image(self):
file_dialog = QtWidgets.QFileDialog()
image_path, _ = file_dialog.getOpenFileName(self, 'Select Background Image', '',
'Image Files (*.png *.jpg *.jpeg *.bmp)')
if image_path:
self.show_background_image(image_path)
def show_background_image(self, image_path):
pixmap = QtGui.QPixmap(image_path)
current_pixmap = self.label_2.pixmap()
if current_pixmap:
current_pixmap = QtGui.QPixmap(current_pixmap)
scene = QtWidgets.QGraphicsScene()
scene.addPixmap(pixmap)
scene.addPixmap(current_pixmap)
merged_pixmap = QtGui.QPixmap(scene.sceneRect().size().toSize())
merged_pixmap.fill(QtCore.Qt.transparent)
painter = QtGui.QPainter(merged_pixmap)
scene.render(painter)
painter.end()
self.label_2.setPixmap(merged_pixmap)
else:
self.label_2.setPixmap(pixmap.scaled(self.label_2.size(), QtCore.Qt.KeepAspectRatio))
def show_image(self):
if self.image_files and self.current_index < len(self.image_files):
file_path = os.path.join(self.image_folder, self.image_files[self.current_index])
pixmap = QtGui.QPixmap(file_path)
self.label_orign.setPixmap(pixmap.scaled(self.label_orign.size(), QtCore.Qt.KeepAspectRatio))
def call_opencv_interaction(self, image_path):
input_dir = os.path.dirname(image_path)
image_orign = cv2.imread(image_path)
output_dir = r'D:\Program Files\Pycharm items\segment-anything-model\scripts\output\maskt'
crop_mode = True
print('最好是每加一个点就按w键predict一次')
os.makedirs(output_dir, exist_ok=True)
image_files = [self.image_files[self.current_index]]
sam = sam_model_registry["vit_b"](
checkpoint=r"D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth")
_ = sam.to(device="cuda")
predictor = SamPredictor(sam)
WINDOW_WIDTH = 1280
WINDOW_HEIGHT = 720
def apply_mask(image, mask, alpha_channel=True):
if alpha_channel:
alpha = np.zeros_like(image[..., 0])
alpha[mask == 1] = 255
image = cv2.merge((image[..., 0], image[..., 1], image[..., 2], alpha))
else:
image = np.where(mask[..., None] == 1, image, 0)
return image
def apply_color_mask(image, mask, color, color_dark=0.5):
for c in range(3):
image[:, :, c] = np.where(mask == 1, image[:, :, c] * (1 - color_dark) + color_dark * color[c],
image[:, :, c])
return image
def get_next_filename(base_path, filename):
name, ext = os.path.splitext(filename)
for i in range(1, 101):
new_name = f"{name}_{i}{ext}"
if not os.path.exists(os.path.join(base_path, new_name)):
return new_name
return None
def save_masked_image(image, mask, output_dir, filename, crop_mode_):
# 保存图像到指定路径
if crop_mode_:
# 如果采用了裁剪模式,则裁剪图像
y, x = np.where(mask)
y_min, y_max, x_min, x_max = y.min(), y.max(), x.min(), x.max()
cropped_mask = mask[y_min:y_max + 1, x_min:x_max + 1]
cropped_image = image[y_min:y_max + 1, x_min:x_max + 1]
masked_image = apply_mask(cropped_image, cropped_mask)
else:
masked_image = apply_mask(image, mask)
filename = filename[:filename.rfind('.')] + '.png'
new_filename = get_next_filename(output_dir, filename)
if new_filename:
if masked_image.shape[-1] == 4:
cv2.imwrite(os.path.join(output_dir, new_filename), masked_image,
[cv2.IMWRITE_PNG_COMPRESSION, 9])
else:
cv2.imwrite(os.path.join(output_dir, new_filename), masked_image)
print(f"Saved as {new_filename}")
# 读取保存的图像文件
saved_image_path = os.path.join(output_dir, new_filename)
saved_image_pixmap = QtGui.QPixmap(saved_image_path)
# 将保存的图像显示在预测图像区域
mainWindow.label_2.setPixmap(
saved_image_pixmap.scaled(mainWindow.label_2.size(), QtCore.Qt.KeepAspectRatio))
else:
print("Could not save the image. Too many variations exist.")
current_index = 0
cv2.namedWindow("image", cv2.WINDOW_NORMAL)
cv2.resizeWindow("image", WINDOW_WIDTH, WINDOW_HEIGHT)
cv2.moveWindow("image", (1920 - WINDOW_WIDTH) // 2, (1080 - WINDOW_HEIGHT) // 2)
def mouse_click(event, x, y, flags, param):
if not self.input_stop:
if event == cv2.EVENT_LBUTTONDOWN:
input_point.append([x, y])
input_label.append(1)
elif event == cv2.EVENT_RBUTTONDOWN:
input_point.append([x, y])
input_label.append(0)
else:
if event == cv2.EVENT_LBUTTONDOWN or event == cv2.EVENT_RBUTTONDOWN:
print('此时不能添加点,按w退出mask选择模式')
cv2.setMouseCallback("image", mouse_click)
input_point = []
input_label = []
input_stop = False
while True:
filename = self.image_files[self.current_index]
image_orign = cv2.imread(os.path.join(input_dir, filename))
image_crop = image_orign.copy()
image = cv2.cvtColor(image_orign.copy(), cv2.COLOR_BGR2RGB)
selected_mask = None
logit_input = None
while True:
image_display = image_orign.copy()
display_info = f'{filename} '
cv2.putText(image_display, display_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255),
2,
cv2.LINE_AA)
for point, label in zip(input_point, input_label):
color = (0, 255, 0) if label == 1 else (0, 0, 255)
cv2.circle(image_display, tuple(point), 5, color, -1)
if selected_mask is not None:
color = tuple(np.random.randint(0, 256, 3).tolist())
selected_image = apply_color_mask(image_display, selected_mask, color)
cv2.imshow("image", image_display)
key = cv2.waitKey(1)
if key == ord(" "):
input_point = []
input_label = []
selected_mask = None
logit_input = None
elif key == ord("w"):
input_stop = True
if len(input_point) > 0 and len(input_label) > 0:
predictor.set_image(image)
input_point_np = np.array(input_point)
input_label_np = np.array(input_label)
masks, scores, logits = predictor.predict(
point_coords=input_point_np,
point_labels=input_label_np,
mask_input=logit_input[None, :, :] if logit_input is not None else None,
multimask_output=True,
)
mask_idx = 0
num_masks = len(masks)
prediction_window_name = "Prediction"
cv2.namedWindow(prediction_window_name, cv2.WINDOW_NORMAL)
cv2.resizeWindow(prediction_window_name, WINDOW_WIDTH, WINDOW_HEIGHT)
cv2.moveWindow(prediction_window_name, (1920 - WINDOW_WIDTH) // 2,
(1080 - WINDOW_HEIGHT) // 2)
while True:
color = tuple(np.random.randint(0, 256, 3).tolist())
image_select = image_orign.copy()
selected_mask = masks[mask_idx]
selected_image = apply_color_mask(image_select, selected_mask, color)
mask_info = f'Total: {num_masks} | Current: {mask_idx} | Score: {scores[mask_idx]:.2f} | w 预测 | d 切换下一个 | a 切换上一个 |q 移除最后一个'
cv2.putText(selected_image, mask_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7,
(0, 255, 255), 2, cv2.LINE_AA)
cv2.imshow(prediction_window_name, selected_image)
key = cv2.waitKey(10)
if key == ord('q') and len(input_point) > 0:
input_point.pop(-1)
elif key == ord('s'):
save_masked_image(image_crop, selected_mask, output_dir, filename,
crop_mode_=crop_mode)
elif key == ord('a'):
if mask_idx > 0:
mask_idx -= 1
else:
mask_idx = num_masks - 1
elif key == ord('d'):
if mask_idx < num_masks - 1:
mask_idx += 1
else:
mask_idx = 0
elif key == ord('w'):
input_stop = False # Allow adding points again
break
elif key == ord(" "):
input_point = []
input_label = []
selected_mask = None
logit_input = None
break
logit_input = logits[mask_idx, :, :]
print('max score:', np.argmax(scores), ' select:', mask_idx)
elif key == ord('a'):
current_index = max(0, current_index - 1)
input_point = []
input_label = []
break
elif key == ord('d'):
current_index = min(len(image_files) - 1, current_index + 1)
input_point = []
input_label = []
break
elif key == 27:
break
elif key == ord('q') and len(input_point) > 0:
input_point.pop(-1)
input_label.pop(-1)
elif key == ord('s') and selected_mask is not None:
save_masked_image(image_crop, selected_mask, output_dir, filename, crop_mode_=crop_mode)
if key == 27:
break
cv2.destroyAllWindows() # Close all windows before exiting
if key == 27:
break
def save_prediction(self):
if self.label_2.pixmap(): # 检查预测图像区域是否有图像
# 保存预测结果的部分,这里假设你已经有了保存预测结果的代码,我用 placeholer 代替
# placeholder: 这里假设 save_prediction_result 是一个保存预测结果的函数,它接受预测结果的图像数据以及保存路径作为参数
# 这里假设预测结果图像数据为 prediction_image保存路径为 save_path
prediction_image = self.label_2.pixmap().toImage()
save_path = "prediction_result.png"
prediction_image.save(save_path)
# 调用 adjust_prediction_size 方法来根据 horizontalSlider 的值调整预测区域的大小
self.adjust_prediction_size(self.horizontalSlider.value())
def adjust_prediction_size(self, value):
if self.image_files and self.current_index < len(self.image_files):
# 获取预测图像区域的原始大小
pixmap = self.label_2.pixmap()
if pixmap.isNull():
return
original_size = pixmap.size()
# 判断是缩小还是还原图像
if value < self.last_value:
# 缩小掩码
scale_factor = 1.0 + (self.last_value - value) * 0.1
else:
# 放大掩码
scale_factor = 1.0 - (value - self.last_value) * 0.1
self.last_value = value # 更新上一次的滑块值
# 根据缩放比例调整预测图像区域的大小,并保持纵横比例
scaled_size = QtCore.QSize(original_size.width() * scale_factor, original_size.height() * scale_factor)
scaled_pixmap = pixmap.scaled(scaled_size, QtCore.Qt.KeepAspectRatio)
# 更新预测图像区域的大小并显示
self.label_2.setPixmap(scaled_pixmap)
if __name__ == "__main__":
app = QtWidgets.QApplication(sys.argv)
mainWindow = MyMainWindow()
mainWindow.show()
sys.exit(app.exec_())

248
scripts/amg.py Normal file
View File

@ -0,0 +1,248 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import cv2 # type: ignore
import matplotlib.pyplot as plt
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
import argparse
import json
import os
from typing import Any, Dict, List
import numpy as np
parser = argparse.ArgumentParser(
description=(
"Runs automatic mask generation on an input image or directory of images, "
"and outputs masks as either PNGs or COCO-style RLEs. Requires open-cv, "
"as well as pycocotools if saving in RLE format."
)
)
# --input scripts/input/crops/Guide/23140.jpg --output scripts/output/crops --model-type vit_b --checkpoint D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth
parser.add_argument(
"--input",
type=str,
default=r'scripts/input/crops/Guide/23140.jpg',
required=False,
help="Path to either a single input image or folder of images.",
)
parser.add_argument(
"--output",
type=str,
required=False,
default=r'scripts/output/crops/Guide/',
help=(
"Path to the directory where masks will be output. Output will be either a folder "
"of PNGs per image or a single json with COCO-style masks."
),
)
parser.add_argument(
"--model-type",
type=str,
required=False,
default='vit_b',
help="The type of model to load, in ['default', 'vit_h', 'vit_l', 'vit_b']",
)
parser.add_argument(
"--checkpoint",
type=str,
required=False,
default=r'D:/Program Files/Pycharm items/segment-anything-model/weights/vit_b.pth',
help="The path to the SAM checkpoint to use for mask generation.",
)
parser.add_argument("--device", type=str, default="cuda", help="The device to run generation on.")
parser.add_argument(
"--convert-to-rle",
action="store_true",
help=(
"Save masks as COCO RLEs in a single json instead of as a folder of PNGs. "
"Requires pycocotools."
),
)
amg_settings = parser.add_argument_group("AMG Settings")
amg_settings.add_argument(
"--points-per-side",
type=int,
default=None,
help="Generate masks by sampling a grid over the image with this many points to a side.",
)
amg_settings.add_argument(
"--points-per-batch",
type=int,
default=None,
help="How many input points to process simultaneously in one batch.",
)
amg_settings.add_argument(
"--pred-iou-thresh",
type=float,
default=None,
help="Exclude masks with a predicted score from the model that is lower than this threshold.",
)
amg_settings.add_argument(
"--stability-score-thresh",
type=float,
default=None,
help="Exclude masks with a stability score lower than this threshold.",
)
amg_settings.add_argument(
"--stability-score-offset",
type=float,
default=None,
help="Larger values perturb the mask more when measuring stability score.",
)
amg_settings.add_argument(
"--box-nms-thresh",
type=float,
default=None,
help="The overlap threshold for excluding a duplicate mask.",
)
amg_settings.add_argument(
"--crop-n-layers",
type=int,
default=None,
help=(
"If >0, mask generation is run on smaller crops of the image to generate more masks. "
"The value sets how many different scales to crop at."
),
)
amg_settings.add_argument(
"--crop-nms-thresh",
type=float,
default=None,
help="The overlap threshold for excluding duplicate masks across different crops.",
)
amg_settings.add_argument(
"--crop-overlap-ratio",
type=int,
default=None,
help="Larger numbers mean image crops will overlap more.",
)
amg_settings.add_argument(
"--crop-n-points-downscale-factor",
type=int,
default=None,
help="The number of points-per-side in each layer of crop is reduced by this factor.",
)
amg_settings.add_argument(
"--min-mask-region-area",
type=int,
default=None,
help=(
"Disconnected mask regions or holes with area smaller than this value "
"in pixels are removed by postprocessing."
),
)
def write_masks_to_folder(masks: List[Dict[str, Any]], path: str) -> None:
header = "id,area,bbox_x0,bbox_y0,bbox_w,bbox_h,point_input_x,point_input_y,predicted_iou,stability_score,crop_box_x0,crop_box_y0,crop_box_w,crop_box_h" # noqa
metadata = [header]
for i, mask_data in enumerate(masks):
mask = mask_data["segmentation"]
filename = f"{i}.png"
cv2.imwrite(os.path.join(path, filename), mask * 255)
mask_metadata = [
str(i),
str(mask_data["area"]),
*[str(x) for x in mask_data["bbox"]],
*[str(x) for x in mask_data["point_coords"][0]],
str(mask_data["predicted_iou"]),
str(mask_data["stability_score"]),
*[str(x) for x in mask_data["crop_box"]],
]
row = ",".join(mask_metadata)
metadata.append(row)
metadata_path = os.path.join(path, "metadata.csv")
with open(metadata_path, "w") as f:
f.write("\n".join(metadata))
return
def get_amg_kwargs(args):
amg_kwargs = {
"points_per_side": args.points_per_side,
"points_per_batch": args.points_per_batch,
"pred_iou_thresh": args.pred_iou_thresh,
"stability_score_thresh": args.stability_score_thresh,
"stability_score_offset": args.stability_score_offset,
"box_nms_thresh": args.box_nms_thresh,
"crop_n_layers": args.crop_n_layers,
"crop_nms_thresh": args.crop_nms_thresh,
"crop_overlap_ratio": args.crop_overlap_ratio,
"crop_n_points_downscale_factor": args.crop_n_points_downscale_factor,
"min_mask_region_area": args.min_mask_region_area,
}
amg_kwargs = {k: v for k, v in amg_kwargs.items() if v is not None}
return amg_kwargs
def main(args: argparse.Namespace) -> None:
print("Loading model...")
sam = sam_model_registry[args.model_type](checkpoint=args.checkpoint)
_ = sam.to(device=args.device)
output_mode = "coco_rle" if args.convert_to_rle else "binary_mask"
amg_kwargs = get_amg_kwargs(args)
generator = SamAutomaticMaskGenerator(sam, output_mode=output_mode, **amg_kwargs)
if not os.path.isdir(args.input):
targets = [args.input]
else:
targets = [
f for f in os.listdir(args.input) if not os.path.isdir(os.path.join(args.input, f))
]
targets = [os.path.join(args.input, f) for f in targets]
os.makedirs(args.output, exist_ok=True)
for t in targets:
print(f"Processing '{t}'...")
image = cv2.imread(t)
if image is None:
print(f"Could not load '{t}' as an image, skipping...")
continue
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
masks = generator.generate(image)
base = os.path.basename(t)
base = os.path.splitext(base)[0]
save_base = os.path.join(args.output, base)
if output_mode == "binary_mask":
os.makedirs(save_base, exist_ok=False)
write_masks_to_folder(masks, save_base)
else:
save_file = save_base + ".json"
with open(save_file, "w") as f:
json.dump(masks, f)
print("Done!")
if __name__ == "__main__":
args = parser.parse_args()
main(args)
"""
D:\anaconda3\envs\pytorch\python.exe "D:/Program Files/Pycharm items/segment-anything-model/scripts/amg.py" --input "scripts/input/crops/Guide/23140.jpg" --output "scripts/output/crops" --model-type vit_b --checkpoint "D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth"
"""

272
scripts/amg1.py Normal file
View File

@ -0,0 +1,272 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import cv2 # type: ignore
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
import argparse
import json
import os
from typing import Any, Dict, List
parser = argparse.ArgumentParser(
description=(
"Runs automatic mask generation on an input image or directory of images, "
"and outputs masks as either PNGs or COCO-style RLEs. Requires open-cv, "
"as well as pycocotools if saving in RLE format."
)
)
parser.add_argument(
"--input",
type=str,
required=False,
default=r'D:\Program Files\Pycharm items\segment-anything-model\scripts\input\images',
help="Path to either a single input image or folder of images.",
)
parser.add_argument(
"--output",
type=str,
required=False,
default=r'output/mask',
help=(
"Path to the directory where masks will be output. Output will be either a folder "
"of PNGs per image or a single json with COCO-style masks."
),
)
parser.add_argument(
"--model-type",
type=str,
required=False,
default='vit_b',
help="The type of model to load, in ['default', 'vit_h', 'vit_l', 'vit_b']",
)
parser.add_argument(
"--checkpoint",
type=str,
required=False,
default=r'D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth',
help="The path to the SAM checkpoint to use for mask generation.",
)
# parser.add_argument(
# "--input",
# type=str,
# required=True,
# help="Path to either a single input image or folder of images.",
# )
#
# parser.add_argument(
# "--output",
# type=str,
# required=True,
# help=(
# "Path to the directory where masks will be output. Output will be either a folder "
# "of PNGs per image or a single json with COCO-style masks."
# ),
# )
#
# parser.add_argument(
# "--model-type",
# type=str,
# required=True,
# help="The type of model to load, in ['default', 'vit_h', 'vit_l', 'vit_b']",
# )
#
# parser.add_argument(
# "--checkpoint",
# type=str,
# required=True,
# help="The path to the SAM checkpoint to use for mask generation.",
# )
parser.add_argument("--device", type=str, default="cuda", help="The device to run generation on.")
parser.add_argument(
"--convert-to-rle",
action="store_true",
help=(
"Save masks as COCO RLEs in a single json instead of as a folder of PNGs. "
"Requires pycocotools."
),
)
amg_settings = parser.add_argument_group("AMG Settings")
amg_settings.add_argument(
"--points-per-side",
type=int,
default=None,
help="Generate masks by sampling a grid over the image with this many points to a side.",
)
amg_settings.add_argument(
"--points-per-batch",
type=int,
default=None,
help="How many input points to process simultaneously in one batch.",
)
amg_settings.add_argument(
"--pred-iou-thresh",
type=float,
default=None,
help="Exclude masks with a predicted score from the model that is lower than this threshold.",
)
amg_settings.add_argument(
"--stability-score-thresh",
type=float,
default=None,
help="Exclude masks with a stability score lower than this threshold.",
)
amg_settings.add_argument(
"--stability-score-offset",
type=float,
default=None,
help="Larger values perturb the mask more when measuring stability score.",
)
amg_settings.add_argument(
"--box-nms-thresh",
type=float,
default=None,
help="The overlap threshold for excluding a duplicate mask.",
)
amg_settings.add_argument(
"--crop-n-layers",
type=int,
default=None,
help=(
"If >0, mask generation is run on smaller crops of the image to generate more masks. "
"The value sets how many different scales to crop at."
),
)
amg_settings.add_argument(
"--crop-nms-thresh",
type=float,
default=None,
help="The overlap threshold for excluding duplicate masks across different crops.",
)
amg_settings.add_argument(
"--crop-overlap-ratio",
type=int,
default=None,
help="Larger numbers mean image crops will overlap more.",
)
amg_settings.add_argument(
"--crop-n-points-downscale-factor",
type=int,
default=None,
help="The number of points-per-side in each layer of crop is reduced by this factor.",
)
amg_settings.add_argument(
"--min-mask-region-area",
type=int,
default=None,
help=(
"Disconnected mask regions or holes with area smaller than this value "
"in pixels are removed by postprocessing."
),
)
def write_masks_to_folder(masks: List[Dict[str, Any]], path: str) -> None:
header = "id,area,bbox_x0,bbox_y0,bbox_w,bbox_h,point_input_x,point_input_y,predicted_iou,stability_score,crop_box_x0,crop_box_y0,crop_box_w,crop_box_h" # noqa
metadata = [header]
for i, mask_data in enumerate(masks):
mask = mask_data["segmentation"]
filename = f"{i}.png"
cv2.imwrite(os.path.join(path, filename), mask * 255)
mask_metadata = [
str(i),
str(mask_data["area"]),
*[str(x) for x in mask_data["bbox"]],
*[str(x) for x in mask_data["point_coords"][0]],
str(mask_data["predicted_iou"]),
str(mask_data["stability_score"]),
*[str(x) for x in mask_data["crop_box"]],
]
row = ",".join(mask_metadata)
metadata.append(row)
metadata_path = os.path.join(path, "metadata.csv")
with open(metadata_path, "w") as f:
f.write("\n".join(metadata))
return
def get_amg_kwargs(args):
amg_kwargs = {
"points_per_side": args.points_per_side,
"points_per_batch": args.points_per_batch,
"pred_iou_thresh": args.pred_iou_thresh,
"stability_score_thresh": args.stability_score_thresh,
"stability_score_offset": args.stability_score_offset,
"box_nms_thresh": args.box_nms_thresh,
"crop_n_layers": args.crop_n_layers,
"crop_nms_thresh": args.crop_nms_thresh,
"crop_overlap_ratio": args.crop_overlap_ratio,
"crop_n_points_downscale_factor": args.crop_n_points_downscale_factor,
"min_mask_region_area": args.min_mask_region_area,
}
amg_kwargs = {k: v for k, v in amg_kwargs.items() if v is not None}
return amg_kwargs
def main(args: argparse.Namespace) -> None:
print("Loading model...")
sam = sam_model_registry[args.model_type](checkpoint=args.checkpoint)
_ = sam.to(device=args.device)
output_mode = "coco_rle" if args.convert_to_rle else "binary_mask"
amg_kwargs = get_amg_kwargs(args)
generator = SamAutomaticMaskGenerator(sam, output_mode=output_mode, **amg_kwargs)
if not os.path.isdir(args.input):
targets = [args.input]
else:
targets = [
f for f in os.listdir(args.input) if not os.path.isdir(os.path.join(args.input, f))
]
targets = [os.path.join(args.input, f) for f in targets]
os.makedirs(args.output, exist_ok=True)
for t in targets:
print(f"Processing '{t}'...")
image = cv2.imread(t)
if image is None:
print(f"Could not load '{t}' as an image, skipping...")
continue
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
masks = generator.generate(image)
base = os.path.basename(t)
base = os.path.splitext(base)[0]
save_base = os.path.join(args.output, base)
if output_mode == "binary_mask":
os.makedirs(save_base, exist_ok=False)
write_masks_to_folder(masks, save_base)
else:
save_file = save_base + ".json"
with open(save_file, "w") as f:
json.dump(masks, f)
print("Done!")
if __name__ == "__main__":
args = parser.parse_args()
main(args)

View File

@ -0,0 +1,201 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
from segment_anything import sam_model_registry
from segment_anything.utils.onnx import SamOnnxModel
import argparse
import warnings
try:
import onnxruntime # type: ignore
onnxruntime_exists = True
except ImportError:
onnxruntime_exists = False
parser = argparse.ArgumentParser(
description="Export the SAM prompt encoder and mask decoder to an ONNX model."
)
parser.add_argument(
"--checkpoint", type=str, required=True, help="The path to the SAM model checkpoint."
)
parser.add_argument(
"--output", type=str, required=True, help="The filename to save the ONNX model to."
)
parser.add_argument(
"--model-type",
type=str,
required=True,
help="In ['default', 'vit_h', 'vit_l', 'vit_b']. Which type of SAM model to export.",
)
parser.add_argument(
"--return-single-mask",
action="store_true",
help=(
"If true, the exported ONNX model will only return the best mask, "
"instead of returning multiple masks. For high resolution images "
"this can improve runtime when upscaling masks is expensive."
),
)
parser.add_argument(
"--opset",
type=int,
default=17,
help="The ONNX opset version to use. Must be >=11",
)
parser.add_argument(
"--quantize-out",
type=str,
default=None,
help=(
"If set, will quantize the model and save it with this name. "
"Quantization is performed with quantize_dynamic from onnxruntime.quantization.quantize."
),
)
parser.add_argument(
"--gelu-approximate",
action="store_true",
help=(
"Replace GELU operations with approximations using tanh. Useful "
"for some runtimes that have slow or unimplemented erf ops, used in GELU."
),
)
parser.add_argument(
"--use-stability-score",
action="store_true",
help=(
"Replaces the model's predicted mask quality score with the stability "
"score calculated on the low resolution masks using an offset of 1.0. "
),
)
parser.add_argument(
"--return-extra-metrics",
action="store_true",
help=(
"The model will return five results: (masks, scores, stability_scores, "
"areas, low_res_logits) instead of the usual three. This can be "
"significantly slower for high resolution outputs."
),
)
def run_export(
model_type: str,
checkpoint: str,
output: str,
opset: int,
return_single_mask: bool,
gelu_approximate: bool = False,
use_stability_score: bool = False,
return_extra_metrics=False,
):
print("Loading model...")
sam = sam_model_registry[model_type](checkpoint=checkpoint)
onnx_model = SamOnnxModel(
model=sam,
return_single_mask=return_single_mask,
use_stability_score=use_stability_score,
return_extra_metrics=return_extra_metrics,
)
if gelu_approximate:
for n, m in onnx_model.named_modules():
if isinstance(m, torch.nn.GELU):
m.approximate = "tanh"
dynamic_axes = {
"point_coords": {1: "num_points"},
"point_labels": {1: "num_points"},
}
embed_dim = sam.prompt_encoder.embed_dim
embed_size = sam.prompt_encoder.image_embedding_size
mask_input_size = [4 * x for x in embed_size]
dummy_inputs = {
"image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float),
"point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float),
"point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float),
"mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float),
"has_mask_input": torch.tensor([1], dtype=torch.float),
"orig_im_size": torch.tensor([1500, 2250], dtype=torch.float),
}
_ = onnx_model(**dummy_inputs)
output_names = ["masks", "iou_predictions", "low_res_masks"]
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
warnings.filterwarnings("ignore", category=UserWarning)
with open(output, "wb") as f:
print(f"Exporting onnx model to {output}...")
torch.onnx.export(
onnx_model,
tuple(dummy_inputs.values()),
f,
export_params=True,
verbose=False,
opset_version=opset,
do_constant_folding=True,
input_names=list(dummy_inputs.keys()),
output_names=output_names,
dynamic_axes=dynamic_axes,
)
if onnxruntime_exists:
ort_inputs = {k: to_numpy(v) for k, v in dummy_inputs.items()}
# set cpu provider default
providers = ["CPUExecutionProvider"]
ort_session = onnxruntime.InferenceSession(output, providers=providers)
_ = ort_session.run(None, ort_inputs)
print("Model has successfully been run with ONNXRuntime.")
def to_numpy(tensor):
return tensor.cpu().numpy()
if __name__ == "__main__":
args = parser.parse_args()
run_export(
model_type=args.model_type,
checkpoint=args.checkpoint,
output=args.output,
opset=args.opset,
return_single_mask=args.return_single_mask,
gelu_approximate=args.gelu_approximate,
use_stability_score=args.use_stability_score,
return_extra_metrics=args.return_extra_metrics,
)
if args.quantize_out is not None:
assert onnxruntime_exists, "onnxruntime is required to quantize the model."
from onnxruntime.quantization import QuantType # type: ignore
from onnxruntime.quantization.quantize import quantize_dynamic # type: ignore
print(f"Quantizing model and writing to {args.quantize_out}...")
quantize_dynamic(
model_input=args.output,
model_output=args.quantize_out,
optimize_model=True,
per_channel=False,
reduce_range=False,
weight_type=QuantType.QUInt8,
)
print("Done!")

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 964 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 918 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.7 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.7 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 906 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 915 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 959 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 KiB

Some files were not shown because too many files have changed in this diff Show More