ai-station-code/run.py

1815 lines
63 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import sys
import io
from fastapi import FastAPI,File, UploadFile,Form,Query
from fastapi.staticfiles import StaticFiles
from fastapi import FastAPI, HTTPException
from fastapi.responses import FileResponse,JSONResponse
import sys
import os
import shutil
from pydantic import BaseModel, validator
from typing import List, Optional
import asyncio
import pandas as pd
import numpy as np
from PIL import Image
import pickle
import cv2
import copy
import base64
# 获取当前脚本所在目录
print("Current working directory:", os.getcwd())
current_dir = os.path.dirname(os.path.abspath(__file__))
# 添加环境变量路径
sys.path.append(os.path.join(current_dir))
print("Current sys.path:", sys.path)
import torch
from dimaoshibie import segformer
from wudingpv.taihuyuan_roof.manet.model.resunet import resUnetpamcarb as roof_resUnetpamcarb
from wudingpv.predictandeval_util import segmentation
from guangfufadian import model_base as guangfufadian_model_base
from fenglifadian import model_base as fenglifadian_model_base
from work_util import prepare_data,model_deal,params,data_util,post_model,sam_deal
from work_util.logger import logger
import joblib
import mysql.connector
import uuid
import json
import zipfile
from pathlib import Path
from segment_anything_model import sam_annotator
from segment_anything_model.sam_config import sam_config, sam_api_config
from segment_anything_model.segment_anything import sam_model_registry, SamPredictor
import traceback
version = f"{sys.version_info.major}.{sys.version_info.minor}"
app = FastAPI()
param = params.ModelParams()
def get_roof_model():
model_roof = roof_resUnetpamcarb()
model_path_roof = os.path.join(current_dir,'wudingpv/models/roof_best.pth')
model_dict_roof = torch.load(model_path_roof, map_location=torch.device('cpu'))
model_roof.load_state_dict(model_dict_roof['net'])
logger.info("屋顶识别权重加载成功")
model_roof.eval()
model_roof.cuda()
return model_roof
def get_pv_model():
model_roof = roof_resUnetpamcarb()
model_path_roof = os.path.join(current_dir,'wudingpv/models/pv_best.pth')
model_dict_roof = torch.load(model_path_roof, map_location=torch.device('cpu'))
model_roof.load_state_dict(model_dict_roof['net'])
logger.info("屋顶识别权重加载成功")
model_roof.eval()
model_roof.cuda()
return model_roof
# 模型实例
dimaoshibie_SegFormer = segformer.SegFormer_Segmentation()
# 模型实例
dimaoshibie_SegFormer = segformer.SegFormer_Segmentation()
roof_model = get_roof_model()
pv_model = get_pv_model()
ch4_model_flow = joblib.load(os.path.join(current_dir,'jiawanyuce/liuliang_model/xgb_model_liuliang.pkl'))
ch4_model_gas = joblib.load(os.path.join(current_dir,'jiawanyuce/qixiangnongdu_model/xgb_model_qixiangnongdu.pkl'))
pvfd_param = guangfufadian_model_base.guangfufadian_Args()
pvfd_model_path = os.path.join(pvfd_param.checkpoints,'Crossformer_station08_il192_ol96_sl6_win2_fa10_dm256_nh4_el3_itr0/checkpoint.pth') # 修改为实际模型路径
pvfd_model = guangfufadian_model_base.ModelInference(pvfd_model_path, pvfd_param)
windfd_args = fenglifadian_model_base.fenglifadian_Args()
windfd_model_path = os.path.join(windfd_args.checkpoints,'Crossformer_Wind_farm_il192_ol12_sl6_win2_fa10_dm256_nh4_el3_itr0/checkpoint.pth') # 修改为实际模型路径
windfd_model = fenglifadian_model_base.ModelInference(windfd_model_path, windfd_args)
# 模型加载
checkpoint_path = os.path.join(current_dir,'segment_anything_model/weights/vit_b.pth')
sam = sam_model_registry["vit_b"](checkpoint=checkpoint_path)
device = "cuda" if cv2.cuda.getCudaEnabledDeviceCount() > 0 else "cpu"
_ = sam.to(device=device)
sam_predictor = SamPredictor(sam)
print(f"SAM模型已加载使用设备: {device}")
# 将 /root/app 目录挂载为静态文件
app.mount("/files", StaticFiles(directory="/root/app"), name="files")
@app.get("/")
async def read_root():
message = f"Hello world! From FastAPI running on Uvicorn with Gunicorn. Using Python {version}"
return {"message": message}
# 首页
# 获取数据界面资源信息
@app.get("/ai-station-api/index/show")
async def get_source_index_info():
sql = "SELECT id,application_name, describe_data, img_url FROM app_shouye"
data = data_util.fetch_data(sql)
if data is None:
return {
"success":False,
"msg":"获取信息列表失败",
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":data}
# 获取数据界面资源信息
@app.get("/ai-station-api/data_source/show")
async def get_data_source_info():
sql = "SELECT id,application_name, task_type, sample_name, img_url, time, download_url FROM data_samples"
data = data_util.fetch_data(sql)
data = data_util.generate_json_data_source(data)
if data is None:
return {
"success":False,
"msg":"获取信息列表失败",
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":data}
# 获取应用界面资源信息
@app.get("/ai-station-api/app_source/show")
async def get_app_source_info():
sql = "SELECT id,application_name, task_type, sample_name, img_url, time FROM app_samples"
data = data_util.fetch_data(sql)
data = data_util.generate_json_app_source(data)
if data is None:
return {
"success":False,
"msg":"获取信息列表失败",
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":data}
# 获取煤基碳材料界面资源信息
@app.get("/ai-station-api/mjtcl_source/show")
async def get_mjtcl_source_info():
sql = "SELECT id,application_name, task_type, sample_name, img_url, time FROM meijitancailiao_samples"
data = data_util.fetch_data(sql)
data = data_util.generate_json_meijitancailiao_source(data)
if data is None:
return {
"success":False,
"msg":"获取信息列表失败",
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":data}
# 获取煤热解界面资源信息
@app.get("/ai-station-api/mrj_source/show")
async def get_mrj_source_info():
sql = "SELECT id,application_name, task_type, sample_name, img_url, time FROM meirejie_samples"
data = data_util.fetch_data(sql)
data = data_util.generate_json_meirejie_source(data)
if data is None:
return {
"success":False,
"msg":"获取信息列表失败",
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":data}
"""煤基碳材料输入特征说明接口列表
type = {zongkongtiji,zongbiaomianji,tancailiao,meiliqing,moniqi}
"""
@app.get("/ai-station-api/mjt_feature/show")
async def get_mjt_feature_info(type:str = None):
# if type in ["tancailiao","meiliqing"]:
# sql = "SELECT type, chinese_name, col_name, data_type, unit FROM meijitancailiao_features where use_type = %s;"
# else:
sql = "SELECT type, chinese_name, col_name, data_type, unit, data_scale,best_data FROM meijitancailiao_features where use_type = %s;"
data = data_util.fetch_data_with_param(sql,(type,))
if data is None:
return {
"success":False,
"msg":"获取信息列表失败",
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":data}
"""煤热解材料输入特征说明接口列表
type = {tar,char,gas,water}
"""
@app.get("/ai-station-api/mrj_feature/show")
async def get_mrj_feature_info(type:str = None):
sql = "SELECT type, chinese_name, col_name, data_type, unit, data_scale,best_data FROM meijitancailiao_features where use_type = %s;"
data = data_util.fetch_data_with_param(sql,(type,))
if data is None:
return {
"success":False,
"msg":"获取信息列表失败",
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":data}
#========================================煤基碳材料 ==================================
@app.post("/ai-station-api/mjt/predict")
async def mjt_models_predict(content: post_model.Zongbiaomianji):
# 处理接收到的字典数据
meijiegou_test_content = pd.DataFrame([content.model_dump()])
meijiegou_test_content= meijiegou_test_content.rename(columns={"K_C":"K/C"})
new_order = ["A", "VM", "K/C", "MM", "AT", "At", "Rt"]
meijiegou_test_content = meijiegou_test_content.reindex(columns=new_order)
ssa_result = model_deal.pred_single_ssa(meijiegou_test_content)
logger.info("Root endpoint was accessed")
if ssa_result is None:
return {
"success":False,
"msg":"获取信息列表失败",
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":ssa_result}
#==========================煤基碳材料-总表面积细节接口==================================
"""
模型综合分析
"""
@app.post("/ai-station-api/mjt_models_ssa/predict")
async def mjt_models_predict_ssa(content: post_model.Zongbiaomianji):
# 处理接收到的字典数据
meijiegou_test_content = pd.DataFrame([content.model_dump()])
meijiegou_test_content= meijiegou_test_content.rename(columns={"K_C":"K/C"})
new_order = ["A", "VM", "K/C", "MM", "AT", "At", "Rt"]
meijiegou_test_content = meijiegou_test_content.reindex(columns=new_order)
ssa_result = model_deal.pred_single_ssa(meijiegou_test_content)
# logger.info("Root endpoint was accessed")
if ssa_result is None:
return {
"success":False,
"msg":"获取信息列表失败",
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":ssa_result}
"""
上传文件
"""
@app.post("/ai-station-api/document/upload")
async def upload_file(file: UploadFile = File(...),type: str = Form(...), ):
if not file.filename.endswith('.csv'):
raise HTTPException(status_code=400, detail="文件类型有误,必须是 CSV 文件")
upload_dir = os.path.join(current_dir,'tmp',type, str(uuid.uuid4()))
if not os.path.exists(upload_dir ):
os.makedirs(upload_dir )
# 将文件保存到指定目录
file_location = os.path.join(upload_dir , file.filename)
with open(file_location, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
return {"success":True,
"msg":"获取信息成功",
"data":{"location": file_location}}
"""
批量预测接口
"""
@app.get("/ai-station-api/mjt_multi_ssa/predict")
async def mjt_multi_ssa_pred(model:str = None, path:str = None):
data = model_deal.get_excel_ssa(model, path)
if data['status'] == True:
return {"success":True,
"msg":"获取信息成功",
"data":data['reason']}
else:
return {"success":False,
"msg":data['reason'],
"data":None}
#==========================煤基碳材料-总孔体积细节接口==================================
"""
模型综合分析
"""
@app.post("/ai-station-api/mjt_models_tpv/predict")
async def mjt_models_predict_tpv(content: post_model.Zongbiaomianji):
# 处理接收到的字典数据
meijiegou_test_content = pd.DataFrame([content.model_dump()])
meijiegou_test_content= meijiegou_test_content.rename(columns={"K_C":"K/C"})
new_order = ["A", "VM", "K/C", "MM", "AT", "At", "Rt"]
meijiegou_test_content = meijiegou_test_content.reindex(columns=new_order)
tpv_result = model_deal.pred_single_tpv(meijiegou_test_content)
if tpv_result is None:
return {
"success":False,
"msg":"获取信息列表失败",
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":tpv_result}
"""
批量预测接口
"""
@app.get("/ai-station-api/mjt_multi_tpv/predict")
async def mjt_multi_tpv_pred(model:str = None, path:str = None):
data = model_deal.get_excel_tpv(model, path)
if data['status'] == True:
return {"success":True,
"msg":"获取信息成功",
"data":data['reason']}
else:
return {"success":False,
"msg":data['reason'],
"data":None}
#==========================煤基碳材料-煤炭材料应用细节接口==================================
@app.post("/ai-station-api/mjt_models_meitan/predict")
async def mjt_models_predict_meitan(content: post_model.Meitan):
# 处理接收到的字典数据
meijiegou_test_content = pd.DataFrame([content.model_dump()])
meijiegou_test_content= meijiegou_test_content.rename(columns={"ID_IG":"ID/IG"})
new_order = ["SSA", "TPV", "N", "O", "ID/IG", "J"]
meijiegou_test_content = meijiegou_test_content.reindex(columns=new_order)
meitan_result = model_deal.pred_single_meitan(meijiegou_test_content)
if meitan_result is None:
return {
"success":False,
"msg":"获取信息列表失败",
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":meitan_result}
"""
批量预测接口
"""
@app.get("/ai-station-api/mjt_multi_meitan/predict")
async def mjt_multi_meitan_pred(model:str = None, path:str = None):
data = model_deal.get_excel_meitan(model, path)
if data['status'] == True:
return {"success":True,
"msg":"获取信息成功",
"data":data['reason']}
else:
return {"success":False,
"msg":data['reason'],
"data":None}
#==========================煤基碳材料-煤沥青应用细节接口==================================
@app.post("/ai-station-api/mjt_models_meiliqing/predict")
async def mjt_models_predict_meiliqing(content: post_model.Meitan):
# 处理接收到的字典数据
meijiegou_test_content = pd.DataFrame([content.model_dump()])
meijiegou_test_content= meijiegou_test_content.rename(columns={"ID_IG":"ID/IG"})
new_order = ["SSA", "TPV", "N", "O", "ID/IG", "J"]
meijiegou_test_content = meijiegou_test_content.reindex(columns=new_order)
meitan_result = model_deal.pred_single_meiliqing(meijiegou_test_content)
if meitan_result is None:
return {
"success":False,
"msg":"获取信息列表失败",
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":meitan_result}
"""
批量预测接口
"""
@app.get("/ai-station-api/mjt_multi_meiliqing/predict")
async def mjt_multi_meiliqing_pred(model:str = None, path:str = None):
data = model_deal.get_excel_meiliqing(model, path)
if data['status'] == True:
return {"success":True,
"msg":"获取信息成功",
"data":data['reason']}
else:
return {"success":False,
"msg":data['reason'],
"data":None}
#=============================煤基炭材料-制备模拟器================================================
# 正向模拟器
@app.post("/ai-station-api/mnq_model/analysis")
async def mnq_model_predict(return_count: int, # 返回条数
model_choice: str, # 模型选择标记
form_data: post_model.FormData # 表单数据
):
form_data = form_data.model_dump()
params = prepare_data.get_params(form_data)
pred_data = prepare_data.create_pred_data(params)
result = model_deal.pred_func(model_choice,pred_data)
sorted_result = result.sort_values(by=['SSA', 'TPV'], ascending=[False, False])
upload_dir = os.path.join(current_dir,'tmp','moniqi', str(uuid.uuid4()))
if not os.path.exists(upload_dir):
os.makedirs(upload_dir )
# 将文件保存到指定目录
file_name = model_choice + "_moniqi.csv"
file_location = os.path.join(upload_dir , file_name)
sorted_result.to_csv(file_location, index = False)
# 保留条数
if return_count is None:
return {"success":True,
"msg":"获取信息成功",
"data":{"result": sorted_result.to_dict(orient='records'), "path":file_location}}
else:
data = sorted_result.iloc[:return_count]
# logger.info(data.to_dict(orient="records"))
return {"success":True,
"msg":"获取信息成功",
"data":{"result":data.to_dict(orient="records"), "path":file_location}
}
# 反向模拟器
@app.get("/ai-station-api/mnq_model/reverse_analysis")
async def mjt_multi_meiliqing_pred(model:str = None, path:str = None, type:int = 1, scale:float=None, num:int=5):
result = prepare_data.moniqi_data_prepare(model,path,type,scale,num)
if result['status'] == False:
return {
"success":False,
"msg":result['reason'],
"data":None
}
else:
return {
"success":True,
"msg":"获取信息成功",
"data":{"data": result['reason']}
}
#===============================煤热解-tar =======================================================
"""
模型综合分析
"""
@app.post("/ai-station-api/mrj_models_tar/predict")
async def mrj_models_predict_tar(content: post_model.Meirejie):
# 处理接收到的字典数据
test_content = pd.DataFrame([content.model_dump()])
test_content= test_content.rename(columns={"H_C":"H/C"})
test_content= test_content.rename(columns={"O_C":"O/C"})
test_content= test_content.rename(columns={"N_C":"N/C"})
new_order = ["A", "V", "FC", "C", "H", "N" ,"S" ,"O" ,"H/C" ,"O/C" ,"N/C" ,"Rt" ,"Hr", "dp","T"]
test_content = test_content.reindex(columns=new_order)
tmp_result = model_deal.pred_single_tar(test_content)
if tmp_result is None:
return {
"success":False,
"msg":"获取信息列表失败",
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":tmp_result}
"""
批量预测接口
"""
@app.get("/ai-station-api/mrj_multi_tar/predict")
async def mrj_multi_tar_pred(model:str = None, path:str = None):
data = model_deal.get_excel_tar(model, path)
if data['status'] == True:
return {"success":True,
"msg":"获取信息成功",
"data":data['reason']}
else:
return {"success":False,
"msg":data['reason'],
"data":None}
#===============================煤热解-char =======================================================
"""
模型综合分析
"""
@app.post("/ai-station-api/mrj_models_char/predict")
async def mrj_models_predict_char(content: post_model.Meirejie):
# 处理接收到的字典数据
test_content = pd.DataFrame([content.model_dump()])
test_content= test_content.rename(columns={"H_C":"H/C"})
test_content= test_content.rename(columns={"O_C":"O/C"})
test_content= test_content.rename(columns={"N_C":"N/C"})
new_order = ["A", "V", "FC", "C", "H", "N" ,"S" ,"O" ,"H/C" ,"O/C" ,"N/C" ,"Rt" ,"Hr", "dp","T"]
test_content = test_content.reindex(columns=new_order)
tmp_result = model_deal.pred_single_char(test_content)
if tmp_result is None:
return {
"success":False,
"msg":"获取信息列表失败",
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":tmp_result}
"""
批量预测接口
"""
@app.get("/ai-station-api/mrj_multi_char/predict")
async def mrj_multi_char_pred(model:str = None, path:str = None):
data = model_deal.get_excel_char(model, path)
if data['status'] == True:
return {"success":True,
"msg":"获取信息成功",
"data":data['reason']}
else:
return {"success":False,
"msg":data['reason'],
"data":None}
#===============================煤热解-water =======================================================
"""
模型综合分析
"""
@app.post("/ai-station-api/mrj_models_water/predict")
async def mrj_models_predict_water(content: post_model.Meirejie):
# 处理接收到的字典数据
test_content = pd.DataFrame([content.model_dump()])
test_content= test_content.rename(columns={"H_C":"H/C"})
test_content= test_content.rename(columns={"O_C":"O/C"})
test_content= test_content.rename(columns={"N_C":"N/C"})
new_order = ["A", "V", "FC", "C", "H", "N" ,"S" ,"O" ,"H/C" ,"O/C" ,"N/C" ,"Rt" ,"Hr", "dp","T"]
test_content = test_content.reindex(columns=new_order)
tmp_result = model_deal.pred_single_water(test_content)
if tmp_result is None:
return {
"success":False,
"msg":"获取信息列表失败",
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":tmp_result}
"""
批量预测接口
"""
@app.get("/ai-station-api/mrj_multi_water/predict")
async def mrj_multi_water_pred(model:str = None, path:str = None):
data = model_deal.get_excel_water(model, path)
if data['status'] == True:
return {"success":True,
"msg":"获取信息成功",
"data":data['reason']}
else:
return {"success":False,
"msg":data['reason'],
"data":None}
#===============================煤热解-gas =======================================================
"""
模型综合分析
"""
@app.post("/ai-station-api/mrj_models_gas/predict")
async def mrj_models_predict_gas(content: post_model.Meirejie):
# 处理接收到的字典数据
test_content = pd.DataFrame([content.model_dump()])
test_content= test_content.rename(columns={"H_C":"H/C"})
test_content= test_content.rename(columns={"O_C":"O/C"})
test_content= test_content.rename(columns={"N_C":"N/C"})
new_order = ["A", "V", "FC", "C", "H", "N" ,"S" ,"O" ,"H/C" ,"O/C" ,"N/C" ,"Rt" ,"Hr", "dp","T"]
test_content = test_content.reindex(columns=new_order)
tmp_result = model_deal.pred_single_gas(test_content)
if tmp_result is None:
return {
"success":False,
"msg":"获取信息列表失败",
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":tmp_result}
"""
批量预测接口
"""
@app.get("/ai-station-api/mrj_multi_gas/predict")
async def mrj_multi_gas_pred(model:str = None, path:str = None):
data = model_deal.get_excel_gas(model, path)
if data['status'] == True:
return {"success":True,
"msg":"获取信息成功",
"data":data['reason']}
else:
return {"success":False,
"msg":data['reason'],
"data":None}
#================================ 地貌识别 ====================================================
"""
上传图片 , type = dimaoshibie
"""
@app.post("/ai-station-api/image/upload")
async def upload_image(file: UploadFile = File(...),type: str = Form(...), ):
if not data_util.allowed_file(file.filename):
return {
"success":False,
"msg":"图片必须以 '.jpg', '.jpeg', '.png', '.tif' 结尾",
"data":None
}
if file.size > param.MAX_FILE_SIZE:
return {
"success":False,
"msg":"图片大小不能大于100MB",
"data":None
}
upload_dir = os.path.join(current_dir,'tmp',type, str(uuid.uuid4()))
if not os.path.exists(upload_dir):
os.makedirs(upload_dir )
# 将文件保存到指定目录
file_location = os.path.join(upload_dir , file.filename)
with open(file_location, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
file_location = model_deal.get_pic_url(file_location)
return {"success":True,
"msg":"获取信息成功",
"data":{"location": file_location}}
"""
图片地貌识别
"""
@app.get("/ai-station-api/dmsb_image/analysis")
async def dmsb_image_analysis(path:str = None):
path = model_deal.get_pic_path(path)
result = model_deal.dimaoshibie_pic(dimaoshibie_SegFormer,path,param.dmsb_count,param.dmsb_name_classes)
if result['status'] == False:
return {"success":False,
"msg":result['reason'],
"data":None}
else:
path = result['reason']
path= model_deal.get_pic_url(path[0]) # 是一个列表
return {"success":True,
"msg":"获取信息成功",
"data":{"result": path}}
"""
图片地貌 - 分割结果计算 ,图片demo的是z17
"""
@app.get("/ai-station-api/dmsb_image/calculate")
async def dmsb_image_calculate(path:str = None, scale:float = 1.92*1.92):
path = model_deal.get_pic_path(path)
result = model_deal.dimaoshibie_area(path,scale,param.dmsb_colors)
logger.info(result)
if result['status'] == True:
res = result['reason']
translated_dict = {param.dmsb_type[key]: value for key, value in res.items()}
# json_data = json.dumps(translated_dict)
return {"success":True,
"msg":"获取信息成功",
"data":{"result": translated_dict}}
else:
return {"success":False,
"msg":result['reason'],
"data":None}
"""
下载
"""
@app.get("/ai-station-api/download")
async def download_zip(path:str = None):
path = model_deal.get_pic_path(path)
zip_filename = "download.zip"
zip_filepath = Path(zip_filename)
dir_path = os.path.dirname(path)
# 创建 ZIP 文件
try:
with zipfile.ZipFile(zip_filepath, 'w') as zip_file:
# 遍历文件夹中的所有文件
for foldername, subfolders, filenames in os.walk(dir_path):
for filename in filenames:
file_path = os.path.join(foldername, filename)
# 将文件写入 ZIP 文件
zip_file.write(file_path, os.path.relpath(file_path, dir_path))
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error creating zip file: {str(e)}")
# 返回 ZIP 文件作为响应
return FileResponse(zip_filepath, media_type='application/zip', filename=zip_filename)
#============================屋顶识别======================================
"""
上传图片 type = roof
"""
"""
图片分析
"""
@app.get("/ai-station-api/roof_image/analysis")
async def roof_image_analysis(path:str = None):
path = model_deal.get_pic_path(path)
result = model_deal.roof_pic(roof_model,path,param.wdpv_palette)
if result['status'] == False:
return {"success":False,
"msg":result['reason'],
"data":None}
else:
path = result['reason']
path= model_deal.get_pic_url(path[0])
return {"success":True,
"msg":"获取信息成功",
"data":{"result": path}}
"""
屋顶面积 图 - 分割结果计算 ,图片demo的是z18
"""
@app.get("/ai-station-api/roof_image/calculate")
async def roof_image_calculate(path:str = None, scale:float = 0.92*0.92):
path = model_deal.get_pic_path(path)
result = model_deal.roof_area(path,scale,param.wdpv_colors)
logger.info(result)
if result['status'] == True:
res = result['reason']
translated_dict = {param.wd_type[key]: value for key, value in res.items()}
del translated_dict['其他']
return {"success":True,
"msg":"获取信息成功",
"data":{"result": translated_dict}}
else:
return {"success":False,
"msg":result['reason'],
"data":None}
#============================光伏识别======================================
"""
上传图片 type = pv
"""
"""
图片分析
"""
@app.get("/ai-station-api/pv_image/analysis")
async def pv_image_analysis(path:str = None):
path = model_deal.get_pic_path(path)
result = model_deal.roof_pic(pv_model,path,param.wdpv_palette)
if result['status'] == False:
return {"success":False,
"msg":result['reason'],
"data":None}
else:
path = result['reason']
path= model_deal.get_pic_url(path[0])
return {"success":True,
"msg":"获取信息成功",
"data":{"result": path}}
"""
光伏面积 图 - 分割结果计算 ,图片demo的是z18
"""
@app.get("/ai-station-api/pv_image/calculate")
async def pv_image_calculate(path:str = None, scale:float = 0.92*0.92):
path = model_deal.get_pic_path(path)
result = model_deal.roof_area(path,scale,param.wdpv_colors)
logger.info(result)
if result['status'] == True:
res = result['reason']
translated_dict = {param.pv_type[key]: value for key, value in res.items()}
del translated_dict['其他']
return {"success":True,
"msg":"获取信息成功",
"data":{"result": translated_dict}}
else:
return {"success":False,
"msg":result['reason'],
"data":None}
#============================屋顶光伏识别======================================
"""
上传图片 type = roofpv
"""
"""
图片分析
"""
@app.get("/ai-station-api/roofpv_image/analysis")
async def roofpv_image_analysis(path:str = None):
path = model_deal.get_pic_path(path)
result = model_deal.roofpv_pic(roof_model,pv_model,path,param.wdpv_palette)
if result['status'] == False:
return {"success":False,
"msg":result['reason'],
"data":None}
else:
file_list = result['reason']
final_path = prepare_data.merge_binary(file_list)
final_path = model_deal.get_pic_url(final_path)
return {"success":True,
"msg":"获取信息成功",
"data":{"result": final_path}}
"""
光伏面积 图 - 分割结果计算 ,图片demo的是z18
"""
@app.get("/ai-station-api/roofpv_image/calculate")
async def roofpv_image_calculate(path:str = None, scale:float = 0.92*0.92):
path = model_deal.get_pic_path(path)
result = model_deal.roof_area_roofpv(path,scale,param.wdpv_colors)
logger.info(result)
if result['status'] == True:
res = result['reason']
translated_dict = {param.wdpv_type[key]: value for key, value in res.items()}
del translated_dict['其他']
return {"success":True,
"msg":"获取信息成功",
"data":{"result": translated_dict}}
else:
return {"success":False,
"msg":result['reason'],
"data":None}
# ====================================时序预测类============================================
# ====================================甲烷预测==============================================
"""
返回显示列表
"""
@app.get("/ai-station-api/ch4_features/show")
async def get_ch4_features_info():
sql = "SELECT chinese_name, col_name,type,data_type,unit,data_sample FROM ch4_features;"
data = data_util.fetch_data(sql)
if data is None:
return {
"success":False,
"msg":"获取信息列表失败",
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":data}
@app.get("/ai-station-api/ch4_pic_table/show")
async def get_ch4_pic_table_info():
sql = "SELECT chinese_name, col_name FROM ch4_features where col_name != 'date_time';"
data = data_util.fetch_data(sql)
if data is None:
return {
"success":False,
"msg":"获取信息列表失败",
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":data}
"""
针对列表,返回数据,供画图
"""
@app.get("/ai-station-api/ch4_feature_pic/select")
async def get_ch4_feature_pic_select(type:str = None, path:str =None):
if type == 'date_time':
return {
"success":False,
"msg":"时间列无法展示,请重新选择",
"data":None
}
data = prepare_data.show_data_jiawanyuce(type,path)
if data is None:
return {
"success":False,
"msg":"获取信息列表失败",
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":data}
"""
上传数据分页展示
"""
@app.get("/ai-station-api/ch4_data/show")
async def get_ch4_data(path:str = None, page: int = Query(1, ge=1), page_size: int = Query(10, ge=1)):
data = prepare_data.get_jiawanyuce_data(path)
total_items = len(data) # 获取总条数
total_pages = (total_items + page_size - 1) // page_size # 计算总页数
# 检查请求的页码是否超出范围
if page > total_pages:
raise HTTPException(status_code=404, detail="Page not found")
# 计算起始和结束索引
start_idx = (page - 1) * page_size
end_idx = start_idx + page_size
# 获取当前页的数据
items = data.iloc[start_idx:end_idx].to_dict(orient="records")
# 返回数据,包括总条数、总页数和当前页的数据
return {
"success":True,
"msg":"获取信息成功",
"data":{
"total_items": total_items,
"total_pages": total_pages,
"current_page": page,
"page_size": page_size,
"items": items
}
}
"""
csv上传
"""
# @app.post("/ai-station-api/document/upload") type = "ch4"
"""
预测, type
"""
@app.get("/ai-station-api/ch4_data/predict")
async def get_ch4_predict(path:str=None,start_time:str=None, end_time:str = None,type:int=1,is_show:bool=True):
data = model_deal.start_predict_endpoint(ch4_model_flow,ch4_model_gas,path,start_time,end_time,type,is_show)
if data['status'] ==False:
return {
"success":False,
"msg":data['reason'],
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":data['reason']}
#========================================光伏预测==========================================================
"""
返回显示列表
"""
@app.get("/ai-station-api/pvelectric_features/show")
async def get_pvelectric_features_info():
sql = "SELECT chinese_name, col_name,type,data_type,unit,data_sample FROM pv_electric_features;"
data = data_util.fetch_data(sql)
if data is None:
return {
"success":False,
"msg":"获取信息列表失败",
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":data}
@app.get("/ai-station-api/pvelectric_pic_table/show")
async def get_pvelectric_pic_table_info():
sql = "SELECT chinese_name, col_name FROM pv_electric_features where col_name != 'date_time';"
data = data_util.fetch_data(sql)
if data is None:
return {
"success":False,
"msg":"获取信息列表失败",
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":data}
"""
针对列表,返回数据,供画图
"""
@app.get("/ai-station-api/pvelectric_feature_pic/select")
async def get_pvelectric_feature_pic_select(type:str = None, path:str =None):
if type == 'date_time':
return {
"success":False,
"msg":"时间列无法展示,请重新选择",
"data":None
}
data = prepare_data.show_data_pvfd(path,type)
if data is None:
return {
"success":False,
"msg":"获取信息列表失败",
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":data}
"""
上传数据分页展示
"""
@app.get("/ai-station-api/pvelectric_data/show")
async def get_pvelectric_data(path:str = None, page: int = Query(1, ge=1), page_size: int = Query(10, ge=1)):
data = prepare_data.show_testdata_pvfd(path)
total_items = len(data) # 获取总条数
total_pages = (total_items + page_size - 1) // page_size # 计算总页数
# 检查请求的页码是否超出范围
if page > total_pages:
raise HTTPException(status_code=404, detail="Page not found")
# 计算起始和结束索引
start_idx = (page - 1) * page_size
end_idx = start_idx + page_size
# 获取当前页的数据
items = data.iloc[start_idx:end_idx].to_dict(orient="records")
# 返回数据,包括总条数、总页数和当前页的数据
return {
"success":True,
"msg":"获取信息成功",
"data":{
"total_items": total_items,
"total_pages": total_pages,
"current_page": page,
"page_size": page_size,
"items": items
}
}
"""
csv上传
"""
# @app.post("/ai-station-api/document/upload") type = "pvelectric"
"""
预测, type
"""
@app.get("/ai-station-api/pvelectric_data/predict")
async def get_pvelectri_predict(path:str=None,start_time:str=None, end_time:str = None,is_show:bool=True):
data = model_deal.start_pvelectric_predict_endpoint(pvfd_model,path,start_time,end_time,is_show)
if data['status'] ==False:
return {
"success":False,
"msg":data['reason'],
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":data['reason']}
#========================================风力发电预测==========================================================
"""
返回显示列表
"""
@app.get("/ai-station-api/wind_electric_features/show")
async def get_wind_electric_features_info():
sql = "SELECT chinese_name, col_name,type,data_type,unit,data_sample FROM wind_electric_features;"
data = data_util.fetch_data(sql)
if data is None:
return {
"success":False,
"msg":"获取信息列表失败",
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":data}
@app.get("/ai-station-api/wind_electric_pic_table/show")
async def get_wind_electric_pic_table_info():
sql = "SELECT chinese_name, col_name FROM wind_electric_features where col_name != 'date';"
data = data_util.fetch_data(sql)
if data is None:
return {
"success":False,
"msg":"获取信息列表失败",
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":data}
"""
针对列表,返回数据,供画图
"""
@app.get("/ai-station-api/wind_electric_feature_pic/select")
async def get_wind_electric_feature_pic_select(type:str = None, path:str =None):
if type == 'date':
return {
"success":False,
"msg":"时间列无法展示,请重新选择",
"data":None
}
data = prepare_data.show_data_windfd(path,type)
if data is None:
return {
"success":False,
"msg":"获取信息列表失败",
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":data}
"""
上传数据分页展示
"""
@app.get("/ai-station-api/wind_electric_data/show")
async def get_wind_electric_data(path:str = None, page: int = Query(1, ge=1), page_size: int = Query(10, ge=1)):
data = prepare_data.show_testdata_windfd(path)
total_items = len(data) # 获取总条数
total_pages = (total_items + page_size - 1) // page_size # 计算总页数
# 检查请求的页码是否超出范围
if page > total_pages:
raise HTTPException(status_code=404, detail="Page not found")
# 计算起始和结束索引
start_idx = (page - 1) * page_size
end_idx = start_idx + page_size
# 获取当前页的数据
items = data.iloc[start_idx:end_idx].to_dict(orient="records")
# 返回数据,包括总条数、总页数和当前页的数据
return {
"success":True,
"msg":"获取信息成功",
"data":{
"total_items": total_items,
"total_pages": total_pages,
"current_page": page,
"page_size": page_size,
"items": items
}
}
"""
csv上传
"""
# @app.post("/ai-station-api/document/upload") type = "wind_electric"
"""
预测, type
"""
@app.get("/ai-station-api/wind_electric_data/predict")
async def get_wind_electri_predict(path:str=None,start_time:str=None, end_time:str = None,is_show:bool=True):
data = model_deal.start_wind_electric_predict_endpoint(windfd_model,path,start_time,end_time,is_show)
if data['status'] ==False:
return {
"success":False,
"msg":data['reason'],
"data":None
}
else:
return {"success":True,
"msg":"获取信息成功",
"data":data['reason']}
#======================== SAM =============================================================
"""
文件上传
1、 创建inputs目录outputs目录
2、 将图片上传到iuputs目录
4、 加载当前图像
5、 返回当前图像路径给前端
"""
@app.post("/ai-station-api/sam-image/upload")
async def upload_sam_image(file: UploadFile = File(...),type: str = Form(...), ):
if not data_util.allowed_file(file.filename):
return {
"success":False,
"msg":"图片必须以 '.jpg', '.jpeg', '.png', '.tif' 结尾",
"data":None
}
if file.size > param.MAX_FILE_SAM_SIZE:
return {
"success":False,
"msg":"图片大小不能大于10MB",
"data":None
}
upload_dir = os.path.join(current_dir,'tmp',type, str(uuid.uuid4()))
if not os.path.exists(upload_dir):
os.makedirs(upload_dir )
input_dir = os.path.join(upload_dir,'input')
os.makedirs(input_dir)
output_dir = os.path.join(upload_dir,'output')
os.makedirs(output_dir)
temp_dir = os.path.join(upload_dir,'temp')
os.makedirs(temp_dir)
# 将文件保存到指定目录
file_location = os.path.join(input_dir , file.filename)
with open(file_location, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
config_path = os.path.join(upload_dir,'model_params.pickle')
# # 初始化配置 , 每次上传图片时,会创建一个新的配置文件
config = copy.deepcopy(sam_config)
api_config = copy.deepcopy(sam_api_config)
config['input_dir'] = input_dir
config['output_dir'] = output_dir
config['image_files'] = [f for f in os.listdir(input_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff'))]
config['current_index'] = 0
config['filename'] = config['image_files'][config['current_index']]
image_path = os.path.join(config['input_dir'], config['filename'])
config['image'] = cv2.imread(image_path)
config['image_rgb'] = cv2.cvtColor(config['image'].copy(), cv2.COLOR_BGR2RGB)
# 配置类获取参数信息
api_config['output_dir'] = output_dir
# # 重置pickle中的信息
# 重置API中的class_annotations保留类别但清除掩码和点
config, api_config = sam_deal.reset_class_annotations(config,api_config)
config = sam_deal.reset_annotation(config)
save_data = (config,api_config)
with open(config_path, 'wb') as file:
pickle.dump(save_data, file)
# 将图片拷贝到temp目录下
pil_img = Image.fromarray(config['image_rgb'])
tmp_path = os.path.join(temp_dir,'output_image.jpg')
pil_img.save(tmp_path)
encoded_string = sam_deal.load_tmp_image(tmp_path)
return {"success":True,
"msg":"获取信息成功",
"image": JSONResponse(content={"image_data": encoded_string}),
#"input_dir": input_dir,
#'output_dir': output_dir,
#'temp_dir': temp_dir,
'file_path': upload_dir}
"""
添加分类
添加分类,并将分类设置为最新的添加的分类;
要求针对返回current_index,将列表默认成选择current_index对应的分类
"""
@app.get("/ai-station-api/sam_class/create")
async def sam_class_set(
class_name: str = None,
color: Optional[List[int]] = Query(None, description="list of RGB color"),
path: str = None
):
loaded_data,api_config = sam_deal.load_model(path)
result = sam_deal.add_class(loaded_data,class_name,color)
if result['status'] == True:
loaded_data = result['reason']
else:
return {
"success":False,
"msg":result['reason'],
"data":None
}
loaded_data['class_index'] = loaded_data['class_names'].index(class_name)
r, g, b = [int(c) for c in color]
bgr_color = (b, g, r)
result, api_config = sam_deal.set_current_class(loaded_data, api_config, loaded_data['class_index'], color=bgr_color)
# 更新配置内容
sam_deal.save_model(loaded_data,api_config,path)
tmp_path = os.path.join(path,'temp/output_image.jpg')
encoded_string = sam_deal.load_tmp_image(tmp_path)
return {"success":True,
"msg":f"已添加类别: {class_name}, 颜色: {color}",
"image": JSONResponse(content={"image_data": encoded_string}),
"data":{"class_name_list": loaded_data['class_names'],
"current_index": loaded_data['class_index'],
"class_dict":loaded_data['class_colors'],
}}
"""
选择颜色,
current_index : 下拉列表中的分类索引
rgb_color :
"""
@app.get("/ai-station-api/sam_color/select")
async def set_sam_color(
current_index: int = None,
rgb_color: List[int] = Query(None, description="list of RGB color"),
path: str = None
):
loaded_data,api_config = sam_deal.load_model(path)
r, g, b = [int(c) for c in rgb_color]
bgr_color = (b, g, r)
data, api = sam_deal.set_class_color(loaded_data, api_config, current_index, bgr_color)
result, api_config = sam_deal.set_current_class(data, api, current_index, color=bgr_color)
sam_deal.save_model(data,api,path)
img = sam_deal.refresh_image(data,api,path)
if img['status'] == True:
encoded_string = sam_deal.load_tmp_image(img['reason'])
return {
"success":True,
"msg":"",
"image": JSONResponse(content={"image_data": encoded_string}),
}
else:
return {
"success":False,
"msg":img['reason'],
"image":None
}
"""
选择分类
"""
@app.get("/ai-station-api/sam_classs/change")
async def on_class_selected(class_index : int = None,path: str = None):
# 加载配置内容
loaded_data,api_config = sam_deal.load_model(path)
result, api_config = sam_deal.set_current_class(loaded_data, api_config, class_index, color=None)
# loaded_data['class_index'] = class_index
sam_deal.save_model(loaded_data,api_config,path)
if result:
img = sam_deal.refresh_image(loaded_data,api_config,path)
if img['status'] == True:
encoded_string = sam_deal.load_tmp_image(img['reason'])
return {
"success":True,
"msg":"",
"image":JSONResponse(content={"image_data": encoded_string})
}
else:
return {
"success":False,
"msg":img['reason'],
"image":None
}
else:
return {
"success":False,
"msg":"分类标签识别错误",
"image":None
}
"""
删除分类, 前端跳转为normal index=0
"""
@app.get("/ai-station-api/sam_class/delete")
async def sam_remove_class(path:str=None,select_index:int=None):
loaded_data,api_config = sam_deal.load_model(path)
class_name = loaded_data['class_names'][select_index]
loaded_data,api_config = sam_deal.remove_class(loaded_data,api_config,class_name)
sam_deal.save_model(loaded_data,api_config,path)
img = sam_deal.refresh_image(loaded_data,api_config,path)
if img['status'] == True:
encoded_string = sam_deal.load_tmp_image(img['reason'])
return {
"success":True,
"msg":"",
"image":JSONResponse(content={"image_data": encoded_string})
}
else:
return {
"success":False,
"msg":img['reason'],
"image":None
}
"""
添加标注点 -左键
"""
@app.get("/ai-station-api/sam_point/left_add")
async def left_mouse_down(x:int=None,y:int=None,path:str=None):
loaded_data,api_config = sam_deal.load_model(path)
if not api_config['current_class']:
return {
"success":False,
"msg":"请先选择一个分类,在添加标点之前",
"image":None
}
is_foreground = True
result = sam_deal.add_annotation_point(api_config,x,y,is_foreground)
if result['status']== False:
return {
"success":False,
"msg":result['reason'],
"image":None
}
api_config = result['api']
sam_deal.save_model(loaded_data,api_config,path)
img = sam_deal.refresh_image(loaded_data,api_config,path)
if img['status'] == True:
encoded_string = sam_deal.load_tmp_image(img['reason'])
return {
"success":True,
"msg":"",
"image":JSONResponse(content={"image_data": encoded_string})
}
else:
return {
"success":False,
"msg":img['reason'],
"image":None
}
"""
添加标注点 -右键
"""
@app.get("/ai-station-api/sam_point/right_add")
async def right_mouse_down(x:int=None,y:int=None,path:str=None):
loaded_data,api_config = sam_deal.load_model(path)
if not api_config['current_class']:
return {
"success":False,
"msg":"请先选择一个分类,在添加标点之前",
"image":None
}
is_foreground = False
result = sam_deal.add_annotation_point(api_config,x,y,is_foreground)
if result['status']== False:
return {
"success":False,
"msg":result['reason'],
"image":None
}
api_config = result['api']
sam_deal.save_model(loaded_data,api_config,path)
img = sam_deal.refresh_image(loaded_data,api_config,path)
if img['status'] == True:
encoded_string = sam_deal.load_tmp_image(img['reason'])
return {
"success":True,
"msg":"",
"image":JSONResponse(content={"image_data": encoded_string})
}
else:
return {
"success":False,
"msg":img['reason'],
"image":None
}
"""
删除上一个点
"""
@app.get("/ai-station-api/sam_point/delete_last")
async def sam_delete_last_point(path:str=None):
loaded_data,api_config = sam_deal.load_model(path)
result = sam_deal.delete_last_point(api_config)
if result['status'] == True:
api_config = result['reason']
sam_deal.save_model(loaded_data,api_config,path)
img = sam_deal.refresh_image(loaded_data,api_config,path)
if img['status'] == True:
encoded_string = sam_deal.load_tmp_image(img['reason'])
return {
"success":True,
"msg":"",
"image":JSONResponse(content={"image_data": encoded_string})
}
else:
return {
"success":False,
"msg":img['reason'],
"data":None
}
else:
return {
"success":False,
"msg":result['reason'],
"data":None
}
"""
删除所有点
"""
@app.get("/ai-station-api/sam_point/delete_all")
async def sam_clear_all_point(path:str=None):
loaded_data,api_config = sam_deal.load_model(path)
result = sam_deal.reset_current_class_points(api_config)
if result['status'] == True:
api_config = result['reason']
sam_deal.save_model(loaded_data,api_config,path)
img = sam_deal.refresh_image(loaded_data,api_config,path)
if img['status'] == True:
encoded_string = sam_deal.load_tmp_image(img['reason'])
return {
"success":True,
"msg":"",
"image":JSONResponse(content={"image_data": encoded_string})
}
else:
return {
"success":False,
"msg":img['reason'],
"image":None
}
else:
return {
"success":False,
"msg":result['reason'],
"image":None
}
"""
模型预测
"""
@app.get("/ai-station-api/sam_model/predict")
async def sam_predict_mask(path:str=None):
loaded_data,api_config = sam_deal.load_model(path)
class_data = api_config['class_annotations'].get(api_config['current_class'], {})
if not class_data.get('points'):
return {
"success":False,
"msg":"请在预测前添加至少一个预测样本点",
"data":None
}
else:
loaded_data = sam_deal.reset_annotation(loaded_data)
# 将标注点添加到loaded_data 中
for i, (x, y) in enumerate(class_data['points']):
is_foreground = class_data['point_types'][i]
loaded_data = sam_deal.add_point(loaded_data, x, y, is_foreground=is_foreground)
try:
result = sam_deal.predict_mask(loaded_data,sam_predictor)
if result['status'] == False:
return {
"success":False,
"msg":result['reason'],
"data":None}
result = result['reason']
loaded_data = result['data']
class_data['masks'] = [np.array(mask, dtype=np.uint8) for mask in result['masks']]
class_data['scores'] = result['scores']
class_data['selected_mask_index'] = result['selected_index']
if result['selected_index'] >= 0:
class_data['selected_mask'] = class_data['masks'][result['selected_index']]
logger.info(f"predict: Predicted {len(result['masks'])} masks, selected index: {result['selected_index']}")
sam_deal.save_model(loaded_data,api_config,path)
img = sam_deal.refresh_image(loaded_data,api_config,path)
if img['status'] == True:
encoded_string = sam_deal.load_tmp_image(img['reason'])
return {
"success":True,
"msg":"",
"image":JSONResponse(content={"image_data": encoded_string})
}
else:
return {
"success":False,
"msg":img['reason'],
"image":None
}
except Exception as e:
logger.error(f"predict: Error during prediction: {str(e)}")
traceback.print_exc()
return {
"success":False,
"msg":f"predict: Error during prediction: {str(e)}",
"data":None}
"""
清除所有信息
"""
@app.get("/ai-station-api/sam_model/clear_all")
async def sam_reset_annotation(path:str=None):
loaded_data,api_config = sam_deal.load_model(path)
loaded_data,api_config = sam_deal.reset_annotation_all(loaded_data,api_config)
img = sam_deal.refresh_image(loaded_data,api_config,path)
if img['status'] == True:
encoded_string = sam_deal.load_tmp_image(img['reason'])
return {
"success":True,
"msg":"",
"image":JSONResponse(content={"image_data": encoded_string})
}
else:
return {
"success":False,
"msg":img['reason'],
"image":None
}
"""
保存预测分类
"""
@app.get("/ai-station-api/sam_model/save_stage")
async def sam_add_to_class(path:str=None,class_index:int=None):
loaded_data,api_config = sam_deal.load_model(path)
class_name = loaded_data['class_names'][class_index]
result = sam_deal.add_to_class(api_config,class_name)
if result['status'] == True:
api_config = result['reason']
sam_deal.save_model(loaded_data,api_config,path)
img = sam_deal.refresh_image(loaded_data,api_config,path)
if img['status'] == True:
encoded_string = sam_deal.load_tmp_image(img['reason'])
return {
"success":True,
"msg":"",
"image":JSONResponse(content={"image_data": encoded_string})
}
else:
return {
"success":False,
"msg":img['reason'],
"image":None
}
else:
return {
"success":False,
"msg":result['reason'],
"image":None
}
"""
保存结果
"""
@app.get("/ai-station-api/sam_model/save")
async def sam_save_annotation(path:str=None):
loaded_data,api_config = sam_deal.load_model(path)
if not api_config['output_dir']:
logger.info("save_annotation: Output directory not set")
return {
"success":False,
"msg":"save_annotation: Output directory not set",
"data":None
}
has_annotations = False
for class_name, class_data in api_config['class_annotations'].items():
if 'final_mask' in class_data and class_data['final_mask'] is not None:
has_annotations = True
break
if not has_annotations:
logger.info("save_annotation: No final masks to save")
return {
"success":False,
"msg":"save_annotation: No final masks to save",
"data":None
}
image_info = sam_deal.get_image_info(loaded_data)
if not image_info:
logger.info("save_annotation: No image info available")
return {
"success":False,
"msg":"save_annotation: No image info available",
"data":None
}
image_basename = os.path.splitext(image_info['filename'])[0]
annotation_dir = os.path.join(api_config['output_dir'], image_basename)
os.makedirs(annotation_dir, exist_ok=True)
saved_files = []
orig_img = loaded_data['image']
original_img_path = os.path.join(annotation_dir, f"{image_basename}.jpg")
cv2.imwrite(original_img_path, orig_img)
saved_files.append(original_img_path)
vis_img = orig_img.copy()
img_height, img_width = orig_img.shape[:2]
labelme_data = {
"version": "5.1.1",
"flags": {},
"shapes": [],
"imagePath": f"{image_basename}.jpg",
"imageData": None,
"imageHeight": img_height,
"imageWidth": img_width
}
for class_name, class_data in api_config['class_annotations'].items():
if 'final_mask' in class_data and class_data['final_mask'] is not None:
color = api_config['class_colors'].get(class_name, (0, 255, 0))
vis_mask = class_data['final_mask'].copy()
color_mask = np.zeros_like(vis_img)
color_mask[vis_mask > 0] = color
vis_img = cv2.addWeighted(vis_img, 1.0, color_mask, 0.5, 0)
binary_mask = (class_data['final_mask'] > 0).astype(np.uint8)
contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
for contour in contours:
epsilon = 0.0001 * cv2.arcLength(contour, True)
approx_contour = cv2.approxPolyDP(contour, epsilon, True)
points = [[float(point[0][0]), float(point[0][1])] for point in approx_contour]
if len(points) >= 3:
shape_data = {
"label": class_name,
"points": points,
"group_id": None,
"shape_type": "polygon",
"flags": {}
}
labelme_data["shapes"].append(shape_data)
vis_path = os.path.join(annotation_dir, f"{image_basename}_mask.jpg")
cv2.imwrite(vis_path, vis_img)
saved_files.append(vis_path)
try:
is_success, buffer = cv2.imencode(".jpg", orig_img)
if is_success:
img_bytes = io.BytesIO(buffer).getvalue()
labelme_data["imageData"] = base64.b64encode(img_bytes).decode('utf-8')
else:
print("save_annotation: Failed to encode image data")
labelme_data["imageData"] = ""
except Exception as e:
logger.error(f"save_annotation: Could not encode image data: {str(e)}")
labelme_data["imageData"] = ""
json_path = os.path.join(annotation_dir, f"{image_basename}.json")
with open(json_path, 'w') as f:
json.dump(labelme_data, f, indent=2)
saved_files.append(json_path)
logger.info(f"save_annotation: Annotation saved to {annotation_dir}")
# 将其打包
zip_filename = "download.zip"
zip_filepath = Path(zip_filename)
dir_path = os.path.dirname(annotation_dir)
# 创建 ZIP 文件
try:
with zipfile.ZipFile(zip_filepath, 'w') as zip_file:
# 遍历文件夹中的所有文件
for foldername, subfolders, filenames in os.walk(dir_path):
for filename in filenames:
file_path = os.path.join(foldername, filename)
# 将文件写入 ZIP 文件
zip_file.write(file_path, os.path.relpath(file_path, dir_path))
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error creating zip file: {str(e)}")
# 返回 ZIP 文件作为响应
return FileResponse(zip_filepath, media_type='application/zip', filename=zip_filename)
# @app.post("/ai-station-api/items/")
# async def create_item(item: post_model.Zongbiaomianji):
# try:
# data = {
# "type": type(item).__name__, # 返回类型名称
# "received_data": item.model_dump() # 使用 model_dump() 方法
# }
# print(pd.DataFrame([item.model_dump()]))
# # 返回接收到的数据
# return data
# except Exception as e:
# # 记录错误信息
# logger.error(f"Error occurred: {e}")