301 lines
14 KiB
Python
301 lines
14 KiB
Python
|
import cv2
|
||
|
import os
|
||
|
import numpy as np
|
||
|
from PyQt5 import QtCore, QtGui, QtWidgets
|
||
|
from segment_anything import sam_model_registry, SamPredictor
|
||
|
|
||
|
|
||
|
class Ui_MainWindow(object):
|
||
|
def setupUi(self, MainWindow):
|
||
|
MainWindow.setObjectName("MainWindow")
|
||
|
MainWindow.resize(1170, 486)
|
||
|
self.centralwidget = QtWidgets.QWidget(MainWindow)
|
||
|
self.centralwidget.setObjectName("centralwidget")
|
||
|
self.pushButton_w = QtWidgets.QPushButton(self.centralwidget)
|
||
|
self.pushButton_w.setGeometry(QtCore.QRect(10, 90, 151, 51))
|
||
|
self.pushButton_w.setObjectName("pushButton_w")
|
||
|
self.pushButton_a = QtWidgets.QPushButton(self.centralwidget)
|
||
|
self.pushButton_a.setGeometry(QtCore.QRect(10, 160, 151, 51))
|
||
|
self.pushButton_a.setObjectName("pushButton_a")
|
||
|
self.pushButton_d = QtWidgets.QPushButton(self.centralwidget)
|
||
|
self.pushButton_d.setGeometry(QtCore.QRect(10, 230, 151, 51))
|
||
|
self.pushButton_d.setObjectName("pushButton_d")
|
||
|
self.pushButton_s = QtWidgets.QPushButton(self.centralwidget)
|
||
|
self.pushButton_s.setGeometry(QtCore.QRect(10, 300, 151, 51))
|
||
|
self.pushButton_s.setObjectName("pushButton_s")
|
||
|
self.pushButton_q = QtWidgets.QPushButton(self.centralwidget)
|
||
|
self.pushButton_q.setGeometry(QtCore.QRect(10, 370, 151, 51))
|
||
|
self.pushButton_q.setObjectName("pushButton_q")
|
||
|
self.label_orign = QtWidgets.QLabel(self.centralwidget)
|
||
|
self.label_orign.setGeometry(QtCore.QRect(180, 20, 450, 450))
|
||
|
self.label_orign.setStyleSheet("background-color: rgb(255, 255, 255);")
|
||
|
self.label_orign.setObjectName("label_orign")
|
||
|
self.label_pre = QtWidgets.QLabel(self.centralwidget)
|
||
|
self.label_pre.setGeometry(QtCore.QRect(660, 20, 450, 450))
|
||
|
self.label_pre.setStyleSheet("background-color: rgb(255, 255, 255);")
|
||
|
self.label_pre.setObjectName("label_pre")
|
||
|
self.pushButton_opimg = QtWidgets.QPushButton(self.centralwidget)
|
||
|
self.pushButton_opimg.setGeometry(QtCore.QRect(10, 20, 151, 51))
|
||
|
self.pushButton_opimg.setObjectName("pushButton_opimg")
|
||
|
MainWindow.setCentralWidget(self.centralwidget)
|
||
|
self.menubar = QtWidgets.QMenuBar(MainWindow)
|
||
|
self.menubar.setGeometry(QtCore.QRect(0, 0, 1170, 26))
|
||
|
self.menubar.setObjectName("menubar")
|
||
|
MainWindow.setMenuBar(self.menubar)
|
||
|
self.statusbar = QtWidgets.QStatusBar(MainWindow)
|
||
|
self.statusbar.setObjectName("statusbar")
|
||
|
MainWindow.setStatusBar(self.statusbar)
|
||
|
|
||
|
self.retranslateUi(MainWindow)
|
||
|
QtCore.QMetaObject.connectSlotsByName(MainWindow)
|
||
|
|
||
|
def retranslateUi(self, MainWindow):
|
||
|
_translate = QtCore.QCoreApplication.translate
|
||
|
MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow"))
|
||
|
self.pushButton_w.setText(_translate("MainWindow", "w"))
|
||
|
self.pushButton_a.setText(_translate("MainWindow", "a"))
|
||
|
self.pushButton_d.setText(_translate("MainWindow", "d"))
|
||
|
self.pushButton_s.setText(_translate("MainWindow", "s"))
|
||
|
self.pushButton_q.setText(_translate("MainWindow", "q"))
|
||
|
self.label_orign.setText(
|
||
|
_translate("MainWindow", "<html><head/><body><p align=\"center\">Original Image</p></body></html>"))
|
||
|
self.label_pre.setText(
|
||
|
_translate("MainWindow", "<html><head/><body><p align=\"center\">Predicted Image</p></body></html>"))
|
||
|
self.pushButton_opimg.setText(_translate("MainWindow", "Open Image"))
|
||
|
|
||
|
|
||
|
class MainWindow(QtWidgets.QMainWindow, Ui_MainWindow):
|
||
|
def __init__(self):
|
||
|
super().__init__()
|
||
|
self.setupUi(self)
|
||
|
|
||
|
self.pushButton_opimg.clicked.connect(self.open_image)
|
||
|
self.pushButton_w.clicked.connect(self.predict_and_interact)
|
||
|
|
||
|
self.image_files = []
|
||
|
self.current_index = 0
|
||
|
self.input_point = []
|
||
|
self.input_label = []
|
||
|
self.input_stop = False
|
||
|
self.interaction_count = 0 # 记录交互次数
|
||
|
self.sam = sam_model_registry["vit_b"](
|
||
|
checkpoint=r"D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth")
|
||
|
_ = self.sam.to(device="cuda")
|
||
|
self.predictor = SamPredictor(self.sam)
|
||
|
|
||
|
# Calculate coordinate scaling factors
|
||
|
self.scale_x = 1.0
|
||
|
self.scale_y = 1.0
|
||
|
self.label_pre_width = self.label_pre.width()
|
||
|
self.label_pre_height = self.label_pre.height()
|
||
|
|
||
|
# Set mouse click event for original image label
|
||
|
self.set_mouse_click_event()
|
||
|
|
||
|
def open_image(self):
|
||
|
options = QtWidgets.QFileDialog.Options()
|
||
|
filename, _ = QtWidgets.QFileDialog.getOpenFileName(self, "Open Image File", "",
|
||
|
"Image Files (*.png *.jpg *.jpeg *.JPG *.JPEG *.PNG *.tiff)",
|
||
|
options=options)
|
||
|
if filename:
|
||
|
image = cv2.imread(filename)
|
||
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||
|
self.image_files.append(image)
|
||
|
self.display_original_image()
|
||
|
|
||
|
def display_original_image(self):
|
||
|
if self.image_files:
|
||
|
image = self.image_files[self.current_index]
|
||
|
height, width, channel = image.shape
|
||
|
bytesPerLine = 3 * width
|
||
|
qImg = QtGui.QImage(image.data, width, height, bytesPerLine, QtGui.QImage.Format_RGB888)
|
||
|
pixmap = QtGui.QPixmap.fromImage(qImg)
|
||
|
self.label_orign.setPixmap(pixmap.scaled(self.label_orign.size(), QtCore.Qt.KeepAspectRatio))
|
||
|
|
||
|
# Add mouse click event
|
||
|
self.label_orign.mousePressEvent = self.mouse_click
|
||
|
|
||
|
# Draw marked points on the original image
|
||
|
painter = QtGui.QPainter(self.label_orign.pixmap()) # Use label_orign for drawing points
|
||
|
pen_foreground = QtGui.QPen(QtGui.QColor(0, 255, 0)) # Red color for foreground points
|
||
|
pen_foreground.setWidth(5) # Set the width of the pen to 5 (adjust as needed)
|
||
|
pen_background = QtGui.QPen(QtGui.QColor(255, 0, 0)) # Green color for background points
|
||
|
pen_background.setWidth(5) # Set the width of the pen to 5 (adjust as needed)
|
||
|
painter.setPen(pen_foreground)
|
||
|
for point, label in zip(self.input_point, self.input_label):
|
||
|
x, y = self.convert_to_label_coords(point)
|
||
|
if label == 1: # Foreground point
|
||
|
painter.drawPoint(QtCore.QPoint(x, y))
|
||
|
painter.setPen(pen_background)
|
||
|
for point, label in zip(self.input_point, self.input_label):
|
||
|
x, y = self.convert_to_label_coords(point)
|
||
|
if label == 0: # Background point
|
||
|
painter.drawPoint(QtCore.QPoint(x, y))
|
||
|
painter.end()
|
||
|
|
||
|
# Calculate coordinate scaling factors
|
||
|
self.scale_x = width / self.label_orign.width()
|
||
|
self.scale_y = height / self.label_orign.height()
|
||
|
|
||
|
def convert_to_label_coords(self, point):
|
||
|
x = point[0] / self.scale_x
|
||
|
y = point[1] / self.scale_y
|
||
|
return x, y
|
||
|
|
||
|
def mouse_click(self, event):
|
||
|
if not self.input_stop:
|
||
|
x = int(event.pos().x() * self.scale_x)
|
||
|
y = int(event.pos().y() * self.scale_y)
|
||
|
if event.button() == QtCore.Qt.LeftButton: # If left-clicked, mark as foreground
|
||
|
self.input_label.append(1) # Foreground label is 1
|
||
|
elif event.button() == QtCore.Qt.RightButton: # If right-clicked, mark as background
|
||
|
self.input_label.append(0) # Background label is 0
|
||
|
|
||
|
self.input_point.append([x, y])
|
||
|
|
||
|
# Update the original image with marked points
|
||
|
self.display_original_image()
|
||
|
|
||
|
def predict_and_interact(self):
|
||
|
if not self.image_files:
|
||
|
return
|
||
|
|
||
|
image = self.image_files[self.current_index].copy()
|
||
|
filename = f"image_{self.current_index}.png"
|
||
|
image_crop = image.copy()
|
||
|
|
||
|
while True: # Outer loop for prediction
|
||
|
# Prediction logic
|
||
|
if not self.input_stop: # If not in interaction mode
|
||
|
if len(self.input_point) > 0 and len(self.input_label) > 0:
|
||
|
self.predictor.set_image(image)
|
||
|
input_point_np = np.array(self.input_point)
|
||
|
input_label_np = np.array(self.input_label)
|
||
|
|
||
|
masks, scores, logits = self.predictor.predict(
|
||
|
point_coords=input_point_np,
|
||
|
point_labels=input_label_np,
|
||
|
multimask_output=True,
|
||
|
)
|
||
|
|
||
|
mask_idx = 0
|
||
|
num_masks = len(masks)
|
||
|
|
||
|
while True: # Inner loop for interaction
|
||
|
color = tuple(np.random.randint(0, 256, 3).tolist())
|
||
|
image_select = image.copy()
|
||
|
selected_mask = masks[mask_idx]
|
||
|
selected_image = self.apply_color_mask(image_select, selected_mask, color)
|
||
|
mask_info = f'Total: {num_masks} | Current: {mask_idx} | Score: {scores[mask_idx]:.2f} | w Predict | d Next | a Previous | q Remove Last | s Save'
|
||
|
cv2.putText(selected_image, mask_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255),
|
||
|
2, cv2.LINE_AA)
|
||
|
|
||
|
# Display the predicted result in label_pre area
|
||
|
self.display_prediction_image(selected_image)
|
||
|
|
||
|
key = cv2.waitKey(10)
|
||
|
|
||
|
# Handle key press events
|
||
|
if key == ord('q') and len(self.input_point) > 0:
|
||
|
self.input_point.pop(-1)
|
||
|
self.input_label.pop(-1)
|
||
|
self.display_original_image()
|
||
|
elif key == ord('s'):
|
||
|
self.save_masked_image(image_crop, selected_mask, filename)
|
||
|
elif key == ord('a'):
|
||
|
if mask_idx > 0:
|
||
|
mask_idx -= 1
|
||
|
else:
|
||
|
mask_idx = num_masks - 1
|
||
|
elif key == ord('d'):
|
||
|
if mask_idx < num_masks - 1:
|
||
|
mask_idx += 1
|
||
|
else:
|
||
|
mask_idx = 0
|
||
|
elif key == ord(" "):
|
||
|
break
|
||
|
|
||
|
if cv2.getWindowProperty("Prediction", cv2.WND_PROP_VISIBLE) < 1:
|
||
|
break
|
||
|
|
||
|
# If 'w' is pressed, toggle interaction mode
|
||
|
if key == ord('w'):
|
||
|
self.input_stop = not self.input_stop # Toggle interaction mode
|
||
|
if not self.input_stop: # If entering interaction mode
|
||
|
self.interaction_count += 1
|
||
|
if self.interaction_count % 2 == 0: # If even number of interactions, call the interaction function
|
||
|
self.input_point = [] # Reset input points for the next interaction
|
||
|
self.input_label = [] # Reset input labels for the next interaction
|
||
|
self.display_original_image() # Display original image
|
||
|
self.set_mouse_click_event() # Set mouse click event
|
||
|
break # Exit outer loop
|
||
|
else:
|
||
|
continue # Continue prediction
|
||
|
|
||
|
# Exit the outer loop if not in interaction mode
|
||
|
if not self.input_stop:
|
||
|
break
|
||
|
|
||
|
def set_mouse_click_event(self):
|
||
|
self.label_orign.mousePressEvent = self.mouse_click
|
||
|
|
||
|
def display_prediction_image(self, image):
|
||
|
height, width, channel = image.shape
|
||
|
bytesPerLine = 3 * width
|
||
|
qImg = QtGui.QImage(image.data, width, height, bytesPerLine, QtGui.QImage.Format_RGB888)
|
||
|
pixmap = QtGui.QPixmap.fromImage(qImg)
|
||
|
self.label_pre.setPixmap(pixmap.scaled(self.label_pre.size(), QtCore.Qt.KeepAspectRatio))
|
||
|
|
||
|
# Draw marked points on the predicted image
|
||
|
painter = QtGui.QPainter(self.label_pre.pixmap())
|
||
|
pen_foreground = QtGui.QPen(QtGui.QColor(0, 255, 0)) # Red color for foreground points
|
||
|
pen_foreground.setWidth(5) # Set the width of the pen to 5 (adjust as needed)
|
||
|
pen_background = QtGui.QPen(QtGui.QColor(255, 0, 0)) # Green color for background points
|
||
|
pen_background.setWidth(5) # Set the width of the pen to 5 (adjust as needed)
|
||
|
painter.setPen(pen_foreground)
|
||
|
for point, label in zip(self.input_point, self.input_label):
|
||
|
x, y = self.convert_to_label_coords(point)
|
||
|
if label == 1: # Foreground point
|
||
|
painter.drawPoint(QtCore.QPoint(x, y))
|
||
|
painter.setPen(pen_background)
|
||
|
for point, label in zip(self.input_point, self.input_label):
|
||
|
x, y = self.convert_to_label_coords(point)
|
||
|
if label == 0: # Background point
|
||
|
painter.drawPoint(QtCore.QPoint(x, y))
|
||
|
painter.end()
|
||
|
|
||
|
def apply_color_mask(self, image, mask, color=(0, 255, 0), color_dark=0.5):
|
||
|
for c in range(3):
|
||
|
image[:, :, c] = np.where(mask == 1, image[:, :, c] * (1 - color_dark) + color_dark * color[c],
|
||
|
image[:, :, c])
|
||
|
return image
|
||
|
|
||
|
def save_masked_image(self, image, mask, filename):
|
||
|
output_dir = os.path.dirname(filename)
|
||
|
filename = os.path.basename(filename)
|
||
|
filename = filename[:filename.rfind('.')] + '_masked.png'
|
||
|
new_filename = os.path.join(output_dir, filename)
|
||
|
|
||
|
masked_image = self.apply_color_mask(image, mask)
|
||
|
cv2.imwrite(new_filename, cv2.cvtColor(masked_image, cv2.COLOR_RGB2BGR))
|
||
|
print(f"Saved as {new_filename}")
|
||
|
|
||
|
def previous_image(self):
|
||
|
if self.current_index > 0:
|
||
|
self.current_index -= 1
|
||
|
self.display_original_image()
|
||
|
|
||
|
def next_image(self):
|
||
|
if self.current_index < len(self.image_files) - 1:
|
||
|
self.current_index += 1
|
||
|
self.display_original_image()
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
import sys
|
||
|
app = QtWidgets.QApplication(sys.argv)
|
||
|
window = MainWindow()
|
||
|
window.show()
|
||
|
sys.exit(app.exec_())
|
||
|
|
||
|
|