ai_platform_regression/prophet_predict/prophet_predict.py

83 lines
2.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pandas as pd
from prophet import Prophet
import datetime as dt
from get_holiday_cn.client import getHoliday
from logzero import logger
def run_prophet(data: pd.DataFrame, period: int = 1, freq: str = 'D'):
"""_summary_
Args:
data (pd.DataFrame): 训练数据
holiday_mode (dict): 假期设置
period (int, optional): 预测跨度,即预测未来多少步. Defaults to 1.
freq (str, optional): pd.date_range中任何有效的频率均可, 例如M, D, H, T, S等. Defaults to 'D'.
Returns:
_type_: _description_
"""
assert period > 0
assert 'ds' in data.columns and 'y' in data.columns
try:
data.ds = pd.to_datetime(data.ds)
except Exception as e:
return e
holiday_data = build_holiday(data.ds.min(), data.ds.max())
train_data = data.copy()
model = Prophet(
growth="linear",
yearly_seasonality=True,
weekly_seasonality=True,
daily_seasonality=True,
seasonality_mode="multiplicative",
seasonality_prior_scale=12,
holidays=holiday_data,
n_changepoints=100, # change points num, default=25
)
model.fit(train_data)
future = model.make_future_dataframe(periods=period, freq=freq, include_history=True)
forecast = model.predict(future)
return forecast
def get_date_type(date: str, holiday_client: getHoliday):
"""一个判断某个日期是哪种假期的类
Args:
date (str): "YYYY-MM-DD"
holiday_client (getHoliday): object of getHoliday class
Returns:
str: oridinary for simple day and others for special day
"""
rst = holiday_client.assemble_holiday_data(today=date)
if rst.get('code') == 0:
if rst.get('holiday') is None:
return 'oridinary'
else:
return rst.get('holiday').get('name')
def build_holiday(start_date: str = "2015-01-01", end_date: str = "2021-12-31"):
"""基于起止日期,将该时间段内的国内假期都找出来,包括本应该放假但是最后调休上班的
Args:
start_date (str): 以"YYYY-MM-DD"形式的字符串, 默认2015-01-01
end_date (_type_): 以"YYYY-MM-DD"形式的字符串默认2021-12-31
Returns:
_type_: _description_
"""
ds_list = pd.DataFrame(pd.date_range(start=start_date, end=end_date, freq='D'), columns=['date'])
ds_list.date = ds_list.date.apply(lambda x: dt.datetime.strftime(x, format='%Y-%m-%d'))
client = getHoliday()
ds_list['day_type'] = ds_list.date.apply(lambda x: get_date_type(x, client))
special_date = ds_list[ds_list.day_type != 'simple'].copy()
special_date.columns = ['ds', 'holiday']
return special_date
if __name__ == '__main__':
pass