22-T67/prophet_city_hours.py

166 lines
5.9 KiB
Python
Raw Permalink Normal View History

2023-03-30 10:25:44 +08:00
from cmath import log
import pandas as pd
import os
import numpy as np
from prophet import Prophet
import datetime as dt
from get_holiday_cn.client import getHoliday
from logzero import logger
import pickle
import matplotlib.pyplot as plt
def concat_date(x:str, y:str):
"""_summary_
Args:
x (str): 年月日
y (str): 小时
Returns:
_type_: 合成的时间
"""
time_str = f"{x} {y}:00:00"
return dt.datetime.strptime(time_str, "%Y%m%d %H:%M:%S")
def load_data():
data_folder = [x for x in os.listdir('./data/') if x.startswith('城市_')]
data_folder.sort()
# 一个读取数据并合成成一个大文件的函数
total_data = pd.DataFrame()
for folder in data_folder:
files = os.listdir(f"./data/{folder}")
files.sort()
for file in files:
if file.endswith('csv'):
data = pd.read_csv(f'./data/{folder}/{file}')
use_data = data[(data['type']=='PM2.5')|(data['type']=='O3')].copy()
total_data = pd.concat([total_data, use_data])
total_data['ds'] = total_data.apply(lambda x: concat_date(x.date, x.hour), axis=1)
total_data.ds = pd.to_datetime(total_data.ds)
total_data.sort_values(by='ds', ascending=True, inplace=True)
total_data.reset_index(drop=True, inplace=True)
logger.info(f"总数据集大小:{total_data.shape}")
return total_data
def build_model(city: str, data: pd.DataFrame, dtype:str, holiday_mode:dict, split_date="2021-01-01 00:00:00"):
"""_summary_
Args:
city (str): 城市
data (pd.DataFrame): 数据
dtype (str): O3还是PM2.5
holiday_mode (dict): 假期字典
split_date (str, optional): 划分训练测试的分割日期. Defaults to "2021-01-01".
Returns:
model: 模型
forecast: 对该组数据的预测
"""
logger.info(f"选择了 {city}{dtype} 数据,")
use_data = data[(data['type']==dtype)][["ds", city]].copy()
use_data.columns = ["ds", "y"]
train_data = use_data[use_data.ds < split_date].copy()
logger.info(train_data.iloc[-1].ds)
test_data = use_data[use_data.ds >= split_date].copy()
model=Prophet(
growth="linear",
yearly_seasonality=True,
weekly_seasonality=True,
daily_seasonality=True,
seasonality_mode="multiplicative",
seasonality_prior_scale=12,
holidays=holiday_mode,
n_changepoints= 100, # change points num, default=25
)
model.fit(train_data)
future = model.make_future_dataframe(365*24, freq='H', include_history=True)
forecast=model.predict(future)
model.plot_components(forecast)
plt.savefig(f'./figure/{city}_{dtype}_components.png')
return model, forecast
def get_date_type(date:str, holiday_client:getHoliday):
"""一个判断某个日期是哪种假期的类
Args:
date (str): "YYYY-MM-DD"
holiday_client (getHoliday): object of getHoliday class
Returns:
str: oridinary for simple day and others for special day
"""
rst = holiday_client.assemble_holiday_data(today=date)
if rst.get('code') == 0:
if rst.get('holiday') is None:
return 'oridinary'
else:
return rst.get('holiday').get('name')
def build_holiday(start_date:str="2015-01-01", end_date:str="2021-12-31"):
"""基于起止日期,将该时间段内的国内假期都找出来,包括本应该放假但是最后调休上班的
Args:
start_date (str): "YYYY-MM-DD"形式的字符串 默认2015-01-01
end_date (_type_): "YYYY-MM-DD"形式的字符串默认2021-12-31
Returns:
_type_: _description_
"""
ds_list = pd.DataFrame(pd.date_range(start=start_date, end=end_date, freq='D'), columns=['date'])
ds_list.date = ds_list.date.apply(lambda x: dt.datetime.strftime(x, format='%Y-%m-%d'))
client = getHoliday()
ds_list['day_type'] = ds_list.date.apply(lambda x: get_date_type(x, client))
special_date = ds_list[ds_list.day_type != 'simple'].copy()
special_date.columns = ['ds', 'holiday']
return special_date
def train(data_type, city_list, data):
model_dict = dict()
predict_dict = dict()
holiday_data = build_holiday(data.ds.min(), data.ds.max())
for city in city_list:
model, pred = build_model(city, data, data_type, holiday_data, '2021-01-01')
model_dict[city] = model
predict_dict[city] = pred
logger.info(f"{city} 模型构建完成")
if not os.path.exists('./result/'):
os.mkdir('./result/')
if not os.path.exists(f'./result/{data_type}/'):
os.mkdir(f'./result/{data_type}')
if not os.path.exists(f'./result/{data_type}/model/'):
os.mkdir(f'./result/{data_type}/model')
if not os.path.exists(f'./result/{data_type}/data/'):
os.mkdir(f'./result/{data_type}/data/')
for city in predict_dict:
city_pred = predict_dict.get(city)
city_pred.to_csv(f'./result/{data_type}/data/{city}.csv', encoding='utf-8', index=False)
logger.info(f"{city} 预测数据保存完成")
for city in model_dict:
city_model = model_dict.get(city)
with open(f'./result/{data_type}/model/{city}.pkl', 'wb') as fwb:
pickle.dump(city_model, fwb)
logger.info(f"{city} 模型保存完成")
return model_dict, predict_dict
if __name__ == '__main__':
data_type = 'O3' # 修改此处以切换数据类型
city_list = ['北京'] # 修改此处以添加城市
if os.path.exists('./data/total_data.csv'):
data = pd.read_csv('./data/total_data.csv')
else:
data = load_data()
data.to_csv('./data/total_data.csv', encoding='utf-8', index=False)
model_dict, pred_list = train(data_type, city_list, data)
'''
# if test
# 从存储的模型中加载
with open('./result/O3/model/北京.pkl', 'rb') as fr:
local_model = pickle.load(fr)
'''