from dataclasses import dataclass from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from transformers.utils import ModelOutput from transformers.modeling_utils import PreTrainedModel import torchvision.transforms as transforms import os from safetensors.torch import load_file from .build_unfreeze import load_sd_model, load_Florence2_model from .utils import initiate_time_steps, normalize class MLP(nn.Module): def __init__(self, input_dim, output_dim): super(MLP, self).__init__() self.layers = nn.Sequential( nn.Linear(input_dim, output_dim), nn.GELU(), nn.Linear(output_dim, output_dim), ) def forward(self, x): return self.layers(x) @dataclass class SDOutput(ModelOutput): loss: Optional[torch.FloatTensor] = None class SDModel(PreTrainedModel): def __init__( self, config = None, training_args = None, ): super().__init__(config) self.training_args = training_args if self.training_args.fp32: self._dtype = torch.float32 else: self._dtype = torch.bfloat16 # Change device to _device to avoid conflict with nn.Module self._device = torch.device(config.device if hasattr(config, 'device') else "cuda" if torch.cuda.is_available() else "cpu") self.vae, self.tokenizer, self.text_encoder, self.unet, self.scheduler = load_sd_model(training_args) torch.cuda.empty_cache() self.unet.eval() self.text_encoder.eval() self.model, self.processor = load_Florence2_model(training_args) self.config = config # Move models to appropriate device self.unet = self.unet.to(self._dtype).to(self._device) self.text_encoder = self.text_encoder.to(self._dtype).to_empty(device=self._device) self.model = self.model.to(self._dtype).to_empty(device=self._device) self.vae = self.vae.to(torch.float32).to_empty(device=self._device) self.batch_size = self.training_args.batch_size hidden_dim = 1024 self.language_proj = nn.Sequential( nn.Linear(1024, hidden_dim, dtype=self._dtype), nn.GELU(), nn.Linear(hidden_dim, 1024, dtype=self._dtype) ).to_empty(device=self._device) for param in self.language_proj.parameters(): param.requires_grad = True # Add learnable queries for decoder self.num_queries = self.training_args.learnable_token_length # 77 # # 128 self.query_embed = nn.Parameter(torch.randn(1, self.num_queries, 1024, dtype=self._dtype)) self.query_embed.requires_grad = True self.unet.enable_gradient_checkpointing() def _unet_pred_noise(self, x_start, t, noise, context): # Convert timesteps to long tensor t = t.to(dtype=torch.long) # Ensure consistent dtype for all tensors dtype = self.unet.dtype x_start = x_start.to(dtype) noise = noise.to(dtype) context = context.to(dtype) # Add noise nt = t.shape[0] noised_latent = self.scheduler.add_noise(x_start, noise, t) # Get prediction pred_noise = self.unet( noised_latent, t, encoder_hidden_states=context.expand(nt, -1, -1) ).sample return pred_noise def generate_images(self, images): batch_size = self.training_args.eval_batch_size prompt = [""] * batch_size inputs = self.processor(text=prompt, images=images, return_tensors="pt").to(self._device).to(self._dtype) # Get embeddings if inputs["input_ids"] is not None: inputs_embeds = self.model.language_model.get_input_embeddings()(inputs["input_ids"]).to(self._dtype) if inputs["pixel_values"] is not None: image_features = self.model._encode_image(inputs["pixel_values"]).to(self._dtype) inputs_embeds, attention_mask = self.model._merge_input_ids_with_image_features(image_features, inputs_embeds) if inputs_embeds is not None: attention_mask = attention_mask.to(inputs_embeds.dtype) encoder_outputs = self.model.language_model.model.encoder( inputs_embeds=inputs_embeds, attention_mask=attention_mask, output_hidden_states=True, return_dict=True ) # Prepare decoder inputs decoder_input_embeds = self.query_embed.expand(batch_size, -1, -1) # [batch_size, 128, 1024] decoder_attention_mask = torch.ones( (batch_size, self.num_queries), dtype=self._dtype, device=self._device ) encoder_hidden_states = encoder_outputs.last_hidden_state.to(self._dtype) decoder_input_embeds = decoder_input_embeds.to(self._dtype) attention_mask = attention_mask.to(self._dtype) # Run decoder decoder_outputs = self.model.language_model.model.decoder( inputs_embeds=decoder_input_embeds, attention_mask=decoder_attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=attention_mask, output_hidden_states=True, return_dict=True ) last_decoder_hidden_state = decoder_outputs.last_hidden_state # [batch_size, 128, 1024] conditional_context = self.language_proj(last_decoder_hidden_state) un_token = self.tokenizer("", padding="max_length", truncation=True,max_length=77, return_tensors="pt").input_ids.to(self._device) un_context_embeddings = self.text_encoder(un_token).last_hidden_state un_context_embeddings = un_context_embeddings.expand(batch_size, -1, -1) if self.training_args.use_text_encoder: context_embeddings = self.text_encoder( inputs_embeds=conditional_context.to(self._dtype) ).last_hidden_state # 1, 77 , 1024 latent_shape = (batch_size, 4, self.training_args.image_size // 8, self.training_args.image_size // 8) latents = torch.randn(latent_shape, device=self._device, dtype=self._dtype) scheduler = self.scheduler scheduler.set_timesteps(self.training_args.num_inference_steps) with torch.no_grad(): for t in scheduler.timesteps: latent_model_input = torch.cat([latents, latents], dim=0) latent_model_input = scheduler.scale_model_input(latent_model_input, t) combined_embeddings = torch.cat([un_context_embeddings, context_embeddings], dim=0).to(self._dtype) noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=combined_embeddings )[0] noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2, dim=0) noise_pred = noise_pred_uncond + self.training_args.guidance_scale * (noise_pred_cond - noise_pred_uncond) latents = scheduler.step(noise_pred, t, latents)[0] # Decode latents with VAE scaled_latents = latents / 0.18215 with torch.no_grad(): decoded_latents = self.vae.decode(scaled_latents.to(torch.float32))[0] return decoded_latents def get_conditional_context(self, images, batch_size=None): if batch_size is None: batch_size = self.batch_size prompt = [""] * batch_size inputs = self.processor(text=prompt, images=images, return_tensors="pt").to(self._device).to(self._dtype) # Get embeddings if inputs["input_ids"] is not None: inputs_embeds = self.model.language_model.get_input_embeddings()(inputs["input_ids"]).to(self._dtype) if inputs["pixel_values"] is not None: image_features = self.model._encode_image(inputs["pixel_values"]).to(self._dtype) inputs_embeds, attention_mask = self.model._merge_input_ids_with_image_features(image_features, inputs_embeds) if inputs_embeds is not None: attention_mask = attention_mask.to(inputs_embeds.dtype) encoder_outputs = self.model.language_model.model.encoder( inputs_embeds=inputs_embeds, attention_mask=attention_mask, output_hidden_states=True, return_dict=True ) # Prepare decoder inputs decoder_input_embeds = self.query_embed.expand(batch_size, -1, -1) # [batch_size, 128, 1024] decoder_attention_mask = torch.ones( (batch_size, self.num_queries), dtype=self._dtype, device=self._device ) encoder_hidden_states = encoder_outputs.last_hidden_state.to(self._dtype) decoder_input_embeds = decoder_input_embeds.to(self._dtype) attention_mask = attention_mask.to(self._dtype) # Run decoder decoder_outputs = self.model.language_model.model.decoder( inputs_embeds=decoder_input_embeds, attention_mask=decoder_attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=attention_mask, output_hidden_states=True, return_dict=True ) last_decoder_hidden_state = decoder_outputs.last_hidden_state # [batch_size, 128, 1024] return last_decoder_hidden_state def forward( self, image=None, filename=None, **kwargs, ) -> SDOutput: images_for_language_model = image normalize_images = normalize(image, rescale=True) x0=self.vae.encode(normalize_images.to(torch.float32)).latent_dist.sample() latent = x0 * 0.18215 # prepare_total_timesteps total_timestep = self.scheduler.num_train_timesteps # Initiate timesteps and noise timesteps = initiate_time_steps(0, total_timestep, self.batch_size, self.config).long() timesteps = timesteps.to(self._device) c, h, w = latent.shape[1:] if not self.config.tta.use_same_noise_among_timesteps: noise = torch.randn((self.batch_size, c, h, w), device=self._device, dtype=self._dtype) else: noise = torch.randn((1, c, h, w), device=self._device, dtype=self._dtype) noise = noise.repeat(self.batch_size, 1, 1, 1) # prepare context for language model conditional_context = self.get_conditional_context(images_for_language_model) conditional_context = self.language_proj(conditional_context) # [b, 159, 1024] if self.training_args.use_text_encoder: text_encoder_output = self.text_encoder(input_ids=None, inputs_embeds=conditional_context.to(self._dtype)) pred_noise = self._unet_pred_noise(x_start=latent, t=timesteps, noise=noise, context=text_encoder_output.last_hidden_state.to(self._dtype)).to(self._dtype) else: pred_noise = self._unet_pred_noise(x_start=latent, t=timesteps, noise=noise, context=conditional_context.to(self._dtype)).to(self._dtype) # Compute diffusion loss if self.config.tta.loss == "l1": loss = torch.nn.functional.l1_loss(pred_noise, noise) else: loss = torch.nn.functional.mse_loss(pred_noise, noise) return SDOutput(loss=loss)