yxc97's picture
Upload folder using huggingface_hub
62a2f1c verified
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_cluster import radius_graph
from torch_geometric.nn import MessagePassing
class CosineCutoff(nn.Module):
def __init__(self, cutoff):
super(CosineCutoff, self).__init__()
self.cutoff = cutoff
def forward(self, distances):
cutoffs = 0.5 * (torch.cos(distances * math.pi / self.cutoff) + 1.0)
cutoffs = cutoffs * (distances < self.cutoff).float()
return cutoffs
class ExpNormalSmearing(nn.Module):
def __init__(self, cutoff=5.0, num_rbf=50, trainable=True):
super(ExpNormalSmearing, self).__init__()
self.cutoff = cutoff
self.num_rbf = num_rbf
self.trainable = trainable
self.cutoff_fn = CosineCutoff(cutoff)
self.alpha = 5.0 / cutoff
means, betas = self._initial_params()
if trainable:
self.register_parameter("means", nn.Parameter(means))
self.register_parameter("betas", nn.Parameter(betas))
else:
self.register_buffer("means", means)
self.register_buffer("betas", betas)
def _initial_params(self):
start_value = torch.exp(torch.scalar_tensor(-self.cutoff))
means = torch.linspace(start_value, 1, self.num_rbf)
betas = torch.tensor([(2 / self.num_rbf * (1 - start_value)) ** -2] * self.num_rbf)
return means, betas
def reset_parameters(self):
means, betas = self._initial_params()
self.means.data.copy_(means)
self.betas.data.copy_(betas)
def forward(self, dist):
dist = dist.unsqueeze(-1)
return self.cutoff_fn(dist) * torch.exp(-self.betas * (torch.exp(self.alpha * (-dist)) - self.means) ** 2)
class GaussianSmearing(nn.Module):
def __init__(self, cutoff=5.0, num_rbf=50, trainable=True):
super(GaussianSmearing, self).__init__()
self.cutoff = cutoff
self.num_rbf = num_rbf
self.trainable = trainable
offset, coeff = self._initial_params()
if trainable:
self.register_parameter("coeff", nn.Parameter(coeff))
self.register_parameter("offset", nn.Parameter(offset))
else:
self.register_buffer("coeff", coeff)
self.register_buffer("offset", offset)
def _initial_params(self):
offset = torch.linspace(0, self.cutoff, self.num_rbf)
coeff = -0.5 / (offset[1] - offset[0]) ** 2
return offset, coeff
def reset_parameters(self):
offset, coeff = self._initial_params()
self.offset.data.copy_(offset)
self.coeff.data.copy_(coeff)
def forward(self, dist):
dist = dist.unsqueeze(-1) - self.offset
return torch.exp(self.coeff * torch.pow(dist, 2))
rbf_class_mapping = {"gauss": GaussianSmearing, "expnorm": ExpNormalSmearing}
class ShiftedSoftplus(nn.Module):
def __init__(self):
super(ShiftedSoftplus, self).__init__()
self.shift = torch.log(torch.tensor(2.0)).item()
def forward(self, x):
return F.softplus(x) - self.shift
class Swish(nn.Module):
def __init__(self):
super(Swish, self).__init__()
def forward(self, x):
return x * torch.sigmoid(x)
act_class_mapping = {"ssp": ShiftedSoftplus, "silu": nn.SiLU, "tanh": nn.Tanh, "sigmoid": nn.Sigmoid, "swish": Swish}
class Sphere(nn.Module):
def __init__(self, l=2):
super(Sphere, self).__init__()
self.l = l
def forward(self, edge_vec):
edge_sh = self._spherical_harmonics(self.l, edge_vec[..., 0], edge_vec[..., 1], edge_vec[..., 2])
return edge_sh
@staticmethod
def _spherical_harmonics(lmax: int, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
sh_1_0, sh_1_1, sh_1_2 = x, y, z
if lmax == 1:
return torch.stack([sh_1_0, sh_1_1, sh_1_2], dim=-1)
sh_2_0 = math.sqrt(3.0) * x * z
sh_2_1 = math.sqrt(3.0) * x * y
y2 = y.pow(2)
x2z2 = x.pow(2) + z.pow(2)
sh_2_2 = y2 - 0.5 * x2z2
sh_2_3 = math.sqrt(3.0) * y * z
sh_2_4 = math.sqrt(3.0) / 2.0 * (z.pow(2) - x.pow(2))
if lmax == 2:
return torch.stack([sh_1_0, sh_1_1, sh_1_2, sh_2_0, sh_2_1, sh_2_2, sh_2_3, sh_2_4], dim=-1)
class VecLayerNorm(nn.Module):
def __init__(self, hidden_channels, trainable, norm_type="max_min"):
super(VecLayerNorm, self).__init__()
self.hidden_channels = hidden_channels
self.eps = 1e-12
weight = torch.ones(self.hidden_channels)
if trainable:
self.register_parameter("weight", nn.Parameter(weight))
else:
self.register_buffer("weight", weight)
if norm_type == "rms":
self.norm = self.rms_norm
elif norm_type == "max_min":
self.norm = self.max_min_norm
else:
self.norm = self.none_norm
self.reset_parameters()
def reset_parameters(self):
weight = torch.ones(self.hidden_channels)
self.weight.data.copy_(weight)
def none_norm(self, vec):
return vec
def rms_norm(self, vec):
# vec: (num_atoms, 3 or 5, hidden_channels)
dist = torch.norm(vec, dim=1)
if (dist == 0).all():
return torch.zeros_like(vec)
dist = dist.clamp(min=self.eps)
dist = torch.sqrt(torch.mean(dist ** 2, dim=-1))
return vec / F.relu(dist).unsqueeze(-1).unsqueeze(-1)
def max_min_norm(self, vec):
# vec: (num_atoms, 3 or 5, hidden_channels)
dist = torch.norm(vec, dim=1, keepdim=True)
if (dist == 0).all():
return torch.zeros_like(vec)
dist = dist.clamp(min=self.eps)
direct = vec / dist
max_val, _ = torch.max(dist, dim=-1)
min_val, _ = torch.min(dist, dim=-1)
delta = (max_val - min_val).view(-1)
delta = torch.where(delta == 0, torch.ones_like(delta), delta)
dist = (dist - min_val.view(-1, 1, 1)) / delta.view(-1, 1, 1)
return F.relu(dist) * direct
def forward(self, vec):
# vec: (num_atoms, 3 or 8, hidden_channels)
if vec.shape[1] == 3:
vec = self.norm(vec)
return vec * self.weight.unsqueeze(0).unsqueeze(0)
elif vec.shape[1] == 8:
vec1, vec2 = torch.split(vec, [3, 5], dim=1)
vec1 = self.norm(vec1)
vec2 = self.norm(vec2)
vec = torch.cat([vec1, vec2], dim=1)
return vec * self.weight.unsqueeze(0).unsqueeze(0)
else:
raise ValueError("VecLayerNorm only support 3 or 8 channels")
class Distance(nn.Module):
def __init__(self, cutoff, max_num_neighbors=32, loop=True):
super(Distance, self).__init__()
self.cutoff = cutoff
self.max_num_neighbors = max_num_neighbors
self.loop = loop
def forward(self, pos, batch):
edge_index = radius_graph(pos, r=self.cutoff, batch=batch, loop=self.loop, max_num_neighbors=self.max_num_neighbors)
edge_vec = pos[edge_index[0]] - pos[edge_index[1]]
if self.loop:
mask = edge_index[0] != edge_index[1]
edge_weight = torch.zeros(edge_vec.size(0), device=edge_vec.device)
edge_weight[mask] = torch.norm(edge_vec[mask], dim=-1)
else:
edge_weight = torch.norm(edge_vec, dim=-1)
return edge_index, edge_weight, edge_vec
class NeighborEmbedding(MessagePassing):
def __init__(self, hidden_channels, num_rbf, cutoff, max_z=100):
super(NeighborEmbedding, self).__init__(aggr="add")
self.embedding = nn.Embedding(max_z, hidden_channels)
self.distance_proj = nn.Linear(num_rbf, hidden_channels)
self.combine = nn.Linear(hidden_channels * 2, hidden_channels)
self.cutoff = CosineCutoff(cutoff)
self.reset_parameters()
def reset_parameters(self):
self.embedding.reset_parameters()
nn.init.xavier_uniform_(self.distance_proj.weight)
nn.init.xavier_uniform_(self.combine.weight)
self.distance_proj.bias.data.fill_(0)
self.combine.bias.data.fill_(0)
def forward(self, z, x, edge_index, edge_weight, edge_attr):
# remove self loops
mask = edge_index[0] != edge_index[1]
if not mask.all():
edge_index = edge_index[:, mask]
edge_weight = edge_weight[mask]
edge_attr = edge_attr[mask]
C = self.cutoff(edge_weight)
W = self.distance_proj(edge_attr) * C.view(-1, 1)
x_neighbors = self.embedding(z)
# propagate_type: (x: Tensor, W: Tensor)
x_neighbors = self.propagate(edge_index, x=x_neighbors, W=W, size=None)
x_neighbors = self.combine(torch.cat([x, x_neighbors], dim=1))
return x_neighbors
def message(self, x_j, W):
return x_j * W
class EdgeEmbedding(MessagePassing):
def __init__(self, num_rbf, hidden_channels):
super(EdgeEmbedding, self).__init__(aggr=None)
self.edge_proj = nn.Linear(num_rbf, hidden_channels)
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.edge_proj.weight)
self.edge_proj.bias.data.fill_(0)
def forward(self, edge_index, edge_attr, x):
# propagate_type: (x: Tensor, edge_attr: Tensor)
out = self.propagate(edge_index, x=x, edge_attr=edge_attr)
return out
def message(self, x_i, x_j, edge_attr):
return (x_i + x_j) * self.edge_proj(edge_attr)
def aggregate(self, features, index):
# no aggregate
return features