lambertxiao commited on
Commit
17c91d1
·
verified ·
1 Parent(s): 281e410

Update De_DiffusionV2_Image.py

Browse files
Files changed (1) hide show
  1. De_DiffusionV2_Image.py +19 -2
De_DiffusionV2_Image.py CHANGED
@@ -145,8 +145,15 @@ class SDModel(PreTrainedModel):
145
  un_context_embeddings = self.text_encoder(un_token).last_hidden_state
146
  un_context_embeddings = un_context_embeddings.expand(batch_size, -1, -1)
147
  if self.training_args.use_text_encoder:
 
 
 
 
 
 
148
  context_embeddings = self.text_encoder(
149
- inputs_embeds=conditional_context.to(self._dtype)
 
150
  ).last_hidden_state # 1, 77 , 1024
151
 
152
  latent_shape = (batch_size, 4, self.training_args.image_size // 8, self.training_args.image_size // 8)
@@ -250,7 +257,17 @@ class SDModel(PreTrainedModel):
250
  conditional_context = self.language_proj(conditional_context) # [b, 159, 1024]
251
 
252
  if self.training_args.use_text_encoder:
253
- text_encoder_output = self.text_encoder(input_ids=None, inputs_embeds=conditional_context.to(self._dtype))
 
 
 
 
 
 
 
 
 
 
254
  pred_noise = self._unet_pred_noise(x_start=latent, t=timesteps, noise=noise, context=text_encoder_output.last_hidden_state.to(self._dtype)).to(self._dtype)
255
  else:
256
  pred_noise = self._unet_pred_noise(x_start=latent, t=timesteps, noise=noise, context=conditional_context.to(self._dtype)).to(self._dtype)
 
145
  un_context_embeddings = self.text_encoder(un_token).last_hidden_state
146
  un_context_embeddings = un_context_embeddings.expand(batch_size, -1, -1)
147
  if self.training_args.use_text_encoder:
148
+ # Create attention mask for conditional_context
149
+ context_attention_mask = torch.ones(
150
+ (batch_size, conditional_context.shape[1]),
151
+ dtype=torch.long,
152
+ device=self._device
153
+ )
154
  context_embeddings = self.text_encoder(
155
+ inputs_embeds=conditional_context.to(self._dtype),
156
+ attention_mask=context_attention_mask
157
  ).last_hidden_state # 1, 77 , 1024
158
 
159
  latent_shape = (batch_size, 4, self.training_args.image_size // 8, self.training_args.image_size // 8)
 
257
  conditional_context = self.language_proj(conditional_context) # [b, 159, 1024]
258
 
259
  if self.training_args.use_text_encoder:
260
+ # Create attention mask for conditional_context
261
+ context_attention_mask = torch.ones(
262
+ (self.batch_size, conditional_context.shape[1]),
263
+ dtype=torch.long,
264
+ device=self._device
265
+ )
266
+ text_encoder_output = self.text_encoder(
267
+ input_ids=None,
268
+ inputs_embeds=conditional_context.to(self._dtype),
269
+ attention_mask=context_attention_mask
270
+ )
271
  pred_noise = self._unet_pred_noise(x_start=latent, t=timesteps, noise=noise, context=text_encoder_output.last_hidden_state.to(self._dtype)).to(self._dtype)
272
  else:
273
  pred_noise = self._unet_pred_noise(x_start=latent, t=timesteps, noise=noise, context=conditional_context.to(self._dtype)).to(self._dtype)