ai-station-code/work_util/sam_deal.py

398 lines
15 KiB
Python
Raw Normal View History

2025-06-04 17:04:02 +08:00
import cv2
import os
import numpy as np
import json
import shutil
from segment_anything_model.segment_anything import sam_model_registry, SamPredictor
import random
import traceback
import base64
import io
from work_util.logger import logger
import pickle
from PIL import Image
def get_display_image_ori(loaded_data):
if loaded_data['image_rgb'] is None:
logger.info("get_display_image: No image loaded")
return {"status": False,"reason":"无图片加载"}
display_image = loaded_data['image_rgb'].copy()
try:
for point, label in zip(loaded_data['input_point'], loaded_data['input_label']):
color = (0, 255, 0) if label == 1 else (0, 0, 255)
cv2.circle(display_image, tuple(point), 5, color, -1)
if loaded_data['selected_mask'] is not None:
class_name = loaded_data['class_names'][loaded_data['class_index']]
color = loaded_data['class_colors'].get(class_name, (0, 0, 128))
display_image = sam_apply_mask(display_image, loaded_data['selected_mask'], color)
logger.info(f"get_display_image: Returning image with shape {display_image.shape}")
return {"status" : True, "reason":display_image}
except Exception as e:
logger.info(f"get_display_image: Error processing image: {str(e)}")
traceback.print_exc()
return {"status": False,"reason":f"get_display_image: Error processing image: {str(e)}"}
def get_all_classes(data):
return {
"classes": data['class_names'],
"colors": {name: color for name, color in data['class_colors'].items()}
}
def get_classes(data):
return get_all_classes(data)
def reset_annotation(loaded_data):
loaded_data['input_point'] = []
loaded_data['input_label'] = []
loaded_data['selected_mask'] = None
loaded_data['logit_input'] = None
loaded_data['masks'] = {}
logger.info("已重置标注状态")
return loaded_data
def remove_class(data,api,class_name):
if class_name in api['class_annotations']:
del api['class_annotations'][class_name]
if class_name == data['class_names'][data['class_index']]:
data['class_index'] = 0
data['class_names'].remove(class_name)
if class_name in data['class_colors']:
del data['class_colors'][class_name]
if class_name in data['masks']:
del data['masks'][class_name]
return data,api
def reset_annotation_all(data,api):
data = reset_annotation(data)
for class_data in api['class_annotations'].values():
class_data['points'] = []
class_data['point_types'] = []
class_data['masks'] = []
class_data['scores'] = []
class_data['selected_mask_index'] = -1
if 'selected_mask' in class_data:
del class_data['selected_mask']
return data,api
def reset_class_annotations(data,api):
"""重置class_annotations保留类别但清除掩码和点"""
classes = get_classes(data).get('classes', [])
new_annotations = {}
for class_name in classes:
new_annotations[class_name] = {
'points': [],
'point_types': [],
'masks': [],
'selected_mask_index': -1
}
api['class_annotations'] = new_annotations
logger.info("已重置class_annotations保留类别但清除掩码和点")
return data,api
def add_class(data,class_name, color=None):
if class_name in data['class_names']:
logger.info(f"类别 '{class_name}' 已存在")
return {'status':False, 'reason':f"类别 '{class_name}' 已存在"}
data['class_names'].append(class_name)
if color is None:
color = tuple(np.random.randint(100, 256, 3).tolist())
r, g, b = [int(c) for c in color]
bgr_color = (b, g, r)
data['class_colors'][class_name] = tuple(bgr_color)
logger.info(f"已添加类别: {class_name}, 颜色: {tuple(color)}")
return {'status':True, 'reason':data}
def set_current_class(data, api, class_index, color=None):
classes = get_classes(data)
if 'classes' in classes and class_index < len(classes['classes']):
class_name = classes['classes'][class_index]
api['current_class'] = class_name
if class_name not in api['class_annotations']:
api['class_annotations'][class_name] = {
'points': [],
'point_types': [],
'masks': [],
'selected_mask_index': -1
}
color = data['class_colors'][class_name]
if color:
api['class_colors'][class_name] = color
elif class_name not in api['class_colors']:
predefined_colors = [
(255, 0, 0), (0, 255, 0), (0, 0, 255),
(255, 255, 0), (255, 0, 255), (0, 255, 255)
]
color_index = len(api['class_colors']) % len(predefined_colors)
api['class_colors'][class_name] = predefined_colors[color_index]
return class_name,api
return None,api
def add_point(data, x, y, is_foreground=True):
data['input_point'].append([x, y])
data['input_label'].append(1 if is_foreground else 0)
logger.info(f"添加{'前景' if is_foreground else '背景'}点: ({x}, {y})")
return data
def load_model(path):
# 加载配置内容
config_path = os.path.join(path,'model_params.pickle')
with open(config_path, 'rb') as file:
loaded_data,api_config = pickle.load(file)
return loaded_data,api_config
def save_model(loaded_data,api_config,path):
config_path = os.path.join(path,'model_params.pickle')
save_data = (loaded_data,api_config)
with open(config_path, 'wb') as file:
pickle.dump(save_data, file)
def sam_apply_mask(image, mask, color, alpha=0.5):
masked_image = image.copy()
for c in range(3):
masked_image[:, :, c] = np.where(
mask == 1,
image[:, :, c] * (1 - alpha) + alpha * color[c],
image[:, :, c]
)
return masked_image
def apply_mask_overlay(image, mask, color, alpha=0.5):
colored_mask = np.zeros_like(image)
colored_mask[mask > 0] = color
return cv2.addWeighted(image, 1, colored_mask, alpha, 0)
def load_tmp_image(path):
with open(path, "rb") as image_file:
# 将图片文件读取为二进制数据并进行 Base64 编码
encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
return encoded_string
def refresh_image(data,api,path):
result = get_image_display(data,api)
if result['status'] == False:
return {
"status" : False,
"reason" : result['reason']
}
else:
img = result['reason']
display_img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
pil_img = Image.fromarray(display_img_rgb)
tmp_path = os.path.join(path,'temp/output_image.jpg')
pil_img.save(tmp_path)
return {
"status" : True,
"reason" : tmp_path
}
def set_class_color(data,api,index,color):
class_name = data['class_names'][index]
api['class_colors'][class_name] = color
data['class_colors'][class_name] = color
return data,api
def add_to_class(api,class_name):
if class_name is None:
class_name = api['current_class']
if not class_name or class_name not in api['class_annotations']:
return {
'status':False,
'reason': "该分类没有标注点"
}
class_data = api['class_annotations'][class_name]
if 'selected_mask' not in class_data or class_data['selected_mask'] is None:
return {
'status':False,
'reason': "该分类没有进行预测"
}
class_data['final_mask'] = class_data['selected_mask'].copy()
class_data['points'] = []
class_data['point_types'] = []
class_data['masks'] = []
class_data['scores'] = []
class_data['selected_mask_index'] = -1
return {
'status':True,
'reason': api
}
def get_image_display(data,api):
if data['image'] is None:
logger.info("get_display_image: No image loaded")
return {"status":False, "reason":"获取图像:没有图像加载"}
display_image = data['image'].copy()
try:
for point, label in zip(data['input_point'], data['input_label']):
color = (0, 255, 0) if label == 1 else (0, 0, 255)
cv2.circle(display_image, tuple(point), 5, color, -1)
if data['selected_mask'] is not None:
class_name = data['class_names'][data['class_index']]
color = data['class_colors'].get(class_name, (0, 0, 128))
display_image = sam_apply_mask(display_image, data['selected_mask'], color)
logger.info(f"get_display_image: Returning image with shape {display_image.shape}")
img = display_image
if not isinstance(img, np.ndarray) or img.size == 0:
logger.info(f"get_image_display: Invalid image array, shape: {img.shape if isinstance(img, np.ndarray) else 'None'}")
return {"status":False, "reason":f"get_image_display: Invalid image array, shape: {display_image.shape if isinstance(display_image, np.ndarray) else 'None'}"}
# 仅应用当前图片的final_mask
for class_name, class_data in api['class_annotations'].items():
if 'final_mask' in class_data and class_data['final_mask'] is not None:
color = api['class_colors'].get(class_name, (0, 255, 0))
mask = class_data['final_mask']
if isinstance(mask, list):
mask = np.array(mask, dtype=np.uint8)
logger.info(f"Applying mask for class {class_name}, shape: {mask.shape}")
img = apply_mask_overlay(img, mask, color, alpha=0.5)
elif 'selected_mask' in class_data and class_data['selected_mask'] is not None:
color = api['class_colors'].get(class_name, (0, 255, 0))
mask = class_data['selected_mask']
if isinstance(mask, list):
mask = np.array(mask, dtype=np.uint8)
logger.info(f"Applying mask for class {class_name}, shape: {mask.shape}")
img = apply_mask_overlay(img, mask, color, alpha=0.5)
if api['current_class'] and api['current_class'] in api['class_annotations']:
class_data = api['class_annotations'][api['current_class']]
for i, (x, y) in enumerate(class_data['points']):
is_fg = class_data['point_types'][i]
color = (0, 255, 0) if is_fg else (0, 0, 255)
print(f"Drawing point at ({x}, {y}), type: {'foreground' if is_fg else 'background'}")
cv2.circle(img, (int(x), int(y)), 5, color, -1)
logger.info(f"get_image_display: Returning image with shape {img.shape}")
return {"status":True, "reason":img}
except Exception as e:
logger.error(f"get_display_image: Error processing image: {str(e)}")
traceback.print_exc()
return {"status":False, "reason":f"get_display_image: Error processing image: {str(e)}"}
def add_annotation_point(api, x, y, is_foreground=True):
if not api['current_class'] or api['current_class'] not in api['class_annotations']:
return {
"status": False,
"reason": "请选择或新建分类"
}
class_data = api['class_annotations'][api['current_class']]
class_data['points'].append((x, y))
class_data['point_types'].append(is_foreground)
return {
'status' : True,
'reason' : {
'points': class_data['points'],
'types': class_data['point_types']
},
'api':api
}
def delete_last_point(api):
if not api['current_class'] or api['current_class'] not in api['class_annotations']:
return {
"status": False,
"reason": "当前类没有点需要删除"
}
class_data = api['class_annotations'][api['current_class']]
if not class_data['points']:
return {
"status": False,
"reason": "当前类没有点需要删除"
}
class_data['points'].pop()
class_data['point_types'].pop()
return {
"status": True,
"reason": api
}
def reset_current_class_points(api):
if not api['current_class'] or api['current_class'] not in api['class_annotations']:
return {
"status": False,
"reason": "当前类没有点需要删除"
}
class_data = api['class_annotations'][api['current_class']]
class_data['points'] = []
class_data['point_types'] = []
return {
"status": True,
"reason": api
}
def predict_mask(data,sam_predictor):
if data['image_rgb'] is None:
logger.info("predict_mask: No image loaded")
return {"status":False, "reason":"预测掩码:没有图像加载"}
if len(data['input_point']) == 0:
logger.info("predict_mask: No points added")
return {"status":False, "reason":"预测掩码:没有进行点标注"}
try:
sam_predictor.set_image(data['image_rgb'])
except Exception as e:
logger.error(f"predict_mask: Error setting image: {str(e)}")
return {"status":False, "reason":f"predict_mask: Error setting image: {str(e)}"}
input_point_np = np.array(data['input_point'])
input_label_np = np.array(data['input_label'])
try:
masks_pred, scores, logits = sam_predictor.predict(
point_coords=input_point_np,
point_labels=input_label_np,
mask_input=data['logit_input'][None, :, :] if data['logit_input'] is not None else None,
multimask_output=True,
)
except Exception as e:
logger.error(f"predict_mask: Error during prediction: {str(e)}")
traceback.print_exc()
return {"status":False, "reason":f"predict_mask: Error during prediction: {str(e)}"}
data['masks_pred'] = masks_pred
data['scores'] = scores
data['logits'] = logits
best_mask_idx = np.argmax(scores)
data['selected_mask'] = masks_pred[best_mask_idx]
data['logit_input'] = logits[best_mask_idx, :, :]
logger.info(f"predict_mask: Predicted {len(masks_pred)} masks, best score: {scores[best_mask_idx]:.4f}")
return {
"status":True,
"reason":{
"masks": [mask.tolist() for mask in masks_pred],
"scores": scores.tolist(),
"selected_index": int(best_mask_idx),
"data":data
}}
def get_image_info(data):
return {
"filename": data['filename'],
"index": data['current_index'],
"total": len(data['image_files']),
"width": data['image'].shape[1],
"height": data['image'].shape[0]
}