201 lines
8.9 KiB
Python
201 lines
8.9 KiB
Python
|
import sys
|
||
|
from PyQt5 import QtCore, QtGui, QtWidgets
|
||
|
from PyQt5.QtGui import QPixmap, QImage
|
||
|
from PyQt5.QtCore import QTimer
|
||
|
import cv2
|
||
|
import os
|
||
|
import numpy as np
|
||
|
from segment_anything import sam_model_registry, SamPredictor
|
||
|
|
||
|
input_dir = r'D:\Program Files\Pycharm items\segment-anything-model\scripts\input\images'
|
||
|
output_dir = r'D:\Program Files\Pycharm items\segment-anything-model\scripts\output\maskt'
|
||
|
crop_mode = True
|
||
|
|
||
|
sam = sam_model_registry["vit_b"](checkpoint=r"D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth")
|
||
|
_ = sam.to(device="cuda")
|
||
|
predictor = SamPredictor(sam)
|
||
|
|
||
|
WINDOW_WIDTH = 1280
|
||
|
WINDOW_HEIGHT = 720
|
||
|
|
||
|
class Ui_MainWindow(object):
|
||
|
def setupUi(self, MainWindow):
|
||
|
MainWindow.setObjectName("MainWindow")
|
||
|
MainWindow.resize(1170, 486)
|
||
|
self.centralwidget = QtWidgets.QWidget(MainWindow)
|
||
|
self.centralwidget.setObjectName("centralwidget")
|
||
|
self.pushButton_w = QtWidgets.QPushButton(self.centralwidget)
|
||
|
self.pushButton_w.setGeometry(QtCore.QRect(10, 90, 151, 51))
|
||
|
self.pushButton_w.setObjectName("pushButton_w")
|
||
|
self.pushButton_a = QtWidgets.QPushButton(self.centralwidget)
|
||
|
self.pushButton_a.setGeometry(QtCore.QRect(10, 160, 151, 51))
|
||
|
self.pushButton_a.setObjectName("pushButton_a")
|
||
|
self.pushButton_d = QtWidgets.QPushButton(self.centralwidget)
|
||
|
self.pushButton_d.setGeometry(QtCore.QRect(10, 230, 151, 51))
|
||
|
self.pushButton_d.setObjectName("pushButton_d")
|
||
|
self.pushButton_s = QtWidgets.QPushButton(self.centralwidget)
|
||
|
self.pushButton_s.setGeometry(QtCore.QRect(10, 300, 151, 51))
|
||
|
self.pushButton_s.setObjectName("pushButton_s")
|
||
|
self.pushButton_q = QtWidgets.QPushButton(self.centralwidget)
|
||
|
self.pushButton_q.setGeometry(QtCore.QRect(10, 370, 151, 51))
|
||
|
self.pushButton_q.setObjectName("pushButton_q")
|
||
|
self.label_orign = QtWidgets.QLabel(self.centralwidget)
|
||
|
self.label_orign.setGeometry(QtCore.QRect(180, 20, 471, 401))
|
||
|
self.label_orign.setStyleSheet("background-color: rgb(255, 255, 255);")
|
||
|
self.label_orign.setObjectName("label_orign")
|
||
|
self.label_pre = QtWidgets.QLabel(self.centralwidget)
|
||
|
self.label_pre.setGeometry(QtCore.QRect(660, 20, 471, 401))
|
||
|
self.label_pre.setStyleSheet("background-color: rgb(255, 255, 255);")
|
||
|
self.label_pre.setObjectName("label_pre")
|
||
|
self.pushButton_opimg = QtWidgets.QPushButton(self.centralwidget)
|
||
|
self.pushButton_opimg.setGeometry(QtCore.QRect(10, 20, 151, 51))
|
||
|
self.pushButton_opimg.setObjectName("pushButton_opimg")
|
||
|
MainWindow.setCentralWidget(self.centralwidget)
|
||
|
self.menubar = QtWidgets.QMenuBar(MainWindow)
|
||
|
self.menubar.setGeometry(QtCore.QRect(0, 0, 1170, 26))
|
||
|
self.menubar.setObjectName("menubar")
|
||
|
MainWindow.setMenuBar(self.menubar)
|
||
|
self.statusbar = QtWidgets.QStatusBar(MainWindow)
|
||
|
self.statusbar.setObjectName("statusbar")
|
||
|
MainWindow.setStatusBar(self.statusbar)
|
||
|
|
||
|
self.retranslateUi(MainWindow)
|
||
|
QtCore.QMetaObject.connectSlotsByName(MainWindow)
|
||
|
|
||
|
self.pushButton_opimg.clicked.connect(self.open_image)
|
||
|
self.pushButton_w.clicked.connect(self.predict_image)
|
||
|
|
||
|
self.timer = QTimer()
|
||
|
self.timer.timeout.connect(self.update_original_image)
|
||
|
self.timer.start(100) # Update every 100 milliseconds
|
||
|
|
||
|
self.image_files = [f for f in os.listdir(input_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.JPG', '.JPEG', '.PNG', '.tiff'))]
|
||
|
self.current_index = 0
|
||
|
self.input_point = []
|
||
|
self.input_label = []
|
||
|
self.input_stop = False
|
||
|
|
||
|
cv2.namedWindow("image")
|
||
|
cv2.setMouseCallback("image", self.mouse_click)
|
||
|
|
||
|
def retranslateUi(self, MainWindow):
|
||
|
_translate = QtCore.QCoreApplication.translate
|
||
|
MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow"))
|
||
|
self.pushButton_w.setText(_translate("MainWindow", "w"))
|
||
|
self.pushButton_a.setText(_translate("MainWindow", "a"))
|
||
|
self.pushButton_d.setText(_translate("MainWindow", "d"))
|
||
|
self.pushButton_s.setText(_translate("MainWindow", "s"))
|
||
|
self.pushButton_q.setText(_translate("MainWindow", "q"))
|
||
|
self.label_orign.setText(_translate("MainWindow", "<html><head/><body><p align=\"center\">原始图像</p></body></html>"))
|
||
|
self.label_pre.setText(_translate("MainWindow", "<html><head/><body><p align=\"center\">预测图像</p></body></html>"))
|
||
|
self.pushButton_opimg.setText(_translate("MainWindow", "打开图像"))
|
||
|
|
||
|
def open_image(self):
|
||
|
filename, _ = QtWidgets.QFileDialog.getOpenFileName(None, "Open Image File", "", "Image files (*.jpg *.png)")
|
||
|
if filename:
|
||
|
self.current_index = 0
|
||
|
pixmap = QPixmap(filename)
|
||
|
pixmap = pixmap.scaled(self.label_orign.size(), QtCore.Qt.KeepAspectRatio)
|
||
|
self.label_orign.setPixmap(pixmap)
|
||
|
self.label_orign.setAlignment(QtCore.Qt.AlignCenter)
|
||
|
self.input_point = []
|
||
|
self.input_label = []
|
||
|
|
||
|
def predict_image(self):
|
||
|
if self.current_index < len(self.image_files):
|
||
|
filename = self.image_files[self.current_index]
|
||
|
image_orign = cv2.imread(os.path.join(input_dir, filename))
|
||
|
image_display = cv2.cvtColor(image_orign.copy(), cv2.COLOR_BGR2RGB)
|
||
|
|
||
|
for point, label in zip(self.input_point, self.input_label):
|
||
|
color = (0, 255, 0) if label == 1 else (0, 0, 255)
|
||
|
cv2.circle(image_display, tuple(point), 5, color, -1)
|
||
|
|
||
|
cv2.imshow("image", image_display)
|
||
|
|
||
|
def update_original_image(self):
|
||
|
if self.current_index < len(self.image_files):
|
||
|
filename = self.image_files[self.current_index]
|
||
|
image_orign = cv2.imread(os.path.join(input_dir, filename))
|
||
|
image_display = cv2.cvtColor(image_orign.copy(), cv2.COLOR_BGR2RGB)
|
||
|
|
||
|
for point, label in zip(self.input_point, self.input_label):
|
||
|
color = (0, 255, 0) if label == 1 else (0, 0, 255)
|
||
|
cv2.circle(image_display, tuple(point), 5, color, -1)
|
||
|
|
||
|
height, width, channel = image_display.shape
|
||
|
bytesPerLine = 3 * width
|
||
|
qImg = QImage(image_display.data, width, height, bytesPerLine, QImage.Format_RGB888)
|
||
|
pixmap = QPixmap.fromImage(qImg)
|
||
|
pixmap = pixmap.scaled(self.label_orign.size(), QtCore.Qt.KeepAspectRatio)
|
||
|
self.label_orign.setPixmap(pixmap)
|
||
|
self.label_orign.setAlignment(QtCore.Qt.AlignCenter)
|
||
|
|
||
|
def mouse_click(self, event, x, y, flags, param):
|
||
|
if not self.input_stop:
|
||
|
if event == cv2.EVENT_LBUTTONDOWN:
|
||
|
self.input_point.append([x, y])
|
||
|
self.input_label.append(1)
|
||
|
elif event == cv2.EVENT_RBUTTONDOWN:
|
||
|
self.input_point.append([x, y])
|
||
|
self.input_label.append(0)
|
||
|
else:
|
||
|
if event == cv2.EVENT_LBUTTONDOWN or event == cv2.EVENT_RBUTTONDOWN:
|
||
|
print('此时不能添加点,按w退出mask选择模式')
|
||
|
|
||
|
def apply_mask(image, mask, alpha_channel=True):
|
||
|
if alpha_channel:
|
||
|
alpha = np.zeros_like(image[..., 0])
|
||
|
alpha[mask == 1] = 255
|
||
|
image = cv2.merge((image[..., 0], image[..., 1], image[..., 2], alpha))
|
||
|
else:
|
||
|
image = np.where(mask[..., None] == 1, image, 0)
|
||
|
return image
|
||
|
|
||
|
|
||
|
def apply_color_mask(image, mask, color, color_dark=0.5):
|
||
|
for c in range(3):
|
||
|
image[:, :, c] = np.where(mask == 1, image[:, :, c] * (1 - color_dark) + color_dark * color[c], image[:, :, c])
|
||
|
return image
|
||
|
|
||
|
|
||
|
def get_next_filename(base_path, filename):
|
||
|
name, ext = os.path.splitext(filename)
|
||
|
for i in range(1, 101):
|
||
|
new_name = f"{name}_{i}{ext}"
|
||
|
if not os.path.exists(os.path.join(base_path, new_name)):
|
||
|
return new_name
|
||
|
return None
|
||
|
|
||
|
|
||
|
def save_masked_image(image, mask, output_dir, filename, crop_mode_):
|
||
|
if crop_mode_:
|
||
|
y, x = np.where(mask)
|
||
|
y_min, y_max, x_min, x_max = y.min(), y.max(), x.min(), x.max()
|
||
|
cropped_mask = mask[y_min:y_max + 1, x_min:x_max + 1]
|
||
|
cropped_image = image[y_min:y_max + 1, x_min:x_max + 1]
|
||
|
masked_image = apply_mask(cropped_image, cropped_mask)
|
||
|
else:
|
||
|
masked_image = apply_mask(image, mask)
|
||
|
filename = filename[:filename.rfind('.')] + '.png'
|
||
|
new_filename = get_next_filename(output_dir, filename)
|
||
|
|
||
|
if new_filename:
|
||
|
if masked_image.shape[-1] == 4:
|
||
|
cv2.imwrite(os.path.join(output_dir, new_filename), masked_image, [cv2.IMWRITE_PNG_COMPRESSION, 9])
|
||
|
else:
|
||
|
cv2.imwrite(os.path.join(output_dir, new_filename), masked_image)
|
||
|
print(f"Saved as {new_filename}")
|
||
|
else:
|
||
|
print("Could not save the image. Too many variations exist.")
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
app = QtWidgets.QApplication(sys.argv)
|
||
|
MainWindow = QtWidgets.QMainWindow()
|
||
|
ui = Ui_MainWindow()
|
||
|
ui.setupUi(MainWindow)
|
||
|
MainWindow.show()
|
||
|
sys.exit(app.exec_())
|
||
|
|