local_embedding/local_encoder.py

30 lines
927 B
Python
Raw Permalink Normal View History

2024-12-25 12:18:49 +08:00
from transformers import AutoTokenizer, AutoModel
import torch
def load_model(path):
tokenizer = AutoTokenizer.from_pretrained(path)
model = AutoModel.from_pretrained(path)
model.eval()
return tokenizer, model
def embedding(tokenizer,model , sentences):
"""_summary_
Args:
tokenizer (_type_): 分词器
model (_type_): 向量模型
sentences (_type_): 句子list
Returns:
_type_: 向量长度为1024list
"""
encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
with torch.no_grad():
model_output = model(**encoded_input)
# Perform pooling. In this case, cls pooling.
sentence_embeddings = model_output[0][:, 0]
# normalize embeddings
sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
return sentence_embeddings.cpu().numpy().tolist()