SAM/salt/segment1.py

300 lines
14 KiB
Python

import cv2
import os
import numpy as np
from PyQt5 import QtCore, QtGui, QtWidgets
from segment_anything import sam_model_registry, SamPredictor
class Ui_MainWindow(object):
def setupUi(self, MainWindow):
MainWindow.setObjectName("MainWindow")
MainWindow.resize(1170, 486)
self.centralwidget = QtWidgets.QWidget(MainWindow)
self.centralwidget.setObjectName("centralwidget")
self.pushButton_w = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_w.setGeometry(QtCore.QRect(10, 90, 151, 51))
self.pushButton_w.setObjectName("pushButton_w")
self.pushButton_a = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_a.setGeometry(QtCore.QRect(10, 160, 151, 51))
self.pushButton_a.setObjectName("pushButton_a")
self.pushButton_d = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_d.setGeometry(QtCore.QRect(10, 230, 151, 51))
self.pushButton_d.setObjectName("pushButton_d")
self.pushButton_s = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_s.setGeometry(QtCore.QRect(10, 300, 151, 51))
self.pushButton_s.setObjectName("pushButton_s")
self.pushButton_q = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_q.setGeometry(QtCore.QRect(10, 370, 151, 51))
self.pushButton_q.setObjectName("pushButton_q")
self.label_orign = QtWidgets.QLabel(self.centralwidget)
self.label_orign.setGeometry(QtCore.QRect(180, 20, 450, 450))
self.label_orign.setStyleSheet("background-color: rgb(255, 255, 255);")
self.label_orign.setObjectName("label_orign")
self.label_pre = QtWidgets.QLabel(self.centralwidget)
self.label_pre.setGeometry(QtCore.QRect(660, 20, 450, 450))
self.label_pre.setStyleSheet("background-color: rgb(255, 255, 255);")
self.label_pre.setObjectName("label_pre")
self.pushButton_opimg = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_opimg.setGeometry(QtCore.QRect(10, 20, 151, 51))
self.pushButton_opimg.setObjectName("pushButton_opimg")
MainWindow.setCentralWidget(self.centralwidget)
self.menubar = QtWidgets.QMenuBar(MainWindow)
self.menubar.setGeometry(QtCore.QRect(0, 0, 1170, 26))
self.menubar.setObjectName("menubar")
MainWindow.setMenuBar(self.menubar)
self.statusbar = QtWidgets.QStatusBar(MainWindow)
self.statusbar.setObjectName("statusbar")
MainWindow.setStatusBar(self.statusbar)
self.retranslateUi(MainWindow)
QtCore.QMetaObject.connectSlotsByName(MainWindow)
def retranslateUi(self, MainWindow):
_translate = QtCore.QCoreApplication.translate
MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow"))
self.pushButton_w.setText(_translate("MainWindow", "w"))
self.pushButton_a.setText(_translate("MainWindow", "a"))
self.pushButton_d.setText(_translate("MainWindow", "d"))
self.pushButton_s.setText(_translate("MainWindow", "s"))
self.pushButton_q.setText(_translate("MainWindow", "q"))
self.label_orign.setText(
_translate("MainWindow", "<html><head/><body><p align=\"center\">Original Image</p></body></html>"))
self.label_pre.setText(
_translate("MainWindow", "<html><head/><body><p align=\"center\">Predicted Image</p></body></html>"))
self.pushButton_opimg.setText(_translate("MainWindow", "Open Image"))
class MainWindow(QtWidgets.QMainWindow, Ui_MainWindow):
def __init__(self):
super().__init__()
self.setupUi(self)
self.pushButton_opimg.clicked.connect(self.open_image)
self.pushButton_w.clicked.connect(self.predict_and_interact)
self.image_files = []
self.current_index = 0
self.input_point = []
self.input_label = []
self.input_stop = False
self.interaction_count = 0 # 记录交互次数
self.sam = sam_model_registry["vit_b"](
checkpoint=r"D:\Program Files\Pycharm items\segment-anything-model\weights\vit_b.pth")
_ = self.sam.to(device="cuda")
self.predictor = SamPredictor(self.sam)
# Calculate coordinate scaling factors
self.scale_x = 1.0
self.scale_y = 1.0
self.label_pre_width = self.label_pre.width()
self.label_pre_height = self.label_pre.height()
# Set mouse click event for original image label
self.set_mouse_click_event()
def open_image(self):
options = QtWidgets.QFileDialog.Options()
filename, _ = QtWidgets.QFileDialog.getOpenFileName(self, "Open Image File", "",
"Image Files (*.png *.jpg *.jpeg *.JPG *.JPEG *.PNG *.tiff)",
options=options)
if filename:
image = cv2.imread(filename)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
self.image_files.append(image)
self.display_original_image()
def display_original_image(self):
if self.image_files:
image = self.image_files[self.current_index]
height, width, channel = image.shape
bytesPerLine = 3 * width
qImg = QtGui.QImage(image.data, width, height, bytesPerLine, QtGui.QImage.Format_RGB888)
pixmap = QtGui.QPixmap.fromImage(qImg)
self.label_orign.setPixmap(pixmap.scaled(self.label_orign.size(), QtCore.Qt.KeepAspectRatio))
# Add mouse click event
self.label_orign.mousePressEvent = self.mouse_click
# Draw marked points on the original image
painter = QtGui.QPainter(self.label_orign.pixmap()) # Use label_orign for drawing points
pen_foreground = QtGui.QPen(QtGui.QColor(0, 255, 0)) # Red color for foreground points
pen_foreground.setWidth(5) # Set the width of the pen to 5 (adjust as needed)
pen_background = QtGui.QPen(QtGui.QColor(255, 0, 0)) # Green color for background points
pen_background.setWidth(5) # Set the width of the pen to 5 (adjust as needed)
painter.setPen(pen_foreground)
for point, label in zip(self.input_point, self.input_label):
x, y = self.convert_to_label_coords(point)
if label == 1: # Foreground point
painter.drawPoint(QtCore.QPoint(x, y))
painter.setPen(pen_background)
for point, label in zip(self.input_point, self.input_label):
x, y = self.convert_to_label_coords(point)
if label == 0: # Background point
painter.drawPoint(QtCore.QPoint(x, y))
painter.end()
# Calculate coordinate scaling factors
self.scale_x = width / self.label_orign.width()
self.scale_y = height / self .label_orign.height()
def convert_to_label_coords(self, point):
x = point[0] / self.scale_x
y = point[1] / self.scale_y
return x, y
def mouse_click(self, event):
if not self.input_stop:
x = int(event.pos().x() * self.scale_x)
y = int(event.pos().y() * self.scale_y)
if event.button() == QtCore.Qt.LeftButton: # If left-clicked, mark as foreground
self.input_label.append(1) # Foreground label is 1
elif event.button() == QtCore.Qt.RightButton: # If right-clicked, mark as background
self.input_label.append(0) # Background label is 0
self.input_point.append([x, y])
# Update the original image with marked points
self.display_original_image()
def predict_and_interact(self):
if not self.image_files:
return
image = self.image_files[self.current_index].copy()
filename = f"image_{self.current_index}.png"
image_crop = image.copy()
while True: # Outer loop for prediction
# Prediction logic
if not self.input_stop: # If not in interaction mode
if len(self.input_point) > 0 and len(self.input_label) > 0:
self.predictor.set_image(image)
input_point_np = np.array(self.input_point)
input_label_np = np.array(self.input_label)
masks, scores, logits = self.predictor.predict(
point_coords=input_point_np,
point_labels=input_label_np,
multimask_output=True,
)
mask_idx = 0
num_masks = len(masks)
while True: # Inner loop for interaction
color = tuple(np.random.randint(0, 256, 3).tolist())
image_select = image.copy()
selected_mask = masks[mask_idx]
selected_image = self.apply_color_mask(image_select, selected_mask, color)
mask_info = f'Total: {num_masks} | Current: {mask_idx} | Score: {scores[mask_idx]:.2f} | w Predict | d Next | a Previous | q Remove Last | s Save'
cv2.putText(selected_image, mask_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255),
2, cv2.LINE_AA)
# Display the predicted result in label_pre area
self.display_prediction_image(selected_image)
key = cv2.waitKey(10)
# Handle key press events
if key == ord('q') and len(self.input_point) > 0:
self.input_point.pop(-1)
self.input_label.pop(-1)
self.display_original_image()
elif key == ord('s'):
self.save_masked_image(image_crop, selected_mask, filename)
elif key == ord('a'):
if mask_idx > 0:
mask_idx -= 1
else:
mask_idx = num_masks - 1
elif key == ord('d'):
if mask_idx < num_masks - 1:
mask_idx += 1
else:
mask_idx = 0
elif key == ord(" "):
break
if cv2.getWindowProperty("Prediction", cv2.WND_PROP_VISIBLE) < 1:
break
# If 'w' is pressed, toggle interaction mode
if key == ord('w'):
self.input_stop = not self.input_stop # Toggle interaction mode
if not self.input_stop: # If entering interaction mode
self.interaction_count += 1
if self.interaction_count % 2 == 0: # If even number of interactions, call the interaction function
self.input_point = [] # Reset input points for the next interaction
self.input_label = [] # Reset input labels for the next interaction
self.display_original_image() # Display original image
self.set_mouse_click_event() # Set mouse click event
break # Exit outer loop
else:
continue # Continue prediction
# Exit the outer loop if not in interaction mode
if not self.input_stop:
break
def set_mouse_click_event(self):
self.label_orign.mousePressEvent = self.mouse_click
def display_prediction_image(self, image):
height, width, channel = image.shape
bytesPerLine = 3 * width
qImg = QtGui.QImage(image.data, width, height, bytesPerLine, QtGui.QImage.Format_RGB888)
pixmap = QtGui.QPixmap.fromImage(qImg)
self.label_pre.setPixmap(pixmap.scaled(self.label_pre.size(), QtCore.Qt.KeepAspectRatio))
# Draw marked points on the predicted image
painter = QtGui.QPainter(self.label_pre.pixmap())
pen_foreground = QtGui.QPen(QtGui.QColor(0, 255, 0)) # Red color for foreground points
pen_foreground.setWidth(5) # Set the width of the pen to 5 (adjust as needed)
pen_background = QtGui.QPen(QtGui.QColor(255, 0, 0)) # Green color for background points
pen_background.setWidth(5) # Set the width of the pen to 5 (adjust as needed)
painter.setPen(pen_foreground)
for point, label in zip(self.input_point, self.input_label):
x, y = self.convert_to_label_coords(point)
if label == 1: # Foreground point
painter.drawPoint(QtCore.QPoint(x, y))
painter.setPen(pen_background)
for point, label in zip(self.input_point, self.input_label):
x, y = self.convert_to_label_coords(point)
if label == 0: # Background point
painter.drawPoint(QtCore.QPoint(x, y))
painter.end()
def apply_color_mask(self, image, mask, color=(0, 255, 0), color_dark=0.5):
for c in range(3):
image[:, :, c] = np.where(mask == 1, image[:, :, c] * (1 - color_dark) + color_dark * color[c],
image[:, :, c])
return image
def save_masked_image(self, image, mask, filename):
output_dir = os.path.dirname(filename)
filename = os.path.basename(filename)
filename = filename[:filename.rfind('.')] + '_masked.png'
new_filename = os.path.join(output_dir, filename)
masked_image = self.apply_color_mask(image, mask)
cv2.imwrite(new_filename, cv2.cvtColor(masked_image, cv2.COLOR_RGB2BGR))
print(f"Saved as { new_filename}")
def previous_image(self):
if self.current_index > 0:
self.current_index -= 1
self.display_original_image()
def next_image(self):
if self.current_index < len(self.image_files) - 1:
self.current_index += 1
self.display_original_image()
if __name__ == "__main__":
import sys
app = QtWidgets.QApplication(sys.argv)
window = MainWindow()
window.show()
sys.exit(app.exec_())