Overwrite with converted Qwen2.5-3B model files
Browse files- README.md +222 -71
- VLV_stage1.py +257 -0
- VLV_stage2.py +460 -0
- build.py +78 -0
- config.json +13 -9
- configuration_vlv.py +172 -0
- model-00001-of-00005.safetensors +3 -0
- model-00002-of-00005.safetensors +3 -0
- model-00003-of-00005.safetensors +3 -0
- model-00004-of-00005.safetensors +3 -0
- model-00005-of-00005.safetensors +3 -0
- model.safetensors.index.json +0 -0
- modeling_clip.py +60 -3
- vlv_utils.py +71 -0
README.md
CHANGED
@@ -1,104 +1,255 @@
|
|
1 |
---
|
2 |
license: apache-2.0
|
3 |
-
|
4 |
-
-
|
5 |
-
|
6 |
-
-
|
7 |
-
-
|
|
|
|
|
|
|
8 |
pipeline_tag: image-to-text
|
|
|
|
|
|
|
|
|
9 |
---
|
10 |
|
11 |
-
#
|
12 |
|
13 |
-
|
14 |
-
[](https://arxiv.org/abs/2507.07104)
|
15 |
-
[](https://github.com/Tiezheng11/Vision-Language-Vision)
|
16 |
-
[](https://huggingface.co/lambertxiao/Vision-Language-Vision-Captioner-Qwen2.5-3B)
|
17 |
-
[](https://huggingface.co/datasets/ccvl/LAION-High-Qualtiy-Pro-6M-VLV)
|
18 |
|
19 |
-
##
|
20 |
|
21 |
-
|
22 |
-
|
|
|
|
|
23 |
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
-
##
|
27 |
|
28 |
```bash
|
29 |
-
|
30 |
-
pip install -r requirements.txt
|
31 |
```
|
32 |
|
33 |
-
##
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
```python
|
35 |
-
|
36 |
-
|
37 |
-
import torch, numpy as np
|
38 |
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
41 |
|
42 |
-
|
43 |
-
|
44 |
-
AutoModel.from_pretrained(
|
45 |
-
MODEL_NAME,
|
46 |
-
trust_remote_code=True,
|
47 |
-
low_cpu_mem_usage=False,
|
48 |
-
)
|
49 |
-
.to(device)
|
50 |
-
.eval()
|
51 |
-
)
|
52 |
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
sentences = [s.strip() for s in text.split(".") if s.strip()]
|
57 |
-
if not text.rstrip().endswith("."):
|
58 |
-
sentences = sentences[:-1] # drop dangling fragment
|
59 |
-
return ". ".join(sentences) + ("." if sentences else "")
|
60 |
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
"""
|
69 |
-
Wrapper for NumPy arrays.
|
70 |
-
Accepts uint8 [0, 255] or float [0, 1] ranges.
|
71 |
-
"""
|
72 |
-
if arr.dtype != np.uint8:
|
73 |
-
arr = (np.clip(arr, 0, 1) * 255).astype(np.uint8)
|
74 |
-
return caption_image(Image.fromarray(arr, mode="RGB"), max_len)
|
75 |
```
|
76 |
|
|
|
|
|
77 |
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
```python
|
81 |
-
|
|
|
|
|
|
|
82 |
|
83 |
-
|
84 |
-
from PIL import Image
|
85 |
-
from IPython.display import display # Jupyter/Colab only
|
86 |
|
87 |
-
|
88 |
|
89 |
-
|
90 |
-
img = Image.open(io.BytesIO(requests.get(IMG_URL, timeout=10).content)).convert("RGB")
|
91 |
|
92 |
-
|
93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
```
|
96 |
-
|
|
|
97 |
|
98 |
```bibtex
|
99 |
-
@article{
|
100 |
-
title
|
101 |
-
author
|
102 |
-
journal
|
103 |
-
year
|
104 |
}
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
license: apache-2.0
|
3 |
+
tags:
|
4 |
+
- image-captioning
|
5 |
+
- multimodal
|
6 |
+
- vision-language
|
7 |
+
- diffusion
|
8 |
+
- pytorch
|
9 |
+
- transformers
|
10 |
+
library_name: transformers
|
11 |
pipeline_tag: image-to-text
|
12 |
+
datasets:
|
13 |
+
- conceptual_captions
|
14 |
+
- coco
|
15 |
+
model_type: VLV_decoder
|
16 |
---
|
17 |
|
18 |
+
# VLV Captioner Model
|
19 |
|
20 |
+
This is a VLV (Vision-Language-Vision) model for image captioning. The model combines stable diffusion image encoding with Qwen language model for generating descriptive captions from images.
|
|
|
|
|
|
|
|
|
21 |
|
22 |
+
## Model Description
|
23 |
|
24 |
+
The VLV Captioner is a multimodal model that:
|
25 |
+
- Uses a diffusion-based vision encoder to extract image features
|
26 |
+
- Employs the Qwen2.5-3B language model for text generation
|
27 |
+
- Generates natural language descriptions of input images
|
28 |
|
29 |
+
## Model Architecture
|
30 |
+
|
31 |
+
- **Vision Encoder**: Stable Diffusion-based image encoder with Florence2 components
|
32 |
+
- **Language Model**: Qwen2.5-3B transformer model
|
33 |
+
- **Image Size**: 384x384 pixels
|
34 |
+
- **Max Caption Length**: 300 tokens
|
35 |
+
- **Precision**: Mixed precision (bfloat16/float32)
|
36 |
+
|
37 |
+
## Usage
|
38 |
+
|
39 |
+
### Method 1: Load from Hugging Face Hub
|
40 |
+
|
41 |
+
```python
|
42 |
+
from transformers import AutoModel, AutoConfig
|
43 |
+
from PIL import Image
|
44 |
+
import torch
|
45 |
+
import os
|
46 |
+
|
47 |
+
# Optional: Set custom cache directory if needed
|
48 |
+
cache_dir = "/path/to/your/cache" # Use a directory with sufficient space
|
49 |
+
os.makedirs(cache_dir, exist_ok=True)
|
50 |
+
|
51 |
+
# Load the model with authentication token (if required)
|
52 |
+
token = os.getenv('HUGGINGFACE_TOKEN') # or your token string
|
53 |
+
|
54 |
+
print("Loading config...")
|
55 |
+
config = AutoConfig.from_pretrained(
|
56 |
+
"your-username/vlv-captioner",
|
57 |
+
trust_remote_code=True,
|
58 |
+
token=token,
|
59 |
+
cache_dir=cache_dir
|
60 |
+
)
|
61 |
+
|
62 |
+
print("Loading model...")
|
63 |
+
try:
|
64 |
+
model = AutoModel.from_pretrained(
|
65 |
+
"your-username/vlv-captioner",
|
66 |
+
trust_remote_code=True,
|
67 |
+
token=token,
|
68 |
+
cache_dir=cache_dir,
|
69 |
+
torch_dtype=torch.float32, # Specify dtype explicitly
|
70 |
+
low_cpu_mem_usage=True
|
71 |
+
# Note: Avoid device_map="auto" to prevent meta tensor issues
|
72 |
+
)
|
73 |
+
print("Model loaded successfully!")
|
74 |
+
|
75 |
+
# Load and process an image
|
76 |
+
image = Image.open("path/to/your/image.jpg")
|
77 |
+
|
78 |
+
# Move model to GPU if available
|
79 |
+
if torch.cuda.is_available():
|
80 |
+
model = model.to('cuda')
|
81 |
+
print("Model moved to GPU!")
|
82 |
+
|
83 |
+
# Generate caption
|
84 |
+
print("Generating caption...")
|
85 |
+
with torch.no_grad():
|
86 |
+
captions = model([image], max_length=300)
|
87 |
+
|
88 |
+
# Handle different possible output formats
|
89 |
+
if hasattr(captions, 'generated_text'):
|
90 |
+
print("Generated caption:", captions.generated_text[0])
|
91 |
+
elif isinstance(captions, list):
|
92 |
+
print("Generated caption:", captions[0])
|
93 |
+
else:
|
94 |
+
print("Generated caption:", captions)
|
95 |
+
|
96 |
+
except Exception as e:
|
97 |
+
print(f"Error during model loading or inference: {e}")
|
98 |
+
# If cached files are corrupted, try clearing cache and redownloading
|
99 |
+
import shutil
|
100 |
+
cache_path = f"{cache_dir}/modules/transformers_modules/your-username/vlv-captioner"
|
101 |
+
if os.path.exists(cache_path):
|
102 |
+
print(f"Clearing cache at {cache_path}")
|
103 |
+
shutil.rmtree(cache_path)
|
104 |
+
|
105 |
+
# Retry with force download
|
106 |
+
model = AutoModel.from_pretrained(
|
107 |
+
"your-username/vlv-captioner",
|
108 |
+
trust_remote_code=True,
|
109 |
+
token=token,
|
110 |
+
cache_dir=cache_dir,
|
111 |
+
force_download=True,
|
112 |
+
torch_dtype=torch.float32
|
113 |
+
)
|
114 |
+
```
|
115 |
+
|
116 |
+
### Method 2: Load from original checkpoint
|
117 |
+
|
118 |
+
```python
|
119 |
+
from VLV_stage2 import VLV_MODEL
|
120 |
+
|
121 |
+
# Load from original .pt checkpoint file
|
122 |
+
model = VLV_MODEL.from_checkpoint("path/to/model.pt")
|
123 |
+
|
124 |
+
# Load and process an image
|
125 |
+
image = Image.open("path/to/your/image.jpg")
|
126 |
+
|
127 |
+
# Generate caption
|
128 |
+
with torch.no_grad():
|
129 |
+
captions = model([image], max_length=300)
|
130 |
+
print(captions.generated_text[0]) # Generated caption
|
131 |
+
```
|
132 |
+
|
133 |
+
## Model Details
|
134 |
+
|
135 |
+
- **Model Type**: Vision-Language Model
|
136 |
+
- **Architecture**: VLV_decoder
|
137 |
+
- **Language Backbone**: Qwen/Qwen2.5-3B
|
138 |
+
- **Vision Backbone**: Stable Diffusion + Florence2
|
139 |
+
- **Training Data**: Various image-caption datasets
|
140 |
+
- **Framework**: PyTorch, Transformers
|
141 |
+
|
142 |
+
## Training Configuration
|
143 |
+
|
144 |
+
- **Batch Size**: 1 (inference)
|
145 |
+
- **Learnable Token Length**: 77
|
146 |
+
- **Guidance Scale**: 7.5
|
147 |
+
- **Inference Steps**: 50
|
148 |
+
- **Beam Search**: 4 beams
|
149 |
|
150 |
+
## Requirements
|
151 |
|
152 |
```bash
|
153 |
+
pip install torch transformers safetensors torchvision pillow diffusers
|
|
|
154 |
```
|
155 |
|
156 |
+
## Troubleshooting
|
157 |
+
|
158 |
+
### Common Issues and Solutions
|
159 |
+
|
160 |
+
#### 1. Meta Tensor Issues
|
161 |
+
If you encounter meta tensor errors, avoid using `device_map="auto"` when loading the model:
|
162 |
+
|
163 |
```python
|
164 |
+
# ❌ Don't use this - can cause meta tensor issues
|
165 |
+
model = AutoModel.from_pretrained("model-name", device_map="auto")
|
|
|
166 |
|
167 |
+
# ✅ Use this instead
|
168 |
+
model = AutoModel.from_pretrained("model-name", torch_dtype=torch.float32, low_cpu_mem_usage=True)
|
169 |
+
if torch.cuda.is_available():
|
170 |
+
model = model.to('cuda')
|
171 |
+
```
|
172 |
|
173 |
+
#### 2. Cache Issues
|
174 |
+
If you experience corrupted cache files, clear the cache and redownload:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
|
176 |
+
```python
|
177 |
+
import shutil
|
178 |
+
import os
|
|
|
|
|
|
|
|
|
179 |
|
180 |
+
cache_dir = "/your/cache/directory"
|
181 |
+
cache_path = f"{cache_dir}/modules/transformers_modules/your-username/model-name"
|
182 |
+
if os.path.exists(cache_path):
|
183 |
+
shutil.rmtree(cache_path)
|
184 |
+
|
185 |
+
# Then reload with force_download=True
|
186 |
+
model = AutoModel.from_pretrained("model-name", force_download=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
```
|
188 |
|
189 |
+
#### 3. Authentication Issues
|
190 |
+
Make sure your Hugging Face token is properly set:
|
191 |
|
192 |
+
```bash
|
193 |
+
# Option 1: Environment variable
|
194 |
+
export HUGGINGFACE_TOKEN="your_token_here"
|
195 |
+
|
196 |
+
# Option 2: Hugging Face CLI login
|
197 |
+
huggingface-cli login
|
198 |
+
```
|
199 |
+
|
200 |
+
#### 4. Memory Issues
|
201 |
+
For large models, use a custom cache directory with sufficient space:
|
202 |
|
203 |
```python
|
204 |
+
cache_dir = "/path/to/large/storage"
|
205 |
+
os.makedirs(cache_dir, exist_ok=True)
|
206 |
+
model = AutoModel.from_pretrained("model-name", cache_dir=cache_dir, low_cpu_mem_usage=True)
|
207 |
+
```
|
208 |
|
209 |
+
## Advanced Usage
|
|
|
|
|
210 |
|
211 |
+
### Batch Processing with Original Inference Script
|
212 |
|
213 |
+
For large-scale inference, you can use the original training inference script:
|
|
|
214 |
|
215 |
+
```bash
|
216 |
+
python Caption_inference.py \
|
217 |
+
--input_path /path/to/images \
|
218 |
+
--output_path captions.json \
|
219 |
+
--clip_decoder_checkpoint /path/to/model.pt \
|
220 |
+
--qwen_model Qwen/Qwen2.5-3B \
|
221 |
+
--stable_diffusion_model_path stabilityai/stable-diffusion-2-1-base \
|
222 |
+
--florence2_model_path microsoft/Florence-2-large \
|
223 |
+
--batch_size 4 \
|
224 |
+
--max_length 300 \
|
225 |
+
--num_beams 4 \
|
226 |
+
--image_size 384 \
|
227 |
+
--guidance_scale 7.5 \
|
228 |
+
--use_text_encoder \
|
229 |
+
--distributed # For multi-GPU inference
|
230 |
+
```
|
231 |
|
232 |
+
### Configuration Parameters
|
233 |
+
|
234 |
+
- `image_size`: Input image resolution (default: 384)
|
235 |
+
- `guidance_scale`: Diffusion guidance scale (default: 7.5)
|
236 |
+
- `learnable_token_length`: Number of vision tokens (default: 77)
|
237 |
+
- `max_length`: Maximum caption length (default: 300)
|
238 |
+
- `num_beams`: Beam search width (default: 4)
|
239 |
+
- `use_text_encoder`: Enable CLIP text encoder (recommended: True)
|
240 |
```
|
241 |
+
|
242 |
+
## Citation
|
243 |
|
244 |
```bibtex
|
245 |
+
@article{vlv_autoencoder,
|
246 |
+
title={Vision-Language-Vision Auto-Encoder: Scalable Knowledge Distillation from Diffusion Models},
|
247 |
+
author={Zhang, Tiezheng and Li, Yitong and Chou, Yu-Cheng and Chen, Jieneng and Yuille, Alan L. and Wei, Chen and Xiao, Junfei},
|
248 |
+
journal={arXiv preprint},
|
249 |
+
year={2024}
|
250 |
}
|
251 |
+
```
|
252 |
+
|
253 |
+
## License
|
254 |
+
|
255 |
+
This model is released under the Apache 2.0 license.
|
VLV_stage1.py
ADDED
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from typing import Optional
|
5 |
+
from dataclasses import dataclass
|
6 |
+
from transformers.utils import ModelOutput
|
7 |
+
from transformers.modeling_utils import PreTrainedModel
|
8 |
+
from transformers.configuration_utils import PretrainedConfig
|
9 |
+
from .build import load_sd_model, load_Florence2_model
|
10 |
+
from .vlv_utils import initiate_time_steps, normalize
|
11 |
+
|
12 |
+
|
13 |
+
class SDConfig(PretrainedConfig):
|
14 |
+
"""Configuration class for SDModel."""
|
15 |
+
model_type = "sd"
|
16 |
+
|
17 |
+
def __init__(self, **kwargs):
|
18 |
+
super().__init__(**kwargs)
|
19 |
+
|
20 |
+
|
21 |
+
class MLP(nn.Module):
|
22 |
+
def __init__(self, input_dim, output_dim):
|
23 |
+
super().__init__()
|
24 |
+
self.layers = nn.Sequential(
|
25 |
+
nn.Linear(input_dim, output_dim),
|
26 |
+
nn.GELU(),
|
27 |
+
nn.Linear(output_dim, output_dim),
|
28 |
+
)
|
29 |
+
|
30 |
+
def forward(self, x):
|
31 |
+
return self.layers(x)
|
32 |
+
|
33 |
+
@dataclass
|
34 |
+
class SDOutput(ModelOutput):
|
35 |
+
loss: Optional[torch.FloatTensor] = None
|
36 |
+
|
37 |
+
class SDModel(PreTrainedModel):
|
38 |
+
config_class = SDConfig
|
39 |
+
|
40 |
+
def __init__(
|
41 |
+
self,
|
42 |
+
config=None,
|
43 |
+
training_args = None,
|
44 |
+
):
|
45 |
+
if config is None:
|
46 |
+
config = SDConfig()
|
47 |
+
super().__init__(config)
|
48 |
+
self.training_args = training_args
|
49 |
+
if self.training_args.fp32:
|
50 |
+
self._dtype = torch.float32
|
51 |
+
else:
|
52 |
+
self._dtype = torch.bfloat16
|
53 |
+
self._device = torch.device(self.training_args.device if hasattr(self.training_args, 'device') else "cuda" if torch.cuda.is_available() else "cpu")
|
54 |
+
|
55 |
+
self.vae, self.tokenizer, self.text_encoder, self.unet, self.scheduler = load_sd_model(training_args)
|
56 |
+
torch.cuda.empty_cache()
|
57 |
+
self.unet.eval()
|
58 |
+
self.text_encoder.eval()
|
59 |
+
self.model, self.processor = load_Florence2_model(training_args)
|
60 |
+
|
61 |
+
self.unet = self.unet.to(self._dtype).to(device=self._device)
|
62 |
+
self.text_encoder = self.text_encoder.to(self._dtype).to_empty(device=self._device)
|
63 |
+
self.model = self.model.to(self._dtype).to_empty(device=self._device)
|
64 |
+
self.vae = self.vae.to(torch.float32).to_empty(device=self._device)
|
65 |
+
|
66 |
+
self.batch_size = self.training_args.batch_size
|
67 |
+
|
68 |
+
hidden_dim = 1024
|
69 |
+
self.language_proj = nn.Sequential(
|
70 |
+
nn.Linear(1024, hidden_dim, dtype=self._dtype),
|
71 |
+
nn.GELU(),
|
72 |
+
nn.Linear(hidden_dim, 1024, dtype=self._dtype)
|
73 |
+
).to_empty(device=self._device)
|
74 |
+
for param in self.language_proj.parameters():
|
75 |
+
param.requires_grad = True
|
76 |
+
|
77 |
+
self.num_queries = self.training_args.learnable_token_length
|
78 |
+
self.query_embed = nn.Parameter(torch.randn(1, self.num_queries, 1024, dtype=self._dtype))
|
79 |
+
self.query_embed.requires_grad = True
|
80 |
+
|
81 |
+
self.unet.enable_gradient_checkpointing()
|
82 |
+
|
83 |
+
def _unet_pred_noise(self, x_start, t, noise, context):
|
84 |
+
t = t.to(dtype=torch.long)
|
85 |
+
|
86 |
+
dtype = self.unet.dtype
|
87 |
+
x_start = x_start.to(dtype)
|
88 |
+
noise = noise.to(dtype)
|
89 |
+
context = context.to(dtype)
|
90 |
+
|
91 |
+
nt = t.shape[0]
|
92 |
+
noised_latent = self.scheduler.add_noise(x_start, noise, t)
|
93 |
+
|
94 |
+
pred_noise = self.unet(
|
95 |
+
noised_latent,
|
96 |
+
t,
|
97 |
+
encoder_hidden_states=context.expand(nt, -1, -1)
|
98 |
+
).sample
|
99 |
+
|
100 |
+
return pred_noise
|
101 |
+
|
102 |
+
def generate_images(self, images):
|
103 |
+
batch_size = self.training_args.eval_batch_size
|
104 |
+
prompt = ["<MORE_DETAILED_CAPTION>"] * batch_size
|
105 |
+
inputs = self.processor(text=prompt, images=images, return_tensors="pt").to(self._device).to(self._dtype)
|
106 |
+
|
107 |
+
if inputs["input_ids"] is not None:
|
108 |
+
inputs_embeds = self.model.language_model.get_input_embeddings()(inputs["input_ids"]).to(self._dtype)
|
109 |
+
if inputs["pixel_values"] is not None:
|
110 |
+
image_features = self.model._encode_image(inputs["pixel_values"]).to(self._dtype)
|
111 |
+
inputs_embeds, attention_mask = self.model._merge_input_ids_with_image_features(image_features, inputs_embeds)
|
112 |
+
if inputs_embeds is not None:
|
113 |
+
attention_mask = attention_mask.to(inputs_embeds.dtype)
|
114 |
+
encoder_outputs = self.model.language_model.model.encoder(
|
115 |
+
inputs_embeds=inputs_embeds,
|
116 |
+
attention_mask=attention_mask,
|
117 |
+
output_hidden_states=True,
|
118 |
+
return_dict=True
|
119 |
+
)
|
120 |
+
|
121 |
+
decoder_input_embeds = self.query_embed.expand(batch_size, -1, -1)
|
122 |
+
decoder_attention_mask = torch.ones(
|
123 |
+
(batch_size, self.num_queries),
|
124 |
+
dtype=self._dtype,
|
125 |
+
device=self._device
|
126 |
+
)
|
127 |
+
|
128 |
+
encoder_hidden_states = encoder_outputs.last_hidden_state.to(self._dtype)
|
129 |
+
decoder_input_embeds = decoder_input_embeds.to(self._dtype)
|
130 |
+
attention_mask = attention_mask.to(self._dtype)
|
131 |
+
|
132 |
+
decoder_outputs = self.model.language_model.model.decoder(
|
133 |
+
inputs_embeds=decoder_input_embeds,
|
134 |
+
attention_mask=decoder_attention_mask,
|
135 |
+
encoder_hidden_states=encoder_hidden_states,
|
136 |
+
encoder_attention_mask=attention_mask,
|
137 |
+
output_hidden_states=True,
|
138 |
+
return_dict=True
|
139 |
+
)
|
140 |
+
|
141 |
+
last_decoder_hidden_state = decoder_outputs.last_hidden_state
|
142 |
+
conditional_context = self.language_proj(last_decoder_hidden_state)
|
143 |
+
|
144 |
+
un_token = self.tokenizer("", padding="max_length", truncation=True,max_length=77, return_tensors="pt").input_ids.to(self._device)
|
145 |
+
un_context_embeddings = self.text_encoder(un_token).last_hidden_state
|
146 |
+
un_context_embeddings = un_context_embeddings.expand(batch_size, -1, -1)
|
147 |
+
if self.training_args.use_text_encoder:
|
148 |
+
context_embeddings = self.text_encoder(
|
149 |
+
inputs_embeds=conditional_context.to(self._dtype)
|
150 |
+
).last_hidden_state
|
151 |
+
|
152 |
+
latent_shape = (batch_size, 4, self.training_args.image_size // 8, self.training_args.image_size // 8)
|
153 |
+
latents = torch.randn(latent_shape, device=self._device, dtype=self._dtype)
|
154 |
+
|
155 |
+
scheduler = self.scheduler
|
156 |
+
scheduler.set_timesteps(self.training_args.num_inference_steps)
|
157 |
+
with torch.no_grad():
|
158 |
+
for t in scheduler.timesteps:
|
159 |
+
latent_model_input = torch.cat([latents, latents], dim=0)
|
160 |
+
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
161 |
+
|
162 |
+
combined_embeddings = torch.cat([un_context_embeddings, context_embeddings], dim=0).to(self._dtype)
|
163 |
+
noise_pred = self.unet(
|
164 |
+
latent_model_input, t, encoder_hidden_states=combined_embeddings
|
165 |
+
)[0]
|
166 |
+
|
167 |
+
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2, dim=0)
|
168 |
+
noise_pred = noise_pred_uncond + self.training_args.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
169 |
+
|
170 |
+
latents = scheduler.step(noise_pred, t, latents)[0]
|
171 |
+
|
172 |
+
scaled_latents = latents / 0.18215
|
173 |
+
with torch.no_grad():
|
174 |
+
decoded_latents = self.vae.decode(scaled_latents.to(torch.float32))[0]
|
175 |
+
|
176 |
+
return decoded_latents
|
177 |
+
|
178 |
+
def get_conditional_context(self, images, batch_size=None):
|
179 |
+
if batch_size is None:
|
180 |
+
batch_size = self.batch_size
|
181 |
+
prompt = ["<MORE_DETAILED_CAPTION>"] * batch_size
|
182 |
+
inputs = self.processor(text=prompt, images=images, return_tensors="pt").to(self._device).to(self._dtype)
|
183 |
+
|
184 |
+
if inputs["input_ids"] is not None:
|
185 |
+
inputs_embeds = self.model.language_model.get_input_embeddings()(inputs["input_ids"]).to(self._dtype)
|
186 |
+
if inputs["pixel_values"] is not None:
|
187 |
+
image_features = self.model._encode_image(inputs["pixel_values"]).to(self._dtype)
|
188 |
+
inputs_embeds, attention_mask = self.model._merge_input_ids_with_image_features(image_features, inputs_embeds)
|
189 |
+
if inputs_embeds is not None:
|
190 |
+
attention_mask = attention_mask.to(inputs_embeds.dtype)
|
191 |
+
encoder_outputs = self.model.language_model.model.encoder(
|
192 |
+
inputs_embeds=inputs_embeds,
|
193 |
+
attention_mask=attention_mask,
|
194 |
+
output_hidden_states=True,
|
195 |
+
return_dict=True
|
196 |
+
)
|
197 |
+
|
198 |
+
decoder_input_embeds = self.query_embed.expand(batch_size, -1, -1)
|
199 |
+
decoder_attention_mask = torch.ones(
|
200 |
+
(batch_size, self.num_queries),
|
201 |
+
dtype=self._dtype,
|
202 |
+
device=self._device
|
203 |
+
)
|
204 |
+
|
205 |
+
encoder_hidden_states = encoder_outputs.last_hidden_state.to(self._dtype)
|
206 |
+
decoder_input_embeds = decoder_input_embeds.to(self._dtype)
|
207 |
+
attention_mask = attention_mask.to(self._dtype)
|
208 |
+
|
209 |
+
decoder_outputs = self.model.language_model.model.decoder(
|
210 |
+
inputs_embeds=decoder_input_embeds,
|
211 |
+
attention_mask=decoder_attention_mask,
|
212 |
+
encoder_hidden_states=encoder_hidden_states,
|
213 |
+
encoder_attention_mask=attention_mask,
|
214 |
+
output_hidden_states=True,
|
215 |
+
return_dict=True
|
216 |
+
)
|
217 |
+
|
218 |
+
last_decoder_hidden_state = decoder_outputs.last_hidden_state
|
219 |
+
return last_decoder_hidden_state
|
220 |
+
|
221 |
+
def forward(
|
222 |
+
self,
|
223 |
+
image=None,
|
224 |
+
filename=None,
|
225 |
+
**kwargs,
|
226 |
+
) -> SDOutput:
|
227 |
+
images_for_language_model = image
|
228 |
+
normalize_images = normalize(image, rescale=True)
|
229 |
+
x0=self.vae.encode(normalize_images.to(torch.float32)).latent_dist.sample()
|
230 |
+
latent = x0 * 0.18215
|
231 |
+
|
232 |
+
total_timestep = self.scheduler.num_train_timesteps
|
233 |
+
|
234 |
+
timesteps = initiate_time_steps(0, total_timestep, self.batch_size, self.training_args).long()
|
235 |
+
timesteps = timesteps.to(self._device)
|
236 |
+
c, h, w = latent.shape[1:]
|
237 |
+
if not self.training_args.use_same_noise_among_timesteps:
|
238 |
+
noise = torch.randn((self.batch_size, c, h, w), device=self._device, dtype=self._dtype)
|
239 |
+
else:
|
240 |
+
noise = torch.randn((1, c, h, w), device=self._device, dtype=self._dtype)
|
241 |
+
noise = noise.repeat(self.batch_size, 1, 1, 1)
|
242 |
+
|
243 |
+
conditional_context = self.get_conditional_context(images_for_language_model)
|
244 |
+
conditional_context = self.language_proj(conditional_context)
|
245 |
+
|
246 |
+
if self.training_args.use_text_encoder:
|
247 |
+
text_encoder_output = self.text_encoder(input_ids=None, inputs_embeds=conditional_context.to(self._dtype))
|
248 |
+
pred_noise = self._unet_pred_noise(x_start=latent, t=timesteps, noise=noise, context=text_encoder_output.last_hidden_state.to(self._dtype)).to(self._dtype)
|
249 |
+
else:
|
250 |
+
pred_noise = self._unet_pred_noise(x_start=latent, t=timesteps, noise=noise, context=conditional_context.to(self._dtype)).to(self._dtype)
|
251 |
+
|
252 |
+
if self.training_args.loss == "l1":
|
253 |
+
loss = torch.nn.functional.l1_loss(pred_noise, noise)
|
254 |
+
else:
|
255 |
+
loss = torch.nn.functional.mse_loss(pred_noise, noise)
|
256 |
+
|
257 |
+
return SDOutput(loss=loss)
|
VLV_stage2.py
ADDED
@@ -0,0 +1,460 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Optional, Tuple, Dict, Any, Union
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import numpy as np
|
7 |
+
from transformers.utils import ModelOutput
|
8 |
+
from transformers.modeling_utils import PreTrainedModel
|
9 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, PretrainedConfig
|
10 |
+
from safetensors.torch import load_file
|
11 |
+
import torchvision.transforms as transforms
|
12 |
+
from .build import load_sd_model, load_Florence2_model
|
13 |
+
from .vlv_utils import initiate_time_steps, normalize, process_caption
|
14 |
+
from .VLV_stage1 import SDModel, SDConfig
|
15 |
+
from .configuration_vlv import VLV_Config
|
16 |
+
import os
|
17 |
+
import sys
|
18 |
+
import argparse
|
19 |
+
|
20 |
+
def handle_module_prefix(state_dict):
|
21 |
+
"""Handle 'module.' prefix in state dict keys."""
|
22 |
+
if any(k.startswith('module.') for k in state_dict.keys()):
|
23 |
+
return {k.replace('module.', ''): v for k, v in state_dict.items()}
|
24 |
+
return state_dict
|
25 |
+
|
26 |
+
def create_model_args(args):
|
27 |
+
"""Create model arguments needed by SDModel."""
|
28 |
+
model_args = argparse.Namespace()
|
29 |
+
model_args.use_text_encoder = args.use_text_encoder
|
30 |
+
model_args.batch_size = args.batch_size
|
31 |
+
model_args.eval_batch_size = args.batch_size
|
32 |
+
model_args.distributed_strategy = 'none'
|
33 |
+
model_args.fp32 = args.fp32
|
34 |
+
model_args.learnable_token_length = args.learnable_token_length
|
35 |
+
model_args.num_inference_steps = args.num_inference_steps
|
36 |
+
model_args.image_size = args.image_size
|
37 |
+
model_args.guidance_scale = args.guidance_scale
|
38 |
+
model_args.unfreeze_florence2_all = False
|
39 |
+
model_args.unfreeze_florence2_language_model = False
|
40 |
+
model_args.unfreeze_florence2_language_model_decoder = False
|
41 |
+
return model_args
|
42 |
+
|
43 |
+
def load_model_checkpoint(model, model_path, device):
|
44 |
+
"""Load model checkpoint."""
|
45 |
+
try:
|
46 |
+
checkpoint = torch.load(model_path, map_location="cpu")
|
47 |
+
|
48 |
+
# Handle different checkpoint formats
|
49 |
+
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
|
50 |
+
state_dict = checkpoint['model_state_dict']
|
51 |
+
elif isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
52 |
+
state_dict = checkpoint['state_dict']
|
53 |
+
else:
|
54 |
+
state_dict = checkpoint
|
55 |
+
|
56 |
+
state_dict = handle_module_prefix(state_dict)
|
57 |
+
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
58 |
+
|
59 |
+
if missing_keys:
|
60 |
+
print(f"Missing keys: {missing_keys[:10]}...") # Show first 10
|
61 |
+
if unexpected_keys:
|
62 |
+
print(f"Unexpected keys: {unexpected_keys[:10]}...") # Show first 10
|
63 |
+
|
64 |
+
print(f"Successfully loaded model from {model_path}")
|
65 |
+
except Exception as e:
|
66 |
+
print(f"Error loading model: {e}")
|
67 |
+
raise e
|
68 |
+
|
69 |
+
return model
|
70 |
+
|
71 |
+
def initialize_diffusion_model(args):
|
72 |
+
"""Initialize the diffusion model."""
|
73 |
+
config = SDConfig()
|
74 |
+
diffusion_model_args = create_model_args(args)
|
75 |
+
diffusion_model = SDModel(config, diffusion_model_args)
|
76 |
+
_dtype = torch.float32 if diffusion_model_args.fp32 else torch.bfloat16
|
77 |
+
|
78 |
+
# Delete components that aren't needed for inference
|
79 |
+
if hasattr(diffusion_model, 'vae'):
|
80 |
+
del diffusion_model.vae
|
81 |
+
if hasattr(diffusion_model, 'unet'):
|
82 |
+
del diffusion_model.unet
|
83 |
+
|
84 |
+
# Clear CUDA cache
|
85 |
+
torch.cuda.empty_cache()
|
86 |
+
|
87 |
+
diffusion_model = diffusion_model.to(_dtype)
|
88 |
+
|
89 |
+
# Freeze parameters that shouldn't be trained
|
90 |
+
for param in diffusion_model.language_proj.parameters():
|
91 |
+
param.requires_grad = False
|
92 |
+
diffusion_model.query_embed.requires_grad = False
|
93 |
+
|
94 |
+
return diffusion_model
|
95 |
+
|
96 |
+
class MLP(nn.Module):
|
97 |
+
def __init__(self, input_dim, output_dim):
|
98 |
+
super(MLP, self).__init__()
|
99 |
+
self.layers = nn.Sequential(
|
100 |
+
nn.Linear(input_dim, output_dim),
|
101 |
+
nn.GELU(),
|
102 |
+
nn.Linear(output_dim, output_dim),
|
103 |
+
)
|
104 |
+
|
105 |
+
def forward(self, x):
|
106 |
+
return self.layers(x)
|
107 |
+
|
108 |
+
|
109 |
+
@dataclass
|
110 |
+
class CLIPDecoderOutput(ModelOutput):
|
111 |
+
"""
|
112 |
+
Output class for the CLIP Decoder model.
|
113 |
+
"""
|
114 |
+
last_hidden_state: Optional[torch.FloatTensor] = None
|
115 |
+
generated_ids: Optional[torch.LongTensor] = None
|
116 |
+
generated_text: Optional[list] = None
|
117 |
+
|
118 |
+
|
119 |
+
class CLIPDecoder(nn.Module):
|
120 |
+
|
121 |
+
def __init__(
|
122 |
+
self,
|
123 |
+
language_model: str,
|
124 |
+
VLV_model: SDModel,
|
125 |
+
device: torch.device,
|
126 |
+
bf16: str,
|
127 |
+
qwen2_config: dict = None,
|
128 |
+
args: argparse.Namespace = None
|
129 |
+
):
|
130 |
+
"""
|
131 |
+
Initialize the CLIP Decoder model.
|
132 |
+
|
133 |
+
Args:
|
134 |
+
language_model: Path to the language model
|
135 |
+
VLV_model: The VLV model instance
|
136 |
+
device: The device to run the model on
|
137 |
+
bf16: Whether to use bfloat16 precision
|
138 |
+
qwen2_config: Optional qwen2 configuration dict
|
139 |
+
"""
|
140 |
+
super(CLIPDecoder, self).__init__()
|
141 |
+
|
142 |
+
self._dtype = torch.bfloat16 if bf16 == "bf16" else torch.float32
|
143 |
+
self.qwen2_tokenizer = AutoTokenizer.from_pretrained(language_model)
|
144 |
+
|
145 |
+
self.qwen2_config = AutoConfig.from_pretrained(language_model)
|
146 |
+
self.qwen2_model = AutoModelForCausalLM.from_pretrained(
|
147 |
+
language_model,
|
148 |
+
torch_dtype=self._dtype,
|
149 |
+
device_map=None,
|
150 |
+
low_cpu_mem_usage=True
|
151 |
+
)
|
152 |
+
|
153 |
+
self.VLV_model = VLV_model # fp32 in this case
|
154 |
+
self.device = device
|
155 |
+
self.mlp = MLP(input_dim=1024, output_dim=self.qwen2_model.config.hidden_size)
|
156 |
+
self.ignore_token_id = -100
|
157 |
+
|
158 |
+
|
159 |
+
def get_conditional_context(self, images, batch_size):
|
160 |
+
"""
|
161 |
+
Get conditional context from images using the diffusion model.
|
162 |
+
|
163 |
+
Args:
|
164 |
+
images: Input images
|
165 |
+
batch_size: Batch size
|
166 |
+
|
167 |
+
Returns:
|
168 |
+
Decoder hidden states from the diffusion model
|
169 |
+
"""
|
170 |
+
prompt = ["<MORE_DETAILED_CAPTION>"] * batch_size
|
171 |
+
inputs = self.VLV_model.processor(text=prompt, images=images, return_tensors="pt").to(self.device).to(self._dtype)
|
172 |
+
|
173 |
+
# Ensure all components are on the correct device
|
174 |
+
self.VLV_model = self.VLV_model.to(inputs["input_ids"].device)
|
175 |
+
self.qwen2_model = self.qwen2_model.to(inputs["input_ids"].device)
|
176 |
+
self.mlp = self.mlp.to(inputs["input_ids"].device)
|
177 |
+
self.VLV_model.model.language_model.model = self.VLV_model.model.language_model.model.to(inputs["input_ids"].device)
|
178 |
+
|
179 |
+
if inputs["input_ids"] is not None:
|
180 |
+
inputs_embeds = self.VLV_model.model.language_model.get_input_embeddings()(inputs["input_ids"]).to(self.device)
|
181 |
+
|
182 |
+
if inputs["pixel_values"] is not None:
|
183 |
+
image_features = self.VLV_model.model._encode_image(inputs["pixel_values"]).to(self.device)
|
184 |
+
inputs_embeds, attention_mask = self.VLV_model.model._merge_input_ids_with_image_features(
|
185 |
+
image_features, inputs_embeds
|
186 |
+
)
|
187 |
+
|
188 |
+
if inputs_embeds is not None:
|
189 |
+
attention_mask = attention_mask.to(inputs_embeds.dtype)
|
190 |
+
|
191 |
+
encoder_outputs = self.VLV_model.model.language_model.model.encoder(
|
192 |
+
inputs_embeds=inputs_embeds,
|
193 |
+
attention_mask=attention_mask,
|
194 |
+
output_hidden_states=True,
|
195 |
+
return_dict=True
|
196 |
+
)
|
197 |
+
|
198 |
+
decoder_inputs_embeds = self.VLV_model.query_embed.expand(batch_size, -1, -1)
|
199 |
+
decoder_attention_mask = torch.ones(
|
200 |
+
(batch_size, self.VLV_model.num_queries),
|
201 |
+
dtype=self._dtype,
|
202 |
+
device=self.device
|
203 |
+
)
|
204 |
+
|
205 |
+
encoder_hidden_states = encoder_outputs.last_hidden_state.to(self._dtype)
|
206 |
+
decoder_input_embeds = decoder_inputs_embeds.to(self._dtype)
|
207 |
+
attention_mask = attention_mask.to(self._dtype)
|
208 |
+
|
209 |
+
decoder_outputs = self.VLV_model.model.language_model.model.decoder(
|
210 |
+
inputs_embeds=decoder_input_embeds,
|
211 |
+
attention_mask=decoder_attention_mask,
|
212 |
+
encoder_hidden_states=encoder_hidden_states,
|
213 |
+
encoder_attention_mask=attention_mask,
|
214 |
+
output_hidden_states=True,
|
215 |
+
return_dict=True
|
216 |
+
)
|
217 |
+
|
218 |
+
return decoder_outputs.last_hidden_state
|
219 |
+
|
220 |
+
def process_image(self, images, batch_size):
|
221 |
+
"""
|
222 |
+
Process images to get clip text embeddings.
|
223 |
+
|
224 |
+
Args:
|
225 |
+
images: Input images
|
226 |
+
batch_size: Batch size
|
227 |
+
|
228 |
+
Returns:
|
229 |
+
Processed clip text embeddings and attention mask
|
230 |
+
"""
|
231 |
+
decoder_hidden_states = self.get_conditional_context(images, batch_size)
|
232 |
+
context_embeds = self.VLV_model.language_proj(decoder_hidden_states)
|
233 |
+
clip_text_embeds = self.VLV_model.text_encoder(inputs_embeds=context_embeds).last_hidden_state
|
234 |
+
clip_text_embeds = self.mlp(clip_text_embeds)
|
235 |
+
clip_text_embeds_attention_mask = torch.ones(
|
236 |
+
(batch_size, self.VLV_model.num_queries),
|
237 |
+
dtype=torch.long,
|
238 |
+
device=self.device
|
239 |
+
)
|
240 |
+
|
241 |
+
return clip_text_embeds, clip_text_embeds_attention_mask
|
242 |
+
|
243 |
+
def prepare_generation_inputs(self, clip_text_embeds, clip_text_attention_mask=None):
|
244 |
+
"""
|
245 |
+
Prepare inputs for text generation.
|
246 |
+
|
247 |
+
Args:
|
248 |
+
clip_text_embeds: Processed clip text embeddings
|
249 |
+
clip_text_attention_mask: Attention mask for clip text embeddings
|
250 |
+
|
251 |
+
Returns:
|
252 |
+
Dictionary of generation inputs
|
253 |
+
"""
|
254 |
+
if clip_text_attention_mask is None:
|
255 |
+
clip_text_attention_mask = torch.ones(
|
256 |
+
(clip_text_embeds.shape[0], clip_text_embeds.shape[1]),
|
257 |
+
dtype=torch.long,
|
258 |
+
device=clip_text_embeds.device
|
259 |
+
)
|
260 |
+
|
261 |
+
return {
|
262 |
+
"inputs_embeds": clip_text_embeds,
|
263 |
+
"attention_mask": clip_text_attention_mask
|
264 |
+
}
|
265 |
+
|
266 |
+
def generate(self, images, max_new_tokens=300, num_beams=4, early_stopping=True):
|
267 |
+
"""
|
268 |
+
Generate text from images.
|
269 |
+
|
270 |
+
Args:
|
271 |
+
images: Input images
|
272 |
+
max_new_tokens: Maximum number of tokens to generate
|
273 |
+
num_beams: Number of beams for beam search
|
274 |
+
early_stopping: Whether to stop early in beam search
|
275 |
+
|
276 |
+
Returns:
|
277 |
+
CLIPDecoderOutput with generated ids and text
|
278 |
+
"""
|
279 |
+
batch_size = len(images)
|
280 |
+
clip_text_embeds, clip_text_attention_mask = self.process_image(images, batch_size)
|
281 |
+
generation_inputs = self.prepare_generation_inputs(clip_text_embeds, clip_text_attention_mask)
|
282 |
+
|
283 |
+
generation_inputs["inputs_embeds"] = generation_inputs["inputs_embeds"].to(self._dtype)
|
284 |
+
generation_inputs["attention_mask"] = generation_inputs["attention_mask"].to(self._dtype)
|
285 |
+
|
286 |
+
generated_ids = self.qwen2_model.generate(
|
287 |
+
inputs_embeds=generation_inputs["inputs_embeds"],
|
288 |
+
attention_mask=generation_inputs["attention_mask"],
|
289 |
+
max_new_tokens=max_new_tokens,
|
290 |
+
num_beams=num_beams,
|
291 |
+
early_stopping=early_stopping
|
292 |
+
)
|
293 |
+
|
294 |
+
generated_text = self.qwen2_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
295 |
+
processed_generated_text = [process_caption(text) for text in generated_text]
|
296 |
+
|
297 |
+
return CLIPDecoderOutput(
|
298 |
+
generated_ids=generated_ids,
|
299 |
+
generated_text=processed_generated_text
|
300 |
+
)
|
301 |
+
|
302 |
+
def forward(self, images, captions=None):
|
303 |
+
"""
|
304 |
+
Forward pass for training.
|
305 |
+
|
306 |
+
Args:
|
307 |
+
images: Input images
|
308 |
+
captions: Target captions (optional, for training)
|
309 |
+
|
310 |
+
Returns:
|
311 |
+
CLIPDecoderOutput with loss and logits
|
312 |
+
"""
|
313 |
+
batch_size = images.shape[0]
|
314 |
+
|
315 |
+
# Process images
|
316 |
+
clip_text_embeds, clip_text_attention_mask = self.process_image(images, batch_size)
|
317 |
+
|
318 |
+
# If no captions provided, return embeddings for generation
|
319 |
+
if captions is None:
|
320 |
+
return CLIPDecoderOutput(
|
321 |
+
last_hidden_state=clip_text_embeds
|
322 |
+
)
|
323 |
+
|
324 |
+
assert len(captions) == batch_size
|
325 |
+
# Process captions for training
|
326 |
+
processed_captions = [process_caption(caption) for caption in captions]
|
327 |
+
qwen_input_ids = self.qwen2_tokenizer(
|
328 |
+
text=processed_captions,
|
329 |
+
truncation=True,
|
330 |
+
return_tensors="pt",
|
331 |
+
padding="max_length",
|
332 |
+
max_length=300,
|
333 |
+
return_token_type_ids=False,
|
334 |
+
).input_ids
|
335 |
+
|
336 |
+
assert len(captions) == batch_size
|
337 |
+
qwen_attention_mask = qwen_input_ids.ne(self.qwen2_tokenizer.pad_token_id).to(torch.long).to(self.device)
|
338 |
+
|
339 |
+
# Prepare labels for training
|
340 |
+
labels = qwen_input_ids
|
341 |
+
labels[labels == self.qwen2_tokenizer.pad_token_id] = self.ignore_token_id
|
342 |
+
labels = labels.to(self.device)
|
343 |
+
|
344 |
+
# Get embeddings for captions to create the full input sequence
|
345 |
+
labels_for_embeddings = labels.clone()
|
346 |
+
labels_for_embeddings[labels_for_embeddings == self.ignore_token_id] = self.qwen2_tokenizer.pad_token_id
|
347 |
+
clip_text_embeds_qwen = self.qwen2_model.get_input_embeddings()(labels_for_embeddings)
|
348 |
+
|
349 |
+
# Concatenate the embeddings and prepare attention mask
|
350 |
+
inputs_embeds = torch.cat((clip_text_embeds, clip_text_embeds_qwen), dim=1)
|
351 |
+
clip_seq_len = clip_text_embeds.shape[1]
|
352 |
+
clip_ignore_labels = torch.full((labels.shape[0], clip_seq_len), self.ignore_token_id).to(labels)
|
353 |
+
combined_labels = torch.cat((clip_ignore_labels, labels), dim=1)
|
354 |
+
|
355 |
+
attention_mask = torch.cat((
|
356 |
+
clip_text_attention_mask,
|
357 |
+
qwen_attention_mask
|
358 |
+
), dim=1)
|
359 |
+
|
360 |
+
# Forward through language model
|
361 |
+
outputs = self.qwen2_model(
|
362 |
+
inputs_embeds=inputs_embeds,
|
363 |
+
labels=combined_labels,
|
364 |
+
attention_mask=attention_mask,
|
365 |
+
use_cache=False
|
366 |
+
)
|
367 |
+
return outputs
|
368 |
+
|
369 |
+
|
370 |
+
# HuggingFace Model Wrapper
|
371 |
+
class VLV_MODEL(PreTrainedModel):
|
372 |
+
config_class = VLV_Config
|
373 |
+
model_type = "VLV_decoder"
|
374 |
+
|
375 |
+
def __init__(self, config):
|
376 |
+
super().__init__(config)
|
377 |
+
"""Load the CLIPDecoder model."""
|
378 |
+
# Initialize the diffusion model first
|
379 |
+
device = "cuda"
|
380 |
+
de_diffusion_model = initialize_diffusion_model(config)
|
381 |
+
clip_decoder_model = CLIPDecoder(
|
382 |
+
language_model=config.qwen_model,
|
383 |
+
VLV_model=de_diffusion_model,
|
384 |
+
device=device,
|
385 |
+
bf16=config.mixed_precision,
|
386 |
+
qwen2_config=config.qwen2_config
|
387 |
+
)
|
388 |
+
|
389 |
+
# Load the trained weights
|
390 |
+
# clip_decoder_model = load_model_checkpoint(clip_decoder_model, config.clip_decoder_checkpoint, device)
|
391 |
+
|
392 |
+
# Set to evaluation mode
|
393 |
+
clip_decoder_model.eval()
|
394 |
+
|
395 |
+
# Store components directly as attributes to match checkpoint structure
|
396 |
+
self.VLV_model = clip_decoder_model.VLV_model
|
397 |
+
self.qwen2_model = clip_decoder_model.qwen2_model
|
398 |
+
self.mlp = clip_decoder_model.mlp
|
399 |
+
|
400 |
+
# Keep the full model for methods
|
401 |
+
self._clip_decoder_model = clip_decoder_model
|
402 |
+
self.max_new_tokens = config.max_length
|
403 |
+
self.num_beams = config.num_beams
|
404 |
+
self.transform = self.get_transform(config.image_size)
|
405 |
+
|
406 |
+
def get_transform(self, image_size):
|
407 |
+
"""Transformation pipeline for input images."""
|
408 |
+
return transforms.Compose([
|
409 |
+
transforms.Resize(image_size),
|
410 |
+
transforms.CenterCrop((image_size, image_size)),
|
411 |
+
transforms.PILToTensor(),
|
412 |
+
])
|
413 |
+
|
414 |
+
@classmethod
|
415 |
+
def from_checkpoint(cls, checkpoint_path, config=None, **kwargs):
|
416 |
+
"""
|
417 |
+
Load model from original training checkpoint.
|
418 |
+
|
419 |
+
Args:
|
420 |
+
checkpoint_path: Path to the original model.pt checkpoint
|
421 |
+
config: Optional VLV_Config, will create default if None
|
422 |
+
**kwargs: Additional arguments for model initialization
|
423 |
+
"""
|
424 |
+
if config is None:
|
425 |
+
# Create default config
|
426 |
+
config = VLV_Config(
|
427 |
+
image_size=384,
|
428 |
+
guidance_scale=7.5,
|
429 |
+
learnable_token_length=77,
|
430 |
+
max_length=300,
|
431 |
+
num_beams=4,
|
432 |
+
**kwargs
|
433 |
+
)
|
434 |
+
|
435 |
+
# Initialize model
|
436 |
+
model = cls(config)
|
437 |
+
|
438 |
+
# Load checkpoint weights
|
439 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
440 |
+
load_model_checkpoint(model._clip_decoder_model, checkpoint_path, device)
|
441 |
+
|
442 |
+
return model
|
443 |
+
|
444 |
+
def forward(self, valid_images, max_length):
|
445 |
+
valid_images = [self.transform(img) for img in valid_images]
|
446 |
+
if hasattr(self._clip_decoder_model, 'module'):
|
447 |
+
outputs = self._clip_decoder_model.module.generate(
|
448 |
+
valid_images,
|
449 |
+
max_new_tokens=max_length,
|
450 |
+
num_beams=self.num_beams,
|
451 |
+
early_stopping=True
|
452 |
+
)
|
453 |
+
else:
|
454 |
+
outputs = self._clip_decoder_model.generate(
|
455 |
+
valid_images,
|
456 |
+
max_new_tokens=max_length,
|
457 |
+
num_beams=self.num_beams,
|
458 |
+
early_stopping=True
|
459 |
+
)
|
460 |
+
return outputs
|
build.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler
|
3 |
+
from transformers import CLIPTokenizer, AutoProcessor
|
4 |
+
from .modeling_clip import CustomCLIPTextModel
|
5 |
+
from .modeling_florence2 import Florence2ForConditionalGeneration
|
6 |
+
from .configuration_florence2 import Florence2Config
|
7 |
+
|
8 |
+
|
9 |
+
def load_sd_model(training_args):
|
10 |
+
"""Load Stable Diffusion model"""
|
11 |
+
|
12 |
+
repo_id = "stabilityai/stable-diffusion-2-1-base"
|
13 |
+
|
14 |
+
text_encoder = CustomCLIPTextModel.from_pretrained(repo_id, subfolder="text_encoder")
|
15 |
+
tokenizer = CLIPTokenizer.from_pretrained(repo_id, subfolder="tokenizer")
|
16 |
+
vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae",revision=None)
|
17 |
+
scheduler = DDPMScheduler.from_pretrained(repo_id, subfolder="scheduler")
|
18 |
+
unet = UNet2DConditionModel.from_pretrained(repo_id, subfolder="unet",revision=None)
|
19 |
+
|
20 |
+
for m in [vae, text_encoder, unet]:
|
21 |
+
for param in m.parameters():
|
22 |
+
param.requires_grad = False
|
23 |
+
|
24 |
+
return (vae, tokenizer, text_encoder, unet, scheduler)
|
25 |
+
|
26 |
+
|
27 |
+
def load_Florence2_model(training_args):
|
28 |
+
config = Florence2Config.from_pretrained("microsoft/Florence-2-large")
|
29 |
+
config.vision_config.model_type = "davit"
|
30 |
+
config._attn_implementation = "eager"
|
31 |
+
|
32 |
+
# Load the model with pre-trained weights
|
33 |
+
model = Florence2ForConditionalGeneration.from_pretrained("microsoft/Florence-2-large", config=config)
|
34 |
+
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)
|
35 |
+
|
36 |
+
# freeze the model
|
37 |
+
if training_args.unfreeze_florence2_all:
|
38 |
+
for param in model.parameters():
|
39 |
+
param.requires_grad = True
|
40 |
+
elif training_args.unfreeze_florence2_language_model:
|
41 |
+
for param in model.parameters():
|
42 |
+
param.requires_grad = False
|
43 |
+
for param in model.language_model.parameters():
|
44 |
+
param.requires_grad = True
|
45 |
+
for param in model.language_model.lm_head.parameters():
|
46 |
+
param.requires_grad = False
|
47 |
+
|
48 |
+
model.language_model.lm_head.weight = torch.nn.Parameter(
|
49 |
+
model.language_model.lm_head.weight.detach().clone())
|
50 |
+
|
51 |
+
for p in model.language_model.lm_head.parameters():
|
52 |
+
p.requires_grad = False
|
53 |
+
|
54 |
+
|
55 |
+
elif training_args.unfreeze_florence2_language_model_decoder:
|
56 |
+
# Create a separate embedding layer for decoder
|
57 |
+
original_embeddings = model.language_model.model.shared
|
58 |
+
new_decoder_embeddings = torch.nn.Embedding(
|
59 |
+
num_embeddings=original_embeddings.num_embeddings,
|
60 |
+
embedding_dim=original_embeddings.embedding_dim,
|
61 |
+
padding_idx=original_embeddings.padding_idx
|
62 |
+
)
|
63 |
+
# Copy the weights
|
64 |
+
new_decoder_embeddings.weight.data = original_embeddings.weight.data.clone()
|
65 |
+
|
66 |
+
# Replace the decoder embeddings
|
67 |
+
model.language_model.model.encoder.embed_tokens = original_embeddings
|
68 |
+
model.language_model.model.decoder.embed_tokens = new_decoder_embeddings
|
69 |
+
for param in model.parameters():
|
70 |
+
param.requires_grad = False
|
71 |
+
for param in model.language_model.model.decoder.parameters():
|
72 |
+
param.requires_grad = True
|
73 |
+
model.language_model.model.decoder.embed_tokens.weight.requires_grad = False
|
74 |
+
else:
|
75 |
+
for param in model.parameters():
|
76 |
+
param.requires_grad = False
|
77 |
+
|
78 |
+
return model, processor
|
config.json
CHANGED
@@ -3,27 +3,31 @@
|
|
3 |
"VLV_MODEL"
|
4 |
],
|
5 |
"auto_map": {
|
6 |
-
"AutoConfig": "
|
7 |
-
"AutoModel": "
|
8 |
-
"AutoModelForCausalLM": "
|
9 |
},
|
10 |
"model_type": "VLV_decoder",
|
11 |
"batch_size": 1,
|
12 |
"deepspeed": true,
|
13 |
"distributed": true,
|
14 |
"fp32": true,
|
15 |
-
"guidance_scale": 2.
|
16 |
"hidden_size": 128,
|
17 |
-
"image_size":
|
18 |
"learnable_token_length": 77,
|
19 |
"local_rank": 0,
|
20 |
-
"mixed_precision": "
|
21 |
"num_inference_steps": 50,
|
22 |
-
"torch_dtype": "
|
23 |
"transformers_version": "4.51.1",
|
24 |
"use_text_encoder": true,
|
25 |
"verbose": true,
|
26 |
"qwen_model": "Qwen/Qwen2.5-3B",
|
|
|
|
|
|
|
|
|
27 |
"qwen2_config":{
|
28 |
"architectures": [
|
29 |
"Qwen2ForCausalLM"
|
@@ -45,11 +49,11 @@
|
|
45 |
"rope_theta": 1000000.0,
|
46 |
"sliding_window": 32768,
|
47 |
"tie_word_embeddings": true,
|
48 |
-
"torch_dtype": "
|
49 |
"transformers_version": "4.40.1",
|
50 |
"use_cache": true,
|
51 |
"use_mrope": false,
|
52 |
"use_sliding_window": false,
|
53 |
"vocab_size": 151936
|
54 |
}
|
55 |
-
}
|
|
|
3 |
"VLV_MODEL"
|
4 |
],
|
5 |
"auto_map": {
|
6 |
+
"AutoConfig": "configuration_vlv.VLV_Config",
|
7 |
+
"AutoModel": "VLV_stage2.VLV_MODEL",
|
8 |
+
"AutoModelForCausalLM": "VLV_stage2.VLV_MODEL"
|
9 |
},
|
10 |
"model_type": "VLV_decoder",
|
11 |
"batch_size": 1,
|
12 |
"deepspeed": true,
|
13 |
"distributed": true,
|
14 |
"fp32": true,
|
15 |
+
"guidance_scale": 2.5,
|
16 |
"hidden_size": 128,
|
17 |
+
"image_size": 384,
|
18 |
"learnable_token_length": 77,
|
19 |
"local_rank": 0,
|
20 |
+
"mixed_precision": "fp32",
|
21 |
"num_inference_steps": 50,
|
22 |
+
"torch_dtype": "float32",
|
23 |
"transformers_version": "4.51.1",
|
24 |
"use_text_encoder": true,
|
25 |
"verbose": true,
|
26 |
"qwen_model": "Qwen/Qwen2.5-3B",
|
27 |
+
"stable_diffusion_model_path": "stabilityai/stable-diffusion-2-1-base",
|
28 |
+
"florence2_model_path": "microsoft/Florence-2-large",
|
29 |
+
"max_length": 300,
|
30 |
+
"num_beams": 4,
|
31 |
"qwen2_config":{
|
32 |
"architectures": [
|
33 |
"Qwen2ForCausalLM"
|
|
|
49 |
"rope_theta": 1000000.0,
|
50 |
"sliding_window": 32768,
|
51 |
"tie_word_embeddings": true,
|
52 |
+
"torch_dtype": "float32",
|
53 |
"transformers_version": "4.40.1",
|
54 |
"use_cache": true,
|
55 |
"use_mrope": false,
|
56 |
"use_sliding_window": false,
|
57 |
"vocab_size": 151936
|
58 |
}
|
59 |
+
}
|
configuration_vlv.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 VLV Team and the HuggingFace Inc. team. All rights reserved.
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""VLV model configuration"""
|
16 |
+
|
17 |
+
from typing import Optional, Dict, Any
|
18 |
+
from transformers.configuration_utils import PretrainedConfig
|
19 |
+
from transformers.utils import logging
|
20 |
+
|
21 |
+
logger = logging.get_logger(__name__)
|
22 |
+
|
23 |
+
|
24 |
+
class VLV_Config(PretrainedConfig):
|
25 |
+
r"""
|
26 |
+
This is the configuration class to store the configuration of a [`VLV_MODEL`]. It is used to instantiate a VLV model
|
27 |
+
according to the specified arguments, defining the model architecture.
|
28 |
+
|
29 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
30 |
+
documentation from [`PretrainedConfig`] for more information.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
model_type (`str`, *optional*, defaults to "VLV_decoder"):
|
34 |
+
The model type identifier.
|
35 |
+
batch_size (`int`, *optional*, defaults to 1):
|
36 |
+
The batch size for inference.
|
37 |
+
deepspeed (`bool`, *optional*, defaults to True):
|
38 |
+
Whether to use deepspeed.
|
39 |
+
distributed (`bool`, *optional*, defaults to True):
|
40 |
+
Whether to use distributed training.
|
41 |
+
fp32 (`bool`, *optional*, defaults to True):
|
42 |
+
Whether to use fp32 precision.
|
43 |
+
guidance_scale (`float`, *optional*, defaults to 2.0):
|
44 |
+
The guidance scale for generation.
|
45 |
+
hidden_size (`int`, *optional*, defaults to 128):
|
46 |
+
The hidden size of the model.
|
47 |
+
image_size (`int`, *optional*, defaults to 768):
|
48 |
+
The size of input images.
|
49 |
+
learnable_token_length (`int`, *optional*, defaults to 77):
|
50 |
+
The length of learnable tokens.
|
51 |
+
local_rank (`int`, *optional*, defaults to 0):
|
52 |
+
The local rank for distributed training.
|
53 |
+
mixed_precision (`str`, *optional*, defaults to "bf16"):
|
54 |
+
The mixed precision mode.
|
55 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
56 |
+
The number of inference steps.
|
57 |
+
torch_dtype (`str`, *optional*, defaults to "bfloat16"):
|
58 |
+
The torch dtype.
|
59 |
+
use_text_encoder (`bool`, *optional*, defaults to True):
|
60 |
+
Whether to use text encoder.
|
61 |
+
verbose (`bool`, *optional*, defaults to True):
|
62 |
+
Whether to enable verbose mode.
|
63 |
+
qwen_model (`str`, *optional*, defaults to "Qwen/Qwen2.5-3B"):
|
64 |
+
The Qwen model to use.
|
65 |
+
qwen2_config (`dict`, *optional*):
|
66 |
+
The Qwen2 configuration.
|
67 |
+
max_length (`int`, *optional*, defaults to 300):
|
68 |
+
Maximum length for generation.
|
69 |
+
num_beams (`int`, *optional*, defaults to 4):
|
70 |
+
Number of beams for beam search.
|
71 |
+
"""
|
72 |
+
|
73 |
+
model_type = "VLV_decoder"
|
74 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
75 |
+
|
76 |
+
def __init__(
|
77 |
+
self,
|
78 |
+
model_type: str = "VLV_decoder",
|
79 |
+
batch_size: int = 1,
|
80 |
+
deepspeed: bool = True,
|
81 |
+
distributed: bool = True,
|
82 |
+
fp32: bool = True,
|
83 |
+
guidance_scale: float = 2.0,
|
84 |
+
hidden_size: int = 128,
|
85 |
+
image_size: int = 768,
|
86 |
+
learnable_token_length: int = 77,
|
87 |
+
local_rank: int = 0,
|
88 |
+
mixed_precision: str = "bf16",
|
89 |
+
num_inference_steps: int = 50,
|
90 |
+
torch_dtype: str = "bfloat16",
|
91 |
+
transformers_version: str = "4.51.1",
|
92 |
+
use_text_encoder: bool = True,
|
93 |
+
verbose: bool = True,
|
94 |
+
qwen_model: str = "Qwen/Qwen2.5-3B",
|
95 |
+
stable_diffusion_model_path: str = "stabilityai/stable-diffusion-2-1-base",
|
96 |
+
florence2_model_path: str = "microsoft/Florence-2-large",
|
97 |
+
qwen2_config: Optional[Dict[str, Any]] = None,
|
98 |
+
max_length: int = 300,
|
99 |
+
num_beams: int = 4,
|
100 |
+
**kwargs,
|
101 |
+
):
|
102 |
+
self.model_type = model_type
|
103 |
+
self.batch_size = batch_size
|
104 |
+
self.deepspeed = deepspeed
|
105 |
+
self.distributed = distributed
|
106 |
+
self.fp32 = fp32
|
107 |
+
self.guidance_scale = guidance_scale
|
108 |
+
self.hidden_size = hidden_size
|
109 |
+
self.image_size = image_size
|
110 |
+
self.learnable_token_length = learnable_token_length
|
111 |
+
self.local_rank = local_rank
|
112 |
+
self.mixed_precision = mixed_precision
|
113 |
+
self.num_inference_steps = num_inference_steps
|
114 |
+
self.torch_dtype = torch_dtype
|
115 |
+
self.transformers_version = transformers_version
|
116 |
+
self.use_text_encoder = use_text_encoder
|
117 |
+
self.verbose = verbose
|
118 |
+
self.qwen_model = qwen_model
|
119 |
+
self.stable_diffusion_model_path = stable_diffusion_model_path
|
120 |
+
self.florence2_model_path = florence2_model_path
|
121 |
+
self.qwen2_config = qwen2_config or self._get_default_qwen2_config()
|
122 |
+
self.max_length = max_length
|
123 |
+
self.num_beams = num_beams
|
124 |
+
|
125 |
+
super().__init__(**kwargs)
|
126 |
+
|
127 |
+
def _get_default_qwen2_config(self):
|
128 |
+
"""Get default Qwen2 configuration."""
|
129 |
+
return {
|
130 |
+
"architectures": ["Qwen2ForCausalLM"],
|
131 |
+
"attention_dropout": 0.0,
|
132 |
+
"bos_token_id": 151643,
|
133 |
+
"eos_token_id": 151643,
|
134 |
+
"hidden_act": "silu",
|
135 |
+
"hidden_size": 2048,
|
136 |
+
"initializer_range": 0.02,
|
137 |
+
"intermediate_size": 11008,
|
138 |
+
"max_position_embeddings": 32768,
|
139 |
+
"max_window_layers": 36,
|
140 |
+
"model_type": "qwen2",
|
141 |
+
"num_attention_heads": 16,
|
142 |
+
"num_hidden_layers": 36,
|
143 |
+
"num_key_value_heads": 2,
|
144 |
+
"rms_norm_eps": 1e-06,
|
145 |
+
"rope_theta": 1000000.0,
|
146 |
+
"sliding_window": 32768,
|
147 |
+
"tie_word_embeddings": True,
|
148 |
+
"torch_dtype": "bfloat16",
|
149 |
+
"transformers_version": "4.40.1",
|
150 |
+
"use_cache": True,
|
151 |
+
"use_mrope": False,
|
152 |
+
"use_sliding_window": False,
|
153 |
+
"vocab_size": 151936
|
154 |
+
}
|
155 |
+
|
156 |
+
|
157 |
+
class CLIPDecoderConfig(PretrainedConfig):
|
158 |
+
r"""
|
159 |
+
Configuration class for CLIPDecoder model (legacy support).
|
160 |
+
"""
|
161 |
+
|
162 |
+
model_type = "vlv_stage2"
|
163 |
+
|
164 |
+
def __init__(
|
165 |
+
self,
|
166 |
+
input_dim: int = 1024,
|
167 |
+
bf16: bool = False,
|
168 |
+
**kwargs,
|
169 |
+
):
|
170 |
+
self.input_dim = input_dim
|
171 |
+
self.bf16 = bf16
|
172 |
+
super().__init__(**kwargs)
|
model-00001-of-00005.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d7460963b2ea4c7cde35d0c64c8d46d4a9324c7574433f8cf9878bbaf687f61b
|
3 |
+
size 622330008
|
model-00002-of-00005.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dca6a859202a8817026897383409ec85fb0a22d4b6527da6ab5f5e2ccd3745be
|
3 |
+
size 832409864
|
model-00003-of-00005.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:30d67c2d202ae6c4166ba0b82310f19225665305e6fb3b22c66ff5318fbf6f50
|
3 |
+
size 210079920
|
model-00004-of-00005.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c015745c4638633cfb7d09e9b2b96bfa15fd21511fd74642d13296afc9423a4f
|
3 |
+
size 5215310704
|
model-00005-of-00005.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:da8bab2f53dbd82612d2034d6e67724a171a44fc040198ca5fe9d6120cc3409e
|
3 |
+
size 5046894020
|
model.safetensors.index.json
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
modeling_clip.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
from transformers import CLIPTokenizer, CLIPImageProcessor, CLIPTextModel, CLIPPreTrainedModel, CLIPTextConfig
|
2 |
-
from transformers.models.clip.modeling_clip import CLIPTextEmbeddings, CLIPEncoder, CLIPAttention, CLIPMLP, CLIPEncoderLayer, _create_4d_causal_attention_mask, _prepare_4d_attention_mask, BaseModelOutputWithPooling
|
3 |
from typing import Optional, Union, Tuple
|
4 |
import torch
|
5 |
from torch import nn
|
@@ -53,7 +53,8 @@ class CustomCLIPTextTransformer(nn.Module):
|
|
53 |
|
54 |
|
55 |
if inputs_embeds is not None:
|
56 |
-
inputs_embeds
|
|
|
57 |
else:
|
58 |
inputs_embeds = self.embeddings(input_ids=input_ids, position_ids=position_ids)
|
59 |
|
@@ -134,9 +135,49 @@ class CustomCLIPTextModel(CLIPPreTrainedModel):
|
|
134 |
output_hidden_states: Optional[bool] = None,
|
135 |
return_dict: Optional[bool] = None,
|
136 |
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
|
138 |
|
139 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
input_ids=input_ids,
|
141 |
attention_mask=attention_mask,
|
142 |
position_ids=position_ids,
|
@@ -145,3 +186,19 @@ class CustomCLIPTextModel(CLIPPreTrainedModel):
|
|
145 |
output_hidden_states=output_hidden_states,
|
146 |
return_dict=return_dict,
|
147 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from transformers import CLIPTokenizer, CLIPImageProcessor, CLIPTextModel, CLIPPreTrainedModel, CLIPTextConfig
|
2 |
+
from transformers.models.clip.modeling_clip import CLIPTextEmbeddings, CLIPEncoder, CLIPAttention, CLIPMLP, CLIPEncoderLayer, _create_4d_causal_attention_mask, _prepare_4d_attention_mask, BaseModelOutputWithPooling, CLIPTextModelOutput
|
3 |
from typing import Optional, Union, Tuple
|
4 |
import torch
|
5 |
from torch import nn
|
|
|
53 |
|
54 |
|
55 |
if inputs_embeds is not None:
|
56 |
+
# inputs_embeds are already embeddings, just add positional embeddings
|
57 |
+
inputs_embeds = self.embeddings.position_embedding(self.embeddings.position_ids[:, :inputs_embeds.size(1)]) + inputs_embeds
|
58 |
else:
|
59 |
inputs_embeds = self.embeddings(input_ids=input_ids, position_ids=position_ids)
|
60 |
|
|
|
135 |
output_hidden_states: Optional[bool] = None,
|
136 |
return_dict: Optional[bool] = None,
|
137 |
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
138 |
+
return self.text_model(
|
139 |
+
input_ids=input_ids,
|
140 |
+
attention_mask=attention_mask,
|
141 |
+
position_ids=position_ids,
|
142 |
+
inputs_embeds=inputs_embeds,
|
143 |
+
output_attentions=output_attentions,
|
144 |
+
output_hidden_states=output_hidden_states,
|
145 |
+
return_dict=return_dict,
|
146 |
+
)
|
147 |
|
148 |
|
149 |
+
class CustomCLIPTextModelWithProjection(CLIPPreTrainedModel):
|
150 |
+
config_class = CLIPTextConfig
|
151 |
+
_no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]
|
152 |
+
|
153 |
+
def __init__(self, config: CLIPTextConfig):
|
154 |
+
super().__init__(config)
|
155 |
+
self.text_model = CustomCLIPTextTransformer(config)
|
156 |
+
|
157 |
+
# Add the projection layer for SDXL's second text encoder
|
158 |
+
projection_dim = getattr(config, 'projection_dim', config.hidden_size)
|
159 |
+
self.text_projection = nn.Linear(config.hidden_size, projection_dim, bias=False)
|
160 |
+
|
161 |
+
# Initialize weights and apply final processing
|
162 |
+
self.post_init()
|
163 |
+
|
164 |
+
def get_input_embeddings(self) -> nn.Module:
|
165 |
+
return self.text_model.embeddings.token_embedding
|
166 |
+
|
167 |
+
def set_input_embeddings(self, value):
|
168 |
+
self.text_model.embeddings.token_embedding = value
|
169 |
+
|
170 |
+
def forward(
|
171 |
+
self,
|
172 |
+
input_ids: Optional[torch.Tensor] = None,
|
173 |
+
attention_mask: Optional[torch.Tensor] = None,
|
174 |
+
position_ids: Optional[torch.Tensor] = None,
|
175 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
176 |
+
output_attentions: Optional[bool] = None,
|
177 |
+
output_hidden_states: Optional[bool] = None,
|
178 |
+
return_dict: Optional[bool] = None,
|
179 |
+
) -> Union[Tuple, CLIPTextModelOutput]:
|
180 |
+
text_outputs = self.text_model(
|
181 |
input_ids=input_ids,
|
182 |
attention_mask=attention_mask,
|
183 |
position_ids=position_ids,
|
|
|
186 |
output_hidden_states=output_hidden_states,
|
187 |
return_dict=return_dict,
|
188 |
)
|
189 |
+
|
190 |
+
pooled_output = text_outputs[1] if not return_dict else text_outputs.pooler_output
|
191 |
+
|
192 |
+
# Apply the projection to the pooled output
|
193 |
+
text_embeds = self.text_projection(pooled_output)
|
194 |
+
|
195 |
+
if not return_dict:
|
196 |
+
# Include both last_hidden_state, pooler_output, text_embeds, and other outputs
|
197 |
+
return (text_outputs[0], text_outputs[1], text_embeds) + text_outputs[2:]
|
198 |
+
|
199 |
+
return CLIPTextModelOutput(
|
200 |
+
text_embeds=text_embeds, # Projected embeddings (for similarity)
|
201 |
+
last_hidden_state=text_outputs.last_hidden_state, # All token representations
|
202 |
+
hidden_states=text_outputs.hidden_states,
|
203 |
+
attentions=text_outputs.attentions,
|
204 |
+
)
|
vlv_utils.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Utility functions"""
|
2 |
+
import importlib
|
3 |
+
import random
|
4 |
+
import re
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
|
10 |
+
def normalize(image,rescale=True):
|
11 |
+
|
12 |
+
if rescale:
|
13 |
+
image = image.float() / 255.0 # Convert to float and rescale to [0, 1]
|
14 |
+
normalize_image = 2*image-1 # normalize to [-1, 1]
|
15 |
+
|
16 |
+
return normalize_image
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
def process_caption(caption):
|
21 |
+
"""Process a caption to ensure proper formatting and remove duplicates.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
caption: A string containing the caption text
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
processed_caption: A string with processed caption
|
28 |
+
"""
|
29 |
+
if not caption.endswith('.'):
|
30 |
+
last_period_index = caption.rfind('.')
|
31 |
+
if last_period_index != -1:
|
32 |
+
caption = caption[:last_period_index + 1]
|
33 |
+
|
34 |
+
sentences = re.split(r'(?<=[.!?])\s+', caption)
|
35 |
+
|
36 |
+
unique_sentences = []
|
37 |
+
for sentence in sentences:
|
38 |
+
if sentence and sentence not in unique_sentences:
|
39 |
+
unique_sentences.append(sentence)
|
40 |
+
|
41 |
+
processed_caption = ' '.join(unique_sentences)
|
42 |
+
|
43 |
+
return processed_caption
|
44 |
+
|
45 |
+
|
46 |
+
def initiate_time_steps(step, total_timestep, batch_size, config):
|
47 |
+
"""A helper function to initiate time steps for the diffusion model.
|
48 |
+
|
49 |
+
Args:
|
50 |
+
step: An integer of the constant step
|
51 |
+
total_timestep: An integer of the total timesteps of the diffusion model
|
52 |
+
batch_size: An integer of the batch size
|
53 |
+
config: A config object
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
timesteps: A tensor of shape [batch_size,] of the time steps
|
57 |
+
"""
|
58 |
+
if config.rand_timestep_equal_int:
|
59 |
+
# the same timestep for each image in the batch
|
60 |
+
interval_val = total_timestep // batch_size
|
61 |
+
start_point = random.randint(0, interval_val - 1)
|
62 |
+
timesteps = torch.tensor(
|
63 |
+
list(range(start_point, total_timestep, interval_val))
|
64 |
+
).long()
|
65 |
+
return timesteps
|
66 |
+
elif config.random_timestep_per_iteration:
|
67 |
+
# random timestep for each image in the batch
|
68 |
+
return torch.randint(0, total_timestep, (batch_size,)).long() #default
|
69 |
+
else:
|
70 |
+
# why we need to do this?
|
71 |
+
return torch.tensor([step] * batch_size).long()
|