from exec_eval import eval_exec_match import argparse def evaluate(gold, predict, db_dir, plug_value, keep_distinct, progress_bar_for_each_datapoint): scores = [] with open(gold) as f: glist = [] for l in f.readlines(): glist.append(l) with open(predict) as f: plist = [] for l in f.readlines(): plist.append(l) assert len(plist) == len(glist), "number of sessions must equal" for idx,(p, g) in enumerate(zip(plist, glist)): # print(p) print("序号:",idx) p_str = p # p_str = p_str.replace("value", "1") g_str = g exec_score = eval_exec_match(db=db_dir, p_str=p_str, g_str=g_str, plug_value=plug_value, keep_distinct=keep_distinct, progress_bar_for_each_datapoint=progress_bar_for_each_datapoint) scores.append(exec_score) return sum(scores),len(scores) if __name__ == "__main__": parser = argparse.ArgumentParser() gold = "/home/zhangxj/WorkFile/LCA-GPT/LLM-SQL/data/tb_process/gold_no_err.txt" pred = "/home/zhangxj/WorkFile/LCA-GPT/LLM-SQL/data/tb_process/pred_no_err.txt" db_dir = "/home/zhangxj/WorkFile/LCA-GPT/LLM-SQL/data/tb_process/tb_process.db" parser.add_argument('--gold', dest='gold', type=str, help="the path to the gold queries",default=gold) parser.add_argument('--pred', dest='pred', type=str, help="the path to the predicted queries",default=pred) parser.add_argument('--db', dest='db', type=str, help="the directory that contains all the databases and test suites",default=db_dir) # parser.add_argument('--table', dest='table', type=str, help="the tables.json schema file") parser.add_argument('--plug_value', default=False, action='store_true', help='whether to plug in the gold value into the predicted query; suitable if your model does not predict values.') parser.add_argument('--keep_distinct', default=False, action='store_true', help='whether to keep distinct keyword during evaluation. default is false.') parser.add_argument('--progress_bar_for_each_datapoint', default=False, action='store_true', help='whether to print progress bar of running test inputs for each datapoint') args = parser.parse_args() sum,len = evaluate(args.gold, args.pred, args.db, args.plug_value, args.keep_distinct, args.progress_bar_for_each_datapoint) print(sum/len,f"({sum}/{len})")