1
0
Fork 0
wgz_forecast/pv/vmd_decom.py

58 lines
2.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import numpy as np
import pandas as pd
import numpy as np
from vmdpy import VMD
from skopt import gp_minimize
from skopt.space import Real, Integer, Categorical
from skopt.utils import use_named_args
# 加载数据
def do_vmd(dataframe):
signal = dataframe['power'].values
# 定义VMD函数用于优化过程
def vmd_objective(params):
alpha, tau, K, DC, init, tol = params
u, u_hat, omega = VMD(signal, alpha, tau, K, DC, init, tol)
# 定义目标函数,这里使用模态的平均带宽作为优化目标
# 你可以根据实际需求定义其他目标函数
objective = np.mean([np.diff(np.where(um != 0)) for um in u])
return -objective # gp_minimize用于最小化所以取负值
# 定义参数空间
space = [
Real(low=2, high=100, prior='log-uniform', name='alpha'),
Real(low=0, high=1, name='tau'),
Integer(low=2, high=10, name='K'),
Categorical(categories=[0, 1], name='DC'),
Categorical(categories=[0, 1], name='init'),
Real(low=1e-6, high=1e-2, prior='log-uniform', name='tol')
]
# 使用贝叶斯优化
@use_named_args(space)
def objective_function(alpha, tau, K, DC, init, tol):
return vmd_objective((alpha, tau, K, DC, init, tol))
result = gp_minimize(objective_function, space, n_calls=50, random_state=0)
print('最优参数:', result.x)
print('最优目标函数值:', -result.fun) # 取负值回到原始目标函数的值
# 使用找到的最优参数进行VMD分解
alpha_opt, tau_opt, K_opt, DC_opt, init_opt, tol_opt = result.x
u_opt, u_hat_opt, omega_opt = VMD(signal, alpha_opt, tau_opt, K_opt, DC_opt, init_opt, tol_opt)
# 保存或处理分解得到的模态函数
best_params = [alpha_opt, tau_opt, K_opt, DC_opt, init_opt, tol_opt]
vmd_rst = u_opt.T
vmd_train_data = pd.concat([data2vmd, pd.DataFrame.from_records(vmd_rst, index=data2vmd.index, columns=[f"vmd_{x}"for x in range(vmd_rst.shape[1])])], axis=1)
return vmd_train_data
if __name__ == '__main__':
data2vmd = pd.read_csv('./data/pv_data_hourly.csv', index_col=0)
train_data = do_vmd(data2vmd)
train_data.to_csv('./data/vmd_train.csv', index=False, encoding='utf-8-sig')