lambertxiao's picture
Upload folder using huggingface_hub
746c807 verified
raw
history blame
5.01 kB
"""Utility functions"""
import importlib
import random
import torch
import numpy as np
from PIL import Image
class UnNormalize(object):
"""Unformalize image as: image = (image * std) + mean
"""
def __init__(self, mean, std):
self.mean = torch.tensor(mean)
self.std = torch.tensor(std)
def __call__(self, tensor):
"""
Args:
tensor: A tensor of shape [C, H, W] or [N, C, H, W]
Returns:
tensor: A tensor of shape [C, H, W] or [N, C, H, W]
"""
std = self.std.to(tensor.device)
mean = self.mean.to(tensor.device)
if tensor.ndim == 3:
std, mean = std.view(-1, 1, 1), mean.view(-1, 1, 1)
elif tensor.ndim == 4:
std, mean = std.view(1, -1, 1, 1), mean.view(1, -1, 1, 1)
tensor = (tensor * std) + mean
return tensor
class VQVAEUnNormalize(UnNormalize):
"""Unformalize image as:
First: image = (image * std) + mean
Second: image = (image * 2) - 1
"""
def __call__(self, tensor):
"""
Args:
tensor (Tensor): Tensor image of size (C, H, W) or (N, C, H, W)
to be unnormalized.
Returns:
Tensor: UnNormalized image.
"""
tensor = super().__call__(tensor)
tensor = 2 * tensor - 1
return tensor
def normalize(image,rescale=True):
if rescale:
image = image.float() / 255.0 # Convert to float and rescale to [0, 1]
normalize_image = 2*image-1 # normalize to [-1, 1]
return normalize_image
# train_transforms = transforms.Compose(
# [
# transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
# transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
# transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
# transforms.ToTensor(),
# transforms.Normalize([0.5], [0.5]),
# ]
# )
def mean_list(l):
l = [int(_l) for _l in l]
return float(sum(l)) / len(l)
def segment_mean(x, index):
"""Function as tf.segment_mean.
"""
x = x.view(-1, x.shape[-1])
index = index.view(-1)
max_index = index.max() + 1
sum_x = torch.zeros((max_index, x.shape[-1]),
dtype=x.dtype,
device=x.device)
num_index = torch.zeros((max_index,),
dtype=x.dtype,
device=x.device)
num_index = num_index.scatter_add_(
0, index, torch.ones_like(index, dtype=x.dtype))
num_index = torch.where(torch.eq(num_index, 0),
torch.ones_like(num_index, dtype=x.dtype),
num_index)
index_2d = index.view(-1, 1).expand(-1, x.shape[-1])
sum_x = sum_x.scatter_add_(0, index_2d, x)
mean_x = sum_x.div_(num_index.view(-1, 1))
return mean_x
def initiate_time_steps(step, total_timestep, batch_size, config):
"""A helper function to initiate time steps for the diffusion model.
Args:
step: An integer of the constant step
total_timestep: An integer of the total timesteps of the diffusion model
batch_size: An integer of the batch size
config: A config object
Returns:
timesteps: A tensor of shape [batch_size,] of the time steps
"""
if config.tta.rand_timestep_equal_int:
# the same timestep for each image in the batch
interval_val = total_timestep // batch_size
start_point = random.randint(0, interval_val - 1)
timesteps = torch.tensor(
list(range(start_point, total_timestep, interval_val))
).long()
return timesteps
elif config.tta.random_timestep_per_iteration:
# random timestep for each image in the batch
return torch.randint(0, total_timestep, (batch_size,)).long() #default
else:
# why we need to do this?
return torch.tensor([step] * batch_size).long()
def instantiate_from_config(config):
"""A helper function to instantiate a class from a config object.
See https://github.com/CompVis/stable-diffusion/blob/main/ldm/util.py
"""
if not "target" in config:
if config == '__is_first_stage__':
return None
elif config == "__is_unconditional__":
return None
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()))
def get_obj_from_str(string, reload=False):
"""A helper function to instantiate a class from a config object.
See https://github.com/CompVis/stable-diffusion/blob/main/ldm/util.py
"""
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)