|
""" |
|
Script to quantize ONNX models to additional formats: int4, int8, etc. |
|
Based on transformers.js/scripts/quantize.py, extended for more quantization options. |
|
""" |
|
|
|
from enum import Enum |
|
from tqdm import tqdm |
|
from typing import Set, List, Optional |
|
import onnx |
|
import os |
|
from dataclasses import dataclass, field |
|
from transformers import HfArgumentParser |
|
from onnxruntime.quantization import QuantType, QuantizationMode |
|
from onnxruntime.quantization.onnx_quantizer import ONNXQuantizer |
|
from onnxruntime.quantization.registry import IntegerOpsRegistry |
|
from onnxruntime.quantization.matmul_nbits_quantizer import MatMulNBitsQuantizer |
|
from onnxruntime.quantization.matmul_bnb4_quantizer import MatMulBnb4Quantizer |
|
import float16 |
|
import utils |
|
|
|
class QuantMode(Enum): |
|
FP16 = "fp16" |
|
Q8 = "q8" |
|
QI8 = "int8" |
|
QU8 = "uint8" |
|
Q4 = "q4" |
|
Q4F16 = "q4f16" |
|
BNB4 = "bnb4" |
|
INT4 = "int4" |
|
INT8 = "int8" |
|
|
|
QUANTIZE_SUFFIX_MAPPING = { |
|
QuantMode.Q8: "quantized", |
|
QuantMode.INT4: "int4", |
|
QuantMode.INT8: "int8", |
|
} |
|
|
|
QUANTIZE_OPTIONS = tuple(x.value for x in QuantMode) |
|
QUINT8_OPS = ( |
|
"Conv", |
|
"GroupQueryAttention", |
|
"MultiHeadAttention", |
|
) |
|
|
|
@dataclass |
|
class IOArguments: |
|
input_folder: str = field(metadata={"help": "Path of the input folder containing the .onnx models to quantize"}) |
|
output_folder: str = field(metadata={"help": "Path of the output folder where the quantized .onnx models will be saved"}) |
|
|
|
@dataclass |
|
class QuantizationArguments: |
|
modes: QuantMode = field(default=QUANTIZE_OPTIONS, metadata={"help": "Quantization mode to use.", "choices": QUANTIZE_OPTIONS, "nargs": "+",}) |
|
per_channel: bool = field(default=None, metadata={"help": "Whether to quantize weights per channel"}) |
|
reduce_range: bool = field(default=None, metadata={"help": "Whether to quantize weights with 7-bits."}) |
|
block_size: int = field(default=None, metadata={"help": "Block size for blockwise quantization."}) |
|
is_symmetric: bool = field(default=True, metadata={"help": "Indicate whether to quantize the model symmetrically"}) |
|
accuracy_level: int = field(default=None, metadata={"help": "Accuracy level of the 4-bit quantized MatMul computation."}) |
|
quant_type: int = field(default=MatMulBnb4Quantizer.NF4, metadata={"help": "Quantization data type. 0: FP4, 1: NF4", "choices": [MatMulBnb4Quantizer.FP4, MatMulBnb4Quantizer.NF4],}) |
|
op_block_list: List[str] = field(default=None, metadata={"help": "List of operators to exclude from quantization.", "nargs": "+",}) |
|
|
|
def quantize_int4( |
|
model: onnx.ModelProto, |
|
save_path: str, |
|
block_size: int = 32, |
|
is_symmetric: bool = True, |
|
accuracy_level: int = 4, |
|
): |
|
""" |
|
Quantize the weights of the model from float32 to 4-bit int using MatMulNBitsQuantizer |
|
""" |
|
quantizer = MatMulNBitsQuantizer( |
|
model=model, |
|
block_size=block_size, |
|
is_symmetric=is_symmetric, |
|
accuracy_level=accuracy_level, |
|
) |
|
quantizer.process() |
|
utils.check_and_save_model(quantizer.model.model, save_path) |
|
return quantizer.model.model |
|
|
|
def quantize_int8( |
|
model: onnx.ModelProto, |
|
save_path: str, |
|
per_channel: bool = False, |
|
reduce_range: bool = False, |
|
weight_type: QuantType = QuantType.QInt8, |
|
op_block_list: Optional[List[str]] = None, |
|
): |
|
""" |
|
Quantize the weights of the model from float32 to int8 |
|
""" |
|
op_types_to_quantize = set(IntegerOpsRegistry.keys()) |
|
if op_block_list is not None: |
|
op_types_to_quantize.difference_update(op_block_list) |
|
|
|
quantizer = ONNXQuantizer( |
|
model, |
|
per_channel, |
|
reduce_range, |
|
mode=QuantizationMode.IntegerOps, |
|
static=False, |
|
weight_qType=weight_type, |
|
activation_qType=QuantType.QUInt8, |
|
tensors_range=None, |
|
nodes_to_quantize=[], |
|
nodes_to_exclude=[], |
|
op_types_to_quantize=op_types_to_quantize, |
|
extra_options=dict(EnableSubgraph=True, MatMulConstBOnly=True), |
|
) |
|
quantizer.quantize_model() |
|
utils.check_and_save_model(quantizer.model.model, save_path) |
|
return quantizer.model.model |
|
|
|
def main(): |
|
parser = HfArgumentParser((IOArguments, QuantizationArguments)) |
|
io_args, quantization_args = parser.parse_args_into_dataclasses() |
|
input_folder = io_args.input_folder |
|
output_folder = io_args.output_folder |
|
if not quantization_args.modes: |
|
raise ValueError("At least one quantization mode must be specified") |
|
|
|
if not os.path.exists(input_folder): |
|
raise ValueError(f"Input folder {input_folder} does not exist") |
|
|
|
model_names_or_paths = [ |
|
os.path.join(input_folder, file) |
|
for file in os.listdir(input_folder) |
|
if file.endswith(".onnx") |
|
] |
|
if not model_names_or_paths: |
|
raise ValueError(f"No .onnx models found in {input_folder}") |
|
|
|
os.makedirs(output_folder, exist_ok=True) |
|
|
|
for model_path in tqdm(model_names_or_paths, desc="Models"): |
|
file_name_without_extension = os.path.splitext(os.path.basename(model_path))[0] |
|
model = onnx.load_model(model_path) |
|
for mode in tqdm(quantization_args.modes, desc="Modes"): |
|
try: |
|
suffix = QUANTIZE_SUFFIX_MAPPING.get(QuantMode(mode), mode) |
|
except Exception: |
|
suffix = mode |
|
save_path = os.path.join(output_folder, f"{file_name_without_extension}_{suffix}.onnx") |
|
mode_enum = QuantMode(mode) |
|
try: |
|
if mode_enum == QuantMode.FP16: |
|
float16.convert_float_to_float16( |
|
model, |
|
keep_io_types=True, |
|
disable_shape_infer=False, |
|
op_block_list=quantization_args.op_block_list or [] |
|
) |
|
|
|
elif mode_enum == QuantMode.INT4 or mode_enum == QuantMode.Q4: |
|
quantize_int4( |
|
model, |
|
save_path, |
|
block_size=quantization_args.block_size or 32, |
|
is_symmetric=quantization_args.is_symmetric, |
|
accuracy_level=quantization_args.accuracy_level or 0, |
|
) |
|
|
|
elif mode_enum == QuantMode.INT8 or mode_enum == QuantMode.QI8: |
|
quantize_int8( |
|
model, |
|
save_path, |
|
per_channel=quantization_args.per_channel or False, |
|
reduce_range=quantization_args.reduce_range or False, |
|
weight_type=QuantType.QInt8, |
|
op_block_list=quantization_args.op_block_list, |
|
) |
|
|
|
elif mode_enum == QuantMode.Q8: |
|
quantize_int8( |
|
model, |
|
save_path, |
|
per_channel=quantization_args.per_channel or False, |
|
reduce_range=quantization_args.reduce_range or False, |
|
weight_type=QuantType.QUInt8, |
|
op_block_list=quantization_args.op_block_list, |
|
) |
|
|
|
|
|
except Exception as e: |
|
print(f"[WARN] Quantization mode '{mode}' failed for model '{model_path}': {e}") |
|
continue |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|