30 lines
927 B
Python
30 lines
927 B
Python
|
from transformers import AutoTokenizer, AutoModel
|
|||
|
import torch
|
|||
|
|
|||
|
def load_model(path):
|
|||
|
tokenizer = AutoTokenizer.from_pretrained(path)
|
|||
|
model = AutoModel.from_pretrained(path)
|
|||
|
model.eval()
|
|||
|
return tokenizer, model
|
|||
|
|
|||
|
def embedding(tokenizer,model , sentences):
|
|||
|
"""_summary_
|
|||
|
|
|||
|
Args:
|
|||
|
tokenizer (_type_): 分词器
|
|||
|
model (_type_): 向量模型
|
|||
|
sentences (_type_): 句子,list
|
|||
|
|
|||
|
Returns:
|
|||
|
_type_: 向量,长度为1024,list
|
|||
|
"""
|
|||
|
encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
|
|||
|
with torch.no_grad():
|
|||
|
model_output = model(**encoded_input)
|
|||
|
# Perform pooling. In this case, cls pooling.
|
|||
|
sentence_embeddings = model_output[0][:, 0]
|
|||
|
# normalize embeddings
|
|||
|
sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
|
|||
|
return sentence_embeddings.cpu().numpy().tolist()
|
|||
|
|