|
"""Utility functions""" |
|
import importlib |
|
import random |
|
|
|
import torch |
|
import numpy as np |
|
from PIL import Image |
|
|
|
|
|
|
|
class UnNormalize(object): |
|
"""Unformalize image as: image = (image * std) + mean |
|
""" |
|
def __init__(self, mean, std): |
|
self.mean = torch.tensor(mean) |
|
self.std = torch.tensor(std) |
|
|
|
def __call__(self, tensor): |
|
""" |
|
Args: |
|
tensor: A tensor of shape [C, H, W] or [N, C, H, W] |
|
|
|
Returns: |
|
tensor: A tensor of shape [C, H, W] or [N, C, H, W] |
|
""" |
|
|
|
std = self.std.to(tensor.device) |
|
mean = self.mean.to(tensor.device) |
|
if tensor.ndim == 3: |
|
std, mean = std.view(-1, 1, 1), mean.view(-1, 1, 1) |
|
elif tensor.ndim == 4: |
|
std, mean = std.view(1, -1, 1, 1), mean.view(1, -1, 1, 1) |
|
tensor = (tensor * std) + mean |
|
return tensor |
|
|
|
|
|
class VQVAEUnNormalize(UnNormalize): |
|
"""Unformalize image as: |
|
First: image = (image * std) + mean |
|
Second: image = (image * 2) - 1 |
|
""" |
|
def __call__(self, tensor): |
|
""" |
|
Args: |
|
tensor (Tensor): Tensor image of size (C, H, W) or (N, C, H, W) |
|
to be unnormalized. |
|
Returns: |
|
Tensor: UnNormalized image. |
|
""" |
|
tensor = super().__call__(tensor) |
|
tensor = 2 * tensor - 1 |
|
return tensor |
|
|
|
def normalize(image,rescale=True): |
|
|
|
if rescale: |
|
image = image.float() / 255.0 |
|
normalize_image = 2*image-1 |
|
|
|
return normalize_image |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def mean_list(l): |
|
l = [int(_l) for _l in l] |
|
return float(sum(l)) / len(l) |
|
|
|
|
|
def segment_mean(x, index): |
|
"""Function as tf.segment_mean. |
|
""" |
|
x = x.view(-1, x.shape[-1]) |
|
index = index.view(-1) |
|
|
|
max_index = index.max() + 1 |
|
sum_x = torch.zeros((max_index, x.shape[-1]), |
|
dtype=x.dtype, |
|
device=x.device) |
|
num_index = torch.zeros((max_index,), |
|
dtype=x.dtype, |
|
device=x.device) |
|
|
|
num_index = num_index.scatter_add_( |
|
0, index, torch.ones_like(index, dtype=x.dtype)) |
|
num_index = torch.where(torch.eq(num_index, 0), |
|
torch.ones_like(num_index, dtype=x.dtype), |
|
num_index) |
|
|
|
index_2d = index.view(-1, 1).expand(-1, x.shape[-1]) |
|
sum_x = sum_x.scatter_add_(0, index_2d, x) |
|
mean_x = sum_x.div_(num_index.view(-1, 1)) |
|
|
|
return mean_x |
|
|
|
|
|
|
|
|
|
|
|
def initiate_time_steps(step, total_timestep, batch_size, config): |
|
"""A helper function to initiate time steps for the diffusion model. |
|
|
|
Args: |
|
step: An integer of the constant step |
|
total_timestep: An integer of the total timesteps of the diffusion model |
|
batch_size: An integer of the batch size |
|
config: A config object |
|
|
|
Returns: |
|
timesteps: A tensor of shape [batch_size,] of the time steps |
|
""" |
|
if config.tta.rand_timestep_equal_int: |
|
|
|
interval_val = total_timestep // batch_size |
|
start_point = random.randint(0, interval_val - 1) |
|
timesteps = torch.tensor( |
|
list(range(start_point, total_timestep, interval_val)) |
|
).long() |
|
return timesteps |
|
elif config.tta.random_timestep_per_iteration: |
|
|
|
return torch.randint(0, total_timestep, (batch_size,)).long() |
|
else: |
|
|
|
return torch.tensor([step] * batch_size).long() |
|
|
|
|
|
def instantiate_from_config(config): |
|
"""A helper function to instantiate a class from a config object. |
|
See https://github.com/CompVis/stable-diffusion/blob/main/ldm/util.py |
|
""" |
|
if not "target" in config: |
|
if config == '__is_first_stage__': |
|
return None |
|
elif config == "__is_unconditional__": |
|
return None |
|
raise KeyError("Expected key `target` to instantiate.") |
|
return get_obj_from_str(config["target"])(**config.get("params", dict())) |
|
|
|
|
|
def get_obj_from_str(string, reload=False): |
|
"""A helper function to instantiate a class from a config object. |
|
See https://github.com/CompVis/stable-diffusion/blob/main/ldm/util.py |
|
""" |
|
module, cls = string.rsplit(".", 1) |
|
if reload: |
|
module_imp = importlib.import_module(module) |
|
importlib.reload(module_imp) |
|
return getattr(importlib.import_module(module, package=None), cls) |