shai commited on
Commit
62d81c6
·
0 Parent(s):

Initial commit

Browse files
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ library_name: transformers
5
+ license: apache-2.0
6
+ tags:
7
+ - gpt
8
+ - llm
9
+ - multimodal large language model
10
+ thumbnail: >-
11
+ https://h2o.ai/etc.clientlibs/h2o/clientlibs/clientlib-site/resources/images/favicon.ico
12
+ pipeline_tag: text-generation
13
+ ---
14
+ # Model Card
15
+ The H2OVL-Mississippi-800M is a compact yet powerful vision-language model from H2O.ai, featuring 0.8 billion parameters. Despite its small size, it delivers state-of-the-art performance in text recognition, excelling in the Text Recognition segment of OCRBench and outperforming much larger models in this domain. Built upon the robust architecture of our H2O-Danube language models, the Mississippi-800M extends their capabilities by seamlessly integrating vision and language tasks.
16
+
17
+ <div align="center">
18
+ <img src="./assets/text_recognition.png" alt="Mississippi-2B Benchmarks" width="600"/>
19
+ </div>
20
+
21
+
22
+ ## Key Features:
23
+
24
+ - 0.8 Billion Parameters: Balance between performance and efficiency, making it suitable for OCR and document processing.
25
+ - Trained on 19 million image-text pairs, with a focus on OCR, document comprehension, and chart, figure, and table interpretation, the model is optimized for superior OCR performance.
26
+
27
+ ## Usage
28
+
29
+ ### Install dependencies:
30
+ ```bash
31
+ pip install transformers torch torchvision einops timm peft sentencepiece flash_attn
32
+ ```
33
+
34
+ ### Sample demo:
35
+
36
+ ```python
37
+ import torch
38
+ from transformers import AutoModel, AutoTokenizer
39
+
40
+
41
+ # Set up the model and tokenizer
42
+ model_path = 'h2oai/h2o-mississippi-800m'
43
+ config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
44
+ config.llm_config._attn_implementation = 'flash_attention_2'
45
+ model = AutoModel.from_pretrained(
46
+ model_path,
47
+ torch_dtype=torch.bfloat16,
48
+ config=config,
49
+ low_cpu_mem_usage=True,
50
+ trust_remote_code=True).eval().cuda()
51
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False)
52
+ generation_config = dict(max_new_tokens=1024, do_sample=True)
53
+
54
+ # pure-text conversation
55
+ question = 'Hello, how are you?'
56
+ response, history = model.chat(tokenizer, None, question, generation_config, history=None, return_history=True)
57
+ print(f'User: {question}\nAssistant: {response}')
58
+
59
+
60
+ # Example for single image
61
+ image_file = './examples/image.jpg'
62
+ question = '<image>\nRead the text in the image.'
63
+ response, history = model.chat(tokenizer, image_file, question, generation_config, history=None, return_history=True)
64
+ print(f'User: {question}\nAssistant: {response}')
65
+
66
+
67
+ ```
68
+
69
+
70
+ ## Benchmarks
71
+
72
+ ### 🤗 OpenVLM Leaderboard
73
+
74
+ | Benchmark | acc_n |
75
+ |:-------------------|:-----:|
76
+ | OCRBench | 75.1 |
77
+
78
+
79
+
80
+ ## Acknowledgments
81
+
82
+ We would like to express our gratitude to the [InternVL team at OpenGVLab](https://github.com/OpenGVLab/InternVL) for their research and codebases, upon which we have built and expanded. We also acknowledge the work of the [LLaVA team](https://github.com/haotian-liu/LLaVA) and the [Monkey team](https://github.com/Yuliang-Liu/Monkey/tree/main/project/mini_monkey) for their insights and techniques used in improving multimodal models.
83
+
84
+ ## Disclaimer
85
+
86
+ Please read this disclaimer carefully before using the large language model provided in this repository. Your use of the model signifies your agreement to the following terms and conditions.
87
+
88
+ - Biases and Offensiveness: The large language model is trained on a diverse range of internet text data, which may contain biased, racist, offensive, or otherwise inappropriate content. By using this model, you acknowledge and accept that the generated content may sometimes exhibit biases or produce content that is offensive or inappropriate. The developers of this repository do not endorse, support, or promote any such content or viewpoints.
89
+ - Limitations: The large language model is an AI-based tool and not a human. It may produce incorrect, nonsensical, or irrelevant responses. It is the user's responsibility to critically evaluate the generated content and use it at their discretion.
90
+ - Use at Your Own Risk: Users of this large language model must assume full responsibility for any consequences that may arise from their use of the tool. The developers and contributors of this repository shall not be held liable for any damages, losses, or harm resulting from the use or misuse of the provided model.
91
+ - Ethical Considerations: Users are encouraged to use the large language model responsibly and ethically. By using this model, you agree not to use it for purposes that promote hate speech, discrimination, harassment, or any form of illegal or harmful activities.
92
+ - Reporting Issues: If you encounter any biased, offensive, or otherwise inappropriate content generated by the large language model, please report it to the repository maintainers through the provided channels. Your feedback will help improve the model and mitigate potential issues.
93
+ - Changes to this Disclaimer: The developers of this repository reserve the right to modify or update this disclaimer at any time without prior notice. It is the user's responsibility to periodically review the disclaimer to stay informed about any changes.
94
+
95
+ By using the large language model provided in this repository, you agree to accept and comply with the terms and conditions outlined in this disclaimer. If you do not agree with any part of this disclaimer, you should refrain from using the model and any content generated by it.
added_tokens.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</box>": 32008,
3
+ "</img>": 32001,
4
+ "</quad>": 32004,
5
+ "</ref>": 32006,
6
+ "<IMG_CONTEXT>": 32002,
7
+ "<box>": 32007,
8
+ "<img>": 32000,
9
+ "<quad>": 32003,
10
+ "<ref>": 32005,
11
+ "<|end|>": 32009
12
+ }
assets/text_recognition.png ADDED
config.json ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "H2OVLChatModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_h2ovl_chat.H2OVLChatConfig",
7
+ "AutoModel": "modelling_h2ovl_chat.H2OVLChatModel",
8
+ "AutoModelForCausalLM": "modelling_h2ovl_chat.H2OVLChatModel"
9
+ },
10
+ "downsample_ratio": 0.5,
11
+ "dynamic_image_size": true,
12
+ "force_image_size": 448,
13
+ "llm_config": {
14
+ "_name_or_path": "h2oai/h2o-danube3-500m-chat",
15
+ "add_cross_attention": false,
16
+ "architectures": [
17
+ "LlamaForCausalLM"
18
+ ],
19
+ "attention_bias": false,
20
+ "attention_dropout": 0.0,
21
+ "bos_token_id": 1,
22
+ "chunk_size_feed_forward": 0,
23
+ "cross_attention_hidden_size": null,
24
+ "decoder_start_token_id": null,
25
+ "encoder_no_repeat_ngram_size": 0,
26
+ "eos_token_id": 2,
27
+ "exponential_decay_length_penalty": null,
28
+ "finetuning_task": null,
29
+ "hidden_act": "silu",
30
+ "hidden_size": 1536,
31
+ "initializer_range": 0.02,
32
+ "intermediate_size": 4096,
33
+ "is_decoder": false,
34
+ "is_encoder_decoder": false,
35
+ "length_penalty": 1.0,
36
+ "max_position_embeddings": 8192,
37
+ "mlp_bias": false,
38
+ "model_type": "llama",
39
+ "num_attention_heads": 16,
40
+ "num_hidden_layers": 16,
41
+ "num_key_value_heads": 8,
42
+ "output_attentions": false,
43
+ "output_hidden_states": false,
44
+ "output_scores": false,
45
+ "pad_token_id": 0,
46
+ "prefix": null,
47
+ "pretraining_tp": 1,
48
+ "problem_type": null,
49
+ "pruned_heads": {},
50
+ "remove_invalid_values": false,
51
+ "return_dict": true,
52
+ "return_dict_in_generate": false,
53
+ "rms_norm_eps": 1e-05,
54
+ "rope_scaling": null,
55
+ "rope_theta": 100000,
56
+ "sep_token_id": null,
57
+ "sliding_window": null,
58
+ "suppress_tokens": null,
59
+ "task_specific_params": null,
60
+ "tie_encoder_decoder": false,
61
+ "tie_word_embeddings": false,
62
+ "tokenizer_class": null,
63
+ "torch_dtype": "bfloat16",
64
+ "torchscript": false,
65
+ "transformers_version": "4.42.0.dev0",
66
+ "use_bfloat16": false,
67
+ "use_cache": true,
68
+ "vocab_size": 32010
69
+ },
70
+ "max_dynamic_patch": 6,
71
+ "min_dynamic_patch": 1,
72
+ "model_type": "h2ovl_chat",
73
+ "pad2square": false,
74
+ "ps_version": "v2",
75
+ "select_layer": -1,
76
+ "template": "h2ogpt2",
77
+ "torch_dtype": "bfloat16",
78
+ "transformers_version": null,
79
+ "use_backbone_lora": 0,
80
+ "use_llm_lora": 0,
81
+ "use_thumbnail": true,
82
+ "use_msac": false,
83
+ "vision_config": {
84
+ "_name_or_path": "OpenGVLab/InternViT-300M-448px",
85
+ "add_cross_attention": false,
86
+ "architectures": [
87
+ "InternVisionModel"
88
+ ],
89
+ "attention_dropout": 0.0,
90
+ "auto_map": {
91
+ "AutoConfig": "OpenGVLab/InternViT-300M-448px--configuration_intern_vit.InternVisionConfig",
92
+ "AutoModel": "OpenGVLab/InternViT-300M-448px--modeling_intern_vit.InternVisionModel"
93
+ },
94
+ "bos_token_id": null,
95
+ "chunk_size_feed_forward": 0,
96
+ "cross_attention_hidden_size": null,
97
+ "decoder_start_token_id": null,
98
+ "drop_path_rate": 0.0,
99
+ "dropout": 0.0,
100
+ "eos_token_id": null,
101
+ "exponential_decay_length_penalty": null,
102
+ "finetuning_task": null,
103
+ "hidden_act": "gelu",
104
+ "hidden_size": 1024,
105
+ "image_size": 448,
106
+ "initializer_factor": 1.0,
107
+ "initializer_range": 0.02,
108
+ "intermediate_size": 4096,
109
+ "is_decoder": false,
110
+ "is_encoder_decoder": false,
111
+ "layer_norm_eps": 1e-06,
112
+ "length_penalty": 1.0,
113
+ "max_length": 20,
114
+ "min_length": 0,
115
+ "model_type": "intern_vit_6b",
116
+ "no_repeat_ngram_size": 0,
117
+ "norm_type": "layer_norm",
118
+ "num_attention_heads": 16,
119
+ "num_beam_groups": 1,
120
+ "num_beams": 1,
121
+ "num_channels": 3,
122
+ "num_hidden_layers": 24,
123
+ "num_return_sequences": 1,
124
+ "output_attentions": false,
125
+ "output_hidden_states": false,
126
+ "output_scores": false,
127
+ "pad_token_id": null,
128
+ "patch_size": 14,
129
+ "prefix": null,
130
+ "problem_type": null,
131
+ "pruned_heads": {},
132
+ "qk_normalization": false,
133
+ "qkv_bias": true,
134
+ "remove_invalid_values": false,
135
+ "repetition_penalty": 1.0,
136
+ "return_dict": true,
137
+ "return_dict_in_generate": false,
138
+ "sep_token_id": null,
139
+ "task_specific_params": null,
140
+ "tie_encoder_decoder": false,
141
+ "tie_word_embeddings": true,
142
+ "tokenizer_class": null,
143
+ "torch_dtype": "bfloat16",
144
+ "torchscript": false,
145
+ "transformers_version": "4.42.0.dev0",
146
+ "use_bfloat16": false,
147
+ "use_flash_attn": true
148
+ }
149
+ }
configuration_h2ovl_chat.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from transformers.configuration_utils import PretrainedConfig
3
+ from transformers.utils import logging
4
+ from transformers import AutoConfig
5
+ from transformers.models.auto import CONFIG_MAPPING
6
+
7
+ logger = logging.get_logger(__name__)
8
+
9
+ class H2OVLChatConfig(PretrainedConfig):
10
+ model_type = 'h2ovl_chat'
11
+ is_composition = True
12
+
13
+ def __init__(
14
+ self,
15
+ vision_config=None,
16
+ llm_config=None,
17
+ use_backbone_lora=0,
18
+ use_llm_lora=0,
19
+ pad2square=False,
20
+ select_layer=-4,
21
+ force_image_size=None,
22
+ downsample_ratio=0.5,
23
+ template=None,
24
+ dynamic_image_size=False,
25
+ use_thumbnail=False,
26
+ ps_version='v1',
27
+ min_dynamic_patch=1,
28
+ max_dynamic_patch=6,
29
+ use_msac=False,
30
+ **kwargs):
31
+ super().__init__(**kwargs)
32
+
33
+ if vision_config["model_type"] in CONFIG_MAPPING:
34
+ self.vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
35
+ else:
36
+ self.vision_config = AutoConfig.from_pretrained(vision_config["_name_or_path"], trust_remote_code=True)
37
+ self.vision_config.update(vision_config)
38
+
39
+ if llm_config["model_type"] in CONFIG_MAPPING:
40
+ self.llm_config = CONFIG_MAPPING[llm_config["model_type"]](**llm_config)
41
+ else:
42
+ self.llm_config = AutoConfig.from_pretrained(llm_config["_name_or_path"], trust_remote_code=True)
43
+ self.llm_config.update(llm_config)
44
+
45
+ self.use_backbone_lora = use_backbone_lora
46
+ self.use_llm_lora = use_llm_lora
47
+ self.pad2square = pad2square
48
+ self.select_layer = select_layer
49
+ self.force_image_size = force_image_size
50
+ self.downsample_ratio = downsample_ratio
51
+ self.template = template
52
+ self.dynamic_image_size = dynamic_image_size
53
+ self.use_thumbnail = use_thumbnail
54
+ self.ps_version = ps_version # pixel shuffle version
55
+ self.min_dynamic_patch = min_dynamic_patch
56
+ self.max_dynamic_patch = max_dynamic_patch
57
+ self.use_msac = use_msac
58
+
59
+ logger.info(f'vision_select_layer: {self.select_layer}')
60
+ logger.info(f'ps_version: {self.ps_version}')
61
+ logger.info(f'min_dynamic_patch: {self.min_dynamic_patch}')
62
+ logger.info(f'max_dynamic_patch: {self.max_dynamic_patch}')
63
+
64
+ def to_dict(self):
65
+ """
66
+ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
67
+
68
+ Returns:
69
+ `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
70
+ """
71
+ output = copy.deepcopy(self.__dict__)
72
+ output['vision_config'] = self.vision_config.to_dict()
73
+ output['llm_config'] = self.llm_config.to_dict()
74
+ output['model_type'] = self.__class__.model_type
75
+ output['use_backbone_lora'] = self.use_backbone_lora
76
+ output['use_llm_lora'] = self.use_llm_lora
77
+ output['pad2square'] = self.pad2square
78
+ output['select_layer'] = self.select_layer
79
+ output['force_image_size'] = self.force_image_size
80
+ output['downsample_ratio'] = self.downsample_ratio
81
+ output['template'] = self.template
82
+ output['dynamic_image_size'] = self.dynamic_image_size
83
+ output['use_thumbnail'] = self.use_thumbnail
84
+ output['ps_version'] = self.ps_version
85
+ output['min_dynamic_patch'] = self.min_dynamic_patch
86
+ output['max_dynamic_patch'] = self.max_dynamic_patch
87
+ output['use_msac'] = self.use_msac
88
+
89
+ return output
conversation.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Conversation prompt templates.
3
+
4
+ We kindly request that you import fastchat instead of copying this file if you wish to use it.
5
+ If you have changes in mind, please contribute back so the community can benefit collectively and continue to maintain these valuable templates.
6
+ """
7
+
8
+ import dataclasses
9
+ from enum import IntEnum, auto
10
+ from typing import Any, Dict, List, Tuple, Union
11
+
12
+
13
+ class SeparatorStyle(IntEnum):
14
+ """Separator styles."""
15
+
16
+ ADD_COLON_SINGLE = auto()
17
+ NO_COLON_SINGLE = auto()
18
+
19
+
20
+ @dataclasses.dataclass
21
+ class Conversation:
22
+ """A class that manages prompt templates and keeps all conversation history."""
23
+
24
+ # The name of this template
25
+ name: str
26
+ # The template of the system prompt
27
+ system_template: str = '{system_message}'
28
+ # The system message
29
+ system_message: str = ''
30
+ # The names of two roles
31
+ roles: Tuple[str] = ('USER', 'ASSISTANT')
32
+ # All messages. Each item is (role, message).
33
+ messages: List[List[str]] = ()
34
+ # The number of few shot examples
35
+ offset: int = 0
36
+ # The separator style and configurations
37
+ sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE
38
+ sep: str = '\n'
39
+ sep2: str = None
40
+ # Stop criteria (the default one is EOS token)
41
+ stop_str: Union[str, List[str]] = None
42
+ # Stops generation if meeting any token in this list
43
+ stop_token_ids: List[int] = None
44
+
45
+ def get_prompt(self) -> str:
46
+ """Get the prompt for generation."""
47
+ system_prompt = self.system_template.format(system_message=self.system_message)
48
+ if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
49
+ ret = system_prompt + self.sep
50
+ for role, message in self.messages:
51
+ if message:
52
+ ret += role + ': ' + message + self.sep
53
+ else:
54
+ ret += role + ':'
55
+ return ret
56
+ if self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
57
+ ret = system_prompt
58
+ for role, message in self.messages:
59
+ if message:
60
+ ret += role + message + self.sep
61
+ else:
62
+ ret += role
63
+ return ret
64
+ else:
65
+ raise ValueError(f'Invalid style: {self.sep_style}')
66
+
67
+ def set_system_message(self, system_message: str):
68
+ """Set the system message."""
69
+ self.system_message = system_message
70
+
71
+ def append_message(self, role: str, message: str):
72
+ """Append a new message."""
73
+ self.messages.append([role, message])
74
+
75
+ def update_last_message(self, message: str):
76
+ """Update the last output.
77
+
78
+ The last message is typically set to be None when constructing the prompt,
79
+ so we need to update it in-place after getting the response from a model.
80
+ """
81
+ self.messages[-1][1] = message
82
+
83
+ def to_gradio_chatbot(self):
84
+ """Convert the conversation to gradio chatbot format."""
85
+ ret = []
86
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
87
+ if i % 2 == 0:
88
+ ret.append([msg, None])
89
+ else:
90
+ ret[-1][-1] = msg
91
+ return ret
92
+
93
+ def to_openai_api_messages(self):
94
+ """Convert the conversation to OpenAI chat completion format."""
95
+ ret = [{'role': 'system', 'content': self.system_message}]
96
+
97
+ for i, (_, msg) in enumerate(self.messages[self.offset :]):
98
+ if i % 2 == 0:
99
+ ret.append({'role': 'user', 'content': msg})
100
+ else:
101
+ if msg is not None:
102
+ ret.append({'role': 'assistant', 'content': msg})
103
+ return ret
104
+
105
+ def copy(self):
106
+ return Conversation(
107
+ name=self.name,
108
+ system_template=self.system_template,
109
+ system_message=self.system_message,
110
+ roles=self.roles,
111
+ messages=[[x, y] for x, y in self.messages],
112
+ offset=self.offset,
113
+ sep_style=self.sep_style,
114
+ sep=self.sep,
115
+ sep2=self.sep2,
116
+ stop_str=self.stop_str,
117
+ stop_token_ids=self.stop_token_ids,
118
+ )
119
+
120
+ def dict(self):
121
+ return {
122
+ 'template_name': self.name,
123
+ 'system_message': self.system_message,
124
+ 'roles': self.roles,
125
+ 'messages': self.messages,
126
+ 'offset': self.offset,
127
+ }
128
+
129
+
130
+ # A global registry for all conversation templates
131
+ conv_templates: Dict[str, Conversation] = {}
132
+
133
+
134
+ def register_conv_template(template: Conversation, override: bool = False):
135
+ """Register a new conversation template."""
136
+ if not override:
137
+ assert (
138
+ template.name not in conv_templates
139
+ ), f'{template.name} has been registered.'
140
+
141
+ conv_templates[template.name] = template
142
+
143
+
144
+ def get_conv_template(name: str) -> Conversation:
145
+ """Get a conversation template."""
146
+ return conv_templates[name].copy()
147
+
148
+
149
+
150
+ register_conv_template(
151
+ Conversation(
152
+ name='h2ogpt2',
153
+ roles=('<|prompt|>', '<|answer|>'),
154
+ sep_style=SeparatorStyle.NO_COLON_SINGLE,
155
+ sep='<|end|>',
156
+ stop_token_ids=[
157
+ 2,
158
+ 32009
159
+ ]
160
+ )
161
+ )
generation_config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_sample": true,
3
+ "repetition_penalty": 1.0,
4
+ "temperature": 0.01,
5
+ "top_p": 0.001,
6
+ "top_k": 1,
7
+ "max_length": 1024,
8
+ "eos_token_id": [
9
+ 2,
10
+ 32009
11
+ ],
12
+ "transformers_version": "4.44.0"
13
+ }
image_process.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms as T
3
+ from PIL import Image
4
+ from torchvision.transforms.functional import InterpolationMode
5
+
6
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
7
+ IMAGENET_STD = (0.229, 0.224, 0.225)
8
+
9
+
10
+ def build_transform(input_size):
11
+ MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
12
+ transform = T.Compose([
13
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
14
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
15
+ T.ToTensor(),
16
+ T.Normalize(mean=MEAN, std=STD)
17
+ ])
18
+ return transform
19
+
20
+
21
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
22
+ best_ratio_diff = float('inf')
23
+ best_ratio = (1, 1)
24
+ area = width * height
25
+ for ratio in target_ratios:
26
+ target_aspect_ratio = ratio[0] / ratio[1]
27
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
28
+ if ratio_diff < best_ratio_diff:
29
+ best_ratio_diff = ratio_diff
30
+ best_ratio = ratio
31
+ elif ratio_diff == best_ratio_diff:
32
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
33
+ best_ratio = ratio
34
+ return best_ratio
35
+
36
+
37
+ def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False):
38
+ orig_width, orig_height = image.size
39
+ aspect_ratio = orig_width / orig_height
40
+
41
+ # calculate the existing image aspect ratio
42
+ target_ratios = set(
43
+ (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
44
+ i * j <= max_num and i * j >= min_num)
45
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
46
+
47
+ # find the closest aspect ratio to the target
48
+ target_aspect_ratio = find_closest_aspect_ratio(
49
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size)
50
+
51
+ # calculate the target width and height
52
+ target_width = image_size * target_aspect_ratio[0]
53
+ target_height = image_size * target_aspect_ratio[1]
54
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
55
+
56
+ # resize the image
57
+ resized_img = image.resize((target_width, target_height))
58
+ processed_images = []
59
+ for i in range(blocks):
60
+ box = (
61
+ (i % (target_width // image_size)) * image_size,
62
+ (i // (target_width // image_size)) * image_size,
63
+ ((i % (target_width // image_size)) + 1) * image_size,
64
+ ((i // (target_width // image_size)) + 1) * image_size
65
+ )
66
+ # split the image
67
+ split_img = resized_img.crop(box)
68
+ processed_images.append(split_img)
69
+ assert len(processed_images) == blocks
70
+ if use_thumbnail and len(processed_images) != 1:
71
+ thumbnail_img = image.resize((image_size, image_size))
72
+ processed_images.append(thumbnail_img)
73
+ return processed_images, target_aspect_ratio
74
+
75
+
76
+ def dynamic_preprocess2(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False, prior_aspect_ratio=None):
77
+ orig_width, orig_height = image.size
78
+ aspect_ratio = orig_width / orig_height
79
+
80
+ # calculate the existing image aspect ratio
81
+ target_ratios = set(
82
+ (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
83
+ i * j <= max_num and i * j >= min_num)
84
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
85
+
86
+ new_target_ratios = []
87
+ if prior_aspect_ratio is not None:
88
+ for i in target_ratios:
89
+ if prior_aspect_ratio[0]%i[0] != 0 and prior_aspect_ratio[1]%i[1] != 0:
90
+ new_target_ratios.append(i)
91
+ else:
92
+ continue
93
+
94
+ # find the closest aspect ratio to the target
95
+ target_aspect_ratio = find_closest_aspect_ratio(
96
+ aspect_ratio, new_target_ratios, orig_width, orig_height, image_size)
97
+
98
+ # calculate the target width and height
99
+ target_width = image_size * target_aspect_ratio[0]
100
+ target_height = image_size * target_aspect_ratio[1]
101
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
102
+
103
+ # resize the image
104
+ resized_img = image.resize((target_width, target_height))
105
+ processed_images = []
106
+ for i in range(blocks):
107
+ box = (
108
+ (i % (target_width // image_size)) * image_size,
109
+ (i // (target_width // image_size)) * image_size,
110
+ ((i % (target_width // image_size)) + 1) * image_size,
111
+ ((i // (target_width // image_size)) + 1) * image_size
112
+ )
113
+ # split the image
114
+ split_img = resized_img.crop(box)
115
+ processed_images.append(split_img)
116
+ assert len(processed_images) == blocks
117
+ if use_thumbnail and len(processed_images) != 1:
118
+ thumbnail_img = image.resize((image_size, image_size))
119
+ processed_images.append(thumbnail_img)
120
+ return processed_images
121
+
122
+ def load_image1(image_file, input_size=448, min_num=1, max_num=12):
123
+ image = Image.open(image_file).convert('RGB')
124
+ transform = build_transform(input_size=input_size)
125
+ images, target_aspect_ratio = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, min_num=min_num, max_num=max_num)
126
+ pixel_values = [transform(image) for image in images]
127
+ pixel_values = torch.stack(pixel_values)
128
+ return pixel_values, target_aspect_ratio
129
+
130
+ def load_image2(image_file, input_size=448, min_num=1, max_num=12, target_aspect_ratio=None):
131
+ image = Image.open(image_file).convert('RGB')
132
+ transform = build_transform(input_size=input_size)
133
+ images = dynamic_preprocess2(image, image_size=input_size, use_thumbnail=True, min_num=min_num, max_num=max_num, prior_aspect_ratio=target_aspect_ratio)
134
+ pixel_values = [transform(image) for image in images]
135
+ pixel_values = torch.stack(pixel_values)
136
+ return pixel_values
137
+
138
+ def load_single_image(file_name, max_num=6, msac=False):
139
+ pixel_values, target_aspect_ratio = load_image1(file_name, min_num=1, max_num=max_num)
140
+ pixel_values = pixel_values.to(torch.bfloat16).cuda()
141
+ if not msac:
142
+ num_patches_list = [pixel_values.size(0)]
143
+ return pixel_values, num_patches_list
144
+
145
+ pixel_values2 = load_image2(file_name, min_num=3, max_num=max_num, target_aspect_ratio=target_aspect_ratio)
146
+ pixel_values2 = pixel_values2.to(torch.bfloat16).cuda()
147
+ pixel_values = torch.cat([pixel_values2[:-1], pixel_values[:-1], pixel_values2[-1:]], dim=0).to(torch.bfloat16).cuda()
148
+ num_patches_list = [pixel_values.size(0)] # The number of patches after MSAC
149
+ return pixel_values, num_patches_list
150
+
151
+ def load_multi_images(image_files, max_num=6):
152
+ pixel_values_list = []
153
+ num_patches_list = []
154
+ for image_file in image_files:
155
+ pixel_values, _ = load_image1(image_file, max_num=max_num)
156
+ pixel_values = pixel_values.to(torch.bfloat16).cuda()
157
+ pixel_values_list.append(pixel_values)
158
+ num_patches_list.append(pixel_values.size(0))
159
+ pixel_values = torch.cat(pixel_values_list, dim=0)
160
+
161
+ return pixel_values, num_patches_list
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:260fd8e6c1e92c974fd4fa1a8be5d26330e570c9e09281dd9abd1108b66514d2
3
+ size 1652650984
modelling_h2ovl_chat.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from typing import Any, List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import CrossEntropyLoss
7
+ from transformers import (AutoModel, GenerationConfig, AutoModelForCausalLM, LlamaForCausalLM)
8
+ from transformers.modeling_outputs import CausalLMOutputWithPast
9
+ from transformers.modeling_utils import PreTrainedModel
10
+ from transformers.utils import logging
11
+ from peft import LoraConfig, get_peft_model
12
+ import transformers
13
+ from .conversation import get_conv_template
14
+ from .configuration_h2ovl_chat import H2OVLChatConfig
15
+ from .image_process import load_single_image, load_multi_images
16
+
17
+ logger = logging.get_logger(__name__)
18
+
19
+ def version_cmp(v1, v2, op='eq'):
20
+ import operator
21
+
22
+ from packaging import version
23
+ op_func = getattr(operator, op)
24
+ return op_func(version.parse(v1), version.parse(v2))
25
+
26
+ class H2OVLChatModel(PreTrainedModel):
27
+ config_class = H2OVLChatConfig
28
+ main_input_name = 'pixel_values'
29
+ _supports_flash_attn_2 = True
30
+
31
+ def __init__(self, config: H2OVLChatConfig, vision_model=None, language_model=None):
32
+ super().__init__(config)
33
+
34
+ assert version_cmp(transformers.__version__, '4.37.0', 'ge')
35
+ image_size = config.force_image_size or config.vision_config.image_size
36
+ patch_size = config.vision_config.patch_size
37
+ self.patch_size = patch_size
38
+ self.select_layer = config.select_layer
39
+ self.template = config.template
40
+ self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2))
41
+ self.downsample_ratio = config.downsample_ratio
42
+ self.ps_version = config.ps_version
43
+ self.use_msac = config.use_msac
44
+
45
+ logger.info(f'num_image_token: {self.num_image_token}')
46
+ logger.info(f'ps_version: {self.ps_version}')
47
+ if vision_model is not None:
48
+ self.vision_model = vision_model
49
+ else:
50
+ self.vision_model = AutoModel.from_config(config.vision_config, trust_remote_code=True)
51
+ if language_model is not None:
52
+ self.language_model = language_model
53
+ else:
54
+ self.language_model = AutoModelForCausalLM.from_config(config.llm_config, attn_implementation=config.llm_config._attn_implementation, trust_remote_code=True)
55
+
56
+ vit_hidden_size = config.vision_config.hidden_size
57
+ llm_hidden_size = config.llm_config.hidden_size
58
+
59
+ self.mlp1 = nn.Sequential(
60
+ nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
61
+ nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
62
+ nn.GELU(),
63
+ nn.Linear(llm_hidden_size, llm_hidden_size)
64
+ )
65
+
66
+ self.img_context_token_id = None
67
+ self.conv_template = get_conv_template(self.template)
68
+ if hasattr(config, 'system_message'):
69
+ self.system_message = config.system_message
70
+ else:
71
+ self.system_message = self.conv_template.system_message
72
+ self.num_samples = 0
73
+
74
+ if config.use_backbone_lora:
75
+ self.wrap_backbone_lora(r=config.use_backbone_lora, lora_alpha=2 * config.use_backbone_lora)
76
+
77
+ if config.use_llm_lora:
78
+ self.wrap_llm_lora(r=config.use_llm_lora, lora_alpha=2 * config.use_llm_lora)
79
+
80
+ def wrap_backbone_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
81
+ lora_config = LoraConfig(
82
+ r=r,
83
+ target_modules=['attn.qkv', 'attn.proj', 'mlp.fc1', 'mlp.fc2'],
84
+ lora_alpha=lora_alpha,
85
+ lora_dropout=lora_dropout,
86
+ )
87
+ self.vision_model = get_peft_model(self.vision_model, lora_config)
88
+ self.vision_model.print_trainable_parameters()
89
+
90
+ def wrap_llm_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
91
+ # Determine the target modules based on the architecture of the language model
92
+ if self.llm_arch_name == 'InternLM2ForCausalLM':
93
+ target_modules = ['attention.wqkv', 'attention.wo', 'feed_forward.w1', 'feed_forward.w2', 'feed_forward.w3']
94
+ elif self.llm_arch_name == 'Phi3ForCausalLM':
95
+ target_modules = ['mlp.down_proj', 'mlp.gate_up_proj', 'self_attn.o_proj', 'self_attn.qkv_proj']
96
+ elif self.llm_arch_name in ['Qwen2ForCausalLM', 'LlamaForCausalLM', 'MistralForCausalLM']:
97
+ target_modules = ['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.o_proj',
98
+ 'mlp.gate_proj', 'mlp.down_proj', 'mlp.up_proj']
99
+ else:
100
+ raise NotImplemented
101
+ lora_config = LoraConfig(
102
+ r=r,
103
+ target_modules=target_modules,
104
+ lora_alpha=lora_alpha,
105
+ lora_dropout=lora_dropout,
106
+ task_type='CAUSAL_LM'
107
+ )
108
+ self.language_model = get_peft_model(self.language_model, lora_config)
109
+ self.language_model.enable_input_require_grads()
110
+ self.language_model.print_trainable_parameters()
111
+
112
+ def forward(
113
+ self,
114
+ pixel_values: torch.FloatTensor,
115
+ input_ids: torch.LongTensor = None,
116
+ attention_mask: Optional[torch.Tensor] = None,
117
+ position_ids: Optional[torch.LongTensor] = None,
118
+ image_flags: Optional[torch.LongTensor] = None,
119
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
120
+ labels: Optional[torch.LongTensor] = None,
121
+ use_cache: Optional[bool] = None,
122
+ output_attentions: Optional[bool] = None,
123
+ output_hidden_states: Optional[bool] = None,
124
+ return_dict: Optional[bool] = None,
125
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
126
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
127
+
128
+ image_flags = image_flags.squeeze(-1)
129
+ input_embeds = self.language_model.get_input_embeddings()(input_ids).clone()
130
+
131
+ vit_embeds = self.extract_feature(pixel_values)
132
+ vit_embeds = vit_embeds[image_flags == 1]
133
+ vit_batch_size = pixel_values.shape[0]
134
+
135
+ B, N, C = input_embeds.shape
136
+ input_embeds = input_embeds.reshape(B * N, C)
137
+
138
+ if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0:
139
+ print(f'dynamic ViT batch size: {vit_batch_size}, images per sample: {vit_batch_size / B}, dynamic token length: {N}')
140
+
141
+ input_ids = input_ids.reshape(B * N)
142
+ selected = (input_ids == self.img_context_token_id)
143
+ try:
144
+ input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C)
145
+ ignore_flag = False
146
+ except Exception as e:
147
+ vit_embeds = vit_embeds.reshape(-1, C)
148
+ print(f'warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, '
149
+ f'vit_embeds.shape={vit_embeds.shape}')
150
+ n_token = selected.sum()
151
+ input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:n_token]
152
+ ignore_flag = True
153
+
154
+ input_embeds = input_embeds.reshape(B, N, C)
155
+
156
+ outputs = self.language_model(
157
+ inputs_embeds=input_embeds,
158
+ attention_mask=attention_mask,
159
+ position_ids=position_ids,
160
+ past_key_values=past_key_values,
161
+ use_cache=use_cache,
162
+ output_attentions=output_attentions,
163
+ output_hidden_states=output_hidden_states,
164
+ return_dict=return_dict,
165
+ )
166
+ logits = outputs.logits
167
+
168
+ loss = None
169
+ if labels is not None:
170
+ # Shift so that tokens < n predict n
171
+ shift_logits = logits[..., :-1, :].contiguous()
172
+ shift_labels = labels[..., 1:].contiguous()
173
+ # Flatten the tokens
174
+ loss_fct = CrossEntropyLoss()
175
+ shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size)
176
+ shift_labels = shift_labels.view(-1)
177
+ # Enable model parallelism
178
+ shift_labels = shift_labels.to(shift_logits.device)
179
+ loss = loss_fct(shift_logits, shift_labels)
180
+ if ignore_flag:
181
+ loss = loss * 0.0
182
+
183
+ if not return_dict:
184
+ output = (logits,) + outputs[1:]
185
+ return (loss,) + output if loss is not None else output
186
+
187
+ return CausalLMOutputWithPast(
188
+ loss=loss,
189
+ logits=logits,
190
+ past_key_values=outputs.past_key_values,
191
+ hidden_states=outputs.hidden_states,
192
+ attentions=outputs.attentions,
193
+ )
194
+
195
+ def pixel_shuffle(self, x, scale_factor=0.5):
196
+ n, w, h, c = x.size()
197
+ # N, W, H, C --> N, W, H * scale, C // scale
198
+ x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
199
+ # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
200
+ x = x.permute(0, 2, 1, 3).contiguous()
201
+ # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
202
+ x = x.view(n, int(h * scale_factor), int(w * scale_factor),
203
+ int(c / (scale_factor * scale_factor)))
204
+ if self.ps_version == 'v1':
205
+ warnings.warn("In ps_version 'v1', the height and width have not been swapped back, "
206
+ 'which results in a transposed image.')
207
+ else:
208
+ x = x.permute(0, 2, 1, 3).contiguous()
209
+ return x
210
+
211
+ def extract_feature(self, pixel_values):
212
+ if self.select_layer == -1:
213
+ vit_embeds = self.vision_model(
214
+ pixel_values=pixel_values,
215
+ output_hidden_states=False,
216
+ return_dict=True).last_hidden_state
217
+ else:
218
+ vit_embeds = self.vision_model(
219
+ pixel_values=pixel_values,
220
+ output_hidden_states=True,
221
+ return_dict=True).hidden_states[self.select_layer]
222
+ vit_embeds = vit_embeds[:, 1:, :]
223
+
224
+ h = w = int(vit_embeds.shape[1] ** 0.5)
225
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
226
+ vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
227
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
228
+ vit_embeds = self.mlp1(vit_embeds)
229
+ return vit_embeds
230
+
231
+ def chat(self, tokenizer, image_files, question, generation_config , max_tiles=6, history=None, return_history=False,
232
+ num_patches_list=None, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>', IMG_CONTEXT_TOKEN='<IMG_CONTEXT>',
233
+ verbose=False):
234
+
235
+ if image_files:
236
+ if isinstance(image_files, list):
237
+ pixel_values, num_patches_list = load_multi_images(image_files, max_num=max_tiles) # Load multiple images
238
+ else:
239
+ pixel_values, num_patches_list = load_single_image(image_files, max_num=max_tiles, msac=self.use_msac) # Load single image
240
+ else:
241
+ pixel_values = None
242
+ num_patches_list = []
243
+
244
+
245
+ if history is None and pixel_values is not None and '<image>' not in question:
246
+ question = '<image>\n' + question
247
+
248
+ if num_patches_list is None:
249
+ num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else []
250
+
251
+ assert pixel_values is None or len(pixel_values) == sum(num_patches_list)
252
+
253
+ img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
254
+ self.img_context_token_id = img_context_token_id
255
+
256
+ template = get_conv_template(self.template)
257
+ template.system_message = self.system_message
258
+ eos_token_id = tokenizer.convert_tokens_to_ids(template.sep)
259
+
260
+ history = [] if history is None else history
261
+ for (old_question, old_answer) in history:
262
+ template.append_message(template.roles[0], old_question)
263
+ template.append_message(template.roles[1], old_answer)
264
+ template.append_message(template.roles[0], question)
265
+ template.append_message(template.roles[1], None)
266
+ query = template.get_prompt()
267
+
268
+ if verbose and pixel_values is not None:
269
+ image_bs = pixel_values.shape[0]
270
+ print(f'dynamic ViT batch size: {image_bs}')
271
+
272
+ for num_patches in num_patches_list:
273
+ image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
274
+ query = query.replace('<image>', image_tokens, 1)
275
+
276
+ model_inputs = tokenizer(query, return_tensors='pt')
277
+ input_ids = model_inputs['input_ids'].cuda()
278
+ attention_mask = model_inputs['attention_mask'].cuda()
279
+ generation_config['eos_token_id'] = eos_token_id
280
+ generation_output = self.generate(
281
+ pixel_values=pixel_values,
282
+ input_ids=input_ids,
283
+ attention_mask=attention_mask,
284
+ **generation_config
285
+ )
286
+ response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0]
287
+ response = response.split(template.sep)[0].strip()
288
+ history.append((question, response))
289
+ if return_history:
290
+ return response, history
291
+ else:
292
+ query_to_print = query.replace(IMG_CONTEXT_TOKEN, '')
293
+ query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '<image>')
294
+ if verbose:
295
+ print(query_to_print, response)
296
+ return response
297
+
298
+ @torch.no_grad()
299
+ def generate(
300
+ self,
301
+ pixel_values: Optional[torch.FloatTensor] = None,
302
+ input_ids: Optional[torch.FloatTensor] = None,
303
+ attention_mask: Optional[torch.LongTensor] = None,
304
+ visual_features: Optional[torch.FloatTensor] = None,
305
+ generation_config: Optional[GenerationConfig] = None,
306
+ output_hidden_states: Optional[bool] = None,
307
+ return_dict: Optional[bool] = None,
308
+ **generate_kwargs,
309
+ ) -> torch.LongTensor:
310
+
311
+ assert self.img_context_token_id is not None
312
+ if pixel_values is not None:
313
+ if visual_features is not None:
314
+ vit_embeds = visual_features
315
+ else:
316
+ vit_embeds = self.extract_feature(pixel_values)
317
+ input_embeds = self.language_model.get_input_embeddings()(input_ids)
318
+ B, N, C = input_embeds.shape
319
+ input_embeds = input_embeds.reshape(B * N, C)
320
+
321
+ input_ids = input_ids.reshape(B * N)
322
+ selected = (input_ids == self.img_context_token_id)
323
+ assert selected.sum() != 0
324
+ input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)
325
+
326
+ input_embeds = input_embeds.reshape(B, N, C)
327
+ else:
328
+ input_embeds = self.language_model.get_input_embeddings()(input_ids)
329
+
330
+ outputs = self.language_model.generate(
331
+ inputs_embeds=input_embeds,
332
+ attention_mask=attention_mask,
333
+ generation_config=generation_config,
334
+ output_hidden_states=output_hidden_states,
335
+ return_dict=return_dict,
336
+ use_cache=True,
337
+ **generate_kwargs,
338
+ )
339
+
340
+ return outputs
special_tokens_map.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<img>",
4
+ "</img>",
5
+ "<IMG_CONTEXT>",
6
+ "<quad>",
7
+ "</quad>",
8
+ "<ref>",
9
+ "</ref>",
10
+ "<box>",
11
+ "</box>",
12
+ "<|end|>"
13
+ ],
14
+ "bos_token": {
15
+ "content": "<s>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false
20
+ },
21
+ "cls_token": {
22
+ "content": "</s>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false
27
+ },
28
+ "eos_token": {
29
+ "content": "</s>",
30
+ "lstrip": false,
31
+ "normalized": false,
32
+ "rstrip": false,
33
+ "single_word": false
34
+ },
35
+ "pad_token": {
36
+ "content": "<unk>",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false
41
+ },
42
+ "sep_token": {
43
+ "content": "</s>",
44
+ "lstrip": false,
45
+ "normalized": false,
46
+ "rstrip": false,
47
+ "single_word": false
48
+ },
49
+ "unk_token": {
50
+ "content": "<unk>",
51
+ "lstrip": false,
52
+ "normalized": false,
53
+ "rstrip": false,
54
+ "single_word": false
55
+ }
56
+ }
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dadfd56d766715c61d2ef780a525ab43b8e6da4de6865bda3d95fdef5e134055
3
+ size 493443
tokenizer_config.json ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": false,
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": "<unk>",
8
+ "lstrip": false,
9
+ "normalized": false,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": true
13
+ },
14
+ "1": {
15
+ "content": "<s>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "2": {
23
+ "content": "</s>",
24
+ "lstrip": false,
25
+ "normalized": false,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": true
29
+ },
30
+ "32000": {
31
+ "content": "<img>",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false,
36
+ "special": true
37
+ },
38
+ "32001": {
39
+ "content": "</img>",
40
+ "lstrip": false,
41
+ "normalized": false,
42
+ "rstrip": false,
43
+ "single_word": false,
44
+ "special": true
45
+ },
46
+ "32002": {
47
+ "content": "<IMG_CONTEXT>",
48
+ "lstrip": false,
49
+ "normalized": false,
50
+ "rstrip": false,
51
+ "single_word": false,
52
+ "special": true
53
+ },
54
+ "32003": {
55
+ "content": "<quad>",
56
+ "lstrip": false,
57
+ "normalized": false,
58
+ "rstrip": false,
59
+ "single_word": false,
60
+ "special": true
61
+ },
62
+ "32004": {
63
+ "content": "</quad>",
64
+ "lstrip": false,
65
+ "normalized": false,
66
+ "rstrip": false,
67
+ "single_word": false,
68
+ "special": true
69
+ },
70
+ "32005": {
71
+ "content": "<ref>",
72
+ "lstrip": false,
73
+ "normalized": false,
74
+ "rstrip": false,
75
+ "single_word": false,
76
+ "special": true
77
+ },
78
+ "32006": {
79
+ "content": "</ref>",
80
+ "lstrip": false,
81
+ "normalized": false,
82
+ "rstrip": false,
83
+ "single_word": false,
84
+ "special": true
85
+ },
86
+ "32007": {
87
+ "content": "<box>",
88
+ "lstrip": false,
89
+ "normalized": false,
90
+ "rstrip": false,
91
+ "single_word": false,
92
+ "special": true
93
+ },
94
+ "32008": {
95
+ "content": "</box>",
96
+ "lstrip": false,
97
+ "normalized": false,
98
+ "rstrip": false,
99
+ "single_word": false,
100
+ "special": true
101
+ },
102
+ "32009": {
103
+ "content": "<|end|>",
104
+ "lstrip": false,
105
+ "normalized": false,
106
+ "rstrip": false,
107
+ "single_word": false,
108
+ "special": true
109
+ }
110
+ },
111
+ "additional_special_tokens": [
112
+ "<img>",
113
+ "</img>",
114
+ "<IMG_CONTEXT>",
115
+ "<quad>",
116
+ "</quad>",
117
+ "<ref>",
118
+ "</ref>",
119
+ "<box>",
120
+ "</box>",
121
+ "<|end|>"
122
+ ],
123
+ "bos_token": "<s>",
124
+ "chat_template": "{% for message in messages %}{% if message['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% if ((message['role'] == 'user') != (loop.index0 % 2 == 0)) or ((message['role'] == 'assistant') != (loop.index0 % 2 == 1)) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '<|prompt|>' + message['content'].strip() + eos_token }}{% elif message['role'] == 'assistant' %}{{ '<|answer|>' + message['content'].strip() + eos_token }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|answer|>' }}{% endif %}",
125
+ "clean_up_tokenization_spaces": false,
126
+ "cls_token": "</s>",
127
+ "eos_token": "<|end|>",
128
+ "legacy": true,
129
+ "model_max_length": 8192,
130
+ "pad_token": "<unk>",
131
+ "sep_token": "</s>",
132
+ "sp_model_kwargs": {},
133
+ "spaces_between_special_tokens": false,
134
+ "tokenizer_class": "LlamaTokenizer",
135
+ "unk_token": "<unk>",
136
+ "use_default_system_prompt": false
137
+ }