yxc97's picture
Upload folder using huggingface_hub
62a2f1c verified
from abc import ABCMeta, abstractmethod
import ase
import torch
import torch.nn as nn
from torch_scatter import scatter
from visnet.models.utils import act_class_mapping
__all__ = ["Scalar", "DipoleMoment", "ElectronicSpatialExtent", "VectorOutput"]
class GatedEquivariantBlock(nn.Module):
"""
Gated Equivariant Block as defined in Schütt et al. (2021):
Equivariant message passing for the prediction of tensorial properties and molecular spectra
"""
def __init__(
self,
hidden_channels,
out_channels,
intermediate_channels=None,
activation="silu",
scalar_activation=False,
):
super(GatedEquivariantBlock, self).__init__()
self.out_channels = out_channels
if intermediate_channels is None:
intermediate_channels = hidden_channels
self.vec1_proj = nn.Linear(hidden_channels, hidden_channels, bias=False)
self.vec2_proj = nn.Linear(hidden_channels, out_channels, bias=False)
act_class = act_class_mapping[activation]
self.update_net = nn.Sequential(
nn.Linear(hidden_channels * 2, intermediate_channels),
act_class(),
nn.Linear(intermediate_channels, out_channels * 2),
)
self.act = act_class() if scalar_activation else None
def reset_parameters(self):
nn.init.xavier_uniform_(self.vec1_proj.weight)
nn.init.xavier_uniform_(self.vec2_proj.weight)
nn.init.xavier_uniform_(self.update_net[0].weight)
self.update_net[0].bias.data.fill_(0)
nn.init.xavier_uniform_(self.update_net[2].weight)
self.update_net[2].bias.data.fill_(0)
def forward(self, x, v):
vec1 = torch.norm(self.vec1_proj(v), dim=-2)
vec2 = self.vec2_proj(v)
x = torch.cat([x, vec1], dim=-1)
x, v = torch.split(self.update_net(x), self.out_channels, dim=-1)
v = v.unsqueeze(1) * vec2
if self.act is not None:
x = self.act(x)
return x, v
class OutputModel(nn.Module, metaclass=ABCMeta):
def __init__(self, allow_prior_model):
super(OutputModel, self).__init__()
self.allow_prior_model = allow_prior_model
def reset_parameters(self):
pass
@abstractmethod
def pre_reduce(self, x, v, z, pos, batch):
return
def post_reduce(self, x):
return x
class Scalar(OutputModel):
def __init__(self, hidden_channels, activation="silu", allow_prior_model=True):
super(Scalar, self).__init__(allow_prior_model=allow_prior_model)
act_class = act_class_mapping[activation]
self.output_network = nn.Sequential(
nn.Linear(hidden_channels, hidden_channels // 2),
act_class(),
nn.Linear(hidden_channels // 2, 1),
)
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.output_network[0].weight)
self.output_network[0].bias.data.fill_(0)
nn.init.xavier_uniform_(self.output_network[2].weight)
self.output_network[2].bias.data.fill_(0)
def pre_reduce(self, x, v, z, pos, batch):
# include v in output to make sure all parameters have a gradient
return self.output_network(x)
class EquivariantScalar(OutputModel):
def __init__(self, hidden_channels, activation="silu", allow_prior_model=True):
super(EquivariantScalar, self).__init__(allow_prior_model=allow_prior_model)
self.output_network = nn.ModuleList([
GatedEquivariantBlock(
hidden_channels,
hidden_channels // 2,
activation=activation,
scalar_activation=True,
),
GatedEquivariantBlock(
hidden_channels // 2,
1,
activation=activation,
scalar_activation=False,
),
])
self.reset_parameters()
def reset_parameters(self):
for layer in self.output_network:
layer.reset_parameters()
def pre_reduce(self, x, v, z, pos, batch):
for layer in self.output_network:
x, v = layer(x, v)
# include v in output to make sure all parameters have a gradient
return x + v.sum() * 0
class DipoleMoment(Scalar):
def __init__(self, hidden_channels, activation="silu", allow_prior_model=False):
super(DipoleMoment, self).__init__(hidden_channels, activation, allow_prior_model=allow_prior_model)
atomic_mass = torch.from_numpy(ase.data.atomic_masses).float()
self.register_buffer("atomic_mass", atomic_mass)
def pre_reduce(self, x, v, z, pos, batch):
x = self.output_network(x)
# Get center of mass.
mass = self.atomic_mass[z].view(-1, 1)
c = scatter(mass * pos, batch, dim=0) / scatter(mass, batch, dim=0)
x = x * (pos - c[batch])
return x
def post_reduce(self, x):
return torch.norm(x, dim=-1, keepdim=True)
class EquivariantDipoleMoment(EquivariantScalar):
def __init__(self, hidden_channels, activation="silu", allow_prior_model=False):
super(EquivariantDipoleMoment, self).__init__(hidden_channels, activation, allow_prior_model=allow_prior_model)
atomic_mass = torch.from_numpy(ase.data.atomic_masses).float()
self.register_buffer("atomic_mass", atomic_mass)
def pre_reduce(self, x, v, z, pos, batch):
if v.shape[1] == 8:
l1_v, l2_v = torch.split(v, [3, 5], dim=1)
else:
l1_v, l2_v = v, torch.zeros(v.shape[0], 5, v.shape[2])
for layer in self.output_network:
x, l1_v = layer(x, l1_v)
# Get center of mass.
mass = self.atomic_mass[z].view(-1, 1)
c = scatter(mass * pos, batch, dim=0) / scatter(mass, batch, dim=0)
x = x * (pos - c[batch])
return x + l1_v.squeeze() + l2_v.sum() * 0
def post_reduce(self, x):
return torch.norm(x, dim=-1, keepdim=True)
class ElectronicSpatialExtent(OutputModel):
def __init__(self, hidden_channels, activation="silu", allow_prior_model=False):
super(ElectronicSpatialExtent, self).__init__(allow_prior_model=False)
act_class = act_class_mapping[activation]
self.output_network = nn.Sequential(
nn.Linear(hidden_channels, hidden_channels // 2),
act_class(),
nn.Linear(hidden_channels // 2, 1),
)
atomic_mass = torch.from_numpy(ase.data.atomic_masses).float()
self.register_buffer("atomic_mass", atomic_mass)
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.output_network[0].weight)
self.output_network[0].bias.data.fill_(0)
nn.init.xavier_uniform_(self.output_network[2].weight)
self.output_network[2].bias.data.fill_(0)
def pre_reduce(self, x, v, z, pos, batch):
x = self.output_network(x)
# Get center of mass.
mass = self.atomic_mass[z].view(-1, 1)
c = scatter(mass * pos, batch, dim=0) / scatter(mass, batch, dim=0)
x = torch.norm(pos - c[batch], dim=1, keepdim=True) ** 2 * x
return x
class EquivariantElectronicSpatialExtent(ElectronicSpatialExtent):
pass
class EquivariantVectorOutput(EquivariantScalar):
def __init__(self, hidden_channels, activation="silu", allow_prior_model=False):
super(EquivariantVectorOutput, self).__init__(hidden_channels, activation, allow_prior_model=allow_prior_model)
def pre_reduce(self, x, v, z, pos, batch):
for layer in self.output_network:
x, v = layer(x, v)
# Return shape: (num_atoms, 3)
if v.shape[1] == 8:
l1_v, l2_v = torch.split(v.squeeze(), [3, 5], dim=1)
return l1_v + x.sum() * 0 + l2_v.sum() * 0
else:
return v + x.sum() * 0