LCA-GPT/LLM-SQL/exec_eval.py

275 lines
9.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import re
import asyncio
import sqlite3
import threading
from typing import Tuple, Any, List, Set
from itertools import product
from collections import defaultdict
import tqdm
import random
from parse import get_all_preds_for_execution, remove_distinct
import time
import pickle as pkl
import subprocess
from itertools import chain
threadLock = threading.Lock()
TIMEOUT = 10
EXEC_TMP_DIR = 'tmp/'
def permute_tuple(element: Tuple, perm: Tuple) -> Tuple:
assert len(element) == len(perm)
return tuple([element[i] for i in perm])
def unorder_row(row: Tuple) -> Tuple:
return tuple(sorted(row, key=lambda x: str(x) + str(type(x))))
# unorder each row in the table
# [result_1 and result_2 has the same bag of unordered row]
# is a necessary condition of
# [result_1 and result_2 are equivalent in denotation]
def quick_rej(result1: List[Tuple], result2: List[Tuple], order_matters: bool) -> bool:
s1 = [unorder_row(row) for row in result1]
s2 = [unorder_row(row) for row in result2]
if order_matters:
return s1 == s2
else:
return set(s1) == set(s2)
# return whether two bag of relations are equivalent
def multiset_eq(l1: List, l2: List) -> bool:
if len(l1) != len(l2):
return False
d = defaultdict(int)
for e in l1:
d[e] = d[e] + 1
for e in l2:
d[e] = d[e] - 1
if d[e] < 0:
return False
return True
def get_constraint_permutation(tab1_sets_by_columns: List[Set], result2: List[Tuple]):
num_cols = len(result2[0])
perm_constraints = [{i for i in range(num_cols)} for _ in range(num_cols)]
if num_cols <= 3:
return product(*perm_constraints)
# we sample 20 rows and constrain the space of permutations
for _ in range(20):
random_tab2_row = random.choice(result2)
for tab1_col in range(num_cols):
for tab2_col in set(perm_constraints[tab1_col]):
if random_tab2_row[tab2_col] not in tab1_sets_by_columns[tab1_col]:
perm_constraints[tab1_col].remove(tab2_col)
return product(*perm_constraints)
# check whether two denotations are correct
def result_eq(result1: List[Tuple], result2: List[Tuple], order_matters: bool) -> bool:
if len(result1) == 0 and len(result2) == 0:
return True
# if length is not the same, then they are definitely different bag of rows
if len(result1) != len(result2):
return False
num_cols = len(result1[0])
# if the results do not have the same number of columns, they are different
if len(result2[0]) != num_cols:
return False
# unorder each row and compare whether the denotation is the same
# this can already find most pair of denotations that are different
if not quick_rej(result1, result2, order_matters):
return False
# the rest of the problem is in fact more complicated than one might think
# we want to find a permutation of column order and a permutation of row order,
# s.t. result_1 is the same as result_2
# we return true if we can find such column & row permutations
# and false if we cannot
tab1_sets_by_columns = [{row[i] for row in result1} for i in range(num_cols)]
# on a high level, we enumerate all possible column permutations that might make result_1 == result_2
# we decrease the size of the column permutation space by the function get_constraint_permutation
# if one of the permutation make result_1, result_2 equivalent, then they are equivalent
for perm in get_constraint_permutation(tab1_sets_by_columns, result2):
if len(perm) != len(set(perm)):
continue
if num_cols == 1:
result2_perm = result2
else:
result2_perm = [permute_tuple(element, perm) for element in result2]
if order_matters:
if result1 == result2_perm:
return True
else:
# in fact the first condition must hold if the second condition holds
# but the first is way more efficient implementation-wise
# and we use it to quickly reject impossible candidates
if set(result1) == set(result2_perm) and multiset_eq(result1, result2_perm):
return True
return False
def replace_cur_year(query: str) -> str:
return re.sub(
"YEAR\s*\(\s*CURDATE\s*\(\s*\)\s*\)\s*", "2020", query, flags=re.IGNORECASE
)
# get the database cursor for a sqlite database path
def get_cursor_from_path(sqlite_path: str):
try:
if not os.path.exists(sqlite_path):
print("Openning a new connection %s" % sqlite_path)
conn = sqlite3.connect(sqlite_path)
except Exception as e:
print(sqlite_path)
raise e
cursor = conn.cursor()
return cursor
def exec_on_db_(sqlite_path: str, query: str) -> Tuple[str, Any]:
query = replace_cur_year(query)
cursor = get_cursor_from_path(sqlite_path)
# print("####SQL 查询####:",query)
try:
cursor.execute(query)
result = cursor.fetchall()
# print("####SQL查询结果####",result)
cursor.close()
cursor.connection.close()
return "result", result
except Exception as e:
cursor.close()
cursor.connection.close()
return "exception", e
def exec_on_db(sqlite_path: str, query: str):
try:
return exec_on_db_(sqlite_path, query)
except Exception as e:
return ("exception", e)
def query_database(db_path, query):
"""
参数:
db_path (str): 数据库文件的路径。
query (str): 要执行的 SQL 查询语句。
返回:
result: 查询结果。
"""
# 连接到数据库
conn = sqlite3.connect(db_path)
cur = conn.cursor()
try:
# 执行 SQL 查询
cur.execute(query)
# 获取查询结果
result = cur.fetchone()
# 返回结果
return "result",result
except sqlite3.Error as e:
print(f"An error occurred: {e}")
return "exception", e
finally:
# 关闭游标和连接
cur.close()
conn.close()
# postprocess the model predictions to avoid execution errors
# e.g. removing spaces between ">" and "="
def postprocess(query: str) -> str:
query = query.replace('> =', '>=').replace('< =', '<=').replace('! =', '!=').replace('<>','!=')
return query
# approximate whether p_str and g_str are semantically equivalent
# db is the database path
# we are going to evaluate whether they are equivalent in all the databases
# that are in the same directory as db
# 0 if denotationally equivalent
# 1 otherwise
# the meaning of each auxillary argument can be seen in the parser definition in evaluation.py
def eval_exec_match(db: str, p_str: str, g_str: str, plug_value: bool, keep_distinct: bool, progress_bar_for_each_datapoint: bool) -> int:
# post-process the prediction.
# e.g. removing spaces between ">" and "="
p_str, g_str = postprocess(p_str), postprocess(g_str)
if not keep_distinct:
p_str = remove_distinct(p_str)
g_str = remove_distinct(g_str)
# we decide whether two denotations are equivalent based on "bag semantics"
# https://courses.cs.washington.edu/courses/cse444/10sp/lectures/lecture16.pdf
# if there is order by in query, then we assume order of the rows matter
# order by might also be used to find the max/min instead of sorting,
# but in that case the result mostly only contains one row and hence order_matters does not make a difference
order_matters = 'order by' in g_str.lower()
# find all databases in the same directory
# db_dir = os.path.dirname(db)
# db_paths = [os.path.join(db_dir, basename) for basename in os.listdir(db_dir) if '.sqlite' in basename]
db_paths = [db]
# print("db_paths:####",db_paths)
preds = [p_str]
# if plug in value (i.e. we do not consider value prediction correctness)
# enumerate all ways to plug in values in the gold query to the model predictions
# otherwise, we only evaluate the predicted query with its own value prediction
if plug_value:
_, preds = get_all_preds_for_execution(g_str, p_str)
# we did not add this line in our EMNLP work
# this reduces "false negatives" when value is substituted
preds = chain([p_str], preds)
for pred in preds:
pred_passes = 1
# compare the gold and predicted denotations on each database in the directory
# wrap with progress bar if required
if progress_bar_for_each_datapoint:
ranger = tqdm.tqdm(db_paths)
else:
ranger = db_paths
for db_path in ranger:
# print("#####db_path####",db_path)
g_flag, g_denotation = query_database(db_path,g_str) # exec_on_db(db_path, g_str)
p_flag, p_denotation = query_database(db_path,p_str) # exec_on_db(db_path, pred)
# we should expect the gold to be succesfully executed on the database
assert g_flag != 'exception', 'gold query %s has error on database file %s' % (g_str, db_path)
# wrong if execution fails
if p_flag == 'exception':
pred_passes = 0
# if denotations are not equivalent, the prediction must be wrong
# elif not result_eq(g_denotation, p_denotation, order_matters=order_matters):
elif g_denotation != p_denotation:
pred_passes = 0
if pred_passes == 0:
break
# the model prediction has the same denotation as the gold for all databases
if pred_passes == 1:
return 1
# none of the predictions passed
return 0