166 lines
5.9 KiB
Python
166 lines
5.9 KiB
Python
|
from cmath import log
|
|||
|
import pandas as pd
|
|||
|
import os
|
|||
|
import numpy as np
|
|||
|
from prophet import Prophet
|
|||
|
import datetime as dt
|
|||
|
from get_holiday_cn.client import getHoliday
|
|||
|
from logzero import logger
|
|||
|
import pickle
|
|||
|
import matplotlib.pyplot as plt
|
|||
|
|
|||
|
|
|||
|
def concat_date(x:str, y:str):
|
|||
|
"""_summary_
|
|||
|
|
|||
|
Args:
|
|||
|
x (str): 年月日
|
|||
|
y (str): 小时
|
|||
|
|
|||
|
Returns:
|
|||
|
_type_: 合成的时间
|
|||
|
"""
|
|||
|
time_str = f"{x} {y}:00:00"
|
|||
|
return dt.datetime.strptime(time_str, "%Y%m%d %H:%M:%S")
|
|||
|
|
|||
|
|
|||
|
def load_data():
|
|||
|
data_folder = [x for x in os.listdir('./data/') if x.startswith('城市_')]
|
|||
|
data_folder.sort()
|
|||
|
# 一个读取数据并合成成一个大文件的函数
|
|||
|
total_data = pd.DataFrame()
|
|||
|
for folder in data_folder:
|
|||
|
files = os.listdir(f"./data/{folder}")
|
|||
|
files.sort()
|
|||
|
for file in files:
|
|||
|
if file.endswith('csv'):
|
|||
|
data = pd.read_csv(f'./data/{folder}/{file}')
|
|||
|
use_data = data[(data['type']=='PM2.5')|(data['type']=='O3')].copy()
|
|||
|
total_data = pd.concat([total_data, use_data])
|
|||
|
total_data['ds'] = total_data.apply(lambda x: concat_date(x.date, x.hour), axis=1)
|
|||
|
total_data.ds = pd.to_datetime(total_data.ds)
|
|||
|
total_data.sort_values(by='ds', ascending=True, inplace=True)
|
|||
|
total_data.reset_index(drop=True, inplace=True)
|
|||
|
logger.info(f"总数据集大小:{total_data.shape}")
|
|||
|
return total_data
|
|||
|
|
|||
|
|
|||
|
def build_model(city: str, data: pd.DataFrame, dtype:str, holiday_mode:dict, split_date="2021-01-01 00:00:00"):
|
|||
|
"""_summary_
|
|||
|
|
|||
|
Args:
|
|||
|
city (str): 城市
|
|||
|
data (pd.DataFrame): 数据
|
|||
|
dtype (str): O3还是PM2.5
|
|||
|
holiday_mode (dict): 假期字典
|
|||
|
split_date (str, optional): 划分训练测试的分割日期. Defaults to "2021-01-01".
|
|||
|
|
|||
|
Returns:
|
|||
|
model: 模型
|
|||
|
forecast: 对该组数据的预测
|
|||
|
"""
|
|||
|
logger.info(f"选择了 {city} 的 {dtype} 数据,")
|
|||
|
use_data = data[(data['type']==dtype)][["ds", city]].copy()
|
|||
|
use_data.columns = ["ds", "y"]
|
|||
|
train_data = use_data[use_data.ds < split_date].copy()
|
|||
|
logger.info(train_data.iloc[-1].ds)
|
|||
|
test_data = use_data[use_data.ds >= split_date].copy()
|
|||
|
model=Prophet(
|
|||
|
growth="linear",
|
|||
|
yearly_seasonality=True,
|
|||
|
weekly_seasonality=True,
|
|||
|
daily_seasonality=True,
|
|||
|
seasonality_mode="multiplicative",
|
|||
|
seasonality_prior_scale=12,
|
|||
|
holidays=holiday_mode,
|
|||
|
n_changepoints= 100, # change points num, default=25
|
|||
|
)
|
|||
|
model.fit(train_data)
|
|||
|
future = model.make_future_dataframe(365*24, freq='H', include_history=True)
|
|||
|
forecast=model.predict(future)
|
|||
|
model.plot_components(forecast)
|
|||
|
plt.savefig(f'./figure/{city}_{dtype}_components.png')
|
|||
|
return model, forecast
|
|||
|
|
|||
|
|
|||
|
def get_date_type(date:str, holiday_client:getHoliday):
|
|||
|
"""一个判断某个日期是哪种假期的类
|
|||
|
|
|||
|
Args:
|
|||
|
date (str): "YYYY-MM-DD"
|
|||
|
holiday_client (getHoliday): object of getHoliday class
|
|||
|
|
|||
|
Returns:
|
|||
|
str: oridinary for simple day and others for special day
|
|||
|
"""
|
|||
|
rst = holiday_client.assemble_holiday_data(today=date)
|
|||
|
if rst.get('code') == 0:
|
|||
|
if rst.get('holiday') is None:
|
|||
|
return 'oridinary'
|
|||
|
else:
|
|||
|
return rst.get('holiday').get('name')
|
|||
|
|
|||
|
|
|||
|
def build_holiday(start_date:str="2015-01-01", end_date:str="2021-12-31"):
|
|||
|
"""基于起止日期,将该时间段内的国内假期都找出来,包括本应该放假但是最后调休上班的
|
|||
|
|
|||
|
Args:
|
|||
|
start_date (str): 以"YYYY-MM-DD"形式的字符串, 默认2015-01-01
|
|||
|
end_date (_type_): 以"YYYY-MM-DD"形式的字符串,默认2021-12-31
|
|||
|
|
|||
|
Returns:
|
|||
|
_type_: _description_
|
|||
|
"""
|
|||
|
ds_list = pd.DataFrame(pd.date_range(start=start_date, end=end_date, freq='D'), columns=['date'])
|
|||
|
ds_list.date = ds_list.date.apply(lambda x: dt.datetime.strftime(x, format='%Y-%m-%d'))
|
|||
|
client = getHoliday()
|
|||
|
ds_list['day_type'] = ds_list.date.apply(lambda x: get_date_type(x, client))
|
|||
|
special_date = ds_list[ds_list.day_type != 'simple'].copy()
|
|||
|
special_date.columns = ['ds', 'holiday']
|
|||
|
return special_date
|
|||
|
|
|||
|
def train(data_type, city_list, data):
|
|||
|
model_dict = dict()
|
|||
|
predict_dict = dict()
|
|||
|
holiday_data = build_holiday(data.ds.min(), data.ds.max())
|
|||
|
for city in city_list:
|
|||
|
model, pred = build_model(city, data, data_type, holiday_data, '2021-01-01')
|
|||
|
model_dict[city] = model
|
|||
|
predict_dict[city] = pred
|
|||
|
logger.info(f"{city} 模型构建完成")
|
|||
|
if not os.path.exists('./result/'):
|
|||
|
os.mkdir('./result/')
|
|||
|
if not os.path.exists(f'./result/{data_type}/'):
|
|||
|
os.mkdir(f'./result/{data_type}')
|
|||
|
if not os.path.exists(f'./result/{data_type}/model/'):
|
|||
|
os.mkdir(f'./result/{data_type}/model')
|
|||
|
if not os.path.exists(f'./result/{data_type}/data/'):
|
|||
|
os.mkdir(f'./result/{data_type}/data/')
|
|||
|
for city in predict_dict:
|
|||
|
city_pred = predict_dict.get(city)
|
|||
|
city_pred.to_csv(f'./result/{data_type}/data/{city}.csv', encoding='utf-8', index=False)
|
|||
|
logger.info(f"{city} 预测数据保存完成")
|
|||
|
for city in model_dict:
|
|||
|
city_model = model_dict.get(city)
|
|||
|
with open(f'./result/{data_type}/model/{city}.pkl', 'wb') as fwb:
|
|||
|
pickle.dump(city_model, fwb)
|
|||
|
logger.info(f"{city} 模型保存完成")
|
|||
|
|
|||
|
return model_dict, predict_dict
|
|||
|
|
|||
|
if __name__ == '__main__':
|
|||
|
data_type = 'O3' # 修改此处以切换数据类型
|
|||
|
city_list = ['北京'] # 修改此处以添加城市
|
|||
|
if os.path.exists('./data/total_data.csv'):
|
|||
|
data = pd.read_csv('./data/total_data.csv')
|
|||
|
else:
|
|||
|
data = load_data()
|
|||
|
data.to_csv('./data/total_data.csv', encoding='utf-8', index=False)
|
|||
|
|
|||
|
model_dict, pred_list = train(data_type, city_list, data)
|
|||
|
'''
|
|||
|
# if test
|
|||
|
# 从存储的模型中加载
|
|||
|
with open('./result/O3/model/北京.pkl', 'rb') as fr:
|
|||
|
local_model = pickle.load(fr)
|
|||
|
'''
|