44 lines
1.9 KiB
Python
44 lines
1.9 KiB
Python
import argparse
|
||
|
||
class Evaluator:
|
||
"""A simple evaluator for exact match."""
|
||
def __init__(self):
|
||
pass
|
||
|
||
def eval_exact_match(self, pred, label):
|
||
# 比较预测的SQL和标准SQL是否完全相同
|
||
return pred == label
|
||
|
||
def evaluate(gold, predict, db):
|
||
with open(gold) as f:
|
||
gold_sqls = [line.strip() for line in f.readlines()]
|
||
|
||
with open(predict) as f:
|
||
predict_sqls = [line.strip() for line in f.readlines()]
|
||
|
||
assert len(gold_sqls) == len(predict_sqls), "The number of gold and predict SQLs must be equal"
|
||
evaluator = Evaluator()
|
||
exact_matches = []
|
||
|
||
for g_sql, p_sql in zip(gold_sqls, predict_sqls):
|
||
# gStructured = get_sql(schema, g_sql)
|
||
# pStructured = get_sql(schema, p_sql) # getsql主要功能是获取表格的别名,但是我们的数据里没有,所以不需要这个步骤
|
||
exact_match = evaluator.eval_exact_match(p_sql, g_sql)
|
||
exact_matches.append(exact_match)
|
||
|
||
# 输出精确匹配的结果
|
||
exact_match_rate = sum(exact_matches) / len(exact_matches)
|
||
print(f"Exact Match Rate: {exact_match_rate:.2f} ({sum(exact_matches)}/{len(exact_matches)})")
|
||
|
||
if __name__ == "__main__":
|
||
parser = argparse.ArgumentParser()
|
||
gold_path = "/home/zhangxj/WorkFile/LCA-GPT/LLM-SQL/data/tb_process/gold.txt"
|
||
pred_path = "/home/zhangxj/WorkFile/LCA-GPT/LLM-SQL/data/tb_process/base_pred.txt"
|
||
db_path = "/home/zhangxj/WorkFile/Text2Sql/data/spider_data/database/academic/academic.sqlite"
|
||
|
||
parser.add_argument('--gold', dest='gold', type=str, help="the path to the gold sql queries", default = gold_path)
|
||
parser.add_argument('--pred', dest='pred', type=str, help="the path to the predicted sql queries", default = pred_path)
|
||
parser.add_argument('--db', dest='db', type=str, help="the path to the database", default=db_path)
|
||
args = parser.parse_args()
|
||
|
||
evaluate(args.gold, args.pred, args.db) |