690 lines
29 KiB
Plaintext
690 lines
29 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2024-11-15T08:05:02.254040400Z",
|
||
"start_time": "2024-11-15T08:05:00.664449900Z"
|
||
}
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"import os\n",
|
||
"from langchain_community.llms import QianfanLLMEndpoint,Tongyi\n",
|
||
"from langchain_community.chat_models import ChatZhipuAI\n",
|
||
"from langchain.chains import LLMChain\n",
|
||
"from langchain.prompts import PromptTemplate\n",
|
||
"import psycopg2\n",
|
||
"import re"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 32,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2024-11-15T09:27:27.129813200Z",
|
||
"start_time": "2024-11-15T09:27:27.095824600Z"
|
||
}
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 百度千帆\n",
|
||
"api = \"xxxxx\"\n",
|
||
"sk = \"xxxxxx\"\n",
|
||
"\n",
|
||
"llm = QianfanLLMEndpoint(model=\"ERNIE-4.0-8K\",qianfan_ak=api,qianfan_sk=sk)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2024-11-15T08:05:13.561075500Z",
|
||
"start_time": "2024-11-15T08:05:13.451555700Z"
|
||
},
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 千问\n",
|
||
"# os.environ['DASHSCOPE_API_KEY'] = 'sk-xxxxxxxxx'\n",
|
||
"# \n",
|
||
"# llm = Tongyi(model_name=\"qwen-max\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 5,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2024-11-15T08:05:18.389436200Z",
|
||
"start_time": "2024-11-15T08:05:15.487698800Z"
|
||
}
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"生命的意义在于追求个人成长与幸福,同时为他人和社会贡献价值,实现自我与世界的和谐共存。\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"input_text = \"用50个字左右阐述,生命的意义在于\"\n",
|
||
"response = llm.invoke(input_text)\n",
|
||
"print(response)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 10,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2024-11-15T08:11:05.408878300Z",
|
||
"start_time": "2024-11-15T08:11:05.273677300Z"
|
||
}
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 数据库连接参数\n",
|
||
"dbname = \"gis_lca\"\n",
|
||
"user = \"postgres\"\n",
|
||
"password = \"xxxxxx\"\n",
|
||
"host = \"localhost\" # 或者是你 Docker 容器的 IP 地址,如果你在不同的机器上\n",
|
||
"port = \"5432\"\n",
|
||
"\n",
|
||
"# 连接字符串\n",
|
||
"conn_string = f\"host={host} dbname={dbname} user={user} password={password} port={port}\"\n",
|
||
"# 连接到数据库\n",
|
||
"conn = psycopg2.connect(conn_string)\n",
|
||
"cur = conn.cursor()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 11,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2024-11-15T08:11:53.090176Z",
|
||
"start_time": "2024-11-15T08:11:53.041140900Z"
|
||
},
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 获取表名\n",
|
||
"def get_table_name(cur):\n",
|
||
" # 执行 SQL 查询以获取所有表的列表\n",
|
||
" cur.execute(\"\"\"\n",
|
||
" SELECT table_name \n",
|
||
" FROM information_schema.tables \n",
|
||
" WHERE table_schema = 'public';\n",
|
||
" \"\"\")\n",
|
||
" \n",
|
||
" # 获取查询结果\n",
|
||
" tables = cur.fetchall()\n",
|
||
" tname = []\n",
|
||
" # 打印表名\n",
|
||
" for table in tables:\n",
|
||
" tname.append(table[0])\n",
|
||
" \n",
|
||
" return tname"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 12,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2024-11-15T08:13:58.272398Z",
|
||
"start_time": "2024-11-15T08:13:58.238434200Z"
|
||
}
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 获取数据库列名。\n",
|
||
"def get_table_columns(cur,table_name):\n",
|
||
" try:\n",
|
||
" # 执行 SQL 查询以获取表的列名\n",
|
||
" cur.execute(f\"\"\"\n",
|
||
" SELECT column_name \n",
|
||
" FROM information_schema.columns \n",
|
||
" WHERE table_name = '{table_name}';\n",
|
||
" \"\"\")\n",
|
||
" # 获取查询结果\n",
|
||
" cols = [desc[0] for desc in cur.fetchall()]\n",
|
||
" columns = []\n",
|
||
" for col in cols:\n",
|
||
" # column = \"'\"+col+\"'\"\n",
|
||
" columns.append(col)\n",
|
||
" except Exception as e:\n",
|
||
" print(f\"An error occurred: {e}\")\n",
|
||
" return columns"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 13,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2024-11-15T08:13:59.601786Z",
|
||
"start_time": "2024-11-15T08:13:59.206151400Z"
|
||
},
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 获取数据库表结构,表名对应列名\n",
|
||
"schema = dict()\n",
|
||
"\n",
|
||
"table_names = get_table_name(cur)\n",
|
||
"for name in table_names:\n",
|
||
" schema[name] = get_table_columns(cur,name)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 14,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2024-11-15T08:14:02.392006100Z",
|
||
"start_time": "2024-11-15T08:14:02.359287Z"
|
||
},
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"def get_column_types(cur, table_name):\n",
|
||
" # 执行 SQL 查询以获取表的列名和数据类型\n",
|
||
" cur.execute(\"\"\"\n",
|
||
" SELECT column_name, data_type\n",
|
||
" FROM information_schema.columns\n",
|
||
" WHERE table_name = %s;\n",
|
||
" \"\"\", (table_name,))\n",
|
||
"\n",
|
||
" # 获取查询结果\n",
|
||
" results = cur.fetchall()\n",
|
||
" return results"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 15,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2024-11-15T08:14:04.182996600Z",
|
||
"start_time": "2024-11-15T08:14:04.117421900Z"
|
||
},
|
||
"collapsed": false
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"[('share', 'integer'), ('deleted', 'smallint'), ('update_time', 'timestamp without time zone'), ('create_time', 'timestamp without time zone'), ('create_user', 'bigint'), ('reference_process_id', 'bigint'), ('parent_id', 'bigint'), ('node_type', 'smallint'), ('extend', 'json'), ('id', 'bigint'), ('version', 'bigint'), ('last_change', 'bigint'), ('category_id', 'bigint'), ('infrastructure_process', 'smallint'), ('quantitative_reference_id', 'bigint'), ('location_id', 'bigint'), ('process_doc_id', 'bigint'), ('dq_system_id', 'bigint'), ('exchange_dq_system_id', 'bigint'), ('social_dq_system_id', 'bigint'), ('last_internal_id', 'integer'), ('currency_id', 'bigint'), ('product_system_id', 'bigint'), ('process_source', 'smallint'), ('node', 'jsonb'), ('ref_id', 'character varying'), ('name', 'character varying'), ('default_allocation_method', 'character varying'), ('synonyms', 'text'), ('dq_entry', 'character varying'), ('tags', 'character varying'), ('library', 'character varying'), ('description', 'text'), ('process_type', 'character varying')]\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"col = get_column_types(cur,\"tb_process\")\n",
|
||
"print(col)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 16,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2024-11-15T08:14:05.075951300Z",
|
||
"start_time": "2024-11-15T08:14:05.044963600Z"
|
||
}
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"def getSql(response):\n",
|
||
" pattern = r\"#(.*?)#\"\n",
|
||
" sql_list = re.findall(pattern, response)\n",
|
||
" sql_str = sql_list[-1]\n",
|
||
" # sql_str = sql_str.lower()\n",
|
||
" print(\"SQL:\",sql_str)\n",
|
||
"\n",
|
||
" return sql_str"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 40,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2024-11-15T12:08:29.507242700Z",
|
||
"start_time": "2024-11-15T12:08:29.440042800Z"
|
||
}
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"def query_database(text, schema):\n",
|
||
" # 将表结构信息包含在提示中\n",
|
||
" schema_info = \"\\n\".join([f\"Table {table}: {', '.join(columns)}\" for table, columns in schema.items()])\n",
|
||
" # print(schema_info)\n",
|
||
" # prompt= PromptTemplate(\n",
|
||
" # input_variables=[\"schema_info\",\"prompt\"],\n",
|
||
" # template='''以下是数据库的表结构信息:{schema_info},分析表结构信息.\n",
|
||
" # 请根据以下描述生成一个符合SQLite数据库语法的SQL查询,并且不能修改给出的数据表列名。\n",
|
||
" # 描述:{prompt}。\n",
|
||
" # 要求输出的SQL以#开头,以#结尾,参数类型一律按照字符串处理,样例如下:\n",
|
||
" # #SELECT * FROM table#\n",
|
||
" # #SELECT COUNT(*) FROM table where Colume_name='abc'#\n",
|
||
" # 注意不要输出分析过程和其他内容,直接给出SQL语句。\n",
|
||
" # '''\n",
|
||
" # )\n",
|
||
" prompt= PromptTemplate(\n",
|
||
" input_variables=[\"prompt\"],\n",
|
||
" template='''请根据以下描述生成一个符合SQLite数据库语法的SQL查询。\n",
|
||
" 描述:{prompt}。\n",
|
||
" 并且要求输出的SQL以#开头,以#结尾,参数类型一律按照字符串处理,样例如下:\n",
|
||
" #SELECT * FROM table#\n",
|
||
" #SELECT COUNT(*) FROM table where Colume_name='abc'#\n",
|
||
" 注意不要输出分析过程和其他内容,直接给出SQL语句。\n",
|
||
" '''\n",
|
||
" )\n",
|
||
" # print(schema_info)\n",
|
||
" chain = LLMChain(llm=llm, prompt=prompt)\n",
|
||
" query = chain.run(\n",
|
||
" {\"prompt\":text}\n",
|
||
" )\n",
|
||
" # print(query)\n",
|
||
" sql = getSql(query)\n",
|
||
" # result = execute_query(sql)\n",
|
||
" return sql"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 41,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2024-11-15T12:08:30.524029800Z",
|
||
"start_time": "2024-11-15T12:08:30.209010600Z"
|
||
}
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>SQL</th>\n",
|
||
" <th>Question</th>\n",
|
||
" <th>table</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>0</th>\n",
|
||
" <td>select count( * ) as total from tb_process whe...</td>\n",
|
||
" <td>数据库tb_process表中有多少未删除的数据?</td>\n",
|
||
" <td>tb_process</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1</th>\n",
|
||
" <td>select tags from tb_process where product_syst...</td>\n",
|
||
" <td>在数据库tb_process表中,product_system_id为12983111603...</td>\n",
|
||
" <td>tb_process</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2</th>\n",
|
||
" <td>select id,ref_id,name,synonyms,category_id,pro...</td>\n",
|
||
" <td>在数据库tb_process表中,process_type为‘unit_process’且p...</td>\n",
|
||
" <td>tb_process</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>3</th>\n",
|
||
" <td>select id,ref_id,name,synonyms,category_id,des...</td>\n",
|
||
" <td>在数据库tb_process表中,id为1300755365750636544的记录的id、...</td>\n",
|
||
" <td>tb_process</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>4</th>\n",
|
||
" <td>select id,ref_id,name,synonyms,category_id,des...</td>\n",
|
||
" <td>在数据库tb_process表中,id为9237375的记录的id、ref_id、name、...</td>\n",
|
||
" <td>tb_process</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" SQL \n",
|
||
"0 select count( * ) as total from tb_process whe... \\\n",
|
||
"1 select tags from tb_process where product_syst... \n",
|
||
"2 select id,ref_id,name,synonyms,category_id,pro... \n",
|
||
"3 select id,ref_id,name,synonyms,category_id,des... \n",
|
||
"4 select id,ref_id,name,synonyms,category_id,des... \n",
|
||
"\n",
|
||
" Question table \n",
|
||
"0 数据库tb_process表中有多少未删除的数据? tb_process \n",
|
||
"1 在数据库tb_process表中,product_system_id为12983111603... tb_process \n",
|
||
"2 在数据库tb_process表中,process_type为‘unit_process’且p... tb_process \n",
|
||
"3 在数据库tb_process表中,id为1300755365750636544的记录的id、... tb_process \n",
|
||
"4 在数据库tb_process表中,id为9237375的记录的id、ref_id、name、... tb_process "
|
||
]
|
||
},
|
||
"execution_count": 41,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"import pandas as pd\n",
|
||
"df = pd.read_excel(r\"./data/sql_ques_clear.xlsx\")\n",
|
||
"df.head()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 42,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2024-11-15T12:08:43.896102500Z",
|
||
"start_time": "2024-11-15T12:08:31.101109600Z"
|
||
},
|
||
"collapsed": false
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"{'tb_process': ['share', 'deleted', 'update_time', 'create_time', 'create_user', 'reference_process_id', 'parent_id', 'node_type', 'extend', 'id', 'version', 'last_change', 'category_id', 'infrastructure_process', 'quantitative_reference_id', 'location_id', 'process_doc_id', 'dq_system_id', 'exchange_dq_system_id', 'social_dq_system_id', 'last_internal_id', 'currency_id', 'product_system_id', 'process_source', 'node', 'ref_id', 'name', 'default_allocation_method', 'synonyms', 'dq_entry', 'tags', 'library', 'description', 'process_type']}\n",
|
||
"SQL: SELECT COUNT(*) FROM tb_process\n",
|
||
"res: SELECT COUNT(*) FROM tb_process\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"ques_test = \"数据库tb_process表中有多少未删除的数据?\"\n",
|
||
"tmp_dict = dict()\n",
|
||
"tmp_dict[\"tb_process\"] = schema[\"tb_process\"]\n",
|
||
"print(tmp_dict)\n",
|
||
"res = query_database(ques_test,tmp_dict)\n",
|
||
"print(\"res:\",res)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 43,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2024-11-15T12:13:59.277771300Z",
|
||
"start_time": "2024-11-15T12:08:43.881680800Z"
|
||
}
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"SQL: SELECT COUNT(*) FROM tb_process\n",
|
||
"SQL: SELECT tags FROM tb_process WHERE product_system_id='1298311160348545024' AND is_deleted <> '1'\n",
|
||
"SQL: SELECT id, ref_id, name, synonyms, category_id, process_type, default_allocation_method, infrastructure_process, quantitative_reference_id, dq_entry, dq_system_id, exchange_dq_system_id, social_dq_system_id, location_id, process_doc_id, currency_id, version, last_change, tags, library, product_system_id, process_source, reference_process_id, node, parent_id, node_type, deleted, create_time, update_time, create_user, extend, share FROM tb_process WHERE process_type='unit_process' AND process_source='0'\n",
|
||
"SQL: SELECT id, ref_id, name, synonyms, category_id, description, process_type, default_allocation_method, infrastructure_process, quantitative_reference_id, dq_entry, dq_system_id, exchange_dq_system_id, social_dq_system_id, location_id, process_doc_id, currency_id, version, last_change, tags, library, product_system_id, process_source, reference_process_id, node, parent_id, node_type, deleted, create_time, update_time, create_user, extend, share FROM tb_process WHERE id='1300755365750636544'\n",
|
||
"SQL: SELECT id, ref_id, name, synonyms, category_id, description, process_type, default_allocation_method, infrastructure_process, quantitative_reference_id, dq_entry, dq_system_id, exchange_dq_system_id, social_dq_system_id, location_id, process_doc_id, currency_id, version, last_change, tags, library, product_system_id, process_source, reference_process_id, node, parent_id, node_type, deleted, create_time, update_time, create_user, extend, share FROM tb_process WHERE id='9237375'\n",
|
||
"SQL: SELECT COUNT(*) FROM tb_process WHERE name='铸铁' AND product_system_id='1300754248748761088' AND id!='1300755365750636544' AND is_deleted!='1'\n",
|
||
"SQL: SELECT * FROM tb_process WHERE ref_id='49e7b885-8db7-3d63-8fce-5e32b8a71391' AND process_type='unit_process' LIMIT 1\n",
|
||
"SQL: SELECT location_id FROM tb_process WHERE id='3448019508185273680'\n",
|
||
"SQL: SELECT id, ref_id, name, synonyms, category_id, description, process_type, default_allocation_method, infrastructure_process, quantitative_reference_id, dq_entry, dq_system_id, exchange_dq_system_id, social_dq_system_id, location_id, process_doc_id, currency_id, version, last_change, tags, library, product_system_id, process_source, reference_process_id, node, parent_id, node_type, deleted, create_time, update_time, create_user, extend, share FROM tb_process WHERE product_system_id='1276473778162892800' AND deleted='0'\n",
|
||
"SQL: SELECT id, name, node_type, tags, process_type, node FROM tb_process WHERE parent_id='1250448580720721920' AND deleted<>1\n",
|
||
"SQL: SELECT id, name, node_type, tags, process_type, node FROM tb_process WHERE parent_id='1250448580720721920' AND deleted<>1\n",
|
||
"SQL: SELECT id, ref_id, name, synonyms, category_id, description, process_type, default_allocation_method, infrastructure_process, quantitative_reference_id, dq_entry, dq_system_id, exchange_dq_system_id, social_dq_system_id, location_id, process_doc_id, currency_id, version, last_change, tags, library, product_system_id, process_source, reference_process_id, node, parent_id, node_type, deleted, create_time, update_time, create_user, extend, share FROM tb_process WHERE id IN ('1176899308570542080','1176898365472899072')\n",
|
||
"SQL: SELECT id, ref_id, name, synonyms, category_id, description, process_type, default_allocation_method, infrastructure_process, quantitative_reference_id, dq_entry, dq_system_id, exchange_dq_system_id, social_dq_system_id, location_id, process_doc_id, currency_id, version, last_change, tags, library, product_system_id, process_source, reference_process_id, node, parent_id, node_type, deleted, create_time, update_time, create_user, extend, share FROM tb_process WHERE synonyms LIKE '%欧元%'\n",
|
||
"SQL: SELECT id, ref_id, name, synonyms, category_id, description, process_type, default_allocation_method, infrastructure_process, quantitative_reference_id, dq_entry, dq_system_id, exchange_dq_system_id, social_dq_system_id, location_id, process_doc_id, currency_id, version, last_change, tags, library, product_system_id, process_source, reference_process_id, node, parent_id, node_type, deleted, create_time, update_time, create_user, extend, share FROM tb_process WHERE synonyms LIKE '%欧元%' AND process_source='0'\n",
|
||
"SQL: SELECT id, ref_id, name, synonyms, category_id, description, process_type, default_allocation_method, infrastructure_process, quantitative_reference_id, dq_entry, dq_system_id, exchange_dq_system_id, social_dq_system_id, location_id, process_doc_id, currency_id, version, last_change, tags, library, product_system_id, process_source, reference_process_id, node, parent_id, node_type, deleted, create_time, update_time, create_user, extend, share FROM tb_process WHERE product_system_id='1298206895525330944' AND (node_type='0' OR node_type='1') AND deleted='0' AND parent_id IS NULL ORDER BY id DESC\n",
|
||
"SQL: SELECT COUNT(*) FROM tb_flow WHERE deleted='0'\n",
|
||
"SQL: SELECT COUNT(*) FROM tb_flow WHERE deleted='0'\n",
|
||
"SQL: SELECT id, ref_id, name, flow_type, cas_number, description, tags, library, synonyms, infrastructure_flow, formula, category_id, reference_flow_property_id, source, version, last_change, create_time, create_user, update_user, deleted, update_time, location_id, name_cn, synonyms_cn, is_system, extend FROM tb_flow WHERE deleted='0' LIMIT 20000\n",
|
||
"SQL: SELECT id, ref_id, name, flow_type, cas_number, description, tags, library, synonyms, infrastructure_flow, formula, category_id, reference_flow_property_id, source, version, last_change, create_time, create_user, update_user, deleted, update_time, location_id, name_cn, synonyms_cn, is_system, extend FROM tb_flow WHERE deleted='0' LIMIT 20000 OFFSET 60000\n",
|
||
"SQL: SELECT COUNT(*) AS total FROM tb_calculation_task WHERE product_system_id='1298311160348545024'\n",
|
||
"SQL: INSERT INTO tb_calculation_task (id, status, product_system_id, final_demand, method_id, method_name, calculate_type, create_user) VALUES ('1301153780011634688', '0', '1276473778162892800', 'JSON字符串', '44645931', 'ipcc 2013 gwp 100a', 'lci,lcia,contribution,sensitivity', '10078')\n",
|
||
"SQL: SELECT COUNT(*) AS total FROM tb_product_system WHERE creator='10043' AND is_deleted<>'1'\n",
|
||
"SQL: SELECT COUNT(*) AS total FROM tb_product_system WHERE deleted='0'\n",
|
||
"SQL: SELECT id, ref_id, name, description, location_id, category_id, cutoff, reference_process_id, reference_exchange_id, target_amount, target_flow_property_factor_id, target_unit_id, model_info, project_type, create_user, create_time, update_time, deleted FROM tb_product_system WHERE id='1298311160348545024'\n",
|
||
"SQL: SELECT COUNT(*) AS total FROM sys_user WHERE delete_flag = '0'\n",
|
||
"SQL: SELECT id FROM sys_user WHERE org_id IN ('11', '12')\n",
|
||
"SQL: SELECT id, product_system_id, form_id, data, extend FROM tb_form_data WHERE product_system_id='1247208720564224000' AND form_id='4'\n",
|
||
"SQL: SELECT id FROM tb_form_data WHERE form_id='14' AND product_system_id='1299311046711840768'\n",
|
||
"SQL: SELECT id, name, method_id, direction, reference_unit FROM tb_impact_category WHERE method_id='43262324' ORDER BY id\n",
|
||
"SQL: SELECT SUM(amount) AS total FROM tb_lcia WHERE task_id='1298311273208877056' AND impact_category_id='43262501'\n",
|
||
"SQL: SELECT COUNT(*) FROM (SELECT * FROM tb_parameter WHERE create_user = '10028' AND scope = 'global' UNION ALL SELECT * FROM tb_parameter WHERE process_id = '3448019511314224627' AND create_user = '10028')\n",
|
||
"SQL: SELECT id, type, name, sort, template, script, script_url, extend FROM tb_product_system_form WHERE id='14'\n",
|
||
"SQL: SELECT id, name, sort, extend FROM tb_product_system_form WHERE type='cbam' ORDER BY sort ASC\n",
|
||
"SQL: SELECT * FROM tb_process_doc WHERE valid_from > '2020-01-01'\n",
|
||
"SQL: SELECT * FROM tb_process_doc WHERE reviewer_id='1212'\n",
|
||
"SQL: SELECT COUNT(*) FROM tb_process_doc WHERE project='asdfasdf'\n",
|
||
"SQL: SELECT * FROM tb_process_doc WHERE valid_from > '2019-01-01' AND valid_until < '2022-12-31'\n",
|
||
"SQL: SELECT * FROM tb_process_doc WHERE geography='17 European plants'\n",
|
||
"SQL: SELECT * FROM tb_process_doc WHERE technology LIKE '%Acid-catalysed esterification of free-fatty acids.%'\n",
|
||
"SQL: SELECT reviewer_id, COUNT(*) as record_count FROM tb_process_doc GROUP BY reviewer_id\n",
|
||
"SQL: SELECT * FROM tb_process_doc WHERE copyright='1'\n",
|
||
"SQL: SELECT * FROM tb_process_doc WHERE data_generator_id='1319'\n",
|
||
"SQL: SELECT dataset_owner_id, COUNT(*) as record_count FROM tb_process_doc GROUP BY dataset_owner_id\n",
|
||
"SQL: SELECT * FROM tb_process_doc WHERE data_documentor_id='1217'\n",
|
||
"SQL: SELECT * FROM tb_process_doc WHERE publication_id='1500'\n",
|
||
"SQL: SELECT * FROM tb_process_doc WHERE preceding_dataset='6006a8fc-1a6c-55ee-a49d-fc9342f5d451'\n",
|
||
"SQL: SELECT * FROM tb_process_doc WHERE data_treatment LIKE '%Approximation%'\n",
|
||
"SQL: SELECT * FROM tb_process_doc WHERE data_collection_period LIKE '%2020-2021%'\n",
|
||
"SQL: SELECT * FROM tb_process WHERE deleted='1'\n",
|
||
"SQL: SELECT * FROM tb_process WHERE update_time > '2024-10-15'\n",
|
||
"SQL: SELECT COUNT(*) FROM tb_process WHERE create_user='10000'\n",
|
||
"SQL: SELECT * FROM tb_process WHERE category_id='66'\n",
|
||
"SQL: SELECT location_id, COUNT(*) FROM tb_process GROUP BY location_id\n",
|
||
"SQL: SELECT * FROM tb_process WHERE dq_system_id='129124'\n",
|
||
"SQL: SELECT * FROM tb_process WHERE tags LIKE '%10煤电%'\n",
|
||
"SQL: SELECT * FROM tb_process WHERE default_allocation_method='economic allocation'\n",
|
||
"SQL: SELECT process_type, COUNT(*) as count FROM tb_process GROUP BY process_type\n",
|
||
"SQL: SELECT * FROM tb_process WHERE name LIKE '%AAAAsadfasd%'\n",
|
||
"SQL: SELECT * FROM tb_process WHERE node_type='2'\n",
|
||
"SQL: SELECT version, COUNT(*) as count FROM tb_process GROUP BY version\n",
|
||
"SQL: SELECT * FROM tb_process WHERE process_doc_id='594641'\n",
|
||
"SQL: SELECT * FROM tb_process WHERE default_allocation_method='CAUSAL'\n",
|
||
"SQL: SELECT * FROM tb_process WHERE create_time BETWEEN '2022-01-01' AND '2023-01-01'\n",
|
||
"SQL: SELECT * FROM tb_process WHERE process_source='1'\n",
|
||
"SQL: SELECT * FROM tb_process WHERE ref_id='d4730e0a-9bae-3fde-93e1-cac96735aa76'\n",
|
||
"SQL: SELECT * FROM tb_process WHERE infrastructure_process='0'\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"sql_gold = []\n",
|
||
"sql_pred = []\n",
|
||
"\n",
|
||
"for index,item in df.iterrows():\n",
|
||
" prompt = item['Question']\n",
|
||
" table = item['table']\n",
|
||
" gold = item['SQL']\n",
|
||
" \n",
|
||
" if table not in table_names:\n",
|
||
" print(table)\n",
|
||
" continue\n",
|
||
" if 'join' in gold:\n",
|
||
" continue\n",
|
||
" colum_type = get_column_types(cur,table)\n",
|
||
" tmp = dict()\n",
|
||
" tmp[table] = schema[table]\n",
|
||
" sql = query_database(prompt,tmp)\n",
|
||
" sql_pred.append(sql)\n",
|
||
" sql_gold.append(gold)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 44,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2024-11-15T12:27:01.460453800Z",
|
||
"start_time": "2024-11-15T12:27:01.429826700Z"
|
||
},
|
||
"collapsed": false
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"66"
|
||
]
|
||
},
|
||
"execution_count": 44,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"len(sql_gold)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 45,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2024-11-15T12:27:03.001823100Z",
|
||
"start_time": "2024-11-15T12:27:02.925201900Z"
|
||
}
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"def save_txt(save_list,path):\n",
|
||
" with open(path,\"w\",encoding=\"utf-8\") as f:\n",
|
||
" for item in save_list:\n",
|
||
" f.write(item+\"\\n\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 46,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2024-11-15T12:27:03.826791400Z",
|
||
"start_time": "2024-11-15T12:27:03.758307800Z"
|
||
},
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"save_txt(sql_gold,\"./data/Qianfan/gold.txt\")\n",
|
||
"save_txt(sql_pred,\"./data/Qianfan/pred.txt\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2024-11-13T03:32:33.063956100Z",
|
||
"start_time": "2024-11-13T03:26:33.872560100Z"
|
||
}
|
||
},
|
||
"source": [
|
||
"### 改路径!!!"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 13,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2024-11-13T01:15:39.906977700Z",
|
||
"start_time": "2024-11-13T01:15:39.866974300Z"
|
||
},
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"!python --version"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 执行准确率"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 73,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2024-11-13T10:37:42.149121600Z",
|
||
"start_time": "2024-11-13T10:37:42.077659900Z"
|
||
},
|
||
"collapsed": false
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"[\"'last_login_time'\", \"'version'\", \"'deleted'\", \"'creator'\", \"'create_time'\", \"'updater'\", \"'update_time'\", \"'user_type'\", \"'person_auth_id'\", \"'enterprise_auth_id'\", \"'id'\", \"'gender'\", \"'org_id'\", \"'super_admin'\", \"'status'\", \"'username'\", \"'password'\", \"'real_name'\", \"'avatar'\", \"'open_id'\", \"'email'\", \"'mobile'\", \"'nickname'\", \"'id_number'\", \"'remark'\"]\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"print(schema['sys_user'])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": []
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Qwen",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.10.14"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 2
|
||
}
|