File size: 2,961 Bytes
0fdf9c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a1edf95
0fdf9c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d5fb6c3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
---
base_model:
- google/embeddinggemma-300m-qat-q4_0-unquantized
---
# Gemma3 Embedding Model: ONNX Conversion Demonstration

This repository demonstrates the conversion and comparison of the Gemma3 embedding model from Hugging Face to ONNX format using optimum-onnx. It includes scripts for both ONNX and PyTorch inference pipelines, as well as a comparison of their outputs.

## Files

- `onnx_gemma3_pipeline.py`: Runs the Gemma3 embedding model using ONNXRuntime, including post-processing steps (Pooling, Dense, Normalize) with ONNX exported layers.
- `pytorch_gemma3_pipeline.py`: Runs the original Gemma3 embedding model using PyTorch and SentenceTransformer for reference.
- `compare_gemma3_onnx_vs_pytorch.py`: Compares the output embeddings and cosine similarities between the ONNX and PyTorch pipelines.
- `download_missing_hf_files.py`: Downloads required files from Hugging Face and exports Dense layers to ONNX.
- `gemma3_mean_pooling_basic.py`: The most basic implementation, running Gemma3 ONNX inference with only mean pooling (no Dense or Normalize stages).

## Pipeline Differences

Both pipelines use ONNXRuntime for transformer inference via `ORTModelForFeatureExtraction`. The key difference is in post-processing:

- **ONNX pipeline** (`onnx_gemma3_pipeline.py`): Uses ONNXRuntime for both the transformer and Dense layers (exported to ONNX), making most of the pipeline ONNX-based except for normalization.
- **PyTorch pipeline** (`pytorch_gemma3_pipeline.py`): Uses ONNXRuntime for the transformer, but all post-processing (Pooling, Dense, Normalize) is performed with PyTorch modules from SentenceTransformer.

This demonstrates how ONNX conversion can offload more computation for faster, hardware-agnostic inference, while the PyTorch pipeline serves as the reference implementation.

## Setup

1. Install dependencies:
	 ```sh
	 pip install git+https://github.com/simondanielsson/optimum-onnx.git@feature/add-gemma3-export
	 pip install git+https://github.com/huggingface/[email protected]
	 pip install sentence-transformers onnxruntime safetensors huggingface_hub
	 ```
2. Export the ONNX model:
	 ```sh
	 optimum-cli export onnx --model google/embeddinggemma-300m-qat-q4_0-unquantized --optimize O3 --slim embeddinggemma-300m-onnx
	 python download_missing_hf_files.py
	 ```

## Usage

- Run the ONNX pipeline:
	```sh
	python onnx_gemma3_pipeline.py
	```
- Run the PyTorch pipeline:
	```sh
	python pytorch_gemma3_pipeline.py
	```
- Compare outputs:
	```sh
	python compare_gemma3_onnx_vs_pytorch.py
	```

## Results

The comparison script prints cosine similarities between sample word embeddings (e.g., "apple", "banana", "car") for both ONNX and PyTorch pipelines, demonstrating the fidelity of the ONNX conversion.

## References
- [Optimum-ONNX Gemma3 PR](https://github.com/huggingface/optimum-onnx/pull/50)
- [Gemma3 Model](https://huggingface.co/google/embeddinggemma-300m-qat-q4_0-unquantized)