228 lines
9.1 KiB
Python
228 lines
9.1 KiB
Python
|
#! -*- coding: utf-8 -*-
|
|||
|
# 10个epoch后在valid上能达到约0.77的分数
|
|||
|
# (Accuracy=0.7282149325820084 F1=0.8207266829447049 Final=0.7744708077633566)
|
|||
|
|
|||
|
import json, os, re
|
|||
|
os.environ['TF_KERAS'] = '1'
|
|||
|
import numpy as np
|
|||
|
from bert4keras.backend import keras, K
|
|||
|
from bert4keras.models import build_transformer_model
|
|||
|
from bert4keras.tokenizers import Tokenizer, load_vocab
|
|||
|
from bert4keras.optimizers import Adam
|
|||
|
from bert4keras.snippets import sequence_padding, DataGenerator
|
|||
|
from bert4keras.snippets import open
|
|||
|
from keras.layers import Lambda
|
|||
|
from keras.models import Model
|
|||
|
from tqdm import tqdm
|
|||
|
|
|||
|
import tensorflow as tf
|
|||
|
config = tf.compat.v1.ConfigProto()
|
|||
|
config.gpu_options.allow_growth=True # 按需分配显存
|
|||
|
tf_session = tf.compat.v1.Session(graph=tf.compat.v1.get_default_graph(),config=config)
|
|||
|
tf.compat.v1.keras.backend.set_session(tf_session)
|
|||
|
|
|||
|
max_p_len = 256
|
|||
|
max_q_len = 64
|
|||
|
max_a_len = 32
|
|||
|
batch_size = 32
|
|||
|
epochs = 10
|
|||
|
|
|||
|
# # bert配置
|
|||
|
# config_path = '../models/nezha_gpt/config.json'
|
|||
|
# checkpoint_path = '../models/nezha_gpt/gpt.ckpt'
|
|||
|
# dict_path = '../models/tokenizer/vocab.txt'
|
|||
|
|
|||
|
# # 标注数据
|
|||
|
# webqa_data = json.load(open('../data/qa/WebQA.json'))
|
|||
|
# sogou_data = json.load(open('../data/qa/SogouQA.json'))
|
|||
|
|
|||
|
# # 保存一个随机序(供划分valid用)
|
|||
|
# if not os.path.exists('../random_order.json'):
|
|||
|
# random_order = list(range(len(sogou_data)))
|
|||
|
# np.random.shuffle(random_order)
|
|||
|
# json.dump(random_order, open('../random_order.json', 'w'), indent=4)
|
|||
|
# else:
|
|||
|
# random_order = json.load(open('../random_order.json'))
|
|||
|
|
|||
|
# # 划分valid
|
|||
|
# train_data = [sogou_data[j] for i, j in enumerate(random_order) if i % 3 != 0]
|
|||
|
# valid_data = [sogou_data[j] for i, j in enumerate(random_order) if i % 3 == 0]
|
|||
|
# train_data.extend(train_data)
|
|||
|
# train_data.extend(webqa_data) # 将SogouQA和WebQA按2:1的比例混合
|
|||
|
|
|||
|
# # 加载并精简词表,建立分词器
|
|||
|
# token_dict, keep_tokens = load_vocab(
|
|||
|
# dict_path=dict_path,
|
|||
|
# simplified=True,
|
|||
|
# startswith=['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]'],
|
|||
|
# )
|
|||
|
# tokenizer = Tokenizer(token_dict, do_lower_case=True)
|
|||
|
|
|||
|
|
|||
|
# class data_generator(DataGenerator):
|
|||
|
# """数据生成器
|
|||
|
# """
|
|||
|
# def __iter__(self, random=False):
|
|||
|
# """单条样本格式为
|
|||
|
# 输入:[CLS][MASK][MASK][SEP]问题[SEP]篇章[SEP]
|
|||
|
# 输出:答案
|
|||
|
# """
|
|||
|
# batch_token_ids, batch_segment_ids, batch_a_token_ids = [], [], []
|
|||
|
# for is_end, D in self.sample(random):
|
|||
|
# question = D['question']
|
|||
|
# answers = [p['answer'] for p in D['passages'] if p['answer']]
|
|||
|
# passage = np.random.choice(D['passages'])['passage']
|
|||
|
# passage = re.sub(u' |、|;|,', ',', passage)
|
|||
|
# final_answer = ''
|
|||
|
# for answer in answers:
|
|||
|
# if all([
|
|||
|
# a in passage[:max_p_len - 2] for a in answer.split(' ')
|
|||
|
# ]):
|
|||
|
# final_answer = answer.replace(' ', ',')
|
|||
|
# break
|
|||
|
# a_token_ids, _ = tokenizer.encode(
|
|||
|
# final_answer, maxlen=max_a_len + 1
|
|||
|
# )
|
|||
|
# q_token_ids, _ = tokenizer.encode(question, maxlen=max_q_len + 1)
|
|||
|
# p_token_ids, _ = tokenizer.encode(passage, maxlen=max_p_len + 1)
|
|||
|
# token_ids = [tokenizer._token_start_id]
|
|||
|
# token_ids += ([tokenizer._token_mask_id] * max_a_len)
|
|||
|
# token_ids += [tokenizer._token_end_id]
|
|||
|
# token_ids += (q_token_ids[1:] + p_token_ids[1:])
|
|||
|
# segment_ids = [0] * len(token_ids)
|
|||
|
# batch_token_ids.append(token_ids)
|
|||
|
# batch_segment_ids.append(segment_ids)
|
|||
|
# batch_a_token_ids.append(a_token_ids[1:])
|
|||
|
# if len(batch_token_ids) == self.batch_size or is_end:
|
|||
|
# batch_token_ids = sequence_padding(batch_token_ids)
|
|||
|
# batch_segment_ids = sequence_padding(batch_segment_ids)
|
|||
|
# batch_a_token_ids = sequence_padding(
|
|||
|
# batch_a_token_ids, max_a_len
|
|||
|
# )
|
|||
|
# yield [batch_token_ids, batch_segment_ids], batch_a_token_ids
|
|||
|
# batch_token_ids, batch_segment_ids, batch_a_token_ids = [], [], []
|
|||
|
|
|||
|
|
|||
|
def masked_cross_entropy(y_true, y_pred):
|
|||
|
"""交叉熵作为loss,并mask掉padding部分的预测
|
|||
|
"""
|
|||
|
y_true = K.reshape(y_true, [K.shape(y_true)[0], -1])
|
|||
|
y_mask = K.cast(K.not_equal(y_true, 0), K.floatx())
|
|||
|
cross_entropy = K.sparse_categorical_crossentropy(y_true, y_pred)
|
|||
|
cross_entropy = K.sum(cross_entropy * y_mask) / K.sum(y_mask)
|
|||
|
return cross_entropy
|
|||
|
|
|||
|
|
|||
|
|
|||
|
def build_reading_model(config_path:str, ckpt_path:str, keep_tokens:str, weight_path:str):
|
|||
|
model = build_transformer_model(
|
|||
|
config_path,
|
|||
|
ckpt_path,
|
|||
|
with_mlm=True,
|
|||
|
keep_tokens=keep_tokens, # 只保留keep_tokens中的字,精简原字表
|
|||
|
)
|
|||
|
output = Lambda(lambda x: x[:, 1:max_a_len + 1])(model.output)
|
|||
|
model = Model(model.input, output)
|
|||
|
model.compile(loss=masked_cross_entropy, optimizer=Adam(1e-5))
|
|||
|
model.load_weights(weight_path)
|
|||
|
return model
|
|||
|
|
|||
|
def get_ngram_set(x, n):
|
|||
|
"""生成ngram合集,返回结果格式是:
|
|||
|
{(n-1)-gram: set([n-gram的第n个字集合])}
|
|||
|
"""
|
|||
|
result = {}
|
|||
|
for i in range(len(x) - n + 1):
|
|||
|
k = tuple(x[i:i + n])
|
|||
|
if k[:-1] not in result:
|
|||
|
result[k[:-1]] = set()
|
|||
|
result[k[:-1]].add(k[-1])
|
|||
|
return result
|
|||
|
|
|||
|
|
|||
|
def gen_answer(question, passages, model, tokenizer):
|
|||
|
"""由于是MLM模型,所以可以直接argmax解码。
|
|||
|
"""
|
|||
|
all_p_token_ids, token_ids, segment_ids = [], [], []
|
|||
|
for passage in passages:
|
|||
|
passage = re.sub(u' |、|;|,', ',', passage)
|
|||
|
p_token_ids, _ = tokenizer.encode(passage, maxlen=max_p_len + 1)
|
|||
|
q_token_ids, _ = tokenizer.encode(question, maxlen=max_q_len + 1)
|
|||
|
all_p_token_ids.append(p_token_ids[1:])
|
|||
|
token_ids.append([tokenizer._token_start_id])
|
|||
|
token_ids[-1] += ([tokenizer._token_mask_id] * max_a_len)
|
|||
|
token_ids[-1] += [tokenizer._token_end_id]
|
|||
|
token_ids[-1] += (q_token_ids[1:] + p_token_ids[1:])
|
|||
|
segment_ids.append([0] * len(token_ids[-1]))
|
|||
|
token_ids = sequence_padding(token_ids)
|
|||
|
segment_ids = sequence_padding(segment_ids)
|
|||
|
probas = model.predict([token_ids, segment_ids])
|
|||
|
results = {}
|
|||
|
for t, p in zip(all_p_token_ids, probas):
|
|||
|
a, score = tuple(), 0.
|
|||
|
for i in range(max_a_len):
|
|||
|
idxs = list(get_ngram_set(t, i + 1)[a])
|
|||
|
if tokenizer._token_end_id not in idxs:
|
|||
|
idxs.append(tokenizer._token_end_id)
|
|||
|
# pi是将passage以外的token的概率置零
|
|||
|
pi = np.zeros_like(p[i])
|
|||
|
pi[idxs] = p[i, idxs]
|
|||
|
a = a + (pi.argmax(),)
|
|||
|
score += pi.max()
|
|||
|
if a[-1] == tokenizer._token_end_id:
|
|||
|
break
|
|||
|
score = score / (i + 1)
|
|||
|
a = tokenizer.decode(a)
|
|||
|
if a:
|
|||
|
results[a] = results.get(a, []) + [score]
|
|||
|
results = {
|
|||
|
k: (np.array(v)**2).sum() / (sum(v) + 1)
|
|||
|
for k, v in results.items()
|
|||
|
}
|
|||
|
return results
|
|||
|
|
|||
|
|
|||
|
def max_in_dict(d):
|
|||
|
if d:
|
|||
|
return sorted(d.items(), key=lambda s: -s[1])[0][0]
|
|||
|
|
|||
|
|
|||
|
# def predict_to_file(data, filename):
|
|||
|
# """将预测结果输出到文件,方便评估
|
|||
|
# """
|
|||
|
# with open(filename, 'w', encoding='utf-8') as f:
|
|||
|
# for d in tqdm(iter(data), desc=u'正在预测(共%s条样本)' % len(data)):
|
|||
|
# q_text = d['question']
|
|||
|
# p_texts = [p['passage'] for p in d['passages']]
|
|||
|
# a = gen_answer(q_text, p_texts)
|
|||
|
# a = max_in_dict(a)
|
|||
|
# if a:
|
|||
|
# s = u'%s\t%s\n' % (d['id'], a)
|
|||
|
# else:
|
|||
|
# s = u'%s\t\n' % (d['id'])
|
|||
|
# f.write(s)
|
|||
|
# f.flush()
|
|||
|
|
|||
|
|
|||
|
# class Evaluator(keras.callbacks.Callback):
|
|||
|
# """评估与保存
|
|||
|
# """
|
|||
|
# def __init__(self):
|
|||
|
# self.lowest = 1e10
|
|||
|
|
|||
|
# def on_epoch_end(self, epoch, logs=None):
|
|||
|
# # 保存最优
|
|||
|
# if logs['loss'] <= self.lowest:
|
|||
|
# self.lowest = logs['loss']
|
|||
|
# model.save_weights('../models/qa/best_model.weights')
|
|||
|
|
|||
|
|
|||
|
if __name__ == '__main__':
|
|||
|
model = build_reading_model()
|
|||
|
model.load_weights('../models/qa/best_model.weights')
|
|||
|
questions = "嬴政出生在哪里?"
|
|||
|
passages = ["秦始皇嬴政(前259年—前210年),嬴姓,赵氏 ,名政(一说名“正”),又称赵政 、祖龙 ,也有吕政一说(详见“人物争议-姓名之争”目录)。秦庄襄王和赵姬之子。中国古代杰出的政治家、战略家、改革家,首次完成中国大一统的政治人物,也是中国第一个称皇帝的君主。",
|
|||
|
"公元前221年,秦统一六国之后,秦王嬴政认为自己“德兼三皇,功过五帝”,遂采用三皇之“皇”、五帝之“帝”构成“皇帝”的称号,是中国历史上第一个使用“皇帝”称号的君主,所以自称“始皇帝”。",
|
|||
|
"秦始皇有二十余子。长子扶苏,少子胡亥。",
|
|||
|
"嬴政出生在当时赵国的邯郸廓城(在今城内中街以东,丛台西南的朱家巷一带),是当时的秦国王孙异人之子。"]
|
|||
|
print(gen_answer(questions, passages))
|