ai-station-code/work_util/model_deal.py

1089 lines
49 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from PIL import Image
from work_util import prepare_data,params
from work_util.logger import logger
import os,sys
import xgboost as xgb
import uuid
from collections import defaultdict
import copy
import torch
import numpy as np
import pandas as pd
import cv2 as cv
import torch.nn.functional as F
from joblib import dump, load
from fastapi import HTTPException
import pickle
param = params.ModelParams()
import cv2
import traceback
################################################### 图像类函数调用###########################################################################
# 生成随机目录,确保文件处理不冲突
def create_tmp_path(self, path):
folder_name = str(uuid.uuid4())
folder_path = os.path.join(path,'tmp', folder_name)
os.makedirs(folder_path)
return folder_path
# 遍历图像,并保存返回
def get_images(path):
imglist = []
lines = os.listdir(path)
# print(lines)
for ii, line in enumerate(lines):
name = line
_imagepath = os.path.join(path, name)
assert os.path.isfile(_imagepath)
image = Image.open(_imagepath)
orininal_h = image.size[1]
orininal_w = image.size[0]
item = {"name": name, "orininal_h": orininal_h, "orininal_w": orininal_w, "image": image}
imglist.append(item)
# print("共监测到{}张原始图像和标签".format(len(imglist)))
return imglist
# 地貌识别——模型计算
def dimaoshibie_cal(model,path,count=None, name_classes=None):
image = Image.open(path)
r_image, count_dict, classes_nums = model.detect_image(image, count=count, name_classes=name_classes)
# 十一个分类的每个的像素点
logger.info(count_dict)
# 总像素点数量,对应十一个分类
logger.info(classes_nums)
return r_image, count_dict, classes_nums
"""
1. 图片分割成规定大小
2. 遍历分割图片,进行预测
3. 图片合并成大图
4. 删除分割的小图片
"""
def dimaoshibie_pic(model,path,count=None, name_classes=None):
single = prepare_data.cut_big_image(path)
if single == False:
return {
"success": False,
"msg": "图片大小不合规请检查图片大小是否符合尺寸大于512*512文件大小小于10MB",
"data": None
}
else:
file_directory = os.path.dirname(path)
folder_path = os.path.join(file_directory,'ori')
lines = os.listdir(folder_path)
# 创建小图结果目录
output_floder_path = os.path.join(file_directory,'binary')
# 创建输出文件夹(如果不存在)
os.makedirs(output_floder_path, exist_ok=True)
total_piex = {
"_background_" : 0, # Background (黑色)
"Cropland" : 0, # Cropland (淡黄色)
"Forest": 0, # Forest (深绿色)
"Grass": 0, # Grass (浅绿色)
"Shrub": 0, # Shrub (浅蓝绿色)
"Wetland": 0, # Wetland (浅蓝色)
"Water": 0, # Water (深蓝色)
"Tundra": 0, # Tundra (土黄色)
"Impervious surface": 0, # Impervious surface (红色)
"Bareland": 0, # Bareland (灰色)
"Ice/snow": 0 # Ice/snow (浅天蓝色)
}
for ii, line in enumerate(lines):
_imagepath = os.path.join(folder_path, line)
r_image, count_dict, classes_nums = dimaoshibie_cal(model,_imagepath,count, name_classes)
final_path = os.path.join(output_floder_path,line)
# 保存融合后的图像
r_image.save(final_path) # 替换为你想保存的路径
for key in count_dict:
if key in total_piex:
total_piex[key] += count_dict[key]
# 将像素分类,存储
# 创建一个新的字典,用于存储中文键
translated_total_piex = {}
for key, value in total_piex.items():
if key in param.dmsb_type:
translated_total_piex[param.dmsb_type[key]] = value
output_file_path = os.path.join(file_directory,'result.txt')
# 写入字典到文本文件
with open(output_file_path, 'w') as file:
for key, value in translated_total_piex.items():
file.write(f"{key}: {value}\n")
# 图片合并
status = prepare_data.merge_pic_binary(path,['binary'])
if status['status'] == False:
return {
"success": False,
"msg": "图片合并失败,请稍后再试",
"data": None
}
target_path = status['path']
# 临时文件删除
tmp = prepare_data.delete_folder(folder_path)
if tmp==False:
return {
"success": False,
"msg": "临时文件删除失败",
"data": None
}
tmp = prepare_data.delete_folder(output_floder_path)
if tmp==False:
return {
"success": False,
"msg": "临时文件删除失败",
"data": None
}
# 返回图片地址
return {"status":True, "reason":target_path}
"""
计算每一类地貌的面积大小
1、读取前面的图片分割结果
2、计算每一类的图像像素点数量
3、计算每一类的像素点面积
"""
def dimaoshibie_area(path,scale,colors=None):
file_directory = os.path.dirname(path)
file_name = os.path.basename(path)
file_path = os.path.join(file_directory,"merge_binary_" + file_name)
if not os.path.exists(file_path):
return {"status": False, "reason": "没有找到对应文件,请先进行图像分割", "result":""}
image = Image.open(path)
try:
# 如果文件存在,读取图片
# logger.info(file_path)
image = Image.open(file_path)
image = image.convert('RGB')
color_count = defaultdict(int)
# 遍历图像中的每个像素
for x in range(image.width):
for y in range(image.height):
pixel_color = image.getpixel((x, y))
if pixel_color in colors:
color_count[pixel_color] += 1
else:
color_count[pixel_color] += 1
for k,v in color_count.items():
color_count[k] = v * scale
# logger.info(color_count)
result = {color: color_count[color] for color in colors}
return {"status": True, "reason": result}
except Exception as e:
return {"status": False, "reason": str(e)}
"""
1. 图片分割成规定大小
2. 遍历分割图片,进行预测
3. 图片合并成大图
4. 删除分割的小图片
"""
def roof_pic(net,path,palette):
single = prepare_data.cut_big_image(path)
logger.info("图片分割完成")
if single == False:
return {"status":False, "reason":"图片大小不合规请检查图片大小是否符合尺寸大于512*512文件大小小于10MB"}
else:
file_directory = os.path.dirname(path)
folder_path = os.path.join(file_directory,'ori')
# 创建小图结果目录
output_floder_path = os.path.join(file_directory,'binary')
# 创建输出文件夹(如果不存在)
os.makedirs(output_floder_path, exist_ok=True)
imglist = get_images(folder_path)
assert len(imglist) != 0
for i in imglist:
image = i["image"]
name = i["name"]
# orininal_w = i["orininal_w"]
# orininal_h = i["orininal_h"]
# old_img = copy.deepcopy(image)
imaged = cv.resize(np.array(image), dsize=(512, 512), interpolation=cv.INTER_LINEAR)
image_data = np.expand_dims(
np.transpose(prepare_data.roof_pv_preprocess_input(np.array(imaged, np.float32), md=False), (2, 0, 1)), 0)
with torch.no_grad():
images = torch.from_numpy(image_data)
image = images.to(device="cuda", dtype=torch.float32)
model = net.to(device="cuda")
out = model(image) # batch_size, 2, 512, 512
if isinstance(out, list) or isinstance(out, tuple): # 可能有多个输出(这里把辅助解码头的也输出的所以是多个)
out = out[0] # 就取第一个
out = out[0] # 去掉batch
pr = F.softmax(out.permute(1, 2, 0), dim=-1).cpu().numpy()
result = pr.argmax(axis=-1)
# 结果图
output_im = Image.fromarray(np.uint8(result)).convert('P')
output_im.putpalette(palette)
tmp_path_binary= os.path.join(output_floder_path,name)
output_im.save(tmp_path_binary)
# 图片合并
status = prepare_data.merge_pic_binary(path,['binary'])
if status['status'] == False:
return {"status":False, "reason":"图片合并失败"}
target_path = status['path']
# 临时文件删除
tmp = prepare_data.delete_folder(folder_path)
if tmp==False:
return {"status":False, "reason":"临时文件删除失败"}
tmp = prepare_data.delete_folder(output_floder_path)
if tmp==False:
return {"status":False, "reason":"临时文件删除失败"}
# 返回图片地址
return {"status":True, "reason":target_path}
"""
计算每一类地貌的面积大小
1、读取前面的图片分割结果
2、计算每一类的图像像素点数量
3、计算每一类的像素点面积
"""
def roof_area(path,scale,colors=None):
file_directory = os.path.dirname(path)
file_name = os.path.basename(path)
file_path = os.path.join(file_directory,"merge_binary_" + file_name)
if not os.path.exists(file_path):
return {"status": False, "reason": "没有找到对应文件,请先进行图像分割", "result":""}
image = Image.open(path)
try:
# 如果文件存在,读取图片
image = Image.open(file_path)
image = image.convert('RGB')
color_count = defaultdict(int)
# 遍历图像中的每个像素
for x in range(image.width):
for y in range(image.height):
pixel_color = image.getpixel((x, y))
if pixel_color in colors:
color_count[pixel_color] += 1
for k,v in color_count.items():
color_count[k] = v * scale
result = {color: color_count[color] for color in colors}
return {"status": True, "reason":result}
except Exception as e:
return {"status": False, "reason": str(e)}
def roof_area_roofpv(path,scale,colors=None):
file_directory = os.path.dirname(path)
file_name = os.path.basename(path)
file_path = os.path.join(file_directory,"merge_roofpv_binary_"+file_name)
if not os.path.exists(file_path):
return {"status": False, "reason": "没有找到对应文件,请先进行图像分割", "result":""}
image = Image.open(path)
try:
# 如果文件存在,读取图片
image = Image.open(file_path)
image = image.convert('RGB')
color_count = defaultdict(int)
# 遍历图像中的每个像素
for x in range(image.width):
for y in range(image.height):
pixel_color = image.getpixel((x, y))
if pixel_color in colors:
color_count[pixel_color] += 1
for k,v in color_count.items():
color_count[k] = v * scale
result = {color: color_count[color] for color in colors}
return {"status": True, "reason":result}
except Exception as e:
return {"status": False, "reason": str(e)}
"""
1. 图片分割成规定大小
2. 遍历分割图片,进行预测
3. 图片合并成大图
4. 删除分割的小图片
"""
def roofpv_pic(net_roof,net_pv,path,palette):
single = prepare_data.cut_big_image(path)
logger.info("图片分割完成")
if single == False:
return {"status":False, "reason":"图片大小不合规请检查图片大小是否符合尺寸大于512*512文件大小小于10MB"}
else:
file_directory = os.path.dirname(path)
folder_path = os.path.join(file_directory,'ori')
# 创建小图结果目录
output_floder_path_roof = os.path.join(file_directory, 'roof_binary')
output_floder_path_pv = os.path.join(file_directory, 'pv_binary')
# 创建输出文件夹(如果不存在)
os.makedirs(output_floder_path_roof, exist_ok=True)
os.makedirs(output_floder_path_pv, exist_ok=True)
imglist = get_images(folder_path)
assert len(imglist) != 0
for i in imglist:
image = i["image"]
name = i["name"]
# orininal_w = i["orininal_w"]
# orininal_h = i["orininal_h"]
# old_img = copy.deepcopy(image)
imaged = cv.resize(np.array(image), dsize=(512, 512), interpolation=cv.INTER_LINEAR)
image_data = np.expand_dims(
np.transpose(prepare_data.roof_pv_preprocess_input(np.array(imaged, np.float32), md=False), (2, 0, 1)), 0)
with torch.no_grad():
images = torch.from_numpy(image_data)
image = images.to(device="cuda", dtype=torch.float32)
model = net_roof.to(device="cuda")
out = model(image) # batch_size, 2, 512, 512
if isinstance(out, list) or isinstance(out, tuple): # 可能有多个输出(这里把辅助解码头的也输出的所以是多个)
out = out[0] # 就取第一个
out = out[0] # 去掉batch
pr = F.softmax(out.permute(1, 2, 0), dim=-1).cpu().numpy()
result = pr.argmax(axis=-1)
# 结果图
output_im = Image.fromarray(np.uint8(result)).convert('P')
output_im.putpalette(palette)
tmp_path_binary= os.path.join(output_floder_path_roof,name)
output_im.save(tmp_path_binary)
with torch.no_grad():
images = torch.from_numpy(image_data)
image = images.to(device="cuda", dtype=torch.float32)
model = net_pv.to(device="cuda")
out = model(image) # batch_size, 2, 512, 512
if isinstance(out, list) or isinstance(out, tuple): # 可能有多个输出(这里把辅助解码头的也输出的所以是多个)
out = out[0] # 就取第一个
out = out[0] # 去掉batch
pr = F.softmax(out.permute(1, 2, 0), dim=-1).cpu().numpy()
result = pr.argmax(axis=-1)
# 结果图
output_im = Image.fromarray(np.uint8(result)).convert('P')
output_im.putpalette(palette)
tmp_path_binary= os.path.join(output_floder_path_pv,name)
output_im.save(tmp_path_binary)
# 图片合并
sataus = prepare_data.merge_pic_binary(path,['roof_binary','pv_binary'])
if sataus['status'] == False:
return {"status":False, "reason":"图片合并失败"}
target_path = sataus['path']
# 临时文件删除
tmp = prepare_data.delete_folder(folder_path)
if tmp==False:
return {"status":False, "reason":"临时文件删除失败"}
tmp = prepare_data.delete_folder(output_floder_path_roof)
if tmp==False:
return {"status":False, "reason":"临时文件删除失败"}
tmp = prepare_data.delete_folder(output_floder_path_pv)
if tmp==False:
return {"status":False, "reason":"临时文件删除失败"}
# 返回图片地址
return {"status":True, "reason":target_path}
# 甲烷预测
# def start_predict_endpoint(ch4_model_flow,ch4_model_gas,data_path, start_index, end_index, type, is_show):
# try:
# data = pd.read_csv(data_path)
# data['date_time'] = pd.to_datetime(data['date_time'])
# max_date = data['date_time'].max()
# min_date = data['date_time'].min()
# start_index = pd.to_datetime(start_index)
# end_index = pd.to_datetime(end_index)
# if max_date < end_index :
# return {"reason": "结束日期填写错误,超过上传数据最大日期","status": False}
# if min_date > start_index :
# return {"reason": "开始日期填写错误,小于上传数据最小日期","status": False}
# end_index_dt = pd.to_datetime(end_index)
# end_index_plus_one_hour = end_index_dt + pd.Timedelta(hours=1)
# filtered_data = data[(data['date_time'] >= start_index) & (data['date_time'] <= end_index)]
# if is_show:
# if len(filtered_data) < 96:
# return {"reason": "日期填写错误截取步长应该超过96个步长","status": False}
# if max_date < end_index_plus_one_hour:
# return {"reason": "选择显示真实值,需要保留最终后四节点作为展示信息,请调整结束日期","status": False}
# else:
# if len(filtered_data) < 96:
# return {"reason": "上传文件中有效信息长度应大于96","status": False}
# train_data = prepare_data.get_pred_data(data,start_index,end_index)
# del train_data['date_time']
# train_data = np.array(train_data.values)
# train_data = xgb.DMatrix(train_data)
# target = None
# if type == 1: # 流量
# target = "Nm3d-1-ch4"
# result = ch4_model_flow.predict(train_data)
# else: # 气相
# target = "S_gas_ch4"
# result = ch4_model_gas.predict(train_data)
# if is_show:
# history = data[(data['date_time'] >= start_index) & (data['date_time'] <= end_index_plus_one_hour)]
# cols = ['date_time']
# cols.append(target)
# history = history[cols]
# history.rename(columns={target: 'true_data'}, inplace=True)
# history['pred_data'] = 0
# total_rows = len(history)
# history.reset_index(drop=True, inplace=True)
# history.loc[total_rows - 4:total_rows - 1, 'pred_data'] = result[0]
# else:
# history = data[(data['date_time'] >= start_index) & (data['date_time'] <= end_index)]
# history.reset_index(drop=True, inplace=True)
# cols = ['date_time']
# cols.append(target)
# history = history[cols]
# history.rename(columns={target: 'true_data'}, inplace=True)
# history['pred_data'] = 0
# last_date = history['date_time'].iloc[-1]
# # 创建新的日期和对应的值
# new_dates = [last_date + pd.Timedelta(minutes=15 * (i + 1)) for i in range(4)]
# new_data = pd.DataFrame({
# 'date_time': new_dates,
# 'true_data':[0,0,0,0],
# 'pred_data': result[0]
# })
# history = pd.concat([history, new_data], ignore_index=True)
# return {"status": True,"reason":history.to_dict(orient='records')}
# except Exception as e:
# return{"reason": str(e),"status":False}
def start_predict_endpoint(ch4_model_flow,ch4_model_gas,data_path, start_index, end_index, type, is_show):
try:
data = pd.read_csv(data_path)
data['date_time'] = pd.to_datetime(data['date_time'])
max_date = data['date_time'].max()
min_date = data['date_time'].min()
start_index = pd.to_datetime(start_index)
end_index = pd.to_datetime(end_index)
if max_date < end_index :
return {"reason": "结束日期填写错误,超过上传数据最大日期","status": False}
if min_date > start_index :
return {"reason": "开始日期填写错误,小于上传数据最小日期","status": False}
end_index_dt = pd.to_datetime(end_index)
end_index_plus_one_hour = end_index_dt + pd.Timedelta(hours=1)
filtered_data = data[(data['date_time'] >= start_index) & (data['date_time'] <= end_index)]
if is_show:
if len(filtered_data) < 96:
return {"reason": "日期填写错误截取步长应该超过96个步长","status": False}
if max_date < end_index_plus_one_hour:
return {"reason": "选择显示真实值,需要保留最终后四节点作为展示信息,请调整结束日期","status": False}
else:
if len(filtered_data) < 96:
return {"reason": "上传文件中有效信息长度应大于96","status": False}
train_data = prepare_data.get_pred_data(data,start_index,end_index)
del train_data['date_time']
train_data = np.array(train_data.values)
train_data = xgb.DMatrix(train_data)
target = None
if type == 1: # 流量
target = "Nm3d-1-ch4"
result = ch4_model_flow.predict(train_data)
else: # 气相
target = "S_gas_ch4"
result = ch4_model_gas.predict(train_data)
if is_show:
"""两字典,历史+真实,预测"""
history = data[(data['date_time'] >= start_index) & (data['date_time'] <= end_index_plus_one_hour)]
cols = ['date_time']
cols.append(target)
history = history[cols]
# history.rename(columns={target: 'true_data'}, inplace=True)
last_date = history['date_time'].iloc[-5]
new_dates = [last_date + pd.Timedelta(minutes=15 * (i + 1)) for i in range(4)]
new_data = pd.DataFrame({
'date_time': new_dates,
target: result[0]
})
else:
history = data[(data['date_time'] >= start_index) & (data['date_time'] <= end_index)]
history.reset_index(drop=True, inplace=True)
cols = ['date_time']
cols.append(target)
history = history[cols]
last_date = history['date_time'].iloc[-1]
# 创建新的日期和对应的值
new_dates = [last_date + pd.Timedelta(minutes=15 * (i + 1)) for i in range(4)]
new_data = pd.DataFrame({
'date_time': new_dates,
target: result[0]
})
return {"status": True,"reason":[history.to_dict(orient='records'),new_data.to_dict(orient='records')]}
except Exception as e:
return{"reason": str(e),"status":False}
# 光伏出力预测 - 预测长度96
def start_pvelectric_predict_endpoint(pvfd_model,data_path, start_index, end_index, is_show):
try:
data = pd.read_csv(data_path)
data['date_time'] = pd.to_datetime(data['date_time'])
max_date = data['date_time'].max()
min_date = data['date_time'].min()
start_index = pd.to_datetime(start_index)
end_index = pd.to_datetime(end_index)
if max_date < end_index :
return {"reason": "结束日期填写错误,超过上传数据最大日期","status": False}
if min_date > start_index :
return {"reason": "开始日期填写错误,小于上传数据最小日期","status": False}
end_index_dt = pd.to_datetime(end_index)
end_index_plus_one_day = end_index_dt + pd.Timedelta(hours=24)
filtered_data = data[(data['date_time'] >= start_index) & (data['date_time'] <= end_index)]
if is_show:
if len(filtered_data) < 192:
return {"reason": "日期填写错误截取步长应该超过192个步长","status": False}
if max_date < end_index_plus_one_day:
return {"reason": "选择显示真实值需要保留最终一天数据96点作为展示信息请调整结束日期","status": False}
else:
if len(filtered_data) < 192:
return {"reason": "上传文件中有效信息长度应大于192","status": False}
predictions = pvfd_model.run_inference(filtered_data)
if predictions['status'] == True:
predictions_value = np.array(predictions['reason']).flatten()
predictions_value = [max(0, x) for x in predictions_value]
target = "power"
if is_show:
history = data[(data['date_time'] >= start_index) & (data['date_time'] <= end_index_plus_one_day)]
cols = ['date_time']
cols.append(target)
history = history[cols]
# history.rename(columns={target: 'true_data'}, inplace=True)
# history['pred_data'] = 0
last_date = history['date_time'].iloc[-97]
new_dates = [last_date + pd.Timedelta(minutes=15 * (i + 1)) for i in range(96)]
new_data = pd.DataFrame({
'date_time': new_dates,
target: predictions_value
})
else:
history = data[(data['date_time'] >= start_index) & (data['date_time'] <= end_index)]
history.reset_index(drop=True, inplace=True)
cols = ['date_time']
cols.append(target)
history = history[cols]
last_date = history['date_time'].iloc[-1]
# 创建新的日期和对应的值
new_dates = [last_date + pd.Timedelta(minutes=15 * (i + 1)) for i in range(96)]
new_data = pd.DataFrame({
'date_time': new_dates,
target: predictions_value
})
return {"status": True,"reason":[history.to_dict(orient='records'),new_data.to_dict(orient='records')]}
except Exception as e:
return{"reason": str(e),"status":False}
# def start_pvelectric_predict_endpoint(pvfd_model,data_path, start_index, end_index, is_show):
# try:
# data = pd.read_csv(data_path)
# data['date_time'] = pd.to_datetime(data['date_time'])
# max_date = data['date_time'].max()
# min_date = data['date_time'].min()
# start_index = pd.to_datetime(start_index)
# end_index = pd.to_datetime(end_index)
# if max_date < end_index :
# return {"reason": "结束日期填写错误,超过上传数据最大日期","status": False}
# if min_date > start_index :
# return {"reason": "开始日期填写错误,小于上传数据最小日期","status": False}
# end_index_dt = pd.to_datetime(end_index)
# end_index_plus_one_day = end_index_dt + pd.Timedelta(hours=24)
# filtered_data = data[(data['date_time'] >= start_index) & (data['date_time'] <= end_index)]
# if is_show:
# if len(filtered_data) < 192:
# return {"reason": "日期填写错误截取步长应该超过192个步长","status": False}
# if max_date < end_index_plus_one_day:
# return {"reason": "选择显示真实值需要保留最终一天数据96点作为展示信息请调整结束日期","status": False}
# else:
# if len(filtered_data) < 192:
# return {"reason": "上传文件中有效信息长度应大于192","status": False}
# predictions = pvfd_model.run_inference(filtered_data)
# if predictions['status'] == True:
# predictions_value = np.array(predictions['reason']).flatten()
# predictions_value = [max(0, x) for x in predictions_value]
# target = "power"
# if is_show:
# history = data[(data['date_time'] >= start_index) & (data['date_time'] <= end_index_plus_one_day)]
# cols = ['date_time']
# cols.append(target)
# history = history[cols]
# history.rename(columns={target: 'true_data'}, inplace=True)
# history['pred_data'] = 0
# total_rows = len(history)
# history.reset_index(drop=True, inplace=True)
# history.loc[total_rows - 96:total_rows - 1, 'pred_data'] = predictions_value
# else:
# history = data[(data['date_time'] >= start_index) & (data['date_time'] <= end_index)]
# history.reset_index(drop=True, inplace=True)
# cols = ['date_time']
# cols.append(target)
# history = history[cols]
# history.rename(columns={target: 'true_data'}, inplace=True)
# history['pred_data'] = 0
# last_date = history['date_time'].iloc[-1]
# # 创建新的日期和对应的值
# new_dates = [last_date + pd.Timedelta(minutes=15 * (i + 1)) for i in range(96)]
# new_data = pd.DataFrame({
# 'date_time': new_dates,
# 'true_data':[0]*96,
# 'pred_data': predictions_value
# })
# history = pd.concat([history, new_data], ignore_index=True)
# return {"status": True,"reason":history.to_dict(orient='records')}
# except Exception as e:
# return{"reason": str(e),"status":False}
# 风力发电出力预测 - 预测长度12
def start_wind_electric_predict_endpoint(windfd_model,data_path, start_index, end_index, is_show):
# try:
data = pd.read_csv(data_path)
data['date'] = pd.to_datetime(data['date'])
max_date = data['date'].max()
min_date = data['date'].min()
start_index = pd.to_datetime(start_index)
end_index = pd.to_datetime(end_index)
if max_date < end_index :
return {"reason": "结束日期填写错误,超过上传数据最大日期","status": False}
if min_date > start_index :
return {"reason": "开始日期填写错误,小于上传数据最小日期","status": False}
end_index_dt = pd.to_datetime(end_index)
end_index_plus_one_day = end_index_dt + pd.Timedelta(hours=3)
filtered_data = data[(data['date'] >= start_index) & (data['date'] <= end_index)]
if is_show:
if len(filtered_data) < 192:
return {"reason": "日期填写错误截取步长应该超过192个步长","status": False}
if max_date < end_index_plus_one_day:
return {"reason": "选择显示真实值需要保留最终一天数据12个点作为展示信息请调整结束日期","status": False}
else:
if len(filtered_data) < 192:
return {"reason": "上传文件中有效信息长度应大于192","status": False}
predictions = windfd_model.run_inference(filtered_data)
if predictions['status'] == True:
predictions_value = np.array(predictions['reason']).flatten()
predictions_value = [max(0, x) for x in predictions_value]
target = "Power"
if is_show:
history = data[(data['date'] >= start_index) & (data['date'] <= end_index_plus_one_day)]
cols = ['date']
cols.append(target)
history = history[cols]
# history.rename(columns={target: 'true_data'}, inplace=True)
last_date = history['date'].iloc[-13]
new_dates = [last_date + pd.Timedelta(minutes=15 * (i + 1)) for i in range(12)]
new_data = pd.DataFrame({
'date': new_dates,
target: predictions_value
})
else:
history = data[(data['date'] >= start_index) & (data['date'] <= end_index)]
history.reset_index(drop=True, inplace=True)
cols = ['date']
cols.append(target)
history = history[cols]
last_date = history['date'].iloc[-1]
# 创建新的日期和对应的值
new_dates = [last_date + pd.Timedelta(minutes=15 * (i + 1)) for i in range(12)]
new_data = pd.DataFrame({
'date': new_dates,
target: predictions_value
})
history.rename(columns={target: 'power'}, inplace=True)
new_data.rename(columns={target: 'power'}, inplace=True)
return {"status": True,"reason":[history.to_dict(orient='records'),new_data.to_dict(orient='records')]}
# except Exception as e:
# return{"reason": str(e),"status":False}
# def start_wind_electric_predict_endpoint(windfd_model,data_path, start_index, end_index, is_show):
# # try:
# data = pd.read_csv(data_path)
# data['date'] = pd.to_datetime(data['date'])
# max_date = data['date'].max()
# min_date = data['date'].min()
# start_index = pd.to_datetime(start_index)
# end_index = pd.to_datetime(end_index)
# if max_date < end_index :
# return {"reason": "结束日期填写错误,超过上传数据最大日期","status": False}
# if min_date > start_index :
# return {"reason": "开始日期填写错误,小于上传数据最小日期","status": False}
# end_index_dt = pd.to_datetime(end_index)
# end_index_plus_one_day = end_index_dt + pd.Timedelta(hours=3)
# filtered_data = data[(data['date'] >= start_index) & (data['date'] <= end_index)]
# if is_show:
# if len(filtered_data) < 192:
# return {"reason": "日期填写错误截取步长应该超过192个步长","status": False}
# if max_date < end_index_plus_one_day:
# return {"reason": "选择显示真实值需要保留最终一天数据12个点作为展示信息请调整结束日期","status": False}
# else:
# if len(filtered_data) < 192:
# return {"reason": "上传文件中有效信息长度应大于192","status": False}
# predictions = windfd_model.run_inference(filtered_data)
# if predictions['status'] == True:
# predictions_value = np.array(predictions['reason']).flatten()
# predictions_value = [max(0, x) for x in predictions_value]
# target = "Power"
# if is_show:
# history = data[(data['date'] >= start_index) & (data['date'] <= end_index_plus_one_day)]
# cols = ['date']
# cols.append(target)
# history = history[cols]
# history.rename(columns={target: 'true_data'}, inplace=True)
# history['pred_data'] = 0
# total_rows = len(history)
# history.reset_index(drop=True, inplace=True)
# history.loc[total_rows - 12:total_rows - 1, 'pred_data'] = predictions_value
# else:
# history = data[(data['date'] >= start_index) & (data['date'] <= end_index)]
# history.reset_index(drop=True, inplace=True)
# cols = ['date']
# cols.append(target)
# history = history[cols]
# history.rename(columns={target: 'true_data'}, inplace=True)
# history['pred_data'] = 0
# last_date = history['date'].iloc[-1]
# # 创建新的日期和对应的值
# new_dates = [last_date + pd.Timedelta(minutes=15 * (i + 1)) for i in range(12)]
# new_data = pd.DataFrame({
# 'date': new_dates,
# 'true_data':[0]*12,
# 'pred_data': predictions_value
# })
# history = pd.concat([history, new_data], ignore_index=True)
# return {"status": True,"reason":history.to_dict(orient='records')}
# # except Exception as e:
# # return{"reason": str(e),"status":False}
def pred_single_tar(test_content):
tar_pred = []
current_directory = os.getcwd()
for name in param.meirejie_model_list_tar:
model_path = os.path.join(current_directory,"meirejie",param.meirejie_model_dict[name])
tar_model = load(model_path)
pred = tar_model.predict(test_content)
tar_pred.append(pred[0])
result = [param.meirejie_tar_mae, param.meirejie_tar_r2,tar_pred]
# 创建 DataFrame
df = pd.DataFrame(result, index=param.index, columns=param.columns)
result = df.to_json(orient='index')
return result
def pred_single_gas(test_content):
gas_pred = []
current_directory = os.getcwd()
for name in param.meirejie_model_list_gas:
model_path = os.path.join(current_directory,"meirejie",param.meirejie_model_dict[name])
gas_model = load(model_path)
pred = gas_model.predict(test_content)
gas_pred.append(pred[0])
result = [param.meirejie_gas_mae, param.meirejie_gas_r2,gas_pred]
# 创建 DataFrame
df = pd.DataFrame(result, index=param.index, columns=param.columns)
result = df.to_json(orient='index')
return result
def pred_single_water(test_content):
water_pred = []
current_directory = os.getcwd()
for name in param.meirejie_model_list_water:
model_path = os.path.join(current_directory,"meirejie",param.meirejie_model_dict[name])
# print(model_path)
water_model = load(model_path)
pred = water_model.predict(test_content)
water_pred.append(pred[0])
result = [param.meirejie_water_mae, param.meirejie_water_r2,water_pred]
# 创建 DataFrame
df = pd.DataFrame(result, index=param.index, columns=param.columns)
result = df.to_json(orient='index')
return result
def pred_single_char(test_content):
char_pred = []
current_directory = os.getcwd()
for name in param.meirejie_model_list_char:
model_path = os.path.join(current_directory,"meirejie",param.meirejie_model_dict[name])
char_model = load(model_path)
pred = char_model.predict(test_content)
char_pred.append(pred[0])
result = [param.meirejie_char_mae, param.meirejie_char_r2,char_pred]
# 创建 DataFrame
df = pd.DataFrame(result, index=param.index, columns=param.columns)
result = df.to_json(orient='index')
return result
def choose_model_meirejie(name,data):
current_directory = os.getcwd()
model_path = os.path.join(current_directory,'meirejie',param.meirejie_model_dict[name])
model = load(model_path)
pred = model.predict(data)
return pred
def get_excel_tar(model_name,file_path):
if not file_path.endswith('.csv'):
return {"status":False, "reason":"上传文件类型有误,必须是 CSV 文件"}
test_data = pd.read_csv(file_path)
expected_columns = ["A", "V", "FC", "C", "H", "N" ,"S" ,"O" ,"H/C" ,"O/C" ,"N/C" ,"Rt" ,"Hr", "dp","T","Tar"]
if list(test_data.columns) != expected_columns:
return {"status":False, "reason":f"文件列名不匹配,预期列名为: {expected_columns}"}
# raise HTTPException(status_code=400, detail=f"文件列名不匹配,预期列名为: {expected_columns}")
del test_data['Tar']
pred = choose_model_meirejie(model_name,test_data)
test_data['tar_pred'] = pred
return {"status":True, "reason":test_data.to_dict(orient='records')}
def get_excel_gas(model_name,file_path):
if not file_path.endswith('.csv'):
return {"status":False, "reason":"上传文件类型有误,必须是 CSV 文件"}
test_data = pd.read_csv(file_path)
expected_columns = ["A", "V", "FC", "C", "H", "N" ,"S" ,"O" ,"H/C" ,"O/C" ,"N/C" ,"Rt" ,"Hr", "dp","T","Gas"]
if list(test_data.columns) != expected_columns:
return {"status":False, "reason":f"文件列名不匹配,预期列名为: {expected_columns}"}
# raise HTTPException(status_code=400, detail=f"文件列名不匹配,预期列名为: {expected_columns}")
del test_data['Gas']
pred = choose_model_meirejie(model_name,test_data)
test_data['gas_pred'] = pred
return {"status":True, "reason":test_data.to_dict(orient='records')}
def get_excel_char(model_name,file_path):
if not file_path.endswith('.csv'):
return {"status":False, "reason":"上传文件类型有误,必须是 CSV 文件"}
test_data = pd.read_csv(file_path)
expected_columns = ["A", "V", "FC", "C", "H", "N" ,"S" ,"O" ,"H/C" ,"O/C" ,"N/C" ,"Rt" ,"Hr", "dp","T","Char"]
if list(test_data.columns) != expected_columns:
return {"status":False, "reason":f"文件列名不匹配,预期列名为: {expected_columns}"}
# raise HTTPException(status_code=400, detail=f"文件列名不匹配,预期列名为: {expected_columns}")
del test_data['Char']
pred = choose_model_meirejie(model_name,test_data)
test_data['char_pred'] = pred
return {"status":True, "reason":test_data.to_dict(orient='records')}
def get_excel_water(model_name,file_path):
if not file_path.endswith('.csv'):
return {"status":False, "reason":"上传文件类型有误,必须是 CSV 文件"}
test_data = pd.read_csv(file_path)
expected_columns = ["A", "V", "FC", "C", "H", "N" ,"S" ,"O" ,"H/C" ,"O/C" ,"N/C" ,"Rt" ,"Hr", "dp","T","Water"]
if list(test_data.columns) != expected_columns:
return {"status":False, "reason":f"文件列名不匹配,预期列名为: {expected_columns}"}
# raise HTTPException(status_code=400, detail=f"文件列名不匹配,预期列名为: {expected_columns}")
del test_data['Water']
pred = choose_model_meirejie(model_name,test_data)
test_data['water_pred'] = pred
return {"status":True, "reason":test_data.to_dict(orient='records')}
#==========================================煤基炭材料====================================================
def pred_single_ssa(test_content):
char_pred = []
current_directory = os.getcwd()
for name in param.meijitancailiao_model_list_ssa:
model_path = os.path.join(current_directory,"meijitancailiao",param.meijitancailiao_model_dict[name])
ssa_model = load(model_path)
pred = ssa_model.predict(test_content)
char_pred.append(pred[0])
result = [param.meijitancailiao_ssa_mae, param.meijitancailiao_ssa_r2,char_pred]
# 创建 DataFrame
df = pd.DataFrame(result, index=param.index, columns=param.columns_ssa)
result = df.to_json(orient='index')
return result
def pred_single_tpv(test_content):
char_pred = []
current_directory = os.getcwd()
for name in param.meijitancailiao_model_list_tpv:
model_path = os.path.join(current_directory,"meijitancailiao",param.meijitancailiao_model_dict[name])
tpv_model = load(model_path)
pred = tpv_model.predict(test_content)
char_pred.append(pred[0])
result = [param.meijitancailiao_tpv_mae, param.meijitancailiao_tpv_r2,char_pred]
# 创建 DataFrame
df = pd.DataFrame(result, index=param.index, columns=param.columns_tpv)
result = df.to_json(orient='index')
return result
def pred_single_meitan(test_content):
char_pred = []
current_directory = os.getcwd()
for name in param.meijitancailiao_model_list_meitan:
model_path = os.path.join(current_directory,"meijitancailiao",param.meijitancailiao_model_dict[name])
meitan_model = load(model_path)
pred = meitan_model.predict(test_content)
char_pred.append(pred[0])
result = [param.meijitancailiao_meitan_mae, param.meijitancailiao_meitan_r2,char_pred]
# 创建 DataFrame
df = pd.DataFrame(result, index=param.index, columns=param.columns_meitan)
result = df.to_json(orient='index')
return result
def pred_single_meiliqing(test_content):
char_pred = []
current_directory = os.getcwd()
for name in param.meijitancailiao_model_list_meiliqing:
model_path = os.path.join(current_directory,"meijitancailiao",param.meijitancailiao_model_dict[name])
meiliqing_model = load(model_path)
pred = meiliqing_model.predict(test_content)
char_pred.append(pred[0])
result = [param.meijitancailiao_meiliqing_mae, param.meijitancailiao_meiliqing_r2,char_pred]
# 创建 DataFrame
df = pd.DataFrame(result, index=param.index, columns=param.columns_meiliqing)
result = df.to_json(orient='index')
return result
def choose_model_meijitancailiao(name,data):
current_directory = os.getcwd()
model_path = os.path.join(current_directory,'meijitancailiao',param.meijitancailiao_model_dict[name])
model = load(model_path)
pred = model.predict(data)
return pred
def get_excel_ssa(model_name,file_path):
if not file_path.endswith('.csv'):
return {"status":False, "reason":"上传文件类型有误,必须是 CSV 文件"}
# raise HTTPException(status_code=400, detail="文件类型有误,必须是 CSV 文件")
test_data = pd.read_csv(file_path)
expected_columns = ["A", "VM", "K/C", "MM", "AT", "At", "Rt", "SSA"]
if list(test_data.columns) != expected_columns:
return {"status":False, "reason":f"文件列名不匹配,预期列名为: {expected_columns}"}
# raise HTTPException(status_code=400, detail=f"文件列名不匹配,预期列名为: {expected_columns}")
del test_data['SSA']
pred = choose_model_meijitancailiao(model_name,test_data)
test_data['ssa_pred'] = pred
return {"status":True, "reason":test_data.to_dict(orient='records')}
def get_excel_tpv(model_name,file_path):
if not file_path.endswith('.csv'):
return {"status":False, "reason":"上传文件类型有误,必须是 CSV 文件"}
# raise HTTPException(status_code=400, detail="文件类型有误,必须是 CSV 文件")
test_data = pd.read_csv(file_path)
expected_columns = ["A", "VM", "K/C", "MM", "AT", "At", "Rt", "TPV"]
if list(test_data.columns) != expected_columns:
# raise HTTPException(status_code=400, detail=f"文件列名不匹配,预期列名为: {expected_columns}")
return {"status":False, "reason":f"文件列名不匹配,预期列名为: {expected_columns}"}
del test_data['TPV']
pred = choose_model_meijitancailiao(model_name,test_data)
test_data['tpv_pred'] = pred
return {"status":True, "reason":test_data.to_dict(orient='records')}
def get_excel_meitan(model_name,file_path):
if not file_path.endswith('.csv'):
return {"status":False, "reason":"上传文件类型有误,必须是 CSV 文件"}
# raise HTTPException(status_code=400, detail="文件类型有误,必须是 CSV 文件")
test_data = pd.read_csv(file_path)
expected_columns = ["SSA", "TPV", "N", "O", "ID/IG", "J","C"]
if list(test_data.columns) != expected_columns:
# raise HTTPException(status_code=400, detail=f"文件列名不匹配,预期列名为: {expected_columns}")
return {"status":False, "reason":f"文件列名不匹配,预期列名为: {expected_columns}"}
del test_data['C']
pred = choose_model_meijitancailiao(model_name,test_data)
test_data['C_pred'] = pred
# return test_data.to_dict(orient='records')
return {"status":True, "reason":test_data.to_dict(orient='records')}
def get_excel_meiliqing(model_name,file_path):
if not file_path.endswith('.csv'):
return {"status":False, "reason":"上传文件类型有误,必须是 CSV 文件"}
test_data = pd.read_csv(file_path)
expected_columns = ["SSA", "TPV", "N", "O", "ID/IG", "J", "C"]
if list(test_data.columns) != expected_columns:
return {"status":False, "reason":f"文件列名不匹配,预期列名为: {expected_columns}"}
del test_data['C']
pred = choose_model_meijitancailiao(model_name,test_data)
test_data['C_pred'] = pred
return {"status":True, "reason":test_data.to_dict(orient='records')}
def pred_func(func_name, pred_data):
current_script_directory = os.getcwd()
ssa_model_path = os.path.join(current_script_directory,"meijitancailiao",param.simu_model_dict[func_name][0])
tpv_model_path = os.path.join(current_script_directory,"meijitancailiao",param.simu_model_dict[func_name][1])
ssa_model = load(ssa_model_path)
tpv_model = load(tpv_model_path)
# 使用模型预测
pred_data = pd.DataFrame(pred_data,columns=['A','VM','K/C','MM','AT','At','Rt'])
pred_ssa = ssa_model.predict(pred_data)
pred_tpv = tpv_model.predict(pred_data)
result = pd.DataFrame({
"SSA":pred_ssa,
"TPV":pred_tpv
})
result = pd.concat([pred_data,result], axis=1)
return result
def get_pic_url(path):
# 1. 移除本地根目录 `/root/app`
logger.info(path)
relative_path = path.replace("/root/app", "", 1)
# 2. 拼接服务器地址和端口
server = "http://124.16.151.196:13432"
# 3. 组合成完整 URL
url = f"{server}/files{relative_path}"
return url
def get_pic_path(url):
# 1. 移除协议和服务器地址
path_without_server = url.replace("http://124.16.151.196:13432", "")
# 2. 移除 /files 前缀
relative_path = path_without_server.replace("/files", "", 1)
# 3. 添加本地根目录
local_path = f"/root/app{relative_path}"
return local_path