lambertxiao commited on
Commit
10abe26
·
verified ·
1 Parent(s): 17c91d1

Update De_DiffusionV2_stage2.py

Browse files
Files changed (1) hide show
  1. De_DiffusionV2_stage2.py +12 -1
De_DiffusionV2_stage2.py CHANGED
@@ -285,7 +285,18 @@ class CLIPDecoder(nn.Module):
285
  """
286
  decoder_hidden_states = self.get_conditional_context(images, batch_size)
287
  context_embeds = self.VLV_model.language_proj(decoder_hidden_states)
288
- clip_text_embeds = self.VLV_model.text_encoder(inputs_embeds=context_embeds).last_hidden_state
 
 
 
 
 
 
 
 
 
 
 
289
  # clip_text_embeds = clip_text_embeds.to(self._dtype)
290
  clip_text_embeds = self.mlp(clip_text_embeds)
291
  clip_text_embeds_attention_mask = torch.ones(
 
285
  """
286
  decoder_hidden_states = self.get_conditional_context(images, batch_size)
287
  context_embeds = self.VLV_model.language_proj(decoder_hidden_states)
288
+
289
+ # Create attention mask for context_embeds
290
+ context_attention_mask = torch.ones(
291
+ (batch_size, context_embeds.shape[1]),
292
+ dtype=torch.long,
293
+ device=self.device
294
+ )
295
+
296
+ clip_text_embeds = self.VLV_model.text_encoder(
297
+ inputs_embeds=context_embeds,
298
+ attention_mask=context_attention_mask
299
+ ).last_hidden_state
300
  # clip_text_embeds = clip_text_embeds.to(self._dtype)
301
  clip_text_embeds = self.mlp(clip_text_embeds)
302
  clip_text_embeds_attention_mask = torch.ones(