yxc97's picture
Upload folder using huggingface_hub
62a2f1c verified
from os.path import join
import torch
from pytorch_lightning import LightningDataModule
from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn
from torch.utils.data import Subset
from torch_geometric.loader import DataLoader
from torch_scatter import scatter
from tqdm import tqdm
from visnet.datasets import *
from visnet.utils import MissingLabelException, make_splits
class DataModule(LightningDataModule):
def __init__(self, hparams):
super(DataModule, self).__init__()
self.hparams.update(hparams.__dict__) if hasattr(hparams, "__dict__") else self.hparams.update(hparams)
self._mean, self._std = None, None
self._saved_dataloaders = dict()
self.dataset = None
def prepare_dataset(self):
assert hasattr(self, f"_prepare_{self.hparams['dataset']}_dataset"), f"Dataset {self.hparams['dataset']} not defined"
dataset_factory = lambda t: getattr(self, f"_prepare_{t}_dataset")()
self.idx_train, self.idx_val, self.idx_test = dataset_factory(self.hparams["dataset"])
print(f"train {len(self.idx_train)}, val {len(self.idx_val)}, test {len(self.idx_test)}")
self.train_dataset = Subset(self.dataset, self.idx_train)
self.val_dataset = Subset(self.dataset, self.idx_val)
self.test_dataset = Subset(self.dataset, self.idx_test)
if self.hparams["standardize"]:
self._standardize()
def train_dataloader(self):
return self._get_dataloader(self.train_dataset, "train")
def val_dataloader(self):
loaders = [self._get_dataloader(self.val_dataset, "val")]
delta = 1 if self.hparams['reload'] == 1 else 2
if (
len(self.test_dataset) > 0
and (self.trainer.current_epoch + delta) % self.hparams["test_interval"] == 0
):
loaders.append(self._get_dataloader(self.test_dataset, "test"))
return loaders
def test_dataloader(self):
return self._get_dataloader(self.test_dataset, "test")
@property
def atomref(self):
if hasattr(self.dataset, "get_atomref"):
return self.dataset.get_atomref()
return None
@property
def mean(self):
return self._mean
@property
def std(self):
return self._std
def _get_dataloader(self, dataset, stage, store_dataloader=True):
store_dataloader = (store_dataloader and not self.hparams["reload"])
if stage in self._saved_dataloaders and store_dataloader:
return self._saved_dataloaders[stage]
if stage == "train":
batch_size = self.hparams["batch_size"]
shuffle = True
elif stage in ["val", "test"]:
batch_size = self.hparams["inference_batch_size"]
shuffle = False
dl = DataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=self.hparams["num_workers"],
pin_memory=True,
)
if store_dataloader:
self._saved_dataloaders[stage] = dl
return dl
@rank_zero_only
def _standardize(self):
def get_label(batch, atomref):
if batch.y is None:
raise MissingLabelException()
if atomref is None:
return batch.y.clone()
atomref_energy = scatter(atomref[batch.z], batch.batch, dim=0)
return (batch.y.squeeze() - atomref_energy.squeeze()).clone()
data = tqdm(
self._get_dataloader(self.train_dataset, "val", store_dataloader=False),
desc="computing mean and std",
)
try:
atomref = self.atomref if self.hparams["prior_model"] == "Atomref" else None
ys = torch.cat([get_label(batch, atomref) for batch in data])
except MissingLabelException:
rank_zero_warn(
"Standardize is true but failed to compute dataset mean and "
"standard deviation. Maybe the dataset only contains forces."
)
return None
self._mean = ys.mean(dim=0)
self._std = ys.std(dim=0)
def _prepare_Chignolin_dataset(self):
self.dataset = Chignolin(root=self.hparams["dataset_root"])
train_size = self.hparams["train_size"]
val_size = self.hparams["val_size"]
idx_train, idx_val, idx_test = make_splits(
len(self.dataset),
train_size,
val_size,
None,
self.hparams["seed"],
join(self.hparams["log_dir"], "splits.npz"),
self.hparams["splits"],
)
return idx_train, idx_val, idx_test
def _prepare_MD17_dataset(self):
self.dataset = MD17(root=self.hparams["dataset_root"], dataset_arg=self.hparams["dataset_arg"])
train_size = self.hparams["train_size"]
val_size = self.hparams["val_size"]
idx_train, idx_val, idx_test = make_splits(
len(self.dataset),
train_size,
val_size,
None,
self.hparams["seed"],
join(self.hparams["log_dir"], "splits.npz"),
self.hparams["splits"],
)
return idx_train, idx_val, idx_test
def _prepare_MD22_dataset(self):
self.dataset = MD22(root=self.hparams["dataset_root"], dataset_arg=self.hparams["dataset_arg"])
train_val_size = self.dataset.molecule_splits[self.hparams["dataset_arg"]]
train_size = round(train_val_size * 0.95)
val_size = train_val_size - train_size
idx_train, idx_val, idx_test = make_splits(
len(self.dataset),
train_size,
val_size,
None,
self.hparams["seed"],
join(self.hparams["log_dir"], "splits.npz"),
self.hparams["splits"],
)
return idx_train, idx_val, idx_test
def _prepare_Molecule3D_dataset(self):
self.dataset = Molecule3D(root=self.hparams["dataset_root"])
split_dict = self.dataset.get_idx_split(self.hparams['split_mode'])
idx_train = split_dict['train']
idx_val = split_dict['valid']
idx_test = split_dict['test']
return idx_train, idx_val, idx_test
def _prepare_QM9_dataset(self):
self.dataset = QM9(root=self.hparams["dataset_root"], dataset_arg=self.hparams["dataset_arg"])
train_size = self.hparams["train_size"]
val_size = self.hparams["val_size"]
idx_train, idx_val, idx_test = make_splits(
len(self.dataset),
train_size,
val_size,
None,
self.hparams["seed"],
join(self.hparams["log_dir"], "splits.npz"),
self.hparams["splits"],
)
return idx_train, idx_val, idx_test
def _prepare_rMD17_dataset(self):
self.dataset = rMD17(root=self.hparams["dataset_root"], dataset_arg=self.hparams["dataset_arg"])
train_size = self.hparams["train_size"]
val_size = self.hparams["val_size"]
idx_train, idx_val, idx_test = make_splits(
len(self.dataset),
train_size,
val_size,
None,
self.hparams["seed"],
join(self.hparams["log_dir"], "splits.npz"),
self.hparams["splits"],
)
return idx_train, idx_val, idx_test