import os from langchain_community.llms import QianfanLLMEndpoint,Tongyi from langchain.chains import LLMChain from langchain.prompts import PromptTemplate import psycopg2 import re def getSql(response): pattern = r"#(.*?)#" sql_list = re.findall(pattern, response) sql_str = ''.join(sql_list) print(sql_str) return sql_str def query_database(text, schema): # 将表结构信息包含在提示中 schema_info = "\n".join([f"Table {table}: {', '.join(columns)}" for table, columns in schema.items()]) # print(schema_info) prompt = PromptTemplate( input_variables=["schema_info", "prompt"], template='''以下是数据库的表结构信息:{schema_info}。 分析表结构信息,请根据以下描述生成一个符合SQLite数据库语法的SQL查询,描述:{prompt}。 并且要求输出的SQL以#开头,以#结尾,样例如下: #SELECT * FROM table# #SELECT COUNT(*) FROM table# 注意不要输出分析过程和其他内容,直接给出SQL语句。 ''' ) chain = LLMChain(llm=llm, prompt=prompt) query = chain.run( {"schema_info": schema_info, "prompt": text} ) # print(query) sql = getSql(query) # result = execute_query(sql) return sql os.environ['DASHSCOPE_API_KEY'] = 'sk-c5f441f863f44094b0ddb96c831b5002' # os.environ["QIANFAN_ACCESS_KEY"] = "zLiAbXCHVdJqQaZAREMpXxAc" # os.environ["QIANFAN_SECRET_KEY"] = "wod9Pc2bZAyFVVOC9ypleJOAc4GoA4rT" api = "zLiAbXCHVdJqQaZAREMpXxAc" sk = "wod9Pc2bZAyFVVOC9ypleJOAc4GoA4rT" llm = QianfanLLMEndpoint(model="ERNIE-4.0-8K", qianfan_ak=api, qianfan_sk=sk) # 数据库连接参数 dbname = "gis_lca" user = "postgres" password = "Qibebt+123" host = "localhost" # 或者是你 Docker 容器的 IP 地址,如果你在不同的机器上 port = "5432" # 连接字符串 conn_string = f"host={host} dbname={dbname} user={user} password={password} port={port}" # 连接到数据库 conn = psycopg2.connect(conn_string) cur = conn.cursor() # 获取数据库表结构,表名对应列名 schema = dict() table_names = get_table_name(cur) for name in table_names: schema[name] = get_table_columns(cur, name) # 读取数据 df = pd.read_excel(r"D:\python\LCA-GPT\LLM-SQL\Log\sql_ques_all.xlsx") question = df['Question'].tolist() sqllist = [] result = [] for ques in question: prompt = ques sql = query_database(prompt, schema) sql = sql.lower() sqllist.append(sql) for sql in sqllist: sql = sql.lower() sql_data = pd.DataFrame(sqllist,columns=["SQL_pred"]) path = "./data/SQL_pred.xlsx" sql_data.to_excel(path,index = False)