日常代码提交

This commit is contained in:
xiazongji 2025-06-04 17:04:02 +08:00
parent 54e8b0411b
commit 68e3076f6a
29 changed files with 2497 additions and 167 deletions

2
.gitignore vendored
View File

@ -1,2 +1,4 @@
*.log
tmp/
datas/
imgs/

Binary file not shown.

Binary file not shown.

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,11 @@
A,V,FC,C,H,N,S,O,H/C,O/C,N/C,Rt,Hr,dp,T,Tar
13.5,54.1,45.9,71.6,5.8,1.4,0.9,20.3,0.972067039106145,0.212639664804469,0.0167597765363128,30.0,10.0,0.84,450,6.2108384
14.270205066345,50.41,49.59,58.69,3.14,1.24,0.21,18.12,0.642017379451355,0.231555631283012,0.0181096804030864,120.0,5.0,12.0,650,9.38428
11.29,47.12,52.88,65.21,9.71,2.01,0.7,22.37,1.78684250881767,0.257284158871339,0.026420137139352,20.0,15.0,6.0,480,10.14648
39.4,53.71,46.29,77.46,6.64,1.57,0.57,20.61,1.0286599535244,0.199554608830364,0.0173730220205821,15.0,30.5,5.0,600,9.6968
11.29,47.12,52.88,65.21,9.71,2.01,0.7,22.37,1.78684250881767,0.257284158871339,0.026420137139352,10.0,15.0,6.0,480,10.07076
25.56,45.86,54.14,65.71,4.92,1.26,2.42,25.69,0.898493380003044,0.293220210013697,0.0164358545122508,15.0,5.0,0.15,600,8.87
11.2165120400292,34.8127274862041,65.1872725137959,80.57,5.39,1.01,0.49,12.54,0.802780191138141,0.116730793099168,0.0107448713629674,30.0,30.0,0.14,650,9.53456
17.44,39.81,60.19,78.08,3.95,0.65,2.87,14.45,0.607069672131148,0.138799948770492,0.0071355386416861,20.0,40.0,6.0,500,6.1334
15.84092126406,50.41,49.59,58.69,3.14,1.24,0.21,18.12,0.642017379451355,0.231555631283012,0.0181096804030864,120.0,5.0,6.0,650,6.504628
6.17,46.5,53.5,76.14,3.06,1.06,0.24,19.5,0.482269503546099,0.192080378250591,0.0119329055499268,30.0,10.0,0.07,650,8.585445
1 A V FC C H N S O H/C O/C N/C Rt Hr dp T Tar
2 13.5 54.1 45.9 71.6 5.8 1.4 0.9 20.3 0.972067039106145 0.212639664804469 0.0167597765363128 30.0 10.0 0.84 450 6.2108384
3 14.270205066345 50.41 49.59 58.69 3.14 1.24 0.21 18.12 0.642017379451355 0.231555631283012 0.0181096804030864 120.0 5.0 12.0 650 9.38428
4 11.29 47.12 52.88 65.21 9.71 2.01 0.7 22.37 1.78684250881767 0.257284158871339 0.026420137139352 20.0 15.0 6.0 480 10.14648
5 39.4 53.71 46.29 77.46 6.64 1.57 0.57 20.61 1.0286599535244 0.199554608830364 0.0173730220205821 15.0 30.5 5.0 600 9.6968
6 11.29 47.12 52.88 65.21 9.71 2.01 0.7 22.37 1.78684250881767 0.257284158871339 0.026420137139352 10.0 15.0 6.0 480 10.07076
7 25.56 45.86 54.14 65.71 4.92 1.26 2.42 25.69 0.898493380003044 0.293220210013697 0.0164358545122508 15.0 5.0 0.15 600 8.87
8 11.2165120400292 34.8127274862041 65.1872725137959 80.57 5.39 1.01 0.49 12.54 0.802780191138141 0.116730793099168 0.0107448713629674 30.0 30.0 0.14 650 9.53456
9 17.44 39.81 60.19 78.08 3.95 0.65 2.87 14.45 0.607069672131148 0.138799948770492 0.0071355386416861 20.0 40.0 6.0 500 6.1334
10 15.84092126406 50.41 49.59 58.69 3.14 1.24 0.21 18.12 0.642017379451355 0.231555631283012 0.0181096804030864 120.0 5.0 6.0 650 6.504628
11 6.17 46.5 53.5 76.14 3.06 1.06 0.24 19.5 0.482269503546099 0.192080378250591 0.0119329055499268 30.0 10.0 0.07 650 8.585445

View File

@ -1,6 +1,6 @@
import pandas as pd
# 读取Excel文件
file_path = "D:\\project\\ai_station\\meirejie\\data\\char_data.csv" # 替换为你的Excel文件路径
file_path = "/home/xiazj/ai-station-code/meirejie/data/tar_data.csv" # 替换为你的Excel文件路径
df = pd.read_csv(file_path)
# 随机抽取10条数据
test_set = df.sample(n=10, random_state=1) # random_state保证每次抽样结果相同
@ -8,45 +8,45 @@ test_set = df.sample(n=10, random_state=1) # random_state保证每次抽样结
columns = ['A', 'V', 'FC', 'C', 'H', 'N', 'S', 'O', 'H/C', 'O/C', 'N/C', 'Rt','Hr', 'dp', 'T','Tar']
test_set = test_set[columns]
# 保存测试集到新的Excel文件
test_set.to_csv('D:\\project\\ai_station\\meirejie\\data\\char_data_test.csv', index=False) # 保存为test_set.xlsx不保存索引
test_set.to_csv('/home/xiazj/ai-station-code/meirejie/data/tar_data_test.csv', index=False) # 保存为test_set.xlsx不保存索引
print("测试集已保存为 char_data_test.csv")
file_path = "D:\\project\\ai_station\\meirejie\\data\\gas_data.csv" # 替换为你的Excel文件路径
file_path = "/home/xiazj/ai-station-code/meirejie/data/gas_data.csv" # 替换为你的Excel文件路径
df = pd.read_csv(file_path)
# 随机抽取10条数据
test_set = df.sample(n=10, random_state=1) # random_state保证每次抽样结果相同
columns = ['A', 'V', 'FC', 'C', 'H', 'N', 'S', 'O', 'H/C', 'O/C', 'N/C', 'Rt','Hr', 'dp', 'T','Gas']
test_set = test_set[columns]
# 保存测试集到新的Excel文件
test_set.to_csv('D:\\project\\ai_station\\meirejie\\data\\gas_data_test.csv', index=False) # 保存为test_set.xlsx不保存索引
test_set.to_csv('/home/xiazj/ai-station-code/meirejie/data/gas_data_test.csv', index=False) # 保存为test_set.xlsx不保存索引
print("测试集已保存为 gas_data_test.csv")
file_path = "D:\\project\\ai_station\\meirejie\\data\\water_data.csv" # 替换为你的Excel文件路径
file_path = "/home/xiazj/ai-station-code/meirejie/data/water_data.csv" # 替换为你的Excel文件路径
df = pd.read_csv(file_path)
# 随机抽取10条数据
test_set = df.sample(n=10, random_state=1) # random_state保证每次抽样结果相同
columns = ['A', 'V', 'FC', 'C', 'H', 'N', 'S', 'O', 'H/C', 'O/C', 'N/C', 'Rt','Hr', 'dp', 'T','Water']
test_set = test_set[columns]
# 保存测试集到新的Excel文件
test_set.to_csv('D:\\project\\ai_station\\meirejie\\data\\water_data_test.csv', index=False) # 保存为test_set.xlsx不保存索引
test_set.to_csv('/home/xiazj/ai-station-code/meirejie/data/water_data_test.csv', index=False) # 保存为test_set.xlsx不保存索引
print("测试集已保存为 water_data_test.csv")
file_path = "D:\\project\\ai_station\\meirejie\\data\\char_data.csv" # 替换为你的Excel文件路径
file_path = "/home/xiazj/ai-station-code/meirejie/data/char_data.csv" # 替换为你的Excel文件路径
df = pd.read_csv(file_path)
# 随机抽取10条数据
test_set = df.sample(n=10, random_state=1) # random_state保证每次抽样结果相同
columns = ['A', 'V', 'FC', 'C', 'H', 'N', 'S', 'O', 'H/C', 'O/C', 'N/C', 'Rt','Hr', 'dp', 'T','Char']
test_set = test_set[columns]
# 保存测试集到新的Excel文件
test_set.to_csv('D:\\project\\ai_station\\meirejie\\data\\char_data_test.csv', index=False) # 保存为test_set.xlsx不保存索引
test_set.to_csv('/home/xiazj/ai-station-code/meirejie/data/char_data_test.csv', index=False) # 保存为test_set.xlsx不保存索引
print("测试集已保存为 char_data_test.csv")

830
run.py
View File

@ -1,18 +1,22 @@
import sys
import io
from fastapi import FastAPI,File, UploadFile,Form,Query
from fastapi.staticfiles import StaticFiles
from fastapi import FastAPI, HTTPException
from fastapi.responses import FileResponse
from fastapi.responses import FileResponse,JSONResponse
import sys
import os
import shutil
from pydantic import BaseModel, validator
from typing import List
from typing import List, Optional
import asyncio
import pandas as pd
import numpy as np
from PIL import Image
import pickle
import cv2
import copy
import base64
# 获取当前脚本所在目录
print("Current working directory:", os.getcwd())
current_dir = os.path.dirname(os.path.abspath(__file__))
@ -25,7 +29,7 @@ from wudingpv.taihuyuan_roof.manet.model.resunet import resUnetpamcarb as roof_r
from wudingpv.predictandeval_util import segmentation
from guangfufadian import model_base as guangfufadian_model_base
from fenglifadian import model_base as fenglifadian_model_base
from work_util import prepare_data,model_deal,params,data_util,post_model
from work_util import prepare_data,model_deal,params,data_util,post_model,sam_deal
from work_util.logger import logger
import joblib
import mysql.connector
@ -33,7 +37,10 @@ import uuid
import json
import zipfile
from pathlib import Path
from segment_anything_model import sam_annotator
from segment_anything_model.sam_config import sam_config, sam_api_config
from segment_anything_model.segment_anything import sam_model_registry, SamPredictor
import traceback
version = f"{sys.version_info.major}.{sys.version_info.minor}"
app = FastAPI()
@ -77,6 +84,17 @@ windfd_args = fenglifadian_model_base.fenglifadian_Args()
windfd_model_path = os.path.join(windfd_args.checkpoints,'Crossformer_Wind_farm_il192_ol12_sl6_win2_fa10_dm256_nh4_el3_itr0/checkpoint.pth') # 修改为实际模型路径
windfd_model = fenglifadian_model_base.ModelInference(windfd_model_path, windfd_args)
# 模型加载
checkpoint_path = os.path.join(current_dir,'segment_anything_model/weights/vit_b.pth')
sam = sam_model_registry["vit_b"](checkpoint=checkpoint_path)
device = "cuda" if cv2.cuda.getCudaEnabledDeviceCount() > 0 else "cpu"
_ = sam.to(device=device)
sam_predictor = SamPredictor(sam)
print(f"SAM模型已加载使用设备: {device}")
# 将 /root/app 目录挂载为静态文件
app.mount("/files", StaticFiles(directory="/root/app"), name="files")
@ -86,6 +104,25 @@ async def read_root():
return {"message": message}
# 首页
# 获取数据界面资源信息
@app.get("/ai-station-api/index/show")
async def get_source_index_info():
sql = "SELECT id,application_name, describe_data, img_url FROM app_shouye"
data = data_util.fetch_data(sql)
if data is None:
return {
"success":False,
"msg":"获取信息列表失败",
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":data}
# 获取数据界面资源信息
@app.get("/ai-station-api/data_source/show")
@ -188,7 +225,7 @@ type = {tar,char,gas,water}
"""
@app.get("/ai-station-api/mrj_feature/show")
async def get_mrj_feature_info(type:str = None):
sql = "SELECT type, chinese_name, col_name, data_type, unit, data_scale FROM meirejie_features where use_type = %s;"
sql = "SELECT type, chinese_name, col_name, data_type, unit, data_scale,best_data FROM meijitancailiao_features where use_type = %s;"
data = data_util.fetch_data_with_param(sql,(type,))
if data is None:
return {
@ -235,7 +272,7 @@ async def mjt_models_predict_ssa(content: post_model.Zongbiaomianji):
new_order = ["A", "VM", "K/C", "MM", "AT", "At", "Rt"]
meijiegou_test_content = meijiegou_test_content.reindex(columns=new_order)
ssa_result = model_deal.pred_single_ssa(meijiegou_test_content)
logger.info("Root endpoint was accessed")
# logger.info("Root endpoint was accessed")
if ssa_result is None:
return {
"success":False,
@ -273,9 +310,14 @@ async def upload_file(file: UploadFile = File(...),type: str = Form(...), ):
@app.get("/ai-station-api/mjt_multi_ssa/predict")
async def mjt_multi_ssa_pred(model:str = None, path:str = None):
data = model_deal.get_excel_ssa(model, path)
return {"success":True,
if data['status'] == True:
return {"success":True,
"msg":"获取信息成功",
"data":data}
"data":data['reason']}
else:
return {"success":False,
"msg":data['reason'],
"data":None}
@ -308,9 +350,14 @@ async def mjt_models_predict_tpv(content: post_model.Zongbiaomianji):
@app.get("/ai-station-api/mjt_multi_tpv/predict")
async def mjt_multi_tpv_pred(model:str = None, path:str = None):
data = model_deal.get_excel_tpv(model, path)
return {"success":True,
if data['status'] == True:
return {"success":True,
"msg":"获取信息成功",
"data":data}
"data":data['reason']}
else:
return {"success":False,
"msg":data['reason'],
"data":None}
#==========================煤基碳材料-煤炭材料应用细节接口==================================
@app.post("/ai-station-api/mjt_models_meitan/predict")
@ -339,9 +386,14 @@ async def mjt_models_predict_meitan(content: post_model.Meitan):
@app.get("/ai-station-api/mjt_multi_meitan/predict")
async def mjt_multi_meitan_pred(model:str = None, path:str = None):
data = model_deal.get_excel_meitan(model, path)
return {"success":True,
if data['status'] == True:
return {"success":True,
"msg":"获取信息成功",
"data":data}
"data":data['reason']}
else:
return {"success":False,
"msg":data['reason'],
"data":None}
#==========================煤基碳材料-煤沥青应用细节接口==================================
@ -371,10 +423,14 @@ async def mjt_models_predict_meiliqing(content: post_model.Meitan):
@app.get("/ai-station-api/mjt_multi_meiliqing/predict")
async def mjt_multi_meiliqing_pred(model:str = None, path:str = None):
data = model_deal.get_excel_meiliqing(model, path)
return {"success":True,
if data['status'] == True:
return {"success":True,
"msg":"获取信息成功",
"data":data}
"data":data['reason']}
else:
return {"success":False,
"msg":data['reason'],
"data":None}
@ -428,6 +484,176 @@ async def mjt_multi_meiliqing_pred(model:str = None, path:str = None, type:int =
}
#===============================煤热解-tar =======================================================
"""
模型综合分析
"""
@app.post("/ai-station-api/mrj_models_tar/predict")
async def mrj_models_predict_tar(content: post_model.Meirejie):
# 处理接收到的字典数据
test_content = pd.DataFrame([content.model_dump()])
test_content= test_content.rename(columns={"H_C":"H/C"})
test_content= test_content.rename(columns={"O_C":"O/C"})
test_content= test_content.rename(columns={"N_C":"N/C"})
new_order = ["A", "V", "FC", "C", "H", "N" ,"S" ,"O" ,"H/C" ,"O/C" ,"N/C" ,"Rt" ,"Hr", "dp","T"]
test_content = test_content.reindex(columns=new_order)
tmp_result = model_deal.pred_single_tar(test_content)
if tmp_result is None:
return {
"success":False,
"msg":"获取信息列表失败",
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":tmp_result}
"""
批量预测接口
"""
@app.get("/ai-station-api/mrj_multi_tar/predict")
async def mrj_multi_tar_pred(model:str = None, path:str = None):
data = model_deal.get_excel_tar(model, path)
if data['status'] == True:
return {"success":True,
"msg":"获取信息成功",
"data":data['reason']}
else:
return {"success":False,
"msg":data['reason'],
"data":None}
#===============================煤热解-char =======================================================
"""
模型综合分析
"""
@app.post("/ai-station-api/mrj_models_char/predict")
async def mrj_models_predict_char(content: post_model.Meirejie):
# 处理接收到的字典数据
test_content = pd.DataFrame([content.model_dump()])
test_content= test_content.rename(columns={"H_C":"H/C"})
test_content= test_content.rename(columns={"O_C":"O/C"})
test_content= test_content.rename(columns={"N_C":"N/C"})
new_order = ["A", "V", "FC", "C", "H", "N" ,"S" ,"O" ,"H/C" ,"O/C" ,"N/C" ,"Rt" ,"Hr", "dp","T"]
test_content = test_content.reindex(columns=new_order)
tmp_result = model_deal.pred_single_char(test_content)
if tmp_result is None:
return {
"success":False,
"msg":"获取信息列表失败",
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":tmp_result}
"""
批量预测接口
"""
@app.get("/ai-station-api/mrj_multi_char/predict")
async def mrj_multi_char_pred(model:str = None, path:str = None):
data = model_deal.get_excel_char(model, path)
if data['status'] == True:
return {"success":True,
"msg":"获取信息成功",
"data":data['reason']}
else:
return {"success":False,
"msg":data['reason'],
"data":None}
#===============================煤热解-water =======================================================
"""
模型综合分析
"""
@app.post("/ai-station-api/mrj_models_water/predict")
async def mrj_models_predict_water(content: post_model.Meirejie):
# 处理接收到的字典数据
test_content = pd.DataFrame([content.model_dump()])
test_content= test_content.rename(columns={"H_C":"H/C"})
test_content= test_content.rename(columns={"O_C":"O/C"})
test_content= test_content.rename(columns={"N_C":"N/C"})
new_order = ["A", "V", "FC", "C", "H", "N" ,"S" ,"O" ,"H/C" ,"O/C" ,"N/C" ,"Rt" ,"Hr", "dp","T"]
test_content = test_content.reindex(columns=new_order)
tmp_result = model_deal.pred_single_water(test_content)
if tmp_result is None:
return {
"success":False,
"msg":"获取信息列表失败",
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":tmp_result}
"""
批量预测接口
"""
@app.get("/ai-station-api/mrj_multi_water/predict")
async def mrj_multi_water_pred(model:str = None, path:str = None):
data = model_deal.get_excel_water(model, path)
if data['status'] == True:
return {"success":True,
"msg":"获取信息成功",
"data":data['reason']}
else:
return {"success":False,
"msg":data['reason'],
"data":None}
#===============================煤热解-gas =======================================================
"""
模型综合分析
"""
@app.post("/ai-station-api/mrj_models_gas/predict")
async def mrj_models_predict_gas(content: post_model.Meirejie):
# 处理接收到的字典数据
test_content = pd.DataFrame([content.model_dump()])
test_content= test_content.rename(columns={"H_C":"H/C"})
test_content= test_content.rename(columns={"O_C":"O/C"})
test_content= test_content.rename(columns={"N_C":"N/C"})
new_order = ["A", "V", "FC", "C", "H", "N" ,"S" ,"O" ,"H/C" ,"O/C" ,"N/C" ,"Rt" ,"Hr", "dp","T"]
test_content = test_content.reindex(columns=new_order)
tmp_result = model_deal.pred_single_gas(test_content)
if tmp_result is None:
return {
"success":False,
"msg":"获取信息列表失败",
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":tmp_result}
"""
批量预测接口
"""
@app.get("/ai-station-api/mrj_multi_gas/predict")
async def mrj_multi_gas_pred(model:str = None, path:str = None):
data = model_deal.get_excel_gas(model, path)
if data['status'] == True:
return {"success":True,
"msg":"获取信息成功",
"data":data['reason']}
else:
return {"success":False,
"msg":data['reason'],
"data":None}
#================================ 地貌识别 ====================================================
"""
上传图片 , type = dimaoshibie
@ -458,6 +684,7 @@ async def upload_image(file: UploadFile = File(...),type: str = Form(...), ):
"msg":"获取信息成功",
"data":{"location": file_location}}
"""
图片地貌识别
"""
@ -753,7 +980,6 @@ async def get_ch4_data(path:str = None, page: int = Query(1, ge=1), page_size: i
csv上传
"""
# @app.post("/ai-station-api/document/upload") type = "ch4"
"""
预测, type
"""
@ -771,7 +997,6 @@ async def get_ch4_predict(path:str=None,start_time:str=None, end_time:str = None
"msg":"获取信息成功",
"data":data['reason']}
#========================================光伏预测==========================================================
"""
返回显示列表
@ -982,6 +1207,8 @@ csv上传
"""
# @app.post("/ai-station-api/document/upload") type = "wind_electric"
"""
预测, type
"""
@ -1000,9 +1227,576 @@ async def get_wind_electri_predict(path:str=None,start_time:str=None, end_time:s
"data":data['reason']}
#======================== SAM =============================================================
"""
文件上传
1 创建inputs目录outputs目录
2 将图片上传到iuputs目录
4 加载当前图像
5 返回当前图像路径给前端
"""
@app.post("/ai-station-api/sam-image/upload")
async def upload_sam_image(file: UploadFile = File(...),type: str = Form(...), ):
if not data_util.allowed_file(file.filename):
return {
"success":False,
"msg":"图片必须以 '.jpg', '.jpeg', '.png', '.tif' 结尾",
"data":None
}
if file.size > param.MAX_FILE_SAM_SIZE:
return {
"success":False,
"msg":"图片大小不能大于10MB",
"data":None
}
upload_dir = os.path.join(current_dir,'tmp',type, str(uuid.uuid4()))
if not os.path.exists(upload_dir):
os.makedirs(upload_dir )
input_dir = os.path.join(upload_dir,'input')
os.makedirs(input_dir)
output_dir = os.path.join(upload_dir,'output')
os.makedirs(output_dir)
temp_dir = os.path.join(upload_dir,'temp')
os.makedirs(temp_dir)
# 将文件保存到指定目录
file_location = os.path.join(input_dir , file.filename)
with open(file_location, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
config_path = os.path.join(upload_dir,'model_params.pickle')
# # 初始化配置 , 每次上传图片时,会创建一个新的配置文件
config = copy.deepcopy(sam_config)
api_config = copy.deepcopy(sam_api_config)
config['input_dir'] = input_dir
config['output_dir'] = output_dir
config['image_files'] = [f for f in os.listdir(input_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff'))]
config['current_index'] = 0
config['filename'] = config['image_files'][config['current_index']]
image_path = os.path.join(config['input_dir'], config['filename'])
config['image'] = cv2.imread(image_path)
config['image_rgb'] = cv2.cvtColor(config['image'].copy(), cv2.COLOR_BGR2RGB)
# 配置类获取参数信息
api_config['output_dir'] = output_dir
# # 重置pickle中的信息
# 重置API中的class_annotations保留类别但清除掩码和点
config, api_config = sam_deal.reset_class_annotations(config,api_config)
config = sam_deal.reset_annotation(config)
save_data = (config,api_config)
with open(config_path, 'wb') as file:
pickle.dump(save_data, file)
# 将图片拷贝到temp目录下
pil_img = Image.fromarray(config['image_rgb'])
tmp_path = os.path.join(temp_dir,'output_image.jpg')
pil_img.save(tmp_path)
encoded_string = sam_deal.load_tmp_image(tmp_path)
return {"success":True,
"msg":"获取信息成功",
"image": JSONResponse(content={"image_data": encoded_string}),
#"input_dir": input_dir,
#'output_dir': output_dir,
#'temp_dir': temp_dir,
'file_path': upload_dir}
"""
添加分类
添加分类并将分类设置为最新的添加的分类;
要求针对返回current_index,将列表默认成选择current_index对应的分类
"""
@app.get("/ai-station-api/sam_class/create")
async def sam_class_set(
class_name: str = None,
color: Optional[List[int]] = Query(None, description="list of RGB color"),
path: str = None
):
loaded_data,api_config = sam_deal.load_model(path)
result = sam_deal.add_class(loaded_data,class_name,color)
if result['status'] == True:
loaded_data = result['reason']
else:
return {
"success":False,
"msg":result['reason'],
"data":None
}
loaded_data['class_index'] = loaded_data['class_names'].index(class_name)
r, g, b = [int(c) for c in color]
bgr_color = (b, g, r)
result, api_config = sam_deal.set_current_class(loaded_data, api_config, loaded_data['class_index'], color=bgr_color)
# 更新配置内容
sam_deal.save_model(loaded_data,api_config,path)
tmp_path = os.path.join(path,'temp/output_image.jpg')
encoded_string = sam_deal.load_tmp_image(tmp_path)
return {"success":True,
"msg":f"已添加类别: {class_name}, 颜色: {color}",
"image": JSONResponse(content={"image_data": encoded_string}),
"data":{"class_name_list": loaded_data['class_names'],
"current_index": loaded_data['class_index'],
"class_dict":loaded_data['class_colors'],
}}
"""
选择颜色,
current_index : 下拉列表中的分类索引
rgb_color :
"""
@app.get("/ai-station-api/sam_color/select")
async def set_sam_color(
current_index: int = None,
rgb_color: List[int] = Query(None, description="list of RGB color"),
path: str = None
):
loaded_data,api_config = sam_deal.load_model(path)
r, g, b = [int(c) for c in rgb_color]
bgr_color = (b, g, r)
data, api = sam_deal.set_class_color(loaded_data, api_config, current_index, bgr_color)
result, api_config = sam_deal.set_current_class(data, api, current_index, color=bgr_color)
sam_deal.save_model(data,api,path)
img = sam_deal.refresh_image(data,api,path)
if img['status'] == True:
encoded_string = sam_deal.load_tmp_image(img['reason'])
return {
"success":True,
"msg":"",
"image": JSONResponse(content={"image_data": encoded_string}),
}
else:
return {
"success":False,
"msg":img['reason'],
"image":None
}
"""
选择分类
"""
@app.get("/ai-station-api/sam_classs/change")
async def on_class_selected(class_index : int = None,path: str = None):
# 加载配置内容
loaded_data,api_config = sam_deal.load_model(path)
result, api_config = sam_deal.set_current_class(loaded_data, api_config, class_index, color=None)
# loaded_data['class_index'] = class_index
sam_deal.save_model(loaded_data,api_config,path)
if result:
img = sam_deal.refresh_image(loaded_data,api_config,path)
if img['status'] == True:
encoded_string = sam_deal.load_tmp_image(img['reason'])
return {
"success":True,
"msg":"",
"image":JSONResponse(content={"image_data": encoded_string})
}
else:
return {
"success":False,
"msg":img['reason'],
"image":None
}
else:
return {
"success":False,
"msg":"分类标签识别错误",
"image":None
}
"""
删除分类, 前端跳转为normal index=0
"""
@app.get("/ai-station-api/sam_class/delete")
async def sam_remove_class(path:str=None,select_index:int=None):
loaded_data,api_config = sam_deal.load_model(path)
class_name = loaded_data['class_names'][select_index]
loaded_data,api_config = sam_deal.remove_class(loaded_data,api_config,class_name)
sam_deal.save_model(loaded_data,api_config,path)
img = sam_deal.refresh_image(loaded_data,api_config,path)
if img['status'] == True:
encoded_string = sam_deal.load_tmp_image(img['reason'])
return {
"success":True,
"msg":"",
"image":JSONResponse(content={"image_data": encoded_string})
}
else:
return {
"success":False,
"msg":img['reason'],
"image":None
}
"""
添加标注点 -左键
"""
@app.get("/ai-station-api/sam_point/left_add")
async def left_mouse_down(x:int=None,y:int=None,path:str=None):
loaded_data,api_config = sam_deal.load_model(path)
if not api_config['current_class']:
return {
"success":False,
"msg":"请先选择一个分类,在添加标点之前",
"image":None
}
is_foreground = True
result = sam_deal.add_annotation_point(api_config,x,y,is_foreground)
if result['status']== False:
return {
"success":False,
"msg":result['reason'],
"image":None
}
api_config = result['api']
sam_deal.save_model(loaded_data,api_config,path)
img = sam_deal.refresh_image(loaded_data,api_config,path)
if img['status'] == True:
encoded_string = sam_deal.load_tmp_image(img['reason'])
return {
"success":True,
"msg":"",
"image":JSONResponse(content={"image_data": encoded_string})
}
else:
return {
"success":False,
"msg":img['reason'],
"image":None
}
"""
添加标注点 -右键
"""
@app.get("/ai-station-api/sam_point/right_add")
async def right_mouse_down(x:int=None,y:int=None,path:str=None):
loaded_data,api_config = sam_deal.load_model(path)
if not api_config['current_class']:
return {
"success":False,
"msg":"请先选择一个分类,在添加标点之前",
"image":None
}
is_foreground = False
result = sam_deal.add_annotation_point(api_config,x,y,is_foreground)
if result['status']== False:
return {
"success":False,
"msg":result['reason'],
"image":None
}
api_config = result['api']
sam_deal.save_model(loaded_data,api_config,path)
img = sam_deal.refresh_image(loaded_data,api_config,path)
if img['status'] == True:
encoded_string = sam_deal.load_tmp_image(img['reason'])
return {
"success":True,
"msg":"",
"image":JSONResponse(content={"image_data": encoded_string})
}
else:
return {
"success":False,
"msg":img['reason'],
"image":None
}
"""
删除上一个点
"""
@app.get("/ai-station-api/sam_point/delete_last")
async def sam_delete_last_point(path:str=None):
loaded_data,api_config = sam_deal.load_model(path)
result = sam_deal.delete_last_point(api_config)
if result['status'] == True:
api_config = result['reason']
sam_deal.save_model(loaded_data,api_config,path)
img = sam_deal.refresh_image(loaded_data,api_config,path)
if img['status'] == True:
encoded_string = sam_deal.load_tmp_image(img['reason'])
return {
"success":True,
"msg":"",
"image":JSONResponse(content={"image_data": encoded_string})
}
else:
return {
"success":False,
"msg":img['reason'],
"data":None
}
else:
return {
"success":False,
"msg":result['reason'],
"data":None
}
"""
删除所有点
"""
@app.get("/ai-station-api/sam_point/delete_all")
async def sam_clear_all_point(path:str=None):
loaded_data,api_config = sam_deal.load_model(path)
result = sam_deal.reset_current_class_points(api_config)
if result['status'] == True:
api_config = result['reason']
sam_deal.save_model(loaded_data,api_config,path)
img = sam_deal.refresh_image(loaded_data,api_config,path)
if img['status'] == True:
encoded_string = sam_deal.load_tmp_image(img['reason'])
return {
"success":True,
"msg":"",
"image":JSONResponse(content={"image_data": encoded_string})
}
else:
return {
"success":False,
"msg":img['reason'],
"image":None
}
else:
return {
"success":False,
"msg":result['reason'],
"image":None
}
"""
模型预测
"""
@app.get("/ai-station-api/sam_model/predict")
async def sam_predict_mask(path:str=None):
loaded_data,api_config = sam_deal.load_model(path)
class_data = api_config['class_annotations'].get(api_config['current_class'], {})
if not class_data.get('points'):
return {
"success":False,
"msg":"请在预测前添加至少一个预测样本点",
"data":None
}
else:
loaded_data = sam_deal.reset_annotation(loaded_data)
# 将标注点添加到loaded_data 中
for i, (x, y) in enumerate(class_data['points']):
is_foreground = class_data['point_types'][i]
loaded_data = sam_deal.add_point(loaded_data, x, y, is_foreground=is_foreground)
try:
result = sam_deal.predict_mask(loaded_data,sam_predictor)
if result['status'] == False:
return {
"success":False,
"msg":result['reason'],
"data":None}
result = result['reason']
loaded_data = result['data']
class_data['masks'] = [np.array(mask, dtype=np.uint8) for mask in result['masks']]
class_data['scores'] = result['scores']
class_data['selected_mask_index'] = result['selected_index']
if result['selected_index'] >= 0:
class_data['selected_mask'] = class_data['masks'][result['selected_index']]
logger.info(f"predict: Predicted {len(result['masks'])} masks, selected index: {result['selected_index']}")
sam_deal.save_model(loaded_data,api_config,path)
img = sam_deal.refresh_image(loaded_data,api_config,path)
if img['status'] == True:
encoded_string = sam_deal.load_tmp_image(img['reason'])
return {
"success":True,
"msg":"",
"image":JSONResponse(content={"image_data": encoded_string})
}
else:
return {
"success":False,
"msg":img['reason'],
"image":None
}
except Exception as e:
logger.error(f"predict: Error during prediction: {str(e)}")
traceback.print_exc()
return {
"success":False,
"msg":f"predict: Error during prediction: {str(e)}",
"data":None}
"""
清除所有信息
"""
@app.get("/ai-station-api/sam_model/clear_all")
async def sam_reset_annotation(path:str=None):
loaded_data,api_config = sam_deal.load_model(path)
loaded_data,api_config = sam_deal.reset_annotation_all(loaded_data,api_config)
img = sam_deal.refresh_image(loaded_data,api_config,path)
if img['status'] == True:
encoded_string = sam_deal.load_tmp_image(img['reason'])
return {
"success":True,
"msg":"",
"image":JSONResponse(content={"image_data": encoded_string})
}
else:
return {
"success":False,
"msg":img['reason'],
"image":None
}
"""
保存预测分类
"""
@app.get("/ai-station-api/sam_model/save_stage")
async def sam_add_to_class(path:str=None,class_index:int=None):
loaded_data,api_config = sam_deal.load_model(path)
class_name = loaded_data['class_names'][class_index]
result = sam_deal.add_to_class(api_config,class_name)
if result['status'] == True:
api_config = result['reason']
sam_deal.save_model(loaded_data,api_config,path)
img = sam_deal.refresh_image(loaded_data,api_config,path)
if img['status'] == True:
encoded_string = sam_deal.load_tmp_image(img['reason'])
return {
"success":True,
"msg":"",
"image":JSONResponse(content={"image_data": encoded_string})
}
else:
return {
"success":False,
"msg":img['reason'],
"image":None
}
else:
return {
"success":False,
"msg":result['reason'],
"image":None
}
"""
保存结果
"""
@app.get("/ai-station-api/sam_model/save")
async def sam_save_annotation(path:str=None):
loaded_data,api_config = sam_deal.load_model(path)
if not api_config['output_dir']:
logger.info("save_annotation: Output directory not set")
return {
"success":False,
"msg":"save_annotation: Output directory not set",
"data":None
}
has_annotations = False
for class_name, class_data in api_config['class_annotations'].items():
if 'final_mask' in class_data and class_data['final_mask'] is not None:
has_annotations = True
break
if not has_annotations:
logger.info("save_annotation: No final masks to save")
return {
"success":False,
"msg":"save_annotation: No final masks to save",
"data":None
}
image_info = sam_deal.get_image_info(loaded_data)
if not image_info:
logger.info("save_annotation: No image info available")
return {
"success":False,
"msg":"save_annotation: No image info available",
"data":None
}
image_basename = os.path.splitext(image_info['filename'])[0]
annotation_dir = os.path.join(api_config['output_dir'], image_basename)
os.makedirs(annotation_dir, exist_ok=True)
saved_files = []
orig_img = loaded_data['image']
original_img_path = os.path.join(annotation_dir, f"{image_basename}.jpg")
cv2.imwrite(original_img_path, orig_img)
saved_files.append(original_img_path)
vis_img = orig_img.copy()
img_height, img_width = orig_img.shape[:2]
labelme_data = {
"version": "5.1.1",
"flags": {},
"shapes": [],
"imagePath": f"{image_basename}.jpg",
"imageData": None,
"imageHeight": img_height,
"imageWidth": img_width
}
for class_name, class_data in api_config['class_annotations'].items():
if 'final_mask' in class_data and class_data['final_mask'] is not None:
color = api_config['class_colors'].get(class_name, (0, 255, 0))
vis_mask = class_data['final_mask'].copy()
color_mask = np.zeros_like(vis_img)
color_mask[vis_mask > 0] = color
vis_img = cv2.addWeighted(vis_img, 1.0, color_mask, 0.5, 0)
binary_mask = (class_data['final_mask'] > 0).astype(np.uint8)
contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
for contour in contours:
epsilon = 0.0001 * cv2.arcLength(contour, True)
approx_contour = cv2.approxPolyDP(contour, epsilon, True)
points = [[float(point[0][0]), float(point[0][1])] for point in approx_contour]
if len(points) >= 3:
shape_data = {
"label": class_name,
"points": points,
"group_id": None,
"shape_type": "polygon",
"flags": {}
}
labelme_data["shapes"].append(shape_data)
vis_path = os.path.join(annotation_dir, f"{image_basename}_mask.jpg")
cv2.imwrite(vis_path, vis_img)
saved_files.append(vis_path)
try:
is_success, buffer = cv2.imencode(".jpg", orig_img)
if is_success:
img_bytes = io.BytesIO(buffer).getvalue()
labelme_data["imageData"] = base64.b64encode(img_bytes).decode('utf-8')
else:
print("save_annotation: Failed to encode image data")
labelme_data["imageData"] = ""
except Exception as e:
logger.error(f"save_annotation: Could not encode image data: {str(e)}")
labelme_data["imageData"] = ""
json_path = os.path.join(annotation_dir, f"{image_basename}.json")
with open(json_path, 'w') as f:
json.dump(labelme_data, f, indent=2)
saved_files.append(json_path)
logger.info(f"save_annotation: Annotation saved to {annotation_dir}")
# 将其打包
zip_filename = "download.zip"
zip_filepath = Path(zip_filename)
dir_path = os.path.dirname(annotation_dir)
# 创建 ZIP 文件
try:
with zipfile.ZipFile(zip_filepath, 'w') as zip_file:
# 遍历文件夹中的所有文件
for foldername, subfolders, filenames in os.walk(dir_path):
for filename in filenames:
file_path = os.path.join(foldername, filename)
# 将文件写入 ZIP 文件
zip_file.write(file_path, os.path.relpath(file_path, dir_path))
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error creating zip file: {str(e)}")
# 返回 ZIP 文件作为响应
return FileResponse(zip_filepath, media_type='application/zip', filename=zip_filename)
# @app.post("/ai-station-api/items/")

@ -0,0 +1 @@
Subproject commit 685b6e545db2114b9b834b1c94712b026298b0a7

Binary file not shown.

Binary file not shown.

View File

@ -129,6 +129,11 @@ def allowed_file(filename: str) -> bool:
param = params.ModelParams()
return any(filename.endswith(ext) for ext in param.ALLOWED_EXTENSIONS)
####################SAM###################################
# data_info ={
# "result": [
# {

View File

@ -13,10 +13,10 @@ import cv2 as cv
import torch.nn.functional as F
from joblib import dump, load
from fastapi import HTTPException
import pickle
param = params.ModelParams()
import cv2
import traceback
################################################### 图像类函数调用###########################################################################
@ -579,6 +579,7 @@ def pred_single_tar(test_content):
def pred_single_gas(test_content):
gas_pred = []
current_directory = os.getcwd()
@ -587,7 +588,7 @@ def pred_single_gas(test_content):
gas_model = load(model_path)
pred = gas_model.predict(test_content)
gas_pred.append(pred[0])
result = [param.meirejie_tar_mae, param.meirejie_tar_r2,gas_pred]
result = [param.meirejie_gas_mae, param.meirejie_gas_r2,gas_pred]
# 创建 DataFrame
df = pd.DataFrame(result, index=param.index, columns=param.columns)
result = df.to_json(orient='index')
@ -602,7 +603,7 @@ def pred_single_water(test_content):
water_model = load(model_path)
pred = water_model.predict(test_content)
water_pred.append(pred[0])
result = [param.meirejie_tar_mae, param.meirejie_tar_r2,water_pred]
result = [param.meirejie_water_mae, param.meirejie_water_r2,water_pred]
# 创建 DataFrame
df = pd.DataFrame(result, index=param.index, columns=param.columns)
result = df.to_json(orient='index')
@ -616,7 +617,7 @@ def pred_single_char(test_content):
char_model = load(model_path)
pred = char_model.predict(test_content)
char_pred.append(pred[0])
result = [param.meirejie_tar_mae, param.meirejie_tar_r2,char_pred]
result = [param.meirejie_char_mae, param.meirejie_char_r2,char_pred]
# 创建 DataFrame
df = pd.DataFrame(result, index=param.index, columns=param.columns)
result = df.to_json(orient='index')
@ -624,50 +625,65 @@ def pred_single_char(test_content):
def choose_model(name,data):
def choose_model_meirejie(name,data):
current_directory = os.getcwd()
model_path = os.path.join(current_directory,'meirejie',param.meirejie_model_dict[name])
model = load(model_path)
pred = model.predict(data)
return pred
def get_excel_tar(model_name):
data_name = param.meirejie_test_data['tar']
current_directory = os.getcwd()
data_path = os.path.join(current_directory,"tmp","meirejie",data_name)
test_data = pd.read_csv(data_path)
pred = choose_model(model_name,test_data)
def get_excel_tar(model_name,file_path):
if not file_path.endswith('.csv'):
return {"status":False, "reason":"上传文件类型有误,必须是 CSV 文件"}
test_data = pd.read_csv(file_path)
expected_columns = ["A", "V", "FC", "C", "H", "N" ,"S" ,"O" ,"H/C" ,"O/C" ,"N/C" ,"Rt" ,"Hr", "dp","T","Tar"]
if list(test_data.columns) != expected_columns:
return {"status":False, "reason":f"文件列名不匹配,预期列名为: {expected_columns}"}
# raise HTTPException(status_code=400, detail=f"文件列名不匹配,预期列名为: {expected_columns}")
del test_data['Tar']
pred = choose_model_meirejie(model_name,test_data)
test_data['tar_pred'] = pred
return test_data.to_json(orient='records', lines=True)
return {"status":True, "reason":test_data.to_dict(orient='records')}
def get_excel_gas(model_name):
data_name = param.meirejie_test_data['gas']
current_directory = os.getcwd()
data_path = os.path.join(current_directory,"tmp","meirejie",data_name)
test_data = pd.read_csv(data_path)
pred = choose_model(model_name,test_data)
def get_excel_gas(model_name,file_path):
if not file_path.endswith('.csv'):
return {"status":False, "reason":"上传文件类型有误,必须是 CSV 文件"}
test_data = pd.read_csv(file_path)
expected_columns = ["A", "V", "FC", "C", "H", "N" ,"S" ,"O" ,"H/C" ,"O/C" ,"N/C" ,"Rt" ,"Hr", "dp","T","Gas"]
if list(test_data.columns) != expected_columns:
return {"status":False, "reason":f"文件列名不匹配,预期列名为: {expected_columns}"}
# raise HTTPException(status_code=400, detail=f"文件列名不匹配,预期列名为: {expected_columns}")
del test_data['Gas']
pred = choose_model_meirejie(model_name,test_data)
test_data['gas_pred'] = pred
return test_data.to_json(orient='records', lines=True)
return {"status":True, "reason":test_data.to_dict(orient='records')}
def get_excel_char(model_name):
data_name = param.meirejie_test_data['char']
current_directory = os.getcwd()
data_path = os.path.join(current_directory,"tmp","meirejie",data_name)
test_data = pd.read_csv(data_path)
pred = choose_model(model_name,test_data)
def get_excel_char(model_name,file_path):
if not file_path.endswith('.csv'):
return {"status":False, "reason":"上传文件类型有误,必须是 CSV 文件"}
test_data = pd.read_csv(file_path)
expected_columns = ["A", "V", "FC", "C", "H", "N" ,"S" ,"O" ,"H/C" ,"O/C" ,"N/C" ,"Rt" ,"Hr", "dp","T","Char"]
if list(test_data.columns) != expected_columns:
return {"status":False, "reason":f"文件列名不匹配,预期列名为: {expected_columns}"}
# raise HTTPException(status_code=400, detail=f"文件列名不匹配,预期列名为: {expected_columns}")
del test_data['Char']
pred = choose_model_meirejie(model_name,test_data)
test_data['char_pred'] = pred
return test_data.to_json(orient='records', lines=True)
return {"status":True, "reason":test_data.to_dict(orient='records')}
def get_excel_water(model_name):
data_name = param.meirejie_test_data['water']
current_directory = os.getcwd()
data_path = os.path.join(current_directory,"tmp","meirejie",data_name)
test_data = pd.read_csv(data_path)
pred = choose_model(model_name,test_data)
def get_excel_water(model_name,file_path):
if not file_path.endswith('.csv'):
return {"status":False, "reason":"上传文件类型有误,必须是 CSV 文件"}
test_data = pd.read_csv(file_path)
expected_columns = ["A", "V", "FC", "C", "H", "N" ,"S" ,"O" ,"H/C" ,"O/C" ,"N/C" ,"Rt" ,"Hr", "dp","T","Water"]
if list(test_data.columns) != expected_columns:
return {"status":False, "reason":f"文件列名不匹配,预期列名为: {expected_columns}"}
# raise HTTPException(status_code=400, detail=f"文件列名不匹配,预期列名为: {expected_columns}")
del test_data['Water']
pred = choose_model_meirejie(model_name,test_data)
test_data['water_pred'] = pred
return test_data.to_json(orient='records', lines=True)
return {"status":True, "reason":test_data.to_dict(orient='records')}
@ -741,52 +757,59 @@ def choose_model_meijitancailiao(name,data):
def get_excel_ssa(model_name,file_path):
if not file_path.endswith('.csv'):
raise HTTPException(status_code=400, detail="文件类型有误,必须是 CSV 文件")
return {"status":False, "reason":"上传文件类型有误,必须是 CSV 文件"}
# raise HTTPException(status_code=400, detail="文件类型有误,必须是 CSV 文件")
test_data = pd.read_csv(file_path)
expected_columns = ["A", "VM", "K/C", "MM", "AT", "At", "Rt", "SSA"]
if list(test_data.columns) != expected_columns:
raise HTTPException(status_code=400, detail=f"文件列名不匹配,预期列名为: {expected_columns}")
return {"status":False, "reason":f"文件列名不匹配,预期列名为: {expected_columns}"}
# raise HTTPException(status_code=400, detail=f"文件列名不匹配,预期列名为: {expected_columns}")
del test_data['SSA']
pred = choose_model_meijitancailiao(model_name,test_data)
test_data['ssa_pred'] = pred
return test_data.to_json(orient='records')
return {"status":True, "reason":test_data.to_dict(orient='records')}
def get_excel_tpv(model_name,file_path):
if not file_path.endswith('.csv'):
raise HTTPException(status_code=400, detail="文件类型有误,必须是 CSV 文件")
return {"status":False, "reason":"上传文件类型有误,必须是 CSV 文件"}
# raise HTTPException(status_code=400, detail="文件类型有误,必须是 CSV 文件")
test_data = pd.read_csv(file_path)
expected_columns = ["A", "VM", "K/C", "MM", "AT", "At", "Rt", "TPV"]
if list(test_data.columns) != expected_columns:
raise HTTPException(status_code=400, detail=f"文件列名不匹配,预期列名为: {expected_columns}")
# raise HTTPException(status_code=400, detail=f"文件列名不匹配,预期列名为: {expected_columns}")
return {"status":False, "reason":f"文件列名不匹配,预期列名为: {expected_columns}"}
del test_data['TPV']
pred = choose_model_meijitancailiao(model_name,test_data)
test_data['tpv_pred'] = pred
return test_data.to_json(orient='records')
return {"status":True, "reason":test_data.to_dict(orient='records')}
def get_excel_meitan(model_name,file_path):
if not file_path.endswith('.csv'):
raise HTTPException(status_code=400, detail="文件类型有误,必须是 CSV 文件")
return {"status":False, "reason":"上传文件类型有误,必须是 CSV 文件"}
# raise HTTPException(status_code=400, detail="文件类型有误,必须是 CSV 文件")
test_data = pd.read_csv(file_path)
expected_columns = ["SSA", "TPV", "N", "O", "ID/IG", "J","C"]
if list(test_data.columns) != expected_columns:
raise HTTPException(status_code=400, detail=f"文件列名不匹配,预期列名为: {expected_columns}")
# raise HTTPException(status_code=400, detail=f"文件列名不匹配,预期列名为: {expected_columns}")
return {"status":False, "reason":f"文件列名不匹配,预期列名为: {expected_columns}"}
del test_data['C']
pred = choose_model_meijitancailiao(model_name,test_data)
test_data['C_pred'] = pred
return test_data.to_json(orient='records', lines=True)
# return test_data.to_dict(orient='records')
return {"status":True, "reason":test_data.to_dict(orient='records')}
def get_excel_meiliqing(model_name,file_path):
if not file_path.endswith('.csv'):
raise HTTPException(status_code=400, detail="文件类型有误,必须是 CSV 文件")
return {"status":False, "reason":"上传文件类型有误,必须是 CSV 文件"}
test_data = pd.read_csv(file_path)
expected_columns = ["SSA", "TPV", "N", "O", "ID/IG", "J", "C"]
if list(test_data.columns) != expected_columns:
raise HTTPException(status_code=400, detail=f"文件列名不匹配,预期列名为: {expected_columns}")
return {"status":False, "reason":f"文件列名不匹配,预期列名为: {expected_columns}"}
del test_data['C']
pred = choose_model_meijitancailiao(model_name,test_data)
test_data['C_pred'] = pred
return test_data.to_json(orient='records', lines=True)
return {"status":True, "reason":test_data.to_dict(orient='records')}
def pred_func(func_name, pred_data):
@ -835,3 +858,6 @@ def get_pic_path(url):
# 3. 添加本地根目录
local_path = f"/root/app{relative_path}"
return local_path

View File

@ -140,9 +140,18 @@ class ModelParams():
}
index = ['mae', 'r2', 'result']
columns = ['XGBoost', 'Linear Regression', 'Ridge Regression', 'Gaussian Process Regression',
'ElasticNet Regression', 'K-Nearest Neighbors', 'Support Vector Regression',
'Decision Tree Regression', 'Random Forest Regression', 'AdaBoost Regression']
columns = [
'XGBoost(极端梯度提升)',
'Linear Regression(线性回归)',
'Ridge Regression(岭回归)',
'Gaussian Process Regression(高斯过程回归)',
'ElasticNet Regression(弹性网回归)',
'K-Nearest Neighbors(K最近邻)',
'Support Vector Regression(支持向量回归)',
'Decision Tree Regression(决策树回归)',
'Random Forest Regression(随机森林回归)',
'AdaBoost Regression(AdaBoost回归)'
]
meirejie_model_list_gas = ['xgb_gas','lr_gas','ridge_gas','gp_gas','en_gas','kn_gas','svr_gas','dtr_gas','rfr_gas','adb_gas']
meirejie_model_list_char = ['xgb_char','lr_char','ridge_char','gp_char','en_char','kn_char','svr_char','dtr_char','rfr_char','adb_char']
meirejie_model_list_water = ['xgb_water','lr_water','ridge_water','gp_water','en_water','kn_water','svr_water','dtr_water','rfr_water','adb_water']
@ -224,29 +233,65 @@ class ModelParams():
"xgb_meiliqing":"model/meiliqing_XGB.joblib",
}
columns_ssa = ['XGBoost', 'Linear Regression', 'Ridge Regression', 'Gaussian Process Regression',
'ElasticNet Regression', 'K-Nearest Neighbors', 'Support Vector Regression',
'Decision Tree Regression', 'Random Forest Regression', 'AdaBoost Regression']
columns_ssa = [
'XGBoost(极端梯度提升)',
'Linear Regression(线性回归)',
'Ridge Regression(岭回归)',
'Gaussian Process Regression(高斯过程回归)',
'ElasticNet Regression(弹性网回归)',
'K-Nearest Neighbors(最近邻居算法)',
'Support Vector Regression(支持向量回归)',
'Decision Tree Regression(决策树回归)',
'Random Forest Regression(随机森林回归)',
'AdaBoost Regression(自适应提升回归)'
]
meijitancailiao_model_list_ssa = ['xgb_ssa','lr_ssa','ridge_ssa','gp_ssa','en_ssa','kn_ssa','svr_ssa','dtr_ssa','rfr_ssa','adb_ssa']
meijitancailiao_ssa_mae = [258, 407,408 ,282 ,411 ,389, 405, 288,193, 330]
meijitancailiao_ssa_r2 = [0.92,0.82,0.82,0.89,0.81,0.82,0.87,0.88,0.95,0.88]
columns_tpv = ['XGBoost', 'Linear Regression', 'Ridge Regression', 'Gaussian Process Regression',
'ElasticNet Regression', 'Gradient Boosting Regression', 'Support Vector Regression',
'Decision Tree Regression', 'Random Forest Regression', 'AdaBoost Regression']
columns_tpv = [
'XGBoost(极端梯度提升)',
'Linear Regression(线性回归)',
'Ridge Regression(岭回归)',
'Gaussian Process Regression(高斯过程回归)',
'ElasticNet Regression(弹性网回归)',
'Gradient Boosting Regression(梯度提升回归)',
'Support Vector Regression(支持向量回归)',
'Decision Tree Regression(决策树回归)',
'Random Forest Regression(随机森林回归)',
'AdaBoost Regression(自适应提升回归)'
]
meijitancailiao_model_list_tpv = ['xgb_tpv', 'lr_tpv', 'ridge_tpv', 'gp_tpv', 'en_tpv', 'gdbt_tpv', 'svr_tpv', 'dtr_tpv', 'rfr_tpv', 'adb_tpv']
meijitancailiao_tpv_mae = [0.2, 0.2, 0.2, 0.2, 0.2, 0.23, 0.23, 0.21, 0.16, 0.21]
meijitancailiao_tpv_r2 = [0.81, 0.81, 0.81, 0.8, 0.82, 0.80, 0.78, 0.73, 0.85, 0.84]
columns_meitan = ['XGBoost', 'Linear Regression', 'Ridge Regression', 'Gaussian Process Regression',
'ElasticNet Regression', 'Gradient Boosting Regression', 'Support Vector Regression',
'Decision Tree Regression', 'Random Forest Regression', 'AdaBoost Regression']
columns_meitan = [
'XGBoost(极端梯度提升)',
'Linear Regression(线性回归)',
'Ridge Regression(岭回归)',
'Gaussian Process Regression(高斯过程回归)',
'ElasticNet Regression(弹性网回归)',
'Gradient Boosting Regression(梯度提升回归)',
'Support Vector Regression(支持向量回归)',
'Decision Tree Regression(决策树回归)',
'Random Forest Regression(随机森林回归)',
'AdaBoost Regression(自适应提升回归)'
]
meijitancailiao_model_list_meitan = ['xgb_meitan', 'lr_meitan', 'ridge_meitan', 'gp_meitan', 'en_meitan', 'gdbt_meitan', 'svr_meitan', 'dtr_meitan', 'rfr_meitan', 'adb_meitan']
meijitancailiao_meitan_mae = [8.17, 37.61, 37.66, 13.41, 20.96, 8.03, 14.89, 19.48, 12.53, 15.6]
meijitancailiao_meitan_r2 = [0.96, 0.19, 0.19, 0.91, 0.8, 0.96, 0.88, 0.86, 0.91, 0.91]
columns_meiliqing = ['XGBoost', 'Linear Regression', 'Ridge Regression', 'Gaussian Process Regression',
'ElasticNet Regression', 'Gradient Boosting Regression', 'Support Vector Regression',
'Decision Tree Regression', 'Random Forest Regression', 'AdaBoost Regression']
columns_meiliqing = [
'XGBoost(极端梯度提升)',
'Linear Regression(线性回归)',
'Ridge Regression(岭回归)',
'Gaussian Process Regression(高斯过程回归)',
'ElasticNet Regression(弹性网回归)',
'Gradient Boosting Regression(梯度提升回归)',
'Support Vector Regression(支持向量回归)',
'Decision Tree Regression(决策树回归)',
'Random Forest Regression(随机森林回归)',
'AdaBoost Regression(自适应提升回归)'
]
meijitancailiao_model_list_meiliqing = ['xgb_meiliqing', 'lr_meiliqing', 'ridge_meiliqing', 'gp_meiliqing', 'en_meiliqing', 'gdbt_meiliqing', 'svr_meiliqing', 'dtr_meiliqing', 'rfr_meiliqing', 'adb_meiliqing']
meijitancailiao_meiliqing_mae = [8.38, 35.02, 35.1, 11.02, 13.58, 7.04, 13.13, 13.13, 11.25, 9.99]
meijitancailiao_meiliqing_r2 = [0.95, 0.33, 0.33, 0.94, 0.91, 0.97, 0.88, 0.89, 0.92, 0.94]
@ -277,3 +322,11 @@ class ModelParams():
ALLOWED_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.tif'}
MAX_FILE_SIZE = 100 * 1024 * 1024 # 100 MB
MAX_FILE_SAM_SIZE = 10 * 1024 * 1024
DEFAULT_MODEL_PATH = r"/home/xiazj/ai-station-code/segment_anything_model/weights/vit_b.pth"

View File

@ -21,6 +21,26 @@ class Meitan(BaseModel):
J: float
class Meirejie(BaseModel):
A: float
V: float
FC: float
C: float
H: float
N: float
S: float
O: float
H_C: float
O_C: float
N_C: float
Rt: float
Hr: float
dp: float
T: float
class FormData(BaseModel):
A_min: Optional[float] = None
A_max: Optional[float] = None

398
work_util/sam_deal.py Normal file
View File

@ -0,0 +1,398 @@
import cv2
import os
import numpy as np
import json
import shutil
from segment_anything_model.segment_anything import sam_model_registry, SamPredictor
import random
import traceback
import base64
import io
from work_util.logger import logger
import pickle
from PIL import Image
def get_display_image_ori(loaded_data):
if loaded_data['image_rgb'] is None:
logger.info("get_display_image: No image loaded")
return {"status": False,"reason":"无图片加载"}
display_image = loaded_data['image_rgb'].copy()
try:
for point, label in zip(loaded_data['input_point'], loaded_data['input_label']):
color = (0, 255, 0) if label == 1 else (0, 0, 255)
cv2.circle(display_image, tuple(point), 5, color, -1)
if loaded_data['selected_mask'] is not None:
class_name = loaded_data['class_names'][loaded_data['class_index']]
color = loaded_data['class_colors'].get(class_name, (0, 0, 128))
display_image = sam_apply_mask(display_image, loaded_data['selected_mask'], color)
logger.info(f"get_display_image: Returning image with shape {display_image.shape}")
return {"status" : True, "reason":display_image}
except Exception as e:
logger.info(f"get_display_image: Error processing image: {str(e)}")
traceback.print_exc()
return {"status": False,"reason":f"get_display_image: Error processing image: {str(e)}"}
def get_all_classes(data):
return {
"classes": data['class_names'],
"colors": {name: color for name, color in data['class_colors'].items()}
}
def get_classes(data):
return get_all_classes(data)
def reset_annotation(loaded_data):
loaded_data['input_point'] = []
loaded_data['input_label'] = []
loaded_data['selected_mask'] = None
loaded_data['logit_input'] = None
loaded_data['masks'] = {}
logger.info("已重置标注状态")
return loaded_data
def remove_class(data,api,class_name):
if class_name in api['class_annotations']:
del api['class_annotations'][class_name]
if class_name == data['class_names'][data['class_index']]:
data['class_index'] = 0
data['class_names'].remove(class_name)
if class_name in data['class_colors']:
del data['class_colors'][class_name]
if class_name in data['masks']:
del data['masks'][class_name]
return data,api
def reset_annotation_all(data,api):
data = reset_annotation(data)
for class_data in api['class_annotations'].values():
class_data['points'] = []
class_data['point_types'] = []
class_data['masks'] = []
class_data['scores'] = []
class_data['selected_mask_index'] = -1
if 'selected_mask' in class_data:
del class_data['selected_mask']
return data,api
def reset_class_annotations(data,api):
"""重置class_annotations保留类别但清除掩码和点"""
classes = get_classes(data).get('classes', [])
new_annotations = {}
for class_name in classes:
new_annotations[class_name] = {
'points': [],
'point_types': [],
'masks': [],
'selected_mask_index': -1
}
api['class_annotations'] = new_annotations
logger.info("已重置class_annotations保留类别但清除掩码和点")
return data,api
def add_class(data,class_name, color=None):
if class_name in data['class_names']:
logger.info(f"类别 '{class_name}' 已存在")
return {'status':False, 'reason':f"类别 '{class_name}' 已存在"}
data['class_names'].append(class_name)
if color is None:
color = tuple(np.random.randint(100, 256, 3).tolist())
r, g, b = [int(c) for c in color]
bgr_color = (b, g, r)
data['class_colors'][class_name] = tuple(bgr_color)
logger.info(f"已添加类别: {class_name}, 颜色: {tuple(color)}")
return {'status':True, 'reason':data}
def set_current_class(data, api, class_index, color=None):
classes = get_classes(data)
if 'classes' in classes and class_index < len(classes['classes']):
class_name = classes['classes'][class_index]
api['current_class'] = class_name
if class_name not in api['class_annotations']:
api['class_annotations'][class_name] = {
'points': [],
'point_types': [],
'masks': [],
'selected_mask_index': -1
}
color = data['class_colors'][class_name]
if color:
api['class_colors'][class_name] = color
elif class_name not in api['class_colors']:
predefined_colors = [
(255, 0, 0), (0, 255, 0), (0, 0, 255),
(255, 255, 0), (255, 0, 255), (0, 255, 255)
]
color_index = len(api['class_colors']) % len(predefined_colors)
api['class_colors'][class_name] = predefined_colors[color_index]
return class_name,api
return None,api
def add_point(data, x, y, is_foreground=True):
data['input_point'].append([x, y])
data['input_label'].append(1 if is_foreground else 0)
logger.info(f"添加{'前景' if is_foreground else '背景'}点: ({x}, {y})")
return data
def load_model(path):
# 加载配置内容
config_path = os.path.join(path,'model_params.pickle')
with open(config_path, 'rb') as file:
loaded_data,api_config = pickle.load(file)
return loaded_data,api_config
def save_model(loaded_data,api_config,path):
config_path = os.path.join(path,'model_params.pickle')
save_data = (loaded_data,api_config)
with open(config_path, 'wb') as file:
pickle.dump(save_data, file)
def sam_apply_mask(image, mask, color, alpha=0.5):
masked_image = image.copy()
for c in range(3):
masked_image[:, :, c] = np.where(
mask == 1,
image[:, :, c] * (1 - alpha) + alpha * color[c],
image[:, :, c]
)
return masked_image
def apply_mask_overlay(image, mask, color, alpha=0.5):
colored_mask = np.zeros_like(image)
colored_mask[mask > 0] = color
return cv2.addWeighted(image, 1, colored_mask, alpha, 0)
def load_tmp_image(path):
with open(path, "rb") as image_file:
# 将图片文件读取为二进制数据并进行 Base64 编码
encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
return encoded_string
def refresh_image(data,api,path):
result = get_image_display(data,api)
if result['status'] == False:
return {
"status" : False,
"reason" : result['reason']
}
else:
img = result['reason']
display_img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
pil_img = Image.fromarray(display_img_rgb)
tmp_path = os.path.join(path,'temp/output_image.jpg')
pil_img.save(tmp_path)
return {
"status" : True,
"reason" : tmp_path
}
def set_class_color(data,api,index,color):
class_name = data['class_names'][index]
api['class_colors'][class_name] = color
data['class_colors'][class_name] = color
return data,api
def add_to_class(api,class_name):
if class_name is None:
class_name = api['current_class']
if not class_name or class_name not in api['class_annotations']:
return {
'status':False,
'reason': "该分类没有标注点"
}
class_data = api['class_annotations'][class_name]
if 'selected_mask' not in class_data or class_data['selected_mask'] is None:
return {
'status':False,
'reason': "该分类没有进行预测"
}
class_data['final_mask'] = class_data['selected_mask'].copy()
class_data['points'] = []
class_data['point_types'] = []
class_data['masks'] = []
class_data['scores'] = []
class_data['selected_mask_index'] = -1
return {
'status':True,
'reason': api
}
def get_image_display(data,api):
if data['image'] is None:
logger.info("get_display_image: No image loaded")
return {"status":False, "reason":"获取图像:没有图像加载"}
display_image = data['image'].copy()
try:
for point, label in zip(data['input_point'], data['input_label']):
color = (0, 255, 0) if label == 1 else (0, 0, 255)
cv2.circle(display_image, tuple(point), 5, color, -1)
if data['selected_mask'] is not None:
class_name = data['class_names'][data['class_index']]
color = data['class_colors'].get(class_name, (0, 0, 128))
display_image = sam_apply_mask(display_image, data['selected_mask'], color)
logger.info(f"get_display_image: Returning image with shape {display_image.shape}")
img = display_image
if not isinstance(img, np.ndarray) or img.size == 0:
logger.info(f"get_image_display: Invalid image array, shape: {img.shape if isinstance(img, np.ndarray) else 'None'}")
return {"status":False, "reason":f"get_image_display: Invalid image array, shape: {display_image.shape if isinstance(display_image, np.ndarray) else 'None'}"}
# 仅应用当前图片的final_mask
for class_name, class_data in api['class_annotations'].items():
if 'final_mask' in class_data and class_data['final_mask'] is not None:
color = api['class_colors'].get(class_name, (0, 255, 0))
mask = class_data['final_mask']
if isinstance(mask, list):
mask = np.array(mask, dtype=np.uint8)
logger.info(f"Applying mask for class {class_name}, shape: {mask.shape}")
img = apply_mask_overlay(img, mask, color, alpha=0.5)
elif 'selected_mask' in class_data and class_data['selected_mask'] is not None:
color = api['class_colors'].get(class_name, (0, 255, 0))
mask = class_data['selected_mask']
if isinstance(mask, list):
mask = np.array(mask, dtype=np.uint8)
logger.info(f"Applying mask for class {class_name}, shape: {mask.shape}")
img = apply_mask_overlay(img, mask, color, alpha=0.5)
if api['current_class'] and api['current_class'] in api['class_annotations']:
class_data = api['class_annotations'][api['current_class']]
for i, (x, y) in enumerate(class_data['points']):
is_fg = class_data['point_types'][i]
color = (0, 255, 0) if is_fg else (0, 0, 255)
print(f"Drawing point at ({x}, {y}), type: {'foreground' if is_fg else 'background'}")
cv2.circle(img, (int(x), int(y)), 5, color, -1)
logger.info(f"get_image_display: Returning image with shape {img.shape}")
return {"status":True, "reason":img}
except Exception as e:
logger.error(f"get_display_image: Error processing image: {str(e)}")
traceback.print_exc()
return {"status":False, "reason":f"get_display_image: Error processing image: {str(e)}"}
def add_annotation_point(api, x, y, is_foreground=True):
if not api['current_class'] or api['current_class'] not in api['class_annotations']:
return {
"status": False,
"reason": "请选择或新建分类"
}
class_data = api['class_annotations'][api['current_class']]
class_data['points'].append((x, y))
class_data['point_types'].append(is_foreground)
return {
'status' : True,
'reason' : {
'points': class_data['points'],
'types': class_data['point_types']
},
'api':api
}
def delete_last_point(api):
if not api['current_class'] or api['current_class'] not in api['class_annotations']:
return {
"status": False,
"reason": "当前类没有点需要删除"
}
class_data = api['class_annotations'][api['current_class']]
if not class_data['points']:
return {
"status": False,
"reason": "当前类没有点需要删除"
}
class_data['points'].pop()
class_data['point_types'].pop()
return {
"status": True,
"reason": api
}
def reset_current_class_points(api):
if not api['current_class'] or api['current_class'] not in api['class_annotations']:
return {
"status": False,
"reason": "当前类没有点需要删除"
}
class_data = api['class_annotations'][api['current_class']]
class_data['points'] = []
class_data['point_types'] = []
return {
"status": True,
"reason": api
}
def predict_mask(data,sam_predictor):
if data['image_rgb'] is None:
logger.info("predict_mask: No image loaded")
return {"status":False, "reason":"预测掩码:没有图像加载"}
if len(data['input_point']) == 0:
logger.info("predict_mask: No points added")
return {"status":False, "reason":"预测掩码:没有进行点标注"}
try:
sam_predictor.set_image(data['image_rgb'])
except Exception as e:
logger.error(f"predict_mask: Error setting image: {str(e)}")
return {"status":False, "reason":f"predict_mask: Error setting image: {str(e)}"}
input_point_np = np.array(data['input_point'])
input_label_np = np.array(data['input_label'])
try:
masks_pred, scores, logits = sam_predictor.predict(
point_coords=input_point_np,
point_labels=input_label_np,
mask_input=data['logit_input'][None, :, :] if data['logit_input'] is not None else None,
multimask_output=True,
)
except Exception as e:
logger.error(f"predict_mask: Error during prediction: {str(e)}")
traceback.print_exc()
return {"status":False, "reason":f"predict_mask: Error during prediction: {str(e)}"}
data['masks_pred'] = masks_pred
data['scores'] = scores
data['logits'] = logits
best_mask_idx = np.argmax(scores)
data['selected_mask'] = masks_pred[best_mask_idx]
data['logit_input'] = logits[best_mask_idx, :, :]
logger.info(f"predict_mask: Predicted {len(masks_pred)} masks, best score: {scores[best_mask_idx]:.4f}")
return {
"status":True,
"reason":{
"masks": [mask.tolist() for mask in masks_pred],
"scores": scores.tolist(),
"selected_index": int(best_mask_idx),
"data":data
}}
def get_image_info(data):
return {
"filename": data['filename'],
"index": data['current_index'],
"total": len(data['image_files']),
"width": data['image'].shape[1],
"height": data['image'].shape[0]
}