LCA-GPT/LLM-SQL/LLM-sql.ipynb

690 lines
29 KiB
Plaintext
Raw Permalink Normal View History

2024-12-29 16:18:16 +08:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-15T08:05:02.254040400Z",
"start_time": "2024-11-15T08:05:00.664449900Z"
}
},
"outputs": [],
"source": [
"import os\n",
"from langchain_community.llms import QianfanLLMEndpoint,Tongyi\n",
"from langchain_community.chat_models import ChatZhipuAI\n",
"from langchain.chains import LLMChain\n",
"from langchain.prompts import PromptTemplate\n",
"import psycopg2\n",
"import re"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-15T09:27:27.129813200Z",
"start_time": "2024-11-15T09:27:27.095824600Z"
}
},
"outputs": [],
"source": [
"# 百度千帆\n",
"api = \"xxxxx\"\n",
"sk = \"xxxxxx\"\n",
"\n",
"llm = QianfanLLMEndpoint(model=\"ERNIE-4.0-8K\",qianfan_ak=api,qianfan_sk=sk)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-15T08:05:13.561075500Z",
"start_time": "2024-11-15T08:05:13.451555700Z"
},
"collapsed": false
},
"outputs": [],
"source": [
"# 千问\n",
"# os.environ['DASHSCOPE_API_KEY'] = 'sk-xxxxxxxxx'\n",
"# \n",
"# llm = Tongyi(model_name=\"qwen-max\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-15T08:05:18.389436200Z",
"start_time": "2024-11-15T08:05:15.487698800Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"生命的意义在于追求个人成长与幸福,同时为他人和社会贡献价值,实现自我与世界的和谐共存。\n"
]
}
],
"source": [
"input_text = \"用50个字左右阐述生命的意义在于\"\n",
"response = llm.invoke(input_text)\n",
"print(response)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-15T08:11:05.408878300Z",
"start_time": "2024-11-15T08:11:05.273677300Z"
}
},
"outputs": [],
"source": [
"# 数据库连接参数\n",
"dbname = \"gis_lca\"\n",
"user = \"postgres\"\n",
"password = \"xxxxxx\"\n",
"host = \"localhost\" # 或者是你 Docker 容器的 IP 地址,如果你在不同的机器上\n",
"port = \"5432\"\n",
"\n",
"# 连接字符串\n",
"conn_string = f\"host={host} dbname={dbname} user={user} password={password} port={port}\"\n",
"# 连接到数据库\n",
"conn = psycopg2.connect(conn_string)\n",
"cur = conn.cursor()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-15T08:11:53.090176Z",
"start_time": "2024-11-15T08:11:53.041140900Z"
},
"collapsed": false
},
"outputs": [],
"source": [
"# 获取表名\n",
"def get_table_name(cur):\n",
" # 执行 SQL 查询以获取所有表的列表\n",
" cur.execute(\"\"\"\n",
" SELECT table_name \n",
" FROM information_schema.tables \n",
" WHERE table_schema = 'public';\n",
" \"\"\")\n",
" \n",
" # 获取查询结果\n",
" tables = cur.fetchall()\n",
" tname = []\n",
" # 打印表名\n",
" for table in tables:\n",
" tname.append(table[0])\n",
" \n",
" return tname"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-15T08:13:58.272398Z",
"start_time": "2024-11-15T08:13:58.238434200Z"
}
},
"outputs": [],
"source": [
"# 获取数据库列名。\n",
"def get_table_columns(cur,table_name):\n",
" try:\n",
" # 执行 SQL 查询以获取表的列名\n",
" cur.execute(f\"\"\"\n",
" SELECT column_name \n",
" FROM information_schema.columns \n",
" WHERE table_name = '{table_name}';\n",
" \"\"\")\n",
" # 获取查询结果\n",
" cols = [desc[0] for desc in cur.fetchall()]\n",
" columns = []\n",
" for col in cols:\n",
" # column = \"'\"+col+\"'\"\n",
" columns.append(col)\n",
" except Exception as e:\n",
" print(f\"An error occurred: {e}\")\n",
" return columns"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-15T08:13:59.601786Z",
"start_time": "2024-11-15T08:13:59.206151400Z"
},
"collapsed": false
},
"outputs": [],
"source": [
"# 获取数据库表结构,表名对应列名\n",
"schema = dict()\n",
"\n",
"table_names = get_table_name(cur)\n",
"for name in table_names:\n",
" schema[name] = get_table_columns(cur,name)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-15T08:14:02.392006100Z",
"start_time": "2024-11-15T08:14:02.359287Z"
},
"collapsed": false
},
"outputs": [],
"source": [
"def get_column_types(cur, table_name):\n",
" # 执行 SQL 查询以获取表的列名和数据类型\n",
" cur.execute(\"\"\"\n",
" SELECT column_name, data_type\n",
" FROM information_schema.columns\n",
" WHERE table_name = %s;\n",
" \"\"\", (table_name,))\n",
"\n",
" # 获取查询结果\n",
" results = cur.fetchall()\n",
" return results"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-15T08:14:04.182996600Z",
"start_time": "2024-11-15T08:14:04.117421900Z"
},
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[('share', 'integer'), ('deleted', 'smallint'), ('update_time', 'timestamp without time zone'), ('create_time', 'timestamp without time zone'), ('create_user', 'bigint'), ('reference_process_id', 'bigint'), ('parent_id', 'bigint'), ('node_type', 'smallint'), ('extend', 'json'), ('id', 'bigint'), ('version', 'bigint'), ('last_change', 'bigint'), ('category_id', 'bigint'), ('infrastructure_process', 'smallint'), ('quantitative_reference_id', 'bigint'), ('location_id', 'bigint'), ('process_doc_id', 'bigint'), ('dq_system_id', 'bigint'), ('exchange_dq_system_id', 'bigint'), ('social_dq_system_id', 'bigint'), ('last_internal_id', 'integer'), ('currency_id', 'bigint'), ('product_system_id', 'bigint'), ('process_source', 'smallint'), ('node', 'jsonb'), ('ref_id', 'character varying'), ('name', 'character varying'), ('default_allocation_method', 'character varying'), ('synonyms', 'text'), ('dq_entry', 'character varying'), ('tags', 'character varying'), ('library', 'character varying'), ('description', 'text'), ('process_type', 'character varying')]\n"
]
}
],
"source": [
"col = get_column_types(cur,\"tb_process\")\n",
"print(col)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-15T08:14:05.075951300Z",
"start_time": "2024-11-15T08:14:05.044963600Z"
}
},
"outputs": [],
"source": [
"def getSql(response):\n",
" pattern = r\"#(.*?)#\"\n",
" sql_list = re.findall(pattern, response)\n",
" sql_str = sql_list[-1]\n",
" # sql_str = sql_str.lower()\n",
" print(\"SQL:\",sql_str)\n",
"\n",
" return sql_str"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-15T12:08:29.507242700Z",
"start_time": "2024-11-15T12:08:29.440042800Z"
}
},
"outputs": [],
"source": [
"def query_database(text, schema):\n",
" # 将表结构信息包含在提示中\n",
" schema_info = \"\\n\".join([f\"Table {table}: {', '.join(columns)}\" for table, columns in schema.items()])\n",
" # print(schema_info)\n",
" # prompt= PromptTemplate(\n",
" # input_variables=[\"schema_info\",\"prompt\"],\n",
" # template='''以下是数据库的表结构信息:{schema_info},分析表结构信息.\n",
" # 请根据以下描述生成一个符合SQLite数据库语法的SQL查询并且不能修改给出的数据表列名。\n",
" # 描述:{prompt}。\n",
" # 要求输出的SQL以#开头,以#结尾,参数类型一律按照字符串处理,样例如下:\n",
" # #SELECT * FROM table#\n",
" # #SELECT COUNT(*) FROM table where Colume_name='abc'#\n",
" # 注意不要输出分析过程和其他内容直接给出SQL语句。\n",
" # '''\n",
" # )\n",
" prompt= PromptTemplate(\n",
" input_variables=[\"prompt\"],\n",
" template='''请根据以下描述生成一个符合SQLite数据库语法的SQL查询。\n",
" 描述:{prompt}。\n",
" 并且要求输出的SQL以#开头,以#结尾,参数类型一律按照字符串处理,样例如下:\n",
" #SELECT * FROM table#\n",
" #SELECT COUNT(*) FROM table where Colume_name='abc'#\n",
" 注意不要输出分析过程和其他内容直接给出SQL语句。\n",
" '''\n",
" )\n",
" # print(schema_info)\n",
" chain = LLMChain(llm=llm, prompt=prompt)\n",
" query = chain.run(\n",
" {\"prompt\":text}\n",
" )\n",
" # print(query)\n",
" sql = getSql(query)\n",
" # result = execute_query(sql)\n",
" return sql"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-15T12:08:30.524029800Z",
"start_time": "2024-11-15T12:08:30.209010600Z"
}
},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>SQL</th>\n",
" <th>Question</th>\n",
" <th>table</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>select count( * ) as total from tb_process whe...</td>\n",
" <td>数据库tb_process表中有多少未删除的数据</td>\n",
" <td>tb_process</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>select tags from tb_process where product_syst...</td>\n",
" <td>在数据库tb_process表中product_system_id为12983111603...</td>\n",
" <td>tb_process</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>select id,ref_id,name,synonyms,category_id,pro...</td>\n",
" <td>在数据库tb_process表中process_type为unit_process且p...</td>\n",
" <td>tb_process</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>select id,ref_id,name,synonyms,category_id,des...</td>\n",
" <td>在数据库tb_process表中id为1300755365750636544的记录的id、...</td>\n",
" <td>tb_process</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>select id,ref_id,name,synonyms,category_id,des...</td>\n",
" <td>在数据库tb_process表中id为9237375的记录的id、ref_id、name、...</td>\n",
" <td>tb_process</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" SQL \n",
"0 select count( * ) as total from tb_process whe... \\\n",
"1 select tags from tb_process where product_syst... \n",
"2 select id,ref_id,name,synonyms,category_id,pro... \n",
"3 select id,ref_id,name,synonyms,category_id,des... \n",
"4 select id,ref_id,name,synonyms,category_id,des... \n",
"\n",
" Question table \n",
"0 数据库tb_process表中有多少未删除的数据 tb_process \n",
"1 在数据库tb_process表中product_system_id为12983111603... tb_process \n",
"2 在数据库tb_process表中process_type为unit_process且p... tb_process \n",
"3 在数据库tb_process表中id为1300755365750636544的记录的id、... tb_process \n",
"4 在数据库tb_process表中id为9237375的记录的id、ref_id、name、... tb_process "
]
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"df = pd.read_excel(r\"./data/sql_ques_clear.xlsx\")\n",
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-15T12:08:43.896102500Z",
"start_time": "2024-11-15T12:08:31.101109600Z"
},
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'tb_process': ['share', 'deleted', 'update_time', 'create_time', 'create_user', 'reference_process_id', 'parent_id', 'node_type', 'extend', 'id', 'version', 'last_change', 'category_id', 'infrastructure_process', 'quantitative_reference_id', 'location_id', 'process_doc_id', 'dq_system_id', 'exchange_dq_system_id', 'social_dq_system_id', 'last_internal_id', 'currency_id', 'product_system_id', 'process_source', 'node', 'ref_id', 'name', 'default_allocation_method', 'synonyms', 'dq_entry', 'tags', 'library', 'description', 'process_type']}\n",
"SQL: SELECT COUNT(*) FROM tb_process\n",
"res: SELECT COUNT(*) FROM tb_process\n"
]
}
],
"source": [
"ques_test = \"数据库tb_process表中有多少未删除的数据\"\n",
"tmp_dict = dict()\n",
"tmp_dict[\"tb_process\"] = schema[\"tb_process\"]\n",
"print(tmp_dict)\n",
"res = query_database(ques_test,tmp_dict)\n",
"print(\"res:\",res)"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-15T12:13:59.277771300Z",
"start_time": "2024-11-15T12:08:43.881680800Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"SQL: SELECT COUNT(*) FROM tb_process\n",
"SQL: SELECT tags FROM tb_process WHERE product_system_id='1298311160348545024' AND is_deleted <> '1'\n",
"SQL: SELECT id, ref_id, name, synonyms, category_id, process_type, default_allocation_method, infrastructure_process, quantitative_reference_id, dq_entry, dq_system_id, exchange_dq_system_id, social_dq_system_id, location_id, process_doc_id, currency_id, version, last_change, tags, library, product_system_id, process_source, reference_process_id, node, parent_id, node_type, deleted, create_time, update_time, create_user, extend, share FROM tb_process WHERE process_type='unit_process' AND process_source='0'\n",
"SQL: SELECT id, ref_id, name, synonyms, category_id, description, process_type, default_allocation_method, infrastructure_process, quantitative_reference_id, dq_entry, dq_system_id, exchange_dq_system_id, social_dq_system_id, location_id, process_doc_id, currency_id, version, last_change, tags, library, product_system_id, process_source, reference_process_id, node, parent_id, node_type, deleted, create_time, update_time, create_user, extend, share FROM tb_process WHERE id='1300755365750636544'\n",
"SQL: SELECT id, ref_id, name, synonyms, category_id, description, process_type, default_allocation_method, infrastructure_process, quantitative_reference_id, dq_entry, dq_system_id, exchange_dq_system_id, social_dq_system_id, location_id, process_doc_id, currency_id, version, last_change, tags, library, product_system_id, process_source, reference_process_id, node, parent_id, node_type, deleted, create_time, update_time, create_user, extend, share FROM tb_process WHERE id='9237375'\n",
"SQL: SELECT COUNT(*) FROM tb_process WHERE name='铸铁' AND product_system_id='1300754248748761088' AND id!='1300755365750636544' AND is_deleted!='1'\n",
"SQL: SELECT * FROM tb_process WHERE ref_id='49e7b885-8db7-3d63-8fce-5e32b8a71391' AND process_type='unit_process' LIMIT 1\n",
"SQL: SELECT location_id FROM tb_process WHERE id='3448019508185273680'\n",
"SQL: SELECT id, ref_id, name, synonyms, category_id, description, process_type, default_allocation_method, infrastructure_process, quantitative_reference_id, dq_entry, dq_system_id, exchange_dq_system_id, social_dq_system_id, location_id, process_doc_id, currency_id, version, last_change, tags, library, product_system_id, process_source, reference_process_id, node, parent_id, node_type, deleted, create_time, update_time, create_user, extend, share FROM tb_process WHERE product_system_id='1276473778162892800' AND deleted='0'\n",
"SQL: SELECT id, name, node_type, tags, process_type, node FROM tb_process WHERE parent_id='1250448580720721920' AND deleted<>1\n",
"SQL: SELECT id, name, node_type, tags, process_type, node FROM tb_process WHERE parent_id='1250448580720721920' AND deleted<>1\n",
"SQL: SELECT id, ref_id, name, synonyms, category_id, description, process_type, default_allocation_method, infrastructure_process, quantitative_reference_id, dq_entry, dq_system_id, exchange_dq_system_id, social_dq_system_id, location_id, process_doc_id, currency_id, version, last_change, tags, library, product_system_id, process_source, reference_process_id, node, parent_id, node_type, deleted, create_time, update_time, create_user, extend, share FROM tb_process WHERE id IN ('1176899308570542080','1176898365472899072')\n",
"SQL: SELECT id, ref_id, name, synonyms, category_id, description, process_type, default_allocation_method, infrastructure_process, quantitative_reference_id, dq_entry, dq_system_id, exchange_dq_system_id, social_dq_system_id, location_id, process_doc_id, currency_id, version, last_change, tags, library, product_system_id, process_source, reference_process_id, node, parent_id, node_type, deleted, create_time, update_time, create_user, extend, share FROM tb_process WHERE synonyms LIKE '%欧元%'\n",
"SQL: SELECT id, ref_id, name, synonyms, category_id, description, process_type, default_allocation_method, infrastructure_process, quantitative_reference_id, dq_entry, dq_system_id, exchange_dq_system_id, social_dq_system_id, location_id, process_doc_id, currency_id, version, last_change, tags, library, product_system_id, process_source, reference_process_id, node, parent_id, node_type, deleted, create_time, update_time, create_user, extend, share FROM tb_process WHERE synonyms LIKE '%欧元%' AND process_source='0'\n",
"SQL: SELECT id, ref_id, name, synonyms, category_id, description, process_type, default_allocation_method, infrastructure_process, quantitative_reference_id, dq_entry, dq_system_id, exchange_dq_system_id, social_dq_system_id, location_id, process_doc_id, currency_id, version, last_change, tags, library, product_system_id, process_source, reference_process_id, node, parent_id, node_type, deleted, create_time, update_time, create_user, extend, share FROM tb_process WHERE product_system_id='1298206895525330944' AND (node_type='0' OR node_type='1') AND deleted='0' AND parent_id IS NULL ORDER BY id DESC\n",
"SQL: SELECT COUNT(*) FROM tb_flow WHERE deleted='0'\n",
"SQL: SELECT COUNT(*) FROM tb_flow WHERE deleted='0'\n",
"SQL: SELECT id, ref_id, name, flow_type, cas_number, description, tags, library, synonyms, infrastructure_flow, formula, category_id, reference_flow_property_id, source, version, last_change, create_time, create_user, update_user, deleted, update_time, location_id, name_cn, synonyms_cn, is_system, extend FROM tb_flow WHERE deleted='0' LIMIT 20000\n",
"SQL: SELECT id, ref_id, name, flow_type, cas_number, description, tags, library, synonyms, infrastructure_flow, formula, category_id, reference_flow_property_id, source, version, last_change, create_time, create_user, update_user, deleted, update_time, location_id, name_cn, synonyms_cn, is_system, extend FROM tb_flow WHERE deleted='0' LIMIT 20000 OFFSET 60000\n",
"SQL: SELECT COUNT(*) AS total FROM tb_calculation_task WHERE product_system_id='1298311160348545024'\n",
"SQL: INSERT INTO tb_calculation_task (id, status, product_system_id, final_demand, method_id, method_name, calculate_type, create_user) VALUES ('1301153780011634688', '0', '1276473778162892800', 'JSON字符串', '44645931', 'ipcc 2013 gwp 100a', 'lci,lcia,contribution,sensitivity', '10078')\n",
"SQL: SELECT COUNT(*) AS total FROM tb_product_system WHERE creator='10043' AND is_deleted<>'1'\n",
"SQL: SELECT COUNT(*) AS total FROM tb_product_system WHERE deleted='0'\n",
"SQL: SELECT id, ref_id, name, description, location_id, category_id, cutoff, reference_process_id, reference_exchange_id, target_amount, target_flow_property_factor_id, target_unit_id, model_info, project_type, create_user, create_time, update_time, deleted FROM tb_product_system WHERE id='1298311160348545024'\n",
"SQL: SELECT COUNT(*) AS total FROM sys_user WHERE delete_flag = '0'\n",
"SQL: SELECT id FROM sys_user WHERE org_id IN ('11', '12')\n",
"SQL: SELECT id, product_system_id, form_id, data, extend FROM tb_form_data WHERE product_system_id='1247208720564224000' AND form_id='4'\n",
"SQL: SELECT id FROM tb_form_data WHERE form_id='14' AND product_system_id='1299311046711840768'\n",
"SQL: SELECT id, name, method_id, direction, reference_unit FROM tb_impact_category WHERE method_id='43262324' ORDER BY id\n",
"SQL: SELECT SUM(amount) AS total FROM tb_lcia WHERE task_id='1298311273208877056' AND impact_category_id='43262501'\n",
"SQL: SELECT COUNT(*) FROM (SELECT * FROM tb_parameter WHERE create_user = '10028' AND scope = 'global' UNION ALL SELECT * FROM tb_parameter WHERE process_id = '3448019511314224627' AND create_user = '10028')\n",
"SQL: SELECT id, type, name, sort, template, script, script_url, extend FROM tb_product_system_form WHERE id='14'\n",
"SQL: SELECT id, name, sort, extend FROM tb_product_system_form WHERE type='cbam' ORDER BY sort ASC\n",
"SQL: SELECT * FROM tb_process_doc WHERE valid_from > '2020-01-01'\n",
"SQL: SELECT * FROM tb_process_doc WHERE reviewer_id='1212'\n",
"SQL: SELECT COUNT(*) FROM tb_process_doc WHERE project='asdfasdf'\n",
"SQL: SELECT * FROM tb_process_doc WHERE valid_from > '2019-01-01' AND valid_until < '2022-12-31'\n",
"SQL: SELECT * FROM tb_process_doc WHERE geography='17 European plants'\n",
"SQL: SELECT * FROM tb_process_doc WHERE technology LIKE '%Acid-catalysed esterification of free-fatty acids.%'\n",
"SQL: SELECT reviewer_id, COUNT(*) as record_count FROM tb_process_doc GROUP BY reviewer_id\n",
"SQL: SELECT * FROM tb_process_doc WHERE copyright='1'\n",
"SQL: SELECT * FROM tb_process_doc WHERE data_generator_id='1319'\n",
"SQL: SELECT dataset_owner_id, COUNT(*) as record_count FROM tb_process_doc GROUP BY dataset_owner_id\n",
"SQL: SELECT * FROM tb_process_doc WHERE data_documentor_id='1217'\n",
"SQL: SELECT * FROM tb_process_doc WHERE publication_id='1500'\n",
"SQL: SELECT * FROM tb_process_doc WHERE preceding_dataset='6006a8fc-1a6c-55ee-a49d-fc9342f5d451'\n",
"SQL: SELECT * FROM tb_process_doc WHERE data_treatment LIKE '%Approximation%'\n",
"SQL: SELECT * FROM tb_process_doc WHERE data_collection_period LIKE '%2020-2021%'\n",
"SQL: SELECT * FROM tb_process WHERE deleted='1'\n",
"SQL: SELECT * FROM tb_process WHERE update_time > '2024-10-15'\n",
"SQL: SELECT COUNT(*) FROM tb_process WHERE create_user='10000'\n",
"SQL: SELECT * FROM tb_process WHERE category_id='66'\n",
"SQL: SELECT location_id, COUNT(*) FROM tb_process GROUP BY location_id\n",
"SQL: SELECT * FROM tb_process WHERE dq_system_id='129124'\n",
"SQL: SELECT * FROM tb_process WHERE tags LIKE '%10煤电%'\n",
"SQL: SELECT * FROM tb_process WHERE default_allocation_method='economic allocation'\n",
"SQL: SELECT process_type, COUNT(*) as count FROM tb_process GROUP BY process_type\n",
"SQL: SELECT * FROM tb_process WHERE name LIKE '%AAAAsadfasd%'\n",
"SQL: SELECT * FROM tb_process WHERE node_type='2'\n",
"SQL: SELECT version, COUNT(*) as count FROM tb_process GROUP BY version\n",
"SQL: SELECT * FROM tb_process WHERE process_doc_id='594641'\n",
"SQL: SELECT * FROM tb_process WHERE default_allocation_method='CAUSAL'\n",
"SQL: SELECT * FROM tb_process WHERE create_time BETWEEN '2022-01-01' AND '2023-01-01'\n",
"SQL: SELECT * FROM tb_process WHERE process_source='1'\n",
"SQL: SELECT * FROM tb_process WHERE ref_id='d4730e0a-9bae-3fde-93e1-cac96735aa76'\n",
"SQL: SELECT * FROM tb_process WHERE infrastructure_process='0'\n"
]
}
],
"source": [
"sql_gold = []\n",
"sql_pred = []\n",
"\n",
"for index,item in df.iterrows():\n",
" prompt = item['Question']\n",
" table = item['table']\n",
" gold = item['SQL']\n",
" \n",
" if table not in table_names:\n",
" print(table)\n",
" continue\n",
" if 'join' in gold:\n",
" continue\n",
" colum_type = get_column_types(cur,table)\n",
" tmp = dict()\n",
" tmp[table] = schema[table]\n",
" sql = query_database(prompt,tmp)\n",
" sql_pred.append(sql)\n",
" sql_gold.append(gold)"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-15T12:27:01.460453800Z",
"start_time": "2024-11-15T12:27:01.429826700Z"
},
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"66"
]
},
"execution_count": 44,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(sql_gold)"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-15T12:27:03.001823100Z",
"start_time": "2024-11-15T12:27:02.925201900Z"
}
},
"outputs": [],
"source": [
"def save_txt(save_list,path):\n",
" with open(path,\"w\",encoding=\"utf-8\") as f:\n",
" for item in save_list:\n",
" f.write(item+\"\\n\")"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-15T12:27:03.826791400Z",
"start_time": "2024-11-15T12:27:03.758307800Z"
},
"collapsed": false
},
"outputs": [],
"source": [
"save_txt(sql_gold,\"./data/Qianfan/gold.txt\")\n",
"save_txt(sql_pred,\"./data/Qianfan/pred.txt\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-13T03:32:33.063956100Z",
"start_time": "2024-11-13T03:26:33.872560100Z"
}
},
"source": [
"### 改路径!!!"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-13T01:15:39.906977700Z",
"start_time": "2024-11-13T01:15:39.866974300Z"
},
"collapsed": false
},
"outputs": [],
"source": [
"!python --version"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 执行准确率"
]
},
{
"cell_type": "code",
"execution_count": 73,
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-13T10:37:42.149121600Z",
"start_time": "2024-11-13T10:37:42.077659900Z"
},
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[\"'last_login_time'\", \"'version'\", \"'deleted'\", \"'creator'\", \"'create_time'\", \"'updater'\", \"'update_time'\", \"'user_type'\", \"'person_auth_id'\", \"'enterprise_auth_id'\", \"'id'\", \"'gender'\", \"'org_id'\", \"'super_admin'\", \"'status'\", \"'username'\", \"'password'\", \"'real_name'\", \"'avatar'\", \"'open_id'\", \"'email'\", \"'mobile'\", \"'nickname'\", \"'id_number'\", \"'remark'\"]\n"
]
}
],
"source": [
"print(schema['sys_user'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Qwen",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 2
}