83 lines
2.7 KiB
Python
83 lines
2.7 KiB
Python
import pandas as pd
|
||
from prophet import Prophet
|
||
import datetime as dt
|
||
from get_holiday_cn.client import getHoliday
|
||
from logzero import logger
|
||
|
||
|
||
def run_prophet(data: pd.DataFrame, period: int = 1, freq: str = 'D'):
|
||
"""_summary_
|
||
|
||
Args:
|
||
data (pd.DataFrame): 训练数据
|
||
holiday_mode (dict): 假期设置
|
||
period (int, optional): 预测跨度,即预测未来多少步. Defaults to 1.
|
||
freq (str, optional): pd.date_range中任何有效的频率均可, 例如M, D, H, T, S等. Defaults to 'D'.
|
||
|
||
Returns:
|
||
_type_: _description_
|
||
"""
|
||
assert period > 0
|
||
assert 'ds' in data.columns and 'y' in data.columns
|
||
try:
|
||
data.ds = pd.to_datetime(data.ds)
|
||
except Exception as e:
|
||
return e
|
||
holiday_data = build_holiday(data.ds.min(), data.ds.max())
|
||
train_data = data.copy()
|
||
model = Prophet(
|
||
growth="linear",
|
||
yearly_seasonality=True,
|
||
weekly_seasonality=True,
|
||
daily_seasonality=True,
|
||
seasonality_mode="multiplicative",
|
||
seasonality_prior_scale=12,
|
||
holidays=holiday_data,
|
||
n_changepoints=100, # change points num, default=25
|
||
)
|
||
model.fit(train_data)
|
||
future = model.make_future_dataframe(periods=period, freq=freq, include_history=True)
|
||
forecast = model.predict(future)
|
||
return forecast
|
||
|
||
|
||
def get_date_type(date: str, holiday_client: getHoliday):
|
||
"""一个判断某个日期是哪种假期的类
|
||
|
||
Args:
|
||
date (str): "YYYY-MM-DD"
|
||
holiday_client (getHoliday): object of getHoliday class
|
||
|
||
Returns:
|
||
str: oridinary for simple day and others for special day
|
||
"""
|
||
rst = holiday_client.assemble_holiday_data(today=date)
|
||
if rst.get('code') == 0:
|
||
if rst.get('holiday') is None:
|
||
return 'oridinary'
|
||
else:
|
||
return rst.get('holiday').get('name')
|
||
|
||
|
||
def build_holiday(start_date: str = "2015-01-01", end_date: str = "2021-12-31"):
|
||
"""基于起止日期,将该时间段内的国内假期都找出来,包括本应该放假但是最后调休上班的
|
||
|
||
Args:
|
||
start_date (str): 以"YYYY-MM-DD"形式的字符串, 默认2015-01-01
|
||
end_date (_type_): 以"YYYY-MM-DD"形式的字符串,默认2021-12-31
|
||
|
||
Returns:
|
||
_type_: _description_
|
||
"""
|
||
ds_list = pd.DataFrame(pd.date_range(start=start_date, end=end_date, freq='D'), columns=['date'])
|
||
ds_list.date = ds_list.date.apply(lambda x: dt.datetime.strftime(x, format='%Y-%m-%d'))
|
||
client = getHoliday()
|
||
ds_list['day_type'] = ds_list.date.apply(lambda x: get_date_type(x, client))
|
||
special_date = ds_list[ds_list.day_type != 'simple'].copy()
|
||
special_date.columns = ['ds', 'holiday']
|
||
return special_date
|
||
|
||
|
||
if __name__ == '__main__':
|
||
pass
|