import sys import os import cv2 import numpy as np from PyQt5 import QtCore, QtGui, QtWidgets from segment_anything import sam_model_registry, SamPredictor class ImageLabel(QtWidgets.QLabel): clicked = QtCore.pyqtSignal(QtCore.QPoint) def __init__(self, *args, **kwargs): super(ImageLabel, self).__init__(*args, **kwargs) self.foreground_points = [] self.background_points = [] def mousePressEvent(self, event): self.clicked.emit(event.pos()) if event.button() == QtCore.Qt.LeftButton: self.foreground_points.append(event.pos()) elif event.button() == QtCore.Qt.RightButton: self.background_points.append(event.pos()) self.update() def paintEvent(self, event): super().paintEvent(event) painter = QtGui.QPainter(self) painter.setPen(QtGui.QPen(QtGui.QColor("green"), 5)) for point in self.foreground_points: painter.drawPoint(point) painter.setPen(QtGui.QPen(QtGui.QColor("red"), 5)) for point in self.background_points: painter.drawPoint(point) class MainWindow(QtWidgets.QMainWindow): def __init__(self, predictor): super().__init__() self.predictor = predictor self.current_points = [] self.current_image = None self.setupUi() def setupUi(self): self.setObjectName("MainWindow") self.resize(1333, 657) self.centralwidget = QtWidgets.QWidget(self) self.centralwidget.setObjectName("centralwidget") self.pushButton_init = QtWidgets.QPushButton(self.centralwidget) self.pushButton_init.setGeometry(QtCore.QRect(10, 30, 141, 41)) self.pushButton_init.setObjectName("pushButton_init") self.pushButton_openimg = QtWidgets.QPushButton(self.centralwidget) self.pushButton_openimg.setGeometry(QtCore.QRect(10, 90, 141, 41)) self.pushButton_openimg.setObjectName("pushButton_openimg") self.pushButton_save_mask = QtWidgets.QPushButton(self.centralwidget) self.pushButton_save_mask.setGeometry(QtCore.QRect(10, 600, 141, 41)) self.pushButton_save_mask.setObjectName("pushButton_save_mask") self.label_Originalimg = ImageLabel(self.centralwidget) self.label_Originalimg.setGeometry(QtCore.QRect(160, 30, 571, 581)) self.label_Originalimg.setStyleSheet("background-color: rgb(255, 255, 255);") self.label_Originalimg.setObjectName("label_Originalimg") self.label_Maskimg = QtWidgets.QLabel(self.centralwidget) self.label_Maskimg.setGeometry(QtCore.QRect(740, 30, 581, 581)) self.label_Maskimg.setStyleSheet("background-color: rgb(255, 255, 255);") self.label_Maskimg.setObjectName("label_Maskimg") self.setCentralWidget(self.centralwidget) self.menubar = QtWidgets.QMenuBar(self) self.menubar.setGeometry(QtCore.QRect(0, 0, 1333, 26)) self.menubar.setObjectName("menubar") self.setMenuBar(self.menubar) self.statusbar = QtWidgets.QStatusBar(self) self.statusbar.setObjectName("statusbar") self.setStatusBar(self.statusbar) self.retranslateUi() QtCore.QMetaObject.connectSlotsByName(self) def retranslateUi(self): _translate = QtCore.QCoreApplication.translate self.setWindowTitle(_translate("MainWindow", "MainWindow")) self.pushButton_init.setText(_translate("MainWindow", "重置选择")) self.pushButton_openimg.setText(_translate("MainWindow", "打开图片")) self.pushButton_save_mask.setText(_translate("MainWindow", "保存掩码")) self.pushButton_openimg.clicked.connect(self.button_image_open) self.pushButton_save_mask.clicked.connect(self.button_save_mask) self.label_Originalimg.clicked.connect(self.mouse_click) def button_image_open(self): choice = QtWidgets.QMessageBox.question(None, "选择", "您想要打开文件夹还是选择一个图片文件?", QtWidgets.QMessageBox.Open | QtWidgets.QMessageBox.Cancel) if choice == QtWidgets.QMessageBox.Open: folder_path = QtWidgets.QFileDialog.getExistingDirectory(None, "选择文件夹", "") if folder_path: image_files = [os.path.join(folder_path, file) for file in os.listdir(folder_path) if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp'))] if image_files: self.image_files = image_files self.current_index = 0 self.display_image() elif choice == QtWidgets.QMessageBox.Cancel: selected_image, _ = QtWidgets.QFileDialog.getOpenFileName(None, "选择图片", "", "Image files (*.png *.jpg *.jpeg *.bmp)") if selected_image: self.image_files = [selected_image] self.current_index = 0 self.display_image() def display_image(self): if hasattr(self, 'image_files') and self.image_files: pixmap = QtGui.QPixmap(self.image_files[self.current_index]) self.label_Originalimg.setPixmap(pixmap) self.label_Originalimg.setScaledContents(True) def mouse_click(self, pos): x, y = pos.x(), pos.y() print("Mouse clicked at position:", x, y) self.current_points.append([x, y]) def button_save_mask(self): if self.current_image is not None and len(self.current_points) > 0: masks, _, _ = self.predictor.predict(point_coords=np.array(self.current_points), point_labels=np.ones(len(self.current_points), dtype=np.uint8), multimask_output=True) if masks: mask_image = masks[0] mask_image = cv2.cvtColor(mask_image, cv2.COLOR_BGR2RGB) q_image = QtGui.QImage(mask_image.data, mask_image.shape[1], mask_image.shape[0], mask_image.strides[0], QtGui.QImage.Format_RGB888) pixmap = QtGui.QPixmap.fromImage(q_image) self.label_Maskimg.setPixmap(pixmap) self.label_Maskimg.setScaledContents(True) if __name__ == "__main__": app = QtWidgets.QApplication(sys.argv) sam = sam_model_registry["vit_b"](checkpoint=r"D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth") sam = sam.to(device="cuda") predictor = SamPredictor(sam) window = MainWindow(predictor) window.show() sys.exit(app.exec_())