Gemma3 Embedding Model: ONNX Conversion Demonstration
This repository demonstrates the conversion and comparison of the Gemma3 embedding model from Hugging Face to ONNX format using optimum-onnx. It includes scripts for both ONNX and PyTorch inference pipelines, as well as a comparison of their outputs.
Files
onnx_gemma3_pipeline.py
: Runs the Gemma3 embedding model using ONNXRuntime, including post-processing steps (Pooling, Dense, Normalize) with ONNX exported layers.pytorch_gemma3_pipeline.py
: Runs the original Gemma3 embedding model using PyTorch and SentenceTransformer for reference.compare_gemma3_onnx_vs_pytorch.py
: Compares the output embeddings and cosine similarities between the ONNX and PyTorch pipelines.download_missing_hf_files.py
: Downloads required files from Hugging Face and exports Dense layers to ONNX.gemma3_mean_pooling_basic.py
: The most basic implementation, running Gemma3 ONNX inference with only mean pooling (no Dense or Normalize stages).
Pipeline Differences
Both pipelines use ONNXRuntime for transformer inference via ORTModelForFeatureExtraction
. The key difference is in post-processing:
- ONNX pipeline (
onnx_gemma3_pipeline.py
): Uses ONNXRuntime for both the transformer and Dense layers (exported to ONNX), making most of the pipeline ONNX-based except for normalization. - PyTorch pipeline (
pytorch_gemma3_pipeline.py
): Uses ONNXRuntime for the transformer, but all post-processing (Pooling, Dense, Normalize) is performed with PyTorch modules from SentenceTransformer.
This demonstrates how ONNX conversion can offload more computation for faster, hardware-agnostic inference, while the PyTorch pipeline serves as the reference implementation.
Setup
- Install dependencies:
pip install git+https://github.com/simondanielsson/optimum-onnx.git@feature/add-gemma3-export pip install git+https://github.com/huggingface/[email protected] pip install sentence-transformers onnxruntime safetensors huggingface_hub
- Export the ONNX model:
optimum-cli export onnx --model google/embeddinggemma-300m-qat-q4_0-unquantized --optimize O3 --slim embeddinggemma-300m-onnx python download_missing_hf_files.py
Usage
- Run the ONNX pipeline:
python onnx_gemma3_pipeline.py
- Run the PyTorch pipeline:
python pytorch_gemma3_pipeline.py
- Compare outputs:
python compare_gemma3_onnx_vs_pytorch.py
Results
The comparison script prints cosine similarities between sample word embeddings (e.g., "apple", "banana", "car") for both ONNX and PyTorch pipelines, demonstrating the fidelity of the ONNX conversion.