|
import math |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch_cluster import radius_graph |
|
from torch_geometric.nn import MessagePassing |
|
|
|
|
|
class CosineCutoff(nn.Module): |
|
|
|
def __init__(self, cutoff): |
|
super(CosineCutoff, self).__init__() |
|
|
|
self.cutoff = cutoff |
|
|
|
def forward(self, distances): |
|
cutoffs = 0.5 * (torch.cos(distances * math.pi / self.cutoff) + 1.0) |
|
cutoffs = cutoffs * (distances < self.cutoff).float() |
|
return cutoffs |
|
|
|
|
|
class ExpNormalSmearing(nn.Module): |
|
def __init__(self, cutoff=5.0, num_rbf=50, trainable=True): |
|
super(ExpNormalSmearing, self).__init__() |
|
self.cutoff = cutoff |
|
self.num_rbf = num_rbf |
|
self.trainable = trainable |
|
|
|
self.cutoff_fn = CosineCutoff(cutoff) |
|
self.alpha = 5.0 / cutoff |
|
|
|
means, betas = self._initial_params() |
|
if trainable: |
|
self.register_parameter("means", nn.Parameter(means)) |
|
self.register_parameter("betas", nn.Parameter(betas)) |
|
else: |
|
self.register_buffer("means", means) |
|
self.register_buffer("betas", betas) |
|
|
|
def _initial_params(self): |
|
start_value = torch.exp(torch.scalar_tensor(-self.cutoff)) |
|
means = torch.linspace(start_value, 1, self.num_rbf) |
|
betas = torch.tensor([(2 / self.num_rbf * (1 - start_value)) ** -2] * self.num_rbf) |
|
return means, betas |
|
|
|
def reset_parameters(self): |
|
means, betas = self._initial_params() |
|
self.means.data.copy_(means) |
|
self.betas.data.copy_(betas) |
|
|
|
def forward(self, dist): |
|
dist = dist.unsqueeze(-1) |
|
return self.cutoff_fn(dist) * torch.exp(-self.betas * (torch.exp(self.alpha * (-dist)) - self.means) ** 2) |
|
|
|
|
|
class GaussianSmearing(nn.Module): |
|
def __init__(self, cutoff=5.0, num_rbf=50, trainable=True): |
|
super(GaussianSmearing, self).__init__() |
|
self.cutoff = cutoff |
|
self.num_rbf = num_rbf |
|
self.trainable = trainable |
|
|
|
offset, coeff = self._initial_params() |
|
if trainable: |
|
self.register_parameter("coeff", nn.Parameter(coeff)) |
|
self.register_parameter("offset", nn.Parameter(offset)) |
|
else: |
|
self.register_buffer("coeff", coeff) |
|
self.register_buffer("offset", offset) |
|
|
|
def _initial_params(self): |
|
offset = torch.linspace(0, self.cutoff, self.num_rbf) |
|
coeff = -0.5 / (offset[1] - offset[0]) ** 2 |
|
return offset, coeff |
|
|
|
def reset_parameters(self): |
|
offset, coeff = self._initial_params() |
|
self.offset.data.copy_(offset) |
|
self.coeff.data.copy_(coeff) |
|
|
|
def forward(self, dist): |
|
dist = dist.unsqueeze(-1) - self.offset |
|
return torch.exp(self.coeff * torch.pow(dist, 2)) |
|
|
|
|
|
rbf_class_mapping = {"gauss": GaussianSmearing, "expnorm": ExpNormalSmearing} |
|
|
|
|
|
class ShiftedSoftplus(nn.Module): |
|
def __init__(self): |
|
super(ShiftedSoftplus, self).__init__() |
|
self.shift = torch.log(torch.tensor(2.0)).item() |
|
|
|
def forward(self, x): |
|
return F.softplus(x) - self.shift |
|
|
|
|
|
class Swish(nn.Module): |
|
def __init__(self): |
|
super(Swish, self).__init__() |
|
|
|
def forward(self, x): |
|
return x * torch.sigmoid(x) |
|
|
|
|
|
act_class_mapping = {"ssp": ShiftedSoftplus, "silu": nn.SiLU, "tanh": nn.Tanh, "sigmoid": nn.Sigmoid, "swish": Swish} |
|
|
|
|
|
class Sphere(nn.Module): |
|
|
|
def __init__(self, l=2): |
|
super(Sphere, self).__init__() |
|
self.l = l |
|
|
|
def forward(self, edge_vec): |
|
edge_sh = self._spherical_harmonics(self.l, edge_vec[..., 0], edge_vec[..., 1], edge_vec[..., 2]) |
|
return edge_sh |
|
|
|
@staticmethod |
|
def _spherical_harmonics(lmax: int, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor) -> torch.Tensor: |
|
|
|
sh_1_0, sh_1_1, sh_1_2 = x, y, z |
|
|
|
if lmax == 1: |
|
return torch.stack([sh_1_0, sh_1_1, sh_1_2], dim=-1) |
|
|
|
sh_2_0 = math.sqrt(3.0) * x * z |
|
sh_2_1 = math.sqrt(3.0) * x * y |
|
y2 = y.pow(2) |
|
x2z2 = x.pow(2) + z.pow(2) |
|
sh_2_2 = y2 - 0.5 * x2z2 |
|
sh_2_3 = math.sqrt(3.0) * y * z |
|
sh_2_4 = math.sqrt(3.0) / 2.0 * (z.pow(2) - x.pow(2)) |
|
|
|
if lmax == 2: |
|
return torch.stack([sh_1_0, sh_1_1, sh_1_2, sh_2_0, sh_2_1, sh_2_2, sh_2_3, sh_2_4], dim=-1) |
|
|
|
|
|
class VecLayerNorm(nn.Module): |
|
def __init__(self, hidden_channels, trainable, norm_type="max_min"): |
|
super(VecLayerNorm, self).__init__() |
|
|
|
self.hidden_channels = hidden_channels |
|
self.eps = 1e-12 |
|
|
|
weight = torch.ones(self.hidden_channels) |
|
if trainable: |
|
self.register_parameter("weight", nn.Parameter(weight)) |
|
else: |
|
self.register_buffer("weight", weight) |
|
|
|
if norm_type == "rms": |
|
self.norm = self.rms_norm |
|
elif norm_type == "max_min": |
|
self.norm = self.max_min_norm |
|
else: |
|
self.norm = self.none_norm |
|
|
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
weight = torch.ones(self.hidden_channels) |
|
self.weight.data.copy_(weight) |
|
|
|
def none_norm(self, vec): |
|
return vec |
|
|
|
def rms_norm(self, vec): |
|
|
|
dist = torch.norm(vec, dim=1) |
|
|
|
if (dist == 0).all(): |
|
return torch.zeros_like(vec) |
|
|
|
dist = dist.clamp(min=self.eps) |
|
dist = torch.sqrt(torch.mean(dist ** 2, dim=-1)) |
|
return vec / F.relu(dist).unsqueeze(-1).unsqueeze(-1) |
|
|
|
def max_min_norm(self, vec): |
|
|
|
dist = torch.norm(vec, dim=1, keepdim=True) |
|
|
|
if (dist == 0).all(): |
|
return torch.zeros_like(vec) |
|
|
|
dist = dist.clamp(min=self.eps) |
|
direct = vec / dist |
|
|
|
max_val, _ = torch.max(dist, dim=-1) |
|
min_val, _ = torch.min(dist, dim=-1) |
|
delta = (max_val - min_val).view(-1) |
|
delta = torch.where(delta == 0, torch.ones_like(delta), delta) |
|
dist = (dist - min_val.view(-1, 1, 1)) / delta.view(-1, 1, 1) |
|
|
|
return F.relu(dist) * direct |
|
|
|
def forward(self, vec): |
|
|
|
if vec.shape[1] == 3: |
|
vec = self.norm(vec) |
|
return vec * self.weight.unsqueeze(0).unsqueeze(0) |
|
elif vec.shape[1] == 8: |
|
vec1, vec2 = torch.split(vec, [3, 5], dim=1) |
|
vec1 = self.norm(vec1) |
|
vec2 = self.norm(vec2) |
|
vec = torch.cat([vec1, vec2], dim=1) |
|
return vec * self.weight.unsqueeze(0).unsqueeze(0) |
|
else: |
|
raise ValueError("VecLayerNorm only support 3 or 8 channels") |
|
|
|
|
|
class Distance(nn.Module): |
|
def __init__(self, cutoff, max_num_neighbors=32, loop=True): |
|
super(Distance, self).__init__() |
|
self.cutoff = cutoff |
|
self.max_num_neighbors = max_num_neighbors |
|
self.loop = loop |
|
|
|
def forward(self, pos, batch): |
|
edge_index = radius_graph(pos, r=self.cutoff, batch=batch, loop=self.loop, max_num_neighbors=self.max_num_neighbors) |
|
edge_vec = pos[edge_index[0]] - pos[edge_index[1]] |
|
|
|
if self.loop: |
|
mask = edge_index[0] != edge_index[1] |
|
edge_weight = torch.zeros(edge_vec.size(0), device=edge_vec.device) |
|
edge_weight[mask] = torch.norm(edge_vec[mask], dim=-1) |
|
else: |
|
edge_weight = torch.norm(edge_vec, dim=-1) |
|
|
|
return edge_index, edge_weight, edge_vec |
|
|
|
|
|
class NeighborEmbedding(MessagePassing): |
|
def __init__(self, hidden_channels, num_rbf, cutoff, max_z=100): |
|
super(NeighborEmbedding, self).__init__(aggr="add") |
|
self.embedding = nn.Embedding(max_z, hidden_channels) |
|
self.distance_proj = nn.Linear(num_rbf, hidden_channels) |
|
self.combine = nn.Linear(hidden_channels * 2, hidden_channels) |
|
self.cutoff = CosineCutoff(cutoff) |
|
|
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
self.embedding.reset_parameters() |
|
nn.init.xavier_uniform_(self.distance_proj.weight) |
|
nn.init.xavier_uniform_(self.combine.weight) |
|
self.distance_proj.bias.data.fill_(0) |
|
self.combine.bias.data.fill_(0) |
|
|
|
def forward(self, z, x, edge_index, edge_weight, edge_attr): |
|
|
|
mask = edge_index[0] != edge_index[1] |
|
if not mask.all(): |
|
edge_index = edge_index[:, mask] |
|
edge_weight = edge_weight[mask] |
|
edge_attr = edge_attr[mask] |
|
|
|
C = self.cutoff(edge_weight) |
|
W = self.distance_proj(edge_attr) * C.view(-1, 1) |
|
|
|
x_neighbors = self.embedding(z) |
|
|
|
x_neighbors = self.propagate(edge_index, x=x_neighbors, W=W, size=None) |
|
x_neighbors = self.combine(torch.cat([x, x_neighbors], dim=1)) |
|
return x_neighbors |
|
|
|
def message(self, x_j, W): |
|
return x_j * W |
|
|
|
|
|
class EdgeEmbedding(MessagePassing): |
|
|
|
def __init__(self, num_rbf, hidden_channels): |
|
super(EdgeEmbedding, self).__init__(aggr=None) |
|
self.edge_proj = nn.Linear(num_rbf, hidden_channels) |
|
|
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
nn.init.xavier_uniform_(self.edge_proj.weight) |
|
self.edge_proj.bias.data.fill_(0) |
|
|
|
def forward(self, edge_index, edge_attr, x): |
|
|
|
out = self.propagate(edge_index, x=x, edge_attr=edge_attr) |
|
return out |
|
|
|
def message(self, x_i, x_j, edge_attr): |
|
return (x_i + x_j) * self.edge_proj(edge_attr) |
|
|
|
def aggregate(self, features, index): |
|
|
|
return features |