164 lines
6.6 KiB
Python
164 lines
6.6 KiB
Python
import numpy as np
|
|
import pandas as pd
|
|
import xgboost as xgb
|
|
from sklearn.model_selection import train_test_split
|
|
import os
|
|
import joblib
|
|
from fastapi import FastAPI, HTTPException, File, UploadFile
|
|
from fastapi.responses import JSONResponse
|
|
|
|
# 获取当前工作目录
|
|
current_directory = os.getcwd()
|
|
# 获取文件存储目录
|
|
save_directory = os.path.join(current_directory,'datasets/jiawan_data')
|
|
app = FastAPI()
|
|
ch4_model_flow = joblib.load(os.path.join(current_directory,'checkpoints/jiawanyuce_liuliang/xgb_model_liuliang.pkl'))
|
|
ch4_model_gas = joblib.load(os.path.join(current_directory,'checkpoints/jiawanyuce_qixiangnongdu/xgb_model_qixiangnongdu.pkl'))
|
|
|
|
def data_feature(data,x_cols):
|
|
for name in x_cols:
|
|
|
|
data[name+"_15_first"] = data[name].shift(1)
|
|
data[name+"_30_first"] = data[name].shift(2)
|
|
data[name+"_45_first"] = data[name].shift(3)
|
|
data[name+"_60_first"] = data[name].shift(4)
|
|
|
|
data[name+"_1h_mean"] = data[name].rolling(4).mean()
|
|
data[name+"_1h_max"] = data[name].rolling(4).max()
|
|
data[name+"_1h_min"] = data[name].rolling(4).min()
|
|
data[name+"_1h_median"] = data[name].rolling(4).median()
|
|
data[name+"_1h_std"] = data[name].rolling(4).std()
|
|
data[name+"_1h_var"] = data[name].rolling(4).var()
|
|
data[name+"_1h_skew"] = data[name].rolling(4).skew()
|
|
data[name+"_1h_kurt"] = data[name].rolling(4).kurt()
|
|
data[name+"_1_diff"] = data[name].diff(periods=1)
|
|
data[name+"_2_diff"] = data[name].diff(periods=2)
|
|
|
|
data[name+"_2h_mean"] = data[name].rolling(8).mean()
|
|
data[name+"_2h_max"] = data[name].rolling(8).max()
|
|
data[name+"_2h_min"] = data[name].rolling(8).min()
|
|
data[name+"_2h_median"] = data[name].rolling(8).median()
|
|
data[name+"_2h_std"] = data[name].rolling(8).std()
|
|
data[name+"_2h_var"] = data[name].rolling(8).var()
|
|
data[name+"_2h_skew"] = data[name].rolling(8).skew()
|
|
data[name+"_2h_kurt"] = data[name].rolling(8).kurt()
|
|
|
|
|
|
|
|
# 不想要日均的了,太长了
|
|
for name in x_cols:
|
|
data[name+"_d_mean"] = data[name].rolling(4*24).mean()
|
|
data[name+"_d_max"] = data[name].rolling(4*24).max()
|
|
data[name+"_d_min"] = data[name].rolling(4*24).min()
|
|
data[name+"_d_median"] = data[name].rolling(4).median()
|
|
data[name+"_d_std"] = data[name].rolling(4*24).std()
|
|
data[name+"_d_var"] = data[name].rolling(4*24).var()
|
|
data[name+"_d_skew"] = data[name].rolling(4*24).skew()
|
|
data[name+"_d_kurt"] = data[name].rolling(4*24).kurt()
|
|
|
|
return data
|
|
|
|
|
|
def data_result(data,target_1):
|
|
data[target_1+"_1_after"] = data[target_1].shift(-1)
|
|
data[target_1+"_2_after"] = data[target_1].shift(-2)
|
|
data[target_1+"_3_after"] = data[target_1].shift(-3)
|
|
data[target_1+"_4_after"] = data[target_1].shift(-4)
|
|
return data
|
|
|
|
|
|
|
|
def get_pred_data(data,start_index,end_index):
|
|
|
|
filtered_data = data[(data['index'] >= start_index) & (data['index'] <= end_index)]
|
|
columns= ['X_ch','X_pr','X_li','X_I','Q','pH','Nm3d-1-ch4','S_gas_ch4']
|
|
data = data_feature(filtered_data,columns)
|
|
return data.iloc[-1:,:]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 获取甲烷数据,进行前端访问
|
|
@app.get("/ch4/get_ori_data")
|
|
async def get_ori_data(data_path: str):
|
|
try:
|
|
data = pd.read_csv(os.path.join(save_directory, data_path))
|
|
data_json = data.to_json(orient='records')
|
|
return JSONResponse(content=data_json)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
|
|
|
|
|
|
# 上传数据接口
|
|
@app.post("/uploadfile/ch4/")
|
|
async def upload_file(file: UploadFile = File(...)):
|
|
# 获取文件的完整路径
|
|
file_path = os.path.join(save_directory, file.filename)
|
|
|
|
# 将文件写入指定路径
|
|
with open(file_path, "wb") as buffer:
|
|
# 读取文件内容并写入
|
|
buffer.write(await file.read())
|
|
|
|
# 返回相对路径
|
|
relative_path = os.path.relpath(file_path, start=os.getcwd())
|
|
return JSONResponse(content={"file_path": relative_path})
|
|
|
|
|
|
|
|
# 截取数据开启预测,
|
|
@app.get("/ch4/start_predict")
|
|
async def start_predict_endpoint(data_path: str, start_index: int, end_index: int, type: int, is_show: bool):
|
|
try:
|
|
data = pd.read_csv(os.path.join(save_directory, data_path))
|
|
if is_show:
|
|
if len(data) < end_index + 4:
|
|
return JSONResponse(content={"error_info": "显示真实值需要保留最终后四节点作为展示信息,请调整结束信息","pred_data": None, "true_data": None})
|
|
elif len(data) < 100:
|
|
return JSONResponse(content={"error_info": "上传信息长度应大于100","pred_data": None, "true_data": None})
|
|
elif end_index - start_index < 96:
|
|
return JSONResponse(content={"error_info": "截取步长应该超过96个步长","pred_data": None, "true_data": None})
|
|
else:
|
|
pass
|
|
else:
|
|
if len(data) < 96:
|
|
return JSONResponse(content={"error_info": "上传信息长度应大于96","pred_data": None, "true_data": None})
|
|
elif end_index - start_index < 96:
|
|
return JSONResponse(content={"error_info": "截取步长应该超过96个步长","pred_data": None, "true_data": None})
|
|
else:
|
|
pass
|
|
print("start")
|
|
train_data = get_pred_data(data,start_index,end_index)
|
|
del train_data['index']
|
|
train_data = np.array(train_data.values)
|
|
train_data = xgb.DMatrix(train_data)
|
|
target = None
|
|
if type == 1: # 流量
|
|
target = "Nm3d-1-ch4"
|
|
result = ch4_model_flow.predict(train_data)
|
|
else: # 气相
|
|
target = "S_gas_ch4"
|
|
result = ch4_model_gas.predict(train_data)
|
|
if is_show:
|
|
history = data[(data['index'] >= start_index) & (data['index'] <= end_index + 4)]
|
|
history = history[target].values
|
|
else:
|
|
history = data[(data['index'] >= start_index) & (data['index'] <= end_index)]
|
|
history = history[target].values
|
|
return JSONResponse(content={"error_info": "","pred_data": result[0].tolist(), "true_data": history.tolist()})
|
|
except Exception as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
import uvicorn
|
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|
|
|
|
# data = pd.read_csv(os.path.join(save_directory,'jiawan_test.csv'))
|
|
# data['index'] = range(1, len(data) + 1)
|
|
# data.to_csv(os.path.join(save_directory,'jiawan_test.csv'),index=False)
|