embeddinggemma-300m-onnx / gemma3_mean_pooling_basic.py
Alex Sadleir
Add ONNX and PyTorch pipelines for Gemma3 embedding model
d5fb6c3
from transformers import AutoTokenizer
from optimum.onnxruntime import ORTModelForFeatureExtraction
from sentence_transformers import models
import numpy as np
import torch
tokenizer = AutoTokenizer.from_pretrained("./embeddinggemma-300m")
model = ORTModelForFeatureExtraction.from_pretrained("./embeddinggemma-300m")
inputs = tokenizer("apple", return_tensors="pt")
print(inputs)
input_ids = inputs['input_ids']
sequence_length = input_ids.shape[1]
position_ids = np.arange(sequence_length)[None, :]
position_ids = np.tile(position_ids, (input_ids.shape[0], 1))
inputs['position_ids'] = torch.tensor(position_ids, dtype=torch.long)
outputs = model(**inputs)
last_hidden = outputs.last_hidden_state
attention_mask = inputs['attention_mask']
# Use SentenceTransformer's Pooling module for mean pooling
pooling = models.Pooling(word_embedding_dimension=last_hidden.shape[-1], pooling_mode_mean_tokens=True)
features = {'token_embeddings': last_hidden, 'attention_mask': attention_mask}
pooled = pooling(features)['sentence_embedding']
print("Mean pooled:", pooled[0][:5].detach().cpu().numpy())