162 lines
6.7 KiB
Python
162 lines
6.7 KiB
Python
# import numpy as np
|
|
# import pandas as pd
|
|
# import xgboost as xgb
|
|
# from sklearn.model_selection import train_test_split
|
|
# import os
|
|
# import joblib
|
|
# from fastapi import FastAPI, HTTPException, File, UploadFile
|
|
# from fastapi.responses import JSONResponse
|
|
|
|
# def data_feature(data,x_cols):
|
|
# for name in x_cols:
|
|
|
|
# data[name+"_15_first"] = data[name].shift(1)
|
|
# data[name+"_30_first"] = data[name].shift(2)
|
|
# data[name+"_45_first"] = data[name].shift(3)
|
|
# data[name+"_60_first"] = data[name].shift(4)
|
|
|
|
# data[name+"_1h_mean"] = data[name].rolling(4).mean()
|
|
# data[name+"_1h_max"] = data[name].rolling(4).max()
|
|
# data[name+"_1h_min"] = data[name].rolling(4).min()
|
|
# data[name+"_1h_median"] = data[name].rolling(4).median()
|
|
# data[name+"_1h_std"] = data[name].rolling(4).std()
|
|
# data[name+"_1h_var"] = data[name].rolling(4).var()
|
|
# data[name+"_1h_skew"] = data[name].rolling(4).skew()
|
|
# data[name+"_1h_kurt"] = data[name].rolling(4).kurt()
|
|
# data[name+"_1_diff"] = data[name].diff(periods=1)
|
|
# data[name+"_2_diff"] = data[name].diff(periods=2)
|
|
|
|
# data[name+"_2h_mean"] = data[name].rolling(8).mean()
|
|
# data[name+"_2h_max"] = data[name].rolling(8).max()
|
|
# data[name+"_2h_min"] = data[name].rolling(8).min()
|
|
# data[name+"_2h_median"] = data[name].rolling(8).median()
|
|
# data[name+"_2h_std"] = data[name].rolling(8).std()
|
|
# data[name+"_2h_var"] = data[name].rolling(8).var()
|
|
# data[name+"_2h_skew"] = data[name].rolling(8).skew()
|
|
# data[name+"_2h_kurt"] = data[name].rolling(8).kurt()
|
|
|
|
|
|
|
|
# # 不想要日均的了,太长了
|
|
# for name in x_cols:
|
|
# data[name+"_d_mean"] = data[name].rolling(4*24).mean()
|
|
# data[name+"_d_max"] = data[name].rolling(4*24).max()
|
|
# data[name+"_d_min"] = data[name].rolling(4*24).min()
|
|
# data[name+"_d_median"] = data[name].rolling(4).median()
|
|
# data[name+"_d_std"] = data[name].rolling(4*24).std()
|
|
# data[name+"_d_var"] = data[name].rolling(4*24).var()
|
|
# data[name+"_d_skew"] = data[name].rolling(4*24).skew()
|
|
# data[name+"_d_kurt"] = data[name].rolling(4*24).kurt()
|
|
|
|
# return data
|
|
|
|
|
|
# def get_data_result(data,target_1):
|
|
# data[target_1+"_1_after"] = data[target_1].shift(-1)
|
|
# data[target_1+"_2_after"] = data[target_1].shift(-2)
|
|
# data[target_1+"_3_after"] = data[target_1].shift(-3)
|
|
# data[target_1+"_4_after"] = data[target_1].shift(-4)
|
|
# return data
|
|
|
|
|
|
|
|
# def get_pred_data(data,start_index,end_index):
|
|
|
|
# filtered_data = data[(data['index'] >= start_index) & (data['index'] <= end_index)]
|
|
# columns= ['X_ch','X_pr','X_li','X_I','Q','pH','Nm3d-1-ch4','S_gas_ch4']
|
|
# data = data_feature(filtered_data,columns)
|
|
# return data.iloc[-1:,:]
|
|
|
|
|
|
# # 获取当前工作目录
|
|
# current_directory = os.getcwd()
|
|
# # 获取文件存储目录
|
|
# save_directory = os.path.join(current_directory,'datasets/jiawan_data')
|
|
|
|
# ch4_model_flow = joblib.load(os.path.join(current_directory,'checkpoints/jiawanyuce_liuliang/xgb_model_liuliang.pkl'))
|
|
# ch4_model_gas = joblib.load(os.path.join(current_directory,'checkpoints/jiawanyuce_qixiangnongdu/xgb_model_qixiangnongdu.pkl'))
|
|
|
|
# is_show = True
|
|
|
|
# data = pd.read_csv(os.path.join(save_directory, 'jiawan_test.csv'))
|
|
# train_data = get_pred_data(data,35,35+100)
|
|
# del train_data['index']
|
|
# train_data = np.array(train_data.values)
|
|
# train_data = xgb.DMatrix(train_data)
|
|
|
|
# target = "Nm3d-1-ch4"
|
|
|
|
# start_index = 1
|
|
# end_index = 100
|
|
|
|
# if is_show:
|
|
# # data_result = data[(data['index'] >= 35+100) & (data['index'] <= 35+100 + 4)]
|
|
# # print(data_result.index.values)
|
|
# # test_data = get_data_result(data_result,target)
|
|
# # print(test_data.iloc[:1, -4:])
|
|
# data_result = data[(data['index'] >= start_index) & (data['index'] <= end_index+4)]
|
|
# print(len(data_result[target].values))
|
|
# else:
|
|
# data_result = data[(data['index'] >= start_index) & (data['index'] <= end_index)]
|
|
# print(len(data_result[target].values))
|
|
|
|
# result = ch4_model_flow.predict(train_data)
|
|
# history = data[(data['index'] >= start_index) & (data['index'] <= end_index + 4)]
|
|
# history = history[target].values
|
|
# print(result[0])
|
|
# print(history)
|
|
# print(JSONResponse(content={"error_info": "","pred_data": result[0].tolist(), "true_data": history.tolist()}))
|
|
|
|
import requests
|
|
|
|
# 设置 API 的基本 URL
|
|
BASE_URL = "http://127.0.0.1:8000" # 请根据实际运行地址调整
|
|
|
|
def test_get_ori_data(data_path):
|
|
response = requests.get(f"{BASE_URL}/ch4/get_ori_data?data_path={data_path}")
|
|
if response.status_code == 200:
|
|
print("GET /ch4/get_ori_data 测试通过")
|
|
print("返回数据:", response.json())
|
|
else:
|
|
print("GET /ch4/get_ori_data 测试失败")
|
|
print("状态码:", response.status_code)
|
|
print("错误信息:", response.text)
|
|
|
|
def test_upload_file(file_path):
|
|
with open(file_path, 'rb') as f:
|
|
response = requests.post(f"{BASE_URL}/uploadfile/ch4/", files={"file": f})
|
|
if response.status_code == 200:
|
|
print("POST /uploadfile/ch4/ 测试通过")
|
|
print("返回数据:", response.json())
|
|
else:
|
|
print("POST /uploadfile/ch4/ 测试失败")
|
|
print("状态码:", response.status_code)
|
|
print("错误信息:", response.text)
|
|
|
|
def test_start_predict(data_path, start_index, end_index, type, is_show):
|
|
response = requests.get(f"{BASE_URL}/ch4/start_predict?data_path={data_path}&start_index={start_index}&end_index={end_index}&type={type}&is_show={is_show}")
|
|
if response.status_code == 200:
|
|
print("GET /ch4/start_predict 测试通过")
|
|
print(response.json())
|
|
print("预测数据:", response.json()["pred_data"])
|
|
print("真实数据:", response.json()["true_data"])
|
|
print(len(response.json()["pred_data"]), len(response.json()["true_data"]))
|
|
else:
|
|
print("GET /ch4/start_predict 测试失败")
|
|
print("状态码:", response.status_code)
|
|
print("错误信息:", response.text)
|
|
|
|
if __name__ == "__main__":
|
|
# 请根据实际情况替换数据路径
|
|
test_data_path = "jiawan_test.csv" # 需要测试的文件路径
|
|
test_file_path = "jiawan_data/test_file.csv" # 上传文件的路径
|
|
|
|
# 测试获取原始数据接口
|
|
# test_get_ori_data(test_data_path)
|
|
|
|
# 测试上传文件接口
|
|
# test_upload_file(test_file_path)
|
|
|
|
# 测试开始预测接口
|
|
test_start_predict(test_data_path, start_index=1, end_index=105, type=2, is_show=False)
|