import cv2 import os import numpy as np import json import shutil from segment_anything_model.segment_anything import sam_model_registry, SamPredictor import random import traceback import base64 import io from work_util.logger import logger import pickle from PIL import Image def get_display_image_ori(loaded_data): if loaded_data['image_rgb'] is None: logger.info("get_display_image: No image loaded") return {"status": False,"reason":"无图片加载"} display_image = loaded_data['image_rgb'].copy() try: for point, label in zip(loaded_data['input_point'], loaded_data['input_label']): color = (0, 255, 0) if label == 1 else (0, 0, 255) cv2.circle(display_image, tuple(point), 5, color, -1) if loaded_data['selected_mask'] is not None: class_name = loaded_data['class_names'][loaded_data['class_index']] color = loaded_data['class_colors'].get(class_name, (0, 0, 128)) display_image = sam_apply_mask(display_image, loaded_data['selected_mask'], color) logger.info(f"get_display_image: Returning image with shape {display_image.shape}") return {"status" : True, "reason":display_image} except Exception as e: logger.info(f"get_display_image: Error processing image: {str(e)}") traceback.print_exc() return {"status": False,"reason":f"get_display_image: Error processing image: {str(e)}"} def get_all_classes(data): return { "classes": data['class_names'], "colors": {name: color for name, color in data['class_colors'].items()} } def get_classes(data): return get_all_classes(data) def reset_annotation(loaded_data): loaded_data['input_point'] = [] loaded_data['input_label'] = [] loaded_data['selected_mask'] = None loaded_data['logit_input'] = None loaded_data['masks'] = {} logger.info("已重置标注状态") return loaded_data def remove_class(data,api,class_name): if class_name in api['class_annotations']: del api['class_annotations'][class_name] if class_name == data['class_names'][data['class_index']]: data['class_index'] = 0 data['class_names'].remove(class_name) if class_name in data['class_colors']: del data['class_colors'][class_name] if class_name in data['masks']: del data['masks'][class_name] return data,api def reset_annotation_all(data,api): data = reset_annotation(data) for class_data in api['class_annotations'].values(): class_data['points'] = [] class_data['point_types'] = [] class_data['masks'] = [] class_data['scores'] = [] class_data['selected_mask_index'] = -1 if 'selected_mask' in class_data: del class_data['selected_mask'] return data,api def reset_class_annotations(data,api): """重置class_annotations,保留类别但清除掩码和点""" classes = get_classes(data).get('classes', []) new_annotations = {} for class_name in classes: new_annotations[class_name] = { 'points': [], 'point_types': [], 'masks': [], 'selected_mask_index': -1 } api['class_annotations'] = new_annotations logger.info("已重置class_annotations,保留类别但清除掩码和点") return data,api def add_class(data,class_name, color=None): if class_name in data['class_names']: logger.info(f"类别 '{class_name}' 已存在") return {'status':False, 'reason':f"类别 '{class_name}' 已存在"} data['class_names'].append(class_name) if color is None: color = tuple(np.random.randint(100, 256, 3).tolist()) r, g, b = [int(c) for c in color] bgr_color = (b, g, r) data['class_colors'][class_name] = tuple(bgr_color) logger.info(f"已添加类别: {class_name}, 颜色: {tuple(color)}") return {'status':True, 'reason':data} def set_current_class(data, api, class_index, color=None): classes = get_classes(data) if 'classes' in classes and class_index < len(classes['classes']): class_name = classes['classes'][class_index] api['current_class'] = class_name if class_name not in api['class_annotations']: api['class_annotations'][class_name] = { 'points': [], 'point_types': [], 'masks': [], 'selected_mask_index': -1 } color = data['class_colors'][class_name] if color: api['class_colors'][class_name] = color elif class_name not in api['class_colors']: predefined_colors = [ (255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255), (0, 255, 255) ] color_index = len(api['class_colors']) % len(predefined_colors) api['class_colors'][class_name] = predefined_colors[color_index] return class_name,api return None,api def add_point(data, x, y, is_foreground=True): data['input_point'].append([x, y]) data['input_label'].append(1 if is_foreground else 0) logger.info(f"添加{'前景' if is_foreground else '背景'}点: ({x}, {y})") return data def load_model(path): # 加载配置内容 config_path = os.path.join(path,'model_params.pickle') with open(config_path, 'rb') as file: loaded_data,api_config = pickle.load(file) return loaded_data,api_config def save_model(loaded_data,api_config,path): config_path = os.path.join(path,'model_params.pickle') save_data = (loaded_data,api_config) with open(config_path, 'wb') as file: pickle.dump(save_data, file) def sam_apply_mask(image, mask, color, alpha=0.5): masked_image = image.copy() for c in range(3): masked_image[:, :, c] = np.where( mask == 1, image[:, :, c] * (1 - alpha) + alpha * color[c], image[:, :, c] ) return masked_image def apply_mask_overlay(image, mask, color, alpha=0.5): colored_mask = np.zeros_like(image) colored_mask[mask > 0] = color return cv2.addWeighted(image, 1, colored_mask, alpha, 0) def load_tmp_image(path): with open(path, "rb") as image_file: # 将图片文件读取为二进制数据并进行 Base64 编码 encoded_string = base64.b64encode(image_file.read()).decode('utf-8') return encoded_string def refresh_image(data,api,path): result = get_image_display(data,api) if result['status'] == False: return { "status" : False, "reason" : result['reason'] } else: img = result['reason'] display_img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) pil_img = Image.fromarray(display_img_rgb) tmp_path = os.path.join(path,'temp/output_image.jpg') pil_img.save(tmp_path) return { "status" : True, "reason" : tmp_path } def set_class_color(data,api,index,color): class_name = data['class_names'][index] api['class_colors'][class_name] = color data['class_colors'][class_name] = color return data,api def add_to_class(api,class_name): if class_name is None: class_name = api['current_class'] if not class_name or class_name not in api['class_annotations']: return { 'status':False, 'reason': "该分类没有标注点" } class_data = api['class_annotations'][class_name] if 'selected_mask' not in class_data or class_data['selected_mask'] is None: return { 'status':False, 'reason': "该分类没有进行预测" } class_data['final_mask'] = class_data['selected_mask'].copy() class_data['points'] = [] class_data['point_types'] = [] class_data['masks'] = [] class_data['scores'] = [] class_data['selected_mask_index'] = -1 return { 'status':True, 'reason': api } def get_image_display(data,api): if data['image'] is None: logger.info("get_display_image: No image loaded") return {"status":False, "reason":"获取图像:没有图像加载"} display_image = data['image'].copy() try: for point, label in zip(data['input_point'], data['input_label']): color = (0, 255, 0) if label == 1 else (0, 0, 255) cv2.circle(display_image, tuple(point), 5, color, -1) if data['selected_mask'] is not None: class_name = data['class_names'][data['class_index']] color = data['class_colors'].get(class_name, (0, 0, 128)) display_image = sam_apply_mask(display_image, data['selected_mask'], color) logger.info(f"get_display_image: Returning image with shape {display_image.shape}") img = display_image if not isinstance(img, np.ndarray) or img.size == 0: logger.info(f"get_image_display: Invalid image array, shape: {img.shape if isinstance(img, np.ndarray) else 'None'}") return {"status":False, "reason":f"get_image_display: Invalid image array, shape: {display_image.shape if isinstance(display_image, np.ndarray) else 'None'}"} # 仅应用当前图片的final_mask for class_name, class_data in api['class_annotations'].items(): if 'final_mask' in class_data and class_data['final_mask'] is not None: color = api['class_colors'].get(class_name, (0, 255, 0)) mask = class_data['final_mask'] if isinstance(mask, list): mask = np.array(mask, dtype=np.uint8) logger.info(f"Applying mask for class {class_name}, shape: {mask.shape}") img = apply_mask_overlay(img, mask, color, alpha=0.5) elif 'selected_mask' in class_data and class_data['selected_mask'] is not None: color = api['class_colors'].get(class_name, (0, 255, 0)) mask = class_data['selected_mask'] if isinstance(mask, list): mask = np.array(mask, dtype=np.uint8) logger.info(f"Applying mask for class {class_name}, shape: {mask.shape}") img = apply_mask_overlay(img, mask, color, alpha=0.5) if api['current_class'] and api['current_class'] in api['class_annotations']: class_data = api['class_annotations'][api['current_class']] for i, (x, y) in enumerate(class_data['points']): is_fg = class_data['point_types'][i] color = (0, 255, 0) if is_fg else (0, 0, 255) print(f"Drawing point at ({x}, {y}), type: {'foreground' if is_fg else 'background'}") cv2.circle(img, (int(x), int(y)), 5, color, -1) logger.info(f"get_image_display: Returning image with shape {img.shape}") return {"status":True, "reason":img} except Exception as e: logger.error(f"get_display_image: Error processing image: {str(e)}") traceback.print_exc() return {"status":False, "reason":f"get_display_image: Error processing image: {str(e)}"} def add_annotation_point(api, x, y, is_foreground=True): if not api['current_class'] or api['current_class'] not in api['class_annotations']: return { "status": False, "reason": "请选择或新建分类" } class_data = api['class_annotations'][api['current_class']] class_data['points'].append((x, y)) class_data['point_types'].append(is_foreground) return { 'status' : True, 'reason' : { 'points': class_data['points'], 'types': class_data['point_types'] }, 'api':api } def delete_last_point(api): if not api['current_class'] or api['current_class'] not in api['class_annotations']: return { "status": False, "reason": "当前类没有点需要删除" } class_data = api['class_annotations'][api['current_class']] if not class_data['points']: return { "status": False, "reason": "当前类没有点需要删除" } class_data['points'].pop() class_data['point_types'].pop() return { "status": True, "reason": api } def reset_current_class_points(api): if not api['current_class'] or api['current_class'] not in api['class_annotations']: return { "status": False, "reason": "当前类没有点需要删除" } class_data = api['class_annotations'][api['current_class']] class_data['points'] = [] class_data['point_types'] = [] return { "status": True, "reason": api } def predict_mask(data,sam_predictor): if data['image_rgb'] is None: logger.info("predict_mask: No image loaded") return {"status":False, "reason":"预测掩码:没有图像加载"} if len(data['input_point']) == 0: logger.info("predict_mask: No points added") return {"status":False, "reason":"预测掩码:没有进行点标注"} try: sam_predictor.set_image(data['image_rgb']) except Exception as e: logger.error(f"predict_mask: Error setting image: {str(e)}") return {"status":False, "reason":f"predict_mask: Error setting image: {str(e)}"} input_point_np = np.array(data['input_point']) input_label_np = np.array(data['input_label']) try: masks_pred, scores, logits = sam_predictor.predict( point_coords=input_point_np, point_labels=input_label_np, mask_input=data['logit_input'][None, :, :] if data['logit_input'] is not None else None, multimask_output=True, ) except Exception as e: logger.error(f"predict_mask: Error during prediction: {str(e)}") traceback.print_exc() return {"status":False, "reason":f"predict_mask: Error during prediction: {str(e)}"} data['masks_pred'] = masks_pred data['scores'] = scores data['logits'] = logits best_mask_idx = np.argmax(scores) data['selected_mask'] = masks_pred[best_mask_idx] data['logit_input'] = logits[best_mask_idx, :, :] logger.info(f"predict_mask: Predicted {len(masks_pred)} masks, best score: {scores[best_mask_idx]:.4f}") return { "status":True, "reason":{ "masks": [mask.tolist() for mask in masks_pred], "scores": scores.tolist(), "selected_index": int(best_mask_idx), "data":data }} def get_image_info(data): return { "filename": data['filename'], "index": data['current_index'], "total": len(data['image_files']), "width": data['image'].shape[1], "height": data['image'].shape[0] }