30 lines
927 B
Python
30 lines
927 B
Python
from transformers import AutoTokenizer, AutoModel
|
||
import torch
|
||
|
||
def load_model(path):
|
||
tokenizer = AutoTokenizer.from_pretrained(path)
|
||
model = AutoModel.from_pretrained(path)
|
||
model.eval()
|
||
return tokenizer, model
|
||
|
||
def embedding(tokenizer,model , sentences):
|
||
"""_summary_
|
||
|
||
Args:
|
||
tokenizer (_type_): 分词器
|
||
model (_type_): 向量模型
|
||
sentences (_type_): 句子,list
|
||
|
||
Returns:
|
||
_type_: 向量,长度为1024,list
|
||
"""
|
||
encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
|
||
with torch.no_grad():
|
||
model_output = model(**encoded_input)
|
||
# Perform pooling. In this case, cls pooling.
|
||
sentence_embeddings = model_output[0][:, 0]
|
||
# normalize embeddings
|
||
sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
|
||
return sentence_embeddings.cpu().numpy().tolist()
|
||
|