Alex Sadleir commited on
Commit
a1edf95
·
1 Parent(s): d5fb6c3

add int4/int8

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__/
README.md CHANGED
@@ -33,7 +33,7 @@ This demonstrates how ONNX conversion can offload more computation for faster, h
33
  ```
34
  2. Export the ONNX model:
35
  ```sh
36
- optimum-cli export onnx --model google/embeddinggemma-300m-qat-q4_0-unquantized embeddinggemma-300m-onnx
37
  python download_missing_hf_files.py
38
  ```
39
 
 
33
  ```
34
  2. Export the ONNX model:
35
  ```sh
36
+ optimum-cli export onnx --model google/embeddinggemma-300m-qat-q4_0-unquantized --optimize O3 --slim embeddinggemma-300m-onnx
37
  python download_missing_hf_files.py
38
  ```
39
 
__pycache__/onnx_gemma3_pipeline.cpython-312.pyc DELETED
Binary file (6.63 kB)
 
download_missing_hf_files.py CHANGED
@@ -67,3 +67,36 @@ torch.onnx.export(
67
  opset_version=14
68
  )
69
  print("Exported dense2.onnx")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  opset_version=14
68
  )
69
  print("Exported dense2.onnx")
70
+
71
+ # # Quantize dense1.onnx and dense2.onnx to int4 using ONNX Runtime matmul_4bits_quantizer
72
+ # from onnxruntime.quantization import (
73
+ # matmul_nbits_quantizer,
74
+ # quant_utils
75
+ # )
76
+ # from pathlib import Path
77
+
78
+ # onnx_dir = Path(onnx_dir)
79
+ # for dense_name in ["dense1.onnx", "dense2.onnx"]:
80
+ # model_fp32_path = onnx_dir / dense_name
81
+ # model_int4_path = model_fp32_path # Overwrite original file
82
+ # quant_config = matmul_nbits_quantizer.DefaultWeightOnlyQuantConfig(
83
+ # block_size=128,
84
+ # is_symmetric=True,
85
+ # accuracy_level=4,
86
+ # quant_format=quant_utils.QuantFormat.QOperator,
87
+ # op_types_to_quantize=("MatMul", "Gather"),
88
+ # quant_axes=( ("MatMul", 0), ("Gather", 1) )
89
+ # )
90
+ # model = quant_utils.load_model_with_shape_infer(model_fp32_path)
91
+ # quant = matmul_nbits_quantizer.MatMulNBitsQuantizer(
92
+ # model,
93
+ # nodes_to_exclude=None,
94
+ # nodes_to_include=None,
95
+ # algo_config=quant_config,
96
+ # )
97
+ # quant.process()
98
+ # quant.model.save_model_to_file(
99
+ # str(model_int4_path),
100
+ # True
101
+ # )
102
+ # print(f"Quantized {dense_name} to int4 and overwrote original file.")
embeddinggemma-300m/model.onnx CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:dee985629c11dd0f70531093aeb8e8f7f5ddfb403f6c2705db340d58e4e03ffb
3
- size 1212541436
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:511afd7b7ed2b58a61876a6aef4c1113a93b2afad17cf363753eef48f4669a41
3
+ size 1212258625
embeddinggemma-300m/onnx/model_int4.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f79136a20219163cdf5701a85422ccde126db44d2d51f6bcfc07b63edc0efab9
3
+ size 869894596
embeddinggemma-300m/onnx/model_int8.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7c74d91f9c44be1b3a7cbe5875d61597b493ae2789e7a62f540b351eaf0cdc57
3
+ size 265916124
float16.py ADDED
@@ -0,0 +1,878 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+ #
3
+ # Copyright (c) Microsoft Corporation, Hugging Face. All rights reserved.
4
+ #
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+ #
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+ #
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+
24
+ from typing import Optional
25
+ import itertools
26
+ import numpy as np
27
+ import onnx
28
+ import packaging.version as pv
29
+ import warnings
30
+ from onnx import helper, numpy_helper
31
+ from onnx import onnx_pb as onnx_proto
32
+ import onnxslim.third_party.onnx_graphsurgeon as gs
33
+
34
+
35
+ FLOAT32 = 1
36
+ FLOAT16 = 10
37
+
38
+
39
+ def _npfloat16_to_int(np_list):
40
+ """
41
+ Convert numpy float16 to python int.
42
+
43
+ :param np_list: numpy float16 list
44
+ :return int_list: python int list
45
+ """
46
+ return [int(bin(_.view("H"))[2:].zfill(16), 2) for _ in np_list]
47
+
48
+
49
+ def convert_np_to_float16(np_array, min_positive_val=1e-7, max_finite_val=1e4):
50
+ """
51
+ Convert float32 numpy array to float16 without changing sign or finiteness.
52
+ Positive values less than min_positive_val are mapped to min_positive_val.
53
+ Positive finite values greater than max_finite_val are mapped to max_finite_val.
54
+ Similar for negative values. NaN, 0, inf, and -inf are unchanged.
55
+ """
56
+
57
+ def between(a, b, c):
58
+ return np.logical_and(a < b, b < c)
59
+
60
+ positive_values = np_array[np.where(np_array > 0)]
61
+ if positive_values.shape[0] > 0:
62
+ pos_max = positive_values.max()
63
+ pos_min = positive_values.min()
64
+
65
+ if pos_max >= max_finite_val:
66
+ warnings.warn(
67
+ "the float32 number {} will be truncated to {}".format(
68
+ pos_max, max_finite_val
69
+ )
70
+ )
71
+
72
+ if pos_min <= min_positive_val:
73
+ warnings.warn(
74
+ "the float32 number {} will be truncated to {}".format(
75
+ pos_min, min_positive_val
76
+ )
77
+ )
78
+
79
+ negative_values = np_array[np.where(np_array < 0)]
80
+ if negative_values.shape[0] > 0:
81
+ neg_max = negative_values.max()
82
+ neg_min = negative_values.min()
83
+
84
+ if neg_min <= -max_finite_val:
85
+ warnings.warn(
86
+ "the float32 number {} will be truncated to {}".format(
87
+ neg_min, -max_finite_val
88
+ )
89
+ )
90
+
91
+ if neg_max >= -min_positive_val:
92
+ warnings.warn(
93
+ "the float32 number {} will be truncated to {}".format(
94
+ neg_max, -min_positive_val
95
+ )
96
+ )
97
+
98
+ np_array = np.where(
99
+ between(0, np_array, min_positive_val), min_positive_val, np_array
100
+ )
101
+ np_array = np.where(
102
+ between(-min_positive_val, np_array, 0), -min_positive_val, np_array
103
+ )
104
+ np_array = np.where(
105
+ between(max_finite_val, np_array, float("inf")), max_finite_val, np_array
106
+ )
107
+ np_array = np.where(
108
+ between(float("-inf"), np_array, -max_finite_val), -max_finite_val, np_array
109
+ )
110
+ return np.float16(np_array)
111
+
112
+
113
+ def convert_tensor_float_to_float16(tensor, min_positive_val=1e-7, max_finite_val=1e4):
114
+ """
115
+ Convert tensor float to float16.
116
+
117
+ :param tensor: TensorProto object
118
+ :return tensor_float16: converted TensorProto object
119
+ """
120
+ if not isinstance(tensor, onnx_proto.TensorProto):
121
+ raise ValueError(
122
+ "Expected input type is an ONNX TensorProto but got %s" % type(tensor)
123
+ )
124
+
125
+ if tensor.data_type == onnx_proto.TensorProto.FLOAT:
126
+ tensor.data_type = onnx_proto.TensorProto.FLOAT16
127
+ # convert float_data (float type) to float16 and write to int32_data
128
+ if tensor.float_data:
129
+ float16_data = convert_np_to_float16(
130
+ np.array(tensor.float_data), min_positive_val, max_finite_val
131
+ )
132
+ int_list = _npfloat16_to_int(float16_data)
133
+ tensor.int32_data[:] = int_list
134
+ tensor.float_data[:] = []
135
+ # convert raw_data (bytes type)
136
+ if tensor.raw_data:
137
+ # convert n.raw_data to float
138
+ float32_list = np.fromstring(tensor.raw_data, dtype="float32")
139
+ # convert float to float16
140
+ float16_list = convert_np_to_float16(
141
+ float32_list, min_positive_val, max_finite_val
142
+ )
143
+ # convert float16 to bytes and write back to raw_data
144
+ tensor.raw_data = float16_list.tostring()
145
+ return tensor
146
+
147
+
148
+ def make_value_info_from_tensor(tensor):
149
+ shape = numpy_helper.to_array(tensor).shape
150
+ return helper.make_tensor_value_info(tensor.name, tensor.data_type, shape)
151
+
152
+
153
+ DEFAULT_OP_BLOCK_LIST = [
154
+ "ArrayFeatureExtractor",
155
+ "Binarizer",
156
+ "CastMap",
157
+ "CategoryMapper",
158
+ "DictVectorizer",
159
+ "FeatureVectorizer",
160
+ "Imputer",
161
+ "LabelEncoder",
162
+ "LinearClassifier",
163
+ "LinearRegressor",
164
+ "Normalizer",
165
+ "OneHotEncoder",
166
+ "RandomUniformLike",
167
+ "SVMClassifier",
168
+ "SVMRegressor",
169
+ "Scaler",
170
+ "TreeEnsembleClassifier",
171
+ "TreeEnsembleRegressor",
172
+ "ZipMap",
173
+ "NonMaxSuppression",
174
+ "TopK",
175
+ "RoiAlign",
176
+ "Resize",
177
+ # 'Range',
178
+ "CumSum",
179
+ "Min",
180
+ "Max",
181
+ "Upsample",
182
+ # NEW:
183
+ "RandomNormalLike",
184
+ # TODO: Ideally, "Cast" nodes should not be here, for the following reasons:
185
+ # - It breaks the semantics that the default list contains "ops that are not supported for float16 in ONNX Runtime".
186
+ # - When fp32 casts already exist in the model (e.g., for rotary embeddings), this script will insert redundant casts around it.
187
+ # However, without it, the graphs produced are invalid. Eventually, we will resolve this.
188
+ "Cast",
189
+ ]
190
+
191
+
192
+ def initial_checking(model, disable_shape_infer):
193
+ func_infer_shape = None
194
+ if not disable_shape_infer and pv.Version(onnx.__version__) >= pv.Version("1.2"):
195
+ try:
196
+ from onnx.shape_inference import infer_shapes
197
+
198
+ func_infer_shape = infer_shapes
199
+ finally:
200
+ pass
201
+
202
+ if not isinstance(model, onnx_proto.ModelProto):
203
+ raise ValueError(
204
+ "Expected model type is an ONNX ModelProto but got %s" % type(model)
205
+ )
206
+
207
+ if func_infer_shape is not None:
208
+ model = func_infer_shape(model)
209
+
210
+ is_fp16_ready_flag = check_if_fp16_ready(model.graph)
211
+
212
+ return model, func_infer_shape, is_fp16_ready_flag
213
+
214
+
215
+ def convert_float_to_float16(
216
+ model,
217
+ min_positive_val=1e-7,
218
+ max_finite_val=1e4,
219
+ keep_io_types=False,
220
+ disable_shape_infer=False,
221
+ op_block_list=None,
222
+ node_block_list=None,
223
+ check_fp16_ready=True,
224
+ ):
225
+
226
+ # create blocklists
227
+ if op_block_list is None:
228
+ op_block_list = DEFAULT_OP_BLOCK_LIST
229
+ if node_block_list is None:
230
+ node_block_list = []
231
+ op_block_list = set(op_block_list)
232
+ node_block_list = set(node_block_list)
233
+
234
+ global_input_name_dict = (
235
+ {}
236
+ ) # key: input name, value: new output name after Cast node
237
+ # basic checking, including shape inference
238
+ model, func_infer_shape, is_fp16_ready_flag = initial_checking(
239
+ model, disable_shape_infer
240
+ )
241
+ if is_fp16_ready_flag and check_fp16_ready:
242
+ raise ValueError(
243
+ "The model is already converted to float16, if convert again, the model might be wrong. \n If you are sure to convert again, please set check_fp16_ready=False."
244
+ )
245
+
246
+ graph_stack = [model.graph]
247
+
248
+ is_top_level = True
249
+ while graph_stack:
250
+ next_level = []
251
+ for curr_graph in graph_stack:
252
+ process_graph_input(
253
+ curr_graph, is_top_level, keep_io_types, global_input_name_dict
254
+ )
255
+ value_info_block_list = process_tensor_in_node(
256
+ curr_graph,
257
+ op_block_list,
258
+ node_block_list,
259
+ min_positive_val,
260
+ max_finite_val,
261
+ )
262
+ process_value_info(curr_graph, value_info_block_list)
263
+ process_node_in_block_list(
264
+ curr_graph, global_input_name_dict, op_block_list, node_block_list
265
+ )
266
+ process_initializers(
267
+ curr_graph,
268
+ op_block_list,
269
+ node_block_list,
270
+ min_positive_val,
271
+ max_finite_val,
272
+ )
273
+ process_graph_output(curr_graph, is_top_level, keep_io_types)
274
+ sub_graph_list = get_next_level_graph(
275
+ curr_graph, op_block_list, node_block_list
276
+ )
277
+ if len(sub_graph_list) > 0:
278
+ next_level.extend(sub_graph_list)
279
+
280
+ if not is_top_level:
281
+ process_node_input_output(curr_graph, global_input_name_dict)
282
+ is_top_level = False # Going to process sub-graph
283
+ graph_stack = next_level
284
+
285
+ remove_unnecessary_cast_node(model.graph)
286
+
287
+ # Topologically sort the graph
288
+ # NOTE: We do not perform another round of optimization as the model is already optimized
289
+ graph = gs.import_onnx(model)
290
+ graph.toposort()
291
+ model = gs.export_onnx(graph)
292
+
293
+ return model
294
+
295
+
296
+ # Change the input/output of the node to the new output name after Cast node for sub-graph
297
+ # Because there have NO value_info start from
298
+ def process_node_input_output(
299
+ graph: onnx_proto.GraphProto, global_input_name_dict: dict
300
+ ):
301
+ for node in graph.node:
302
+ for i, input_name in enumerate(node.input):
303
+ if input_name in global_input_name_dict:
304
+ node.input[i] = global_input_name_dict[input_name]
305
+ for i, output_name in enumerate(node.output):
306
+ if output_name in global_input_name_dict:
307
+ node.output[i] = global_input_name_dict[output_name]
308
+
309
+
310
+ def process_graph_input(
311
+ graph: onnx_proto.GraphProto,
312
+ is_top_level: bool,
313
+ is_io_fp32: bool,
314
+ global_input_name_dict: dict,
315
+ ):
316
+ # The input dtype is float32, need to cast to fp16
317
+ if is_top_level and is_io_fp32:
318
+ for graph_input in graph.input: # n_input is ValueInfoProto
319
+ if graph_input.type.tensor_type.elem_type == onnx_proto.TensorProto.FLOAT:
320
+ downstream_nodes = find_downstream_node_by_input_name(
321
+ graph, graph_input.name
322
+ )
323
+ for d_node in downstream_nodes:
324
+ # More than one node may consume the model input, so we only create
325
+ # a single cast node, and then reuse this node when needed.
326
+ cast_exists = graph_input.name in global_input_name_dict
327
+ if cast_exists:
328
+ cast_node_output_name = global_input_name_dict[graph_input.name]
329
+ else:
330
+ cast_node_output_name = graph_input.name + "_fp16"
331
+ add_cast_node(
332
+ graph,
333
+ [graph_input.name],
334
+ [cast_node_output_name],
335
+ cast_node_output_name, # Set node name same as output name
336
+ FLOAT16,
337
+ )
338
+ add_new_value_info(
339
+ graph,
340
+ graph_input,
341
+ cast_node_output_name,
342
+ onnx_proto.TensorProto.FLOAT16,
343
+ )
344
+ for i, input_name in enumerate(d_node.input):
345
+ if input_name == graph_input.name:
346
+ d_node.input[i] = (
347
+ cast_node_output_name # Change the input of the second node
348
+ )
349
+ global_input_name_dict[graph_input.name] = (
350
+ cast_node_output_name
351
+ )
352
+
353
+ # For the sub-graph, don't do cast
354
+ else: # Change the input dtype to fp16 without any cast
355
+ for graph_input in graph.input:
356
+ if graph_input.type.tensor_type.elem_type == onnx_proto.TensorProto.FLOAT:
357
+ graph_input.type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT16
358
+
359
+
360
+ def process_graph_output(
361
+ graph: onnx_proto.GraphProto, is_top_level: bool, is_io_fp32: bool
362
+ ):
363
+ if is_top_level and is_io_fp32: # the output dtype is float32, need to cast to fp16
364
+ for i, graph_output in enumerate(graph.output):
365
+ if graph_output.type.tensor_type.elem_type == onnx_proto.TensorProto.FLOAT:
366
+ new_producer_name = graph_output.name + "_fp16"
367
+ original_name = graph_output.name # The correct output name
368
+
369
+ # Get the node(s) that produce the model output
370
+ # These will most likely be fp16, but could be fp32 if the previous node is in block_list
371
+ upstream_nodes = find_upstream_node_by_output_name(graph, original_name)
372
+ assert len(upstream_nodes) == 1 # Should be only one node
373
+
374
+ producer_node = upstream_nodes[0]
375
+
376
+ for i, output_name in enumerate(producer_node.output):
377
+ if output_name == original_name:
378
+ producer_node.output[i] = new_producer_name
379
+
380
+ cast_node_name = new_producer_name + "_input_cast" + str(i)
381
+ add_cast_node(
382
+ graph,
383
+ [new_producer_name],
384
+ [original_name],
385
+ cast_node_name,
386
+ onnx_proto.TensorProto.FLOAT,
387
+ )
388
+ for value_info in graph.value_info:
389
+ if original_name == value_info.name:
390
+ value_info.type.tensor_type.elem_type = (
391
+ onnx_proto.TensorProto.FLOAT
392
+ )
393
+
394
+ # Get the node(s) that consume the model output
395
+ downstream_nodes = find_downstream_node_by_input_name(
396
+ graph,
397
+ original_name,
398
+ include_subgraphs=False,
399
+ )
400
+
401
+ # It is possible that the producer node is also input to downstream nodes
402
+ # So, we update the inputs of these downstream nodes
403
+ for d_node in downstream_nodes:
404
+ for i, input_name in enumerate(d_node.input):
405
+ if input_name == original_name:
406
+ d_node.input[i] = new_producer_name
407
+
408
+ else: # change the output dtype to fp16 in tensor
409
+ for graph_output in graph.output:
410
+ if graph_output.type.tensor_type.elem_type == onnx_proto.TensorProto.FLOAT:
411
+ graph_output.type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT16
412
+
413
+
414
+ def process_node_in_block_list(
415
+ graph: onnx_proto.GraphProto,
416
+ global_input_name_dict: dict,
417
+ op_block_list: list,
418
+ node_block_list: list,
419
+ ):
420
+ # NB: Important to create a copy of the nodes in the graph to avoid modifying
421
+ # the graph in-place while iterating (causing an infinite loop)
422
+ for node in list(graph.node):
423
+ if (node.op_type in op_block_list) or (node.name in node_block_list):
424
+ insert_cast32_before_node(graph, node, global_input_name_dict)
425
+ insert_cast16_after_node(graph, node, global_input_name_dict)
426
+
427
+
428
+ # Todo: global_input_name_dict still not fill value
429
+ def insert_cast32_before_node(
430
+ graph: onnx_proto.GraphProto, node: onnx_proto.NodeProto, global_input_name_dict
431
+ ):
432
+ for i, input_name in enumerate(node.input):
433
+ for value_info in itertools.chain(graph.value_info, graph.input):
434
+ if input_name == value_info.name:
435
+ if (
436
+ value_info.type.tensor_type.elem_type
437
+ != onnx_proto.TensorProto.FLOAT16
438
+ ):
439
+ break
440
+ cast_output_name = node.name + "_input_cast_" + str(i)
441
+ add_new_value_info(
442
+ graph, value_info, cast_output_name, onnx_proto.TensorProto.FLOAT
443
+ )
444
+ cast_node_name = node.name + "_input_cast" + str(i)
445
+ add_cast_node(
446
+ graph,
447
+ [input_name],
448
+ [cast_output_name],
449
+ cast_node_name,
450
+ onnx_proto.TensorProto.FLOAT,
451
+ )
452
+ node.input[i] = cast_output_name
453
+ break
454
+
455
+
456
+ # Todo: global_input_name_dict still not fill value
457
+ def insert_cast16_after_node(
458
+ graph: onnx_proto.GraphProto, node: onnx_proto.NodeProto, global_input_name_dict
459
+ ):
460
+ for i, output_name in enumerate(node.output):
461
+ for value_info in itertools.chain(graph.value_info, graph.output):
462
+ if output_name == value_info.name:
463
+ if (
464
+ value_info.type.tensor_type.elem_type
465
+ != onnx_proto.TensorProto.FLOAT
466
+ ):
467
+ break
468
+ cast_input_name = node.name + "_output_cast_" + str(i)
469
+ add_new_value_info(
470
+ graph, value_info, cast_input_name, onnx_proto.TensorProto.FLOAT
471
+ )
472
+ value_info.type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT16
473
+ cast_node_name = node.name + "_output_cast" + str(i)
474
+ add_cast_node(
475
+ graph,
476
+ [cast_input_name],
477
+ [output_name],
478
+ cast_node_name,
479
+ onnx_proto.TensorProto.FLOAT16,
480
+ )
481
+ node.output[i] = cast_input_name
482
+ break
483
+
484
+
485
+ # Process tensor data in attribute of the node
486
+ def process_tensor_in_node(
487
+ graph: onnx_proto.GraphProto,
488
+ op_block_list: list,
489
+ node_block_list: list,
490
+ min_positive_val,
491
+ max_finite_val,
492
+ ):
493
+ value_info_block_list = set() # This is for later use, not in this step
494
+ for node in graph.node:
495
+ # NOTE: "Cast" operation cannot change its output type because it is strongly typed.
496
+ if (
497
+ (node.op_type in op_block_list)
498
+ or (node.name in node_block_list)
499
+ or (node.op_type == "Cast")
500
+ ):
501
+ # if (node.op_type in op_block_list) or (node.name in node_block_list):
502
+ # Only need to block the output value_info changing
503
+ for output_name in node.output:
504
+ value_info_block_list.add(output_name)
505
+ else:
506
+ for attr in node.attribute:
507
+ # one tensor
508
+ if attr.t.data_type == onnx_proto.TensorProto.FLOAT:
509
+ attr.t.CopyFrom(
510
+ convert_tensor_float_to_float16(
511
+ attr.t, min_positive_val, max_finite_val
512
+ )
513
+ )
514
+ # list of tensor
515
+ for t in attr.tensors:
516
+ if t.data_type == onnx_proto.TensorProto.FLOAT:
517
+ t.CopyFrom(
518
+ convert_tensor_float_to_float16(
519
+ t, min_positive_val, max_finite_val
520
+ )
521
+ )
522
+ return value_info_block_list
523
+
524
+
525
+ # Change all the value info type from float32 to float16 if not in block list
526
+ def process_value_info(graph: onnx_proto.GraphProto, value_info_block_list: list):
527
+ for value_info in graph.value_info:
528
+ if value_info.name in value_info_block_list:
529
+ continue
530
+ else:
531
+ if value_info.type.tensor_type.elem_type == onnx_proto.TensorProto.FLOAT:
532
+ value_info.type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT16
533
+
534
+
535
+ # Initializer is 'edge' type, so doesn't have value_info
536
+ def process_initializers(
537
+ graph: onnx_proto.GraphProto,
538
+ op_block_list,
539
+ node_block_list,
540
+ min_positive_val,
541
+ max_finite_val,
542
+ ):
543
+ # Find the input of the block node, don't need to change this kind of initializer
544
+ initializer_block_list = set()
545
+ for node in graph.node:
546
+ if (node.op_type in op_block_list) or (node.name in node_block_list):
547
+ for (
548
+ input_name
549
+ ) in (
550
+ node.input
551
+ ): # some is initializer, some is value_info, can't distinguish but doesn't matter
552
+ initializer_block_list.add(input_name)
553
+ # Process initializers
554
+ for initializer in graph.initializer:
555
+ if initializer.name not in initializer_block_list:
556
+ if initializer.data_type == onnx_proto.TensorProto.FLOAT:
557
+ convert_tensor_float_to_float16(
558
+ initializer, min_positive_val, max_finite_val
559
+ )
560
+
561
+
562
+ def get_next_level_graph(
563
+ graph: onnx_proto.GraphProto, op_block_list: list, node_block_list: list
564
+ ):
565
+ sub_graph_list = []
566
+ for node in graph.node:
567
+ if node.op_type in op_block_list or node.name in node_block_list:
568
+ continue
569
+ for attr in node.attribute:
570
+ # Check if sub-graph exist
571
+ if len(attr.g.node) > 0: # single sub-graph
572
+ sub_graph_list.append(attr.g)
573
+ for g in attr.graphs:
574
+ if len(g.node) > 0: # multiple sub-graphs
575
+ sub_graph_list.append(g)
576
+ return sub_graph_list
577
+
578
+
579
+ def add_cast_node(
580
+ graph: onnx_proto.GraphProto,
581
+ inputs: list,
582
+ outputs: list,
583
+ node_name: str,
584
+ to_type: int,
585
+ ):
586
+ new_node = [helper.make_node("Cast", inputs, outputs, to=to_type, name=node_name)]
587
+ graph.node.extend(new_node)
588
+
589
+
590
+ def add_new_value_info(
591
+ graph: onnx_proto.GraphProto,
592
+ exist_value_info: onnx_proto.ValueInfoProto,
593
+ name: str,
594
+ dtype: int,
595
+ ):
596
+ new_value_info = graph.value_info.add()
597
+ new_value_info.CopyFrom(exist_value_info)
598
+ new_value_info.name = name
599
+ new_value_info.type.tensor_type.elem_type = dtype
600
+
601
+
602
+ # Find the node that has the specified output name
603
+ def find_upstream_node_by_output_name(graph: onnx_proto.GraphProto, output_name: str):
604
+ nodes = []
605
+ for node in graph.node:
606
+ if output_name in node.output:
607
+ nodes.append(node)
608
+ assert len(nodes) <= 1 # Suppose there is less than one node found
609
+ return nodes
610
+
611
+
612
+ # Find the node that has the specified input name, including in subgraphs
613
+ def find_downstream_node_by_input_name(
614
+ graph: onnx_proto.GraphProto, input_name: str, include_subgraphs=True
615
+ ):
616
+ nodes = []
617
+
618
+ # Check nodes in current graph
619
+ for node in graph.node:
620
+ if input_name in node.input:
621
+ nodes.append(node)
622
+
623
+ if not include_subgraphs:
624
+ continue
625
+
626
+ # Recursively check subgraphs in node attributes
627
+ for attr in node.attribute:
628
+ if attr.type == onnx_proto.AttributeProto.GRAPH:
629
+ # Single subgraph
630
+ if len(attr.g.node) > 0:
631
+ nodes.extend(find_downstream_node_by_input_name(attr.g, input_name))
632
+
633
+ # Multiple subgraphs
634
+ if attr.type == onnx_proto.AttributeProto.GRAPHS:
635
+ for g in attr.graphs:
636
+ if len(g.node) > 0:
637
+ nodes.extend(find_downstream_node_by_input_name(g, input_name))
638
+
639
+ return nodes
640
+
641
+
642
+ # Remove identity node
643
+ def remove_identity_node_from_model(model: onnx_proto.ModelProto):
644
+ remove_identity_node_from_graph(model.graph)
645
+ try:
646
+ from onnx.shape_inference import infer_shapes
647
+
648
+ func_infer_shape = infer_shapes
649
+ model = func_infer_shape(model)
650
+ return model
651
+ finally:
652
+ pass
653
+
654
+
655
+ # Remove identity node
656
+ def remove_identity_node_from_graph(graph: onnx_proto.GraphProto):
657
+ for curr_node in graph.node:
658
+ if curr_node.op_type == "Identity":
659
+ for input_name in curr_node.input:
660
+ upstream_nodes = find_upstream_node_by_output_name(graph, input_name)
661
+ for u_node in upstream_nodes:
662
+ if u_node is not None:
663
+ u_node.output[0] = curr_node.output[0]
664
+ graph.node.remove(curr_node)
665
+
666
+
667
+ def convert_float_to_float16_model_path(
668
+ model_path, min_positive_val=1e-7, max_finite_val=1e4, keep_io_types=False
669
+ ):
670
+ """
671
+ Convert tensor float type in the ONNX Model to tensor float16.
672
+ *It is to fix an issue that infer_shapes func cannot be used to infer >2GB models.
673
+ *But this function can be applied to all model sizes.
674
+ :param model_path: ONNX Model path
675
+ :return: converted ONNX ModelProto object
676
+ Examples
677
+ ::
678
+ #Convert to ONNX ModelProto object and save model binary file:
679
+ from onnxmltools.utils.float16_converter import convert_float_to_float16_model_path
680
+ new_onnx_model = convert_float_to_float16_model_path('model.onnx')
681
+ onnx.save(new_onnx_model, 'new_model.onnx')
682
+ """
683
+
684
+ disable_shape_infer = False
685
+ if pv.Version(onnx.__version__) >= pv.Version("1.8"):
686
+ try:
687
+ # infer_shapes_path can be applied to all model sizes
688
+ from onnx.shape_inference import infer_shapes_path
689
+ import tempfile
690
+ import os
691
+
692
+ # shape_infer_model_path should be in the same folder of model_path
693
+ with tempfile.NamedTemporaryFile(
694
+ dir=os.path.dirname(model_path)
695
+ ) as tmpfile:
696
+ shape_infer_model_path = tmpfile.name
697
+ infer_shapes_path(model_path, shape_infer_model_path)
698
+ model = onnx.load(shape_infer_model_path)
699
+ disable_shape_infer = True
700
+ finally:
701
+ pass
702
+ if not disable_shape_infer:
703
+ model = onnx.load(model_path)
704
+ return convert_float_to_float16(
705
+ model, min_positive_val, max_finite_val, keep_io_types, disable_shape_infer
706
+ )
707
+
708
+
709
+ def remove_unnecessary_cast_node(graph_proto: onnx_proto.GraphProto):
710
+ # 1. find all cast nodes in the graph
711
+ cast_node_list = []
712
+ input_name_to_cast_node_dict = {}
713
+ output_name_to_cast_node_dict = {}
714
+ # using name as key to point to a node. because node object cannot be key
715
+ name_to_node_dict = {}
716
+ for node in graph_proto.node:
717
+ if node.op_type == "Cast":
718
+ # if node.name not in ["graph_input_cast0", "graph_output_cast0"]:
719
+ cast_node_list.append(node)
720
+
721
+ name_to_node_dict[node.name] = node
722
+ for input_name in node.input:
723
+ input_name_to_cast_node_dict[input_name] = node
724
+ for output_name in node.output:
725
+ output_name_to_cast_node_dict[output_name] = node
726
+
727
+ # 2. find upstream and downstream node of the cast node
728
+ cast_node_upstream_dict = {} # mapping cast node(name) to its upstream node
729
+ cast_node_downstream_dict = {} # mapping cast node(name) to its downstream node
730
+ for current_node in graph_proto.node:
731
+ # find the downstream node(s)
732
+ for input_name in current_node.input:
733
+ if input_name in output_name_to_cast_node_dict:
734
+ # found the downstream node of the cast node, might be multiple
735
+ cast_node = output_name_to_cast_node_dict[input_name]
736
+ if cast_node.name not in cast_node_downstream_dict:
737
+ cast_node_downstream_dict[cast_node.name] = current_node
738
+ else: # already exists one downstream node, make it a list
739
+ existing_downstream_nodes = cast_node_downstream_dict[
740
+ cast_node.name
741
+ ]
742
+ if isinstance(existing_downstream_nodes, list):
743
+ existing_downstream_nodes.append(current_node)
744
+ else: # make a list
745
+ existing_downstream_nodes = [
746
+ existing_downstream_nodes,
747
+ current_node,
748
+ ]
749
+ cast_node_downstream_dict[cast_node.name] = (
750
+ existing_downstream_nodes
751
+ )
752
+ # find the upstream node
753
+ for output_name in current_node.output:
754
+ if output_name in input_name_to_cast_node_dict:
755
+ # found the upstream node of the cast node, should be unique
756
+ cast_node = input_name_to_cast_node_dict[output_name]
757
+ cast_node_upstream_dict[cast_node.name] = current_node
758
+
759
+ # 3. remove the cast node which upstream is 'Constant'
760
+ for cast_node_name, upstream_node in cast_node_upstream_dict.items():
761
+ cast_node = name_to_node_dict[cast_node_name]
762
+ if upstream_node.op_type == "Constant":
763
+ cast_node_list.remove(cast_node)
764
+
765
+ # 4. find (cast_to_fp16, cast_to_fp32) pairs where --fp32--> cast_to_fp16 --fp16--> cast_to_fp32.
766
+ remove_candidate = []
767
+
768
+ name_to_value_info = {
769
+ value_info.name: value_info
770
+ for value_info in itertools.chain(graph_proto.value_info, graph_proto.input)
771
+ }
772
+
773
+ def get_type(name: str) -> Optional[int]:
774
+ if name in name_to_value_info:
775
+ return name_to_value_info[name].type
776
+ else:
777
+ # `name` has no value info.
778
+ return None
779
+
780
+ for cast_node_name, downstream_node in cast_node_downstream_dict.items():
781
+ cast_node = name_to_node_dict[cast_node_name]
782
+ if len(cast_node.input) != 1:
783
+ raise RuntimeError(
784
+ f"Cast node {cast_node_name} should have only one input, but has {len(cast_node.input)}."
785
+ )
786
+
787
+ input_type = get_type(cast_node.input[0])
788
+ if input_type != onnx_proto.TensorProto.FLOAT:
789
+ continue
790
+ if isinstance(downstream_node, list):
791
+ for dn in downstream_node:
792
+ if (
793
+ dn.op_type == "Cast"
794
+ and dn.attribute[0].i == 32
795
+ and cast_node.attribute[0].i == 16
796
+ and dn in cast_node_list
797
+ and cast_node in cast_node_list
798
+ ):
799
+ remove_candidate.append((cast_node, dn))
800
+ else:
801
+ if (
802
+ downstream_node.op_type == "Cast"
803
+ and cast_node.attribute[0].i == FLOAT16
804
+ and downstream_node.attribute[0].i == FLOAT32
805
+ and downstream_node in cast_node_list
806
+ and cast_node in cast_node_list
807
+ ):
808
+ remove_candidate.append((cast_node, downstream_node))
809
+
810
+ # 5. change "upstream --fp32--> cast_to_fp16 --fp16--> cast_to_fp32 --fp32--> downstream" to
811
+ # "upstream --fp32--> downstream".
812
+ for cast_node_pair in remove_candidate:
813
+ first_cast_node = cast_node_pair[0]
814
+ second_cast_node = cast_node_pair[1]
815
+ upstream_node = cast_node_upstream_dict.get(first_cast_node.name)
816
+ downstream_node = cast_node_downstream_dict.get(second_cast_node.name)
817
+ if upstream_node is None and downstream_node is not None:
818
+ # The upstream_node should be graph input
819
+ out = first_cast_node.input[0]
820
+ for i, input_name in enumerate(downstream_node.input):
821
+ for output_name in second_cast_node.output:
822
+ if input_name == output_name:
823
+ # change the input as the upstream node's output
824
+ downstream_node.input[i] = out
825
+ elif upstream_node is not None and downstream_node is None:
826
+ raise ValueError(
827
+ "The downstream node of the second cast node should be graph output"
828
+ )
829
+ else:
830
+ # find the upstream node's output to first_cast_node
831
+ out = None
832
+ for output_name in upstream_node.output:
833
+ if output_name == first_cast_node.input[0]:
834
+ out = output_name
835
+ break
836
+ # find the downstream node's input as second_cast_node's output
837
+ for i, input_name in enumerate(downstream_node.input):
838
+ for output_name in second_cast_node.output:
839
+ if input_name == output_name:
840
+ # change the input as the upstream node's output
841
+ downstream_node.input[i] = out
842
+
843
+ # 6. remove the cast node pair
844
+ for cast_node_pair in remove_candidate:
845
+ graph_proto.node.remove(cast_node_pair[0])
846
+ graph_proto.node.remove(cast_node_pair[1])
847
+
848
+
849
+ # Check if the model is already converted to float16
850
+ def check_if_fp16_ready(graph_proto):
851
+ # Check graph input and ouput
852
+ is_value_info_fp16 = False
853
+ for value_info in itertools.chain(
854
+ graph_proto.output, graph_proto.input, graph_proto.value_info
855
+ ):
856
+ if value_info.type.tensor_type.elem_type == onnx_proto.TensorProto.FLOAT16:
857
+ is_value_info_fp16 = True
858
+ break
859
+
860
+ # Check initializer
861
+ is_initializer_fp16 = False
862
+ for initializer in graph_proto.initializer:
863
+ if initializer.data_type == onnx_proto.TensorProto.FLOAT16:
864
+ is_initializer_fp16 = True
865
+ break
866
+
867
+ # Check cast node
868
+ has_cast_node_fp16 = False
869
+ for node in graph_proto.node:
870
+ if node.op_type == "Cast" and node.attribute[0].i == FLOAT16:
871
+ has_cast_node_fp16 = True
872
+ break
873
+
874
+ # Any of above flags is True, return True
875
+ if is_value_info_fp16 or is_initializer_fp16 or has_cast_node_fp16:
876
+ return True # already converted to float16
877
+ else:
878
+ return False # not converted to float16 yet
quantize_extended.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Script to quantize ONNX models to additional formats: int4, int8, etc.
3
+ Based on transformers.js/scripts/quantize.py, extended for more quantization options.
4
+ """
5
+
6
+ from enum import Enum
7
+ from tqdm import tqdm
8
+ from typing import Set, List, Optional
9
+ import onnx
10
+ import os
11
+ from dataclasses import dataclass, field
12
+ from transformers import HfArgumentParser
13
+ from onnxruntime.quantization import QuantType, QuantizationMode
14
+ from onnxruntime.quantization.onnx_quantizer import ONNXQuantizer
15
+ from onnxruntime.quantization.registry import IntegerOpsRegistry
16
+ from onnxruntime.quantization.matmul_nbits_quantizer import MatMulNBitsQuantizer
17
+ from onnxruntime.quantization.matmul_bnb4_quantizer import MatMulBnb4Quantizer
18
+ import float16
19
+ import utils
20
+
21
+ class QuantMode(Enum):
22
+ FP16 = "fp16"
23
+ Q8 = "q8"
24
+ QI8 = "int8"
25
+ QU8 = "uint8"
26
+ Q4 = "q4"
27
+ Q4F16 = "q4f16"
28
+ BNB4 = "bnb4"
29
+ INT4 = "int4"
30
+ INT8 = "int8"
31
+
32
+ QUANTIZE_SUFFIX_MAPPING = {
33
+ QuantMode.Q8: "quantized",
34
+ QuantMode.INT4: "int4",
35
+ QuantMode.INT8: "int8",
36
+ }
37
+
38
+ QUANTIZE_OPTIONS = tuple(x.value for x in QuantMode)
39
+ QUINT8_OPS = (
40
+ "Conv",
41
+ "GroupQueryAttention",
42
+ "MultiHeadAttention",
43
+ )
44
+
45
+ @dataclass
46
+ class IOArguments:
47
+ input_folder: str = field(metadata={"help": "Path of the input folder containing the .onnx models to quantize"})
48
+ output_folder: str = field(metadata={"help": "Path of the output folder where the quantized .onnx models will be saved"})
49
+
50
+ @dataclass
51
+ class QuantizationArguments:
52
+ modes: QuantMode = field(default=QUANTIZE_OPTIONS, metadata={"help": "Quantization mode to use.", "choices": QUANTIZE_OPTIONS, "nargs": "+",})
53
+ per_channel: bool = field(default=None, metadata={"help": "Whether to quantize weights per channel"})
54
+ reduce_range: bool = field(default=None, metadata={"help": "Whether to quantize weights with 7-bits."})
55
+ block_size: int = field(default=None, metadata={"help": "Block size for blockwise quantization."})
56
+ is_symmetric: bool = field(default=True, metadata={"help": "Indicate whether to quantize the model symmetrically"})
57
+ accuracy_level: int = field(default=None, metadata={"help": "Accuracy level of the 4-bit quantized MatMul computation."})
58
+ quant_type: int = field(default=MatMulBnb4Quantizer.NF4, metadata={"help": "Quantization data type. 0: FP4, 1: NF4", "choices": [MatMulBnb4Quantizer.FP4, MatMulBnb4Quantizer.NF4],})
59
+ op_block_list: List[str] = field(default=None, metadata={"help": "List of operators to exclude from quantization.", "nargs": "+",})
60
+
61
+ def quantize_int4(
62
+ model: onnx.ModelProto,
63
+ save_path: str,
64
+ block_size: int = 32,
65
+ is_symmetric: bool = True,
66
+ accuracy_level: int = 4,
67
+ ):
68
+ """
69
+ Quantize the weights of the model from float32 to 4-bit int using MatMulNBitsQuantizer
70
+ """
71
+ quantizer = MatMulNBitsQuantizer(
72
+ model=model,
73
+ block_size=block_size,
74
+ is_symmetric=is_symmetric,
75
+ accuracy_level=accuracy_level,
76
+ )
77
+ quantizer.process()
78
+ utils.check_and_save_model(quantizer.model.model, save_path)
79
+ return quantizer.model.model
80
+
81
+ def quantize_int8(
82
+ model: onnx.ModelProto,
83
+ save_path: str,
84
+ per_channel: bool = False,
85
+ reduce_range: bool = False,
86
+ weight_type: QuantType = QuantType.QInt8,
87
+ op_block_list: Optional[List[str]] = None,
88
+ ):
89
+ """
90
+ Quantize the weights of the model from float32 to int8
91
+ """
92
+ op_types_to_quantize = set(IntegerOpsRegistry.keys())
93
+ if op_block_list is not None:
94
+ op_types_to_quantize.difference_update(op_block_list)
95
+
96
+ quantizer = ONNXQuantizer(
97
+ model,
98
+ per_channel,
99
+ reduce_range,
100
+ mode=QuantizationMode.IntegerOps,
101
+ static=False,
102
+ weight_qType=weight_type,
103
+ activation_qType=QuantType.QUInt8,
104
+ tensors_range=None,
105
+ nodes_to_quantize=[],
106
+ nodes_to_exclude=[],
107
+ op_types_to_quantize=op_types_to_quantize,
108
+ extra_options=dict(EnableSubgraph=True, MatMulConstBOnly=True),
109
+ )
110
+ quantizer.quantize_model()
111
+ utils.check_and_save_model(quantizer.model.model, save_path)
112
+ return quantizer.model.model
113
+
114
+ def main():
115
+ parser = HfArgumentParser((IOArguments, QuantizationArguments))
116
+ io_args, quantization_args = parser.parse_args_into_dataclasses()
117
+ input_folder = io_args.input_folder
118
+ output_folder = io_args.output_folder
119
+ if not quantization_args.modes:
120
+ raise ValueError("At least one quantization mode must be specified")
121
+
122
+ if not os.path.exists(input_folder):
123
+ raise ValueError(f"Input folder {input_folder} does not exist")
124
+
125
+ model_names_or_paths = [
126
+ os.path.join(input_folder, file)
127
+ for file in os.listdir(input_folder)
128
+ if file.endswith(".onnx")
129
+ ]
130
+ if not model_names_or_paths:
131
+ raise ValueError(f"No .onnx models found in {input_folder}")
132
+
133
+ os.makedirs(output_folder, exist_ok=True)
134
+
135
+ for model_path in tqdm(model_names_or_paths, desc="Models"):
136
+ file_name_without_extension = os.path.splitext(os.path.basename(model_path))[0]
137
+ model = onnx.load_model(model_path)
138
+ for mode in tqdm(quantization_args.modes, desc="Modes"):
139
+ try:
140
+ suffix = QUANTIZE_SUFFIX_MAPPING.get(QuantMode(mode), mode)
141
+ except Exception:
142
+ suffix = mode
143
+ save_path = os.path.join(output_folder, f"{file_name_without_extension}_{suffix}.onnx")
144
+ mode_enum = QuantMode(mode)
145
+ try:
146
+ if mode_enum == QuantMode.FP16:
147
+ float16.convert_float_to_float16(
148
+ model,
149
+ keep_io_types=True,
150
+ disable_shape_infer=False,
151
+ op_block_list=quantization_args.op_block_list or []
152
+ )
153
+
154
+ elif mode_enum == QuantMode.INT4 or mode_enum == QuantMode.Q4:
155
+ quantize_int4(
156
+ model,
157
+ save_path,
158
+ block_size=quantization_args.block_size or 32,
159
+ is_symmetric=quantization_args.is_symmetric,
160
+ accuracy_level=quantization_args.accuracy_level or 0,
161
+ )
162
+
163
+ elif mode_enum == QuantMode.INT8 or mode_enum == QuantMode.QI8:
164
+ quantize_int8(
165
+ model,
166
+ save_path,
167
+ per_channel=quantization_args.per_channel or False,
168
+ reduce_range=quantization_args.reduce_range or False,
169
+ weight_type=QuantType.QInt8,
170
+ op_block_list=quantization_args.op_block_list,
171
+ )
172
+
173
+ elif mode_enum == QuantMode.Q8:
174
+ quantize_int8(
175
+ model,
176
+ save_path,
177
+ per_channel=quantization_args.per_channel or False,
178
+ reduce_range=quantization_args.reduce_range or False,
179
+ weight_type=QuantType.QUInt8,
180
+ op_block_list=quantization_args.op_block_list,
181
+ )
182
+
183
+ # Add other modes as needed (Q4F16, BNB4, QU8, etc.)
184
+ except Exception as e:
185
+ print(f"[WARN] Quantization mode '{mode}' failed for model '{model_path}': {e}")
186
+ continue
187
+
188
+ if __name__ == "__main__":
189
+ main()
utils.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import onnx
2
+ from typing import Optional, Union
3
+ from pathlib import Path
4
+ import os
5
+ import logging
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ # https://github.com/onnx/onnx/pull/6556
11
+ MAXIMUM_PROTOBUF = 2147483648 # 2GiB
12
+
13
+
14
+ def strict_check_model(model_or_path: Union[onnx.ModelProto, str, Path]):
15
+ try:
16
+ onnx.checker.check_model(model_or_path, full_check=True)
17
+ except Exception as e:
18
+ if "No Op registered for" in str(e):
19
+ pass
20
+ else:
21
+ raise e
22
+
23
+
24
+ def check_and_save_model(model: onnx.ModelProto, save_path: Optional[Union[str, Path]]):
25
+ if model.ByteSize() < MAXIMUM_PROTOBUF:
26
+ strict_check_model(model)
27
+ if save_path:
28
+ save_path = Path(save_path).as_posix()
29
+ external_file_name = os.path.basename(save_path) + "_data"
30
+ external_path = os.path.join(os.path.dirname(save_path), external_file_name)
31
+
32
+ if save_path.endswith(".onnx") and os.path.isfile(save_path):
33
+ os.remove(save_path)
34
+ if os.path.isfile(external_path):
35
+ os.remove(external_path)
36
+
37
+ onnx.save(
38
+ model,
39
+ save_path,
40
+ convert_attribute=True,
41
+ )
42
+ elif save_path is not None:
43
+ # path/to/model.onnx
44
+ save_path = Path(save_path).as_posix()
45
+
46
+ external_file_name = os.path.basename(save_path) + "_data"
47
+ # path/to/model.onnx_data
48
+ external_path = os.path.join(os.path.dirname(save_path), external_file_name)
49
+
50
+ if save_path.endswith(".onnx") and os.path.isfile(save_path):
51
+ os.remove(save_path)
52
+ if os.path.isfile(external_path):
53
+ os.remove(external_path)
54
+
55
+ onnx.save(
56
+ model,
57
+ save_path,
58
+ save_as_external_data=True,
59
+ all_tensors_to_one_file=True,
60
+ location=external_file_name,
61
+ convert_attribute=True,
62
+ )
63
+
64
+ else:
65
+ logger.info(
66
+ "Merged ONNX model exceeds 2GB, the model will not be checked without `save_path` given."
67
+ )