110 lines
3.7 KiB
Python
110 lines
3.7 KiB
Python
|
from transformers import AutoTokenizer, AutoModel
|
|||
|
import torch
|
|||
|
import os
|
|||
|
import logging
|
|||
|
from typing import List, Union, Optional
|
|||
|
|
|||
|
# 配置日志
|
|||
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|||
|
logger = logging.getLogger(__name__)
|
|||
|
|
|||
|
class EmbeddingModel:
|
|||
|
"""
|
|||
|
文本向量编码模型类
|
|||
|
封装了加载模型和生成文本向量的功能
|
|||
|
"""
|
|||
|
def __init__(self, model_path: str):
|
|||
|
"""
|
|||
|
初始化向量编码模型
|
|||
|
|
|||
|
参数:
|
|||
|
model_path (str): 预训练模型的路径
|
|||
|
"""
|
|||
|
self.model_path = model_path
|
|||
|
try:
|
|||
|
logger.info(f"正在从{model_path}加载模型")
|
|||
|
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
|||
|
self.model = AutoModel.from_pretrained(model_path)
|
|||
|
self.model.eval() # 设置模型为评估模式
|
|||
|
logger.info("模型加载成功")
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"模型加载失败: {str(e)}")
|
|||
|
raise
|
|||
|
|
|||
|
def get_embeddings(self, sentences: Union[str, List[str]]) -> List[List[float]]:
|
|||
|
"""
|
|||
|
为输入文本生成向量编码
|
|||
|
|
|||
|
参数:
|
|||
|
sentences (Union[str, List[str]]): 单个句子或句子列表
|
|||
|
|
|||
|
返回:
|
|||
|
List[List[float]]: 归一化后的向量编码列表
|
|||
|
"""
|
|||
|
# 将单个字符串转换为列表以统一处理
|
|||
|
if isinstance(sentences, str):
|
|||
|
sentences = [sentences]
|
|||
|
|
|||
|
try:
|
|||
|
# 对输入进行分词
|
|||
|
encoded_input = self.tokenizer(
|
|||
|
sentences,
|
|||
|
padding=True, # 填充
|
|||
|
truncation=True, # 截断
|
|||
|
return_tensors='pt' # 返回PyTorch张量
|
|||
|
)
|
|||
|
|
|||
|
# 如果有GPU则使用GPU
|
|||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|||
|
self.model.to(device)
|
|||
|
encoded_input = {k: v.to(device) for k, v in encoded_input.items()}
|
|||
|
|
|||
|
# 生成向量编码
|
|||
|
with torch.no_grad(): # 不计算梯度
|
|||
|
model_output = self.model(**encoded_input)
|
|||
|
# 获取[CLS]标记的向量表示
|
|||
|
embeddings = model_output[0][:, 0]
|
|||
|
|
|||
|
# 对向量进行归一化
|
|||
|
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
|
|||
|
|
|||
|
# 转换为numpy列表并返回
|
|||
|
return embeddings.cpu().numpy().tolist()
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"生成向量编码时出错: {str(e)}")
|
|||
|
raise
|
|||
|
|
|||
|
def main():
|
|||
|
"""
|
|||
|
主函数,用于演示模型的使用
|
|||
|
"""
|
|||
|
# 从环境变量获取模型路径或使用默认路径
|
|||
|
model_path = os.environ.get(
|
|||
|
"MODEL_PATH",
|
|||
|
"/home/zhangxj/models/bge-large-zh-v1.5"
|
|||
|
)
|
|||
|
|
|||
|
try:
|
|||
|
# 初始化模型
|
|||
|
embedding_model = EmbeddingModel(model_path)
|
|||
|
|
|||
|
# 示例1: 单个句子向量编码
|
|||
|
single_sentence = "金刚烷"
|
|||
|
embeddings = embedding_model.get_embeddings(single_sentence)
|
|||
|
logger.info(f"单个句子向量维度: {len(embeddings[0])}")
|
|||
|
|
|||
|
# 示例2: 多个句子向量编码
|
|||
|
sentences = [
|
|||
|
"The weather is so nice!",
|
|||
|
"It's so sunny outside!",
|
|||
|
"He drove to the stadium.",
|
|||
|
]
|
|||
|
batch_embeddings = embedding_model.get_embeddings(sentences)
|
|||
|
logger.info(f"多个句子向量形状: ({len(batch_embeddings)}, {len(batch_embeddings[0])})")
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"主函数执行出错: {str(e)}")
|
|||
|
|
|||
|
if __name__ == '__main__':
|
|||
|
main()
|