|
from abc import ABCMeta, abstractmethod |
|
|
|
import torch |
|
import torch.nn as nn |
|
from pytorch_lightning.utilities import rank_zero_warn |
|
|
|
__all__ = ["Atomref"] |
|
|
|
|
|
class BasePrior(nn.Module, metaclass=ABCMeta): |
|
""" |
|
Base class for prior models. |
|
Derive this class to make custom prior models, which take some arguments and a dataset as input. |
|
As an example, have a look at the `torchmdnet.priors.Atomref` prior. |
|
""" |
|
|
|
def __init__(self): |
|
super(BasePrior, self).__init__() |
|
|
|
@abstractmethod |
|
def get_init_args(self): |
|
""" |
|
A function that returns all required arguments to construct a prior object. |
|
The values should be returned inside a dict with the keys being the arguments' names. |
|
All values should also be saveable in a .yaml file as this is used to reconstruct the |
|
prior model from a checkpoint file. |
|
""" |
|
return |
|
|
|
@abstractmethod |
|
def forward(self, x, z): |
|
""" |
|
Forward method of the prior model. |
|
|
|
Args: |
|
x (torch.Tensor): scalar atomwise predictions from the model. |
|
z (torch.Tensor): atom types of all atoms. |
|
|
|
Returns: |
|
torch.Tensor: updated scalar atomwise predictions |
|
""" |
|
return |
|
|
|
|
|
class Atomref(BasePrior): |
|
""" |
|
Atomref prior model. |
|
When using this in combination with some dataset, the dataset class must implement |
|
the function `get_atomref`, which returns the atomic reference values as a tensor. |
|
""" |
|
|
|
def __init__(self, max_z=None, dataset=None): |
|
super(Atomref, self).__init__() |
|
if max_z is None and dataset is None: |
|
raise ValueError("Can't instantiate Atomref prior, all arguments are None.") |
|
if dataset is None: |
|
atomref = torch.zeros(max_z, 1) |
|
else: |
|
atomref = dataset.get_atomref() |
|
if atomref is None: |
|
rank_zero_warn( |
|
"The atomref returned by the dataset is None, defaulting to zeros with max. " |
|
"atomic number 99. Maybe atomref is not defined for the current target." |
|
) |
|
atomref = torch.zeros(100, 1) |
|
|
|
if atomref.ndim == 1: |
|
atomref = atomref.view(-1, 1) |
|
self.register_buffer("initial_atomref", atomref) |
|
self.atomref = nn.Embedding(len(atomref), 1) |
|
self.atomref.weight.data.copy_(atomref) |
|
|
|
def reset_parameters(self): |
|
self.atomref.weight.data.copy_(self.initial_atomref) |
|
|
|
def get_init_args(self): |
|
return dict(max_z=self.initial_atomref.size(0)) |
|
|
|
def forward(self, x, z): |
|
return x + self.atomref(z) |
|
|