ai-station-code/jiawanyuce/test.py

162 lines
6.7 KiB
Python
Raw Permalink Normal View History

2025-05-06 11:18:48 +08:00
# import numpy as np
# import pandas as pd
# import xgboost as xgb
# from sklearn.model_selection import train_test_split
# import os
# import joblib
# from fastapi import FastAPI, HTTPException, File, UploadFile
# from fastapi.responses import JSONResponse
# def data_feature(data,x_cols):
# for name in x_cols:
# data[name+"_15_first"] = data[name].shift(1)
# data[name+"_30_first"] = data[name].shift(2)
# data[name+"_45_first"] = data[name].shift(3)
# data[name+"_60_first"] = data[name].shift(4)
# data[name+"_1h_mean"] = data[name].rolling(4).mean()
# data[name+"_1h_max"] = data[name].rolling(4).max()
# data[name+"_1h_min"] = data[name].rolling(4).min()
# data[name+"_1h_median"] = data[name].rolling(4).median()
# data[name+"_1h_std"] = data[name].rolling(4).std()
# data[name+"_1h_var"] = data[name].rolling(4).var()
# data[name+"_1h_skew"] = data[name].rolling(4).skew()
# data[name+"_1h_kurt"] = data[name].rolling(4).kurt()
# data[name+"_1_diff"] = data[name].diff(periods=1)
# data[name+"_2_diff"] = data[name].diff(periods=2)
# data[name+"_2h_mean"] = data[name].rolling(8).mean()
# data[name+"_2h_max"] = data[name].rolling(8).max()
# data[name+"_2h_min"] = data[name].rolling(8).min()
# data[name+"_2h_median"] = data[name].rolling(8).median()
# data[name+"_2h_std"] = data[name].rolling(8).std()
# data[name+"_2h_var"] = data[name].rolling(8).var()
# data[name+"_2h_skew"] = data[name].rolling(8).skew()
# data[name+"_2h_kurt"] = data[name].rolling(8).kurt()
# # 不想要日均的了,太长了
# for name in x_cols:
# data[name+"_d_mean"] = data[name].rolling(4*24).mean()
# data[name+"_d_max"] = data[name].rolling(4*24).max()
# data[name+"_d_min"] = data[name].rolling(4*24).min()
# data[name+"_d_median"] = data[name].rolling(4).median()
# data[name+"_d_std"] = data[name].rolling(4*24).std()
# data[name+"_d_var"] = data[name].rolling(4*24).var()
# data[name+"_d_skew"] = data[name].rolling(4*24).skew()
# data[name+"_d_kurt"] = data[name].rolling(4*24).kurt()
# return data
# def get_data_result(data,target_1):
# data[target_1+"_1_after"] = data[target_1].shift(-1)
# data[target_1+"_2_after"] = data[target_1].shift(-2)
# data[target_1+"_3_after"] = data[target_1].shift(-3)
# data[target_1+"_4_after"] = data[target_1].shift(-4)
# return data
# def get_pred_data(data,start_index,end_index):
# filtered_data = data[(data['index'] >= start_index) & (data['index'] <= end_index)]
# columns= ['X_ch','X_pr','X_li','X_I','Q','pH','Nm3d-1-ch4','S_gas_ch4']
# data = data_feature(filtered_data,columns)
# return data.iloc[-1:,:]
# # 获取当前工作目录
# current_directory = os.getcwd()
# # 获取文件存储目录
# save_directory = os.path.join(current_directory,'datasets/jiawan_data')
# ch4_model_flow = joblib.load(os.path.join(current_directory,'checkpoints/jiawanyuce_liuliang/xgb_model_liuliang.pkl'))
# ch4_model_gas = joblib.load(os.path.join(current_directory,'checkpoints/jiawanyuce_qixiangnongdu/xgb_model_qixiangnongdu.pkl'))
# is_show = True
# data = pd.read_csv(os.path.join(save_directory, 'jiawan_test.csv'))
# train_data = get_pred_data(data,35,35+100)
# del train_data['index']
# train_data = np.array(train_data.values)
# train_data = xgb.DMatrix(train_data)
# target = "Nm3d-1-ch4"
# start_index = 1
# end_index = 100
# if is_show:
# # data_result = data[(data['index'] >= 35+100) & (data['index'] <= 35+100 + 4)]
# # print(data_result.index.values)
# # test_data = get_data_result(data_result,target)
# # print(test_data.iloc[:1, -4:])
# data_result = data[(data['index'] >= start_index) & (data['index'] <= end_index+4)]
# print(len(data_result[target].values))
# else:
# data_result = data[(data['index'] >= start_index) & (data['index'] <= end_index)]
# print(len(data_result[target].values))
# result = ch4_model_flow.predict(train_data)
# history = data[(data['index'] >= start_index) & (data['index'] <= end_index + 4)]
# history = history[target].values
# print(result[0])
# print(history)
# print(JSONResponse(content={"error_info": "","pred_data": result[0].tolist(), "true_data": history.tolist()}))
import requests
# 设置 API 的基本 URL
BASE_URL = "http://127.0.0.1:8000" # 请根据实际运行地址调整
def test_get_ori_data(data_path):
response = requests.get(f"{BASE_URL}/ch4/get_ori_data?data_path={data_path}")
if response.status_code == 200:
print("GET /ch4/get_ori_data 测试通过")
print("返回数据:", response.json())
else:
print("GET /ch4/get_ori_data 测试失败")
print("状态码:", response.status_code)
print("错误信息:", response.text)
def test_upload_file(file_path):
with open(file_path, 'rb') as f:
response = requests.post(f"{BASE_URL}/uploadfile/ch4/", files={"file": f})
if response.status_code == 200:
print("POST /uploadfile/ch4/ 测试通过")
print("返回数据:", response.json())
else:
print("POST /uploadfile/ch4/ 测试失败")
print("状态码:", response.status_code)
print("错误信息:", response.text)
def test_start_predict(data_path, start_index, end_index, type, is_show):
response = requests.get(f"{BASE_URL}/ch4/start_predict?data_path={data_path}&start_index={start_index}&end_index={end_index}&type={type}&is_show={is_show}")
if response.status_code == 200:
print("GET /ch4/start_predict 测试通过")
print(response.json())
print("预测数据:", response.json()["pred_data"])
print("真实数据:", response.json()["true_data"])
print(len(response.json()["pred_data"]), len(response.json()["true_data"]))
else:
print("GET /ch4/start_predict 测试失败")
print("状态码:", response.status_code)
print("错误信息:", response.text)
if __name__ == "__main__":
# 请根据实际情况替换数据路径
test_data_path = "jiawan_test.csv" # 需要测试的文件路径
test_file_path = "jiawan_data/test_file.csv" # 上传文件的路径
# 测试获取原始数据接口
# test_get_ori_data(test_data_path)
# 测试上传文件接口
# test_upload_file(test_file_path)
# 测试开始预测接口
test_start_predict(test_data_path, start_index=1, end_index=105, type=2, is_show=False)