embeddinggemma-300m-onnx / pytorch_gemma3_pipeline.py
Alex Sadleir
Add ONNX and PyTorch pipelines for Gemma3 embedding model
d5fb6c3
from sentence_transformers import models
import torch
from transformers import AutoTokenizer
from optimum.onnxruntime import ORTModelForFeatureExtraction
import numpy as np
# Load tokenizer and ONNX model
model_path = "./embeddinggemma-300m"
tokenizer = AutoTokenizer.from_pretrained("google/embeddinggemma-300m-qat-q4_0-unquantized")
device = "cuda" if torch.cuda.is_available() else "cpu"
onnx_model = ORTModelForFeatureExtraction.from_pretrained(model_path).to(device)
class ONNXSentenceTransformer:
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
self.word_embedding_dimension = 768
self.pooling = models.Pooling(word_embedding_dimension=self.word_embedding_dimension, pooling_mode_mean_tokens=True)
def encode(self, sentences, batch_size=32):
if isinstance(sentences, str):
sentences = [sentences]
embeddings = []
for i in range(0, len(sentences), batch_size):
batch = sentences[i:i+batch_size]
inputs = self.tokenizer(batch, return_tensors="pt", padding=True, truncation=True)
input_ids = inputs['input_ids']
sequence_length = input_ids.shape[1]
position_ids = torch.arange(sequence_length)[None, :].expand(input_ids.shape[0], sequence_length)
inputs['position_ids'] = position_ids
with torch.no_grad():
outputs = self.model(**inputs)
last_hidden = outputs.last_hidden_state
attention_mask = inputs['attention_mask'].to(last_hidden.device)
features = {'token_embeddings': last_hidden, 'attention_mask': attention_mask}
pooled = self.pooling(features)['sentence_embedding']
embeddings.append(pooled)
return torch.cat(embeddings, dim=0).cpu().detach().numpy()
# Usage example
onnx_st = ONNXSentenceTransformer(onnx_model, tokenizer)
words = ["apple", "banana", "car"]
embeddings = onnx_st.encode(words)
print(embeddings)
for idx, embedding in enumerate(embeddings):
print(f"Embedding {idx+1}: {embedding.shape}")
# Cosine similarity demonstration
def cosine_similarity(a, b):
a = a.flatten()
b = b.flatten()
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
print("\nCosine similarities:")
print(f"apple vs banana: {cosine_similarity(embeddings[0], embeddings[1]):.4f}")
print(f"apple vs car: {cosine_similarity(embeddings[0], embeddings[2]):.4f}")
print(f"banana vs car: {cosine_similarity(embeddings[1], embeddings[2]):.4f}")