{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2024-11-15T08:05:02.254040400Z", "start_time": "2024-11-15T08:05:00.664449900Z" } }, "outputs": [], "source": [ "import os\n", "from langchain_community.llms import QianfanLLMEndpoint,Tongyi\n", "from langchain_community.chat_models import ChatZhipuAI\n", "from langchain.chains import LLMChain\n", "from langchain.prompts import PromptTemplate\n", "import psycopg2\n", "import re" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "ExecuteTime": { "end_time": "2024-11-15T09:27:27.129813200Z", "start_time": "2024-11-15T09:27:27.095824600Z" } }, "outputs": [], "source": [ "# 百度千帆\n", "api = \"xxxxx\"\n", "sk = \"xxxxxx\"\n", "\n", "llm = QianfanLLMEndpoint(model=\"ERNIE-4.0-8K\",qianfan_ak=api,qianfan_sk=sk)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2024-11-15T08:05:13.561075500Z", "start_time": "2024-11-15T08:05:13.451555700Z" }, "collapsed": false }, "outputs": [], "source": [ "# 千问\n", "# os.environ['DASHSCOPE_API_KEY'] = 'sk-xxxxxxxxx'\n", "# \n", "# llm = Tongyi(model_name=\"qwen-max\")" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "ExecuteTime": { "end_time": "2024-11-15T08:05:18.389436200Z", "start_time": "2024-11-15T08:05:15.487698800Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "生命的意义在于追求个人成长与幸福,同时为他人和社会贡献价值,实现自我与世界的和谐共存。\n" ] } ], "source": [ "input_text = \"用50个字左右阐述,生命的意义在于\"\n", "response = llm.invoke(input_text)\n", "print(response)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "ExecuteTime": { "end_time": "2024-11-15T08:11:05.408878300Z", "start_time": "2024-11-15T08:11:05.273677300Z" } }, "outputs": [], "source": [ "# 数据库连接参数\n", "dbname = \"gis_lca\"\n", "user = \"postgres\"\n", "password = \"xxxxxx\"\n", "host = \"localhost\" # 或者是你 Docker 容器的 IP 地址,如果你在不同的机器上\n", "port = \"5432\"\n", "\n", "# 连接字符串\n", "conn_string = f\"host={host} dbname={dbname} user={user} password={password} port={port}\"\n", "# 连接到数据库\n", "conn = psycopg2.connect(conn_string)\n", "cur = conn.cursor()" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "ExecuteTime": { "end_time": "2024-11-15T08:11:53.090176Z", "start_time": "2024-11-15T08:11:53.041140900Z" }, "collapsed": false }, "outputs": [], "source": [ "# 获取表名\n", "def get_table_name(cur):\n", " # 执行 SQL 查询以获取所有表的列表\n", " cur.execute(\"\"\"\n", " SELECT table_name \n", " FROM information_schema.tables \n", " WHERE table_schema = 'public';\n", " \"\"\")\n", " \n", " # 获取查询结果\n", " tables = cur.fetchall()\n", " tname = []\n", " # 打印表名\n", " for table in tables:\n", " tname.append(table[0])\n", " \n", " return tname" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "ExecuteTime": { "end_time": "2024-11-15T08:13:58.272398Z", "start_time": "2024-11-15T08:13:58.238434200Z" } }, "outputs": [], "source": [ "# 获取数据库列名。\n", "def get_table_columns(cur,table_name):\n", " try:\n", " # 执行 SQL 查询以获取表的列名\n", " cur.execute(f\"\"\"\n", " SELECT column_name \n", " FROM information_schema.columns \n", " WHERE table_name = '{table_name}';\n", " \"\"\")\n", " # 获取查询结果\n", " cols = [desc[0] for desc in cur.fetchall()]\n", " columns = []\n", " for col in cols:\n", " # column = \"'\"+col+\"'\"\n", " columns.append(col)\n", " except Exception as e:\n", " print(f\"An error occurred: {e}\")\n", " return columns" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "ExecuteTime": { "end_time": "2024-11-15T08:13:59.601786Z", "start_time": "2024-11-15T08:13:59.206151400Z" }, "collapsed": false }, "outputs": [], "source": [ "# 获取数据库表结构,表名对应列名\n", "schema = dict()\n", "\n", "table_names = get_table_name(cur)\n", "for name in table_names:\n", " schema[name] = get_table_columns(cur,name)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "ExecuteTime": { "end_time": "2024-11-15T08:14:02.392006100Z", "start_time": "2024-11-15T08:14:02.359287Z" }, "collapsed": false }, "outputs": [], "source": [ "def get_column_types(cur, table_name):\n", " # 执行 SQL 查询以获取表的列名和数据类型\n", " cur.execute(\"\"\"\n", " SELECT column_name, data_type\n", " FROM information_schema.columns\n", " WHERE table_name = %s;\n", " \"\"\", (table_name,))\n", "\n", " # 获取查询结果\n", " results = cur.fetchall()\n", " return results" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "ExecuteTime": { "end_time": "2024-11-15T08:14:04.182996600Z", "start_time": "2024-11-15T08:14:04.117421900Z" }, "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[('share', 'integer'), ('deleted', 'smallint'), ('update_time', 'timestamp without time zone'), ('create_time', 'timestamp without time zone'), ('create_user', 'bigint'), ('reference_process_id', 'bigint'), ('parent_id', 'bigint'), ('node_type', 'smallint'), ('extend', 'json'), ('id', 'bigint'), ('version', 'bigint'), ('last_change', 'bigint'), ('category_id', 'bigint'), ('infrastructure_process', 'smallint'), ('quantitative_reference_id', 'bigint'), ('location_id', 'bigint'), ('process_doc_id', 'bigint'), ('dq_system_id', 'bigint'), ('exchange_dq_system_id', 'bigint'), ('social_dq_system_id', 'bigint'), ('last_internal_id', 'integer'), ('currency_id', 'bigint'), ('product_system_id', 'bigint'), ('process_source', 'smallint'), ('node', 'jsonb'), ('ref_id', 'character varying'), ('name', 'character varying'), ('default_allocation_method', 'character varying'), ('synonyms', 'text'), ('dq_entry', 'character varying'), ('tags', 'character varying'), ('library', 'character varying'), ('description', 'text'), ('process_type', 'character varying')]\n" ] } ], "source": [ "col = get_column_types(cur,\"tb_process\")\n", "print(col)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "ExecuteTime": { "end_time": "2024-11-15T08:14:05.075951300Z", "start_time": "2024-11-15T08:14:05.044963600Z" } }, "outputs": [], "source": [ "def getSql(response):\n", " pattern = r\"#(.*?)#\"\n", " sql_list = re.findall(pattern, response)\n", " sql_str = sql_list[-1]\n", " # sql_str = sql_str.lower()\n", " print(\"SQL:\",sql_str)\n", "\n", " return sql_str" ] }, { "cell_type": "code", "execution_count": 40, "metadata": { "ExecuteTime": { "end_time": "2024-11-15T12:08:29.507242700Z", "start_time": "2024-11-15T12:08:29.440042800Z" } }, "outputs": [], "source": [ "def query_database(text, schema):\n", " # 将表结构信息包含在提示中\n", " schema_info = \"\\n\".join([f\"Table {table}: {', '.join(columns)}\" for table, columns in schema.items()])\n", " # print(schema_info)\n", " # prompt= PromptTemplate(\n", " # input_variables=[\"schema_info\",\"prompt\"],\n", " # template='''以下是数据库的表结构信息:{schema_info},分析表结构信息.\n", " # 请根据以下描述生成一个符合SQLite数据库语法的SQL查询,并且不能修改给出的数据表列名。\n", " # 描述:{prompt}。\n", " # 要求输出的SQL以#开头,以#结尾,参数类型一律按照字符串处理,样例如下:\n", " # #SELECT * FROM table#\n", " # #SELECT COUNT(*) FROM table where Colume_name='abc'#\n", " # 注意不要输出分析过程和其他内容,直接给出SQL语句。\n", " # '''\n", " # )\n", " prompt= PromptTemplate(\n", " input_variables=[\"prompt\"],\n", " template='''请根据以下描述生成一个符合SQLite数据库语法的SQL查询。\n", " 描述:{prompt}。\n", " 并且要求输出的SQL以#开头,以#结尾,参数类型一律按照字符串处理,样例如下:\n", " #SELECT * FROM table#\n", " #SELECT COUNT(*) FROM table where Colume_name='abc'#\n", " 注意不要输出分析过程和其他内容,直接给出SQL语句。\n", " '''\n", " )\n", " # print(schema_info)\n", " chain = LLMChain(llm=llm, prompt=prompt)\n", " query = chain.run(\n", " {\"prompt\":text}\n", " )\n", " # print(query)\n", " sql = getSql(query)\n", " # result = execute_query(sql)\n", " return sql" ] }, { "cell_type": "code", "execution_count": 41, "metadata": { "ExecuteTime": { "end_time": "2024-11-15T12:08:30.524029800Z", "start_time": "2024-11-15T12:08:30.209010600Z" } }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
SQLQuestiontable
0select count( * ) as total from tb_process whe...数据库tb_process表中有多少未删除的数据?tb_process
1select tags from tb_process where product_syst...在数据库tb_process表中,product_system_id为12983111603...tb_process
2select id,ref_id,name,synonyms,category_id,pro...在数据库tb_process表中,process_type为‘unit_process’且p...tb_process
3select id,ref_id,name,synonyms,category_id,des...在数据库tb_process表中,id为1300755365750636544的记录的id、...tb_process
4select id,ref_id,name,synonyms,category_id,des...在数据库tb_process表中,id为9237375的记录的id、ref_id、name、...tb_process
\n", "
" ], "text/plain": [ " SQL \n", "0 select count( * ) as total from tb_process whe... \\\n", "1 select tags from tb_process where product_syst... \n", "2 select id,ref_id,name,synonyms,category_id,pro... \n", "3 select id,ref_id,name,synonyms,category_id,des... \n", "4 select id,ref_id,name,synonyms,category_id,des... \n", "\n", " Question table \n", "0 数据库tb_process表中有多少未删除的数据? tb_process \n", "1 在数据库tb_process表中,product_system_id为12983111603... tb_process \n", "2 在数据库tb_process表中,process_type为‘unit_process’且p... tb_process \n", "3 在数据库tb_process表中,id为1300755365750636544的记录的id、... tb_process \n", "4 在数据库tb_process表中,id为9237375的记录的id、ref_id、name、... tb_process " ] }, "execution_count": 41, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pandas as pd\n", "df = pd.read_excel(r\"./data/sql_ques_clear.xlsx\")\n", "df.head()" ] }, { "cell_type": "code", "execution_count": 42, "metadata": { "ExecuteTime": { "end_time": "2024-11-15T12:08:43.896102500Z", "start_time": "2024-11-15T12:08:31.101109600Z" }, "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'tb_process': ['share', 'deleted', 'update_time', 'create_time', 'create_user', 'reference_process_id', 'parent_id', 'node_type', 'extend', 'id', 'version', 'last_change', 'category_id', 'infrastructure_process', 'quantitative_reference_id', 'location_id', 'process_doc_id', 'dq_system_id', 'exchange_dq_system_id', 'social_dq_system_id', 'last_internal_id', 'currency_id', 'product_system_id', 'process_source', 'node', 'ref_id', 'name', 'default_allocation_method', 'synonyms', 'dq_entry', 'tags', 'library', 'description', 'process_type']}\n", "SQL: SELECT COUNT(*) FROM tb_process\n", "res: SELECT COUNT(*) FROM tb_process\n" ] } ], "source": [ "ques_test = \"数据库tb_process表中有多少未删除的数据?\"\n", "tmp_dict = dict()\n", "tmp_dict[\"tb_process\"] = schema[\"tb_process\"]\n", "print(tmp_dict)\n", "res = query_database(ques_test,tmp_dict)\n", "print(\"res:\",res)" ] }, { "cell_type": "code", "execution_count": 43, "metadata": { "ExecuteTime": { "end_time": "2024-11-15T12:13:59.277771300Z", "start_time": "2024-11-15T12:08:43.881680800Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "SQL: SELECT COUNT(*) FROM tb_process\n", "SQL: SELECT tags FROM tb_process WHERE product_system_id='1298311160348545024' AND is_deleted <> '1'\n", "SQL: SELECT id, ref_id, name, synonyms, category_id, process_type, default_allocation_method, infrastructure_process, quantitative_reference_id, dq_entry, dq_system_id, exchange_dq_system_id, social_dq_system_id, location_id, process_doc_id, currency_id, version, last_change, tags, library, product_system_id, process_source, reference_process_id, node, parent_id, node_type, deleted, create_time, update_time, create_user, extend, share FROM tb_process WHERE process_type='unit_process' AND process_source='0'\n", "SQL: SELECT id, ref_id, name, synonyms, category_id, description, process_type, default_allocation_method, infrastructure_process, quantitative_reference_id, dq_entry, dq_system_id, exchange_dq_system_id, social_dq_system_id, location_id, process_doc_id, currency_id, version, last_change, tags, library, product_system_id, process_source, reference_process_id, node, parent_id, node_type, deleted, create_time, update_time, create_user, extend, share FROM tb_process WHERE id='1300755365750636544'\n", "SQL: SELECT id, ref_id, name, synonyms, category_id, description, process_type, default_allocation_method, infrastructure_process, quantitative_reference_id, dq_entry, dq_system_id, exchange_dq_system_id, social_dq_system_id, location_id, process_doc_id, currency_id, version, last_change, tags, library, product_system_id, process_source, reference_process_id, node, parent_id, node_type, deleted, create_time, update_time, create_user, extend, share FROM tb_process WHERE id='9237375'\n", "SQL: SELECT COUNT(*) FROM tb_process WHERE name='铸铁' AND product_system_id='1300754248748761088' AND id!='1300755365750636544' AND is_deleted!='1'\n", "SQL: SELECT * FROM tb_process WHERE ref_id='49e7b885-8db7-3d63-8fce-5e32b8a71391' AND process_type='unit_process' LIMIT 1\n", "SQL: SELECT location_id FROM tb_process WHERE id='3448019508185273680'\n", "SQL: SELECT id, ref_id, name, synonyms, category_id, description, process_type, default_allocation_method, infrastructure_process, quantitative_reference_id, dq_entry, dq_system_id, exchange_dq_system_id, social_dq_system_id, location_id, process_doc_id, currency_id, version, last_change, tags, library, product_system_id, process_source, reference_process_id, node, parent_id, node_type, deleted, create_time, update_time, create_user, extend, share FROM tb_process WHERE product_system_id='1276473778162892800' AND deleted='0'\n", "SQL: SELECT id, name, node_type, tags, process_type, node FROM tb_process WHERE parent_id='1250448580720721920' AND deleted<>1\n", "SQL: SELECT id, name, node_type, tags, process_type, node FROM tb_process WHERE parent_id='1250448580720721920' AND deleted<>1\n", "SQL: SELECT id, ref_id, name, synonyms, category_id, description, process_type, default_allocation_method, infrastructure_process, quantitative_reference_id, dq_entry, dq_system_id, exchange_dq_system_id, social_dq_system_id, location_id, process_doc_id, currency_id, version, last_change, tags, library, product_system_id, process_source, reference_process_id, node, parent_id, node_type, deleted, create_time, update_time, create_user, extend, share FROM tb_process WHERE id IN ('1176899308570542080','1176898365472899072')\n", "SQL: SELECT id, ref_id, name, synonyms, category_id, description, process_type, default_allocation_method, infrastructure_process, quantitative_reference_id, dq_entry, dq_system_id, exchange_dq_system_id, social_dq_system_id, location_id, process_doc_id, currency_id, version, last_change, tags, library, product_system_id, process_source, reference_process_id, node, parent_id, node_type, deleted, create_time, update_time, create_user, extend, share FROM tb_process WHERE synonyms LIKE '%欧元%'\n", "SQL: SELECT id, ref_id, name, synonyms, category_id, description, process_type, default_allocation_method, infrastructure_process, quantitative_reference_id, dq_entry, dq_system_id, exchange_dq_system_id, social_dq_system_id, location_id, process_doc_id, currency_id, version, last_change, tags, library, product_system_id, process_source, reference_process_id, node, parent_id, node_type, deleted, create_time, update_time, create_user, extend, share FROM tb_process WHERE synonyms LIKE '%欧元%' AND process_source='0'\n", "SQL: SELECT id, ref_id, name, synonyms, category_id, description, process_type, default_allocation_method, infrastructure_process, quantitative_reference_id, dq_entry, dq_system_id, exchange_dq_system_id, social_dq_system_id, location_id, process_doc_id, currency_id, version, last_change, tags, library, product_system_id, process_source, reference_process_id, node, parent_id, node_type, deleted, create_time, update_time, create_user, extend, share FROM tb_process WHERE product_system_id='1298206895525330944' AND (node_type='0' OR node_type='1') AND deleted='0' AND parent_id IS NULL ORDER BY id DESC\n", "SQL: SELECT COUNT(*) FROM tb_flow WHERE deleted='0'\n", "SQL: SELECT COUNT(*) FROM tb_flow WHERE deleted='0'\n", "SQL: SELECT id, ref_id, name, flow_type, cas_number, description, tags, library, synonyms, infrastructure_flow, formula, category_id, reference_flow_property_id, source, version, last_change, create_time, create_user, update_user, deleted, update_time, location_id, name_cn, synonyms_cn, is_system, extend FROM tb_flow WHERE deleted='0' LIMIT 20000\n", "SQL: SELECT id, ref_id, name, flow_type, cas_number, description, tags, library, synonyms, infrastructure_flow, formula, category_id, reference_flow_property_id, source, version, last_change, create_time, create_user, update_user, deleted, update_time, location_id, name_cn, synonyms_cn, is_system, extend FROM tb_flow WHERE deleted='0' LIMIT 20000 OFFSET 60000\n", "SQL: SELECT COUNT(*) AS total FROM tb_calculation_task WHERE product_system_id='1298311160348545024'\n", "SQL: INSERT INTO tb_calculation_task (id, status, product_system_id, final_demand, method_id, method_name, calculate_type, create_user) VALUES ('1301153780011634688', '0', '1276473778162892800', 'JSON字符串', '44645931', 'ipcc 2013 gwp 100a', 'lci,lcia,contribution,sensitivity', '10078')\n", "SQL: SELECT COUNT(*) AS total FROM tb_product_system WHERE creator='10043' AND is_deleted<>'1'\n", "SQL: SELECT COUNT(*) AS total FROM tb_product_system WHERE deleted='0'\n", "SQL: SELECT id, ref_id, name, description, location_id, category_id, cutoff, reference_process_id, reference_exchange_id, target_amount, target_flow_property_factor_id, target_unit_id, model_info, project_type, create_user, create_time, update_time, deleted FROM tb_product_system WHERE id='1298311160348545024'\n", "SQL: SELECT COUNT(*) AS total FROM sys_user WHERE delete_flag = '0'\n", "SQL: SELECT id FROM sys_user WHERE org_id IN ('11', '12')\n", "SQL: SELECT id, product_system_id, form_id, data, extend FROM tb_form_data WHERE product_system_id='1247208720564224000' AND form_id='4'\n", "SQL: SELECT id FROM tb_form_data WHERE form_id='14' AND product_system_id='1299311046711840768'\n", "SQL: SELECT id, name, method_id, direction, reference_unit FROM tb_impact_category WHERE method_id='43262324' ORDER BY id\n", "SQL: SELECT SUM(amount) AS total FROM tb_lcia WHERE task_id='1298311273208877056' AND impact_category_id='43262501'\n", "SQL: SELECT COUNT(*) FROM (SELECT * FROM tb_parameter WHERE create_user = '10028' AND scope = 'global' UNION ALL SELECT * FROM tb_parameter WHERE process_id = '3448019511314224627' AND create_user = '10028')\n", "SQL: SELECT id, type, name, sort, template, script, script_url, extend FROM tb_product_system_form WHERE id='14'\n", "SQL: SELECT id, name, sort, extend FROM tb_product_system_form WHERE type='cbam' ORDER BY sort ASC\n", "SQL: SELECT * FROM tb_process_doc WHERE valid_from > '2020-01-01'\n", "SQL: SELECT * FROM tb_process_doc WHERE reviewer_id='1212'\n", "SQL: SELECT COUNT(*) FROM tb_process_doc WHERE project='asdfasdf'\n", "SQL: SELECT * FROM tb_process_doc WHERE valid_from > '2019-01-01' AND valid_until < '2022-12-31'\n", "SQL: SELECT * FROM tb_process_doc WHERE geography='17 European plants'\n", "SQL: SELECT * FROM tb_process_doc WHERE technology LIKE '%Acid-catalysed esterification of free-fatty acids.%'\n", "SQL: SELECT reviewer_id, COUNT(*) as record_count FROM tb_process_doc GROUP BY reviewer_id\n", "SQL: SELECT * FROM tb_process_doc WHERE copyright='1'\n", "SQL: SELECT * FROM tb_process_doc WHERE data_generator_id='1319'\n", "SQL: SELECT dataset_owner_id, COUNT(*) as record_count FROM tb_process_doc GROUP BY dataset_owner_id\n", "SQL: SELECT * FROM tb_process_doc WHERE data_documentor_id='1217'\n", "SQL: SELECT * FROM tb_process_doc WHERE publication_id='1500'\n", "SQL: SELECT * FROM tb_process_doc WHERE preceding_dataset='6006a8fc-1a6c-55ee-a49d-fc9342f5d451'\n", "SQL: SELECT * FROM tb_process_doc WHERE data_treatment LIKE '%Approximation%'\n", "SQL: SELECT * FROM tb_process_doc WHERE data_collection_period LIKE '%2020-2021%'\n", "SQL: SELECT * FROM tb_process WHERE deleted='1'\n", "SQL: SELECT * FROM tb_process WHERE update_time > '2024-10-15'\n", "SQL: SELECT COUNT(*) FROM tb_process WHERE create_user='10000'\n", "SQL: SELECT * FROM tb_process WHERE category_id='66'\n", "SQL: SELECT location_id, COUNT(*) FROM tb_process GROUP BY location_id\n", "SQL: SELECT * FROM tb_process WHERE dq_system_id='129124'\n", "SQL: SELECT * FROM tb_process WHERE tags LIKE '%10煤电%'\n", "SQL: SELECT * FROM tb_process WHERE default_allocation_method='economic allocation'\n", "SQL: SELECT process_type, COUNT(*) as count FROM tb_process GROUP BY process_type\n", "SQL: SELECT * FROM tb_process WHERE name LIKE '%AAAAsadfasd%'\n", "SQL: SELECT * FROM tb_process WHERE node_type='2'\n", "SQL: SELECT version, COUNT(*) as count FROM tb_process GROUP BY version\n", "SQL: SELECT * FROM tb_process WHERE process_doc_id='594641'\n", "SQL: SELECT * FROM tb_process WHERE default_allocation_method='CAUSAL'\n", "SQL: SELECT * FROM tb_process WHERE create_time BETWEEN '2022-01-01' AND '2023-01-01'\n", "SQL: SELECT * FROM tb_process WHERE process_source='1'\n", "SQL: SELECT * FROM tb_process WHERE ref_id='d4730e0a-9bae-3fde-93e1-cac96735aa76'\n", "SQL: SELECT * FROM tb_process WHERE infrastructure_process='0'\n" ] } ], "source": [ "sql_gold = []\n", "sql_pred = []\n", "\n", "for index,item in df.iterrows():\n", " prompt = item['Question']\n", " table = item['table']\n", " gold = item['SQL']\n", " \n", " if table not in table_names:\n", " print(table)\n", " continue\n", " if 'join' in gold:\n", " continue\n", " colum_type = get_column_types(cur,table)\n", " tmp = dict()\n", " tmp[table] = schema[table]\n", " sql = query_database(prompt,tmp)\n", " sql_pred.append(sql)\n", " sql_gold.append(gold)" ] }, { "cell_type": "code", "execution_count": 44, "metadata": { "ExecuteTime": { "end_time": "2024-11-15T12:27:01.460453800Z", "start_time": "2024-11-15T12:27:01.429826700Z" }, "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "66" ] }, "execution_count": 44, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(sql_gold)" ] }, { "cell_type": "code", "execution_count": 45, "metadata": { "ExecuteTime": { "end_time": "2024-11-15T12:27:03.001823100Z", "start_time": "2024-11-15T12:27:02.925201900Z" } }, "outputs": [], "source": [ "def save_txt(save_list,path):\n", " with open(path,\"w\",encoding=\"utf-8\") as f:\n", " for item in save_list:\n", " f.write(item+\"\\n\")" ] }, { "cell_type": "code", "execution_count": 46, "metadata": { "ExecuteTime": { "end_time": "2024-11-15T12:27:03.826791400Z", "start_time": "2024-11-15T12:27:03.758307800Z" }, "collapsed": false }, "outputs": [], "source": [ "save_txt(sql_gold,\"./data/Qianfan/gold.txt\")\n", "save_txt(sql_pred,\"./data/Qianfan/pred.txt\")" ] }, { "cell_type": "markdown", "metadata": { "ExecuteTime": { "end_time": "2024-11-13T03:32:33.063956100Z", "start_time": "2024-11-13T03:26:33.872560100Z" } }, "source": [ "### 改路径!!!" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "ExecuteTime": { "end_time": "2024-11-13T01:15:39.906977700Z", "start_time": "2024-11-13T01:15:39.866974300Z" }, "collapsed": false }, "outputs": [], "source": [ "!python --version" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 执行准确率" ] }, { "cell_type": "code", "execution_count": 73, "metadata": { "ExecuteTime": { "end_time": "2024-11-13T10:37:42.149121600Z", "start_time": "2024-11-13T10:37:42.077659900Z" }, "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[\"'last_login_time'\", \"'version'\", \"'deleted'\", \"'creator'\", \"'create_time'\", \"'updater'\", \"'update_time'\", \"'user_type'\", \"'person_auth_id'\", \"'enterprise_auth_id'\", \"'id'\", \"'gender'\", \"'org_id'\", \"'super_admin'\", \"'status'\", \"'username'\", \"'password'\", \"'real_name'\", \"'avatar'\", \"'open_id'\", \"'email'\", \"'mobile'\", \"'nickname'\", \"'id_number'\", \"'remark'\"]\n" ] } ], "source": [ "print(schema['sys_user'])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Qwen", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 2 }