29 KiB
29 KiB
In [1]:
import os from langchain_community.llms import QianfanLLMEndpoint,Tongyi from langchain_community.chat_models import ChatZhipuAI from langchain.chains import LLMChain from langchain.prompts import PromptTemplate import psycopg2 import re
In [32]:
# 百度千帆 api = "xxxxx" sk = "xxxxxx" llm = QianfanLLMEndpoint(model="ERNIE-4.0-8K",qianfan_ak=api,qianfan_sk=sk)
In [3]:
# 千问 # os.environ['DASHSCOPE_API_KEY'] = 'sk-xxxxxxxxx' # # llm = Tongyi(model_name="qwen-max")
In [5]:
input_text = "用50个字左右阐述,生命的意义在于" response = llm.invoke(input_text) print(response)
生命的意义在于追求个人成长与幸福,同时为他人和社会贡献价值,实现自我与世界的和谐共存。
In [10]:
# 数据库连接参数 dbname = "gis_lca" user = "postgres" password = "xxxxxx" host = "localhost" # 或者是你 Docker 容器的 IP 地址,如果你在不同的机器上 port = "5432" # 连接字符串 conn_string = f"host={host} dbname={dbname} user={user} password={password} port={port}" # 连接到数据库 conn = psycopg2.connect(conn_string) cur = conn.cursor()
In [11]:
# 获取表名 def get_table_name(cur): # 执行 SQL 查询以获取所有表的列表 cur.execute(""" SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'; """) # 获取查询结果 tables = cur.fetchall() tname = [] # 打印表名 for table in tables: tname.append(table[0]) return tname
In [12]:
# 获取数据库列名。 def get_table_columns(cur,table_name): try: # 执行 SQL 查询以获取表的列名 cur.execute(f""" SELECT column_name FROM information_schema.columns WHERE table_name = '{table_name}'; """) # 获取查询结果 cols = [desc[0] for desc in cur.fetchall()] columns = [] for col in cols: # column = "'"+col+"'" columns.append(col) except Exception as e: print(f"An error occurred: {e}") return columns
In [13]:
# 获取数据库表结构,表名对应列名 schema = dict() table_names = get_table_name(cur) for name in table_names: schema[name] = get_table_columns(cur,name)
In [14]:
def get_column_types(cur, table_name): # 执行 SQL 查询以获取表的列名和数据类型 cur.execute(""" SELECT column_name, data_type FROM information_schema.columns WHERE table_name = %s; """, (table_name,)) # 获取查询结果 results = cur.fetchall() return results
In [15]:
col = get_column_types(cur,"tb_process") print(col)
[('share', 'integer'), ('deleted', 'smallint'), ('update_time', 'timestamp without time zone'), ('create_time', 'timestamp without time zone'), ('create_user', 'bigint'), ('reference_process_id', 'bigint'), ('parent_id', 'bigint'), ('node_type', 'smallint'), ('extend', 'json'), ('id', 'bigint'), ('version', 'bigint'), ('last_change', 'bigint'), ('category_id', 'bigint'), ('infrastructure_process', 'smallint'), ('quantitative_reference_id', 'bigint'), ('location_id', 'bigint'), ('process_doc_id', 'bigint'), ('dq_system_id', 'bigint'), ('exchange_dq_system_id', 'bigint'), ('social_dq_system_id', 'bigint'), ('last_internal_id', 'integer'), ('currency_id', 'bigint'), ('product_system_id', 'bigint'), ('process_source', 'smallint'), ('node', 'jsonb'), ('ref_id', 'character varying'), ('name', 'character varying'), ('default_allocation_method', 'character varying'), ('synonyms', 'text'), ('dq_entry', 'character varying'), ('tags', 'character varying'), ('library', 'character varying'), ('description', 'text'), ('process_type', 'character varying')]
In [16]:
def getSql(response): pattern = r"#(.*?)#" sql_list = re.findall(pattern, response) sql_str = sql_list[-1] # sql_str = sql_str.lower() print("SQL:",sql_str) return sql_str
In [40]:
def query_database(text, schema): # 将表结构信息包含在提示中 schema_info = "\n".join([f"Table {table}: {', '.join(columns)}" for table, columns in schema.items()]) # print(schema_info) # prompt= PromptTemplate( # input_variables=["schema_info","prompt"], # template='''以下是数据库的表结构信息:{schema_info},分析表结构信息. # 请根据以下描述生成一个符合SQLite数据库语法的SQL查询,并且不能修改给出的数据表列名。 # 描述:{prompt}。 # 要求输出的SQL以#开头,以#结尾,参数类型一律按照字符串处理,样例如下: # #SELECT * FROM table# # #SELECT COUNT(*) FROM table where Colume_name='abc'# # 注意不要输出分析过程和其他内容,直接给出SQL语句。 # ''' # ) prompt= PromptTemplate( input_variables=["prompt"], template='''请根据以下描述生成一个符合SQLite数据库语法的SQL查询。 描述:{prompt}。 并且要求输出的SQL以#开头,以#结尾,参数类型一律按照字符串处理,样例如下: #SELECT * FROM table# #SELECT COUNT(*) FROM table where Colume_name='abc'# 注意不要输出分析过程和其他内容,直接给出SQL语句。 ''' ) # print(schema_info) chain = LLMChain(llm=llm, prompt=prompt) query = chain.run( {"prompt":text} ) # print(query) sql = getSql(query) # result = execute_query(sql) return sql
In [41]:
import pandas as pd df = pd.read_excel(r"./data/sql_ques_clear.xlsx") df.head()
Out[41]:
SQL | Question | table | |
---|---|---|---|
0 | select count( * ) as total from tb_process whe... | 数据库tb_process表中有多少未删除的数据? | tb_process |
1 | select tags from tb_process where product_syst... | 在数据库tb_process表中,product_system_id为12983111603... | tb_process |
2 | select id,ref_id,name,synonyms,category_id,pro... | 在数据库tb_process表中,process_type为‘unit_process’且p... | tb_process |
3 | select id,ref_id,name,synonyms,category_id,des... | 在数据库tb_process表中,id为1300755365750636544的记录的id、... | tb_process |
4 | select id,ref_id,name,synonyms,category_id,des... | 在数据库tb_process表中,id为9237375的记录的id、ref_id、name、... | tb_process |
In [42]:
ques_test = "数据库tb_process表中有多少未删除的数据?" tmp_dict = dict() tmp_dict["tb_process"] = schema["tb_process"] print(tmp_dict) res = query_database(ques_test,tmp_dict) print("res:",res)
{'tb_process': ['share', 'deleted', 'update_time', 'create_time', 'create_user', 'reference_process_id', 'parent_id', 'node_type', 'extend', 'id', 'version', 'last_change', 'category_id', 'infrastructure_process', 'quantitative_reference_id', 'location_id', 'process_doc_id', 'dq_system_id', 'exchange_dq_system_id', 'social_dq_system_id', 'last_internal_id', 'currency_id', 'product_system_id', 'process_source', 'node', 'ref_id', 'name', 'default_allocation_method', 'synonyms', 'dq_entry', 'tags', 'library', 'description', 'process_type']} SQL: SELECT COUNT(*) FROM tb_process res: SELECT COUNT(*) FROM tb_process
In [43]:
sql_gold = [] sql_pred = [] for index,item in df.iterrows(): prompt = item['Question'] table = item['table'] gold = item['SQL'] if table not in table_names: print(table) continue if 'join' in gold: continue colum_type = get_column_types(cur,table) tmp = dict() tmp[table] = schema[table] sql = query_database(prompt,tmp) sql_pred.append(sql) sql_gold.append(gold)
SQL: SELECT COUNT(*) FROM tb_process SQL: SELECT tags FROM tb_process WHERE product_system_id='1298311160348545024' AND is_deleted <> '1' SQL: SELECT id, ref_id, name, synonyms, category_id, process_type, default_allocation_method, infrastructure_process, quantitative_reference_id, dq_entry, dq_system_id, exchange_dq_system_id, social_dq_system_id, location_id, process_doc_id, currency_id, version, last_change, tags, library, product_system_id, process_source, reference_process_id, node, parent_id, node_type, deleted, create_time, update_time, create_user, extend, share FROM tb_process WHERE process_type='unit_process' AND process_source='0' SQL: SELECT id, ref_id, name, synonyms, category_id, description, process_type, default_allocation_method, infrastructure_process, quantitative_reference_id, dq_entry, dq_system_id, exchange_dq_system_id, social_dq_system_id, location_id, process_doc_id, currency_id, version, last_change, tags, library, product_system_id, process_source, reference_process_id, node, parent_id, node_type, deleted, create_time, update_time, create_user, extend, share FROM tb_process WHERE id='1300755365750636544' SQL: SELECT id, ref_id, name, synonyms, category_id, description, process_type, default_allocation_method, infrastructure_process, quantitative_reference_id, dq_entry, dq_system_id, exchange_dq_system_id, social_dq_system_id, location_id, process_doc_id, currency_id, version, last_change, tags, library, product_system_id, process_source, reference_process_id, node, parent_id, node_type, deleted, create_time, update_time, create_user, extend, share FROM tb_process WHERE id='9237375' SQL: SELECT COUNT(*) FROM tb_process WHERE name='铸铁' AND product_system_id='1300754248748761088' AND id!='1300755365750636544' AND is_deleted!='1' SQL: SELECT * FROM tb_process WHERE ref_id='49e7b885-8db7-3d63-8fce-5e32b8a71391' AND process_type='unit_process' LIMIT 1 SQL: SELECT location_id FROM tb_process WHERE id='3448019508185273680' SQL: SELECT id, ref_id, name, synonyms, category_id, description, process_type, default_allocation_method, infrastructure_process, quantitative_reference_id, dq_entry, dq_system_id, exchange_dq_system_id, social_dq_system_id, location_id, process_doc_id, currency_id, version, last_change, tags, library, product_system_id, process_source, reference_process_id, node, parent_id, node_type, deleted, create_time, update_time, create_user, extend, share FROM tb_process WHERE product_system_id='1276473778162892800' AND deleted='0' SQL: SELECT id, name, node_type, tags, process_type, node FROM tb_process WHERE parent_id='1250448580720721920' AND deleted<>1 SQL: SELECT id, name, node_type, tags, process_type, node FROM tb_process WHERE parent_id='1250448580720721920' AND deleted<>1 SQL: SELECT id, ref_id, name, synonyms, category_id, description, process_type, default_allocation_method, infrastructure_process, quantitative_reference_id, dq_entry, dq_system_id, exchange_dq_system_id, social_dq_system_id, location_id, process_doc_id, currency_id, version, last_change, tags, library, product_system_id, process_source, reference_process_id, node, parent_id, node_type, deleted, create_time, update_time, create_user, extend, share FROM tb_process WHERE id IN ('1176899308570542080','1176898365472899072') SQL: SELECT id, ref_id, name, synonyms, category_id, description, process_type, default_allocation_method, infrastructure_process, quantitative_reference_id, dq_entry, dq_system_id, exchange_dq_system_id, social_dq_system_id, location_id, process_doc_id, currency_id, version, last_change, tags, library, product_system_id, process_source, reference_process_id, node, parent_id, node_type, deleted, create_time, update_time, create_user, extend, share FROM tb_process WHERE synonyms LIKE '%欧元%' SQL: SELECT id, ref_id, name, synonyms, category_id, description, process_type, default_allocation_method, infrastructure_process, quantitative_reference_id, dq_entry, dq_system_id, exchange_dq_system_id, social_dq_system_id, location_id, process_doc_id, currency_id, version, last_change, tags, library, product_system_id, process_source, reference_process_id, node, parent_id, node_type, deleted, create_time, update_time, create_user, extend, share FROM tb_process WHERE synonyms LIKE '%欧元%' AND process_source='0' SQL: SELECT id, ref_id, name, synonyms, category_id, description, process_type, default_allocation_method, infrastructure_process, quantitative_reference_id, dq_entry, dq_system_id, exchange_dq_system_id, social_dq_system_id, location_id, process_doc_id, currency_id, version, last_change, tags, library, product_system_id, process_source, reference_process_id, node, parent_id, node_type, deleted, create_time, update_time, create_user, extend, share FROM tb_process WHERE product_system_id='1298206895525330944' AND (node_type='0' OR node_type='1') AND deleted='0' AND parent_id IS NULL ORDER BY id DESC SQL: SELECT COUNT(*) FROM tb_flow WHERE deleted='0' SQL: SELECT COUNT(*) FROM tb_flow WHERE deleted='0' SQL: SELECT id, ref_id, name, flow_type, cas_number, description, tags, library, synonyms, infrastructure_flow, formula, category_id, reference_flow_property_id, source, version, last_change, create_time, create_user, update_user, deleted, update_time, location_id, name_cn, synonyms_cn, is_system, extend FROM tb_flow WHERE deleted='0' LIMIT 20000 SQL: SELECT id, ref_id, name, flow_type, cas_number, description, tags, library, synonyms, infrastructure_flow, formula, category_id, reference_flow_property_id, source, version, last_change, create_time, create_user, update_user, deleted, update_time, location_id, name_cn, synonyms_cn, is_system, extend FROM tb_flow WHERE deleted='0' LIMIT 20000 OFFSET 60000 SQL: SELECT COUNT(*) AS total FROM tb_calculation_task WHERE product_system_id='1298311160348545024' SQL: INSERT INTO tb_calculation_task (id, status, product_system_id, final_demand, method_id, method_name, calculate_type, create_user) VALUES ('1301153780011634688', '0', '1276473778162892800', 'JSON字符串', '44645931', 'ipcc 2013 gwp 100a', 'lci,lcia,contribution,sensitivity', '10078') SQL: SELECT COUNT(*) AS total FROM tb_product_system WHERE creator='10043' AND is_deleted<>'1' SQL: SELECT COUNT(*) AS total FROM tb_product_system WHERE deleted='0' SQL: SELECT id, ref_id, name, description, location_id, category_id, cutoff, reference_process_id, reference_exchange_id, target_amount, target_flow_property_factor_id, target_unit_id, model_info, project_type, create_user, create_time, update_time, deleted FROM tb_product_system WHERE id='1298311160348545024' SQL: SELECT COUNT(*) AS total FROM sys_user WHERE delete_flag = '0' SQL: SELECT id FROM sys_user WHERE org_id IN ('11', '12') SQL: SELECT id, product_system_id, form_id, data, extend FROM tb_form_data WHERE product_system_id='1247208720564224000' AND form_id='4' SQL: SELECT id FROM tb_form_data WHERE form_id='14' AND product_system_id='1299311046711840768' SQL: SELECT id, name, method_id, direction, reference_unit FROM tb_impact_category WHERE method_id='43262324' ORDER BY id SQL: SELECT SUM(amount) AS total FROM tb_lcia WHERE task_id='1298311273208877056' AND impact_category_id='43262501' SQL: SELECT COUNT(*) FROM (SELECT * FROM tb_parameter WHERE create_user = '10028' AND scope = 'global' UNION ALL SELECT * FROM tb_parameter WHERE process_id = '3448019511314224627' AND create_user = '10028') SQL: SELECT id, type, name, sort, template, script, script_url, extend FROM tb_product_system_form WHERE id='14' SQL: SELECT id, name, sort, extend FROM tb_product_system_form WHERE type='cbam' ORDER BY sort ASC SQL: SELECT * FROM tb_process_doc WHERE valid_from > '2020-01-01' SQL: SELECT * FROM tb_process_doc WHERE reviewer_id='1212' SQL: SELECT COUNT(*) FROM tb_process_doc WHERE project='asdfasdf' SQL: SELECT * FROM tb_process_doc WHERE valid_from > '2019-01-01' AND valid_until < '2022-12-31' SQL: SELECT * FROM tb_process_doc WHERE geography='17 European plants' SQL: SELECT * FROM tb_process_doc WHERE technology LIKE '%Acid-catalysed esterification of free-fatty acids.%' SQL: SELECT reviewer_id, COUNT(*) as record_count FROM tb_process_doc GROUP BY reviewer_id SQL: SELECT * FROM tb_process_doc WHERE copyright='1' SQL: SELECT * FROM tb_process_doc WHERE data_generator_id='1319' SQL: SELECT dataset_owner_id, COUNT(*) as record_count FROM tb_process_doc GROUP BY dataset_owner_id SQL: SELECT * FROM tb_process_doc WHERE data_documentor_id='1217' SQL: SELECT * FROM tb_process_doc WHERE publication_id='1500' SQL: SELECT * FROM tb_process_doc WHERE preceding_dataset='6006a8fc-1a6c-55ee-a49d-fc9342f5d451' SQL: SELECT * FROM tb_process_doc WHERE data_treatment LIKE '%Approximation%' SQL: SELECT * FROM tb_process_doc WHERE data_collection_period LIKE '%2020-2021%' SQL: SELECT * FROM tb_process WHERE deleted='1' SQL: SELECT * FROM tb_process WHERE update_time > '2024-10-15' SQL: SELECT COUNT(*) FROM tb_process WHERE create_user='10000' SQL: SELECT * FROM tb_process WHERE category_id='66' SQL: SELECT location_id, COUNT(*) FROM tb_process GROUP BY location_id SQL: SELECT * FROM tb_process WHERE dq_system_id='129124' SQL: SELECT * FROM tb_process WHERE tags LIKE '%10煤电%' SQL: SELECT * FROM tb_process WHERE default_allocation_method='economic allocation' SQL: SELECT process_type, COUNT(*) as count FROM tb_process GROUP BY process_type SQL: SELECT * FROM tb_process WHERE name LIKE '%AAAAsadfasd%' SQL: SELECT * FROM tb_process WHERE node_type='2' SQL: SELECT version, COUNT(*) as count FROM tb_process GROUP BY version SQL: SELECT * FROM tb_process WHERE process_doc_id='594641' SQL: SELECT * FROM tb_process WHERE default_allocation_method='CAUSAL' SQL: SELECT * FROM tb_process WHERE create_time BETWEEN '2022-01-01' AND '2023-01-01' SQL: SELECT * FROM tb_process WHERE process_source='1' SQL: SELECT * FROM tb_process WHERE ref_id='d4730e0a-9bae-3fde-93e1-cac96735aa76' SQL: SELECT * FROM tb_process WHERE infrastructure_process='0'
In [44]:
len(sql_gold)
Out[44]:
66
In [45]:
def save_txt(save_list,path): with open(path,"w",encoding="utf-8") as f: for item in save_list: f.write(item+"\n")
In [46]:
save_txt(sql_gold,"./data/Qianfan/gold.txt") save_txt(sql_pred,"./data/Qianfan/pred.txt")
改路径!!!¶
In [13]:
!python --version
In [ ]:
# 执行准确率
In [73]:
print(schema['sys_user'])
["'last_login_time'", "'version'", "'deleted'", "'creator'", "'create_time'", "'updater'", "'update_time'", "'user_type'", "'person_auth_id'", "'enterprise_auth_id'", "'id'", "'gender'", "'org_id'", "'super_admin'", "'status'", "'username'", "'password'", "'real_name'", "'avatar'", "'open_id'", "'email'", "'mobile'", "'nickname'", "'id_number'", "'remark'"]
In [ ]: