from huggingface_hub import snapshot_download import os import shutil from sentence_transformers import SentenceTransformer import torch # Model repo and local directory repo_id = "google/embeddinggemma-300m-qat-q4_0-unquantized" local_dir = "embeddinggemma-300m" # Download all files except model.safetensors and those already present existing_files = set(os.listdir(local_dir)) # Download snapshot to a temp dir temp_dir = "_hf_temp_download" os.makedirs(temp_dir, exist_ok=True) snapshot_download( repo_id, local_dir=temp_dir, ignore_patterns=["model.safetensors"], resume_download=True, allow_patterns=None ) # Copy missing files for fname in os.listdir(temp_dir): if fname not in existing_files: shutil.move(os.path.join(temp_dir, fname), os.path.join(local_dir, fname)) print(f"Downloaded: {fname}") else: print(f"Already exists: {fname}") # Clean up temp dir shutil.rmtree(temp_dir) print("Done.") # Export Dense layers from SentenceTransformer to ONNX st_model = SentenceTransformer(repo_id) dense1 = st_model[2].linear dense2 = st_model[3].linear onnx_dir = os.path.join(local_dir, "onnx") os.makedirs(onnx_dir, exist_ok=True) # Export Dense1 dummy_input1 = torch.randn(1, dense1.in_features) dense1 = dense1.to(dummy_input1.device) torch.onnx.export( dense1, dummy_input1, os.path.join(onnx_dir, "dense1.onnx"), input_names=["input"], output_names=["output"], opset_version=14 ) print("Exported dense1.onnx") # Export Dense2 dummy_input2 = torch.randn(1, dense2.in_features) dense2 = dense2.to(dummy_input2.device) torch.onnx.export( dense2, dummy_input2, os.path.join(onnx_dir, "dense2.onnx"), input_names=["input"], output_names=["output"], opset_version=14 ) print("Exported dense2.onnx") # # Quantize dense1.onnx and dense2.onnx to int4 using ONNX Runtime matmul_4bits_quantizer # from onnxruntime.quantization import ( # matmul_nbits_quantizer, # quant_utils # ) # from pathlib import Path # onnx_dir = Path(onnx_dir) # for dense_name in ["dense1.onnx", "dense2.onnx"]: # model_fp32_path = onnx_dir / dense_name # model_int4_path = model_fp32_path # Overwrite original file # quant_config = matmul_nbits_quantizer.DefaultWeightOnlyQuantConfig( # block_size=128, # is_symmetric=True, # accuracy_level=4, # quant_format=quant_utils.QuantFormat.QOperator, # op_types_to_quantize=("MatMul", "Gather"), # quant_axes=( ("MatMul", 0), ("Gather", 1) ) # ) # model = quant_utils.load_model_with_shape_infer(model_fp32_path) # quant = matmul_nbits_quantizer.MatMulNBitsQuantizer( # model, # nodes_to_exclude=None, # nodes_to_include=None, # algo_config=quant_config, # ) # quant.process() # quant.model.save_model_to_file( # str(model_int4_path), # True # ) # print(f"Quantized {dense_name} to int4 and overwrote original file.")