embeddinggemma-300m-onnx / download_missing_hf_files.py
Alex Sadleir
add int4/int8
a1edf95
from huggingface_hub import snapshot_download
import os
import shutil
from sentence_transformers import SentenceTransformer
import torch
# Model repo and local directory
repo_id = "google/embeddinggemma-300m-qat-q4_0-unquantized"
local_dir = "embeddinggemma-300m"
# Download all files except model.safetensors and those already present
existing_files = set(os.listdir(local_dir))
# Download snapshot to a temp dir
temp_dir = "_hf_temp_download"
os.makedirs(temp_dir, exist_ok=True)
snapshot_download(
repo_id,
local_dir=temp_dir,
ignore_patterns=["model.safetensors"],
resume_download=True,
allow_patterns=None
)
# Copy missing files
for fname in os.listdir(temp_dir):
if fname not in existing_files:
shutil.move(os.path.join(temp_dir, fname), os.path.join(local_dir, fname))
print(f"Downloaded: {fname}")
else:
print(f"Already exists: {fname}")
# Clean up temp dir
shutil.rmtree(temp_dir)
print("Done.")
# Export Dense layers from SentenceTransformer to ONNX
st_model = SentenceTransformer(repo_id)
dense1 = st_model[2].linear
dense2 = st_model[3].linear
onnx_dir = os.path.join(local_dir, "onnx")
os.makedirs(onnx_dir, exist_ok=True)
# Export Dense1
dummy_input1 = torch.randn(1, dense1.in_features)
dense1 = dense1.to(dummy_input1.device)
torch.onnx.export(
dense1,
dummy_input1,
os.path.join(onnx_dir, "dense1.onnx"),
input_names=["input"],
output_names=["output"],
opset_version=14
)
print("Exported dense1.onnx")
# Export Dense2
dummy_input2 = torch.randn(1, dense2.in_features)
dense2 = dense2.to(dummy_input2.device)
torch.onnx.export(
dense2,
dummy_input2,
os.path.join(onnx_dir, "dense2.onnx"),
input_names=["input"],
output_names=["output"],
opset_version=14
)
print("Exported dense2.onnx")
# # Quantize dense1.onnx and dense2.onnx to int4 using ONNX Runtime matmul_4bits_quantizer
# from onnxruntime.quantization import (
# matmul_nbits_quantizer,
# quant_utils
# )
# from pathlib import Path
# onnx_dir = Path(onnx_dir)
# for dense_name in ["dense1.onnx", "dense2.onnx"]:
# model_fp32_path = onnx_dir / dense_name
# model_int4_path = model_fp32_path # Overwrite original file
# quant_config = matmul_nbits_quantizer.DefaultWeightOnlyQuantConfig(
# block_size=128,
# is_symmetric=True,
# accuracy_level=4,
# quant_format=quant_utils.QuantFormat.QOperator,
# op_types_to_quantize=("MatMul", "Gather"),
# quant_axes=( ("MatMul", 0), ("Gather", 1) )
# )
# model = quant_utils.load_model_with_shape_infer(model_fp32_path)
# quant = matmul_nbits_quantizer.MatMulNBitsQuantizer(
# model,
# nodes_to_exclude=None,
# nodes_to_include=None,
# algo_config=quant_config,
# )
# quant.process()
# quant.model.save_model_to_file(
# str(model_int4_path),
# True
# )
# print(f"Quantized {dense_name} to int4 and overwrote original file.")