275 lines
9.5 KiB
Python
275 lines
9.5 KiB
Python
import os
|
||
import re
|
||
import asyncio
|
||
import sqlite3
|
||
import threading
|
||
from typing import Tuple, Any, List, Set
|
||
from itertools import product
|
||
from collections import defaultdict
|
||
import tqdm
|
||
import random
|
||
from parse import get_all_preds_for_execution, remove_distinct
|
||
import time
|
||
import pickle as pkl
|
||
import subprocess
|
||
from itertools import chain
|
||
|
||
threadLock = threading.Lock()
|
||
TIMEOUT = 10
|
||
EXEC_TMP_DIR = 'tmp/'
|
||
|
||
def permute_tuple(element: Tuple, perm: Tuple) -> Tuple:
|
||
assert len(element) == len(perm)
|
||
return tuple([element[i] for i in perm])
|
||
|
||
|
||
def unorder_row(row: Tuple) -> Tuple:
|
||
return tuple(sorted(row, key=lambda x: str(x) + str(type(x))))
|
||
|
||
|
||
# unorder each row in the table
|
||
# [result_1 and result_2 has the same bag of unordered row]
|
||
# is a necessary condition of
|
||
# [result_1 and result_2 are equivalent in denotation]
|
||
def quick_rej(result1: List[Tuple], result2: List[Tuple], order_matters: bool) -> bool:
|
||
s1 = [unorder_row(row) for row in result1]
|
||
s2 = [unorder_row(row) for row in result2]
|
||
if order_matters:
|
||
return s1 == s2
|
||
else:
|
||
return set(s1) == set(s2)
|
||
|
||
|
||
# return whether two bag of relations are equivalent
|
||
def multiset_eq(l1: List, l2: List) -> bool:
|
||
if len(l1) != len(l2):
|
||
return False
|
||
d = defaultdict(int)
|
||
for e in l1:
|
||
d[e] = d[e] + 1
|
||
for e in l2:
|
||
d[e] = d[e] - 1
|
||
if d[e] < 0:
|
||
return False
|
||
return True
|
||
|
||
|
||
def get_constraint_permutation(tab1_sets_by_columns: List[Set], result2: List[Tuple]):
|
||
num_cols = len(result2[0])
|
||
perm_constraints = [{i for i in range(num_cols)} for _ in range(num_cols)]
|
||
if num_cols <= 3:
|
||
return product(*perm_constraints)
|
||
|
||
# we sample 20 rows and constrain the space of permutations
|
||
for _ in range(20):
|
||
random_tab2_row = random.choice(result2)
|
||
|
||
for tab1_col in range(num_cols):
|
||
for tab2_col in set(perm_constraints[tab1_col]):
|
||
if random_tab2_row[tab2_col] not in tab1_sets_by_columns[tab1_col]:
|
||
perm_constraints[tab1_col].remove(tab2_col)
|
||
return product(*perm_constraints)
|
||
|
||
|
||
# check whether two denotations are correct
|
||
def result_eq(result1: List[Tuple], result2: List[Tuple], order_matters: bool) -> bool:
|
||
if len(result1) == 0 and len(result2) == 0:
|
||
return True
|
||
|
||
# if length is not the same, then they are definitely different bag of rows
|
||
if len(result1) != len(result2):
|
||
return False
|
||
|
||
num_cols = len(result1[0])
|
||
|
||
# if the results do not have the same number of columns, they are different
|
||
if len(result2[0]) != num_cols:
|
||
return False
|
||
|
||
# unorder each row and compare whether the denotation is the same
|
||
# this can already find most pair of denotations that are different
|
||
if not quick_rej(result1, result2, order_matters):
|
||
return False
|
||
|
||
# the rest of the problem is in fact more complicated than one might think
|
||
# we want to find a permutation of column order and a permutation of row order,
|
||
# s.t. result_1 is the same as result_2
|
||
# we return true if we can find such column & row permutations
|
||
# and false if we cannot
|
||
tab1_sets_by_columns = [{row[i] for row in result1} for i in range(num_cols)]
|
||
|
||
# on a high level, we enumerate all possible column permutations that might make result_1 == result_2
|
||
# we decrease the size of the column permutation space by the function get_constraint_permutation
|
||
# if one of the permutation make result_1, result_2 equivalent, then they are equivalent
|
||
for perm in get_constraint_permutation(tab1_sets_by_columns, result2):
|
||
if len(perm) != len(set(perm)):
|
||
continue
|
||
if num_cols == 1:
|
||
result2_perm = result2
|
||
else:
|
||
result2_perm = [permute_tuple(element, perm) for element in result2]
|
||
if order_matters:
|
||
if result1 == result2_perm:
|
||
return True
|
||
else:
|
||
# in fact the first condition must hold if the second condition holds
|
||
# but the first is way more efficient implementation-wise
|
||
# and we use it to quickly reject impossible candidates
|
||
if set(result1) == set(result2_perm) and multiset_eq(result1, result2_perm):
|
||
return True
|
||
return False
|
||
|
||
|
||
def replace_cur_year(query: str) -> str:
|
||
return re.sub(
|
||
"YEAR\s*\(\s*CURDATE\s*\(\s*\)\s*\)\s*", "2020", query, flags=re.IGNORECASE
|
||
)
|
||
|
||
|
||
# get the database cursor for a sqlite database path
|
||
def get_cursor_from_path(sqlite_path: str):
|
||
try:
|
||
if not os.path.exists(sqlite_path):
|
||
print("Openning a new connection %s" % sqlite_path)
|
||
conn = sqlite3.connect(sqlite_path)
|
||
except Exception as e:
|
||
print(sqlite_path)
|
||
raise e
|
||
cursor = conn.cursor()
|
||
return cursor
|
||
|
||
def exec_on_db_(sqlite_path: str, query: str) -> Tuple[str, Any]:
|
||
query = replace_cur_year(query)
|
||
cursor = get_cursor_from_path(sqlite_path)
|
||
# print("####SQL 查询####:",query)
|
||
try:
|
||
cursor.execute(query)
|
||
result = cursor.fetchall()
|
||
# print("####SQL查询结果####:",result)
|
||
cursor.close()
|
||
cursor.connection.close()
|
||
return "result", result
|
||
except Exception as e:
|
||
cursor.close()
|
||
cursor.connection.close()
|
||
return "exception", e
|
||
|
||
def exec_on_db(sqlite_path: str, query: str):
|
||
try:
|
||
return exec_on_db_(sqlite_path, query)
|
||
except Exception as e:
|
||
return ("exception", e)
|
||
|
||
def query_database(db_path, query):
|
||
"""
|
||
参数:
|
||
db_path (str): 数据库文件的路径。
|
||
query (str): 要执行的 SQL 查询语句。
|
||
|
||
返回:
|
||
result: 查询结果。
|
||
"""
|
||
# 连接到数据库
|
||
conn = sqlite3.connect(db_path)
|
||
cur = conn.cursor()
|
||
|
||
try:
|
||
# 执行 SQL 查询
|
||
cur.execute(query)
|
||
|
||
# 获取查询结果
|
||
result = cur.fetchone()
|
||
|
||
# 返回结果
|
||
return "result",result
|
||
|
||
except sqlite3.Error as e:
|
||
print(f"An error occurred: {e}")
|
||
return "exception", e
|
||
|
||
finally:
|
||
# 关闭游标和连接
|
||
cur.close()
|
||
conn.close()
|
||
|
||
# postprocess the model predictions to avoid execution errors
|
||
# e.g. removing spaces between ">" and "="
|
||
def postprocess(query: str) -> str:
|
||
query = query.replace('> =', '>=').replace('< =', '<=').replace('! =', '!=').replace('<>','!=')
|
||
return query
|
||
|
||
|
||
# approximate whether p_str and g_str are semantically equivalent
|
||
# db is the database path
|
||
# we are going to evaluate whether they are equivalent in all the databases
|
||
# that are in the same directory as db
|
||
# 0 if denotationally equivalent
|
||
# 1 otherwise
|
||
# the meaning of each auxillary argument can be seen in the parser definition in evaluation.py
|
||
def eval_exec_match(db: str, p_str: str, g_str: str, plug_value: bool, keep_distinct: bool, progress_bar_for_each_datapoint: bool) -> int:
|
||
# post-process the prediction.
|
||
# e.g. removing spaces between ">" and "="
|
||
p_str, g_str = postprocess(p_str), postprocess(g_str)
|
||
if not keep_distinct:
|
||
p_str = remove_distinct(p_str)
|
||
g_str = remove_distinct(g_str)
|
||
|
||
# we decide whether two denotations are equivalent based on "bag semantics"
|
||
# https://courses.cs.washington.edu/courses/cse444/10sp/lectures/lecture16.pdf
|
||
# if there is order by in query, then we assume order of the rows matter
|
||
# order by might also be used to find the max/min instead of sorting,
|
||
# but in that case the result mostly only contains one row and hence order_matters does not make a difference
|
||
order_matters = 'order by' in g_str.lower()
|
||
|
||
# find all databases in the same directory
|
||
# db_dir = os.path.dirname(db)
|
||
# db_paths = [os.path.join(db_dir, basename) for basename in os.listdir(db_dir) if '.sqlite' in basename]
|
||
db_paths = [db]
|
||
# print("db_paths:####",db_paths)
|
||
|
||
preds = [p_str]
|
||
# if plug in value (i.e. we do not consider value prediction correctness)
|
||
# enumerate all ways to plug in values in the gold query to the model predictions
|
||
# otherwise, we only evaluate the predicted query with its own value prediction
|
||
if plug_value:
|
||
_, preds = get_all_preds_for_execution(g_str, p_str)
|
||
# we did not add this line in our EMNLP work
|
||
# this reduces "false negatives" when value is substituted
|
||
preds = chain([p_str], preds)
|
||
|
||
for pred in preds:
|
||
|
||
pred_passes = 1
|
||
# compare the gold and predicted denotations on each database in the directory
|
||
# wrap with progress bar if required
|
||
if progress_bar_for_each_datapoint:
|
||
ranger = tqdm.tqdm(db_paths)
|
||
else:
|
||
ranger = db_paths
|
||
|
||
for db_path in ranger:
|
||
# print("#####db_path####:",db_path)
|
||
g_flag, g_denotation = query_database(db_path,g_str) # exec_on_db(db_path, g_str)
|
||
p_flag, p_denotation = query_database(db_path,p_str) # exec_on_db(db_path, pred)
|
||
|
||
# we should expect the gold to be succesfully executed on the database
|
||
assert g_flag != 'exception', 'gold query %s has error on database file %s' % (g_str, db_path)
|
||
|
||
# wrong if execution fails
|
||
if p_flag == 'exception':
|
||
pred_passes = 0
|
||
|
||
# if denotations are not equivalent, the prediction must be wrong
|
||
# elif not result_eq(g_denotation, p_denotation, order_matters=order_matters):
|
||
elif g_denotation != p_denotation:
|
||
pred_passes = 0
|
||
if pred_passes == 0:
|
||
break
|
||
|
||
# the model prediction has the same denotation as the gold for all databases
|
||
if pred_passes == 1:
|
||
return 1
|
||
|
||
# none of the predictions passed
|
||
return 0
|