from sentence_transformers import SentenceTransformer import torch import numpy as np # Words to compare words = ["apple", "banana", "car"] # Load original SentenceTransformer (PyTorch, CUDA) st_model = SentenceTransformer("google/embeddinggemma-300m-qat-q4_0-unquantized") st_model = st_model.to("cuda" if torch.cuda.is_available() else "cpu") # Get PyTorch embeddings with torch.no_grad(): pt_embeddings = st_model.encode(words, convert_to_numpy=True, device="cuda" if torch.cuda.is_available() else "cpu") from onnx_gemma3_pipeline import onnx_st from transformers import AutoTokenizer from optimum.onnxruntime import ORTModelForFeatureExtraction # Basic mean pooling ONNX implementation def basic_mean_pooling(words): tokenizer = AutoTokenizer.from_pretrained("./embeddinggemma-300m") model = ORTModelForFeatureExtraction.from_pretrained("./embeddinggemma-300m") embeddings = [] for word in words: inputs = tokenizer(word, return_tensors="pt") input_ids = inputs['input_ids'] sequence_length = input_ids.shape[1] position_ids = np.arange(sequence_length)[None, :] position_ids = np.tile(position_ids, (input_ids.shape[0], 1)) inputs['position_ids'] = torch.tensor(position_ids, dtype=torch.long) outputs = model(**inputs) last_hidden = outputs.last_hidden_state attention_mask = inputs['attention_mask'] from sentence_transformers import models pooling = models.Pooling(word_embedding_dimension=last_hidden.shape[-1], pooling_mode_mean_tokens=True) features = {'token_embeddings': last_hidden, 'attention_mask': attention_mask} pooled = pooling(features)['sentence_embedding'] embeddings.append(pooled[0].detach().cpu().numpy()) return np.stack(embeddings) from transformers import AutoTokenizer from optimum.onnxruntime import ORTModelForFeatureExtraction onnx_embeddings = onnx_st.encode(words) # Cosine similarity function def cosine_similarity(a, b): a = a.flatten() b = b.flatten() return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)) print("Safetensor Cosine similarities:") print(f"apple vs banana: {cosine_similarity(pt_embeddings[0], pt_embeddings[1]):.4f}") print(f"apple vs car: {cosine_similarity(pt_embeddings[0], pt_embeddings[2]):.4f}") print(f"banana vs car: {cosine_similarity(pt_embeddings[1], pt_embeddings[2]):.4f}") print("\nONNX Cosine similarities:") print(f"apple vs banana: {cosine_similarity(onnx_embeddings[0], onnx_embeddings[1]):.4f}") print(f"apple vs car: {cosine_similarity(onnx_embeddings[0], onnx_embeddings[2]):.4f}") print(f"banana vs car: {cosine_similarity(onnx_embeddings[1], onnx_embeddings[2]):.4f}") # Basic mean pooling ONNX pipeline basic_embeddings = basic_mean_pooling(words) print("\nBasic ONNX (mean pooling only) Cosine similarities:") print(f"apple vs banana: {cosine_similarity(basic_embeddings[0], basic_embeddings[1]):.4f}") print(f"apple vs car: {cosine_similarity(basic_embeddings[0], basic_embeddings[2]):.4f}") print(f"banana vs car: {cosine_similarity(basic_embeddings[1], basic_embeddings[2]):.4f}")