EnergyNewsKeyword/Bi-LSTM.py

174 lines
5.4 KiB
Python

import os,sys
os.chdir(sys.path[0])
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from datasets import EnNews
import tensorflow as tf
from keras.backend.tensorflow_backend import set_session
import numpy as np
import random as rn
np.random.seed(421)
rn.seed(12345)
import logging
from keras import regularizers
from keras.layers import Bidirectional, Dense, Dropout, Embedding, LSTM, TimeDistributed
from keras.models import Sequential, load_model
from datasets import *
from eval import keras_metrics, metrics
from nlp import tokenizer as tk
from utils import info, preprocessing, postprocessing, plots
# 记录配置
logging.basicConfig(
format='%(asctime)s\t%(levelname)s\t%(message)s',
level=logging.DEBUG)
info.log_versions()
# 全局变量
SAVE_MODEL = False
MODEL_PATH = "models/bilstm.h5"
SHOW_PLOTS = False
# 数据集和超参数
Dataset = EnNews
rootpath = "/home/zhangxj/WorkFile/本科毕业设计"
tokenizer = tk.tokenizers.nltk
DATASET_FOLDER = rootpath+"/EnergyNews"
MAX_DOCUMENT_LENGTH = 400
MAX_VOCABULARY_SIZE = 20000
EMBEDDINGS_SIZE = 50
batch_size = 32
epochs = 20
KP_WEIGHT = 10
STEM_MODE = metrics.stemMode.both
STEM_TEST = False
# 加载数据集
logging.info("Loading dataset...")
data = Dataset(DATASET_FOLDER)
train_doc_str, train_answer_str = data.load_train()
test_doc_str, test_answer_str = data.load_test()
val_doc_str, val_answer_str = data.load_validation()
train_doc, train_answer = tk.tokenize_set(train_doc_str, train_answer_str, tokenizer)
test_doc, test_answer = tk.tokenize_set(test_doc_str, test_answer_str, tokenizer)
val_doc, val_answer = tk.tokenize_set(val_doc_str, val_answer_str, tokenizer)
# 完整性检查
logging.info("Dataset loaded. Preprocessing data...")
train_x, train_y, test_x, test_y, val_x, val_y, embedding_matrix = preprocessing. \
prepare_sequential(train_doc, train_answer, test_doc, test_answer, val_doc, val_answer,
max_document_length=MAX_DOCUMENT_LENGTH,
max_vocabulary_size=MAX_VOCABULARY_SIZE,
embeddings_size=EMBEDDINGS_SIZE,
stem_test=STEM_TEST)
# 权重训练示例:所有不是 kp的内容
from sklearn.utils import class_weight
train_y_weights = np.argmax(train_y, axis=2)
train_y_weights = np.reshape(class_weight.compute_sample_weight('balanced', train_y_weights.flatten()),
np.shape(train_y_weights))
logging.info("数据预处理完成")
logging.info("可能的最大召回率: %s",
metrics.recall(test_answer,
postprocessing.get_words(test_doc, postprocessing.undo_sequential(test_y)),
STEM_MODE))
if not SAVE_MODEL or not os.path.isfile(MODEL_PATH):
logging.debug("建立网络...")
model = Sequential()
print("-------",np.shape(embedding_matrix)[0])
embedding_layer = Embedding(np.shape(embedding_matrix)[0],
EMBEDDINGS_SIZE,
weights=[embedding_matrix],
input_length=MAX_DOCUMENT_LENGTH,
trainable=False)
model.add(embedding_layer)
model.add(Bidirectional(LSTM(300, activation='tanh', recurrent_activation='hard_sigmoid', return_sequences=True)))
model.add(Dropout(0.25))
model.add(TimeDistributed(Dense(150, activation='relu', kernel_regularizer=regularizers.l2(0.01))))
model.add(Dropout(0.25))
model.add(TimeDistributed(Dense(2, activation='softmax')))
logging.info("编译网络...")
model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'],
sample_weight_mode="temporal")
print(model.summary())
metrics_callback = keras_metrics.MetricsCallback(val_x, val_y)
logging.info("拟合网络...")
history = model.fit(train_x, train_y,
validation_data=(val_x, val_y),
epochs=epochs,
batch_size=batch_size,
sample_weight=train_y_weights,
callbacks=[metrics_callback])
if SHOW_PLOTS:
plots.plot_accuracy(history)
plots.plot_loss(history)
plots.plot_prf(metrics_callback)
if SAVE_MODEL:
model.save(MODEL_PATH)
logging.info("模型保存路径 in %s", MODEL_PATH)
else:
logging.info("加载模型 %s...", MODEL_PATH)
model = load_model(MODEL_PATH)
logging.info("加载模型完成")
logging.info("在测试集上预测...")
output = model.predict(x=test_x, verbose=1)
logging.debug("输出格式: %s", np.shape(output))
obtained_tokens = postprocessing.undo_sequential(output)
obtained_words = postprocessing.get_words(test_doc, obtained_tokens)
precision = metrics.precision(test_answer, obtained_words,STEM_MODE)
recall = metrics.recall(test_answer, obtained_words,STEM_MODE)
f1 = metrics.f1(precision, recall)
print("### 获得的分数 ###")
print("###")
print("### Precision : %.4f" % precision)
print("### Recall : %.4f" % recall)
print("### F1 : %.4f" % f1)
print("### ###")
keras_precision = keras_metrics.keras_precision(test_y, output)
keras_recall = keras_metrics.keras_recall(test_y, output)
keras_f1 = keras_metrics.keras_f1(test_y, output)
print("### 获得的分数 ###")
print("###")
print("### Precision : %.4f" % keras_precision)
print("### Recall : %.4f" % keras_recall)
print("### F1 : %.4f" % keras_f1)
print("### ###")