ai-station-code/work_util/sam_deal.py

400 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import cv2
import os
import numpy as np
import json
import shutil
from segment_anything_model.segment_anything import sam_model_registry, SamPredictor
import random
import traceback
import base64
import io
from work_util.logger import logger
import pickle
from PIL import Image
def get_display_image_ori(loaded_data):
if loaded_data['image_rgb'] is None:
logger.info("get_display_image: No image loaded")
return {"status": False,"reason":"无图片加载"}
display_image = loaded_data['image_rgb'].copy()
try:
for point, label in zip(loaded_data['input_point'], loaded_data['input_label']):
color = (0, 255, 0) if label == 1 else (0, 0, 255)
cv2.circle(display_image, tuple(point), 5, color, -1)
if loaded_data['selected_mask'] is not None:
class_name = loaded_data['class_names'][loaded_data['class_index']]
color = loaded_data['class_colors'].get(class_name, (0, 0, 128))
display_image = sam_apply_mask(display_image, loaded_data['selected_mask'], color)
logger.info(f"get_display_image: Returning image with shape {display_image.shape}")
return {"status" : True, "reason":display_image}
except Exception as e:
logger.info(f"get_display_image: Error processing image: {str(e)}")
traceback.print_exc()
return {"status": False,"reason":f"get_display_image: Error processing image: {str(e)}"}
def get_all_classes(data):
return {
"classes": data['class_names'],
"colors": {name: color for name, color in data['class_colors'].items()}
}
def get_classes(data):
return get_all_classes(data)
def reset_annotation(loaded_data):
loaded_data['input_point'] = []
loaded_data['input_label'] = []
loaded_data['selected_mask'] = None
loaded_data['logit_input'] = None
loaded_data['masks'] = {}
logger.info("已重置标注状态")
return loaded_data
def remove_class(data,api,class_name):
if class_name in api['class_annotations']:
del api['class_annotations'][class_name]
del api['class_colors'][class_name]
if class_name == data['class_names'][data['class_index']]:
# data['class_index'] = 0
data['class_index'] = -1
api['current_class'] = None
data['class_names'].remove(class_name)
if class_name in data['class_colors']:
del data['class_colors'][class_name]
if class_name in data['masks']:
del data['masks'][class_name]
return data,api
def reset_annotation_all(data,api):
data = reset_annotation(data)
for class_data in api['class_annotations'].values():
class_data['points'] = []
class_data['point_types'] = []
class_data['masks'] = []
class_data['scores'] = []
class_data['selected_mask_index'] = -1
if 'selected_mask' in class_data:
del class_data['selected_mask']
return data,api
def reset_class_annotations(data,api):
"""重置class_annotations保留类别但清除掩码和点"""
classes = get_classes(data).get('classes', [])
new_annotations = {}
for class_name in classes:
new_annotations[class_name] = {
'points': [],
'point_types': [],
'masks': [],
'selected_mask_index': -1
}
api['class_annotations'] = new_annotations
logger.info("已重置class_annotations保留类别但清除掩码和点")
return data,api
def add_class(data,class_name, color=None):
if class_name in data['class_names']:
logger.info(f"类别 '{class_name}' 已存在")
return {'status':False, 'reason':f"类别 '{class_name}' 已存在"}
data['class_names'].append(class_name)
if color is None:
color = tuple(np.random.randint(100, 256, 3).tolist())
r, g, b = [int(c) for c in color]
bgr_color = (b, g, r)
data['class_colors'][class_name] = tuple(bgr_color)
logger.info(f"已添加类别: {class_name}, 颜色: {tuple(color)}")
return {'status':True, 'reason':data}
def set_current_class(data, api, class_index, color=None):
classes = get_classes(data)
if 'classes' in classes and class_index < len(classes['classes']):
class_name = classes['classes'][class_index]
api['current_class'] = class_name
if class_name not in api['class_annotations']:
api['class_annotations'][class_name] = {
'points': [],
'point_types': [],
'masks': [],
'selected_mask_index': -1
}
color = data['class_colors'][class_name]
if color:
api['class_colors'][class_name] = color
elif class_name not in api['class_colors']:
predefined_colors = [
(255, 0, 0), (0, 255, 0), (0, 0, 255),
(255, 255, 0), (255, 0, 255), (0, 255, 255)
]
color_index = len(api['class_colors']) % len(predefined_colors)
api['class_colors'][class_name] = predefined_colors[color_index]
return class_name,api
return None,api
def add_point(data, x, y, is_foreground=True):
data['input_point'].append([x, y])
data['input_label'].append(1 if is_foreground else 0)
logger.info(f"添加{'前景' if is_foreground else '背景'}点: ({x}, {y})")
return data
def load_model(path):
# 加载配置内容
config_path = os.path.join(path,'model_params.pickle')
with open(config_path, 'rb') as file:
loaded_data,api_config = pickle.load(file)
return loaded_data,api_config
def save_model(loaded_data,api_config,path):
config_path = os.path.join(path,'model_params.pickle')
save_data = (loaded_data,api_config)
with open(config_path, 'wb') as file:
pickle.dump(save_data, file)
def sam_apply_mask(image, mask, color, alpha=0.5):
masked_image = image.copy()
for c in range(3):
masked_image[:, :, c] = np.where(
mask == 1,
image[:, :, c] * (1 - alpha) + alpha * color[c],
image[:, :, c]
)
return masked_image
def apply_mask_overlay(image, mask, color, alpha=0.5):
colored_mask = np.zeros_like(image)
colored_mask[mask > 0] = color
return cv2.addWeighted(image, 1, colored_mask, alpha, 0)
def load_tmp_image(path):
with open(path, "rb") as image_file:
# 将图片文件读取为二进制数据并进行 Base64 编码
encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
return encoded_string
def refresh_image(data,api,path):
result = get_image_display(data,api)
if result['status'] == False:
return {
"status" : False,
"reason" : result['reason']
}
else:
img = result['reason']
display_img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
pil_img = Image.fromarray(display_img_rgb)
tmp_path = os.path.join(path,'temp/output_image.jpg')
pil_img.save(tmp_path)
return {
"status" : True,
"reason" : tmp_path
}
def set_class_color(data,api,index,color):
class_name = data['class_names'][index]
api['class_colors'][class_name] = color
data['class_colors'][class_name] = color
return data,api
def add_to_class(api,class_name):
if class_name is None:
class_name = api['current_class']
if not class_name or class_name not in api['class_annotations']:
return {
'status':False,
'reason': "该分类没有标注点"
}
class_data = api['class_annotations'][class_name]
if 'selected_mask' not in class_data or class_data['selected_mask'] is None:
return {
'status':False,
'reason': "该分类没有进行预测"
}
class_data['final_mask'] = class_data['selected_mask'].copy()
class_data['points'] = []
class_data['point_types'] = []
class_data['masks'] = []
class_data['scores'] = []
class_data['selected_mask_index'] = -1
return {
'status':True,
'reason': api
}
def get_image_display(data,api):
if data['image'] is None:
logger.info("get_display_image: No image loaded")
return {"status":False, "reason":"获取图像:没有图像加载"}
display_image = data['image'].copy()
try:
for point, label in zip(data['input_point'], data['input_label']):
color = (0, 255, 0) if label == 1 else (0, 0, 255)
cv2.circle(display_image, tuple(point), 5, color, -1)
if data['class_index'] != -1:
if data['selected_mask'] is not None:
class_name = data['class_names'][data['class_index']]
color = data['class_colors'].get(class_name, (0, 0, 128))
display_image = sam_apply_mask(display_image, data['selected_mask'], color)
logger.info(f"get_display_image: Returning image with shape {display_image.shape}")
img = display_image
if not isinstance(img, np.ndarray) or img.size == 0:
logger.info(f"get_image_display: Invalid image array, shape: {img.shape if isinstance(img, np.ndarray) else 'None'}")
return {"status":False, "reason":f"get_image_display: Invalid image array, shape: {display_image.shape if isinstance(display_image, np.ndarray) else 'None'}"}
# 仅应用当前图片的final_mask
for class_name, class_data in api['class_annotations'].items():
if 'final_mask' in class_data and class_data['final_mask'] is not None:
color = api['class_colors'].get(class_name, (0, 255, 0))
mask = class_data['final_mask']
if isinstance(mask, list):
mask = np.array(mask, dtype=np.uint8)
logger.info(f"Applying mask for class {class_name}, shape: {mask.shape}")
img = apply_mask_overlay(img, mask, color, alpha=0.5)
elif 'selected_mask' in class_data and class_data['selected_mask'] is not None:
color = api['class_colors'].get(class_name, (0, 255, 0))
mask = class_data['selected_mask']
if isinstance(mask, list):
mask = np.array(mask, dtype=np.uint8)
logger.info(f"Applying mask for class {class_name}, shape: {mask.shape}")
img = apply_mask_overlay(img, mask, color, alpha=0.5)
if api['current_class'] and api['current_class'] in api['class_annotations']:
class_data = api['class_annotations'][api['current_class']]
for i, (x, y) in enumerate(class_data['points']):
is_fg = class_data['point_types'][i]
color = (0, 255, 0) if is_fg else (0, 0, 255)
print(f"Drawing point at ({x}, {y}), type: {'foreground' if is_fg else 'background'}")
cv2.circle(img, (int(x), int(y)), 5, color, -1)
logger.info(f"get_image_display: Returning image with shape {img.shape}")
return {"status":True, "reason":img}
except Exception as e:
logger.error(f"get_display_image: Error processing image: {str(e)}")
traceback.print_exc()
return {"status":False, "reason":f"get_display_image: Error processing image: {str(e)}"}
def add_annotation_point(api, x, y, is_foreground=True):
if not api['current_class'] or api['current_class'] not in api['class_annotations']:
return {
"status": False,
"reason": "请选择或新建分类"
}
class_data = api['class_annotations'][api['current_class']]
class_data['points'].append((x, y))
class_data['point_types'].append(is_foreground)
return {
'status' : True,
'reason' : {
'points': class_data['points'],
'types': class_data['point_types']
},
'api':api
}
def delete_last_point(api):
if not api['current_class'] or api['current_class'] not in api['class_annotations']:
return {
"status": False,
"reason": "当前类没有点需要删除"
}
class_data = api['class_annotations'][api['current_class']]
if not class_data['points']:
return {
"status": False,
"reason": "当前类没有点需要删除"
}
class_data['points'].pop()
class_data['point_types'].pop()
return {
"status": True,
"reason": api
}
def reset_current_class_points(api):
if not api['current_class'] or api['current_class'] not in api['class_annotations']:
return {
"status": False,
"reason": "当前类没有点需要删除"
}
class_data = api['class_annotations'][api['current_class']]
class_data['points'] = []
class_data['point_types'] = []
return {
"status": True,
"reason": api
}
def predict_mask(data,sam_predictor):
if data['image_rgb'] is None:
logger.info("predict_mask: No image loaded")
return {"status":False, "reason":"预测掩码:没有图像加载"}
if len(data['input_point']) == 0:
logger.info("predict_mask: No points added")
return {"status":False, "reason":"预测掩码:没有进行点标注"}
try:
sam_predictor.set_image(data['image_rgb'])
except Exception as e:
logger.error(f"predict_mask: Error setting image: {str(e)}")
return {"status":False, "reason":f"predict_mask: Error setting image: {str(e)}"}
input_point_np = np.array(data['input_point'])
input_label_np = np.array(data['input_label'])
try:
masks_pred, scores, logits = sam_predictor.predict(
point_coords=input_point_np,
point_labels=input_label_np,
mask_input=data['logit_input'][None, :, :] if data['logit_input'] is not None else None,
multimask_output=True,
)
except Exception as e:
logger.error(f"predict_mask: Error during prediction: {str(e)}")
traceback.print_exc()
return {"status":False, "reason":f"predict_mask: Error during prediction: {str(e)}"}
data['masks_pred'] = masks_pred
data['scores'] = scores
data['logits'] = logits
best_mask_idx = np.argmax(scores)
data['selected_mask'] = masks_pred[best_mask_idx]
data['logit_input'] = logits[best_mask_idx, :, :]
logger.info(f"predict_mask: Predicted {len(masks_pred)} masks, best score: {scores[best_mask_idx]:.4f}")
return {
"status":True,
"reason":{
"masks": [mask.tolist() for mask in masks_pred],
"scores": scores.tolist(),
"selected_index": int(best_mask_idx),
"data":data
}}
def get_image_info(data):
return {
"filename": data['filename'],
"index": data['current_index'],
"total": len(data['image_files']),
"width": data['image'].shape[1],
"height": data['image'].shape[0]
}