wgz_forecast/pv/pv_train.py

35 lines
1.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
import pandas as pd
import numpy as np
def time_series_to_supervised(data, n_in=10, n_out=1,dropnan=True):
"""
:param data:作为列表或2D NumPy数组的观察序列。需要。
:param n_in:作为输入的滞后观察数X。值可以在[1..len数据]之间可选。默认为1。
:param n_out:作为输出的观测数量y。值可以在[0..len数据]之间。可选的。默认为1。
:param dropnan:Boolean是否删除具有NaN值的行。可选的。默认为True。
:return:
"""
n_vars = 1 if type(data) is list else data.shape[1]
df = pd.DataFrame(data)
origNames = df.columns
cols, names = list(), list()
cols.append(df.shift(0))
names += [('%s' % origNames[j]) for j in range(n_vars)]
n_in = max(1, n_in)
for i in range(n_in-1, 0, -1):
time = '(t-%d)' % i
cols.append(df.shift(i))
names += [('%s%s' % (origNames[j], time)) for j in range(n_vars)]
n_out = max(n_out, 0)
for i in range(1, n_out+1):
time = '(t+%d)' % i
cols.append(df.shift(-i))
names += [('%s%s' % (origNames[j], time)) for j in range(n_vars)]
agg = pd.concat(cols, axis=1)
agg.columns = names
if dropnan:
agg.dropna(inplace=True)
return agg