1.3 MiB
1.3 MiB
In [1]:
import pandas as pd import pandas as pd import numpy as np import matplotlib.pyplot as plt #新增加的两行 from pylab import mpl # 设置显示中文字体 mpl.rcParams["font.sans-serif"] = ["SimHei"] mpl.rcParams["axes.unicode_minus"] = False
In [2]:
data = pd.read_excel('./data/mod_merge_ssr&MEIC&BUGS.xlsx') data.head()
Out[2]:
date | PM2.5 | PM10 | SO2 | NO2 | O3 | O3_8h | CO | Ox | wind-U | ... | VOC_resdient | VOC_power | VOC_agricultural | PM2.5_industrial | PM2.5_transportation | PM2.5_resdient | PM2.5_power | PM2.5_agricultural | CO_Bio | VOCs_Bio | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 2015-01-02 01:00:00 | 136.0 | 214.0 | 317.0 | 38.0 | 8.0 | 9.0 | 3.71 | 46 | 0.831775 | ... | 0.937173 | 0.037724 | 0 | 0.926851 | 0.077715 | 0.827110 | 0.436028 | 0 | 0.081546 | 4.217706 |
1 | 2015-01-02 02:00:00 | 114.0 | 176.0 | 305.0 | 38.0 | 8.0 | 9.0 | 3.55 | 46 | -0.695011 | ... | 0.937173 | 0.036215 | 0 | 0.926851 | 0.081248 | 0.827110 | 0.418587 | 0 | 0.080031 | 4.119807 |
2 | 2015-01-02 03:00:00 | 97.0 | 154.0 | 306.0 | 37.0 | 7.0 | 8.0 | 3.51 | 44 | -0.173311 | ... | 0.937173 | 0.035712 | 0 | 0.926851 | 0.088313 | 0.827110 | 0.412773 | 0 | 0.077761 | 3.973464 |
3 | 2015-01-02 04:00:00 | 87.0 | 141.0 | 316.0 | 38.0 | 7.0 | 8.0 | 3.55 | 45 | 0.000000 | ... | 0.937173 | 0.036718 | 0 | 0.926851 | 0.091256 | 0.827110 | 0.424400 | 0 | 0.076766 | 3.909235 |
4 | 2015-01-02 05:00:00 | 85.0 | 139.0 | 292.0 | 37.0 | 7.0 | 8.0 | 3.62 | 44 | 1.234518 | ... | 1.978475 | 0.039736 | 0 | 0.926851 | 0.092434 | 1.746121 | 0.459282 | 0 | 0.077119 | 3.930702 |
5 rows × 54 columns
In [3]:
data.drop(columns='O3_8h', inplace=True)
In [4]:
out_cols = data.columns[1:7].tolist() out_cols
Out[4]:
['PM2.5', 'PM10', 'SO2', 'NO2', 'O3', 'CO']
In [15]:
date_range = pd.date_range(start=data.date.min(), end=data.date.max(), freq='H') data.date = pd.to_datetime(data.date) data = data.set_index('date').reindex(date_range)
In [16]:
import datetime as dt
打印输出列的分布
In [17]:
fig = plt.figure(figsize=(15, 10)) for index, col in enumerate(out_cols): try: plt.subplot(3,3,index+1) plt.title(col) plt.hist(data[col]) except: print(col)
findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans. findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans. findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
打印特征列的分布
In [18]:
fea_cols = [x for x in data.columns if x not in out_cols]
In [19]:
fig = plt.figure(figsize=(24, 28)) for index, col in enumerate(fea_cols): try: plt.subplot(9,5,index+1) plt.title(col) plt.hist(data[col]) except: print(col) fig.savefig('fea.png')
VOCs_Bio
分别对每个输出找到相关性最大的特征
In [20]:
import seaborn as sns
In [21]:
k = 10 fig = plt.figure(figsize=(30, 30)) for i,u_col in enumerate(out_cols): use_cols = fea_cols + [u_col] corrmat = data[use_cols].corr() cols = corrmat.nlargest(k, u_col)[u_col].index cm = np.corrcoef(data[cols].values.T) sns.set(font_scale=1.25) plt.subplot(3,3,i+1) plt.title(u_col) hm = sns.heatmap(cm, cbar=True, annot=True, square=True, fmt='.2f', annot_kws={'size': 10}, yticklabels=cols.values, xticklabels=cols.values) fig.savefig('./cm.png')
/home/zhaojh/miniconda3/envs/py37/lib/python3.7/site-packages/seaborn/matrix.py:198: RuntimeWarning: All-NaN slice encountered vmin = np.nanmin(calc_data) /home/zhaojh/miniconda3/envs/py37/lib/python3.7/site-packages/seaborn/matrix.py:203: RuntimeWarning: All-NaN slice encountered vmax = np.nanmax(calc_data)
直接看相关性太低了,将过去一个时刻的指标扔进来
In [28]:
data.index.name = 'date'
In [29]:
data.reset_index(inplace=True) data.date = pd.to_datetime(data.date) data['pre_time'] = data.date.apply(lambda x: x - dt.timedelta(hours=1)) pre_out = data[['date'] + out_cols].copy() pre_out.columns = ['pre_time'] + [f'pre_{x}' for x in out_cols] pre_out.pre_time = pd.to_datetime(pre_out.pre_time) pre_out.head()
Out[29]:
pre_time | pre_PM2.5 | pre_PM10 | pre_SO2 | pre_NO2 | pre_O3 | pre_CO | |
---|---|---|---|---|---|---|---|
0 | 2015-01-02 01:00:00 | 136.0 | 214.0 | 317.0 | 38.0 | 8.0 | 3.71 |
1 | 2015-01-02 02:00:00 | 114.0 | 176.0 | 305.0 | 38.0 | 8.0 | 3.55 |
2 | 2015-01-02 03:00:00 | 97.0 | 154.0 | 306.0 | 37.0 | 7.0 | 3.51 |
3 | 2015-01-02 04:00:00 | 87.0 | 141.0 | 316.0 | 38.0 | 7.0 | 3.55 |
4 | 2015-01-02 05:00:00 | 85.0 | 139.0 | 292.0 | 37.0 | 7.0 | 3.62 |
In [30]:
use_data = data.merge(pre_out, how='left', on='pre_time').dropna() use_data.shape
Out[30]:
(46455, 60)
定义新的特征列
In [31]:
new_fea_cols = [x for x in use_data.columns if x not in out_cols]
In [32]:
k = 10 fig = plt.figure(figsize=(30, 30)) for i,u_col in enumerate(out_cols): use_cols = new_fea_cols + [u_col] corrmat = use_data[use_cols].corr() cols = corrmat.nlargest(k, u_col)[u_col].index cm = np.corrcoef(use_data[cols].values.T) sns.set(font_scale=1.25) plt.subplot(3,3,i+1) plt.title(u_col) hm = sns.heatmap(cm, cbar=True, annot=True, square=True, fmt='.2f', annot_kws={'size': 10}, yticklabels=cols.values, xticklabels=cols.values)
果然和上一时刻是强相关的,这就尴尬了。先做特征工程试试
对输出列取对数化
In [33]:
for col in out_cols: use_data[col] = np.log1p(use_data[col])
In [34]:
from scipy.stats import norm import scipy.stats as stats
In [35]:
fig = plt.figure(figsize=(15, 10)) for index, col in enumerate(out_cols): try: plt.subplot(3,3,index+1) plt.title(col) sns.distplot(data[col], fit=norm) except: print(col) fig.savefig('no-log.png')
/home/zhaojh/miniconda3/envs/py37/lib/python3.7/site-packages/seaborn/distributions.py:2619: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms). warnings.warn(msg, FutureWarning) /home/zhaojh/miniconda3/envs/py37/lib/python3.7/site-packages/seaborn/distributions.py:2619: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms). warnings.warn(msg, FutureWarning) /home/zhaojh/miniconda3/envs/py37/lib/python3.7/site-packages/seaborn/distributions.py:2619: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms). warnings.warn(msg, FutureWarning) /home/zhaojh/miniconda3/envs/py37/lib/python3.7/site-packages/seaborn/distributions.py:2619: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms). warnings.warn(msg, FutureWarning) /home/zhaojh/miniconda3/envs/py37/lib/python3.7/site-packages/seaborn/distributions.py:2619: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms). warnings.warn(msg, FutureWarning) /home/zhaojh/miniconda3/envs/py37/lib/python3.7/site-packages/seaborn/distributions.py:2619: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms). warnings.warn(msg, FutureWarning)
In [36]:
fig = plt.figure(figsize=(15, 10)) for index, col in enumerate(out_cols): try: plt.subplot(3,3,index+1) plt.title(col) sns.distplot(use_data[col], fit=norm) except: print(col) fig.savefig('logify.png')
/home/zhaojh/miniconda3/envs/py37/lib/python3.7/site-packages/seaborn/distributions.py:2619: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms). warnings.warn(msg, FutureWarning) /home/zhaojh/miniconda3/envs/py37/lib/python3.7/site-packages/seaborn/distributions.py:2619: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms). warnings.warn(msg, FutureWarning) /home/zhaojh/miniconda3/envs/py37/lib/python3.7/site-packages/seaborn/distributions.py:2619: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms). warnings.warn(msg, FutureWarning) /home/zhaojh/miniconda3/envs/py37/lib/python3.7/site-packages/seaborn/distributions.py:2619: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms). warnings.warn(msg, FutureWarning) /home/zhaojh/miniconda3/envs/py37/lib/python3.7/site-packages/seaborn/distributions.py:2619: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms). warnings.warn(msg, FutureWarning) /home/zhaojh/miniconda3/envs/py37/lib/python3.7/site-packages/seaborn/distributions.py:2619: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms). warnings.warn(msg, FutureWarning)
In [37]:
fig = plt.figure(figsize=(15, 15)) for index, col in enumerate(out_cols): try: plt.subplot(3,3,index+1) plt.title(col) rest = stats.probplot(use_data[col], plot=plt) except: print(col)
如果做minmax,是否还保持高斯分布
In [38]:
maxs = use_data[out_cols].max() mins = use_data[out_cols].min()
In [39]:
out_data = list() for col in out_cols: print(col, end=' ') d = (use_data[col] - mins[col]) / (maxs[col] - mins[col]) out_data.append(d)
PM2.5 PM10 SO2 NO2 O3 CO
In [40]:
fig = plt.figure(figsize=(15, 15)) for index, col in enumerate(out_data): try: plt.subplot(3,3,index+1) rest = stats.probplot(col, plot=plt) except: print(col)
下面这几列,数据有问题,直接drop掉
In [41]:
drop_cols = [x for x in new_fea_cols if 'agricultural' in x] + ['NH3_power'] + ['CO_Bio', 'VOCs_Bio'] drop_cols.remove('NH3_agricultural')
In [42]:
fig = plt.figure(figsize=(15, 15)) for index, col in enumerate(drop_cols): try: plt.subplot(3,4,index+1) plt.title(col) plt.hist(data[col]) except: print(col)
In [43]:
use_data.drop(columns=drop_cols, inplace=True) use_data.drop(columns=['date', 'pre_time'], inplace=True)
In [1]:
use_data.reset_index().to_csv('./data/train_data_mod.csv', encoding='utf-8-sig', index=False)
--------------------------------------------------------------------------- NameError Traceback (most recent call last) /tmp/ipykernel_1136037/1598183905.py in <module> ----> 1 use_data.reset_index().to_csv('./data/train_data_mod.csv', encoding='utf-8-sig', index=False) NameError: name 'use_data' is not defined
In [46]:
data.drop(columns=drop_cols).to_csv('./data/ori_data.csv', encoding='utf-8-sig', index=False)
In [ ]: