|
import os.path as osp |
|
|
|
import numpy as np |
|
import torch |
|
from pytorch_lightning.utilities import rank_zero_warn |
|
from torch_geometric.data import Data, InMemoryDataset, download_url |
|
from tqdm import tqdm |
|
|
|
|
|
class MD17(InMemoryDataset): |
|
""" |
|
Machine learning of accurate energy-conserving molecular force fields (Chmiela et al. 2017) |
|
This class provides functionality for loading MD trajectories from the original dataset, not the revised versions. |
|
See http://www.quantum-machine.org/gdml/#datasets for details. |
|
""" |
|
|
|
raw_url = "http://www.quantum-machine.org/gdml/data/npz/" |
|
|
|
molecule_files = dict( |
|
aspirin="md17_aspirin.npz", |
|
ethanol="md17_ethanol.npz", |
|
malonaldehyde="md17_malonaldehyde.npz", |
|
naphthalene="md17_naphthalene.npz", |
|
salicylic_acid="md17_salicylic.npz", |
|
toluene="md17_toluene.npz", |
|
uracil="md17_uracil.npz", |
|
) |
|
|
|
available_molecules = list(molecule_files.keys()) |
|
|
|
def __init__(self, root, transform=None, pre_transform=None, dataset_arg=None): |
|
assert dataset_arg is not None, ( |
|
"Please provide the desired comma separated molecule(s) through" |
|
f"'dataset_arg'. Available molecules are {', '.join(MD17.available_molecules)} " |
|
"or 'all' to train on the combined dataset." |
|
) |
|
|
|
if dataset_arg == "all": |
|
dataset_arg = ",".join(MD17.available_molecules) |
|
self.molecules = dataset_arg.split(",") |
|
|
|
if len(self.molecules) > 1: |
|
rank_zero_warn( |
|
"MD17 molecules have different reference energies, " |
|
"which is not accounted for during training." |
|
) |
|
|
|
super(MD17, self).__init__(osp.join(root, dataset_arg), transform, pre_transform) |
|
|
|
self.offsets = [0] |
|
self.data_all, self.slices_all = [], [] |
|
for path in self.processed_paths: |
|
data, slices = torch.load(path) |
|
self.data_all.append(data) |
|
self.slices_all.append(slices) |
|
self.offsets.append(len(slices[list(slices.keys())[0]]) - 1 + self.offsets[-1]) |
|
|
|
def len(self): |
|
return sum(len(slices[list(slices.keys())[0]]) - 1 for slices in self.slices_all) |
|
|
|
def get(self, idx): |
|
data_idx = 0 |
|
while data_idx < len(self.data_all) - 1 and idx >= self.offsets[data_idx + 1]: |
|
data_idx += 1 |
|
self.data = self.data_all[data_idx] |
|
self.slices = self.slices_all[data_idx] |
|
return super(MD17, self).get(idx - self.offsets[data_idx]) |
|
|
|
@property |
|
def raw_file_names(self): |
|
return [MD17.molecule_files[mol] for mol in self.molecules] |
|
|
|
@property |
|
def processed_file_names(self): |
|
return [f"md17-{mol}.pt" for mol in self.molecules] |
|
|
|
def download(self): |
|
for file_name in self.raw_file_names: |
|
download_url(MD17.raw_url + file_name, self.raw_dir) |
|
|
|
def process(self): |
|
for path, processed_path in zip(self.raw_paths, self.processed_paths): |
|
data_npz = np.load(path) |
|
z = torch.from_numpy(data_npz["z"]).long() |
|
positions = torch.from_numpy(data_npz["R"]).float() |
|
energies = torch.from_numpy(data_npz["E"]).float() |
|
forces = torch.from_numpy(data_npz["F"]).float() |
|
|
|
samples = [] |
|
for pos, y, dy in tqdm(zip(positions, energies, forces), total=energies.size(0)): |
|
|
|
data = Data(z=z, pos=pos, y=y.unsqueeze(1), dy=dy) |
|
|
|
if self.pre_filter is not None: |
|
data = self.pre_filter(data) |
|
|
|
if self.pre_transform is not None: |
|
data = self.pre_transform(data) |
|
|
|
samples.append(data) |
|
|
|
data, slices = self.collate(samples) |
|
torch.save((data, slices), processed_path) |
|
|