ai-station-code/run.py

1847 lines
65 KiB
Python
Raw Normal View History

2025-05-06 11:18:48 +08:00
import sys
2025-06-04 17:04:02 +08:00
import io
2025-05-14 11:00:24 +08:00
from fastapi import FastAPI,File, UploadFile,Form,Query
2025-05-06 11:18:48 +08:00
from fastapi.staticfiles import StaticFiles
2025-05-14 11:00:24 +08:00
from fastapi import FastAPI, HTTPException
2025-06-04 17:04:02 +08:00
from fastapi.responses import FileResponse,JSONResponse
2025-05-06 11:18:48 +08:00
import sys
import os
import shutil
from pydantic import BaseModel, validator
2025-06-04 17:04:02 +08:00
from typing import List, Optional
2025-05-06 11:18:48 +08:00
import asyncio
import pandas as pd
import numpy as np
from PIL import Image
2025-06-04 17:04:02 +08:00
import pickle
import cv2
import copy
import base64
2025-05-06 11:18:48 +08:00
# 获取当前脚本所在目录
print("Current working directory:", os.getcwd())
current_dir = os.path.dirname(os.path.abspath(__file__))
# 添加环境变量路径
sys.path.append(os.path.join(current_dir))
print("Current sys.path:", sys.path)
import torch
from dimaoshibie import segformer
from wudingpv.taihuyuan_roof.manet.model.resunet import resUnetpamcarb as roof_resUnetpamcarb
from wudingpv.predictandeval_util import segmentation
from guangfufadian import model_base as guangfufadian_model_base
from fenglifadian import model_base as fenglifadian_model_base
2025-06-04 17:04:02 +08:00
from work_util import prepare_data,model_deal,params,data_util,post_model,sam_deal
2025-05-06 11:18:48 +08:00
from work_util.logger import logger
import joblib
import mysql.connector
import uuid
2025-05-14 11:00:24 +08:00
import json
import zipfile
from pathlib import Path
2025-06-04 17:04:02 +08:00
from segment_anything_model import sam_annotator
from segment_anything_model.sam_config import sam_config, sam_api_config
from segment_anything_model.segment_anything import sam_model_registry, SamPredictor
import traceback
2025-05-06 11:18:48 +08:00
version = f"{sys.version_info.major}.{sys.version_info.minor}"
app = FastAPI()
2025-05-14 11:00:24 +08:00
param = params.ModelParams()
def get_roof_model():
model_roof = roof_resUnetpamcarb()
model_path_roof = os.path.join(current_dir,'wudingpv/models/roof_best.pth')
model_dict_roof = torch.load(model_path_roof, map_location=torch.device('cpu'))
model_roof.load_state_dict(model_dict_roof['net'])
logger.info("屋顶识别权重加载成功")
model_roof.eval()
model_roof.cuda()
return model_roof
def get_pv_model():
model_roof = roof_resUnetpamcarb()
model_path_roof = os.path.join(current_dir,'wudingpv/models/pv_best.pth')
model_dict_roof = torch.load(model_path_roof, map_location=torch.device('cpu'))
model_roof.load_state_dict(model_dict_roof['net'])
logger.info("屋顶识别权重加载成功")
model_roof.eval()
model_roof.cuda()
return model_roof
# 模型实例
dimaoshibie_SegFormer = segformer.SegFormer_Segmentation()
# 模型实例
dimaoshibie_SegFormer = segformer.SegFormer_Segmentation()
roof_model = get_roof_model()
pv_model = get_pv_model()
ch4_model_flow = joblib.load(os.path.join(current_dir,'jiawanyuce/liuliang_model/xgb_model_liuliang.pkl'))
ch4_model_gas = joblib.load(os.path.join(current_dir,'jiawanyuce/qixiangnongdu_model/xgb_model_qixiangnongdu.pkl'))
pvfd_param = guangfufadian_model_base.guangfufadian_Args()
pvfd_model_path = os.path.join(pvfd_param.checkpoints,'Crossformer_station08_il192_ol96_sl6_win2_fa10_dm256_nh4_el3_itr0/checkpoint.pth') # 修改为实际模型路径
pvfd_model = guangfufadian_model_base.ModelInference(pvfd_model_path, pvfd_param)
windfd_args = fenglifadian_model_base.fenglifadian_Args()
windfd_model_path = os.path.join(windfd_args.checkpoints,'Crossformer_Wind_farm_il192_ol12_sl6_win2_fa10_dm256_nh4_el3_itr0/checkpoint.pth') # 修改为实际模型路径
windfd_model = fenglifadian_model_base.ModelInference(windfd_model_path, windfd_args)
2025-05-06 11:18:48 +08:00
2025-06-04 17:04:02 +08:00
# 模型加载
checkpoint_path = os.path.join(current_dir,'segment_anything_model/weights/vit_b.pth')
sam = sam_model_registry["vit_b"](checkpoint=checkpoint_path)
device = "cuda" if cv2.cuda.getCudaEnabledDeviceCount() > 0 else "cpu"
_ = sam.to(device=device)
sam_predictor = SamPredictor(sam)
print(f"SAM模型已加载使用设备: {device}")
2025-05-06 11:18:48 +08:00
# 将 /root/app 目录挂载为静态文件
app.mount("/files", StaticFiles(directory="/root/app"), name="files")
@app.get("/")
async def read_root():
message = f"Hello world! From FastAPI running on Uvicorn with Gunicorn. Using Python {version}"
return {"message": message}
2025-06-04 17:04:02 +08:00
# 首页
# 获取数据界面资源信息
@app.get("/ai-station-api/index/show")
async def get_source_index_info():
2025-06-13 16:18:58 +08:00
sql = "SELECT id,application_name, describe_data, img_url,url FROM app_shouye"
2025-06-04 17:04:02 +08:00
data = data_util.fetch_data(sql)
if data is None:
return {
"success":False,
"msg":"获取信息列表失败",
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":data}
2025-05-06 11:18:48 +08:00
# 获取数据界面资源信息
2025-05-14 11:00:24 +08:00
@app.get("/ai-station-api/data_source/show")
2025-05-06 11:18:48 +08:00
async def get_data_source_info():
sql = "SELECT id,application_name, task_type, sample_name, img_url, time, download_url FROM data_samples"
data = data_util.fetch_data(sql)
data = data_util.generate_json_data_source(data)
if data is None:
return {
"success":False,
"msg":"获取信息列表失败",
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":data}
# 获取应用界面资源信息
2025-05-14 11:00:24 +08:00
@app.get("/ai-station-api/app_source/show")
2025-05-06 11:18:48 +08:00
async def get_app_source_info():
sql = "SELECT id,application_name, task_type, sample_name, img_url, time FROM app_samples"
data = data_util.fetch_data(sql)
data = data_util.generate_json_app_source(data)
if data is None:
return {
"success":False,
"msg":"获取信息列表失败",
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":data}
# 获取煤基碳材料界面资源信息
2025-05-14 11:00:24 +08:00
@app.get("/ai-station-api/mjtcl_source/show")
2025-05-06 11:18:48 +08:00
async def get_mjtcl_source_info():
sql = "SELECT id,application_name, task_type, sample_name, img_url, time FROM meijitancailiao_samples"
data = data_util.fetch_data(sql)
data = data_util.generate_json_meijitancailiao_source(data)
if data is None:
return {
"success":False,
"msg":"获取信息列表失败",
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":data}
# 获取煤热解界面资源信息
2025-05-14 11:00:24 +08:00
@app.get("/ai-station-api/mrj_source/show")
2025-05-06 11:18:48 +08:00
async def get_mrj_source_info():
sql = "SELECT id,application_name, task_type, sample_name, img_url, time FROM meirejie_samples"
data = data_util.fetch_data(sql)
data = data_util.generate_json_meirejie_source(data)
if data is None:
return {
"success":False,
"msg":"获取信息列表失败",
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":data}
"""煤基碳材料输入特征说明接口列表
type = {zongkongtiji,zongbiaomianji,tancailiao,meiliqing,moniqi}
"""
2025-05-14 11:00:24 +08:00
@app.get("/ai-station-api/mjt_feature/show")
2025-05-06 11:18:48 +08:00
async def get_mjt_feature_info(type:str = None):
# if type in ["tancailiao","meiliqing"]:
# sql = "SELECT type, chinese_name, col_name, data_type, unit FROM meijitancailiao_features where use_type = %s;"
# else:
2025-05-14 11:00:24 +08:00
sql = "SELECT type, chinese_name, col_name, data_type, unit, data_scale,best_data FROM meijitancailiao_features where use_type = %s;"
2025-05-06 11:18:48 +08:00
data = data_util.fetch_data_with_param(sql,(type,))
if data is None:
return {
"success":False,
"msg":"获取信息列表失败",
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":data}
"""煤热解材料输入特征说明接口列表
type = {tar,char,gas,water}
"""
2025-05-14 11:00:24 +08:00
@app.get("/ai-station-api/mrj_feature/show")
2025-05-06 11:18:48 +08:00
async def get_mrj_feature_info(type:str = None):
2025-06-04 17:04:02 +08:00
sql = "SELECT type, chinese_name, col_name, data_type, unit, data_scale,best_data FROM meijitancailiao_features where use_type = %s;"
2025-05-06 11:18:48 +08:00
data = data_util.fetch_data_with_param(sql,(type,))
if data is None:
return {
"success":False,
"msg":"获取信息列表失败",
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":data}
#========================================煤基碳材料 ==================================
2025-05-14 11:00:24 +08:00
@app.post("/ai-station-api/mjt/predict")
2025-05-06 11:18:48 +08:00
async def mjt_models_predict(content: post_model.Zongbiaomianji):
# 处理接收到的字典数据
meijiegou_test_content = pd.DataFrame([content.model_dump()])
meijiegou_test_content= meijiegou_test_content.rename(columns={"K_C":"K/C"})
new_order = ["A", "VM", "K/C", "MM", "AT", "At", "Rt"]
meijiegou_test_content = meijiegou_test_content.reindex(columns=new_order)
ssa_result = model_deal.pred_single_ssa(meijiegou_test_content)
logger.info("Root endpoint was accessed")
if ssa_result is None:
return {
"success":False,
"msg":"获取信息列表失败",
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":ssa_result}
#==========================煤基碳材料-总表面积细节接口==================================
"""
模型综合分析
"""
2025-05-14 11:00:24 +08:00
@app.post("/ai-station-api/mjt_models_ssa/predict")
2025-05-06 11:18:48 +08:00
async def mjt_models_predict_ssa(content: post_model.Zongbiaomianji):
# 处理接收到的字典数据
meijiegou_test_content = pd.DataFrame([content.model_dump()])
meijiegou_test_content= meijiegou_test_content.rename(columns={"K_C":"K/C"})
new_order = ["A", "VM", "K/C", "MM", "AT", "At", "Rt"]
meijiegou_test_content = meijiegou_test_content.reindex(columns=new_order)
ssa_result = model_deal.pred_single_ssa(meijiegou_test_content)
2025-06-04 17:04:02 +08:00
# logger.info("Root endpoint was accessed")
2025-05-06 11:18:48 +08:00
if ssa_result is None:
return {
"success":False,
"msg":"获取信息列表失败",
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":ssa_result}
"""
上传文件
"""
2025-05-14 11:00:24 +08:00
@app.post("/ai-station-api/document/upload")
2025-05-06 11:18:48 +08:00
async def upload_file(file: UploadFile = File(...),type: str = Form(...), ):
if not file.filename.endswith('.csv'):
raise HTTPException(status_code=400, detail="文件类型有误,必须是 CSV 文件")
upload_dir = os.path.join(current_dir,'tmp',type, str(uuid.uuid4()))
if not os.path.exists(upload_dir ):
os.makedirs(upload_dir )
# 将文件保存到指定目录
file_location = os.path.join(upload_dir , file.filename)
with open(file_location, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
return {"success":True,
"msg":"获取信息成功",
"data":{"location": file_location}}
"""
批量预测接口
"""
2025-05-14 11:00:24 +08:00
@app.get("/ai-station-api/mjt_multi_ssa/predict")
2025-05-06 11:18:48 +08:00
async def mjt_multi_ssa_pred(model:str = None, path:str = None):
data = model_deal.get_excel_ssa(model, path)
2025-06-04 17:04:02 +08:00
if data['status'] == True:
return {"success":True,
2025-05-06 11:18:48 +08:00
"msg":"获取信息成功",
2025-06-04 17:04:02 +08:00
"data":data['reason']}
else:
return {"success":False,
"msg":data['reason'],
"data":None}
2025-05-06 11:18:48 +08:00
#==========================煤基碳材料-总孔体积细节接口==================================
"""
模型综合分析
"""
2025-05-14 11:00:24 +08:00
@app.post("/ai-station-api/mjt_models_tpv/predict")
2025-05-06 11:18:48 +08:00
async def mjt_models_predict_tpv(content: post_model.Zongbiaomianji):
# 处理接收到的字典数据
meijiegou_test_content = pd.DataFrame([content.model_dump()])
meijiegou_test_content= meijiegou_test_content.rename(columns={"K_C":"K/C"})
new_order = ["A", "VM", "K/C", "MM", "AT", "At", "Rt"]
meijiegou_test_content = meijiegou_test_content.reindex(columns=new_order)
tpv_result = model_deal.pred_single_tpv(meijiegou_test_content)
if tpv_result is None:
return {
"success":False,
"msg":"获取信息列表失败",
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":tpv_result}
"""
批量预测接口
"""
2025-05-14 11:00:24 +08:00
@app.get("/ai-station-api/mjt_multi_tpv/predict")
2025-05-06 11:18:48 +08:00
async def mjt_multi_tpv_pred(model:str = None, path:str = None):
data = model_deal.get_excel_tpv(model, path)
2025-06-04 17:04:02 +08:00
if data['status'] == True:
return {"success":True,
2025-05-06 11:18:48 +08:00
"msg":"获取信息成功",
2025-06-04 17:04:02 +08:00
"data":data['reason']}
else:
return {"success":False,
"msg":data['reason'],
"data":None}
2025-05-06 11:18:48 +08:00
#==========================煤基碳材料-煤炭材料应用细节接口==================================
2025-05-14 11:00:24 +08:00
@app.post("/ai-station-api/mjt_models_meitan/predict")
2025-05-06 11:18:48 +08:00
async def mjt_models_predict_meitan(content: post_model.Meitan):
# 处理接收到的字典数据
meijiegou_test_content = pd.DataFrame([content.model_dump()])
meijiegou_test_content= meijiegou_test_content.rename(columns={"ID_IG":"ID/IG"})
new_order = ["SSA", "TPV", "N", "O", "ID/IG", "J"]
meijiegou_test_content = meijiegou_test_content.reindex(columns=new_order)
meitan_result = model_deal.pred_single_meitan(meijiegou_test_content)
if meitan_result is None:
return {
"success":False,
"msg":"获取信息列表失败",
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":meitan_result}
"""
批量预测接口
"""
2025-05-14 11:00:24 +08:00
@app.get("/ai-station-api/mjt_multi_meitan/predict")
2025-05-06 11:18:48 +08:00
async def mjt_multi_meitan_pred(model:str = None, path:str = None):
data = model_deal.get_excel_meitan(model, path)
2025-06-04 17:04:02 +08:00
if data['status'] == True:
return {"success":True,
2025-05-06 11:18:48 +08:00
"msg":"获取信息成功",
2025-06-04 17:04:02 +08:00
"data":data['reason']}
else:
return {"success":False,
"msg":data['reason'],
"data":None}
2025-05-06 11:18:48 +08:00
#==========================煤基碳材料-煤沥青应用细节接口==================================
2025-05-14 11:00:24 +08:00
@app.post("/ai-station-api/mjt_models_meiliqing/predict")
2025-05-06 11:18:48 +08:00
async def mjt_models_predict_meiliqing(content: post_model.Meitan):
# 处理接收到的字典数据
meijiegou_test_content = pd.DataFrame([content.model_dump()])
meijiegou_test_content= meijiegou_test_content.rename(columns={"ID_IG":"ID/IG"})
new_order = ["SSA", "TPV", "N", "O", "ID/IG", "J"]
meijiegou_test_content = meijiegou_test_content.reindex(columns=new_order)
meitan_result = model_deal.pred_single_meiliqing(meijiegou_test_content)
if meitan_result is None:
return {
"success":False,
"msg":"获取信息列表失败",
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":meitan_result}
"""
批量预测接口
"""
2025-05-14 11:00:24 +08:00
@app.get("/ai-station-api/mjt_multi_meiliqing/predict")
2025-05-06 11:18:48 +08:00
async def mjt_multi_meiliqing_pred(model:str = None, path:str = None):
data = model_deal.get_excel_meiliqing(model, path)
2025-06-04 17:04:02 +08:00
if data['status'] == True:
return {"success":True,
2025-05-06 11:18:48 +08:00
"msg":"获取信息成功",
2025-06-04 17:04:02 +08:00
"data":data['reason']}
else:
return {"success":False,
"msg":data['reason'],
"data":None}
2025-05-06 11:18:48 +08:00
#=============================煤基炭材料-制备模拟器================================================
2025-05-14 11:00:24 +08:00
# 正向模拟器
@app.post("/ai-station-api/mnq_model/analysis")
2025-05-06 11:18:48 +08:00
async def mnq_model_predict(return_count: int, # 返回条数
model_choice: str, # 模型选择标记
form_data: post_model.FormData # 表单数据
):
form_data = form_data.model_dump()
params = prepare_data.get_params(form_data)
pred_data = prepare_data.create_pred_data(params)
result = model_deal.pred_func(model_choice,pred_data)
sorted_result = result.sort_values(by=['SSA', 'TPV'], ascending=[False, False])
2025-05-14 11:00:24 +08:00
upload_dir = os.path.join(current_dir,'tmp','moniqi', str(uuid.uuid4()))
if not os.path.exists(upload_dir):
os.makedirs(upload_dir )
# 将文件保存到指定目录
file_name = model_choice + "_moniqi.csv"
file_location = os.path.join(upload_dir , file_name)
sorted_result.to_csv(file_location, index = False)
2025-05-06 11:18:48 +08:00
# 保留条数
if return_count is None:
return {"success":True,
"msg":"获取信息成功",
2025-05-14 11:00:24 +08:00
"data":{"result": sorted_result.to_dict(orient='records'), "path":file_location}}
2025-05-06 11:18:48 +08:00
else:
2025-05-14 11:00:24 +08:00
data = sorted_result.iloc[:return_count]
# logger.info(data.to_dict(orient="records"))
2025-05-06 11:18:48 +08:00
return {"success":True,
"msg":"获取信息成功",
2025-05-14 11:00:24 +08:00
"data":{"result":data.to_dict(orient="records"), "path":file_location}
}
2025-05-06 11:18:48 +08:00
2025-05-14 11:00:24 +08:00
# 反向模拟器
@app.get("/ai-station-api/mnq_model/reverse_analysis")
async def mjt_multi_meiliqing_pred(model:str = None, path:str = None, type:int = 1, scale:float=None, num:int=5):
result = prepare_data.moniqi_data_prepare(model,path,type,scale,num)
if result['status'] == False:
return {
"success":False,
"msg":result['reason'],
"data":None
}
else:
return {
"success":True,
"msg":"获取信息成功",
"data":{"data": result['reason']}
}
2025-06-04 17:04:02 +08:00
#===============================煤热解-tar =======================================================
"""
模型综合分析
"""
@app.post("/ai-station-api/mrj_models_tar/predict")
async def mrj_models_predict_tar(content: post_model.Meirejie):
# 处理接收到的字典数据
test_content = pd.DataFrame([content.model_dump()])
test_content= test_content.rename(columns={"H_C":"H/C"})
test_content= test_content.rename(columns={"O_C":"O/C"})
test_content= test_content.rename(columns={"N_C":"N/C"})
new_order = ["A", "V", "FC", "C", "H", "N" ,"S" ,"O" ,"H/C" ,"O/C" ,"N/C" ,"Rt" ,"Hr", "dp","T"]
test_content = test_content.reindex(columns=new_order)
tmp_result = model_deal.pred_single_tar(test_content)
if tmp_result is None:
return {
"success":False,
"msg":"获取信息列表失败",
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":tmp_result}
"""
批量预测接口
"""
@app.get("/ai-station-api/mrj_multi_tar/predict")
async def mrj_multi_tar_pred(model:str = None, path:str = None):
data = model_deal.get_excel_tar(model, path)
if data['status'] == True:
return {"success":True,
"msg":"获取信息成功",
"data":data['reason']}
else:
return {"success":False,
"msg":data['reason'],
"data":None}
#===============================煤热解-char =======================================================
"""
模型综合分析
"""
@app.post("/ai-station-api/mrj_models_char/predict")
async def mrj_models_predict_char(content: post_model.Meirejie):
# 处理接收到的字典数据
test_content = pd.DataFrame([content.model_dump()])
test_content= test_content.rename(columns={"H_C":"H/C"})
test_content= test_content.rename(columns={"O_C":"O/C"})
test_content= test_content.rename(columns={"N_C":"N/C"})
new_order = ["A", "V", "FC", "C", "H", "N" ,"S" ,"O" ,"H/C" ,"O/C" ,"N/C" ,"Rt" ,"Hr", "dp","T"]
test_content = test_content.reindex(columns=new_order)
tmp_result = model_deal.pred_single_char(test_content)
if tmp_result is None:
return {
"success":False,
"msg":"获取信息列表失败",
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":tmp_result}
"""
批量预测接口
"""
@app.get("/ai-station-api/mrj_multi_char/predict")
async def mrj_multi_char_pred(model:str = None, path:str = None):
data = model_deal.get_excel_char(model, path)
if data['status'] == True:
return {"success":True,
"msg":"获取信息成功",
"data":data['reason']}
else:
return {"success":False,
"msg":data['reason'],
"data":None}
#===============================煤热解-water =======================================================
"""
模型综合分析
"""
@app.post("/ai-station-api/mrj_models_water/predict")
async def mrj_models_predict_water(content: post_model.Meirejie):
# 处理接收到的字典数据
test_content = pd.DataFrame([content.model_dump()])
test_content= test_content.rename(columns={"H_C":"H/C"})
test_content= test_content.rename(columns={"O_C":"O/C"})
test_content= test_content.rename(columns={"N_C":"N/C"})
new_order = ["A", "V", "FC", "C", "H", "N" ,"S" ,"O" ,"H/C" ,"O/C" ,"N/C" ,"Rt" ,"Hr", "dp","T"]
test_content = test_content.reindex(columns=new_order)
tmp_result = model_deal.pred_single_water(test_content)
if tmp_result is None:
return {
"success":False,
"msg":"获取信息列表失败",
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":tmp_result}
"""
批量预测接口
"""
@app.get("/ai-station-api/mrj_multi_water/predict")
async def mrj_multi_water_pred(model:str = None, path:str = None):
data = model_deal.get_excel_water(model, path)
if data['status'] == True:
return {"success":True,
"msg":"获取信息成功",
"data":data['reason']}
else:
return {"success":False,
"msg":data['reason'],
"data":None}
#===============================煤热解-gas =======================================================
"""
模型综合分析
"""
@app.post("/ai-station-api/mrj_models_gas/predict")
async def mrj_models_predict_gas(content: post_model.Meirejie):
# 处理接收到的字典数据
test_content = pd.DataFrame([content.model_dump()])
test_content= test_content.rename(columns={"H_C":"H/C"})
test_content= test_content.rename(columns={"O_C":"O/C"})
test_content= test_content.rename(columns={"N_C":"N/C"})
new_order = ["A", "V", "FC", "C", "H", "N" ,"S" ,"O" ,"H/C" ,"O/C" ,"N/C" ,"Rt" ,"Hr", "dp","T"]
test_content = test_content.reindex(columns=new_order)
tmp_result = model_deal.pred_single_gas(test_content)
if tmp_result is None:
return {
"success":False,
"msg":"获取信息列表失败",
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":tmp_result}
"""
批量预测接口
"""
@app.get("/ai-station-api/mrj_multi_gas/predict")
async def mrj_multi_gas_pred(model:str = None, path:str = None):
data = model_deal.get_excel_gas(model, path)
if data['status'] == True:
return {"success":True,
"msg":"获取信息成功",
"data":data['reason']}
else:
return {"success":False,
"msg":data['reason'],
"data":None}
2025-05-14 11:00:24 +08:00
#================================ 地貌识别 ====================================================
"""
上传图片 , type = dimaoshibie
"""
@app.post("/ai-station-api/image/upload")
async def upload_image(file: UploadFile = File(...),type: str = Form(...), ):
if not data_util.allowed_file(file.filename):
return {
"success":False,
"msg":"图片必须以 '.jpg', '.jpeg', '.png', '.tif' 结尾",
"data":None
}
if file.size > param.MAX_FILE_SIZE:
return {
"success":False,
"msg":"图片大小不能大于100MB",
"data":None
}
upload_dir = os.path.join(current_dir,'tmp',type, str(uuid.uuid4()))
if not os.path.exists(upload_dir):
os.makedirs(upload_dir )
# 将文件保存到指定目录
file_location = os.path.join(upload_dir , file.filename)
with open(file_location, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
2025-06-13 16:18:58 +08:00
encoded_string = sam_deal.load_tmp_image(file_location)
2025-05-14 11:00:24 +08:00
file_location = model_deal.get_pic_url(file_location)
return {"success":True,
"msg":"获取信息成功",
2025-06-13 16:18:58 +08:00
"data":{"location": file_location},
"image":JSONResponse(content={"image_data": encoded_string})
}
2025-05-14 11:00:24 +08:00
2025-06-04 17:04:02 +08:00
2025-05-14 11:00:24 +08:00
"""
图片地貌识别
"""
@app.get("/ai-station-api/dmsb_image/analysis")
async def dmsb_image_analysis(path:str = None):
path = model_deal.get_pic_path(path)
result = model_deal.dimaoshibie_pic(dimaoshibie_SegFormer,path,param.dmsb_count,param.dmsb_name_classes)
if result['status'] == False:
return {"success":False,
"msg":result['reason'],
"data":None}
else:
path = result['reason']
2025-06-13 16:18:58 +08:00
encoded_string = sam_deal.load_tmp_image(path[0])
2025-05-14 11:00:24 +08:00
path= model_deal.get_pic_url(path[0]) # 是一个列表
return {"success":True,
"msg":"获取信息成功",
2025-06-13 16:18:58 +08:00
"data":{"result": path},
"image":JSONResponse(content={"image_data": encoded_string})
}
2025-05-14 11:00:24 +08:00
"""
图片地貌 - 分割结果计算 ,图片demo的是z17
"""
@app.get("/ai-station-api/dmsb_image/calculate")
async def dmsb_image_calculate(path:str = None, scale:float = 1.92*1.92):
path = model_deal.get_pic_path(path)
2025-06-13 16:18:58 +08:00
# result = model_deal.dimaoshibie_area(path,scale,param.dmsb_colors)
# logger.info(result)
# if result['status'] == True:
# res = result['reason']
# translated_dict = {param.dmsb_type[key]: value for key, value in res.items()}
file_directory = os.path.dirname(path)
output_file_path = os.path.join(file_directory,'result.txt')
# 初始化一个空字典
total_piex = {}
with open(output_file_path, 'r') as file:
for line in file:
# 去掉行首尾的空白字符,并分割键值
key, value = line.strip().split(': ')
# 将值转换为整数并存入字典
total_piex[key] = int(value) * scale
return {"success":True,
"msg":"获取信息成功",
"data":{"result": total_piex}}
2025-05-14 11:00:24 +08:00
"""
下载
"""
@app.get("/ai-station-api/download")
async def download_zip(path:str = None):
path = model_deal.get_pic_path(path)
zip_filename = "download.zip"
zip_filepath = Path(zip_filename)
dir_path = os.path.dirname(path)
# 创建 ZIP 文件
try:
with zipfile.ZipFile(zip_filepath, 'w') as zip_file:
# 遍历文件夹中的所有文件
for foldername, subfolders, filenames in os.walk(dir_path):
for filename in filenames:
file_path = os.path.join(foldername, filename)
# 将文件写入 ZIP 文件
zip_file.write(file_path, os.path.relpath(file_path, dir_path))
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error creating zip file: {str(e)}")
# 返回 ZIP 文件作为响应
return FileResponse(zip_filepath, media_type='application/zip', filename=zip_filename)
#============================屋顶识别======================================
"""
上传图片 type = roof
"""
"""
图片分析
"""
@app.get("/ai-station-api/roof_image/analysis")
async def roof_image_analysis(path:str = None):
path = model_deal.get_pic_path(path)
result = model_deal.roof_pic(roof_model,path,param.wdpv_palette)
if result['status'] == False:
return {"success":False,
"msg":result['reason'],
"data":None}
else:
path = result['reason']
2025-06-13 16:18:58 +08:00
encoded_string = sam_deal.load_tmp_image(path[0])
2025-05-14 11:00:24 +08:00
path= model_deal.get_pic_url(path[0])
return {"success":True,
"msg":"获取信息成功",
2025-06-13 16:18:58 +08:00
"data":{"result": path},
"image":JSONResponse(content={"image_data": encoded_string})
}
2025-05-14 11:00:24 +08:00
"""
屋顶面积 - 分割结果计算 ,图片demo的是z18
"""
@app.get("/ai-station-api/roof_image/calculate")
async def roof_image_calculate(path:str = None, scale:float = 0.92*0.92):
path = model_deal.get_pic_path(path)
result = model_deal.roof_area(path,scale,param.wdpv_colors)
logger.info(result)
if result['status'] == True:
res = result['reason']
translated_dict = {param.wd_type[key]: value for key, value in res.items()}
del translated_dict['其他']
return {"success":True,
"msg":"获取信息成功",
"data":{"result": translated_dict}}
else:
return {"success":False,
"msg":result['reason'],
"data":None}
#============================光伏识别======================================
"""
上传图片 type = pv
"""
"""
图片分析
"""
@app.get("/ai-station-api/pv_image/analysis")
async def pv_image_analysis(path:str = None):
path = model_deal.get_pic_path(path)
result = model_deal.roof_pic(pv_model,path,param.wdpv_palette)
if result['status'] == False:
return {"success":False,
"msg":result['reason'],
"data":None}
else:
path = result['reason']
2025-06-13 16:18:58 +08:00
encoded_string = sam_deal.load_tmp_image(path[0])
2025-05-14 11:00:24 +08:00
path= model_deal.get_pic_url(path[0])
return {"success":True,
"msg":"获取信息成功",
2025-06-13 16:18:58 +08:00
"data":{"result": path},
"image":JSONResponse(content={"image_data": encoded_string})}
2025-05-14 11:00:24 +08:00
"""
光伏面积 - 分割结果计算 ,图片demo的是z18
"""
@app.get("/ai-station-api/pv_image/calculate")
async def pv_image_calculate(path:str = None, scale:float = 0.92*0.92):
path = model_deal.get_pic_path(path)
result = model_deal.roof_area(path,scale,param.wdpv_colors)
logger.info(result)
if result['status'] == True:
res = result['reason']
translated_dict = {param.pv_type[key]: value for key, value in res.items()}
del translated_dict['其他']
return {"success":True,
"msg":"获取信息成功",
"data":{"result": translated_dict}}
else:
return {"success":False,
"msg":result['reason'],
"data":None}
#============================屋顶光伏识别======================================
"""
上传图片 type = roofpv
"""
"""
图片分析
"""
@app.get("/ai-station-api/roofpv_image/analysis")
async def roofpv_image_analysis(path:str = None):
path = model_deal.get_pic_path(path)
result = model_deal.roofpv_pic(roof_model,pv_model,path,param.wdpv_palette)
if result['status'] == False:
return {"success":False,
"msg":result['reason'],
"data":None}
else:
file_list = result['reason']
final_path = prepare_data.merge_binary(file_list)
2025-06-13 16:18:58 +08:00
encoded_string = sam_deal.load_tmp_image(final_path)
2025-05-14 11:00:24 +08:00
final_path = model_deal.get_pic_url(final_path)
return {"success":True,
"msg":"获取信息成功",
2025-06-13 16:18:58 +08:00
"data":{"result": final_path},
"image":JSONResponse(content={"image_data": encoded_string})
}
2025-05-14 11:00:24 +08:00
"""
光伏面积 - 分割结果计算 ,图片demo的是z18
"""
@app.get("/ai-station-api/roofpv_image/calculate")
async def roofpv_image_calculate(path:str = None, scale:float = 0.92*0.92):
path = model_deal.get_pic_path(path)
result = model_deal.roof_area_roofpv(path,scale,param.wdpv_colors)
logger.info(result)
if result['status'] == True:
res = result['reason']
translated_dict = {param.wdpv_type[key]: value for key, value in res.items()}
del translated_dict['其他']
return {"success":True,
"msg":"获取信息成功",
"data":{"result": translated_dict}}
else:
return {"success":False,
"msg":result['reason'],
"data":None}
# ====================================时序预测类============================================
# ====================================甲烷预测==============================================
"""
返回显示列表
"""
@app.get("/ai-station-api/ch4_features/show")
async def get_ch4_features_info():
sql = "SELECT chinese_name, col_name,type,data_type,unit,data_sample FROM ch4_features;"
data = data_util.fetch_data(sql)
if data is None:
return {
"success":False,
"msg":"获取信息列表失败",
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":data}
@app.get("/ai-station-api/ch4_pic_table/show")
async def get_ch4_pic_table_info():
sql = "SELECT chinese_name, col_name FROM ch4_features where col_name != 'date_time';"
data = data_util.fetch_data(sql)
if data is None:
return {
"success":False,
"msg":"获取信息列表失败",
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":data}
"""
针对列表返回数据供画图
"""
@app.get("/ai-station-api/ch4_feature_pic/select")
async def get_ch4_feature_pic_select(type:str = None, path:str =None):
if type == 'date_time':
return {
"success":False,
"msg":"时间列无法展示,请重新选择",
"data":None
}
data = prepare_data.show_data_jiawanyuce(type,path)
if data is None:
return {
"success":False,
"msg":"获取信息列表失败",
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":data}
"""
上传数据分页展示
"""
@app.get("/ai-station-api/ch4_data/show")
async def get_ch4_data(path:str = None, page: int = Query(1, ge=1), page_size: int = Query(10, ge=1)):
data = prepare_data.get_jiawanyuce_data(path)
total_items = len(data) # 获取总条数
total_pages = (total_items + page_size - 1) // page_size # 计算总页数
# 检查请求的页码是否超出范围
if page > total_pages:
raise HTTPException(status_code=404, detail="Page not found")
# 计算起始和结束索引
start_idx = (page - 1) * page_size
end_idx = start_idx + page_size
# 获取当前页的数据
items = data.iloc[start_idx:end_idx].to_dict(orient="records")
# 返回数据,包括总条数、总页数和当前页的数据
return {
"success":True,
"msg":"获取信息成功",
"data":{
"total_items": total_items,
"total_pages": total_pages,
"current_page": page,
"page_size": page_size,
"items": items
}
}
"""
csv上传
"""
# @app.post("/ai-station-api/document/upload") type = "ch4"
"""
预测, type
"""
@app.get("/ai-station-api/ch4_data/predict")
async def get_ch4_predict(path:str=None,start_time:str=None, end_time:str = None,type:int=1,is_show:bool=True):
data = model_deal.start_predict_endpoint(ch4_model_flow,ch4_model_gas,path,start_time,end_time,type,is_show)
if data['status'] ==False:
return {
"success":False,
"msg":data['reason'],
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
2025-06-13 16:18:58 +08:00
"data":data['reason'],
'time':end_time}
2025-05-14 11:00:24 +08:00
#========================================光伏预测==========================================================
"""
返回显示列表
"""
@app.get("/ai-station-api/pvelectric_features/show")
async def get_pvelectric_features_info():
sql = "SELECT chinese_name, col_name,type,data_type,unit,data_sample FROM pv_electric_features;"
data = data_util.fetch_data(sql)
if data is None:
return {
"success":False,
"msg":"获取信息列表失败",
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":data}
@app.get("/ai-station-api/pvelectric_pic_table/show")
async def get_pvelectric_pic_table_info():
sql = "SELECT chinese_name, col_name FROM pv_electric_features where col_name != 'date_time';"
data = data_util.fetch_data(sql)
if data is None:
return {
"success":False,
"msg":"获取信息列表失败",
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":data}
"""
针对列表返回数据供画图
"""
@app.get("/ai-station-api/pvelectric_feature_pic/select")
async def get_pvelectric_feature_pic_select(type:str = None, path:str =None):
if type == 'date_time':
return {
"success":False,
"msg":"时间列无法展示,请重新选择",
"data":None
}
data = prepare_data.show_data_pvfd(path,type)
if data is None:
return {
"success":False,
"msg":"获取信息列表失败",
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":data}
"""
上传数据分页展示
"""
@app.get("/ai-station-api/pvelectric_data/show")
async def get_pvelectric_data(path:str = None, page: int = Query(1, ge=1), page_size: int = Query(10, ge=1)):
data = prepare_data.show_testdata_pvfd(path)
total_items = len(data) # 获取总条数
total_pages = (total_items + page_size - 1) // page_size # 计算总页数
# 检查请求的页码是否超出范围
if page > total_pages:
raise HTTPException(status_code=404, detail="Page not found")
# 计算起始和结束索引
start_idx = (page - 1) * page_size
end_idx = start_idx + page_size
# 获取当前页的数据
items = data.iloc[start_idx:end_idx].to_dict(orient="records")
# 返回数据,包括总条数、总页数和当前页的数据
return {
"success":True,
"msg":"获取信息成功",
"data":{
"total_items": total_items,
"total_pages": total_pages,
"current_page": page,
"page_size": page_size,
"items": items
}
}
"""
csv上传
"""
# @app.post("/ai-station-api/document/upload") type = "pvelectric"
"""
预测, type
"""
@app.get("/ai-station-api/pvelectric_data/predict")
async def get_pvelectri_predict(path:str=None,start_time:str=None, end_time:str = None,is_show:bool=True):
data = model_deal.start_pvelectric_predict_endpoint(pvfd_model,path,start_time,end_time,is_show)
if data['status'] ==False:
return {
"success":False,
"msg":data['reason'],
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
2025-06-13 16:18:58 +08:00
'time':end_time,
2025-05-14 11:00:24 +08:00
"data":data['reason']}
#========================================风力发电预测==========================================================
"""
返回显示列表
"""
@app.get("/ai-station-api/wind_electric_features/show")
async def get_wind_electric_features_info():
sql = "SELECT chinese_name, col_name,type,data_type,unit,data_sample FROM wind_electric_features;"
data = data_util.fetch_data(sql)
if data is None:
return {
"success":False,
"msg":"获取信息列表失败",
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":data}
@app.get("/ai-station-api/wind_electric_pic_table/show")
async def get_wind_electric_pic_table_info():
sql = "SELECT chinese_name, col_name FROM wind_electric_features where col_name != 'date';"
data = data_util.fetch_data(sql)
if data is None:
return {
"success":False,
"msg":"获取信息列表失败",
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":data}
"""
针对列表返回数据供画图
"""
@app.get("/ai-station-api/wind_electric_feature_pic/select")
async def get_wind_electric_feature_pic_select(type:str = None, path:str =None):
if type == 'date':
return {
"success":False,
"msg":"时间列无法展示,请重新选择",
"data":None
}
data = prepare_data.show_data_windfd(path,type)
if data is None:
return {
"success":False,
"msg":"获取信息列表失败",
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":data}
"""
上传数据分页展示
"""
@app.get("/ai-station-api/wind_electric_data/show")
async def get_wind_electric_data(path:str = None, page: int = Query(1, ge=1), page_size: int = Query(10, ge=1)):
data = prepare_data.show_testdata_windfd(path)
total_items = len(data) # 获取总条数
total_pages = (total_items + page_size - 1) // page_size # 计算总页数
# 检查请求的页码是否超出范围
if page > total_pages:
raise HTTPException(status_code=404, detail="Page not found")
# 计算起始和结束索引
start_idx = (page - 1) * page_size
end_idx = start_idx + page_size
# 获取当前页的数据
items = data.iloc[start_idx:end_idx].to_dict(orient="records")
# 返回数据,包括总条数、总页数和当前页的数据
return {
"success":True,
"msg":"获取信息成功",
"data":{
"total_items": total_items,
"total_pages": total_pages,
"current_page": page,
"page_size": page_size,
"items": items
}
}
"""
csv上传
"""
# @app.post("/ai-station-api/document/upload") type = "wind_electric"
2025-06-04 17:04:02 +08:00
2025-05-14 11:00:24 +08:00
"""
预测, type
"""
@app.get("/ai-station-api/wind_electric_data/predict")
async def get_wind_electri_predict(path:str=None,start_time:str=None, end_time:str = None,is_show:bool=True):
data = model_deal.start_wind_electric_predict_endpoint(windfd_model,path,start_time,end_time,is_show)
2025-06-13 16:18:58 +08:00
# logger.info(data)
2025-05-14 11:00:24 +08:00
if data['status'] ==False:
return {
"success":False,
"msg":data['reason'],
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
2025-06-13 16:18:58 +08:00
'time':end_time,
2025-05-14 11:00:24 +08:00
"data":data['reason']}
2025-06-04 17:04:02 +08:00
#======================== SAM =============================================================
"""
文件上传
1 创建inputs目录outputs目录
2 将图片上传到iuputs目录
4 加载当前图像
5 返回当前图像路径给前端
"""
@app.post("/ai-station-api/sam-image/upload")
async def upload_sam_image(file: UploadFile = File(...),type: str = Form(...), ):
if not data_util.allowed_file(file.filename):
return {
"success":False,
"msg":"图片必须以 '.jpg', '.jpeg', '.png', '.tif' 结尾",
"data":None
}
if file.size > param.MAX_FILE_SAM_SIZE:
return {
"success":False,
"msg":"图片大小不能大于10MB",
"data":None
}
upload_dir = os.path.join(current_dir,'tmp',type, str(uuid.uuid4()))
if not os.path.exists(upload_dir):
os.makedirs(upload_dir )
input_dir = os.path.join(upload_dir,'input')
os.makedirs(input_dir)
output_dir = os.path.join(upload_dir,'output')
os.makedirs(output_dir)
temp_dir = os.path.join(upload_dir,'temp')
os.makedirs(temp_dir)
# 将文件保存到指定目录
file_location = os.path.join(input_dir , file.filename)
with open(file_location, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
config_path = os.path.join(upload_dir,'model_params.pickle')
# # 初始化配置 , 每次上传图片时,会创建一个新的配置文件
config = copy.deepcopy(sam_config)
api_config = copy.deepcopy(sam_api_config)
config['input_dir'] = input_dir
config['output_dir'] = output_dir
config['image_files'] = [f for f in os.listdir(input_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff'))]
config['current_index'] = 0
config['filename'] = config['image_files'][config['current_index']]
image_path = os.path.join(config['input_dir'], config['filename'])
config['image'] = cv2.imread(image_path)
config['image_rgb'] = cv2.cvtColor(config['image'].copy(), cv2.COLOR_BGR2RGB)
# 配置类获取参数信息
api_config['output_dir'] = output_dir
# # 重置pickle中的信息
# 重置API中的class_annotations保留类别但清除掩码和点
config, api_config = sam_deal.reset_class_annotations(config,api_config)
config = sam_deal.reset_annotation(config)
save_data = (config,api_config)
with open(config_path, 'wb') as file:
pickle.dump(save_data, file)
# 将图片拷贝到temp目录下
pil_img = Image.fromarray(config['image_rgb'])
tmp_path = os.path.join(temp_dir,'output_image.jpg')
pil_img.save(tmp_path)
encoded_string = sam_deal.load_tmp_image(tmp_path)
return {"success":True,
"msg":"获取信息成功",
"image": JSONResponse(content={"image_data": encoded_string}),
#"input_dir": input_dir,
#'output_dir': output_dir,
#'temp_dir': temp_dir,
'file_path': upload_dir}
"""
添加分类
添加分类并将分类设置为最新的添加的分类;
要求针对返回current_index,将列表默认成选择current_index对应的分类
"""
2025-06-13 16:18:58 +08:00
@app.post("/ai-station-api/sam_class/create")
async def sam_class_set(item:post_model.samItem):
class_name = item.class_name
color= item.color
path = item.path
2025-06-04 17:04:02 +08:00
loaded_data,api_config = sam_deal.load_model(path)
result = sam_deal.add_class(loaded_data,class_name,color)
if result['status'] == True:
loaded_data = result['reason']
else:
return {
"success":False,
"msg":result['reason'],
"data":None
}
loaded_data['class_index'] = loaded_data['class_names'].index(class_name)
r, g, b = [int(c) for c in color]
bgr_color = (b, g, r)
result, api_config = sam_deal.set_current_class(loaded_data, api_config, loaded_data['class_index'], color=bgr_color)
# 更新配置内容
sam_deal.save_model(loaded_data,api_config,path)
tmp_path = os.path.join(path,'temp/output_image.jpg')
encoded_string = sam_deal.load_tmp_image(tmp_path)
return {"success":True,
"msg":f"已添加类别: {class_name}, 颜色: {color}",
"image": JSONResponse(content={"image_data": encoded_string}),
"data":{"class_name_list": loaded_data['class_names'],
"current_index": loaded_data['class_index'],
"class_dict":loaded_data['class_colors'],
2025-06-13 16:18:58 +08:00
"color":color,
2025-06-04 17:04:02 +08:00
}}
"""
选择颜色,
current_index : 下拉列表中的分类索引
rgb_color :
"""
2025-06-13 16:18:58 +08:00
@app.post("/ai-station-api/sam_color/select")
async def set_sam_color(item:post_model.samItem2):
current_index = item.current_index
rgb_color = item.rgb_color
path = item.path
2025-06-04 17:04:02 +08:00
loaded_data,api_config = sam_deal.load_model(path)
r, g, b = [int(c) for c in rgb_color]
bgr_color = (b, g, r)
data, api = sam_deal.set_class_color(loaded_data, api_config, current_index, bgr_color)
result, api_config = sam_deal.set_current_class(data, api, current_index, color=bgr_color)
sam_deal.save_model(data,api,path)
img = sam_deal.refresh_image(data,api,path)
if img['status'] == True:
encoded_string = sam_deal.load_tmp_image(img['reason'])
return {
"success":True,
"msg":"",
2025-06-13 16:18:58 +08:00
"color":rgb_color,
2025-06-04 17:04:02 +08:00
"image": JSONResponse(content={"image_data": encoded_string}),
}
else:
return {
"success":False,
"msg":img['reason'],
2025-06-13 16:18:58 +08:00
"color":rgb_color,
2025-06-04 17:04:02 +08:00
"image":None
}
"""
选择分类
"""
2025-06-13 16:18:58 +08:00
@app.get("/ai-station-api/sam_class/change")
2025-06-04 17:04:02 +08:00
async def on_class_selected(class_index : int = None,path: str = None):
# 加载配置内容
loaded_data,api_config = sam_deal.load_model(path)
result, api_config = sam_deal.set_current_class(loaded_data, api_config, class_index, color=None)
2025-06-13 16:18:58 +08:00
loaded_data['class_index']=class_index
2025-06-04 17:04:02 +08:00
sam_deal.save_model(loaded_data,api_config,path)
if result:
img = sam_deal.refresh_image(loaded_data,api_config,path)
if img['status'] == True:
encoded_string = sam_deal.load_tmp_image(img['reason'])
return {
"success":True,
2025-06-13 16:18:58 +08:00
"msg":"更改成功",
2025-06-04 17:04:02 +08:00
"image":JSONResponse(content={"image_data": encoded_string})
}
else:
return {
"success":False,
"msg":img['reason'],
"image":None
}
else:
return {
"success":False,
"msg":"分类标签识别错误",
"image":None
}
"""
删除分类, 前端跳转为normal index=0
"""
@app.get("/ai-station-api/sam_class/delete")
async def sam_remove_class(path:str=None,select_index:int=None):
loaded_data,api_config = sam_deal.load_model(path)
2025-06-13 16:18:58 +08:00
if select_index == -1:
return {
"success":False,
"msg":"没有选定分类,请选定当前分类后再删除",
"image":None
}
2025-06-04 17:04:02 +08:00
class_name = loaded_data['class_names'][select_index]
loaded_data,api_config = sam_deal.remove_class(loaded_data,api_config,class_name)
sam_deal.save_model(loaded_data,api_config,path)
img = sam_deal.refresh_image(loaded_data,api_config,path)
if img['status'] == True:
encoded_string = sam_deal.load_tmp_image(img['reason'])
return {
"success":True,
2025-06-13 16:18:58 +08:00
"msg":"删除成功",
"class_name_list": loaded_data['class_names'],
"current_index": loaded_data['class_index'],
"class_dict":loaded_data['class_colors'],
2025-06-04 17:04:02 +08:00
"image":JSONResponse(content={"image_data": encoded_string})
}
else:
return {
"success":False,
"msg":img['reason'],
"image":None
}
"""
添加标注点 -左键
"""
@app.get("/ai-station-api/sam_point/left_add")
async def left_mouse_down(x:int=None,y:int=None,path:str=None):
loaded_data,api_config = sam_deal.load_model(path)
if not api_config['current_class']:
return {
"success":False,
"msg":"请先选择一个分类,在添加标点之前",
"image":None
}
is_foreground = True
result = sam_deal.add_annotation_point(api_config,x,y,is_foreground)
if result['status']== False:
return {
"success":False,
"msg":result['reason'],
"image":None
}
api_config = result['api']
sam_deal.save_model(loaded_data,api_config,path)
img = sam_deal.refresh_image(loaded_data,api_config,path)
if img['status'] == True:
encoded_string = sam_deal.load_tmp_image(img['reason'])
return {
"success":True,
"msg":"",
"image":JSONResponse(content={"image_data": encoded_string})
}
else:
return {
"success":False,
"msg":img['reason'],
"image":None
}
"""
添加标注点 -右键
"""
@app.get("/ai-station-api/sam_point/right_add")
async def right_mouse_down(x:int=None,y:int=None,path:str=None):
loaded_data,api_config = sam_deal.load_model(path)
if not api_config['current_class']:
return {
"success":False,
"msg":"请先选择一个分类,在添加标点之前",
"image":None
}
is_foreground = False
result = sam_deal.add_annotation_point(api_config,x,y,is_foreground)
if result['status']== False:
return {
"success":False,
"msg":result['reason'],
"image":None
}
api_config = result['api']
sam_deal.save_model(loaded_data,api_config,path)
img = sam_deal.refresh_image(loaded_data,api_config,path)
if img['status'] == True:
encoded_string = sam_deal.load_tmp_image(img['reason'])
return {
"success":True,
"msg":"",
"image":JSONResponse(content={"image_data": encoded_string})
}
else:
return {
"success":False,
"msg":img['reason'],
"image":None
}
"""
删除上一个点
"""
@app.get("/ai-station-api/sam_point/delete_last")
async def sam_delete_last_point(path:str=None):
loaded_data,api_config = sam_deal.load_model(path)
result = sam_deal.delete_last_point(api_config)
if result['status'] == True:
api_config = result['reason']
sam_deal.save_model(loaded_data,api_config,path)
img = sam_deal.refresh_image(loaded_data,api_config,path)
if img['status'] == True:
encoded_string = sam_deal.load_tmp_image(img['reason'])
return {
"success":True,
"msg":"",
"image":JSONResponse(content={"image_data": encoded_string})
}
else:
return {
"success":False,
"msg":img['reason'],
"data":None
}
else:
return {
"success":False,
"msg":result['reason'],
"data":None
}
"""
删除所有点
"""
@app.get("/ai-station-api/sam_point/delete_all")
async def sam_clear_all_point(path:str=None):
loaded_data,api_config = sam_deal.load_model(path)
result = sam_deal.reset_current_class_points(api_config)
if result['status'] == True:
api_config = result['reason']
sam_deal.save_model(loaded_data,api_config,path)
img = sam_deal.refresh_image(loaded_data,api_config,path)
if img['status'] == True:
encoded_string = sam_deal.load_tmp_image(img['reason'])
return {
"success":True,
"msg":"",
"image":JSONResponse(content={"image_data": encoded_string})
}
else:
return {
"success":False,
"msg":img['reason'],
"image":None
}
else:
return {
"success":False,
"msg":result['reason'],
"image":None
}
"""
模型预测
"""
@app.get("/ai-station-api/sam_model/predict")
async def sam_predict_mask(path:str=None):
loaded_data,api_config = sam_deal.load_model(path)
class_data = api_config['class_annotations'].get(api_config['current_class'], {})
if not class_data.get('points'):
return {
"success":False,
"msg":"请在预测前添加至少一个预测样本点",
"data":None
}
else:
loaded_data = sam_deal.reset_annotation(loaded_data)
# 将标注点添加到loaded_data 中
for i, (x, y) in enumerate(class_data['points']):
is_foreground = class_data['point_types'][i]
loaded_data = sam_deal.add_point(loaded_data, x, y, is_foreground=is_foreground)
try:
result = sam_deal.predict_mask(loaded_data,sam_predictor)
if result['status'] == False:
return {
"success":False,
"msg":result['reason'],
"data":None}
result = result['reason']
loaded_data = result['data']
class_data['masks'] = [np.array(mask, dtype=np.uint8) for mask in result['masks']]
class_data['scores'] = result['scores']
class_data['selected_mask_index'] = result['selected_index']
if result['selected_index'] >= 0:
class_data['selected_mask'] = class_data['masks'][result['selected_index']]
logger.info(f"predict: Predicted {len(result['masks'])} masks, selected index: {result['selected_index']}")
sam_deal.save_model(loaded_data,api_config,path)
img = sam_deal.refresh_image(loaded_data,api_config,path)
if img['status'] == True:
encoded_string = sam_deal.load_tmp_image(img['reason'])
return {
"success":True,
"msg":"",
"image":JSONResponse(content={"image_data": encoded_string})
}
else:
return {
"success":False,
"msg":img['reason'],
"image":None
}
except Exception as e:
logger.error(f"predict: Error during prediction: {str(e)}")
traceback.print_exc()
return {
"success":False,
"msg":f"predict: Error during prediction: {str(e)}",
"data":None}
"""
清除所有信息
"""
@app.get("/ai-station-api/sam_model/clear_all")
async def sam_reset_annotation(path:str=None):
loaded_data,api_config = sam_deal.load_model(path)
loaded_data,api_config = sam_deal.reset_annotation_all(loaded_data,api_config)
2025-06-13 16:18:58 +08:00
sam_deal.save_model(loaded_data,api_config,path)
2025-06-04 17:04:02 +08:00
img = sam_deal.refresh_image(loaded_data,api_config,path)
if img['status'] == True:
encoded_string = sam_deal.load_tmp_image(img['reason'])
return {
"success":True,
"msg":"",
"image":JSONResponse(content={"image_data": encoded_string})
}
else:
return {
"success":False,
"msg":img['reason'],
"image":None
}
"""
保存预测分类
"""
@app.get("/ai-station-api/sam_model/save_stage")
async def sam_add_to_class(path:str=None,class_index:int=None):
loaded_data,api_config = sam_deal.load_model(path)
class_name = loaded_data['class_names'][class_index]
result = sam_deal.add_to_class(api_config,class_name)
if result['status'] == True:
api_config = result['reason']
sam_deal.save_model(loaded_data,api_config,path)
img = sam_deal.refresh_image(loaded_data,api_config,path)
if img['status'] == True:
encoded_string = sam_deal.load_tmp_image(img['reason'])
return {
"success":True,
"msg":"",
"image":JSONResponse(content={"image_data": encoded_string})
}
else:
return {
"success":False,
"msg":img['reason'],
"image":None
}
else:
return {
"success":False,
"msg":result['reason'],
"image":None
}
"""
保存结果
"""
@app.get("/ai-station-api/sam_model/save")
async def sam_save_annotation(path:str=None):
loaded_data,api_config = sam_deal.load_model(path)
2025-05-14 11:00:24 +08:00
2025-06-04 17:04:02 +08:00
if not api_config['output_dir']:
logger.info("save_annotation: Output directory not set")
return {
"success":False,
"msg":"save_annotation: Output directory not set",
"data":None
}
has_annotations = False
for class_name, class_data in api_config['class_annotations'].items():
if 'final_mask' in class_data and class_data['final_mask'] is not None:
has_annotations = True
break
if not has_annotations:
logger.info("save_annotation: No final masks to save")
return {
"success":False,
"msg":"save_annotation: No final masks to save",
"data":None
}
image_info = sam_deal.get_image_info(loaded_data)
if not image_info:
logger.info("save_annotation: No image info available")
return {
"success":False,
"msg":"save_annotation: No image info available",
"data":None
}
image_basename = os.path.splitext(image_info['filename'])[0]
annotation_dir = os.path.join(api_config['output_dir'], image_basename)
os.makedirs(annotation_dir, exist_ok=True)
saved_files = []
orig_img = loaded_data['image']
original_img_path = os.path.join(annotation_dir, f"{image_basename}.jpg")
cv2.imwrite(original_img_path, orig_img)
saved_files.append(original_img_path)
vis_img = orig_img.copy()
img_height, img_width = orig_img.shape[:2]
labelme_data = {
"version": "5.1.1",
"flags": {},
"shapes": [],
"imagePath": f"{image_basename}.jpg",
"imageData": None,
"imageHeight": img_height,
"imageWidth": img_width
}
for class_name, class_data in api_config['class_annotations'].items():
if 'final_mask' in class_data and class_data['final_mask'] is not None:
color = api_config['class_colors'].get(class_name, (0, 255, 0))
vis_mask = class_data['final_mask'].copy()
color_mask = np.zeros_like(vis_img)
color_mask[vis_mask > 0] = color
vis_img = cv2.addWeighted(vis_img, 1.0, color_mask, 0.5, 0)
binary_mask = (class_data['final_mask'] > 0).astype(np.uint8)
contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
for contour in contours:
epsilon = 0.0001 * cv2.arcLength(contour, True)
approx_contour = cv2.approxPolyDP(contour, epsilon, True)
points = [[float(point[0][0]), float(point[0][1])] for point in approx_contour]
if len(points) >= 3:
shape_data = {
"label": class_name,
"points": points,
"group_id": None,
"shape_type": "polygon",
"flags": {}
}
labelme_data["shapes"].append(shape_data)
vis_path = os.path.join(annotation_dir, f"{image_basename}_mask.jpg")
cv2.imwrite(vis_path, vis_img)
saved_files.append(vis_path)
try:
is_success, buffer = cv2.imencode(".jpg", orig_img)
if is_success:
img_bytes = io.BytesIO(buffer).getvalue()
labelme_data["imageData"] = base64.b64encode(img_bytes).decode('utf-8')
else:
print("save_annotation: Failed to encode image data")
labelme_data["imageData"] = ""
except Exception as e:
logger.error(f"save_annotation: Could not encode image data: {str(e)}")
labelme_data["imageData"] = ""
json_path = os.path.join(annotation_dir, f"{image_basename}.json")
with open(json_path, 'w') as f:
json.dump(labelme_data, f, indent=2)
saved_files.append(json_path)
logger.info(f"save_annotation: Annotation saved to {annotation_dir}")
# 将其打包
zip_filename = "download.zip"
zip_filepath = Path(zip_filename)
dir_path = os.path.dirname(annotation_dir)
# 创建 ZIP 文件
try:
with zipfile.ZipFile(zip_filepath, 'w') as zip_file:
# 遍历文件夹中的所有文件
for foldername, subfolders, filenames in os.walk(dir_path):
for filename in filenames:
file_path = os.path.join(foldername, filename)
# 将文件写入 ZIP 文件
zip_file.write(file_path, os.path.relpath(file_path, dir_path))
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error creating zip file: {str(e)}")
# 返回 ZIP 文件作为响应
return FileResponse(zip_filepath, media_type='application/zip', filename=zip_filename)
2025-05-14 11:00:24 +08:00
2025-05-06 11:18:48 +08:00
2025-05-14 11:00:24 +08:00
# @app.post("/ai-station-api/items/")
2025-05-06 11:18:48 +08:00
# async def create_item(item: post_model.Zongbiaomianji):
# try:
# data = {
# "type": type(item).__name__, # 返回类型名称
# "received_data": item.model_dump() # 使用 model_dump() 方法
# }
# print(pd.DataFrame([item.model_dump()]))
# # 返回接收到的数据
# return data
# except Exception as e:
# # 记录错误信息
# logger.error(f"Error occurred: {e}")