22-T67/数据预处理+特征工程.ipynb

1.3 MiB
Raw Permalink Blame History

In [1]:
import pandas as pd
import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
#新增加的两行
from pylab import mpl
# 设置显示中文字体
mpl.rcParams["font.sans-serif"] = ["SimHei"]

mpl.rcParams["axes.unicode_minus"] = False
In [2]:
data = pd.read_excel('./data/mod_merge_ssr&MEIC&BUGS.xlsx')
data.head()
Out[2]:
date PM2.5 PM10 SO2 NO2 O3 O3_8h CO Ox wind-U ... VOC_resdient VOC_power VOC_agricultural PM2.5_industrial PM2.5_transportation PM2.5_resdient PM2.5_power PM2.5_agricultural CO_Bio VOCs_Bio
0 2015-01-02 01:00:00 136.0 214.0 317.0 38.0 8.0 9.0 3.71 46 0.831775 ... 0.937173 0.037724 0 0.926851 0.077715 0.827110 0.436028 0 0.081546 4.217706
1 2015-01-02 02:00:00 114.0 176.0 305.0 38.0 8.0 9.0 3.55 46 -0.695011 ... 0.937173 0.036215 0 0.926851 0.081248 0.827110 0.418587 0 0.080031 4.119807
2 2015-01-02 03:00:00 97.0 154.0 306.0 37.0 7.0 8.0 3.51 44 -0.173311 ... 0.937173 0.035712 0 0.926851 0.088313 0.827110 0.412773 0 0.077761 3.973464
3 2015-01-02 04:00:00 87.0 141.0 316.0 38.0 7.0 8.0 3.55 45 0.000000 ... 0.937173 0.036718 0 0.926851 0.091256 0.827110 0.424400 0 0.076766 3.909235
4 2015-01-02 05:00:00 85.0 139.0 292.0 37.0 7.0 8.0 3.62 44 1.234518 ... 1.978475 0.039736 0 0.926851 0.092434 1.746121 0.459282 0 0.077119 3.930702

5 rows × 54 columns

In [3]:
data.drop(columns='O3_8h', inplace=True)
In [4]:
out_cols = data.columns[1:7].tolist()
out_cols
Out[4]:
['PM2.5', 'PM10', 'SO2', 'NO2', 'O3', 'CO']
In [15]:
date_range = pd.date_range(start=data.date.min(), end=data.date.max(), freq='H')
data.date = pd.to_datetime(data.date)
data = data.set_index('date').reindex(date_range)
In [16]:
import datetime as dt

打印输出列的分布

In [17]:
fig = plt.figure(figsize=(15, 10))
for index, col in enumerate(out_cols):
    try:
        plt.subplot(3,3,index+1)
        plt.title(col)
        plt.hist(data[col])
    except:
        print(col)
findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans.
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans.
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
No description has been provided for this image

打印特征列的分布

In [18]:
fea_cols = [x for x in data.columns if x not in out_cols]
In [19]:
fig = plt.figure(figsize=(24, 28))
for index, col in enumerate(fea_cols):
    try:
        plt.subplot(9,5,index+1)
        plt.title(col)
        plt.hist(data[col])
    except:
        print(col)
fig.savefig('fea.png')
VOCs_Bio
No description has been provided for this image

分别对每个输出找到相关性最大的特征

In [20]:
import seaborn as sns
In [21]:
k = 10
fig = plt.figure(figsize=(30, 30))
for i,u_col in enumerate(out_cols):
    use_cols = fea_cols + [u_col]
    corrmat = data[use_cols].corr()
    cols = corrmat.nlargest(k, u_col)[u_col].index
    cm = np.corrcoef(data[cols].values.T)
    sns.set(font_scale=1.25)
    plt.subplot(3,3,i+1)
    plt.title(u_col)
    hm = sns.heatmap(cm, cbar=True, annot=True, square=True, fmt='.2f', annot_kws={'size': 10}, 
                    yticklabels=cols.values, xticklabels=cols.values)
fig.savefig('./cm.png')
/home/zhaojh/miniconda3/envs/py37/lib/python3.7/site-packages/seaborn/matrix.py:198: RuntimeWarning: All-NaN slice encountered
  vmin = np.nanmin(calc_data)
/home/zhaojh/miniconda3/envs/py37/lib/python3.7/site-packages/seaborn/matrix.py:203: RuntimeWarning: All-NaN slice encountered
  vmax = np.nanmax(calc_data)
No description has been provided for this image

直接看相关性太低了,将过去一个时刻的指标扔进来

In [28]:
data.index.name = 'date'
In [29]:
data.reset_index(inplace=True)
data.date = pd.to_datetime(data.date)
data['pre_time'] = data.date.apply(lambda x: x - dt.timedelta(hours=1))
pre_out = data[['date'] + out_cols].copy()
pre_out.columns = ['pre_time'] + [f'pre_{x}' for x in out_cols]
pre_out.pre_time = pd.to_datetime(pre_out.pre_time)
pre_out.head()
Out[29]:
pre_time pre_PM2.5 pre_PM10 pre_SO2 pre_NO2 pre_O3 pre_CO
0 2015-01-02 01:00:00 136.0 214.0 317.0 38.0 8.0 3.71
1 2015-01-02 02:00:00 114.0 176.0 305.0 38.0 8.0 3.55
2 2015-01-02 03:00:00 97.0 154.0 306.0 37.0 7.0 3.51
3 2015-01-02 04:00:00 87.0 141.0 316.0 38.0 7.0 3.55
4 2015-01-02 05:00:00 85.0 139.0 292.0 37.0 7.0 3.62
In [30]:
use_data = data.merge(pre_out, how='left', on='pre_time').dropna()
use_data.shape
Out[30]:
(46455, 60)

定义新的特征列

In [31]:
new_fea_cols = [x for x in use_data.columns if x not in out_cols]
In [32]:
k = 10
fig = plt.figure(figsize=(30, 30))
for i,u_col in enumerate(out_cols):
    use_cols = new_fea_cols + [u_col]
    corrmat = use_data[use_cols].corr()
    cols = corrmat.nlargest(k, u_col)[u_col].index
    cm = np.corrcoef(use_data[cols].values.T)
    sns.set(font_scale=1.25)
    plt.subplot(3,3,i+1)
    plt.title(u_col)
    hm = sns.heatmap(cm, cbar=True, annot=True, square=True, fmt='.2f', annot_kws={'size': 10}, 
                    yticklabels=cols.values, xticklabels=cols.values)
No description has been provided for this image

果然和上一时刻是强相关的,这就尴尬了。先做特征工程试试

对输出列取对数化

In [33]:
for col in out_cols:
    use_data[col] = np.log1p(use_data[col])
In [34]:
from scipy.stats import norm
import scipy.stats as stats
In [35]:
fig = plt.figure(figsize=(15, 10))
for index, col in enumerate(out_cols):
    try:
        plt.subplot(3,3,index+1)
        plt.title(col)
        sns.distplot(data[col], fit=norm)
    except:
        print(col)
fig.savefig('no-log.png')
/home/zhaojh/miniconda3/envs/py37/lib/python3.7/site-packages/seaborn/distributions.py:2619: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).
  warnings.warn(msg, FutureWarning)
/home/zhaojh/miniconda3/envs/py37/lib/python3.7/site-packages/seaborn/distributions.py:2619: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).
  warnings.warn(msg, FutureWarning)
/home/zhaojh/miniconda3/envs/py37/lib/python3.7/site-packages/seaborn/distributions.py:2619: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).
  warnings.warn(msg, FutureWarning)
/home/zhaojh/miniconda3/envs/py37/lib/python3.7/site-packages/seaborn/distributions.py:2619: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).
  warnings.warn(msg, FutureWarning)
/home/zhaojh/miniconda3/envs/py37/lib/python3.7/site-packages/seaborn/distributions.py:2619: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).
  warnings.warn(msg, FutureWarning)
/home/zhaojh/miniconda3/envs/py37/lib/python3.7/site-packages/seaborn/distributions.py:2619: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).
  warnings.warn(msg, FutureWarning)
No description has been provided for this image
In [36]:
fig = plt.figure(figsize=(15, 10))
for index, col in enumerate(out_cols):
    try:
        plt.subplot(3,3,index+1)
        plt.title(col)
        sns.distplot(use_data[col], fit=norm)
    except:
        print(col)
fig.savefig('logify.png')
/home/zhaojh/miniconda3/envs/py37/lib/python3.7/site-packages/seaborn/distributions.py:2619: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).
  warnings.warn(msg, FutureWarning)
/home/zhaojh/miniconda3/envs/py37/lib/python3.7/site-packages/seaborn/distributions.py:2619: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).
  warnings.warn(msg, FutureWarning)
/home/zhaojh/miniconda3/envs/py37/lib/python3.7/site-packages/seaborn/distributions.py:2619: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).
  warnings.warn(msg, FutureWarning)
/home/zhaojh/miniconda3/envs/py37/lib/python3.7/site-packages/seaborn/distributions.py:2619: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).
  warnings.warn(msg, FutureWarning)
/home/zhaojh/miniconda3/envs/py37/lib/python3.7/site-packages/seaborn/distributions.py:2619: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).
  warnings.warn(msg, FutureWarning)
/home/zhaojh/miniconda3/envs/py37/lib/python3.7/site-packages/seaborn/distributions.py:2619: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).
  warnings.warn(msg, FutureWarning)
No description has been provided for this image
In [37]:
fig = plt.figure(figsize=(15, 15))
for index, col in enumerate(out_cols):
    try:
        plt.subplot(3,3,index+1)
        plt.title(col)
        rest = stats.probplot(use_data[col], plot=plt)
    except:
        print(col)
No description has been provided for this image

如果做minmax是否还保持高斯分布

In [38]:
maxs = use_data[out_cols].max()
mins = use_data[out_cols].min()
In [39]:
out_data = list()
for col in out_cols:
    print(col, end=' ')
    d = (use_data[col] - mins[col]) / (maxs[col] - mins[col])
    out_data.append(d)
PM2.5 PM10 SO2 NO2 O3 CO 
In [40]:
fig = plt.figure(figsize=(15, 15))
for index, col in enumerate(out_data):
    try:
        plt.subplot(3,3,index+1)
        rest = stats.probplot(col, plot=plt)
    except:
        print(col)
No description has been provided for this image

下面这几列数据有问题直接drop掉

In [41]:
drop_cols = [x for x in new_fea_cols if 'agricultural' in x] + ['NH3_power'] + ['CO_Bio', 'VOCs_Bio']
drop_cols.remove('NH3_agricultural')
In [42]:
fig = plt.figure(figsize=(15, 15))
for index, col in enumerate(drop_cols):
    try:
        plt.subplot(3,4,index+1)
        plt.title(col)
        plt.hist(data[col])
    except:
        print(col)
No description has been provided for this image
In [43]:
use_data.drop(columns=drop_cols, inplace=True)
use_data.drop(columns=['date', 'pre_time'], inplace=True)
In [1]:
use_data.reset_index().to_csv('./data/train_data_mod.csv', encoding='utf-8-sig', index=False)
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
/tmp/ipykernel_1136037/1598183905.py in <module>
----> 1 use_data.reset_index().to_csv('./data/train_data_mod.csv', encoding='utf-8-sig', index=False)

NameError: name 'use_data' is not defined
In [46]:
data.drop(columns=drop_cols).to_csv('./data/ori_data.csv', encoding='utf-8-sig', index=False)
In [ ]: