import pandas as pd from prophet import Prophet import datetime as dt from get_holiday_cn.client import getHoliday from logzero import logger def run_prophet(data: pd.DataFrame, period: int = 1, freq: str = 'D'): """_summary_ Args: data (pd.DataFrame): 训练数据 holiday_mode (dict): 假期设置 period (int, optional): 预测跨度,即预测未来多少步. Defaults to 1. freq (str, optional): pd.date_range中任何有效的频率均可, 例如M, D, H, T, S等. Defaults to 'D'. Returns: _type_: _description_ """ assert period > 0 assert 'ds' in data.columns and 'y' in data.columns try: data.ds = pd.to_datetime(data.ds) except Exception as e: return e holiday_data = build_holiday(data.ds.min(), data.ds.max()) train_data = data.copy() model = Prophet( growth="linear", yearly_seasonality=True, weekly_seasonality=True, daily_seasonality=True, seasonality_mode="multiplicative", seasonality_prior_scale=12, holidays=holiday_data, n_changepoints=100, # change points num, default=25 ) model.fit(train_data) future = model.make_future_dataframe(periods=period, freq=freq, include_history=True) forecast = model.predict(future) return forecast def get_date_type(date: str, holiday_client: getHoliday): """一个判断某个日期是哪种假期的类 Args: date (str): "YYYY-MM-DD" holiday_client (getHoliday): object of getHoliday class Returns: str: oridinary for simple day and others for special day """ rst = holiday_client.assemble_holiday_data(today=date) if rst.get('code') == 0: if rst.get('holiday') is None: return 'oridinary' else: return rst.get('holiday').get('name') def build_holiday(start_date: str = "2015-01-01", end_date: str = "2021-12-31"): """基于起止日期,将该时间段内的国内假期都找出来,包括本应该放假但是最后调休上班的 Args: start_date (str): 以"YYYY-MM-DD"形式的字符串, 默认2015-01-01 end_date (_type_): 以"YYYY-MM-DD"形式的字符串,默认2021-12-31 Returns: _type_: _description_ """ ds_list = pd.DataFrame(pd.date_range(start=start_date, end=end_date, freq='D'), columns=['date']) ds_list.date = ds_list.date.apply(lambda x: dt.datetime.strftime(x, format='%Y-%m-%d')) client = getHoliday() ds_list['day_type'] = ds_list.date.apply(lambda x: get_date_type(x, client)) special_date = ds_list[ds_list.day_type != 'simple'].copy() special_date.columns = ['ds', 'holiday'] return special_date if __name__ == '__main__': pass