SAM/salt/suibian.py

454 lines
21 KiB
Python
Raw Permalink Normal View History

2024-06-19 08:51:04 +08:00
import os
import sys
import cv2
import numpy as np
from PyQt5 import QtCore, QtGui, QtWidgets
from segment_anything import sam_model_registry, SamPredictor
class Ui_MainWindow(object):
def setupUi(self, MainWindow):
MainWindow.setObjectName("MainWindow")
MainWindow.resize(1140, 450)
MainWindow.setMinimumSize(QtCore.QSize(1140, 450))
MainWindow.setMaximumSize(QtCore.QSize(1140, 450))
self.centralwidget = QtWidgets.QWidget(MainWindow)
self.centralwidget.setObjectName("centralwidget")
self.pushButton_w = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_w.setGeometry(QtCore.QRect(10, 90, 151, 51))
self.pushButton_w.setObjectName("pushButton_w")
self.pushButton_a = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_a.setGeometry(QtCore.QRect(10, 160, 71, 51))
self.pushButton_a.setObjectName("pushButton_a")
self.pushButton_d = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_d.setGeometry(QtCore.QRect(90, 160, 71, 51))
self.pushButton_d.setObjectName("pushButton_d")
self.pushButton_s = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_s.setGeometry(QtCore.QRect(10, 360, 151, 51))
self.pushButton_s.setObjectName("pushButton_s")
self.pushButton_5 = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_5.setGeometry(QtCore.QRect(10, 230, 151, 51))
self.pushButton_5.setObjectName("pushButton_5")
self.label_orign = QtWidgets.QLabel(self.centralwidget)
self.label_orign.setGeometry(QtCore.QRect(180, 20, 471, 401))
self.label_orign.setStyleSheet("background-color: rgb(255, 255, 255);")
self.label_orign.setObjectName("label_orign")
self.label_2 = QtWidgets.QLabel(self.centralwidget)
self.label_2.setGeometry(QtCore.QRect(660, 20, 471, 401))
self.label_2.setStyleSheet("background-color: rgb(255, 255, 255);")
self.label_2.setObjectName("label_2")
self.pushButton_w_2 = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_w_2.setGeometry(QtCore.QRect(10, 20, 151, 51))
self.pushButton_w_2.setObjectName("pushButton_w_2")
self.lineEdit = QtWidgets.QLineEdit(self.centralwidget)
self.lineEdit.setGeometry(QtCore.QRect(50, 290, 81, 21))
self.lineEdit.setObjectName("lineEdit")
self.horizontalSlider = QtWidgets.QSlider(self.centralwidget)
self.horizontalSlider.setGeometry(QtCore.QRect(10, 320, 141, 22))
self.horizontalSlider.setRange(0, 10) # 将范围设置为从0到最大值
self.horizontalSlider.setSingleStep(1)
self.horizontalSlider.setValue(0) # 初始值设为0
self.horizontalSlider.setOrientation(QtCore.Qt.Horizontal)
self.horizontalSlider.setTickInterval(0)
self.horizontalSlider.setObjectName("horizontalSlider")
MainWindow.setCentralWidget(self.centralwidget)
self.menubar = QtWidgets.QMenuBar(MainWindow)
self.menubar.setGeometry(QtCore.QRect(0, 0, 1140, 23))
self.menubar.setObjectName("menubar")
MainWindow.setMenuBar(self.menubar)
self.statusbar = QtWidgets.QStatusBar(MainWindow)
self.statusbar.setObjectName("statusbar")
MainWindow.setStatusBar(self.statusbar)
self.retranslateUi(MainWindow)
QtCore.QMetaObject.connectSlotsByName(MainWindow)
def retranslateUi(self, MainWindow):
_translate = QtCore.QCoreApplication.translate
MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow"))
self.pushButton_w.setText(_translate("MainWindow", "Predict"))
self.pushButton_a.setText(_translate("MainWindow", "Pre"))
self.pushButton_d.setText(_translate("MainWindow", "Next"))
self.pushButton_s.setText(_translate("MainWindow", "Save"))
self.pushButton_5.setText(_translate("MainWindow", "背景图"))
self.label_orign.setText(_translate("MainWindow", "<html><head/><body><p align=\"center\">原始图像</p></body></html>"))
self.label_2.setText(_translate("MainWindow", "<html><head/><body><p align=\"center\">预测图像</p></body></html>"))
self.pushButton_w_2.setText(_translate("MainWindow", "Openimg"))
self.lineEdit.setText(_translate("MainWindow", "改变mask大小"))
class MyMainWindow(QtWidgets.QMainWindow, Ui_MainWindow):
def __init__(self):
super().__init__()
self.setupUi(self)
self.k = 0
self.last_value = 0 # 保存上一次滑块值
self.image_path = ""
self.image_folder = ""
self.image_files = []
self.current_index = 0
self.input_stop = False # 在这里初始化 input_stop
self.pushButton_w_2.clicked.connect(self.open_image_folder)
self.pushButton_a.clicked.connect(self.load_previous_image)
self.pushButton_d.clicked.connect(self.load_next_image)
self.pushButton_s.clicked.connect(self.save_prediction)
self.pushButton_5.clicked.connect(self.select_background_image)
self.horizontalSlider.valueChanged.connect(self.adjust_prediction_size) # 连接水平滑块的值改变信号
def adjust_pixmap_size(self, pixmap, scale_factor):
scaled_size = QtCore.QSize(pixmap.size().width() * scale_factor / 100,
pixmap.size().height() * scale_factor / 100)
return pixmap.scaled(scaled_size, QtCore.Qt.KeepAspectRatio)
def open_image_folder(self):
folder_dialog = QtWidgets.QFileDialog()
folder_path = folder_dialog.getExistingDirectory(self, 'Open Image Folder', '')
if folder_path:
self.image_folder = folder_path
self.image_files = self.get_image_files(self.image_folder)
if self.image_files:
self.show_image_selection_dialog()
def load_previous_image(self):
if self.image_files:
if self.current_index > 0:
self.current_index -= 1
else:
self.current_index = len(self.image_files) - 1
self.show_image()
def load_next_image(self):
if self.image_files:
if self.current_index < len(self.image_files) - 1:
self.current_index += 1
else:
self.current_index = 0
self.show_image()
def get_image_files(self, folder_path):
image_files = [file for file in os.listdir(folder_path) if file.endswith(('png', 'jpg', 'jpeg', 'bmp'))]
return image_files
def show_image_selection_dialog(self):
dialog = QtWidgets.QDialog(self)
dialog.setWindowTitle("Select Image")
layout = QtWidgets.QVBoxLayout()
self.listWidget = QtWidgets.QListWidget()
for image_file in self.image_files:
item = QtWidgets.QListWidgetItem(image_file)
pixmap = QtGui.QPixmap(os.path.join(self.image_folder, image_file)).scaledToWidth(100)
item.setIcon(QtGui.QIcon(pixmap))
self.listWidget.addItem(item)
self.listWidget.itemDoubleClicked.connect(self.image_selected)
layout.addWidget(self.listWidget)
buttonBox = QtWidgets.QDialogButtonBox(QtWidgets.QDialogButtonBox.Ok | QtWidgets.QDialogButtonBox.Cancel)
buttonBox.accepted.connect(self.image_selected)
buttonBox.rejected.connect(dialog.reject)
layout.addWidget(buttonBox)
dialog.setLayout(layout)
dialog.exec_()
def image_selected(self):
selected_item = self.listWidget.currentItem()
if selected_item:
selected_index = self.listWidget.currentRow()
if selected_index >= 0 and selected_index < len(self.image_files): # 检查索引是否在有效范围内
self.current_index = selected_index
self.show_image() # 显示所选图像
# 调用OpenCV窗口显示
self.call_opencv_interaction(os.path.join(self.image_folder, self.image_files[self.current_index]))
def select_background_image(self):
file_dialog = QtWidgets.QFileDialog()
image_path, _ = file_dialog.getOpenFileName(self, 'Select Background Image', '',
'Image Files (*.png *.jpg *.jpeg *.bmp)')
if image_path:
self.show_background_image(image_path)
def show_background_image(self, image_path):
pixmap = QtGui.QPixmap(image_path)
current_pixmap = self.label_2.pixmap()
if current_pixmap:
current_pixmap = QtGui.QPixmap(current_pixmap)
scene = QtWidgets.QGraphicsScene()
scene.addPixmap(pixmap)
scene.addPixmap(current_pixmap)
merged_pixmap = QtGui.QPixmap(scene.sceneRect().size().toSize())
merged_pixmap.fill(QtCore.Qt.transparent)
painter = QtGui.QPainter(merged_pixmap)
scene.render(painter)
painter.end()
self.label_2.setPixmap(merged_pixmap)
else:
self.label_2.setPixmap(pixmap.scaled(self.label_2.size(), QtCore.Qt.KeepAspectRatio))
def show_image(self):
if self.image_files and self.current_index < len(self.image_files):
file_path = os.path.join(self.image_folder, self.image_files[self.current_index])
pixmap = QtGui.QPixmap(file_path)
self.label_orign.setPixmap(pixmap.scaled(self.label_orign.size(), QtCore.Qt.KeepAspectRatio))
def call_opencv_interaction(self, image_path):
input_dir = os.path.dirname(image_path)
image_orign = cv2.imread(image_path)
output_dir = r'D:\Program Files\Pycharm items\segment-anything-model\scripts\output\maskt'
crop_mode = True
print('最好是每加一个点就按w键predict一次')
os.makedirs(output_dir, exist_ok=True)
image_files = [self.image_files[self.current_index]]
sam = sam_model_registry["vit_b"](
checkpoint=r"D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth")
_ = sam.to(device="cuda")
predictor = SamPredictor(sam)
WINDOW_WIDTH = 1280
WINDOW_HEIGHT = 720
def apply_mask(image, mask, alpha_channel=True):
if alpha_channel:
alpha = np.zeros_like(image[..., 0])
alpha[mask == 1] = 255
image = cv2.merge((image[..., 0], image[..., 1], image[..., 2], alpha))
else:
image = np.where(mask[..., None] == 1, image, 0)
return image
def apply_color_mask(image, mask, color, color_dark=0.5):
for c in range(3):
image[:, :, c] = np.where(mask == 1, image[:, :, c] * (1 - color_dark) + color_dark * color[c],
image[:, :, c])
return image
def get_next_filename(base_path, filename):
name, ext = os.path.splitext(filename)
for i in range(1, 101):
new_name = f"{name}_{i}{ext}"
if not os.path.exists(os.path.join(base_path, new_name)):
return new_name
return None
def save_masked_image(image, mask, output_dir, filename, crop_mode_):
# 保存图像到指定路径
if crop_mode_:
# 如果采用了裁剪模式,则裁剪图像
y, x = np.where(mask)
y_min, y_max, x_min, x_max = y.min(), y.max(), x.min(), x.max()
cropped_mask = mask[y_min:y_max + 1, x_min:x_max + 1]
cropped_image = image[y_min:y_max + 1, x_min:x_max + 1]
masked_image = apply_mask(cropped_image, cropped_mask)
else:
masked_image = apply_mask(image, mask)
filename = filename[:filename.rfind('.')] + '.png'
new_filename = get_next_filename(output_dir, filename)
if new_filename:
if masked_image.shape[-1] == 4:
cv2.imwrite(os.path.join(output_dir, new_filename), masked_image,
[cv2.IMWRITE_PNG_COMPRESSION, 9])
else:
cv2.imwrite(os.path.join(output_dir, new_filename), masked_image)
print(f"Saved as {new_filename}")
# 读取保存的图像文件
saved_image_path = os.path.join(output_dir, new_filename)
saved_image_pixmap = QtGui.QPixmap(saved_image_path)
# 将保存的图像显示在预测图像区域
mainWindow.label_2.setPixmap(
saved_image_pixmap.scaled(mainWindow.label_2.size(), QtCore.Qt.KeepAspectRatio))
else:
print("Could not save the image. Too many variations exist.")
current_index = 0
cv2.namedWindow("image", cv2.WINDOW_NORMAL)
cv2.resizeWindow("image", WINDOW_WIDTH, WINDOW_HEIGHT)
cv2.moveWindow("image", (1920 - WINDOW_WIDTH) // 2, (1080 - WINDOW_HEIGHT) // 2)
def mouse_click(event, x, y, flags, param):
if not self.input_stop:
if event == cv2.EVENT_LBUTTONDOWN:
input_point.append([x, y])
input_label.append(1)
elif event == cv2.EVENT_RBUTTONDOWN:
input_point.append([x, y])
input_label.append(0)
else:
if event == cv2.EVENT_LBUTTONDOWN or event == cv2.EVENT_RBUTTONDOWN:
print('此时不能添加点,按w退出mask选择模式')
cv2.setMouseCallback("image", mouse_click)
input_point = []
input_label = []
input_stop = False
while True:
filename = self.image_files[self.current_index]
image_orign = cv2.imread(os.path.join(input_dir, filename))
image_crop = image_orign.copy()
image = cv2.cvtColor(image_orign.copy(), cv2.COLOR_BGR2RGB)
selected_mask = None
logit_input = None
while True:
image_display = image_orign.copy()
display_info = f'{filename} '
cv2.putText(image_display, display_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255),
2,
cv2.LINE_AA)
for point, label in zip(input_point, input_label):
color = (0, 255, 0) if label == 1 else (0, 0, 255)
cv2.circle(image_display, tuple(point), 5, color, -1)
if selected_mask is not None:
color = tuple(np.random.randint(0, 256, 3).tolist())
selected_image = apply_color_mask(image_display, selected_mask, color)
cv2.imshow("image", image_display)
key = cv2.waitKey(1)
if key == ord(" "):
input_point = []
input_label = []
selected_mask = None
logit_input = None
elif key == ord("w"):
input_stop = True
if len(input_point) > 0 and len(input_label) > 0:
predictor.set_image(image)
input_point_np = np.array(input_point)
input_label_np = np.array(input_label)
masks, scores, logits = predictor.predict(
point_coords=input_point_np,
point_labels=input_label_np,
mask_input=logit_input[None, :, :] if logit_input is not None else None,
multimask_output=True,
)
mask_idx = 0
num_masks = len(masks)
prediction_window_name = "Prediction"
cv2.namedWindow(prediction_window_name, cv2.WINDOW_NORMAL)
cv2.resizeWindow(prediction_window_name, WINDOW_WIDTH, WINDOW_HEIGHT)
cv2.moveWindow(prediction_window_name, (1920 - WINDOW_WIDTH) // 2,
(1080 - WINDOW_HEIGHT) // 2)
while True:
color = tuple(np.random.randint(0, 256, 3).tolist())
image_select = image_orign.copy()
selected_mask = masks[mask_idx]
selected_image = apply_color_mask(image_select, selected_mask, color)
mask_info = f'Total: {num_masks} | Current: {mask_idx} | Score: {scores[mask_idx]:.2f} | w 预测 | d 切换下一个 | a 切换上一个 |q 移除最后一个'
cv2.putText(selected_image, mask_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7,
(0, 255, 255), 2, cv2.LINE_AA)
cv2.imshow(prediction_window_name, selected_image)
key = cv2.waitKey(10)
if key == ord('q') and len(input_point) > 0:
input_point.pop(-1)
elif key == ord('s'):
save_masked_image(image_crop, selected_mask, output_dir, filename,
crop_mode_=crop_mode)
elif key == ord('a'):
if mask_idx > 0:
mask_idx -= 1
else:
mask_idx = num_masks - 1
elif key == ord('d'):
if mask_idx < num_masks - 1:
mask_idx += 1
else:
mask_idx = 0
elif key == ord('w'):
input_stop = False # Allow adding points again
break
elif key == ord(" "):
input_point = []
input_label = []
selected_mask = None
logit_input = None
break
logit_input = logits[mask_idx, :, :]
print('max score:', np.argmax(scores), ' select:', mask_idx)
elif key == ord('a'):
current_index = max(0, current_index - 1)
input_point = []
input_label = []
break
elif key == ord('d'):
current_index = min(len(image_files) - 1, current_index + 1)
input_point = []
input_label = []
break
elif key == 27:
break
elif key == ord('q') and len(input_point) > 0:
input_point.pop(-1)
input_label.pop(-1)
elif key == ord('s') and selected_mask is not None:
save_masked_image(image_crop, selected_mask, output_dir, filename, crop_mode_=crop_mode)
if key == 27:
break
cv2.destroyAllWindows() # Close all windows before exiting
if key == 27:
break
def save_prediction(self):
if self.label_2.pixmap(): # 检查预测图像区域是否有图像
# 保存预测结果的部分,这里假设你已经有了保存预测结果的代码,我用 placeholer 代替
# placeholder: 这里假设 save_prediction_result 是一个保存预测结果的函数,它接受预测结果的图像数据以及保存路径作为参数
# 这里假设预测结果图像数据为 prediction_image保存路径为 save_path
prediction_image = self.label_2.pixmap().toImage()
save_path = "prediction_result.png"
prediction_image.save(save_path)
# 调用 adjust_prediction_size 方法来根据 horizontalSlider 的值调整预测区域的大小
self.adjust_prediction_size(self.horizontalSlider.value())
def adjust_prediction_size(self, value):
if self.image_files and self.current_index < len(self.image_files):
# 获取预测图像区域的原始大小
pixmap = self.label_2.pixmap()
if pixmap.isNull():
return
original_size = pixmap.size()
# 判断是缩小还是还原图像
if value < self.last_value:
# 缩小掩码
scale_factor = 1.0 + (self.last_value - value) * 0.1
else:
# 放大掩码
scale_factor = 1.0 - (value - self.last_value) * 0.1
self.last_value = value # 更新上一次的滑块值
# 根据缩放比例调整预测图像区域的大小,并保持纵横比例
scaled_size = QtCore.QSize(original_size.width() * scale_factor, original_size.height() * scale_factor)
scaled_pixmap = pixmap.scaled(scaled_size, QtCore.Qt.KeepAspectRatio)
# 更新预测图像区域的大小并显示
self.label_2.setPixmap(scaled_pixmap)
if __name__ == "__main__":
app = QtWidgets.QApplication(sys.argv)
mainWindow = MyMainWindow()
mainWindow.show()
sys.exit(app.exec_())