yxc97's picture
Upload folder using huggingface_hub
62a2f1c verified
from abc import ABCMeta, abstractmethod
import torch
import torch.nn as nn
from pytorch_lightning.utilities import rank_zero_warn
__all__ = ["Atomref"]
class BasePrior(nn.Module, metaclass=ABCMeta):
"""
Base class for prior models.
Derive this class to make custom prior models, which take some arguments and a dataset as input.
As an example, have a look at the `torchmdnet.priors.Atomref` prior.
"""
def __init__(self):
super(BasePrior, self).__init__()
@abstractmethod
def get_init_args(self):
"""
A function that returns all required arguments to construct a prior object.
The values should be returned inside a dict with the keys being the arguments' names.
All values should also be saveable in a .yaml file as this is used to reconstruct the
prior model from a checkpoint file.
"""
return
@abstractmethod
def forward(self, x, z):
"""
Forward method of the prior model.
Args:
x (torch.Tensor): scalar atomwise predictions from the model.
z (torch.Tensor): atom types of all atoms.
Returns:
torch.Tensor: updated scalar atomwise predictions
"""
return
class Atomref(BasePrior):
"""
Atomref prior model.
When using this in combination with some dataset, the dataset class must implement
the function `get_atomref`, which returns the atomic reference values as a tensor.
"""
def __init__(self, max_z=None, dataset=None):
super(Atomref, self).__init__()
if max_z is None and dataset is None:
raise ValueError("Can't instantiate Atomref prior, all arguments are None.")
if dataset is None:
atomref = torch.zeros(max_z, 1)
else:
atomref = dataset.get_atomref()
if atomref is None:
rank_zero_warn(
"The atomref returned by the dataset is None, defaulting to zeros with max. "
"atomic number 99. Maybe atomref is not defined for the current target."
)
atomref = torch.zeros(100, 1)
if atomref.ndim == 1:
atomref = atomref.view(-1, 1)
self.register_buffer("initial_atomref", atomref)
self.atomref = nn.Embedding(len(atomref), 1)
self.atomref.weight.data.copy_(atomref)
def reset_parameters(self):
self.atomref.weight.data.copy_(self.initial_atomref)
def get_init_args(self):
return dict(max_z=self.initial_atomref.size(0))
def forward(self, x, z):
return x + self.atomref(z)