|
import argparse |
|
import os |
|
from os.path import dirname |
|
|
|
import numpy as np |
|
import torch |
|
import yaml |
|
from pytorch_lightning.utilities import rank_zero_warn |
|
|
|
|
|
def train_val_test_split(dset_len, train_size, val_size, test_size, seed): |
|
|
|
assert (train_size is None) + (val_size is None) + (test_size is None) <= 1, "Only one of train_size, val_size, test_size is allowed to be None." |
|
|
|
is_float = (isinstance(train_size, float), isinstance(val_size, float), isinstance(test_size, float)) |
|
|
|
train_size = round(dset_len * train_size) if is_float[0] else train_size |
|
val_size = round(dset_len * val_size) if is_float[1] else val_size |
|
test_size = round(dset_len * test_size) if is_float[2] else test_size |
|
|
|
if train_size is None: |
|
train_size = dset_len - val_size - test_size |
|
elif val_size is None: |
|
val_size = dset_len - train_size - test_size |
|
elif test_size is None: |
|
test_size = dset_len - train_size - val_size |
|
|
|
if train_size + val_size + test_size > dset_len: |
|
if is_float[2]: |
|
test_size -= 1 |
|
elif is_float[1]: |
|
val_size -= 1 |
|
elif is_float[0]: |
|
train_size -= 1 |
|
|
|
assert train_size >= 0 and val_size >= 0 and test_size >= 0, ( |
|
f"One of training ({train_size}), validation ({val_size}) or " |
|
f"testing ({test_size}) splits ended up with a negative size." |
|
) |
|
|
|
total = train_size + val_size + test_size |
|
assert dset_len >= total, f"The dataset ({dset_len}) is smaller than the combined split sizes ({total})." |
|
|
|
if total < dset_len: |
|
rank_zero_warn(f"{dset_len - total} samples were excluded from the dataset") |
|
|
|
idxs = np.arange(dset_len, dtype=np.int64) |
|
idxs = np.random.default_rng(seed).permutation(idxs) |
|
|
|
idx_train = idxs[:train_size] |
|
idx_val = idxs[train_size: train_size + val_size] |
|
idx_test = idxs[train_size + val_size: total] |
|
|
|
return np.array(idx_train), np.array(idx_val), np.array(idx_test) |
|
|
|
|
|
def make_splits(dataset_len, train_size, val_size, test_size, seed, filename=None, splits=None): |
|
if splits is not None: |
|
splits = np.load(splits) |
|
idx_train = splits["idx_train"] |
|
idx_val = splits["idx_val"] |
|
idx_test = splits["idx_test"] |
|
else: |
|
idx_train, idx_val, idx_test = train_val_test_split(dataset_len, train_size, val_size, test_size, seed) |
|
|
|
if filename is not None: |
|
np.savez(filename, idx_train=idx_train, idx_val=idx_val, idx_test=idx_test) |
|
|
|
return torch.from_numpy(idx_train), torch.from_numpy(idx_val), torch.from_numpy(idx_test) |
|
|
|
|
|
class LoadFromFile(argparse.Action): |
|
def __call__(self, parser, namespace, values, option_string=None): |
|
if values.name.endswith("yaml") or values.name.endswith("yml"): |
|
with values as f: |
|
config = yaml.load(f, Loader=yaml.FullLoader) |
|
for key in config.keys(): |
|
if key not in namespace: |
|
raise ValueError(f"Unknown argument in config file: {key}") |
|
namespace.__dict__.update(config) |
|
else: |
|
raise ValueError("Configuration file must end with yaml or yml") |
|
|
|
|
|
class LoadFromCheckpoint(argparse.Action): |
|
def __call__(self, parser, namespace, values, option_string=None): |
|
ckpt = torch.load(values, map_location="cpu") |
|
config = ckpt["hyper_parameters"] |
|
for key in config.keys(): |
|
if key not in namespace: |
|
raise ValueError(f"Unknown argument in the model checkpoint: {key}") |
|
namespace.__dict__.update(config) |
|
namespace.__dict__.update(load_model=values) |
|
|
|
|
|
def save_argparse(args, filename, exclude=None): |
|
os.makedirs(dirname(filename), exist_ok=True) |
|
if filename.endswith("yaml") or filename.endswith("yml"): |
|
if isinstance(exclude, str): |
|
exclude = [exclude] |
|
args = args.__dict__.copy() |
|
for exl in exclude: |
|
del args[exl] |
|
yaml.dump(args, open(filename, "w")) |
|
else: |
|
raise ValueError("Configuration file should end with yaml or yml") |
|
|
|
|
|
def number(text): |
|
if text is None or text == "None": |
|
return None |
|
|
|
try: |
|
num_int = int(text) |
|
except ValueError: |
|
num_int = None |
|
num_float = float(text) |
|
|
|
if num_int == num_float: |
|
return num_int |
|
return num_float |
|
|
|
|
|
class MissingLabelException(Exception): |
|
pass |