# coding=utf-8 # Copyright 2025 HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import gc import json import os import re from safetensors.torch import save_file from safetensors.torch import safe_open from huggingface_hub import snapshot_download from transformers import VoxtralConfig # fmt: off STATE_DICT_MAPPING = { r"^language_model\.lm_head": r"output", r"^language_model\.model\.norm": r"norm", r"^language_model\.model\.embed_tokens": r"tok_embeddings", r"^language_model\.model\.layers\.(\d+)\.input_layernorm": r"layers.\1.attention_norm", r"^language_model\.model\.layers\.(\d+)\.post_attention_layernorm": r"layers.\1.ffn_norm", r"^language_model\.model\.layers\.(\d+)\.self_attn\.(q|k|v|o)_proj": r"layers.\1.attention.w\2", r"^language_model\.model\.layers\.(\d+)\.mlp\.gate_proj": r"layers.\1.feed_forward.w1", r"^language_model\.model\.layers\.(\d+)\.mlp\.down_proj": r"layers.\1.feed_forward.w2", r"^language_model\.model\.layers\.(\d+)\.mlp\.up_proj": r"layers.\1.feed_forward.w3", r"language_model.model.embed_tokens": r"tok_embeddings", r"audio_tower.conv1": r"mm_whisper_embeddings.whisper_encoder.conv_layers.0" , r"audio_tower.conv2": r"mm_whisper_embeddings.whisper_encoder.conv_layers.1" , r"audio_tower.layer_norm": r"mm_whisper_embeddings.whisper_encoder.transformer.norm" , r"audio_tower.layers.(\d+).self_attn.(q|k|v)_proj": r"mm_whisper_embeddings.whisper_encoder.transformer.layers.\1.attention.w\2" , r"audio_tower.layers.(\d+).self_attn.out_proj": r"mm_whisper_embeddings.whisper_encoder.transformer.layers.\1.attention.wo" , r"audio_tower.layers.(\d+).self_attn_layer_norm": r"mm_whisper_embeddings.whisper_encoder.transformer.layers.\1.attention_norm" , r"audio_tower.layers.(\d+).fc(\d+)": r"mm_whisper_embeddings.whisper_encoder.transformer.layers.\1.feed_forward.w\2" , r"audio_tower.layers.(\d+).final_layer_norm": r"mm_whisper_embeddings.whisper_encoder.transformer.layers.\1.ffn_norm" , r"multi_modal_projector.linear_1": r"mm_whisper_embeddings.audio_language_projection.0" , r"multi_modal_projector.linear_2": r"mm_whisper_embeddings.audio_language_projection.2" , } # fmt: on SKIP_KEYS = ["audio_tower.embed_positions.weight"] def add_quantization_config(config, hf_config: VoxtralConfig): quantization_config = hf_config.quantization_config mistral_ignore = [] # keys to ignore in the quantization config for hf_key in quantization_config["ignore"]: mistral_key = map_hf_key_to_mistral(hf_key) mistral_ignore.append(mistral_key) quantization_config["ignore"] = mistral_ignore config["quantization"] = quantization_config return config def map_hf_key_to_mistral(hf_key): """Map a key from HF format to Mistral format""" for pattern, replacement in STATE_DICT_MAPPING.items(): new_key, n_replace = re.subn(pattern, replacement, hf_key) if n_replace > 0: return new_key.replace("weight_scale", "qscale_weight") # If no mapping found, return the original key return hf_key.replace("weight_scale", "qscale_weight") def permute_for_mistral_rope(tensor, n_heads, dim1, dim2): """Reverse the ROPE permutation to get back to Mistral format.""" tensor = tensor.view(n_heads, 2, dim1 // n_heads // 2, dim2) tensor = tensor.transpose(1, 2) tensor = tensor.reshape(dim1, dim2) return tensor def convert_state_dict(hf_state_dict, config): """Convert HF Voxtral state dict to Mistral format""" mistral_dict = {} num_attention_heads = config["n_heads"] hidden_size = config["dim"] head_dim = config["head_dim"] num_key_value_heads = config["n_kv_heads"] key_value_dim = head_dim * num_key_value_heads query_dim = head_dim * num_attention_heads for hf_key, tensor in hf_state_dict.items(): if hf_key in SKIP_KEYS: continue mistral_key = map_hf_key_to_mistral(hf_key) if "language_model" in hf_key: if hf_key.endswith("q_proj.weight"): tensor = permute_for_mistral_rope(tensor, num_attention_heads, query_dim, hidden_size) elif hf_key.endswith("q_proj.weight_scale") and tensor.size(0) == num_attention_heads: tensor = permute_for_mistral_rope(tensor, num_attention_heads, query_dim, 1) elif hf_key.endswith("k_proj.weight"): tensor = permute_for_mistral_rope(tensor, num_key_value_heads, key_value_dim, hidden_size) elif hf_key.endswith("k_proj.weight_scale") and tensor.size(0) == num_key_value_heads: tensor = permute_for_mistral_rope(tensor, num_key_value_heads, key_value_dim, 1) mistral_dict[mistral_key] = tensor return mistral_dict def write_model( input_path_or_repo, output_dir, unquantized_model_path=None, ): print("Converting HF Voxtral model to Mistral format.") os.makedirs(output_dir, exist_ok=True) # Load the HF Voxtral model print(f"Loading HF Voxtral model from {input_path_or_repo}...") hf_config = VoxtralConfig.from_pretrained(input_path_or_repo) local_path = snapshot_download(input_path_or_repo) # Convert config if unquantized_model_path is not None: if os.path.exists(unquantized_model_path): unquantized_model_path = unquantized_model_path else: unquantized_model_path = snapshot_download(unquantized_model_path) config_path = os.path.join(unquantized_model_path, "params.json") with open(config_path, "r") as f: config = json.load(f) config = add_quantization_config(config, hf_config) with open(os.path.join(output_dir, "params.json"), "w") as f: json.dump(config, f, indent=2) else: raise ValueError(f"Unquantized model config not found for {unquantized_model_path}") # Convert state dict print("Converting state dict...") tensor_files = sorted([f for f in os.listdir(os.path.join(local_path)) if f.endswith(".safetensors")]) hf_state_dict = {} for file in tensor_files: file_path = os.path.join(local_path, file) with safe_open(file_path, framework="pt", device="cuda") as f: for key in f.keys(): hf_state_dict[key] = f.get_tensor(key) mistral_state_dict = convert_state_dict(hf_state_dict, config) # save the state dict save_file(mistral_state_dict, os.path.join(output_dir, "consolidated.safetensors")) del hf_state_dict, mistral_state_dict gc.collect() print("Model converted successfully.") def write_tokenizer(input_path_or_repo: str, output_dir: str): """Extract and save the tokenizer from Voxtral model""" from transformers import MistralCommonTokenizer print("Extracting tokenizer...") tokenizer = MistralCommonTokenizer.from_pretrained(input_path_or_repo) tokenizer.save_pretrained(output_dir) print("Tokenizer saved successfully.") def main(): parser = argparse.ArgumentParser(description="Convert HF Voxtral weights to Mistral format") parser.add_argument( "--input_path_or_repo", type=str, default="RedHatAI/Voxtral-Mini-3B-2507-FP8-dynamic", help="Path or repo containing HF Voxtral model", ) parser.add_argument( "--output_dir", type=str, default="Voxtral-Mini-3B-2507-FP8-dynamic-converted", help="Location to write Mistral model and tokenizer", ) parser.add_argument( "--skip_tokenizer", action="store_true", help="Skip tokenizer conversion" ) parser.add_argument( "--unquantized_model_path", type=str, default="mistralai/Voxtral-Mini-3B-2507", help="Path to the unquantized model", ) args = parser.parse_args() write_model( args.input_path_or_repo, args.output_dir, unquantized_model_path=args.unquantized_model_path, ) if not args.skip_tokenizer: write_tokenizer( args.input_path_or_repo, args.output_dir, ) if __name__ == "__main__": main()