from sentence_transformers import models import torch from transformers import AutoTokenizer from optimum.onnxruntime import ORTModelForFeatureExtraction import numpy as np import os import onnxruntime # ONNX pipeline for Gemma3 embedding model model_dir = "embeddinggemma-300m" tokenizer = AutoTokenizer.from_pretrained("google/embeddinggemma-300m-qat-q4_0-unquantized") device = "cuda" if torch.cuda.is_available() else "cpu" onnx_model = ORTModelForFeatureExtraction.from_pretrained( model_dir, file_name="model.onnx" ).to(device) class ONNXTransformer: def __init__(self, onnx_model, tokenizer, max_seq_length=2048): self.onnx_model = onnx_model self.tokenizer = tokenizer self.max_seq_length = max_seq_length def encode(self, sentences): inputs = self.tokenizer(sentences, return_tensors="pt", padding=True, truncation=True, max_length=self.max_seq_length) input_ids = inputs['input_ids'] sequence_length = input_ids.shape[1] position_ids = torch.arange(sequence_length)[None, :].expand(input_ids.shape[0], sequence_length) inputs['position_ids'] = position_ids.to(input_ids.device) with torch.no_grad(): outputs = self.onnx_model(**inputs) return outputs.last_hidden_state modules = [] onnx_transformer = ONNXTransformer(onnx_model, tokenizer, max_seq_length=2048) modules.append(onnx_transformer) for idx, name in [(1, "Pooling"), (2, "Dense"), (3, "Dense"), (4, "Normalize")]: module_path = os.path.join(model_dir, f"{idx}_{name}") if name == "Pooling": modules.append(models.Pooling(module_path)) elif name == "Dense": # Use ONNXRuntime for Dense layers dense_onnx_path = os.path.join(model_dir, "onnx", f"dense{idx-1}.onnx") modules.append(onnxruntime.InferenceSession(dense_onnx_path, providers=["CPUExecutionProvider"])) elif name == "Normalize": modules.append(models.Normalize()) class ONNXSentenceTransformer: def __init__(self, modules): self.modules = modules def encode(self, sentences): features = self.modules[0].encode(sentences) for module in self.modules[1:]: if isinstance(module, models.Pooling): features = module({'token_embeddings': features, 'attention_mask': torch.ones(features.shape[:2], device=features.device)})['sentence_embedding'] elif isinstance(module, onnxruntime.InferenceSession): # ONNX Dense layer expects shape [1, in_features], so process each embedding separately if isinstance(features, torch.Tensor): features = features.cpu().detach().numpy() outputs = [] for vec in features: ort_inputs = {module.get_inputs()[0].name: vec.reshape(1, -1)} out = module.run(None, ort_inputs)[0] outputs.append(out.squeeze(0)) features = np.stack(outputs, axis=0) elif isinstance(module, models.Normalize): # Normalize still uses PyTorch if not isinstance(features, torch.Tensor): features = torch.from_numpy(features) features = module({'sentence_embedding': features})['sentence_embedding'] if isinstance(features, torch.Tensor): return features.cpu().detach().numpy() return features onnx_st = ONNXSentenceTransformer(modules) def cosine_similarity(a, b): a = a.flatten() b = b.flatten() return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)) if __name__ == "__main__": words = ["apple", "banana", "car"] embeddings = onnx_st.encode(words) print(embeddings) for idx, embedding in enumerate(embeddings): print(f"Embedding {idx+1}: {embedding.shape}") print("\nCosine similarities:") print(f"apple vs banana: {cosine_similarity(embeddings[0], embeddings[1]):.4f}") print(f"apple vs car: {cosine_similarity(embeddings[0], embeddings[2]):.4f}") print(f"banana vs car: {cosine_similarity(embeddings[1], embeddings[2]):.4f}")