{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2024-11-15T08:05:02.254040400Z", "start_time": "2024-11-15T08:05:00.664449900Z" } }, "outputs": [], "source": [ "import os\n", "from langchain_community.llms import QianfanLLMEndpoint,Tongyi\n", "from langchain_community.chat_models import ChatZhipuAI\n", "from langchain.chains import LLMChain\n", "from langchain.prompts import PromptTemplate\n", "import psycopg2\n", "import re" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "ExecuteTime": { "end_time": "2024-11-15T09:27:27.129813200Z", "start_time": "2024-11-15T09:27:27.095824600Z" } }, "outputs": [], "source": [ "# 百度千帆\n", "api = \"xxxxx\"\n", "sk = \"xxxxxx\"\n", "\n", "llm = QianfanLLMEndpoint(model=\"ERNIE-4.0-8K\",qianfan_ak=api,qianfan_sk=sk)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2024-11-15T08:05:13.561075500Z", "start_time": "2024-11-15T08:05:13.451555700Z" }, "collapsed": false }, "outputs": [], "source": [ "# 千问\n", "# os.environ['DASHSCOPE_API_KEY'] = 'sk-xxxxxxxxx'\n", "# \n", "# llm = Tongyi(model_name=\"qwen-max\")" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "ExecuteTime": { "end_time": "2024-11-15T08:05:18.389436200Z", "start_time": "2024-11-15T08:05:15.487698800Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "生命的意义在于追求个人成长与幸福,同时为他人和社会贡献价值,实现自我与世界的和谐共存。\n" ] } ], "source": [ "input_text = \"用50个字左右阐述,生命的意义在于\"\n", "response = llm.invoke(input_text)\n", "print(response)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "ExecuteTime": { "end_time": "2024-11-15T08:11:05.408878300Z", "start_time": "2024-11-15T08:11:05.273677300Z" } }, "outputs": [], "source": [ "# 数据库连接参数\n", "dbname = \"gis_lca\"\n", "user = \"postgres\"\n", "password = \"xxxxxx\"\n", "host = \"localhost\" # 或者是你 Docker 容器的 IP 地址,如果你在不同的机器上\n", "port = \"5432\"\n", "\n", "# 连接字符串\n", "conn_string = f\"host={host} dbname={dbname} user={user} password={password} port={port}\"\n", "# 连接到数据库\n", "conn = psycopg2.connect(conn_string)\n", "cur = conn.cursor()" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "ExecuteTime": { "end_time": "2024-11-15T08:11:53.090176Z", "start_time": "2024-11-15T08:11:53.041140900Z" }, "collapsed": false }, "outputs": [], "source": [ "# 获取表名\n", "def get_table_name(cur):\n", " # 执行 SQL 查询以获取所有表的列表\n", " cur.execute(\"\"\"\n", " SELECT table_name \n", " FROM information_schema.tables \n", " WHERE table_schema = 'public';\n", " \"\"\")\n", " \n", " # 获取查询结果\n", " tables = cur.fetchall()\n", " tname = []\n", " # 打印表名\n", " for table in tables:\n", " tname.append(table[0])\n", " \n", " return tname" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "ExecuteTime": { "end_time": "2024-11-15T08:13:58.272398Z", "start_time": "2024-11-15T08:13:58.238434200Z" } }, "outputs": [], "source": [ "# 获取数据库列名。\n", "def get_table_columns(cur,table_name):\n", " try:\n", " # 执行 SQL 查询以获取表的列名\n", " cur.execute(f\"\"\"\n", " SELECT column_name \n", " FROM information_schema.columns \n", " WHERE table_name = '{table_name}';\n", " \"\"\")\n", " # 获取查询结果\n", " cols = [desc[0] for desc in cur.fetchall()]\n", " columns = []\n", " for col in cols:\n", " # column = \"'\"+col+\"'\"\n", " columns.append(col)\n", " except Exception as e:\n", " print(f\"An error occurred: {e}\")\n", " return columns" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "ExecuteTime": { "end_time": "2024-11-15T08:13:59.601786Z", "start_time": "2024-11-15T08:13:59.206151400Z" }, "collapsed": false }, "outputs": [], "source": [ "# 获取数据库表结构,表名对应列名\n", "schema = dict()\n", "\n", "table_names = get_table_name(cur)\n", "for name in table_names:\n", " schema[name] = get_table_columns(cur,name)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "ExecuteTime": { "end_time": "2024-11-15T08:14:02.392006100Z", "start_time": "2024-11-15T08:14:02.359287Z" }, "collapsed": false }, "outputs": [], "source": [ "def get_column_types(cur, table_name):\n", " # 执行 SQL 查询以获取表的列名和数据类型\n", " cur.execute(\"\"\"\n", " SELECT column_name, data_type\n", " FROM information_schema.columns\n", " WHERE table_name = %s;\n", " \"\"\", (table_name,))\n", "\n", " # 获取查询结果\n", " results = cur.fetchall()\n", " return results" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "ExecuteTime": { "end_time": "2024-11-15T08:14:04.182996600Z", "start_time": "2024-11-15T08:14:04.117421900Z" }, "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[('share', 'integer'), ('deleted', 'smallint'), ('update_time', 'timestamp without time zone'), ('create_time', 'timestamp without time zone'), ('create_user', 'bigint'), ('reference_process_id', 'bigint'), ('parent_id', 'bigint'), ('node_type', 'smallint'), ('extend', 'json'), ('id', 'bigint'), ('version', 'bigint'), ('last_change', 'bigint'), ('category_id', 'bigint'), ('infrastructure_process', 'smallint'), ('quantitative_reference_id', 'bigint'), ('location_id', 'bigint'), ('process_doc_id', 'bigint'), ('dq_system_id', 'bigint'), ('exchange_dq_system_id', 'bigint'), ('social_dq_system_id', 'bigint'), ('last_internal_id', 'integer'), ('currency_id', 'bigint'), ('product_system_id', 'bigint'), ('process_source', 'smallint'), ('node', 'jsonb'), ('ref_id', 'character varying'), ('name', 'character varying'), ('default_allocation_method', 'character varying'), ('synonyms', 'text'), ('dq_entry', 'character varying'), ('tags', 'character varying'), ('library', 'character varying'), ('description', 'text'), ('process_type', 'character varying')]\n" ] } ], "source": [ "col = get_column_types(cur,\"tb_process\")\n", "print(col)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "ExecuteTime": { "end_time": "2024-11-15T08:14:05.075951300Z", "start_time": "2024-11-15T08:14:05.044963600Z" } }, "outputs": [], "source": [ "def getSql(response):\n", " pattern = r\"#(.*?)#\"\n", " sql_list = re.findall(pattern, response)\n", " sql_str = sql_list[-1]\n", " # sql_str = sql_str.lower()\n", " print(\"SQL:\",sql_str)\n", "\n", " return sql_str" ] }, { "cell_type": "code", "execution_count": 40, "metadata": { "ExecuteTime": { "end_time": "2024-11-15T12:08:29.507242700Z", "start_time": "2024-11-15T12:08:29.440042800Z" } }, "outputs": [], "source": [ "def query_database(text, schema):\n", " # 将表结构信息包含在提示中\n", " schema_info = \"\\n\".join([f\"Table {table}: {', '.join(columns)}\" for table, columns in schema.items()])\n", " # print(schema_info)\n", " # prompt= PromptTemplate(\n", " # input_variables=[\"schema_info\",\"prompt\"],\n", " # template='''以下是数据库的表结构信息:{schema_info},分析表结构信息.\n", " # 请根据以下描述生成一个符合SQLite数据库语法的SQL查询,并且不能修改给出的数据表列名。\n", " # 描述:{prompt}。\n", " # 要求输出的SQL以#开头,以#结尾,参数类型一律按照字符串处理,样例如下:\n", " # #SELECT * FROM table#\n", " # #SELECT COUNT(*) FROM table where Colume_name='abc'#\n", " # 注意不要输出分析过程和其他内容,直接给出SQL语句。\n", " # '''\n", " # )\n", " prompt= PromptTemplate(\n", " input_variables=[\"prompt\"],\n", " template='''请根据以下描述生成一个符合SQLite数据库语法的SQL查询。\n", " 描述:{prompt}。\n", " 并且要求输出的SQL以#开头,以#结尾,参数类型一律按照字符串处理,样例如下:\n", " #SELECT * FROM table#\n", " #SELECT COUNT(*) FROM table where Colume_name='abc'#\n", " 注意不要输出分析过程和其他内容,直接给出SQL语句。\n", " '''\n", " )\n", " # print(schema_info)\n", " chain = LLMChain(llm=llm, prompt=prompt)\n", " query = chain.run(\n", " {\"prompt\":text}\n", " )\n", " # print(query)\n", " sql = getSql(query)\n", " # result = execute_query(sql)\n", " return sql" ] }, { "cell_type": "code", "execution_count": 41, "metadata": { "ExecuteTime": { "end_time": "2024-11-15T12:08:30.524029800Z", "start_time": "2024-11-15T12:08:30.209010600Z" } }, "outputs": [ { "data": { "text/html": [ "
\n", " | SQL | \n", "Question | \n", "table | \n", "
---|---|---|---|
0 | \n", "select count( * ) as total from tb_process whe... | \n", "数据库tb_process表中有多少未删除的数据? | \n", "tb_process | \n", "
1 | \n", "select tags from tb_process where product_syst... | \n", "在数据库tb_process表中,product_system_id为12983111603... | \n", "tb_process | \n", "
2 | \n", "select id,ref_id,name,synonyms,category_id,pro... | \n", "在数据库tb_process表中,process_type为‘unit_process’且p... | \n", "tb_process | \n", "
3 | \n", "select id,ref_id,name,synonyms,category_id,des... | \n", "在数据库tb_process表中,id为1300755365750636544的记录的id、... | \n", "tb_process | \n", "
4 | \n", "select id,ref_id,name,synonyms,category_id,des... | \n", "在数据库tb_process表中,id为9237375的记录的id、ref_id、name、... | \n", "tb_process | \n", "