ai_platform_regression/prophet_predict/prophet_predict.py

83 lines
2.7 KiB
Python
Raw Permalink Normal View History

2022-12-07 10:43:52 +08:00
import pandas as pd
from prophet import Prophet
import datetime as dt
from get_holiday_cn.client import getHoliday
from logzero import logger
2022-12-08 16:22:29 +08:00
def run_prophet(data: pd.DataFrame, period: int = 1, freq: str = 'D'):
2022-12-07 10:43:52 +08:00
"""_summary_
Args:
data (pd.DataFrame): 训练数据
holiday_mode (dict): 假期设置
period (int, optional): 预测跨度即预测未来多少步. Defaults to 1.
freq (str, optional): pd.date_range中任何有效的频率均可, 例如M, D, H, T, S等. Defaults to 'D'.
Returns:
_type_: _description_
"""
assert period > 0
assert 'ds' in data.columns and 'y' in data.columns
2022-12-08 16:22:29 +08:00
try:
2022-12-07 10:43:52 +08:00
data.ds = pd.to_datetime(data.ds)
except Exception as e:
return e
holiday_data = build_holiday(data.ds.min(), data.ds.max())
train_data = data.copy()
2022-12-08 16:22:29 +08:00
model = Prophet(
2022-12-07 10:43:52 +08:00
growth="linear",
yearly_seasonality=True,
weekly_seasonality=True,
daily_seasonality=True,
seasonality_mode="multiplicative",
seasonality_prior_scale=12,
holidays=holiday_data,
2022-12-08 16:22:29 +08:00
n_changepoints=100, # change points num, default=25
)
2022-12-07 10:43:52 +08:00
model.fit(train_data)
future = model.make_future_dataframe(periods=period, freq=freq, include_history=True)
2022-12-08 16:22:29 +08:00
forecast = model.predict(future)
2022-12-07 10:43:52 +08:00
return forecast
2022-12-08 16:22:29 +08:00
def get_date_type(date: str, holiday_client: getHoliday):
2022-12-07 10:43:52 +08:00
"""一个判断某个日期是哪种假期的类
Args:
date (str): "YYYY-MM-DD"
holiday_client (getHoliday): object of getHoliday class
Returns:
str: oridinary for simple day and others for special day
"""
rst = holiday_client.assemble_holiday_data(today=date)
if rst.get('code') == 0:
if rst.get('holiday') is None:
return 'oridinary'
else:
return rst.get('holiday').get('name')
2022-12-08 16:22:29 +08:00
def build_holiday(start_date: str = "2015-01-01", end_date: str = "2021-12-31"):
2022-12-07 10:43:52 +08:00
"""基于起止日期,将该时间段内的国内假期都找出来,包括本应该放假但是最后调休上班的
Args:
start_date (str): "YYYY-MM-DD"形式的字符串 默认2015-01-01
end_date (_type_): "YYYY-MM-DD"形式的字符串默认2021-12-31
Returns:
_type_: _description_
"""
ds_list = pd.DataFrame(pd.date_range(start=start_date, end=end_date, freq='D'), columns=['date'])
ds_list.date = ds_list.date.apply(lambda x: dt.datetime.strftime(x, format='%Y-%m-%d'))
client = getHoliday()
ds_list['day_type'] = ds_list.date.apply(lambda x: get_date_type(x, client))
2022-12-08 16:22:29 +08:00
special_date = ds_list[ds_list.day_type != 'simple'].copy()
2022-12-07 10:43:52 +08:00
special_date.columns = ['ds', 'holiday']
return special_date
if __name__ == '__main__':
2022-12-08 16:22:29 +08:00
pass