from transformers import AutoTokenizer, AutoModel import torch import os import logging from typing import List, Union, Optional # 配置日志 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) class EmbeddingModel: """ 文本向量编码模型类 封装了加载模型和生成文本向量的功能 """ def __init__(self, model_path: str): """ 初始化向量编码模型 参数: model_path (str): 预训练模型的路径 """ self.model_path = model_path try: logger.info(f"正在从{model_path}加载模型") self.tokenizer = AutoTokenizer.from_pretrained(model_path) self.model = AutoModel.from_pretrained(model_path) self.model.eval() # 设置模型为评估模式 logger.info("模型加载成功") except Exception as e: logger.error(f"模型加载失败: {str(e)}") raise def get_embeddings(self, sentences: Union[str, List[str]]) -> List[List[float]]: """ 为输入文本生成向量编码 参数: sentences (Union[str, List[str]]): 单个句子或句子列表 返回: List[List[float]]: 归一化后的向量编码列表 """ # 将单个字符串转换为列表以统一处理 if isinstance(sentences, str): sentences = [sentences] try: # 对输入进行分词 encoded_input = self.tokenizer( sentences, padding=True, # 填充 truncation=True, # 截断 return_tensors='pt' # 返回PyTorch张量 ) # 如果有GPU则使用GPU device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.model.to(device) encoded_input = {k: v.to(device) for k, v in encoded_input.items()} # 生成向量编码 with torch.no_grad(): # 不计算梯度 model_output = self.model(**encoded_input) # 获取[CLS]标记的向量表示 embeddings = model_output[0][:, 0] # 对向量进行归一化 embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) # 转换为numpy列表并返回 return embeddings.cpu().numpy().tolist() except Exception as e: logger.error(f"生成向量编码时出错: {str(e)}") raise def main(): """ 主函数,用于演示模型的使用 """ # 从环境变量获取模型路径或使用默认路径 model_path = os.environ.get( "MODEL_PATH", "/home/zhangxj/models/bge-large-zh-v1.5" ) try: # 初始化模型 embedding_model = EmbeddingModel(model_path) # 示例1: 单个句子向量编码 single_sentence = "金刚烷" embeddings = embedding_model.get_embeddings(single_sentence) logger.info(f"单个句子向量维度: {len(embeddings[0])}") # 示例2: 多个句子向量编码 sentences = [ "The weather is so nice!", "It's so sunny outside!", "He drove to the stadium.", ] batch_embeddings = embedding_model.get_embeddings(sentences) logger.info(f"多个句子向量形状: ({len(batch_embeddings)}, {len(batch_embeddings[0])})") except Exception as e: logger.error(f"主函数执行出错: {str(e)}") if __name__ == '__main__': main()