diff --git a/LSTM.py b/LSTM.py new file mode 100644 index 0000000..79bf903 --- /dev/null +++ b/LSTM.py @@ -0,0 +1,173 @@ +import os,sys +os.chdir(sys.path[0]) +os.environ["CUDA_VISIBLE_DEVICES"] = "0" + +from datasets import EnNews + +import tensorflow as tf +from keras.backend.tensorflow_backend import set_session + +import numpy as np +import random as rn + +np.random.seed(421) +rn.seed(12345) + +import logging + +from keras import regularizers +from keras.layers import Bidirectional, Dense, Dropout, Embedding, LSTM, TimeDistributed + +from keras.models import Sequential, load_model + +from datasets import * +from eval import keras_metrics, metrics +from nlp import tokenizer as tk +from utils import info, preprocessing, postprocessing, plots + +# 记录配置 + +logging.basicConfig( + format='%(asctime)s\t%(levelname)s\t%(message)s', + level=logging.DEBUG) + +info.log_versions() + +# 全局变量 + +SAVE_MODEL = False +MODEL_PATH = "models/bilstm.h5" +SHOW_PLOTS = True + +# 数据集和超参数 +Dataset = EnNews + +rootpath = "/home/zhangxj/WorkFile/本科毕业设计" + +tokenizer = tk.tokenizers.nltk +DATASET_FOLDER = rootpath+"/EnergyNews" +MAX_DOCUMENT_LENGTH = 400 +MAX_VOCABULARY_SIZE = 20000 +EMBEDDINGS_SIZE = 50 +batch_size = 32 +epochs = 20 +KP_WEIGHT = 10 +STEM_MODE = metrics.stemMode.both +STEM_TEST = False + + + +# 加载数据集 +logging.info("Loading dataset...") + +data = Dataset(DATASET_FOLDER) + +train_doc_str, train_answer_str = data.load_train() +test_doc_str, test_answer_str = data.load_test() +val_doc_str, val_answer_str = data.load_validation() + +train_doc, train_answer = tk.tokenize_set(train_doc_str, train_answer_str, tokenizer) +test_doc, test_answer = tk.tokenize_set(test_doc_str, test_answer_str, tokenizer) +val_doc, val_answer = tk.tokenize_set(val_doc_str, val_answer_str, tokenizer) + +# 完整性检查 + +logging.info("Dataset loaded. Preprocessing data...") + +train_x, train_y, test_x, test_y, val_x, val_y, embedding_matrix = preprocessing. \ + prepare_sequential(train_doc, train_answer, test_doc, test_answer, val_doc, val_answer, + max_document_length=MAX_DOCUMENT_LENGTH, + max_vocabulary_size=MAX_VOCABULARY_SIZE, + embeddings_size=EMBEDDINGS_SIZE, + stem_test=STEM_TEST) + +# 权重训练示例:所有不是 kp的内容 +from sklearn.utils import class_weight + +train_y_weights = np.argmax(train_y, axis=2) +train_y_weights = np.reshape(class_weight.compute_sample_weight('balanced', train_y_weights.flatten()), + np.shape(train_y_weights)) + +logging.info("数据预处理完成") +logging.info("可能的最大召回率: %s", + metrics.recall(test_answer, + postprocessing.get_words(test_doc, postprocessing.undo_sequential(test_y)), + STEM_MODE)) + +if not SAVE_MODEL or not os.path.isfile(MODEL_PATH): + + logging.debug("建立网络...") + model = Sequential() + print("-------",np.shape(embedding_matrix)[0]) + embedding_layer = Embedding(np.shape(embedding_matrix)[0], + EMBEDDINGS_SIZE, + weights=[embedding_matrix], + input_length=MAX_DOCUMENT_LENGTH, + trainable=False) + + model.add(embedding_layer) + model.add(LSTM(300, activation='tanh', recurrent_activation='hard_sigmoid', return_sequences=True)) + model.add(Dropout(0.25)) + model.add(TimeDistributed(Dense(150, activation='relu', kernel_regularizer=regularizers.l2(0.01)))) + model.add(Dropout(0.25)) + model.add(TimeDistributed(Dense(2, activation='softmax'))) + + logging.info("编译网络...") + model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'], + sample_weight_mode="temporal") + print(model.summary()) + + metrics_callback = keras_metrics.MetricsCallback(val_x, val_y) + + logging.info("拟合网络...") + + history = model.fit(train_x, train_y, + validation_data=(val_x, val_y), + epochs=epochs, + batch_size=batch_size, + sample_weight=train_y_weights, + callbacks=[metrics_callback]) + + if SHOW_PLOTS: + plots.plot_accuracy(history) + plots.plot_loss(history) + plots.plot_prf(metrics_callback) + + if SAVE_MODEL: + model.save(MODEL_PATH) + logging.info("模型保存路径 in %s", MODEL_PATH) + +else: + logging.info("加载模型 %s...", MODEL_PATH) + model = load_model(MODEL_PATH) + logging.info("加载模型完成") + +logging.info("在测试集上预测...") +output = model.predict(x=test_x, verbose=1) +logging.debug("输出格式: %s", np.shape(output)) + +obtained_tokens = postprocessing.undo_sequential(output) +obtained_words = postprocessing.get_words(test_doc, obtained_tokens) + +precision = metrics.precision(test_answer, obtained_words,STEM_MODE) +recall = metrics.recall(test_answer, obtained_words,STEM_MODE) +f1 = metrics.f1(precision, recall) + +print("### 获得的分数 ###") +print("###") +print("### Precision : %.4f" % precision) +print("### Recall : %.4f" % recall) +print("### F1 : %.4f" % f1) +print("### ###") + +keras_precision = keras_metrics.keras_precision(test_y, output) +keras_recall = keras_metrics.keras_recall(test_y, output) +keras_f1 = keras_metrics.keras_f1(test_y, output) + +print("### 获得的分数 ###") +print("###") +print("### Precision : %.4f" % keras_precision) +print("### Recall : %.4f" % keras_recall) +print("### F1 : %.4f" % keras_f1) +print("### ###") +