SAM/salt/GUI.py

201 lines
8.9 KiB
Python

import sys
from PyQt5 import QtCore, QtGui, QtWidgets
from PyQt5.QtGui import QPixmap, QImage
from PyQt5.QtCore import QTimer
import cv2
import os
import numpy as np
from segment_anything import sam_model_registry, SamPredictor
input_dir = r'D:\Program Files\Pycharm items\segment-anything-model\scripts\input\images'
output_dir = r'D:\Program Files\Pycharm items\segment-anything-model\scripts\output\maskt'
crop_mode = True
sam = sam_model_registry["vit_b"](checkpoint=r"D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth")
_ = sam.to(device="cuda")
predictor = SamPredictor(sam)
WINDOW_WIDTH = 1280
WINDOW_HEIGHT = 720
class Ui_MainWindow(object):
def setupUi(self, MainWindow):
MainWindow.setObjectName("MainWindow")
MainWindow.resize(1170, 486)
self.centralwidget = QtWidgets.QWidget(MainWindow)
self.centralwidget.setObjectName("centralwidget")
self.pushButton_w = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_w.setGeometry(QtCore.QRect(10, 90, 151, 51))
self.pushButton_w.setObjectName("pushButton_w")
self.pushButton_a = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_a.setGeometry(QtCore.QRect(10, 160, 151, 51))
self.pushButton_a.setObjectName("pushButton_a")
self.pushButton_d = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_d.setGeometry(QtCore.QRect(10, 230, 151, 51))
self.pushButton_d.setObjectName("pushButton_d")
self.pushButton_s = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_s.setGeometry(QtCore.QRect(10, 300, 151, 51))
self.pushButton_s.setObjectName("pushButton_s")
self.pushButton_q = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_q.setGeometry(QtCore.QRect(10, 370, 151, 51))
self.pushButton_q.setObjectName("pushButton_q")
self.label_orign = QtWidgets.QLabel(self.centralwidget)
self.label_orign.setGeometry(QtCore.QRect(180, 20, 471, 401))
self.label_orign.setStyleSheet("background-color: rgb(255, 255, 255);")
self.label_orign.setObjectName("label_orign")
self.label_pre = QtWidgets.QLabel(self.centralwidget)
self.label_pre.setGeometry(QtCore.QRect(660, 20, 471, 401))
self.label_pre.setStyleSheet("background-color: rgb(255, 255, 255);")
self.label_pre.setObjectName("label_pre")
self.pushButton_opimg = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_opimg.setGeometry(QtCore.QRect(10, 20, 151, 51))
self.pushButton_opimg.setObjectName("pushButton_opimg")
MainWindow.setCentralWidget(self.centralwidget)
self.menubar = QtWidgets.QMenuBar(MainWindow)
self.menubar.setGeometry(QtCore.QRect(0, 0, 1170, 26))
self.menubar.setObjectName("menubar")
MainWindow.setMenuBar(self.menubar)
self.statusbar = QtWidgets.QStatusBar(MainWindow)
self.statusbar.setObjectName("statusbar")
MainWindow.setStatusBar(self.statusbar)
self.retranslateUi(MainWindow)
QtCore.QMetaObject.connectSlotsByName(MainWindow)
self.pushButton_opimg.clicked.connect(self.open_image)
self.pushButton_w.clicked.connect(self.predict_image)
self.timer = QTimer()
self.timer.timeout.connect(self.update_original_image)
self.timer.start(100) # Update every 100 milliseconds
self.image_files = [f for f in os.listdir(input_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.JPG', '.JPEG', '.PNG', '.tiff'))]
self.current_index = 0
self.input_point = []
self.input_label = []
self.input_stop = False
cv2.namedWindow("image")
cv2.setMouseCallback("image", self.mouse_click)
def retranslateUi(self, MainWindow):
_translate = QtCore.QCoreApplication.translate
MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow"))
self.pushButton_w.setText(_translate("MainWindow", "w"))
self.pushButton_a.setText(_translate("MainWindow", "a"))
self.pushButton_d.setText(_translate("MainWindow", "d"))
self.pushButton_s.setText(_translate("MainWindow", "s"))
self.pushButton_q.setText(_translate("MainWindow", "q"))
self.label_orign.setText(_translate("MainWindow", "<html><head/><body><p align=\"center\">原始图像</p></body></html>"))
self.label_pre.setText(_translate("MainWindow", "<html><head/><body><p align=\"center\">预测图像</p></body></html>"))
self.pushButton_opimg.setText(_translate("MainWindow", "打开图像"))
def open_image(self):
filename, _ = QtWidgets.QFileDialog.getOpenFileName(None, "Open Image File", "", "Image files (*.jpg *.png)")
if filename:
self.current_index = 0
pixmap = QPixmap(filename)
pixmap = pixmap.scaled(self.label_orign.size(), QtCore.Qt.KeepAspectRatio)
self.label_orign.setPixmap(pixmap)
self.label_orign.setAlignment(QtCore.Qt.AlignCenter)
self.input_point = []
self.input_label = []
def predict_image(self):
if self.current_index < len(self.image_files):
filename = self.image_files[self.current_index]
image_orign = cv2.imread(os.path.join(input_dir, filename))
image_display = cv2.cvtColor(image_orign.copy(), cv2.COLOR_BGR2RGB)
for point, label in zip(self.input_point, self.input_label):
color = (0, 255, 0) if label == 1 else (0, 0, 255)
cv2.circle(image_display, tuple(point), 5, color, -1)
cv2.imshow("image", image_display)
def update_original_image(self):
if self.current_index < len(self.image_files):
filename = self.image_files[self.current_index]
image_orign = cv2.imread(os.path.join(input_dir, filename))
image_display = cv2.cvtColor(image_orign.copy(), cv2.COLOR_BGR2RGB)
for point, label in zip(self.input_point, self.input_label):
color = (0, 255, 0) if label == 1 else (0, 0, 255)
cv2.circle(image_display, tuple(point), 5, color, -1)
height, width, channel = image_display.shape
bytesPerLine = 3 * width
qImg = QImage(image_display.data, width, height, bytesPerLine, QImage.Format_RGB888)
pixmap = QPixmap.fromImage(qImg)
pixmap = pixmap.scaled(self.label_orign.size(), QtCore.Qt.KeepAspectRatio)
self.label_orign.setPixmap(pixmap)
self.label_orign.setAlignment(QtCore.Qt.AlignCenter)
def mouse_click(self, event, x, y, flags, param):
if not self.input_stop:
if event == cv2.EVENT_LBUTTONDOWN:
self.input_point.append([x, y])
self.input_label.append(1)
elif event == cv2.EVENT_RBUTTONDOWN:
self.input_point.append([x, y])
self.input_label.append(0)
else:
if event == cv2.EVENT_LBUTTONDOWN or event == cv2.EVENT_RBUTTONDOWN:
print('此时不能添加点,按w退出mask选择模式')
def apply_mask(image, mask, alpha_channel=True):
if alpha_channel:
alpha = np.zeros_like(image[..., 0])
alpha[mask == 1] = 255
image = cv2.merge((image[..., 0], image[..., 1], image[..., 2], alpha))
else:
image = np.where(mask[..., None] == 1, image, 0)
return image
def apply_color_mask(image, mask, color, color_dark=0.5):
for c in range(3):
image[:, :, c] = np.where(mask == 1, image[:, :, c] * (1 - color_dark) + color_dark * color[c], image[:, :, c])
return image
def get_next_filename(base_path, filename):
name, ext = os.path.splitext(filename)
for i in range(1, 101):
new_name = f"{name}_{i}{ext}"
if not os.path.exists(os.path.join(base_path, new_name)):
return new_name
return None
def save_masked_image(image, mask, output_dir, filename, crop_mode_):
if crop_mode_:
y, x = np.where(mask)
y_min, y_max, x_min, x_max = y.min(), y.max(), x.min(), x.max()
cropped_mask = mask[y_min:y_max + 1, x_min:x_max + 1]
cropped_image = image[y_min:y_max + 1, x_min:x_max + 1]
masked_image = apply_mask(cropped_image, cropped_mask)
else:
masked_image = apply_mask(image, mask)
filename = filename[:filename.rfind('.')] + '.png'
new_filename = get_next_filename(output_dir, filename)
if new_filename:
if masked_image.shape[-1] == 4:
cv2.imwrite(os.path.join(output_dir, new_filename), masked_image, [cv2.IMWRITE_PNG_COMPRESSION, 9])
else:
cv2.imwrite(os.path.join(output_dir, new_filename), masked_image)
print(f"Saved as {new_filename}")
else:
print("Could not save the image. Too many variations exist.")
if __name__ == "__main__":
app = QtWidgets.QApplication(sys.argv)
MainWindow = QtWidgets.QMainWindow()
ui = Ui_MainWindow()
ui.setupUi(MainWindow)
MainWindow.show()
sys.exit(app.exec_())