400 lines
15 KiB
Python
400 lines
15 KiB
Python
import cv2
|
||
import os
|
||
import numpy as np
|
||
import json
|
||
import shutil
|
||
from segment_anything_model.segment_anything import sam_model_registry, SamPredictor
|
||
import random
|
||
import traceback
|
||
import base64
|
||
import io
|
||
from work_util.logger import logger
|
||
import pickle
|
||
from PIL import Image
|
||
|
||
|
||
|
||
def get_display_image_ori(loaded_data):
|
||
if loaded_data['image_rgb'] is None:
|
||
logger.info("get_display_image: No image loaded")
|
||
return {"status": False,"reason":"无图片加载"}
|
||
display_image = loaded_data['image_rgb'].copy()
|
||
try:
|
||
for point, label in zip(loaded_data['input_point'], loaded_data['input_label']):
|
||
color = (0, 255, 0) if label == 1 else (0, 0, 255)
|
||
cv2.circle(display_image, tuple(point), 5, color, -1)
|
||
if loaded_data['selected_mask'] is not None:
|
||
class_name = loaded_data['class_names'][loaded_data['class_index']]
|
||
color = loaded_data['class_colors'].get(class_name, (0, 0, 128))
|
||
display_image = sam_apply_mask(display_image, loaded_data['selected_mask'], color)
|
||
logger.info(f"get_display_image: Returning image with shape {display_image.shape}")
|
||
return {"status" : True, "reason":display_image}
|
||
except Exception as e:
|
||
logger.info(f"get_display_image: Error processing image: {str(e)}")
|
||
traceback.print_exc()
|
||
return {"status": False,"reason":f"get_display_image: Error processing image: {str(e)}"}
|
||
|
||
def get_all_classes(data):
|
||
return {
|
||
"classes": data['class_names'],
|
||
"colors": {name: color for name, color in data['class_colors'].items()}
|
||
}
|
||
|
||
def get_classes(data):
|
||
return get_all_classes(data)
|
||
|
||
|
||
def reset_annotation(loaded_data):
|
||
loaded_data['input_point'] = []
|
||
loaded_data['input_label'] = []
|
||
loaded_data['selected_mask'] = None
|
||
loaded_data['logit_input'] = None
|
||
loaded_data['masks'] = {}
|
||
logger.info("已重置标注状态")
|
||
return loaded_data
|
||
|
||
|
||
def remove_class(data,api,class_name):
|
||
if class_name in api['class_annotations']:
|
||
del api['class_annotations'][class_name]
|
||
del api['class_colors'][class_name]
|
||
if class_name == data['class_names'][data['class_index']]:
|
||
# data['class_index'] = 0
|
||
data['class_index'] = -1
|
||
api['current_class'] = None
|
||
data['class_names'].remove(class_name)
|
||
if class_name in data['class_colors']:
|
||
del data['class_colors'][class_name]
|
||
if class_name in data['masks']:
|
||
del data['masks'][class_name]
|
||
return data,api
|
||
|
||
|
||
def reset_annotation_all(data,api):
|
||
data = reset_annotation(data)
|
||
for class_data in api['class_annotations'].values():
|
||
class_data['points'] = []
|
||
class_data['point_types'] = []
|
||
class_data['masks'] = []
|
||
class_data['scores'] = []
|
||
class_data['selected_mask_index'] = -1
|
||
if 'selected_mask' in class_data:
|
||
del class_data['selected_mask']
|
||
return data,api
|
||
|
||
|
||
def reset_class_annotations(data,api):
|
||
"""重置class_annotations,保留类别但清除掩码和点"""
|
||
classes = get_classes(data).get('classes', [])
|
||
new_annotations = {}
|
||
for class_name in classes:
|
||
new_annotations[class_name] = {
|
||
'points': [],
|
||
'point_types': [],
|
||
'masks': [],
|
||
'selected_mask_index': -1
|
||
}
|
||
api['class_annotations'] = new_annotations
|
||
logger.info("已重置class_annotations,保留类别但清除掩码和点")
|
||
return data,api
|
||
|
||
|
||
def add_class(data,class_name, color=None):
|
||
if class_name in data['class_names']:
|
||
logger.info(f"类别 '{class_name}' 已存在")
|
||
return {'status':False, 'reason':f"类别 '{class_name}' 已存在"}
|
||
data['class_names'].append(class_name)
|
||
if color is None:
|
||
color = tuple(np.random.randint(100, 256, 3).tolist())
|
||
r, g, b = [int(c) for c in color]
|
||
bgr_color = (b, g, r)
|
||
data['class_colors'][class_name] = tuple(bgr_color)
|
||
logger.info(f"已添加类别: {class_name}, 颜色: {tuple(color)}")
|
||
return {'status':True, 'reason':data}
|
||
|
||
|
||
def set_current_class(data, api, class_index, color=None):
|
||
classes = get_classes(data)
|
||
if 'classes' in classes and class_index < len(classes['classes']):
|
||
class_name = classes['classes'][class_index]
|
||
api['current_class'] = class_name
|
||
if class_name not in api['class_annotations']:
|
||
api['class_annotations'][class_name] = {
|
||
'points': [],
|
||
'point_types': [],
|
||
'masks': [],
|
||
'selected_mask_index': -1
|
||
}
|
||
color = data['class_colors'][class_name]
|
||
if color:
|
||
api['class_colors'][class_name] = color
|
||
elif class_name not in api['class_colors']:
|
||
predefined_colors = [
|
||
(255, 0, 0), (0, 255, 0), (0, 0, 255),
|
||
(255, 255, 0), (255, 0, 255), (0, 255, 255)
|
||
]
|
||
color_index = len(api['class_colors']) % len(predefined_colors)
|
||
api['class_colors'][class_name] = predefined_colors[color_index]
|
||
return class_name,api
|
||
return None,api
|
||
|
||
|
||
def add_point(data, x, y, is_foreground=True):
|
||
data['input_point'].append([x, y])
|
||
data['input_label'].append(1 if is_foreground else 0)
|
||
logger.info(f"添加{'前景' if is_foreground else '背景'}点: ({x}, {y})")
|
||
return data
|
||
|
||
|
||
|
||
def load_model(path):
|
||
# 加载配置内容
|
||
config_path = os.path.join(path,'model_params.pickle')
|
||
with open(config_path, 'rb') as file:
|
||
loaded_data,api_config = pickle.load(file)
|
||
return loaded_data,api_config
|
||
|
||
def save_model(loaded_data,api_config,path):
|
||
config_path = os.path.join(path,'model_params.pickle')
|
||
save_data = (loaded_data,api_config)
|
||
with open(config_path, 'wb') as file:
|
||
pickle.dump(save_data, file)
|
||
|
||
|
||
def sam_apply_mask(image, mask, color, alpha=0.5):
|
||
masked_image = image.copy()
|
||
for c in range(3):
|
||
masked_image[:, :, c] = np.where(
|
||
mask == 1,
|
||
image[:, :, c] * (1 - alpha) + alpha * color[c],
|
||
image[:, :, c]
|
||
)
|
||
return masked_image
|
||
|
||
|
||
def apply_mask_overlay(image, mask, color, alpha=0.5):
|
||
colored_mask = np.zeros_like(image)
|
||
colored_mask[mask > 0] = color
|
||
return cv2.addWeighted(image, 1, colored_mask, alpha, 0)
|
||
|
||
def load_tmp_image(path):
|
||
with open(path, "rb") as image_file:
|
||
# 将图片文件读取为二进制数据并进行 Base64 编码
|
||
encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
|
||
return encoded_string
|
||
|
||
def refresh_image(data,api,path):
|
||
result = get_image_display(data,api)
|
||
if result['status'] == False:
|
||
return {
|
||
"status" : False,
|
||
"reason" : result['reason']
|
||
}
|
||
else:
|
||
img = result['reason']
|
||
display_img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||
pil_img = Image.fromarray(display_img_rgb)
|
||
tmp_path = os.path.join(path,'temp/output_image.jpg')
|
||
pil_img.save(tmp_path)
|
||
return {
|
||
"status" : True,
|
||
"reason" : tmp_path
|
||
}
|
||
|
||
|
||
def set_class_color(data,api,index,color):
|
||
class_name = data['class_names'][index]
|
||
api['class_colors'][class_name] = color
|
||
data['class_colors'][class_name] = color
|
||
return data,api
|
||
|
||
|
||
|
||
def add_to_class(api,class_name):
|
||
if class_name is None:
|
||
class_name = api['current_class']
|
||
if not class_name or class_name not in api['class_annotations']:
|
||
return {
|
||
'status':False,
|
||
'reason': "该分类没有标注点"
|
||
}
|
||
class_data = api['class_annotations'][class_name]
|
||
if 'selected_mask' not in class_data or class_data['selected_mask'] is None:
|
||
return {
|
||
'status':False,
|
||
'reason': "该分类没有进行预测"
|
||
}
|
||
class_data['final_mask'] = class_data['selected_mask'].copy()
|
||
class_data['points'] = []
|
||
class_data['point_types'] = []
|
||
class_data['masks'] = []
|
||
class_data['scores'] = []
|
||
class_data['selected_mask_index'] = -1
|
||
return {
|
||
'status':True,
|
||
'reason': api
|
||
}
|
||
|
||
|
||
|
||
def get_image_display(data,api):
|
||
if data['image'] is None:
|
||
logger.info("get_display_image: No image loaded")
|
||
return {"status":False, "reason":"获取图像:没有图像加载"}
|
||
display_image = data['image'].copy()
|
||
try:
|
||
for point, label in zip(data['input_point'], data['input_label']):
|
||
color = (0, 255, 0) if label == 1 else (0, 0, 255)
|
||
cv2.circle(display_image, tuple(point), 5, color, -1)
|
||
if data['class_index'] != -1:
|
||
if data['selected_mask'] is not None:
|
||
class_name = data['class_names'][data['class_index']]
|
||
color = data['class_colors'].get(class_name, (0, 0, 128))
|
||
display_image = sam_apply_mask(display_image, data['selected_mask'], color)
|
||
logger.info(f"get_display_image: Returning image with shape {display_image.shape}")
|
||
img = display_image
|
||
if not isinstance(img, np.ndarray) or img.size == 0:
|
||
logger.info(f"get_image_display: Invalid image array, shape: {img.shape if isinstance(img, np.ndarray) else 'None'}")
|
||
return {"status":False, "reason":f"get_image_display: Invalid image array, shape: {display_image.shape if isinstance(display_image, np.ndarray) else 'None'}"}
|
||
# 仅应用当前图片的final_mask
|
||
for class_name, class_data in api['class_annotations'].items():
|
||
if 'final_mask' in class_data and class_data['final_mask'] is not None:
|
||
color = api['class_colors'].get(class_name, (0, 255, 0))
|
||
mask = class_data['final_mask']
|
||
if isinstance(mask, list):
|
||
mask = np.array(mask, dtype=np.uint8)
|
||
logger.info(f"Applying mask for class {class_name}, shape: {mask.shape}")
|
||
img = apply_mask_overlay(img, mask, color, alpha=0.5)
|
||
elif 'selected_mask' in class_data and class_data['selected_mask'] is not None:
|
||
color = api['class_colors'].get(class_name, (0, 255, 0))
|
||
mask = class_data['selected_mask']
|
||
if isinstance(mask, list):
|
||
mask = np.array(mask, dtype=np.uint8)
|
||
logger.info(f"Applying mask for class {class_name}, shape: {mask.shape}")
|
||
img = apply_mask_overlay(img, mask, color, alpha=0.5)
|
||
if api['current_class'] and api['current_class'] in api['class_annotations']:
|
||
class_data = api['class_annotations'][api['current_class']]
|
||
for i, (x, y) in enumerate(class_data['points']):
|
||
is_fg = class_data['point_types'][i]
|
||
color = (0, 255, 0) if is_fg else (0, 0, 255)
|
||
print(f"Drawing point at ({x}, {y}), type: {'foreground' if is_fg else 'background'}")
|
||
cv2.circle(img, (int(x), int(y)), 5, color, -1)
|
||
logger.info(f"get_image_display: Returning image with shape {img.shape}")
|
||
return {"status":True, "reason":img}
|
||
except Exception as e:
|
||
logger.error(f"get_display_image: Error processing image: {str(e)}")
|
||
traceback.print_exc()
|
||
return {"status":False, "reason":f"get_display_image: Error processing image: {str(e)}"}
|
||
|
||
|
||
def add_annotation_point(api, x, y, is_foreground=True):
|
||
if not api['current_class'] or api['current_class'] not in api['class_annotations']:
|
||
return {
|
||
"status": False,
|
||
"reason": "请选择或新建分类"
|
||
}
|
||
class_data = api['class_annotations'][api['current_class']]
|
||
|
||
class_data['points'].append((x, y))
|
||
class_data['point_types'].append(is_foreground)
|
||
return {
|
||
'status' : True,
|
||
'reason' : {
|
||
'points': class_data['points'],
|
||
'types': class_data['point_types']
|
||
},
|
||
'api':api
|
||
}
|
||
|
||
|
||
|
||
def delete_last_point(api):
|
||
if not api['current_class'] or api['current_class'] not in api['class_annotations']:
|
||
return {
|
||
"status": False,
|
||
"reason": "当前类没有点需要删除"
|
||
}
|
||
class_data = api['class_annotations'][api['current_class']]
|
||
if not class_data['points']:
|
||
return {
|
||
"status": False,
|
||
"reason": "当前类没有点需要删除"
|
||
}
|
||
class_data['points'].pop()
|
||
class_data['point_types'].pop()
|
||
return {
|
||
"status": True,
|
||
"reason": api
|
||
}
|
||
|
||
|
||
|
||
def reset_current_class_points(api):
|
||
if not api['current_class'] or api['current_class'] not in api['class_annotations']:
|
||
return {
|
||
"status": False,
|
||
"reason": "当前类没有点需要删除"
|
||
}
|
||
class_data = api['class_annotations'][api['current_class']]
|
||
class_data['points'] = []
|
||
class_data['point_types'] = []
|
||
return {
|
||
"status": True,
|
||
"reason": api
|
||
}
|
||
|
||
|
||
|
||
|
||
def predict_mask(data,sam_predictor):
|
||
if data['image_rgb'] is None:
|
||
logger.info("predict_mask: No image loaded")
|
||
return {"status":False, "reason":"预测掩码:没有图像加载"}
|
||
if len(data['input_point']) == 0:
|
||
logger.info("predict_mask: No points added")
|
||
return {"status":False, "reason":"预测掩码:没有进行点标注"}
|
||
try:
|
||
sam_predictor.set_image(data['image_rgb'])
|
||
except Exception as e:
|
||
logger.error(f"predict_mask: Error setting image: {str(e)}")
|
||
return {"status":False, "reason":f"predict_mask: Error setting image: {str(e)}"}
|
||
input_point_np = np.array(data['input_point'])
|
||
input_label_np = np.array(data['input_label'])
|
||
try:
|
||
masks_pred, scores, logits = sam_predictor.predict(
|
||
point_coords=input_point_np,
|
||
point_labels=input_label_np,
|
||
mask_input=data['logit_input'][None, :, :] if data['logit_input'] is not None else None,
|
||
multimask_output=True,
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"predict_mask: Error during prediction: {str(e)}")
|
||
traceback.print_exc()
|
||
return {"status":False, "reason":f"predict_mask: Error during prediction: {str(e)}"}
|
||
data['masks_pred'] = masks_pred
|
||
data['scores'] = scores
|
||
data['logits'] = logits
|
||
best_mask_idx = np.argmax(scores)
|
||
data['selected_mask'] = masks_pred[best_mask_idx]
|
||
data['logit_input'] = logits[best_mask_idx, :, :]
|
||
logger.info(f"predict_mask: Predicted {len(masks_pred)} masks, best score: {scores[best_mask_idx]:.4f}")
|
||
return {
|
||
"status":True,
|
||
"reason":{
|
||
"masks": [mask.tolist() for mask in masks_pred],
|
||
"scores": scores.tolist(),
|
||
"selected_index": int(best_mask_idx),
|
||
"data":data
|
||
}}
|
||
|
||
|
||
def get_image_info(data):
|
||
|
||
return {
|
||
"filename": data['filename'],
|
||
"index": data['current_index'],
|
||
"total": len(data['image_files']),
|
||
"width": data['image'].shape[1],
|
||
"height": data['image'].shape[0]
|
||
}
|
||
|