lambertxiao commited on
Commit
492f6af
·
verified ·
1 Parent(s): 7ba7930

Overwrite with converted Qwen2.5-3B model files

Browse files
README.md CHANGED
@@ -1,104 +1,255 @@
1
  ---
2
  license: apache-2.0
3
- language:
4
- - en
5
- base_model:
6
- - Qwen/Qwen2.5-3B-Instruct
7
- - microsoft/Florence-2-large
 
 
 
8
  pipeline_tag: image-to-text
 
 
 
 
9
  ---
10
 
11
- # Vision-Language-Vision Auto-Encoder: Scalable Knowledge Distillation from Diffusion Models
12
 
13
- [![Website](https://img.shields.io/badge/Project%20Page-Website-brightgreen?logo=googlechrome&logoColor=white)](https://lambert-x.github.io/Vision-Language-Vision/)
14
- [![arXiv](https://img.shields.io/badge/arXiv-2507.07104-B31B1B.svg?logo=arXiv&logoColor=white)](https://arxiv.org/abs/2507.07104)
15
- [![GitHub](https://img.shields.io/badge/Code-GitHub-black?logo=github)](https://github.com/Tiezheng11/Vision-Language-Vision)
16
- [![HF Model](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue)](https://huggingface.co/lambertxiao/Vision-Language-Vision-Captioner-Qwen2.5-3B)
17
- [![HF Dataset](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Dataset-yellow)](https://huggingface.co/datasets/ccvl/LAION-High-Qualtiy-Pro-6M-VLV)
18
 
19
- ## VLV Captioner (Qwen 2.5 3B)
20
 
21
- This repository hosts the 3-billion-parameter **Vision-Language-Vision Captioner** model, distantly supervised by diffusion models and built on top of Qwen 2.5 3B.
22
- Checkpoint URL: **<https://huggingface.co/lambertxiao/Vision-Language-Vision-Captioner-Qwen2.5-3B>**
 
 
23
 
24
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- ## 1 · Install Dependencies
27
 
28
  ```bash
29
- # inside your virtualenv / conda env
30
- pip install -r requirements.txt
31
  ```
32
 
33
- ## 2 · Example Usage
 
 
 
 
 
 
34
  ```python
35
- from transformers import AutoModel
36
- from PIL import Image
37
- import torch, numpy as np
38
 
39
- MODEL_NAME = "lambertxiao/Vision-Language-Vision-Captioner-Qwen2.5-3B"
40
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
41
 
42
- # ────── load model ──────
43
- model = (
44
- AutoModel.from_pretrained(
45
- MODEL_NAME,
46
- trust_remote_code=True,
47
- low_cpu_mem_usage=False,
48
- )
49
- .to(device)
50
- .eval()
51
- )
52
 
53
- # ────── helpers ──────
54
- def _trim_tail(text: str) -> str:
55
- """Remove an incomplete trailing sentence fragment, if any."""
56
- sentences = [s.strip() for s in text.split(".") if s.strip()]
57
- if not text.rstrip().endswith("."):
58
- sentences = sentences[:-1] # drop dangling fragment
59
- return ". ".join(sentences) + ("." if sentences else "")
60
 
61
- def caption_image(img: Image.Image, max_len: int = 77) -> str:
62
- """Generate a caption for one PIL image."""
63
- with torch.no_grad():
64
- raw = model([img], max_len).generated_text[0]
65
- return _trim_tail(raw)
66
-
67
- def caption_from_numpy(arr: np.ndarray, max_len: int = 77) -> str:
68
- """
69
- Wrapper for NumPy arrays.
70
- Accepts uint8 [0, 255] or float [0, 1] ranges.
71
- """
72
- if arr.dtype != np.uint8:
73
- arr = (np.clip(arr, 0, 1) * 255).astype(np.uint8)
74
- return caption_image(Image.fromarray(arr, mode="RGB"), max_len)
75
  ```
76
 
 
 
77
 
78
- ## 3 · Quick Test
 
 
 
 
 
 
 
 
 
79
 
80
  ```python
81
- # caption a remote sample image (cat photo) in one cell
 
 
 
82
 
83
- import io, requests
84
- from PIL import Image
85
- from IPython.display import display # Jupyter/Colab only
86
 
87
- IMG_URL = "https://huggingface.co/datasets/huggingface/cats-image/resolve/main/cats_image.jpeg"
88
 
89
- # download & open
90
- img = Image.open(io.BytesIO(requests.get(IMG_URL, timeout=10).content)).convert("RGB")
91
 
92
- display(img) # show the image
93
- print(caption_image(img)) # generate and print the caption
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
 
 
 
 
 
 
 
 
95
  ```
96
- ## 4 · Citation
 
97
 
98
  ```bibtex
99
- @article{zhang2025vision,
100
- title = {Vision-Language-Vision Auto-Encoder: Scalable Knowledge Distillation from Diffusion Models},
101
- author = {Zhang, Tiezheng and Li, Yitong and Chou, Yu-Cheng and Chen, Jieneng and Yuille, Alan and Wei, Chen and Xiao, Junfei},
102
- journal = {arXiv preprint arXiv:2507.07104},
103
- year = {2025}
104
  }
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
+ tags:
4
+ - image-captioning
5
+ - multimodal
6
+ - vision-language
7
+ - diffusion
8
+ - pytorch
9
+ - transformers
10
+ library_name: transformers
11
  pipeline_tag: image-to-text
12
+ datasets:
13
+ - conceptual_captions
14
+ - coco
15
+ model_type: VLV_decoder
16
  ---
17
 
18
+ # VLV Captioner Model
19
 
20
+ This is a VLV (Vision-Language-Vision) model for image captioning. The model combines stable diffusion image encoding with Qwen language model for generating descriptive captions from images.
 
 
 
 
21
 
22
+ ## Model Description
23
 
24
+ The VLV Captioner is a multimodal model that:
25
+ - Uses a diffusion-based vision encoder to extract image features
26
+ - Employs the Qwen2.5-3B language model for text generation
27
+ - Generates natural language descriptions of input images
28
 
29
+ ## Model Architecture
30
+
31
+ - **Vision Encoder**: Stable Diffusion-based image encoder with Florence2 components
32
+ - **Language Model**: Qwen2.5-3B transformer model
33
+ - **Image Size**: 384x384 pixels
34
+ - **Max Caption Length**: 300 tokens
35
+ - **Precision**: Mixed precision (bfloat16/float32)
36
+
37
+ ## Usage
38
+
39
+ ### Method 1: Load from Hugging Face Hub
40
+
41
+ ```python
42
+ from transformers import AutoModel, AutoConfig
43
+ from PIL import Image
44
+ import torch
45
+ import os
46
+
47
+ # Optional: Set custom cache directory if needed
48
+ cache_dir = "/path/to/your/cache" # Use a directory with sufficient space
49
+ os.makedirs(cache_dir, exist_ok=True)
50
+
51
+ # Load the model with authentication token (if required)
52
+ token = os.getenv('HUGGINGFACE_TOKEN') # or your token string
53
+
54
+ print("Loading config...")
55
+ config = AutoConfig.from_pretrained(
56
+ "your-username/vlv-captioner",
57
+ trust_remote_code=True,
58
+ token=token,
59
+ cache_dir=cache_dir
60
+ )
61
+
62
+ print("Loading model...")
63
+ try:
64
+ model = AutoModel.from_pretrained(
65
+ "your-username/vlv-captioner",
66
+ trust_remote_code=True,
67
+ token=token,
68
+ cache_dir=cache_dir,
69
+ torch_dtype=torch.float32, # Specify dtype explicitly
70
+ low_cpu_mem_usage=True
71
+ # Note: Avoid device_map="auto" to prevent meta tensor issues
72
+ )
73
+ print("Model loaded successfully!")
74
+
75
+ # Load and process an image
76
+ image = Image.open("path/to/your/image.jpg")
77
+
78
+ # Move model to GPU if available
79
+ if torch.cuda.is_available():
80
+ model = model.to('cuda')
81
+ print("Model moved to GPU!")
82
+
83
+ # Generate caption
84
+ print("Generating caption...")
85
+ with torch.no_grad():
86
+ captions = model([image], max_length=300)
87
+
88
+ # Handle different possible output formats
89
+ if hasattr(captions, 'generated_text'):
90
+ print("Generated caption:", captions.generated_text[0])
91
+ elif isinstance(captions, list):
92
+ print("Generated caption:", captions[0])
93
+ else:
94
+ print("Generated caption:", captions)
95
+
96
+ except Exception as e:
97
+ print(f"Error during model loading or inference: {e}")
98
+ # If cached files are corrupted, try clearing cache and redownloading
99
+ import shutil
100
+ cache_path = f"{cache_dir}/modules/transformers_modules/your-username/vlv-captioner"
101
+ if os.path.exists(cache_path):
102
+ print(f"Clearing cache at {cache_path}")
103
+ shutil.rmtree(cache_path)
104
+
105
+ # Retry with force download
106
+ model = AutoModel.from_pretrained(
107
+ "your-username/vlv-captioner",
108
+ trust_remote_code=True,
109
+ token=token,
110
+ cache_dir=cache_dir,
111
+ force_download=True,
112
+ torch_dtype=torch.float32
113
+ )
114
+ ```
115
+
116
+ ### Method 2: Load from original checkpoint
117
+
118
+ ```python
119
+ from VLV_stage2 import VLV_MODEL
120
+
121
+ # Load from original .pt checkpoint file
122
+ model = VLV_MODEL.from_checkpoint("path/to/model.pt")
123
+
124
+ # Load and process an image
125
+ image = Image.open("path/to/your/image.jpg")
126
+
127
+ # Generate caption
128
+ with torch.no_grad():
129
+ captions = model([image], max_length=300)
130
+ print(captions.generated_text[0]) # Generated caption
131
+ ```
132
+
133
+ ## Model Details
134
+
135
+ - **Model Type**: Vision-Language Model
136
+ - **Architecture**: VLV_decoder
137
+ - **Language Backbone**: Qwen/Qwen2.5-3B
138
+ - **Vision Backbone**: Stable Diffusion + Florence2
139
+ - **Training Data**: Various image-caption datasets
140
+ - **Framework**: PyTorch, Transformers
141
+
142
+ ## Training Configuration
143
+
144
+ - **Batch Size**: 1 (inference)
145
+ - **Learnable Token Length**: 77
146
+ - **Guidance Scale**: 7.5
147
+ - **Inference Steps**: 50
148
+ - **Beam Search**: 4 beams
149
 
150
+ ## Requirements
151
 
152
  ```bash
153
+ pip install torch transformers safetensors torchvision pillow diffusers
 
154
  ```
155
 
156
+ ## Troubleshooting
157
+
158
+ ### Common Issues and Solutions
159
+
160
+ #### 1. Meta Tensor Issues
161
+ If you encounter meta tensor errors, avoid using `device_map="auto"` when loading the model:
162
+
163
  ```python
164
+ # Don't use this - can cause meta tensor issues
165
+ model = AutoModel.from_pretrained("model-name", device_map="auto")
 
166
 
167
+ # Use this instead
168
+ model = AutoModel.from_pretrained("model-name", torch_dtype=torch.float32, low_cpu_mem_usage=True)
169
+ if torch.cuda.is_available():
170
+ model = model.to('cuda')
171
+ ```
172
 
173
+ #### 2. Cache Issues
174
+ If you experience corrupted cache files, clear the cache and redownload:
 
 
 
 
 
 
 
 
175
 
176
+ ```python
177
+ import shutil
178
+ import os
 
 
 
 
179
 
180
+ cache_dir = "/your/cache/directory"
181
+ cache_path = f"{cache_dir}/modules/transformers_modules/your-username/model-name"
182
+ if os.path.exists(cache_path):
183
+ shutil.rmtree(cache_path)
184
+
185
+ # Then reload with force_download=True
186
+ model = AutoModel.from_pretrained("model-name", force_download=True)
 
 
 
 
 
 
 
187
  ```
188
 
189
+ #### 3. Authentication Issues
190
+ Make sure your Hugging Face token is properly set:
191
 
192
+ ```bash
193
+ # Option 1: Environment variable
194
+ export HUGGINGFACE_TOKEN="your_token_here"
195
+
196
+ # Option 2: Hugging Face CLI login
197
+ huggingface-cli login
198
+ ```
199
+
200
+ #### 4. Memory Issues
201
+ For large models, use a custom cache directory with sufficient space:
202
 
203
  ```python
204
+ cache_dir = "/path/to/large/storage"
205
+ os.makedirs(cache_dir, exist_ok=True)
206
+ model = AutoModel.from_pretrained("model-name", cache_dir=cache_dir, low_cpu_mem_usage=True)
207
+ ```
208
 
209
+ ## Advanced Usage
 
 
210
 
211
+ ### Batch Processing with Original Inference Script
212
 
213
+ For large-scale inference, you can use the original training inference script:
 
214
 
215
+ ```bash
216
+ python Caption_inference.py \
217
+ --input_path /path/to/images \
218
+ --output_path captions.json \
219
+ --clip_decoder_checkpoint /path/to/model.pt \
220
+ --qwen_model Qwen/Qwen2.5-3B \
221
+ --stable_diffusion_model_path stabilityai/stable-diffusion-2-1-base \
222
+ --florence2_model_path microsoft/Florence-2-large \
223
+ --batch_size 4 \
224
+ --max_length 300 \
225
+ --num_beams 4 \
226
+ --image_size 384 \
227
+ --guidance_scale 7.5 \
228
+ --use_text_encoder \
229
+ --distributed # For multi-GPU inference
230
+ ```
231
 
232
+ ### Configuration Parameters
233
+
234
+ - `image_size`: Input image resolution (default: 384)
235
+ - `guidance_scale`: Diffusion guidance scale (default: 7.5)
236
+ - `learnable_token_length`: Number of vision tokens (default: 77)
237
+ - `max_length`: Maximum caption length (default: 300)
238
+ - `num_beams`: Beam search width (default: 4)
239
+ - `use_text_encoder`: Enable CLIP text encoder (recommended: True)
240
  ```
241
+
242
+ ## Citation
243
 
244
  ```bibtex
245
+ @article{vlv_autoencoder,
246
+ title={Vision-Language-Vision Auto-Encoder: Scalable Knowledge Distillation from Diffusion Models},
247
+ author={Zhang, Tiezheng and Li, Yitong and Chou, Yu-Cheng and Chen, Jieneng and Yuille, Alan L. and Wei, Chen and Xiao, Junfei},
248
+ journal={arXiv preprint},
249
+ year={2024}
250
  }
251
+ ```
252
+
253
+ ## License
254
+
255
+ This model is released under the Apache 2.0 license.
VLV_stage1.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ from typing import Optional
5
+ from dataclasses import dataclass
6
+ from transformers.utils import ModelOutput
7
+ from transformers.modeling_utils import PreTrainedModel
8
+ from transformers.configuration_utils import PretrainedConfig
9
+ from .build import load_sd_model, load_Florence2_model
10
+ from .vlv_utils import initiate_time_steps, normalize
11
+
12
+
13
+ class SDConfig(PretrainedConfig):
14
+ """Configuration class for SDModel."""
15
+ model_type = "sd"
16
+
17
+ def __init__(self, **kwargs):
18
+ super().__init__(**kwargs)
19
+
20
+
21
+ class MLP(nn.Module):
22
+ def __init__(self, input_dim, output_dim):
23
+ super().__init__()
24
+ self.layers = nn.Sequential(
25
+ nn.Linear(input_dim, output_dim),
26
+ nn.GELU(),
27
+ nn.Linear(output_dim, output_dim),
28
+ )
29
+
30
+ def forward(self, x):
31
+ return self.layers(x)
32
+
33
+ @dataclass
34
+ class SDOutput(ModelOutput):
35
+ loss: Optional[torch.FloatTensor] = None
36
+
37
+ class SDModel(PreTrainedModel):
38
+ config_class = SDConfig
39
+
40
+ def __init__(
41
+ self,
42
+ config=None,
43
+ training_args = None,
44
+ ):
45
+ if config is None:
46
+ config = SDConfig()
47
+ super().__init__(config)
48
+ self.training_args = training_args
49
+ if self.training_args.fp32:
50
+ self._dtype = torch.float32
51
+ else:
52
+ self._dtype = torch.bfloat16
53
+ self._device = torch.device(self.training_args.device if hasattr(self.training_args, 'device') else "cuda" if torch.cuda.is_available() else "cpu")
54
+
55
+ self.vae, self.tokenizer, self.text_encoder, self.unet, self.scheduler = load_sd_model(training_args)
56
+ torch.cuda.empty_cache()
57
+ self.unet.eval()
58
+ self.text_encoder.eval()
59
+ self.model, self.processor = load_Florence2_model(training_args)
60
+
61
+ self.unet = self.unet.to(self._dtype).to(device=self._device)
62
+ self.text_encoder = self.text_encoder.to(self._dtype).to_empty(device=self._device)
63
+ self.model = self.model.to(self._dtype).to_empty(device=self._device)
64
+ self.vae = self.vae.to(torch.float32).to_empty(device=self._device)
65
+
66
+ self.batch_size = self.training_args.batch_size
67
+
68
+ hidden_dim = 1024
69
+ self.language_proj = nn.Sequential(
70
+ nn.Linear(1024, hidden_dim, dtype=self._dtype),
71
+ nn.GELU(),
72
+ nn.Linear(hidden_dim, 1024, dtype=self._dtype)
73
+ ).to_empty(device=self._device)
74
+ for param in self.language_proj.parameters():
75
+ param.requires_grad = True
76
+
77
+ self.num_queries = self.training_args.learnable_token_length
78
+ self.query_embed = nn.Parameter(torch.randn(1, self.num_queries, 1024, dtype=self._dtype))
79
+ self.query_embed.requires_grad = True
80
+
81
+ self.unet.enable_gradient_checkpointing()
82
+
83
+ def _unet_pred_noise(self, x_start, t, noise, context):
84
+ t = t.to(dtype=torch.long)
85
+
86
+ dtype = self.unet.dtype
87
+ x_start = x_start.to(dtype)
88
+ noise = noise.to(dtype)
89
+ context = context.to(dtype)
90
+
91
+ nt = t.shape[0]
92
+ noised_latent = self.scheduler.add_noise(x_start, noise, t)
93
+
94
+ pred_noise = self.unet(
95
+ noised_latent,
96
+ t,
97
+ encoder_hidden_states=context.expand(nt, -1, -1)
98
+ ).sample
99
+
100
+ return pred_noise
101
+
102
+ def generate_images(self, images):
103
+ batch_size = self.training_args.eval_batch_size
104
+ prompt = ["<MORE_DETAILED_CAPTION>"] * batch_size
105
+ inputs = self.processor(text=prompt, images=images, return_tensors="pt").to(self._device).to(self._dtype)
106
+
107
+ if inputs["input_ids"] is not None:
108
+ inputs_embeds = self.model.language_model.get_input_embeddings()(inputs["input_ids"]).to(self._dtype)
109
+ if inputs["pixel_values"] is not None:
110
+ image_features = self.model._encode_image(inputs["pixel_values"]).to(self._dtype)
111
+ inputs_embeds, attention_mask = self.model._merge_input_ids_with_image_features(image_features, inputs_embeds)
112
+ if inputs_embeds is not None:
113
+ attention_mask = attention_mask.to(inputs_embeds.dtype)
114
+ encoder_outputs = self.model.language_model.model.encoder(
115
+ inputs_embeds=inputs_embeds,
116
+ attention_mask=attention_mask,
117
+ output_hidden_states=True,
118
+ return_dict=True
119
+ )
120
+
121
+ decoder_input_embeds = self.query_embed.expand(batch_size, -1, -1)
122
+ decoder_attention_mask = torch.ones(
123
+ (batch_size, self.num_queries),
124
+ dtype=self._dtype,
125
+ device=self._device
126
+ )
127
+
128
+ encoder_hidden_states = encoder_outputs.last_hidden_state.to(self._dtype)
129
+ decoder_input_embeds = decoder_input_embeds.to(self._dtype)
130
+ attention_mask = attention_mask.to(self._dtype)
131
+
132
+ decoder_outputs = self.model.language_model.model.decoder(
133
+ inputs_embeds=decoder_input_embeds,
134
+ attention_mask=decoder_attention_mask,
135
+ encoder_hidden_states=encoder_hidden_states,
136
+ encoder_attention_mask=attention_mask,
137
+ output_hidden_states=True,
138
+ return_dict=True
139
+ )
140
+
141
+ last_decoder_hidden_state = decoder_outputs.last_hidden_state
142
+ conditional_context = self.language_proj(last_decoder_hidden_state)
143
+
144
+ un_token = self.tokenizer("", padding="max_length", truncation=True,max_length=77, return_tensors="pt").input_ids.to(self._device)
145
+ un_context_embeddings = self.text_encoder(un_token).last_hidden_state
146
+ un_context_embeddings = un_context_embeddings.expand(batch_size, -1, -1)
147
+ if self.training_args.use_text_encoder:
148
+ context_embeddings = self.text_encoder(
149
+ inputs_embeds=conditional_context.to(self._dtype)
150
+ ).last_hidden_state
151
+
152
+ latent_shape = (batch_size, 4, self.training_args.image_size // 8, self.training_args.image_size // 8)
153
+ latents = torch.randn(latent_shape, device=self._device, dtype=self._dtype)
154
+
155
+ scheduler = self.scheduler
156
+ scheduler.set_timesteps(self.training_args.num_inference_steps)
157
+ with torch.no_grad():
158
+ for t in scheduler.timesteps:
159
+ latent_model_input = torch.cat([latents, latents], dim=0)
160
+ latent_model_input = scheduler.scale_model_input(latent_model_input, t)
161
+
162
+ combined_embeddings = torch.cat([un_context_embeddings, context_embeddings], dim=0).to(self._dtype)
163
+ noise_pred = self.unet(
164
+ latent_model_input, t, encoder_hidden_states=combined_embeddings
165
+ )[0]
166
+
167
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2, dim=0)
168
+ noise_pred = noise_pred_uncond + self.training_args.guidance_scale * (noise_pred_cond - noise_pred_uncond)
169
+
170
+ latents = scheduler.step(noise_pred, t, latents)[0]
171
+
172
+ scaled_latents = latents / 0.18215
173
+ with torch.no_grad():
174
+ decoded_latents = self.vae.decode(scaled_latents.to(torch.float32))[0]
175
+
176
+ return decoded_latents
177
+
178
+ def get_conditional_context(self, images, batch_size=None):
179
+ if batch_size is None:
180
+ batch_size = self.batch_size
181
+ prompt = ["<MORE_DETAILED_CAPTION>"] * batch_size
182
+ inputs = self.processor(text=prompt, images=images, return_tensors="pt").to(self._device).to(self._dtype)
183
+
184
+ if inputs["input_ids"] is not None:
185
+ inputs_embeds = self.model.language_model.get_input_embeddings()(inputs["input_ids"]).to(self._dtype)
186
+ if inputs["pixel_values"] is not None:
187
+ image_features = self.model._encode_image(inputs["pixel_values"]).to(self._dtype)
188
+ inputs_embeds, attention_mask = self.model._merge_input_ids_with_image_features(image_features, inputs_embeds)
189
+ if inputs_embeds is not None:
190
+ attention_mask = attention_mask.to(inputs_embeds.dtype)
191
+ encoder_outputs = self.model.language_model.model.encoder(
192
+ inputs_embeds=inputs_embeds,
193
+ attention_mask=attention_mask,
194
+ output_hidden_states=True,
195
+ return_dict=True
196
+ )
197
+
198
+ decoder_input_embeds = self.query_embed.expand(batch_size, -1, -1)
199
+ decoder_attention_mask = torch.ones(
200
+ (batch_size, self.num_queries),
201
+ dtype=self._dtype,
202
+ device=self._device
203
+ )
204
+
205
+ encoder_hidden_states = encoder_outputs.last_hidden_state.to(self._dtype)
206
+ decoder_input_embeds = decoder_input_embeds.to(self._dtype)
207
+ attention_mask = attention_mask.to(self._dtype)
208
+
209
+ decoder_outputs = self.model.language_model.model.decoder(
210
+ inputs_embeds=decoder_input_embeds,
211
+ attention_mask=decoder_attention_mask,
212
+ encoder_hidden_states=encoder_hidden_states,
213
+ encoder_attention_mask=attention_mask,
214
+ output_hidden_states=True,
215
+ return_dict=True
216
+ )
217
+
218
+ last_decoder_hidden_state = decoder_outputs.last_hidden_state
219
+ return last_decoder_hidden_state
220
+
221
+ def forward(
222
+ self,
223
+ image=None,
224
+ filename=None,
225
+ **kwargs,
226
+ ) -> SDOutput:
227
+ images_for_language_model = image
228
+ normalize_images = normalize(image, rescale=True)
229
+ x0=self.vae.encode(normalize_images.to(torch.float32)).latent_dist.sample()
230
+ latent = x0 * 0.18215
231
+
232
+ total_timestep = self.scheduler.num_train_timesteps
233
+
234
+ timesteps = initiate_time_steps(0, total_timestep, self.batch_size, self.training_args).long()
235
+ timesteps = timesteps.to(self._device)
236
+ c, h, w = latent.shape[1:]
237
+ if not self.training_args.use_same_noise_among_timesteps:
238
+ noise = torch.randn((self.batch_size, c, h, w), device=self._device, dtype=self._dtype)
239
+ else:
240
+ noise = torch.randn((1, c, h, w), device=self._device, dtype=self._dtype)
241
+ noise = noise.repeat(self.batch_size, 1, 1, 1)
242
+
243
+ conditional_context = self.get_conditional_context(images_for_language_model)
244
+ conditional_context = self.language_proj(conditional_context)
245
+
246
+ if self.training_args.use_text_encoder:
247
+ text_encoder_output = self.text_encoder(input_ids=None, inputs_embeds=conditional_context.to(self._dtype))
248
+ pred_noise = self._unet_pred_noise(x_start=latent, t=timesteps, noise=noise, context=text_encoder_output.last_hidden_state.to(self._dtype)).to(self._dtype)
249
+ else:
250
+ pred_noise = self._unet_pred_noise(x_start=latent, t=timesteps, noise=noise, context=conditional_context.to(self._dtype)).to(self._dtype)
251
+
252
+ if self.training_args.loss == "l1":
253
+ loss = torch.nn.functional.l1_loss(pred_noise, noise)
254
+ else:
255
+ loss = torch.nn.functional.mse_loss(pred_noise, noise)
256
+
257
+ return SDOutput(loss=loss)
VLV_stage2.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional, Tuple, Dict, Any, Union
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import numpy as np
7
+ from transformers.utils import ModelOutput
8
+ from transformers.modeling_utils import PreTrainedModel
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, PretrainedConfig
10
+ from safetensors.torch import load_file
11
+ import torchvision.transforms as transforms
12
+ from .build import load_sd_model, load_Florence2_model
13
+ from .vlv_utils import initiate_time_steps, normalize, process_caption
14
+ from .VLV_stage1 import SDModel, SDConfig
15
+ from .configuration_vlv import VLV_Config
16
+ import os
17
+ import sys
18
+ import argparse
19
+
20
+ def handle_module_prefix(state_dict):
21
+ """Handle 'module.' prefix in state dict keys."""
22
+ if any(k.startswith('module.') for k in state_dict.keys()):
23
+ return {k.replace('module.', ''): v for k, v in state_dict.items()}
24
+ return state_dict
25
+
26
+ def create_model_args(args):
27
+ """Create model arguments needed by SDModel."""
28
+ model_args = argparse.Namespace()
29
+ model_args.use_text_encoder = args.use_text_encoder
30
+ model_args.batch_size = args.batch_size
31
+ model_args.eval_batch_size = args.batch_size
32
+ model_args.distributed_strategy = 'none'
33
+ model_args.fp32 = args.fp32
34
+ model_args.learnable_token_length = args.learnable_token_length
35
+ model_args.num_inference_steps = args.num_inference_steps
36
+ model_args.image_size = args.image_size
37
+ model_args.guidance_scale = args.guidance_scale
38
+ model_args.unfreeze_florence2_all = False
39
+ model_args.unfreeze_florence2_language_model = False
40
+ model_args.unfreeze_florence2_language_model_decoder = False
41
+ return model_args
42
+
43
+ def load_model_checkpoint(model, model_path, device):
44
+ """Load model checkpoint."""
45
+ try:
46
+ checkpoint = torch.load(model_path, map_location="cpu")
47
+
48
+ # Handle different checkpoint formats
49
+ if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
50
+ state_dict = checkpoint['model_state_dict']
51
+ elif isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
52
+ state_dict = checkpoint['state_dict']
53
+ else:
54
+ state_dict = checkpoint
55
+
56
+ state_dict = handle_module_prefix(state_dict)
57
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
58
+
59
+ if missing_keys:
60
+ print(f"Missing keys: {missing_keys[:10]}...") # Show first 10
61
+ if unexpected_keys:
62
+ print(f"Unexpected keys: {unexpected_keys[:10]}...") # Show first 10
63
+
64
+ print(f"Successfully loaded model from {model_path}")
65
+ except Exception as e:
66
+ print(f"Error loading model: {e}")
67
+ raise e
68
+
69
+ return model
70
+
71
+ def initialize_diffusion_model(args):
72
+ """Initialize the diffusion model."""
73
+ config = SDConfig()
74
+ diffusion_model_args = create_model_args(args)
75
+ diffusion_model = SDModel(config, diffusion_model_args)
76
+ _dtype = torch.float32 if diffusion_model_args.fp32 else torch.bfloat16
77
+
78
+ # Delete components that aren't needed for inference
79
+ if hasattr(diffusion_model, 'vae'):
80
+ del diffusion_model.vae
81
+ if hasattr(diffusion_model, 'unet'):
82
+ del diffusion_model.unet
83
+
84
+ # Clear CUDA cache
85
+ torch.cuda.empty_cache()
86
+
87
+ diffusion_model = diffusion_model.to(_dtype)
88
+
89
+ # Freeze parameters that shouldn't be trained
90
+ for param in diffusion_model.language_proj.parameters():
91
+ param.requires_grad = False
92
+ diffusion_model.query_embed.requires_grad = False
93
+
94
+ return diffusion_model
95
+
96
+ class MLP(nn.Module):
97
+ def __init__(self, input_dim, output_dim):
98
+ super(MLP, self).__init__()
99
+ self.layers = nn.Sequential(
100
+ nn.Linear(input_dim, output_dim),
101
+ nn.GELU(),
102
+ nn.Linear(output_dim, output_dim),
103
+ )
104
+
105
+ def forward(self, x):
106
+ return self.layers(x)
107
+
108
+
109
+ @dataclass
110
+ class CLIPDecoderOutput(ModelOutput):
111
+ """
112
+ Output class for the CLIP Decoder model.
113
+ """
114
+ last_hidden_state: Optional[torch.FloatTensor] = None
115
+ generated_ids: Optional[torch.LongTensor] = None
116
+ generated_text: Optional[list] = None
117
+
118
+
119
+ class CLIPDecoder(nn.Module):
120
+
121
+ def __init__(
122
+ self,
123
+ language_model: str,
124
+ VLV_model: SDModel,
125
+ device: torch.device,
126
+ bf16: str,
127
+ qwen2_config: dict = None,
128
+ args: argparse.Namespace = None
129
+ ):
130
+ """
131
+ Initialize the CLIP Decoder model.
132
+
133
+ Args:
134
+ language_model: Path to the language model
135
+ VLV_model: The VLV model instance
136
+ device: The device to run the model on
137
+ bf16: Whether to use bfloat16 precision
138
+ qwen2_config: Optional qwen2 configuration dict
139
+ """
140
+ super(CLIPDecoder, self).__init__()
141
+
142
+ self._dtype = torch.bfloat16 if bf16 == "bf16" else torch.float32
143
+ self.qwen2_tokenizer = AutoTokenizer.from_pretrained(language_model)
144
+
145
+ self.qwen2_config = AutoConfig.from_pretrained(language_model)
146
+ self.qwen2_model = AutoModelForCausalLM.from_pretrained(
147
+ language_model,
148
+ torch_dtype=self._dtype,
149
+ device_map=None,
150
+ low_cpu_mem_usage=True
151
+ )
152
+
153
+ self.VLV_model = VLV_model # fp32 in this case
154
+ self.device = device
155
+ self.mlp = MLP(input_dim=1024, output_dim=self.qwen2_model.config.hidden_size)
156
+ self.ignore_token_id = -100
157
+
158
+
159
+ def get_conditional_context(self, images, batch_size):
160
+ """
161
+ Get conditional context from images using the diffusion model.
162
+
163
+ Args:
164
+ images: Input images
165
+ batch_size: Batch size
166
+
167
+ Returns:
168
+ Decoder hidden states from the diffusion model
169
+ """
170
+ prompt = ["<MORE_DETAILED_CAPTION>"] * batch_size
171
+ inputs = self.VLV_model.processor(text=prompt, images=images, return_tensors="pt").to(self.device).to(self._dtype)
172
+
173
+ # Ensure all components are on the correct device
174
+ self.VLV_model = self.VLV_model.to(inputs["input_ids"].device)
175
+ self.qwen2_model = self.qwen2_model.to(inputs["input_ids"].device)
176
+ self.mlp = self.mlp.to(inputs["input_ids"].device)
177
+ self.VLV_model.model.language_model.model = self.VLV_model.model.language_model.model.to(inputs["input_ids"].device)
178
+
179
+ if inputs["input_ids"] is not None:
180
+ inputs_embeds = self.VLV_model.model.language_model.get_input_embeddings()(inputs["input_ids"]).to(self.device)
181
+
182
+ if inputs["pixel_values"] is not None:
183
+ image_features = self.VLV_model.model._encode_image(inputs["pixel_values"]).to(self.device)
184
+ inputs_embeds, attention_mask = self.VLV_model.model._merge_input_ids_with_image_features(
185
+ image_features, inputs_embeds
186
+ )
187
+
188
+ if inputs_embeds is not None:
189
+ attention_mask = attention_mask.to(inputs_embeds.dtype)
190
+
191
+ encoder_outputs = self.VLV_model.model.language_model.model.encoder(
192
+ inputs_embeds=inputs_embeds,
193
+ attention_mask=attention_mask,
194
+ output_hidden_states=True,
195
+ return_dict=True
196
+ )
197
+
198
+ decoder_inputs_embeds = self.VLV_model.query_embed.expand(batch_size, -1, -1)
199
+ decoder_attention_mask = torch.ones(
200
+ (batch_size, self.VLV_model.num_queries),
201
+ dtype=self._dtype,
202
+ device=self.device
203
+ )
204
+
205
+ encoder_hidden_states = encoder_outputs.last_hidden_state.to(self._dtype)
206
+ decoder_input_embeds = decoder_inputs_embeds.to(self._dtype)
207
+ attention_mask = attention_mask.to(self._dtype)
208
+
209
+ decoder_outputs = self.VLV_model.model.language_model.model.decoder(
210
+ inputs_embeds=decoder_input_embeds,
211
+ attention_mask=decoder_attention_mask,
212
+ encoder_hidden_states=encoder_hidden_states,
213
+ encoder_attention_mask=attention_mask,
214
+ output_hidden_states=True,
215
+ return_dict=True
216
+ )
217
+
218
+ return decoder_outputs.last_hidden_state
219
+
220
+ def process_image(self, images, batch_size):
221
+ """
222
+ Process images to get clip text embeddings.
223
+
224
+ Args:
225
+ images: Input images
226
+ batch_size: Batch size
227
+
228
+ Returns:
229
+ Processed clip text embeddings and attention mask
230
+ """
231
+ decoder_hidden_states = self.get_conditional_context(images, batch_size)
232
+ context_embeds = self.VLV_model.language_proj(decoder_hidden_states)
233
+ clip_text_embeds = self.VLV_model.text_encoder(inputs_embeds=context_embeds).last_hidden_state
234
+ clip_text_embeds = self.mlp(clip_text_embeds)
235
+ clip_text_embeds_attention_mask = torch.ones(
236
+ (batch_size, self.VLV_model.num_queries),
237
+ dtype=torch.long,
238
+ device=self.device
239
+ )
240
+
241
+ return clip_text_embeds, clip_text_embeds_attention_mask
242
+
243
+ def prepare_generation_inputs(self, clip_text_embeds, clip_text_attention_mask=None):
244
+ """
245
+ Prepare inputs for text generation.
246
+
247
+ Args:
248
+ clip_text_embeds: Processed clip text embeddings
249
+ clip_text_attention_mask: Attention mask for clip text embeddings
250
+
251
+ Returns:
252
+ Dictionary of generation inputs
253
+ """
254
+ if clip_text_attention_mask is None:
255
+ clip_text_attention_mask = torch.ones(
256
+ (clip_text_embeds.shape[0], clip_text_embeds.shape[1]),
257
+ dtype=torch.long,
258
+ device=clip_text_embeds.device
259
+ )
260
+
261
+ return {
262
+ "inputs_embeds": clip_text_embeds,
263
+ "attention_mask": clip_text_attention_mask
264
+ }
265
+
266
+ def generate(self, images, max_new_tokens=300, num_beams=4, early_stopping=True):
267
+ """
268
+ Generate text from images.
269
+
270
+ Args:
271
+ images: Input images
272
+ max_new_tokens: Maximum number of tokens to generate
273
+ num_beams: Number of beams for beam search
274
+ early_stopping: Whether to stop early in beam search
275
+
276
+ Returns:
277
+ CLIPDecoderOutput with generated ids and text
278
+ """
279
+ batch_size = len(images)
280
+ clip_text_embeds, clip_text_attention_mask = self.process_image(images, batch_size)
281
+ generation_inputs = self.prepare_generation_inputs(clip_text_embeds, clip_text_attention_mask)
282
+
283
+ generation_inputs["inputs_embeds"] = generation_inputs["inputs_embeds"].to(self._dtype)
284
+ generation_inputs["attention_mask"] = generation_inputs["attention_mask"].to(self._dtype)
285
+
286
+ generated_ids = self.qwen2_model.generate(
287
+ inputs_embeds=generation_inputs["inputs_embeds"],
288
+ attention_mask=generation_inputs["attention_mask"],
289
+ max_new_tokens=max_new_tokens,
290
+ num_beams=num_beams,
291
+ early_stopping=early_stopping
292
+ )
293
+
294
+ generated_text = self.qwen2_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
295
+ processed_generated_text = [process_caption(text) for text in generated_text]
296
+
297
+ return CLIPDecoderOutput(
298
+ generated_ids=generated_ids,
299
+ generated_text=processed_generated_text
300
+ )
301
+
302
+ def forward(self, images, captions=None):
303
+ """
304
+ Forward pass for training.
305
+
306
+ Args:
307
+ images: Input images
308
+ captions: Target captions (optional, for training)
309
+
310
+ Returns:
311
+ CLIPDecoderOutput with loss and logits
312
+ """
313
+ batch_size = images.shape[0]
314
+
315
+ # Process images
316
+ clip_text_embeds, clip_text_attention_mask = self.process_image(images, batch_size)
317
+
318
+ # If no captions provided, return embeddings for generation
319
+ if captions is None:
320
+ return CLIPDecoderOutput(
321
+ last_hidden_state=clip_text_embeds
322
+ )
323
+
324
+ assert len(captions) == batch_size
325
+ # Process captions for training
326
+ processed_captions = [process_caption(caption) for caption in captions]
327
+ qwen_input_ids = self.qwen2_tokenizer(
328
+ text=processed_captions,
329
+ truncation=True,
330
+ return_tensors="pt",
331
+ padding="max_length",
332
+ max_length=300,
333
+ return_token_type_ids=False,
334
+ ).input_ids
335
+
336
+ assert len(captions) == batch_size
337
+ qwen_attention_mask = qwen_input_ids.ne(self.qwen2_tokenizer.pad_token_id).to(torch.long).to(self.device)
338
+
339
+ # Prepare labels for training
340
+ labels = qwen_input_ids
341
+ labels[labels == self.qwen2_tokenizer.pad_token_id] = self.ignore_token_id
342
+ labels = labels.to(self.device)
343
+
344
+ # Get embeddings for captions to create the full input sequence
345
+ labels_for_embeddings = labels.clone()
346
+ labels_for_embeddings[labels_for_embeddings == self.ignore_token_id] = self.qwen2_tokenizer.pad_token_id
347
+ clip_text_embeds_qwen = self.qwen2_model.get_input_embeddings()(labels_for_embeddings)
348
+
349
+ # Concatenate the embeddings and prepare attention mask
350
+ inputs_embeds = torch.cat((clip_text_embeds, clip_text_embeds_qwen), dim=1)
351
+ clip_seq_len = clip_text_embeds.shape[1]
352
+ clip_ignore_labels = torch.full((labels.shape[0], clip_seq_len), self.ignore_token_id).to(labels)
353
+ combined_labels = torch.cat((clip_ignore_labels, labels), dim=1)
354
+
355
+ attention_mask = torch.cat((
356
+ clip_text_attention_mask,
357
+ qwen_attention_mask
358
+ ), dim=1)
359
+
360
+ # Forward through language model
361
+ outputs = self.qwen2_model(
362
+ inputs_embeds=inputs_embeds,
363
+ labels=combined_labels,
364
+ attention_mask=attention_mask,
365
+ use_cache=False
366
+ )
367
+ return outputs
368
+
369
+
370
+ # HuggingFace Model Wrapper
371
+ class VLV_MODEL(PreTrainedModel):
372
+ config_class = VLV_Config
373
+ model_type = "VLV_decoder"
374
+
375
+ def __init__(self, config):
376
+ super().__init__(config)
377
+ """Load the CLIPDecoder model."""
378
+ # Initialize the diffusion model first
379
+ device = "cuda"
380
+ de_diffusion_model = initialize_diffusion_model(config)
381
+ clip_decoder_model = CLIPDecoder(
382
+ language_model=config.qwen_model,
383
+ VLV_model=de_diffusion_model,
384
+ device=device,
385
+ bf16=config.mixed_precision,
386
+ qwen2_config=config.qwen2_config
387
+ )
388
+
389
+ # Load the trained weights
390
+ # clip_decoder_model = load_model_checkpoint(clip_decoder_model, config.clip_decoder_checkpoint, device)
391
+
392
+ # Set to evaluation mode
393
+ clip_decoder_model.eval()
394
+
395
+ # Store components directly as attributes to match checkpoint structure
396
+ self.VLV_model = clip_decoder_model.VLV_model
397
+ self.qwen2_model = clip_decoder_model.qwen2_model
398
+ self.mlp = clip_decoder_model.mlp
399
+
400
+ # Keep the full model for methods
401
+ self._clip_decoder_model = clip_decoder_model
402
+ self.max_new_tokens = config.max_length
403
+ self.num_beams = config.num_beams
404
+ self.transform = self.get_transform(config.image_size)
405
+
406
+ def get_transform(self, image_size):
407
+ """Transformation pipeline for input images."""
408
+ return transforms.Compose([
409
+ transforms.Resize(image_size),
410
+ transforms.CenterCrop((image_size, image_size)),
411
+ transforms.PILToTensor(),
412
+ ])
413
+
414
+ @classmethod
415
+ def from_checkpoint(cls, checkpoint_path, config=None, **kwargs):
416
+ """
417
+ Load model from original training checkpoint.
418
+
419
+ Args:
420
+ checkpoint_path: Path to the original model.pt checkpoint
421
+ config: Optional VLV_Config, will create default if None
422
+ **kwargs: Additional arguments for model initialization
423
+ """
424
+ if config is None:
425
+ # Create default config
426
+ config = VLV_Config(
427
+ image_size=384,
428
+ guidance_scale=7.5,
429
+ learnable_token_length=77,
430
+ max_length=300,
431
+ num_beams=4,
432
+ **kwargs
433
+ )
434
+
435
+ # Initialize model
436
+ model = cls(config)
437
+
438
+ # Load checkpoint weights
439
+ device = "cuda" if torch.cuda.is_available() else "cpu"
440
+ load_model_checkpoint(model._clip_decoder_model, checkpoint_path, device)
441
+
442
+ return model
443
+
444
+ def forward(self, valid_images, max_length):
445
+ valid_images = [self.transform(img) for img in valid_images]
446
+ if hasattr(self._clip_decoder_model, 'module'):
447
+ outputs = self._clip_decoder_model.module.generate(
448
+ valid_images,
449
+ max_new_tokens=max_length,
450
+ num_beams=self.num_beams,
451
+ early_stopping=True
452
+ )
453
+ else:
454
+ outputs = self._clip_decoder_model.generate(
455
+ valid_images,
456
+ max_new_tokens=max_length,
457
+ num_beams=self.num_beams,
458
+ early_stopping=True
459
+ )
460
+ return outputs
build.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler
3
+ from transformers import CLIPTokenizer, AutoProcessor
4
+ from .modeling_clip import CustomCLIPTextModel
5
+ from .modeling_florence2 import Florence2ForConditionalGeneration
6
+ from .configuration_florence2 import Florence2Config
7
+
8
+
9
+ def load_sd_model(training_args):
10
+ """Load Stable Diffusion model"""
11
+
12
+ repo_id = "stabilityai/stable-diffusion-2-1-base"
13
+
14
+ text_encoder = CustomCLIPTextModel.from_pretrained(repo_id, subfolder="text_encoder")
15
+ tokenizer = CLIPTokenizer.from_pretrained(repo_id, subfolder="tokenizer")
16
+ vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae",revision=None)
17
+ scheduler = DDPMScheduler.from_pretrained(repo_id, subfolder="scheduler")
18
+ unet = UNet2DConditionModel.from_pretrained(repo_id, subfolder="unet",revision=None)
19
+
20
+ for m in [vae, text_encoder, unet]:
21
+ for param in m.parameters():
22
+ param.requires_grad = False
23
+
24
+ return (vae, tokenizer, text_encoder, unet, scheduler)
25
+
26
+
27
+ def load_Florence2_model(training_args):
28
+ config = Florence2Config.from_pretrained("microsoft/Florence-2-large")
29
+ config.vision_config.model_type = "davit"
30
+ config._attn_implementation = "eager"
31
+
32
+ # Load the model with pre-trained weights
33
+ model = Florence2ForConditionalGeneration.from_pretrained("microsoft/Florence-2-large", config=config)
34
+ processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)
35
+
36
+ # freeze the model
37
+ if training_args.unfreeze_florence2_all:
38
+ for param in model.parameters():
39
+ param.requires_grad = True
40
+ elif training_args.unfreeze_florence2_language_model:
41
+ for param in model.parameters():
42
+ param.requires_grad = False
43
+ for param in model.language_model.parameters():
44
+ param.requires_grad = True
45
+ for param in model.language_model.lm_head.parameters():
46
+ param.requires_grad = False
47
+
48
+ model.language_model.lm_head.weight = torch.nn.Parameter(
49
+ model.language_model.lm_head.weight.detach().clone())
50
+
51
+ for p in model.language_model.lm_head.parameters():
52
+ p.requires_grad = False
53
+
54
+
55
+ elif training_args.unfreeze_florence2_language_model_decoder:
56
+ # Create a separate embedding layer for decoder
57
+ original_embeddings = model.language_model.model.shared
58
+ new_decoder_embeddings = torch.nn.Embedding(
59
+ num_embeddings=original_embeddings.num_embeddings,
60
+ embedding_dim=original_embeddings.embedding_dim,
61
+ padding_idx=original_embeddings.padding_idx
62
+ )
63
+ # Copy the weights
64
+ new_decoder_embeddings.weight.data = original_embeddings.weight.data.clone()
65
+
66
+ # Replace the decoder embeddings
67
+ model.language_model.model.encoder.embed_tokens = original_embeddings
68
+ model.language_model.model.decoder.embed_tokens = new_decoder_embeddings
69
+ for param in model.parameters():
70
+ param.requires_grad = False
71
+ for param in model.language_model.model.decoder.parameters():
72
+ param.requires_grad = True
73
+ model.language_model.model.decoder.embed_tokens.weight.requires_grad = False
74
+ else:
75
+ for param in model.parameters():
76
+ param.requires_grad = False
77
+
78
+ return model, processor
config.json CHANGED
@@ -3,27 +3,31 @@
3
  "VLV_MODEL"
4
  ],
5
  "auto_map": {
6
- "AutoConfig": "De_DiffusionV2_stage2.VLV_Config",
7
- "AutoModel": "De_DiffusionV2_stage2.VLV_MODEL",
8
- "AutoModelForCausalLM": "De_DiffusionV2_stage2.VLV_MODEL"
9
  },
10
  "model_type": "VLV_decoder",
11
  "batch_size": 1,
12
  "deepspeed": true,
13
  "distributed": true,
14
  "fp32": true,
15
- "guidance_scale": 2.0,
16
  "hidden_size": 128,
17
- "image_size": 768,
18
  "learnable_token_length": 77,
19
  "local_rank": 0,
20
- "mixed_precision": "bf16",
21
  "num_inference_steps": 50,
22
- "torch_dtype": "bfloat16",
23
  "transformers_version": "4.51.1",
24
  "use_text_encoder": true,
25
  "verbose": true,
26
  "qwen_model": "Qwen/Qwen2.5-3B",
 
 
 
 
27
  "qwen2_config":{
28
  "architectures": [
29
  "Qwen2ForCausalLM"
@@ -45,11 +49,11 @@
45
  "rope_theta": 1000000.0,
46
  "sliding_window": 32768,
47
  "tie_word_embeddings": true,
48
- "torch_dtype": "bfloat16",
49
  "transformers_version": "4.40.1",
50
  "use_cache": true,
51
  "use_mrope": false,
52
  "use_sliding_window": false,
53
  "vocab_size": 151936
54
  }
55
- }
 
3
  "VLV_MODEL"
4
  ],
5
  "auto_map": {
6
+ "AutoConfig": "configuration_vlv.VLV_Config",
7
+ "AutoModel": "VLV_stage2.VLV_MODEL",
8
+ "AutoModelForCausalLM": "VLV_stage2.VLV_MODEL"
9
  },
10
  "model_type": "VLV_decoder",
11
  "batch_size": 1,
12
  "deepspeed": true,
13
  "distributed": true,
14
  "fp32": true,
15
+ "guidance_scale": 2.5,
16
  "hidden_size": 128,
17
+ "image_size": 384,
18
  "learnable_token_length": 77,
19
  "local_rank": 0,
20
+ "mixed_precision": "fp32",
21
  "num_inference_steps": 50,
22
+ "torch_dtype": "float32",
23
  "transformers_version": "4.51.1",
24
  "use_text_encoder": true,
25
  "verbose": true,
26
  "qwen_model": "Qwen/Qwen2.5-3B",
27
+ "stable_diffusion_model_path": "stabilityai/stable-diffusion-2-1-base",
28
+ "florence2_model_path": "microsoft/Florence-2-large",
29
+ "max_length": 300,
30
+ "num_beams": 4,
31
  "qwen2_config":{
32
  "architectures": [
33
  "Qwen2ForCausalLM"
 
49
  "rope_theta": 1000000.0,
50
  "sliding_window": 32768,
51
  "tie_word_embeddings": true,
52
+ "torch_dtype": "float32",
53
  "transformers_version": "4.40.1",
54
  "use_cache": true,
55
  "use_mrope": false,
56
  "use_sliding_window": false,
57
  "vocab_size": 151936
58
  }
59
+ }
configuration_vlv.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 VLV Team and the HuggingFace Inc. team. All rights reserved.
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """VLV model configuration"""
16
+
17
+ from typing import Optional, Dict, Any
18
+ from transformers.configuration_utils import PretrainedConfig
19
+ from transformers.utils import logging
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ class VLV_Config(PretrainedConfig):
25
+ r"""
26
+ This is the configuration class to store the configuration of a [`VLV_MODEL`]. It is used to instantiate a VLV model
27
+ according to the specified arguments, defining the model architecture.
28
+
29
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
30
+ documentation from [`PretrainedConfig`] for more information.
31
+
32
+ Args:
33
+ model_type (`str`, *optional*, defaults to "VLV_decoder"):
34
+ The model type identifier.
35
+ batch_size (`int`, *optional*, defaults to 1):
36
+ The batch size for inference.
37
+ deepspeed (`bool`, *optional*, defaults to True):
38
+ Whether to use deepspeed.
39
+ distributed (`bool`, *optional*, defaults to True):
40
+ Whether to use distributed training.
41
+ fp32 (`bool`, *optional*, defaults to True):
42
+ Whether to use fp32 precision.
43
+ guidance_scale (`float`, *optional*, defaults to 2.0):
44
+ The guidance scale for generation.
45
+ hidden_size (`int`, *optional*, defaults to 128):
46
+ The hidden size of the model.
47
+ image_size (`int`, *optional*, defaults to 768):
48
+ The size of input images.
49
+ learnable_token_length (`int`, *optional*, defaults to 77):
50
+ The length of learnable tokens.
51
+ local_rank (`int`, *optional*, defaults to 0):
52
+ The local rank for distributed training.
53
+ mixed_precision (`str`, *optional*, defaults to "bf16"):
54
+ The mixed precision mode.
55
+ num_inference_steps (`int`, *optional*, defaults to 50):
56
+ The number of inference steps.
57
+ torch_dtype (`str`, *optional*, defaults to "bfloat16"):
58
+ The torch dtype.
59
+ use_text_encoder (`bool`, *optional*, defaults to True):
60
+ Whether to use text encoder.
61
+ verbose (`bool`, *optional*, defaults to True):
62
+ Whether to enable verbose mode.
63
+ qwen_model (`str`, *optional*, defaults to "Qwen/Qwen2.5-3B"):
64
+ The Qwen model to use.
65
+ qwen2_config (`dict`, *optional*):
66
+ The Qwen2 configuration.
67
+ max_length (`int`, *optional*, defaults to 300):
68
+ Maximum length for generation.
69
+ num_beams (`int`, *optional*, defaults to 4):
70
+ Number of beams for beam search.
71
+ """
72
+
73
+ model_type = "VLV_decoder"
74
+ keys_to_ignore_at_inference = ["past_key_values"]
75
+
76
+ def __init__(
77
+ self,
78
+ model_type: str = "VLV_decoder",
79
+ batch_size: int = 1,
80
+ deepspeed: bool = True,
81
+ distributed: bool = True,
82
+ fp32: bool = True,
83
+ guidance_scale: float = 2.0,
84
+ hidden_size: int = 128,
85
+ image_size: int = 768,
86
+ learnable_token_length: int = 77,
87
+ local_rank: int = 0,
88
+ mixed_precision: str = "bf16",
89
+ num_inference_steps: int = 50,
90
+ torch_dtype: str = "bfloat16",
91
+ transformers_version: str = "4.51.1",
92
+ use_text_encoder: bool = True,
93
+ verbose: bool = True,
94
+ qwen_model: str = "Qwen/Qwen2.5-3B",
95
+ stable_diffusion_model_path: str = "stabilityai/stable-diffusion-2-1-base",
96
+ florence2_model_path: str = "microsoft/Florence-2-large",
97
+ qwen2_config: Optional[Dict[str, Any]] = None,
98
+ max_length: int = 300,
99
+ num_beams: int = 4,
100
+ **kwargs,
101
+ ):
102
+ self.model_type = model_type
103
+ self.batch_size = batch_size
104
+ self.deepspeed = deepspeed
105
+ self.distributed = distributed
106
+ self.fp32 = fp32
107
+ self.guidance_scale = guidance_scale
108
+ self.hidden_size = hidden_size
109
+ self.image_size = image_size
110
+ self.learnable_token_length = learnable_token_length
111
+ self.local_rank = local_rank
112
+ self.mixed_precision = mixed_precision
113
+ self.num_inference_steps = num_inference_steps
114
+ self.torch_dtype = torch_dtype
115
+ self.transformers_version = transformers_version
116
+ self.use_text_encoder = use_text_encoder
117
+ self.verbose = verbose
118
+ self.qwen_model = qwen_model
119
+ self.stable_diffusion_model_path = stable_diffusion_model_path
120
+ self.florence2_model_path = florence2_model_path
121
+ self.qwen2_config = qwen2_config or self._get_default_qwen2_config()
122
+ self.max_length = max_length
123
+ self.num_beams = num_beams
124
+
125
+ super().__init__(**kwargs)
126
+
127
+ def _get_default_qwen2_config(self):
128
+ """Get default Qwen2 configuration."""
129
+ return {
130
+ "architectures": ["Qwen2ForCausalLM"],
131
+ "attention_dropout": 0.0,
132
+ "bos_token_id": 151643,
133
+ "eos_token_id": 151643,
134
+ "hidden_act": "silu",
135
+ "hidden_size": 2048,
136
+ "initializer_range": 0.02,
137
+ "intermediate_size": 11008,
138
+ "max_position_embeddings": 32768,
139
+ "max_window_layers": 36,
140
+ "model_type": "qwen2",
141
+ "num_attention_heads": 16,
142
+ "num_hidden_layers": 36,
143
+ "num_key_value_heads": 2,
144
+ "rms_norm_eps": 1e-06,
145
+ "rope_theta": 1000000.0,
146
+ "sliding_window": 32768,
147
+ "tie_word_embeddings": True,
148
+ "torch_dtype": "bfloat16",
149
+ "transformers_version": "4.40.1",
150
+ "use_cache": True,
151
+ "use_mrope": False,
152
+ "use_sliding_window": False,
153
+ "vocab_size": 151936
154
+ }
155
+
156
+
157
+ class CLIPDecoderConfig(PretrainedConfig):
158
+ r"""
159
+ Configuration class for CLIPDecoder model (legacy support).
160
+ """
161
+
162
+ model_type = "vlv_stage2"
163
+
164
+ def __init__(
165
+ self,
166
+ input_dim: int = 1024,
167
+ bf16: bool = False,
168
+ **kwargs,
169
+ ):
170
+ self.input_dim = input_dim
171
+ self.bf16 = bf16
172
+ super().__init__(**kwargs)
model-00001-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d7460963b2ea4c7cde35d0c64c8d46d4a9324c7574433f8cf9878bbaf687f61b
3
+ size 622330008
model-00002-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dca6a859202a8817026897383409ec85fb0a22d4b6527da6ab5f5e2ccd3745be
3
+ size 832409864
model-00003-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:30d67c2d202ae6c4166ba0b82310f19225665305e6fb3b22c66ff5318fbf6f50
3
+ size 210079920
model-00004-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c015745c4638633cfb7d09e9b2b96bfa15fd21511fd74642d13296afc9423a4f
3
+ size 5215310704
model-00005-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:da8bab2f53dbd82612d2034d6e67724a171a44fc040198ca5fe9d6120cc3409e
3
+ size 5046894020
model.safetensors.index.json CHANGED
The diff for this file is too large to render. See raw diff
 
modeling_clip.py CHANGED
@@ -1,5 +1,5 @@
1
  from transformers import CLIPTokenizer, CLIPImageProcessor, CLIPTextModel, CLIPPreTrainedModel, CLIPTextConfig
2
- from transformers.models.clip.modeling_clip import CLIPTextEmbeddings, CLIPEncoder, CLIPAttention, CLIPMLP, CLIPEncoderLayer, _create_4d_causal_attention_mask, _prepare_4d_attention_mask, BaseModelOutputWithPooling
3
  from typing import Optional, Union, Tuple
4
  import torch
5
  from torch import nn
@@ -53,7 +53,8 @@ class CustomCLIPTextTransformer(nn.Module):
53
 
54
 
55
  if inputs_embeds is not None:
56
- inputs_embeds = self.embeddings(inputs_embeds=inputs_embeds)
 
57
  else:
58
  inputs_embeds = self.embeddings(input_ids=input_ids, position_ids=position_ids)
59
 
@@ -134,9 +135,49 @@ class CustomCLIPTextModel(CLIPPreTrainedModel):
134
  output_hidden_states: Optional[bool] = None,
135
  return_dict: Optional[bool] = None,
136
  ) -> Union[Tuple, BaseModelOutputWithPooling]:
 
 
 
 
 
 
 
 
 
137
 
138
 
139
- return self.text_model(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  input_ids=input_ids,
141
  attention_mask=attention_mask,
142
  position_ids=position_ids,
@@ -145,3 +186,19 @@ class CustomCLIPTextModel(CLIPPreTrainedModel):
145
  output_hidden_states=output_hidden_states,
146
  return_dict=return_dict,
147
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from transformers import CLIPTokenizer, CLIPImageProcessor, CLIPTextModel, CLIPPreTrainedModel, CLIPTextConfig
2
+ from transformers.models.clip.modeling_clip import CLIPTextEmbeddings, CLIPEncoder, CLIPAttention, CLIPMLP, CLIPEncoderLayer, _create_4d_causal_attention_mask, _prepare_4d_attention_mask, BaseModelOutputWithPooling, CLIPTextModelOutput
3
  from typing import Optional, Union, Tuple
4
  import torch
5
  from torch import nn
 
53
 
54
 
55
  if inputs_embeds is not None:
56
+ # inputs_embeds are already embeddings, just add positional embeddings
57
+ inputs_embeds = self.embeddings.position_embedding(self.embeddings.position_ids[:, :inputs_embeds.size(1)]) + inputs_embeds
58
  else:
59
  inputs_embeds = self.embeddings(input_ids=input_ids, position_ids=position_ids)
60
 
 
135
  output_hidden_states: Optional[bool] = None,
136
  return_dict: Optional[bool] = None,
137
  ) -> Union[Tuple, BaseModelOutputWithPooling]:
138
+ return self.text_model(
139
+ input_ids=input_ids,
140
+ attention_mask=attention_mask,
141
+ position_ids=position_ids,
142
+ inputs_embeds=inputs_embeds,
143
+ output_attentions=output_attentions,
144
+ output_hidden_states=output_hidden_states,
145
+ return_dict=return_dict,
146
+ )
147
 
148
 
149
+ class CustomCLIPTextModelWithProjection(CLIPPreTrainedModel):
150
+ config_class = CLIPTextConfig
151
+ _no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]
152
+
153
+ def __init__(self, config: CLIPTextConfig):
154
+ super().__init__(config)
155
+ self.text_model = CustomCLIPTextTransformer(config)
156
+
157
+ # Add the projection layer for SDXL's second text encoder
158
+ projection_dim = getattr(config, 'projection_dim', config.hidden_size)
159
+ self.text_projection = nn.Linear(config.hidden_size, projection_dim, bias=False)
160
+
161
+ # Initialize weights and apply final processing
162
+ self.post_init()
163
+
164
+ def get_input_embeddings(self) -> nn.Module:
165
+ return self.text_model.embeddings.token_embedding
166
+
167
+ def set_input_embeddings(self, value):
168
+ self.text_model.embeddings.token_embedding = value
169
+
170
+ def forward(
171
+ self,
172
+ input_ids: Optional[torch.Tensor] = None,
173
+ attention_mask: Optional[torch.Tensor] = None,
174
+ position_ids: Optional[torch.Tensor] = None,
175
+ inputs_embeds: Optional[torch.FloatTensor] = None,
176
+ output_attentions: Optional[bool] = None,
177
+ output_hidden_states: Optional[bool] = None,
178
+ return_dict: Optional[bool] = None,
179
+ ) -> Union[Tuple, CLIPTextModelOutput]:
180
+ text_outputs = self.text_model(
181
  input_ids=input_ids,
182
  attention_mask=attention_mask,
183
  position_ids=position_ids,
 
186
  output_hidden_states=output_hidden_states,
187
  return_dict=return_dict,
188
  )
189
+
190
+ pooled_output = text_outputs[1] if not return_dict else text_outputs.pooler_output
191
+
192
+ # Apply the projection to the pooled output
193
+ text_embeds = self.text_projection(pooled_output)
194
+
195
+ if not return_dict:
196
+ # Include both last_hidden_state, pooler_output, text_embeds, and other outputs
197
+ return (text_outputs[0], text_outputs[1], text_embeds) + text_outputs[2:]
198
+
199
+ return CLIPTextModelOutput(
200
+ text_embeds=text_embeds, # Projected embeddings (for similarity)
201
+ last_hidden_state=text_outputs.last_hidden_state, # All token representations
202
+ hidden_states=text_outputs.hidden_states,
203
+ attentions=text_outputs.attentions,
204
+ )
vlv_utils.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions"""
2
+ import importlib
3
+ import random
4
+ import re
5
+ import torch
6
+ import numpy as np
7
+ from PIL import Image
8
+
9
+
10
+ def normalize(image,rescale=True):
11
+
12
+ if rescale:
13
+ image = image.float() / 255.0 # Convert to float and rescale to [0, 1]
14
+ normalize_image = 2*image-1 # normalize to [-1, 1]
15
+
16
+ return normalize_image
17
+
18
+
19
+
20
+ def process_caption(caption):
21
+ """Process a caption to ensure proper formatting and remove duplicates.
22
+
23
+ Args:
24
+ caption: A string containing the caption text
25
+
26
+ Returns:
27
+ processed_caption: A string with processed caption
28
+ """
29
+ if not caption.endswith('.'):
30
+ last_period_index = caption.rfind('.')
31
+ if last_period_index != -1:
32
+ caption = caption[:last_period_index + 1]
33
+
34
+ sentences = re.split(r'(?<=[.!?])\s+', caption)
35
+
36
+ unique_sentences = []
37
+ for sentence in sentences:
38
+ if sentence and sentence not in unique_sentences:
39
+ unique_sentences.append(sentence)
40
+
41
+ processed_caption = ' '.join(unique_sentences)
42
+
43
+ return processed_caption
44
+
45
+
46
+ def initiate_time_steps(step, total_timestep, batch_size, config):
47
+ """A helper function to initiate time steps for the diffusion model.
48
+
49
+ Args:
50
+ step: An integer of the constant step
51
+ total_timestep: An integer of the total timesteps of the diffusion model
52
+ batch_size: An integer of the batch size
53
+ config: A config object
54
+
55
+ Returns:
56
+ timesteps: A tensor of shape [batch_size,] of the time steps
57
+ """
58
+ if config.rand_timestep_equal_int:
59
+ # the same timestep for each image in the batch
60
+ interval_val = total_timestep // batch_size
61
+ start_point = random.randint(0, interval_val - 1)
62
+ timesteps = torch.tensor(
63
+ list(range(start_point, total_timestep, interval_val))
64
+ ).long()
65
+ return timesteps
66
+ elif config.random_timestep_per_iteration:
67
+ # random timestep for each image in the batch
68
+ return torch.randint(0, total_timestep, (batch_size,)).long() #default
69
+ else:
70
+ # why we need to do this?
71
+ return torch.tensor([step] * batch_size).long()