LCA-GPT/LLM-SQL/LLM-sql.ipynb

29 KiB
Raw Permalink Blame History

In [1]:
import os
from langchain_community.llms import QianfanLLMEndpoint,Tongyi
from langchain_community.chat_models import ChatZhipuAI
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
import psycopg2
import re
In [32]:
# 百度千帆
api = "xxxxx"
sk = "xxxxxx"

llm = QianfanLLMEndpoint(model="ERNIE-4.0-8K",qianfan_ak=api,qianfan_sk=sk)
In [3]:
# 千问
# os.environ['DASHSCOPE_API_KEY'] = 'sk-xxxxxxxxx'
# 
# llm = Tongyi(model_name="qwen-max")
In [5]:
input_text = "用50个字左右阐述生命的意义在于"
response = llm.invoke(input_text)
print(response)
生命的意义在于追求个人成长与幸福,同时为他人和社会贡献价值,实现自我与世界的和谐共存。
In [10]:
# 数据库连接参数
dbname = "gis_lca"
user = "postgres"
password = "xxxxxx"
host = "localhost"  # 或者是你 Docker 容器的 IP 地址,如果你在不同的机器上
port = "5432"

# 连接字符串
conn_string = f"host={host} dbname={dbname} user={user} password={password} port={port}"
# 连接到数据库
conn = psycopg2.connect(conn_string)
cur = conn.cursor()
In [11]:
# 获取表名
def get_table_name(cur):
    # 执行 SQL 查询以获取所有表的列表
    cur.execute("""
        SELECT table_name 
        FROM information_schema.tables 
        WHERE table_schema = 'public';
    """)
    
    # 获取查询结果
    tables = cur.fetchall()
    tname = []
    # 打印表名
    for table in tables:
        tname.append(table[0])
    
    return tname
In [12]:
# 获取数据库列名。
def get_table_columns(cur,table_name):
    try:
        # 执行 SQL 查询以获取表的列名
        cur.execute(f"""
            SELECT column_name 
            FROM information_schema.columns 
            WHERE table_name = '{table_name}';
        """)
        # 获取查询结果
        cols = [desc[0] for desc in cur.fetchall()]
        columns = []
        for col in cols:
            # column = "'"+col+"'"
            columns.append(col)
    except Exception as e:
        print(f"An error occurred: {e}")
    return columns
In [13]:
# 获取数据库表结构,表名对应列名
schema = dict()

table_names = get_table_name(cur)
for name in table_names:
    schema[name] = get_table_columns(cur,name)
In [14]:
def get_column_types(cur, table_name):
    # 执行 SQL 查询以获取表的列名和数据类型
    cur.execute("""
        SELECT column_name, data_type
        FROM information_schema.columns
        WHERE table_name = %s;
    """, (table_name,))

    # 获取查询结果
    results = cur.fetchall()
    return results
In [15]:
col = get_column_types(cur,"tb_process")
print(col)
[('share', 'integer'), ('deleted', 'smallint'), ('update_time', 'timestamp without time zone'), ('create_time', 'timestamp without time zone'), ('create_user', 'bigint'), ('reference_process_id', 'bigint'), ('parent_id', 'bigint'), ('node_type', 'smallint'), ('extend', 'json'), ('id', 'bigint'), ('version', 'bigint'), ('last_change', 'bigint'), ('category_id', 'bigint'), ('infrastructure_process', 'smallint'), ('quantitative_reference_id', 'bigint'), ('location_id', 'bigint'), ('process_doc_id', 'bigint'), ('dq_system_id', 'bigint'), ('exchange_dq_system_id', 'bigint'), ('social_dq_system_id', 'bigint'), ('last_internal_id', 'integer'), ('currency_id', 'bigint'), ('product_system_id', 'bigint'), ('process_source', 'smallint'), ('node', 'jsonb'), ('ref_id', 'character varying'), ('name', 'character varying'), ('default_allocation_method', 'character varying'), ('synonyms', 'text'), ('dq_entry', 'character varying'), ('tags', 'character varying'), ('library', 'character varying'), ('description', 'text'), ('process_type', 'character varying')]
In [16]:
def getSql(response):
    pattern = r"#(.*?)#"
    sql_list = re.findall(pattern, response)
    sql_str = sql_list[-1]
    # sql_str = sql_str.lower()
    print("SQL:",sql_str)

    return sql_str
In [40]:
def query_database(text, schema):
    # 将表结构信息包含在提示中
    schema_info = "\n".join([f"Table {table}: {', '.join(columns)}" for table, columns in schema.items()])
    # print(schema_info)
    # prompt= PromptTemplate(
    #     input_variables=["schema_info","prompt"],
    #     template='''以下是数据库的表结构信息:{schema_info},分析表结构信息.
    #     请根据以下描述生成一个符合SQLite数据库语法的SQL查询并且不能修改给出的数据表列名。
    #     描述:{prompt}。
    #     要求输出的SQL以#开头,以#结尾,参数类型一律按照字符串处理,样例如下:
    #     #SELECT * FROM table#
    #     #SELECT COUNT(*) FROM table where Colume_name='abc'#
    #     注意不要输出分析过程和其他内容直接给出SQL语句。
    #     '''
    # )
    prompt= PromptTemplate(
        input_variables=["prompt"],
        template='''请根据以下描述生成一个符合SQLite数据库语法的SQL查询。
        描述:{prompt}
        并且要求输出的SQL以#开头,以#结尾,参数类型一律按照字符串处理,样例如下:
        #SELECT * FROM table#
        #SELECT COUNT(*) FROM table where Colume_name='abc'#
        注意不要输出分析过程和其他内容直接给出SQL语句。
        '''
    )
    # print(schema_info)
    chain = LLMChain(llm=llm, prompt=prompt)
    query = chain.run(
        {"prompt":text}
    )
    # print(query)
    sql = getSql(query)
    # result = execute_query(sql)
    return sql
In [41]:
import pandas as pd
df = pd.read_excel(r"./data/sql_ques_clear.xlsx")
df.head()
Out[41]:
SQL Question table
0 select count( * ) as total from tb_process whe... 数据库tb_process表中有多少未删除的数据 tb_process
1 select tags from tb_process where product_syst... 在数据库tb_process表中product_system_id为12983111603... tb_process
2 select id,ref_id,name,synonyms,category_id,pro... 在数据库tb_process表中process_type为unit_process且p... tb_process
3 select id,ref_id,name,synonyms,category_id,des... 在数据库tb_process表中id为1300755365750636544的记录的id、... tb_process
4 select id,ref_id,name,synonyms,category_id,des... 在数据库tb_process表中id为9237375的记录的id、ref_id、name、... tb_process
In [42]:
ques_test = "数据库tb_process表中有多少未删除的数据"
tmp_dict = dict()
tmp_dict["tb_process"] = schema["tb_process"]
print(tmp_dict)
res = query_database(ques_test,tmp_dict)
print("res:",res)
{'tb_process': ['share', 'deleted', 'update_time', 'create_time', 'create_user', 'reference_process_id', 'parent_id', 'node_type', 'extend', 'id', 'version', 'last_change', 'category_id', 'infrastructure_process', 'quantitative_reference_id', 'location_id', 'process_doc_id', 'dq_system_id', 'exchange_dq_system_id', 'social_dq_system_id', 'last_internal_id', 'currency_id', 'product_system_id', 'process_source', 'node', 'ref_id', 'name', 'default_allocation_method', 'synonyms', 'dq_entry', 'tags', 'library', 'description', 'process_type']}
SQL: SELECT COUNT(*) FROM tb_process
res: SELECT COUNT(*) FROM tb_process
In [43]:
sql_gold = []
sql_pred = []

for index,item in df.iterrows():
    prompt = item['Question']
    table = item['table']
    gold = item['SQL']
    
    if table not in table_names:
        print(table)
        continue
    if 'join' in gold:
        continue
    colum_type = get_column_types(cur,table)
    tmp = dict()
    tmp[table] = schema[table]
    sql = query_database(prompt,tmp)
    sql_pred.append(sql)
    sql_gold.append(gold)
SQL: SELECT COUNT(*) FROM tb_process
SQL: SELECT tags FROM tb_process WHERE product_system_id='1298311160348545024' AND is_deleted <> '1'
SQL: SELECT id, ref_id, name, synonyms, category_id, process_type, default_allocation_method, infrastructure_process, quantitative_reference_id, dq_entry, dq_system_id, exchange_dq_system_id, social_dq_system_id, location_id, process_doc_id, currency_id, version, last_change, tags, library, product_system_id, process_source, reference_process_id, node, parent_id, node_type, deleted, create_time, update_time, create_user, extend, share FROM tb_process WHERE process_type='unit_process' AND process_source='0'
SQL: SELECT id, ref_id, name, synonyms, category_id, description, process_type, default_allocation_method, infrastructure_process, quantitative_reference_id, dq_entry, dq_system_id, exchange_dq_system_id, social_dq_system_id, location_id, process_doc_id, currency_id, version, last_change, tags, library, product_system_id, process_source, reference_process_id, node, parent_id, node_type, deleted, create_time, update_time, create_user, extend, share FROM tb_process WHERE id='1300755365750636544'
SQL: SELECT id, ref_id, name, synonyms, category_id, description, process_type, default_allocation_method, infrastructure_process, quantitative_reference_id, dq_entry, dq_system_id, exchange_dq_system_id, social_dq_system_id, location_id, process_doc_id, currency_id, version, last_change, tags, library, product_system_id, process_source, reference_process_id, node, parent_id, node_type, deleted, create_time, update_time, create_user, extend, share FROM tb_process WHERE id='9237375'
SQL: SELECT COUNT(*) FROM tb_process WHERE name='铸铁' AND product_system_id='1300754248748761088' AND id!='1300755365750636544' AND is_deleted!='1'
SQL: SELECT * FROM tb_process WHERE ref_id='49e7b885-8db7-3d63-8fce-5e32b8a71391' AND process_type='unit_process' LIMIT 1
SQL: SELECT location_id FROM tb_process WHERE id='3448019508185273680'
SQL: SELECT id, ref_id, name, synonyms, category_id, description, process_type, default_allocation_method, infrastructure_process, quantitative_reference_id, dq_entry, dq_system_id, exchange_dq_system_id, social_dq_system_id, location_id, process_doc_id, currency_id, version, last_change, tags, library, product_system_id, process_source, reference_process_id, node, parent_id, node_type, deleted, create_time, update_time, create_user, extend, share FROM tb_process WHERE product_system_id='1276473778162892800' AND deleted='0'
SQL: SELECT id, name, node_type, tags, process_type, node FROM tb_process WHERE parent_id='1250448580720721920' AND deleted<>1
SQL: SELECT id, name, node_type, tags, process_type, node FROM tb_process WHERE parent_id='1250448580720721920' AND deleted<>1
SQL: SELECT id, ref_id, name, synonyms, category_id, description, process_type, default_allocation_method, infrastructure_process, quantitative_reference_id, dq_entry, dq_system_id, exchange_dq_system_id, social_dq_system_id, location_id, process_doc_id, currency_id, version, last_change, tags, library, product_system_id, process_source, reference_process_id, node, parent_id, node_type, deleted, create_time, update_time, create_user, extend, share FROM tb_process WHERE id IN ('1176899308570542080','1176898365472899072')
SQL: SELECT id, ref_id, name, synonyms, category_id, description, process_type, default_allocation_method, infrastructure_process, quantitative_reference_id, dq_entry, dq_system_id, exchange_dq_system_id, social_dq_system_id, location_id, process_doc_id, currency_id, version, last_change, tags, library, product_system_id, process_source, reference_process_id, node, parent_id, node_type, deleted, create_time, update_time, create_user, extend, share FROM tb_process WHERE synonyms LIKE '%欧元%'
SQL: SELECT id, ref_id, name, synonyms, category_id, description, process_type, default_allocation_method, infrastructure_process, quantitative_reference_id, dq_entry, dq_system_id, exchange_dq_system_id, social_dq_system_id, location_id, process_doc_id, currency_id, version, last_change, tags, library, product_system_id, process_source, reference_process_id, node, parent_id, node_type, deleted, create_time, update_time, create_user, extend, share FROM tb_process WHERE synonyms LIKE '%欧元%' AND process_source='0'
SQL: SELECT id, ref_id, name, synonyms, category_id, description, process_type, default_allocation_method, infrastructure_process, quantitative_reference_id, dq_entry, dq_system_id, exchange_dq_system_id, social_dq_system_id, location_id, process_doc_id, currency_id, version, last_change, tags, library, product_system_id, process_source, reference_process_id, node, parent_id, node_type, deleted, create_time, update_time, create_user, extend, share FROM tb_process WHERE product_system_id='1298206895525330944' AND (node_type='0' OR node_type='1') AND deleted='0' AND parent_id IS NULL ORDER BY id DESC
SQL: SELECT COUNT(*) FROM tb_flow WHERE deleted='0'
SQL: SELECT COUNT(*) FROM tb_flow WHERE deleted='0'
SQL: SELECT id, ref_id, name, flow_type, cas_number, description, tags, library, synonyms, infrastructure_flow, formula, category_id, reference_flow_property_id, source, version, last_change, create_time, create_user, update_user, deleted, update_time, location_id, name_cn, synonyms_cn, is_system, extend FROM tb_flow WHERE deleted='0' LIMIT 20000
SQL: SELECT id, ref_id, name, flow_type, cas_number, description, tags, library, synonyms, infrastructure_flow, formula, category_id, reference_flow_property_id, source, version, last_change, create_time, create_user, update_user, deleted, update_time, location_id, name_cn, synonyms_cn, is_system, extend FROM tb_flow WHERE deleted='0' LIMIT 20000 OFFSET 60000
SQL: SELECT COUNT(*) AS total FROM tb_calculation_task WHERE product_system_id='1298311160348545024'
SQL: INSERT INTO tb_calculation_task (id, status, product_system_id, final_demand, method_id, method_name, calculate_type, create_user) VALUES ('1301153780011634688', '0', '1276473778162892800', 'JSON字符串', '44645931', 'ipcc 2013 gwp 100a', 'lci,lcia,contribution,sensitivity', '10078')
SQL: SELECT COUNT(*) AS total FROM tb_product_system WHERE creator='10043' AND is_deleted<>'1'
SQL: SELECT COUNT(*) AS total FROM tb_product_system WHERE deleted='0'
SQL: SELECT id, ref_id, name, description, location_id, category_id, cutoff, reference_process_id, reference_exchange_id, target_amount, target_flow_property_factor_id, target_unit_id, model_info, project_type, create_user, create_time, update_time, deleted FROM tb_product_system WHERE id='1298311160348545024'
SQL: SELECT COUNT(*) AS total FROM sys_user WHERE delete_flag = '0'
SQL: SELECT id FROM sys_user WHERE org_id IN ('11', '12')
SQL: SELECT id, product_system_id, form_id, data, extend FROM tb_form_data WHERE product_system_id='1247208720564224000' AND form_id='4'
SQL: SELECT id FROM tb_form_data WHERE form_id='14' AND product_system_id='1299311046711840768'
SQL: SELECT id, name, method_id, direction, reference_unit FROM tb_impact_category WHERE method_id='43262324' ORDER BY id
SQL: SELECT SUM(amount) AS total FROM tb_lcia WHERE task_id='1298311273208877056' AND impact_category_id='43262501'
SQL: SELECT COUNT(*) FROM (SELECT * FROM tb_parameter WHERE create_user = '10028' AND scope = 'global' UNION ALL SELECT * FROM tb_parameter WHERE process_id = '3448019511314224627' AND create_user = '10028')
SQL: SELECT id, type, name, sort, template, script, script_url, extend FROM tb_product_system_form WHERE id='14'
SQL: SELECT id, name, sort, extend FROM tb_product_system_form WHERE type='cbam' ORDER BY sort ASC
SQL: SELECT * FROM tb_process_doc WHERE valid_from > '2020-01-01'
SQL: SELECT * FROM tb_process_doc WHERE reviewer_id='1212'
SQL: SELECT COUNT(*) FROM tb_process_doc WHERE project='asdfasdf'
SQL: SELECT * FROM tb_process_doc WHERE valid_from > '2019-01-01' AND valid_until < '2022-12-31'
SQL: SELECT * FROM tb_process_doc WHERE geography='17 European plants'
SQL: SELECT * FROM tb_process_doc WHERE technology LIKE '%Acid-catalysed esterification of free-fatty acids.%'
SQL: SELECT reviewer_id, COUNT(*) as record_count FROM tb_process_doc GROUP BY reviewer_id
SQL: SELECT * FROM tb_process_doc WHERE copyright='1'
SQL: SELECT * FROM tb_process_doc WHERE data_generator_id='1319'
SQL: SELECT dataset_owner_id, COUNT(*) as record_count FROM tb_process_doc GROUP BY dataset_owner_id
SQL: SELECT * FROM tb_process_doc WHERE data_documentor_id='1217'
SQL: SELECT * FROM tb_process_doc WHERE publication_id='1500'
SQL: SELECT * FROM tb_process_doc WHERE preceding_dataset='6006a8fc-1a6c-55ee-a49d-fc9342f5d451'
SQL: SELECT * FROM tb_process_doc WHERE data_treatment LIKE '%Approximation%'
SQL: SELECT * FROM tb_process_doc WHERE data_collection_period LIKE '%2020-2021%'
SQL: SELECT * FROM tb_process WHERE deleted='1'
SQL: SELECT * FROM tb_process WHERE update_time > '2024-10-15'
SQL: SELECT COUNT(*) FROM tb_process WHERE create_user='10000'
SQL: SELECT * FROM tb_process WHERE category_id='66'
SQL: SELECT location_id, COUNT(*) FROM tb_process GROUP BY location_id
SQL: SELECT * FROM tb_process WHERE dq_system_id='129124'
SQL: SELECT * FROM tb_process WHERE tags LIKE '%10煤电%'
SQL: SELECT * FROM tb_process WHERE default_allocation_method='economic allocation'
SQL: SELECT process_type, COUNT(*) as count FROM tb_process GROUP BY process_type
SQL: SELECT * FROM tb_process WHERE name LIKE '%AAAAsadfasd%'
SQL: SELECT * FROM tb_process WHERE node_type='2'
SQL: SELECT version, COUNT(*) as count FROM tb_process GROUP BY version
SQL: SELECT * FROM tb_process WHERE process_doc_id='594641'
SQL: SELECT * FROM tb_process WHERE default_allocation_method='CAUSAL'
SQL: SELECT * FROM tb_process WHERE create_time BETWEEN '2022-01-01' AND '2023-01-01'
SQL: SELECT * FROM tb_process WHERE process_source='1'
SQL: SELECT * FROM tb_process WHERE ref_id='d4730e0a-9bae-3fde-93e1-cac96735aa76'
SQL: SELECT * FROM tb_process WHERE infrastructure_process='0'
In [44]:
len(sql_gold)
Out[44]:
66
In [45]:
def save_txt(save_list,path):
    with open(path,"w",encoding="utf-8") as f:
        for item in save_list:
            f.write(item+"\n")
In [46]:
save_txt(sql_gold,"./data/Qianfan/gold.txt")
save_txt(sql_pred,"./data/Qianfan/pred.txt")

改路径!!!

In [13]:
!python --version
In [ ]:
# 执行准确率
In [73]:
print(schema['sys_user'])
["'last_login_time'", "'version'", "'deleted'", "'creator'", "'create_time'", "'updater'", "'update_time'", "'user_type'", "'person_auth_id'", "'enterprise_auth_id'", "'id'", "'gender'", "'org_id'", "'super_admin'", "'status'", "'username'", "'password'", "'real_name'", "'avatar'", "'open_id'", "'email'", "'mobile'", "'nickname'", "'id_number'", "'remark'"]
In [ ]: