This commit is contained in:
parent
5029796101
commit
4998a82e39
|
@ -0,0 +1,61 @@
|
|||
import os
|
||||
import json
|
||||
from flask import Flask, request, make_response
|
||||
from logzero import logger
|
||||
from functools import lru_cache
|
||||
|
||||
from utils import *
|
||||
from local_encoder import EmbeddingModel
|
||||
|
||||
# current_path = os.path.dirname(os.path.abspath(__file__)) # for local
|
||||
current_path = os.getcwd() # for docker
|
||||
logger.info(f"{current_path}")
|
||||
|
||||
app = Flask(__name__)
|
||||
os.environ["CUDA_DEVICE_ORDER"] = 'PCI_BUS_ID'
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
|
||||
|
||||
|
||||
path = "/home/zhangxj/models/bge-large-zh-v1.5"
|
||||
model = EmbeddingModel(path)
|
||||
|
||||
@lru_cache()
|
||||
def process_and_embed(sentence):
|
||||
# 判断是否为中文
|
||||
if has_no_chinese(sentence):
|
||||
# 英文处理
|
||||
clean_text = preprocess_eng(sentence)
|
||||
processed_text = get_noun_en(clean_text)
|
||||
else:
|
||||
# 中文处理
|
||||
clean_text = preprocess_zh(sentence)
|
||||
processed_text = get_noun_zh(clean_text)
|
||||
|
||||
# 如果处理后为空,使用原始文本
|
||||
if not processed_text.strip():
|
||||
processed_text = sentence
|
||||
|
||||
# 获取向量编码
|
||||
embeddings = model.get_embeddings(processed_text)
|
||||
return embeddings
|
||||
|
||||
@app.route('/embedding/', methods=["POST"])
|
||||
def run_cls():
|
||||
resp_info = dict()
|
||||
if request.method == "POST":
|
||||
sentences = request.json.get('sentences')
|
||||
if sentences is not None and len(sentences) != 0:
|
||||
logger.info(sentences)
|
||||
resp_info["code"] = 200
|
||||
resp_info["data"] = process_and_embed(sentences)
|
||||
else:
|
||||
resp_info["msg"] = "Input is None, please check !"
|
||||
resp_info["code"] = 406
|
||||
resp = make_response(json.dumps(resp_info))
|
||||
resp.status_code = 200
|
||||
return resp
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(host='0.0.0.0', port=5163, debug=False)
|
||||
# res = process_and_embed("土豆")
|
||||
# print(res)
|
|
@ -0,0 +1,110 @@
|
|||
from transformers import AutoTokenizer, AutoModel
|
||||
import torch
|
||||
import os
|
||||
import logging
|
||||
from typing import List, Union, Optional
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class EmbeddingModel:
|
||||
"""
|
||||
文本向量编码模型类
|
||||
封装了加载模型和生成文本向量的功能
|
||||
"""
|
||||
def __init__(self, model_path: str):
|
||||
"""
|
||||
初始化向量编码模型
|
||||
|
||||
参数:
|
||||
model_path (str): 预训练模型的路径
|
||||
"""
|
||||
self.model_path = model_path
|
||||
try:
|
||||
logger.info(f"正在从{model_path}加载模型")
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
self.model = AutoModel.from_pretrained(model_path)
|
||||
self.model.eval() # 设置模型为评估模式
|
||||
logger.info("模型加载成功")
|
||||
except Exception as e:
|
||||
logger.error(f"模型加载失败: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_embeddings(self, sentences: Union[str, List[str]]) -> List[List[float]]:
|
||||
"""
|
||||
为输入文本生成向量编码
|
||||
|
||||
参数:
|
||||
sentences (Union[str, List[str]]): 单个句子或句子列表
|
||||
|
||||
返回:
|
||||
List[List[float]]: 归一化后的向量编码列表
|
||||
"""
|
||||
# 将单个字符串转换为列表以统一处理
|
||||
if isinstance(sentences, str):
|
||||
sentences = [sentences]
|
||||
|
||||
try:
|
||||
# 对输入进行分词
|
||||
encoded_input = self.tokenizer(
|
||||
sentences,
|
||||
padding=True, # 填充
|
||||
truncation=True, # 截断
|
||||
return_tensors='pt' # 返回PyTorch张量
|
||||
)
|
||||
|
||||
# 如果有GPU则使用GPU
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.model.to(device)
|
||||
encoded_input = {k: v.to(device) for k, v in encoded_input.items()}
|
||||
|
||||
# 生成向量编码
|
||||
with torch.no_grad(): # 不计算梯度
|
||||
model_output = self.model(**encoded_input)
|
||||
# 获取[CLS]标记的向量表示
|
||||
embeddings = model_output[0][:, 0]
|
||||
|
||||
# 对向量进行归一化
|
||||
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
|
||||
|
||||
# 转换为numpy列表并返回
|
||||
return embeddings.cpu().numpy().tolist()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成向量编码时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
def main():
|
||||
"""
|
||||
主函数,用于演示模型的使用
|
||||
"""
|
||||
# 从环境变量获取模型路径或使用默认路径
|
||||
model_path = os.environ.get(
|
||||
"MODEL_PATH",
|
||||
"/home/zhangxj/models/bge-large-zh-v1.5"
|
||||
)
|
||||
|
||||
try:
|
||||
# 初始化模型
|
||||
embedding_model = EmbeddingModel(model_path)
|
||||
|
||||
# 示例1: 单个句子向量编码
|
||||
single_sentence = "金刚烷"
|
||||
embeddings = embedding_model.get_embeddings(single_sentence)
|
||||
logger.info(f"单个句子向量维度: {len(embeddings[0])}")
|
||||
|
||||
# 示例2: 多个句子向量编码
|
||||
sentences = [
|
||||
"The weather is so nice!",
|
||||
"It's so sunny outside!",
|
||||
"He drove to the stadium.",
|
||||
]
|
||||
batch_embeddings = embedding_model.get_embeddings(sentences)
|
||||
logger.info(f"多个句子向量形状: ({len(batch_embeddings)}, {len(batch_embeddings[0])})")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"主函数执行出错: {str(e)}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,41 @@
|
|||
import requests
|
||||
import json
|
||||
|
||||
# API 地址
|
||||
API_URL = "http://0.0.0.0:5163/embedding/"
|
||||
|
||||
# 要发送的数据(JSON 格式)
|
||||
data = {
|
||||
"sentences": "土豆" # 单句示例,可改为 ["土豆", "西红柿"] 处理多句
|
||||
}
|
||||
|
||||
# 发送 POST 请求
|
||||
try:
|
||||
# 设置请求头,明确指定内容类型为 JSON
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
# 使用 POST 方法发送请求,data 需要序列化为 JSON 字符串
|
||||
response = requests.post(
|
||||
url=API_URL,
|
||||
data=json.dumps(data), # 将字典转换为 JSON 字符串
|
||||
headers=headers
|
||||
)
|
||||
|
||||
# 检查响应状态码
|
||||
if response.status_code == 200:
|
||||
# 解析返回的 JSON 数据
|
||||
result = response.json()
|
||||
print("API Response:", result)
|
||||
|
||||
# 根据返回的 code 判断结果
|
||||
if result.get("code") == 200:
|
||||
embeddings = result.get("data")
|
||||
print("Embeddings:", embeddings)
|
||||
else:
|
||||
print("Error from API:", result.get("msg"))
|
||||
else:
|
||||
print(f"Request failed with status code: {response.status_code}")
|
||||
print("Response text:", response.text)
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"Error connecting to API: {e}")
|
|
@ -0,0 +1,87 @@
|
|||
import nltk
|
||||
from nltk.tokenize import word_tokenize
|
||||
from nltk import pos_tag
|
||||
import jieba.posseg as pseg
|
||||
|
||||
# 下载相关数据
|
||||
nltk.download('punkt')
|
||||
nltk.download('averaged_perceptron_tagger')
|
||||
|
||||
|
||||
from nltk.stem import WordNetLemmatizer
|
||||
import string
|
||||
import re
|
||||
|
||||
def preprocess_eng(text):
|
||||
'''
|
||||
英文文本预处理:小写化,去除标点(待定),去除特殊符号,只保留单词
|
||||
拼写是否正确:是,因为是从ecoinvent导入的,没有拼写错误;
|
||||
词干提取(stemming)和词形还原(lemmatization):可以处理一下,有的提取不准确,不做此操作
|
||||
'''
|
||||
# 去除标点
|
||||
text = text.translate(str.maketrans('', '', string.punctuation))
|
||||
# 去除数字
|
||||
text = re.sub(r'\d+', ' ', text)
|
||||
# 去除多余字符
|
||||
text = re.sub(r'[^A-Za-z0-9\s]', '', text)
|
||||
# 去除多余空格
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
return text
|
||||
|
||||
def preprocess_zh(text):
|
||||
'''
|
||||
中文文本预处理:只保留中文内容,去除英文、数字和标点
|
||||
'''
|
||||
text = str(text)
|
||||
# 去除英文
|
||||
text = re.sub(r'[a-zA-Z]',' ',text)
|
||||
text = re.sub(r'\d', ' ', text)
|
||||
# 去除中文标点符号
|
||||
text = re.sub(r'[,。!?、;:“”()《》【】-]', ' ', text)
|
||||
# 去除英文标点符号
|
||||
text = re.sub(r'[.,!?;:"\'\(\)\[\]{}]', ' ', text)
|
||||
# 去除空格
|
||||
text = re.sub(r'\s+','',text)
|
||||
|
||||
return text
|
||||
|
||||
# 英文名词处理
|
||||
def get_noun_en(text):
|
||||
# 分词
|
||||
words = word_tokenize(text)
|
||||
# 词性标注
|
||||
tagged = pos_tag(words)
|
||||
|
||||
# 提取名词
|
||||
nouns = [word for word, tag in tagged if tag.startswith('NN')]
|
||||
noun = ' '.join(nouns)
|
||||
return noun
|
||||
|
||||
# 中文名词提取
|
||||
def get_noun_zh(text):
|
||||
x = str(text)
|
||||
if x=='nan':
|
||||
return ''
|
||||
words = pseg.cut(text)
|
||||
nouns = [word for word, flag in words if flag.startswith('n')]
|
||||
noun = ' '.join(nouns)
|
||||
return noun
|
||||
|
||||
def has_no_chinese(text):
|
||||
"""
|
||||
判断一个文本是否不包含中文字符
|
||||
|
||||
参数:
|
||||
text (str): 需要检查的文本
|
||||
|
||||
返回:
|
||||
bool: 如果文本中没有中文字符返回True,否则返回False
|
||||
"""
|
||||
for char in text:
|
||||
if '\u4e00' <= char <= '\u9fff' or \
|
||||
'\u3400' <= char <= '\u4dbf' or \
|
||||
'\u2f00' <= char <= '\u2fdf' or \
|
||||
'\u3100' <= char <= '\u312f' or \
|
||||
'\u31a0' <= char <= '\u31bf':
|
||||
return False
|
||||
return True
|
Loading…
Reference in New Issue