import os import sys import cv2 import numpy as np from PyQt5 import QtCore, QtGui, QtWidgets from segment_anything import sam_model_registry, SamPredictor class Ui_MainWindow(object): def setupUi(self, MainWindow): MainWindow.setObjectName("MainWindow") MainWindow.resize(1140, 450) MainWindow.setMinimumSize(QtCore.QSize(1140, 450)) MainWindow.setMaximumSize(QtCore.QSize(1140, 450)) self.centralwidget = QtWidgets.QWidget(MainWindow) self.centralwidget.setObjectName("centralwidget") self.pushButton_w = QtWidgets.QPushButton(self.centralwidget) self.pushButton_w.setGeometry(QtCore.QRect(10, 90, 151, 51)) self.pushButton_w.setObjectName("pushButton_w") self.pushButton_a = QtWidgets.QPushButton(self.centralwidget) self.pushButton_a.setGeometry(QtCore.QRect(10, 160, 71, 51)) self.pushButton_a.setObjectName("pushButton_a") self.pushButton_d = QtWidgets.QPushButton(self.centralwidget) self.pushButton_d.setGeometry(QtCore.QRect(90, 160, 71, 51)) self.pushButton_d.setObjectName("pushButton_d") self.pushButton_s = QtWidgets.QPushButton(self.centralwidget) self.pushButton_s.setGeometry(QtCore.QRect(10, 360, 151, 51)) self.pushButton_s.setObjectName("pushButton_s") self.pushButton_5 = QtWidgets.QPushButton(self.centralwidget) self.pushButton_5.setGeometry(QtCore.QRect(10, 230, 151, 51)) self.pushButton_5.setObjectName("pushButton_5") self.label_orign = QtWidgets.QLabel(self.centralwidget) self.label_orign.setGeometry(QtCore.QRect(180, 20, 471, 401)) self.label_orign.setStyleSheet("background-color: rgb(255, 255, 255);") self.label_orign.setObjectName("label_orign") self.label_2 = QtWidgets.QLabel(self.centralwidget) self.label_2.setGeometry(QtCore.QRect(660, 20, 471, 401)) self.label_2.setStyleSheet("background-color: rgb(255, 255, 255);") self.label_2.setObjectName("label_2") self.pushButton_w_2 = QtWidgets.QPushButton(self.centralwidget) self.pushButton_w_2.setGeometry(QtCore.QRect(10, 20, 151, 51)) self.pushButton_w_2.setObjectName("pushButton_w_2") self.lineEdit = QtWidgets.QLineEdit(self.centralwidget) self.lineEdit.setGeometry(QtCore.QRect(50, 290, 81, 21)) self.lineEdit.setObjectName("lineEdit") self.horizontalSlider = QtWidgets.QSlider(self.centralwidget) self.horizontalSlider.setGeometry(QtCore.QRect(10, 320, 141, 22)) self.horizontalSlider.setRange(0, 10) # 将范围设置为从0到最大值 self.horizontalSlider.setSingleStep(1) self.horizontalSlider.setValue(0) # 初始值设为0 self.horizontalSlider.setOrientation(QtCore.Qt.Horizontal) self.horizontalSlider.setTickInterval(0) self.horizontalSlider.setObjectName("horizontalSlider") MainWindow.setCentralWidget(self.centralwidget) self.menubar = QtWidgets.QMenuBar(MainWindow) self.menubar.setGeometry(QtCore.QRect(0, 0, 1140, 23)) self.menubar.setObjectName("menubar") MainWindow.setMenuBar(self.menubar) self.statusbar = QtWidgets.QStatusBar(MainWindow) self.statusbar.setObjectName("statusbar") MainWindow.setStatusBar(self.statusbar) self.retranslateUi(MainWindow) QtCore.QMetaObject.connectSlotsByName(MainWindow) def retranslateUi(self, MainWindow): _translate = QtCore.QCoreApplication.translate MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow")) self.pushButton_w.setText(_translate("MainWindow", "获取json文件")) self.pushButton_a.setText(_translate("MainWindow", "Pre")) self.pushButton_d.setText(_translate("MainWindow", "Next")) self.pushButton_s.setText(_translate("MainWindow", "Save")) self.pushButton_5.setText(_translate("MainWindow", "背景图")) self.label_orign.setText(_translate("MainWindow", "
原始图像
")) self.label_2.setText(_translate("MainWindow", "预测图像
")) self.pushButton_w_2.setText(_translate("MainWindow", "Openimg")) self.lineEdit.setText(_translate("MainWindow", "改变mask大小")) class MyMainWindow(QtWidgets.QMainWindow, Ui_MainWindow): def __init__(self): super().__init__() self.setupUi(self) self.k = 0 self.last_value = 0 # 保存上一次滑块值 self.image_path = "" self.image_folder = "" self.image_files = [] self.current_index = 0 self.input_stop = False # 在这里初始化 input_stop self.pushButton_w_2.clicked.connect(self.open_image_folder) self.pushButton_a.clicked.connect(self.load_previous_image) self.pushButton_d.clicked.connect(self.load_next_image) self.pushButton_s.clicked.connect(self.save_prediction) self.pushButton_5.clicked.connect(self.select_background_image) self.horizontalSlider.valueChanged.connect(self.adjust_prediction_size) # 连接水平滑块的值改变信号 def adjust_pixmap_size(self, pixmap, scale_factor): scaled_size = QtCore.QSize(pixmap.size().width() * scale_factor / 100, pixmap.size().height() * scale_factor / 100) return pixmap.scaled(scaled_size, QtCore.Qt.KeepAspectRatio) def open_image_folder(self): folder_dialog = QtWidgets.QFileDialog() folder_path = folder_dialog.getExistingDirectory(self, 'Open Image Folder', '') if folder_path: self.image_folder = folder_path self.image_files = self.get_image_files(self.image_folder) if self.image_files: self.show_image_selection_dialog() def load_previous_image(self): if self.image_files: if self.current_index > 0: self.current_index -= 1 else: self.current_index = len(self.image_files) - 1 self.show_image() def load_next_image(self): if self.image_files: if self.current_index < len(self.image_files) - 1: self.current_index += 1 else: self.current_index = 0 self.show_image() def get_image_files(self, folder_path): image_files = [file for file in os.listdir(folder_path) if file.endswith(('png', 'jpg', 'jpeg', 'bmp'))] return image_files def show_image_selection_dialog(self): dialog = QtWidgets.QDialog(self) dialog.setWindowTitle("Select Image") layout = QtWidgets.QVBoxLayout() self.listWidget = QtWidgets.QListWidget() for image_file in self.image_files: item = QtWidgets.QListWidgetItem(image_file) pixmap = QtGui.QPixmap(os.path.join(self.image_folder, image_file)).scaledToWidth(100) item.setIcon(QtGui.QIcon(pixmap)) self.listWidget.addItem(item) self.listWidget.itemDoubleClicked.connect(self.image_selected) layout.addWidget(self.listWidget) buttonBox = QtWidgets.QDialogButtonBox(QtWidgets.QDialogButtonBox.Ok | QtWidgets.QDialogButtonBox.Cancel) buttonBox.accepted.connect(self.image_selected) buttonBox.rejected.connect(dialog.reject) layout.addWidget(buttonBox) dialog.setLayout(layout) dialog.exec_() def image_selected(self): selected_item = self.listWidget.currentItem() if selected_item: selected_index = self.listWidget.currentRow() if selected_index >= 0 and selected_index < len(self.image_files): # 检查索引是否在有效范围内 self.current_index = selected_index self.show_image() # 显示所选图像 # 调用OpenCV窗口显示 self.call_opencv_interaction(os.path.join(self.image_folder, self.image_files[self.current_index])) def select_background_image(self): file_dialog = QtWidgets.QFileDialog() image_path, _ = file_dialog.getOpenFileName(self, 'Select Background Image', '', 'Image Files (*.png *.jpg *.jpeg *.bmp)') if image_path: self.show_background_image(image_path) def show_background_image(self, image_path): pixmap = QtGui.QPixmap(image_path) current_pixmap = self.label_2.pixmap() if current_pixmap: current_pixmap = QtGui.QPixmap(current_pixmap) scene = QtWidgets.QGraphicsScene() scene.addPixmap(pixmap) scene.addPixmap(current_pixmap) merged_pixmap = QtGui.QPixmap(scene.sceneRect().size().toSize()) merged_pixmap.fill(QtCore.Qt.transparent) painter = QtGui.QPainter(merged_pixmap) scene.render(painter) painter.end() self.label_2.setPixmap(merged_pixmap) else: self.label_2.setPixmap(pixmap.scaled(self.label_2.size(), QtCore.Qt.KeepAspectRatio)) def show_image(self): if self.image_files and self.current_index < len(self.image_files): file_path = os.path.join(self.image_folder, self.image_files[self.current_index]) pixmap = QtGui.QPixmap(file_path) self.label_orign.setPixmap(pixmap.scaled(self.label_orign.size(), QtCore.Qt.KeepAspectRatio)) def call_opencv_interaction(self, image_path): input_dir = os.path.dirname(image_path) image_orign = cv2.imread(image_path) output_dir = r'D:\Program Files\Pycharm items\segment-anything-model\scripts\output\maskt' crop_mode = True print('最好是每加一个点就按w键predict一次') os.makedirs(output_dir, exist_ok=True) image_files = [self.image_files[self.current_index]] sam = sam_model_registry["vit_b"]( checkpoint=r"D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth") _ = sam.to(device="cuda") predictor = SamPredictor(sam) WINDOW_WIDTH = 1280 WINDOW_HEIGHT = 720 def apply_mask(image, mask, alpha_channel=True): if alpha_channel: alpha = np.zeros_like(image[..., 0]) alpha[mask == 1] = 255 image = cv2.merge((image[..., 0], image[..., 1], image[..., 2], alpha)) else: image = np.where(mask[..., None] == 1, image, 0) return image def apply_color_mask(image, mask, color, color_dark=0.5): for c in range(3): image[:, :, c] = np.where(mask == 1, image[:, :, c] * (1 - color_dark) + color_dark * color[c], image[:, :, c]) return image def get_next_filename(base_path, filename): name, ext = os.path.splitext(filename) for i in range(1, 101): new_name = f"{name}_{i}{ext}" if not os.path.exists(os.path.join(base_path, new_name)): return new_name return None def save_masked_image(image, mask, output_dir, filename, crop_mode_): # 保存图像到指定路径 if crop_mode_: # 如果采用了裁剪模式,则裁剪图像 y, x = np.where(mask) y_min, y_max, x_min, x_max = y.min(), y.max(), x.min(), x.max() cropped_mask = mask[y_min:y_max + 1, x_min:x_max + 1] cropped_image = image[y_min:y_max + 1, x_min:x_max + 1] masked_image = apply_mask(cropped_image, cropped_mask) else: masked_image = apply_mask(image, mask) filename = filename[:filename.rfind('.')] + '.png' new_filename = get_next_filename(output_dir, filename) if new_filename: if masked_image.shape[-1] == 4: cv2.imwrite(os.path.join(output_dir, new_filename), masked_image, [cv2.IMWRITE_PNG_COMPRESSION, 9]) else: cv2.imwrite(os.path.join(output_dir, new_filename), masked_image) print(f"Saved as {new_filename}") # 读取保存的图像文件 saved_image_path = os.path.join(output_dir, new_filename) saved_image_pixmap = QtGui.QPixmap(saved_image_path) # 将保存的图像显示在预测图像区域 mainWindow.label_2.setPixmap( saved_image_pixmap.scaled(mainWindow.label_2.size(), QtCore.Qt.KeepAspectRatio)) else: print("Could not save the image. Too many variations exist.") current_index = 0 cv2.namedWindow("image", cv2.WINDOW_NORMAL) cv2.resizeWindow("image", WINDOW_WIDTH, WINDOW_HEIGHT) cv2.moveWindow("image", (1920 - WINDOW_WIDTH) // 2, (1080 - WINDOW_HEIGHT) // 2) def mouse_click(event, x, y, flags, param): if not self.input_stop: if event == cv2.EVENT_LBUTTONDOWN: input_point.append([x, y]) input_label.append(1) elif event == cv2.EVENT_RBUTTONDOWN: input_point.append([x, y]) input_label.append(0) else: if event == cv2.EVENT_LBUTTONDOWN or event == cv2.EVENT_RBUTTONDOWN: print('此时不能添加点,按w退出mask选择模式') cv2.setMouseCallback("image", mouse_click) input_point = [] input_label = [] input_stop = False while True: filename = self.image_files[self.current_index] image_orign = cv2.imread(os.path.join(input_dir, filename)) image_crop = image_orign.copy() image = cv2.cvtColor(image_orign.copy(), cv2.COLOR_BGR2RGB) selected_mask = None logit_input = None while True: image_display = image_orign.copy() display_info = f'{filename} ' cv2.putText(image_display, display_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2, cv2.LINE_AA) for point, label in zip(input_point, input_label): color = (0, 255, 0) if label == 1 else (0, 0, 255) cv2.circle(image_display, tuple(point), 5, color, -1) if selected_mask is not None: color = tuple(np.random.randint(0, 256, 3).tolist()) selected_image = apply_color_mask(image_display, selected_mask, color) cv2.imshow("image", image_display) key = cv2.waitKey(1) if key == ord(" "): input_point = [] input_label = [] selected_mask = None logit_input = None elif key == ord("w"): input_stop = True if len(input_point) > 0 and len(input_label) > 0: predictor.set_image(image) input_point_np = np.array(input_point) input_label_np = np.array(input_label) masks, scores, logits = predictor.predict( point_coords=input_point_np, point_labels=input_label_np, mask_input=logit_input[None, :, :] if logit_input is not None else None, multimask_output=True, ) mask_idx = 0 num_masks = len(masks) prediction_window_name = "Prediction" cv2.namedWindow(prediction_window_name, cv2.WINDOW_NORMAL) cv2.resizeWindow(prediction_window_name, WINDOW_WIDTH, WINDOW_HEIGHT) cv2.moveWindow(prediction_window_name, (1920 - WINDOW_WIDTH) // 2, (1080 - WINDOW_HEIGHT) // 2) while True: color = tuple(np.random.randint(0, 256, 3).tolist()) image_select = image_orign.copy() selected_mask = masks[mask_idx] selected_image = apply_color_mask(image_select, selected_mask, color) mask_info = f'Total: {num_masks} | Current: {mask_idx} | Score: {scores[mask_idx]:.2f} | w 预测 | d 切换下一个 | a 切换上一个 |q 移除最后一个' cv2.putText(selected_image, mask_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2, cv2.LINE_AA) cv2.imshow(prediction_window_name, selected_image) key = cv2.waitKey(10) if key == ord('q') and len(input_point) > 0: input_point.pop(-1) elif key == ord('s'): save_masked_image(image_crop, selected_mask, output_dir, filename, crop_mode_=crop_mode) elif key == ord('a'): if mask_idx > 0: mask_idx -= 1 else: mask_idx = num_masks - 1 elif key == ord('d'): if mask_idx < num_masks - 1: mask_idx += 1 else: mask_idx = 0 elif key == ord('w'): input_stop = False # Allow adding points again break elif key == ord(" "): input_point = [] input_label = [] selected_mask = None logit_input = None break logit_input = logits[mask_idx, :, :] print('max score:', np.argmax(scores), ' select:', mask_idx) elif key == ord('a'): current_index = max(0, current_index - 1) input_point = [] input_label = [] break elif key == ord('d'): current_index = min(len(image_files) - 1, current_index + 1) input_point = [] input_label = [] break elif key == 27: break elif key == ord('q') and len(input_point) > 0: input_point.pop(-1) input_label.pop(-1) elif key == ord('s') and selected_mask is not None: save_masked_image(image_crop, selected_mask, output_dir, filename, crop_mode_=crop_mode) if key == 27: break cv2.destroyAllWindows() # Close all windows before exiting if key == 27: break def save_prediction(self): if self.label_2.pixmap(): # 检查预测图像区域是否有图像 # 保存预测结果的部分,这里假设你已经有了保存预测结果的代码,我用 placeholer 代替 # placeholder: 这里假设 save_prediction_result 是一个保存预测结果的函数,它接受预测结果的图像数据以及保存路径作为参数 # 这里假设预测结果图像数据为 prediction_image,保存路径为 save_path prediction_image = self.label_2.pixmap().toImage() save_path = "prediction_result.png" prediction_image.save(save_path) # 调用 adjust_prediction_size 方法来根据 horizontalSlider 的值调整预测区域的大小 self.adjust_prediction_size(self.horizontalSlider.value()) def adjust_prediction_size(self, value): if self.image_files and self.current_index < len(self.image_files): # 获取预测图像区域的原始大小 pixmap = self.label_2.pixmap() if pixmap.isNull(): return original_size = pixmap.size() # 判断是缩小还是还原图像 if value < self.last_value: # 缩小掩码 scale_factor = 1.0 + (self.last_value - value) * 0.1 else: # 放大掩码 scale_factor = 1.0 - (value - self.last_value) * 0.1 self.last_value = value # 更新上一次的滑块值 # 根据缩放比例调整预测图像区域的大小,并保持纵横比例 scaled_size = QtCore.QSize(original_size.width() * scale_factor, original_size.height() * scale_factor) scaled_pixmap = pixmap.scaled(scaled_size, QtCore.Qt.KeepAspectRatio) # 更新预测图像区域的大小并显示 self.label_2.setPixmap(scaled_pixmap) if __name__ == "__main__": app = QtWidgets.QApplication(sys.argv) mainWindow = MyMainWindow() mainWindow.show() sys.exit(app.exec_())