File size: 7,534 Bytes
62a2f1c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 |
from os.path import join
import torch
from pytorch_lightning import LightningDataModule
from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn
from torch.utils.data import Subset
from torch_geometric.loader import DataLoader
from torch_scatter import scatter
from tqdm import tqdm
from visnet.datasets import *
from visnet.utils import MissingLabelException, make_splits
class DataModule(LightningDataModule):
def __init__(self, hparams):
super(DataModule, self).__init__()
self.hparams.update(hparams.__dict__) if hasattr(hparams, "__dict__") else self.hparams.update(hparams)
self._mean, self._std = None, None
self._saved_dataloaders = dict()
self.dataset = None
def prepare_dataset(self):
assert hasattr(self, f"_prepare_{self.hparams['dataset']}_dataset"), f"Dataset {self.hparams['dataset']} not defined"
dataset_factory = lambda t: getattr(self, f"_prepare_{t}_dataset")()
self.idx_train, self.idx_val, self.idx_test = dataset_factory(self.hparams["dataset"])
print(f"train {len(self.idx_train)}, val {len(self.idx_val)}, test {len(self.idx_test)}")
self.train_dataset = Subset(self.dataset, self.idx_train)
self.val_dataset = Subset(self.dataset, self.idx_val)
self.test_dataset = Subset(self.dataset, self.idx_test)
if self.hparams["standardize"]:
self._standardize()
def train_dataloader(self):
return self._get_dataloader(self.train_dataset, "train")
def val_dataloader(self):
loaders = [self._get_dataloader(self.val_dataset, "val")]
delta = 1 if self.hparams['reload'] == 1 else 2
if (
len(self.test_dataset) > 0
and (self.trainer.current_epoch + delta) % self.hparams["test_interval"] == 0
):
loaders.append(self._get_dataloader(self.test_dataset, "test"))
return loaders
def test_dataloader(self):
return self._get_dataloader(self.test_dataset, "test")
@property
def atomref(self):
if hasattr(self.dataset, "get_atomref"):
return self.dataset.get_atomref()
return None
@property
def mean(self):
return self._mean
@property
def std(self):
return self._std
def _get_dataloader(self, dataset, stage, store_dataloader=True):
store_dataloader = (store_dataloader and not self.hparams["reload"])
if stage in self._saved_dataloaders and store_dataloader:
return self._saved_dataloaders[stage]
if stage == "train":
batch_size = self.hparams["batch_size"]
shuffle = True
elif stage in ["val", "test"]:
batch_size = self.hparams["inference_batch_size"]
shuffle = False
dl = DataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=self.hparams["num_workers"],
pin_memory=True,
)
if store_dataloader:
self._saved_dataloaders[stage] = dl
return dl
@rank_zero_only
def _standardize(self):
def get_label(batch, atomref):
if batch.y is None:
raise MissingLabelException()
if atomref is None:
return batch.y.clone()
atomref_energy = scatter(atomref[batch.z], batch.batch, dim=0)
return (batch.y.squeeze() - atomref_energy.squeeze()).clone()
data = tqdm(
self._get_dataloader(self.train_dataset, "val", store_dataloader=False),
desc="computing mean and std",
)
try:
atomref = self.atomref if self.hparams["prior_model"] == "Atomref" else None
ys = torch.cat([get_label(batch, atomref) for batch in data])
except MissingLabelException:
rank_zero_warn(
"Standardize is true but failed to compute dataset mean and "
"standard deviation. Maybe the dataset only contains forces."
)
return None
self._mean = ys.mean(dim=0)
self._std = ys.std(dim=0)
def _prepare_Chignolin_dataset(self):
self.dataset = Chignolin(root=self.hparams["dataset_root"])
train_size = self.hparams["train_size"]
val_size = self.hparams["val_size"]
idx_train, idx_val, idx_test = make_splits(
len(self.dataset),
train_size,
val_size,
None,
self.hparams["seed"],
join(self.hparams["log_dir"], "splits.npz"),
self.hparams["splits"],
)
return idx_train, idx_val, idx_test
def _prepare_MD17_dataset(self):
self.dataset = MD17(root=self.hparams["dataset_root"], dataset_arg=self.hparams["dataset_arg"])
train_size = self.hparams["train_size"]
val_size = self.hparams["val_size"]
idx_train, idx_val, idx_test = make_splits(
len(self.dataset),
train_size,
val_size,
None,
self.hparams["seed"],
join(self.hparams["log_dir"], "splits.npz"),
self.hparams["splits"],
)
return idx_train, idx_val, idx_test
def _prepare_MD22_dataset(self):
self.dataset = MD22(root=self.hparams["dataset_root"], dataset_arg=self.hparams["dataset_arg"])
train_val_size = self.dataset.molecule_splits[self.hparams["dataset_arg"]]
train_size = round(train_val_size * 0.95)
val_size = train_val_size - train_size
idx_train, idx_val, idx_test = make_splits(
len(self.dataset),
train_size,
val_size,
None,
self.hparams["seed"],
join(self.hparams["log_dir"], "splits.npz"),
self.hparams["splits"],
)
return idx_train, idx_val, idx_test
def _prepare_Molecule3D_dataset(self):
self.dataset = Molecule3D(root=self.hparams["dataset_root"])
split_dict = self.dataset.get_idx_split(self.hparams['split_mode'])
idx_train = split_dict['train']
idx_val = split_dict['valid']
idx_test = split_dict['test']
return idx_train, idx_val, idx_test
def _prepare_QM9_dataset(self):
self.dataset = QM9(root=self.hparams["dataset_root"], dataset_arg=self.hparams["dataset_arg"])
train_size = self.hparams["train_size"]
val_size = self.hparams["val_size"]
idx_train, idx_val, idx_test = make_splits(
len(self.dataset),
train_size,
val_size,
None,
self.hparams["seed"],
join(self.hparams["log_dir"], "splits.npz"),
self.hparams["splits"],
)
return idx_train, idx_val, idx_test
def _prepare_rMD17_dataset(self):
self.dataset = rMD17(root=self.hparams["dataset_root"], dataset_arg=self.hparams["dataset_arg"])
train_size = self.hparams["train_size"]
val_size = self.hparams["val_size"]
idx_train, idx_val, idx_test = make_splits(
len(self.dataset),
train_size,
val_size,
None,
self.hparams["seed"],
join(self.hparams["log_dir"], "splits.npz"),
self.hparams["splits"],
)
return idx_train, idx_val, idx_test
|