SAM/salt/display1.py

147 lines
6.6 KiB
Python

import sys
import os
import cv2
import numpy as np
from PyQt5 import QtCore, QtGui, QtWidgets
from segment_anything import sam_model_registry, SamPredictor
class ImageLabel(QtWidgets.QLabel):
clicked = QtCore.pyqtSignal(QtCore.QPoint)
def __init__(self, *args, **kwargs):
super(ImageLabel, self).__init__(*args, **kwargs)
self.foreground_points = []
self.background_points = []
def mousePressEvent(self, event):
self.clicked.emit(event.pos())
if event.button() == QtCore.Qt.LeftButton:
self.foreground_points.append(event.pos())
elif event.button() == QtCore.Qt.RightButton:
self.background_points.append(event.pos())
self.update()
def paintEvent(self, event):
super().paintEvent(event)
painter = QtGui.QPainter(self)
painter.setPen(QtGui.QPen(QtGui.QColor("green"), 5))
for point in self.foreground_points:
painter.drawPoint(point)
painter.setPen(QtGui.QPen(QtGui.QColor("red"), 5))
for point in self.background_points:
painter.drawPoint(point)
class MainWindow(QtWidgets.QMainWindow):
def __init__(self, predictor):
super().__init__()
self.predictor = predictor
self.current_points = []
self.current_image = None
self.setupUi()
def setupUi(self):
self.setObjectName("MainWindow")
self.resize(1333, 657)
self.centralwidget = QtWidgets.QWidget(self)
self.centralwidget.setObjectName("centralwidget")
self.pushButton_init = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_init.setGeometry(QtCore.QRect(10, 30, 141, 41))
self.pushButton_init.setObjectName("pushButton_init")
self.pushButton_openimg = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_openimg.setGeometry(QtCore.QRect(10, 90, 141, 41))
self.pushButton_openimg.setObjectName("pushButton_openimg")
self.pushButton_save_mask = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_save_mask.setGeometry(QtCore.QRect(10, 600, 141, 41))
self.pushButton_save_mask.setObjectName("pushButton_save_mask")
self.label_Originalimg = ImageLabel(self.centralwidget)
self.label_Originalimg.setGeometry(QtCore.QRect(160, 30, 571, 581))
self.label_Originalimg.setStyleSheet("background-color: rgb(255, 255, 255);")
self.label_Originalimg.setObjectName("label_Originalimg")
self.label_Maskimg = QtWidgets.QLabel(self.centralwidget)
self.label_Maskimg.setGeometry(QtCore.QRect(740, 30, 581, 581))
self.label_Maskimg.setStyleSheet("background-color: rgb(255, 255, 255);")
self.label_Maskimg.setObjectName("label_Maskimg")
self.setCentralWidget(self.centralwidget)
self.menubar = QtWidgets.QMenuBar(self)
self.menubar.setGeometry(QtCore.QRect(0, 0, 1333, 26))
self.menubar.setObjectName("menubar")
self.setMenuBar(self.menubar)
self.statusbar = QtWidgets.QStatusBar(self)
self.statusbar.setObjectName("statusbar")
self.setStatusBar(self.statusbar)
self.retranslateUi()
QtCore.QMetaObject.connectSlotsByName(self)
def retranslateUi(self):
_translate = QtCore.QCoreApplication.translate
self.setWindowTitle(_translate("MainWindow", "MainWindow"))
self.pushButton_init.setText(_translate("MainWindow", "重置选择"))
self.pushButton_openimg.setText(_translate("MainWindow", "打开图片"))
self.pushButton_save_mask.setText(_translate("MainWindow", "保存掩码"))
self.pushButton_openimg.clicked.connect(self.button_image_open)
self.pushButton_save_mask.clicked.connect(self.button_save_mask)
self.label_Originalimg.clicked.connect(self.mouse_click)
def button_image_open(self):
choice = QtWidgets.QMessageBox.question(None, "选择", "您想要打开文件夹还是选择一个图片文件?",
QtWidgets.QMessageBox.Open | QtWidgets.QMessageBox.Cancel)
if choice == QtWidgets.QMessageBox.Open:
folder_path = QtWidgets.QFileDialog.getExistingDirectory(None, "选择文件夹", "")
if folder_path:
image_files = [os.path.join(folder_path, file) for file in os.listdir(folder_path)
if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp'))]
if image_files:
self.image_files = image_files
self.current_index = 0
self.display_image()
elif choice == QtWidgets.QMessageBox.Cancel:
selected_image, _ = QtWidgets.QFileDialog.getOpenFileName(None, "选择图片", "",
"Image files (*.png *.jpg *.jpeg *.bmp)")
if selected_image:
self.image_files = [selected_image]
self.current_index = 0
self.display_image()
def display_image(self):
if hasattr(self, 'image_files') and self.image_files:
pixmap = QtGui.QPixmap(self.image_files[self.current_index])
self.label_Originalimg.setPixmap(pixmap)
self.label_Originalimg.setScaledContents(True)
def mouse_click(self, pos):
x, y = pos.x(), pos.y()
print("Mouse clicked at position:", x, y)
self.current_points.append([x, y])
def button_save_mask(self):
if self.current_image is not None and len(self.current_points) > 0:
masks, _, _ = self.predictor.predict(point_coords=np.array(self.current_points),
point_labels=np.ones(len(self.current_points), dtype=np.uint8),
multimask_output=True)
if masks:
mask_image = masks[0]
mask_image = cv2.cvtColor(mask_image, cv2.COLOR_BGR2RGB)
q_image = QtGui.QImage(mask_image.data, mask_image.shape[1], mask_image.shape[0], mask_image.strides[0],
QtGui.QImage.Format_RGB888)
pixmap = QtGui.QPixmap.fromImage(q_image)
self.label_Maskimg.setPixmap(pixmap)
self.label_Maskimg.setScaledContents(True)
if __name__ == "__main__":
app = QtWidgets.QApplication(sys.argv)
sam = sam_model_registry["vit_b"](checkpoint=r"D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth")
sam = sam.to(device="cuda")
predictor = SamPredictor(sam)
window = MainWindow(predictor)
window.show()
sys.exit(app.exec_())