166 lines
5.9 KiB
Python
166 lines
5.9 KiB
Python
from cmath import log
|
||
import pandas as pd
|
||
import os
|
||
import numpy as np
|
||
from prophet import Prophet
|
||
import datetime as dt
|
||
from get_holiday_cn.client import getHoliday
|
||
from logzero import logger
|
||
import pickle
|
||
import matplotlib.pyplot as plt
|
||
|
||
|
||
def concat_date(x:str, y:str):
|
||
"""_summary_
|
||
|
||
Args:
|
||
x (str): 年月日
|
||
y (str): 小时
|
||
|
||
Returns:
|
||
_type_: 合成的时间
|
||
"""
|
||
time_str = f"{x} {y}:00:00"
|
||
return dt.datetime.strptime(time_str, "%Y%m%d %H:%M:%S")
|
||
|
||
|
||
def load_data():
|
||
data_folder = [x for x in os.listdir('./data/') if x.startswith('城市_')]
|
||
data_folder.sort()
|
||
# 一个读取数据并合成成一个大文件的函数
|
||
total_data = pd.DataFrame()
|
||
for folder in data_folder:
|
||
files = os.listdir(f"./data/{folder}")
|
||
files.sort()
|
||
for file in files:
|
||
if file.endswith('csv'):
|
||
data = pd.read_csv(f'./data/{folder}/{file}')
|
||
use_data = data[(data['type']=='PM2.5')|(data['type']=='O3')].copy()
|
||
total_data = pd.concat([total_data, use_data])
|
||
total_data['ds'] = total_data.apply(lambda x: concat_date(x.date, x.hour), axis=1)
|
||
total_data.ds = pd.to_datetime(total_data.ds)
|
||
total_data.sort_values(by='ds', ascending=True, inplace=True)
|
||
total_data.reset_index(drop=True, inplace=True)
|
||
logger.info(f"总数据集大小:{total_data.shape}")
|
||
return total_data
|
||
|
||
|
||
def build_model(city: str, data: pd.DataFrame, dtype:str, holiday_mode:dict, split_date="2021-01-01 00:00:00"):
|
||
"""_summary_
|
||
|
||
Args:
|
||
city (str): 城市
|
||
data (pd.DataFrame): 数据
|
||
dtype (str): O3还是PM2.5
|
||
holiday_mode (dict): 假期字典
|
||
split_date (str, optional): 划分训练测试的分割日期. Defaults to "2021-01-01".
|
||
|
||
Returns:
|
||
model: 模型
|
||
forecast: 对该组数据的预测
|
||
"""
|
||
logger.info(f"选择了 {city} 的 {dtype} 数据,")
|
||
use_data = data[(data['type']==dtype)][["ds", city]].copy()
|
||
use_data.columns = ["ds", "y"]
|
||
train_data = use_data[use_data.ds < split_date].copy()
|
||
logger.info(train_data.iloc[-1].ds)
|
||
test_data = use_data[use_data.ds >= split_date].copy()
|
||
model=Prophet(
|
||
growth="linear",
|
||
yearly_seasonality=True,
|
||
weekly_seasonality=True,
|
||
daily_seasonality=True,
|
||
seasonality_mode="multiplicative",
|
||
seasonality_prior_scale=12,
|
||
holidays=holiday_mode,
|
||
n_changepoints= 100, # change points num, default=25
|
||
)
|
||
model.fit(train_data)
|
||
future = model.make_future_dataframe(365*24, freq='H', include_history=True)
|
||
forecast=model.predict(future)
|
||
model.plot_components(forecast)
|
||
plt.savefig(f'./figure/{city}_{dtype}_components.png')
|
||
return model, forecast
|
||
|
||
|
||
def get_date_type(date:str, holiday_client:getHoliday):
|
||
"""一个判断某个日期是哪种假期的类
|
||
|
||
Args:
|
||
date (str): "YYYY-MM-DD"
|
||
holiday_client (getHoliday): object of getHoliday class
|
||
|
||
Returns:
|
||
str: oridinary for simple day and others for special day
|
||
"""
|
||
rst = holiday_client.assemble_holiday_data(today=date)
|
||
if rst.get('code') == 0:
|
||
if rst.get('holiday') is None:
|
||
return 'oridinary'
|
||
else:
|
||
return rst.get('holiday').get('name')
|
||
|
||
|
||
def build_holiday(start_date:str="2015-01-01", end_date:str="2021-12-31"):
|
||
"""基于起止日期,将该时间段内的国内假期都找出来,包括本应该放假但是最后调休上班的
|
||
|
||
Args:
|
||
start_date (str): 以"YYYY-MM-DD"形式的字符串, 默认2015-01-01
|
||
end_date (_type_): 以"YYYY-MM-DD"形式的字符串,默认2021-12-31
|
||
|
||
Returns:
|
||
_type_: _description_
|
||
"""
|
||
ds_list = pd.DataFrame(pd.date_range(start=start_date, end=end_date, freq='D'), columns=['date'])
|
||
ds_list.date = ds_list.date.apply(lambda x: dt.datetime.strftime(x, format='%Y-%m-%d'))
|
||
client = getHoliday()
|
||
ds_list['day_type'] = ds_list.date.apply(lambda x: get_date_type(x, client))
|
||
special_date = ds_list[ds_list.day_type != 'simple'].copy()
|
||
special_date.columns = ['ds', 'holiday']
|
||
return special_date
|
||
|
||
def train(data_type, city_list, data):
|
||
model_dict = dict()
|
||
predict_dict = dict()
|
||
holiday_data = build_holiday(data.ds.min(), data.ds.max())
|
||
for city in city_list:
|
||
model, pred = build_model(city, data, data_type, holiday_data, '2021-01-01')
|
||
model_dict[city] = model
|
||
predict_dict[city] = pred
|
||
logger.info(f"{city} 模型构建完成")
|
||
if not os.path.exists('./result/'):
|
||
os.mkdir('./result/')
|
||
if not os.path.exists(f'./result/{data_type}/'):
|
||
os.mkdir(f'./result/{data_type}')
|
||
if not os.path.exists(f'./result/{data_type}/model/'):
|
||
os.mkdir(f'./result/{data_type}/model')
|
||
if not os.path.exists(f'./result/{data_type}/data/'):
|
||
os.mkdir(f'./result/{data_type}/data/')
|
||
for city in predict_dict:
|
||
city_pred = predict_dict.get(city)
|
||
city_pred.to_csv(f'./result/{data_type}/data/{city}.csv', encoding='utf-8', index=False)
|
||
logger.info(f"{city} 预测数据保存完成")
|
||
for city in model_dict:
|
||
city_model = model_dict.get(city)
|
||
with open(f'./result/{data_type}/model/{city}.pkl', 'wb') as fwb:
|
||
pickle.dump(city_model, fwb)
|
||
logger.info(f"{city} 模型保存完成")
|
||
|
||
return model_dict, predict_dict
|
||
|
||
if __name__ == '__main__':
|
||
data_type = 'O3' # 修改此处以切换数据类型
|
||
city_list = ['北京'] # 修改此处以添加城市
|
||
if os.path.exists('./data/total_data.csv'):
|
||
data = pd.read_csv('./data/total_data.csv')
|
||
else:
|
||
data = load_data()
|
||
data.to_csv('./data/total_data.csv', encoding='utf-8', index=False)
|
||
|
||
model_dict, pred_list = train(data_type, city_list, data)
|
||
'''
|
||
# if test
|
||
# 从存储的模型中加载
|
||
with open('./result/O3/model/北京.pkl', 'rb') as fr:
|
||
local_model = pickle.load(fr)
|
||
''' |