ai-station-code/jiawanyuce/test_xgb.py

164 lines
6.6 KiB
Python

import numpy as np
import pandas as pd
import xgboost as xgb
from sklearn.model_selection import train_test_split
import os
import joblib
from fastapi import FastAPI, HTTPException, File, UploadFile
from fastapi.responses import JSONResponse
# 获取当前工作目录
current_directory = os.getcwd()
# 获取文件存储目录
save_directory = os.path.join(current_directory,'datasets/jiawan_data')
app = FastAPI()
ch4_model_flow = joblib.load(os.path.join(current_directory,'checkpoints/jiawanyuce_liuliang/xgb_model_liuliang.pkl'))
ch4_model_gas = joblib.load(os.path.join(current_directory,'checkpoints/jiawanyuce_qixiangnongdu/xgb_model_qixiangnongdu.pkl'))
def data_feature(data,x_cols):
for name in x_cols:
data[name+"_15_first"] = data[name].shift(1)
data[name+"_30_first"] = data[name].shift(2)
data[name+"_45_first"] = data[name].shift(3)
data[name+"_60_first"] = data[name].shift(4)
data[name+"_1h_mean"] = data[name].rolling(4).mean()
data[name+"_1h_max"] = data[name].rolling(4).max()
data[name+"_1h_min"] = data[name].rolling(4).min()
data[name+"_1h_median"] = data[name].rolling(4).median()
data[name+"_1h_std"] = data[name].rolling(4).std()
data[name+"_1h_var"] = data[name].rolling(4).var()
data[name+"_1h_skew"] = data[name].rolling(4).skew()
data[name+"_1h_kurt"] = data[name].rolling(4).kurt()
data[name+"_1_diff"] = data[name].diff(periods=1)
data[name+"_2_diff"] = data[name].diff(periods=2)
data[name+"_2h_mean"] = data[name].rolling(8).mean()
data[name+"_2h_max"] = data[name].rolling(8).max()
data[name+"_2h_min"] = data[name].rolling(8).min()
data[name+"_2h_median"] = data[name].rolling(8).median()
data[name+"_2h_std"] = data[name].rolling(8).std()
data[name+"_2h_var"] = data[name].rolling(8).var()
data[name+"_2h_skew"] = data[name].rolling(8).skew()
data[name+"_2h_kurt"] = data[name].rolling(8).kurt()
# 不想要日均的了,太长了
for name in x_cols:
data[name+"_d_mean"] = data[name].rolling(4*24).mean()
data[name+"_d_max"] = data[name].rolling(4*24).max()
data[name+"_d_min"] = data[name].rolling(4*24).min()
data[name+"_d_median"] = data[name].rolling(4).median()
data[name+"_d_std"] = data[name].rolling(4*24).std()
data[name+"_d_var"] = data[name].rolling(4*24).var()
data[name+"_d_skew"] = data[name].rolling(4*24).skew()
data[name+"_d_kurt"] = data[name].rolling(4*24).kurt()
return data
def data_result(data,target_1):
data[target_1+"_1_after"] = data[target_1].shift(-1)
data[target_1+"_2_after"] = data[target_1].shift(-2)
data[target_1+"_3_after"] = data[target_1].shift(-3)
data[target_1+"_4_after"] = data[target_1].shift(-4)
return data
def get_pred_data(data,start_index,end_index):
filtered_data = data[(data['index'] >= start_index) & (data['index'] <= end_index)]
columns= ['X_ch','X_pr','X_li','X_I','Q','pH','Nm3d-1-ch4','S_gas_ch4']
data = data_feature(filtered_data,columns)
return data.iloc[-1:,:]
# 获取甲烷数据,进行前端访问
@app.get("/ch4/get_ori_data")
async def get_ori_data(data_path: str):
try:
data = pd.read_csv(os.path.join(save_directory, data_path))
data_json = data.to_json(orient='records')
return JSONResponse(content=data_json)
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
# 上传数据接口
@app.post("/uploadfile/ch4/")
async def upload_file(file: UploadFile = File(...)):
# 获取文件的完整路径
file_path = os.path.join(save_directory, file.filename)
# 将文件写入指定路径
with open(file_path, "wb") as buffer:
# 读取文件内容并写入
buffer.write(await file.read())
# 返回相对路径
relative_path = os.path.relpath(file_path, start=os.getcwd())
return JSONResponse(content={"file_path": relative_path})
# 截取数据开启预测,
@app.get("/ch4/start_predict")
async def start_predict_endpoint(data_path: str, start_index: int, end_index: int, type: int, is_show: bool):
try:
data = pd.read_csv(os.path.join(save_directory, data_path))
if is_show:
if len(data) < end_index + 4:
return JSONResponse(content={"error_info": "显示真实值需要保留最终后四节点作为展示信息,请调整结束信息","pred_data": None, "true_data": None})
elif len(data) < 100:
return JSONResponse(content={"error_info": "上传信息长度应大于100","pred_data": None, "true_data": None})
elif end_index - start_index < 96:
return JSONResponse(content={"error_info": "截取步长应该超过96个步长","pred_data": None, "true_data": None})
else:
pass
else:
if len(data) < 96:
return JSONResponse(content={"error_info": "上传信息长度应大于96","pred_data": None, "true_data": None})
elif end_index - start_index < 96:
return JSONResponse(content={"error_info": "截取步长应该超过96个步长","pred_data": None, "true_data": None})
else:
pass
print("start")
train_data = get_pred_data(data,start_index,end_index)
del train_data['index']
train_data = np.array(train_data.values)
train_data = xgb.DMatrix(train_data)
target = None
if type == 1: # 流量
target = "Nm3d-1-ch4"
result = ch4_model_flow.predict(train_data)
else: # 气相
target = "S_gas_ch4"
result = ch4_model_gas.predict(train_data)
if is_show:
history = data[(data['index'] >= start_index) & (data['index'] <= end_index + 4)]
history = history[target].values
else:
history = data[(data['index'] >= start_index) & (data['index'] <= end_index)]
history = history[target].values
return JSONResponse(content={"error_info": "","pred_data": result[0].tolist(), "true_data": history.tolist()})
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
if __name__ == '__main__':
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
# data = pd.read_csv(os.path.join(save_directory,'jiawan_test.csv'))
# data['index'] = range(1, len(data) + 1)
# data.to_csv(os.path.join(save_directory,'jiawan_test.csv'),index=False)