from abc import ABCMeta, abstractmethod import ase import torch import torch.nn as nn from torch_scatter import scatter from visnet.models.utils import act_class_mapping __all__ = ["Scalar", "DipoleMoment", "ElectronicSpatialExtent", "VectorOutput"] class GatedEquivariantBlock(nn.Module): """ Gated Equivariant Block as defined in Schütt et al. (2021): Equivariant message passing for the prediction of tensorial properties and molecular spectra """ def __init__( self, hidden_channels, out_channels, intermediate_channels=None, activation="silu", scalar_activation=False, ): super(GatedEquivariantBlock, self).__init__() self.out_channels = out_channels if intermediate_channels is None: intermediate_channels = hidden_channels self.vec1_proj = nn.Linear(hidden_channels, hidden_channels, bias=False) self.vec2_proj = nn.Linear(hidden_channels, out_channels, bias=False) act_class = act_class_mapping[activation] self.update_net = nn.Sequential( nn.Linear(hidden_channels * 2, intermediate_channels), act_class(), nn.Linear(intermediate_channels, out_channels * 2), ) self.act = act_class() if scalar_activation else None def reset_parameters(self): nn.init.xavier_uniform_(self.vec1_proj.weight) nn.init.xavier_uniform_(self.vec2_proj.weight) nn.init.xavier_uniform_(self.update_net[0].weight) self.update_net[0].bias.data.fill_(0) nn.init.xavier_uniform_(self.update_net[2].weight) self.update_net[2].bias.data.fill_(0) def forward(self, x, v): vec1 = torch.norm(self.vec1_proj(v), dim=-2) vec2 = self.vec2_proj(v) x = torch.cat([x, vec1], dim=-1) x, v = torch.split(self.update_net(x), self.out_channels, dim=-1) v = v.unsqueeze(1) * vec2 if self.act is not None: x = self.act(x) return x, v class OutputModel(nn.Module, metaclass=ABCMeta): def __init__(self, allow_prior_model): super(OutputModel, self).__init__() self.allow_prior_model = allow_prior_model def reset_parameters(self): pass @abstractmethod def pre_reduce(self, x, v, z, pos, batch): return def post_reduce(self, x): return x class Scalar(OutputModel): def __init__(self, hidden_channels, activation="silu", allow_prior_model=True): super(Scalar, self).__init__(allow_prior_model=allow_prior_model) act_class = act_class_mapping[activation] self.output_network = nn.Sequential( nn.Linear(hidden_channels, hidden_channels // 2), act_class(), nn.Linear(hidden_channels // 2, 1), ) self.reset_parameters() def reset_parameters(self): nn.init.xavier_uniform_(self.output_network[0].weight) self.output_network[0].bias.data.fill_(0) nn.init.xavier_uniform_(self.output_network[2].weight) self.output_network[2].bias.data.fill_(0) def pre_reduce(self, x, v, z, pos, batch): # include v in output to make sure all parameters have a gradient return self.output_network(x) class EquivariantScalar(OutputModel): def __init__(self, hidden_channels, activation="silu", allow_prior_model=True): super(EquivariantScalar, self).__init__(allow_prior_model=allow_prior_model) self.output_network = nn.ModuleList([ GatedEquivariantBlock( hidden_channels, hidden_channels // 2, activation=activation, scalar_activation=True, ), GatedEquivariantBlock( hidden_channels // 2, 1, activation=activation, scalar_activation=False, ), ]) self.reset_parameters() def reset_parameters(self): for layer in self.output_network: layer.reset_parameters() def pre_reduce(self, x, v, z, pos, batch): for layer in self.output_network: x, v = layer(x, v) # include v in output to make sure all parameters have a gradient return x + v.sum() * 0 class DipoleMoment(Scalar): def __init__(self, hidden_channels, activation="silu", allow_prior_model=False): super(DipoleMoment, self).__init__(hidden_channels, activation, allow_prior_model=allow_prior_model) atomic_mass = torch.from_numpy(ase.data.atomic_masses).float() self.register_buffer("atomic_mass", atomic_mass) def pre_reduce(self, x, v, z, pos, batch): x = self.output_network(x) # Get center of mass. mass = self.atomic_mass[z].view(-1, 1) c = scatter(mass * pos, batch, dim=0) / scatter(mass, batch, dim=0) x = x * (pos - c[batch]) return x def post_reduce(self, x): return torch.norm(x, dim=-1, keepdim=True) class EquivariantDipoleMoment(EquivariantScalar): def __init__(self, hidden_channels, activation="silu", allow_prior_model=False): super(EquivariantDipoleMoment, self).__init__(hidden_channels, activation, allow_prior_model=allow_prior_model) atomic_mass = torch.from_numpy(ase.data.atomic_masses).float() self.register_buffer("atomic_mass", atomic_mass) def pre_reduce(self, x, v, z, pos, batch): if v.shape[1] == 8: l1_v, l2_v = torch.split(v, [3, 5], dim=1) else: l1_v, l2_v = v, torch.zeros(v.shape[0], 5, v.shape[2]) for layer in self.output_network: x, l1_v = layer(x, l1_v) # Get center of mass. mass = self.atomic_mass[z].view(-1, 1) c = scatter(mass * pos, batch, dim=0) / scatter(mass, batch, dim=0) x = x * (pos - c[batch]) return x + l1_v.squeeze() + l2_v.sum() * 0 def post_reduce(self, x): return torch.norm(x, dim=-1, keepdim=True) class ElectronicSpatialExtent(OutputModel): def __init__(self, hidden_channels, activation="silu", allow_prior_model=False): super(ElectronicSpatialExtent, self).__init__(allow_prior_model=False) act_class = act_class_mapping[activation] self.output_network = nn.Sequential( nn.Linear(hidden_channels, hidden_channels // 2), act_class(), nn.Linear(hidden_channels // 2, 1), ) atomic_mass = torch.from_numpy(ase.data.atomic_masses).float() self.register_buffer("atomic_mass", atomic_mass) self.reset_parameters() def reset_parameters(self): nn.init.xavier_uniform_(self.output_network[0].weight) self.output_network[0].bias.data.fill_(0) nn.init.xavier_uniform_(self.output_network[2].weight) self.output_network[2].bias.data.fill_(0) def pre_reduce(self, x, v, z, pos, batch): x = self.output_network(x) # Get center of mass. mass = self.atomic_mass[z].view(-1, 1) c = scatter(mass * pos, batch, dim=0) / scatter(mass, batch, dim=0) x = torch.norm(pos - c[batch], dim=1, keepdim=True) ** 2 * x return x class EquivariantElectronicSpatialExtent(ElectronicSpatialExtent): pass class EquivariantVectorOutput(EquivariantScalar): def __init__(self, hidden_channels, activation="silu", allow_prior_model=False): super(EquivariantVectorOutput, self).__init__(hidden_channels, activation, allow_prior_model=allow_prior_model) def pre_reduce(self, x, v, z, pos, batch): for layer in self.output_network: x, v = layer(x, v) # Return shape: (num_atoms, 3) if v.shape[1] == 8: l1_v, l2_v = torch.split(v.squeeze(), [3, 5], dim=1) return l1_v + x.sum() * 0 + l2_v.sum() * 0 else: return v + x.sum() * 0