LCA-GPT/LLM-SQL/getPredSql.py

92 lines
2.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
from langchain_community.llms import QianfanLLMEndpoint,Tongyi
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
import psycopg2
import re
def getSql(response):
pattern = r"#(.*?)#"
sql_list = re.findall(pattern, response)
sql_str = ''.join(sql_list)
print(sql_str)
return sql_str
def query_database(text, schema):
# 将表结构信息包含在提示中
schema_info = "\n".join([f"Table {table}: {', '.join(columns)}" for table, columns in schema.items()])
# print(schema_info)
prompt = PromptTemplate(
input_variables=["schema_info", "prompt"],
template='''以下是数据库的表结构信息:{schema_info}
分析表结构信息请根据以下描述生成一个符合SQLite数据库语法的SQL查询描述{prompt}
并且要求输出的SQL以#开头,以#结尾,样例如下:
#SELECT * FROM table#
#SELECT COUNT(*) FROM table#
注意不要输出分析过程和其他内容直接给出SQL语句。
'''
)
chain = LLMChain(llm=llm, prompt=prompt)
query = chain.run(
{"schema_info": schema_info, "prompt": text}
)
# print(query)
sql = getSql(query)
# result = execute_query(sql)
return sql
os.environ['DASHSCOPE_API_KEY'] = 'sk-c5f441f863f44094b0ddb96c831b5002'
# os.environ["QIANFAN_ACCESS_KEY"] = "zLiAbXCHVdJqQaZAREMpXxAc"
# os.environ["QIANFAN_SECRET_KEY"] = "wod9Pc2bZAyFVVOC9ypleJOAc4GoA4rT"
api = "zLiAbXCHVdJqQaZAREMpXxAc"
sk = "wod9Pc2bZAyFVVOC9ypleJOAc4GoA4rT"
llm = QianfanLLMEndpoint(model="ERNIE-4.0-8K", qianfan_ak=api, qianfan_sk=sk)
# 数据库连接参数
dbname = "gis_lca"
user = "postgres"
password = "Qibebt+123"
host = "localhost" # 或者是你 Docker 容器的 IP 地址,如果你在不同的机器上
port = "5432"
# 连接字符串
conn_string = f"host={host} dbname={dbname} user={user} password={password} port={port}"
# 连接到数据库
conn = psycopg2.connect(conn_string)
cur = conn.cursor()
# 获取数据库表结构,表名对应列名
schema = dict()
table_names = get_table_name(cur)
for name in table_names:
schema[name] = get_table_columns(cur, name)
# 读取数据
df = pd.read_excel(r"D:\python\LCA-GPT\LLM-SQL\Log\sql_ques_all.xlsx")
question = df['Question'].tolist()
sqllist = []
result = []
for ques in question:
prompt = ques
sql = query_database(prompt, schema)
sql = sql.lower()
sqllist.append(sql)
for sql in sqllist:
sql = sql.lower()
sql_data = pd.DataFrame(sqllist,columns=["SQL_pred"])
path = "./data/SQL_pred.xlsx"
sql_data.to_excel(path,index = False)