from sentence_transformers import models import torch from transformers import AutoTokenizer from optimum.onnxruntime import ORTModelForFeatureExtraction import numpy as np # Load tokenizer and ONNX model model_path = "./embeddinggemma-300m" tokenizer = AutoTokenizer.from_pretrained("google/embeddinggemma-300m-qat-q4_0-unquantized") device = "cuda" if torch.cuda.is_available() else "cpu" onnx_model = ORTModelForFeatureExtraction.from_pretrained(model_path).to(device) class ONNXSentenceTransformer: def __init__(self, model, tokenizer): self.model = model self.tokenizer = tokenizer self.word_embedding_dimension = 768 self.pooling = models.Pooling(word_embedding_dimension=self.word_embedding_dimension, pooling_mode_mean_tokens=True) def encode(self, sentences, batch_size=32): if isinstance(sentences, str): sentences = [sentences] embeddings = [] for i in range(0, len(sentences), batch_size): batch = sentences[i:i+batch_size] inputs = self.tokenizer(batch, return_tensors="pt", padding=True, truncation=True) input_ids = inputs['input_ids'] sequence_length = input_ids.shape[1] position_ids = torch.arange(sequence_length)[None, :].expand(input_ids.shape[0], sequence_length) inputs['position_ids'] = position_ids with torch.no_grad(): outputs = self.model(**inputs) last_hidden = outputs.last_hidden_state attention_mask = inputs['attention_mask'].to(last_hidden.device) features = {'token_embeddings': last_hidden, 'attention_mask': attention_mask} pooled = self.pooling(features)['sentence_embedding'] embeddings.append(pooled) return torch.cat(embeddings, dim=0).cpu().detach().numpy() # Usage example onnx_st = ONNXSentenceTransformer(onnx_model, tokenizer) words = ["apple", "banana", "car"] embeddings = onnx_st.encode(words) print(embeddings) for idx, embedding in enumerate(embeddings): print(f"Embedding {idx+1}: {embedding.shape}") # Cosine similarity demonstration def cosine_similarity(a, b): a = a.flatten() b = b.flatten() return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)) print("\nCosine similarities:") print(f"apple vs banana: {cosine_similarity(embeddings[0], embeddings[1]):.4f}") print(f"apple vs car: {cosine_similarity(embeddings[0], embeddings[2]):.4f}") print(f"banana vs car: {cosine_similarity(embeddings[1], embeddings[2]):.4f}")