92 lines
2.7 KiB
Python
92 lines
2.7 KiB
Python
import os
|
||
from langchain_community.llms import QianfanLLMEndpoint,Tongyi
|
||
from langchain.chains import LLMChain
|
||
from langchain.prompts import PromptTemplate
|
||
import psycopg2
|
||
import re
|
||
|
||
def getSql(response):
|
||
pattern = r"#(.*?)#"
|
||
sql_list = re.findall(pattern, response)
|
||
sql_str = ''.join(sql_list)
|
||
print(sql_str)
|
||
|
||
return sql_str
|
||
|
||
|
||
def query_database(text, schema):
|
||
# 将表结构信息包含在提示中
|
||
schema_info = "\n".join([f"Table {table}: {', '.join(columns)}" for table, columns in schema.items()])
|
||
# print(schema_info)
|
||
prompt = PromptTemplate(
|
||
input_variables=["schema_info", "prompt"],
|
||
template='''以下是数据库的表结构信息:{schema_info}。
|
||
分析表结构信息,请根据以下描述生成一个符合SQLite数据库语法的SQL查询,描述:{prompt}。
|
||
并且要求输出的SQL以#开头,以#结尾,样例如下:
|
||
#SELECT * FROM table#
|
||
#SELECT COUNT(*) FROM table#
|
||
注意不要输出分析过程和其他内容,直接给出SQL语句。
|
||
'''
|
||
)
|
||
chain = LLMChain(llm=llm, prompt=prompt)
|
||
query = chain.run(
|
||
{"schema_info": schema_info, "prompt": text}
|
||
)
|
||
|
||
# print(query)
|
||
sql = getSql(query)
|
||
# result = execute_query(sql)
|
||
return sql
|
||
|
||
|
||
os.environ['DASHSCOPE_API_KEY'] = 'sk-c5f441f863f44094b0ddb96c831b5002'
|
||
# os.environ["QIANFAN_ACCESS_KEY"] = "zLiAbXCHVdJqQaZAREMpXxAc"
|
||
# os.environ["QIANFAN_SECRET_KEY"] = "wod9Pc2bZAyFVVOC9ypleJOAc4GoA4rT"
|
||
|
||
api = "zLiAbXCHVdJqQaZAREMpXxAc"
|
||
sk = "wod9Pc2bZAyFVVOC9ypleJOAc4GoA4rT"
|
||
|
||
llm = QianfanLLMEndpoint(model="ERNIE-4.0-8K", qianfan_ak=api, qianfan_sk=sk)
|
||
|
||
# 数据库连接参数
|
||
dbname = "gis_lca"
|
||
user = "postgres"
|
||
password = "Qibebt+123"
|
||
host = "localhost" # 或者是你 Docker 容器的 IP 地址,如果你在不同的机器上
|
||
port = "5432"
|
||
|
||
# 连接字符串
|
||
conn_string = f"host={host} dbname={dbname} user={user} password={password} port={port}"
|
||
# 连接到数据库
|
||
conn = psycopg2.connect(conn_string)
|
||
cur = conn.cursor()
|
||
|
||
# 获取数据库表结构,表名对应列名
|
||
schema = dict()
|
||
|
||
table_names = get_table_name(cur)
|
||
for name in table_names:
|
||
schema[name] = get_table_columns(cur, name)
|
||
|
||
|
||
# 读取数据
|
||
|
||
df = pd.read_excel(r"D:\python\LCA-GPT\LLM-SQL\Log\sql_ques_all.xlsx")
|
||
question = df['Question'].tolist()
|
||
|
||
sqllist = []
|
||
result = []
|
||
for ques in question:
|
||
prompt = ques
|
||
sql = query_database(prompt, schema)
|
||
sql = sql.lower()
|
||
sqllist.append(sql)
|
||
|
||
for sql in sqllist:
|
||
sql = sql.lower()
|
||
sql_data = pd.DataFrame(sqllist,columns=["SQL_pred"])
|
||
path = "./data/SQL_pred.xlsx"
|
||
sql_data.to_excel(path,index = False)
|
||
|
||
|