147 lines
6.6 KiB
Python
147 lines
6.6 KiB
Python
import sys
|
|
import os
|
|
import cv2
|
|
import numpy as np
|
|
from PyQt5 import QtCore, QtGui, QtWidgets
|
|
from segment_anything import sam_model_registry, SamPredictor
|
|
|
|
class ImageLabel(QtWidgets.QLabel):
|
|
clicked = QtCore.pyqtSignal(QtCore.QPoint)
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super(ImageLabel, self).__init__(*args, **kwargs)
|
|
self.foreground_points = []
|
|
self.background_points = []
|
|
|
|
def mousePressEvent(self, event):
|
|
self.clicked.emit(event.pos())
|
|
|
|
if event.button() == QtCore.Qt.LeftButton:
|
|
self.foreground_points.append(event.pos())
|
|
elif event.button() == QtCore.Qt.RightButton:
|
|
self.background_points.append(event.pos())
|
|
|
|
self.update()
|
|
|
|
def paintEvent(self, event):
|
|
super().paintEvent(event)
|
|
|
|
painter = QtGui.QPainter(self)
|
|
painter.setPen(QtGui.QPen(QtGui.QColor("green"), 5))
|
|
|
|
for point in self.foreground_points:
|
|
painter.drawPoint(point)
|
|
|
|
painter.setPen(QtGui.QPen(QtGui.QColor("red"), 5))
|
|
|
|
for point in self.background_points:
|
|
painter.drawPoint(point)
|
|
|
|
|
|
class MainWindow(QtWidgets.QMainWindow):
|
|
def __init__(self, predictor):
|
|
super().__init__()
|
|
self.predictor = predictor
|
|
self.current_points = []
|
|
self.current_image = None
|
|
self.setupUi()
|
|
|
|
def setupUi(self):
|
|
self.setObjectName("MainWindow")
|
|
self.resize(1333, 657)
|
|
self.centralwidget = QtWidgets.QWidget(self)
|
|
self.centralwidget.setObjectName("centralwidget")
|
|
self.pushButton_init = QtWidgets.QPushButton(self.centralwidget)
|
|
self.pushButton_init.setGeometry(QtCore.QRect(10, 30, 141, 41))
|
|
self.pushButton_init.setObjectName("pushButton_init")
|
|
self.pushButton_openimg = QtWidgets.QPushButton(self.centralwidget)
|
|
self.pushButton_openimg.setGeometry(QtCore.QRect(10, 90, 141, 41))
|
|
self.pushButton_openimg.setObjectName("pushButton_openimg")
|
|
self.pushButton_save_mask = QtWidgets.QPushButton(self.centralwidget)
|
|
self.pushButton_save_mask.setGeometry(QtCore.QRect(10, 600, 141, 41))
|
|
self.pushButton_save_mask.setObjectName("pushButton_save_mask")
|
|
self.label_Originalimg = ImageLabel(self.centralwidget)
|
|
self.label_Originalimg.setGeometry(QtCore.QRect(160, 30, 571, 581))
|
|
self.label_Originalimg.setStyleSheet("background-color: rgb(255, 255, 255);")
|
|
self.label_Originalimg.setObjectName("label_Originalimg")
|
|
self.label_Maskimg = QtWidgets.QLabel(self.centralwidget)
|
|
self.label_Maskimg.setGeometry(QtCore.QRect(740, 30, 581, 581))
|
|
self.label_Maskimg.setStyleSheet("background-color: rgb(255, 255, 255);")
|
|
self.label_Maskimg.setObjectName("label_Maskimg")
|
|
self.setCentralWidget(self.centralwidget)
|
|
self.menubar = QtWidgets.QMenuBar(self)
|
|
self.menubar.setGeometry(QtCore.QRect(0, 0, 1333, 26))
|
|
self.menubar.setObjectName("menubar")
|
|
self.setMenuBar(self.menubar)
|
|
self.statusbar = QtWidgets.QStatusBar(self)
|
|
self.statusbar.setObjectName("statusbar")
|
|
self.setStatusBar(self.statusbar)
|
|
self.retranslateUi()
|
|
QtCore.QMetaObject.connectSlotsByName(self)
|
|
|
|
def retranslateUi(self):
|
|
_translate = QtCore.QCoreApplication.translate
|
|
self.setWindowTitle(_translate("MainWindow", "MainWindow"))
|
|
self.pushButton_init.setText(_translate("MainWindow", "重置选择"))
|
|
self.pushButton_openimg.setText(_translate("MainWindow", "打开图片"))
|
|
self.pushButton_save_mask.setText(_translate("MainWindow", "保存掩码"))
|
|
|
|
self.pushButton_openimg.clicked.connect(self.button_image_open)
|
|
self.pushButton_save_mask.clicked.connect(self.button_save_mask)
|
|
self.label_Originalimg.clicked.connect(self.mouse_click)
|
|
|
|
def button_image_open(self):
|
|
choice = QtWidgets.QMessageBox.question(None, "选择", "您想要打开文件夹还是选择一个图片文件?",
|
|
QtWidgets.QMessageBox.Open | QtWidgets.QMessageBox.Cancel)
|
|
if choice == QtWidgets.QMessageBox.Open:
|
|
folder_path = QtWidgets.QFileDialog.getExistingDirectory(None, "选择文件夹", "")
|
|
if folder_path:
|
|
image_files = [os.path.join(folder_path, file) for file in os.listdir(folder_path)
|
|
if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp'))]
|
|
if image_files:
|
|
self.image_files = image_files
|
|
self.current_index = 0
|
|
self.display_image()
|
|
elif choice == QtWidgets.QMessageBox.Cancel:
|
|
selected_image, _ = QtWidgets.QFileDialog.getOpenFileName(None, "选择图片", "",
|
|
"Image files (*.png *.jpg *.jpeg *.bmp)")
|
|
if selected_image:
|
|
self.image_files = [selected_image]
|
|
self.current_index = 0
|
|
self.display_image()
|
|
|
|
def display_image(self):
|
|
if hasattr(self, 'image_files') and self.image_files:
|
|
pixmap = QtGui.QPixmap(self.image_files[self.current_index])
|
|
self.label_Originalimg.setPixmap(pixmap)
|
|
self.label_Originalimg.setScaledContents(True)
|
|
|
|
def mouse_click(self, pos):
|
|
x, y = pos.x(), pos.y()
|
|
print("Mouse clicked at position:", x, y)
|
|
self.current_points.append([x, y])
|
|
|
|
def button_save_mask(self):
|
|
if self.current_image is not None and len(self.current_points) > 0:
|
|
masks, _, _ = self.predictor.predict(point_coords=np.array(self.current_points),
|
|
point_labels=np.ones(len(self.current_points), dtype=np.uint8),
|
|
multimask_output=True)
|
|
if masks:
|
|
mask_image = masks[0]
|
|
mask_image = cv2.cvtColor(mask_image, cv2.COLOR_BGR2RGB)
|
|
q_image = QtGui.QImage(mask_image.data, mask_image.shape[1], mask_image.shape[0], mask_image.strides[0],
|
|
QtGui.QImage.Format_RGB888)
|
|
pixmap = QtGui.QPixmap.fromImage(q_image)
|
|
self.label_Maskimg.setPixmap(pixmap)
|
|
self.label_Maskimg.setScaledContents(True)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
app = QtWidgets.QApplication(sys.argv)
|
|
sam = sam_model_registry["vit_b"](checkpoint=r"D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth")
|
|
sam = sam.to(device="cuda")
|
|
predictor = SamPredictor(sam)
|
|
window = MainWindow(predictor)
|
|
window.show()
|
|
sys.exit(app.exec_())
|