LCA_LLM_application/Retrieval_new/local_encoder.py

110 lines
3.7 KiB
Python

from transformers import AutoTokenizer, AutoModel
import torch
import os
import logging
from typing import List, Union, Optional
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class EmbeddingModel:
"""
文本向量编码模型类
封装了加载模型和生成文本向量的功能
"""
def __init__(self, model_path: str):
"""
初始化向量编码模型
参数:
model_path (str): 预训练模型的路径
"""
self.model_path = model_path
try:
logger.info(f"正在从{model_path}加载模型")
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model = AutoModel.from_pretrained(model_path)
self.model.eval() # 设置模型为评估模式
logger.info("模型加载成功")
except Exception as e:
logger.error(f"模型加载失败: {str(e)}")
raise
def get_embeddings(self, sentences: Union[str, List[str]]) -> List[List[float]]:
"""
为输入文本生成向量编码
参数:
sentences (Union[str, List[str]]): 单个句子或句子列表
返回:
List[List[float]]: 归一化后的向量编码列表
"""
# 将单个字符串转换为列表以统一处理
if isinstance(sentences, str):
sentences = [sentences]
try:
# 对输入进行分词
encoded_input = self.tokenizer(
sentences,
padding=True, # 填充
truncation=True, # 截断
return_tensors='pt' # 返回PyTorch张量
)
# 如果有GPU则使用GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model.to(device)
encoded_input = {k: v.to(device) for k, v in encoded_input.items()}
# 生成向量编码
with torch.no_grad(): # 不计算梯度
model_output = self.model(**encoded_input)
# 获取[CLS]标记的向量表示
embeddings = model_output[0][:, 0]
# 对向量进行归一化
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
# 转换为numpy列表并返回
return embeddings.cpu().numpy().tolist()
except Exception as e:
logger.error(f"生成向量编码时出错: {str(e)}")
raise
def main():
"""
主函数,用于演示模型的使用
"""
# 从环境变量获取模型路径或使用默认路径
model_path = os.environ.get(
"MODEL_PATH",
"/home/zhangxj/models/bge-large-zh-v1.5"
)
try:
# 初始化模型
embedding_model = EmbeddingModel(model_path)
# 示例1: 单个句子向量编码
single_sentence = "金刚烷"
embeddings = embedding_model.get_embeddings(single_sentence)
logger.info(f"单个句子向量维度: {len(embeddings[0])}")
# 示例2: 多个句子向量编码
sentences = [
"The weather is so nice!",
"It's so sunny outside!",
"He drove to the stadium.",
]
batch_embeddings = embedding_model.get_embeddings(sentences)
logger.info(f"多个句子向量形状: ({len(batch_embeddings)}, {len(batch_embeddings[0])})")
except Exception as e:
logger.error(f"主函数执行出错: {str(e)}")
if __name__ == '__main__':
main()