225 lines
8.0 KiB
Python
225 lines
8.0 KiB
Python
import re
|
|
import sqlparse
|
|
from typing import List, Tuple, Set, Iterator, Dict, Any, Union
|
|
from sqlparse.sql import Comparison, Identifier
|
|
from sqlparse.tokens import Whitespace
|
|
import itertools
|
|
from collections import namedtuple
|
|
|
|
Token = namedtuple('Token', ['ttype', 'value'])
|
|
VALUE_NUM_SYMBOL = 'VALUERARE'
|
|
QUOTE_CHARS = {'`', '\'', '"'}
|
|
|
|
|
|
def tokenize(query: str) -> List[Token]:
|
|
tokens = list([Token(t.ttype, t.value) for t in sqlparse.parse(query)[0].flatten()])
|
|
return tokens
|
|
|
|
|
|
def join_tokens(tokens: List[Token]) -> str:
|
|
return ''.join([x.value for x in tokens]).strip().replace(' ', ' ')
|
|
|
|
|
|
def round_trip_test(query: str) -> None:
|
|
tokens = tokenize(query)
|
|
reconstructed = ''.join([token.value for token in tokens])
|
|
assert query == reconstructed, "Round trip test fails for string %s" % query
|
|
|
|
|
|
def postprocess(query: str) -> str:
|
|
query = query.replace('> =', '>=').replace('< =', '<=').replace('! =', '!=')
|
|
return query
|
|
|
|
|
|
# strip_query, reformat_query and replace values
|
|
# were implemented by Yu Tao for processing CoSQL
|
|
def strip_query(query: str) -> Tuple[List[str], List[str]]:
|
|
query_keywords, all_values = [], []
|
|
|
|
# then replace all stuff enclosed by "" with a numerical value to get it marked as {VALUE}
|
|
|
|
# Tao's implementation is commented out here.
|
|
"""
|
|
str_1 = re.findall("\"[^\"]*\"", query)
|
|
str_2 = re.findall("\'[^\']*\'", query)
|
|
values = str_1 + str_2
|
|
"""
|
|
|
|
toks = sqlparse.parse(query)[0].flatten()
|
|
values = [t.value for t in toks if t.ttype == sqlparse.tokens.Literal.String.Single or t.ttype == sqlparse.tokens.Literal.String.Symbol]
|
|
|
|
|
|
for val in values:
|
|
all_values.append(val)
|
|
query = query.replace(val.strip(), VALUE_NUM_SYMBOL)
|
|
|
|
query_tokenized = query.split()
|
|
float_nums = re.findall("[-+]?\d*\.\d+", query)
|
|
all_values += [qt for qt in query_tokenized if qt in float_nums]
|
|
query_tokenized = [VALUE_NUM_SYMBOL if qt in float_nums else qt for qt in query_tokenized]
|
|
|
|
query = " ".join(query_tokenized)
|
|
int_nums = [i.strip() for i in re.findall("[^tT]\d+", query)]
|
|
|
|
all_values += [qt for qt in query_tokenized if qt in int_nums]
|
|
query_tokenized = [VALUE_NUM_SYMBOL if qt in int_nums else qt for qt in query_tokenized]
|
|
# print int_nums, query, query_tokenized
|
|
|
|
for tok in query_tokenized:
|
|
if "." in tok:
|
|
table = re.findall("[Tt]\d+\.", tok)
|
|
if len(table) > 0:
|
|
to = tok.replace(".", " . ").split()
|
|
to = [t.lower() for t in to if len(t) > 0]
|
|
query_keywords.extend(to)
|
|
else:
|
|
query_keywords.append(tok.lower())
|
|
|
|
elif len(tok) > 0:
|
|
query_keywords.append(tok.lower())
|
|
return query_keywords, all_values
|
|
|
|
|
|
def reformat_query(query: str) -> str:
|
|
query = query.strip().replace(";", "").replace("\t", "")
|
|
query = ' '.join([t.value for t in tokenize(query) if t.ttype != sqlparse.tokens.Whitespace])
|
|
t_stars = ["t1.*", "t2.*", "t3.*", "T1.*", "T2.*", "T3.*"]
|
|
for ts in t_stars:
|
|
query = query.replace(ts, "*")
|
|
return query
|
|
|
|
|
|
def replace_values(sql: str) -> Tuple[List[str], Set[str]]:
|
|
sql = sqlparse.format(sql, reindent=False, keyword_case='upper')
|
|
# sql = re.sub(r"(<=|>=|!=|=|<|>|,)", r" \1 ", sql)
|
|
sql = re.sub(r"(T\d+\.)\s", r"\1", sql)
|
|
query_toks_no_value, values = strip_query(sql)
|
|
return query_toks_no_value, set(values)
|
|
|
|
|
|
# extract the non-value tokens and the set of values
|
|
# from a sql query
|
|
def extract_query_values(sql: str) -> Tuple[List[str], Set[str]]:
|
|
reformated = reformat_query(query=sql)
|
|
query_value_replaced, values = replace_values(reformated)
|
|
return query_value_replaced, values
|
|
|
|
|
|
# plug in the values into query with value slots
|
|
def plugin(query_value_replaced: List[str], values_in_order: List[str]) -> str:
|
|
q_length = len(query_value_replaced)
|
|
query_w_values = query_value_replaced[:]
|
|
value_idx = [idx for idx in range(q_length) if query_value_replaced[idx] == VALUE_NUM_SYMBOL.lower()]
|
|
assert len(value_idx) == len(values_in_order)
|
|
|
|
for idx, value in zip(value_idx, values_in_order):
|
|
query_w_values[idx] = value
|
|
return ' '.join(query_w_values)
|
|
|
|
|
|
# a generator generating all possible ways of
|
|
# filling values into predicted query
|
|
def plugin_all_permutations(query_value_replaced: List[str], values: Set[str]) -> Iterator[str]:
|
|
num_slots = len([v for v in query_value_replaced if v == VALUE_NUM_SYMBOL.lower()])
|
|
for values in itertools.product(*[list(values) for _ in range(num_slots)]):
|
|
yield plugin(query_value_replaced, list(values))
|
|
|
|
|
|
# given the gold query and the model prediction
|
|
# extract values from the gold, extract predicted sql with value slots
|
|
# return 1) number of possible ways to plug in gold values and 2) an iterator of predictions with value plugged in
|
|
def get_all_preds_for_execution(gold: str, pred: str) -> Tuple[int, Iterator[str]]:
|
|
_, gold_values = extract_query_values(gold)
|
|
pred_query_value_replaced, _ = extract_query_values(pred)
|
|
num_slots = len([v for v in pred_query_value_replaced if v == VALUE_NUM_SYMBOL.lower()])
|
|
num_alternatives = len(gold_values) ** num_slots
|
|
return num_alternatives, plugin_all_permutations(pred_query_value_replaced, gold_values)
|
|
|
|
|
|
def remove_distinct(s):
|
|
toks = [t.value for t in list(sqlparse.parse(s)[0].flatten())]
|
|
return ''.join([t for t in toks if t.lower() != 'distinct'])
|
|
|
|
|
|
def extract_all_comparison_from_node(node: Token) -> List[Comparison]:
|
|
comparison_list = []
|
|
if hasattr(node, 'tokens'):
|
|
for t in node.tokens:
|
|
comparison_list.extend(extract_all_comparison_from_node(t))
|
|
if type(node) == Comparison:
|
|
comparison_list.append(node)
|
|
return comparison_list
|
|
|
|
|
|
def extract_all_comparison(query: str) -> List[Comparison]:
|
|
tree = sqlparse.parse(query)[0]
|
|
comparison_list = extract_all_comparison_from_node(tree)
|
|
return comparison_list
|
|
|
|
|
|
def extract_toks_from_comparison(comparison_node: Comparison) -> List[Token]:
|
|
tokens = [t for t in comparison_node.tokens if t.ttype != Whitespace]
|
|
return tokens
|
|
|
|
|
|
def extract_info_from_comparison(comparison_node: Comparison) -> Dict[str, Any]:
|
|
tokens = extract_toks_from_comparison(comparison_node)
|
|
left, op, right = tokens
|
|
|
|
returned_dict = {
|
|
'left': left,
|
|
'op': op.value,
|
|
'right': right
|
|
}
|
|
|
|
if type(left) != Identifier:
|
|
return returned_dict
|
|
|
|
table = None
|
|
if len(left.tokens) == 3 and re.match('^[tT][0-9]$', left.tokens[0].value) is None:
|
|
table = left.tokens[0].value.lower()
|
|
col = left.tokens[-1].value
|
|
|
|
if type(right) == Identifier:
|
|
if len(right.tokens) == 1 and type(right.tokens[0]) == sqlparse.sql.Token:
|
|
right_val = right.tokens[0].value
|
|
else:
|
|
return returned_dict
|
|
elif type(right) == sqlparse.sql.Token:
|
|
right_val = right.value
|
|
else:
|
|
return returned_dict
|
|
|
|
returned_dict['table_col'], returned_dict['val'] = (table, col.upper()), process_str_value(right_val)
|
|
|
|
return returned_dict
|
|
|
|
|
|
def extract_all_comparison_from_query(query: str) -> List[Dict[str, Any]]:
|
|
comparison_list = extract_all_comparison(query)
|
|
return [extract_info_from_comparison(c) for c in comparison_list]
|
|
|
|
|
|
def extract_typed_value_in_comparison_from_query(query: str) -> List[Tuple[Tuple[Union[str, None], str], str]]:
|
|
cmps = extract_all_comparison_from_query(query)
|
|
typed_values = [(cmp['table_col'], cmp['val']) for cmp in cmps if 'table_col' in cmp]
|
|
for table, col, val1, val2 in re.findall('(?:([^\.\s]*)\.)?([^\.\s]+) between ([^\s;]+) and ([^\s;]+)', query, re.IGNORECASE):
|
|
if table == '':
|
|
table = None
|
|
else:
|
|
table = table.lower()
|
|
col = col.upper()
|
|
for v in [val1, val2]:
|
|
typed_values.append(((table, col), v))
|
|
return typed_values
|
|
|
|
|
|
def process_str_value(v: str) -> str:
|
|
if len(v) > 0 and v[0] in QUOTE_CHARS:
|
|
v = v[1:]
|
|
if len(v) > 0 and v[-1] in QUOTE_CHARS:
|
|
v = v[:-1]
|
|
for c in QUOTE_CHARS:
|
|
v = v.replace(c + c, c)
|
|
return v
|