wgz_forecast/wgz_forecast(lh)/carbon/carbon_train.py

68 lines
3.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pandas as pd
import numpy as np
import xgboost as xgb
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score
import joblib
from logzero import logger
def time_series_to_supervised(data, columns, n_in=24, n_out=1,dropnan=True):
"""
:param data:作为列表或2D NumPy数组的观察序列。需要。
:param n_in:作为输入的滞后观察数X。值可以在[1..len数据]之间可选。默认为1。
:param n_out:作为输出的观测数量y。值可以在[0..len数据]之间。可选的。默认为1。
:param dropnan:Boolean是否删除具有NaN值的行。可选的。默认为True。
:return:
"""
logger.info(f"正在处理训练数据size{data.shape}")
n_vars = 1 if type(data) is list else data.shape[1]
df = pd.DataFrame(data)
origNames = columns
cols, names = list(), list()
cols.append(df.shift(0))
names += [('%s' % origNames[j]) for j in range(n_vars)]
n_in = max(1, n_in)
for i in range(n_in-1, 0, -1):
time = '(t-%d)' % i
cols.append(df.shift(i))
names += [('%s%s' % (origNames[j], time)) for j in range(n_vars)]
n_out = max(n_out, 0)
for i in range(1, n_out+1):
time = '(t+%d)' % i
cols.append(df.shift(-i))
names += [('%s%s' % (origNames[j], time)) for j in range(n_vars)]
agg = pd.concat(cols, axis=1)
agg.columns = names
if dropnan:
agg.dropna(inplace=True)
return agg
def train_model(train_data: pd.DataFrame):
"""训练模型的函数,需要根据模型类型实际调整
Args:
data (pd.DataFrame): 训练集
"""
# 特征和输出列名,需要根据业务场景灵活处理
fea_cols = train_data.columns[:-1].tolist()
out_cols = train_data.columns[-1:].tolist()
logger.info(f"特征列: {fea_cols}, 输出列: {out_cols}")
X = train_data[fea_cols]
y = train_data[out_cols]
train_X,test_X,train_y,test_y = train_test_split(X, y, test_size=0.2, random_state=42)
valid_X,test_X,valid_y,test_y = train_test_split(test_X, test_y, test_size=0.5, random_state=42)
# 参数
other_params = {'learning_rate': 0.1, 'n_estimators': 150, 'max_depth': 10, 'min_child_weight': 1, 'seed': 0, 'subsample': 0.8, 'colsample_bytree': 0.8, 'gamma': 0, 'reg_alpha': 0, 'reg_lambda': 1}
print(train_X.shape, train_y.shape)
gbm = xgb.XGBRegressor(objective='reg:squarederror', early_stopping_rounds=20,**other_params)
gbm.fit(train_X.values, train_y.values, eval_set=[(valid_X.values, valid_y.values)])
y_pred = gbm.predict(test_X.values)
logger.info(f"Root Mean Squared Error on Test set: {np.sqrt(mean_squared_error(test_y, y_pred))}")
logger.info(f"R2 score on Test set: {r2_score(test_y, y_pred)}")
joblib.dump(gbm, './models/carbon_pred.joblib')
logger.info(f"save_path: ./models/carbon_pred.joblib")
if __name__ == '__main__':
data = pd.read_csv('./data/carbon_data_hourly.csv', index_col=0)
agg = time_series_to_supervised(data.values, data.columns, 24, 1)
train_model(agg)