LCA-GPT/LLM-SQL/getPredSql.py

92 lines
2.7 KiB
Python
Raw Normal View History

2024-12-29 16:18:16 +08:00
import os
from langchain_community.llms import QianfanLLMEndpoint,Tongyi
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
import psycopg2
import re
def getSql(response):
pattern = r"#(.*?)#"
sql_list = re.findall(pattern, response)
sql_str = ''.join(sql_list)
print(sql_str)
return sql_str
def query_database(text, schema):
# 将表结构信息包含在提示中
schema_info = "\n".join([f"Table {table}: {', '.join(columns)}" for table, columns in schema.items()])
# print(schema_info)
prompt = PromptTemplate(
input_variables=["schema_info", "prompt"],
template='''以下是数据库的表结构信息:{schema_info}
分析表结构信息请根据以下描述生成一个符合SQLite数据库语法的SQL查询描述{prompt}
并且要求输出的SQL以#开头,以#结尾,样例如下:
#SELECT * FROM table#
#SELECT COUNT(*) FROM table#
注意不要输出分析过程和其他内容直接给出SQL语句
'''
)
chain = LLMChain(llm=llm, prompt=prompt)
query = chain.run(
{"schema_info": schema_info, "prompt": text}
)
# print(query)
sql = getSql(query)
# result = execute_query(sql)
return sql
os.environ['DASHSCOPE_API_KEY'] = 'sk-c5f441f863f44094b0ddb96c831b5002'
# os.environ["QIANFAN_ACCESS_KEY"] = "zLiAbXCHVdJqQaZAREMpXxAc"
# os.environ["QIANFAN_SECRET_KEY"] = "wod9Pc2bZAyFVVOC9ypleJOAc4GoA4rT"
api = "zLiAbXCHVdJqQaZAREMpXxAc"
sk = "wod9Pc2bZAyFVVOC9ypleJOAc4GoA4rT"
llm = QianfanLLMEndpoint(model="ERNIE-4.0-8K", qianfan_ak=api, qianfan_sk=sk)
# 数据库连接参数
dbname = "gis_lca"
user = "postgres"
password = "Qibebt+123"
host = "localhost" # 或者是你 Docker 容器的 IP 地址,如果你在不同的机器上
port = "5432"
# 连接字符串
conn_string = f"host={host} dbname={dbname} user={user} password={password} port={port}"
# 连接到数据库
conn = psycopg2.connect(conn_string)
cur = conn.cursor()
# 获取数据库表结构,表名对应列名
schema = dict()
table_names = get_table_name(cur)
for name in table_names:
schema[name] = get_table_columns(cur, name)
# 读取数据
df = pd.read_excel(r"D:\python\LCA-GPT\LLM-SQL\Log\sql_ques_all.xlsx")
question = df['Question'].tolist()
sqllist = []
result = []
for ques in question:
prompt = ques
sql = query_database(prompt, schema)
sql = sql.lower()
sqllist.append(sql)
for sql in sqllist:
sql = sql.lower()
sql_data = pd.DataFrame(sqllist,columns=["SQL_pred"])
path = "./data/SQL_pred.xlsx"
sql_data.to_excel(path,index = False)