diff --git a/.gitignore b/.gitignore index 1d2cbe6..5faa378 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,4 @@ *.log -tmp/ \ No newline at end of file +tmp/ +datas/ +imgs/ \ No newline at end of file diff --git a/__pycache__/run.cpython-310.pyc b/__pycache__/run.cpython-310.pyc index f0b5d59..2a3db61 100644 Binary files a/__pycache__/run.cpython-310.pyc and b/__pycache__/run.cpython-310.pyc differ diff --git a/download.zip b/download.zip index 8759f48..10fbcd7 100644 Binary files a/download.zip and b/download.zip differ diff --git a/fast_api_run.py b/fast_api_run.py index 6867621..94086d4 100644 --- a/fast_api_run.py +++ b/fast_api_run.py @@ -19,51 +19,1071 @@ from wudingpv.taihuyuan_roof.manet.model.resunet import resUnetpamcarb as roof_r from wudingpv.predictandeval_util import segmentation from guangfufadian import model_base as guangfufadian_model_base from fenglifadian import model_base as fenglifadian_model_base -from work_util import prepare_data,model_deal,params,data_util,post_model +from work_util import prepare_data,model_deal,params,data_util,post_model,sam_deal from work_util.logger import logger import joblib -def get_roof_model(): - model_roof = roof_resUnetpamcarb() - model_path_roof = os.path.join(current_dir,'wudingpv/models/roof_best.pth') - model_dict_roof = torch.load(model_path_roof, map_location=torch.device('cpu')) - model_roof.load_state_dict(model_dict_roof['net']) - logger.info("屋顶识别权重加载成功") - model_roof.eval() - model_roof.cuda() - return model_roof +import pickle +from segment_anything_model import sam_annotator +import cv2 +import io +import json +import base64 +from segment_anything_model.sam_config import sam_config,sam_api_config +from segment_anything_model.segment_anything import sam_model_registry, SamPredictor +import traceback +# def get_roof_model(): +# model_roof = roof_resUnetpamcarb() +# model_path_roof = os.path.join(current_dir,'wudingpv/models/roof_best.pth') +# model_dict_roof = torch.load(model_path_roof, map_location=torch.device('cpu')) +# model_roof.load_state_dict(model_dict_roof['net']) +# logger.info("屋顶识别权重加载成功") +# model_roof.eval() +# model_roof.cuda() +# return model_roof + +# def get_pv_model(): +# model_roof = roof_resUnetpamcarb() +# model_path_roof = os.path.join(current_dir,'wudingpv/models/pv_best.pth') +# model_dict_roof = torch.load(model_path_roof, map_location=torch.device('cpu')) +# model_roof.load_state_dict(model_dict_roof['net']) +# logger.info("屋顶识别权重加载成功") +# model_roof.eval() +# model_roof.cuda() +# return model_roof + + +# # 初始化 FastAPI +# app = FastAPI() + +# # 初始化参数 +# param = params.ModelParams() +# pvfd_param = guangfufadian_model_base.guangfufadian_Args() +# windfd_args = fenglifadian_model_base.fenglifadian_Args() + +# # 模型实例 +# dimaoshibie_SegFormer = segformer.SegFormer_Segmentation() +# roof_model = get_roof_model() +# pv_model = get_pv_model() + +# pvfd_model_path = os.path.join(pvfd_param.checkpoints,'Crossformer_station08_il192_ol96_sl6_win2_fa10_dm256_nh4_el3_itr0/checkpoint.pth') # 修改为实际模型路径 +# pvfd_model = guangfufadian_model_base.ModelInference(pvfd_model_path, pvfd_param) + +# windfd_model_path = os.path.join(windfd_args.checkpoints,'Crossformer_Wind_farm_il192_ol12_sl6_win2_fa10_dm256_nh4_el3_itr0/checkpoint.pth') # 修改为实际模型路径 +# windfd_model = fenglifadian_model_base.ModelInference(windfd_model_path, windfd_args) + +# ch4_model_flow = joblib.load(os.path.join(current_dir,'jiawanyuce/liuliang_model/xgb_model_liuliang.pkl')) +# ch4_model_gas = joblib.load(os.path.join(current_dir,'jiawanyuce/qixiangnongdu_model/xgb_model_qixiangnongdu.pkl')) + + + +# ==================================SAM================================================= + +# 前端需要进行图片对画布压缩和解压缩的point,来回确定 + + + + + + + +location = "http://124.16.151.196:13432/files/tmp/sam/dc345a3c4-a75a-4121-91fc-e4b9f488384f/input/微信图片_20250506163349.jpg", +input_dir = "/home/xiazj/ai-station-code/tmp/sam/c345a3c4-a75a-4121-91fc-e4b9f488384f/input", +output_dir= "/home/xiazj/ai-station-code/tmp/sam/c345a3c4-a75a-4121-91fc-e4b9f488384f/output", +file_path= "/home/xiazj/ai-station-code/tmp/sam/282dc905-6d9d-40aa-8ce1-62160f5f9864" + + +# 模型加载 +checkpoint_path = os.path.join(current_dir,'segment_anything_model/weights/vit_b.pth') +sam = sam_model_registry["vit_b"](checkpoint=checkpoint_path) +device = "cuda" if cv2.cuda.getCudaEnabledDeviceCount() > 0 else "cpu" +_ = sam.to(device=device) +sam_predictor = SamPredictor(sam) +print(f"SAM模型已加载,使用设备: {device}") + + +# 添加分类 +def sam_class_set(class_name, color,path): + # 加载配置内容 + loaded_data,api_config = sam_deal.load_model(path) + result = sam_deal.add_class(loaded_data,class_name,color) + if result['status'] == True: + loaded_data = result['reason'] + else: + return { + "success":False, + "msg":result['reason'], + "data":None + } + loaded_data['class_index'] = loaded_data['class_names'].index(class_name) + r, g, b = [int(c) for c in color] + bgr_color = (b, g, r) + result, api_config = sam_deal.set_current_class(loaded_data, api_config, loaded_data['class_index'], color=bgr_color) + # 更新配置内容 + sam_deal.save_model(loaded_data,api_config,path) + return {"success":True, + "msg":f"已添加类别: {class_name}, 颜色: {color}", + "data":{"class_name_list": loaded_data['class_names'], + "current_index": loaded_data['class_index'], + "class_dict":loaded_data['class_colors'] + }} + + +# class_name = 'water' +# color = [0,255,0] +# path = file_path +# sam_class_set(class_name=class_name, color=color, path=path) + + +# 选择分类 - 传递标签索引 +def on_class_selected(class_index,path): + # 加载配置内容 + loaded_data,api_config = sam_deal.load_model(path) + result, api_config = sam_deal.set_current_class(loaded_data, api_config, class_index, color=None) + sam_deal.save_model(loaded_data,api_config,path) + if result: + img = sam_deal.refresh_image(loaded_data,api_config,path) + if img['status'] == True: + return { + "success":True, + "msg":"", + "data":img['reason'] + } + else: + return { + "success":False, + "msg":img['reason'], + "data":None + } + else: + return { + "success":False, + "msg":"分类标签识别错误", + "data":None + } + +# on_class_selected(2,file_path) + + +# 选择颜色 +def set_sam_color(current_index, rgb_color,path): + loaded_data,api_config = sam_deal.load_model(path) + r, g, b = [int(c) for c in rgb_color] + bgr_color = (b, g, r) + data, api = sam_deal.set_class_color(loaded_data, api_config, current_index, bgr_color) + sam_deal.save_model(data,api,path) + img = sam_deal.refresh_image(loaded_data,api_config,path) + if img['status'] == True: + return { + "success":True, + "msg":"", + "data":img['reason'] + } + else: + return { + "success":False, + "msg":img['reason'], + "data":None + } + + + +# set_sam_color(2,(0,255,0),file_path) + + +# 移除分类 +def sam_remove_class(path,select_index): + loaded_data,api_config = sam_deal.load_model(path) + class_name = loaded_data['class_names'][select_index] + loaded_data,api_config = sam_deal.remove_class(loaded_data,api_config,class_name) + sam_deal.save_model(loaded_data,api_config,path) + img = sam_deal.refresh_image(loaded_data,api_config,path) + if img['status'] == True: + return { + "success":True, + "msg":"", + "data":img['reason'] + } + else: + return { + "success":False, + "msg":img['reason'], + "data":None + } + +# sam_remove_class(file_path,1) + +# 加点 +def left_mouse_down(x,y,path): + loaded_data,api_config = sam_deal.load_model(path) + if not api_config['current_class']: + return { + "success":False, + "msg":"请先选择一个分类,在添加标点之前", + "image":None + } + is_foreground = True + result = sam_deal.add_annotation_point(api_config,x,y,is_foreground) + if result['status']== False: + return { + "success":False, + "msg":result['reason'], + "image":None + } + api_config = result['reason'] + sam_deal.save_model(loaded_data,api_config,path) + img = sam_deal.refresh_image(loaded_data,api_config,path) + if img['status'] == True: + encoded_string = sam_deal.load_tmp_image(img['reason']) + return { + "success":True, + "msg":"", + "image":encoded_string + } + else: + return { + "success":False, + "msg":img['reason'], + "image":None + } + + + + + +# left_mouse_down(27,326,file_path) +# left_mouse_down(67,449,file_path) + + +def right_mouse_down(x,y,path): + loaded_data,api_config = sam_deal.load_model(path) + if not api_config['current_class']: + return { + "success":False, + "msg":"请先选择一个分类,在添加标点之前", + "data":None + } + is_foreground = False + api_config, result = sam_deal.add_annotation_point(api_config,x,y,is_foreground) + sam_deal.save_model(loaded_data,api_config,path) + img = sam_deal.refresh_image(loaded_data,api_config,path) + if img['status'] == True: + return { + "success":True, + "msg":"", + "data":img['reason'] + } + else: + return { + "success":False, + "msg":img['reason'], + "data":None + } + +# x = 231 +# y = 281 +# right_mouse_down(x,y,file_path) + + + + +# 删除前一个点 +def sam_delete_last_point(path): + loaded_data,api_config = sam_deal.load_model(path) + result = sam_deal.delete_last_point(api_config) + if result['status'] == True: + api_config = result['reason'] + sam_deal.save_model(loaded_data,api_config,path) + img = sam_deal.refresh_image(loaded_data,api_config,path) + if img['status'] == True: + return { + "success":True, + "msg":"", + "data":img['reason'] + } + else: + return { + "success":False, + "msg":img['reason'], + "data":None + } + else: + return { + "success":False, + "msg":result['reason'], + "data":None + } + + +# sam_delete_last_point(file_path) + + + + +# 清除所有点 +def sam_clear_all_point(path): + loaded_data,api_config = sam_deal.load_model(path) + result = sam_deal.reset_current_class_points(api_config) + if result['status'] == True: + api_config = result['reason'] + sam_deal.save_model(loaded_data,api_config,path) + img = sam_deal.refresh_image(loaded_data,api_config,path) + if img['status'] == True: + return { + "success":True, + "msg":"", + "data":img['reason'] + } + else: + return { + "success":False, + "msg":img['reason'], + "data":None + } + else: + return { + "success":False, + "msg":result['reason'], + "data":None + } +# sam_clear_all_point(file_path) + + + +# 模型预测 +def sam_predict_mask(path): + loaded_data,api_config = sam_deal.load_model(path) + class_data = api_config['class_annotations'].get(api_config['current_class'], {}) + if not class_data.get('points'): + return { + "success":False, + "msg":"请在预测前添加至少一个预测样本点", + "data":None + } + else: + loaded_data = sam_deal.reset_annotation(loaded_data) + # 将标注点添加到loaded_data 中 + for i, (x, y) in enumerate(class_data['points']): + is_foreground = class_data['point_types'][i] + loaded_data = sam_deal.add_point(loaded_data, x, y, is_foreground=is_foreground) + try: + result = sam_deal.predict_mask(loaded_data,sam_predictor) + if result['status'] == False: + return { + "success":False, + "msg":result['reason'], + "data":None} + + result = result['reason'] + loaded_data = result['data'] + class_data['masks'] = [np.array(mask, dtype=np.uint8) for mask in result['masks']] + class_data['scores'] = result['scores'] + class_data['selected_mask_index'] = result['selected_index'] + if result['selected_index'] >= 0: + class_data['selected_mask'] = class_data['masks'][result['selected_index']] + logger.info(f"predict: Predicted {len(result['masks'])} masks, selected index: {result['selected_index']}") + sam_deal.save_model(loaded_data,api_config,path) + img = sam_deal.refresh_image(loaded_data,api_config,path) + if img['status'] == True: + return { + "success":True, + "msg":"", + "data":img['reason'] + } + else: + return { + "success":False, + "msg":img['reason'], + "data":None + } + except Exception as e: + logger.error(f"predict: Error during prediction: {str(e)}") + traceback.print_exc() + return { + "success":False, + "msg":f"predict: Error during prediction: {str(e)}", + "data":None} + +# sam_predict_mask(file_path) + + + +# 清除预测内容以及点标记 +def sam_reset_annotation(path): + loaded_data,api_config = sam_deal.load_model(path) + loaded_data,api_config = sam_deal.reset_annotation_all(loaded_data,api_config) + sam_deal.save_model(loaded_data,api_config,path) + img = sam_deal.refresh_image(loaded_data,api_config,path) + if img['status'] == True: + return { + "success":True, + "msg":"", + "data":img['reason'] + } + else: + return { + "success":False, + "msg":img['reason'], + "data":None + } + +# sam_reset_annotation(file_path) + + +# 保存一个分类 +# class_index , 当前的选择分类 +def sam_add_to_class(path,class_index): + loaded_data,api_config = sam_deal.load_model(path) + class_name = loaded_data['class_names'][class_index] + result = sam_deal.add_to_class(api_config,class_name) + if result['status'] == True: + api_config = result['reason'] + sam_deal.save_model(loaded_data,api_config,path) + img = sam_deal.refresh_image(loaded_data,api_config,path) + if img['status'] == True: + return { + "success":True, + "msg":"", + "data":img['reason'] + } + else: + return { + "success":False, + "msg":img['reason'], + "data":None + } + else: + return { + "success":False, + "msg":result['reason'], + "data":None + } + +# sam_add_to_class(file_path,2) + + + + +def sam_save_annotation(path): + loaded_data,api_config = sam_deal.load_model(path) + + if not api_config['output_dir']: + logger.info("save_annotation: Output directory not set") + return { + "success":False, + "msg":"save_annotation: Output directory not set", + "data":None + } + has_annotations = False + for class_name, class_data in api_config['class_annotations'].items(): + if 'final_mask' in class_data and class_data['final_mask'] is not None: + has_annotations = True + break + if not has_annotations: + logger.info("save_annotation: No final masks to save") + return { + "success":False, + "msg":"save_annotation: No final masks to save", + "data":None + } + image_info = sam_deal.get_image_info(loaded_data) + if not image_info: + logger.info("save_annotation: No image info available") + return { + "success":False, + "msg":"save_annotation: No image info available", + "data":None + } + image_basename = os.path.splitext(image_info['filename'])[0] + annotation_dir = os.path.join(api_config['output_dir'], image_basename) + annotation_dir = file_path + "/output/20241231160414" + os.makedirs(annotation_dir, exist_ok=True) + saved_files = [] + orig_img = loaded_data['image'] + original_img_path = os.path.join(annotation_dir, f"{image_basename}.jpg") + cv2.imwrite(original_img_path, orig_img) + saved_files.append(original_img_path) + vis_img = orig_img.copy() + img_height, img_width = orig_img.shape[:2] + labelme_data = { + "version": "5.1.1", + "flags": {}, + "shapes": [], + "imagePath": f"{image_basename}.jpg", + "imageData": None, + "imageHeight": img_height, + "imageWidth": img_width + } + for class_name, class_data in api_config['class_annotations'].items(): + if 'final_mask' in class_data and class_data['final_mask'] is not None: + color = api_config['class_colors'].get(class_name, (0, 255, 0)) + vis_mask = class_data['final_mask'].copy() + color_mask = np.zeros_like(vis_img) + color_mask[vis_mask > 0] = color + vis_img = cv2.addWeighted(vis_img, 1.0, color_mask, 0.5, 0) + binary_mask = (class_data['final_mask'] > 0).astype(np.uint8) + contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + for contour in contours: + epsilon = 0.0001 * cv2.arcLength(contour, True) + approx_contour = cv2.approxPolyDP(contour, epsilon, True) + points = [[float(point[0][0]), float(point[0][1])] for point in approx_contour] + if len(points) >= 3: + shape_data = { + "label": class_name, + "points": points, + "group_id": None, + "shape_type": "polygon", + "flags": {} + } + labelme_data["shapes"].append(shape_data) + vis_path = os.path.join(annotation_dir, f"{image_basename}_mask.jpg") + cv2.imwrite(vis_path, vis_img) + saved_files.append(vis_path) + try: + is_success, buffer = cv2.imencode(".jpg", orig_img) + if is_success: + img_bytes = io.BytesIO(buffer).getvalue() + labelme_data["imageData"] = base64.b64encode(img_bytes).decode('utf-8') + else: + print("save_annotation: Failed to encode image data") + labelme_data["imageData"] = "" + except Exception as e: + logger.error(f"save_annotation: Could not encode image data: {str(e)}") + labelme_data["imageData"] = "" + json_path = os.path.join(annotation_dir, f"{image_basename}.json") + with open(json_path, 'w') as f: + json.dump(labelme_data, f, indent=2) + saved_files.append(json_path) + logger.info(f"save_annotation: Annotation saved to {annotation_dir}") + return { + "success":False, + "msg":"", + "data":{ + 'output_dir': annotation_dir, + 'files': saved_files, + 'classes': list(api_config['class_annotations'].keys()) + } + + } + +# sam_save_annotation(file_path) + + +# # 添加类别 +# def sam_class_set(class_name:str=None, color:List = None ,path:str = None): +# # 加载配置内容 +# config_path = os.path.join(path,'model_params.pickle') +# with open(config_path, 'rb') as file: +# loaded_data,api_config = pickle.load(file) + +# if class_name in loaded_data['class_names']: +# return { +# "success":False, +# "msg":f"类别 '{class_name}' 已存在", +# "data":None +# } +# loaded_data['class_names'].append(class_name) +# if color is None: +# return { +# "success":False, +# "msg":f"请指定{class_name}代表的颜色", +# "data":None +# } + +# loaded_data['class_colors'][class_name] = tuple(color) +# logger.info(f"已添加类别: {class_name}, 颜色: {color}") +# loaded_data['class_index'] = loaded_data['class_names'].index(class_name) + +# api_config['current_class'] = class_name +# if class_name not in api_config['class_annotations']: +# api_config['class_annotations'][class_name] = { +# 'points': [], +# 'point_types': [], +# 'masks': [], +# 'selected_mask_index': -1 +# } +# api_config['class_colors'][class_name] = tuple(color) + +# save_data = (loaded_data,api_config) +# with open(config_path, 'wb') as file: +# pickle.dump(save_data, file) + +# # 加载存储中的图片信息 +# img = model_deal.get_display_image_ori(loaded_data) +# image = Image.fromarray(img['reason']) +# tmp_path = os.path.join(path,'temp/output_image.jpg') +# image.save(tmp_path) + +# return {"success":True, +# "msg":f"已添加类别: {class_name}, 颜色: {color}", +# "data":{"class_name_list": loaded_data['class_names'], +# "current_index": loaded_data['class_index'], +# "class_dict":loaded_data['class_colors'] +# }} +# # 加载class分类 +# # sam_class_set(class_name='test11', color = [120,0,0], path=file_path) + + +# # 添加point + + + +# # 添加标注器中的点信息 +# def add_annotation_point(loaded_data,x, y, is_foreground=True): +# loaded_data['input_point'].append([x, y]) +# loaded_data['input_label'].append(1 if is_foreground else 0) +# logger.info(f"添加{'前景' if is_foreground else '背景'}点: ({x}, {y})") +# return loaded_data + + +# def predict_mask(data): +# if data['image_rgb'] is None: +# logger.info("predict_mask: No image loaded") +# return {"status":False, "reason":"预测掩码:没有图像加载"} +# if len(data['input_point']) == 0: +# logger.info("predict_mask: No points added") +# return {"status":False, "reason":"预测掩码:没有进行点标注"} +# try: +# sam_predictor.set_image(data['image_rgb']) +# except Exception as e: +# logger.error(f"predict_mask: Error setting image: {str(e)}") +# return {"status":False, "reason":f"predict_mask: Error setting image: {str(e)}"} +# input_point_np = np.array(data['input_point']) +# input_label_np = np.array(data['input_label']) +# try: +# masks_pred, scores, logits = sam_predictor.predict( +# point_coords=input_point_np, +# point_labels=input_label_np, +# mask_input=data['logit_input'][None, :, :] if data['logit_input'] is not None else None, +# multimask_output=True, +# ) +# except Exception as e: +# logger.error(f"predict_mask: Error during prediction: {str(e)}") +# traceback.print_exc() +# return {"status":False, "reason":f"predict_mask: Error during prediction: {str(e)}"} +# data['masks_pred'] = masks_pred +# data['scores'] = scores +# data['logits'] = logits +# best_mask_idx = np.argmax(scores) +# data['selected_mask'] = masks_pred[best_mask_idx] +# data['logit_input'] = logits[best_mask_idx, :, :] +# logger.info(f"predict_mask: Predicted {len(masks_pred)} masks, best score: {scores[best_mask_idx]:.4f}") +# return {"status":True, "reason":{ +# "masks": [mask.tolist() for mask in masks_pred], +# "scores": scores.tolist(), +# "selected_index": int(best_mask_idx), +# "data":data +# }} + +# def sam_apply_mask(image, mask, color, alpha=0.5): +# masked_image = image.copy() +# for c in range(3): +# masked_image[:, :, c] = np.where( +# mask == 1, +# image[:, :, c] * (1 - alpha) + alpha * color[c], +# image[:, :, c] +# ) +# return masked_image + + +# def apply_mask_overlay(image, mask, color, alpha=0.5): +# colored_mask = np.zeros_like(image) +# colored_mask[mask > 0] = color +# return cv2.addWeighted(image, 1, colored_mask, alpha, 0) + + +# def get_image_display(data,api): +# if data['image'] is None: +# print("get_display_image: No image loaded") +# return {"status":False, "reason":"获取图像:没有图像加载"} +# display_image = data['image'].copy() +# try: +# for point, label in zip(data['input_point'], data['input_label']): +# color = (0, 255, 0) if label == 1 else (0, 0, 255) +# cv2.circle(display_image, tuple(point), 5, color, -1) +# if data['selected_mask'] is not None: +# class_name = data['class_names'][data['class_index']] +# color = data['class_colors'].get(class_name, (0, 0, 128)) +# display_image = sam_apply_mask(display_image, data['selected_mask'], color) +# logger.info(f"get_display_image: Returning image with shape {display_image.shape}") + +# if not isinstance(display_image, np.ndarray) or display_image.size == 0: +# logger.info(f"get_image_display: Invalid image array, shape: {display_image.shape if isinstance(display_image, np.ndarray) else 'None'}") +# return {"status":False, "reason":f"get_image_display: Invalid image array, shape: {display_image.shape if isinstance(display_image, np.ndarray) else 'None'}"} +# # 仅应用当前图片的final_mask +# for class_name, class_data in api['class_annotations'].items(): +# if 'final_mask' in class_data and class_data['final_mask'] is not None: +# color = api['class_colors'].get(class_name, (0, 255, 0)) +# mask = class_data['final_mask'] +# if isinstance(mask, list): +# mask = np.array(mask, dtype=np.uint8) +# logger.info(f"Applying mask for class {class_name}, shape: {mask.shape}") +# display_image = apply_mask_overlay(display_image, mask, color, alpha=0.5) +# elif 'selected_mask' in class_data and class_data['selected_mask'] is not None: +# color = api['class_colors'].get(class_name, (0, 255, 0)) +# mask = class_data['selected_mask'] +# if isinstance(mask, list): +# mask = np.array(mask, dtype=np.uint8) +# logger.info(f"Applying mask for class {class_name}, shape: {mask.shape}") +# img = apply_mask_overlay(display_image, mask, color, alpha=0.5) +# if api['current_class'] and api['current_class'] in api['class_annotations']: +# class_data = api['class_annotations'][api['current_class']] +# for i, (x, y) in enumerate(class_data['points']): +# is_fg = class_data['point_types'][i] +# color = (0, 255, 0) if is_fg else (0, 0, 255) +# logger.info(f"Drawing point at ({x}, {y}), type: {'foreground' if is_fg else 'background'}") +# cv2.circle(img, (int(x), int(y)), 5, color, -1) +# logger.info(f"get_image_display: Returning image with shape {img.shape}") +# return {"status":True, "reason":img} + +# except Exception as e: +# logger.error(f"get_display_image: Error processing image: {str(e)}") +# traceback.print_exc() +# return {"status":False, "reason":f"get_display_image: Error processing image: {str(e)}"} + + +# # 前端需要保存每次点击的点的信息, +# # 何时删除: ????? +# point_list = [[303,123],[123,232],[234,234],[343,123]] +# point_type = [True,True,False,True] +# def sam_predict(point_list,point_type,path): +# # 加载配置内容 +# config_path = os.path.join(path,'model_params.pickle') +# with open(config_path, 'rb') as file: +# loaded_data,api_config = pickle.load(file) +# # 获取当前类别 +# class_data = api_config['class_annotations'][api_config['current_class']] +# for index, point in enumerate(point_list): +# x,y = point +# class_data['points'].append((x, y)) +# class_data['point_types'].append(point_type[index]) + + +# if not api_config['current_class'] or api_config['current_class'] not in api_config['class_annotations']: +# logger.info("predict: No current class selected") +# return {"status":False, "reason":"没有选择分类类别"} +# class_data = api_config['class_annotations'][api_config['current_class']] +# if not class_data['points']: +# logger.info("predict: No points added for current class") +# return {"status":False, "reason":"当前类别没有标注点信息,请添加点信息"} +# loaded_data = reset_annotation(loaded_data) +# for i, (x, y) in enumerate(class_data['points']): +# is_foreground = class_data['point_types'][i] +# loaded_data = add_annotation_point(loaded_data, x, y, is_foreground=is_foreground) +# try: + +# result = predict_mask(loaded_data) +# if result['status'] == False: +# return { +# "success":False, +# "msg":result['reason'], +# "data":None +# } +# result = result['reason'] + +# loaded_data = result['data'] + +# if result is None: +# logger.info("predict: SAMAnnotator.predict_mask returned None") +# return {"status":False, "reason":"模型预测返回空值"} +# loaded_data = result['data'] +# class_data['masks'] = [np.array(mask, dtype=np.uint8) for mask in result['masks']] +# class_data['scores'] = result['scores'] +# class_data['selected_mask_index'] = result['selected_index'] +# if result['selected_index'] >= 0: +# class_data['selected_mask'] = class_data['masks'][result['selected_index']] +# print(f"predict: Predicted {len(result['masks'])} masks, selected index: {result['selected_index']}") +# save_data = (loaded_data,api_config) +# with open(config_path, 'wb') as file: +# pickle.dump(save_data, file) + +# # 生成预测结果图 +# result = get_image_display(loaded_data,api_config) +# if result['status'] == True: +# img = result['reason'] +# display_img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) +# pil_img = Image.fromarray(display_img_rgb) +# tmp_path = os.path.join(path,'temp/output_image.jpg') +# pil_img.save(tmp_path) +# else: +# return { +# "success":False, +# "msg":result['reason'], +# "data":None +# } + +# return { +# "success":True, +# "msg":result['reason'], +# "location":tmp_path +# } +# except Exception as e: +# print(f"predict: Error during prediction: {str(e)}") +# traceback.print_exc() +# return {"status":False, "reason":f"predict: Error during prediction: {str(e)}"} + +# sam_predict(point_list,point_type,file_path) + + + +""" +保存当前结果 +""" + + + + + + +# 添加pionts +# def add_point(x,y,type,path,foreground_mode=True): +# # 加载配置内容 +# config_path = os.path.join(path,'model_params.pickle') +# with open(config_path, 'rb') as file: +# loaded_data,api_config = pickle.load(file) + +# class_data = api_config['class_annotations'][api_config['current_class']] +# class_data['points'].append((x, y)) +# class_data['point_types'].append(is_foreground) + + + + + +# # 支持的图像文件扩展名 +# SUPPORTED_IMAGE_EXTENSIONS = ('.png', '.jpg', '.jpeg', '.JPG', '.JPEG', '.PNG', '.tiff') + + + + +# DEFAULT_MODEL_PATH = r"/home/xiazj/ai-station-code/segment_anything_model/weights/vit_b.pth" +# config_path = r'/home/xiazj/ai-station-code/segment_anything_model/model_params.pickle' + +# # # 初始化配置 , 每次上传图片时,会创建一个新的配置文件 +# with open(config_path, 'wb') as file: +# pickle.dump(config, file) + +# # # 初始化模型,让模型加载,然后类直接copy即可,暂时没测试 +# # sam = sam_model_registry["vit_b"](checkpoint=DEFAULT_MODEL_PATH) +# # # 将模型移至GPU(如果可用) +# # device = "cuda" if cv2.cuda.getCudaEnabledDeviceCount() > 0 else "cpu" +# # _ = sam.to(device=device) +# # sam_predictor = SamPredictor(sam) + + + +# """ +# 设置分类 +# """ +# { +# "":"", +# "":"", +# } + + +# """ +# 选择分类 +# """ + + +# """ +# 加点预测 +# """ + + + +# """ +# 图像保存 +# """ + + +# # # 加载模型 +# annotator = sam_annotator.SAMAnnotator(DEFAULT_MODEL_PATH,config_path) +# print('success') + +# # # 加载路径 +# input_dir = r"/home/xiazj/ai-station-code/segment_anything_model/input" +# output_dir = r"/home/xiazj/ai-station-code/segment_anything_model/output" +# temp_dir = r"/home/xiazj/ai-station-code/segment_anything_model/temp_images" +# image_count = annotator.set_input_directory(input_dir) +# print(image_count) +# annotator.set_output_directory(output_dir) +# annotator.save_params_to_file(config_path) + + +# # # 加载当前图像 ,基于current_index,进行图片替换 +# annotator.load_image() +# image_info = annotator.get_current_image_info() + +# # 获取显示图像, 这里显示的是当前标注的图像当前状态信息 +# display_image = annotator.get_display_image() + +# # 保存为临时文件 +# temp_image_path = os.path.join(temp_dir,image_info['filename']) +# cv2.imwrite(temp_image_path, display_image) +# annotator.save_params_to_file(config_path) + +# """添加标注点""" +# # ,测试一下加载新模型 +# annotator = sam_annotator.SAMAnnotator(DEFAULT_MODEL_PATH,config_path) + +# point = { +# "x": 200, +# "y": 640, +# "is_foreground": True, +# "button_type": "left" +# } +# is_foreground = True if point['button_type'] == "left" else False +# result = annotator.add_point(point['x'], point['y'], is_foreground) + +# # 第二次加点 +# point = { +# "x": 270, +# "y": 480, +# "is_foreground": False, +# "button_type": "right" +# } +# is_foreground = True if point['button_type'] == "left" else False +# result = annotator.add_point(point['x'], point['y'], is_foreground) + +# point = { +# "x": 400, +# "y": 430, +# "is_foreground": True, +# "button_type": "left" +# } +# is_foreground = True if point['button_type'] == "left" else False +# result = annotator.add_point(point['x'], point['y'], is_foreground) + +# point = { +# "x": 705, +# "y": 497, +# "is_foreground": True, +# "button_type": "left" +# } +# is_foreground = True if point['button_type'] == "left" else False +# result = annotator.add_point(point['x'], point['y'], is_foreground) + +# point = { +# "x": 408, +# "y": 460, +# "is_foreground": True, +# "button_type": "left" +# } +# is_foreground = True if point['button_type'] == "left" else False +# result = annotator.add_point(point['x'], point['y'], is_foreground) + +# point = { +# "x": 428, +# "y": 435, +# "is_foreground": True, +# "button_type": "left" +# } +# is_foreground = True if point['button_type'] == "left" else False +# result = annotator.add_point(point['x'], point['y'], is_foreground) + +# point = { +# "x": 270, +# "y": 1110, +# "is_foreground": True, +# "button_type": "left" +# } +# is_foreground = True if point['button_type'] == "left" else False +# result = annotator.add_point(point['x'], point['y'], is_foreground) + +# # 更新显示图像 +# display_image = annotator.get_display_image() +# if display_image is not None: +# image_info = annotator.get_current_image_info() +# temp_image_path = os.path.join(temp_dir,image_info['filename']) +# cv2.imwrite(temp_image_path, display_image) +# annotator.save_params_to_file(config_path) + + +# """ +# 获取列表 +# """ +# points_data = annotator.get_points_with_labels() +# print(points_data) + +# """ +# 删除点 +# """ +# result = annotator.delete_last_point() + +# # 更新显示图像 +# display_image = annotator.get_display_image() +# if display_image is not None: +# image_info = annotator.get_current_image_info() +# temp_image_path = os.path.join(temp_dir,image_info['filename']) +# cv2.imwrite(temp_image_path, display_image) +# annotator.save_params_to_file(config_path) + + +# with open(config_path, 'rb') as file: +# loaded_data = pickle.load(file) +# print(loaded_data) + + +# """ +# 开启预测 +# """ +# result = annotator.predict_mask() +# if result is None: +# print({"status": "error", "message": "预测失败,请检查是否有添加点或加载图像"}) +# # 更新显示图像 +# display_image = annotator.get_display_image() +# if display_image is not None: +# image_info = annotator.get_current_image_info() +# temp_image_path = os.path.join(temp_dir,image_info['filename']) +# cv2.imwrite(temp_image_path, display_image) + +# annotator.save_params_to_file(config_path) + + +# """ +# 逻辑: +# 添加一个点,就预测一次 +# 因此,需要一个add_point, 然后再predict + + +# 同理,删除一个point,也需要一次predict,重新预测,等价于回滚; +# """ + + + + +# """ +# 添加类别 , 也就是要分割的内容,以及想标注的颜色 +# """ + +# class_name = "origin" +# color = [128, 128, 128] +# result = annotator.add_class(class_name,color) +# print({"status": "success", "classes": result}) + + +# """ +# 设置某一类别 +# """ + + +# """ +# 确定分类 +# """ + + + + + + -def get_pv_model(): - model_roof = roof_resUnetpamcarb() - model_path_roof = os.path.join(current_dir,'wudingpv/models/pv_best.pth') - model_dict_roof = torch.load(model_path_roof, map_location=torch.device('cpu')) - model_roof.load_state_dict(model_dict_roof['net']) - logger.info("屋顶识别权重加载成功") - model_roof.eval() - model_roof.cuda() - return model_roof -# 初始化 FastAPI -app = FastAPI() -# 初始化参数 -param = params.ModelParams() -pvfd_param = guangfufadian_model_base.guangfufadian_Args() -windfd_args = fenglifadian_model_base.fenglifadian_Args() -# 模型实例 -dimaoshibie_SegFormer = segformer.SegFormer_Segmentation() -roof_model = get_roof_model() -pv_model = get_pv_model() -pvfd_model_path = os.path.join(pvfd_param.checkpoints,'Crossformer_station08_il192_ol96_sl6_win2_fa10_dm256_nh4_el3_itr0/checkpoint.pth') # 修改为实际模型路径 -pvfd_model = guangfufadian_model_base.ModelInference(pvfd_model_path, pvfd_param) -windfd_model_path = os.path.join(windfd_args.checkpoints,'Crossformer_Wind_farm_il192_ol12_sl6_win2_fa10_dm256_nh4_el3_itr0/checkpoint.pth') # 修改为实际模型路径 -windfd_model = fenglifadian_model_base.ModelInference(windfd_model_path, windfd_args) -ch4_model_flow = joblib.load(os.path.join(current_dir,'jiawanyuce/liuliang_model/xgb_model_liuliang.pkl')) -ch4_model_gas = joblib.load(os.path.join(current_dir,'jiawanyuce/qixiangnongdu_model/xgb_model_qixiangnongdu.pkl')) # 模型调用 @@ -163,13 +1183,13 @@ test_data_path = "/home/xiazj/ai-station-code/guangfufadian/datasets/run_test.cs # # 3. 上传数据的预测 -test_data_path = "/home/xiazj/ai-station-code/fenglifadian/datasets/Wind_farm_test.csv" -predictions = windfd_model.run_inference(test_data_path) -predictions = np.array(predictions).flatten() -print(len(predictions)) -pred_data, true_data = prepare_data.result_merge_fenglifadian(test_data_path,predictions) -print(pred_data) -print(true_data) +# test_data_path = "/home/xiazj/ai-station-code/fenglifadian/datasets/Wind_farm_test.csv" +# predictions = windfd_model.run_inference(test_data_path) +# predictions = np.array(predictions).flatten() +# print(len(predictions)) +# pred_data, true_data = prepare_data.result_merge_fenglifadian(test_data_path,predictions) +# print(pred_data) +# print(true_data) """甲烷产量预测""" diff --git a/fenglifadian/__pycache__/model_base.cpython-39.pyc b/fenglifadian/__pycache__/model_base.cpython-39.pyc index cf5d8f7..7f336d7 100644 Binary files a/fenglifadian/__pycache__/model_base.cpython-39.pyc and b/fenglifadian/__pycache__/model_base.cpython-39.pyc differ diff --git a/guangfufadian/__pycache__/model_base.cpython-39.pyc b/guangfufadian/__pycache__/model_base.cpython-39.pyc index 31676c8..5140439 100644 Binary files a/guangfufadian/__pycache__/model_base.cpython-39.pyc and b/guangfufadian/__pycache__/model_base.cpython-39.pyc differ diff --git a/meirejie/data/char_data_test.csv b/meirejie/data/char_data_test.csv index 0a545b1..3411730 100644 --- a/meirejie/data/char_data_test.csv +++ b/meirejie/data/char_data_test.csv @@ -1,11 +1,11 @@ -A,V,FC,C,H,N,S,O,H/C,O/C,N/C,Rt,Hr,dp,T,Char -16.6457787398883,62.0967741935484,37.9032258064516,68.38,5.6,1.12,0.42,24.05,0.982743492249196,0.263783269961977,0.0140391927464171,0.3,5.0,2.0,650,48.44904 -8.38943552563439,42.4194460146976,57.5805539853024,82.63,4.08,1.09,0.33,18.56,0.592520876195087,0.16846181774174,0.011306858456804,20.0,5.0,0.2,490,78.69 -13.5,54.1,45.9,71.6,5.8,1.4,0.9,20.3,0.972067039106145,0.212639664804469,0.0167597765363128,30.0,10.0,0.84,600,57.0752 -17.44,39.81,60.19,78.08,3.95,0.65,2.87,14.45,0.607069672131148,0.138799948770492,0.0071355386416861,20.0,5.0,0.2,510,79.9 -24.2816545626776,49.402279677509,50.597720322491,73.99,5.66,1.14,0.49,18.72,0.917961886741452,0.189755372347615,0.0132064178556948,15.0,5.0,0.2,600,66.820533 -21.14,35.73,64.27,77.41,4.39,1.62,0.51,16.08,0.68053223097791,0.155793825087198,0.0179378817797627,15.0,5.0,0.08,600,78.5678276838 -4.77944939105516,41.9496990541703,58.0503009458297,75.09,4.79,3.56,0.32,19.22,0.765481422293248,0.191969636436276,0.0406369499457793,20.0,5.0,0.2,510,84.31 -8.24,38.05,61.95,82.3,4.73,0.92,1.32,12.05,0.689671931956258,0.109811664641555,0.0095816698489845,60.0,10.0,0.2,650,73.39 -5.35137948984904,33.22,66.78,80.23,5.17,1.08,0.24,13.28,0.773276829116291,0.124143088620217,0.0115382560851837,10.0,30.0,0.07,600,73.314965 -43.585255354201,50.5566709253513,49.4433290746487,64.68,5.18,0.11,4.89,25.28,0.961038961038961,0.293135435992579,0.0014577259475218,30.0,10.0,1.0,800,59.675632 +A,V,FC,C,H,N,S,O,H/C,O/C,N/C,Rt,Hr,dp,T,Char +16.6457787398883,62.0967741935484,37.9032258064516,68.38,5.6,1.12,0.42,24.05,0.982743492249196,0.263783269961977,0.0140391927464171,0.3,5.0,2.0,650,48.44904 +8.38943552563439,42.4194460146976,57.5805539853024,82.63,4.08,1.09,0.33,18.56,0.592520876195087,0.16846181774174,0.011306858456804,20.0,5.0,0.2,490,78.69 +13.5,54.1,45.9,71.6,5.8,1.4,0.9,20.3,0.972067039106145,0.212639664804469,0.0167597765363128,30.0,10.0,0.84,600,57.0752 +17.44,39.81,60.19,78.08,3.95,0.65,2.87,14.45,0.607069672131148,0.138799948770492,0.0071355386416861,20.0,5.0,0.2,510,79.9 +24.2816545626776,49.402279677509,50.597720322491,73.99,5.66,1.14,0.49,18.72,0.917961886741452,0.189755372347615,0.0132064178556948,15.0,5.0,0.2,600,66.820533 +21.14,35.73,64.27,77.41,4.39,1.62,0.51,16.08,0.68053223097791,0.155793825087198,0.0179378817797627,15.0,5.0,0.08,600,78.5678276838 +4.77944939105516,41.9496990541703,58.0503009458297,75.09,4.79,3.56,0.32,19.22,0.765481422293248,0.191969636436276,0.0406369499457793,20.0,5.0,0.2,510,84.31 +8.24,38.05,61.95,82.3,4.73,0.92,1.32,12.05,0.689671931956258,0.109811664641555,0.0095816698489845,60.0,10.0,0.2,650,73.39 +5.35137948984904,33.22,66.78,80.23,5.17,1.08,0.24,13.28,0.773276829116291,0.124143088620217,0.0115382560851837,10.0,30.0,0.07,600,73.314965 +43.585255354201,50.5566709253513,49.4433290746487,64.68,5.18,0.11,4.89,25.28,0.961038961038961,0.293135435992579,0.0014577259475218,30.0,10.0,1.0,800,59.675632 diff --git a/meirejie/data/gas_data_test.csv b/meirejie/data/gas_data_test.csv index 88ed4d9..141002c 100644 --- a/meirejie/data/gas_data_test.csv +++ b/meirejie/data/gas_data_test.csv @@ -1,11 +1,11 @@ -A,V,FC,C,H,N,S,O,H/C,O/C,N/C,Rt,Hr,dp,T,Gas -5.78,37.71,62.29,76.25,4.37,0.89,0.46,12.25,0.687737704918033,0.120491803278689,0.0100046838407494,30.0,10.0,13.0,650,6.1 -25.57,45.86,54.14,65.71,4.92,1.26,2.42,25.69,0.898493380003044,0.293220210013697,0.0164358545122508,20.0,5.0,0.15,600,15.38 -5.5005500550055,31.5483119906868,68.4516880093132,83.09,4.62,1.07,0.48,10.74,0.667228306655434,0.0969430737754242,0.0110379450853635,15.0,5.0,0.25,600,7.97193 -16.6457787398883,62.0967741935484,37.9032258064516,68.38,5.6,1.12,0.42,24.05,0.982743492249196,0.263783269961977,0.0140391927464171,0.3,5.0,2.0,750,10.88348 -9.79,48.99,51.01,73.91,3.98,0.88,0.49,20.75,0.646191313759978,0.210560140711676,0.0102054622417226,20.0,10.0,3.0,550,16.68 -5.44,41.2,58.8,82.0,4.79,1.51,0.58,11.1,0.7009756097560976,0.1015243902439024,0.0157839721254355,15.0,5.0,0.2,600,11.3 -38.7730061349693,47.15,52.85,74.47,4.8,1.41,0.98,18.34,0.773465825164496,0.184705250436417,0.0162289704387193,60.0,10.0,1.0,700,8.05 -15.85645,36.5599621123663,63.4400378876337,80.0704032260806,5.75550432943082,1.22339222461332,0.44301627527664,12.5076839445986,0.8625665560614234,0.1171564345937181,0.013096248608248,15.0,5.0,0.2,600,8.4 -30.38,59.81,40.19,67.36,4.54,1.67,1.12,25.31,0.808788598574822,0.281806710213777,0.0212504241601629,15.0,5.0,0.08,1100,25.8819005672 -11.2165120400292,34.8127274862041,65.1872725137959,80.57,5.39,1.01,0.49,12.54,0.802780191138141,0.116730793099168,0.0107448713629674,30.0,30.0,0.14,450,18.87111 +A,V,FC,C,H,N,S,O,H/C,O/C,N/C,Rt,Hr,dp,T,Gas +5.78,37.71,62.29,76.25,4.37,0.89,0.46,12.25,0.687737704918033,0.120491803278689,0.0100046838407494,30.0,10.0,13.0,650,6.1 +25.57,45.86,54.14,65.71,4.92,1.26,2.42,25.69,0.898493380003044,0.293220210013697,0.0164358545122508,20.0,5.0,0.15,600,15.38 +5.5005500550055,31.5483119906868,68.4516880093132,83.09,4.62,1.07,0.48,10.74,0.667228306655434,0.0969430737754242,0.0110379450853635,15.0,5.0,0.25,600,7.97193 +16.6457787398883,62.0967741935484,37.9032258064516,68.38,5.6,1.12,0.42,24.05,0.982743492249196,0.263783269961977,0.0140391927464171,0.3,5.0,2.0,750,10.88348 +9.79,48.99,51.01,73.91,3.98,0.88,0.49,20.75,0.646191313759978,0.210560140711676,0.0102054622417226,20.0,10.0,3.0,550,16.68 +5.44,41.2,58.8,82.0,4.79,1.51,0.58,11.1,0.7009756097560976,0.1015243902439024,0.0157839721254355,15.0,5.0,0.2,600,11.3 +38.7730061349693,47.15,52.85,74.47,4.8,1.41,0.98,18.34,0.773465825164496,0.184705250436417,0.0162289704387193,60.0,10.0,1.0,700,8.05 +15.85645,36.5599621123663,63.4400378876337,80.0704032260806,5.75550432943082,1.22339222461332,0.44301627527664,12.5076839445986,0.8625665560614234,0.1171564345937181,0.013096248608248,15.0,5.0,0.2,600,8.4 +30.38,59.81,40.19,67.36,4.54,1.67,1.12,25.31,0.808788598574822,0.281806710213777,0.0212504241601629,15.0,5.0,0.08,1100,25.8819005672 +11.2165120400292,34.8127274862041,65.1872725137959,80.57,5.39,1.01,0.49,12.54,0.802780191138141,0.116730793099168,0.0107448713629674,30.0,30.0,0.14,450,18.87111 diff --git a/meirejie/data/tar_data_test.csv b/meirejie/data/tar_data_test.csv new file mode 100644 index 0000000..6b0f995 --- /dev/null +++ b/meirejie/data/tar_data_test.csv @@ -0,0 +1,11 @@ +A,V,FC,C,H,N,S,O,H/C,O/C,N/C,Rt,Hr,dp,T,Tar +13.5,54.1,45.9,71.6,5.8,1.4,0.9,20.3,0.972067039106145,0.212639664804469,0.0167597765363128,30.0,10.0,0.84,450,6.2108384 +14.270205066345,50.41,49.59,58.69,3.14,1.24,0.21,18.12,0.642017379451355,0.231555631283012,0.0181096804030864,120.0,5.0,12.0,650,9.38428 +11.29,47.12,52.88,65.21,9.71,2.01,0.7,22.37,1.78684250881767,0.257284158871339,0.026420137139352,20.0,15.0,6.0,480,10.14648 +39.4,53.71,46.29,77.46,6.64,1.57,0.57,20.61,1.0286599535244,0.199554608830364,0.0173730220205821,15.0,30.5,5.0,600,9.6968 +11.29,47.12,52.88,65.21,9.71,2.01,0.7,22.37,1.78684250881767,0.257284158871339,0.026420137139352,10.0,15.0,6.0,480,10.07076 +25.56,45.86,54.14,65.71,4.92,1.26,2.42,25.69,0.898493380003044,0.293220210013697,0.0164358545122508,15.0,5.0,0.15,600,8.87 +11.2165120400292,34.8127274862041,65.1872725137959,80.57,5.39,1.01,0.49,12.54,0.802780191138141,0.116730793099168,0.0107448713629674,30.0,30.0,0.14,650,9.53456 +17.44,39.81,60.19,78.08,3.95,0.65,2.87,14.45,0.607069672131148,0.138799948770492,0.0071355386416861,20.0,40.0,6.0,500,6.1334 +15.84092126406,50.41,49.59,58.69,3.14,1.24,0.21,18.12,0.642017379451355,0.231555631283012,0.0181096804030864,120.0,5.0,6.0,650,6.504628 +6.17,46.5,53.5,76.14,3.06,1.06,0.24,19.5,0.482269503546099,0.192080378250591,0.0119329055499268,30.0,10.0,0.07,650,8.585445 diff --git a/meirejie/data/water_data_test.csv b/meirejie/data/water_data_test.csv index 3aea0d8..d0f3241 100644 --- a/meirejie/data/water_data_test.csv +++ b/meirejie/data/water_data_test.csv @@ -1,11 +1,11 @@ -A,V,FC,C,H,N,S,O,H/C,O/C,N/C,Rt,Hr,dp,T,Water -4.54,47.35,52.65,70.41,6.95,1.15,0.43,21.06,1.18449083936941,0.224328930549638,0.0139996347921359,30.0,20.0,2.0,600,10.046517999999995 -13.5,54.1,45.9,71.6,5.8,1.4,0.9,20.3,0.972067039106145,0.212639664804469,0.0167597765363128,30.0,10.0,0.84,550,8.1107936 -11.3260262196432,54.9321376635967,45.0678623364033,63.75,4.39,1.25,0.55,30.11,0.826352941176471,0.354235294117647,0.0168067226890756,20.0,5.0,3.0,376,5.59 -10.5,42.36,57.64,82.1381852776787,5.01843263170413,0.820549411344988,0.249732429539779,30.9311452015698,0.733169248588389,0.282430867236138,0.0085627417319903,20.0,5.0,0.2,510,6.15 -10.5,42.36,57.64,82.1381852776787,5.01843263170413,0.820549411344988,0.249732429539779,30.9311452015698,0.733169248588389,0.282430867236138,0.0085627417319903,60.0,10.0,0.15,650,5.44968 -8.8735776177054,38.1,61.9,78.54,5.28,1.2,0.39,14.59,0.80672268907563,0.139323911382735,0.0130961475499291,20.0,5.0,0.07,600,4.368024 -24.2816545626776,49.402279677509,50.597720322491,73.99,5.66,1.14,0.49,18.72,0.917961886741452,0.189755372347615,0.0132064178556948,15.0,5.0,0.2,600,7.078245 -16.6326530612245,36.82,63.18,83.04,5.39,1.48,0.64,9.45,0.778901734104046,0.0853504335260115,0.0152766308835673,60.0,10.0,1.0,500,5.58 -4.4152621238755,29.6624837732583,70.3375162267417,81.78,4.79,1.1,0.38,11.95,0.702861335289802,0.10959280997799,0.0115291898123886,30.0,20.0,0.85,700,12.071072 -11.9159836065574,51.1573804815633,48.8426195184367,83.22,3.89,2.72,0.45,20.21,0.560922855082913,0.182137707281903,0.0280152435884231,30.0,5.0,0.2,510,1.1399999999999952 +A,V,FC,C,H,N,S,O,H/C,O/C,N/C,Rt,Hr,dp,T,Water +4.54,47.35,52.65,70.41,6.95,1.15,0.43,21.06,1.18449083936941,0.224328930549638,0.0139996347921359,30.0,20.0,2.0,600,10.046517999999995 +13.5,54.1,45.9,71.6,5.8,1.4,0.9,20.3,0.972067039106145,0.212639664804469,0.0167597765363128,30.0,10.0,0.84,550,8.1107936 +11.3260262196432,54.9321376635967,45.0678623364033,63.75,4.39,1.25,0.55,30.11,0.826352941176471,0.354235294117647,0.0168067226890756,20.0,5.0,3.0,376,5.59 +10.5,42.36,57.64,82.1381852776787,5.01843263170413,0.820549411344988,0.249732429539779,30.9311452015698,0.733169248588389,0.282430867236138,0.0085627417319903,20.0,5.0,0.2,510,6.15 +10.5,42.36,57.64,82.1381852776787,5.01843263170413,0.820549411344988,0.249732429539779,30.9311452015698,0.733169248588389,0.282430867236138,0.0085627417319903,60.0,10.0,0.15,650,5.44968 +8.8735776177054,38.1,61.9,78.54,5.28,1.2,0.39,14.59,0.80672268907563,0.139323911382735,0.0130961475499291,20.0,5.0,0.07,600,4.368024 +24.2816545626776,49.402279677509,50.597720322491,73.99,5.66,1.14,0.49,18.72,0.917961886741452,0.189755372347615,0.0132064178556948,15.0,5.0,0.2,600,7.078245 +16.6326530612245,36.82,63.18,83.04,5.39,1.48,0.64,9.45,0.778901734104046,0.0853504335260115,0.0152766308835673,60.0,10.0,1.0,500,5.58 +4.4152621238755,29.6624837732583,70.3375162267417,81.78,4.79,1.1,0.38,11.95,0.702861335289802,0.10959280997799,0.0115291898123886,30.0,20.0,0.85,700,12.071072 +11.9159836065574,51.1573804815633,48.8426195184367,83.22,3.89,2.72,0.45,20.21,0.560922855082913,0.182137707281903,0.0280152435884231,30.0,5.0,0.2,510,1.1399999999999952 diff --git a/meirejie/utils/demo_data_make.py b/meirejie/utils/demo_data_make.py index 001f652..7ad0073 100644 --- a/meirejie/utils/demo_data_make.py +++ b/meirejie/utils/demo_data_make.py @@ -1,6 +1,6 @@ import pandas as pd # 读取Excel文件 -file_path = "D:\\project\\ai_station\\meirejie\\data\\char_data.csv" # 替换为你的Excel文件路径 +file_path = "/home/xiazj/ai-station-code/meirejie/data/tar_data.csv" # 替换为你的Excel文件路径 df = pd.read_csv(file_path) # 随机抽取10条数据 test_set = df.sample(n=10, random_state=1) # random_state保证每次抽样结果相同 @@ -8,45 +8,45 @@ test_set = df.sample(n=10, random_state=1) # random_state保证每次抽样结 columns = ['A', 'V', 'FC', 'C', 'H', 'N', 'S', 'O', 'H/C', 'O/C', 'N/C', 'Rt','Hr', 'dp', 'T','Tar'] test_set = test_set[columns] # 保存测试集到新的Excel文件 -test_set.to_csv('D:\\project\\ai_station\\meirejie\\data\\char_data_test.csv', index=False) # 保存为test_set.xlsx,不保存索引 +test_set.to_csv('/home/xiazj/ai-station-code/meirejie/data/tar_data_test.csv', index=False) # 保存为test_set.xlsx,不保存索引 print("测试集已保存为 char_data_test.csv") -file_path = "D:\\project\\ai_station\\meirejie\\data\\gas_data.csv" # 替换为你的Excel文件路径 +file_path = "/home/xiazj/ai-station-code/meirejie/data/gas_data.csv" # 替换为你的Excel文件路径 df = pd.read_csv(file_path) # 随机抽取10条数据 test_set = df.sample(n=10, random_state=1) # random_state保证每次抽样结果相同 columns = ['A', 'V', 'FC', 'C', 'H', 'N', 'S', 'O', 'H/C', 'O/C', 'N/C', 'Rt','Hr', 'dp', 'T','Gas'] test_set = test_set[columns] # 保存测试集到新的Excel文件 -test_set.to_csv('D:\\project\\ai_station\\meirejie\\data\\gas_data_test.csv', index=False) # 保存为test_set.xlsx,不保存索引 +test_set.to_csv('/home/xiazj/ai-station-code/meirejie/data/gas_data_test.csv', index=False) # 保存为test_set.xlsx,不保存索引 print("测试集已保存为 gas_data_test.csv") -file_path = "D:\\project\\ai_station\\meirejie\\data\\water_data.csv" # 替换为你的Excel文件路径 +file_path = "/home/xiazj/ai-station-code/meirejie/data/water_data.csv" # 替换为你的Excel文件路径 df = pd.read_csv(file_path) # 随机抽取10条数据 test_set = df.sample(n=10, random_state=1) # random_state保证每次抽样结果相同 columns = ['A', 'V', 'FC', 'C', 'H', 'N', 'S', 'O', 'H/C', 'O/C', 'N/C', 'Rt','Hr', 'dp', 'T','Water'] test_set = test_set[columns] # 保存测试集到新的Excel文件 -test_set.to_csv('D:\\project\\ai_station\\meirejie\\data\\water_data_test.csv', index=False) # 保存为test_set.xlsx,不保存索引 +test_set.to_csv('/home/xiazj/ai-station-code/meirejie/data/water_data_test.csv', index=False) # 保存为test_set.xlsx,不保存索引 print("测试集已保存为 water_data_test.csv") -file_path = "D:\\project\\ai_station\\meirejie\\data\\char_data.csv" # 替换为你的Excel文件路径 +file_path = "/home/xiazj/ai-station-code/meirejie/data/char_data.csv" # 替换为你的Excel文件路径 df = pd.read_csv(file_path) # 随机抽取10条数据 test_set = df.sample(n=10, random_state=1) # random_state保证每次抽样结果相同 columns = ['A', 'V', 'FC', 'C', 'H', 'N', 'S', 'O', 'H/C', 'O/C', 'N/C', 'Rt','Hr', 'dp', 'T','Char'] test_set = test_set[columns] # 保存测试集到新的Excel文件 -test_set.to_csv('D:\\project\\ai_station\\meirejie\\data\\char_data_test.csv', index=False) # 保存为test_set.xlsx,不保存索引 +test_set.to_csv('/home/xiazj/ai-station-code/meirejie/data/char_data_test.csv', index=False) # 保存为test_set.xlsx,不保存索引 print("测试集已保存为 char_data_test.csv") \ No newline at end of file diff --git a/run.py b/run.py index 68e5d9e..e490824 100644 --- a/run.py +++ b/run.py @@ -1,18 +1,22 @@ import sys - +import io from fastapi import FastAPI,File, UploadFile,Form,Query from fastapi.staticfiles import StaticFiles from fastapi import FastAPI, HTTPException -from fastapi.responses import FileResponse +from fastapi.responses import FileResponse,JSONResponse import sys import os import shutil from pydantic import BaseModel, validator -from typing import List +from typing import List, Optional import asyncio import pandas as pd import numpy as np from PIL import Image +import pickle +import cv2 +import copy +import base64 # 获取当前脚本所在目录 print("Current working directory:", os.getcwd()) current_dir = os.path.dirname(os.path.abspath(__file__)) @@ -25,7 +29,7 @@ from wudingpv.taihuyuan_roof.manet.model.resunet import resUnetpamcarb as roof_r from wudingpv.predictandeval_util import segmentation from guangfufadian import model_base as guangfufadian_model_base from fenglifadian import model_base as fenglifadian_model_base -from work_util import prepare_data,model_deal,params,data_util,post_model +from work_util import prepare_data,model_deal,params,data_util,post_model,sam_deal from work_util.logger import logger import joblib import mysql.connector @@ -33,7 +37,10 @@ import uuid import json import zipfile from pathlib import Path - +from segment_anything_model import sam_annotator +from segment_anything_model.sam_config import sam_config, sam_api_config +from segment_anything_model.segment_anything import sam_model_registry, SamPredictor +import traceback version = f"{sys.version_info.major}.{sys.version_info.minor}" app = FastAPI() @@ -77,6 +84,17 @@ windfd_args = fenglifadian_model_base.fenglifadian_Args() windfd_model_path = os.path.join(windfd_args.checkpoints,'Crossformer_Wind_farm_il192_ol12_sl6_win2_fa10_dm256_nh4_el3_itr0/checkpoint.pth') # 修改为实际模型路径 windfd_model = fenglifadian_model_base.ModelInference(windfd_model_path, windfd_args) + +# 模型加载 +checkpoint_path = os.path.join(current_dir,'segment_anything_model/weights/vit_b.pth') +sam = sam_model_registry["vit_b"](checkpoint=checkpoint_path) +device = "cuda" if cv2.cuda.getCudaEnabledDeviceCount() > 0 else "cpu" +_ = sam.to(device=device) +sam_predictor = SamPredictor(sam) +print(f"SAM模型已加载,使用设备: {device}") + + + # 将 /root/app 目录挂载为静态文件 app.mount("/files", StaticFiles(directory="/root/app"), name="files") @@ -86,6 +104,25 @@ async def read_root(): return {"message": message} +# 首页 +# 获取数据界面资源信息 +@app.get("/ai-station-api/index/show") +async def get_source_index_info(): + sql = "SELECT id,application_name, describe_data, img_url FROM app_shouye" + data = data_util.fetch_data(sql) + if data is None: + return { + "success":False, + "msg":"获取信息列表失败", + "data":None + } + else: + return {"success":True, + "msg":"获取信息成功", + "data":data} + + + # 获取数据界面资源信息 @app.get("/ai-station-api/data_source/show") @@ -188,7 +225,7 @@ type = {tar,char,gas,water} """ @app.get("/ai-station-api/mrj_feature/show") async def get_mrj_feature_info(type:str = None): - sql = "SELECT type, chinese_name, col_name, data_type, unit, data_scale FROM meirejie_features where use_type = %s;" + sql = "SELECT type, chinese_name, col_name, data_type, unit, data_scale,best_data FROM meijitancailiao_features where use_type = %s;" data = data_util.fetch_data_with_param(sql,(type,)) if data is None: return { @@ -235,7 +272,7 @@ async def mjt_models_predict_ssa(content: post_model.Zongbiaomianji): new_order = ["A", "VM", "K/C", "MM", "AT", "At", "Rt"] meijiegou_test_content = meijiegou_test_content.reindex(columns=new_order) ssa_result = model_deal.pred_single_ssa(meijiegou_test_content) - logger.info("Root endpoint was accessed") + # logger.info("Root endpoint was accessed") if ssa_result is None: return { "success":False, @@ -273,9 +310,14 @@ async def upload_file(file: UploadFile = File(...),type: str = Form(...), ): @app.get("/ai-station-api/mjt_multi_ssa/predict") async def mjt_multi_ssa_pred(model:str = None, path:str = None): data = model_deal.get_excel_ssa(model, path) - return {"success":True, + if data['status'] == True: + return {"success":True, "msg":"获取信息成功", - "data":data} + "data":data['reason']} + else: + return {"success":False, + "msg":data['reason'], + "data":None} @@ -308,9 +350,14 @@ async def mjt_models_predict_tpv(content: post_model.Zongbiaomianji): @app.get("/ai-station-api/mjt_multi_tpv/predict") async def mjt_multi_tpv_pred(model:str = None, path:str = None): data = model_deal.get_excel_tpv(model, path) - return {"success":True, + if data['status'] == True: + return {"success":True, "msg":"获取信息成功", - "data":data} + "data":data['reason']} + else: + return {"success":False, + "msg":data['reason'], + "data":None} #==========================煤基碳材料-煤炭材料应用细节接口================================== @app.post("/ai-station-api/mjt_models_meitan/predict") @@ -339,9 +386,14 @@ async def mjt_models_predict_meitan(content: post_model.Meitan): @app.get("/ai-station-api/mjt_multi_meitan/predict") async def mjt_multi_meitan_pred(model:str = None, path:str = None): data = model_deal.get_excel_meitan(model, path) - return {"success":True, + if data['status'] == True: + return {"success":True, "msg":"获取信息成功", - "data":data} + "data":data['reason']} + else: + return {"success":False, + "msg":data['reason'], + "data":None} #==========================煤基碳材料-煤沥青应用细节接口================================== @@ -371,10 +423,14 @@ async def mjt_models_predict_meiliqing(content: post_model.Meitan): @app.get("/ai-station-api/mjt_multi_meiliqing/predict") async def mjt_multi_meiliqing_pred(model:str = None, path:str = None): data = model_deal.get_excel_meiliqing(model, path) - return {"success":True, + if data['status'] == True: + return {"success":True, "msg":"获取信息成功", - "data":data} - + "data":data['reason']} + else: + return {"success":False, + "msg":data['reason'], + "data":None} @@ -428,6 +484,176 @@ async def mjt_multi_meiliqing_pred(model:str = None, path:str = None, type:int = } +#===============================煤热解-tar ======================================================= +""" +模型综合分析 +""" +@app.post("/ai-station-api/mrj_models_tar/predict") +async def mrj_models_predict_tar(content: post_model.Meirejie): + # 处理接收到的字典数据 + test_content = pd.DataFrame([content.model_dump()]) + test_content= test_content.rename(columns={"H_C":"H/C"}) + test_content= test_content.rename(columns={"O_C":"O/C"}) + test_content= test_content.rename(columns={"N_C":"N/C"}) + new_order = ["A", "V", "FC", "C", "H", "N" ,"S" ,"O" ,"H/C" ,"O/C" ,"N/C" ,"Rt" ,"Hr", "dp","T"] + test_content = test_content.reindex(columns=new_order) + tmp_result = model_deal.pred_single_tar(test_content) + if tmp_result is None: + return { + "success":False, + "msg":"获取信息列表失败", + "data":None + } + else: + return {"success":True, + "msg":"获取信息成功", + "data":tmp_result} + + + +""" +批量预测接口 +""" +@app.get("/ai-station-api/mrj_multi_tar/predict") +async def mrj_multi_tar_pred(model:str = None, path:str = None): + data = model_deal.get_excel_tar(model, path) + if data['status'] == True: + return {"success":True, + "msg":"获取信息成功", + "data":data['reason']} + else: + return {"success":False, + "msg":data['reason'], + "data":None} + + + + +#===============================煤热解-char ======================================================= +""" +模型综合分析 +""" +@app.post("/ai-station-api/mrj_models_char/predict") +async def mrj_models_predict_char(content: post_model.Meirejie): + # 处理接收到的字典数据 + test_content = pd.DataFrame([content.model_dump()]) + test_content= test_content.rename(columns={"H_C":"H/C"}) + test_content= test_content.rename(columns={"O_C":"O/C"}) + test_content= test_content.rename(columns={"N_C":"N/C"}) + new_order = ["A", "V", "FC", "C", "H", "N" ,"S" ,"O" ,"H/C" ,"O/C" ,"N/C" ,"Rt" ,"Hr", "dp","T"] + test_content = test_content.reindex(columns=new_order) + tmp_result = model_deal.pred_single_char(test_content) + if tmp_result is None: + return { + "success":False, + "msg":"获取信息列表失败", + "data":None + } + else: + return {"success":True, + "msg":"获取信息成功", + "data":tmp_result} + +""" +批量预测接口 +""" +@app.get("/ai-station-api/mrj_multi_char/predict") +async def mrj_multi_char_pred(model:str = None, path:str = None): + data = model_deal.get_excel_char(model, path) + if data['status'] == True: + return {"success":True, + "msg":"获取信息成功", + "data":data['reason']} + else: + return {"success":False, + "msg":data['reason'], + "data":None} + + +#===============================煤热解-water ======================================================= +""" +模型综合分析 +""" +@app.post("/ai-station-api/mrj_models_water/predict") +async def mrj_models_predict_water(content: post_model.Meirejie): + # 处理接收到的字典数据 + test_content = pd.DataFrame([content.model_dump()]) + test_content= test_content.rename(columns={"H_C":"H/C"}) + test_content= test_content.rename(columns={"O_C":"O/C"}) + test_content= test_content.rename(columns={"N_C":"N/C"}) + new_order = ["A", "V", "FC", "C", "H", "N" ,"S" ,"O" ,"H/C" ,"O/C" ,"N/C" ,"Rt" ,"Hr", "dp","T"] + test_content = test_content.reindex(columns=new_order) + tmp_result = model_deal.pred_single_water(test_content) + if tmp_result is None: + return { + "success":False, + "msg":"获取信息列表失败", + "data":None + } + else: + return {"success":True, + "msg":"获取信息成功", + "data":tmp_result} + +""" +批量预测接口 +""" +@app.get("/ai-station-api/mrj_multi_water/predict") +async def mrj_multi_water_pred(model:str = None, path:str = None): + data = model_deal.get_excel_water(model, path) + if data['status'] == True: + return {"success":True, + "msg":"获取信息成功", + "data":data['reason']} + else: + return {"success":False, + "msg":data['reason'], + "data":None} + + +#===============================煤热解-gas ======================================================= +""" +模型综合分析 +""" +@app.post("/ai-station-api/mrj_models_gas/predict") +async def mrj_models_predict_gas(content: post_model.Meirejie): + # 处理接收到的字典数据 + test_content = pd.DataFrame([content.model_dump()]) + test_content= test_content.rename(columns={"H_C":"H/C"}) + test_content= test_content.rename(columns={"O_C":"O/C"}) + test_content= test_content.rename(columns={"N_C":"N/C"}) + new_order = ["A", "V", "FC", "C", "H", "N" ,"S" ,"O" ,"H/C" ,"O/C" ,"N/C" ,"Rt" ,"Hr", "dp","T"] + test_content = test_content.reindex(columns=new_order) + tmp_result = model_deal.pred_single_gas(test_content) + if tmp_result is None: + return { + "success":False, + "msg":"获取信息列表失败", + "data":None + } + else: + return {"success":True, + "msg":"获取信息成功", + "data":tmp_result} + +""" +批量预测接口 +""" +@app.get("/ai-station-api/mrj_multi_gas/predict") +async def mrj_multi_gas_pred(model:str = None, path:str = None): + data = model_deal.get_excel_gas(model, path) + if data['status'] == True: + return {"success":True, + "msg":"获取信息成功", + "data":data['reason']} + else: + return {"success":False, + "msg":data['reason'], + "data":None} + + + + #================================ 地貌识别 ==================================================== """ 上传图片 , type = dimaoshibie @@ -458,6 +684,7 @@ async def upload_image(file: UploadFile = File(...),type: str = Form(...), ): "msg":"获取信息成功", "data":{"location": file_location}} + """ 图片地貌识别 """ @@ -753,7 +980,6 @@ async def get_ch4_data(path:str = None, page: int = Query(1, ge=1), page_size: i csv上传 """ # @app.post("/ai-station-api/document/upload") type = "ch4" - """ 预测, type """ @@ -771,7 +997,6 @@ async def get_ch4_predict(path:str=None,start_time:str=None, end_time:str = None "msg":"获取信息成功", "data":data['reason']} - #========================================光伏预测========================================================== """ 返回显示列表 @@ -982,6 +1207,8 @@ csv上传 """ # @app.post("/ai-station-api/document/upload") type = "wind_electric" + + """ 预测, type """ @@ -1000,9 +1227,576 @@ async def get_wind_electri_predict(path:str=None,start_time:str=None, end_time:s "data":data['reason']} +#======================== SAM ============================================================= +""" +文件上传 +1、 创建inputs目录,outputs目录 +2、 将图片上传到iuputs目录 +4、 加载当前图像 +5、 返回当前图像路径给前端 +""" +@app.post("/ai-station-api/sam-image/upload") +async def upload_sam_image(file: UploadFile = File(...),type: str = Form(...), ): + if not data_util.allowed_file(file.filename): + return { + "success":False, + "msg":"图片必须以 '.jpg', '.jpeg', '.png', '.tif' 结尾", + "data":None + } + if file.size > param.MAX_FILE_SAM_SIZE: + return { + "success":False, + "msg":"图片大小不能大于10MB", + "data":None + } + upload_dir = os.path.join(current_dir,'tmp',type, str(uuid.uuid4())) + if not os.path.exists(upload_dir): + os.makedirs(upload_dir ) + input_dir = os.path.join(upload_dir,'input') + os.makedirs(input_dir) + output_dir = os.path.join(upload_dir,'output') + os.makedirs(output_dir) + temp_dir = os.path.join(upload_dir,'temp') + os.makedirs(temp_dir) + # 将文件保存到指定目录 + file_location = os.path.join(input_dir , file.filename) + with open(file_location, "wb") as buffer: + shutil.copyfileobj(file.file, buffer) + + config_path = os.path.join(upload_dir,'model_params.pickle') + # # 初始化配置 , 每次上传图片时,会创建一个新的配置文件 + config = copy.deepcopy(sam_config) + api_config = copy.deepcopy(sam_api_config) + + config['input_dir'] = input_dir + config['output_dir'] = output_dir + config['image_files'] = [f for f in os.listdir(input_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff'))] + config['current_index'] = 0 + config['filename'] = config['image_files'][config['current_index']] + image_path = os.path.join(config['input_dir'], config['filename']) + config['image'] = cv2.imread(image_path) + config['image_rgb'] = cv2.cvtColor(config['image'].copy(), cv2.COLOR_BGR2RGB) + # 配置类获取参数信息 + api_config['output_dir'] = output_dir + # # 重置pickle中的信息 + # 重置API中的class_annotations,保留类别但清除掩码和点 + config, api_config = sam_deal.reset_class_annotations(config,api_config) + config = sam_deal.reset_annotation(config) + + save_data = (config,api_config) + with open(config_path, 'wb') as file: + pickle.dump(save_data, file) + # 将图片拷贝到temp目录下 + pil_img = Image.fromarray(config['image_rgb']) + tmp_path = os.path.join(temp_dir,'output_image.jpg') + pil_img.save(tmp_path) + encoded_string = sam_deal.load_tmp_image(tmp_path) + return {"success":True, + "msg":"获取信息成功", + "image": JSONResponse(content={"image_data": encoded_string}), + #"input_dir": input_dir, + #'output_dir': output_dir, + #'temp_dir': temp_dir, + 'file_path': upload_dir} + +""" +添加分类 +添加分类,并将分类设置为最新的添加的分类; + +要求:针对返回current_index,将列表默认成选择current_index对应的分类 +""" +@app.get("/ai-station-api/sam_class/create") +async def sam_class_set( + class_name: str = None, + color: Optional[List[int]] = Query(None, description="list of RGB color"), + path: str = None +): + loaded_data,api_config = sam_deal.load_model(path) + result = sam_deal.add_class(loaded_data,class_name,color) + if result['status'] == True: + loaded_data = result['reason'] + else: + return { + "success":False, + "msg":result['reason'], + "data":None + } + loaded_data['class_index'] = loaded_data['class_names'].index(class_name) + r, g, b = [int(c) for c in color] + bgr_color = (b, g, r) + result, api_config = sam_deal.set_current_class(loaded_data, api_config, loaded_data['class_index'], color=bgr_color) + # 更新配置内容 + sam_deal.save_model(loaded_data,api_config,path) + tmp_path = os.path.join(path,'temp/output_image.jpg') + encoded_string = sam_deal.load_tmp_image(tmp_path) + return {"success":True, + "msg":f"已添加类别: {class_name}, 颜色: {color}", + "image": JSONResponse(content={"image_data": encoded_string}), + "data":{"class_name_list": loaded_data['class_names'], + "current_index": loaded_data['class_index'], + "class_dict":loaded_data['class_colors'], + }} +""" +选择颜色, +current_index : 下拉列表中的分类索引 +rgb_color : +""" +@app.get("/ai-station-api/sam_color/select") +async def set_sam_color( + current_index: int = None, + rgb_color: List[int] = Query(None, description="list of RGB color"), + path: str = None + ): + loaded_data,api_config = sam_deal.load_model(path) + r, g, b = [int(c) for c in rgb_color] + bgr_color = (b, g, r) + data, api = sam_deal.set_class_color(loaded_data, api_config, current_index, bgr_color) + result, api_config = sam_deal.set_current_class(data, api, current_index, color=bgr_color) + sam_deal.save_model(data,api,path) + img = sam_deal.refresh_image(data,api,path) + if img['status'] == True: + encoded_string = sam_deal.load_tmp_image(img['reason']) + return { + "success":True, + "msg":"", + "image": JSONResponse(content={"image_data": encoded_string}), + } + else: + return { + "success":False, + "msg":img['reason'], + "image":None + } + + +""" +选择分类 +""" +@app.get("/ai-station-api/sam_classs/change") +async def on_class_selected(class_index : int = None,path: str = None): + # 加载配置内容 + loaded_data,api_config = sam_deal.load_model(path) + result, api_config = sam_deal.set_current_class(loaded_data, api_config, class_index, color=None) + # loaded_data['class_index'] = class_index + sam_deal.save_model(loaded_data,api_config,path) + if result: + img = sam_deal.refresh_image(loaded_data,api_config,path) + if img['status'] == True: + encoded_string = sam_deal.load_tmp_image(img['reason']) + return { + "success":True, + "msg":"", + "image":JSONResponse(content={"image_data": encoded_string}) + } + else: + return { + "success":False, + "msg":img['reason'], + "image":None + } + else: + return { + "success":False, + "msg":"分类标签识别错误", + "image":None + } + + +""" +删除分类, 前端跳转为normal ,index=0 +""" +@app.get("/ai-station-api/sam_class/delete") +async def sam_remove_class(path:str=None,select_index:int=None): + loaded_data,api_config = sam_deal.load_model(path) + class_name = loaded_data['class_names'][select_index] + loaded_data,api_config = sam_deal.remove_class(loaded_data,api_config,class_name) + sam_deal.save_model(loaded_data,api_config,path) + img = sam_deal.refresh_image(loaded_data,api_config,path) + if img['status'] == True: + encoded_string = sam_deal.load_tmp_image(img['reason']) + return { + "success":True, + "msg":"", + "image":JSONResponse(content={"image_data": encoded_string}) + } + else: + return { + "success":False, + "msg":img['reason'], + "image":None + } + +""" +添加标注点 -左键 +""" +@app.get("/ai-station-api/sam_point/left_add") +async def left_mouse_down(x:int=None,y:int=None,path:str=None): + loaded_data,api_config = sam_deal.load_model(path) + if not api_config['current_class']: + return { + "success":False, + "msg":"请先选择一个分类,在添加标点之前", + "image":None + } + is_foreground = True + result = sam_deal.add_annotation_point(api_config,x,y,is_foreground) + if result['status']== False: + return { + "success":False, + "msg":result['reason'], + "image":None + } + api_config = result['api'] + sam_deal.save_model(loaded_data,api_config,path) + img = sam_deal.refresh_image(loaded_data,api_config,path) + if img['status'] == True: + encoded_string = sam_deal.load_tmp_image(img['reason']) + return { + "success":True, + "msg":"", + "image":JSONResponse(content={"image_data": encoded_string}) + } + else: + return { + "success":False, + "msg":img['reason'], + "image":None + } + +""" +添加标注点 -右键 +""" +@app.get("/ai-station-api/sam_point/right_add") +async def right_mouse_down(x:int=None,y:int=None,path:str=None): + loaded_data,api_config = sam_deal.load_model(path) + if not api_config['current_class']: + return { + "success":False, + "msg":"请先选择一个分类,在添加标点之前", + "image":None + } + is_foreground = False + result = sam_deal.add_annotation_point(api_config,x,y,is_foreground) + if result['status']== False: + return { + "success":False, + "msg":result['reason'], + "image":None + } + api_config = result['api'] + sam_deal.save_model(loaded_data,api_config,path) + img = sam_deal.refresh_image(loaded_data,api_config,path) + if img['status'] == True: + encoded_string = sam_deal.load_tmp_image(img['reason']) + return { + "success":True, + "msg":"", + "image":JSONResponse(content={"image_data": encoded_string}) + } + else: + return { + "success":False, + "msg":img['reason'], + "image":None + } + +""" +删除上一个点 +""" +@app.get("/ai-station-api/sam_point/delete_last") +async def sam_delete_last_point(path:str=None): + loaded_data,api_config = sam_deal.load_model(path) + result = sam_deal.delete_last_point(api_config) + if result['status'] == True: + api_config = result['reason'] + sam_deal.save_model(loaded_data,api_config,path) + img = sam_deal.refresh_image(loaded_data,api_config,path) + if img['status'] == True: + encoded_string = sam_deal.load_tmp_image(img['reason']) + return { + "success":True, + "msg":"", + "image":JSONResponse(content={"image_data": encoded_string}) + } + else: + return { + "success":False, + "msg":img['reason'], + "data":None + } + else: + return { + "success":False, + "msg":result['reason'], + "data":None + } + + +""" +删除所有点 +""" +@app.get("/ai-station-api/sam_point/delete_all") +async def sam_clear_all_point(path:str=None): + loaded_data,api_config = sam_deal.load_model(path) + result = sam_deal.reset_current_class_points(api_config) + if result['status'] == True: + api_config = result['reason'] + sam_deal.save_model(loaded_data,api_config,path) + img = sam_deal.refresh_image(loaded_data,api_config,path) + if img['status'] == True: + encoded_string = sam_deal.load_tmp_image(img['reason']) + return { + "success":True, + "msg":"", + "image":JSONResponse(content={"image_data": encoded_string}) + } + else: + return { + "success":False, + "msg":img['reason'], + "image":None + } + else: + return { + "success":False, + "msg":result['reason'], + "image":None + } + + +""" +模型预测 +""" +@app.get("/ai-station-api/sam_model/predict") +async def sam_predict_mask(path:str=None): + loaded_data,api_config = sam_deal.load_model(path) + class_data = api_config['class_annotations'].get(api_config['current_class'], {}) + if not class_data.get('points'): + return { + "success":False, + "msg":"请在预测前添加至少一个预测样本点", + "data":None + } + else: + loaded_data = sam_deal.reset_annotation(loaded_data) + # 将标注点添加到loaded_data 中 + for i, (x, y) in enumerate(class_data['points']): + is_foreground = class_data['point_types'][i] + loaded_data = sam_deal.add_point(loaded_data, x, y, is_foreground=is_foreground) + try: + result = sam_deal.predict_mask(loaded_data,sam_predictor) + if result['status'] == False: + return { + "success":False, + "msg":result['reason'], + "data":None} + + result = result['reason'] + loaded_data = result['data'] + class_data['masks'] = [np.array(mask, dtype=np.uint8) for mask in result['masks']] + class_data['scores'] = result['scores'] + class_data['selected_mask_index'] = result['selected_index'] + if result['selected_index'] >= 0: + class_data['selected_mask'] = class_data['masks'][result['selected_index']] + logger.info(f"predict: Predicted {len(result['masks'])} masks, selected index: {result['selected_index']}") + sam_deal.save_model(loaded_data,api_config,path) + img = sam_deal.refresh_image(loaded_data,api_config,path) + + if img['status'] == True: + encoded_string = sam_deal.load_tmp_image(img['reason']) + return { + "success":True, + "msg":"", + "image":JSONResponse(content={"image_data": encoded_string}) + } + else: + return { + "success":False, + "msg":img['reason'], + "image":None + } + except Exception as e: + logger.error(f"predict: Error during prediction: {str(e)}") + traceback.print_exc() + return { + "success":False, + "msg":f"predict: Error during prediction: {str(e)}", + "data":None} + + +""" +清除所有信息 +""" +@app.get("/ai-station-api/sam_model/clear_all") +async def sam_reset_annotation(path:str=None): + loaded_data,api_config = sam_deal.load_model(path) + loaded_data,api_config = sam_deal.reset_annotation_all(loaded_data,api_config) + img = sam_deal.refresh_image(loaded_data,api_config,path) + if img['status'] == True: + encoded_string = sam_deal.load_tmp_image(img['reason']) + return { + "success":True, + "msg":"", + "image":JSONResponse(content={"image_data": encoded_string}) + } + else: + return { + "success":False, + "msg":img['reason'], + "image":None + } + + +""" +保存预测分类 +""" +@app.get("/ai-station-api/sam_model/save_stage") +async def sam_add_to_class(path:str=None,class_index:int=None): + loaded_data,api_config = sam_deal.load_model(path) + class_name = loaded_data['class_names'][class_index] + result = sam_deal.add_to_class(api_config,class_name) + if result['status'] == True: + api_config = result['reason'] + sam_deal.save_model(loaded_data,api_config,path) + img = sam_deal.refresh_image(loaded_data,api_config,path) + if img['status'] == True: + encoded_string = sam_deal.load_tmp_image(img['reason']) + return { + "success":True, + "msg":"", + "image":JSONResponse(content={"image_data": encoded_string}) + } + else: + return { + "success":False, + "msg":img['reason'], + "image":None + } + else: + return { + "success":False, + "msg":result['reason'], + "image":None + } + + + +""" +保存结果 +""" +@app.get("/ai-station-api/sam_model/save") +async def sam_save_annotation(path:str=None): + loaded_data,api_config = sam_deal.load_model(path) + + if not api_config['output_dir']: + logger.info("save_annotation: Output directory not set") + return { + "success":False, + "msg":"save_annotation: Output directory not set", + "data":None + } + has_annotations = False + for class_name, class_data in api_config['class_annotations'].items(): + if 'final_mask' in class_data and class_data['final_mask'] is not None: + has_annotations = True + break + if not has_annotations: + logger.info("save_annotation: No final masks to save") + return { + "success":False, + "msg":"save_annotation: No final masks to save", + "data":None + } + image_info = sam_deal.get_image_info(loaded_data) + if not image_info: + logger.info("save_annotation: No image info available") + return { + "success":False, + "msg":"save_annotation: No image info available", + "data":None + } + image_basename = os.path.splitext(image_info['filename'])[0] + annotation_dir = os.path.join(api_config['output_dir'], image_basename) + os.makedirs(annotation_dir, exist_ok=True) + saved_files = [] + orig_img = loaded_data['image'] + original_img_path = os.path.join(annotation_dir, f"{image_basename}.jpg") + cv2.imwrite(original_img_path, orig_img) + saved_files.append(original_img_path) + vis_img = orig_img.copy() + img_height, img_width = orig_img.shape[:2] + labelme_data = { + "version": "5.1.1", + "flags": {}, + "shapes": [], + "imagePath": f"{image_basename}.jpg", + "imageData": None, + "imageHeight": img_height, + "imageWidth": img_width + } + for class_name, class_data in api_config['class_annotations'].items(): + if 'final_mask' in class_data and class_data['final_mask'] is not None: + color = api_config['class_colors'].get(class_name, (0, 255, 0)) + vis_mask = class_data['final_mask'].copy() + color_mask = np.zeros_like(vis_img) + color_mask[vis_mask > 0] = color + vis_img = cv2.addWeighted(vis_img, 1.0, color_mask, 0.5, 0) + binary_mask = (class_data['final_mask'] > 0).astype(np.uint8) + contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + for contour in contours: + epsilon = 0.0001 * cv2.arcLength(contour, True) + approx_contour = cv2.approxPolyDP(contour, epsilon, True) + points = [[float(point[0][0]), float(point[0][1])] for point in approx_contour] + if len(points) >= 3: + shape_data = { + "label": class_name, + "points": points, + "group_id": None, + "shape_type": "polygon", + "flags": {} + } + labelme_data["shapes"].append(shape_data) + vis_path = os.path.join(annotation_dir, f"{image_basename}_mask.jpg") + cv2.imwrite(vis_path, vis_img) + saved_files.append(vis_path) + try: + is_success, buffer = cv2.imencode(".jpg", orig_img) + if is_success: + img_bytes = io.BytesIO(buffer).getvalue() + labelme_data["imageData"] = base64.b64encode(img_bytes).decode('utf-8') + else: + print("save_annotation: Failed to encode image data") + labelme_data["imageData"] = "" + except Exception as e: + logger.error(f"save_annotation: Could not encode image data: {str(e)}") + labelme_data["imageData"] = "" + json_path = os.path.join(annotation_dir, f"{image_basename}.json") + with open(json_path, 'w') as f: + json.dump(labelme_data, f, indent=2) + saved_files.append(json_path) + logger.info(f"save_annotation: Annotation saved to {annotation_dir}") + + # 将其打包 + zip_filename = "download.zip" + zip_filepath = Path(zip_filename) + dir_path = os.path.dirname(annotation_dir) + # 创建 ZIP 文件 + try: + with zipfile.ZipFile(zip_filepath, 'w') as zip_file: + # 遍历文件夹中的所有文件 + for foldername, subfolders, filenames in os.walk(dir_path): + for filename in filenames: + file_path = os.path.join(foldername, filename) + # 将文件写入 ZIP 文件 + zip_file.write(file_path, os.path.relpath(file_path, dir_path)) + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error creating zip file: {str(e)}") + + # 返回 ZIP 文件作为响应 + return FileResponse(zip_filepath, media_type='application/zip', filename=zip_filename) # @app.post("/ai-station-api/items/") diff --git a/segment_anything_model b/segment_anything_model new file mode 160000 index 0000000..685b6e5 --- /dev/null +++ b/segment_anything_model @@ -0,0 +1 @@ +Subproject commit 685b6e545db2114b9b834b1c94712b026298b0a7 diff --git a/work_util/__pycache__/data_util.cpython-310.pyc b/work_util/__pycache__/data_util.cpython-310.pyc index b268b2b..a9bb158 100644 Binary files a/work_util/__pycache__/data_util.cpython-310.pyc and b/work_util/__pycache__/data_util.cpython-310.pyc differ diff --git a/work_util/__pycache__/data_util.cpython-39.pyc b/work_util/__pycache__/data_util.cpython-39.pyc index c7985d5..2e8b92d 100644 Binary files a/work_util/__pycache__/data_util.cpython-39.pyc and b/work_util/__pycache__/data_util.cpython-39.pyc differ diff --git a/work_util/__pycache__/model_deal.cpython-310.pyc b/work_util/__pycache__/model_deal.cpython-310.pyc index 6fc43fc..0440bd0 100644 Binary files a/work_util/__pycache__/model_deal.cpython-310.pyc and b/work_util/__pycache__/model_deal.cpython-310.pyc differ diff --git a/work_util/__pycache__/model_deal.cpython-39.pyc b/work_util/__pycache__/model_deal.cpython-39.pyc index ceeffc4..8631118 100644 Binary files a/work_util/__pycache__/model_deal.cpython-39.pyc and b/work_util/__pycache__/model_deal.cpython-39.pyc differ diff --git a/work_util/__pycache__/params.cpython-310.pyc b/work_util/__pycache__/params.cpython-310.pyc index 6e8eb59..c57146d 100644 Binary files a/work_util/__pycache__/params.cpython-310.pyc and b/work_util/__pycache__/params.cpython-310.pyc differ diff --git a/work_util/__pycache__/params.cpython-39.pyc b/work_util/__pycache__/params.cpython-39.pyc index f091689..f09e998 100644 Binary files a/work_util/__pycache__/params.cpython-39.pyc and b/work_util/__pycache__/params.cpython-39.pyc differ diff --git a/work_util/__pycache__/post_model.cpython-310.pyc b/work_util/__pycache__/post_model.cpython-310.pyc index fe4d342..9531e36 100644 Binary files a/work_util/__pycache__/post_model.cpython-310.pyc and b/work_util/__pycache__/post_model.cpython-310.pyc differ diff --git a/work_util/__pycache__/post_model.cpython-39.pyc b/work_util/__pycache__/post_model.cpython-39.pyc index ac297c7..c299c61 100644 Binary files a/work_util/__pycache__/post_model.cpython-39.pyc and b/work_util/__pycache__/post_model.cpython-39.pyc differ diff --git a/work_util/__pycache__/prepare_data.cpython-39.pyc b/work_util/__pycache__/prepare_data.cpython-39.pyc index 2407ed9..7a20b03 100644 Binary files a/work_util/__pycache__/prepare_data.cpython-39.pyc and b/work_util/__pycache__/prepare_data.cpython-39.pyc differ diff --git a/work_util/__pycache__/sam_deal.cpython-310.pyc b/work_util/__pycache__/sam_deal.cpython-310.pyc new file mode 100644 index 0000000..59a8ca6 Binary files /dev/null and b/work_util/__pycache__/sam_deal.cpython-310.pyc differ diff --git a/work_util/__pycache__/sam_deal.cpython-39.pyc b/work_util/__pycache__/sam_deal.cpython-39.pyc new file mode 100644 index 0000000..bfd4b09 Binary files /dev/null and b/work_util/__pycache__/sam_deal.cpython-39.pyc differ diff --git a/work_util/data_util.py b/work_util/data_util.py index a10b470..1fa94d1 100644 --- a/work_util/data_util.py +++ b/work_util/data_util.py @@ -129,6 +129,11 @@ def allowed_file(filename: str) -> bool: param = params.ModelParams() return any(filename.endswith(ext) for ext in param.ALLOWED_EXTENSIONS) + + +####################SAM################################### + + # data_info ={ # "result": [ # { diff --git a/work_util/model_deal.py b/work_util/model_deal.py index 684cfc1..fdfb7ba 100644 --- a/work_util/model_deal.py +++ b/work_util/model_deal.py @@ -13,10 +13,10 @@ import cv2 as cv import torch.nn.functional as F from joblib import dump, load from fastapi import HTTPException - +import pickle param = params.ModelParams() - - +import cv2 +import traceback ################################################### 图像类函数调用########################################################################### @@ -579,6 +579,7 @@ def pred_single_tar(test_content): + def pred_single_gas(test_content): gas_pred = [] current_directory = os.getcwd() @@ -587,7 +588,7 @@ def pred_single_gas(test_content): gas_model = load(model_path) pred = gas_model.predict(test_content) gas_pred.append(pred[0]) - result = [param.meirejie_tar_mae, param.meirejie_tar_r2,gas_pred] + result = [param.meirejie_gas_mae, param.meirejie_gas_r2,gas_pred] # 创建 DataFrame df = pd.DataFrame(result, index=param.index, columns=param.columns) result = df.to_json(orient='index') @@ -602,7 +603,7 @@ def pred_single_water(test_content): water_model = load(model_path) pred = water_model.predict(test_content) water_pred.append(pred[0]) - result = [param.meirejie_tar_mae, param.meirejie_tar_r2,water_pred] + result = [param.meirejie_water_mae, param.meirejie_water_r2,water_pred] # 创建 DataFrame df = pd.DataFrame(result, index=param.index, columns=param.columns) result = df.to_json(orient='index') @@ -616,7 +617,7 @@ def pred_single_char(test_content): char_model = load(model_path) pred = char_model.predict(test_content) char_pred.append(pred[0]) - result = [param.meirejie_tar_mae, param.meirejie_tar_r2,char_pred] + result = [param.meirejie_char_mae, param.meirejie_char_r2,char_pred] # 创建 DataFrame df = pd.DataFrame(result, index=param.index, columns=param.columns) result = df.to_json(orient='index') @@ -624,50 +625,65 @@ def pred_single_char(test_content): -def choose_model(name,data): +def choose_model_meirejie(name,data): current_directory = os.getcwd() model_path = os.path.join(current_directory,'meirejie',param.meirejie_model_dict[name]) model = load(model_path) pred = model.predict(data) return pred - -def get_excel_tar(model_name): - data_name = param.meirejie_test_data['tar'] - current_directory = os.getcwd() - data_path = os.path.join(current_directory,"tmp","meirejie",data_name) - test_data = pd.read_csv(data_path) - pred = choose_model(model_name,test_data) +def get_excel_tar(model_name,file_path): + if not file_path.endswith('.csv'): + return {"status":False, "reason":"上传文件类型有误,必须是 CSV 文件"} + test_data = pd.read_csv(file_path) + expected_columns = ["A", "V", "FC", "C", "H", "N" ,"S" ,"O" ,"H/C" ,"O/C" ,"N/C" ,"Rt" ,"Hr", "dp","T","Tar"] + if list(test_data.columns) != expected_columns: + return {"status":False, "reason":f"文件列名不匹配,预期列名为: {expected_columns}"} + # raise HTTPException(status_code=400, detail=f"文件列名不匹配,预期列名为: {expected_columns}") + del test_data['Tar'] + pred = choose_model_meirejie(model_name,test_data) test_data['tar_pred'] = pred - return test_data.to_json(orient='records', lines=True) + return {"status":True, "reason":test_data.to_dict(orient='records')} -def get_excel_gas(model_name): - data_name = param.meirejie_test_data['gas'] - current_directory = os.getcwd() - data_path = os.path.join(current_directory,"tmp","meirejie",data_name) - test_data = pd.read_csv(data_path) - pred = choose_model(model_name,test_data) +def get_excel_gas(model_name,file_path): + if not file_path.endswith('.csv'): + return {"status":False, "reason":"上传文件类型有误,必须是 CSV 文件"} + test_data = pd.read_csv(file_path) + expected_columns = ["A", "V", "FC", "C", "H", "N" ,"S" ,"O" ,"H/C" ,"O/C" ,"N/C" ,"Rt" ,"Hr", "dp","T","Gas"] + if list(test_data.columns) != expected_columns: + return {"status":False, "reason":f"文件列名不匹配,预期列名为: {expected_columns}"} + # raise HTTPException(status_code=400, detail=f"文件列名不匹配,预期列名为: {expected_columns}") + del test_data['Gas'] + pred = choose_model_meirejie(model_name,test_data) test_data['gas_pred'] = pred - return test_data.to_json(orient='records', lines=True) + return {"status":True, "reason":test_data.to_dict(orient='records')} -def get_excel_char(model_name): - data_name = param.meirejie_test_data['char'] - current_directory = os.getcwd() - data_path = os.path.join(current_directory,"tmp","meirejie",data_name) - test_data = pd.read_csv(data_path) - pred = choose_model(model_name,test_data) +def get_excel_char(model_name,file_path): + if not file_path.endswith('.csv'): + return {"status":False, "reason":"上传文件类型有误,必须是 CSV 文件"} + test_data = pd.read_csv(file_path) + expected_columns = ["A", "V", "FC", "C", "H", "N" ,"S" ,"O" ,"H/C" ,"O/C" ,"N/C" ,"Rt" ,"Hr", "dp","T","Char"] + if list(test_data.columns) != expected_columns: + return {"status":False, "reason":f"文件列名不匹配,预期列名为: {expected_columns}"} + # raise HTTPException(status_code=400, detail=f"文件列名不匹配,预期列名为: {expected_columns}") + del test_data['Char'] + pred = choose_model_meirejie(model_name,test_data) test_data['char_pred'] = pred - return test_data.to_json(orient='records', lines=True) + return {"status":True, "reason":test_data.to_dict(orient='records')} -def get_excel_water(model_name): - data_name = param.meirejie_test_data['water'] - current_directory = os.getcwd() - data_path = os.path.join(current_directory,"tmp","meirejie",data_name) - test_data = pd.read_csv(data_path) - pred = choose_model(model_name,test_data) +def get_excel_water(model_name,file_path): + if not file_path.endswith('.csv'): + return {"status":False, "reason":"上传文件类型有误,必须是 CSV 文件"} + test_data = pd.read_csv(file_path) + expected_columns = ["A", "V", "FC", "C", "H", "N" ,"S" ,"O" ,"H/C" ,"O/C" ,"N/C" ,"Rt" ,"Hr", "dp","T","Water"] + if list(test_data.columns) != expected_columns: + return {"status":False, "reason":f"文件列名不匹配,预期列名为: {expected_columns}"} + # raise HTTPException(status_code=400, detail=f"文件列名不匹配,预期列名为: {expected_columns}") + del test_data['Water'] + pred = choose_model_meirejie(model_name,test_data) test_data['water_pred'] = pred - return test_data.to_json(orient='records', lines=True) + return {"status":True, "reason":test_data.to_dict(orient='records')} @@ -741,52 +757,59 @@ def choose_model_meijitancailiao(name,data): def get_excel_ssa(model_name,file_path): if not file_path.endswith('.csv'): - raise HTTPException(status_code=400, detail="文件类型有误,必须是 CSV 文件") + return {"status":False, "reason":"上传文件类型有误,必须是 CSV 文件"} + # raise HTTPException(status_code=400, detail="文件类型有误,必须是 CSV 文件") test_data = pd.read_csv(file_path) expected_columns = ["A", "VM", "K/C", "MM", "AT", "At", "Rt", "SSA"] if list(test_data.columns) != expected_columns: - raise HTTPException(status_code=400, detail=f"文件列名不匹配,预期列名为: {expected_columns}") + return {"status":False, "reason":f"文件列名不匹配,预期列名为: {expected_columns}"} + # raise HTTPException(status_code=400, detail=f"文件列名不匹配,预期列名为: {expected_columns}") del test_data['SSA'] pred = choose_model_meijitancailiao(model_name,test_data) test_data['ssa_pred'] = pred - return test_data.to_json(orient='records') + return {"status":True, "reason":test_data.to_dict(orient='records')} def get_excel_tpv(model_name,file_path): if not file_path.endswith('.csv'): - raise HTTPException(status_code=400, detail="文件类型有误,必须是 CSV 文件") + return {"status":False, "reason":"上传文件类型有误,必须是 CSV 文件"} + # raise HTTPException(status_code=400, detail="文件类型有误,必须是 CSV 文件") test_data = pd.read_csv(file_path) expected_columns = ["A", "VM", "K/C", "MM", "AT", "At", "Rt", "TPV"] if list(test_data.columns) != expected_columns: - raise HTTPException(status_code=400, detail=f"文件列名不匹配,预期列名为: {expected_columns}") + # raise HTTPException(status_code=400, detail=f"文件列名不匹配,预期列名为: {expected_columns}") + return {"status":False, "reason":f"文件列名不匹配,预期列名为: {expected_columns}"} del test_data['TPV'] pred = choose_model_meijitancailiao(model_name,test_data) test_data['tpv_pred'] = pred - return test_data.to_json(orient='records') + return {"status":True, "reason":test_data.to_dict(orient='records')} def get_excel_meitan(model_name,file_path): if not file_path.endswith('.csv'): - raise HTTPException(status_code=400, detail="文件类型有误,必须是 CSV 文件") + return {"status":False, "reason":"上传文件类型有误,必须是 CSV 文件"} + # raise HTTPException(status_code=400, detail="文件类型有误,必须是 CSV 文件") test_data = pd.read_csv(file_path) expected_columns = ["SSA", "TPV", "N", "O", "ID/IG", "J","C"] if list(test_data.columns) != expected_columns: - raise HTTPException(status_code=400, detail=f"文件列名不匹配,预期列名为: {expected_columns}") + # raise HTTPException(status_code=400, detail=f"文件列名不匹配,预期列名为: {expected_columns}") + return {"status":False, "reason":f"文件列名不匹配,预期列名为: {expected_columns}"} del test_data['C'] pred = choose_model_meijitancailiao(model_name,test_data) test_data['C_pred'] = pred - return test_data.to_json(orient='records', lines=True) + # return test_data.to_dict(orient='records') + return {"status":True, "reason":test_data.to_dict(orient='records')} def get_excel_meiliqing(model_name,file_path): if not file_path.endswith('.csv'): - raise HTTPException(status_code=400, detail="文件类型有误,必须是 CSV 文件") + return {"status":False, "reason":"上传文件类型有误,必须是 CSV 文件"} test_data = pd.read_csv(file_path) expected_columns = ["SSA", "TPV", "N", "O", "ID/IG", "J", "C"] if list(test_data.columns) != expected_columns: - raise HTTPException(status_code=400, detail=f"文件列名不匹配,预期列名为: {expected_columns}") + return {"status":False, "reason":f"文件列名不匹配,预期列名为: {expected_columns}"} del test_data['C'] pred = choose_model_meijitancailiao(model_name,test_data) test_data['C_pred'] = pred - return test_data.to_json(orient='records', lines=True) + return {"status":True, "reason":test_data.to_dict(orient='records')} def pred_func(func_name, pred_data): @@ -834,4 +857,7 @@ def get_pic_path(url): # 3. 添加本地根目录 local_path = f"/root/app{relative_path}" - return local_path \ No newline at end of file + return local_path + + + diff --git a/work_util/params.py b/work_util/params.py index 0cea832..2a65193 100644 --- a/work_util/params.py +++ b/work_util/params.py @@ -140,9 +140,18 @@ class ModelParams(): } index = ['mae', 'r2', 'result'] - columns = ['XGBoost', 'Linear Regression', 'Ridge Regression', 'Gaussian Process Regression', - 'ElasticNet Regression', 'K-Nearest Neighbors', 'Support Vector Regression', - 'Decision Tree Regression', 'Random Forest Regression', 'AdaBoost Regression'] + columns = [ + 'XGBoost(极端梯度提升)', + 'Linear Regression(线性回归)', + 'Ridge Regression(岭回归)', + 'Gaussian Process Regression(高斯过程回归)', + 'ElasticNet Regression(弹性网回归)', + 'K-Nearest Neighbors(K最近邻)', + 'Support Vector Regression(支持向量回归)', + 'Decision Tree Regression(决策树回归)', + 'Random Forest Regression(随机森林回归)', + 'AdaBoost Regression(AdaBoost回归)' +] meirejie_model_list_gas = ['xgb_gas','lr_gas','ridge_gas','gp_gas','en_gas','kn_gas','svr_gas','dtr_gas','rfr_gas','adb_gas'] meirejie_model_list_char = ['xgb_char','lr_char','ridge_char','gp_char','en_char','kn_char','svr_char','dtr_char','rfr_char','adb_char'] meirejie_model_list_water = ['xgb_water','lr_water','ridge_water','gp_water','en_water','kn_water','svr_water','dtr_water','rfr_water','adb_water'] @@ -224,29 +233,65 @@ class ModelParams(): "xgb_meiliqing":"model/meiliqing_XGB.joblib", } - columns_ssa = ['XGBoost', 'Linear Regression', 'Ridge Regression', 'Gaussian Process Regression', - 'ElasticNet Regression', 'K-Nearest Neighbors', 'Support Vector Regression', - 'Decision Tree Regression', 'Random Forest Regression', 'AdaBoost Regression'] + columns_ssa = [ + 'XGBoost(极端梯度提升)', + 'Linear Regression(线性回归)', + 'Ridge Regression(岭回归)', + 'Gaussian Process Regression(高斯过程回归)', + 'ElasticNet Regression(弹性网回归)', + 'K-Nearest Neighbors(最近邻居算法)', + 'Support Vector Regression(支持向量回归)', + 'Decision Tree Regression(决策树回归)', + 'Random Forest Regression(随机森林回归)', + 'AdaBoost Regression(自适应提升回归)' +] meijitancailiao_model_list_ssa = ['xgb_ssa','lr_ssa','ridge_ssa','gp_ssa','en_ssa','kn_ssa','svr_ssa','dtr_ssa','rfr_ssa','adb_ssa'] meijitancailiao_ssa_mae = [258, 407,408 ,282 ,411 ,389, 405, 288,193, 330] meijitancailiao_ssa_r2 = [0.92,0.82,0.82,0.89,0.81,0.82,0.87,0.88,0.95,0.88] - columns_tpv = ['XGBoost', 'Linear Regression', 'Ridge Regression', 'Gaussian Process Regression', - 'ElasticNet Regression', 'Gradient Boosting Regression', 'Support Vector Regression', - 'Decision Tree Regression', 'Random Forest Regression', 'AdaBoost Regression'] + columns_tpv = [ + 'XGBoost(极端梯度提升)', + 'Linear Regression(线性回归)', + 'Ridge Regression(岭回归)', + 'Gaussian Process Regression(高斯过程回归)', + 'ElasticNet Regression(弹性网回归)', + 'Gradient Boosting Regression(梯度提升回归)', + 'Support Vector Regression(支持向量回归)', + 'Decision Tree Regression(决策树回归)', + 'Random Forest Regression(随机森林回归)', + 'AdaBoost Regression(自适应提升回归)' +] meijitancailiao_model_list_tpv = ['xgb_tpv', 'lr_tpv', 'ridge_tpv', 'gp_tpv', 'en_tpv', 'gdbt_tpv', 'svr_tpv', 'dtr_tpv', 'rfr_tpv', 'adb_tpv'] meijitancailiao_tpv_mae = [0.2, 0.2, 0.2, 0.2, 0.2, 0.23, 0.23, 0.21, 0.16, 0.21] meijitancailiao_tpv_r2 = [0.81, 0.81, 0.81, 0.8, 0.82, 0.80, 0.78, 0.73, 0.85, 0.84] - columns_meitan = ['XGBoost', 'Linear Regression', 'Ridge Regression', 'Gaussian Process Regression', - 'ElasticNet Regression', 'Gradient Boosting Regression', 'Support Vector Regression', - 'Decision Tree Regression', 'Random Forest Regression', 'AdaBoost Regression'] + columns_meitan = [ + 'XGBoost(极端梯度提升)', + 'Linear Regression(线性回归)', + 'Ridge Regression(岭回归)', + 'Gaussian Process Regression(高斯过程回归)', + 'ElasticNet Regression(弹性网回归)', + 'Gradient Boosting Regression(梯度提升回归)', + 'Support Vector Regression(支持向量回归)', + 'Decision Tree Regression(决策树回归)', + 'Random Forest Regression(随机森林回归)', + 'AdaBoost Regression(自适应提升回归)' +] meijitancailiao_model_list_meitan = ['xgb_meitan', 'lr_meitan', 'ridge_meitan', 'gp_meitan', 'en_meitan', 'gdbt_meitan', 'svr_meitan', 'dtr_meitan', 'rfr_meitan', 'adb_meitan'] meijitancailiao_meitan_mae = [8.17, 37.61, 37.66, 13.41, 20.96, 8.03, 14.89, 19.48, 12.53, 15.6] meijitancailiao_meitan_r2 = [0.96, 0.19, 0.19, 0.91, 0.8, 0.96, 0.88, 0.86, 0.91, 0.91] - columns_meiliqing = ['XGBoost', 'Linear Regression', 'Ridge Regression', 'Gaussian Process Regression', - 'ElasticNet Regression', 'Gradient Boosting Regression', 'Support Vector Regression', - 'Decision Tree Regression', 'Random Forest Regression', 'AdaBoost Regression'] + columns_meiliqing = [ + 'XGBoost(极端梯度提升)', + 'Linear Regression(线性回归)', + 'Ridge Regression(岭回归)', + 'Gaussian Process Regression(高斯过程回归)', + 'ElasticNet Regression(弹性网回归)', + 'Gradient Boosting Regression(梯度提升回归)', + 'Support Vector Regression(支持向量回归)', + 'Decision Tree Regression(决策树回归)', + 'Random Forest Regression(随机森林回归)', + 'AdaBoost Regression(自适应提升回归)' +] meijitancailiao_model_list_meiliqing = ['xgb_meiliqing', 'lr_meiliqing', 'ridge_meiliqing', 'gp_meiliqing', 'en_meiliqing', 'gdbt_meiliqing', 'svr_meiliqing', 'dtr_meiliqing', 'rfr_meiliqing', 'adb_meiliqing'] meijitancailiao_meiliqing_mae = [8.38, 35.02, 35.1, 11.02, 13.58, 7.04, 13.13, 13.13, 11.25, 9.99] meijitancailiao_meiliqing_r2 = [0.95, 0.33, 0.33, 0.94, 0.91, 0.97, 0.88, 0.89, 0.92, 0.94] @@ -276,4 +321,12 @@ class ModelParams(): } ALLOWED_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.tif'} - MAX_FILE_SIZE = 100 * 1024 * 1024 # 100 MB \ No newline at end of file + MAX_FILE_SIZE = 100 * 1024 * 1024 # 100 MB + MAX_FILE_SAM_SIZE = 10 * 1024 * 1024 + + + + + + DEFAULT_MODEL_PATH = r"/home/xiazj/ai-station-code/segment_anything_model/weights/vit_b.pth" + \ No newline at end of file diff --git a/work_util/post_model.py b/work_util/post_model.py index 5ea6342..2c5fc8a 100644 --- a/work_util/post_model.py +++ b/work_util/post_model.py @@ -21,6 +21,26 @@ class Meitan(BaseModel): J: float +class Meirejie(BaseModel): + A: float + V: float + FC: float + C: float + H: float + N: float + S: float + O: float + H_C: float + O_C: float + N_C: float + Rt: float + Hr: float + dp: float + T: float + + + + class FormData(BaseModel): A_min: Optional[float] = None A_max: Optional[float] = None diff --git a/work_util/sam_deal.py b/work_util/sam_deal.py new file mode 100644 index 0000000..919f5cf --- /dev/null +++ b/work_util/sam_deal.py @@ -0,0 +1,398 @@ +import cv2 +import os +import numpy as np +import json +import shutil +from segment_anything_model.segment_anything import sam_model_registry, SamPredictor +import random +import traceback +import base64 +import io +from work_util.logger import logger +import pickle +from PIL import Image + + + +def get_display_image_ori(loaded_data): + if loaded_data['image_rgb'] is None: + logger.info("get_display_image: No image loaded") + return {"status": False,"reason":"无图片加载"} + display_image = loaded_data['image_rgb'].copy() + try: + for point, label in zip(loaded_data['input_point'], loaded_data['input_label']): + color = (0, 255, 0) if label == 1 else (0, 0, 255) + cv2.circle(display_image, tuple(point), 5, color, -1) + if loaded_data['selected_mask'] is not None: + class_name = loaded_data['class_names'][loaded_data['class_index']] + color = loaded_data['class_colors'].get(class_name, (0, 0, 128)) + display_image = sam_apply_mask(display_image, loaded_data['selected_mask'], color) + logger.info(f"get_display_image: Returning image with shape {display_image.shape}") + return {"status" : True, "reason":display_image} + except Exception as e: + logger.info(f"get_display_image: Error processing image: {str(e)}") + traceback.print_exc() + return {"status": False,"reason":f"get_display_image: Error processing image: {str(e)}"} + +def get_all_classes(data): + return { + "classes": data['class_names'], + "colors": {name: color for name, color in data['class_colors'].items()} + } + +def get_classes(data): + return get_all_classes(data) + + +def reset_annotation(loaded_data): + loaded_data['input_point'] = [] + loaded_data['input_label'] = [] + loaded_data['selected_mask'] = None + loaded_data['logit_input'] = None + loaded_data['masks'] = {} + logger.info("已重置标注状态") + return loaded_data + + +def remove_class(data,api,class_name): + if class_name in api['class_annotations']: + del api['class_annotations'][class_name] + + if class_name == data['class_names'][data['class_index']]: + data['class_index'] = 0 + data['class_names'].remove(class_name) + if class_name in data['class_colors']: + del data['class_colors'][class_name] + if class_name in data['masks']: + del data['masks'][class_name] + return data,api + + +def reset_annotation_all(data,api): + data = reset_annotation(data) + for class_data in api['class_annotations'].values(): + class_data['points'] = [] + class_data['point_types'] = [] + class_data['masks'] = [] + class_data['scores'] = [] + class_data['selected_mask_index'] = -1 + if 'selected_mask' in class_data: + del class_data['selected_mask'] + return data,api + + +def reset_class_annotations(data,api): + """重置class_annotations,保留类别但清除掩码和点""" + classes = get_classes(data).get('classes', []) + new_annotations = {} + for class_name in classes: + new_annotations[class_name] = { + 'points': [], + 'point_types': [], + 'masks': [], + 'selected_mask_index': -1 + } + api['class_annotations'] = new_annotations + logger.info("已重置class_annotations,保留类别但清除掩码和点") + return data,api + + +def add_class(data,class_name, color=None): + if class_name in data['class_names']: + logger.info(f"类别 '{class_name}' 已存在") + return {'status':False, 'reason':f"类别 '{class_name}' 已存在"} + data['class_names'].append(class_name) + if color is None: + color = tuple(np.random.randint(100, 256, 3).tolist()) + r, g, b = [int(c) for c in color] + bgr_color = (b, g, r) + data['class_colors'][class_name] = tuple(bgr_color) + logger.info(f"已添加类别: {class_name}, 颜色: {tuple(color)}") + return {'status':True, 'reason':data} + + +def set_current_class(data, api, class_index, color=None): + classes = get_classes(data) + if 'classes' in classes and class_index < len(classes['classes']): + class_name = classes['classes'][class_index] + api['current_class'] = class_name + if class_name not in api['class_annotations']: + api['class_annotations'][class_name] = { + 'points': [], + 'point_types': [], + 'masks': [], + 'selected_mask_index': -1 + } + color = data['class_colors'][class_name] + if color: + api['class_colors'][class_name] = color + elif class_name not in api['class_colors']: + predefined_colors = [ + (255, 0, 0), (0, 255, 0), (0, 0, 255), + (255, 255, 0), (255, 0, 255), (0, 255, 255) + ] + color_index = len(api['class_colors']) % len(predefined_colors) + api['class_colors'][class_name] = predefined_colors[color_index] + return class_name,api + return None,api + + +def add_point(data, x, y, is_foreground=True): + data['input_point'].append([x, y]) + data['input_label'].append(1 if is_foreground else 0) + logger.info(f"添加{'前景' if is_foreground else '背景'}点: ({x}, {y})") + return data + + + +def load_model(path): + # 加载配置内容 + config_path = os.path.join(path,'model_params.pickle') + with open(config_path, 'rb') as file: + loaded_data,api_config = pickle.load(file) + return loaded_data,api_config + +def save_model(loaded_data,api_config,path): + config_path = os.path.join(path,'model_params.pickle') + save_data = (loaded_data,api_config) + with open(config_path, 'wb') as file: + pickle.dump(save_data, file) + + +def sam_apply_mask(image, mask, color, alpha=0.5): + masked_image = image.copy() + for c in range(3): + masked_image[:, :, c] = np.where( + mask == 1, + image[:, :, c] * (1 - alpha) + alpha * color[c], + image[:, :, c] + ) + return masked_image + + +def apply_mask_overlay(image, mask, color, alpha=0.5): + colored_mask = np.zeros_like(image) + colored_mask[mask > 0] = color + return cv2.addWeighted(image, 1, colored_mask, alpha, 0) + +def load_tmp_image(path): + with open(path, "rb") as image_file: + # 将图片文件读取为二进制数据并进行 Base64 编码 + encoded_string = base64.b64encode(image_file.read()).decode('utf-8') + return encoded_string + +def refresh_image(data,api,path): + result = get_image_display(data,api) + if result['status'] == False: + return { + "status" : False, + "reason" : result['reason'] + } + else: + img = result['reason'] + display_img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + pil_img = Image.fromarray(display_img_rgb) + tmp_path = os.path.join(path,'temp/output_image.jpg') + pil_img.save(tmp_path) + return { + "status" : True, + "reason" : tmp_path + } + + +def set_class_color(data,api,index,color): + class_name = data['class_names'][index] + api['class_colors'][class_name] = color + data['class_colors'][class_name] = color + return data,api + + + +def add_to_class(api,class_name): + if class_name is None: + class_name = api['current_class'] + if not class_name or class_name not in api['class_annotations']: + return { + 'status':False, + 'reason': "该分类没有标注点" + } + class_data = api['class_annotations'][class_name] + if 'selected_mask' not in class_data or class_data['selected_mask'] is None: + return { + 'status':False, + 'reason': "该分类没有进行预测" + } + class_data['final_mask'] = class_data['selected_mask'].copy() + class_data['points'] = [] + class_data['point_types'] = [] + class_data['masks'] = [] + class_data['scores'] = [] + class_data['selected_mask_index'] = -1 + return { + 'status':True, + 'reason': api + } + + + +def get_image_display(data,api): + + if data['image'] is None: + logger.info("get_display_image: No image loaded") + return {"status":False, "reason":"获取图像:没有图像加载"} + display_image = data['image'].copy() + try: + for point, label in zip(data['input_point'], data['input_label']): + color = (0, 255, 0) if label == 1 else (0, 0, 255) + cv2.circle(display_image, tuple(point), 5, color, -1) + if data['selected_mask'] is not None: + class_name = data['class_names'][data['class_index']] + color = data['class_colors'].get(class_name, (0, 0, 128)) + display_image = sam_apply_mask(display_image, data['selected_mask'], color) + logger.info(f"get_display_image: Returning image with shape {display_image.shape}") + img = display_image + if not isinstance(img, np.ndarray) or img.size == 0: + logger.info(f"get_image_display: Invalid image array, shape: {img.shape if isinstance(img, np.ndarray) else 'None'}") + return {"status":False, "reason":f"get_image_display: Invalid image array, shape: {display_image.shape if isinstance(display_image, np.ndarray) else 'None'}"} + # 仅应用当前图片的final_mask + for class_name, class_data in api['class_annotations'].items(): + if 'final_mask' in class_data and class_data['final_mask'] is not None: + color = api['class_colors'].get(class_name, (0, 255, 0)) + mask = class_data['final_mask'] + if isinstance(mask, list): + mask = np.array(mask, dtype=np.uint8) + logger.info(f"Applying mask for class {class_name}, shape: {mask.shape}") + img = apply_mask_overlay(img, mask, color, alpha=0.5) + elif 'selected_mask' in class_data and class_data['selected_mask'] is not None: + color = api['class_colors'].get(class_name, (0, 255, 0)) + mask = class_data['selected_mask'] + if isinstance(mask, list): + mask = np.array(mask, dtype=np.uint8) + logger.info(f"Applying mask for class {class_name}, shape: {mask.shape}") + img = apply_mask_overlay(img, mask, color, alpha=0.5) + if api['current_class'] and api['current_class'] in api['class_annotations']: + class_data = api['class_annotations'][api['current_class']] + for i, (x, y) in enumerate(class_data['points']): + is_fg = class_data['point_types'][i] + color = (0, 255, 0) if is_fg else (0, 0, 255) + print(f"Drawing point at ({x}, {y}), type: {'foreground' if is_fg else 'background'}") + cv2.circle(img, (int(x), int(y)), 5, color, -1) + logger.info(f"get_image_display: Returning image with shape {img.shape}") + return {"status":True, "reason":img} + except Exception as e: + logger.error(f"get_display_image: Error processing image: {str(e)}") + traceback.print_exc() + return {"status":False, "reason":f"get_display_image: Error processing image: {str(e)}"} + + +def add_annotation_point(api, x, y, is_foreground=True): + if not api['current_class'] or api['current_class'] not in api['class_annotations']: + return { + "status": False, + "reason": "请选择或新建分类" + } + class_data = api['class_annotations'][api['current_class']] + + class_data['points'].append((x, y)) + class_data['point_types'].append(is_foreground) + return { + 'status' : True, + 'reason' : { + 'points': class_data['points'], + 'types': class_data['point_types'] + }, + 'api':api + } + + + +def delete_last_point(api): + if not api['current_class'] or api['current_class'] not in api['class_annotations']: + return { + "status": False, + "reason": "当前类没有点需要删除" + } + class_data = api['class_annotations'][api['current_class']] + if not class_data['points']: + return { + "status": False, + "reason": "当前类没有点需要删除" + } + class_data['points'].pop() + class_data['point_types'].pop() + return { + "status": True, + "reason": api + } + + + +def reset_current_class_points(api): + if not api['current_class'] or api['current_class'] not in api['class_annotations']: + return { + "status": False, + "reason": "当前类没有点需要删除" + } + class_data = api['class_annotations'][api['current_class']] + class_data['points'] = [] + class_data['point_types'] = [] + return { + "status": True, + "reason": api + } + + + + +def predict_mask(data,sam_predictor): + if data['image_rgb'] is None: + logger.info("predict_mask: No image loaded") + return {"status":False, "reason":"预测掩码:没有图像加载"} + if len(data['input_point']) == 0: + logger.info("predict_mask: No points added") + return {"status":False, "reason":"预测掩码:没有进行点标注"} + try: + sam_predictor.set_image(data['image_rgb']) + except Exception as e: + logger.error(f"predict_mask: Error setting image: {str(e)}") + return {"status":False, "reason":f"predict_mask: Error setting image: {str(e)}"} + input_point_np = np.array(data['input_point']) + input_label_np = np.array(data['input_label']) + try: + masks_pred, scores, logits = sam_predictor.predict( + point_coords=input_point_np, + point_labels=input_label_np, + mask_input=data['logit_input'][None, :, :] if data['logit_input'] is not None else None, + multimask_output=True, + ) + except Exception as e: + logger.error(f"predict_mask: Error during prediction: {str(e)}") + traceback.print_exc() + return {"status":False, "reason":f"predict_mask: Error during prediction: {str(e)}"} + data['masks_pred'] = masks_pred + data['scores'] = scores + data['logits'] = logits + best_mask_idx = np.argmax(scores) + data['selected_mask'] = masks_pred[best_mask_idx] + data['logit_input'] = logits[best_mask_idx, :, :] + logger.info(f"predict_mask: Predicted {len(masks_pred)} masks, best score: {scores[best_mask_idx]:.4f}") + return { + "status":True, + "reason":{ + "masks": [mask.tolist() for mask in masks_pred], + "scores": scores.tolist(), + "selected_index": int(best_mask_idx), + "data":data + }} + + +def get_image_info(data): + + return { + "filename": data['filename'], + "index": data['current_index'], + "total": len(data['image_files']), + "width": data['image'].shape[1], + "height": data['image'].shape[0] + } + \ No newline at end of file