LCA-GPT/LLM-SQL/executaion_eval.py

53 lines
2.5 KiB
Python

from exec_eval import eval_exec_match
import argparse
def evaluate(gold, predict, db_dir, plug_value, keep_distinct, progress_bar_for_each_datapoint):
scores = []
with open(gold) as f:
glist = []
for l in f.readlines():
glist.append(l)
with open(predict) as f:
plist = []
for l in f.readlines():
plist.append(l)
assert len(plist) == len(glist), "number of sessions must equal"
for idx,(p, g) in enumerate(zip(plist, glist)):
# print(p)
print("序号:",idx)
p_str = p
# p_str = p_str.replace("value", "1")
g_str = g
exec_score = eval_exec_match(db=db_dir, p_str=p_str, g_str=g_str, plug_value=plug_value,
keep_distinct=keep_distinct, progress_bar_for_each_datapoint=progress_bar_for_each_datapoint)
scores.append(exec_score)
return sum(scores),len(scores)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
gold = "/home/zhangxj/WorkFile/LCA-GPT/LLM-SQL/data/tb_process/gold_no_err.txt"
pred = "/home/zhangxj/WorkFile/LCA-GPT/LLM-SQL/data/tb_process/pred_no_err.txt"
db_dir = "/home/zhangxj/WorkFile/LCA-GPT/LLM-SQL/data/tb_process/tb_process.db"
parser.add_argument('--gold', dest='gold', type=str, help="the path to the gold queries",default=gold)
parser.add_argument('--pred', dest='pred', type=str, help="the path to the predicted queries",default=pred)
parser.add_argument('--db', dest='db', type=str, help="the directory that contains all the databases and test suites",default=db_dir)
# parser.add_argument('--table', dest='table', type=str, help="the tables.json schema file")
parser.add_argument('--plug_value', default=False, action='store_true',
help='whether to plug in the gold value into the predicted query; suitable if your model does not predict values.')
parser.add_argument('--keep_distinct', default=False, action='store_true',
help='whether to keep distinct keyword during evaluation. default is false.')
parser.add_argument('--progress_bar_for_each_datapoint', default=False, action='store_true',
help='whether to print progress bar of running test inputs for each datapoint')
args = parser.parse_args()
sum,len = evaluate(args.gold, args.pred, args.db, args.plug_value, args.keep_distinct, args.progress_bar_for_each_datapoint)
print(sum/len,f"({sum}/{len})")