22-T67/prophet_city_hours.py

166 lines
5.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from cmath import log
import pandas as pd
import os
import numpy as np
from prophet import Prophet
import datetime as dt
from get_holiday_cn.client import getHoliday
from logzero import logger
import pickle
import matplotlib.pyplot as plt
def concat_date(x:str, y:str):
"""_summary_
Args:
x (str): 年月日
y (str): 小时
Returns:
_type_: 合成的时间
"""
time_str = f"{x} {y}:00:00"
return dt.datetime.strptime(time_str, "%Y%m%d %H:%M:%S")
def load_data():
data_folder = [x for x in os.listdir('./data/') if x.startswith('城市_')]
data_folder.sort()
# 一个读取数据并合成成一个大文件的函数
total_data = pd.DataFrame()
for folder in data_folder:
files = os.listdir(f"./data/{folder}")
files.sort()
for file in files:
if file.endswith('csv'):
data = pd.read_csv(f'./data/{folder}/{file}')
use_data = data[(data['type']=='PM2.5')|(data['type']=='O3')].copy()
total_data = pd.concat([total_data, use_data])
total_data['ds'] = total_data.apply(lambda x: concat_date(x.date, x.hour), axis=1)
total_data.ds = pd.to_datetime(total_data.ds)
total_data.sort_values(by='ds', ascending=True, inplace=True)
total_data.reset_index(drop=True, inplace=True)
logger.info(f"总数据集大小:{total_data.shape}")
return total_data
def build_model(city: str, data: pd.DataFrame, dtype:str, holiday_mode:dict, split_date="2021-01-01 00:00:00"):
"""_summary_
Args:
city (str): 城市
data (pd.DataFrame): 数据
dtype (str): O3还是PM2.5
holiday_mode (dict): 假期字典
split_date (str, optional): 划分训练测试的分割日期. Defaults to "2021-01-01".
Returns:
model: 模型
forecast: 对该组数据的预测
"""
logger.info(f"选择了 {city}{dtype} 数据,")
use_data = data[(data['type']==dtype)][["ds", city]].copy()
use_data.columns = ["ds", "y"]
train_data = use_data[use_data.ds < split_date].copy()
logger.info(train_data.iloc[-1].ds)
test_data = use_data[use_data.ds >= split_date].copy()
model=Prophet(
growth="linear",
yearly_seasonality=True,
weekly_seasonality=True,
daily_seasonality=True,
seasonality_mode="multiplicative",
seasonality_prior_scale=12,
holidays=holiday_mode,
n_changepoints= 100, # change points num, default=25
)
model.fit(train_data)
future = model.make_future_dataframe(365*24, freq='H', include_history=True)
forecast=model.predict(future)
model.plot_components(forecast)
plt.savefig(f'./figure/{city}_{dtype}_components.png')
return model, forecast
def get_date_type(date:str, holiday_client:getHoliday):
"""一个判断某个日期是哪种假期的类
Args:
date (str): "YYYY-MM-DD"
holiday_client (getHoliday): object of getHoliday class
Returns:
str: oridinary for simple day and others for special day
"""
rst = holiday_client.assemble_holiday_data(today=date)
if rst.get('code') == 0:
if rst.get('holiday') is None:
return 'oridinary'
else:
return rst.get('holiday').get('name')
def build_holiday(start_date:str="2015-01-01", end_date:str="2021-12-31"):
"""基于起止日期,将该时间段内的国内假期都找出来,包括本应该放假但是最后调休上班的
Args:
start_date (str): 以"YYYY-MM-DD"形式的字符串, 默认2015-01-01
end_date (_type_): 以"YYYY-MM-DD"形式的字符串默认2021-12-31
Returns:
_type_: _description_
"""
ds_list = pd.DataFrame(pd.date_range(start=start_date, end=end_date, freq='D'), columns=['date'])
ds_list.date = ds_list.date.apply(lambda x: dt.datetime.strftime(x, format='%Y-%m-%d'))
client = getHoliday()
ds_list['day_type'] = ds_list.date.apply(lambda x: get_date_type(x, client))
special_date = ds_list[ds_list.day_type != 'simple'].copy()
special_date.columns = ['ds', 'holiday']
return special_date
def train(data_type, city_list, data):
model_dict = dict()
predict_dict = dict()
holiday_data = build_holiday(data.ds.min(), data.ds.max())
for city in city_list:
model, pred = build_model(city, data, data_type, holiday_data, '2021-01-01')
model_dict[city] = model
predict_dict[city] = pred
logger.info(f"{city} 模型构建完成")
if not os.path.exists('./result/'):
os.mkdir('./result/')
if not os.path.exists(f'./result/{data_type}/'):
os.mkdir(f'./result/{data_type}')
if not os.path.exists(f'./result/{data_type}/model/'):
os.mkdir(f'./result/{data_type}/model')
if not os.path.exists(f'./result/{data_type}/data/'):
os.mkdir(f'./result/{data_type}/data/')
for city in predict_dict:
city_pred = predict_dict.get(city)
city_pred.to_csv(f'./result/{data_type}/data/{city}.csv', encoding='utf-8', index=False)
logger.info(f"{city} 预测数据保存完成")
for city in model_dict:
city_model = model_dict.get(city)
with open(f'./result/{data_type}/model/{city}.pkl', 'wb') as fwb:
pickle.dump(city_model, fwb)
logger.info(f"{city} 模型保存完成")
return model_dict, predict_dict
if __name__ == '__main__':
data_type = 'O3' # 修改此处以切换数据类型
city_list = ['北京'] # 修改此处以添加城市
if os.path.exists('./data/total_data.csv'):
data = pd.read_csv('./data/total_data.csv')
else:
data = load_data()
data.to_csv('./data/total_data.csv', encoding='utf-8', index=False)
model_dict, pred_list = train(data_type, city_list, data)
'''
# if test
# 从存储的模型中加载
with open('./result/O3/model/北京.pkl', 'rb') as fr:
local_model = pickle.load(fr)
'''