Update De_DiffusionV2_Image.py
Browse files- De_DiffusionV2_Image.py +19 -2
De_DiffusionV2_Image.py
CHANGED
@@ -145,8 +145,15 @@ class SDModel(PreTrainedModel):
|
|
145 |
un_context_embeddings = self.text_encoder(un_token).last_hidden_state
|
146 |
un_context_embeddings = un_context_embeddings.expand(batch_size, -1, -1)
|
147 |
if self.training_args.use_text_encoder:
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
context_embeddings = self.text_encoder(
|
149 |
-
inputs_embeds=conditional_context.to(self._dtype)
|
|
|
150 |
).last_hidden_state # 1, 77 , 1024
|
151 |
|
152 |
latent_shape = (batch_size, 4, self.training_args.image_size // 8, self.training_args.image_size // 8)
|
@@ -250,7 +257,17 @@ class SDModel(PreTrainedModel):
|
|
250 |
conditional_context = self.language_proj(conditional_context) # [b, 159, 1024]
|
251 |
|
252 |
if self.training_args.use_text_encoder:
|
253 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
254 |
pred_noise = self._unet_pred_noise(x_start=latent, t=timesteps, noise=noise, context=text_encoder_output.last_hidden_state.to(self._dtype)).to(self._dtype)
|
255 |
else:
|
256 |
pred_noise = self._unet_pred_noise(x_start=latent, t=timesteps, noise=noise, context=conditional_context.to(self._dtype)).to(self._dtype)
|
|
|
145 |
un_context_embeddings = self.text_encoder(un_token).last_hidden_state
|
146 |
un_context_embeddings = un_context_embeddings.expand(batch_size, -1, -1)
|
147 |
if self.training_args.use_text_encoder:
|
148 |
+
# Create attention mask for conditional_context
|
149 |
+
context_attention_mask = torch.ones(
|
150 |
+
(batch_size, conditional_context.shape[1]),
|
151 |
+
dtype=torch.long,
|
152 |
+
device=self._device
|
153 |
+
)
|
154 |
context_embeddings = self.text_encoder(
|
155 |
+
inputs_embeds=conditional_context.to(self._dtype),
|
156 |
+
attention_mask=context_attention_mask
|
157 |
).last_hidden_state # 1, 77 , 1024
|
158 |
|
159 |
latent_shape = (batch_size, 4, self.training_args.image_size // 8, self.training_args.image_size // 8)
|
|
|
257 |
conditional_context = self.language_proj(conditional_context) # [b, 159, 1024]
|
258 |
|
259 |
if self.training_args.use_text_encoder:
|
260 |
+
# Create attention mask for conditional_context
|
261 |
+
context_attention_mask = torch.ones(
|
262 |
+
(self.batch_size, conditional_context.shape[1]),
|
263 |
+
dtype=torch.long,
|
264 |
+
device=self._device
|
265 |
+
)
|
266 |
+
text_encoder_output = self.text_encoder(
|
267 |
+
input_ids=None,
|
268 |
+
inputs_embeds=conditional_context.to(self._dtype),
|
269 |
+
attention_mask=context_attention_mask
|
270 |
+
)
|
271 |
pred_noise = self._unet_pred_noise(x_start=latent, t=timesteps, noise=noise, context=text_encoder_output.last_hidden_state.to(self._dtype)).to(self._dtype)
|
272 |
else:
|
273 |
pred_noise = self._unet_pred_noise(x_start=latent, t=timesteps, noise=noise, context=conditional_context.to(self._dtype)).to(self._dtype)
|