|
from sentence_transformers import models |
|
import torch |
|
from transformers import AutoTokenizer |
|
from optimum.onnxruntime import ORTModelForFeatureExtraction |
|
import numpy as np |
|
|
|
model_path = "./embeddinggemma-300m" |
|
tokenizer = AutoTokenizer.from_pretrained("google/embeddinggemma-300m-qat-q4_0-unquantized") |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
onnx_model = ORTModelForFeatureExtraction.from_pretrained(model_path).to(device) |
|
|
|
class ONNXSentenceTransformer: |
|
def __init__(self, model, tokenizer): |
|
self.model = model |
|
self.tokenizer = tokenizer |
|
self.word_embedding_dimension = 768 |
|
self.pooling = models.Pooling(word_embedding_dimension=self.word_embedding_dimension, pooling_mode_mean_tokens=True) |
|
|
|
def encode(self, sentences, batch_size=32): |
|
if isinstance(sentences, str): |
|
sentences = [sentences] |
|
embeddings = [] |
|
for i in range(0, len(sentences), batch_size): |
|
batch = sentences[i:i+batch_size] |
|
inputs = self.tokenizer(batch, return_tensors="pt", padding=True, truncation=True) |
|
input_ids = inputs['input_ids'] |
|
sequence_length = input_ids.shape[1] |
|
position_ids = torch.arange(sequence_length)[None, :].expand(input_ids.shape[0], sequence_length) |
|
inputs['position_ids'] = position_ids |
|
with torch.no_grad(): |
|
outputs = self.model(**inputs) |
|
last_hidden = outputs.last_hidden_state |
|
attention_mask = inputs['attention_mask'].to(last_hidden.device) |
|
features = {'token_embeddings': last_hidden, 'attention_mask': attention_mask} |
|
pooled = self.pooling(features)['sentence_embedding'] |
|
embeddings.append(pooled) |
|
return torch.cat(embeddings, dim=0).cpu().detach().numpy() |
|
|
|
|
|
|
|
onnx_st = ONNXSentenceTransformer(onnx_model, tokenizer) |
|
|
|
words = ["apple", "banana", "car"] |
|
embeddings = onnx_st.encode(words) |
|
print(embeddings) |
|
for idx, embedding in enumerate(embeddings): |
|
print(f"Embedding {idx+1}: {embedding.shape}") |
|
|
|
|
|
def cosine_similarity(a, b): |
|
a = a.flatten() |
|
b = b.flatten() |
|
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)) |
|
|
|
print("\nCosine similarities:") |
|
print(f"apple vs banana: {cosine_similarity(embeddings[0], embeddings[1]):.4f}") |
|
print(f"apple vs car: {cosine_similarity(embeddings[0], embeddings[2]):.4f}") |
|
print(f"banana vs car: {cosine_similarity(embeddings[1], embeddings[2]):.4f}") |
|
|