LSTM
This commit is contained in:
parent
95f3e4c98b
commit
c102001c6d
|
@ -0,0 +1,173 @@
|
|||
import os,sys
|
||||
os.chdir(sys.path[0])
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
|
||||
from datasets import EnNews
|
||||
|
||||
import tensorflow as tf
|
||||
from keras.backend.tensorflow_backend import set_session
|
||||
|
||||
import numpy as np
|
||||
import random as rn
|
||||
|
||||
np.random.seed(421)
|
||||
rn.seed(12345)
|
||||
|
||||
import logging
|
||||
|
||||
from keras import regularizers
|
||||
from keras.layers import Bidirectional, Dense, Dropout, Embedding, LSTM, TimeDistributed
|
||||
|
||||
from keras.models import Sequential, load_model
|
||||
|
||||
from datasets import *
|
||||
from eval import keras_metrics, metrics
|
||||
from nlp import tokenizer as tk
|
||||
from utils import info, preprocessing, postprocessing, plots
|
||||
|
||||
# 记录配置
|
||||
|
||||
logging.basicConfig(
|
||||
format='%(asctime)s\t%(levelname)s\t%(message)s',
|
||||
level=logging.DEBUG)
|
||||
|
||||
info.log_versions()
|
||||
|
||||
# 全局变量
|
||||
|
||||
SAVE_MODEL = False
|
||||
MODEL_PATH = "models/bilstm.h5"
|
||||
SHOW_PLOTS = True
|
||||
|
||||
# 数据集和超参数
|
||||
Dataset = EnNews
|
||||
|
||||
rootpath = "/home/zhangxj/WorkFile/本科毕业设计"
|
||||
|
||||
tokenizer = tk.tokenizers.nltk
|
||||
DATASET_FOLDER = rootpath+"/EnergyNews"
|
||||
MAX_DOCUMENT_LENGTH = 400
|
||||
MAX_VOCABULARY_SIZE = 20000
|
||||
EMBEDDINGS_SIZE = 50
|
||||
batch_size = 32
|
||||
epochs = 20
|
||||
KP_WEIGHT = 10
|
||||
STEM_MODE = metrics.stemMode.both
|
||||
STEM_TEST = False
|
||||
|
||||
|
||||
|
||||
# 加载数据集
|
||||
logging.info("Loading dataset...")
|
||||
|
||||
data = Dataset(DATASET_FOLDER)
|
||||
|
||||
train_doc_str, train_answer_str = data.load_train()
|
||||
test_doc_str, test_answer_str = data.load_test()
|
||||
val_doc_str, val_answer_str = data.load_validation()
|
||||
|
||||
train_doc, train_answer = tk.tokenize_set(train_doc_str, train_answer_str, tokenizer)
|
||||
test_doc, test_answer = tk.tokenize_set(test_doc_str, test_answer_str, tokenizer)
|
||||
val_doc, val_answer = tk.tokenize_set(val_doc_str, val_answer_str, tokenizer)
|
||||
|
||||
# 完整性检查
|
||||
|
||||
logging.info("Dataset loaded. Preprocessing data...")
|
||||
|
||||
train_x, train_y, test_x, test_y, val_x, val_y, embedding_matrix = preprocessing. \
|
||||
prepare_sequential(train_doc, train_answer, test_doc, test_answer, val_doc, val_answer,
|
||||
max_document_length=MAX_DOCUMENT_LENGTH,
|
||||
max_vocabulary_size=MAX_VOCABULARY_SIZE,
|
||||
embeddings_size=EMBEDDINGS_SIZE,
|
||||
stem_test=STEM_TEST)
|
||||
|
||||
# 权重训练示例:所有不是 kp的内容
|
||||
from sklearn.utils import class_weight
|
||||
|
||||
train_y_weights = np.argmax(train_y, axis=2)
|
||||
train_y_weights = np.reshape(class_weight.compute_sample_weight('balanced', train_y_weights.flatten()),
|
||||
np.shape(train_y_weights))
|
||||
|
||||
logging.info("数据预处理完成")
|
||||
logging.info("可能的最大召回率: %s",
|
||||
metrics.recall(test_answer,
|
||||
postprocessing.get_words(test_doc, postprocessing.undo_sequential(test_y)),
|
||||
STEM_MODE))
|
||||
|
||||
if not SAVE_MODEL or not os.path.isfile(MODEL_PATH):
|
||||
|
||||
logging.debug("建立网络...")
|
||||
model = Sequential()
|
||||
print("-------",np.shape(embedding_matrix)[0])
|
||||
embedding_layer = Embedding(np.shape(embedding_matrix)[0],
|
||||
EMBEDDINGS_SIZE,
|
||||
weights=[embedding_matrix],
|
||||
input_length=MAX_DOCUMENT_LENGTH,
|
||||
trainable=False)
|
||||
|
||||
model.add(embedding_layer)
|
||||
model.add(LSTM(300, activation='tanh', recurrent_activation='hard_sigmoid', return_sequences=True))
|
||||
model.add(Dropout(0.25))
|
||||
model.add(TimeDistributed(Dense(150, activation='relu', kernel_regularizer=regularizers.l2(0.01))))
|
||||
model.add(Dropout(0.25))
|
||||
model.add(TimeDistributed(Dense(2, activation='softmax')))
|
||||
|
||||
logging.info("编译网络...")
|
||||
model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'],
|
||||
sample_weight_mode="temporal")
|
||||
print(model.summary())
|
||||
|
||||
metrics_callback = keras_metrics.MetricsCallback(val_x, val_y)
|
||||
|
||||
logging.info("拟合网络...")
|
||||
|
||||
history = model.fit(train_x, train_y,
|
||||
validation_data=(val_x, val_y),
|
||||
epochs=epochs,
|
||||
batch_size=batch_size,
|
||||
sample_weight=train_y_weights,
|
||||
callbacks=[metrics_callback])
|
||||
|
||||
if SHOW_PLOTS:
|
||||
plots.plot_accuracy(history)
|
||||
plots.plot_loss(history)
|
||||
plots.plot_prf(metrics_callback)
|
||||
|
||||
if SAVE_MODEL:
|
||||
model.save(MODEL_PATH)
|
||||
logging.info("模型保存路径 in %s", MODEL_PATH)
|
||||
|
||||
else:
|
||||
logging.info("加载模型 %s...", MODEL_PATH)
|
||||
model = load_model(MODEL_PATH)
|
||||
logging.info("加载模型完成")
|
||||
|
||||
logging.info("在测试集上预测...")
|
||||
output = model.predict(x=test_x, verbose=1)
|
||||
logging.debug("输出格式: %s", np.shape(output))
|
||||
|
||||
obtained_tokens = postprocessing.undo_sequential(output)
|
||||
obtained_words = postprocessing.get_words(test_doc, obtained_tokens)
|
||||
|
||||
precision = metrics.precision(test_answer, obtained_words,STEM_MODE)
|
||||
recall = metrics.recall(test_answer, obtained_words,STEM_MODE)
|
||||
f1 = metrics.f1(precision, recall)
|
||||
|
||||
print("### 获得的分数 ###")
|
||||
print("###")
|
||||
print("### Precision : %.4f" % precision)
|
||||
print("### Recall : %.4f" % recall)
|
||||
print("### F1 : %.4f" % f1)
|
||||
print("### ###")
|
||||
|
||||
keras_precision = keras_metrics.keras_precision(test_y, output)
|
||||
keras_recall = keras_metrics.keras_recall(test_y, output)
|
||||
keras_f1 = keras_metrics.keras_f1(test_y, output)
|
||||
|
||||
print("### 获得的分数 ###")
|
||||
print("###")
|
||||
print("### Precision : %.4f" % keras_precision)
|
||||
print("### Recall : %.4f" % keras_recall)
|
||||
print("### F1 : %.4f" % keras_f1)
|
||||
print("### ###")
|
||||
|
Loading…
Reference in New Issue