53 lines
2.5 KiB
Python
53 lines
2.5 KiB
Python
from exec_eval import eval_exec_match
|
|
import argparse
|
|
|
|
def evaluate(gold, predict, db_dir, plug_value, keep_distinct, progress_bar_for_each_datapoint):
|
|
|
|
scores = []
|
|
with open(gold) as f:
|
|
glist = []
|
|
for l in f.readlines():
|
|
glist.append(l)
|
|
|
|
with open(predict) as f:
|
|
plist = []
|
|
for l in f.readlines():
|
|
plist.append(l)
|
|
|
|
assert len(plist) == len(glist), "number of sessions must equal"
|
|
|
|
for idx,(p, g) in enumerate(zip(plist, glist)):
|
|
# print(p)
|
|
print("序号:",idx)
|
|
|
|
p_str = p
|
|
# p_str = p_str.replace("value", "1")
|
|
g_str = g
|
|
exec_score = eval_exec_match(db=db_dir, p_str=p_str, g_str=g_str, plug_value=plug_value,
|
|
keep_distinct=keep_distinct, progress_bar_for_each_datapoint=progress_bar_for_each_datapoint)
|
|
|
|
scores.append(exec_score)
|
|
return sum(scores),len(scores)
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
gold = "/home/zhangxj/WorkFile/LCA-GPT/LLM-SQL/data/tb_process/gold_no_err.txt"
|
|
pred = "/home/zhangxj/WorkFile/LCA-GPT/LLM-SQL/data/tb_process/pred_no_err.txt"
|
|
db_dir = "/home/zhangxj/WorkFile/LCA-GPT/LLM-SQL/data/tb_process/tb_process.db"
|
|
|
|
parser.add_argument('--gold', dest='gold', type=str, help="the path to the gold queries",default=gold)
|
|
parser.add_argument('--pred', dest='pred', type=str, help="the path to the predicted queries",default=pred)
|
|
parser.add_argument('--db', dest='db', type=str, help="the directory that contains all the databases and test suites",default=db_dir)
|
|
# parser.add_argument('--table', dest='table', type=str, help="the tables.json schema file")
|
|
parser.add_argument('--plug_value', default=False, action='store_true',
|
|
help='whether to plug in the gold value into the predicted query; suitable if your model does not predict values.')
|
|
parser.add_argument('--keep_distinct', default=False, action='store_true',
|
|
help='whether to keep distinct keyword during evaluation. default is false.')
|
|
parser.add_argument('--progress_bar_for_each_datapoint', default=False, action='store_true',
|
|
help='whether to print progress bar of running test inputs for each datapoint')
|
|
args = parser.parse_args()
|
|
|
|
sum,len = evaluate(args.gold, args.pred, args.db, args.plug_value, args.keep_distinct, args.progress_bar_for_each_datapoint)
|
|
print(sum/len,f"({sum}/{len})")
|
|
|