44 lines
1.9 KiB
Python
44 lines
1.9 KiB
Python
|
import argparse
|
|||
|
|
|||
|
class Evaluator:
|
|||
|
"""A simple evaluator for exact match."""
|
|||
|
def __init__(self):
|
|||
|
pass
|
|||
|
|
|||
|
def eval_exact_match(self, pred, label):
|
|||
|
# 比较预测的SQL和标准SQL是否完全相同
|
|||
|
return pred == label
|
|||
|
|
|||
|
def evaluate(gold, predict, db):
|
|||
|
with open(gold) as f:
|
|||
|
gold_sqls = [line.strip() for line in f.readlines()]
|
|||
|
|
|||
|
with open(predict) as f:
|
|||
|
predict_sqls = [line.strip() for line in f.readlines()]
|
|||
|
|
|||
|
assert len(gold_sqls) == len(predict_sqls), "The number of gold and predict SQLs must be equal"
|
|||
|
evaluator = Evaluator()
|
|||
|
exact_matches = []
|
|||
|
|
|||
|
for g_sql, p_sql in zip(gold_sqls, predict_sqls):
|
|||
|
# gStructured = get_sql(schema, g_sql)
|
|||
|
# pStructured = get_sql(schema, p_sql) # getsql主要功能是获取表格的别名,但是我们的数据里没有,所以不需要这个步骤
|
|||
|
exact_match = evaluator.eval_exact_match(p_sql, g_sql)
|
|||
|
exact_matches.append(exact_match)
|
|||
|
|
|||
|
# 输出精确匹配的结果
|
|||
|
exact_match_rate = sum(exact_matches) / len(exact_matches)
|
|||
|
print(f"Exact Match Rate: {exact_match_rate:.2f} ({sum(exact_matches)}/{len(exact_matches)})")
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
parser = argparse.ArgumentParser()
|
|||
|
gold_path = "/home/zhangxj/WorkFile/LCA-GPT/LLM-SQL/data/tb_process/gold.txt"
|
|||
|
pred_path = "/home/zhangxj/WorkFile/LCA-GPT/LLM-SQL/data/tb_process/base_pred.txt"
|
|||
|
db_path = "/home/zhangxj/WorkFile/Text2Sql/data/spider_data/database/academic/academic.sqlite"
|
|||
|
|
|||
|
parser.add_argument('--gold', dest='gold', type=str, help="the path to the gold sql queries", default = gold_path)
|
|||
|
parser.add_argument('--pred', dest='pred', type=str, help="the path to the predicted sql queries", default = pred_path)
|
|||
|
parser.add_argument('--db', dest='db', type=str, help="the path to the database", default=db_path)
|
|||
|
args = parser.parse_args()
|
|||
|
|
|||
|
evaluate(args.gold, args.pred, args.db)
|