import cv2 import os import numpy as np from PyQt5 import QtCore, QtGui, QtWidgets from segment_anything import sam_model_registry, SamPredictor class Ui_MainWindow(object): def setupUi(self, MainWindow): MainWindow.setObjectName("MainWindow") MainWindow.resize(1170, 486) self.centralwidget = QtWidgets.QWidget(MainWindow) self.centralwidget.setObjectName("centralwidget") self.pushButton_w = QtWidgets.QPushButton(self.centralwidget) self.pushButton_w.setGeometry(QtCore.QRect(10, 90, 151, 51)) self.pushButton_w.setObjectName("pushButton_w") self.pushButton_a = QtWidgets.QPushButton(self.centralwidget) self.pushButton_a.setGeometry(QtCore.QRect(10, 160, 151, 51)) self.pushButton_a.setObjectName("pushButton_a") self.pushButton_d = QtWidgets.QPushButton(self.centralwidget) self.pushButton_d.setGeometry(QtCore.QRect(10, 230, 151, 51)) self.pushButton_d.setObjectName("pushButton_d") self.pushButton_s = QtWidgets.QPushButton(self.centralwidget) self.pushButton_s.setGeometry(QtCore.QRect(10, 300, 151, 51)) self.pushButton_s.setObjectName("pushButton_s") self.pushButton_q = QtWidgets.QPushButton(self.centralwidget) self.pushButton_q.setGeometry(QtCore.QRect(10, 370, 151, 51)) self.pushButton_q.setObjectName("pushButton_q") self.label_orign = QtWidgets.QLabel(self.centralwidget) self.label_orign.setGeometry(QtCore.QRect(180, 20, 450, 450)) self.label_orign.setStyleSheet("background-color: rgb(255, 255, 255);") self.label_orign.setObjectName("label_orign") self.label_pre = QtWidgets.QLabel(self.centralwidget) self.label_pre.setGeometry(QtCore.QRect(660, 20, 450, 450)) self.label_pre.setStyleSheet("background-color: rgb(255, 255, 255);") self.label_pre.setObjectName("label_pre") self.pushButton_opimg = QtWidgets.QPushButton(self.centralwidget) self.pushButton_opimg.setGeometry(QtCore.QRect(10, 20, 151, 51)) self.pushButton_opimg.setObjectName("pushButton_opimg") MainWindow.setCentralWidget(self.centralwidget) self.menubar = QtWidgets.QMenuBar(MainWindow) self.menubar.setGeometry(QtCore.QRect(0, 0, 1170, 26)) self.menubar.setObjectName("menubar") MainWindow.setMenuBar(self.menubar) self.statusbar = QtWidgets.QStatusBar(MainWindow) self.statusbar.setObjectName("statusbar") MainWindow.setStatusBar(self.statusbar) self.retranslateUi(MainWindow) QtCore.QMetaObject.connectSlotsByName(MainWindow) def retranslateUi(self, MainWindow): _translate = QtCore.QCoreApplication.translate MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow")) self.pushButton_w.setText(_translate("MainWindow", "w")) self.pushButton_a.setText(_translate("MainWindow", "a")) self.pushButton_d.setText(_translate("MainWindow", "d")) self.pushButton_s.setText(_translate("MainWindow", "s")) self.pushButton_q.setText(_translate("MainWindow", "q")) self.label_orign.setText( _translate("MainWindow", "
Original Image
")) self.label_pre.setText( _translate("MainWindow", "Predicted Image
")) self.pushButton_opimg.setText(_translate("MainWindow", "Open Image")) class MainWindow(QtWidgets.QMainWindow, Ui_MainWindow): def __init__(self): super().__init__() self.setupUi(self) self.pushButton_opimg.clicked.connect(self.open_image) self.pushButton_w.clicked.connect(self.predict_and_interact) self.image_files = [] self.current_index = 0 self.input_point = [] self.input_label = [] self.input_stop = False self.interaction_count = 0 # 记录交互次数 self.sam = sam_model_registry["vit_b"]( checkpoint=r"D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth") _ = self.sam.to(device="cuda") self.predictor = SamPredictor(self.sam) # Calculate coordinate scaling factors self.scale_x = 1.0 self.scale_y = 1.0 self.label_pre_width = self.label_pre.width() self.label_pre_height = self.label_pre.height() # Set mouse click event for original image label self.set_mouse_click_event() def open_image(self): options = QtWidgets.QFileDialog.Options() filename, _ = QtWidgets.QFileDialog.getOpenFileName(self, "Open Image File", "", "Image Files (*.png *.jpg *.jpeg *.JPG *.JPEG *.PNG *.tiff)", options=options) if filename: image = cv2.imread(filename) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) self.image_files.append(image) self.display_original_image() def display_original_image(self): if self.image_files: image = self.image_files[self.current_index] height, width, channel = image.shape bytesPerLine = 3 * width qImg = QtGui.QImage(image.data, width, height, bytesPerLine, QtGui.QImage.Format_RGB888) pixmap = QtGui.QPixmap.fromImage(qImg) self.label_orign.setPixmap(pixmap.scaled(self.label_orign.size(), QtCore.Qt.KeepAspectRatio)) # Add mouse click event self.label_orign.mousePressEvent = self.mouse_click # Draw marked points on the original image painter = QtGui.QPainter(self.label_orign.pixmap()) # Use label_orign for drawing points pen_foreground = QtGui.QPen(QtGui.QColor(0, 255, 0)) # Red color for foreground points pen_foreground.setWidth(5) # Set the width of the pen to 5 (adjust as needed) pen_background = QtGui.QPen(QtGui.QColor(255, 0, 0)) # Green color for background points pen_background.setWidth(5) # Set the width of the pen to 5 (adjust as needed) painter.setPen(pen_foreground) for point, label in zip(self.input_point, self.input_label): x, y = self.convert_to_label_coords(point) if label == 1: # Foreground point painter.drawPoint(QtCore.QPoint(x, y)) painter.setPen(pen_background) for point, label in zip(self.input_point, self.input_label): x, y = self.convert_to_label_coords(point) if label == 0: # Background point painter.drawPoint(QtCore.QPoint(x, y)) painter.end() # Calculate coordinate scaling factors self.scale_x = width / self.label_orign.width() self.scale_y = height / self .label_orign.height() def convert_to_label_coords(self, point): x = point[0] / self.scale_x y = point[1] / self.scale_y return x, y def mouse_click(self, event): if not self.input_stop: x = int(event.pos().x() * self.scale_x) y = int(event.pos().y() * self.scale_y) if event.button() == QtCore.Qt.LeftButton: # If left-clicked, mark as foreground self.input_label.append(1) # Foreground label is 1 elif event.button() == QtCore.Qt.RightButton: # If right-clicked, mark as background self.input_label.append(0) # Background label is 0 self.input_point.append([x, y]) # Update the original image with marked points self.display_original_image() def predict_and_interact(self): if not self.image_files: return image = self.image_files[self.current_index].copy() filename = f"image_{self.current_index}.png" image_crop = image.copy() while True: # Outer loop for prediction # Prediction logic if not self.input_stop: # If not in interaction mode if len(self.input_point) > 0 and len(self.input_label) > 0: self.predictor.set_image(image) input_point_np = np.array(self.input_point) input_label_np = np.array(self.input_label) masks, scores, logits = self.predictor.predict( point_coords=input_point_np, point_labels=input_label_np, multimask_output=True, ) mask_idx = 0 num_masks = len(masks) while True: # Inner loop for interaction color = tuple(np.random.randint(0, 256, 3).tolist()) image_select = image.copy() selected_mask = masks[mask_idx] selected_image = self.apply_color_mask(image_select, selected_mask, color) mask_info = f'Total: {num_masks} | Current: {mask_idx} | Score: {scores[mask_idx]:.2f} | w Predict | d Next | a Previous | q Remove Last | s Save' cv2.putText(selected_image, mask_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2, cv2.LINE_AA) # Display the predicted result in label_pre area self.display_prediction_image(selected_image) key = cv2.waitKey(10) # Handle key press events if key == ord('q') and len(self.input_point) > 0: self.input_point.pop(-1) self.input_label.pop(-1) self.display_original_image() elif key == ord('s'): self.save_masked_image(image_crop, selected_mask, filename) elif key == ord('a'): if mask_idx > 0: mask_idx -= 1 else: mask_idx = num_masks - 1 elif key == ord('d'): if mask_idx < num_masks - 1: mask_idx += 1 else: mask_idx = 0 elif key == ord(" "): break if cv2.getWindowProperty("Prediction", cv2.WND_PROP_VISIBLE) < 1: break # If 'w' is pressed, toggle interaction mode if key == ord('w'): self.input_stop = not self.input_stop # Toggle interaction mode if not self.input_stop: # If entering interaction mode self.interaction_count += 1 if self.interaction_count % 2 == 0: # If even number of interactions, call the interaction function self.input_point = [] # Reset input points for the next interaction self.input_label = [] # Reset input labels for the next interaction self.display_original_image() # Display original image self.set_mouse_click_event() # Set mouse click event break # Exit outer loop else: continue # Continue prediction # Exit the outer loop if not in interaction mode if not self.input_stop: break def set_mouse_click_event(self): self.label_orign.mousePressEvent = self.mouse_click def display_prediction_image(self, image): height, width, channel = image.shape bytesPerLine = 3 * width qImg = QtGui.QImage(image.data, width, height, bytesPerLine, QtGui.QImage.Format_RGB888) pixmap = QtGui.QPixmap.fromImage(qImg) self.label_pre.setPixmap(pixmap.scaled(self.label_pre.size(), QtCore.Qt.KeepAspectRatio)) # Draw marked points on the predicted image painter = QtGui.QPainter(self.label_pre.pixmap()) pen_foreground = QtGui.QPen(QtGui.QColor(0, 255, 0)) # Red color for foreground points pen_foreground.setWidth(5) # Set the width of the pen to 5 (adjust as needed) pen_background = QtGui.QPen(QtGui.QColor(255, 0, 0)) # Green color for background points pen_background.setWidth(5) # Set the width of the pen to 5 (adjust as needed) painter.setPen(pen_foreground) for point, label in zip(self.input_point, self.input_label): x, y = self.convert_to_label_coords(point) if label == 1: # Foreground point painter.drawPoint(QtCore.QPoint(x, y)) painter.setPen(pen_background) for point, label in zip(self.input_point, self.input_label): x, y = self.convert_to_label_coords(point) if label == 0: # Background point painter.drawPoint(QtCore.QPoint(x, y)) painter.end() def apply_color_mask(self, image, mask, color=(0, 255, 0), color_dark=0.5): for c in range(3): image[:, :, c] = np.where(mask == 1, image[:, :, c] * (1 - color_dark) + color_dark * color[c], image[:, :, c]) return image def save_masked_image(self, image, mask, filename): output_dir = os.path.dirname(filename) filename = os.path.basename(filename) filename = filename[:filename.rfind('.')] + '_masked.png' new_filename = os.path.join(output_dir, filename) masked_image = self.apply_color_mask(image, mask) cv2.imwrite(new_filename, cv2.cvtColor(masked_image, cv2.COLOR_RGB2BGR)) print(f"Saved as { new_filename}") def previous_image(self): if self.current_index > 0: self.current_index -= 1 self.display_original_image() def next_image(self): if self.current_index < len(self.image_files) - 1: self.current_index += 1 self.display_original_image() if __name__ == "__main__": import sys app = QtWidgets.QApplication(sys.argv) window = MainWindow() window.show() sys.exit(app.exec_())