SAM/salt/suibian.py

454 lines
21 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import sys
import cv2
import numpy as np
from PyQt5 import QtCore, QtGui, QtWidgets
from segment_anything import sam_model_registry, SamPredictor
class Ui_MainWindow(object):
def setupUi(self, MainWindow):
MainWindow.setObjectName("MainWindow")
MainWindow.resize(1140, 450)
MainWindow.setMinimumSize(QtCore.QSize(1140, 450))
MainWindow.setMaximumSize(QtCore.QSize(1140, 450))
self.centralwidget = QtWidgets.QWidget(MainWindow)
self.centralwidget.setObjectName("centralwidget")
self.pushButton_w = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_w.setGeometry(QtCore.QRect(10, 90, 151, 51))
self.pushButton_w.setObjectName("pushButton_w")
self.pushButton_a = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_a.setGeometry(QtCore.QRect(10, 160, 71, 51))
self.pushButton_a.setObjectName("pushButton_a")
self.pushButton_d = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_d.setGeometry(QtCore.QRect(90, 160, 71, 51))
self.pushButton_d.setObjectName("pushButton_d")
self.pushButton_s = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_s.setGeometry(QtCore.QRect(10, 360, 151, 51))
self.pushButton_s.setObjectName("pushButton_s")
self.pushButton_5 = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_5.setGeometry(QtCore.QRect(10, 230, 151, 51))
self.pushButton_5.setObjectName("pushButton_5")
self.label_orign = QtWidgets.QLabel(self.centralwidget)
self.label_orign.setGeometry(QtCore.QRect(180, 20, 471, 401))
self.label_orign.setStyleSheet("background-color: rgb(255, 255, 255);")
self.label_orign.setObjectName("label_orign")
self.label_2 = QtWidgets.QLabel(self.centralwidget)
self.label_2.setGeometry(QtCore.QRect(660, 20, 471, 401))
self.label_2.setStyleSheet("background-color: rgb(255, 255, 255);")
self.label_2.setObjectName("label_2")
self.pushButton_w_2 = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_w_2.setGeometry(QtCore.QRect(10, 20, 151, 51))
self.pushButton_w_2.setObjectName("pushButton_w_2")
self.lineEdit = QtWidgets.QLineEdit(self.centralwidget)
self.lineEdit.setGeometry(QtCore.QRect(50, 290, 81, 21))
self.lineEdit.setObjectName("lineEdit")
self.horizontalSlider = QtWidgets.QSlider(self.centralwidget)
self.horizontalSlider.setGeometry(QtCore.QRect(10, 320, 141, 22))
self.horizontalSlider.setRange(0, 10) # 将范围设置为从0到最大值
self.horizontalSlider.setSingleStep(1)
self.horizontalSlider.setValue(0) # 初始值设为0
self.horizontalSlider.setOrientation(QtCore.Qt.Horizontal)
self.horizontalSlider.setTickInterval(0)
self.horizontalSlider.setObjectName("horizontalSlider")
MainWindow.setCentralWidget(self.centralwidget)
self.menubar = QtWidgets.QMenuBar(MainWindow)
self.menubar.setGeometry(QtCore.QRect(0, 0, 1140, 23))
self.menubar.setObjectName("menubar")
MainWindow.setMenuBar(self.menubar)
self.statusbar = QtWidgets.QStatusBar(MainWindow)
self.statusbar.setObjectName("statusbar")
MainWindow.setStatusBar(self.statusbar)
self.retranslateUi(MainWindow)
QtCore.QMetaObject.connectSlotsByName(MainWindow)
def retranslateUi(self, MainWindow):
_translate = QtCore.QCoreApplication.translate
MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow"))
self.pushButton_w.setText(_translate("MainWindow", "Predict"))
self.pushButton_a.setText(_translate("MainWindow", "Pre"))
self.pushButton_d.setText(_translate("MainWindow", "Next"))
self.pushButton_s.setText(_translate("MainWindow", "Save"))
self.pushButton_5.setText(_translate("MainWindow", "背景图"))
self.label_orign.setText(_translate("MainWindow", "<html><head/><body><p align=\"center\">原始图像</p></body></html>"))
self.label_2.setText(_translate("MainWindow", "<html><head/><body><p align=\"center\">预测图像</p></body></html>"))
self.pushButton_w_2.setText(_translate("MainWindow", "Openimg"))
self.lineEdit.setText(_translate("MainWindow", "改变mask大小"))
class MyMainWindow(QtWidgets.QMainWindow, Ui_MainWindow):
def __init__(self):
super().__init__()
self.setupUi(self)
self.k = 0
self.last_value = 0 # 保存上一次滑块值
self.image_path = ""
self.image_folder = ""
self.image_files = []
self.current_index = 0
self.input_stop = False # 在这里初始化 input_stop
self.pushButton_w_2.clicked.connect(self.open_image_folder)
self.pushButton_a.clicked.connect(self.load_previous_image)
self.pushButton_d.clicked.connect(self.load_next_image)
self.pushButton_s.clicked.connect(self.save_prediction)
self.pushButton_5.clicked.connect(self.select_background_image)
self.horizontalSlider.valueChanged.connect(self.adjust_prediction_size) # 连接水平滑块的值改变信号
def adjust_pixmap_size(self, pixmap, scale_factor):
scaled_size = QtCore.QSize(pixmap.size().width() * scale_factor / 100,
pixmap.size().height() * scale_factor / 100)
return pixmap.scaled(scaled_size, QtCore.Qt.KeepAspectRatio)
def open_image_folder(self):
folder_dialog = QtWidgets.QFileDialog()
folder_path = folder_dialog.getExistingDirectory(self, 'Open Image Folder', '')
if folder_path:
self.image_folder = folder_path
self.image_files = self.get_image_files(self.image_folder)
if self.image_files:
self.show_image_selection_dialog()
def load_previous_image(self):
if self.image_files:
if self.current_index > 0:
self.current_index -= 1
else:
self.current_index = len(self.image_files) - 1
self.show_image()
def load_next_image(self):
if self.image_files:
if self.current_index < len(self.image_files) - 1:
self.current_index += 1
else:
self.current_index = 0
self.show_image()
def get_image_files(self, folder_path):
image_files = [file for file in os.listdir(folder_path) if file.endswith(('png', 'jpg', 'jpeg', 'bmp'))]
return image_files
def show_image_selection_dialog(self):
dialog = QtWidgets.QDialog(self)
dialog.setWindowTitle("Select Image")
layout = QtWidgets.QVBoxLayout()
self.listWidget = QtWidgets.QListWidget()
for image_file in self.image_files:
item = QtWidgets.QListWidgetItem(image_file)
pixmap = QtGui.QPixmap(os.path.join(self.image_folder, image_file)).scaledToWidth(100)
item.setIcon(QtGui.QIcon(pixmap))
self.listWidget.addItem(item)
self.listWidget.itemDoubleClicked.connect(self.image_selected)
layout.addWidget(self.listWidget)
buttonBox = QtWidgets.QDialogButtonBox(QtWidgets.QDialogButtonBox.Ok | QtWidgets.QDialogButtonBox.Cancel)
buttonBox.accepted.connect(self.image_selected)
buttonBox.rejected.connect(dialog.reject)
layout.addWidget(buttonBox)
dialog.setLayout(layout)
dialog.exec_()
def image_selected(self):
selected_item = self.listWidget.currentItem()
if selected_item:
selected_index = self.listWidget.currentRow()
if selected_index >= 0 and selected_index < len(self.image_files): # 检查索引是否在有效范围内
self.current_index = selected_index
self.show_image() # 显示所选图像
# 调用OpenCV窗口显示
self.call_opencv_interaction(os.path.join(self.image_folder, self.image_files[self.current_index]))
def select_background_image(self):
file_dialog = QtWidgets.QFileDialog()
image_path, _ = file_dialog.getOpenFileName(self, 'Select Background Image', '',
'Image Files (*.png *.jpg *.jpeg *.bmp)')
if image_path:
self.show_background_image(image_path)
def show_background_image(self, image_path):
pixmap = QtGui.QPixmap(image_path)
current_pixmap = self.label_2.pixmap()
if current_pixmap:
current_pixmap = QtGui.QPixmap(current_pixmap)
scene = QtWidgets.QGraphicsScene()
scene.addPixmap(pixmap)
scene.addPixmap(current_pixmap)
merged_pixmap = QtGui.QPixmap(scene.sceneRect().size().toSize())
merged_pixmap.fill(QtCore.Qt.transparent)
painter = QtGui.QPainter(merged_pixmap)
scene.render(painter)
painter.end()
self.label_2.setPixmap(merged_pixmap)
else:
self.label_2.setPixmap(pixmap.scaled(self.label_2.size(), QtCore.Qt.KeepAspectRatio))
def show_image(self):
if self.image_files and self.current_index < len(self.image_files):
file_path = os.path.join(self.image_folder, self.image_files[self.current_index])
pixmap = QtGui.QPixmap(file_path)
self.label_orign.setPixmap(pixmap.scaled(self.label_orign.size(), QtCore.Qt.KeepAspectRatio))
def call_opencv_interaction(self, image_path):
input_dir = os.path.dirname(image_path)
image_orign = cv2.imread(image_path)
output_dir = r'D:\Program Files\Pycharm items\segment-anything-model\scripts\output\maskt'
crop_mode = True
print('最好是每加一个点就按w键predict一次')
os.makedirs(output_dir, exist_ok=True)
image_files = [self.image_files[self.current_index]]
sam = sam_model_registry["vit_b"](
checkpoint=r"D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth")
_ = sam.to(device="cuda")
predictor = SamPredictor(sam)
WINDOW_WIDTH = 1280
WINDOW_HEIGHT = 720
def apply_mask(image, mask, alpha_channel=True):
if alpha_channel:
alpha = np.zeros_like(image[..., 0])
alpha[mask == 1] = 255
image = cv2.merge((image[..., 0], image[..., 1], image[..., 2], alpha))
else:
image = np.where(mask[..., None] == 1, image, 0)
return image
def apply_color_mask(image, mask, color, color_dark=0.5):
for c in range(3):
image[:, :, c] = np.where(mask == 1, image[:, :, c] * (1 - color_dark) + color_dark * color[c],
image[:, :, c])
return image
def get_next_filename(base_path, filename):
name, ext = os.path.splitext(filename)
for i in range(1, 101):
new_name = f"{name}_{i}{ext}"
if not os.path.exists(os.path.join(base_path, new_name)):
return new_name
return None
def save_masked_image(image, mask, output_dir, filename, crop_mode_):
# 保存图像到指定路径
if crop_mode_:
# 如果采用了裁剪模式,则裁剪图像
y, x = np.where(mask)
y_min, y_max, x_min, x_max = y.min(), y.max(), x.min(), x.max()
cropped_mask = mask[y_min:y_max + 1, x_min:x_max + 1]
cropped_image = image[y_min:y_max + 1, x_min:x_max + 1]
masked_image = apply_mask(cropped_image, cropped_mask)
else:
masked_image = apply_mask(image, mask)
filename = filename[:filename.rfind('.')] + '.png'
new_filename = get_next_filename(output_dir, filename)
if new_filename:
if masked_image.shape[-1] == 4:
cv2.imwrite(os.path.join(output_dir, new_filename), masked_image,
[cv2.IMWRITE_PNG_COMPRESSION, 9])
else:
cv2.imwrite(os.path.join(output_dir, new_filename), masked_image)
print(f"Saved as {new_filename}")
# 读取保存的图像文件
saved_image_path = os.path.join(output_dir, new_filename)
saved_image_pixmap = QtGui.QPixmap(saved_image_path)
# 将保存的图像显示在预测图像区域
mainWindow.label_2.setPixmap(
saved_image_pixmap.scaled(mainWindow.label_2.size(), QtCore.Qt.KeepAspectRatio))
else:
print("Could not save the image. Too many variations exist.")
current_index = 0
cv2.namedWindow("image", cv2.WINDOW_NORMAL)
cv2.resizeWindow("image", WINDOW_WIDTH, WINDOW_HEIGHT)
cv2.moveWindow("image", (1920 - WINDOW_WIDTH) // 2, (1080 - WINDOW_HEIGHT) // 2)
def mouse_click(event, x, y, flags, param):
if not self.input_stop:
if event == cv2.EVENT_LBUTTONDOWN:
input_point.append([x, y])
input_label.append(1)
elif event == cv2.EVENT_RBUTTONDOWN:
input_point.append([x, y])
input_label.append(0)
else:
if event == cv2.EVENT_LBUTTONDOWN or event == cv2.EVENT_RBUTTONDOWN:
print('此时不能添加点,按w退出mask选择模式')
cv2.setMouseCallback("image", mouse_click)
input_point = []
input_label = []
input_stop = False
while True:
filename = self.image_files[self.current_index]
image_orign = cv2.imread(os.path.join(input_dir, filename))
image_crop = image_orign.copy()
image = cv2.cvtColor(image_orign.copy(), cv2.COLOR_BGR2RGB)
selected_mask = None
logit_input = None
while True:
image_display = image_orign.copy()
display_info = f'{filename} '
cv2.putText(image_display, display_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255),
2,
cv2.LINE_AA)
for point, label in zip(input_point, input_label):
color = (0, 255, 0) if label == 1 else (0, 0, 255)
cv2.circle(image_display, tuple(point), 5, color, -1)
if selected_mask is not None:
color = tuple(np.random.randint(0, 256, 3).tolist())
selected_image = apply_color_mask(image_display, selected_mask, color)
cv2.imshow("image", image_display)
key = cv2.waitKey(1)
if key == ord(" "):
input_point = []
input_label = []
selected_mask = None
logit_input = None
elif key == ord("w"):
input_stop = True
if len(input_point) > 0 and len(input_label) > 0:
predictor.set_image(image)
input_point_np = np.array(input_point)
input_label_np = np.array(input_label)
masks, scores, logits = predictor.predict(
point_coords=input_point_np,
point_labels=input_label_np,
mask_input=logit_input[None, :, :] if logit_input is not None else None,
multimask_output=True,
)
mask_idx = 0
num_masks = len(masks)
prediction_window_name = "Prediction"
cv2.namedWindow(prediction_window_name, cv2.WINDOW_NORMAL)
cv2.resizeWindow(prediction_window_name, WINDOW_WIDTH, WINDOW_HEIGHT)
cv2.moveWindow(prediction_window_name, (1920 - WINDOW_WIDTH) // 2,
(1080 - WINDOW_HEIGHT) // 2)
while True:
color = tuple(np.random.randint(0, 256, 3).tolist())
image_select = image_orign.copy()
selected_mask = masks[mask_idx]
selected_image = apply_color_mask(image_select, selected_mask, color)
mask_info = f'Total: {num_masks} | Current: {mask_idx} | Score: {scores[mask_idx]:.2f} | w 预测 | d 切换下一个 | a 切换上一个 |q 移除最后一个'
cv2.putText(selected_image, mask_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7,
(0, 255, 255), 2, cv2.LINE_AA)
cv2.imshow(prediction_window_name, selected_image)
key = cv2.waitKey(10)
if key == ord('q') and len(input_point) > 0:
input_point.pop(-1)
elif key == ord('s'):
save_masked_image(image_crop, selected_mask, output_dir, filename,
crop_mode_=crop_mode)
elif key == ord('a'):
if mask_idx > 0:
mask_idx -= 1
else:
mask_idx = num_masks - 1
elif key == ord('d'):
if mask_idx < num_masks - 1:
mask_idx += 1
else:
mask_idx = 0
elif key == ord('w'):
input_stop = False # Allow adding points again
break
elif key == ord(" "):
input_point = []
input_label = []
selected_mask = None
logit_input = None
break
logit_input = logits[mask_idx, :, :]
print('max score:', np.argmax(scores), ' select:', mask_idx)
elif key == ord('a'):
current_index = max(0, current_index - 1)
input_point = []
input_label = []
break
elif key == ord('d'):
current_index = min(len(image_files) - 1, current_index + 1)
input_point = []
input_label = []
break
elif key == 27:
break
elif key == ord('q') and len(input_point) > 0:
input_point.pop(-1)
input_label.pop(-1)
elif key == ord('s') and selected_mask is not None:
save_masked_image(image_crop, selected_mask, output_dir, filename, crop_mode_=crop_mode)
if key == 27:
break
cv2.destroyAllWindows() # Close all windows before exiting
if key == 27:
break
def save_prediction(self):
if self.label_2.pixmap(): # 检查预测图像区域是否有图像
# 保存预测结果的部分,这里假设你已经有了保存预测结果的代码,我用 placeholer 代替
# placeholder: 这里假设 save_prediction_result 是一个保存预测结果的函数,它接受预测结果的图像数据以及保存路径作为参数
# 这里假设预测结果图像数据为 prediction_image保存路径为 save_path
prediction_image = self.label_2.pixmap().toImage()
save_path = "prediction_result.png"
prediction_image.save(save_path)
# 调用 adjust_prediction_size 方法来根据 horizontalSlider 的值调整预测区域的大小
self.adjust_prediction_size(self.horizontalSlider.value())
def adjust_prediction_size(self, value):
if self.image_files and self.current_index < len(self.image_files):
# 获取预测图像区域的原始大小
pixmap = self.label_2.pixmap()
if pixmap.isNull():
return
original_size = pixmap.size()
# 判断是缩小还是还原图像
if value < self.last_value:
# 缩小掩码
scale_factor = 1.0 + (self.last_value - value) * 0.1
else:
# 放大掩码
scale_factor = 1.0 - (value - self.last_value) * 0.1
self.last_value = value # 更新上一次的滑块值
# 根据缩放比例调整预测图像区域的大小,并保持纵横比例
scaled_size = QtCore.QSize(original_size.width() * scale_factor, original_size.height() * scale_factor)
scaled_pixmap = pixmap.scaled(scaled_size, QtCore.Qt.KeepAspectRatio)
# 更新预测图像区域的大小并显示
self.label_2.setPixmap(scaled_pixmap)
if __name__ == "__main__":
app = QtWidgets.QApplication(sys.argv)
mainWindow = MyMainWindow()
mainWindow.show()
sys.exit(app.exec_())