174 lines
5.4 KiB
Python
174 lines
5.4 KiB
Python
import os,sys
|
|
os.chdir(sys.path[0])
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
|
|
|
from datasets import EnNews
|
|
|
|
import tensorflow as tf
|
|
from keras.backend.tensorflow_backend import set_session
|
|
|
|
import numpy as np
|
|
import random as rn
|
|
|
|
np.random.seed(421)
|
|
rn.seed(12345)
|
|
|
|
import logging
|
|
|
|
from keras import regularizers
|
|
from keras.layers import Bidirectional, Dense, Dropout, Embedding, LSTM, TimeDistributed
|
|
|
|
from keras.models import Sequential, load_model
|
|
|
|
from datasets import *
|
|
from eval import keras_metrics, metrics
|
|
from nlp import tokenizer as tk
|
|
from utils import info, preprocessing, postprocessing, plots
|
|
|
|
# 记录配置
|
|
|
|
logging.basicConfig(
|
|
format='%(asctime)s\t%(levelname)s\t%(message)s',
|
|
level=logging.DEBUG)
|
|
|
|
info.log_versions()
|
|
|
|
# 全局变量
|
|
|
|
SAVE_MODEL = False
|
|
MODEL_PATH = "models/bilstm.h5"
|
|
SHOW_PLOTS = False
|
|
|
|
# 数据集和超参数
|
|
Dataset = EnNews
|
|
|
|
rootpath = "/home/zhangxj/WorkFile/本科毕业设计"
|
|
|
|
tokenizer = tk.tokenizers.nltk
|
|
DATASET_FOLDER = rootpath+"/EnergyNews"
|
|
MAX_DOCUMENT_LENGTH = 400
|
|
MAX_VOCABULARY_SIZE = 20000
|
|
EMBEDDINGS_SIZE = 50
|
|
batch_size = 32
|
|
epochs = 20
|
|
KP_WEIGHT = 10
|
|
STEM_MODE = metrics.stemMode.both
|
|
STEM_TEST = False
|
|
|
|
|
|
|
|
# 加载数据集
|
|
logging.info("Loading dataset...")
|
|
|
|
data = Dataset(DATASET_FOLDER)
|
|
|
|
train_doc_str, train_answer_str = data.load_train()
|
|
test_doc_str, test_answer_str = data.load_test()
|
|
val_doc_str, val_answer_str = data.load_validation()
|
|
|
|
train_doc, train_answer = tk.tokenize_set(train_doc_str, train_answer_str, tokenizer)
|
|
test_doc, test_answer = tk.tokenize_set(test_doc_str, test_answer_str, tokenizer)
|
|
val_doc, val_answer = tk.tokenize_set(val_doc_str, val_answer_str, tokenizer)
|
|
|
|
# 完整性检查
|
|
|
|
logging.info("Dataset loaded. Preprocessing data...")
|
|
|
|
train_x, train_y, test_x, test_y, val_x, val_y, embedding_matrix = preprocessing. \
|
|
prepare_sequential(train_doc, train_answer, test_doc, test_answer, val_doc, val_answer,
|
|
max_document_length=MAX_DOCUMENT_LENGTH,
|
|
max_vocabulary_size=MAX_VOCABULARY_SIZE,
|
|
embeddings_size=EMBEDDINGS_SIZE,
|
|
stem_test=STEM_TEST)
|
|
|
|
# 权重训练示例:所有不是 kp的内容
|
|
from sklearn.utils import class_weight
|
|
|
|
train_y_weights = np.argmax(train_y, axis=2)
|
|
train_y_weights = np.reshape(class_weight.compute_sample_weight('balanced', train_y_weights.flatten()),
|
|
np.shape(train_y_weights))
|
|
|
|
logging.info("数据预处理完成")
|
|
logging.info("可能的最大召回率: %s",
|
|
metrics.recall(test_answer,
|
|
postprocessing.get_words(test_doc, postprocessing.undo_sequential(test_y)),
|
|
STEM_MODE))
|
|
|
|
if not SAVE_MODEL or not os.path.isfile(MODEL_PATH):
|
|
|
|
logging.debug("建立网络...")
|
|
model = Sequential()
|
|
print("-------",np.shape(embedding_matrix)[0])
|
|
embedding_layer = Embedding(np.shape(embedding_matrix)[0],
|
|
EMBEDDINGS_SIZE,
|
|
weights=[embedding_matrix],
|
|
input_length=MAX_DOCUMENT_LENGTH,
|
|
trainable=False)
|
|
|
|
model.add(embedding_layer)
|
|
model.add(Bidirectional(LSTM(300, activation='tanh', recurrent_activation='hard_sigmoid', return_sequences=True)))
|
|
model.add(Dropout(0.25))
|
|
model.add(TimeDistributed(Dense(150, activation='relu', kernel_regularizer=regularizers.l2(0.01))))
|
|
model.add(Dropout(0.25))
|
|
model.add(TimeDistributed(Dense(2, activation='softmax')))
|
|
|
|
logging.info("编译网络...")
|
|
model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'],
|
|
sample_weight_mode="temporal")
|
|
print(model.summary())
|
|
|
|
metrics_callback = keras_metrics.MetricsCallback(val_x, val_y)
|
|
|
|
logging.info("拟合网络...")
|
|
|
|
history = model.fit(train_x, train_y,
|
|
validation_data=(val_x, val_y),
|
|
epochs=epochs,
|
|
batch_size=batch_size,
|
|
sample_weight=train_y_weights,
|
|
callbacks=[metrics_callback])
|
|
|
|
if SHOW_PLOTS:
|
|
plots.plot_accuracy(history)
|
|
plots.plot_loss(history)
|
|
plots.plot_prf(metrics_callback)
|
|
|
|
if SAVE_MODEL:
|
|
model.save(MODEL_PATH)
|
|
logging.info("模型保存路径 in %s", MODEL_PATH)
|
|
|
|
else:
|
|
logging.info("加载模型 %s...", MODEL_PATH)
|
|
model = load_model(MODEL_PATH)
|
|
logging.info("加载模型完成")
|
|
|
|
logging.info("在测试集上预测...")
|
|
output = model.predict(x=test_x, verbose=1)
|
|
logging.debug("输出格式: %s", np.shape(output))
|
|
|
|
obtained_tokens = postprocessing.undo_sequential(output)
|
|
obtained_words = postprocessing.get_words(test_doc, obtained_tokens)
|
|
|
|
precision = metrics.precision(test_answer, obtained_words,STEM_MODE)
|
|
recall = metrics.recall(test_answer, obtained_words,STEM_MODE)
|
|
f1 = metrics.f1(precision, recall)
|
|
|
|
print("### 获得的分数 ###")
|
|
print("###")
|
|
print("### Precision : %.4f" % precision)
|
|
print("### Recall : %.4f" % recall)
|
|
print("### F1 : %.4f" % f1)
|
|
print("### ###")
|
|
|
|
keras_precision = keras_metrics.keras_precision(test_y, output)
|
|
keras_recall = keras_metrics.keras_recall(test_y, output)
|
|
keras_f1 = keras_metrics.keras_f1(test_y, output)
|
|
|
|
print("### 获得的分数 ###")
|
|
print("###")
|
|
print("### Precision : %.4f" % keras_precision)
|
|
print("### Recall : %.4f" % keras_recall)
|
|
print("### F1 : %.4f" % keras_f1)
|
|
print("### ###")
|
|
|