SAM_Project
|
@ -0,0 +1,7 @@
|
|||
[flake8]
|
||||
ignore = W503, E203, E221, C901, C408, E741, C407, B017, F811, C101, EXE001, EXE002
|
||||
max-line-length = 100
|
||||
max-complexity = 18
|
||||
select = B,C,E,F,W,T4,B9
|
||||
per-file-ignores =
|
||||
**/__init__.py:F401,F403,E402
|
|
@ -0,0 +1,8 @@
|
|||
# 默认忽略的文件
|
||||
/shelf/
|
||||
/workspace.xml
|
||||
# 基于编辑器的 HTTP 客户端请求
|
||||
/httpRequests/
|
||||
# Datasource local storage ignored files
|
||||
/dataSources/
|
||||
/dataSources.local.xml
|
|
@ -0,0 +1,98 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="PublishConfigData" remoteFilesAllowedToDisappearOnAutoupload="false">
|
||||
<serverData>
|
||||
<paths name="root@123.125.240.150:45809">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="root@222.187.226.110:28961">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="root@connect.east.seetacloud.com:15907">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="root@connect.east.seetacloud.com:26749">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="root@connect.east.seetacloud.com:26749 (2)">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="root@connect.east.seetacloud.com:26749 (3)">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="root@connect.east.seetacloud.com:26749 (4)">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="root@connect.east.seetacloud.com:26749 (5)">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="root@connect.east.seetacloud.com:26749 password">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="root@region-42.seetacloud.com:14975">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="root@region-42.seetacloud.com:34252">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="root@region-8.seetacloud.com:35693">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="root@region-8.seetacloud.com:35693 (2)">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
</serverData>
|
||||
</component>
|
||||
</project>
|
|
@ -0,0 +1,26 @@
|
|||
<component name="InspectionProjectProfileManager">
|
||||
<profile version="1.0">
|
||||
<option name="myName" value="Project Default" />
|
||||
<inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
|
||||
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
||||
<option name="ignoredPackages">
|
||||
<value>
|
||||
<list size="11">
|
||||
<item index="0" class="java.lang.String" itemvalue="sklearn" />
|
||||
<item index="1" class="java.lang.String" itemvalue="tqdm" />
|
||||
<item index="2" class="java.lang.String" itemvalue="scipy" />
|
||||
<item index="3" class="java.lang.String" itemvalue="h5py" />
|
||||
<item index="4" class="java.lang.String" itemvalue="matplotlib" />
|
||||
<item index="5" class="java.lang.String" itemvalue="torch" />
|
||||
<item index="6" class="java.lang.String" itemvalue="numpy" />
|
||||
<item index="7" class="java.lang.String" itemvalue="torchvision" />
|
||||
<item index="8" class="java.lang.String" itemvalue="opencv_python" />
|
||||
<item index="9" class="java.lang.String" itemvalue="Pillow" />
|
||||
<item index="10" class="java.lang.String" itemvalue="lxml" />
|
||||
</list>
|
||||
</value>
|
||||
</option>
|
||||
</inspection_tool>
|
||||
<inspection_tool class="PyPep8NamingInspection" enabled="false" level="WEAK WARNING" enabled_by_default="false" />
|
||||
</profile>
|
||||
</component>
|
|
@ -0,0 +1,6 @@
|
|||
<component name="InspectionProjectProfileManager">
|
||||
<settings>
|
||||
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||
<version value="1.0" />
|
||||
</settings>
|
||||
</component>
|
|
@ -0,0 +1,4 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.9 (pytorch)" project-jdk-type="Python SDK" />
|
||||
</project>
|
|
@ -0,0 +1,8 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectModuleManager">
|
||||
<modules>
|
||||
<module fileurl="file://$PROJECT_DIR$/.idea/segment-anything-model.iml" filepath="$PROJECT_DIR$/.idea/segment-anything-model.iml" />
|
||||
</modules>
|
||||
</component>
|
||||
</project>
|
|
@ -0,0 +1,14 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$">
|
||||
<sourceFolder url="file://$MODULE_DIR$" isTestSource="false" />
|
||||
</content>
|
||||
<orderEntry type="jdk" jdkName="Python 3.9 (pytorch)" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
<component name="PyDocumentationSettings">
|
||||
<option name="format" value="PLAIN" />
|
||||
<option name="myDocStringFormat" value="Plain" />
|
||||
</component>
|
||||
</module>
|
|
@ -0,0 +1,80 @@
|
|||
# Code of Conduct
|
||||
|
||||
## Our Pledge
|
||||
|
||||
In the interest of fostering an open and welcoming environment, we as
|
||||
contributors and maintainers pledge to make participation in our project and
|
||||
our community a harassment-free experience for everyone, regardless of age, body
|
||||
size, disability, ethnicity, sex characteristics, gender identity and expression,
|
||||
level of experience, education, socio-economic status, nationality, personal
|
||||
appearance, race, religion, or sexual identity and orientation.
|
||||
|
||||
## Our Standards
|
||||
|
||||
Examples of behavior that contributes to creating a positive environment
|
||||
include:
|
||||
|
||||
* Using welcoming and inclusive language
|
||||
* Being respectful of differing viewpoints and experiences
|
||||
* Gracefully accepting constructive criticism
|
||||
* Focusing on what is best for the community
|
||||
* Showing empathy towards other community members
|
||||
|
||||
Examples of unacceptable behavior by participants include:
|
||||
|
||||
* The use of sexualized language or imagery and unwelcome sexual attention or
|
||||
advances
|
||||
* Trolling, insulting/derogatory comments, and personal or political attacks
|
||||
* Public or private harassment
|
||||
* Publishing others' private information, such as a physical or electronic
|
||||
address, without explicit permission
|
||||
* Other conduct which could reasonably be considered inappropriate in a
|
||||
professional setting
|
||||
|
||||
## Our Responsibilities
|
||||
|
||||
Project maintainers are responsible for clarifying the standards of acceptable
|
||||
behavior and are expected to take appropriate and fair corrective action in
|
||||
response to any instances of unacceptable behavior.
|
||||
|
||||
Project maintainers have the right and responsibility to remove, edit, or
|
||||
reject comments, commits, code, wiki edits, issues, and other contributions
|
||||
that are not aligned to this Code of Conduct, or to ban temporarily or
|
||||
permanently any contributor for other behaviors that they deem inappropriate,
|
||||
threatening, offensive, or harmful.
|
||||
|
||||
## Scope
|
||||
|
||||
This Code of Conduct applies within all project spaces, and it also applies when
|
||||
an individual is representing the project or its community in public spaces.
|
||||
Examples of representing a project or community include using an official
|
||||
project e-mail address, posting via an official social media account, or acting
|
||||
as an appointed representative at an online or offline event. Representation of
|
||||
a project may be further defined and clarified by project maintainers.
|
||||
|
||||
This Code of Conduct also applies outside the project spaces when there is a
|
||||
reasonable belief that an individual's behavior may have a negative impact on
|
||||
the project or its community.
|
||||
|
||||
## Enforcement
|
||||
|
||||
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
||||
reported by contacting the project team at <opensource-conduct@fb.com>. All
|
||||
complaints will be reviewed and investigated and will result in a response that
|
||||
is deemed necessary and appropriate to the circumstances. The project team is
|
||||
obligated to maintain confidentiality with regard to the reporter of an incident.
|
||||
Further details of specific enforcement policies may be posted separately.
|
||||
|
||||
Project maintainers who do not follow or enforce the Code of Conduct in good
|
||||
faith may face temporary or permanent repercussions as determined by other
|
||||
members of the project's leadership.
|
||||
|
||||
## Attribution
|
||||
|
||||
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
|
||||
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
|
||||
|
||||
[homepage]: https://www.contributor-covenant.org
|
||||
|
||||
For answers to common questions about this code of conduct, see
|
||||
https://www.contributor-covenant.org/faq
|
|
@ -0,0 +1,31 @@
|
|||
# Contributing to segment-anything
|
||||
We want to make contributing to this project as easy and transparent as
|
||||
possible.
|
||||
|
||||
## Pull Requests
|
||||
We actively welcome your pull requests.
|
||||
|
||||
1. Fork the repo and create your branch from `main`.
|
||||
2. If you've added code that should be tested, add tests.
|
||||
3. If you've changed APIs, update the documentation.
|
||||
4. Ensure the test suite passes.
|
||||
5. Make sure your code lints, using the `linter.sh` script in the project's root directory. Linting requires `black==23.*`, `isort==5.12.0`, `flake8`, and `mypy`.
|
||||
6. If you haven't already, complete the Contributor License Agreement ("CLA").
|
||||
|
||||
## Contributor License Agreement ("CLA")
|
||||
In order to accept your pull request, we need you to submit a CLA. You only need
|
||||
to do this once to work on any of Facebook's open source projects.
|
||||
|
||||
Complete your CLA here: <https://code.facebook.com/cla>
|
||||
|
||||
## Issues
|
||||
We use GitHub issues to track public bugs. Please ensure your description is
|
||||
clear and has sufficient instructions to be able to reproduce the issue.
|
||||
|
||||
Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
|
||||
disclosure of security bugs. In those cases, please go through the process
|
||||
outlined on that page and do not file a public issue.
|
||||
|
||||
## License
|
||||
By contributing to segment-anything, you agree that your contributions will be licensed
|
||||
under the LICENSE file in the root directory of this source tree.
|
|
@ -0,0 +1,201 @@
|
|||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
|
@ -0,0 +1,171 @@
|
|||
# Segment Anything
|
||||
|
||||
**[Meta AI Research, FAIR](https://ai.facebook.com/research/)**
|
||||
|
||||
[Alexander Kirillov](https://alexander-kirillov.github.io/), [Eric Mintun](https://ericmintun.github.io/), [Nikhila Ravi](https://nikhilaravi.com/), [Hanzi Mao](https://hanzimao.me/), Chloe Rolland, Laura Gustafson, [Tete Xiao](https://tetexiao.com), [Spencer Whitehead](https://www.spencerwhitehead.com/), Alex Berg, Wan-Yen Lo, [Piotr Dollar](https://pdollar.github.io/), [Ross Girshick](https://www.rossgirshick.info/)
|
||||
|
||||
[[`Paper`](https://ai.facebook.com/research/publications/segment-anything/)] [[`Project`](https://segment-anything.com/)] [[`Demo`](https://segment-anything.com/demo)] [[`Dataset`](https://segment-anything.com/dataset/index.html)] [[`Blog`](https://ai.facebook.com/blog/segment-anything-foundation-model-image-segmentation/)] [[`BibTeX`](#citing-segment-anything)]
|
||||
|
||||

|
||||
|
||||
The **Segment Anything Model (SAM)** produces high quality object masks from input prompts such as points or boxes, and it can be used to generate masks for all objects in an image. It has been trained on a [dataset](https://segment-anything.com/dataset/index.html) of 11 million images and 1.1 billion masks, and has strong zero-shot performance on a variety of segmentation tasks.
|
||||
|
||||
<p float="left">
|
||||
<img src="assets/masks1.png?raw=true" width="37.25%" />
|
||||
<img src="assets/masks2.jpg?raw=true" width="61.5%" />
|
||||
</p>
|
||||
|
||||
## Installation
|
||||
|
||||
The code requires `python>=3.8`, as well as `pytorch>=1.7` and `torchvision>=0.8`. Please follow the instructions [here](https://pytorch.org/get-started/locally/) to install both PyTorch and TorchVision dependencies. Installing both PyTorch and TorchVision with CUDA support is strongly recommended.
|
||||
|
||||
Install Segment Anything:
|
||||
|
||||
```
|
||||
pip install git+https://github.com/facebookresearch/segment-anything.git
|
||||
```
|
||||
|
||||
or clone the repository locally and install with
|
||||
|
||||
```
|
||||
git clone git@github.com:facebookresearch/segment-anything.git
|
||||
cd segment-anything; pip install -e .
|
||||
```
|
||||
|
||||
The following optional dependencies are necessary for mask post-processing, saving masks in COCO format, the example notebooks, and exporting the model in ONNX format. `jupyter` is also required to run the example notebooks.
|
||||
|
||||
```
|
||||
pip install opencv-python pycocotools matplotlib onnxruntime onnx
|
||||
```
|
||||
|
||||
## <a name="GettingStarted"></a>Getting Started
|
||||
|
||||
First download a [model checkpoint](#model-checkpoints). Then the model can be used in just a few lines to get masks from a given prompt:
|
||||
|
||||
```
|
||||
from segment_anything import SamPredictor, sam_model_registry
|
||||
sam = sam_model_registry["<model_type>"](checkpoint="<path/to/checkpoint>")
|
||||
predictor = SamPredictor(sam)
|
||||
predictor.set_image(<your_image>)
|
||||
masks, _, _ = predictor.predict(<input_prompts>)
|
||||
```
|
||||
|
||||
or generate masks for an entire image:
|
||||
|
||||
```
|
||||
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
|
||||
sam = sam_model_registry["<model_type>"](checkpoint="<path/to/checkpoint>")
|
||||
mask_generator = SamAutomaticMaskGenerator(sam)
|
||||
masks = mask_generator.generate(<your_image>)
|
||||
```
|
||||
|
||||
Additionally, masks can be generated for images from the command line:
|
||||
|
||||
```
|
||||
python scripts/amg.py --checkpoint <path/to/checkpoint> --model-type <model_type> --input <image_or_folder> --output <path/to/output>
|
||||
```
|
||||
|
||||
See the examples notebooks on [using SAM with prompts](/notebooks/predictor_example.ipynb) and [automatically generating masks](/notebooks/automatic_mask_generator_example.ipynb) for more details.
|
||||
|
||||
<p float="left">
|
||||
<img src="assets/notebook1.png?raw=true" width="49.1%" />
|
||||
<img src="assets/notebook2.png?raw=true" width="48.9%" />
|
||||
</p>
|
||||
|
||||
## ONNX Export
|
||||
|
||||
SAM's lightweight mask decoder can be exported to ONNX format so that it can be run in any environment that supports ONNX runtime, such as in-browser as showcased in the [demo](https://segment-anything.com/demo). Export the model with
|
||||
|
||||
```
|
||||
python scripts/export_onnx_model.py --checkpoint <path/to/checkpoint> --model-type <model_type> --output <path/to/output>
|
||||
```
|
||||
|
||||
See the [example notebook](https://github.com/facebookresearch/segment-anything/blob/main/notebooks/onnx_model_example.ipynb) for details on how to combine image preprocessing via SAM's backbone with mask prediction using the ONNX model. It is recommended to use the latest stable version of PyTorch for ONNX export.
|
||||
|
||||
### Web demo
|
||||
|
||||
The `demo/` folder has a simple one page React app which shows how to run mask prediction with the exported ONNX model in a web browser with multithreading. Please see [`demo/README.md`](https://github.com/facebookresearch/segment-anything/blob/main/demo/README.md) for more details.
|
||||
|
||||
## <a name="Models"></a>Model Checkpoints
|
||||
|
||||
Three model versions of the model are available with different backbone sizes. These models can be instantiated by running
|
||||
|
||||
```
|
||||
from segment_anything import sam_model_registry
|
||||
sam = sam_model_registry["<model_type>"](checkpoint="<path/to/checkpoint>")
|
||||
```
|
||||
|
||||
Click the links below to download the checkpoint for the corresponding model type.
|
||||
|
||||
- **`default` or `vit_h`: [ViT-H SAM model.](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth)**
|
||||
- `vit_l`: [ViT-L SAM model.](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth)
|
||||
- `vit_b`: [ViT-B SAM model.](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth)
|
||||
|
||||
## Dataset
|
||||
|
||||
See [here](https://ai.facebook.com/datasets/segment-anything/) for an overview of the datastet. The dataset can be downloaded [here](https://ai.facebook.com/datasets/segment-anything-downloads/). By downloading the datasets you agree that you have read and accepted the terms of the SA-1B Dataset Research License.
|
||||
|
||||
We save masks per image as a json file. It can be loaded as a dictionary in python in the below format.
|
||||
|
||||
```python
|
||||
{
|
||||
"image" : image_info,
|
||||
"annotations" : [annotation],
|
||||
}
|
||||
|
||||
image_info {
|
||||
"image_id" : int, # Image id
|
||||
"width" : int, # Image width
|
||||
"height" : int, # Image height
|
||||
"file_name" : str, # Image filename
|
||||
}
|
||||
|
||||
annotation {
|
||||
"id" : int, # Annotation id
|
||||
"segmentation" : dict, # Mask saved in COCO RLE format.
|
||||
"bbox" : [x, y, w, h], # The box around the mask, in XYWH format
|
||||
"area" : int, # The area in pixels of the mask
|
||||
"predicted_iou" : float, # The model's own prediction of the mask's quality
|
||||
"stability_score" : float, # A measure of the mask's quality
|
||||
"crop_box" : [x, y, w, h], # The crop of the image used to generate the mask, in XYWH format
|
||||
"point_coords" : [[x, y]], # The point coordinates input to the model to generate the mask
|
||||
}
|
||||
```
|
||||
|
||||
Image ids can be found in sa_images_ids.txt which can be downloaded using the above [link](https://ai.facebook.com/datasets/segment-anything-downloads/) as well.
|
||||
|
||||
To decode a mask in COCO RLE format into binary:
|
||||
|
||||
```
|
||||
from pycocotools import mask as mask_utils
|
||||
mask = mask_utils.decode(annotation["segmentation"])
|
||||
```
|
||||
|
||||
See [here](https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/mask.py) for more instructions to manipulate masks stored in RLE format.
|
||||
|
||||
## License
|
||||
|
||||
The model is licensed under the [Apache 2.0 license](LICENSE).
|
||||
|
||||
## Contributing
|
||||
|
||||
See [contributing](CONTRIBUTING.md) and the [code of conduct](CODE_OF_CONDUCT.md).
|
||||
|
||||
## Contributors
|
||||
|
||||
The Segment Anything project was made possible with the help of many contributors (alphabetical):
|
||||
|
||||
Aaron Adcock, Vaibhav Aggarwal, Morteza Behrooz, Cheng-Yang Fu, Ashley Gabriel, Ahuva Goldstand, Allen Goodman, Sumanth Gurram, Jiabo Hu, Somya Jain, Devansh Kukreja, Robert Kuo, Joshua Lane, Yanghao Li, Lilian Luong, Jitendra Malik, Mallika Malhotra, William Ngan, Omkar Parkhi, Nikhil Raina, Dirk Rowe, Neil Sejoor, Vanessa Stark, Bala Varadarajan, Bram Wasti, Zachary Winstrom
|
||||
|
||||
## Citing Segment Anything
|
||||
|
||||
If you use SAM or SA-1B in your research, please use the following BibTeX entry.
|
||||
|
||||
```
|
||||
@article{kirillov2023segany,
|
||||
title={Segment Anything},
|
||||
author={Kirillov, Alexander and Mintun, Eric and Ravi, Nikhila and Mao, Hanzi and Rolland, Chloe and Gustafson, Laura and Xiao, Tete and Whitehead, Spencer and Berg, Alexander C. and Lo, Wan-Yen and Doll{\'a}r, Piotr and Girshick, Ross},
|
||||
journal={arXiv:2304.02643},
|
||||
year={2023}
|
||||
}
|
||||
```
|
|
@ -0,0 +1,175 @@
|
|||
import cv2
|
||||
import os
|
||||
import numpy as np
|
||||
from segment_anything import sam_model_registry, SamPredictor
|
||||
|
||||
input_dir = 'scripts/pv1/0.8/wuding/jpg'
|
||||
output_dir = 'scripts/pv1/0.8/wuding/mask'
|
||||
crop_mode = True
|
||||
|
||||
print('最好是每加一个点就按w键predict一次')
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
image_files = [f for f in os.listdir(input_dir) if
|
||||
f.lower().endswith(('.png', '.jpg', '.jpeg', '.JPG', '.JPEG', '.PNG', '.tif'))]
|
||||
|
||||
sam = sam_model_registry["vit_b"](checkpoint=r"D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth")
|
||||
_ = sam.to(device="cuda")
|
||||
predictor = SamPredictor(sam)
|
||||
|
||||
WINDOW_WIDTH = 1280
|
||||
WINDOW_HEIGHT = 720
|
||||
cv2.namedWindow("image", cv2.WINDOW_NORMAL)
|
||||
cv2.resizeWindow("image", WINDOW_WIDTH, WINDOW_HEIGHT)
|
||||
cv2.moveWindow("image", (1920 - WINDOW_WIDTH) // 2, (1080 - WINDOW_HEIGHT) // 2)
|
||||
|
||||
def mouse_click(event, x, y, flags, param): # 鼠标点击事件
|
||||
global input_point, input_label, input_stop # 全局变量,输入点,
|
||||
if not input_stop: # 判定标志是否停止输入响应了!
|
||||
if event == cv2.EVENT_LBUTTONDOWN: # 鼠标左键
|
||||
input_point.append([x, y])
|
||||
input_label.append(1) # 1表示前景点
|
||||
elif event == cv2.EVENT_RBUTTONDOWN: # 鼠标右键
|
||||
input_point.append([x, y])
|
||||
input_label.append(0) # 0表示背景点
|
||||
else:
|
||||
if event == cv2.EVENT_LBUTTONDOWN or event == cv2.EVENT_RBUTTONDOWN: # 提示添加不了
|
||||
print('此时不能添加点,按w退出mask选择模式')
|
||||
|
||||
|
||||
def apply_mask(image, mask, alpha_channel=True):
|
||||
if alpha_channel:
|
||||
alpha = np.zeros_like(mask)
|
||||
alpha[mask == 1] = 255
|
||||
masked_image = image.copy()
|
||||
masked_image[mask == 1] = [0, 255, 0] # 这里用绿色表示mask的区域,可以根据需求修改
|
||||
return cv2.addWeighted(image, 0.5, masked_image, 0.5, 0)
|
||||
else:
|
||||
masked_image = image.copy()
|
||||
masked_image[mask == 1] = [0, 255, 0] # 这里用绿色表示mask的区域,可以根据需求修改
|
||||
return masked_image
|
||||
|
||||
def apply_color_mask(image, mask, color, color_dark=0.5):
|
||||
masked_image = image.copy()
|
||||
for c in range(3):
|
||||
masked_image[:, :, c] = np.where(mask == 1, image[:, :, c] * (1 - color_dark) + color_dark * color[c], image[:, :, c])
|
||||
return masked_image
|
||||
|
||||
def save_masked_image(image, mask, output_dir, filename, crop_mode_):
|
||||
masked_image = apply_mask(image, mask)
|
||||
filename = filename[:filename.rfind('.')] + '_masked.png' # 修改保存的文件名
|
||||
cv2.imwrite(os.path.join(output_dir, filename), masked_image)
|
||||
print(f"Saved as {filename}")
|
||||
|
||||
|
||||
|
||||
current_index = 0
|
||||
|
||||
cv2.namedWindow("image")
|
||||
cv2.setMouseCallback("image", mouse_click)
|
||||
input_point = []
|
||||
input_label = []
|
||||
input_stop = False
|
||||
while True:
|
||||
filename = image_files[current_index]
|
||||
image_orign = cv2.imread(os.path.join(input_dir, filename))
|
||||
image_crop = image_orign.copy() # 原图裁剪
|
||||
image = cv2.cvtColor(image_orign.copy(), cv2.COLOR_BGR2RGB) # 原图色彩转变
|
||||
selected_mask = None
|
||||
logit_input = None
|
||||
while True:
|
||||
# print(input_point)
|
||||
input_stop = False
|
||||
image_display = image_orign.copy()
|
||||
display_info = f'{filename} | Press s to save | Press w to predict | Press d to next image | Press a to previous image | Press space to clear | Press q to remove last point '
|
||||
cv2.putText(image_display, display_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2, cv2.LINE_AA)
|
||||
for point, label in zip(input_point, input_label): # 输入点和输入类型
|
||||
color = (0, 255, 0) if label == 1 else (0, 0, 255)
|
||||
cv2.circle(image_display, tuple(point), 5, color, -1)
|
||||
if selected_mask is not None:
|
||||
color = tuple(np.random.randint(0, 256, 3).tolist())
|
||||
selected_image = apply_color_mask(image_display, selected_mask, color)
|
||||
# S保存,w预测,ad切换,esc退出
|
||||
cv2.imshow("image", image_display)
|
||||
key = cv2.waitKey(1)
|
||||
|
||||
if key == ord(" "):
|
||||
input_point = []
|
||||
input_label = []
|
||||
selected_mask = None
|
||||
logit_input = None
|
||||
elif key == ord("w"):
|
||||
input_stop = True
|
||||
if len(input_point) > 0 and len(input_label) > 0:
|
||||
|
||||
predictor.set_image(image) # 设置输入图像
|
||||
input_point_np = np.array(input_point) # 输入暗示点,需要转变array类型才可以输入
|
||||
input_label_np = np.array(input_label) # 输入暗示点的类型
|
||||
|
||||
masks, scores, logits = predictor.predict(
|
||||
point_coords=input_point_np,
|
||||
point_labels=input_label_np,
|
||||
mask_input=logit_input[None, :, :] if logit_input is not None else None,
|
||||
multimask_output=True,
|
||||
)
|
||||
|
||||
mask_idx = 0
|
||||
num_masks = len(masks) # masks的数量
|
||||
while (1):
|
||||
color = tuple(np.random.randint(0, 256, 3).tolist()) # 随机列表颜色,就是
|
||||
image_select = image_orign.copy()
|
||||
selected_mask = masks[mask_idx] # 选择msks也就是,a,d切换
|
||||
selected_image = apply_color_mask(image_select, selected_mask, color)
|
||||
mask_info = f'Total: {num_masks} | Current: {mask_idx} | Score: {scores[mask_idx]:.2f} | Press w to confirm | Press d to next mask | Press a to previous mask | Press q to remove last point | Press s to save'
|
||||
cv2.putText(selected_image, mask_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2,
|
||||
cv2.LINE_AA)
|
||||
# todo 显示在当前的图片,
|
||||
cv2.imshow("image", selected_image)
|
||||
|
||||
key = cv2.waitKey(10)
|
||||
if key == ord('q') and len(input_point) > 0:
|
||||
input_point.pop(-1)
|
||||
input_label.pop(-1)
|
||||
elif key == ord('s'):
|
||||
save_masked_image(image_crop, selected_mask, output_dir, filename, crop_mode_=crop_mode)
|
||||
elif key == ord('a'):
|
||||
if mask_idx > 0:
|
||||
mask_idx -= 1
|
||||
else:
|
||||
mask_idx = num_masks - 1
|
||||
elif key == ord('d'):
|
||||
if mask_idx < num_masks - 1:
|
||||
mask_idx += 1
|
||||
else:
|
||||
mask_idx = 0
|
||||
elif key == ord('w'):
|
||||
break
|
||||
elif key == ord(" "):
|
||||
input_point = []
|
||||
input_label = []
|
||||
selected_mask = None
|
||||
logit_input = None
|
||||
break
|
||||
logit_input = logits[mask_idx, :, :]
|
||||
print('max score:', np.argmax(scores), ' select:', mask_idx)
|
||||
|
||||
elif key == ord('a'):
|
||||
current_index = max(0, current_index - 1)
|
||||
input_point = []
|
||||
input_label = []
|
||||
break
|
||||
elif key == ord('d'):
|
||||
current_index = min(len(image_files) - 1, current_index + 1)
|
||||
input_point = []
|
||||
input_label = []
|
||||
break
|
||||
elif key == 27:
|
||||
break
|
||||
elif key == ord('q') and len(input_point) > 0:
|
||||
input_point.pop(-1)
|
||||
input_label.pop(-1)
|
||||
elif key == ord('s') and selected_mask is not None:
|
||||
save_masked_image(image_crop, selected_mask, output_dir, filename, crop_mode_=crop_mode)
|
||||
|
||||
if key == 27:
|
||||
break
|
||||
|
|
@ -0,0 +1,180 @@
|
|||
import cv2
|
||||
import os
|
||||
import numpy as np
|
||||
from segment_anything import sam_model_registry, SamPredictor
|
||||
|
||||
input_dir = r'C:\Users\t2581\Desktop\222\images'
|
||||
output_dir = r'C:\Users\t2581\Desktop\222\json'
|
||||
crop_mode = True
|
||||
|
||||
print('最好是每加一个点就按w键predict一次')
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
image_files = [f for f in os.listdir(input_dir) if
|
||||
f.lower().endswith(('.png', '.jpg', '.jpeg', '.JPG', '.JPEG', '.PNG', '.tiff'))]
|
||||
|
||||
sam = sam_model_registry["vit_b"](checkpoint=r"D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth")
|
||||
_ = sam.to(device="cuda")
|
||||
predictor = SamPredictor(sam)
|
||||
|
||||
WINDOW_WIDTH = 1280
|
||||
WINDOW_HEIGHT = 720
|
||||
cv2.namedWindow("image", cv2.WINDOW_NORMAL)
|
||||
cv2.resizeWindow("image", WINDOW_WIDTH, WINDOW_HEIGHT)
|
||||
cv2.moveWindow("image", (1920 - WINDOW_WIDTH) // 2, (1080 - WINDOW_HEIGHT) // 2)
|
||||
|
||||
def mouse_click(event, x, y, flags, param):
|
||||
global input_point, input_label, input_stop
|
||||
if not input_stop:
|
||||
if event == cv2.EVENT_LBUTTONDOWN:
|
||||
input_point.append([x, y])
|
||||
input_label.append(1)
|
||||
elif event == cv2.EVENT_RBUTTONDOWN:
|
||||
input_point.append([x, y])
|
||||
input_label.append(0)
|
||||
else:
|
||||
if event == cv2.EVENT_LBUTTONDOWN or event == cv2.EVENT_RBUTTONDOWN:
|
||||
print('此时不能添加点,按w退出mask选择模式')
|
||||
|
||||
|
||||
def apply_mask(image, mask, alpha_channel=True):
|
||||
if alpha_channel:
|
||||
alpha = np.zeros_like(mask)
|
||||
alpha[mask == 1] = 255
|
||||
masked_image = image.copy()
|
||||
masked_image[mask == 1] = [0, 255, 0]
|
||||
return cv2.addWeighted(image, 0.5, masked_image, 0.5, 0)
|
||||
else:
|
||||
masked_image = image.copy()
|
||||
masked_image[mask == 1] = [0, 255, 0]
|
||||
return masked_image
|
||||
|
||||
def apply_color_mask(image, mask, color, color_dark=0.5):
|
||||
masked_image = image.copy()
|
||||
for c in range(3):
|
||||
masked_image[:, :, c] = np.where(mask == 1, image[:, :, c] * (1 - color_dark) + color_dark * color[c], image[:, :, c])
|
||||
return masked_image
|
||||
|
||||
def draw_external_rectangle(image, mask, pv):
|
||||
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
for contour in contours:
|
||||
x, y, w, h = cv2.boundingRect(contour)
|
||||
cv2.rectangle(image, (x, y), (x + w, y + h), (0, 255, 255), 2) # Yellow rectangle
|
||||
cv2.putText(image, pv, (x, y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
|
||||
|
||||
def save_masked_image(image, mask, output_dir, filename, crop_mode_, pv):
|
||||
masked_image = apply_mask(image, mask)
|
||||
draw_external_rectangle(masked_image, mask, pv)
|
||||
filename = filename[:filename.rfind('.')] + '_masked.png'
|
||||
cv2.imwrite(os.path.join(output_dir, filename), masked_image)
|
||||
print(f"Saved as {filename}")
|
||||
|
||||
current_index = 0
|
||||
|
||||
cv2.namedWindow("image")
|
||||
cv2.setMouseCallback("image", mouse_click)
|
||||
input_point = []
|
||||
input_label = []
|
||||
input_stop = False
|
||||
while True:
|
||||
filename = image_files[current_index]
|
||||
image_orign = cv2.imread(os.path.join(input_dir, filename))
|
||||
image_crop = image_orign.copy()
|
||||
image = cv2.cvtColor(image_orign.copy(), cv2.COLOR_BGR2RGB)
|
||||
selected_mask = None
|
||||
logit_input = None
|
||||
while True:
|
||||
# print(input_point)
|
||||
input_stop = False
|
||||
image_display = image_orign.copy()
|
||||
display_info = f'{filename} '
|
||||
cv2.putText(image_display, display_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2, cv2.LINE_AA)
|
||||
for point, label in zip(input_point, input_label):
|
||||
color = (0, 255, 0) if label == 1 else (0, 0, 255)
|
||||
cv2.circle(image_display, tuple(point), 5, color, -1)
|
||||
if selected_mask is not None:
|
||||
color = tuple(np.random.randint(0, 256, 3).tolist())
|
||||
selected_image = apply_color_mask(image_display, selected_mask, color)
|
||||
|
||||
cv2.imshow("image", image_display)
|
||||
key = cv2.waitKey(1)
|
||||
|
||||
if key == ord(" "):
|
||||
input_point = []
|
||||
input_label = []
|
||||
selected_mask = None
|
||||
logit_input = None
|
||||
elif key == ord("w"):
|
||||
input_stop = True
|
||||
if len(input_point) > 0 and len(input_label) > 0:
|
||||
|
||||
predictor.set_image(image)
|
||||
input_point_np = np.array(input_point)
|
||||
input_label_np = np.array(input_label)
|
||||
|
||||
masks, scores, logits = predictor.predict(
|
||||
point_coords=input_point_np,
|
||||
point_labels=input_label_np,
|
||||
mask_input=logit_input[None, :, :] if logit_input is not None else None,
|
||||
multimask_output=True,
|
||||
)
|
||||
|
||||
mask_idx = 0
|
||||
num_masks = len(masks) # masks的数量
|
||||
while (1):
|
||||
color = tuple(np.random.randint(0, 256, 3).tolist()) # 随机列表颜色,就是
|
||||
image_select = image_orign.copy()
|
||||
selected_mask = masks[mask_idx] # 选择msks也就是,a,d切换
|
||||
selected_image = apply_color_mask(image_select, selected_mask, color)
|
||||
mask_info = f'Total: {num_masks} | Current: {mask_idx} | Score: {scores[mask_idx]:.2f} '
|
||||
cv2.putText(selected_image, mask_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2,
|
||||
cv2.LINE_AA)
|
||||
# todo 显示在当前的图片,
|
||||
cv2.imshow("image", selected_image)
|
||||
|
||||
key = cv2.waitKey(10)
|
||||
if key == ord('q') and len(input_point) > 0:
|
||||
input_point.pop(-1)
|
||||
input_label.pop(-1)
|
||||
elif key == ord('s'):
|
||||
save_masked_image(image_crop, selected_mask, output_dir, filename, crop_mode_=crop_mode, pv=f"pv: {scores[mask_idx]:.2f}")
|
||||
elif key == ord('a'):
|
||||
if mask_idx > 0:
|
||||
mask_idx -= 1
|
||||
else:
|
||||
mask_idx = num_masks - 1
|
||||
elif key == ord('d'):
|
||||
if mask_idx < num_masks - 1:
|
||||
mask_idx += 1
|
||||
else:
|
||||
mask_idx = 0
|
||||
elif key == ord('w'):
|
||||
break
|
||||
elif key == ord(" "):
|
||||
input_point = []
|
||||
input_label = []
|
||||
selected_mask = None
|
||||
logit_input = None
|
||||
break
|
||||
logit_input = logits[mask_idx, :, :]
|
||||
print('max score:', np.argmax(scores), ' select:', mask_idx)
|
||||
|
||||
elif key == ord('a'):
|
||||
current_index = max(0, current_index - 1)
|
||||
input_point = []
|
||||
input_label = []
|
||||
break
|
||||
elif key == ord('d'):
|
||||
current_index = min(len(image_files) - 1, current_index + 1)
|
||||
input_point = []
|
||||
input_label = []
|
||||
break
|
||||
elif key == 27:
|
||||
break
|
||||
elif key == ord('q') and len(input_point) > 0:
|
||||
input_point.pop(-1)
|
||||
input_label.pop(-1)
|
||||
elif key == ord('s') and selected_mask is not None:
|
||||
save_masked_image(image_crop, selected_mask, output_dir, filename, crop_mode_=crop_mode, pv=f"pv: {scores[mask_idx]:.2f}")
|
||||
|
||||
if key == 27:
|
||||
break
|
|
@ -0,0 +1,209 @@
|
|||
import cv2
|
||||
import os
|
||||
import numpy as np
|
||||
import json
|
||||
from segment_anything import sam_model_registry, SamPredictor
|
||||
|
||||
input_dir = r'C:\Users\t2581\Desktop\222\images'
|
||||
output_dir = r'C:\Users\t2581\Desktop\222\2'
|
||||
crop_mode = True
|
||||
|
||||
print('最好是每加一个点就按w键predict一次')
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
image_files = [f for f in os.listdir(input_dir) if
|
||||
f.lower().endswith(('.png', '.jpg', '.jpeg', '.JPG', '.JPEG', '.PNG', '.tiff'))]
|
||||
|
||||
sam = sam_model_registry["vit_b"](checkpoint=r"D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth")
|
||||
_ = sam.to(device="cuda")
|
||||
predictor = SamPredictor(sam)
|
||||
|
||||
WINDOW_WIDTH = 1280
|
||||
WINDOW_HEIGHT = 720
|
||||
cv2.namedWindow("image", cv2.WINDOW_NORMAL)
|
||||
cv2.resizeWindow("image", WINDOW_WIDTH, WINDOW_HEIGHT)
|
||||
cv2.moveWindow("image", (1920 - WINDOW_WIDTH) // 2, (1080 - WINDOW_HEIGHT) // 2)
|
||||
|
||||
def mouse_click(event, x, y, flags, param):
|
||||
global input_point, input_label, input_stop
|
||||
if not input_stop:
|
||||
if event == cv2.EVENT_LBUTTONDOWN:
|
||||
input_point.append([x, y])
|
||||
input_label.append(1)
|
||||
elif event == cv2.EVENT_RBUTTONDOWN:
|
||||
input_point.append([x, y])
|
||||
input_label.append(0)
|
||||
else:
|
||||
if event == cv2.EVENT_LBUTTONDOWN or event == cv2.EVENT_RBUTTONDOWN:
|
||||
print('此时不能添加点,按w退出mask选择模式')
|
||||
|
||||
|
||||
def apply_mask(image, mask, alpha_channel=True):
|
||||
if alpha_channel:
|
||||
alpha = np.zeros_like(mask)
|
||||
alpha[mask == 1] = 255
|
||||
masked_image = image.copy()
|
||||
masked_image[mask == 1] = [0, 255, 0]
|
||||
return cv2.addWeighted(image, 0.5, masked_image, 0.5, 0)
|
||||
else:
|
||||
masked_image = image.copy()
|
||||
masked_image[mask == 1] = [0, 255, 0]
|
||||
return masked_image
|
||||
|
||||
def apply_color_mask(image, mask, color, color_dark=0.5):
|
||||
masked_image = image.copy()
|
||||
for c in range(3):
|
||||
masked_image[:, :, c] = np.where(mask == 1, image[:, :, c] * (1 - color_dark) + color_dark * color[c], image[:, :, c])
|
||||
return masked_image
|
||||
|
||||
def draw_external_rectangle(image, mask, pv):
|
||||
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
for contour in contours:
|
||||
x, y, w, h = cv2.boundingRect(contour)
|
||||
cv2.rectangle(image, (x, y), (x + w, y + h), (0, 255, 255), 2) # Yellow rectangle
|
||||
cv2.putText(image, pv, (x, y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
|
||||
|
||||
def save_masked_image_and_json(image, mask, output_dir, filename, crop_mode_, pv):
|
||||
masked_image = apply_mask(image, mask)
|
||||
draw_external_rectangle(masked_image, mask, pv)
|
||||
masked_filename = filename[:filename.rfind('.')] + '_masked.png'
|
||||
cv2.imwrite(os.path.join(output_dir, masked_filename), masked_image)
|
||||
print(f"Saved image as {masked_filename}")
|
||||
|
||||
# Convert mask to polygons
|
||||
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
polygons = [contour.reshape(-1, 2).tolist() for contour in contours]
|
||||
|
||||
# Create JSON data
|
||||
json_data = {
|
||||
"version": "5.1.1",
|
||||
"flags": {},
|
||||
"shapes": [
|
||||
{
|
||||
"label": "pv",
|
||||
"points": polygon,
|
||||
"group_id": None,
|
||||
"shape_type": "polygon",
|
||||
"flags": {}
|
||||
} for polygon in polygons
|
||||
],
|
||||
"imagePath": filename,
|
||||
"imageData": None,
|
||||
"imageHeight": mask.shape[0],
|
||||
"imageWidth": mask.shape[1]
|
||||
}
|
||||
|
||||
# Save JSON file
|
||||
json_filename = filename[:filename.rfind('.')] + '_masked.json'
|
||||
with open(os.path.join(output_dir, json_filename), 'w') as json_file:
|
||||
json.dump(json_data, json_file, indent=4)
|
||||
print(f"Saved JSON as {json_filename}")
|
||||
|
||||
current_index = 0
|
||||
|
||||
cv2.namedWindow("image")
|
||||
cv2.setMouseCallback("image", mouse_click)
|
||||
input_point = []
|
||||
input_label = []
|
||||
input_stop = False
|
||||
while True:
|
||||
filename = image_files[current_index]
|
||||
image_orign = cv2.imread(os.path.join(input_dir, filename))
|
||||
image_crop = image_orign.copy()
|
||||
image = cv2.cvtColor(image_orign.copy(), cv2.COLOR_BGR2RGB)
|
||||
selected_mask = None
|
||||
logit_input = None
|
||||
while True:
|
||||
input_stop = False
|
||||
image_display = image_orign.copy()
|
||||
display_info = f'{filename} '
|
||||
cv2.putText(image_display, display_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2, cv2.LINE_AA)
|
||||
for point, label in zip(input_point, input_label):
|
||||
color = (0, 255, 0) if label == 1 else (0, 0, 255)
|
||||
cv2.circle(image_display, tuple(point), 5, color, -1)
|
||||
if selected_mask is not None:
|
||||
color = tuple(np.random.randint(0, 256, 3).tolist())
|
||||
selected_image = apply_color_mask(image_display, selected_mask, color)
|
||||
|
||||
cv2.imshow("image", image_display)
|
||||
key = cv2.waitKey(1)
|
||||
|
||||
if key == ord(" "):
|
||||
input_point = []
|
||||
input_label = []
|
||||
selected_mask = None
|
||||
logit_input = None
|
||||
elif key == ord("w"):
|
||||
input_stop = True
|
||||
if len(input_point) > 0 and len(input_label) > 0:
|
||||
|
||||
predictor.set_image(image)
|
||||
input_point_np = np.array(input_point)
|
||||
input_label_np = np.array(input_label)
|
||||
|
||||
masks, scores, logits = predictor.predict(
|
||||
point_coords=input_point_np,
|
||||
point_labels=input_label_np,
|
||||
mask_input=logit_input[None, :, :] if logit_input is not None else None,
|
||||
multimask_output=True,
|
||||
)
|
||||
|
||||
mask_idx = 0
|
||||
num_masks = len(masks) # masks的数量
|
||||
while (1):
|
||||
color = tuple(np.random.randint(0, 256, 3).tolist()) # 随机列表颜色,就是
|
||||
image_select = image_orign.copy()
|
||||
selected_mask = masks[mask_idx] # 选择msks也就是,a,d切换
|
||||
selected_image = apply_color_mask(image_select, selected_mask, color)
|
||||
mask_info = f'Total: {num_masks} | Current: {mask_idx} | Score: {scores[mask_idx]:.2f} '
|
||||
cv2.putText(selected_image, mask_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2,
|
||||
cv2.LINE_AA)
|
||||
# todo 显示在当前的图片,
|
||||
cv2.imshow("image", selected_image)
|
||||
|
||||
key = cv2.waitKey(10)
|
||||
if key == ord('q') and len(input_point) > 0:
|
||||
input_point.pop(-1)
|
||||
input_label.pop(-1)
|
||||
elif key == ord('s'):
|
||||
save_masked_image_and_json(image_crop, selected_mask, output_dir, filename, crop_mode_=crop_mode, pv=f"pv: {scores[mask_idx]:.2f}")
|
||||
elif key == ord('a'):
|
||||
if mask_idx > 0:
|
||||
mask_idx -= 1
|
||||
else:
|
||||
mask_idx = num_masks - 1
|
||||
elif key == ord('d'):
|
||||
if mask_idx < num_masks - 1:
|
||||
mask_idx += 1
|
||||
else:
|
||||
mask_idx = 0
|
||||
elif key == ord('w'):
|
||||
break
|
||||
elif key == ord(" "):
|
||||
input_point = []
|
||||
input_label = []
|
||||
selected_mask = None
|
||||
logit_input = None
|
||||
break
|
||||
logit_input = logits[mask_idx, :, :]
|
||||
print('max score:', np.argmax(scores), ' select:', mask_idx)
|
||||
|
||||
elif key == ord('a'):
|
||||
current_index = max(0, current_index - 1)
|
||||
input_point = []
|
||||
input_label = []
|
||||
break
|
||||
elif key == ord('d'):
|
||||
current_index = min(len(image_files) - 1, current_index + 1)
|
||||
input_point = []
|
||||
input_label = []
|
||||
break
|
||||
elif key == 27:
|
||||
break
|
||||
elif key == ord('q') and len(input_point) > 0:
|
||||
input_point.pop(-1)
|
||||
input_label.pop(-1)
|
||||
elif key == ord('s') and selected_mask is not None:
|
||||
save_masked_image_and_json(image_crop, selected_mask, output_dir, filename, crop_mode_=crop_mode, pv=f"pv: {scores[mask_idx]:.2f}")
|
||||
|
||||
if key == 27:
|
||||
break
|
|
@ -0,0 +1,161 @@
|
|||
import sys
|
||||
import os
|
||||
from PyQt5 import QtCore, QtGui, QtWidgets
|
||||
|
||||
class Ui_MainWindow(object):
|
||||
def setupUi(self, MainWindow):
|
||||
MainWindow.setObjectName("MainWindow")
|
||||
MainWindow.resize(1140, 450)
|
||||
MainWindow.setMinimumSize(QtCore.QSize(1140, 450))
|
||||
MainWindow.setMaximumSize(QtCore.QSize(1140, 450))
|
||||
self.centralwidget = QtWidgets.QWidget(MainWindow)
|
||||
self.centralwidget.setObjectName("centralwidget")
|
||||
self.pushButton_w = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_w.setGeometry(QtCore.QRect(10, 90, 151, 51))
|
||||
self.pushButton_w.setObjectName("pushButton_w")
|
||||
self.pushButton_a = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_a.setGeometry(QtCore.QRect(10, 160, 71, 51))
|
||||
self.pushButton_a.setObjectName("pushButton_a")
|
||||
self.pushButton_d = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_d.setGeometry(QtCore.QRect(90, 160, 71, 51))
|
||||
self.pushButton_d.setObjectName("pushButton_d")
|
||||
self.pushButton_s = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_s.setGeometry(QtCore.QRect(10, 360, 151, 51))
|
||||
self.pushButton_s.setObjectName("pushButton_s")
|
||||
self.pushButton_5 = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_5.setGeometry(QtCore.QRect(10, 230, 151, 51))
|
||||
self.pushButton_5.setObjectName("pushButton_5")
|
||||
self.label_orign = QtWidgets.QLabel(self.centralwidget)
|
||||
self.label_orign.setGeometry(QtCore.QRect(180, 20, 471, 401))
|
||||
self.label_orign.setStyleSheet("background-color: rgb(255, 255, 255);")
|
||||
self.label_orign.setObjectName("label_orign")
|
||||
self.label_2 = QtWidgets.QLabel(self.centralwidget)
|
||||
self.label_2.setGeometry(QtCore.QRect(660, 20, 471, 401))
|
||||
self.label_2.setStyleSheet("background-color: rgb(255, 255, 255);")
|
||||
self.label_2.setObjectName("label_2")
|
||||
self.pushButton_w_2 = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_w_2.setGeometry(QtCore.QRect(10, 20, 151, 51))
|
||||
self.pushButton_w_2.setObjectName("pushButton_w_2")
|
||||
self.lineEdit = QtWidgets.QLineEdit(self.centralwidget)
|
||||
self.lineEdit.setGeometry(QtCore.QRect(50, 290, 81, 21))
|
||||
self.lineEdit.setObjectName("lineEdit")
|
||||
self.horizontalSlider = QtWidgets.QSlider(self.centralwidget)
|
||||
self.horizontalSlider.setGeometry(QtCore.QRect(10, 320, 141, 22))
|
||||
self.horizontalSlider.setSliderPosition(50)
|
||||
self.horizontalSlider.setOrientation(QtCore.Qt.Horizontal)
|
||||
self.horizontalSlider.setTickInterval(0)
|
||||
self.horizontalSlider.setObjectName("horizontalSlider")
|
||||
MainWindow.setCentralWidget(self.centralwidget)
|
||||
self.menubar = QtWidgets.QMenuBar(MainWindow)
|
||||
self.menubar.setGeometry(QtCore.QRect(0, 0, 1140, 23))
|
||||
self.menubar.setObjectName("menubar")
|
||||
MainWindow.setMenuBar(self.menubar)
|
||||
self.statusbar = QtWidgets.QStatusBar(MainWindow)
|
||||
self.statusbar.setObjectName("statusbar")
|
||||
MainWindow.setStatusBar(self.statusbar)
|
||||
|
||||
self.retranslateUi(MainWindow)
|
||||
QtCore.QMetaObject.connectSlotsByName(MainWindow)
|
||||
|
||||
def retranslateUi(self, MainWindow):
|
||||
_translate = QtCore.QCoreApplication.translate
|
||||
MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow"))
|
||||
self.pushButton_w.setText(_translate("MainWindow", "Predict"))
|
||||
self.pushButton_a.setText(_translate("MainWindow", "Pre"))
|
||||
self.pushButton_d.setText(_translate("MainWindow", "Next"))
|
||||
self.pushButton_s.setText(_translate("MainWindow", "Save"))
|
||||
self.pushButton_5.setText(_translate("MainWindow", "背景图"))
|
||||
self.label_orign.setText(_translate("MainWindow", "<html><head/><body><p align=\"center\">原始图像</p></body></html>"))
|
||||
self.label_2.setText(_translate("MainWindow", "<html><head/><body><p align=\"center\">预测图像</p></body></html>"))
|
||||
self.pushButton_w_2.setText(_translate("MainWindow", "Openimg"))
|
||||
self.lineEdit.setText(_translate("MainWindow", "改变mask大小"))
|
||||
|
||||
|
||||
class MyMainWindow(QtWidgets.QMainWindow, Ui_MainWindow):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.setupUi(self)
|
||||
|
||||
self.image_path = ""
|
||||
self.image_folder = ""
|
||||
self.image_files = []
|
||||
self.current_index = 0
|
||||
|
||||
self.pushButton_w_2.clicked.connect(self.open_image_folder)
|
||||
self.pushButton_a.clicked.connect(self.load_previous_image)
|
||||
self.pushButton_d.clicked.connect(self.load_next_image)
|
||||
|
||||
def open_image_folder(self):
|
||||
folder_dialog = QtWidgets.QFileDialog()
|
||||
folder_path = folder_dialog.getExistingDirectory(self, 'Open Image Folder', '')
|
||||
if folder_path:
|
||||
self.image_folder = folder_path
|
||||
self.image_files = self.get_image_files(self.image_folder)
|
||||
if self.image_files:
|
||||
self.show_image_selection_dialog()
|
||||
|
||||
def get_image_files(self, folder_path):
|
||||
image_files = [file for file in os.listdir(folder_path) if file.endswith(('png', 'jpg', 'jpeg', 'bmp'))]
|
||||
return image_files
|
||||
|
||||
def show_image_selection_dialog(self):
|
||||
dialog = QtWidgets.QDialog(self)
|
||||
dialog.setWindowTitle("Select Image")
|
||||
layout = QtWidgets.QVBoxLayout()
|
||||
|
||||
self.listWidget = QtWidgets.QListWidget()
|
||||
for image_file in self.image_files:
|
||||
item = QtWidgets.QListWidgetItem(image_file)
|
||||
pixmap = QtGui.QPixmap(os.path.join(self.image_folder, image_file)).scaledToWidth(100)
|
||||
item.setIcon(QtGui.QIcon(pixmap))
|
||||
self.listWidget.addItem(item)
|
||||
self.listWidget.itemDoubleClicked.connect(self.image_selected)
|
||||
layout.addWidget(self.listWidget)
|
||||
|
||||
buttonBox = QtWidgets.QDialogButtonBox(QtWidgets.QDialogButtonBox.Ok | QtWidgets.QDialogButtonBox.Cancel)
|
||||
buttonBox.accepted.connect(self.image_selected)
|
||||
buttonBox.rejected.connect(dialog.reject)
|
||||
layout.addWidget(buttonBox)
|
||||
|
||||
dialog.setLayout(layout)
|
||||
|
||||
dialog.exec_()
|
||||
|
||||
def image_selected(self):
|
||||
selected_item = self.listWidget.currentItem()
|
||||
if selected_item:
|
||||
selected_index = self.listWidget.currentRow()
|
||||
if selected_index >= 0:
|
||||
self.current_index = selected_index
|
||||
self.show_image()
|
||||
|
||||
def show_image(self):
|
||||
file_path = os.path.join(self.image_folder, self.image_files[self.current_index])
|
||||
pixmap = QtGui.QPixmap(file_path)
|
||||
self.label_orign.setPixmap(pixmap.scaled(self.label_orign.size(), QtCore.Qt.KeepAspectRatio))
|
||||
|
||||
def load_previous_image(self):
|
||||
if self.image_files:
|
||||
if self.current_index > 0:
|
||||
self.current_index -= 1
|
||||
else:
|
||||
self.current_index = len(self.image_files) - 1
|
||||
self.show_image()
|
||||
|
||||
def load_next_image(self):
|
||||
if self.image_files:
|
||||
if self.current_index < len(self.image_files) - 1:
|
||||
self.current_index += 1
|
||||
else:
|
||||
self.current_index = 0
|
||||
self.show_image()
|
||||
|
||||
if __name__ == "__main__":
|
||||
app = QtWidgets.QApplication(sys.argv)
|
||||
mainWindow = MyMainWindow()
|
||||
mainWindow.show()
|
||||
sys.exit(app.exec_())
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,186 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<ui version="4.0">
|
||||
<class>MainWindow</class>
|
||||
<widget class="QMainWindow" name="MainWindow">
|
||||
<property name="geometry">
|
||||
<rect>
|
||||
<x>0</x>
|
||||
<y>0</y>
|
||||
<width>1140</width>
|
||||
<height>450</height>
|
||||
</rect>
|
||||
</property>
|
||||
<property name="minimumSize">
|
||||
<size>
|
||||
<width>1140</width>
|
||||
<height>450</height>
|
||||
</size>
|
||||
</property>
|
||||
<property name="maximumSize">
|
||||
<size>
|
||||
<width>1140</width>
|
||||
<height>450</height>
|
||||
</size>
|
||||
</property>
|
||||
<property name="windowTitle">
|
||||
<string>MainWindow</string>
|
||||
</property>
|
||||
<widget class="QWidget" name="centralwidget">
|
||||
<widget class="QPushButton" name="pushButton_w">
|
||||
<property name="geometry">
|
||||
<rect>
|
||||
<x>10</x>
|
||||
<y>90</y>
|
||||
<width>151</width>
|
||||
<height>51</height>
|
||||
</rect>
|
||||
</property>
|
||||
<property name="text">
|
||||
<string>Predict</string>
|
||||
</property>
|
||||
</widget>
|
||||
<widget class="QPushButton" name="pushButton_a">
|
||||
<property name="geometry">
|
||||
<rect>
|
||||
<x>10</x>
|
||||
<y>160</y>
|
||||
<width>71</width>
|
||||
<height>51</height>
|
||||
</rect>
|
||||
</property>
|
||||
<property name="text">
|
||||
<string>Pre</string>
|
||||
</property>
|
||||
</widget>
|
||||
<widget class="QPushButton" name="pushButton_d">
|
||||
<property name="geometry">
|
||||
<rect>
|
||||
<x>90</x>
|
||||
<y>160</y>
|
||||
<width>71</width>
|
||||
<height>51</height>
|
||||
</rect>
|
||||
</property>
|
||||
<property name="text">
|
||||
<string>Next</string>
|
||||
</property>
|
||||
</widget>
|
||||
<widget class="QPushButton" name="pushButton_s">
|
||||
<property name="geometry">
|
||||
<rect>
|
||||
<x>10</x>
|
||||
<y>360</y>
|
||||
<width>151</width>
|
||||
<height>51</height>
|
||||
</rect>
|
||||
</property>
|
||||
<property name="text">
|
||||
<string>Save</string>
|
||||
</property>
|
||||
</widget>
|
||||
<widget class="QPushButton" name="pushButton_5">
|
||||
<property name="geometry">
|
||||
<rect>
|
||||
<x>10</x>
|
||||
<y>230</y>
|
||||
<width>151</width>
|
||||
<height>51</height>
|
||||
</rect>
|
||||
</property>
|
||||
<property name="text">
|
||||
<string>背景图</string>
|
||||
</property>
|
||||
</widget>
|
||||
<widget class="QLabel" name="label_orign">
|
||||
<property name="geometry">
|
||||
<rect>
|
||||
<x>180</x>
|
||||
<y>20</y>
|
||||
<width>471</width>
|
||||
<height>401</height>
|
||||
</rect>
|
||||
</property>
|
||||
<property name="styleSheet">
|
||||
<string notr="true">background-color: rgb(255, 255, 255);</string>
|
||||
</property>
|
||||
<property name="text">
|
||||
<string><html><head/><body><p align="center">原始图像</p></body></html></string>
|
||||
</property>
|
||||
</widget>
|
||||
<widget class="QLabel" name="label_2">
|
||||
<property name="geometry">
|
||||
<rect>
|
||||
<x>660</x>
|
||||
<y>20</y>
|
||||
<width>471</width>
|
||||
<height>401</height>
|
||||
</rect>
|
||||
</property>
|
||||
<property name="styleSheet">
|
||||
<string notr="true">background-color: rgb(255, 255, 255);</string>
|
||||
</property>
|
||||
<property name="text">
|
||||
<string><html><head/><body><p align="center">预测图像</p></body></html></string>
|
||||
</property>
|
||||
</widget>
|
||||
<widget class="QPushButton" name="pushButton_w_2">
|
||||
<property name="geometry">
|
||||
<rect>
|
||||
<x>10</x>
|
||||
<y>20</y>
|
||||
<width>151</width>
|
||||
<height>51</height>
|
||||
</rect>
|
||||
</property>
|
||||
<property name="text">
|
||||
<string>Openimg</string>
|
||||
</property>
|
||||
</widget>
|
||||
<widget class="QLineEdit" name="lineEdit">
|
||||
<property name="geometry">
|
||||
<rect>
|
||||
<x>50</x>
|
||||
<y>290</y>
|
||||
<width>81</width>
|
||||
<height>21</height>
|
||||
</rect>
|
||||
</property>
|
||||
<property name="text">
|
||||
<string>改变mask大小</string>
|
||||
</property>
|
||||
</widget>
|
||||
<widget class="QSlider" name="horizontalSlider">
|
||||
<property name="geometry">
|
||||
<rect>
|
||||
<x>10</x>
|
||||
<y>320</y>
|
||||
<width>141</width>
|
||||
<height>22</height>
|
||||
</rect>
|
||||
</property>
|
||||
<property name="sliderPosition">
|
||||
<number>50</number>
|
||||
</property>
|
||||
<property name="orientation">
|
||||
<enum>Qt::Horizontal</enum>
|
||||
</property>
|
||||
<property name="tickInterval">
|
||||
<number>0</number>
|
||||
</property>
|
||||
</widget>
|
||||
</widget>
|
||||
<widget class="QMenuBar" name="menubar">
|
||||
<property name="geometry">
|
||||
<rect>
|
||||
<x>0</x>
|
||||
<y>0</y>
|
||||
<width>1140</width>
|
||||
<height>23</height>
|
||||
</rect>
|
||||
</property>
|
||||
</widget>
|
||||
<widget class="QStatusBar" name="statusbar"/>
|
||||
</widget>
|
||||
<resources/>
|
||||
<connections/>
|
||||
</ui>
|
|
@ -0,0 +1,734 @@
|
|||
import math
|
||||
import os
|
||||
import cv2
|
||||
|
||||
"""
|
||||
标注关键点只能存在一个框和多个点,并且不能删除点和删除框,读取本地文件的关键点要保证其中的关键点
|
||||
数和key_point_num的值是一样的,本地标签中如果只存在框的信息就不要使用该脚本标注,不然会出错,
|
||||
本地文件夹中可以有标签,如果有会优先加载本地标签,没有才会创建一个
|
||||
关键点标注
|
||||
这个是默认的标注就是普通标注
|
||||
按Q切换下一张,R清除干扰再标注(把之前的框屏蔽)
|
||||
按Y从本地把图像和标签一切删掉
|
||||
按T 退出
|
||||
鼠标双击可以删除框
|
||||
单机就是拖拽
|
||||
|
||||
"""
|
||||
draw_line_circle = True # True/None 是否在框上绘制点(8个点)
|
||||
key_point_is = None # 是否标记关键点 设置为None标注普通yolo标签
|
||||
# 可以自定义得参数
|
||||
image_path = R'C:\Users\lengdan\Desktop\data1\images\images' # 标注完成保存到的文件夹
|
||||
label_path = R'C:\Users\lengdan\Desktop\data1\images\labels' # 要标注的图像所在文件夹
|
||||
circle_distance = 10 # 半径范围:鼠标进入点的半径范围内会出现光圈
|
||||
key_point_num = 5 # 关键点个数
|
||||
box_thickness = 2 # 框的粗细
|
||||
small_box_thickness = 1 # 框的8个点的粗细
|
||||
label_thickness = 1 # 框上面的类别字体的粗细
|
||||
label_fontScale = 0.4 # 框上面的类别字体的倍数
|
||||
key_thick = -1 # 关键点的粗细
|
||||
key_text_thick = 2 # 关键点上文字粗细
|
||||
key_text_scale = 0.6 # 关键点上文字的放大倍数
|
||||
key_radius = 4 # 关键点绘制半径
|
||||
dot = 6 # 选择保留几位小数
|
||||
|
||||
key_color = {
|
||||
0: (0, 0, 200),
|
||||
1: (255, 0, 0),
|
||||
2: (0, 222, 0)
|
||||
} # 关键点的颜色
|
||||
key_text_color = {
|
||||
0: (0, 100, 200),
|
||||
1: (255, 0, 0),
|
||||
2: (0, 255, 125)
|
||||
} # 关键点上文本的颜色
|
||||
box_color = {
|
||||
0: (125, 125, 125),
|
||||
1: (0, 255, 0),
|
||||
2: (0, 255, 0),
|
||||
3: (255, 0, 0),
|
||||
4: (0, 255, 255),
|
||||
5: (255, 255, 0),
|
||||
6: (255, 0, 255),
|
||||
7: (0, 125, 125),
|
||||
8: (125, 125, 125),
|
||||
9: (125, 125, 125),
|
||||
10: (125, 0, 125),
|
||||
11: (125, 0, 125),
|
||||
12: (125, 0, 125),
|
||||
13: (125, 0, 125),
|
||||
14: (125, 0, 125),
|
||||
15: (125, 0, 125)
|
||||
} # 每个不同类别框的颜色
|
||||
my_cls = {
|
||||
0: '0',
|
||||
1: '1',
|
||||
2: '10',
|
||||
3: '11',
|
||||
4: '12',
|
||||
5: '13',
|
||||
6: '14',
|
||||
7: '15',
|
||||
8: '2',
|
||||
9: '3',
|
||||
10: '4',
|
||||
11: '5',
|
||||
12: '6',
|
||||
13: '7',
|
||||
14: '8',
|
||||
15: '9'
|
||||
} # 添加自己的框的标签,如果没有就用i:'i'替代
|
||||
final_class = {
|
||||
i: my_cls[i] if i in my_cls else str(i) for i in range(16)
|
||||
} # 框的默认名字
|
||||
|
||||
# 不要修改的参数
|
||||
position = None # 这里判断鼠标放到了哪个点上,方便后面移动的时候做计算
|
||||
label = None # 操作图像对应的标签
|
||||
img = None # 操作的图像
|
||||
Mouse_move = None # 选择移动框的标志位
|
||||
label_index = None # 鼠标选中的框在标签中的位置
|
||||
label_index_pos = None # 记录选中了框的8个点位的哪一个
|
||||
Mouse_insert = None # 用来记录是否进入删除状态
|
||||
draw_rectangle = None # 用来记录开始添加新框
|
||||
end_draw_rectangle = None # 用来记录结束绘制新框
|
||||
append_str_temp = None # 用来保存新增加的框的信息
|
||||
empty_label = None # 本地是否存在标签文件标志
|
||||
# 关键点相关的参数
|
||||
key_points = None
|
||||
key_points_move = None
|
||||
key_x = None # 移动关键点的时候记录其每个关键点的x
|
||||
key_y = None # 移动关键点的时候记录其每个关键点的y
|
||||
key_v = None # 移动关键点的时候记录其每个关键点的状态
|
||||
key_box = None
|
||||
box_move = None # 移动的是框的时候的标志位
|
||||
key_insert = None # 对某个关键点双击,切换其状态
|
||||
move_key_point = None # 把其他位置的关键点移动到这个地方
|
||||
la_path = None
|
||||
key_point_one = None # 使用双击移动关键点的时候,记录第一个按下的键
|
||||
key_point_two = None # 使用双击移动关键点的时候,记录第二个按下的键
|
||||
append_new_key_point = None # 增加第二个关键点
|
||||
append_new_key_point_index = 0 # 增加第二个关键点
|
||||
window_w = None # 获取创建窗口的宽度
|
||||
window_h = None # 获取创建窗口的高度
|
||||
|
||||
|
||||
def flag_init():
|
||||
# 初始化下参数
|
||||
global position, label, img, Mouse_insert, Mouse_move, label_index, draw_rectangle, end_draw_rectangle, append_str_temp, empty_label, \
|
||||
label_index_pos, key_v, key_x, key_y, key_x, key_box, key_points_move, key_points, box_move, move_key_point, window_w, window_h
|
||||
position = None
|
||||
Mouse_move = None
|
||||
label_index = None
|
||||
label_index_pos = None
|
||||
Mouse_insert = None
|
||||
draw_rectangle = None
|
||||
end_draw_rectangle = None
|
||||
append_str_temp = None
|
||||
empty_label = None
|
||||
key_points = None
|
||||
key_points_move = None
|
||||
key_x = None
|
||||
key_y = None
|
||||
key_v = None
|
||||
key_box = None
|
||||
box_move = None
|
||||
move_key_point = None
|
||||
window_w = None
|
||||
window_h = None
|
||||
|
||||
|
||||
# 用来绘制小的填充矩形框
|
||||
def draw_rect_box(img, center, length_1, color=(0, 0, 255)):
|
||||
x1, y1 = center[0] - length_1, center[1] - length_1
|
||||
x2, y2 = center[0] + length_1, center[1] + length_1
|
||||
cv2.rectangle(img, (x1, y1), (x2, y2), color, thickness=-1)
|
||||
|
||||
|
||||
# 用来读取本地图像
|
||||
def img_read(img_path, scale_):
|
||||
global window_w, window_h
|
||||
# scale_填写屏幕的最小尺寸
|
||||
image = cv2.imread(img_path)
|
||||
scale_x, scale_y, _ = image.shape
|
||||
if max(scale_x, scale_y) > scale_ and window_w is None:
|
||||
scale = max(scale_x, scale_y) / scale_
|
||||
image = cv2.resize(image, (int(image.shape[1] / scale), int(image.shape[0] / scale)))
|
||||
if window_w is not None:
|
||||
image = cv2.resize(image, (window_w, window_h))
|
||||
return image
|
||||
|
||||
|
||||
# 判断两点的间距,用来判断鼠标所在位置是否进入了8个点所在的区域
|
||||
def distance(p1, p2):
|
||||
global circle_distance
|
||||
if math.sqrt((p2[0] - p1[0]) ** 2 + (p2[1] - p1[1]) ** 2) < circle_distance:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
# 绘制虚线矩形框,当切换到删除时,由实线框转为虚线框
|
||||
def draw_dotted_rectangle(img, pt1, pt2, length_1=5, gap=6, thick=2, color=(100, 254, 100)):
|
||||
(x1, y1), (x2, y2) = pt1, pt2
|
||||
temp1, temp2 = x1, y1
|
||||
while x1 + length_1 < x2:
|
||||
cv2.line(img, (x1, y1), (x1 + length_1, y1), color, thickness=thick)
|
||||
cv2.line(img, (x1, y2), (x1 + length_1, y2), color, thickness=thick)
|
||||
x1 += length_1 + gap
|
||||
while y1 + length_1 < y2:
|
||||
cv2.line(img, (temp1, y1), (temp1, y1 + length_1), color, thickness=thick)
|
||||
cv2.line(img, (x1, y1), (x1, y1 + length_1), color, thickness=thick)
|
||||
y1 += length_1 + gap
|
||||
|
||||
|
||||
# 把本地标签展示到图像中
|
||||
def label_show(img1, label_path, index):
|
||||
global small_box_thickness, box_thickness, label_fontScale, label_thickness, key_point_is, key_points, \
|
||||
key_radius, key_color, key_thick, key_text_scale, key_text_thick, key_text_color, label, draw_line_circle
|
||||
with open(la_path) as f:
|
||||
label = f.readlines()
|
||||
if len(label) == 0:
|
||||
return
|
||||
for i, points in enumerate(label):
|
||||
if key_point_is:
|
||||
# 获取关键点参数
|
||||
key_points = points.split(' ')[5:]
|
||||
points = points.split(' ')[0:5]
|
||||
classify = int(float(points[0]))
|
||||
points.pop(0)
|
||||
point = [float(s.strip('\n')) for s in points]
|
||||
# point = list(map(float, points))
|
||||
scale_y, scale_x, _ = img1.shape
|
||||
x, y, w, h = int((point[0] - point[2] / 2) * scale_x), int(
|
||||
(point[1] - point[3] / 2) * scale_y), int(
|
||||
point[2] * scale_x), int(point[3] * scale_y)
|
||||
if i == index:
|
||||
draw_dotted_rectangle(img1, (x, y), (x + w, y + h), box_thickness)
|
||||
else:
|
||||
cv2.rectangle(img1, (x, y), (x + w, y + h), box_color[classify], thickness=box_thickness)
|
||||
if draw_line_circle:
|
||||
# 绘制边上中心点,与四个顶点,矩形框中心点
|
||||
draw_rect_box(img1, (x, int(0.5 * (y + y + h))), length_1=small_box_thickness)
|
||||
draw_rect_box(img1, (x + w - 1, int(0.5 * (y + y + h))), length_1=small_box_thickness)
|
||||
draw_rect_box(img1, (int(0.5 * (x + x + w)), y), length_1=small_box_thickness)
|
||||
draw_rect_box(img1, (int(0.5 * (x + x + w)), y + h), length_1=small_box_thickness)
|
||||
draw_rect_box(img1, (x, y), length_1=small_box_thickness)
|
||||
draw_rect_box(img1, (x + w, y), length_1=small_box_thickness)
|
||||
draw_rect_box(img1, (x + w, y + h), length_1=small_box_thickness)
|
||||
draw_rect_box(img1, (x, y + h), length_1=small_box_thickness)
|
||||
draw_rect_box(img1, (int(x + 0.5 * w), int(y + 0.5 * h)), length_1=small_box_thickness)
|
||||
cv2.putText(img1, str(final_class[classify]), (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, label_fontScale,
|
||||
(255, 0, 255), label_thickness)
|
||||
if key_point_is:
|
||||
# 依次获取每个关键点
|
||||
key_x = [float(i) for i in key_points[::3]]
|
||||
key_y = [float(i) for i in key_points[1::3]]
|
||||
key_v = [int(float(i)) for i in key_points[2::3]]
|
||||
index = 0
|
||||
key_point = zip(key_x, key_y)
|
||||
for p in key_point:
|
||||
cv2.circle(img, (int(p[0] * scale_x), int(p[1] * scale_y)), key_radius, key_color[key_v[index]],
|
||||
thickness=key_thick,
|
||||
lineType=cv2.LINE_AA)
|
||||
cv2.putText(img, str(index), (int(p[0] * scale_x - 5), int(p[1] * scale_y - 10)),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
key_text_scale, key_text_color[0], key_text_thick)
|
||||
index += 1
|
||||
key_points = None
|
||||
|
||||
|
||||
# 回调函数,用于记录鼠标操作
|
||||
def mouse_event(event, x, y, flag, param):
|
||||
global label, img, position, Mouse_move, label_index, label_index_pos, dot, Mouse_insert, draw_rectangle, \
|
||||
end_draw_rectangle, key_points, key_v, key_x, key_y, key_x, key_box, key_points_move, box_move, \
|
||||
key_insert, label_path, move_key_point
|
||||
scale_y, scale_x, _ = img.shape
|
||||
# 鼠标如果位于8个点左右,即通过position记录当前位置,通过主函数在鼠标附近绘制空心圈
|
||||
# 通过label_index记录鼠标选择了第几个框,通过label_index_pos记录该框第几个点被选中了
|
||||
with open(la_path) as f:
|
||||
label = f.readlines()
|
||||
if move_key_point is None and key_insert is None and Mouse_insert is None and empty_label is None and event == cv2.EVENT_MOUSEMOVE and img is not None and label is not None and \
|
||||
Mouse_move is None:
|
||||
for i, la in enumerate(label):
|
||||
la = la.strip('\n').split(' ')
|
||||
if key_point_is:
|
||||
key_points = list(map(float, la))[5:]
|
||||
la = list(map(float, la))[0:5]
|
||||
x1, y1 = int((la[1] - la[3] / 2) * scale_x), int((la[2] - la[4] / 2) * scale_y)
|
||||
x2, y2 = x1 + int(la[3] * scale_x), y1 + int(la[4] * scale_y)
|
||||
# 这里判断鼠标放到了哪个点上,方便后面移动的时候做计算
|
||||
if distance((x, y), (x1, y1)):
|
||||
label_index_pos = 0
|
||||
position = (x, y)
|
||||
label_index = i
|
||||
box_move = True
|
||||
key_points_move = None
|
||||
break
|
||||
elif distance((x, y), (x2, y2)):
|
||||
label_index_pos = 1
|
||||
position = (x, y)
|
||||
label_index = i
|
||||
box_move = True
|
||||
key_points_move = None
|
||||
break
|
||||
elif distance((x, y), (x1, int(0.5 * y1 + 0.5 * y2))):
|
||||
label_index_pos = 2
|
||||
position = (x, y)
|
||||
label_index = i
|
||||
box_move = True
|
||||
key_points_move = None
|
||||
break
|
||||
elif distance((x, y), (int((x1 + x2) / 2), y2)):
|
||||
label_index_pos = 3
|
||||
position = (x, y)
|
||||
label_index = i
|
||||
box_move = True
|
||||
key_points_move = None
|
||||
break
|
||||
elif distance((x, y), (int((x1 + x2) / 2), y1)):
|
||||
label_index_pos = 4
|
||||
position = (x, y)
|
||||
label_index = i
|
||||
box_move = True
|
||||
key_points_move = None
|
||||
break
|
||||
elif distance((x, y), (x2, int(0.5 * y1 + 0.5 * y2))):
|
||||
label_index_pos = 5
|
||||
position = (x, y)
|
||||
label_index = i
|
||||
box_move = True
|
||||
key_points_move = None
|
||||
break
|
||||
elif distance((x, y), (x1, y2)):
|
||||
label_index_pos = 6
|
||||
position = (x, y)
|
||||
label_index = i
|
||||
box_move = True
|
||||
key_points_move = None
|
||||
break
|
||||
elif distance((x, y), (x2, y1)):
|
||||
label_index_pos = 7
|
||||
position = (x, y)
|
||||
label_index = i
|
||||
box_move = True
|
||||
key_points_move = None
|
||||
break
|
||||
elif distance((x, y), ((x1 + x2) / 2, (y1 + y2) / 2)):
|
||||
# 框中心
|
||||
label_index_pos = 8
|
||||
position = (x, y)
|
||||
label_index = i
|
||||
box_move = True
|
||||
key_points_move = None
|
||||
break
|
||||
else:
|
||||
label_index_pos = None
|
||||
position = None
|
||||
label_index = None
|
||||
if key_point_is:
|
||||
# 判断鼠标是不是放到了关键点上
|
||||
key_x = [float(i) for i in key_points[::3]]
|
||||
key_y = [float(i) for i in key_points[1::3]]
|
||||
key_v = [float(i) for i in key_points[2::3]] # 能见度
|
||||
if len(key_x) == len(key_v) and len(key_x) == len(key_y):
|
||||
for index, key_ in enumerate(key_x):
|
||||
if distance((x, y), (int(key_ * scale_x), int(key_y[index] * scale_y))):
|
||||
position = (x, y)
|
||||
label_index, label_index_pos = i, index
|
||||
key_box = la
|
||||
key_points_move = True
|
||||
box_move = None
|
||||
break
|
||||
|
||||
# 这里到下一个注释都是为了移动已有的框做准备
|
||||
if position is not None and event == cv2.EVENT_LBUTTONDOWN:
|
||||
Mouse_move = True
|
||||
position = None
|
||||
|
||||
# 首先判断鼠标选择了该框的第几个点,然后移动鼠标的时候只负责移动该点
|
||||
if Mouse_move and box_move:
|
||||
# 先把要移动的框的标签记录下来,然后删除,添加到末尾,不断修改末尾标签来达到移动框的目的
|
||||
# temp_label用来记录标签
|
||||
temp_label = label[label_index]
|
||||
label.pop(label_index)
|
||||
temp_label = temp_label.strip('\n').split(' ')
|
||||
temp_label = [float(i) for i in temp_label]
|
||||
x_1, y_1 = (temp_label[1] - 0.5 * temp_label[3]), (temp_label[2] - 0.5 * temp_label[4])
|
||||
x_2, y_2 = x_1 + temp_label[3], y_1 + temp_label[4]
|
||||
# 判断移动的是8个点中的哪个
|
||||
if label_index_pos == 0:
|
||||
x_1, y_1 = x / scale_x, y / scale_y
|
||||
elif label_index_pos == 1:
|
||||
x_2, y_2 = x / scale_x, y / scale_y
|
||||
elif label_index_pos == 2:
|
||||
x_1 = x / scale_x
|
||||
elif label_index_pos == 3:
|
||||
y_2 = y / scale_y
|
||||
elif label_index_pos == 4:
|
||||
y_1 = y / scale_y
|
||||
elif label_index_pos == 5:
|
||||
x_2 = x / scale_x
|
||||
elif label_index_pos == 6:
|
||||
x_1, y_2 = x / scale_x, y / scale_y
|
||||
elif label_index_pos == 7:
|
||||
y_1, x_2 = y / scale_y, x / scale_x
|
||||
elif label_index_pos == 8:
|
||||
x_1, y_1 = x / scale_x - (abs(temp_label[3]) / 2), y / scale_y - (abs(temp_label[4]) / 2)
|
||||
x_2, y_2 = x / scale_x + (abs(temp_label[3]) / 2), y / scale_y + (abs(temp_label[4]) / 2)
|
||||
# 把移动后的点信息保存下来添加到标签中,以此形成动态绘制一个框的效果
|
||||
temp_label[0], temp_label[1], temp_label[2], temp_label[3], temp_label[4] = str(
|
||||
round((int(temp_label[0])), dot)), \
|
||||
str(round(((x_1 + x_2) * 0.5), dot)), str(round(((y_1 + y_2) * 0.5), dot)), str(
|
||||
round((abs(x_1 - x_2)), dot)), str(round((abs(y_1 - y_2)), dot))
|
||||
temp_label = [str(i) for i in temp_label]
|
||||
str_temp = ' '.join(temp_label) + '\n'
|
||||
label.append(str_temp)
|
||||
label_index = len(label) - 1
|
||||
elif Mouse_move and key_points_move:
|
||||
label.pop(label_index)
|
||||
key_x[label_index_pos] = round(x / scale_x, dot)
|
||||
key_y[label_index_pos] = round(y / scale_y, dot)
|
||||
key_box[0] = int(key_box[0])
|
||||
str_temp = ' '.join([str(j) for j in key_box])
|
||||
for index, kx in enumerate(key_x):
|
||||
str_temp += ' ' + str(kx) + ' ' + str(key_y[index]) + ' ' + str(int(key_v[index]))
|
||||
label.append(str_temp)
|
||||
label_index = len(label) - 1
|
||||
|
||||
if Mouse_move and event == cv2.EVENT_LBUTTONUP:
|
||||
flag_init()
|
||||
|
||||
# 这里是为了删除框
|
||||
if key_point_is is None and Mouse_insert is None and position is not None and event == cv2.EVENT_LBUTTONDBLCLK and Mouse_move is None:
|
||||
Mouse_insert = label_index
|
||||
|
||||
if key_point_is and event == cv2.EVENT_LBUTTONDBLCLK and Mouse_move is None and key_points_move and box_move is None:
|
||||
key_insert = label_index_pos
|
||||
|
||||
if key_point_is and event == cv2.EVENT_LBUTTONDBLCLK and Mouse_insert is None and key_insert is None and position is None:
|
||||
move_key_point = (x, y)
|
||||
|
||||
# 这里是为了增加新的框
|
||||
if key_point_is is None and Mouse_insert is None and position is None and Mouse_move is None and event == cv2.EVENT_LBUTTONDOWN and end_draw_rectangle is None:
|
||||
draw_rectangle = [(x, y), (x, y)]
|
||||
|
||||
# 如果鼠标左键一直没有松开,则不断更新第二个点的位置
|
||||
elif Mouse_insert is None and draw_rectangle is not None and event == cv2.EVENT_MOUSEMOVE and end_draw_rectangle is None:
|
||||
draw_rectangle[1] = (x, y)
|
||||
|
||||
# 鼠标松开了,最后记录松开时鼠标的位置,现在则记录了开始和松开鼠标的两个位置
|
||||
# 如果两个位置太近,则不添加
|
||||
elif Mouse_insert is None and draw_rectangle is not None and event == cv2.EVENT_LBUTTONUP:
|
||||
if end_draw_rectangle is None:
|
||||
draw_rectangle[1] = (x, y)
|
||||
if not distance(draw_rectangle[0], draw_rectangle[1]):
|
||||
end_draw_rectangle = True
|
||||
else:
|
||||
draw_rectangle = None
|
||||
|
||||
|
||||
def create_file_key(img_path, label_path):
|
||||
empty_la = None
|
||||
if not os.path.exists(label_path):
|
||||
with open(label_path, 'w') as f:
|
||||
pass
|
||||
empty_la = True
|
||||
with open(label_path) as f:
|
||||
label_ = f.readlines()
|
||||
if len(label_) == 0 or label_[0] == '\n':
|
||||
empty_la = True
|
||||
img_s = img_read(img_path, 950) # 950调整图像的大小
|
||||
if key_point_is and empty_la:
|
||||
box_create = '0 0.5 0.5 0.3 0.3 '
|
||||
len_t = img_s.shape[1] // key_point_num
|
||||
key_num_x = [str(round((i * len_t + 20) / img_s.shape[1], dot)) + ' ' + str(0.5) + ' ' + '2' for i in
|
||||
range(key_point_num)]
|
||||
with open(label_path, 'w') as f:
|
||||
f.write(box_create + ' '.join(key_num_x))
|
||||
|
||||
|
||||
def main(img_path, label_path):
|
||||
global img, position, label, Mouse_insert, draw_rectangle, end_draw_rectangle, append_str_temp, empty_label, \
|
||||
Mouse_move, dot, box_move, key_insert, key_point_one, key_point_two, key_x, key_y, key_v, \
|
||||
move_key_point, append_new_key_point, append_new_key_point_index, window_w, window_h
|
||||
# 判断本地是否存在文件,或者文件中是否为空或者存在一个换行符,就先把标签删除,添加'0 0 0 0 0\n'
|
||||
# 如果不预先添加一个处理起来有点麻烦,这里就先加一个,然后后面删掉就行了
|
||||
if not os.path.exists(label_path):
|
||||
empty_label = True
|
||||
with open(label_path, 'w') as f:
|
||||
pass
|
||||
with open(label_path) as f:
|
||||
label = f.readlines()
|
||||
if len(label) == 0 or label[0] == '\n':
|
||||
empty_label = True
|
||||
# 这里的2是将原图缩小为2分之一
|
||||
print(img_path)
|
||||
img_s = img_read(img_path, 900)
|
||||
if key_point_is and empty_label:
|
||||
box_create = '0 0.5 0.5 0.3 0.3 '
|
||||
len_t = img_s.shape[1] // key_point_num
|
||||
key_num_x = [str(round((i * len_t + 20) / img_s.shape[1], dot)) + ' ' + str(0.5) + ' ' + '2' for i in
|
||||
range(key_point_num)]
|
||||
with open(label_path, 'w') as f:
|
||||
f.write(box_create + ' '.join(key_num_x))
|
||||
label = box_create + ' '.join(key_num_x)
|
||||
# 创建回调函数,绑定窗口
|
||||
cv2.namedWindow('image', cv2.WINDOW_NORMAL)
|
||||
_, _, window_w, window_h = cv2.getWindowImageRect('image')
|
||||
cv2.resizeWindow('image', img_s.shape[1], img_s.shape[0])
|
||||
cv2.setMouseCallback('image', mouse_event)
|
||||
# 刷新图像的地方
|
||||
while True:
|
||||
# 首先读取下标签,用来初始化显示
|
||||
with open(label_path, 'w') as f:
|
||||
for i in label:
|
||||
f.write(i)
|
||||
# 如果鼠标选中了框的8个点之一,就在鼠标周围绘制空心圈
|
||||
if Mouse_insert is None and draw_rectangle is None and position is not None and key_insert is None:
|
||||
img = img_s.copy()
|
||||
label_show(img, label_path, Mouse_insert)
|
||||
cv2.circle(img, position, 10, (0, 255, 100), 2)
|
||||
# 如果选择开始增加新的框,则不断绘制鼠标起始点和移动过程之间形成的框
|
||||
elif draw_rectangle is not None and end_draw_rectangle is None:
|
||||
img = img_s.copy()
|
||||
label_show(img, label_path, Mouse_insert)
|
||||
cv2.rectangle(img, draw_rectangle[0], draw_rectangle[1], color=box_color[1], thickness=2)
|
||||
# 当松开鼠标后,记录两点位置,并提示选择类别
|
||||
elif draw_rectangle is not None and end_draw_rectangle:
|
||||
scale_y, scale_x, _ = img.shape
|
||||
x1, y1 = draw_rectangle[0]
|
||||
x2, y2 = draw_rectangle[1]
|
||||
w1, h1 = abs(x2 - x1), abs(y2 - y1)
|
||||
append_str_temp = str(round((x1 + x2) / 2 / scale_x, dot)) + ' ' + str(
|
||||
round((y1 + y2) / 2 / scale_y, dot)) + ' ' + \
|
||||
str(round((w1 / scale_x), dot)) + ' ' + str(round((h1 / scale_y), dot)) + '\n'
|
||||
cv2.putText(img, 'choose your classify', (0, img.shape[0] // 2 - 30),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.7, (255, 0, 255), 2)
|
||||
cv2.putText(img, ' '.join([str(i) + ':' + my_cls[i] for i in my_cls]), (0, img.shape[0] // 2 + 30),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.7, (100, 255, 255), 2)
|
||||
elif key_insert is not None:
|
||||
position, Mouse_move, box_move = None, None, None # 禁用其他操作
|
||||
cv2.putText(img, 'Switching visibility: 0 1 2', (0, img.shape[0] // 2 - 30),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
1, (100, 255, 255), 2,
|
||||
lineType=cv2.LINE_AA)
|
||||
elif move_key_point is not None:
|
||||
position, Mouse_move, box_move = None, None, None # 禁用其他操作
|
||||
cv2.putText(img, 'choose point: 0 - {}'.format(key_point_num - 1), (0, img.shape[0] // 2 - 30),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
1, (100, 255, 255), 2,
|
||||
lineType=cv2.LINE_AA)
|
||||
# 如果什么标志都没有,就正常显示一个图
|
||||
else:
|
||||
img = img_s.copy()
|
||||
if Mouse_insert is not None:
|
||||
position, Mouse_move = None, None
|
||||
cv2.putText(img, 'delete: W, exit: E', (0, img.shape[0] // 2 - 30),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
1, (100, 255, 255), 2,
|
||||
lineType=cv2.LINE_AA)
|
||||
label_show(img, label_path, Mouse_insert)
|
||||
cv2.imshow('image', img)
|
||||
|
||||
# key用来获取键盘输入
|
||||
key = cv2.waitKey(10)
|
||||
# 输入为Q则退出
|
||||
if key == ord('Q'):
|
||||
append_new_key_point = None
|
||||
# 退出按键
|
||||
break
|
||||
if move_key_point is not None and key_point_one is None and 48 <= key <= 57:
|
||||
key_point_one = int(chr(key))
|
||||
key = 0
|
||||
if move_key_point is not None and key_point_two is None and 48 <= key <= 57:
|
||||
key_point_two = int(chr(key))
|
||||
key = 0
|
||||
if (move_key_point is not None) and (key_point_one is not None) and (key_point_two is not None):
|
||||
with open(la_path) as f:
|
||||
label = f.readlines()
|
||||
for i, la in enumerate(label):
|
||||
la = la.strip('\n').split(' ')
|
||||
key_points_ = list(map(float, la))[5:]
|
||||
key_box_ = list(map(float, la))[0:5]
|
||||
key_x_ = [float(i) for i in key_points_[::3]]
|
||||
key_y_ = [float(i) for i in key_points_[1::3]]
|
||||
key_v_ = [float(i) for i in key_points_[2::3]] # 能见度
|
||||
key_box_[0] = int(key_box_[0])
|
||||
index_ = key_point_one * 10 + key_point_two
|
||||
if index_ >= key_point_num:
|
||||
break
|
||||
key_x_[index_] = round(move_key_point[0] / img.shape[1], dot)
|
||||
key_y_[index_] = round(move_key_point[1] / img.shape[0], dot)
|
||||
str_temp = ' '.join([str(j) for j in key_box_])
|
||||
for index, kx in enumerate(key_x_):
|
||||
str_temp += ' ' + str(kx) + ' ' + str(key_y_[index]) + ' ' + str(int(key_v_[index]))
|
||||
label = str_temp
|
||||
with open(la_path, 'w') as f:
|
||||
f.write(str_temp)
|
||||
move_key_point, key_point_one, key_point_two = None, None, None
|
||||
break
|
||||
move_key_point, key_point_one, key_point_two = None, None, None
|
||||
# 如果按键输入为W则删除选中的框
|
||||
if Mouse_insert is not None and key == ord('W'):
|
||||
label.pop(Mouse_insert)
|
||||
Mouse_insert = None
|
||||
elif key_insert is not None and key == ord('0'):
|
||||
with open(label_path, 'r') as f:
|
||||
label_temp = f.read()
|
||||
str_temp = label_temp.split(' ')
|
||||
str_temp[3 * int(key_insert) + 7] = '0'
|
||||
str_temp[3 * int(key_insert) + 7 - 1] = '0'
|
||||
str_temp[3 * int(key_insert) + 7 - 2] = '0'
|
||||
with open(label_path, 'w') as f:
|
||||
f.write(' '.join(str_temp))
|
||||
label = ' '.join(str_temp)
|
||||
key_insert = None
|
||||
elif key_insert is not None and key == ord('1'):
|
||||
with open(label_path, 'r') as f:
|
||||
label_temp = f.read()
|
||||
str_temp = label_temp.split(' ')
|
||||
str_temp[3 * int(key_insert) + 7] = '1'
|
||||
with open(label_path, 'w') as f:
|
||||
f.write(' '.join(str_temp))
|
||||
label = ' '.join(str_temp)
|
||||
key_insert = None
|
||||
elif key_insert is not None and key == ord('2'):
|
||||
with open(label_path, 'r') as f:
|
||||
label_temp = f.read()
|
||||
str_temp = label_temp.split(' ')
|
||||
str_temp[3 * int(key_insert) + 7] = '2'
|
||||
with open(label_path, 'w') as f:
|
||||
f.write(' '.join(str_temp))
|
||||
label = ' '.join(str_temp)
|
||||
key_insert = None
|
||||
# 如果输入为E则从选中框的状态退出
|
||||
elif key == ord('E'):
|
||||
Mouse_insert = None
|
||||
# 通过键盘获取输入的类别
|
||||
elif Mouse_move is None and Mouse_insert is None and draw_rectangle is not None and end_draw_rectangle is not None \
|
||||
and (48 <= key <= 57 or key == ord('Z') or key == ord('X') or key == ord('C') or key == ord('V') or key==ord('B') or key==ord('N')):
|
||||
if 48 <= key <= 57:
|
||||
str_temp = str(chr(key)) + ' ' + append_str_temp
|
||||
elif key == ord('Z'):
|
||||
str_temp = str(10) + ' ' + append_str_temp
|
||||
elif key == ord('X'):
|
||||
str_temp = str(11) + ' ' + append_str_temp
|
||||
elif key == ord('C'):
|
||||
str_temp = str(12) + ' ' + append_str_temp
|
||||
elif key == ord('V'):
|
||||
str_temp = str(13) + ' ' + append_str_temp
|
||||
elif key == ord('B'):
|
||||
str_temp = str(14) + ' ' + append_str_temp
|
||||
elif key == ord('N'):
|
||||
str_temp = str(15) + ' ' + append_str_temp
|
||||
label.append(str_temp)
|
||||
append_str_temp, draw_rectangle, end_draw_rectangle, empty_label = None, None, None, None
|
||||
elif key == ord('R'):
|
||||
flag_init()
|
||||
append_new_key_point = True
|
||||
break
|
||||
elif key == ord('T'):
|
||||
exit(0)
|
||||
elif key == ord('Y'):
|
||||
os.remove(img_path)
|
||||
os.remove(label_path)
|
||||
break
|
||||
|
||||
|
||||
def delete_line_feed(label_path):
|
||||
# 去掉最后一行的换行符'\n',保存的时候需要
|
||||
if os.path.exists(label_path):
|
||||
with open(label_path) as f:
|
||||
label_ = f.read()
|
||||
label_ = label_.rstrip('\n')
|
||||
with open(label_path, 'w') as f:
|
||||
f.write(label_)
|
||||
|
||||
|
||||
def append__line_feed(label_path):
|
||||
# 加上最后一行的换行符'\n',标注的时候增加新的框的时候需要
|
||||
with open(label_path) as f:
|
||||
label_ = f.read()
|
||||
if len(label_) < 4:
|
||||
with open(label_path, 'w') as f:
|
||||
pass
|
||||
return
|
||||
label_ = label_.rstrip('\n') + '\n'
|
||||
with open(label_path, 'w') as f:
|
||||
f.write(label_)
|
||||
|
||||
|
||||
def key_check(label_path):
|
||||
# 检查开启关键点之后本地标签是否满足要求, 如果本地标签中和预设关键点数不等以及关键点数量不是3的倍数都会将原有标签重置
|
||||
if os.path.exists(label_path):
|
||||
with open(label_path) as f:
|
||||
label_ = f.readlines()
|
||||
for label_ in label_:
|
||||
label_ = label_.strip('\n').split(' ')
|
||||
if ((len(label_) - 5) % 3) or ((len(label_) - 5) // 3 - key_point_num):
|
||||
with open(label_path, 'w') as f:
|
||||
pass
|
||||
|
||||
|
||||
def label_check(label_path):
|
||||
# 检查普通标签,判断每行是包含5个数值
|
||||
if os.path.exists(label_path):
|
||||
with open(label_path) as f:
|
||||
label_ = f.readlines()
|
||||
for i in label_:
|
||||
i = i.strip('\n').split(' ')
|
||||
if len(i) - 5 != 0:
|
||||
with open(label_path, 'w'):
|
||||
pass
|
||||
|
||||
|
||||
def merge_file_key(la_path, index):
|
||||
with open(la_path) as f:
|
||||
text = f.read().strip('\n')
|
||||
for i in range(index):
|
||||
with open(la_path.split('.')[0] + str(i) + '.txt') as f:
|
||||
text += '\n' + f.read().strip('\n')
|
||||
os.remove(la_path.split('.')[0] + str(i) + '.txt')
|
||||
with open(la_path, 'w') as f:
|
||||
f.write(text)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
image_ = os.listdir(image_path)
|
||||
for im in image_:
|
||||
flag_init()
|
||||
im_path = os.path.join(image_path, im)
|
||||
la_path = os.path.join(label_path, im.split('.')[0] + '.txt')
|
||||
if key_point_is:
|
||||
key_check(la_path) # 检查本地标签的关键点数量是否和预设的关键点数量相等,以及去除框的5点后点数是否满足为3的倍数
|
||||
create_file_key(im_path, la_path)
|
||||
else:
|
||||
delete_line_feed(la_path)
|
||||
label_check(la_path)
|
||||
if os.path.exists(la_path):
|
||||
# 先增加一个换行符为了后面的增加框的操作
|
||||
append__line_feed(la_path)
|
||||
while True:
|
||||
main(im_path, la_path)
|
||||
if append_new_key_point is None:
|
||||
break
|
||||
else:
|
||||
la_path = os.path.join(label_path, im.split('.')[0] + str(append_new_key_point_index) + '.txt')
|
||||
with open(la_path, 'w') as f:
|
||||
pass
|
||||
if key_point_is:
|
||||
key_check(la_path) # 检查本地标签的关键点数量是否和预设的关键点数量相等,以及去除框的5点后点数是否满足为3的倍数
|
||||
create_file_key(im_path, la_path)
|
||||
else:
|
||||
delete_line_feed(la_path)
|
||||
label_check(la_path)
|
||||
append_new_key_point_index += 1
|
||||
if append_new_key_point_index != 0:
|
||||
merge_file_key(os.path.join(label_path, im.split('.')[0] + '.txt'), append_new_key_point_index)
|
||||
append_new_key_point_index = 0
|
||||
if os.path.exists(la_path):
|
||||
# 去掉最后一行的换行符
|
||||
delete_line_feed(la_path)
|
|
@ -0,0 +1,40 @@
|
|||
from PIL import Image
|
||||
|
||||
|
||||
def crop_image_into_nine_parts(image_path):
|
||||
# 打开图片
|
||||
image = Image.open(image_path)
|
||||
|
||||
# 获取图片的宽度和高度
|
||||
width, height = image.size
|
||||
|
||||
# 计算裁剪后小图片的宽度和高度
|
||||
part_width = width // 3
|
||||
part_height = height // 3
|
||||
|
||||
parts = []
|
||||
|
||||
# 循环裁剪九等分的小图片
|
||||
for i in range(3):
|
||||
for j in range(3):
|
||||
left = j * part_width
|
||||
top = i * part_height
|
||||
right = (j + 1) * part_width
|
||||
bottom = (i + 1) * part_height
|
||||
|
||||
# 裁剪小图片
|
||||
part = image.crop((left, top, right, bottom))
|
||||
parts.append(part)
|
||||
|
||||
return parts
|
||||
|
||||
|
||||
# 图片路径
|
||||
image_path = "scripts/input/images/725.jpg"
|
||||
|
||||
# 调用函数裁剪图片
|
||||
nine_parts = crop_image_into_nine_parts(image_path)
|
||||
|
||||
# 保存裁剪后的小图片
|
||||
for idx, part in enumerate(nine_parts):
|
||||
part.save(f"part_{idx + 1}.jpg")
|
|
@ -0,0 +1,91 @@
|
|||
import time
|
||||
from pathlib import Path
|
||||
import cv2
|
||||
import torch
|
||||
from PIL import ImageGrab
|
||||
import numpy as np
|
||||
import os
|
||||
import glob
|
||||
from models.common import DetectMultiBackend
|
||||
from utils.general import non_max_suppression, scale_boxes
|
||||
from utils.plots import Annotator, colors
|
||||
from utils.augmentations import letterbox
|
||||
from utils.torch_utils import select_device
|
||||
|
||||
FILE = Path(__file__).resolve()
|
||||
folder_path = R'C:\Users\lengdan\Desktop\yolov5-master\data\images' # 本地文件夹
|
||||
camera = 2 # 0调用本地相机, 1检测文件夹中的图像, 2检测屏幕上内容
|
||||
|
||||
|
||||
class mydataload:
|
||||
def __init__(self):
|
||||
self.count = 0
|
||||
if camera == 0:
|
||||
self.cap = cv2.VideoCapture(0)
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if camera == 0:
|
||||
_, im0 = self.cap.read()
|
||||
elif camera == 1:
|
||||
file_list = glob.glob(os.path.join(folder_path, '*.jpg'))
|
||||
# 对文件列表按时间戳排序,以确保最新添加的图像排在最后面
|
||||
file_list.sort(key=os.path.getmtime)
|
||||
im0 = cv2.imread(file_list[self.count])
|
||||
self.count = self.count if self.count == len(file_list) - 1 else self.count + 1
|
||||
else: # camera 检测屏幕上的内容
|
||||
# 指定截图区域的左上角和右下角坐标
|
||||
x1, y1 = 1000, 100 # 左上角
|
||||
x2, y2 = 1900, 1000 # 右下角
|
||||
# 截取屏幕区域
|
||||
img = ImageGrab.grab(bbox=(x1, y1, x2, y2))
|
||||
im0 = np.array(img)
|
||||
im = letterbox(im0, 640, auto=True)[0] # padded resize
|
||||
im = im.transpose((2, 0, 1)) # HWC to CHW, BGR to RGB
|
||||
if camera != 2:
|
||||
im = im[::-1]
|
||||
im = np.ascontiguousarray(im) # contiguous
|
||||
return im, im0
|
||||
|
||||
|
||||
def get_image(model, im, im0s, conf_thres=0.5, iou_thres=0.5, line_thickness=3):
|
||||
# temp_list = []
|
||||
pred = model(im, visualize=False)
|
||||
pred = non_max_suppression(pred, conf_thres, iou_thres, None, False, max_det=1000)
|
||||
for i, det in enumerate(pred):
|
||||
im0, names = im0s.copy(), model.names
|
||||
annotator = Annotator(im0, line_width=line_thickness, example=str(names))
|
||||
if len(det):
|
||||
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()
|
||||
for *xyxy, conf, cls in reversed(det):
|
||||
c = int(cls)
|
||||
# temp_list.append((xyxy[0].item(), xyxy[1].item(), xyxy[2].item(), xyxy[3].item()))
|
||||
label = f'{names[c]} {conf:.2f}'
|
||||
annotator.box_label(xyxy, label, color=colors(c, True))
|
||||
return annotator.result()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
device = select_device('0')
|
||||
model = DetectMultiBackend('7_29_last.pt', device=device, dnn=False, data='', fp16=False)
|
||||
dataset = mydataload()
|
||||
model.warmup(imgsz=(1, 3, 640, 640)) # warmup
|
||||
for im, im0s in dataset:
|
||||
t0 = time.time()
|
||||
im = torch.from_numpy(im).to(model.device)
|
||||
im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
|
||||
im /= 255 # 0 - 255 to 0.0 - 1.0
|
||||
if len(im.shape) == 3:
|
||||
im = im[None] # expand for batch dim
|
||||
im0 = get_image(model, im, im0s)
|
||||
|
||||
if camera == 2:
|
||||
im0 = cv2.cvtColor(im0, cv2.COLOR_BGR2RGB)
|
||||
cv2.namedWindow('1', cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux)
|
||||
cv2.resizeWindow('1', im0.shape[1] // 2, im0.shape[0] // 2)
|
||||
cv2.imshow('1', im0)
|
||||
if cv2.waitKey(1) == ord('Q'):
|
||||
exit(0)
|
||||
print(time.time() - t0)
|
|
@ -0,0 +1,169 @@
|
|||
import sys
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PyQt5 import QtCore, QtGui, QtWidgets
|
||||
from segment_anything import sam_model_registry, SamPredictor
|
||||
|
||||
|
||||
class Ui_MainWindow(object):
|
||||
def setupUi(self, MainWindow):
|
||||
MainWindow.setObjectName("MainWindow")
|
||||
MainWindow.resize(1333, 657)
|
||||
self.centralwidget = QtWidgets.QWidget(MainWindow)
|
||||
self.centralwidget.setObjectName("centralwidget")
|
||||
self.pushButton_init = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_init.setGeometry(QtCore.QRect(10, 30, 141, 41))
|
||||
self.pushButton_init.setObjectName("pushButton_init")
|
||||
self.pushButton_openimg = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_openimg.setGeometry(QtCore.QRect(10, 90, 141, 41))
|
||||
self.pushButton_openimg.setObjectName("pushButton_openimg")
|
||||
self.pushButton_Fusionimg = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_Fusionimg.setGeometry(QtCore.QRect(10, 270, 141, 41))
|
||||
self.pushButton_Fusionimg.setObjectName("pushButton_Fusionimg")
|
||||
self.pushButton_exit = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_exit.setGeometry(QtCore.QRect(10, 570, 141, 41))
|
||||
self.pushButton_exit.setObjectName("pushButton_exit")
|
||||
self.pushButton_Transparency = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_Transparency.setGeometry(QtCore.QRect(10, 380, 141, 41))
|
||||
self.pushButton_Transparency.setObjectName("pushButton_Transparency")
|
||||
self.pushButton_copymask = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_copymask.setGeometry(QtCore.QRect(10, 450, 141, 41))
|
||||
self.pushButton_copymask.setObjectName("pushButton_copymask")
|
||||
self.pushButton_saveimg = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_saveimg.setGeometry(QtCore.QRect(10, 510, 141, 41))
|
||||
self.pushButton_saveimg.setObjectName("pushButton_saveimg")
|
||||
self.horizontalSlider = QtWidgets.QSlider(self.centralwidget)
|
||||
self.horizontalSlider.setGeometry(QtCore.QRect(10, 330, 141, 22))
|
||||
self.horizontalSlider.setOrientation(QtCore.Qt.Horizontal)
|
||||
self.horizontalSlider.setObjectName("horizontalSlider")
|
||||
self.horizontalSlider.setValue(50)
|
||||
self.label_Originalimg = QtWidgets.QLabel(self.centralwidget)
|
||||
self.label_Originalimg.setGeometry(QtCore.QRect(160, 30, 571, 581))
|
||||
self.label_Originalimg.setStyleSheet("background-color: rgb(255, 255, 255);")
|
||||
self.label_Originalimg.setObjectName("label_Originalimg")
|
||||
self.label_Maskimg = QtWidgets.QLabel(self.centralwidget)
|
||||
self.label_Maskimg.setGeometry(QtCore.QRect(740, 30, 581, 581))
|
||||
self.label_Maskimg.setStyleSheet("background-color: rgb(255, 255, 255);")
|
||||
self.label_Maskimg.setObjectName("label_Maskimg")
|
||||
self.pushButton_shang = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_shang.setGeometry(QtCore.QRect(10, 150, 141, 41))
|
||||
self.pushButton_shang.setObjectName("pushButton_shang")
|
||||
self.pushButton_xia = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_xia.setGeometry(QtCore.QRect(10, 210, 141, 41))
|
||||
self.pushButton_xia.setObjectName("pushButton_openimg_3")
|
||||
MainWindow.setCentralWidget(self.centralwidget)
|
||||
self.menubar = QtWidgets.QMenuBar(MainWindow)
|
||||
self.menubar.setGeometry(QtCore.QRect(0, 0, 1333, 26))
|
||||
self.menubar.setObjectName("menubar")
|
||||
MainWindow.setMenuBar(self.menubar)
|
||||
self.statusbar = QtWidgets.QStatusBar(MainWindow)
|
||||
self.statusbar.setObjectName("statusbar")
|
||||
MainWindow.setStatusBar(self.statusbar)
|
||||
|
||||
self.retranslateUi(MainWindow)
|
||||
QtCore.QMetaObject.connectSlotsByName(MainWindow)
|
||||
|
||||
def retranslateUi(self, MainWindow):
|
||||
|
||||
_translate = QtCore.QCoreApplication.translate
|
||||
MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow"))
|
||||
self.pushButton_init.setText(_translate("MainWindow", "重置选择"))
|
||||
self.pushButton_openimg.setText(_translate("MainWindow", "打开图片"))
|
||||
self.pushButton_shang.setText(_translate("MainWindow", "上一张"))
|
||||
self.pushButton_xia.setText(_translate("MainWindow", "下一张"))
|
||||
self.pushButton_Fusionimg.setText(_translate("MainWindow", "融合背景图片"))
|
||||
self.pushButton_exit.setText(_translate("MainWindow", "退出"))
|
||||
self.pushButton_Transparency.setText(_translate("MainWindow", "调整透明度"))
|
||||
self.pushButton_copymask.setText(_translate("MainWindow", "复制掩码"))
|
||||
self.pushButton_saveimg.setText(_translate("MainWindow", "保存图片"))
|
||||
self.label_Originalimg.setText(
|
||||
_translate("MainWindow", "<html><head/><body><p align=\"center\">原始图像</p></body></html>"))
|
||||
self.label_Maskimg.setText(
|
||||
_translate("MainWindow", "<html><head/><body><p align=\"center\">掩码图像</p></body></html>"))
|
||||
|
||||
|
||||
def init_slots(self):
|
||||
self.pushButton_openimg.clicked.connect(self.button_image_open)
|
||||
self.pushButton_init.clicked.connect(self.button_image_init)
|
||||
self.pushButton_shang.clicked.connect(self.button_image_shang)
|
||||
self.pushButton_xia.clicked.connect(self.button_image_xia)
|
||||
self.pushButton_Fusionimg.clicked.connect(self.button_image_Fusionimg)
|
||||
self.pushButton_copymask.clicked.connect(self.button_image_copymask)
|
||||
self.pushButton_saveimg.clicked.connect(self.button_image_saveimg)
|
||||
self.pushButton_exit.clicked.connect(self.button_image_exit)
|
||||
self.horizontalSlider.valueChanged.connect(self.slider_value_changed)
|
||||
self.pushButton_openimg.clicked.connect(self.button_image_open)
|
||||
# 连接label_Originalimg点击事件
|
||||
self.label_Originalimg.mousePressEvent = self.label_Originalimg_click_event
|
||||
|
||||
def label_Originalimg_click_event(self, event):
|
||||
# 捕获label_Originalimg中的鼠标点击事件
|
||||
point = event.pos()
|
||||
x, y = point.x(), point.y()
|
||||
print("Clicked at:", x, y) # 您可以使用这些坐标进行预测
|
||||
def button_image_open(self):
|
||||
choice = QtWidgets.QMessageBox.question(None, "选择", "您想要打开文件夹还是选择一个图片文件?",
|
||||
QtWidgets.QMessageBox.Open | QtWidgets.QMessageBox.Cancel)
|
||||
if choice == QtWidgets.QMessageBox.Open:
|
||||
folder_path = QtWidgets.QFileDialog.getExistingDirectory(None, "选择文件夹", "")
|
||||
if folder_path:
|
||||
image_files = [os.path.join(folder_path, file) for file in os.listdir(folder_path)
|
||||
if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp'))]
|
||||
if image_files:
|
||||
self.image_files = image_files
|
||||
self.current_index = 0
|
||||
self.display_image()
|
||||
elif choice == QtWidgets.QMessageBox.Cancel:
|
||||
selected_image, _ = QtWidgets.QFileDialog.getOpenFileName(None, "选择图片", "",
|
||||
"Image files (*.png *.jpg *.jpeg *.bmp)")
|
||||
if selected_image:
|
||||
self.image_files = [selected_image]
|
||||
self.current_index = 0
|
||||
self.display_image()
|
||||
|
||||
def display_image(self):
|
||||
if hasattr(self, 'image_files') and self.image_files:
|
||||
pixmap = QtGui.QPixmap(self.image_files[self.current_index])
|
||||
self.label_Originalimg.setPixmap(pixmap)
|
||||
self.label_Originalimg.setScaledContents(True)
|
||||
|
||||
def button_image_shang(self):
|
||||
if hasattr(self, 'image_files') and self.image_files:
|
||||
self.current_index = (self.current_index - 1) % len(self.image_files)
|
||||
self.display_image()
|
||||
|
||||
def button_image_xia(self):
|
||||
if hasattr(self, 'image_files') and self.image_files:
|
||||
self.current_index = (self.current_index + 1) % len(self.image_files)
|
||||
self.display_image()
|
||||
|
||||
def button_image_exit(self):
|
||||
sys.exit()
|
||||
|
||||
def slider_value_changed(self, value):
|
||||
print("Slider value changed:", value)
|
||||
|
||||
def button_image_saveimg(self):
|
||||
pass
|
||||
|
||||
def button_image_Fusionimg(self):
|
||||
pass
|
||||
|
||||
def button_image_copymask(self):
|
||||
pass
|
||||
|
||||
def button_image_init(self):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app = QtWidgets.QApplication(sys.argv)
|
||||
MainWindow = QtWidgets.QMainWindow()
|
||||
ui = Ui_MainWindow()
|
||||
ui.setupUi(MainWindow)
|
||||
ui.init_slots() # 调用init_slots以连接信号和槽
|
||||
MainWindow.show()
|
||||
sys.exit(app.exec_())
|
||||
|
|
@ -0,0 +1,32 @@
|
|||
#!/bin/bash -e
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
|
||||
{
|
||||
black --version | grep -E "23\." > /dev/null
|
||||
} || {
|
||||
echo "Linter requires 'black==23.*' !"
|
||||
exit 1
|
||||
}
|
||||
|
||||
ISORT_VERSION=$(isort --version-number)
|
||||
if [[ "$ISORT_VERSION" != 5.12* ]]; then
|
||||
echo "Linter requires isort==5.12.0 !"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Running isort ..."
|
||||
isort . --atomic
|
||||
|
||||
echo "Running black ..."
|
||||
black -l 100 .
|
||||
|
||||
echo "Running flake8 ..."
|
||||
if [ -x "$(command -v flake8)" ]; then
|
||||
flake8 .
|
||||
else
|
||||
python3 -m flake8 .
|
||||
fi
|
||||
|
||||
echo "Running mypy..."
|
||||
|
||||
mypy --exclude 'setup.py|notebooks' .
|
|
@ -0,0 +1,113 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<ui version="4.0">
|
||||
<class>MainWindow</class>
|
||||
<widget class="QMainWindow" name="MainWindow">
|
||||
<property name="geometry">
|
||||
<rect>
|
||||
<x>0</x>
|
||||
<y>0</y>
|
||||
<width>889</width>
|
||||
<height>600</height>
|
||||
</rect>
|
||||
</property>
|
||||
<property name="windowTitle">
|
||||
<string>MainWindow</string>
|
||||
</property>
|
||||
<widget class="QWidget" name="centralwidget">
|
||||
<widget class="QLabel" name="label">
|
||||
<property name="geometry">
|
||||
<rect>
|
||||
<x>10</x>
|
||||
<y>120</y>
|
||||
<width>421</width>
|
||||
<height>331</height>
|
||||
</rect>
|
||||
</property>
|
||||
<property name="styleSheet">
|
||||
<string notr="true">background-color: rgb(255, 255, 255);</string>
|
||||
</property>
|
||||
<property name="text">
|
||||
<string>原始图像</string>
|
||||
</property>
|
||||
<property name="alignment">
|
||||
<set>Qt::AlignCenter</set>
|
||||
</property>
|
||||
</widget>
|
||||
<widget class="QLabel" name="label_2">
|
||||
<property name="geometry">
|
||||
<rect>
|
||||
<x>440</x>
|
||||
<y>120</y>
|
||||
<width>421</width>
|
||||
<height>331</height>
|
||||
</rect>
|
||||
</property>
|
||||
<property name="styleSheet">
|
||||
<string notr="true">background-color: rgb(255, 255, 255);</string>
|
||||
</property>
|
||||
<property name="text">
|
||||
<string>预测图像</string>
|
||||
</property>
|
||||
<property name="alignment">
|
||||
<set>Qt::AlignCenter</set>
|
||||
</property>
|
||||
</widget>
|
||||
<widget class="QPushButton" name="pushButton">
|
||||
<property name="geometry">
|
||||
<rect>
|
||||
<x>150</x>
|
||||
<y>470</y>
|
||||
<width>131</width>
|
||||
<height>51</height>
|
||||
</rect>
|
||||
</property>
|
||||
<property name="text">
|
||||
<string>打开图像</string>
|
||||
</property>
|
||||
</widget>
|
||||
<widget class="QPushButton" name="pushButton_2">
|
||||
<property name="geometry">
|
||||
<rect>
|
||||
<x>570</x>
|
||||
<y>470</y>
|
||||
<width>131</width>
|
||||
<height>51</height>
|
||||
</rect>
|
||||
</property>
|
||||
<property name="text">
|
||||
<string>预测图像</string>
|
||||
</property>
|
||||
</widget>
|
||||
<widget class="QTextEdit" name="textEdit">
|
||||
<property name="geometry">
|
||||
<rect>
|
||||
<x>320</x>
|
||||
<y>20</y>
|
||||
<width>221</width>
|
||||
<height>41</height>
|
||||
</rect>
|
||||
</property>
|
||||
<property name="html">
|
||||
<string><!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 4.0//EN" "http://www.w3.org/TR/REC-html40/strict.dtd">
|
||||
<html><head><meta name="qrichtext" content="1" /><style type="text/css">
|
||||
p, li { white-space: pre-wrap; }
|
||||
</style></head><body style=" font-family:'SimSun'; font-size:9pt; font-weight:400; font-style:normal;">
|
||||
<p style=" margin-top:0px; margin-bottom:0px; margin-left:0px; margin-right:0px; -qt-block-indent:0; text-indent:0px;"><span style=" font-size:16pt;">分割图像GUI界面</span></p></body></html></string>
|
||||
</property>
|
||||
</widget>
|
||||
</widget>
|
||||
<widget class="QMenuBar" name="menubar">
|
||||
<property name="geometry">
|
||||
<rect>
|
||||
<x>0</x>
|
||||
<y>0</y>
|
||||
<width>889</width>
|
||||
<height>26</height>
|
||||
</rect>
|
||||
</property>
|
||||
</widget>
|
||||
<widget class="QStatusBar" name="statusbar"/>
|
||||
</widget>
|
||||
<resources/>
|
||||
<connections/>
|
||||
</ui>
|
|
@ -0,0 +1,72 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Form implementation generated from reading ui file 'mask6.ui'
|
||||
#
|
||||
# Created by: PyQt5 UI code generator 5.15.9
|
||||
#
|
||||
# WARNING: Any manual changes made to this file will be lost when pyuic5 is
|
||||
# run again. Do not edit this file unless you know what you are doing.
|
||||
|
||||
|
||||
from PyQt5 import QtCore, QtGui, QtWidgets
|
||||
|
||||
|
||||
class Ui_MainWindow(object):
|
||||
def setupUi(self, MainWindow):
|
||||
MainWindow.setObjectName("MainWindow")
|
||||
MainWindow.resize(889, 600)
|
||||
self.centralwidget = QtWidgets.QWidget(MainWindow)
|
||||
self.centralwidget.setObjectName("centralwidget")
|
||||
self.label = QtWidgets.QLabel(self.centralwidget)
|
||||
self.label.setGeometry(QtCore.QRect(10, 120, 421, 331))
|
||||
self.label.setStyleSheet("background-color: rgb(255, 255, 255);")
|
||||
self.label.setAlignment(QtCore.Qt.AlignCenter)
|
||||
self.label.setObjectName("label")
|
||||
self.label_2 = QtWidgets.QLabel(self.centralwidget)
|
||||
self.label_2.setGeometry(QtCore.QRect(440, 120, 421, 331))
|
||||
self.label_2.setStyleSheet("background-color: rgb(255, 255, 255);")
|
||||
self.label_2.setAlignment(QtCore.Qt.AlignCenter)
|
||||
self.label_2.setObjectName("label_2")
|
||||
self.pushButton = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton.setGeometry(QtCore.QRect(150, 470, 131, 51))
|
||||
self.pushButton.setObjectName("pushButton")
|
||||
self.pushButton_2 = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_2.setGeometry(QtCore.QRect(570, 470, 131, 51))
|
||||
self.pushButton_2.setObjectName("pushButton_2")
|
||||
self.textEdit = QtWidgets.QTextEdit(self.centralwidget)
|
||||
self.textEdit.setGeometry(QtCore.QRect(320, 20, 221, 41))
|
||||
self.textEdit.setObjectName("textEdit")
|
||||
MainWindow.setCentralWidget(self.centralwidget)
|
||||
self.menubar = QtWidgets.QMenuBar(MainWindow)
|
||||
self.menubar.setGeometry(QtCore.QRect(0, 0, 889, 26))
|
||||
self.menubar.setObjectName("menubar")
|
||||
MainWindow.setMenuBar(self.menubar)
|
||||
self.statusbar = QtWidgets.QStatusBar(MainWindow)
|
||||
self.statusbar.setObjectName("statusbar")
|
||||
MainWindow.setStatusBar(self.statusbar)
|
||||
|
||||
self.retranslateUi(MainWindow)
|
||||
QtCore.QMetaObject.connectSlotsByName(MainWindow)
|
||||
|
||||
def retranslateUi(self, MainWindow):
|
||||
_translate = QtCore.QCoreApplication.translate
|
||||
MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow"))
|
||||
self.label.setText(_translate("MainWindow", "原始图像"))
|
||||
self.label_2.setText(_translate("MainWindow", "预测图像"))
|
||||
self.pushButton.setText(_translate("MainWindow", "打开图像"))
|
||||
self.pushButton_2.setText(_translate("MainWindow", "预测图像"))
|
||||
self.textEdit.setHtml(_translate("MainWindow", "<!DOCTYPE HTML PUBLIC \"-//W3C//DTD HTML 4.0//EN\" \"http://www.w3.org/TR/REC-html40/strict.dtd\">\n"
|
||||
"<html><head><meta name=\"qrichtext\" content=\"1\" /><style type=\"text/css\">\n"
|
||||
"p, li { white-space: pre-wrap; }\n"
|
||||
"</style></head><body style=\" font-family:\'SimSun\'; font-size:9pt; font-weight:400; font-style:normal;\">\n"
|
||||
"<p style=\" margin-top:0px; margin-bottom:0px; margin-left:0px; margin-right:0px; -qt-block-indent:0; text-indent:0px;\"><span style=\" font-size:16pt;\">分割图像GUI界面</span></p></body></html>"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
app = QtWidgets.QApplication(sys.argv)
|
||||
MainWindow = QtWidgets.QMainWindow()
|
||||
ui = Ui_MainWindow()
|
||||
ui.setupUi(MainWindow)
|
||||
MainWindow.show()
|
||||
sys.exit(app.exec_())
|
|
@ -0,0 +1,426 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from typing import Optional, Tuple, Type
|
||||
|
||||
from .common import LayerNorm2d, MLPBlock
|
||||
|
||||
# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
|
||||
# 这个代码定义了 SAM 的图像编码器 ImageEncoderViT。它包含以下主要部分:
|
||||
# 1. patch_embed: 这是 ViT 的 patch embedding 层,用于将输入图像划分为 patch,并获得 patch 的 embedding。
|
||||
# 2. pos_embed: 这是 ViT的绝对位置 embedding,用于为每个patch提供位置信息。
|
||||
# 3. blocks: 这是 ViT 的 transformer encoder 块的列表,每个块包含多头自注意力层和前馈神经网络。
|
||||
# 4. neck: 这是图像编码器的“颈部”,包含几个卷积层和 LayerNorm 层,用于从 transformer encoder 块的输出中提取特征。
|
||||
# 5. forward(): 这是图像编码器的前向传播过程。首先通过 patch_embed 层获得 patch embedding, 然后加上 pos_embed。
|
||||
# 接着,patch embedding通过transformer encoder块。最后, neck 层从 transformer encoder 块的输出中提取特征。
|
||||
# 所以,这个 ImageEncoderViT 类定义了 SAM 的图像编码器,它基于 ViT,包含 patch embedding、位置 embedding、
|
||||
# transformer encoder块以及 neck, 可以从输入图像中提取特征。
|
||||
# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
|
||||
class ImageEncoderViT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
img_size: int = 1024,
|
||||
patch_size: int = 16,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
depth: int = 12,
|
||||
num_heads: int = 12,
|
||||
mlp_ratio: float = 4.0,
|
||||
out_chans: int = 256,
|
||||
qkv_bias: bool = True,
|
||||
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
||||
act_layer: Type[nn.Module] = nn.GELU,
|
||||
use_abs_pos: bool = True,
|
||||
use_rel_pos: bool = False,
|
||||
rel_pos_zero_init: bool = True,
|
||||
window_size: int = 0,
|
||||
global_attn_indexes: Tuple[int, ...] = (),
|
||||
) -> None:
|
||||
|
||||
super().__init__()
|
||||
self.img_size = img_size
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
kernel_size=(patch_size, patch_size),
|
||||
stride=(patch_size, patch_size),
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim,
|
||||
)
|
||||
|
||||
self.pos_embed: Optional[nn.Parameter] = None
|
||||
if use_abs_pos:
|
||||
# Initialize absolute positional embedding with pretrain image size.
|
||||
self.pos_embed = nn.Parameter(
|
||||
torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
|
||||
)
|
||||
|
||||
self.blocks = nn.ModuleList()
|
||||
for i in range(depth):
|
||||
block = Block(
|
||||
dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
norm_layer=norm_layer,
|
||||
act_layer=act_layer,
|
||||
use_rel_pos=use_rel_pos,
|
||||
rel_pos_zero_init=rel_pos_zero_init,
|
||||
window_size=window_size if i not in global_attn_indexes else 0,
|
||||
input_size=(img_size // patch_size, img_size // patch_size),
|
||||
)
|
||||
self.blocks.append(block)
|
||||
|
||||
self.neck = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
embed_dim,
|
||||
out_chans,
|
||||
kernel_size=1,
|
||||
bias=False,
|
||||
),
|
||||
LayerNorm2d(out_chans),
|
||||
nn.Conv2d(
|
||||
out_chans,
|
||||
out_chans,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
bias=False,
|
||||
),
|
||||
LayerNorm2d(out_chans),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.patch_embed(x)
|
||||
if self.pos_embed is not None:
|
||||
x = x + self.pos_embed
|
||||
|
||||
for blk in self.blocks:
|
||||
x = blk(x)
|
||||
|
||||
x = self.neck(x.permute(0, 3, 1, 2))
|
||||
|
||||
return x
|
||||
# 这个 Block 类实现了 transformer block, 可以选择使用全局注意力或局部窗口注意力,同时包含残差连接。它包含:
|
||||
# __init__方法:
|
||||
# 1. 输入参数:
|
||||
# - dim: 输入通道数
|
||||
# - num_heads: 注意力头数
|
||||
# - mlp_ratio: mlp 隐藏层与输入 embedding 维度的比例
|
||||
# - qkv_bias: 是否为 query、key、value 添加偏置
|
||||
# - norm_layer: 归一化层
|
||||
# - act_layer: 激活层
|
||||
# - use_rel_pos: 是否使用相对位置 embedding
|
||||
# - rel_pos_zero_init: 是否将相对位置 embedding 初始化为 0
|
||||
# - window_size: 窗口注意力的窗口大小,如果为 0 则使用全局注意力
|
||||
# - input_size: 计算相对位置 embedding 大小所需的输入分辨率
|
||||
# 2. 实例化第 1 次和第 2 次归一化层 norm1 和 norm2。
|
||||
# 3. 实例化 Attention 层和 MLPBlock 层。Attention 层的输入大小根据是否使用窗口注意力进行了调整。
|
||||
# 4. 记录窗口注意力的窗口大小 window_size。
|
||||
# forward方法:
|
||||
# 1. 提取 shortcut 并对 x 进行第 1 次归一化。
|
||||
# 2. 如果使用窗口注意力, 则调用 window_partition 对 x 进行窗口划分。
|
||||
# 3. 将 x 输入 Attention 层。
|
||||
# 4. 如果使用窗口注意力,则调用 window_unpartition 对 x 进行窗口反划分。
|
||||
# 5. x = shortcut + x,实现第 1 次残差连接。
|
||||
# 6. x = x + mlp(norm2(x)),实现第 2 次残差连接和 MLPBlock。
|
||||
# 7. 返回最终的 x。
|
||||
# 所以,这个 Block 类实现了带有可选的窗口注意力和双残差连接的transformer block。
|
||||
# 窗口注意力可以更好地建模局部结构,双残差连接可以提高梯度流动,都是transformer结构的重要改进。
|
||||
# 这个 Block 类实现了 transformer 的关键组成部分,同时提供了窗口注意力和残差连接等重要变体,可以显著提高其表现力和泛化能力。
|
||||
|
||||
class Block(nn.Module):
|
||||
"""Transformer blocks with support of window attention and residual propagation blocks"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qkv_bias: bool = True,
|
||||
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
||||
act_layer: Type[nn.Module] = nn.GELU,
|
||||
use_rel_pos: bool = False,
|
||||
rel_pos_zero_init: bool = True,
|
||||
window_size: int = 0,
|
||||
input_size: Optional[Tuple[int, int]] = None,
|
||||
) -> None:
|
||||
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = Attention(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
use_rel_pos=use_rel_pos,
|
||||
rel_pos_zero_init=rel_pos_zero_init,
|
||||
input_size=input_size if window_size == 0 else (window_size, window_size),
|
||||
)
|
||||
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
|
||||
|
||||
self.window_size = window_size
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
shortcut = x
|
||||
x = self.norm1(x)
|
||||
# Window partition
|
||||
if self.window_size > 0:
|
||||
H, W = x.shape[1], x.shape[2]
|
||||
x, pad_hw = window_partition(x, self.window_size)
|
||||
|
||||
x = self.attn(x)
|
||||
# Reverse window partition
|
||||
if self.window_size > 0:
|
||||
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
|
||||
|
||||
x = shortcut + x
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
|
||||
return x
|
||||
|
||||
# 这个Attention类实现了多头注意力机制,可以加入相对位置 embedding。它包含:
|
||||
# __init__方法:
|
||||
# 1. 输入参数:
|
||||
# - dim: 输入通道数
|
||||
# - num_heads: 注意力头数
|
||||
# - qkv_bias: 是否为查询、键、值添加偏置
|
||||
# - use_rel_pos: 是否使用相对位置 embedding
|
||||
# - rel_pos_zero_init: 是否将相对位置 embedding 初始化为0
|
||||
# - input_size: 计算相对位置 embedding 大小所需的输入分辨率
|
||||
# 2. 计算每个注意力头的维度 head_dim。
|
||||
# 3. 实例化 self.qkv和 输出投影 self.proj。
|
||||
# 4. 如果使用相对位置 embedding, 则初始化 rel_pos_h 和 rel_pos_w。
|
||||
# forward方法:
|
||||
# 1. 从输入 x 中提取批次大小 B、高度 H、宽度 W 和通道数 C。
|
||||
# 2. 计算 qkv,形状为 (3, B, nHead, H * W, C), 包含 query、key 和 value。
|
||||
# 3. 提取 q、 k 和 v, 形状为 (B * nHead, H * W, C)。
|
||||
# 4. 计算注意力图 attn,形状为 (B * nHead, H * W, H * W)。
|
||||
# 5. 如果使用相对位置 embedding, 则调用 add_decomposed_rel_pos 函数将其加入 attn。
|
||||
# 6. 对 attn 进行 softmax 归一化。
|
||||
# 7. 计算输出 x , (attn @ v), 形状为 (B, nHead, H, W, C), 然后合并注意力头, 形状为(B, H, W, C)。
|
||||
# 8. 对 x 进行投影, 返回最终的输出。
|
||||
# 所以,这个 Attention 类实现了带有相对位置 embedding 的多头注意力机制。
|
||||
# 它可以高效地建模图像和视频等二维结构数据,是 transformer 在这些领域得到广泛应用的关键。
|
||||
# 这个 Attention 类提供了相对位置 embedding 和多头注意力机制的实现,
|
||||
# 是理解 transformer 在图像和视频建模中的重要组成部分。
|
||||
class Attention(nn.Module):
|
||||
"""Multi-head Attention block with relative position embeddings."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = True,
|
||||
use_rel_pos: bool = False,
|
||||
rel_pos_zero_init: bool = True,
|
||||
input_size: Optional[Tuple[int, int]] = None,
|
||||
) -> None:
|
||||
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
self.use_rel_pos = use_rel_pos
|
||||
if self.use_rel_pos:
|
||||
assert (
|
||||
input_size is not None
|
||||
), "Input size must be provided if using relative positional encoding."
|
||||
# initialize relative positional embeddings
|
||||
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
|
||||
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
B, H, W, _ = x.shape
|
||||
# qkv with shape (3, B, nHead, H * W, C)
|
||||
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
# q, k, v with shape (B * nHead, H * W, C)
|
||||
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
|
||||
|
||||
attn = (q * self.scale) @ k.transpose(-2, -1)
|
||||
|
||||
if self.use_rel_pos:
|
||||
attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
|
||||
|
||||
attn = attn.softmax(dim=-1)
|
||||
x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
|
||||
x = self.proj(x)
|
||||
|
||||
return x
|
||||
# 这个 window_partition 函数的作用是将输入张量划分为非重叠的窗口。它包含:
|
||||
# 1. 输入参数:
|
||||
# - x: 输入的张量,形状为 [B, H, W, C]
|
||||
# - window_size: 窗口大小
|
||||
# 2. 首先计算输入需要 padding 的高度和宽度,将x进行padding。
|
||||
# 3. 然后将 x 的形状变化为 [B, Hp//window_size, window_size, Wp//window_size, window_size, C],
|
||||
# 表示将图像划分为 Hp//window_size * Wp//window_size 个 window_size * window_size 的 patch。
|
||||
# 4. 最后,通过 permute 和 view 操作,得到 windows 的形状为 [B * num_windows, window_size, window_size, C],
|
||||
# 表示将所有 patch 打平, num_windows 是 patch 的总数
|
||||
# 5. 返回windows和原来的高度和宽度(包含padding)Hp和Wp。
|
||||
# 所以,这个 window_partition 函数的作用是,将输入的图像划分为 window_size * window_size 的 patch,
|
||||
# 并将所有的 patch 打平, 输出可以输入到 transformer encoder 中的 token 序列。
|
||||
# 这个函数实现了将二维图像转化为一维 token 序列的过程,是 transformer 用于处理图像的一个关键步骤。
|
||||
# 通过这个函数,图像可以被 transformer encoder 所处理,就像处理文本序列一样。
|
||||
|
||||
def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
|
||||
|
||||
B, H, W, C = x.shape
|
||||
|
||||
pad_h = (window_size - H % window_size) % window_size
|
||||
pad_w = (window_size - W % window_size) % window_size
|
||||
if pad_h > 0 or pad_w > 0:
|
||||
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
|
||||
Hp, Wp = H + pad_h, W + pad_w
|
||||
|
||||
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
|
||||
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
||||
return windows, (Hp, Wp)
|
||||
|
||||
# 这个 window_unpartition 函数的作用是将 window_partition 函数的输出进行反划分, 恢复成原始的图像形状。它包含:
|
||||
# 1. 输入参数:
|
||||
# - windows: window_partition的输出,形状为 [B * num_windows, window_size, window_size, C]
|
||||
# - window_size: 窗口大小
|
||||
# - pad_hw: padding后的高度和宽度 (Hp, Wp)
|
||||
# - hw: padding前的原始高度和宽度 (H, W)
|
||||
# 2. 首先根据窗口大小和 padding 后的 hw 计算原始的 batch_size B。
|
||||
# 3. 然后将 windows 的形状变回 [B, Hp//window_size, Wp//window_size, window_size, window_size, C], 表示每个patch的位置。
|
||||
# 4. 接着通过permute和view操作,得到x的形状为 [B, Hp, Wp, C], 恢复成图像的形状。
|
||||
# 5. 最后,如果进行了padding,则截取x到原始的高度H和宽度W。
|
||||
# 6. 返回恢复后的图像x。
|
||||
# 所以,这个 window_unpartition 函数的作用是将通过 window_partition 函数得到的 patch 序列恢复成原始的图像。
|
||||
# 它实现了从一维 patch token 序列到二维图像的反过程。
|
||||
# 这个函数与 window_partition 函数相反,使得 transformer 能够最终从 patch token 序列恢复成图像,完成对图像的建模。
|
||||
# 总的来说,这个 window_unpartition 函数实现了从 patch token 序列恢复成原始图像的过程,与 window_partition 函数相对应,
|
||||
# 是使得 transformer 可以处理图像的另一个关键步骤
|
||||
def window_unpartition(
|
||||
windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
|
||||
) -> torch.Tensor:
|
||||
|
||||
Hp, Wp = pad_hw
|
||||
H, W = hw
|
||||
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
|
||||
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
|
||||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
|
||||
|
||||
if Hp > H or Wp > W:
|
||||
x = x[:, :H, :W, :].contiguous()
|
||||
return x
|
||||
|
||||
# 这个 get_rel_pos 函数的作用是根据 query 和 key 的相对位置获取相对位置 embedding。它包含:
|
||||
# 1. 输入参数:
|
||||
# - q_size: query 的大小
|
||||
# - k_size: key 的大小
|
||||
# - rel_pos: 相对位置 embedding, 形状为[L, C]
|
||||
# 2. 首先计算最大的相对距离 max_rel_dist, 它等于 query 和 key 大小的 2 倍减 1。
|
||||
# 3. 如果相对位置 embedding 的长度小于 max_rel_dist, 则通过线性插值将其调整到 max_rel_dist 的长度。
|
||||
# 4. 如果 q_size 和 k_size 不同, 则将 q_size 和 k_size 的坐标按比例缩放,使它们之间的相对距离保持不变。
|
||||
# 5. 根据调整后的 q_size 和 k_size 坐标计算相对坐标 relative_coords。
|
||||
# 6. 根据 relative_coords 从 rel_pos_resized 中提取相对位置 embedding。
|
||||
# 7. 返回提取出的相对位置 embedding。
|
||||
# 所以,这个 get_rel_pos 函数的主要作用是,当 query 和 key 的大小不同时,根据它们的相对位置关系提取相应的相对位置 embedding。
|
||||
# 它实现了相对位置 embedding 的可变长度和可缩放性。
|
||||
# 这个函数使得相对位置 embedding 可以用于 query 和 key 大小不同的 attention 中,是相对位置表示的一个关键步骤。
|
||||
# 总的来说,这个 get_rel_pos 函数实现了根据 query 和 key 的相对位置关系提取相应相对位置 embedding 的过程。
|
||||
# 它提供了相对位置 embedding 的可变长度和可缩放性,使其可以支持不同的 query 和 key 大小,从而应用到更加灵活的 attention 机制中。
|
||||
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
max_rel_dist = int(2 * max(q_size, k_size) - 1)
|
||||
# Interpolate rel pos if needed.
|
||||
if rel_pos.shape[0] != max_rel_dist:
|
||||
# Interpolate rel pos.
|
||||
rel_pos_resized = F.interpolate(
|
||||
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
|
||||
size=max_rel_dist,
|
||||
mode="linear",
|
||||
)
|
||||
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
|
||||
else:
|
||||
rel_pos_resized = rel_pos
|
||||
|
||||
# Scale the coords with short length if shapes for q and k are different.
|
||||
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
|
||||
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
|
||||
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
|
||||
|
||||
return rel_pos_resized[relative_coords.long()]
|
||||
|
||||
# 这个 add_decomposed_rel_pos 函数的作用是根据 query q 和 key k 的空间尺寸, 添加分解的相对位置 embedding 到注意力图 attn 中。它包含:
|
||||
# 1. 输入参数:
|
||||
# - attn: 注意力图,形状为 [B, q_h * q_w, k_h * k_w]
|
||||
# - q: 查询 q,形状为 [B, q_h * q_w, C]
|
||||
# - rel_pos_h: 高度轴的相对位置 embedding, 形状为[Lh, C]
|
||||
# - rel_pos_w: 宽度轴的相对位置 embedding, 形状为[Lw, C]
|
||||
# - q_size: 查询 q的空间尺寸 (q_h, q_w)
|
||||
# - k_size: 键 k的空间尺寸 (k_h, k_w)
|
||||
# 2. 从 q_size 和 k_size 中提取高度 q_h、宽度 q_w 以及高度 k_h、宽度 k_w。
|
||||
# 3. 调用 get_rel_pos 函数获取高度轴 Rh 和宽度轴 Rw 的相对位置 embedding。
|
||||
# 4. 重塑 q 为 [B, q_h, q_w, C]。
|
||||
# 5. 计算高度轴 rel_h 和宽度轴 rel_w 的相对位置图, 形状为 [B, q_h, q_w, k_h] 和 [B, q_h, q_w, k_w]。
|
||||
# 6. 将 attn 的形状变为 [B, q_h, q_w, k_h, k_w], 并加上 rel_h 和 rel_w。
|
||||
# 7. 将 attn 的形状变回 [B, q_h * q_w, k_h * k_w]。
|
||||
# 8. 返回加了相对位置 embedding 的 attn。
|
||||
def add_decomposed_rel_pos(
|
||||
attn: torch.Tensor,
|
||||
q: torch.Tensor,
|
||||
rel_pos_h: torch.Tensor,
|
||||
rel_pos_w: torch.Tensor,
|
||||
q_size: Tuple[int, int],
|
||||
k_size: Tuple[int, int],
|
||||
) -> torch.Tensor:
|
||||
|
||||
q_h, q_w = q_size
|
||||
k_h, k_w = k_size
|
||||
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
|
||||
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
|
||||
|
||||
B, _, dim = q.shape
|
||||
r_q = q.reshape(B, q_h, q_w, dim)
|
||||
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
|
||||
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
|
||||
|
||||
attn = (
|
||||
attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
|
||||
).view(B, q_h * q_w, k_h * k_w)
|
||||
|
||||
return attn
|
||||
|
||||
# 这个 PatchEmbed 类定义了 ViT 的 patch embedding 层。它包含:
|
||||
# 1. __init__: 初始化,设置卷积层的 kernel size、stride、padding以 及输入通道数和 embedding 维度。
|
||||
# 2. proj: 这是一个卷积层,用于将输入图像划分为 patch, 并获得每个 patch 的 embedding。
|
||||
# 3. forward: 前向传播过程。首先通过 proj 卷积层获得 patch embedding ,然后将维度从 [B, C, H, W] 转置成 [B, H, W, C]。
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""
|
||||
Image to Patch Embedding.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kernel_size: Tuple[int, int] = (16, 16),
|
||||
stride: Tuple[int, int] = (16, 16),
|
||||
padding: Tuple[int, int] = (0, 0),
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
) -> None:
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.proj = nn.Conv2d(
|
||||
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.proj(x)
|
||||
# B C H W -> B H W C
|
||||
x = x.permute(0, 2, 3, 1)
|
||||
return x
|
|
@ -0,0 +1,197 @@
|
|||
import cv2
|
||||
import os
|
||||
import numpy as np
|
||||
from segment_anything import sam_model_registry, SamPredictor
|
||||
|
||||
input_dir = 'scripts/input/images'
|
||||
output_dir = 'scripts/output/mask'
|
||||
crop_mode = True
|
||||
|
||||
print('最好是每加一个点就按w键predict一次')
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
image_files = [f for f in os.listdir(input_dir) if
|
||||
f.lower().endswith(('.png', '.jpg', '.jpeg', '.JPG', '.JPEG', '.PNG', '.tiff'))]
|
||||
|
||||
sam = sam_model_registry["vit_b"](checkpoint=r"D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth")
|
||||
_ = sam.to(device="cuda")
|
||||
predictor = SamPredictor(sam)
|
||||
|
||||
WINDOW_WIDTH = 1280
|
||||
WINDOW_HEIGHT = 720
|
||||
cv2.namedWindow("image", cv2.WINDOW_NORMAL)
|
||||
cv2.resizeWindow("image", WINDOW_WIDTH, WINDOW_HEIGHT)
|
||||
cv2.moveWindow("image", (1920 - WINDOW_WIDTH) // 2, (1080 - WINDOW_HEIGHT) // 2)
|
||||
|
||||
|
||||
def mouse_click(event, x, y, flags, param):
|
||||
global input_point, input_label, input_stop
|
||||
if not input_stop:
|
||||
if event == cv2.EVENT_LBUTTONDOWN:
|
||||
input_point.append([x, y])
|
||||
input_label.append(1)
|
||||
elif event == cv2.EVENT_RBUTTONDOWN:
|
||||
input_point.append([x, y])
|
||||
input_label.append(0)
|
||||
else:
|
||||
if event == cv2.EVENT_LBUTTONDOWN or event == cv2.EVENT_RBUTTONDOWN:
|
||||
print('此时不能添加点,按w退出mask选择模式')
|
||||
|
||||
|
||||
def apply_mask(image, mask, alpha_channel=True):
|
||||
if alpha_channel:
|
||||
alpha = np.zeros_like(image[..., 0])
|
||||
alpha[mask == 1] = 255
|
||||
image = cv2.merge((image[..., 0], image[..., 1], image[..., 2], alpha))
|
||||
else:
|
||||
image = np.where(mask[..., None] == 1, image, 0)
|
||||
return image
|
||||
|
||||
|
||||
def apply_color_mask(image, mask, color, color_dark=0.5):
|
||||
for c in range(3):
|
||||
image[:, :, c] = np.where(mask == 1, image[:, :, c] * (1 - color_dark) + color_dark * color[c], image[:, :, c])
|
||||
return image
|
||||
|
||||
|
||||
def get_next_filename(base_path, filename):
|
||||
name, ext = os.path.splitext(filename)
|
||||
for i in range(1, 101):
|
||||
new_name = f"{name}_{i}{ext}"
|
||||
if not os.path.exists(os.path.join(base_path, new_name)):
|
||||
return new_name
|
||||
return None
|
||||
|
||||
|
||||
def save_masked_image(image, mask, output_dir, filename, crop_mode_):
|
||||
if crop_mode_:
|
||||
y, x = np.where(mask)
|
||||
y_min, y_max, x_min, x_max = y.min(), y.max(), x.min(), x.max()
|
||||
cropped_mask = mask[y_min:y_max + 1, x_min:x_max + 1]
|
||||
cropped_image = image[y_min:y_max + 1, x_min:x_max + 1]
|
||||
masked_image = apply_mask(cropped_image, cropped_mask)
|
||||
else:
|
||||
masked_image = apply_mask(image, mask)
|
||||
filename = filename[:filename.rfind('.')] + '.png'
|
||||
new_filename = get_next_filename(output_dir, filename)
|
||||
|
||||
if new_filename:
|
||||
if masked_image.shape[-1] == 4:
|
||||
cv2.imwrite(os.path.join(output_dir, new_filename), masked_image, [cv2.IMWRITE_PNG_COMPRESSION, 9])
|
||||
else:
|
||||
cv2.imwrite(os.path.join(output_dir, new_filename), masked_image)
|
||||
print(f"Saved as {new_filename}")
|
||||
else:
|
||||
print("Could not save the image. Too many variations exist.")
|
||||
|
||||
|
||||
current_index = 0
|
||||
cv2.namedWindow("image")
|
||||
cv2.setMouseCallback("image", mouse_click)
|
||||
input_point = []
|
||||
input_label = []
|
||||
input_stop = False
|
||||
|
||||
while True:
|
||||
filename = image_files[current_index]
|
||||
image_orign = cv2.imread(os.path.join(input_dir, filename))
|
||||
image_crop = image_orign.copy()
|
||||
image = cv2.cvtColor(image_orign.copy(), cv2.COLOR_BGR2RGB)
|
||||
selected_mask = None
|
||||
logit_input = None
|
||||
|
||||
while True:
|
||||
input_stop = False
|
||||
image_display = image_orign.copy()
|
||||
display_info = f'{filename} '
|
||||
cv2.putText(image_display, display_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2, cv2.LINE_AA)
|
||||
|
||||
for point, label in zip(input_point, input_label):
|
||||
color = (0, 255, 0) if label == 1 else (0, 0, 255)
|
||||
cv2.circle(image_display, tuple(point), 5, color, -1)
|
||||
|
||||
if selected_mask is not None:
|
||||
color = tuple(np.random.randint(0, 256, 3).tolist())
|
||||
selected_image = apply_color_mask(image_display, selected_mask, color)
|
||||
|
||||
cv2.imshow("image", image_display)
|
||||
key = cv2.waitKey(1)
|
||||
|
||||
if key == ord(" "):
|
||||
input_point = []
|
||||
input_label = []
|
||||
selected_mask = None
|
||||
logit_input = None
|
||||
|
||||
elif key == ord("w"):
|
||||
input_stop = True
|
||||
if len(input_point) > 0 and len(input_label) > 0:
|
||||
predictor.set_image(image)
|
||||
input_point_np = np.array(input_point)
|
||||
input_label_np = np.array(input_label)
|
||||
|
||||
masks, scores, logits = predictor.predict(
|
||||
point_coords=input_point_np,
|
||||
point_labels=input_label_np,
|
||||
mask_input=logit_input[None, :, :] if logit_input is not None else None,
|
||||
multimask_output=True,
|
||||
)
|
||||
|
||||
mask_idx = 0
|
||||
num_masks = len(masks)
|
||||
while (1):
|
||||
color = tuple(np.random.randint(0, 256, 3).tolist())
|
||||
image_select = image_orign.copy()
|
||||
selected_mask = masks[mask_idx]
|
||||
selected_image = apply_color_mask(image_select, selected_mask, color)
|
||||
mask_info = f'Total: {num_masks} | Current: {mask_idx} | Score: {scores[mask_idx]:.2f} | w 预测 | d 切换下一个 | a 切换上一个 | q 移除最后一个 | s 保存'
|
||||
cv2.putText(selected_image, mask_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2,
|
||||
cv2.LINE_AA)
|
||||
cv2.imshow("image", selected_image)
|
||||
|
||||
key = cv2.waitKey(10)
|
||||
if key == ord('q') and len(input_point) > 0:
|
||||
input_point.pop(-1)
|
||||
input_label.pop(-1)
|
||||
elif key == ord('s'):
|
||||
save_masked_image(image_crop, selected_mask, output_dir, filename, crop_mode_=crop_mode)
|
||||
elif key == ord('a'):
|
||||
if mask_idx > 0:
|
||||
mask_idx -= 1
|
||||
else:
|
||||
mask_idx = num_masks - 1
|
||||
elif key == ord('d'):
|
||||
if mask_idx < num_masks - 1:
|
||||
mask_idx += 1
|
||||
else:
|
||||
mask_idx = 0
|
||||
elif key == ord('w'):
|
||||
break
|
||||
elif key == ord(" "):
|
||||
input_point = []
|
||||
input_label = []
|
||||
selected_mask = None
|
||||
logit_input = None
|
||||
break
|
||||
logit_input = logits[mask_idx, :, :]
|
||||
print('max score:', np.argmax(scores), ' select:', mask_idx)
|
||||
|
||||
elif key == ord('a'):
|
||||
current_index = max(0, current_index - 1)
|
||||
input_point = []
|
||||
input_label = []
|
||||
break
|
||||
elif key == ord('d'):
|
||||
current_index = min(len(image_files) - 1, current_index + 1)
|
||||
input_point = []
|
||||
input_label = []
|
||||
break
|
||||
elif key == 27:
|
||||
break
|
||||
elif key == ord('q') and len(input_point) > 0:
|
||||
input_point.pop(-1)
|
||||
input_label.pop(-1)
|
||||
elif key == ord('s') and selected_mask is not None:
|
||||
save_masked_image(image_crop, selected_mask, output_dir, filename, crop_mode_=crop_mode)
|
||||
|
||||
if key == 27:
|
||||
break
|
|
@ -0,0 +1,200 @@
|
|||
import sys
|
||||
from PyQt5 import QtCore, QtGui, QtWidgets
|
||||
from PyQt5.QtGui import QPixmap, QImage
|
||||
from PyQt5.QtCore import QTimer
|
||||
import cv2
|
||||
import os
|
||||
import numpy as np
|
||||
from segment_anything import sam_model_registry, SamPredictor
|
||||
|
||||
input_dir = r'D:\Program Files\Pycharm items\segment-anything-model\scripts\input\images'
|
||||
output_dir = r'D:\Program Files\Pycharm items\segment-anything-model\scripts\output\maskt'
|
||||
crop_mode = True
|
||||
|
||||
sam = sam_model_registry["vit_b"](checkpoint=r"D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth")
|
||||
_ = sam.to(device="cuda")
|
||||
predictor = SamPredictor(sam)
|
||||
|
||||
WINDOW_WIDTH = 1280
|
||||
WINDOW_HEIGHT = 720
|
||||
|
||||
class Ui_MainWindow(object):
|
||||
def setupUi(self, MainWindow):
|
||||
MainWindow.setObjectName("MainWindow")
|
||||
MainWindow.resize(1170, 486)
|
||||
self.centralwidget = QtWidgets.QWidget(MainWindow)
|
||||
self.centralwidget.setObjectName("centralwidget")
|
||||
self.pushButton_w = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_w.setGeometry(QtCore.QRect(10, 90, 151, 51))
|
||||
self.pushButton_w.setObjectName("pushButton_w")
|
||||
self.pushButton_a = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_a.setGeometry(QtCore.QRect(10, 160, 151, 51))
|
||||
self.pushButton_a.setObjectName("pushButton_a")
|
||||
self.pushButton_d = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_d.setGeometry(QtCore.QRect(10, 230, 151, 51))
|
||||
self.pushButton_d.setObjectName("pushButton_d")
|
||||
self.pushButton_s = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_s.setGeometry(QtCore.QRect(10, 300, 151, 51))
|
||||
self.pushButton_s.setObjectName("pushButton_s")
|
||||
self.pushButton_q = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_q.setGeometry(QtCore.QRect(10, 370, 151, 51))
|
||||
self.pushButton_q.setObjectName("pushButton_q")
|
||||
self.label_orign = QtWidgets.QLabel(self.centralwidget)
|
||||
self.label_orign.setGeometry(QtCore.QRect(180, 20, 471, 401))
|
||||
self.label_orign.setStyleSheet("background-color: rgb(255, 255, 255);")
|
||||
self.label_orign.setObjectName("label_orign")
|
||||
self.label_pre = QtWidgets.QLabel(self.centralwidget)
|
||||
self.label_pre.setGeometry(QtCore.QRect(660, 20, 471, 401))
|
||||
self.label_pre.setStyleSheet("background-color: rgb(255, 255, 255);")
|
||||
self.label_pre.setObjectName("label_pre")
|
||||
self.pushButton_opimg = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_opimg.setGeometry(QtCore.QRect(10, 20, 151, 51))
|
||||
self.pushButton_opimg.setObjectName("pushButton_opimg")
|
||||
MainWindow.setCentralWidget(self.centralwidget)
|
||||
self.menubar = QtWidgets.QMenuBar(MainWindow)
|
||||
self.menubar.setGeometry(QtCore.QRect(0, 0, 1170, 26))
|
||||
self.menubar.setObjectName("menubar")
|
||||
MainWindow.setMenuBar(self.menubar)
|
||||
self.statusbar = QtWidgets.QStatusBar(MainWindow)
|
||||
self.statusbar.setObjectName("statusbar")
|
||||
MainWindow.setStatusBar(self.statusbar)
|
||||
|
||||
self.retranslateUi(MainWindow)
|
||||
QtCore.QMetaObject.connectSlotsByName(MainWindow)
|
||||
|
||||
self.pushButton_opimg.clicked.connect(self.open_image)
|
||||
self.pushButton_w.clicked.connect(self.predict_image)
|
||||
|
||||
self.timer = QTimer()
|
||||
self.timer.timeout.connect(self.update_original_image)
|
||||
self.timer.start(100) # Update every 100 milliseconds
|
||||
|
||||
self.image_files = [f for f in os.listdir(input_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.JPG', '.JPEG', '.PNG', '.tiff'))]
|
||||
self.current_index = 0
|
||||
self.input_point = []
|
||||
self.input_label = []
|
||||
self.input_stop = False
|
||||
|
||||
cv2.namedWindow("image")
|
||||
cv2.setMouseCallback("image", self.mouse_click)
|
||||
|
||||
def retranslateUi(self, MainWindow):
|
||||
_translate = QtCore.QCoreApplication.translate
|
||||
MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow"))
|
||||
self.pushButton_w.setText(_translate("MainWindow", "w"))
|
||||
self.pushButton_a.setText(_translate("MainWindow", "a"))
|
||||
self.pushButton_d.setText(_translate("MainWindow", "d"))
|
||||
self.pushButton_s.setText(_translate("MainWindow", "s"))
|
||||
self.pushButton_q.setText(_translate("MainWindow", "q"))
|
||||
self.label_orign.setText(_translate("MainWindow", "<html><head/><body><p align=\"center\">原始图像</p></body></html>"))
|
||||
self.label_pre.setText(_translate("MainWindow", "<html><head/><body><p align=\"center\">预测图像</p></body></html>"))
|
||||
self.pushButton_opimg.setText(_translate("MainWindow", "打开图像"))
|
||||
|
||||
def open_image(self):
|
||||
filename, _ = QtWidgets.QFileDialog.getOpenFileName(None, "Open Image File", "", "Image files (*.jpg *.png)")
|
||||
if filename:
|
||||
self.current_index = 0
|
||||
pixmap = QPixmap(filename)
|
||||
pixmap = pixmap.scaled(self.label_orign.size(), QtCore.Qt.KeepAspectRatio)
|
||||
self.label_orign.setPixmap(pixmap)
|
||||
self.label_orign.setAlignment(QtCore.Qt.AlignCenter)
|
||||
self.input_point = []
|
||||
self.input_label = []
|
||||
|
||||
def predict_image(self):
|
||||
if self.current_index < len(self.image_files):
|
||||
filename = self.image_files[self.current_index]
|
||||
image_orign = cv2.imread(os.path.join(input_dir, filename))
|
||||
image_display = cv2.cvtColor(image_orign.copy(), cv2.COLOR_BGR2RGB)
|
||||
|
||||
for point, label in zip(self.input_point, self.input_label):
|
||||
color = (0, 255, 0) if label == 1 else (0, 0, 255)
|
||||
cv2.circle(image_display, tuple(point), 5, color, -1)
|
||||
|
||||
cv2.imshow("image", image_display)
|
||||
|
||||
def update_original_image(self):
|
||||
if self.current_index < len(self.image_files):
|
||||
filename = self.image_files[self.current_index]
|
||||
image_orign = cv2.imread(os.path.join(input_dir, filename))
|
||||
image_display = cv2.cvtColor(image_orign.copy(), cv2.COLOR_BGR2RGB)
|
||||
|
||||
for point, label in zip(self.input_point, self.input_label):
|
||||
color = (0, 255, 0) if label == 1 else (0, 0, 255)
|
||||
cv2.circle(image_display, tuple(point), 5, color, -1)
|
||||
|
||||
height, width, channel = image_display.shape
|
||||
bytesPerLine = 3 * width
|
||||
qImg = QImage(image_display.data, width, height, bytesPerLine, QImage.Format_RGB888)
|
||||
pixmap = QPixmap.fromImage(qImg)
|
||||
pixmap = pixmap.scaled(self.label_orign.size(), QtCore.Qt.KeepAspectRatio)
|
||||
self.label_orign.setPixmap(pixmap)
|
||||
self.label_orign.setAlignment(QtCore.Qt.AlignCenter)
|
||||
|
||||
def mouse_click(self, event, x, y, flags, param):
|
||||
if not self.input_stop:
|
||||
if event == cv2.EVENT_LBUTTONDOWN:
|
||||
self.input_point.append([x, y])
|
||||
self.input_label.append(1)
|
||||
elif event == cv2.EVENT_RBUTTONDOWN:
|
||||
self.input_point.append([x, y])
|
||||
self.input_label.append(0)
|
||||
else:
|
||||
if event == cv2.EVENT_LBUTTONDOWN or event == cv2.EVENT_RBUTTONDOWN:
|
||||
print('此时不能添加点,按w退出mask选择模式')
|
||||
|
||||
def apply_mask(image, mask, alpha_channel=True):
|
||||
if alpha_channel:
|
||||
alpha = np.zeros_like(image[..., 0])
|
||||
alpha[mask == 1] = 255
|
||||
image = cv2.merge((image[..., 0], image[..., 1], image[..., 2], alpha))
|
||||
else:
|
||||
image = np.where(mask[..., None] == 1, image, 0)
|
||||
return image
|
||||
|
||||
|
||||
def apply_color_mask(image, mask, color, color_dark=0.5):
|
||||
for c in range(3):
|
||||
image[:, :, c] = np.where(mask == 1, image[:, :, c] * (1 - color_dark) + color_dark * color[c], image[:, :, c])
|
||||
return image
|
||||
|
||||
|
||||
def get_next_filename(base_path, filename):
|
||||
name, ext = os.path.splitext(filename)
|
||||
for i in range(1, 101):
|
||||
new_name = f"{name}_{i}{ext}"
|
||||
if not os.path.exists(os.path.join(base_path, new_name)):
|
||||
return new_name
|
||||
return None
|
||||
|
||||
|
||||
def save_masked_image(image, mask, output_dir, filename, crop_mode_):
|
||||
if crop_mode_:
|
||||
y, x = np.where(mask)
|
||||
y_min, y_max, x_min, x_max = y.min(), y.max(), x.min(), x.max()
|
||||
cropped_mask = mask[y_min:y_max + 1, x_min:x_max + 1]
|
||||
cropped_image = image[y_min:y_max + 1, x_min:x_max + 1]
|
||||
masked_image = apply_mask(cropped_image, cropped_mask)
|
||||
else:
|
||||
masked_image = apply_mask(image, mask)
|
||||
filename = filename[:filename.rfind('.')] + '.png'
|
||||
new_filename = get_next_filename(output_dir, filename)
|
||||
|
||||
if new_filename:
|
||||
if masked_image.shape[-1] == 4:
|
||||
cv2.imwrite(os.path.join(output_dir, new_filename), masked_image, [cv2.IMWRITE_PNG_COMPRESSION, 9])
|
||||
else:
|
||||
cv2.imwrite(os.path.join(output_dir, new_filename), masked_image)
|
||||
print(f"Saved as {new_filename}")
|
||||
else:
|
||||
print("Could not save the image. Too many variations exist.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app = QtWidgets.QApplication(sys.argv)
|
||||
MainWindow = QtWidgets.QMainWindow()
|
||||
ui = Ui_MainWindow()
|
||||
ui.setupUi(MainWindow)
|
||||
MainWindow.show()
|
||||
sys.exit(app.exec_())
|
||||
|
|
@ -0,0 +1,236 @@
|
|||
import cv2
|
||||
import os
|
||||
import numpy as np
|
||||
import json
|
||||
from segment_anything import sam_model_registry, SamPredictor
|
||||
|
||||
input_dir = r'C:\Users\t2581\Desktop\222\images'
|
||||
output_dir = r'C:\Users\t2581\Desktop\222\2'
|
||||
crop_mode = True
|
||||
|
||||
print('最好是每加一个点就按w键predict一次')
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
image_files = [f for f in os.listdir(input_dir) if
|
||||
f.lower().endswith(('.png', '.jpg', '.jpeg', '.JPG', '.JPEG', '.PNG', '.tiff'))]
|
||||
|
||||
sam = sam_model_registry["vit_b"](checkpoint=r"D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth")
|
||||
_ = sam.to(device="cuda")
|
||||
predictor = SamPredictor(sam)
|
||||
|
||||
WINDOW_WIDTH = 1280
|
||||
WINDOW_HEIGHT = 720
|
||||
cv2.namedWindow("image", cv2.WINDOW_NORMAL)
|
||||
cv2.resizeWindow("image", WINDOW_WIDTH, WINDOW_HEIGHT)
|
||||
cv2.moveWindow("image", (1920 - WINDOW_WIDTH) // 2, (1080 - WINDOW_HEIGHT) // 2)
|
||||
|
||||
# 定义类别
|
||||
categories = {1: "category1", 2: "category2", 3: "category3"}
|
||||
category_colors = {1: (0, 255, 0), 2: (0, 0, 255), 3: (255, 0, 0)}
|
||||
current_label = 1 # 默认类别
|
||||
|
||||
def mouse_click(event, x, y, flags, param):
|
||||
global input_points, input_labels, input_stop
|
||||
if not input_stop:
|
||||
if event == cv2.EVENT_LBUTTONDOWN:
|
||||
input_points.append([x, y])
|
||||
input_labels.append(current_label)
|
||||
elif event == cv2.EVENT_RBUTTONDOWN:
|
||||
input_points.append([x, y])
|
||||
input_labels.append(0)
|
||||
else:
|
||||
if event == cv2.EVENT_LBUTTONDOWN or event == cv2.EVENT_RBUTTONDOWN:
|
||||
print('此时不能添加点,按w退出mask选择模式')
|
||||
|
||||
def apply_color_mask(image, mask, color, color_dark=0.5):
|
||||
masked_image = image.copy()
|
||||
for c in range(3):
|
||||
masked_image[:, :, c] = np.where(mask == 1, image[:, :, c] * (1 - color_dark) + color_dark * color[c],
|
||||
image[:, :, c])
|
||||
return masked_image
|
||||
|
||||
def draw_external_rectangle(image, mask, pv):
|
||||
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
for contour in contours:
|
||||
x, y, w, h = cv2.boundingRect(contour)
|
||||
cv2.rectangle(image, (x, y), (x + w, y + h), (0, 255, 255), 2) # Yellow rectangle
|
||||
cv2.putText(image, pv, (x, y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
|
||||
|
||||
def save_masked_image_and_json(image, masks, output_dir, filename, crop_mode_, pv):
|
||||
masked_image = image.copy()
|
||||
json_shapes = []
|
||||
|
||||
for mask, label, score in masks:
|
||||
color = category_colors[int(label[-1])] # 获取类别对应的颜色
|
||||
masked_image = apply_color_mask(masked_image, mask, color)
|
||||
draw_external_rectangle(masked_image, mask, f"{label}: {score:.2f}")
|
||||
|
||||
# Convert mask to polygons
|
||||
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
polygons = [contour.reshape(-1, 2).tolist() for contour in contours]
|
||||
|
||||
# Append JSON shapes
|
||||
json_shapes.extend([
|
||||
{
|
||||
"label": label,
|
||||
"points": polygon,
|
||||
"group_id": None,
|
||||
"shape_type": "polygon",
|
||||
"flags": {}
|
||||
} for polygon in polygons
|
||||
])
|
||||
|
||||
masked_filename = filename[:filename.rfind('.')] + '_masked.png'
|
||||
cv2.imwrite(os.path.join(output_dir, masked_filename), masked_image)
|
||||
print(f"Saved image as {masked_filename}")
|
||||
|
||||
# Create JSON data
|
||||
json_data = {
|
||||
"version": "5.1.1",
|
||||
"flags": {},
|
||||
"shapes": json_shapes,
|
||||
"imagePath": filename,
|
||||
"imageData": None,
|
||||
"imageHeight": image.shape[0],
|
||||
"imageWidth": image.shape[1]
|
||||
}
|
||||
|
||||
# Save JSON file
|
||||
json_filename = filename[:filename.rfind('.')] + '_masked.json'
|
||||
with open(os.path.join(output_dir, json_filename), 'w') as json_file:
|
||||
json.dump(json_data, json_file, indent=4)
|
||||
print(f"Saved JSON as {json_filename}")
|
||||
|
||||
current_index = 0
|
||||
|
||||
cv2.namedWindow("image")
|
||||
cv2.setMouseCallback("image", mouse_click)
|
||||
input_points = []
|
||||
input_labels = []
|
||||
input_stop = False
|
||||
masks = []
|
||||
all_masks = [] # 用于保存所有类别的标注
|
||||
|
||||
while True:
|
||||
filename = image_files[current_index]
|
||||
image_orign = cv2.imread(os.path.join(input_dir, filename))
|
||||
image_crop = image_orign.copy()
|
||||
image = cv2.cvtColor(image_orign.copy(), cv2.COLOR_BGR2RGB)
|
||||
selected_mask = None
|
||||
logit_input = None
|
||||
while True:
|
||||
input_stop = False
|
||||
image_display = image_orign.copy()
|
||||
display_info = f'{filename} | Current label: {categories[current_label]}'
|
||||
cv2.putText(image_display, display_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2, cv2.LINE_AA)
|
||||
for point, label in zip(input_points, input_labels):
|
||||
color = (0, 255, 0) if label > 0 else (0, 0, 255)
|
||||
cv2.circle(image_display, tuple(point), 5, color, -1)
|
||||
if selected_mask is not None:
|
||||
color = tuple(np.random.randint(0, 256, 3).tolist())
|
||||
selected_image = apply_color_mask(image_display, selected_mask, color)
|
||||
|
||||
cv2.imshow("image", image_display)
|
||||
key = cv2.waitKey(1)
|
||||
|
||||
if key == ord(" "):
|
||||
input_points = []
|
||||
input_labels = []
|
||||
selected_mask = None
|
||||
logit_input = None
|
||||
elif key == ord("w"):
|
||||
input_stop = True
|
||||
if len(input_points) > 0 and len(input_labels) > 0:
|
||||
try:
|
||||
predictor.set_image(image)
|
||||
input_point_np = np.array(input_points)
|
||||
input_label_np = np.array(input_labels)
|
||||
|
||||
masks_pred, scores, logits = predictor.predict(
|
||||
point_coords=input_point_np,
|
||||
point_labels=input_label_np,
|
||||
mask_input=logit_input[None, :, :] if logit_input is not None else None,
|
||||
multimask_output=True,
|
||||
)
|
||||
|
||||
mask_idx = 0
|
||||
num_masks = len(masks_pred) # masks的数量
|
||||
while True:
|
||||
color = tuple(np.random.randint(0, 256, 3).tolist()) # 随机列表颜色
|
||||
image_select = image_orign.copy()
|
||||
selected_mask = masks_pred[mask_idx] # 选择msks也就是,a,d切换
|
||||
selected_image = apply_color_mask(image_select, selected_mask, color)
|
||||
mask_info = f'Total: {num_masks} | Current: {mask_idx} | Score: {scores[mask_idx]:.2f} '
|
||||
cv2.putText(selected_image, mask_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255),
|
||||
2, cv2.LINE_AA)
|
||||
# todo 显示在当前的图片,
|
||||
cv2.imshow("image", selected_image)
|
||||
|
||||
key = cv2.waitKey(10)
|
||||
if key == ord('q') and len(input_points) > 0:
|
||||
input_points.pop(-1)
|
||||
input_labels.pop(-1)
|
||||
elif key == ord('s'):
|
||||
masks.append((selected_mask, categories[current_label], scores[mask_idx]))
|
||||
elif key == ord('a'):
|
||||
if mask_idx > 0:
|
||||
mask_idx -= 1
|
||||
else:
|
||||
mask_idx = num_masks - 1
|
||||
elif key == ord('d'):
|
||||
if mask_idx < num_masks - 1:
|
||||
mask_idx += 1
|
||||
else:
|
||||
mask_idx = 0
|
||||
elif key == ord('w'):
|
||||
break
|
||||
elif key == ord(" "):
|
||||
input_points = []
|
||||
input_labels = []
|
||||
selected_mask = None
|
||||
logit_input = None
|
||||
break
|
||||
logit_input = logits[mask_idx, :, :]
|
||||
print('max score:', np.argmax(scores), ' select:', mask_idx)
|
||||
except Exception as e:
|
||||
print(f"Error during prediction: {e}")
|
||||
|
||||
elif key == ord('a'):
|
||||
current_index = max(0, current_index - 1)
|
||||
input_points = []
|
||||
input_labels = []
|
||||
break
|
||||
elif key == ord('d'):
|
||||
current_index = min(len(image_files) - 1, current_index + 1)
|
||||
input_points = []
|
||||
input_labels = []
|
||||
break
|
||||
elif key == 27:
|
||||
break
|
||||
elif key == ord('q') and len(input_points) > 0:
|
||||
input_points.pop(-1)
|
||||
input_labels.pop(-1)
|
||||
elif key == ord('r'):
|
||||
if masks:
|
||||
all_masks.extend(masks) # 保存当前的masks到all_masks
|
||||
masks = [] # 清空当前的masks
|
||||
input_points = []
|
||||
input_labels = []
|
||||
selected_mask = None
|
||||
logit_input = None
|
||||
elif key == ord('s'):
|
||||
if masks:
|
||||
all_masks.extend(masks) # 保存当前的masks到all_masks
|
||||
if all_masks:
|
||||
save_masked_image_and_json(image_crop, all_masks, output_dir, filename, crop_mode_=crop_mode, pv="")
|
||||
all_masks = [] # 清空所有保存的masks
|
||||
masks = [] # 清空当前的masks
|
||||
|
||||
elif key in [ord(str(i)) for i in categories.keys()]:
|
||||
current_label = int(chr(key))
|
||||
print(f"Switched to label: {categories[current_label]}")
|
||||
|
||||
if key == 27:
|
||||
break
|
||||
|
||||
cv2.destroyAllWindows()
|
||||
|
|
@ -0,0 +1,202 @@
|
|||
import cv2
|
||||
import os
|
||||
import numpy as np
|
||||
from segment_anything import sam_model_registry, SamPredictor
|
||||
|
||||
input_dir = r'D:\Program Files\Pycharm items\segment-anything-model\scripts\input\images'
|
||||
output_dir = r'D:\Program Files\Pycharm items\segment-anything-model\scripts\output\maskt'
|
||||
crop_mode = True
|
||||
|
||||
print('最好是每加一个点就按w键predict一次')
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
image_files = [f for f in os.listdir(input_dir) if
|
||||
f.lower().endswith(('.png', '.jpg', '.jpeg', '.JPG', '.JPEG', '.PNG', '.tiff'))]
|
||||
|
||||
sam = sam_model_registry["vit_b"](checkpoint=r"D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth")
|
||||
_ = sam.to(device="cuda")
|
||||
predictor = SamPredictor(sam)
|
||||
|
||||
WINDOW_WIDTH = 1280
|
||||
WINDOW_HEIGHT = 720
|
||||
cv2.namedWindow("image", cv2.WINDOW_NORMAL)
|
||||
cv2.resizeWindow("image", WINDOW_WIDTH, WINDOW_HEIGHT)
|
||||
cv2.moveWindow("image", (1920 - WINDOW_WIDTH) // 2, (1080 - WINDOW_HEIGHT) // 2)
|
||||
|
||||
def mouse_click(event, x, y, flags, param):
|
||||
global input_point, input_label, input_stop
|
||||
if not input_stop:
|
||||
if event == cv2.EVENT_LBUTTONDOWN:
|
||||
input_point.append([x, y])
|
||||
input_label.append(1)
|
||||
elif event == cv2.EVENT_RBUTTONDOWN:
|
||||
input_point.append([x, y])
|
||||
input_label.append(0)
|
||||
else:
|
||||
if event == cv2.EVENT_LBUTTONDOWN or event == cv2.EVENT_RBUTTONDOWN:
|
||||
print('此时不能添加点,按w退出mask选择模式')
|
||||
|
||||
|
||||
def apply_mask(image, mask, alpha_channel=True):
|
||||
if alpha_channel:
|
||||
alpha = np.zeros_like(image[..., 0])
|
||||
alpha[mask == 1] = 255
|
||||
image = cv2.merge((image[..., 0], image[..., 1], image[..., 2], alpha))
|
||||
else:
|
||||
image = np.where(mask[..., None] == 1, image, 0)
|
||||
return image
|
||||
|
||||
|
||||
def apply_color_mask(image, mask, color, color_dark=0.5):
|
||||
for c in range(3):
|
||||
image[:, :, c] = np.where(mask == 1, image[:, :, c] * (1 - color_dark) + color_dark * color[c], image[:, :, c])
|
||||
return image
|
||||
|
||||
|
||||
def get_next_filename(base_path, filename):
|
||||
name, ext = os.path.splitext(filename)
|
||||
for i in range(1, 101):
|
||||
new_name = f"{name}_{i}{ext}"
|
||||
if not os.path.exists(os.path.join(base_path, new_name)):
|
||||
return new_name
|
||||
return None
|
||||
|
||||
|
||||
def save_masked_image(image, mask, output_dir, filename, crop_mode_):
|
||||
if crop_mode_:
|
||||
y, x = np.where(mask)
|
||||
y_min, y_max, x_min, x_max = y.min(), y.max(), x.min(), x.max()
|
||||
cropped_mask = mask[y_min:y_max + 1, x_min:x_max + 1]
|
||||
cropped_image = image[y_min:y_max + 1, x_min:x_max + 1]
|
||||
masked_image = apply_mask(cropped_image, cropped_mask)
|
||||
else:
|
||||
masked_image = apply_mask(image, mask)
|
||||
filename = filename[:filename.rfind('.')] + '.png'
|
||||
new_filename = get_next_filename(output_dir, filename)
|
||||
|
||||
if new_filename:
|
||||
if masked_image.shape[-1] == 4:
|
||||
cv2.imwrite(os.path.join(output_dir, new_filename), masked_image, [cv2.IMWRITE_PNG_COMPRESSION, 9])
|
||||
else:
|
||||
cv2.imwrite(os.path.join(output_dir, new_filename), masked_image)
|
||||
print(f"Saved as {new_filename}")
|
||||
else:
|
||||
print("Could not save the image. Too many variations exist.")
|
||||
|
||||
|
||||
current_index = 0
|
||||
|
||||
cv2.namedWindow("image")
|
||||
cv2.setMouseCallback("image", mouse_click)
|
||||
input_point = []
|
||||
input_label = []
|
||||
input_stop = False
|
||||
while True:
|
||||
filename = image_files[current_index]
|
||||
image_orign = cv2.imread(os.path.join(input_dir, filename))
|
||||
image_crop = image_orign.copy()
|
||||
image = cv2.cvtColor(image_orign.copy(), cv2.COLOR_BGR2RGB)
|
||||
selected_mask = None
|
||||
logit_input = None
|
||||
while True:
|
||||
image_display = image_orign.copy()
|
||||
display_info = f'{filename} '
|
||||
cv2.putText(image_display, display_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2, cv2.LINE_AA)
|
||||
for point, label in zip(input_point, input_label):
|
||||
color = (0, 255, 0) if label == 1 else (0, 0, 255)
|
||||
cv2.circle(image_display, tuple(point), 5, color, -1)
|
||||
if selected_mask is not None:
|
||||
color = tuple(np.random.randint(0, 256, 3).tolist())
|
||||
selected_image = apply_color_mask(image_display, selected_mask, color)
|
||||
|
||||
cv2.imshow("image", image_display)
|
||||
key = cv2.waitKey(1)
|
||||
|
||||
if key == ord(" "):
|
||||
input_point = []
|
||||
input_label = []
|
||||
selected_mask = None
|
||||
logit_input = None
|
||||
elif key == ord("w"):
|
||||
input_stop = True
|
||||
if len(input_point) > 0 and len(input_label) > 0:
|
||||
|
||||
predictor.set_image(image)
|
||||
input_point_np = np.array(input_point)
|
||||
input_label_np = np.array(input_label)
|
||||
|
||||
masks, scores, logits = predictor.predict(
|
||||
point_coords=input_point_np,
|
||||
point_labels=input_label_np,
|
||||
mask_input=logit_input[None, :, :] if logit_input is not None else None,
|
||||
multimask_output=True,
|
||||
)
|
||||
|
||||
mask_idx = 0
|
||||
num_masks = len(masks)
|
||||
|
||||
prediction_window_name = "Prediction"
|
||||
cv2.namedWindow(prediction_window_name, cv2.WINDOW_NORMAL)
|
||||
cv2.resizeWindow(prediction_window_name, WINDOW_WIDTH, WINDOW_HEIGHT)
|
||||
cv2.moveWindow(prediction_window_name, (1920 - WINDOW_WIDTH) // 2, (1080 - WINDOW_HEIGHT) // 2)
|
||||
|
||||
while True:
|
||||
color = tuple(np.random.randint(0, 256, 3).tolist())
|
||||
image_select = image_orign.copy()
|
||||
selected_mask = masks[mask_idx]
|
||||
selected_image = apply_color_mask(image_select, selected_mask, color)
|
||||
mask_info = f'Total: {num_masks} | Current: {mask_idx} | Score: {scores[mask_idx]:.2f} | w 预测 | d 切换下一个 | a 切换上一个 |q 移除最后一个 | s 保存'
|
||||
cv2.putText(selected_image, mask_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2, cv2.LINE_AA)
|
||||
|
||||
cv2.imshow(prediction_window_name, selected_image)
|
||||
|
||||
key = cv2.waitKey(10)
|
||||
if key == ord('q') and len(input_point) > 0:
|
||||
input_point.pop(-1)
|
||||
input_label.pop(-1)
|
||||
elif key == ord('s'):
|
||||
save_masked_image(image_crop, selected_mask, output_dir, filename, crop_mode_=crop_mode)
|
||||
elif key == ord('a'):
|
||||
if mask_idx > 0:
|
||||
mask_idx -= 1
|
||||
else:
|
||||
mask_idx = num_masks - 1
|
||||
elif key == ord('d'):
|
||||
if mask_idx < num_masks - 1:
|
||||
mask_idx += 1
|
||||
else:
|
||||
mask_idx = 0
|
||||
elif key == ord('w'):
|
||||
input_stop = False # Allow adding points again
|
||||
break
|
||||
elif key == ord(" "):
|
||||
input_point = []
|
||||
input_label = []
|
||||
selected_mask = None
|
||||
logit_input = None
|
||||
break
|
||||
logit_input = logits[mask_idx, :, :]
|
||||
print('max score:', np.argmax(scores), ' select:', mask_idx)
|
||||
|
||||
elif key == ord('a'):
|
||||
current_index = max(0, current_index - 1)
|
||||
input_point = []
|
||||
input_label = []
|
||||
break
|
||||
elif key == ord('d'):
|
||||
current_index = min(len(image_files) - 1, current_index + 1)
|
||||
input_point = []
|
||||
input_label = []
|
||||
break
|
||||
elif key == 27:
|
||||
break
|
||||
elif key == ord('q') and len(input_point) > 0:
|
||||
input_point.pop(-1)
|
||||
input_label.pop(-1)
|
||||
elif key == ord('s') and selected_mask is not None:
|
||||
save_masked_image(image_crop, selected_mask, output_dir, filename, crop_mode_=crop_mode)
|
||||
|
||||
if key == 27:
|
||||
break
|
||||
|
||||
cv2.destroyAllWindows() # Close all windows before exiting
|
||||
|
|
@ -0,0 +1,300 @@
|
|||
import cv2
|
||||
import os
|
||||
import numpy as np
|
||||
from PyQt5 import QtCore, QtGui, QtWidgets
|
||||
from segment_anything import sam_model_registry, SamPredictor
|
||||
|
||||
|
||||
class Ui_MainWindow(object):
|
||||
def setupUi(self, MainWindow):
|
||||
MainWindow.setObjectName("MainWindow")
|
||||
MainWindow.resize(1170, 486)
|
||||
self.centralwidget = QtWidgets.QWidget(MainWindow)
|
||||
self.centralwidget.setObjectName("centralwidget")
|
||||
self.pushButton_w = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_w.setGeometry(QtCore.QRect(10, 90, 151, 51))
|
||||
self.pushButton_w.setObjectName("pushButton_w")
|
||||
self.pushButton_a = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_a.setGeometry(QtCore.QRect(10, 160, 151, 51))
|
||||
self.pushButton_a.setObjectName("pushButton_a")
|
||||
self.pushButton_d = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_d.setGeometry(QtCore.QRect(10, 230, 151, 51))
|
||||
self.pushButton_d.setObjectName("pushButton_d")
|
||||
self.pushButton_s = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_s.setGeometry(QtCore.QRect(10, 300, 151, 51))
|
||||
self.pushButton_s.setObjectName("pushButton_s")
|
||||
self.pushButton_q = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_q.setGeometry(QtCore.QRect(10, 370, 151, 51))
|
||||
self.pushButton_q.setObjectName("pushButton_q")
|
||||
self.label_orign = QtWidgets.QLabel(self.centralwidget)
|
||||
self.label_orign.setGeometry(QtCore.QRect(180, 20, 450, 450))
|
||||
self.label_orign.setStyleSheet("background-color: rgb(255, 255, 255);")
|
||||
self.label_orign.setObjectName("label_orign")
|
||||
self.label_pre = QtWidgets.QLabel(self.centralwidget)
|
||||
self.label_pre.setGeometry(QtCore.QRect(660, 20, 450, 450))
|
||||
self.label_pre.setStyleSheet("background-color: rgb(255, 255, 255);")
|
||||
self.label_pre.setObjectName("label_pre")
|
||||
self.pushButton_opimg = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_opimg.setGeometry(QtCore.QRect(10, 20, 151, 51))
|
||||
self.pushButton_opimg.setObjectName("pushButton_opimg")
|
||||
MainWindow.setCentralWidget(self.centralwidget)
|
||||
self.menubar = QtWidgets.QMenuBar(MainWindow)
|
||||
self.menubar.setGeometry(QtCore.QRect(0, 0, 1170, 26))
|
||||
self.menubar.setObjectName("menubar")
|
||||
MainWindow.setMenuBar(self.menubar)
|
||||
self.statusbar = QtWidgets.QStatusBar(MainWindow)
|
||||
self.statusbar.setObjectName("statusbar")
|
||||
MainWindow.setStatusBar(self.statusbar)
|
||||
|
||||
self.retranslateUi(MainWindow)
|
||||
QtCore.QMetaObject.connectSlotsByName(MainWindow)
|
||||
|
||||
def retranslateUi(self, MainWindow):
|
||||
_translate = QtCore.QCoreApplication.translate
|
||||
MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow"))
|
||||
self.pushButton_w.setText(_translate("MainWindow", "w"))
|
||||
self.pushButton_a.setText(_translate("MainWindow", "a"))
|
||||
self.pushButton_d.setText(_translate("MainWindow", "d"))
|
||||
self.pushButton_s.setText(_translate("MainWindow", "s"))
|
||||
self.pushButton_q.setText(_translate("MainWindow", "q"))
|
||||
self.label_orign.setText(
|
||||
_translate("MainWindow", "<html><head/><body><p align=\"center\">Original Image</p></body></html>"))
|
||||
self.label_pre.setText(
|
||||
_translate("MainWindow", "<html><head/><body><p align=\"center\">Predicted Image</p></body></html>"))
|
||||
self.pushButton_opimg.setText(_translate("MainWindow", "Open Image"))
|
||||
|
||||
|
||||
class MainWindow(QtWidgets.QMainWindow, Ui_MainWindow):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.setupUi(self)
|
||||
|
||||
self.pushButton_opimg.clicked.connect(self.open_image)
|
||||
self.pushButton_w.clicked.connect(self.predict_and_interact)
|
||||
|
||||
self.image_files = []
|
||||
self.current_index = 0
|
||||
self.input_point = []
|
||||
self.input_label = []
|
||||
self.input_stop = False
|
||||
self.interaction_count = 0 # 记录交互次数
|
||||
self.sam = sam_model_registry["vit_b"](
|
||||
checkpoint=r"D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth")
|
||||
_ = self.sam.to(device="cuda")
|
||||
self.predictor = SamPredictor(self.sam)
|
||||
|
||||
# Calculate coordinate scaling factors
|
||||
self.scale_x = 1.0
|
||||
self.scale_y = 1.0
|
||||
self.label_pre_width = self.label_pre.width()
|
||||
self.label_pre_height = self.label_pre.height()
|
||||
|
||||
# Set mouse click event for original image label
|
||||
self.set_mouse_click_event()
|
||||
|
||||
def open_image(self):
|
||||
options = QtWidgets.QFileDialog.Options()
|
||||
filename, _ = QtWidgets.QFileDialog.getOpenFileName(self, "Open Image File", "",
|
||||
"Image Files (*.png *.jpg *.jpeg *.JPG *.JPEG *.PNG *.tiff)",
|
||||
options=options)
|
||||
if filename:
|
||||
image = cv2.imread(filename)
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
self.image_files.append(image)
|
||||
self.display_original_image()
|
||||
|
||||
def display_original_image(self):
|
||||
if self.image_files:
|
||||
image = self.image_files[self.current_index]
|
||||
height, width, channel = image.shape
|
||||
bytesPerLine = 3 * width
|
||||
qImg = QtGui.QImage(image.data, width, height, bytesPerLine, QtGui.QImage.Format_RGB888)
|
||||
pixmap = QtGui.QPixmap.fromImage(qImg)
|
||||
self.label_orign.setPixmap(pixmap.scaled(self.label_orign.size(), QtCore.Qt.KeepAspectRatio))
|
||||
|
||||
# Add mouse click event
|
||||
self.label_orign.mousePressEvent = self.mouse_click
|
||||
|
||||
# Draw marked points on the original image
|
||||
painter = QtGui.QPainter(self.label_orign.pixmap()) # Use label_orign for drawing points
|
||||
pen_foreground = QtGui.QPen(QtGui.QColor(0, 255, 0)) # Red color for foreground points
|
||||
pen_foreground.setWidth(5) # Set the width of the pen to 5 (adjust as needed)
|
||||
pen_background = QtGui.QPen(QtGui.QColor(255, 0, 0)) # Green color for background points
|
||||
pen_background.setWidth(5) # Set the width of the pen to 5 (adjust as needed)
|
||||
painter.setPen(pen_foreground)
|
||||
for point, label in zip(self.input_point, self.input_label):
|
||||
x, y = self.convert_to_label_coords(point)
|
||||
if label == 1: # Foreground point
|
||||
painter.drawPoint(QtCore.QPoint(x, y))
|
||||
painter.setPen(pen_background)
|
||||
for point, label in zip(self.input_point, self.input_label):
|
||||
x, y = self.convert_to_label_coords(point)
|
||||
if label == 0: # Background point
|
||||
painter.drawPoint(QtCore.QPoint(x, y))
|
||||
painter.end()
|
||||
|
||||
# Calculate coordinate scaling factors
|
||||
self.scale_x = width / self.label_orign.width()
|
||||
self.scale_y = height / self.label_orign.height()
|
||||
|
||||
def convert_to_label_coords(self, point):
|
||||
x = point[0] / self.scale_x
|
||||
y = point[1] / self.scale_y
|
||||
return x, y
|
||||
|
||||
def mouse_click(self, event):
|
||||
if not self.input_stop:
|
||||
x = int(event.pos().x() * self.scale_x)
|
||||
y = int(event.pos().y() * self.scale_y)
|
||||
if event.button() == QtCore.Qt.LeftButton: # If left-clicked, mark as foreground
|
||||
self.input_label.append(1) # Foreground label is 1
|
||||
elif event.button() == QtCore.Qt.RightButton: # If right-clicked, mark as background
|
||||
self.input_label.append(0) # Background label is 0
|
||||
|
||||
self.input_point.append([x, y])
|
||||
|
||||
# Update the original image with marked points
|
||||
self.display_original_image()
|
||||
|
||||
def predict_and_interact(self):
|
||||
if not self.image_files:
|
||||
return
|
||||
|
||||
image = self.image_files[self.current_index].copy()
|
||||
filename = f"image_{self.current_index}.png"
|
||||
image_crop = image.copy()
|
||||
|
||||
while True: # Outer loop for prediction
|
||||
# Prediction logic
|
||||
if not self.input_stop: # If not in interaction mode
|
||||
if len(self.input_point) > 0 and len(self.input_label) > 0:
|
||||
self.predictor.set_image(image)
|
||||
input_point_np = np.array(self.input_point)
|
||||
input_label_np = np.array(self.input_label)
|
||||
|
||||
masks, scores, logits = self.predictor.predict(
|
||||
point_coords=input_point_np,
|
||||
point_labels=input_label_np,
|
||||
multimask_output=True,
|
||||
)
|
||||
|
||||
mask_idx = 0
|
||||
num_masks = len(masks)
|
||||
|
||||
while True: # Inner loop for interaction
|
||||
color = tuple(np.random.randint(0, 256, 3).tolist())
|
||||
image_select = image.copy()
|
||||
selected_mask = masks[mask_idx]
|
||||
selected_image = self.apply_color_mask(image_select, selected_mask, color)
|
||||
mask_info = f'Total: {num_masks} | Current: {mask_idx} | Score: {scores[mask_idx]:.2f} | w Predict | d Next | a Previous | q Remove Last | s Save'
|
||||
cv2.putText(selected_image, mask_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255),
|
||||
2, cv2.LINE_AA)
|
||||
|
||||
# Display the predicted result in label_pre area
|
||||
self.display_prediction_image(selected_image)
|
||||
|
||||
key = cv2.waitKey(10)
|
||||
|
||||
# Handle key press events
|
||||
if key == ord('q') and len(self.input_point) > 0:
|
||||
self.input_point.pop(-1)
|
||||
self.input_label.pop(-1)
|
||||
self.display_original_image()
|
||||
elif key == ord('s'):
|
||||
self.save_masked_image(image_crop, selected_mask, filename)
|
||||
elif key == ord('a'):
|
||||
if mask_idx > 0:
|
||||
mask_idx -= 1
|
||||
else:
|
||||
mask_idx = num_masks - 1
|
||||
elif key == ord('d'):
|
||||
if mask_idx < num_masks - 1:
|
||||
mask_idx += 1
|
||||
else:
|
||||
mask_idx = 0
|
||||
elif key == ord(" "):
|
||||
break
|
||||
|
||||
if cv2.getWindowProperty("Prediction", cv2.WND_PROP_VISIBLE) < 1:
|
||||
break
|
||||
|
||||
# If 'w' is pressed, toggle interaction mode
|
||||
if key == ord('w'):
|
||||
self.input_stop = not self.input_stop # Toggle interaction mode
|
||||
if not self.input_stop: # If entering interaction mode
|
||||
self.interaction_count += 1
|
||||
if self.interaction_count % 2 == 0: # If even number of interactions, call the interaction function
|
||||
self.input_point = [] # Reset input points for the next interaction
|
||||
self.input_label = [] # Reset input labels for the next interaction
|
||||
self.display_original_image() # Display original image
|
||||
self.set_mouse_click_event() # Set mouse click event
|
||||
break # Exit outer loop
|
||||
else:
|
||||
continue # Continue prediction
|
||||
|
||||
# Exit the outer loop if not in interaction mode
|
||||
if not self.input_stop:
|
||||
break
|
||||
|
||||
def set_mouse_click_event(self):
|
||||
self.label_orign.mousePressEvent = self.mouse_click
|
||||
|
||||
def display_prediction_image(self, image):
|
||||
height, width, channel = image.shape
|
||||
bytesPerLine = 3 * width
|
||||
qImg = QtGui.QImage(image.data, width, height, bytesPerLine, QtGui.QImage.Format_RGB888)
|
||||
pixmap = QtGui.QPixmap.fromImage(qImg)
|
||||
self.label_pre.setPixmap(pixmap.scaled(self.label_pre.size(), QtCore.Qt.KeepAspectRatio))
|
||||
|
||||
# Draw marked points on the predicted image
|
||||
painter = QtGui.QPainter(self.label_pre.pixmap())
|
||||
pen_foreground = QtGui.QPen(QtGui.QColor(0, 255, 0)) # Red color for foreground points
|
||||
pen_foreground.setWidth(5) # Set the width of the pen to 5 (adjust as needed)
|
||||
pen_background = QtGui.QPen(QtGui.QColor(255, 0, 0)) # Green color for background points
|
||||
pen_background.setWidth(5) # Set the width of the pen to 5 (adjust as needed)
|
||||
painter.setPen(pen_foreground)
|
||||
for point, label in zip(self.input_point, self.input_label):
|
||||
x, y = self.convert_to_label_coords(point)
|
||||
if label == 1: # Foreground point
|
||||
painter.drawPoint(QtCore.QPoint(x, y))
|
||||
painter.setPen(pen_background)
|
||||
for point, label in zip(self.input_point, self.input_label):
|
||||
x, y = self.convert_to_label_coords(point)
|
||||
if label == 0: # Background point
|
||||
painter.drawPoint(QtCore.QPoint(x, y))
|
||||
painter.end()
|
||||
|
||||
def apply_color_mask(self, image, mask, color=(0, 255, 0), color_dark=0.5):
|
||||
for c in range(3):
|
||||
image[:, :, c] = np.where(mask == 1, image[:, :, c] * (1 - color_dark) + color_dark * color[c],
|
||||
image[:, :, c])
|
||||
return image
|
||||
|
||||
def save_masked_image(self, image, mask, filename):
|
||||
output_dir = os.path.dirname(filename)
|
||||
filename = os.path.basename(filename)
|
||||
filename = filename[:filename.rfind('.')] + '_masked.png'
|
||||
new_filename = os.path.join(output_dir, filename)
|
||||
|
||||
masked_image = self.apply_color_mask(image, mask)
|
||||
cv2.imwrite(new_filename, cv2.cvtColor(masked_image, cv2.COLOR_RGB2BGR))
|
||||
print(f"Saved as {new_filename}")
|
||||
|
||||
def previous_image(self):
|
||||
if self.current_index > 0:
|
||||
self.current_index -= 1
|
||||
self.display_original_image()
|
||||
|
||||
def next_image(self):
|
||||
if self.current_index < len(self.image_files) - 1:
|
||||
self.current_index += 1
|
||||
self.display_original_image()
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
app = QtWidgets.QApplication(sys.argv)
|
||||
window = MainWindow()
|
||||
window.show()
|
||||
sys.exit(app.exec_())
|
||||
|
||||
|
|
@ -0,0 +1,378 @@
|
|||
import os
|
||||
import sys
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PyQt5 import QtCore, QtGui, QtWidgets
|
||||
from segment_anything import sam_model_registry, SamPredictor
|
||||
|
||||
class Ui_MainWindow(object):
|
||||
def setupUi(self, MainWindow):
|
||||
MainWindow.setObjectName("MainWindow")
|
||||
MainWindow.resize(1140, 450)
|
||||
MainWindow.setMinimumSize(QtCore.QSize(1140, 450))
|
||||
MainWindow.setMaximumSize(QtCore.QSize(1140, 450))
|
||||
self.centralwidget = QtWidgets.QWidget(MainWindow)
|
||||
self.centralwidget.setObjectName("centralwidget")
|
||||
self.pushButton_w = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_w.setGeometry(QtCore.QRect(10, 90, 151, 51))
|
||||
self.pushButton_w.setObjectName("pushButton_w")
|
||||
self.pushButton_a = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_a.setGeometry(QtCore.QRect(10, 160, 71, 51))
|
||||
self.pushButton_a.setObjectName("pushButton_a")
|
||||
self.pushButton_d = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_d.setGeometry(QtCore.QRect(90, 160, 71, 51))
|
||||
self.pushButton_d.setObjectName("pushButton_d")
|
||||
self.pushButton_s = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_s.setGeometry(QtCore.QRect(10, 360, 151, 51))
|
||||
self.pushButton_s.setObjectName("pushButton_s")
|
||||
self.pushButton_5 = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_5.setGeometry(QtCore.QRect(10, 230, 151, 51))
|
||||
self.pushButton_5.setObjectName("pushButton_5")
|
||||
self.label_orign = QtWidgets.QLabel(self.centralwidget)
|
||||
self.label_orign.setGeometry(QtCore.QRect(180, 20, 471, 401))
|
||||
self.label_orign.setStyleSheet("background-color: rgb(255, 255, 255);")
|
||||
self.label_orign.setObjectName("label_orign")
|
||||
self.label_2 = QtWidgets.QLabel(self.centralwidget)
|
||||
self.label_2.setGeometry(QtCore.QRect(660, 20, 471, 401))
|
||||
self.label_2.setStyleSheet("background-color: rgb(255, 255, 255);")
|
||||
self.label_2.setObjectName("label_2")
|
||||
self.pushButton_w_2 = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_w_2.setGeometry(QtCore.QRect(10, 20, 151, 51))
|
||||
self.pushButton_w_2.setObjectName("pushButton_w_2")
|
||||
self.lineEdit = QtWidgets.QLineEdit(self.centralwidget)
|
||||
self.lineEdit.setGeometry(QtCore.QRect(50, 290, 81, 21))
|
||||
self.lineEdit.setObjectName("lineEdit")
|
||||
self.horizontalSlider = QtWidgets.QSlider(self.centralwidget)
|
||||
self.horizontalSlider.setGeometry(QtCore.QRect(10, 320, 141, 22))
|
||||
self.horizontalSlider.setSliderPosition(50)
|
||||
self.horizontalSlider.setOrientation(QtCore.Qt.Horizontal)
|
||||
self.horizontalSlider.setTickInterval(0)
|
||||
self.horizontalSlider.setObjectName("horizontalSlider")
|
||||
MainWindow.setCentralWidget(self.centralwidget)
|
||||
self.menubar = QtWidgets.QMenuBar(MainWindow)
|
||||
self.menubar.setGeometry(QtCore.QRect(0, 0, 1140, 23))
|
||||
self.menubar.setObjectName("menubar")
|
||||
MainWindow.setMenuBar(self.menubar)
|
||||
self.statusbar = QtWidgets.QStatusBar(MainWindow)
|
||||
self.statusbar.setObjectName("statusbar")
|
||||
MainWindow.setStatusBar(self.statusbar)
|
||||
|
||||
self.retranslateUi(MainWindow)
|
||||
QtCore.QMetaObject.connectSlotsByName(MainWindow)
|
||||
|
||||
def retranslateUi(self, MainWindow):
|
||||
_translate = QtCore.QCoreApplication.translate
|
||||
MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow"))
|
||||
self.pushButton_w.setText(_translate("MainWindow", "Predict"))
|
||||
self.pushButton_a.setText(_translate("MainWindow", "Pre"))
|
||||
self.pushButton_d.setText(_translate("MainWindow", "Next"))
|
||||
self.pushButton_s.setText(_translate("MainWindow", "Save"))
|
||||
self.pushButton_5.setText(_translate("MainWindow", "背景图"))
|
||||
self.label_orign.setText(_translate("MainWindow", "<html><head/><body><p align=\"center\">原始图像</p></body></html>"))
|
||||
self.label_2.setText(_translate("MainWindow", "<html><head/><body><p align=\"center\">预测图像</p></body></html>"))
|
||||
self.pushButton_w_2.setText(_translate("MainWindow", "Openimg"))
|
||||
self.lineEdit.setText(_translate("MainWindow", "改变mask大小"))
|
||||
|
||||
class MyMainWindow(QtWidgets.QMainWindow, Ui_MainWindow):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.setupUi(self)
|
||||
|
||||
self.image_path = ""
|
||||
self.image_folder = ""
|
||||
self.image_files = []
|
||||
self.current_index = 0
|
||||
self.input_stop = False # 在这里初始化 input_stop
|
||||
|
||||
self.pushButton_w_2.clicked.connect(self.open_image_folder)
|
||||
self.pushButton_a.clicked.connect(self.load_previous_image)
|
||||
self.pushButton_d.clicked.connect(self.load_next_image)
|
||||
|
||||
|
||||
def open_image_folder(self):
|
||||
folder_dialog = QtWidgets.QFileDialog()
|
||||
folder_path = folder_dialog.getExistingDirectory(self, 'Open Image Folder', '')
|
||||
if folder_path:
|
||||
self.image_folder = folder_path
|
||||
self.image_files = self.get_image_files(self.image_folder)
|
||||
if self.image_files:
|
||||
self.show_image_selection_dialog()
|
||||
|
||||
def load_previous_image(self):
|
||||
if self.image_files:
|
||||
if self.current_index > 0:
|
||||
self.current_index -= 1
|
||||
else:
|
||||
self.current_index = len(self.image_files) - 1
|
||||
self.show_image()
|
||||
|
||||
def load_next_image(self):
|
||||
if self.image_files:
|
||||
if self.current_index < len(self.image_files) - 1:
|
||||
self.current_index += 1
|
||||
else:
|
||||
self.current_index = 0
|
||||
self.show_image()
|
||||
|
||||
def get_image_files(self, folder_path):
|
||||
image_files = [file for file in os.listdir(folder_path) if file.endswith(('png', 'jpg', 'jpeg', 'bmp'))]
|
||||
return image_files
|
||||
|
||||
def show_image_selection_dialog(self):
|
||||
dialog = QtWidgets.QDialog(self)
|
||||
dialog.setWindowTitle("Select Image")
|
||||
layout = QtWidgets.QVBoxLayout()
|
||||
|
||||
self.listWidget = QtWidgets.QListWidget()
|
||||
for image_file in self.image_files:
|
||||
item = QtWidgets.QListWidgetItem(image_file)
|
||||
pixmap = QtGui.QPixmap(os.path.join(self.image_folder, image_file)).scaledToWidth(100)
|
||||
item.setIcon(QtGui.QIcon(pixmap))
|
||||
self.listWidget.addItem(item)
|
||||
self.listWidget.itemDoubleClicked.connect(self.image_selected)
|
||||
layout.addWidget(self.listWidget)
|
||||
|
||||
buttonBox = QtWidgets.QDialogButtonBox(QtWidgets.QDialogButtonBox.Ok | QtWidgets.QDialogButtonBox.Cancel)
|
||||
buttonBox.accepted.connect(self.image_selected)
|
||||
buttonBox.rejected.connect(dialog.reject)
|
||||
layout.addWidget(buttonBox)
|
||||
|
||||
dialog.setLayout(layout)
|
||||
|
||||
dialog.exec_()
|
||||
|
||||
def image_selected(self):
|
||||
selected_item = self.listWidget.currentItem()
|
||||
if selected_item:
|
||||
selected_index = self.listWidget.currentRow()
|
||||
if selected_index >= 0 and selected_index < len(self.image_files): # 检查索引是否在有效范围内
|
||||
self.current_index = selected_index
|
||||
self.show_image() # 显示所选图像
|
||||
# 调用OpenCV窗口显示
|
||||
self.call_opencv_interaction(os.path.join(self.image_folder, self.image_files[self.current_index]))
|
||||
|
||||
def show_image(self):
|
||||
if self.image_files and self.current_index < len(self.image_files):
|
||||
file_path = os.path.join(self.image_folder, self.image_files[self.current_index])
|
||||
pixmap = QtGui.QPixmap(file_path)
|
||||
self.label_orign.setPixmap(pixmap.scaled(self.label_orign.size(), QtCore.Qt.KeepAspectRatio))
|
||||
|
||||
def call_opencv_interaction(self, image_path):
|
||||
input_dir = os.path.dirname(image_path)
|
||||
image_orign = cv2.imread(image_path)
|
||||
output_dir = r'D:\Program Files\Pycharm items\segment-anything-model\scripts\output\maskt'
|
||||
crop_mode = True
|
||||
|
||||
print('最好是每加一个点就按w键predict一次')
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
image_files = [self.image_files[self.current_index]]
|
||||
|
||||
sam = sam_model_registry["vit_b"](
|
||||
checkpoint=r"D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth")
|
||||
_ = sam.to(device="cuda")
|
||||
predictor = SamPredictor(sam)
|
||||
|
||||
WINDOW_WIDTH = 1280
|
||||
WINDOW_HEIGHT = 720
|
||||
|
||||
def apply_mask(image, mask, alpha_channel=True):
|
||||
if alpha_channel:
|
||||
alpha = np.zeros_like(image[..., 0])
|
||||
alpha[mask == 1] = 255
|
||||
image = cv2.merge((image[..., 0], image[..., 1], image[..., 2], alpha))
|
||||
else:
|
||||
image = np.where(mask[..., None] == 1, image, 0)
|
||||
return image
|
||||
|
||||
def apply_color_mask(image, mask, color, color_dark=0.5):
|
||||
for c in range(3):
|
||||
image[:, :, c] = np.where(mask == 1, image[:, :, c] * (1 - color_dark) + color_dark * color[c],
|
||||
image[:, :, c])
|
||||
return image
|
||||
|
||||
def get_next_filename(base_path, filename):
|
||||
name, ext = os.path.splitext(filename)
|
||||
for i in range(1, 101):
|
||||
new_name = f"{name}_{i}{ext}"
|
||||
if not os.path.exists(os.path.join(base_path, new_name)):
|
||||
return new_name
|
||||
return None
|
||||
|
||||
def save_masked_image(image, mask, output_dir, filename, crop_mode_):
|
||||
# 保存图像到指定路径
|
||||
if crop_mode_:
|
||||
# 如果采用了裁剪模式,则裁剪图像
|
||||
y, x = np.where(mask)
|
||||
y_min, y_max, x_min, x_max = y.min(), y.max(), x.min(), x.max()
|
||||
cropped_mask = mask[y_min:y_max + 1, x_min:x_max + 1]
|
||||
cropped_image = image[y_min:y_max + 1, x_min:x_max + 1]
|
||||
masked_image = apply_mask(cropped_image, cropped_mask)
|
||||
else:
|
||||
masked_image = apply_mask(image, mask)
|
||||
filename = filename[:filename.rfind('.')] + '.png'
|
||||
new_filename = get_next_filename(output_dir, filename)
|
||||
|
||||
if new_filename:
|
||||
if masked_image.shape[-1] == 4:
|
||||
cv2.imwrite(os.path.join(output_dir, new_filename), masked_image,
|
||||
[cv2.IMWRITE_PNG_COMPRESSION, 9])
|
||||
else:
|
||||
cv2.imwrite(os.path.join(output_dir, new_filename), masked_image)
|
||||
print(f"Saved as {new_filename}")
|
||||
|
||||
# 读取保存的图像文件
|
||||
saved_image_path = os.path.join(output_dir, new_filename)
|
||||
saved_image_pixmap = QtGui.QPixmap(saved_image_path)
|
||||
|
||||
# 将保存的图像显示在预测图像区域
|
||||
mainWindow.label_2.setPixmap(
|
||||
saved_image_pixmap.scaled(mainWindow.label_2.size(), QtCore.Qt.KeepAspectRatio))
|
||||
else:
|
||||
print("Could not save the image. Too many variations exist.")
|
||||
|
||||
current_index = 0
|
||||
|
||||
cv2.namedWindow("image", cv2.WINDOW_NORMAL)
|
||||
cv2.resizeWindow("image", WINDOW_WIDTH, WINDOW_HEIGHT)
|
||||
cv2.moveWindow("image", (1920 - WINDOW_WIDTH) // 2, (1080 - WINDOW_HEIGHT) // 2)
|
||||
|
||||
def mouse_click(event, x, y, flags, param):
|
||||
if not self.input_stop:
|
||||
if event == cv2.EVENT_LBUTTONDOWN:
|
||||
input_point.append([x, y])
|
||||
input_label.append(1)
|
||||
elif event == cv2.EVENT_RBUTTONDOWN:
|
||||
input_point.append([x, y])
|
||||
input_label.append(0)
|
||||
else:
|
||||
if event == cv2.EVENT_LBUTTONDOWN or event == cv2.EVENT_RBUTTONDOWN:
|
||||
print('此时不能添加点,按w退出mask选择模式')
|
||||
|
||||
cv2.setMouseCallback("image", mouse_click)
|
||||
input_point = []
|
||||
input_label = []
|
||||
input_stop = False
|
||||
while True:
|
||||
filename = self.image_files[self.current_index]
|
||||
image_orign = cv2.imread(os.path.join(input_dir, filename))
|
||||
image_crop = image_orign.copy()
|
||||
image = cv2.cvtColor(image_orign.copy(), cv2.COLOR_BGR2RGB)
|
||||
selected_mask = None
|
||||
logit_input = None
|
||||
while True:
|
||||
image_display = image_orign.copy()
|
||||
display_info = f'{filename} '
|
||||
cv2.putText(image_display, display_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255),
|
||||
2,
|
||||
cv2.LINE_AA)
|
||||
for point, label in zip(input_point, input_label):
|
||||
color = (0, 255, 0) if label == 1 else (0, 0, 255)
|
||||
cv2.circle(image_display, tuple(point), 5, color, -1)
|
||||
if selected_mask is not None:
|
||||
color = tuple(np.random.randint(0, 256, 3).tolist())
|
||||
selected_image = apply_color_mask(image_display, selected_mask, color)
|
||||
|
||||
cv2.imshow("image", image_display)
|
||||
|
||||
key = cv2.waitKey(1)
|
||||
|
||||
if key == ord(" "):
|
||||
input_point = []
|
||||
input_label = []
|
||||
selected_mask = None
|
||||
logit_input = None
|
||||
elif key == ord("w"):
|
||||
input_stop = True
|
||||
if len(input_point) > 0 and len(input_label) > 0:
|
||||
|
||||
predictor.set_image(image)
|
||||
input_point_np = np.array(input_point)
|
||||
input_label_np = np.array(input_label)
|
||||
|
||||
masks, scores, logits = predictor.predict(
|
||||
point_coords=input_point_np,
|
||||
point_labels=input_label_np,
|
||||
mask_input=logit_input[None, :, :] if logit_input is not None else None,
|
||||
multimask_output=True,
|
||||
)
|
||||
|
||||
mask_idx = 0
|
||||
num_masks = len(masks)
|
||||
|
||||
prediction_window_name = "Prediction"
|
||||
cv2.namedWindow(prediction_window_name, cv2.WINDOW_NORMAL)
|
||||
cv2.resizeWindow(prediction_window_name, WINDOW_WIDTH, WINDOW_HEIGHT)
|
||||
cv2.moveWindow(prediction_window_name, (1920 - WINDOW_WIDTH) // 2,
|
||||
(1080 - WINDOW_HEIGHT) // 2)
|
||||
|
||||
while True:
|
||||
color = tuple(np.random.randint(0, 256, 3).tolist())
|
||||
image_select = image_orign.copy()
|
||||
selected_mask = masks[mask_idx]
|
||||
selected_image = apply_color_mask(image_select, selected_mask, color)
|
||||
mask_info = f'Total: {num_masks} | Current: {mask_idx} | Score: {scores[mask_idx]:.2f} | w 预测 | d 切换下一个 | a 切换上一个 |q 移除最后一个'
|
||||
cv2.putText(selected_image, mask_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7,
|
||||
(0, 255, 255), 2, cv2.LINE_AA)
|
||||
|
||||
cv2.imshow(prediction_window_name, selected_image)
|
||||
|
||||
key = cv2.waitKey(10)
|
||||
if key == ord('q') and len(input_point) > 0:
|
||||
input_point.pop(-1)
|
||||
elif key == ord('s'):
|
||||
save_masked_image(image_crop, selected_mask, output_dir, filename,
|
||||
crop_mode_=crop_mode)
|
||||
elif key == ord('a'):
|
||||
if mask_idx > 0:
|
||||
mask_idx -= 1
|
||||
else:
|
||||
mask_idx = num_masks - 1
|
||||
elif key == ord('d'):
|
||||
if mask_idx < num_masks - 1:
|
||||
mask_idx += 1
|
||||
else:
|
||||
mask_idx = 0
|
||||
elif key == ord('w'):
|
||||
input_stop = False # Allow adding points again
|
||||
break
|
||||
elif key == ord(" "):
|
||||
input_point = []
|
||||
input_label = []
|
||||
selected_mask = None
|
||||
logit_input = None
|
||||
break
|
||||
logit_input = logits[mask_idx, :, :]
|
||||
print('max score:', np.argmax(scores), ' select:', mask_idx)
|
||||
|
||||
elif key == ord('a'):
|
||||
current_index = max(0, current_index - 1)
|
||||
input_point = []
|
||||
input_label = []
|
||||
break
|
||||
elif key == ord('d'):
|
||||
current_index = min(len(image_files) - 1, current_index + 1)
|
||||
input_point = []
|
||||
input_label = []
|
||||
break
|
||||
elif key == 27:
|
||||
break
|
||||
elif key == ord('q') and len(input_point) > 0:
|
||||
input_point.pop(-1)
|
||||
input_label.pop(-1)
|
||||
elif key == ord('s') and selected_mask is not None:
|
||||
save_masked_image(image_crop, selected_mask, output_dir, filename, crop_mode_=crop_mode)
|
||||
|
||||
if key == 27:
|
||||
break
|
||||
|
||||
cv2.destroyAllWindows() # Close all windows before exiting
|
||||
if key == 27:
|
||||
break
|
||||
|
||||
if __name__ == "__main__":
|
||||
app = QtWidgets.QApplication(sys.argv)
|
||||
mainWindow = MyMainWindow()
|
||||
mainWindow.show()
|
||||
sys.exit(app.exec_())
|
||||
|
||||
|
|
@ -0,0 +1,453 @@
|
|||
import os
|
||||
import sys
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PyQt5 import QtCore, QtGui, QtWidgets
|
||||
from segment_anything import sam_model_registry, SamPredictor
|
||||
|
||||
class Ui_MainWindow(object):
|
||||
def setupUi(self, MainWindow):
|
||||
MainWindow.setObjectName("MainWindow")
|
||||
MainWindow.resize(1140, 450)
|
||||
MainWindow.setMinimumSize(QtCore.QSize(1140, 450))
|
||||
MainWindow.setMaximumSize(QtCore.QSize(1140, 450))
|
||||
self.centralwidget = QtWidgets.QWidget(MainWindow)
|
||||
self.centralwidget.setObjectName("centralwidget")
|
||||
self.pushButton_w = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_w.setGeometry(QtCore.QRect(10, 90, 151, 51))
|
||||
self.pushButton_w.setObjectName("pushButton_w")
|
||||
self.pushButton_a = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_a.setGeometry(QtCore.QRect(10, 160, 71, 51))
|
||||
self.pushButton_a.setObjectName("pushButton_a")
|
||||
self.pushButton_d = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_d.setGeometry(QtCore.QRect(90, 160, 71, 51))
|
||||
self.pushButton_d.setObjectName("pushButton_d")
|
||||
self.pushButton_s = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_s.setGeometry(QtCore.QRect(10, 360, 151, 51))
|
||||
self.pushButton_s.setObjectName("pushButton_s")
|
||||
self.pushButton_5 = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_5.setGeometry(QtCore.QRect(10, 230, 151, 51))
|
||||
self.pushButton_5.setObjectName("pushButton_5")
|
||||
self.label_orign = QtWidgets.QLabel(self.centralwidget)
|
||||
self.label_orign.setGeometry(QtCore.QRect(180, 20, 471, 401))
|
||||
self.label_orign.setStyleSheet("background-color: rgb(255, 255, 255);")
|
||||
self.label_orign.setObjectName("label_orign")
|
||||
self.label_2 = QtWidgets.QLabel(self.centralwidget)
|
||||
self.label_2.setGeometry(QtCore.QRect(660, 20, 471, 401))
|
||||
self.label_2.setStyleSheet("background-color: rgb(255, 255, 255);")
|
||||
self.label_2.setObjectName("label_2")
|
||||
self.pushButton_w_2 = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_w_2.setGeometry(QtCore.QRect(10, 20, 151, 51))
|
||||
self.pushButton_w_2.setObjectName("pushButton_w_2")
|
||||
self.lineEdit = QtWidgets.QLineEdit(self.centralwidget)
|
||||
self.lineEdit.setGeometry(QtCore.QRect(50, 290, 81, 21))
|
||||
self.lineEdit.setObjectName("lineEdit")
|
||||
self.horizontalSlider = QtWidgets.QSlider(self.centralwidget)
|
||||
self.horizontalSlider.setGeometry(QtCore.QRect(10, 320, 141, 22))
|
||||
self.horizontalSlider.setRange(0, 10) # 将范围设置为从0到最大值
|
||||
self.horizontalSlider.setSingleStep(1)
|
||||
self.horizontalSlider.setValue(0) # 初始值设为0
|
||||
self.horizontalSlider.setOrientation(QtCore.Qt.Horizontal)
|
||||
self.horizontalSlider.setTickInterval(0)
|
||||
self.horizontalSlider.setObjectName("horizontalSlider")
|
||||
MainWindow.setCentralWidget(self.centralwidget)
|
||||
self.menubar = QtWidgets.QMenuBar(MainWindow)
|
||||
self.menubar.setGeometry(QtCore.QRect(0, 0, 1140, 23))
|
||||
self.menubar.setObjectName("menubar")
|
||||
MainWindow.setMenuBar(self.menubar)
|
||||
self.statusbar = QtWidgets.QStatusBar(MainWindow)
|
||||
self.statusbar.setObjectName("statusbar")
|
||||
MainWindow.setStatusBar(self.statusbar)
|
||||
|
||||
self.retranslateUi(MainWindow)
|
||||
QtCore.QMetaObject.connectSlotsByName(MainWindow)
|
||||
|
||||
def retranslateUi(self, MainWindow):
|
||||
_translate = QtCore.QCoreApplication.translate
|
||||
MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow"))
|
||||
self.pushButton_w.setText(_translate("MainWindow", "获取json文件"))
|
||||
self.pushButton_a.setText(_translate("MainWindow", "Pre"))
|
||||
self.pushButton_d.setText(_translate("MainWindow", "Next"))
|
||||
self.pushButton_s.setText(_translate("MainWindow", "Save"))
|
||||
self.pushButton_5.setText(_translate("MainWindow", "背景图"))
|
||||
self.label_orign.setText(_translate("MainWindow", "<html><head/><body><p align=\"center\">原始图像</p></body></html>"))
|
||||
self.label_2.setText(_translate("MainWindow", "<html><head/><body><p align=\"center\">预测图像</p></body></html>"))
|
||||
self.pushButton_w_2.setText(_translate("MainWindow", "Openimg"))
|
||||
self.lineEdit.setText(_translate("MainWindow", "改变mask大小"))
|
||||
|
||||
class MyMainWindow(QtWidgets.QMainWindow, Ui_MainWindow):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.setupUi(self)
|
||||
|
||||
self.k = 0
|
||||
self.last_value = 0 # 保存上一次滑块值
|
||||
|
||||
self.image_path = ""
|
||||
self.image_folder = ""
|
||||
self.image_files = []
|
||||
self.current_index = 0
|
||||
self.input_stop = False # 在这里初始化 input_stop
|
||||
|
||||
self.pushButton_w_2.clicked.connect(self.open_image_folder)
|
||||
self.pushButton_a.clicked.connect(self.load_previous_image)
|
||||
self.pushButton_d.clicked.connect(self.load_next_image)
|
||||
self.pushButton_s.clicked.connect(self.save_prediction)
|
||||
self.pushButton_5.clicked.connect(self.select_background_image)
|
||||
self.horizontalSlider.valueChanged.connect(self.adjust_prediction_size) # 连接水平滑块的值改变信号
|
||||
|
||||
def adjust_pixmap_size(self, pixmap, scale_factor):
|
||||
scaled_size = QtCore.QSize(pixmap.size().width() * scale_factor / 100,
|
||||
pixmap.size().height() * scale_factor / 100)
|
||||
return pixmap.scaled(scaled_size, QtCore.Qt.KeepAspectRatio)
|
||||
|
||||
|
||||
def open_image_folder(self):
|
||||
folder_dialog = QtWidgets.QFileDialog()
|
||||
folder_path = folder_dialog.getExistingDirectory(self, 'Open Image Folder', '')
|
||||
if folder_path:
|
||||
self.image_folder = folder_path
|
||||
self.image_files = self.get_image_files(self.image_folder)
|
||||
if self.image_files:
|
||||
self.show_image_selection_dialog()
|
||||
|
||||
def load_previous_image(self):
|
||||
if self.image_files:
|
||||
if self.current_index > 0:
|
||||
self.current_index -= 1
|
||||
else:
|
||||
self.current_index = len(self.image_files) - 1
|
||||
self.show_image()
|
||||
|
||||
def load_next_image(self):
|
||||
if self.image_files:
|
||||
if self.current_index < len(self.image_files) - 1:
|
||||
self.current_index += 1
|
||||
else:
|
||||
self.current_index = 0
|
||||
self.show_image()
|
||||
|
||||
def get_image_files(self, folder_path):
|
||||
image_files = [file for file in os.listdir(folder_path) if file.endswith(('png', 'jpg', 'jpeg', 'bmp'))]
|
||||
return image_files
|
||||
|
||||
def show_image_selection_dialog(self):
|
||||
dialog = QtWidgets.QDialog(self)
|
||||
dialog.setWindowTitle("Select Image")
|
||||
layout = QtWidgets.QVBoxLayout()
|
||||
|
||||
self.listWidget = QtWidgets.QListWidget()
|
||||
for image_file in self.image_files:
|
||||
item = QtWidgets.QListWidgetItem(image_file)
|
||||
pixmap = QtGui.QPixmap(os.path.join(self.image_folder, image_file)).scaledToWidth(100)
|
||||
item.setIcon(QtGui.QIcon(pixmap))
|
||||
self.listWidget.addItem(item)
|
||||
self.listWidget.itemDoubleClicked.connect(self.image_selected)
|
||||
layout.addWidget(self.listWidget)
|
||||
|
||||
buttonBox = QtWidgets.QDialogButtonBox(QtWidgets.QDialogButtonBox.Ok | QtWidgets.QDialogButtonBox.Cancel)
|
||||
buttonBox.accepted.connect(self.image_selected)
|
||||
buttonBox.rejected.connect(dialog.reject)
|
||||
layout.addWidget(buttonBox)
|
||||
|
||||
dialog.setLayout(layout)
|
||||
|
||||
dialog.exec_()
|
||||
|
||||
def image_selected(self):
|
||||
selected_item = self.listWidget.currentItem()
|
||||
if selected_item:
|
||||
selected_index = self.listWidget.currentRow()
|
||||
if selected_index >= 0 and selected_index < len(self.image_files): # 检查索引是否在有效范围内
|
||||
self.current_index = selected_index
|
||||
self.show_image() # 显示所选图像
|
||||
# 调用OpenCV窗口显示
|
||||
self.call_opencv_interaction(os.path.join(self.image_folder, self.image_files[self.current_index]))
|
||||
|
||||
def select_background_image(self):
|
||||
file_dialog = QtWidgets.QFileDialog()
|
||||
image_path, _ = file_dialog.getOpenFileName(self, 'Select Background Image', '',
|
||||
'Image Files (*.png *.jpg *.jpeg *.bmp)')
|
||||
if image_path:
|
||||
self.show_background_image(image_path)
|
||||
|
||||
def show_background_image(self, image_path):
|
||||
pixmap = QtGui.QPixmap(image_path)
|
||||
current_pixmap = self.label_2.pixmap()
|
||||
if current_pixmap:
|
||||
current_pixmap = QtGui.QPixmap(current_pixmap)
|
||||
scene = QtWidgets.QGraphicsScene()
|
||||
scene.addPixmap(pixmap)
|
||||
scene.addPixmap(current_pixmap)
|
||||
merged_pixmap = QtGui.QPixmap(scene.sceneRect().size().toSize())
|
||||
merged_pixmap.fill(QtCore.Qt.transparent)
|
||||
painter = QtGui.QPainter(merged_pixmap)
|
||||
scene.render(painter)
|
||||
painter.end()
|
||||
self.label_2.setPixmap(merged_pixmap)
|
||||
else:
|
||||
self.label_2.setPixmap(pixmap.scaled(self.label_2.size(), QtCore.Qt.KeepAspectRatio))
|
||||
|
||||
def show_image(self):
|
||||
if self.image_files and self.current_index < len(self.image_files):
|
||||
file_path = os.path.join(self.image_folder, self.image_files[self.current_index])
|
||||
pixmap = QtGui.QPixmap(file_path)
|
||||
self.label_orign.setPixmap(pixmap.scaled(self.label_orign.size(), QtCore.Qt.KeepAspectRatio))
|
||||
|
||||
def call_opencv_interaction(self, image_path):
|
||||
input_dir = os.path.dirname(image_path)
|
||||
image_orign = cv2.imread(image_path)
|
||||
output_dir = r'D:\Program Files\Pycharm items\segment-anything-model\scripts\output\maskt'
|
||||
crop_mode = True
|
||||
|
||||
print('最好是每加一个点就按w键predict一次')
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
image_files = [self.image_files[self.current_index]]
|
||||
|
||||
sam = sam_model_registry["vit_b"](
|
||||
checkpoint=r"D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth")
|
||||
_ = sam.to(device="cuda")
|
||||
predictor = SamPredictor(sam)
|
||||
|
||||
WINDOW_WIDTH = 1280
|
||||
WINDOW_HEIGHT = 720
|
||||
|
||||
def apply_mask(image, mask, alpha_channel=True):
|
||||
if alpha_channel:
|
||||
alpha = np.zeros_like(image[..., 0])
|
||||
alpha[mask == 1] = 255
|
||||
image = cv2.merge((image[..., 0], image[..., 1], image[..., 2], alpha))
|
||||
else:
|
||||
image = np.where(mask[..., None] == 1, image, 0)
|
||||
return image
|
||||
|
||||
def apply_color_mask(image, mask, color, color_dark=0.5):
|
||||
for c in range(3):
|
||||
image[:, :, c] = np.where(mask == 1, image[:, :, c] * (1 - color_dark) + color_dark * color[c],
|
||||
image[:, :, c])
|
||||
return image
|
||||
|
||||
def get_next_filename(base_path, filename):
|
||||
name, ext = os.path.splitext(filename)
|
||||
for i in range(1, 101):
|
||||
new_name = f"{name}_{i}{ext}"
|
||||
if not os.path.exists(os.path.join(base_path, new_name)):
|
||||
return new_name
|
||||
return None
|
||||
|
||||
def save_masked_image(image, mask, output_dir, filename, crop_mode_):
|
||||
# 保存图像到指定路径
|
||||
if crop_mode_:
|
||||
# 如果采用了裁剪模式,则裁剪图像
|
||||
y, x = np.where(mask)
|
||||
y_min, y_max, x_min, x_max = y.min(), y.max(), x.min(), x.max()
|
||||
cropped_mask = mask[y_min:y_max + 1, x_min:x_max + 1]
|
||||
cropped_image = image[y_min:y_max + 1, x_min:x_max + 1]
|
||||
masked_image = apply_mask(cropped_image, cropped_mask)
|
||||
else:
|
||||
masked_image = apply_mask(image, mask)
|
||||
filename = filename[:filename.rfind('.')] + '.png'
|
||||
new_filename = get_next_filename(output_dir, filename)
|
||||
|
||||
if new_filename:
|
||||
if masked_image.shape[-1] == 4:
|
||||
cv2.imwrite(os.path.join(output_dir, new_filename), masked_image,
|
||||
[cv2.IMWRITE_PNG_COMPRESSION, 9])
|
||||
else:
|
||||
cv2.imwrite(os.path.join(output_dir, new_filename), masked_image)
|
||||
print(f"Saved as {new_filename}")
|
||||
|
||||
# 读取保存的图像文件
|
||||
saved_image_path = os.path.join(output_dir, new_filename)
|
||||
saved_image_pixmap = QtGui.QPixmap(saved_image_path)
|
||||
|
||||
# 将保存的图像显示在预测图像区域
|
||||
mainWindow.label_2.setPixmap(
|
||||
saved_image_pixmap.scaled(mainWindow.label_2.size(), QtCore.Qt.KeepAspectRatio))
|
||||
else:
|
||||
print("Could not save the image. Too many variations exist.")
|
||||
|
||||
current_index = 0
|
||||
|
||||
cv2.namedWindow("image", cv2.WINDOW_NORMAL)
|
||||
cv2.resizeWindow("image", WINDOW_WIDTH, WINDOW_HEIGHT)
|
||||
cv2.moveWindow("image", (1920 - WINDOW_WIDTH) // 2, (1080 - WINDOW_HEIGHT) // 2)
|
||||
|
||||
def mouse_click(event, x, y, flags, param):
|
||||
if not self.input_stop:
|
||||
if event == cv2.EVENT_LBUTTONDOWN:
|
||||
input_point.append([x, y])
|
||||
input_label.append(1)
|
||||
elif event == cv2.EVENT_RBUTTONDOWN:
|
||||
input_point.append([x, y])
|
||||
input_label.append(0)
|
||||
else:
|
||||
if event == cv2.EVENT_LBUTTONDOWN or event == cv2.EVENT_RBUTTONDOWN:
|
||||
print('此时不能添加点,按w退出mask选择模式')
|
||||
|
||||
cv2.setMouseCallback("image", mouse_click)
|
||||
input_point = []
|
||||
input_label = []
|
||||
input_stop = False
|
||||
while True:
|
||||
filename = self.image_files[self.current_index]
|
||||
image_orign = cv2.imread(os.path.join(input_dir, filename))
|
||||
image_crop = image_orign.copy()
|
||||
image = cv2.cvtColor(image_orign.copy(), cv2.COLOR_BGR2RGB)
|
||||
selected_mask = None
|
||||
logit_input = None
|
||||
while True:
|
||||
image_display = image_orign.copy()
|
||||
display_info = f'{filename} '
|
||||
cv2.putText(image_display, display_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255),
|
||||
2,
|
||||
cv2.LINE_AA)
|
||||
for point, label in zip(input_point, input_label):
|
||||
color = (0, 255, 0) if label == 1 else (0, 0, 255)
|
||||
cv2.circle(image_display, tuple(point), 5, color, -1)
|
||||
if selected_mask is not None:
|
||||
color = tuple(np.random.randint(0, 256, 3).tolist())
|
||||
selected_image = apply_color_mask(image_display, selected_mask, color)
|
||||
|
||||
cv2.imshow("image", image_display)
|
||||
|
||||
key = cv2.waitKey(1)
|
||||
|
||||
if key == ord(" "):
|
||||
input_point = []
|
||||
input_label = []
|
||||
selected_mask = None
|
||||
logit_input = None
|
||||
elif key == ord("w"):
|
||||
input_stop = True
|
||||
if len(input_point) > 0 and len(input_label) > 0:
|
||||
|
||||
predictor.set_image(image)
|
||||
input_point_np = np.array(input_point)
|
||||
input_label_np = np.array(input_label)
|
||||
|
||||
masks, scores, logits = predictor.predict(
|
||||
point_coords=input_point_np,
|
||||
point_labels=input_label_np,
|
||||
mask_input=logit_input[None, :, :] if logit_input is not None else None,
|
||||
multimask_output=True,
|
||||
)
|
||||
|
||||
mask_idx = 0
|
||||
num_masks = len(masks)
|
||||
|
||||
prediction_window_name = "Prediction"
|
||||
cv2.namedWindow(prediction_window_name, cv2.WINDOW_NORMAL)
|
||||
cv2.resizeWindow(prediction_window_name, WINDOW_WIDTH, WINDOW_HEIGHT)
|
||||
cv2.moveWindow(prediction_window_name, (1920 - WINDOW_WIDTH) // 2,
|
||||
(1080 - WINDOW_HEIGHT) // 2)
|
||||
|
||||
while True:
|
||||
color = tuple(np.random.randint(0, 256, 3).tolist())
|
||||
image_select = image_orign.copy()
|
||||
selected_mask = masks[mask_idx]
|
||||
selected_image = apply_color_mask(image_select, selected_mask, color)
|
||||
mask_info = f'Total: {num_masks} | Current: {mask_idx} | Score: {scores[mask_idx]:.2f} | w 预测 | d 切换下一个 | a 切换上一个 |q 移除最后一个'
|
||||
cv2.putText(selected_image, mask_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7,
|
||||
(0, 255, 255), 2, cv2.LINE_AA)
|
||||
|
||||
cv2.imshow(prediction_window_name, selected_image)
|
||||
|
||||
key = cv2.waitKey(10)
|
||||
if key == ord('q') and len(input_point) > 0:
|
||||
input_point.pop(-1)
|
||||
elif key == ord('s'):
|
||||
save_masked_image(image_crop, selected_mask, output_dir, filename,
|
||||
crop_mode_=crop_mode)
|
||||
elif key == ord('a'):
|
||||
if mask_idx > 0:
|
||||
mask_idx -= 1
|
||||
else:
|
||||
mask_idx = num_masks - 1
|
||||
elif key == ord('d'):
|
||||
if mask_idx < num_masks - 1:
|
||||
mask_idx += 1
|
||||
else:
|
||||
mask_idx = 0
|
||||
elif key == ord('w'):
|
||||
input_stop = False # Allow adding points again
|
||||
break
|
||||
elif key == ord(" "):
|
||||
input_point = []
|
||||
input_label = []
|
||||
selected_mask = None
|
||||
logit_input = None
|
||||
break
|
||||
logit_input = logits[mask_idx, :, :]
|
||||
print('max score:', np.argmax(scores), ' select:', mask_idx)
|
||||
|
||||
elif key == ord('a'):
|
||||
current_index = max(0, current_index - 1)
|
||||
input_point = []
|
||||
input_label = []
|
||||
break
|
||||
elif key == ord('d'):
|
||||
current_index = min(len(image_files) - 1, current_index + 1)
|
||||
input_point = []
|
||||
input_label = []
|
||||
break
|
||||
elif key == 27:
|
||||
break
|
||||
elif key == ord('q') and len(input_point) > 0:
|
||||
input_point.pop(-1)
|
||||
input_label.pop(-1)
|
||||
elif key == ord('s') and selected_mask is not None:
|
||||
save_masked_image(image_crop, selected_mask, output_dir, filename, crop_mode_=crop_mode)
|
||||
|
||||
if key == 27:
|
||||
break
|
||||
|
||||
cv2.destroyAllWindows() # Close all windows before exiting
|
||||
if key == 27:
|
||||
break
|
||||
|
||||
def save_prediction(self):
|
||||
if self.label_2.pixmap(): # 检查预测图像区域是否有图像
|
||||
# 保存预测结果的部分,这里假设你已经有了保存预测结果的代码,我用 placeholer 代替
|
||||
# placeholder: 这里假设 save_prediction_result 是一个保存预测结果的函数,它接受预测结果的图像数据以及保存路径作为参数
|
||||
# 这里假设预测结果图像数据为 prediction_image,保存路径为 save_path
|
||||
prediction_image = self.label_2.pixmap().toImage()
|
||||
save_path = "prediction_result.png"
|
||||
prediction_image.save(save_path)
|
||||
|
||||
# 调用 adjust_prediction_size 方法来根据 horizontalSlider 的值调整预测区域的大小
|
||||
self.adjust_prediction_size(self.horizontalSlider.value())
|
||||
|
||||
def adjust_prediction_size(self, value):
|
||||
if self.image_files and self.current_index < len(self.image_files):
|
||||
# 获取预测图像区域的原始大小
|
||||
pixmap = self.label_2.pixmap()
|
||||
if pixmap.isNull():
|
||||
return
|
||||
|
||||
original_size = pixmap.size()
|
||||
|
||||
# 判断是缩小还是还原图像
|
||||
if value < self.last_value:
|
||||
# 缩小掩码
|
||||
scale_factor = 1.0 + (self.last_value - value) * 0.1
|
||||
else:
|
||||
# 放大掩码
|
||||
scale_factor = 1.0 - (value - self.last_value) * 0.1
|
||||
|
||||
self.last_value = value # 更新上一次的滑块值
|
||||
|
||||
# 根据缩放比例调整预测图像区域的大小,并保持纵横比例
|
||||
scaled_size = QtCore.QSize(original_size.width() * scale_factor, original_size.height() * scale_factor)
|
||||
scaled_pixmap = pixmap.scaled(scaled_size, QtCore.Qt.KeepAspectRatio)
|
||||
|
||||
# 更新预测图像区域的大小并显示
|
||||
self.label_2.setPixmap(scaled_pixmap)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app = QtWidgets.QApplication(sys.argv)
|
||||
mainWindow = MyMainWindow()
|
||||
mainWindow.show()
|
||||
sys.exit(app.exec_())
|
||||
|
|
@ -0,0 +1,146 @@
|
|||
import sys
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PyQt5 import QtCore, QtGui, QtWidgets
|
||||
from segment_anything import sam_model_registry, SamPredictor
|
||||
|
||||
class ImageLabel(QtWidgets.QLabel):
|
||||
clicked = QtCore.pyqtSignal(QtCore.QPoint)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(ImageLabel, self).__init__(*args, **kwargs)
|
||||
self.foreground_points = []
|
||||
self.background_points = []
|
||||
|
||||
def mousePressEvent(self, event):
|
||||
self.clicked.emit(event.pos())
|
||||
|
||||
if event.button() == QtCore.Qt.LeftButton:
|
||||
self.foreground_points.append(event.pos())
|
||||
elif event.button() == QtCore.Qt.RightButton:
|
||||
self.background_points.append(event.pos())
|
||||
|
||||
self.update()
|
||||
|
||||
def paintEvent(self, event):
|
||||
super().paintEvent(event)
|
||||
|
||||
painter = QtGui.QPainter(self)
|
||||
painter.setPen(QtGui.QPen(QtGui.QColor("green"), 5))
|
||||
|
||||
for point in self.foreground_points:
|
||||
painter.drawPoint(point)
|
||||
|
||||
painter.setPen(QtGui.QPen(QtGui.QColor("red"), 5))
|
||||
|
||||
for point in self.background_points:
|
||||
painter.drawPoint(point)
|
||||
|
||||
|
||||
class MainWindow(QtWidgets.QMainWindow):
|
||||
def __init__(self, predictor):
|
||||
super().__init__()
|
||||
self.predictor = predictor
|
||||
self.current_points = []
|
||||
self.current_image = None
|
||||
self.setupUi()
|
||||
|
||||
def setupUi(self):
|
||||
self.setObjectName("MainWindow")
|
||||
self.resize(1333, 657)
|
||||
self.centralwidget = QtWidgets.QWidget(self)
|
||||
self.centralwidget.setObjectName("centralwidget")
|
||||
self.pushButton_init = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_init.setGeometry(QtCore.QRect(10, 30, 141, 41))
|
||||
self.pushButton_init.setObjectName("pushButton_init")
|
||||
self.pushButton_openimg = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_openimg.setGeometry(QtCore.QRect(10, 90, 141, 41))
|
||||
self.pushButton_openimg.setObjectName("pushButton_openimg")
|
||||
self.pushButton_save_mask = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_save_mask.setGeometry(QtCore.QRect(10, 600, 141, 41))
|
||||
self.pushButton_save_mask.setObjectName("pushButton_save_mask")
|
||||
self.label_Originalimg = ImageLabel(self.centralwidget)
|
||||
self.label_Originalimg.setGeometry(QtCore.QRect(160, 30, 571, 581))
|
||||
self.label_Originalimg.setStyleSheet("background-color: rgb(255, 255, 255);")
|
||||
self.label_Originalimg.setObjectName("label_Originalimg")
|
||||
self.label_Maskimg = QtWidgets.QLabel(self.centralwidget)
|
||||
self.label_Maskimg.setGeometry(QtCore.QRect(740, 30, 581, 581))
|
||||
self.label_Maskimg.setStyleSheet("background-color: rgb(255, 255, 255);")
|
||||
self.label_Maskimg.setObjectName("label_Maskimg")
|
||||
self.setCentralWidget(self.centralwidget)
|
||||
self.menubar = QtWidgets.QMenuBar(self)
|
||||
self.menubar.setGeometry(QtCore.QRect(0, 0, 1333, 26))
|
||||
self.menubar.setObjectName("menubar")
|
||||
self.setMenuBar(self.menubar)
|
||||
self.statusbar = QtWidgets.QStatusBar(self)
|
||||
self.statusbar.setObjectName("statusbar")
|
||||
self.setStatusBar(self.statusbar)
|
||||
self.retranslateUi()
|
||||
QtCore.QMetaObject.connectSlotsByName(self)
|
||||
|
||||
def retranslateUi(self):
|
||||
_translate = QtCore.QCoreApplication.translate
|
||||
self.setWindowTitle(_translate("MainWindow", "MainWindow"))
|
||||
self.pushButton_init.setText(_translate("MainWindow", "重置选择"))
|
||||
self.pushButton_openimg.setText(_translate("MainWindow", "打开图片"))
|
||||
self.pushButton_save_mask.setText(_translate("MainWindow", "保存掩码"))
|
||||
|
||||
self.pushButton_openimg.clicked.connect(self.button_image_open)
|
||||
self.pushButton_save_mask.clicked.connect(self.button_save_mask)
|
||||
self.label_Originalimg.clicked.connect(self.mouse_click)
|
||||
|
||||
def button_image_open(self):
|
||||
choice = QtWidgets.QMessageBox.question(None, "选择", "您想要打开文件夹还是选择一个图片文件?",
|
||||
QtWidgets.QMessageBox.Open | QtWidgets.QMessageBox.Cancel)
|
||||
if choice == QtWidgets.QMessageBox.Open:
|
||||
folder_path = QtWidgets.QFileDialog.getExistingDirectory(None, "选择文件夹", "")
|
||||
if folder_path:
|
||||
image_files = [os.path.join(folder_path, file) for file in os.listdir(folder_path)
|
||||
if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp'))]
|
||||
if image_files:
|
||||
self.image_files = image_files
|
||||
self.current_index = 0
|
||||
self.display_image()
|
||||
elif choice == QtWidgets.QMessageBox.Cancel:
|
||||
selected_image, _ = QtWidgets.QFileDialog.getOpenFileName(None, "选择图片", "",
|
||||
"Image files (*.png *.jpg *.jpeg *.bmp)")
|
||||
if selected_image:
|
||||
self.image_files = [selected_image]
|
||||
self.current_index = 0
|
||||
self.display_image()
|
||||
|
||||
def display_image(self):
|
||||
if hasattr(self, 'image_files') and self.image_files:
|
||||
pixmap = QtGui.QPixmap(self.image_files[self.current_index])
|
||||
self.label_Originalimg.setPixmap(pixmap)
|
||||
self.label_Originalimg.setScaledContents(True)
|
||||
|
||||
def mouse_click(self, pos):
|
||||
x, y = pos.x(), pos.y()
|
||||
print("Mouse clicked at position:", x, y)
|
||||
self.current_points.append([x, y])
|
||||
|
||||
def button_save_mask(self):
|
||||
if self.current_image is not None and len(self.current_points) > 0:
|
||||
masks, _, _ = self.predictor.predict(point_coords=np.array(self.current_points),
|
||||
point_labels=np.ones(len(self.current_points), dtype=np.uint8),
|
||||
multimask_output=True)
|
||||
if masks:
|
||||
mask_image = masks[0]
|
||||
mask_image = cv2.cvtColor(mask_image, cv2.COLOR_BGR2RGB)
|
||||
q_image = QtGui.QImage(mask_image.data, mask_image.shape[1], mask_image.shape[0], mask_image.strides[0],
|
||||
QtGui.QImage.Format_RGB888)
|
||||
pixmap = QtGui.QPixmap.fromImage(q_image)
|
||||
self.label_Maskimg.setPixmap(pixmap)
|
||||
self.label_Maskimg.setScaledContents(True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app = QtWidgets.QApplication(sys.argv)
|
||||
sam = sam_model_registry["vit_b"](checkpoint=r"D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth")
|
||||
sam = sam.to(device="cuda")
|
||||
predictor = SamPredictor(sam)
|
||||
window = MainWindow(predictor)
|
||||
window.show()
|
||||
sys.exit(app.exec_())
|
|
@ -0,0 +1,144 @@
|
|||
import sys
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PyQt5 import QtCore, QtGui, QtWidgets
|
||||
from segment_anything import sam_model_registry, SamPredictor
|
||||
|
||||
|
||||
class Ui_MainWindow(object):
|
||||
def setupUi(self, MainWindow):
|
||||
MainWindow.setObjectName("MainWindow")
|
||||
MainWindow.resize(1333, 657)
|
||||
self.centralwidget = QtWidgets.QWidget(MainWindow)
|
||||
self.centralwidget.setObjectName("centralwidget")
|
||||
self.pushButton_init = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_init.setGeometry(QtCore.QRect(10, 30, 141, 41))
|
||||
self.pushButton_init.setObjectName("pushButton_init")
|
||||
self.pushButton_openimg = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_openimg.setGeometry(QtCore.QRect(10, 90, 141, 41))
|
||||
self.pushButton_openimg.setObjectName("pushButton_openimg")
|
||||
self.pushButton_Fusionimg = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_Fusionimg.setGeometry(QtCore.QRect(10, 270, 141, 41))
|
||||
self.pushButton_Fusionimg.setObjectName("pushButton_Fusionimg")
|
||||
self.pushButton_exit = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_exit.setGeometry(QtCore.QRect(10, 570, 141, 41))
|
||||
self.pushButton_exit.setObjectName("pushButton_exit")
|
||||
self.pushButton_Transparency = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_Transparency.setGeometry(QtCore.QRect(10, 380, 141, 41))
|
||||
self.pushButton_Transparency.setObjectName("pushButton_Transparency")
|
||||
self.pushButton_copymask = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_copymask.setGeometry(QtCore.QRect(10, 450, 141, 41))
|
||||
self.pushButton_copymask.setObjectName("pushButton_copymask")
|
||||
self.pushButton_saveimg = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_saveimg.setGeometry(QtCore.QRect(10, 510, 141, 41))
|
||||
self.pushButton_saveimg.setObjectName("pushButton_saveimg")
|
||||
self.horizontalSlider = QtWidgets.QSlider(self.centralwidget)
|
||||
self.horizontalSlider.setGeometry(QtCore.QRect(10, 330, 141, 22))
|
||||
self.horizontalSlider.setOrientation(QtCore.Qt.Horizontal)
|
||||
self.horizontalSlider.setObjectName("horizontalSlider")
|
||||
self.horizontalSlider.setValue(50)
|
||||
self.label_Originalimg = QtWidgets.QLabel(self.centralwidget)
|
||||
self.label_Originalimg.setGeometry(QtCore.QRect(160, 30, 571, 581))
|
||||
self.label_Originalimg.setStyleSheet("background-color: rgb(255, 255, 255);")
|
||||
self.label_Originalimg.setObjectName("label_Originalimg")
|
||||
self.label_Maskimg = QtWidgets.QLabel(self.centralwidget)
|
||||
self.label_Maskimg.setGeometry(QtCore.QRect(740, 30, 581, 581))
|
||||
self.label_Maskimg.setStyleSheet("background-color: rgb(255, 255, 255);")
|
||||
self.label_Maskimg.setObjectName("label_Maskimg")
|
||||
self.pushButton_shang = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_shang.setGeometry(QtCore.QRect(10, 150, 141, 41))
|
||||
self.pushButton_shang.setObjectName("pushButton_shang")
|
||||
self.pushButton_xia = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_xia.setGeometry(QtCore.QRect(10, 210, 141, 41))
|
||||
self.pushButton_xia.setObjectName("pushButton_openimg_3")
|
||||
MainWindow.setCentralWidget(self.centralwidget)
|
||||
self.menubar = QtWidgets.QMenuBar(MainWindow)
|
||||
self.menubar.setGeometry(QtCore.QRect(0, 0, 1333, 26))
|
||||
self.menubar.setObjectName("menubar")
|
||||
MainWindow.setMenuBar(self.menubar)
|
||||
self.statusbar = QtWidgets.QStatusBar(MainWindow)
|
||||
self.statusbar.setObjectName("statusbar")
|
||||
MainWindow.setStatusBar(self.statusbar)
|
||||
|
||||
self.retranslateUi(MainWindow)
|
||||
QtCore.QMetaObject.connectSlotsByName(MainWindow)
|
||||
|
||||
def retranslateUi(self, MainWindow):
|
||||
|
||||
_translate = QtCore.QCoreApplication.translate
|
||||
MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow"))
|
||||
self.pushButton_init.setText(_translate("MainWindow", "重置选择"))
|
||||
self.pushButton_openimg.setText(_translate("MainWindow", "打开图片"))
|
||||
self.pushButton_shang.setText(_translate("MainWindow", "上一张"))
|
||||
self.pushButton_xia.setText(_translate("MainWindow", "下一张"))
|
||||
self.pushButton_Fusionimg.setText(_translate("MainWindow", "融合背景图片"))
|
||||
self.pushButton_exit.setText(_translate("MainWindow", "退出"))
|
||||
self.pushButton_Transparency.setText(_translate("MainWindow", "调整透明度"))
|
||||
self.pushButton_copymask.setText(_translate("MainWindow", "复制掩码"))
|
||||
self.pushButton_saveimg.setText(_translate("MainWindow", "保存图片"))
|
||||
self.label_Originalimg.setText(
|
||||
_translate("MainWindow", "<html><head/><body><p align=\"center\">原始图像</p></body></html>"))
|
||||
self.label_Maskimg.setText(
|
||||
_translate("MainWindow", "<html><head/><body><p align=\"center\">掩码图像</p></body></html>"))
|
||||
|
||||
|
||||
class MyMainWindow(QtWidgets.QMainWindow, Ui_MainWindow):
|
||||
def __init__(self):
|
||||
super(MyMainWindow, self).__init__()
|
||||
self.setupUi(self)
|
||||
self.init_slots()
|
||||
self.image_files = []
|
||||
self.current_index = 0
|
||||
|
||||
def init_slots(self):
|
||||
self.pushButton_openimg.clicked.connect(self.button_image_open)
|
||||
self.pushButton_shang.clicked.connect(self.button_image_shang)
|
||||
self.pushButton_xia.clicked.connect(self.button_image_xia)
|
||||
self.pushButton_exit.clicked.connect(self.button_image_exit)
|
||||
|
||||
def button_image_open(self):
|
||||
choice = QtWidgets.QMessageBox.question(None, "选择", "您想要打开文件夹还是选择一个图片文件?",
|
||||
QtWidgets.QMessageBox.Open | QtWidgets.QMessageBox.Cancel)
|
||||
if choice == QtWidgets.QMessageBox.Open:
|
||||
folder_path = QtWidgets.QFileDialog.getExistingDirectory(None, "选择文件夹", "")
|
||||
if folder_path:
|
||||
self.image_files = [os.path.join(folder_path, file) for file in os.listdir(folder_path)
|
||||
if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp'))]
|
||||
if self.image_files:
|
||||
self.current_index = 0
|
||||
self.display_image()
|
||||
elif choice == QtWidgets.QMessageBox.Cancel:
|
||||
selected_image, _ = QtWidgets.QFileDialog.getOpenFileName(None, "选择图片", "",
|
||||
"Image files (*.png *.jpg *.jpeg *.bmp)")
|
||||
if selected_image:
|
||||
self.image_files = [selected_image]
|
||||
self.current_index = 0
|
||||
self.display_image()
|
||||
|
||||
def display_image(self):
|
||||
if self.image_files:
|
||||
pixmap = QtGui.QPixmap(self.image_files[self.current_index])
|
||||
self.label_Originalimg.setPixmap(pixmap)
|
||||
self.label_Originalimg.setScaledContents(True)
|
||||
|
||||
def button_image_shang(self):
|
||||
if self.image_files:
|
||||
self.current_index = (self.current_index - 1) % len(self.image_files)
|
||||
self.display_image()
|
||||
|
||||
def button_image_xia(self):
|
||||
if self.image_files:
|
||||
self.current_index = (self.current_index + 1) % len(self.image_files)
|
||||
self.display_image()
|
||||
|
||||
def button_image_exit(self):
|
||||
sys.exit()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app = QtWidgets.QApplication(sys.argv)
|
||||
my_main_window = MyMainWindow()
|
||||
my_main_window.show()
|
||||
sys.exit(app.exec_())
|
||||
|
After Width: | Height: | Size: 650 KiB |
|
@ -0,0 +1,299 @@
|
|||
import cv2
|
||||
import os
|
||||
import numpy as np
|
||||
from PyQt5 import QtCore, QtGui, QtWidgets
|
||||
from segment_anything import sam_model_registry, SamPredictor
|
||||
|
||||
|
||||
class Ui_MainWindow(object):
|
||||
def setupUi(self, MainWindow):
|
||||
MainWindow.setObjectName("MainWindow")
|
||||
MainWindow.resize(1170, 486)
|
||||
self.centralwidget = QtWidgets.QWidget(MainWindow)
|
||||
self.centralwidget.setObjectName("centralwidget")
|
||||
self.pushButton_w = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_w.setGeometry(QtCore.QRect(10, 90, 151, 51))
|
||||
self.pushButton_w.setObjectName("pushButton_w")
|
||||
self.pushButton_a = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_a.setGeometry(QtCore.QRect(10, 160, 151, 51))
|
||||
self.pushButton_a.setObjectName("pushButton_a")
|
||||
self.pushButton_d = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_d.setGeometry(QtCore.QRect(10, 230, 151, 51))
|
||||
self.pushButton_d.setObjectName("pushButton_d")
|
||||
self.pushButton_s = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_s.setGeometry(QtCore.QRect(10, 300, 151, 51))
|
||||
self.pushButton_s.setObjectName("pushButton_s")
|
||||
self.pushButton_q = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_q.setGeometry(QtCore.QRect(10, 370, 151, 51))
|
||||
self.pushButton_q.setObjectName("pushButton_q")
|
||||
self.label_orign = QtWidgets.QLabel(self.centralwidget)
|
||||
self.label_orign.setGeometry(QtCore.QRect(180, 20, 450, 450))
|
||||
self.label_orign.setStyleSheet("background-color: rgb(255, 255, 255);")
|
||||
self.label_orign.setObjectName("label_orign")
|
||||
self.label_pre = QtWidgets.QLabel(self.centralwidget)
|
||||
self.label_pre.setGeometry(QtCore.QRect(660, 20, 450, 450))
|
||||
self.label_pre.setStyleSheet("background-color: rgb(255, 255, 255);")
|
||||
self.label_pre.setObjectName("label_pre")
|
||||
self.pushButton_opimg = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_opimg.setGeometry(QtCore.QRect(10, 20, 151, 51))
|
||||
self.pushButton_opimg.setObjectName("pushButton_opimg")
|
||||
MainWindow.setCentralWidget(self.centralwidget)
|
||||
self.menubar = QtWidgets.QMenuBar(MainWindow)
|
||||
self.menubar.setGeometry(QtCore.QRect(0, 0, 1170, 26))
|
||||
self.menubar.setObjectName("menubar")
|
||||
MainWindow.setMenuBar(self.menubar)
|
||||
self.statusbar = QtWidgets.QStatusBar(MainWindow)
|
||||
self.statusbar.setObjectName("statusbar")
|
||||
MainWindow.setStatusBar(self.statusbar)
|
||||
|
||||
self.retranslateUi(MainWindow)
|
||||
QtCore.QMetaObject.connectSlotsByName(MainWindow)
|
||||
|
||||
def retranslateUi(self, MainWindow):
|
||||
_translate = QtCore.QCoreApplication.translate
|
||||
MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow"))
|
||||
self.pushButton_w.setText(_translate("MainWindow", "w"))
|
||||
self.pushButton_a.setText(_translate("MainWindow", "a"))
|
||||
self.pushButton_d.setText(_translate("MainWindow", "d"))
|
||||
self.pushButton_s.setText(_translate("MainWindow", "s"))
|
||||
self.pushButton_q.setText(_translate("MainWindow", "q"))
|
||||
self.label_orign.setText(
|
||||
_translate("MainWindow", "<html><head/><body><p align=\"center\">Original Image</p></body></html>"))
|
||||
self.label_pre.setText(
|
||||
_translate("MainWindow", "<html><head/><body><p align=\"center\">Predicted Image</p></body></html>"))
|
||||
self.pushButton_opimg.setText(_translate("MainWindow", "Open Image"))
|
||||
|
||||
|
||||
class MainWindow(QtWidgets.QMainWindow, Ui_MainWindow):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.setupUi(self)
|
||||
|
||||
self.pushButton_opimg.clicked.connect(self.open_image)
|
||||
self.pushButton_w.clicked.connect(self.predict_and_interact)
|
||||
|
||||
self.image_files = []
|
||||
self.current_index = 0
|
||||
self.input_point = []
|
||||
self.input_label = []
|
||||
self.input_stop = False
|
||||
self.interaction_count = 0 # 记录交互次数
|
||||
self.sam = sam_model_registry["vit_b"](
|
||||
checkpoint=r"D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth")
|
||||
_ = self.sam.to(device="cuda")
|
||||
self.predictor = SamPredictor(self.sam)
|
||||
|
||||
# Calculate coordinate scaling factors
|
||||
self.scale_x = 1.0
|
||||
self.scale_y = 1.0
|
||||
self.label_pre_width = self.label_pre.width()
|
||||
self.label_pre_height = self.label_pre.height()
|
||||
|
||||
# Set mouse click event for original image label
|
||||
self.set_mouse_click_event()
|
||||
|
||||
def open_image(self):
|
||||
options = QtWidgets.QFileDialog.Options()
|
||||
filename, _ = QtWidgets.QFileDialog.getOpenFileName(self, "Open Image File", "",
|
||||
"Image Files (*.png *.jpg *.jpeg *.JPG *.JPEG *.PNG *.tiff)",
|
||||
options=options)
|
||||
if filename:
|
||||
image = cv2.imread(filename)
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
self.image_files.append(image)
|
||||
self.display_original_image()
|
||||
|
||||
def display_original_image(self):
|
||||
if self.image_files:
|
||||
image = self.image_files[self.current_index]
|
||||
height, width, channel = image.shape
|
||||
bytesPerLine = 3 * width
|
||||
qImg = QtGui.QImage(image.data, width, height, bytesPerLine, QtGui.QImage.Format_RGB888)
|
||||
pixmap = QtGui.QPixmap.fromImage(qImg)
|
||||
self.label_orign.setPixmap(pixmap.scaled(self.label_orign.size(), QtCore.Qt.KeepAspectRatio))
|
||||
|
||||
# Add mouse click event
|
||||
self.label_orign.mousePressEvent = self.mouse_click
|
||||
|
||||
# Draw marked points on the original image
|
||||
painter = QtGui.QPainter(self.label_orign.pixmap()) # Use label_orign for drawing points
|
||||
pen_foreground = QtGui.QPen(QtGui.QColor(0, 255, 0)) # Red color for foreground points
|
||||
pen_foreground.setWidth(5) # Set the width of the pen to 5 (adjust as needed)
|
||||
pen_background = QtGui.QPen(QtGui.QColor(255, 0, 0)) # Green color for background points
|
||||
pen_background.setWidth(5) # Set the width of the pen to 5 (adjust as needed)
|
||||
painter.setPen(pen_foreground)
|
||||
for point, label in zip(self.input_point, self.input_label):
|
||||
x, y = self.convert_to_label_coords(point)
|
||||
if label == 1: # Foreground point
|
||||
painter.drawPoint(QtCore.QPoint(x, y))
|
||||
painter.setPen(pen_background)
|
||||
for point, label in zip(self.input_point, self.input_label):
|
||||
x, y = self.convert_to_label_coords(point)
|
||||
if label == 0: # Background point
|
||||
painter.drawPoint(QtCore.QPoint(x, y))
|
||||
painter.end()
|
||||
|
||||
# Calculate coordinate scaling factors
|
||||
self.scale_x = width / self.label_orign.width()
|
||||
self.scale_y = height / self .label_orign.height()
|
||||
|
||||
def convert_to_label_coords(self, point):
|
||||
x = point[0] / self.scale_x
|
||||
y = point[1] / self.scale_y
|
||||
return x, y
|
||||
|
||||
def mouse_click(self, event):
|
||||
if not self.input_stop:
|
||||
x = int(event.pos().x() * self.scale_x)
|
||||
y = int(event.pos().y() * self.scale_y)
|
||||
if event.button() == QtCore.Qt.LeftButton: # If left-clicked, mark as foreground
|
||||
self.input_label.append(1) # Foreground label is 1
|
||||
elif event.button() == QtCore.Qt.RightButton: # If right-clicked, mark as background
|
||||
self.input_label.append(0) # Background label is 0
|
||||
|
||||
self.input_point.append([x, y])
|
||||
|
||||
# Update the original image with marked points
|
||||
self.display_original_image()
|
||||
|
||||
def predict_and_interact(self):
|
||||
if not self.image_files:
|
||||
return
|
||||
|
||||
image = self.image_files[self.current_index].copy()
|
||||
filename = f"image_{self.current_index}.png"
|
||||
image_crop = image.copy()
|
||||
|
||||
while True: # Outer loop for prediction
|
||||
# Prediction logic
|
||||
if not self.input_stop: # If not in interaction mode
|
||||
if len(self.input_point) > 0 and len(self.input_label) > 0:
|
||||
self.predictor.set_image(image)
|
||||
input_point_np = np.array(self.input_point)
|
||||
input_label_np = np.array(self.input_label)
|
||||
|
||||
masks, scores, logits = self.predictor.predict(
|
||||
point_coords=input_point_np,
|
||||
point_labels=input_label_np,
|
||||
multimask_output=True,
|
||||
)
|
||||
|
||||
mask_idx = 0
|
||||
num_masks = len(masks)
|
||||
|
||||
while True: # Inner loop for interaction
|
||||
color = tuple(np.random.randint(0, 256, 3).tolist())
|
||||
image_select = image.copy()
|
||||
selected_mask = masks[mask_idx]
|
||||
selected_image = self.apply_color_mask(image_select, selected_mask, color)
|
||||
mask_info = f'Total: {num_masks} | Current: {mask_idx} | Score: {scores[mask_idx]:.2f} | w Predict | d Next | a Previous | q Remove Last | s Save'
|
||||
cv2.putText(selected_image, mask_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255),
|
||||
2, cv2.LINE_AA)
|
||||
|
||||
# Display the predicted result in label_pre area
|
||||
self.display_prediction_image(selected_image)
|
||||
|
||||
key = cv2.waitKey(10)
|
||||
|
||||
# Handle key press events
|
||||
if key == ord('q') and len(self.input_point) > 0:
|
||||
self.input_point.pop(-1)
|
||||
self.input_label.pop(-1)
|
||||
self.display_original_image()
|
||||
elif key == ord('s'):
|
||||
self.save_masked_image(image_crop, selected_mask, filename)
|
||||
elif key == ord('a'):
|
||||
if mask_idx > 0:
|
||||
mask_idx -= 1
|
||||
else:
|
||||
mask_idx = num_masks - 1
|
||||
elif key == ord('d'):
|
||||
if mask_idx < num_masks - 1:
|
||||
mask_idx += 1
|
||||
else:
|
||||
mask_idx = 0
|
||||
elif key == ord(" "):
|
||||
break
|
||||
|
||||
if cv2.getWindowProperty("Prediction", cv2.WND_PROP_VISIBLE) < 1:
|
||||
break
|
||||
|
||||
# If 'w' is pressed, toggle interaction mode
|
||||
if key == ord('w'):
|
||||
self.input_stop = not self.input_stop # Toggle interaction mode
|
||||
if not self.input_stop: # If entering interaction mode
|
||||
self.interaction_count += 1
|
||||
if self.interaction_count % 2 == 0: # If even number of interactions, call the interaction function
|
||||
self.input_point = [] # Reset input points for the next interaction
|
||||
self.input_label = [] # Reset input labels for the next interaction
|
||||
self.display_original_image() # Display original image
|
||||
self.set_mouse_click_event() # Set mouse click event
|
||||
break # Exit outer loop
|
||||
else:
|
||||
continue # Continue prediction
|
||||
|
||||
# Exit the outer loop if not in interaction mode
|
||||
if not self.input_stop:
|
||||
break
|
||||
|
||||
def set_mouse_click_event(self):
|
||||
self.label_orign.mousePressEvent = self.mouse_click
|
||||
|
||||
def display_prediction_image(self, image):
|
||||
height, width, channel = image.shape
|
||||
bytesPerLine = 3 * width
|
||||
qImg = QtGui.QImage(image.data, width, height, bytesPerLine, QtGui.QImage.Format_RGB888)
|
||||
pixmap = QtGui.QPixmap.fromImage(qImg)
|
||||
self.label_pre.setPixmap(pixmap.scaled(self.label_pre.size(), QtCore.Qt.KeepAspectRatio))
|
||||
|
||||
# Draw marked points on the predicted image
|
||||
painter = QtGui.QPainter(self.label_pre.pixmap())
|
||||
pen_foreground = QtGui.QPen(QtGui.QColor(0, 255, 0)) # Red color for foreground points
|
||||
pen_foreground.setWidth(5) # Set the width of the pen to 5 (adjust as needed)
|
||||
pen_background = QtGui.QPen(QtGui.QColor(255, 0, 0)) # Green color for background points
|
||||
pen_background.setWidth(5) # Set the width of the pen to 5 (adjust as needed)
|
||||
painter.setPen(pen_foreground)
|
||||
for point, label in zip(self.input_point, self.input_label):
|
||||
x, y = self.convert_to_label_coords(point)
|
||||
if label == 1: # Foreground point
|
||||
painter.drawPoint(QtCore.QPoint(x, y))
|
||||
painter.setPen(pen_background)
|
||||
for point, label in zip(self.input_point, self.input_label):
|
||||
x, y = self.convert_to_label_coords(point)
|
||||
if label == 0: # Background point
|
||||
painter.drawPoint(QtCore.QPoint(x, y))
|
||||
painter.end()
|
||||
|
||||
def apply_color_mask(self, image, mask, color=(0, 255, 0), color_dark=0.5):
|
||||
for c in range(3):
|
||||
image[:, :, c] = np.where(mask == 1, image[:, :, c] * (1 - color_dark) + color_dark * color[c],
|
||||
image[:, :, c])
|
||||
return image
|
||||
|
||||
def save_masked_image(self, image, mask, filename):
|
||||
output_dir = os.path.dirname(filename)
|
||||
filename = os.path.basename(filename)
|
||||
filename = filename[:filename.rfind('.')] + '_masked.png'
|
||||
new_filename = os.path.join(output_dir, filename)
|
||||
|
||||
masked_image = self.apply_color_mask(image, mask)
|
||||
cv2.imwrite(new_filename, cv2.cvtColor(masked_image, cv2.COLOR_RGB2BGR))
|
||||
print(f"Saved as { new_filename}")
|
||||
|
||||
def previous_image(self):
|
||||
if self.current_index > 0:
|
||||
self.current_index -= 1
|
||||
self.display_original_image()
|
||||
|
||||
def next_image(self):
|
||||
if self.current_index < len(self.image_files) - 1:
|
||||
self.current_index += 1
|
||||
self.display_original_image()
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
app = QtWidgets.QApplication(sys.argv)
|
||||
window = MainWindow()
|
||||
window.show()
|
||||
sys.exit(app.exec_())
|
||||
|
|
@ -0,0 +1,453 @@
|
|||
import os
|
||||
import sys
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PyQt5 import QtCore, QtGui, QtWidgets
|
||||
from segment_anything import sam_model_registry, SamPredictor
|
||||
|
||||
class Ui_MainWindow(object):
|
||||
def setupUi(self, MainWindow):
|
||||
MainWindow.setObjectName("MainWindow")
|
||||
MainWindow.resize(1140, 450)
|
||||
MainWindow.setMinimumSize(QtCore.QSize(1140, 450))
|
||||
MainWindow.setMaximumSize(QtCore.QSize(1140, 450))
|
||||
self.centralwidget = QtWidgets.QWidget(MainWindow)
|
||||
self.centralwidget.setObjectName("centralwidget")
|
||||
self.pushButton_w = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_w.setGeometry(QtCore.QRect(10, 90, 151, 51))
|
||||
self.pushButton_w.setObjectName("pushButton_w")
|
||||
self.pushButton_a = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_a.setGeometry(QtCore.QRect(10, 160, 71, 51))
|
||||
self.pushButton_a.setObjectName("pushButton_a")
|
||||
self.pushButton_d = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_d.setGeometry(QtCore.QRect(90, 160, 71, 51))
|
||||
self.pushButton_d.setObjectName("pushButton_d")
|
||||
self.pushButton_s = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_s.setGeometry(QtCore.QRect(10, 360, 151, 51))
|
||||
self.pushButton_s.setObjectName("pushButton_s")
|
||||
self.pushButton_5 = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_5.setGeometry(QtCore.QRect(10, 230, 151, 51))
|
||||
self.pushButton_5.setObjectName("pushButton_5")
|
||||
self.label_orign = QtWidgets.QLabel(self.centralwidget)
|
||||
self.label_orign.setGeometry(QtCore.QRect(180, 20, 471, 401))
|
||||
self.label_orign.setStyleSheet("background-color: rgb(255, 255, 255);")
|
||||
self.label_orign.setObjectName("label_orign")
|
||||
self.label_2 = QtWidgets.QLabel(self.centralwidget)
|
||||
self.label_2.setGeometry(QtCore.QRect(660, 20, 471, 401))
|
||||
self.label_2.setStyleSheet("background-color: rgb(255, 255, 255);")
|
||||
self.label_2.setObjectName("label_2")
|
||||
self.pushButton_w_2 = QtWidgets.QPushButton(self.centralwidget)
|
||||
self.pushButton_w_2.setGeometry(QtCore.QRect(10, 20, 151, 51))
|
||||
self.pushButton_w_2.setObjectName("pushButton_w_2")
|
||||
self.lineEdit = QtWidgets.QLineEdit(self.centralwidget)
|
||||
self.lineEdit.setGeometry(QtCore.QRect(50, 290, 81, 21))
|
||||
self.lineEdit.setObjectName("lineEdit")
|
||||
self.horizontalSlider = QtWidgets.QSlider(self.centralwidget)
|
||||
self.horizontalSlider.setGeometry(QtCore.QRect(10, 320, 141, 22))
|
||||
self.horizontalSlider.setRange(0, 10) # 将范围设置为从0到最大值
|
||||
self.horizontalSlider.setSingleStep(1)
|
||||
self.horizontalSlider.setValue(0) # 初始值设为0
|
||||
self.horizontalSlider.setOrientation(QtCore.Qt.Horizontal)
|
||||
self.horizontalSlider.setTickInterval(0)
|
||||
self.horizontalSlider.setObjectName("horizontalSlider")
|
||||
MainWindow.setCentralWidget(self.centralwidget)
|
||||
self.menubar = QtWidgets.QMenuBar(MainWindow)
|
||||
self.menubar.setGeometry(QtCore.QRect(0, 0, 1140, 23))
|
||||
self.menubar.setObjectName("menubar")
|
||||
MainWindow.setMenuBar(self.menubar)
|
||||
self.statusbar = QtWidgets.QStatusBar(MainWindow)
|
||||
self.statusbar.setObjectName("statusbar")
|
||||
MainWindow.setStatusBar(self.statusbar)
|
||||
|
||||
self.retranslateUi(MainWindow)
|
||||
QtCore.QMetaObject.connectSlotsByName(MainWindow)
|
||||
|
||||
def retranslateUi(self, MainWindow):
|
||||
_translate = QtCore.QCoreApplication.translate
|
||||
MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow"))
|
||||
self.pushButton_w.setText(_translate("MainWindow", "Predict"))
|
||||
self.pushButton_a.setText(_translate("MainWindow", "Pre"))
|
||||
self.pushButton_d.setText(_translate("MainWindow", "Next"))
|
||||
self.pushButton_s.setText(_translate("MainWindow", "Save"))
|
||||
self.pushButton_5.setText(_translate("MainWindow", "背景图"))
|
||||
self.label_orign.setText(_translate("MainWindow", "<html><head/><body><p align=\"center\">原始图像</p></body></html>"))
|
||||
self.label_2.setText(_translate("MainWindow", "<html><head/><body><p align=\"center\">预测图像</p></body></html>"))
|
||||
self.pushButton_w_2.setText(_translate("MainWindow", "Openimg"))
|
||||
self.lineEdit.setText(_translate("MainWindow", "改变mask大小"))
|
||||
|
||||
class MyMainWindow(QtWidgets.QMainWindow, Ui_MainWindow):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.setupUi(self)
|
||||
|
||||
self.k = 0
|
||||
self.last_value = 0 # 保存上一次滑块值
|
||||
|
||||
self.image_path = ""
|
||||
self.image_folder = ""
|
||||
self.image_files = []
|
||||
self.current_index = 0
|
||||
self.input_stop = False # 在这里初始化 input_stop
|
||||
|
||||
self.pushButton_w_2.clicked.connect(self.open_image_folder)
|
||||
self.pushButton_a.clicked.connect(self.load_previous_image)
|
||||
self.pushButton_d.clicked.connect(self.load_next_image)
|
||||
self.pushButton_s.clicked.connect(self.save_prediction)
|
||||
self.pushButton_5.clicked.connect(self.select_background_image)
|
||||
self.horizontalSlider.valueChanged.connect(self.adjust_prediction_size) # 连接水平滑块的值改变信号
|
||||
|
||||
def adjust_pixmap_size(self, pixmap, scale_factor):
|
||||
scaled_size = QtCore.QSize(pixmap.size().width() * scale_factor / 100,
|
||||
pixmap.size().height() * scale_factor / 100)
|
||||
return pixmap.scaled(scaled_size, QtCore.Qt.KeepAspectRatio)
|
||||
|
||||
|
||||
def open_image_folder(self):
|
||||
folder_dialog = QtWidgets.QFileDialog()
|
||||
folder_path = folder_dialog.getExistingDirectory(self, 'Open Image Folder', '')
|
||||
if folder_path:
|
||||
self.image_folder = folder_path
|
||||
self.image_files = self.get_image_files(self.image_folder)
|
||||
if self.image_files:
|
||||
self.show_image_selection_dialog()
|
||||
|
||||
def load_previous_image(self):
|
||||
if self.image_files:
|
||||
if self.current_index > 0:
|
||||
self.current_index -= 1
|
||||
else:
|
||||
self.current_index = len(self.image_files) - 1
|
||||
self.show_image()
|
||||
|
||||
def load_next_image(self):
|
||||
if self.image_files:
|
||||
if self.current_index < len(self.image_files) - 1:
|
||||
self.current_index += 1
|
||||
else:
|
||||
self.current_index = 0
|
||||
self.show_image()
|
||||
|
||||
def get_image_files(self, folder_path):
|
||||
image_files = [file for file in os.listdir(folder_path) if file.endswith(('png', 'jpg', 'jpeg', 'bmp'))]
|
||||
return image_files
|
||||
|
||||
def show_image_selection_dialog(self):
|
||||
dialog = QtWidgets.QDialog(self)
|
||||
dialog.setWindowTitle("Select Image")
|
||||
layout = QtWidgets.QVBoxLayout()
|
||||
|
||||
self.listWidget = QtWidgets.QListWidget()
|
||||
for image_file in self.image_files:
|
||||
item = QtWidgets.QListWidgetItem(image_file)
|
||||
pixmap = QtGui.QPixmap(os.path.join(self.image_folder, image_file)).scaledToWidth(100)
|
||||
item.setIcon(QtGui.QIcon(pixmap))
|
||||
self.listWidget.addItem(item)
|
||||
self.listWidget.itemDoubleClicked.connect(self.image_selected)
|
||||
layout.addWidget(self.listWidget)
|
||||
|
||||
buttonBox = QtWidgets.QDialogButtonBox(QtWidgets.QDialogButtonBox.Ok | QtWidgets.QDialogButtonBox.Cancel)
|
||||
buttonBox.accepted.connect(self.image_selected)
|
||||
buttonBox.rejected.connect(dialog.reject)
|
||||
layout.addWidget(buttonBox)
|
||||
|
||||
dialog.setLayout(layout)
|
||||
|
||||
dialog.exec_()
|
||||
|
||||
def image_selected(self):
|
||||
selected_item = self.listWidget.currentItem()
|
||||
if selected_item:
|
||||
selected_index = self.listWidget.currentRow()
|
||||
if selected_index >= 0 and selected_index < len(self.image_files): # 检查索引是否在有效范围内
|
||||
self.current_index = selected_index
|
||||
self.show_image() # 显示所选图像
|
||||
# 调用OpenCV窗口显示
|
||||
self.call_opencv_interaction(os.path.join(self.image_folder, self.image_files[self.current_index]))
|
||||
|
||||
def select_background_image(self):
|
||||
file_dialog = QtWidgets.QFileDialog()
|
||||
image_path, _ = file_dialog.getOpenFileName(self, 'Select Background Image', '',
|
||||
'Image Files (*.png *.jpg *.jpeg *.bmp)')
|
||||
if image_path:
|
||||
self.show_background_image(image_path)
|
||||
|
||||
def show_background_image(self, image_path):
|
||||
pixmap = QtGui.QPixmap(image_path)
|
||||
current_pixmap = self.label_2.pixmap()
|
||||
if current_pixmap:
|
||||
current_pixmap = QtGui.QPixmap(current_pixmap)
|
||||
scene = QtWidgets.QGraphicsScene()
|
||||
scene.addPixmap(pixmap)
|
||||
scene.addPixmap(current_pixmap)
|
||||
merged_pixmap = QtGui.QPixmap(scene.sceneRect().size().toSize())
|
||||
merged_pixmap.fill(QtCore.Qt.transparent)
|
||||
painter = QtGui.QPainter(merged_pixmap)
|
||||
scene.render(painter)
|
||||
painter.end()
|
||||
self.label_2.setPixmap(merged_pixmap)
|
||||
else:
|
||||
self.label_2.setPixmap(pixmap.scaled(self.label_2.size(), QtCore.Qt.KeepAspectRatio))
|
||||
|
||||
def show_image(self):
|
||||
if self.image_files and self.current_index < len(self.image_files):
|
||||
file_path = os.path.join(self.image_folder, self.image_files[self.current_index])
|
||||
pixmap = QtGui.QPixmap(file_path)
|
||||
self.label_orign.setPixmap(pixmap.scaled(self.label_orign.size(), QtCore.Qt.KeepAspectRatio))
|
||||
|
||||
def call_opencv_interaction(self, image_path):
|
||||
input_dir = os.path.dirname(image_path)
|
||||
image_orign = cv2.imread(image_path)
|
||||
output_dir = r'D:\Program Files\Pycharm items\segment-anything-model\scripts\output\maskt'
|
||||
crop_mode = True
|
||||
|
||||
print('最好是每加一个点就按w键predict一次')
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
image_files = [self.image_files[self.current_index]]
|
||||
|
||||
sam = sam_model_registry["vit_b"](
|
||||
checkpoint=r"D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth")
|
||||
_ = sam.to(device="cuda")
|
||||
predictor = SamPredictor(sam)
|
||||
|
||||
WINDOW_WIDTH = 1280
|
||||
WINDOW_HEIGHT = 720
|
||||
|
||||
def apply_mask(image, mask, alpha_channel=True):
|
||||
if alpha_channel:
|
||||
alpha = np.zeros_like(image[..., 0])
|
||||
alpha[mask == 1] = 255
|
||||
image = cv2.merge((image[..., 0], image[..., 1], image[..., 2], alpha))
|
||||
else:
|
||||
image = np.where(mask[..., None] == 1, image, 0)
|
||||
return image
|
||||
|
||||
def apply_color_mask(image, mask, color, color_dark=0.5):
|
||||
for c in range(3):
|
||||
image[:, :, c] = np.where(mask == 1, image[:, :, c] * (1 - color_dark) + color_dark * color[c],
|
||||
image[:, :, c])
|
||||
return image
|
||||
|
||||
def get_next_filename(base_path, filename):
|
||||
name, ext = os.path.splitext(filename)
|
||||
for i in range(1, 101):
|
||||
new_name = f"{name}_{i}{ext}"
|
||||
if not os.path.exists(os.path.join(base_path, new_name)):
|
||||
return new_name
|
||||
return None
|
||||
|
||||
def save_masked_image(image, mask, output_dir, filename, crop_mode_):
|
||||
# 保存图像到指定路径
|
||||
if crop_mode_:
|
||||
# 如果采用了裁剪模式,则裁剪图像
|
||||
y, x = np.where(mask)
|
||||
y_min, y_max, x_min, x_max = y.min(), y.max(), x.min(), x.max()
|
||||
cropped_mask = mask[y_min:y_max + 1, x_min:x_max + 1]
|
||||
cropped_image = image[y_min:y_max + 1, x_min:x_max + 1]
|
||||
masked_image = apply_mask(cropped_image, cropped_mask)
|
||||
else:
|
||||
masked_image = apply_mask(image, mask)
|
||||
filename = filename[:filename.rfind('.')] + '.png'
|
||||
new_filename = get_next_filename(output_dir, filename)
|
||||
|
||||
if new_filename:
|
||||
if masked_image.shape[-1] == 4:
|
||||
cv2.imwrite(os.path.join(output_dir, new_filename), masked_image,
|
||||
[cv2.IMWRITE_PNG_COMPRESSION, 9])
|
||||
else:
|
||||
cv2.imwrite(os.path.join(output_dir, new_filename), masked_image)
|
||||
print(f"Saved as {new_filename}")
|
||||
|
||||
# 读取保存的图像文件
|
||||
saved_image_path = os.path.join(output_dir, new_filename)
|
||||
saved_image_pixmap = QtGui.QPixmap(saved_image_path)
|
||||
|
||||
# 将保存的图像显示在预测图像区域
|
||||
mainWindow.label_2.setPixmap(
|
||||
saved_image_pixmap.scaled(mainWindow.label_2.size(), QtCore.Qt.KeepAspectRatio))
|
||||
else:
|
||||
print("Could not save the image. Too many variations exist.")
|
||||
|
||||
current_index = 0
|
||||
|
||||
cv2.namedWindow("image", cv2.WINDOW_NORMAL)
|
||||
cv2.resizeWindow("image", WINDOW_WIDTH, WINDOW_HEIGHT)
|
||||
cv2.moveWindow("image", (1920 - WINDOW_WIDTH) // 2, (1080 - WINDOW_HEIGHT) // 2)
|
||||
|
||||
def mouse_click(event, x, y, flags, param):
|
||||
if not self.input_stop:
|
||||
if event == cv2.EVENT_LBUTTONDOWN:
|
||||
input_point.append([x, y])
|
||||
input_label.append(1)
|
||||
elif event == cv2.EVENT_RBUTTONDOWN:
|
||||
input_point.append([x, y])
|
||||
input_label.append(0)
|
||||
else:
|
||||
if event == cv2.EVENT_LBUTTONDOWN or event == cv2.EVENT_RBUTTONDOWN:
|
||||
print('此时不能添加点,按w退出mask选择模式')
|
||||
|
||||
cv2.setMouseCallback("image", mouse_click)
|
||||
input_point = []
|
||||
input_label = []
|
||||
input_stop = False
|
||||
while True:
|
||||
filename = self.image_files[self.current_index]
|
||||
image_orign = cv2.imread(os.path.join(input_dir, filename))
|
||||
image_crop = image_orign.copy()
|
||||
image = cv2.cvtColor(image_orign.copy(), cv2.COLOR_BGR2RGB)
|
||||
selected_mask = None
|
||||
logit_input = None
|
||||
while True:
|
||||
image_display = image_orign.copy()
|
||||
display_info = f'{filename} '
|
||||
cv2.putText(image_display, display_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255),
|
||||
2,
|
||||
cv2.LINE_AA)
|
||||
for point, label in zip(input_point, input_label):
|
||||
color = (0, 255, 0) if label == 1 else (0, 0, 255)
|
||||
cv2.circle(image_display, tuple(point), 5, color, -1)
|
||||
if selected_mask is not None:
|
||||
color = tuple(np.random.randint(0, 256, 3).tolist())
|
||||
selected_image = apply_color_mask(image_display, selected_mask, color)
|
||||
|
||||
cv2.imshow("image", image_display)
|
||||
|
||||
key = cv2.waitKey(1)
|
||||
|
||||
if key == ord(" "):
|
||||
input_point = []
|
||||
input_label = []
|
||||
selected_mask = None
|
||||
logit_input = None
|
||||
elif key == ord("w"):
|
||||
input_stop = True
|
||||
if len(input_point) > 0 and len(input_label) > 0:
|
||||
|
||||
predictor.set_image(image)
|
||||
input_point_np = np.array(input_point)
|
||||
input_label_np = np.array(input_label)
|
||||
|
||||
masks, scores, logits = predictor.predict(
|
||||
point_coords=input_point_np,
|
||||
point_labels=input_label_np,
|
||||
mask_input=logit_input[None, :, :] if logit_input is not None else None,
|
||||
multimask_output=True,
|
||||
)
|
||||
|
||||
mask_idx = 0
|
||||
num_masks = len(masks)
|
||||
|
||||
prediction_window_name = "Prediction"
|
||||
cv2.namedWindow(prediction_window_name, cv2.WINDOW_NORMAL)
|
||||
cv2.resizeWindow(prediction_window_name, WINDOW_WIDTH, WINDOW_HEIGHT)
|
||||
cv2.moveWindow(prediction_window_name, (1920 - WINDOW_WIDTH) // 2,
|
||||
(1080 - WINDOW_HEIGHT) // 2)
|
||||
|
||||
while True:
|
||||
color = tuple(np.random.randint(0, 256, 3).tolist())
|
||||
image_select = image_orign.copy()
|
||||
selected_mask = masks[mask_idx]
|
||||
selected_image = apply_color_mask(image_select, selected_mask, color)
|
||||
mask_info = f'Total: {num_masks} | Current: {mask_idx} | Score: {scores[mask_idx]:.2f} | w 预测 | d 切换下一个 | a 切换上一个 |q 移除最后一个'
|
||||
cv2.putText(selected_image, mask_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7,
|
||||
(0, 255, 255), 2, cv2.LINE_AA)
|
||||
|
||||
cv2.imshow(prediction_window_name, selected_image)
|
||||
|
||||
key = cv2.waitKey(10)
|
||||
if key == ord('q') and len(input_point) > 0:
|
||||
input_point.pop(-1)
|
||||
elif key == ord('s'):
|
||||
save_masked_image(image_crop, selected_mask, output_dir, filename,
|
||||
crop_mode_=crop_mode)
|
||||
elif key == ord('a'):
|
||||
if mask_idx > 0:
|
||||
mask_idx -= 1
|
||||
else:
|
||||
mask_idx = num_masks - 1
|
||||
elif key == ord('d'):
|
||||
if mask_idx < num_masks - 1:
|
||||
mask_idx += 1
|
||||
else:
|
||||
mask_idx = 0
|
||||
elif key == ord('w'):
|
||||
input_stop = False # Allow adding points again
|
||||
break
|
||||
elif key == ord(" "):
|
||||
input_point = []
|
||||
input_label = []
|
||||
selected_mask = None
|
||||
logit_input = None
|
||||
break
|
||||
logit_input = logits[mask_idx, :, :]
|
||||
print('max score:', np.argmax(scores), ' select:', mask_idx)
|
||||
|
||||
elif key == ord('a'):
|
||||
current_index = max(0, current_index - 1)
|
||||
input_point = []
|
||||
input_label = []
|
||||
break
|
||||
elif key == ord('d'):
|
||||
current_index = min(len(image_files) - 1, current_index + 1)
|
||||
input_point = []
|
||||
input_label = []
|
||||
break
|
||||
elif key == 27:
|
||||
break
|
||||
elif key == ord('q') and len(input_point) > 0:
|
||||
input_point.pop(-1)
|
||||
input_label.pop(-1)
|
||||
elif key == ord('s') and selected_mask is not None:
|
||||
save_masked_image(image_crop, selected_mask, output_dir, filename, crop_mode_=crop_mode)
|
||||
|
||||
if key == 27:
|
||||
break
|
||||
|
||||
cv2.destroyAllWindows() # Close all windows before exiting
|
||||
if key == 27:
|
||||
break
|
||||
|
||||
def save_prediction(self):
|
||||
if self.label_2.pixmap(): # 检查预测图像区域是否有图像
|
||||
# 保存预测结果的部分,这里假设你已经有了保存预测结果的代码,我用 placeholer 代替
|
||||
# placeholder: 这里假设 save_prediction_result 是一个保存预测结果的函数,它接受预测结果的图像数据以及保存路径作为参数
|
||||
# 这里假设预测结果图像数据为 prediction_image,保存路径为 save_path
|
||||
prediction_image = self.label_2.pixmap().toImage()
|
||||
save_path = "prediction_result.png"
|
||||
prediction_image.save(save_path)
|
||||
|
||||
# 调用 adjust_prediction_size 方法来根据 horizontalSlider 的值调整预测区域的大小
|
||||
self.adjust_prediction_size(self.horizontalSlider.value())
|
||||
|
||||
def adjust_prediction_size(self, value):
|
||||
if self.image_files and self.current_index < len(self.image_files):
|
||||
# 获取预测图像区域的原始大小
|
||||
pixmap = self.label_2.pixmap()
|
||||
if pixmap.isNull():
|
||||
return
|
||||
|
||||
original_size = pixmap.size()
|
||||
|
||||
# 判断是缩小还是还原图像
|
||||
if value < self.last_value:
|
||||
# 缩小掩码
|
||||
scale_factor = 1.0 + (self.last_value - value) * 0.1
|
||||
else:
|
||||
# 放大掩码
|
||||
scale_factor = 1.0 - (value - self.last_value) * 0.1
|
||||
|
||||
self.last_value = value # 更新上一次的滑块值
|
||||
|
||||
# 根据缩放比例调整预测图像区域的大小,并保持纵横比例
|
||||
scaled_size = QtCore.QSize(original_size.width() * scale_factor, original_size.height() * scale_factor)
|
||||
scaled_pixmap = pixmap.scaled(scaled_size, QtCore.Qt.KeepAspectRatio)
|
||||
|
||||
# 更新预测图像区域的大小并显示
|
||||
self.label_2.setPixmap(scaled_pixmap)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app = QtWidgets.QApplication(sys.argv)
|
||||
mainWindow = MyMainWindow()
|
||||
mainWindow.show()
|
||||
sys.exit(app.exec_())
|
||||
|
|
@ -0,0 +1,248 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import cv2 # type: ignore
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
import numpy as np
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description=(
|
||||
"Runs automatic mask generation on an input image or directory of images, "
|
||||
"and outputs masks as either PNGs or COCO-style RLEs. Requires open-cv, "
|
||||
"as well as pycocotools if saving in RLE format."
|
||||
)
|
||||
)
|
||||
# --input scripts/input/crops/Guide/23140.jpg --output scripts/output/crops --model-type vit_b --checkpoint D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth
|
||||
parser.add_argument(
|
||||
"--input",
|
||||
type=str,
|
||||
default=r'scripts/input/crops/Guide/23140.jpg',
|
||||
required=False,
|
||||
help="Path to either a single input image or folder of images.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
required=False,
|
||||
default=r'scripts/output/crops/Guide/',
|
||||
help=(
|
||||
"Path to the directory where masks will be output. Output will be either a folder "
|
||||
"of PNGs per image or a single json with COCO-style masks."
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model-type",
|
||||
type=str,
|
||||
required=False,
|
||||
default='vit_b',
|
||||
help="The type of model to load, in ['default', 'vit_h', 'vit_l', 'vit_b']",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint",
|
||||
type=str,
|
||||
required=False,
|
||||
default=r'D:/Program Files/Pycharm items/segment-anything-model/weights/vit_b.pth',
|
||||
help="The path to the SAM checkpoint to use for mask generation.",
|
||||
)
|
||||
|
||||
parser.add_argument("--device", type=str, default="cuda", help="The device to run generation on.")
|
||||
|
||||
parser.add_argument(
|
||||
"--convert-to-rle",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Save masks as COCO RLEs in a single json instead of as a folder of PNGs. "
|
||||
"Requires pycocotools."
|
||||
),
|
||||
)
|
||||
|
||||
amg_settings = parser.add_argument_group("AMG Settings")
|
||||
|
||||
amg_settings.add_argument(
|
||||
"--points-per-side",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Generate masks by sampling a grid over the image with this many points to a side.",
|
||||
)
|
||||
|
||||
amg_settings.add_argument(
|
||||
"--points-per-batch",
|
||||
type=int,
|
||||
default=None,
|
||||
help="How many input points to process simultaneously in one batch.",
|
||||
)
|
||||
|
||||
amg_settings.add_argument(
|
||||
"--pred-iou-thresh",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Exclude masks with a predicted score from the model that is lower than this threshold.",
|
||||
)
|
||||
|
||||
amg_settings.add_argument(
|
||||
"--stability-score-thresh",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Exclude masks with a stability score lower than this threshold.",
|
||||
)
|
||||
|
||||
amg_settings.add_argument(
|
||||
"--stability-score-offset",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Larger values perturb the mask more when measuring stability score.",
|
||||
)
|
||||
|
||||
amg_settings.add_argument(
|
||||
"--box-nms-thresh",
|
||||
type=float,
|
||||
default=None,
|
||||
help="The overlap threshold for excluding a duplicate mask.",
|
||||
)
|
||||
|
||||
amg_settings.add_argument(
|
||||
"--crop-n-layers",
|
||||
type=int,
|
||||
default=None,
|
||||
help=(
|
||||
"If >0, mask generation is run on smaller crops of the image to generate more masks. "
|
||||
"The value sets how many different scales to crop at."
|
||||
),
|
||||
)
|
||||
|
||||
amg_settings.add_argument(
|
||||
"--crop-nms-thresh",
|
||||
type=float,
|
||||
default=None,
|
||||
help="The overlap threshold for excluding duplicate masks across different crops.",
|
||||
)
|
||||
|
||||
amg_settings.add_argument(
|
||||
"--crop-overlap-ratio",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Larger numbers mean image crops will overlap more.",
|
||||
)
|
||||
|
||||
amg_settings.add_argument(
|
||||
"--crop-n-points-downscale-factor",
|
||||
type=int,
|
||||
default=None,
|
||||
help="The number of points-per-side in each layer of crop is reduced by this factor.",
|
||||
)
|
||||
|
||||
amg_settings.add_argument(
|
||||
"--min-mask-region-area",
|
||||
type=int,
|
||||
default=None,
|
||||
help=(
|
||||
"Disconnected mask regions or holes with area smaller than this value "
|
||||
"in pixels are removed by postprocessing."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def write_masks_to_folder(masks: List[Dict[str, Any]], path: str) -> None:
|
||||
header = "id,area,bbox_x0,bbox_y0,bbox_w,bbox_h,point_input_x,point_input_y,predicted_iou,stability_score,crop_box_x0,crop_box_y0,crop_box_w,crop_box_h" # noqa
|
||||
metadata = [header]
|
||||
for i, mask_data in enumerate(masks):
|
||||
mask = mask_data["segmentation"]
|
||||
filename = f"{i}.png"
|
||||
cv2.imwrite(os.path.join(path, filename), mask * 255)
|
||||
mask_metadata = [
|
||||
str(i),
|
||||
str(mask_data["area"]),
|
||||
*[str(x) for x in mask_data["bbox"]],
|
||||
*[str(x) for x in mask_data["point_coords"][0]],
|
||||
str(mask_data["predicted_iou"]),
|
||||
str(mask_data["stability_score"]),
|
||||
*[str(x) for x in mask_data["crop_box"]],
|
||||
]
|
||||
row = ",".join(mask_metadata)
|
||||
metadata.append(row)
|
||||
metadata_path = os.path.join(path, "metadata.csv")
|
||||
with open(metadata_path, "w") as f:
|
||||
f.write("\n".join(metadata))
|
||||
|
||||
return
|
||||
|
||||
|
||||
def get_amg_kwargs(args):
|
||||
amg_kwargs = {
|
||||
"points_per_side": args.points_per_side,
|
||||
"points_per_batch": args.points_per_batch,
|
||||
"pred_iou_thresh": args.pred_iou_thresh,
|
||||
"stability_score_thresh": args.stability_score_thresh,
|
||||
"stability_score_offset": args.stability_score_offset,
|
||||
"box_nms_thresh": args.box_nms_thresh,
|
||||
"crop_n_layers": args.crop_n_layers,
|
||||
"crop_nms_thresh": args.crop_nms_thresh,
|
||||
"crop_overlap_ratio": args.crop_overlap_ratio,
|
||||
"crop_n_points_downscale_factor": args.crop_n_points_downscale_factor,
|
||||
"min_mask_region_area": args.min_mask_region_area,
|
||||
}
|
||||
amg_kwargs = {k: v for k, v in amg_kwargs.items() if v is not None}
|
||||
return amg_kwargs
|
||||
|
||||
|
||||
def main(args: argparse.Namespace) -> None:
|
||||
print("Loading model...")
|
||||
sam = sam_model_registry[args.model_type](checkpoint=args.checkpoint)
|
||||
_ = sam.to(device=args.device)
|
||||
output_mode = "coco_rle" if args.convert_to_rle else "binary_mask"
|
||||
amg_kwargs = get_amg_kwargs(args)
|
||||
generator = SamAutomaticMaskGenerator(sam, output_mode=output_mode, **amg_kwargs)
|
||||
|
||||
if not os.path.isdir(args.input):
|
||||
targets = [args.input]
|
||||
else:
|
||||
targets = [
|
||||
f for f in os.listdir(args.input) if not os.path.isdir(os.path.join(args.input, f))
|
||||
]
|
||||
targets = [os.path.join(args.input, f) for f in targets]
|
||||
|
||||
os.makedirs(args.output, exist_ok=True)
|
||||
|
||||
for t in targets:
|
||||
print(f"Processing '{t}'...")
|
||||
image = cv2.imread(t)
|
||||
if image is None:
|
||||
print(f"Could not load '{t}' as an image, skipping...")
|
||||
continue
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
masks = generator.generate(image)
|
||||
|
||||
base = os.path.basename(t)
|
||||
base = os.path.splitext(base)[0]
|
||||
save_base = os.path.join(args.output, base)
|
||||
if output_mode == "binary_mask":
|
||||
os.makedirs(save_base, exist_ok=False)
|
||||
write_masks_to_folder(masks, save_base)
|
||||
else:
|
||||
save_file = save_base + ".json"
|
||||
with open(save_file, "w") as f:
|
||||
json.dump(masks, f)
|
||||
print("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
"""
|
||||
D:\anaconda3\envs\pytorch\python.exe "D:/Program Files/Pycharm items/segment-anything-model/scripts/amg.py" --input "scripts/input/crops/Guide/23140.jpg" --output "scripts/output/crops" --model-type vit_b --checkpoint "D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth"
|
||||
|
||||
"""
|
|
@ -0,0 +1,272 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import cv2 # type: ignore
|
||||
|
||||
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description=(
|
||||
"Runs automatic mask generation on an input image or directory of images, "
|
||||
"and outputs masks as either PNGs or COCO-style RLEs. Requires open-cv, "
|
||||
"as well as pycocotools if saving in RLE format."
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input",
|
||||
type=str,
|
||||
required=False,
|
||||
default=r'D:\Program Files\Pycharm items\segment-anything-model\scripts\input\images',
|
||||
help="Path to either a single input image or folder of images.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
required=False,
|
||||
default=r'output/mask',
|
||||
help=(
|
||||
"Path to the directory where masks will be output. Output will be either a folder "
|
||||
"of PNGs per image or a single json with COCO-style masks."
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model-type",
|
||||
type=str,
|
||||
required=False,
|
||||
default='vit_b',
|
||||
help="The type of model to load, in ['default', 'vit_h', 'vit_l', 'vit_b']",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint",
|
||||
type=str,
|
||||
required=False,
|
||||
default=r'D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth',
|
||||
help="The path to the SAM checkpoint to use for mask generation.",
|
||||
)
|
||||
|
||||
# parser.add_argument(
|
||||
# "--input",
|
||||
# type=str,
|
||||
# required=True,
|
||||
# help="Path to either a single input image or folder of images.",
|
||||
# )
|
||||
#
|
||||
# parser.add_argument(
|
||||
# "--output",
|
||||
# type=str,
|
||||
# required=True,
|
||||
# help=(
|
||||
# "Path to the directory where masks will be output. Output will be either a folder "
|
||||
# "of PNGs per image or a single json with COCO-style masks."
|
||||
# ),
|
||||
# )
|
||||
#
|
||||
# parser.add_argument(
|
||||
# "--model-type",
|
||||
# type=str,
|
||||
# required=True,
|
||||
# help="The type of model to load, in ['default', 'vit_h', 'vit_l', 'vit_b']",
|
||||
# )
|
||||
#
|
||||
# parser.add_argument(
|
||||
# "--checkpoint",
|
||||
# type=str,
|
||||
# required=True,
|
||||
# help="The path to the SAM checkpoint to use for mask generation.",
|
||||
# )
|
||||
|
||||
parser.add_argument("--device", type=str, default="cuda", help="The device to run generation on.")
|
||||
|
||||
parser.add_argument(
|
||||
"--convert-to-rle",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Save masks as COCO RLEs in a single json instead of as a folder of PNGs. "
|
||||
"Requires pycocotools."
|
||||
),
|
||||
)
|
||||
|
||||
amg_settings = parser.add_argument_group("AMG Settings")
|
||||
|
||||
amg_settings.add_argument(
|
||||
"--points-per-side",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Generate masks by sampling a grid over the image with this many points to a side.",
|
||||
)
|
||||
|
||||
amg_settings.add_argument(
|
||||
"--points-per-batch",
|
||||
type=int,
|
||||
default=None,
|
||||
help="How many input points to process simultaneously in one batch.",
|
||||
)
|
||||
|
||||
amg_settings.add_argument(
|
||||
"--pred-iou-thresh",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Exclude masks with a predicted score from the model that is lower than this threshold.",
|
||||
)
|
||||
|
||||
amg_settings.add_argument(
|
||||
"--stability-score-thresh",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Exclude masks with a stability score lower than this threshold.",
|
||||
)
|
||||
|
||||
amg_settings.add_argument(
|
||||
"--stability-score-offset",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Larger values perturb the mask more when measuring stability score.",
|
||||
)
|
||||
|
||||
amg_settings.add_argument(
|
||||
"--box-nms-thresh",
|
||||
type=float,
|
||||
default=None,
|
||||
help="The overlap threshold for excluding a duplicate mask.",
|
||||
)
|
||||
|
||||
amg_settings.add_argument(
|
||||
"--crop-n-layers",
|
||||
type=int,
|
||||
default=None,
|
||||
help=(
|
||||
"If >0, mask generation is run on smaller crops of the image to generate more masks. "
|
||||
"The value sets how many different scales to crop at."
|
||||
),
|
||||
)
|
||||
|
||||
amg_settings.add_argument(
|
||||
"--crop-nms-thresh",
|
||||
type=float,
|
||||
default=None,
|
||||
help="The overlap threshold for excluding duplicate masks across different crops.",
|
||||
)
|
||||
|
||||
amg_settings.add_argument(
|
||||
"--crop-overlap-ratio",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Larger numbers mean image crops will overlap more.",
|
||||
)
|
||||
|
||||
amg_settings.add_argument(
|
||||
"--crop-n-points-downscale-factor",
|
||||
type=int,
|
||||
default=None,
|
||||
help="The number of points-per-side in each layer of crop is reduced by this factor.",
|
||||
)
|
||||
|
||||
amg_settings.add_argument(
|
||||
"--min-mask-region-area",
|
||||
type=int,
|
||||
default=None,
|
||||
help=(
|
||||
"Disconnected mask regions or holes with area smaller than this value "
|
||||
"in pixels are removed by postprocessing."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def write_masks_to_folder(masks: List[Dict[str, Any]], path: str) -> None:
|
||||
header = "id,area,bbox_x0,bbox_y0,bbox_w,bbox_h,point_input_x,point_input_y,predicted_iou,stability_score,crop_box_x0,crop_box_y0,crop_box_w,crop_box_h" # noqa
|
||||
metadata = [header]
|
||||
for i, mask_data in enumerate(masks):
|
||||
mask = mask_data["segmentation"]
|
||||
filename = f"{i}.png"
|
||||
cv2.imwrite(os.path.join(path, filename), mask * 255)
|
||||
mask_metadata = [
|
||||
str(i),
|
||||
str(mask_data["area"]),
|
||||
*[str(x) for x in mask_data["bbox"]],
|
||||
*[str(x) for x in mask_data["point_coords"][0]],
|
||||
str(mask_data["predicted_iou"]),
|
||||
str(mask_data["stability_score"]),
|
||||
*[str(x) for x in mask_data["crop_box"]],
|
||||
]
|
||||
row = ",".join(mask_metadata)
|
||||
metadata.append(row)
|
||||
metadata_path = os.path.join(path, "metadata.csv")
|
||||
with open(metadata_path, "w") as f:
|
||||
f.write("\n".join(metadata))
|
||||
|
||||
return
|
||||
|
||||
|
||||
def get_amg_kwargs(args):
|
||||
amg_kwargs = {
|
||||
"points_per_side": args.points_per_side,
|
||||
"points_per_batch": args.points_per_batch,
|
||||
"pred_iou_thresh": args.pred_iou_thresh,
|
||||
"stability_score_thresh": args.stability_score_thresh,
|
||||
"stability_score_offset": args.stability_score_offset,
|
||||
"box_nms_thresh": args.box_nms_thresh,
|
||||
"crop_n_layers": args.crop_n_layers,
|
||||
"crop_nms_thresh": args.crop_nms_thresh,
|
||||
"crop_overlap_ratio": args.crop_overlap_ratio,
|
||||
"crop_n_points_downscale_factor": args.crop_n_points_downscale_factor,
|
||||
"min_mask_region_area": args.min_mask_region_area,
|
||||
}
|
||||
amg_kwargs = {k: v for k, v in amg_kwargs.items() if v is not None}
|
||||
return amg_kwargs
|
||||
|
||||
|
||||
def main(args: argparse.Namespace) -> None:
|
||||
print("Loading model...")
|
||||
sam = sam_model_registry[args.model_type](checkpoint=args.checkpoint)
|
||||
_ = sam.to(device=args.device)
|
||||
output_mode = "coco_rle" if args.convert_to_rle else "binary_mask"
|
||||
amg_kwargs = get_amg_kwargs(args)
|
||||
generator = SamAutomaticMaskGenerator(sam, output_mode=output_mode, **amg_kwargs)
|
||||
|
||||
if not os.path.isdir(args.input):
|
||||
targets = [args.input]
|
||||
else:
|
||||
targets = [
|
||||
f for f in os.listdir(args.input) if not os.path.isdir(os.path.join(args.input, f))
|
||||
]
|
||||
targets = [os.path.join(args.input, f) for f in targets]
|
||||
|
||||
os.makedirs(args.output, exist_ok=True)
|
||||
|
||||
for t in targets:
|
||||
print(f"Processing '{t}'...")
|
||||
image = cv2.imread(t)
|
||||
if image is None:
|
||||
print(f"Could not load '{t}' as an image, skipping...")
|
||||
continue
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
masks = generator.generate(image)
|
||||
|
||||
base = os.path.basename(t)
|
||||
base = os.path.splitext(base)[0]
|
||||
save_base = os.path.join(args.output, base)
|
||||
if output_mode == "binary_mask":
|
||||
os.makedirs(save_base, exist_ok=False)
|
||||
write_masks_to_folder(masks, save_base)
|
||||
else:
|
||||
save_file = save_base + ".json"
|
||||
with open(save_file, "w") as f:
|
||||
json.dump(masks, f)
|
||||
print("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
main(args)
|
|
@ -0,0 +1,201 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
|
||||
from segment_anything import sam_model_registry
|
||||
from segment_anything.utils.onnx import SamOnnxModel
|
||||
|
||||
import argparse
|
||||
import warnings
|
||||
|
||||
try:
|
||||
import onnxruntime # type: ignore
|
||||
|
||||
onnxruntime_exists = True
|
||||
except ImportError:
|
||||
onnxruntime_exists = False
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Export the SAM prompt encoder and mask decoder to an ONNX model."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint", type=str, required=True, help="The path to the SAM model checkpoint."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output", type=str, required=True, help="The filename to save the ONNX model to."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model-type",
|
||||
type=str,
|
||||
required=True,
|
||||
help="In ['default', 'vit_h', 'vit_l', 'vit_b']. Which type of SAM model to export.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--return-single-mask",
|
||||
action="store_true",
|
||||
help=(
|
||||
"If true, the exported ONNX model will only return the best mask, "
|
||||
"instead of returning multiple masks. For high resolution images "
|
||||
"this can improve runtime when upscaling masks is expensive."
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--opset",
|
||||
type=int,
|
||||
default=17,
|
||||
help="The ONNX opset version to use. Must be >=11",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--quantize-out",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"If set, will quantize the model and save it with this name. "
|
||||
"Quantization is performed with quantize_dynamic from onnxruntime.quantization.quantize."
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--gelu-approximate",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Replace GELU operations with approximations using tanh. Useful "
|
||||
"for some runtimes that have slow or unimplemented erf ops, used in GELU."
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-stability-score",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Replaces the model's predicted mask quality score with the stability "
|
||||
"score calculated on the low resolution masks using an offset of 1.0. "
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--return-extra-metrics",
|
||||
action="store_true",
|
||||
help=(
|
||||
"The model will return five results: (masks, scores, stability_scores, "
|
||||
"areas, low_res_logits) instead of the usual three. This can be "
|
||||
"significantly slower for high resolution outputs."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def run_export(
|
||||
model_type: str,
|
||||
checkpoint: str,
|
||||
output: str,
|
||||
opset: int,
|
||||
return_single_mask: bool,
|
||||
gelu_approximate: bool = False,
|
||||
use_stability_score: bool = False,
|
||||
return_extra_metrics=False,
|
||||
):
|
||||
print("Loading model...")
|
||||
sam = sam_model_registry[model_type](checkpoint=checkpoint)
|
||||
|
||||
onnx_model = SamOnnxModel(
|
||||
model=sam,
|
||||
return_single_mask=return_single_mask,
|
||||
use_stability_score=use_stability_score,
|
||||
return_extra_metrics=return_extra_metrics,
|
||||
)
|
||||
|
||||
if gelu_approximate:
|
||||
for n, m in onnx_model.named_modules():
|
||||
if isinstance(m, torch.nn.GELU):
|
||||
m.approximate = "tanh"
|
||||
|
||||
dynamic_axes = {
|
||||
"point_coords": {1: "num_points"},
|
||||
"point_labels": {1: "num_points"},
|
||||
}
|
||||
|
||||
embed_dim = sam.prompt_encoder.embed_dim
|
||||
embed_size = sam.prompt_encoder.image_embedding_size
|
||||
mask_input_size = [4 * x for x in embed_size]
|
||||
dummy_inputs = {
|
||||
"image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float),
|
||||
"point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float),
|
||||
"point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float),
|
||||
"mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float),
|
||||
"has_mask_input": torch.tensor([1], dtype=torch.float),
|
||||
"orig_im_size": torch.tensor([1500, 2250], dtype=torch.float),
|
||||
}
|
||||
|
||||
_ = onnx_model(**dummy_inputs)
|
||||
|
||||
output_names = ["masks", "iou_predictions", "low_res_masks"]
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
with open(output, "wb") as f:
|
||||
print(f"Exporting onnx model to {output}...")
|
||||
torch.onnx.export(
|
||||
onnx_model,
|
||||
tuple(dummy_inputs.values()),
|
||||
f,
|
||||
export_params=True,
|
||||
verbose=False,
|
||||
opset_version=opset,
|
||||
do_constant_folding=True,
|
||||
input_names=list(dummy_inputs.keys()),
|
||||
output_names=output_names,
|
||||
dynamic_axes=dynamic_axes,
|
||||
)
|
||||
|
||||
if onnxruntime_exists:
|
||||
ort_inputs = {k: to_numpy(v) for k, v in dummy_inputs.items()}
|
||||
# set cpu provider default
|
||||
providers = ["CPUExecutionProvider"]
|
||||
ort_session = onnxruntime.InferenceSession(output, providers=providers)
|
||||
_ = ort_session.run(None, ort_inputs)
|
||||
print("Model has successfully been run with ONNXRuntime.")
|
||||
|
||||
|
||||
def to_numpy(tensor):
|
||||
return tensor.cpu().numpy()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
run_export(
|
||||
model_type=args.model_type,
|
||||
checkpoint=args.checkpoint,
|
||||
output=args.output,
|
||||
opset=args.opset,
|
||||
return_single_mask=args.return_single_mask,
|
||||
gelu_approximate=args.gelu_approximate,
|
||||
use_stability_score=args.use_stability_score,
|
||||
return_extra_metrics=args.return_extra_metrics,
|
||||
)
|
||||
|
||||
if args.quantize_out is not None:
|
||||
assert onnxruntime_exists, "onnxruntime is required to quantize the model."
|
||||
from onnxruntime.quantization import QuantType # type: ignore
|
||||
from onnxruntime.quantization.quantize import quantize_dynamic # type: ignore
|
||||
|
||||
print(f"Quantizing model and writing to {args.quantize_out}...")
|
||||
quantize_dynamic(
|
||||
model_input=args.output,
|
||||
model_output=args.quantize_out,
|
||||
optimize_model=True,
|
||||
per_channel=False,
|
||||
reduce_range=False,
|
||||
weight_type=QuantType.QUInt8,
|
||||
)
|
||||
print("Done!")
|
After Width: | Height: | Size: 1.6 KiB |
After Width: | Height: | Size: 2.3 KiB |
After Width: | Height: | Size: 1.5 KiB |
After Width: | Height: | Size: 3.3 KiB |
After Width: | Height: | Size: 1.6 KiB |
After Width: | Height: | Size: 1.6 KiB |
After Width: | Height: | Size: 2.0 KiB |
After Width: | Height: | Size: 1.9 KiB |
After Width: | Height: | Size: 964 B |
After Width: | Height: | Size: 3.1 KiB |
After Width: | Height: | Size: 3.1 KiB |
After Width: | Height: | Size: 2.3 KiB |
After Width: | Height: | Size: 2.1 KiB |
After Width: | Height: | Size: 2.3 KiB |
After Width: | Height: | Size: 1.3 KiB |
After Width: | Height: | Size: 1.3 KiB |
After Width: | Height: | Size: 1.5 KiB |
After Width: | Height: | Size: 918 B |
After Width: | Height: | Size: 1.3 KiB |
After Width: | Height: | Size: 1.4 KiB |
After Width: | Height: | Size: 1.0 KiB |
After Width: | Height: | Size: 1.7 KiB |
After Width: | Height: | Size: 1.8 KiB |
After Width: | Height: | Size: 3.8 KiB |
After Width: | Height: | Size: 1.2 KiB |
After Width: | Height: | Size: 1.2 KiB |
After Width: | Height: | Size: 2.8 KiB |
After Width: | Height: | Size: 1.5 KiB |
After Width: | Height: | Size: 1.3 KiB |
After Width: | Height: | Size: 1.6 KiB |
After Width: | Height: | Size: 1.2 KiB |
After Width: | Height: | Size: 3.8 KiB |
After Width: | Height: | Size: 1.4 KiB |
After Width: | Height: | Size: 1.5 KiB |
After Width: | Height: | Size: 1.2 KiB |
After Width: | Height: | Size: 1.7 KiB |
After Width: | Height: | Size: 1.3 KiB |
After Width: | Height: | Size: 1.2 KiB |
After Width: | Height: | Size: 1.3 KiB |
After Width: | Height: | Size: 1.4 KiB |
After Width: | Height: | Size: 2.5 KiB |
After Width: | Height: | Size: 4.3 KiB |
After Width: | Height: | Size: 1.5 KiB |
After Width: | Height: | Size: 2.0 KiB |
After Width: | Height: | Size: 906 B |
After Width: | Height: | Size: 915 B |
After Width: | Height: | Size: 959 B |
After Width: | Height: | Size: 1.2 KiB |
After Width: | Height: | Size: 3.1 KiB |
After Width: | Height: | Size: 3.3 KiB |
After Width: | Height: | Size: 3.8 KiB |
After Width: | Height: | Size: 3.8 KiB |
After Width: | Height: | Size: 3.5 KiB |
After Width: | Height: | Size: 2.6 KiB |
After Width: | Height: | Size: 3.3 KiB |
After Width: | Height: | Size: 1.2 KiB |
After Width: | Height: | Size: 1.2 KiB |
After Width: | Height: | Size: 1.2 KiB |