LCA-GPT/LLM-SQL/match_eval.py

44 lines
1.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import argparse
class Evaluator:
"""A simple evaluator for exact match."""
def __init__(self):
pass
def eval_exact_match(self, pred, label):
# 比较预测的SQL和标准SQL是否完全相同
return pred == label
def evaluate(gold, predict, db):
with open(gold) as f:
gold_sqls = [line.strip() for line in f.readlines()]
with open(predict) as f:
predict_sqls = [line.strip() for line in f.readlines()]
assert len(gold_sqls) == len(predict_sqls), "The number of gold and predict SQLs must be equal"
evaluator = Evaluator()
exact_matches = []
for g_sql, p_sql in zip(gold_sqls, predict_sqls):
# gStructured = get_sql(schema, g_sql)
# pStructured = get_sql(schema, p_sql) # getsql主要功能是获取表格的别名但是我们的数据里没有所以不需要这个步骤
exact_match = evaluator.eval_exact_match(p_sql, g_sql)
exact_matches.append(exact_match)
# 输出精确匹配的结果
exact_match_rate = sum(exact_matches) / len(exact_matches)
print(f"Exact Match Rate: {exact_match_rate:.2f} ({sum(exact_matches)}/{len(exact_matches)})")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
gold_path = "/home/zhangxj/WorkFile/LCA-GPT/LLM-SQL/data/tb_process/gold.txt"
pred_path = "/home/zhangxj/WorkFile/LCA-GPT/LLM-SQL/data/tb_process/base_pred.txt"
db_path = "/home/zhangxj/WorkFile/Text2Sql/data/spider_data/database/academic/academic.sqlite"
parser.add_argument('--gold', dest='gold', type=str, help="the path to the gold sql queries", default = gold_path)
parser.add_argument('--pred', dest='pred', type=str, help="the path to the predicted sql queries", default = pred_path)
parser.add_argument('--db', dest='db', type=str, help="the path to the database", default=db_path)
args = parser.parse_args()
evaluate(args.gold, args.pred, args.db)