81 lines
2.7 KiB
Python
81 lines
2.7 KiB
Python
|
import pandas as pd
|
|||
|
from prophet import Prophet
|
|||
|
import datetime as dt
|
|||
|
from get_holiday_cn.client import getHoliday
|
|||
|
from logzero import logger
|
|||
|
|
|||
|
def run_prophet(data: pd.DataFrame, period:int=1, freq:str='D'):
|
|||
|
"""_summary_
|
|||
|
|
|||
|
Args:
|
|||
|
data (pd.DataFrame): 训练数据
|
|||
|
holiday_mode (dict): 假期设置
|
|||
|
period (int, optional): 预测跨度,即预测未来多少步. Defaults to 1.
|
|||
|
freq (str, optional): pd.date_range中任何有效的频率均可, 例如M, D, H, T, S等. Defaults to 'D'.
|
|||
|
|
|||
|
Returns:
|
|||
|
_type_: _description_
|
|||
|
"""
|
|||
|
assert period > 0
|
|||
|
assert 'ds' in data.columns and 'y' in data.columns
|
|||
|
try:
|
|||
|
data.ds = pd.to_datetime(data.ds)
|
|||
|
except Exception as e:
|
|||
|
return e
|
|||
|
holiday_data = build_holiday(data.ds.min(), data.ds.max())
|
|||
|
train_data = data.copy()
|
|||
|
model=Prophet(
|
|||
|
growth="linear",
|
|||
|
yearly_seasonality=True,
|
|||
|
weekly_seasonality=True,
|
|||
|
daily_seasonality=True,
|
|||
|
seasonality_mode="multiplicative",
|
|||
|
seasonality_prior_scale=12,
|
|||
|
holidays=holiday_data,
|
|||
|
n_changepoints= 100, # change points num, default=25
|
|||
|
)
|
|||
|
model.fit(train_data)
|
|||
|
future = model.make_future_dataframe(periods=period, freq=freq, include_history=True)
|
|||
|
forecast=model.predict(future)
|
|||
|
return forecast
|
|||
|
|
|||
|
|
|||
|
def get_date_type(date:str, holiday_client:getHoliday):
|
|||
|
"""一个判断某个日期是哪种假期的类
|
|||
|
|
|||
|
Args:
|
|||
|
date (str): "YYYY-MM-DD"
|
|||
|
holiday_client (getHoliday): object of getHoliday class
|
|||
|
|
|||
|
Returns:
|
|||
|
str: oridinary for simple day and others for special day
|
|||
|
"""
|
|||
|
rst = holiday_client.assemble_holiday_data(today=date)
|
|||
|
if rst.get('code') == 0:
|
|||
|
if rst.get('holiday') is None:
|
|||
|
return 'oridinary'
|
|||
|
else:
|
|||
|
return rst.get('holiday').get('name')
|
|||
|
|
|||
|
|
|||
|
def build_holiday(start_date:str="2015-01-01", end_date:str="2021-12-31"):
|
|||
|
"""基于起止日期,将该时间段内的国内假期都找出来,包括本应该放假但是最后调休上班的
|
|||
|
|
|||
|
Args:
|
|||
|
start_date (str): 以"YYYY-MM-DD"形式的字符串, 默认2015-01-01
|
|||
|
end_date (_type_): 以"YYYY-MM-DD"形式的字符串,默认2021-12-31
|
|||
|
|
|||
|
Returns:
|
|||
|
_type_: _description_
|
|||
|
"""
|
|||
|
ds_list = pd.DataFrame(pd.date_range(start=start_date, end=end_date, freq='D'), columns=['date'])
|
|||
|
ds_list.date = ds_list.date.apply(lambda x: dt.datetime.strftime(x, format='%Y-%m-%d'))
|
|||
|
client = getHoliday()
|
|||
|
ds_list['day_type'] = ds_list.date.apply(lambda x: get_date_type(x, client))
|
|||
|
special_date = ds_list[ds_list.day_type != 'simple'].copy()
|
|||
|
special_date.columns = ['ds', 'holiday']
|
|||
|
return special_date
|
|||
|
|
|||
|
|
|||
|
if __name__ == '__main__':
|
|||
|
pass
|