from os.path import join import torch from pytorch_lightning import LightningDataModule from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn from torch.utils.data import Subset from torch_geometric.loader import DataLoader from torch_scatter import scatter from tqdm import tqdm from visnet.datasets import * from visnet.utils import MissingLabelException, make_splits class DataModule(LightningDataModule): def __init__(self, hparams): super(DataModule, self).__init__() self.hparams.update(hparams.__dict__) if hasattr(hparams, "__dict__") else self.hparams.update(hparams) self._mean, self._std = None, None self._saved_dataloaders = dict() self.dataset = None def prepare_dataset(self): assert hasattr(self, f"_prepare_{self.hparams['dataset']}_dataset"), f"Dataset {self.hparams['dataset']} not defined" dataset_factory = lambda t: getattr(self, f"_prepare_{t}_dataset")() self.idx_train, self.idx_val, self.idx_test = dataset_factory(self.hparams["dataset"]) print(f"train {len(self.idx_train)}, val {len(self.idx_val)}, test {len(self.idx_test)}") self.train_dataset = Subset(self.dataset, self.idx_train) self.val_dataset = Subset(self.dataset, self.idx_val) self.test_dataset = Subset(self.dataset, self.idx_test) if self.hparams["standardize"]: self._standardize() def train_dataloader(self): return self._get_dataloader(self.train_dataset, "train") def val_dataloader(self): loaders = [self._get_dataloader(self.val_dataset, "val")] delta = 1 if self.hparams['reload'] == 1 else 2 if ( len(self.test_dataset) > 0 and (self.trainer.current_epoch + delta) % self.hparams["test_interval"] == 0 ): loaders.append(self._get_dataloader(self.test_dataset, "test")) return loaders def test_dataloader(self): return self._get_dataloader(self.test_dataset, "test") @property def atomref(self): if hasattr(self.dataset, "get_atomref"): return self.dataset.get_atomref() return None @property def mean(self): return self._mean @property def std(self): return self._std def _get_dataloader(self, dataset, stage, store_dataloader=True): store_dataloader = (store_dataloader and not self.hparams["reload"]) if stage in self._saved_dataloaders and store_dataloader: return self._saved_dataloaders[stage] if stage == "train": batch_size = self.hparams["batch_size"] shuffle = True elif stage in ["val", "test"]: batch_size = self.hparams["inference_batch_size"] shuffle = False dl = DataLoader( dataset=dataset, batch_size=batch_size, shuffle=shuffle, num_workers=self.hparams["num_workers"], pin_memory=True, ) if store_dataloader: self._saved_dataloaders[stage] = dl return dl @rank_zero_only def _standardize(self): def get_label(batch, atomref): if batch.y is None: raise MissingLabelException() if atomref is None: return batch.y.clone() atomref_energy = scatter(atomref[batch.z], batch.batch, dim=0) return (batch.y.squeeze() - atomref_energy.squeeze()).clone() data = tqdm( self._get_dataloader(self.train_dataset, "val", store_dataloader=False), desc="computing mean and std", ) try: atomref = self.atomref if self.hparams["prior_model"] == "Atomref" else None ys = torch.cat([get_label(batch, atomref) for batch in data]) except MissingLabelException: rank_zero_warn( "Standardize is true but failed to compute dataset mean and " "standard deviation. Maybe the dataset only contains forces." ) return None self._mean = ys.mean(dim=0) self._std = ys.std(dim=0) def _prepare_Chignolin_dataset(self): self.dataset = Chignolin(root=self.hparams["dataset_root"]) train_size = self.hparams["train_size"] val_size = self.hparams["val_size"] idx_train, idx_val, idx_test = make_splits( len(self.dataset), train_size, val_size, None, self.hparams["seed"], join(self.hparams["log_dir"], "splits.npz"), self.hparams["splits"], ) return idx_train, idx_val, idx_test def _prepare_MD17_dataset(self): self.dataset = MD17(root=self.hparams["dataset_root"], dataset_arg=self.hparams["dataset_arg"]) train_size = self.hparams["train_size"] val_size = self.hparams["val_size"] idx_train, idx_val, idx_test = make_splits( len(self.dataset), train_size, val_size, None, self.hparams["seed"], join(self.hparams["log_dir"], "splits.npz"), self.hparams["splits"], ) return idx_train, idx_val, idx_test def _prepare_MD22_dataset(self): self.dataset = MD22(root=self.hparams["dataset_root"], dataset_arg=self.hparams["dataset_arg"]) train_val_size = self.dataset.molecule_splits[self.hparams["dataset_arg"]] train_size = round(train_val_size * 0.95) val_size = train_val_size - train_size idx_train, idx_val, idx_test = make_splits( len(self.dataset), train_size, val_size, None, self.hparams["seed"], join(self.hparams["log_dir"], "splits.npz"), self.hparams["splits"], ) return idx_train, idx_val, idx_test def _prepare_Molecule3D_dataset(self): self.dataset = Molecule3D(root=self.hparams["dataset_root"]) split_dict = self.dataset.get_idx_split(self.hparams['split_mode']) idx_train = split_dict['train'] idx_val = split_dict['valid'] idx_test = split_dict['test'] return idx_train, idx_val, idx_test def _prepare_QM9_dataset(self): self.dataset = QM9(root=self.hparams["dataset_root"], dataset_arg=self.hparams["dataset_arg"]) train_size = self.hparams["train_size"] val_size = self.hparams["val_size"] idx_train, idx_val, idx_test = make_splits( len(self.dataset), train_size, val_size, None, self.hparams["seed"], join(self.hparams["log_dir"], "splits.npz"), self.hparams["splits"], ) return idx_train, idx_val, idx_test def _prepare_rMD17_dataset(self): self.dataset = rMD17(root=self.hparams["dataset_root"], dataset_arg=self.hparams["dataset_arg"]) train_size = self.hparams["train_size"] val_size = self.hparams["val_size"] idx_train, idx_val, idx_test = make_splits( len(self.dataset), train_size, val_size, None, self.hparams["seed"], join(self.hparams["log_dir"], "splits.npz"), self.hparams["splits"], ) return idx_train, idx_val, idx_test