diff --git a/Retrieval_new/README.md b/Retrieval_new/README.md new file mode 100644 index 0000000..e69de29 diff --git a/Retrieval_new/embedding.py b/Retrieval_new/embedding.py new file mode 100644 index 0000000..32a1184 --- /dev/null +++ b/Retrieval_new/embedding.py @@ -0,0 +1,61 @@ +import os +import json +from flask import Flask, request, make_response +from logzero import logger +from functools import lru_cache + +from utils import * +from local_encoder import EmbeddingModel + +# current_path = os.path.dirname(os.path.abspath(__file__)) # for local +current_path = os.getcwd() # for docker +logger.info(f"{current_path}") + +app = Flask(__name__) +os.environ["CUDA_DEVICE_ORDER"] = 'PCI_BUS_ID' +os.environ['CUDA_VISIBLE_DEVICES'] = '0,1' + + +path = "/home/zhangxj/models/bge-large-zh-v1.5" +model = EmbeddingModel(path) + +@lru_cache() +def process_and_embed(sentence): + # 判断是否为中文 + if has_no_chinese(sentence): + # 英文处理 + clean_text = preprocess_eng(sentence) + processed_text = get_noun_en(clean_text) + else: + # 中文处理 + clean_text = preprocess_zh(sentence) + processed_text = get_noun_zh(clean_text) + + # 如果处理后为空,使用原始文本 + if not processed_text.strip(): + processed_text = sentence + + # 获取向量编码 + embeddings = model.get_embeddings(processed_text) + return embeddings + +@app.route('/embedding/', methods=["POST"]) +def run_cls(): + resp_info = dict() + if request.method == "POST": + sentences = request.json.get('sentences') + if sentences is not None and len(sentences) != 0: + logger.info(sentences) + resp_info["code"] = 200 + resp_info["data"] = process_and_embed(sentences) + else: + resp_info["msg"] = "Input is None, please check !" + resp_info["code"] = 406 + resp = make_response(json.dumps(resp_info)) + resp.status_code = 200 + return resp + +if __name__ == '__main__': + app.run(host='0.0.0.0', port=5163, debug=False) + # res = process_and_embed("土豆") + # print(res) \ No newline at end of file diff --git a/Retrieval_new/local_encoder.py b/Retrieval_new/local_encoder.py new file mode 100644 index 0000000..2367b05 --- /dev/null +++ b/Retrieval_new/local_encoder.py @@ -0,0 +1,110 @@ +from transformers import AutoTokenizer, AutoModel +import torch +import os +import logging +from typing import List, Union, Optional + +# 配置日志 +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +class EmbeddingModel: + """ + 文本向量编码模型类 + 封装了加载模型和生成文本向量的功能 + """ + def __init__(self, model_path: str): + """ + 初始化向量编码模型 + + 参数: + model_path (str): 预训练模型的路径 + """ + self.model_path = model_path + try: + logger.info(f"正在从{model_path}加载模型") + self.tokenizer = AutoTokenizer.from_pretrained(model_path) + self.model = AutoModel.from_pretrained(model_path) + self.model.eval() # 设置模型为评估模式 + logger.info("模型加载成功") + except Exception as e: + logger.error(f"模型加载失败: {str(e)}") + raise + + def get_embeddings(self, sentences: Union[str, List[str]]) -> List[List[float]]: + """ + 为输入文本生成向量编码 + + 参数: + sentences (Union[str, List[str]]): 单个句子或句子列表 + + 返回: + List[List[float]]: 归一化后的向量编码列表 + """ + # 将单个字符串转换为列表以统一处理 + if isinstance(sentences, str): + sentences = [sentences] + + try: + # 对输入进行分词 + encoded_input = self.tokenizer( + sentences, + padding=True, # 填充 + truncation=True, # 截断 + return_tensors='pt' # 返回PyTorch张量 + ) + + # 如果有GPU则使用GPU + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.model.to(device) + encoded_input = {k: v.to(device) for k, v in encoded_input.items()} + + # 生成向量编码 + with torch.no_grad(): # 不计算梯度 + model_output = self.model(**encoded_input) + # 获取[CLS]标记的向量表示 + embeddings = model_output[0][:, 0] + + # 对向量进行归一化 + embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) + + # 转换为numpy列表并返回 + return embeddings.cpu().numpy().tolist() + + except Exception as e: + logger.error(f"生成向量编码时出错: {str(e)}") + raise + +def main(): + """ + 主函数,用于演示模型的使用 + """ + # 从环境变量获取模型路径或使用默认路径 + model_path = os.environ.get( + "MODEL_PATH", + "/home/zhangxj/models/bge-large-zh-v1.5" + ) + + try: + # 初始化模型 + embedding_model = EmbeddingModel(model_path) + + # 示例1: 单个句子向量编码 + single_sentence = "金刚烷" + embeddings = embedding_model.get_embeddings(single_sentence) + logger.info(f"单个句子向量维度: {len(embeddings[0])}") + + # 示例2: 多个句子向量编码 + sentences = [ + "The weather is so nice!", + "It's so sunny outside!", + "He drove to the stadium.", + ] + batch_embeddings = embedding_model.get_embeddings(sentences) + logger.info(f"多个句子向量形状: ({len(batch_embeddings)}, {len(batch_embeddings[0])})") + + except Exception as e: + logger.error(f"主函数执行出错: {str(e)}") + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/Retrieval_new/main.py b/Retrieval_new/main.py new file mode 100644 index 0000000..37781dc --- /dev/null +++ b/Retrieval_new/main.py @@ -0,0 +1,41 @@ +import requests +import json + +# API 地址 +API_URL = "http://0.0.0.0:5163/embedding/" + +# 要发送的数据(JSON 格式) +data = { + "sentences": "土豆" # 单句示例,可改为 ["土豆", "西红柿"] 处理多句 +} + +# 发送 POST 请求 +try: + # 设置请求头,明确指定内容类型为 JSON + headers = {"Content-Type": "application/json"} + + # 使用 POST 方法发送请求,data 需要序列化为 JSON 字符串 + response = requests.post( + url=API_URL, + data=json.dumps(data), # 将字典转换为 JSON 字符串 + headers=headers + ) + + # 检查响应状态码 + if response.status_code == 200: + # 解析返回的 JSON 数据 + result = response.json() + print("API Response:", result) + + # 根据返回的 code 判断结果 + if result.get("code") == 200: + embeddings = result.get("data") + print("Embeddings:", embeddings) + else: + print("Error from API:", result.get("msg")) + else: + print(f"Request failed with status code: {response.status_code}") + print("Response text:", response.text) + +except requests.exceptions.RequestException as e: + print(f"Error connecting to API: {e}") \ No newline at end of file diff --git a/Retrieval_new/utils.py b/Retrieval_new/utils.py new file mode 100644 index 0000000..a8fc4f0 --- /dev/null +++ b/Retrieval_new/utils.py @@ -0,0 +1,87 @@ +import nltk +from nltk.tokenize import word_tokenize +from nltk import pos_tag +import jieba.posseg as pseg + +# 下载相关数据 +nltk.download('punkt') +nltk.download('averaged_perceptron_tagger') + + +from nltk.stem import WordNetLemmatizer +import string +import re + +def preprocess_eng(text): + ''' + 英文文本预处理:小写化,去除标点(待定),去除特殊符号,只保留单词 + 拼写是否正确:是,因为是从ecoinvent导入的,没有拼写错误; + 词干提取(stemming)和词形还原(lemmatization):可以处理一下,有的提取不准确,不做此操作 + ''' + # 去除标点 + text = text.translate(str.maketrans('', '', string.punctuation)) + # 去除数字 + text = re.sub(r'\d+', ' ', text) + # 去除多余字符 + text = re.sub(r'[^A-Za-z0-9\s]', '', text) + # 去除多余空格 + text = re.sub(r'\s+', ' ', text) + return text + +def preprocess_zh(text): + ''' + 中文文本预处理:只保留中文内容,去除英文、数字和标点 + ''' + text = str(text) + # 去除英文 + text = re.sub(r'[a-zA-Z]',' ',text) + text = re.sub(r'\d', ' ', text) + # 去除中文标点符号 + text = re.sub(r'[,。!?、;:“”()《》【】-]', ' ', text) + # 去除英文标点符号 + text = re.sub(r'[.,!?;:"\'\(\)\[\]{}]', ' ', text) + # 去除空格 + text = re.sub(r'\s+','',text) + + return text + +# 英文名词处理 +def get_noun_en(text): + # 分词 + words = word_tokenize(text) + # 词性标注 + tagged = pos_tag(words) + + # 提取名词 + nouns = [word for word, tag in tagged if tag.startswith('NN')] + noun = ' '.join(nouns) + return noun + +# 中文名词提取 +def get_noun_zh(text): + x = str(text) + if x=='nan': + return '' + words = pseg.cut(text) + nouns = [word for word, flag in words if flag.startswith('n')] + noun = ' '.join(nouns) + return noun + +def has_no_chinese(text): + """ + 判断一个文本是否不包含中文字符 + + 参数: + text (str): 需要检查的文本 + + 返回: + bool: 如果文本中没有中文字符返回True,否则返回False + """ + for char in text: + if '\u4e00' <= char <= '\u9fff' or \ + '\u3400' <= char <= '\u4dbf' or \ + '\u2f00' <= char <= '\u2fdf' or \ + '\u3100' <= char <= '\u312f' or \ + '\u31a0' <= char <= '\u31bf': + return False + return True \ No newline at end of file