LCA-LLM/LLM-SQL/LLM-sql.ipynb

690 lines
29 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-15T08:05:02.254040400Z",
"start_time": "2024-11-15T08:05:00.664449900Z"
}
},
"outputs": [],
"source": [
"import os\n",
"from langchain_community.llms import QianfanLLMEndpoint,Tongyi\n",
"from langchain_community.chat_models import ChatZhipuAI\n",
"from langchain.chains import LLMChain\n",
"from langchain.prompts import PromptTemplate\n",
"import psycopg2\n",
"import re"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-15T09:27:27.129813200Z",
"start_time": "2024-11-15T09:27:27.095824600Z"
}
},
"outputs": [],
"source": [
"# 百度千帆\n",
"api = \"xxxxx\"\n",
"sk = \"xxxxxx\"\n",
"\n",
"llm = QianfanLLMEndpoint(model=\"ERNIE-4.0-8K\",qianfan_ak=api,qianfan_sk=sk)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-15T08:05:13.561075500Z",
"start_time": "2024-11-15T08:05:13.451555700Z"
},
"collapsed": false
},
"outputs": [],
"source": [
"# 千问\n",
"# os.environ['DASHSCOPE_API_KEY'] = 'sk-xxxxxxxxx'\n",
"# \n",
"# llm = Tongyi(model_name=\"qwen-max\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-15T08:05:18.389436200Z",
"start_time": "2024-11-15T08:05:15.487698800Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"生命的意义在于追求个人成长与幸福,同时为他人和社会贡献价值,实现自我与世界的和谐共存。\n"
]
}
],
"source": [
"input_text = \"用50个字左右阐述生命的意义在于\"\n",
"response = llm.invoke(input_text)\n",
"print(response)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-15T08:11:05.408878300Z",
"start_time": "2024-11-15T08:11:05.273677300Z"
}
},
"outputs": [],
"source": [
"# 数据库连接参数\n",
"dbname = \"gis_lca\"\n",
"user = \"postgres\"\n",
"password = \"xxxxxx\"\n",
"host = \"localhost\" # 或者是你 Docker 容器的 IP 地址,如果你在不同的机器上\n",
"port = \"5432\"\n",
"\n",
"# 连接字符串\n",
"conn_string = f\"host={host} dbname={dbname} user={user} password={password} port={port}\"\n",
"# 连接到数据库\n",
"conn = psycopg2.connect(conn_string)\n",
"cur = conn.cursor()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-15T08:11:53.090176Z",
"start_time": "2024-11-15T08:11:53.041140900Z"
},
"collapsed": false
},
"outputs": [],
"source": [
"# 获取表名\n",
"def get_table_name(cur):\n",
" # 执行 SQL 查询以获取所有表的列表\n",
" cur.execute(\"\"\"\n",
" SELECT table_name \n",
" FROM information_schema.tables \n",
" WHERE table_schema = 'public';\n",
" \"\"\")\n",
" \n",
" # 获取查询结果\n",
" tables = cur.fetchall()\n",
" tname = []\n",
" # 打印表名\n",
" for table in tables:\n",
" tname.append(table[0])\n",
" \n",
" return tname"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-15T08:13:58.272398Z",
"start_time": "2024-11-15T08:13:58.238434200Z"
}
},
"outputs": [],
"source": [
"# 获取数据库列名。\n",
"def get_table_columns(cur,table_name):\n",
" try:\n",
" # 执行 SQL 查询以获取表的列名\n",
" cur.execute(f\"\"\"\n",
" SELECT column_name \n",
" FROM information_schema.columns \n",
" WHERE table_name = '{table_name}';\n",
" \"\"\")\n",
" # 获取查询结果\n",
" cols = [desc[0] for desc in cur.fetchall()]\n",
" columns = []\n",
" for col in cols:\n",
" # column = \"'\"+col+\"'\"\n",
" columns.append(col)\n",
" except Exception as e:\n",
" print(f\"An error occurred: {e}\")\n",
" return columns"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-15T08:13:59.601786Z",
"start_time": "2024-11-15T08:13:59.206151400Z"
},
"collapsed": false
},
"outputs": [],
"source": [
"# 获取数据库表结构,表名对应列名\n",
"schema = dict()\n",
"\n",
"table_names = get_table_name(cur)\n",
"for name in table_names:\n",
" schema[name] = get_table_columns(cur,name)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-15T08:14:02.392006100Z",
"start_time": "2024-11-15T08:14:02.359287Z"
},
"collapsed": false
},
"outputs": [],
"source": [
"def get_column_types(cur, table_name):\n",
" # 执行 SQL 查询以获取表的列名和数据类型\n",
" cur.execute(\"\"\"\n",
" SELECT column_name, data_type\n",
" FROM information_schema.columns\n",
" WHERE table_name = %s;\n",
" \"\"\", (table_name,))\n",
"\n",
" # 获取查询结果\n",
" results = cur.fetchall()\n",
" return results"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-15T08:14:04.182996600Z",
"start_time": "2024-11-15T08:14:04.117421900Z"
},
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[('share', 'integer'), ('deleted', 'smallint'), ('update_time', 'timestamp without time zone'), ('create_time', 'timestamp without time zone'), ('create_user', 'bigint'), ('reference_process_id', 'bigint'), ('parent_id', 'bigint'), ('node_type', 'smallint'), ('extend', 'json'), ('id', 'bigint'), ('version', 'bigint'), ('last_change', 'bigint'), ('category_id', 'bigint'), ('infrastructure_process', 'smallint'), ('quantitative_reference_id', 'bigint'), ('location_id', 'bigint'), ('process_doc_id', 'bigint'), ('dq_system_id', 'bigint'), ('exchange_dq_system_id', 'bigint'), ('social_dq_system_id', 'bigint'), ('last_internal_id', 'integer'), ('currency_id', 'bigint'), ('product_system_id', 'bigint'), ('process_source', 'smallint'), ('node', 'jsonb'), ('ref_id', 'character varying'), ('name', 'character varying'), ('default_allocation_method', 'character varying'), ('synonyms', 'text'), ('dq_entry', 'character varying'), ('tags', 'character varying'), ('library', 'character varying'), ('description', 'text'), ('process_type', 'character varying')]\n"
]
}
],
"source": [
"col = get_column_types(cur,\"tb_process\")\n",
"print(col)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-15T08:14:05.075951300Z",
"start_time": "2024-11-15T08:14:05.044963600Z"
}
},
"outputs": [],
"source": [
"def getSql(response):\n",
" pattern = r\"#(.*?)#\"\n",
" sql_list = re.findall(pattern, response)\n",
" sql_str = sql_list[-1]\n",
" # sql_str = sql_str.lower()\n",
" print(\"SQL:\",sql_str)\n",
"\n",
" return sql_str"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-15T12:08:29.507242700Z",
"start_time": "2024-11-15T12:08:29.440042800Z"
}
},
"outputs": [],
"source": [
"def query_database(text, schema):\n",
" # 将表结构信息包含在提示中\n",
" schema_info = \"\\n\".join([f\"Table {table}: {', '.join(columns)}\" for table, columns in schema.items()])\n",
" # print(schema_info)\n",
" # prompt= PromptTemplate(\n",
" # input_variables=[\"schema_info\",\"prompt\"],\n",
" # template='''以下是数据库的表结构信息:{schema_info},分析表结构信息.\n",
" # 请根据以下描述生成一个符合SQLite数据库语法的SQL查询并且不能修改给出的数据表列名。\n",
" # 描述:{prompt}。\n",
" # 要求输出的SQL以#开头,以#结尾,参数类型一律按照字符串处理,样例如下:\n",
" # #SELECT * FROM table#\n",
" # #SELECT COUNT(*) FROM table where Colume_name='abc'#\n",
" # 注意不要输出分析过程和其他内容直接给出SQL语句。\n",
" # '''\n",
" # )\n",
" prompt= PromptTemplate(\n",
" input_variables=[\"prompt\"],\n",
" template='''请根据以下描述生成一个符合SQLite数据库语法的SQL查询。\n",
" 描述:{prompt}。\n",
" 并且要求输出的SQL以#开头,以#结尾,参数类型一律按照字符串处理,样例如下:\n",
" #SELECT * FROM table#\n",
" #SELECT COUNT(*) FROM table where Colume_name='abc'#\n",
" 注意不要输出分析过程和其他内容直接给出SQL语句。\n",
" '''\n",
" )\n",
" # print(schema_info)\n",
" chain = LLMChain(llm=llm, prompt=prompt)\n",
" query = chain.run(\n",
" {\"prompt\":text}\n",
" )\n",
" # print(query)\n",
" sql = getSql(query)\n",
" # result = execute_query(sql)\n",
" return sql"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-15T12:08:30.524029800Z",
"start_time": "2024-11-15T12:08:30.209010600Z"
}
},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>SQL</th>\n",
" <th>Question</th>\n",
" <th>table</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>select count( * ) as total from tb_process whe...</td>\n",
" <td>数据库tb_process表中有多少未删除的数据</td>\n",
" <td>tb_process</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>select tags from tb_process where product_syst...</td>\n",
" <td>在数据库tb_process表中product_system_id为12983111603...</td>\n",
" <td>tb_process</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>select id,ref_id,name,synonyms,category_id,pro...</td>\n",
" <td>在数据库tb_process表中process_type为unit_process且p...</td>\n",
" <td>tb_process</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>select id,ref_id,name,synonyms,category_id,des...</td>\n",
" <td>在数据库tb_process表中id为1300755365750636544的记录的id、...</td>\n",
" <td>tb_process</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>select id,ref_id,name,synonyms,category_id,des...</td>\n",
" <td>在数据库tb_process表中id为9237375的记录的id、ref_id、name、...</td>\n",
" <td>tb_process</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" SQL \n",
"0 select count( * ) as total from tb_process whe... \\\n",
"1 select tags from tb_process where product_syst... \n",
"2 select id,ref_id,name,synonyms,category_id,pro... \n",
"3 select id,ref_id,name,synonyms,category_id,des... \n",
"4 select id,ref_id,name,synonyms,category_id,des... \n",
"\n",
" Question table \n",
"0 数据库tb_process表中有多少未删除的数据 tb_process \n",
"1 在数据库tb_process表中product_system_id为12983111603... tb_process \n",
"2 在数据库tb_process表中process_type为unit_process且p... tb_process \n",
"3 在数据库tb_process表中id为1300755365750636544的记录的id、... tb_process \n",
"4 在数据库tb_process表中id为9237375的记录的id、ref_id、name、... tb_process "
]
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"df = pd.read_excel(r\"./data/sql_ques_clear.xlsx\")\n",
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-15T12:08:43.896102500Z",
"start_time": "2024-11-15T12:08:31.101109600Z"
},
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'tb_process': ['share', 'deleted', 'update_time', 'create_time', 'create_user', 'reference_process_id', 'parent_id', 'node_type', 'extend', 'id', 'version', 'last_change', 'category_id', 'infrastructure_process', 'quantitative_reference_id', 'location_id', 'process_doc_id', 'dq_system_id', 'exchange_dq_system_id', 'social_dq_system_id', 'last_internal_id', 'currency_id', 'product_system_id', 'process_source', 'node', 'ref_id', 'name', 'default_allocation_method', 'synonyms', 'dq_entry', 'tags', 'library', 'description', 'process_type']}\n",
"SQL: SELECT COUNT(*) FROM tb_process\n",
"res: SELECT COUNT(*) FROM tb_process\n"
]
}
],
"source": [
"ques_test = \"数据库tb_process表中有多少未删除的数据\"\n",
"tmp_dict = dict()\n",
"tmp_dict[\"tb_process\"] = schema[\"tb_process\"]\n",
"print(tmp_dict)\n",
"res = query_database(ques_test,tmp_dict)\n",
"print(\"res:\",res)"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-15T12:13:59.277771300Z",
"start_time": "2024-11-15T12:08:43.881680800Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"SQL: SELECT COUNT(*) FROM tb_process\n",
"SQL: SELECT tags FROM tb_process WHERE product_system_id='1298311160348545024' AND is_deleted <> '1'\n",
"SQL: SELECT id, ref_id, name, synonyms, category_id, process_type, default_allocation_method, infrastructure_process, quantitative_reference_id, dq_entry, dq_system_id, exchange_dq_system_id, social_dq_system_id, location_id, process_doc_id, currency_id, version, last_change, tags, library, product_system_id, process_source, reference_process_id, node, parent_id, node_type, deleted, create_time, update_time, create_user, extend, share FROM tb_process WHERE process_type='unit_process' AND process_source='0'\n",
"SQL: SELECT id, ref_id, name, synonyms, category_id, description, process_type, default_allocation_method, infrastructure_process, quantitative_reference_id, dq_entry, dq_system_id, exchange_dq_system_id, social_dq_system_id, location_id, process_doc_id, currency_id, version, last_change, tags, library, product_system_id, process_source, reference_process_id, node, parent_id, node_type, deleted, create_time, update_time, create_user, extend, share FROM tb_process WHERE id='1300755365750636544'\n",
"SQL: SELECT id, ref_id, name, synonyms, category_id, description, process_type, default_allocation_method, infrastructure_process, quantitative_reference_id, dq_entry, dq_system_id, exchange_dq_system_id, social_dq_system_id, location_id, process_doc_id, currency_id, version, last_change, tags, library, product_system_id, process_source, reference_process_id, node, parent_id, node_type, deleted, create_time, update_time, create_user, extend, share FROM tb_process WHERE id='9237375'\n",
"SQL: SELECT COUNT(*) FROM tb_process WHERE name='铸铁' AND product_system_id='1300754248748761088' AND id!='1300755365750636544' AND is_deleted!='1'\n",
"SQL: SELECT * FROM tb_process WHERE ref_id='49e7b885-8db7-3d63-8fce-5e32b8a71391' AND process_type='unit_process' LIMIT 1\n",
"SQL: SELECT location_id FROM tb_process WHERE id='3448019508185273680'\n",
"SQL: SELECT id, ref_id, name, synonyms, category_id, description, process_type, default_allocation_method, infrastructure_process, quantitative_reference_id, dq_entry, dq_system_id, exchange_dq_system_id, social_dq_system_id, location_id, process_doc_id, currency_id, version, last_change, tags, library, product_system_id, process_source, reference_process_id, node, parent_id, node_type, deleted, create_time, update_time, create_user, extend, share FROM tb_process WHERE product_system_id='1276473778162892800' AND deleted='0'\n",
"SQL: SELECT id, name, node_type, tags, process_type, node FROM tb_process WHERE parent_id='1250448580720721920' AND deleted<>1\n",
"SQL: SELECT id, name, node_type, tags, process_type, node FROM tb_process WHERE parent_id='1250448580720721920' AND deleted<>1\n",
"SQL: SELECT id, ref_id, name, synonyms, category_id, description, process_type, default_allocation_method, infrastructure_process, quantitative_reference_id, dq_entry, dq_system_id, exchange_dq_system_id, social_dq_system_id, location_id, process_doc_id, currency_id, version, last_change, tags, library, product_system_id, process_source, reference_process_id, node, parent_id, node_type, deleted, create_time, update_time, create_user, extend, share FROM tb_process WHERE id IN ('1176899308570542080','1176898365472899072')\n",
"SQL: SELECT id, ref_id, name, synonyms, category_id, description, process_type, default_allocation_method, infrastructure_process, quantitative_reference_id, dq_entry, dq_system_id, exchange_dq_system_id, social_dq_system_id, location_id, process_doc_id, currency_id, version, last_change, tags, library, product_system_id, process_source, reference_process_id, node, parent_id, node_type, deleted, create_time, update_time, create_user, extend, share FROM tb_process WHERE synonyms LIKE '%欧元%'\n",
"SQL: SELECT id, ref_id, name, synonyms, category_id, description, process_type, default_allocation_method, infrastructure_process, quantitative_reference_id, dq_entry, dq_system_id, exchange_dq_system_id, social_dq_system_id, location_id, process_doc_id, currency_id, version, last_change, tags, library, product_system_id, process_source, reference_process_id, node, parent_id, node_type, deleted, create_time, update_time, create_user, extend, share FROM tb_process WHERE synonyms LIKE '%欧元%' AND process_source='0'\n",
"SQL: SELECT id, ref_id, name, synonyms, category_id, description, process_type, default_allocation_method, infrastructure_process, quantitative_reference_id, dq_entry, dq_system_id, exchange_dq_system_id, social_dq_system_id, location_id, process_doc_id, currency_id, version, last_change, tags, library, product_system_id, process_source, reference_process_id, node, parent_id, node_type, deleted, create_time, update_time, create_user, extend, share FROM tb_process WHERE product_system_id='1298206895525330944' AND (node_type='0' OR node_type='1') AND deleted='0' AND parent_id IS NULL ORDER BY id DESC\n",
"SQL: SELECT COUNT(*) FROM tb_flow WHERE deleted='0'\n",
"SQL: SELECT COUNT(*) FROM tb_flow WHERE deleted='0'\n",
"SQL: SELECT id, ref_id, name, flow_type, cas_number, description, tags, library, synonyms, infrastructure_flow, formula, category_id, reference_flow_property_id, source, version, last_change, create_time, create_user, update_user, deleted, update_time, location_id, name_cn, synonyms_cn, is_system, extend FROM tb_flow WHERE deleted='0' LIMIT 20000\n",
"SQL: SELECT id, ref_id, name, flow_type, cas_number, description, tags, library, synonyms, infrastructure_flow, formula, category_id, reference_flow_property_id, source, version, last_change, create_time, create_user, update_user, deleted, update_time, location_id, name_cn, synonyms_cn, is_system, extend FROM tb_flow WHERE deleted='0' LIMIT 20000 OFFSET 60000\n",
"SQL: SELECT COUNT(*) AS total FROM tb_calculation_task WHERE product_system_id='1298311160348545024'\n",
"SQL: INSERT INTO tb_calculation_task (id, status, product_system_id, final_demand, method_id, method_name, calculate_type, create_user) VALUES ('1301153780011634688', '0', '1276473778162892800', 'JSON字符串', '44645931', 'ipcc 2013 gwp 100a', 'lci,lcia,contribution,sensitivity', '10078')\n",
"SQL: SELECT COUNT(*) AS total FROM tb_product_system WHERE creator='10043' AND is_deleted<>'1'\n",
"SQL: SELECT COUNT(*) AS total FROM tb_product_system WHERE deleted='0'\n",
"SQL: SELECT id, ref_id, name, description, location_id, category_id, cutoff, reference_process_id, reference_exchange_id, target_amount, target_flow_property_factor_id, target_unit_id, model_info, project_type, create_user, create_time, update_time, deleted FROM tb_product_system WHERE id='1298311160348545024'\n",
"SQL: SELECT COUNT(*) AS total FROM sys_user WHERE delete_flag = '0'\n",
"SQL: SELECT id FROM sys_user WHERE org_id IN ('11', '12')\n",
"SQL: SELECT id, product_system_id, form_id, data, extend FROM tb_form_data WHERE product_system_id='1247208720564224000' AND form_id='4'\n",
"SQL: SELECT id FROM tb_form_data WHERE form_id='14' AND product_system_id='1299311046711840768'\n",
"SQL: SELECT id, name, method_id, direction, reference_unit FROM tb_impact_category WHERE method_id='43262324' ORDER BY id\n",
"SQL: SELECT SUM(amount) AS total FROM tb_lcia WHERE task_id='1298311273208877056' AND impact_category_id='43262501'\n",
"SQL: SELECT COUNT(*) FROM (SELECT * FROM tb_parameter WHERE create_user = '10028' AND scope = 'global' UNION ALL SELECT * FROM tb_parameter WHERE process_id = '3448019511314224627' AND create_user = '10028')\n",
"SQL: SELECT id, type, name, sort, template, script, script_url, extend FROM tb_product_system_form WHERE id='14'\n",
"SQL: SELECT id, name, sort, extend FROM tb_product_system_form WHERE type='cbam' ORDER BY sort ASC\n",
"SQL: SELECT * FROM tb_process_doc WHERE valid_from > '2020-01-01'\n",
"SQL: SELECT * FROM tb_process_doc WHERE reviewer_id='1212'\n",
"SQL: SELECT COUNT(*) FROM tb_process_doc WHERE project='asdfasdf'\n",
"SQL: SELECT * FROM tb_process_doc WHERE valid_from > '2019-01-01' AND valid_until < '2022-12-31'\n",
"SQL: SELECT * FROM tb_process_doc WHERE geography='17 European plants'\n",
"SQL: SELECT * FROM tb_process_doc WHERE technology LIKE '%Acid-catalysed esterification of free-fatty acids.%'\n",
"SQL: SELECT reviewer_id, COUNT(*) as record_count FROM tb_process_doc GROUP BY reviewer_id\n",
"SQL: SELECT * FROM tb_process_doc WHERE copyright='1'\n",
"SQL: SELECT * FROM tb_process_doc WHERE data_generator_id='1319'\n",
"SQL: SELECT dataset_owner_id, COUNT(*) as record_count FROM tb_process_doc GROUP BY dataset_owner_id\n",
"SQL: SELECT * FROM tb_process_doc WHERE data_documentor_id='1217'\n",
"SQL: SELECT * FROM tb_process_doc WHERE publication_id='1500'\n",
"SQL: SELECT * FROM tb_process_doc WHERE preceding_dataset='6006a8fc-1a6c-55ee-a49d-fc9342f5d451'\n",
"SQL: SELECT * FROM tb_process_doc WHERE data_treatment LIKE '%Approximation%'\n",
"SQL: SELECT * FROM tb_process_doc WHERE data_collection_period LIKE '%2020-2021%'\n",
"SQL: SELECT * FROM tb_process WHERE deleted='1'\n",
"SQL: SELECT * FROM tb_process WHERE update_time > '2024-10-15'\n",
"SQL: SELECT COUNT(*) FROM tb_process WHERE create_user='10000'\n",
"SQL: SELECT * FROM tb_process WHERE category_id='66'\n",
"SQL: SELECT location_id, COUNT(*) FROM tb_process GROUP BY location_id\n",
"SQL: SELECT * FROM tb_process WHERE dq_system_id='129124'\n",
"SQL: SELECT * FROM tb_process WHERE tags LIKE '%10煤电%'\n",
"SQL: SELECT * FROM tb_process WHERE default_allocation_method='economic allocation'\n",
"SQL: SELECT process_type, COUNT(*) as count FROM tb_process GROUP BY process_type\n",
"SQL: SELECT * FROM tb_process WHERE name LIKE '%AAAAsadfasd%'\n",
"SQL: SELECT * FROM tb_process WHERE node_type='2'\n",
"SQL: SELECT version, COUNT(*) as count FROM tb_process GROUP BY version\n",
"SQL: SELECT * FROM tb_process WHERE process_doc_id='594641'\n",
"SQL: SELECT * FROM tb_process WHERE default_allocation_method='CAUSAL'\n",
"SQL: SELECT * FROM tb_process WHERE create_time BETWEEN '2022-01-01' AND '2023-01-01'\n",
"SQL: SELECT * FROM tb_process WHERE process_source='1'\n",
"SQL: SELECT * FROM tb_process WHERE ref_id='d4730e0a-9bae-3fde-93e1-cac96735aa76'\n",
"SQL: SELECT * FROM tb_process WHERE infrastructure_process='0'\n"
]
}
],
"source": [
"sql_gold = []\n",
"sql_pred = []\n",
"\n",
"for index,item in df.iterrows():\n",
" prompt = item['Question']\n",
" table = item['table']\n",
" gold = item['SQL']\n",
" \n",
" if table not in table_names:\n",
" print(table)\n",
" continue\n",
" if 'join' in gold:\n",
" continue\n",
" colum_type = get_column_types(cur,table)\n",
" tmp = dict()\n",
" tmp[table] = schema[table]\n",
" sql = query_database(prompt,tmp)\n",
" sql_pred.append(sql)\n",
" sql_gold.append(gold)"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-15T12:27:01.460453800Z",
"start_time": "2024-11-15T12:27:01.429826700Z"
},
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"66"
]
},
"execution_count": 44,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(sql_gold)"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-15T12:27:03.001823100Z",
"start_time": "2024-11-15T12:27:02.925201900Z"
}
},
"outputs": [],
"source": [
"def save_txt(save_list,path):\n",
" with open(path,\"w\",encoding=\"utf-8\") as f:\n",
" for item in save_list:\n",
" f.write(item+\"\\n\")"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-15T12:27:03.826791400Z",
"start_time": "2024-11-15T12:27:03.758307800Z"
},
"collapsed": false
},
"outputs": [],
"source": [
"save_txt(sql_gold,\"./data/Qianfan/gold.txt\")\n",
"save_txt(sql_pred,\"./data/Qianfan/pred.txt\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-13T03:32:33.063956100Z",
"start_time": "2024-11-13T03:26:33.872560100Z"
}
},
"source": [
"### 改路径!!!"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-13T01:15:39.906977700Z",
"start_time": "2024-11-13T01:15:39.866974300Z"
},
"collapsed": false
},
"outputs": [],
"source": [
"!python --version"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 执行准确率"
]
},
{
"cell_type": "code",
"execution_count": 73,
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-13T10:37:42.149121600Z",
"start_time": "2024-11-13T10:37:42.077659900Z"
},
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[\"'last_login_time'\", \"'version'\", \"'deleted'\", \"'creator'\", \"'create_time'\", \"'updater'\", \"'update_time'\", \"'user_type'\", \"'person_auth_id'\", \"'enterprise_auth_id'\", \"'id'\", \"'gender'\", \"'org_id'\", \"'super_admin'\", \"'status'\", \"'username'\", \"'password'\", \"'real_name'\", \"'avatar'\", \"'open_id'\", \"'email'\", \"'mobile'\", \"'nickname'\", \"'id_number'\", \"'remark'\"]\n"
]
}
],
"source": [
"print(schema['sys_user'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Qwen",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 2
}