|
from abc import ABCMeta, abstractmethod |
|
|
|
import ase |
|
import torch |
|
import torch.nn as nn |
|
from torch_scatter import scatter |
|
|
|
from visnet.models.utils import act_class_mapping |
|
|
|
__all__ = ["Scalar", "DipoleMoment", "ElectronicSpatialExtent", "VectorOutput"] |
|
|
|
|
|
class GatedEquivariantBlock(nn.Module): |
|
""" |
|
Gated Equivariant Block as defined in Schütt et al. (2021): |
|
Equivariant message passing for the prediction of tensorial properties and molecular spectra |
|
""" |
|
def __init__( |
|
self, |
|
hidden_channels, |
|
out_channels, |
|
intermediate_channels=None, |
|
activation="silu", |
|
scalar_activation=False, |
|
): |
|
super(GatedEquivariantBlock, self).__init__() |
|
self.out_channels = out_channels |
|
|
|
if intermediate_channels is None: |
|
intermediate_channels = hidden_channels |
|
|
|
self.vec1_proj = nn.Linear(hidden_channels, hidden_channels, bias=False) |
|
self.vec2_proj = nn.Linear(hidden_channels, out_channels, bias=False) |
|
|
|
act_class = act_class_mapping[activation] |
|
self.update_net = nn.Sequential( |
|
nn.Linear(hidden_channels * 2, intermediate_channels), |
|
act_class(), |
|
nn.Linear(intermediate_channels, out_channels * 2), |
|
) |
|
|
|
self.act = act_class() if scalar_activation else None |
|
|
|
def reset_parameters(self): |
|
nn.init.xavier_uniform_(self.vec1_proj.weight) |
|
nn.init.xavier_uniform_(self.vec2_proj.weight) |
|
nn.init.xavier_uniform_(self.update_net[0].weight) |
|
self.update_net[0].bias.data.fill_(0) |
|
nn.init.xavier_uniform_(self.update_net[2].weight) |
|
self.update_net[2].bias.data.fill_(0) |
|
|
|
def forward(self, x, v): |
|
vec1 = torch.norm(self.vec1_proj(v), dim=-2) |
|
vec2 = self.vec2_proj(v) |
|
|
|
x = torch.cat([x, vec1], dim=-1) |
|
x, v = torch.split(self.update_net(x), self.out_channels, dim=-1) |
|
v = v.unsqueeze(1) * vec2 |
|
|
|
if self.act is not None: |
|
x = self.act(x) |
|
return x, v |
|
|
|
|
|
class OutputModel(nn.Module, metaclass=ABCMeta): |
|
def __init__(self, allow_prior_model): |
|
super(OutputModel, self).__init__() |
|
self.allow_prior_model = allow_prior_model |
|
|
|
def reset_parameters(self): |
|
pass |
|
|
|
@abstractmethod |
|
def pre_reduce(self, x, v, z, pos, batch): |
|
return |
|
|
|
def post_reduce(self, x): |
|
return x |
|
|
|
|
|
class Scalar(OutputModel): |
|
def __init__(self, hidden_channels, activation="silu", allow_prior_model=True): |
|
super(Scalar, self).__init__(allow_prior_model=allow_prior_model) |
|
act_class = act_class_mapping[activation] |
|
self.output_network = nn.Sequential( |
|
nn.Linear(hidden_channels, hidden_channels // 2), |
|
act_class(), |
|
nn.Linear(hidden_channels // 2, 1), |
|
) |
|
|
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
nn.init.xavier_uniform_(self.output_network[0].weight) |
|
self.output_network[0].bias.data.fill_(0) |
|
nn.init.xavier_uniform_(self.output_network[2].weight) |
|
self.output_network[2].bias.data.fill_(0) |
|
|
|
def pre_reduce(self, x, v, z, pos, batch): |
|
|
|
return self.output_network(x) |
|
|
|
|
|
class EquivariantScalar(OutputModel): |
|
def __init__(self, hidden_channels, activation="silu", allow_prior_model=True): |
|
super(EquivariantScalar, self).__init__(allow_prior_model=allow_prior_model) |
|
self.output_network = nn.ModuleList([ |
|
GatedEquivariantBlock( |
|
hidden_channels, |
|
hidden_channels // 2, |
|
activation=activation, |
|
scalar_activation=True, |
|
), |
|
GatedEquivariantBlock( |
|
hidden_channels // 2, |
|
1, |
|
activation=activation, |
|
scalar_activation=False, |
|
), |
|
]) |
|
|
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
for layer in self.output_network: |
|
layer.reset_parameters() |
|
|
|
def pre_reduce(self, x, v, z, pos, batch): |
|
for layer in self.output_network: |
|
x, v = layer(x, v) |
|
|
|
return x + v.sum() * 0 |
|
|
|
|
|
class DipoleMoment(Scalar): |
|
def __init__(self, hidden_channels, activation="silu", allow_prior_model=False): |
|
super(DipoleMoment, self).__init__(hidden_channels, activation, allow_prior_model=allow_prior_model) |
|
atomic_mass = torch.from_numpy(ase.data.atomic_masses).float() |
|
self.register_buffer("atomic_mass", atomic_mass) |
|
|
|
def pre_reduce(self, x, v, z, pos, batch): |
|
x = self.output_network(x) |
|
|
|
|
|
mass = self.atomic_mass[z].view(-1, 1) |
|
c = scatter(mass * pos, batch, dim=0) / scatter(mass, batch, dim=0) |
|
x = x * (pos - c[batch]) |
|
return x |
|
|
|
def post_reduce(self, x): |
|
return torch.norm(x, dim=-1, keepdim=True) |
|
|
|
|
|
class EquivariantDipoleMoment(EquivariantScalar): |
|
def __init__(self, hidden_channels, activation="silu", allow_prior_model=False): |
|
super(EquivariantDipoleMoment, self).__init__(hidden_channels, activation, allow_prior_model=allow_prior_model) |
|
atomic_mass = torch.from_numpy(ase.data.atomic_masses).float() |
|
self.register_buffer("atomic_mass", atomic_mass) |
|
|
|
def pre_reduce(self, x, v, z, pos, batch): |
|
if v.shape[1] == 8: |
|
l1_v, l2_v = torch.split(v, [3, 5], dim=1) |
|
else: |
|
l1_v, l2_v = v, torch.zeros(v.shape[0], 5, v.shape[2]) |
|
|
|
for layer in self.output_network: |
|
x, l1_v = layer(x, l1_v) |
|
|
|
|
|
mass = self.atomic_mass[z].view(-1, 1) |
|
c = scatter(mass * pos, batch, dim=0) / scatter(mass, batch, dim=0) |
|
x = x * (pos - c[batch]) |
|
return x + l1_v.squeeze() + l2_v.sum() * 0 |
|
|
|
def post_reduce(self, x): |
|
return torch.norm(x, dim=-1, keepdim=True) |
|
|
|
|
|
class ElectronicSpatialExtent(OutputModel): |
|
def __init__(self, hidden_channels, activation="silu", allow_prior_model=False): |
|
super(ElectronicSpatialExtent, self).__init__(allow_prior_model=False) |
|
act_class = act_class_mapping[activation] |
|
self.output_network = nn.Sequential( |
|
nn.Linear(hidden_channels, hidden_channels // 2), |
|
act_class(), |
|
nn.Linear(hidden_channels // 2, 1), |
|
) |
|
atomic_mass = torch.from_numpy(ase.data.atomic_masses).float() |
|
self.register_buffer("atomic_mass", atomic_mass) |
|
|
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
nn.init.xavier_uniform_(self.output_network[0].weight) |
|
self.output_network[0].bias.data.fill_(0) |
|
nn.init.xavier_uniform_(self.output_network[2].weight) |
|
self.output_network[2].bias.data.fill_(0) |
|
|
|
def pre_reduce(self, x, v, z, pos, batch): |
|
x = self.output_network(x) |
|
|
|
|
|
mass = self.atomic_mass[z].view(-1, 1) |
|
c = scatter(mass * pos, batch, dim=0) / scatter(mass, batch, dim=0) |
|
|
|
x = torch.norm(pos - c[batch], dim=1, keepdim=True) ** 2 * x |
|
return x |
|
|
|
|
|
class EquivariantElectronicSpatialExtent(ElectronicSpatialExtent): |
|
pass |
|
|
|
|
|
class EquivariantVectorOutput(EquivariantScalar): |
|
def __init__(self, hidden_channels, activation="silu", allow_prior_model=False): |
|
super(EquivariantVectorOutput, self).__init__(hidden_channels, activation, allow_prior_model=allow_prior_model) |
|
|
|
def pre_reduce(self, x, v, z, pos, batch): |
|
for layer in self.output_network: |
|
x, v = layer(x, v) |
|
|
|
if v.shape[1] == 8: |
|
l1_v, l2_v = torch.split(v.squeeze(), [3, 5], dim=1) |
|
return l1_v + x.sum() * 0 + l2_v.sum() * 0 |
|
else: |
|
return v + x.sum() * 0 |
|
|