import argparse class Evaluator: """A simple evaluator for exact match.""" def __init__(self): pass def eval_exact_match(self, pred, label): # 比较预测的SQL和标准SQL是否完全相同 return pred == label def evaluate(gold, predict, db): with open(gold) as f: gold_sqls = [line.strip() for line in f.readlines()] with open(predict) as f: predict_sqls = [line.strip() for line in f.readlines()] assert len(gold_sqls) == len(predict_sqls), "The number of gold and predict SQLs must be equal" evaluator = Evaluator() exact_matches = [] for g_sql, p_sql in zip(gold_sqls, predict_sqls): # gStructured = get_sql(schema, g_sql) # pStructured = get_sql(schema, p_sql) # getsql主要功能是获取表格的别名,但是我们的数据里没有,所以不需要这个步骤 exact_match = evaluator.eval_exact_match(p_sql, g_sql) exact_matches.append(exact_match) # 输出精确匹配的结果 exact_match_rate = sum(exact_matches) / len(exact_matches) print(f"Exact Match Rate: {exact_match_rate:.2f} ({sum(exact_matches)}/{len(exact_matches)})") if __name__ == "__main__": parser = argparse.ArgumentParser() gold_path = "/home/zhangxj/WorkFile/LCA-GPT/LLM-SQL/data/tb_process/gold.txt" pred_path = "/home/zhangxj/WorkFile/LCA-GPT/LLM-SQL/data/tb_process/base_pred.txt" db_path = "/home/zhangxj/WorkFile/Text2Sql/data/spider_data/database/academic/academic.sqlite" parser.add_argument('--gold', dest='gold', type=str, help="the path to the gold sql queries", default = gold_path) parser.add_argument('--pred', dest='pred', type=str, help="the path to the predicted sql queries", default = pred_path) parser.add_argument('--db', dest='db', type=str, help="the path to the database", default=db_path) args = parser.parse_args() evaluate(args.gold, args.pred, args.db)