embeddinggemma-300m-onnx / onnx_gemma3_pipeline.py
Alex Sadleir
Add ONNX and PyTorch pipelines for Gemma3 embedding model
d5fb6c3
from sentence_transformers import models
import torch
from transformers import AutoTokenizer
from optimum.onnxruntime import ORTModelForFeatureExtraction
import numpy as np
import os
import onnxruntime
# ONNX pipeline for Gemma3 embedding model
model_dir = "embeddinggemma-300m"
tokenizer = AutoTokenizer.from_pretrained("google/embeddinggemma-300m-qat-q4_0-unquantized")
device = "cuda" if torch.cuda.is_available() else "cpu"
onnx_model = ORTModelForFeatureExtraction.from_pretrained(
model_dir,
file_name="model.onnx"
).to(device)
class ONNXTransformer:
def __init__(self, onnx_model, tokenizer, max_seq_length=2048):
self.onnx_model = onnx_model
self.tokenizer = tokenizer
self.max_seq_length = max_seq_length
def encode(self, sentences):
inputs = self.tokenizer(sentences, return_tensors="pt", padding=True, truncation=True, max_length=self.max_seq_length)
input_ids = inputs['input_ids']
sequence_length = input_ids.shape[1]
position_ids = torch.arange(sequence_length)[None, :].expand(input_ids.shape[0], sequence_length)
inputs['position_ids'] = position_ids.to(input_ids.device)
with torch.no_grad():
outputs = self.onnx_model(**inputs)
return outputs.last_hidden_state
modules = []
onnx_transformer = ONNXTransformer(onnx_model, tokenizer, max_seq_length=2048)
modules.append(onnx_transformer)
for idx, name in [(1, "Pooling"), (2, "Dense"), (3, "Dense"), (4, "Normalize")]:
module_path = os.path.join(model_dir, f"{idx}_{name}")
if name == "Pooling":
modules.append(models.Pooling(module_path))
elif name == "Dense":
# Use ONNXRuntime for Dense layers
dense_onnx_path = os.path.join(model_dir, "onnx", f"dense{idx-1}.onnx")
modules.append(onnxruntime.InferenceSession(dense_onnx_path, providers=["CPUExecutionProvider"]))
elif name == "Normalize":
modules.append(models.Normalize())
class ONNXSentenceTransformer:
def __init__(self, modules):
self.modules = modules
def encode(self, sentences):
features = self.modules[0].encode(sentences)
for module in self.modules[1:]:
if isinstance(module, models.Pooling):
features = module({'token_embeddings': features, 'attention_mask': torch.ones(features.shape[:2], device=features.device)})['sentence_embedding']
elif isinstance(module, onnxruntime.InferenceSession):
# ONNX Dense layer expects shape [1, in_features], so process each embedding separately
if isinstance(features, torch.Tensor):
features = features.cpu().detach().numpy()
outputs = []
for vec in features:
ort_inputs = {module.get_inputs()[0].name: vec.reshape(1, -1)}
out = module.run(None, ort_inputs)[0]
outputs.append(out.squeeze(0))
features = np.stack(outputs, axis=0)
elif isinstance(module, models.Normalize):
# Normalize still uses PyTorch
if not isinstance(features, torch.Tensor):
features = torch.from_numpy(features)
features = module({'sentence_embedding': features})['sentence_embedding']
if isinstance(features, torch.Tensor):
return features.cpu().detach().numpy()
return features
onnx_st = ONNXSentenceTransformer(modules)
def cosine_similarity(a, b):
a = a.flatten()
b = b.flatten()
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
if __name__ == "__main__":
words = ["apple", "banana", "car"]
embeddings = onnx_st.encode(words)
print(embeddings)
for idx, embedding in enumerate(embeddings):
print(f"Embedding {idx+1}: {embedding.shape}")
print("\nCosine similarities:")
print(f"apple vs banana: {cosine_similarity(embeddings[0], embeddings[1]):.4f}")
print(f"apple vs car: {cosine_similarity(embeddings[0], embeddings[2]):.4f}")
print(f"banana vs car: {cosine_similarity(embeddings[1], embeddings[2]):.4f}")