|
import argparse |
|
import logging |
|
import os |
|
import sys |
|
import json |
|
import re |
|
import numpy as np |
|
import pytorch_lightning as pl |
|
import torch |
|
import torch.nn as nn |
|
from torch import Tensor |
|
from torch.autograd import grad |
|
from torch_geometric.data import Data |
|
from torch_geometric.nn import MessagePassing |
|
from torch_scatter import scatter |
|
from torch.nn.functional import l1_loss, mse_loss |
|
from torch.optim import AdamW |
|
from torch.optim.lr_scheduler import ReduceLROnPlateau |
|
|
|
from pytorch_lightning.callbacks import EarlyStopping |
|
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint |
|
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger |
|
from pytorch_lightning.strategies import DDPStrategy |
|
from pytorch_lightning.utilities import rank_zero_warn |
|
from pytorch_lightning import LightningModule |
|
|
|
from visnet import datasets, models, priors |
|
from visnet.data import DataModule |
|
from visnet.models import output_modules |
|
from visnet.utils import LoadFromCheckpoint, LoadFromFile, number, save_argparse |
|
|
|
from typing import Optional, Tuple , List |
|
from metrics import calculate_mae |
|
from visnet.models.utils import ( |
|
CosineCutoff, |
|
Distance, |
|
EdgeEmbedding, |
|
NeighborEmbedding, |
|
Sphere, |
|
VecLayerNorm, |
|
act_class_mapping, |
|
rbf_class_mapping, |
|
ExpNormalSmearing, |
|
GaussianSmearing |
|
) |
|
|
|
""" |
|
Models |
|
""" |
|
class ViSNetBlock(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
lmax=2, |
|
vecnorm_type='none', |
|
trainable_vecnorm=False, |
|
num_heads=8, |
|
num_layers=9, |
|
hidden_channels=256, |
|
num_rbf=32, |
|
rbf_type="expnorm", |
|
trainable_rbf=False, |
|
activation="silu", |
|
attn_activation="silu", |
|
max_z=100, |
|
cutoff=5.0, |
|
max_num_neighbors=32, |
|
vertex_type="Edge", |
|
): |
|
super(ViSNetBlock, self).__init__() |
|
self.lmax = lmax |
|
self.vecnorm_type = vecnorm_type |
|
self.trainable_vecnorm = trainable_vecnorm |
|
self.num_heads = num_heads |
|
self.num_layers = num_layers |
|
self.hidden_channels = hidden_channels |
|
self.num_rbf = num_rbf |
|
self.rbf_type = rbf_type |
|
self.trainable_rbf = trainable_rbf |
|
self.activation = activation |
|
self.attn_activation = attn_activation |
|
self.max_z = max_z |
|
self.cutoff = cutoff |
|
self.max_num_neighbors = max_num_neighbors |
|
|
|
self.embedding = nn.Embedding(max_z, hidden_channels) |
|
self.distance = Distance(cutoff, max_num_neighbors=max_num_neighbors, loop=True) |
|
self.sphere = Sphere(l=lmax) |
|
self.distance_expansion = rbf_class_mapping[rbf_type](cutoff, num_rbf, trainable_rbf) |
|
self.neighbor_embedding = NeighborEmbedding(hidden_channels, num_rbf, cutoff, max_z).jittable() |
|
self.edge_embedding = EdgeEmbedding(num_rbf, hidden_channels).jittable() |
|
|
|
self.vis_mp_layers = nn.ModuleList() |
|
vis_mp_kwargs = dict( |
|
num_heads=num_heads, |
|
hidden_channels=hidden_channels, |
|
activation=activation, |
|
attn_activation=attn_activation, |
|
cutoff=cutoff, |
|
vecnorm_type=vecnorm_type, |
|
trainable_vecnorm=trainable_vecnorm |
|
) |
|
vis_mp_class = VIS_MP_MAP.get(vertex_type, ViS_MP) |
|
for _ in range(num_layers - 1): |
|
layer = vis_mp_class(last_layer=False, **vis_mp_kwargs).jittable() |
|
self.vis_mp_layers.append(layer) |
|
self.vis_mp_layers.append(vis_mp_class(last_layer=True, **vis_mp_kwargs).jittable()) |
|
|
|
self.out_norm = nn.LayerNorm(hidden_channels) |
|
self.vec_out_norm = VecLayerNorm(hidden_channels, trainable=trainable_vecnorm, norm_type=vecnorm_type) |
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
self.embedding.reset_parameters() |
|
self.distance_expansion.reset_parameters() |
|
self.neighbor_embedding.reset_parameters() |
|
self.edge_embedding.reset_parameters() |
|
for layer in self.vis_mp_layers: |
|
layer.reset_parameters() |
|
self.out_norm.reset_parameters() |
|
self.vec_out_norm.reset_parameters() |
|
|
|
def forward(self, data: Data) -> Tuple[Tensor, Tensor]: |
|
|
|
z, pos, batch = data.z, data.pos, data.batch |
|
|
|
|
|
x = self.embedding(z) |
|
edge_index, edge_weight, edge_vec = self.distance(pos, batch) |
|
edge_attr = self.distance_expansion(edge_weight) |
|
mask = edge_index[0] != edge_index[1] |
|
edge_vec[mask] = edge_vec[mask] / torch.norm(edge_vec[mask], dim=1).unsqueeze(1) |
|
edge_vec = self.sphere(edge_vec) |
|
x = self.neighbor_embedding(z, x, edge_index, edge_weight, edge_attr) |
|
vec = torch.zeros(x.size(0), ((self.lmax + 1) ** 2) - 1, x.size(1), device=x.device) |
|
edge_attr = self.edge_embedding(edge_index, edge_attr, x) |
|
|
|
|
|
for attn in self.vis_mp_layers[:-1]: |
|
dx, dvec, dedge_attr = attn(x, vec, edge_index, edge_weight, edge_attr, edge_vec) |
|
x = x + dx |
|
vec = vec + dvec |
|
edge_attr = edge_attr + dedge_attr |
|
|
|
dx, dvec, _ = self.vis_mp_layers[-1](x, vec, edge_index, edge_weight, edge_attr, edge_vec) |
|
x = x + dx |
|
vec = vec + dvec |
|
|
|
x = self.out_norm(x) |
|
vec = self.vec_out_norm(vec) |
|
|
|
return x, vec |
|
|
|
class ViS_MP(MessagePassing): |
|
def __init__( |
|
self, |
|
num_heads, |
|
hidden_channels, |
|
activation, |
|
attn_activation, |
|
cutoff, |
|
vecnorm_type, |
|
trainable_vecnorm, |
|
last_layer=False, |
|
): |
|
super(ViS_MP, self).__init__(aggr="add", node_dim=0) |
|
assert hidden_channels % num_heads == 0, ( |
|
f"The number of hidden channels ({hidden_channels}) " |
|
f"must be evenly divisible by the number of " |
|
f"attention heads ({num_heads})" |
|
) |
|
|
|
self.num_heads = num_heads |
|
self.hidden_channels = hidden_channels |
|
self.head_dim = hidden_channels // num_heads |
|
self.last_layer = last_layer |
|
|
|
self.layernorm = nn.LayerNorm(hidden_channels) |
|
self.vec_layernorm = VecLayerNorm(hidden_channels, trainable=trainable_vecnorm, norm_type=vecnorm_type) |
|
|
|
self.act = act_class_mapping[activation]() |
|
self.attn_activation = act_class_mapping[attn_activation]() |
|
|
|
self.cutoff = CosineCutoff(cutoff) |
|
|
|
self.vec_proj = nn.Linear(hidden_channels, hidden_channels * 3, bias=False) |
|
|
|
self.q_proj = nn.Linear(hidden_channels, hidden_channels) |
|
self.k_proj = nn.Linear(hidden_channels, hidden_channels) |
|
self.v_proj = nn.Linear(hidden_channels, hidden_channels) |
|
self.dk_proj = nn.Linear(hidden_channels, hidden_channels) |
|
self.dv_proj = nn.Linear(hidden_channels, hidden_channels) |
|
|
|
self.s_proj = nn.Linear(hidden_channels, hidden_channels * 2) |
|
if not self.last_layer: |
|
self.f_proj = nn.Linear(hidden_channels, hidden_channels) |
|
self.w_src_proj = nn.Linear(hidden_channels, hidden_channels, bias=False) |
|
self.w_trg_proj = nn.Linear(hidden_channels, hidden_channels, bias=False) |
|
|
|
self.o_proj = nn.Linear(hidden_channels, hidden_channels * 3) |
|
|
|
self.reset_parameters() |
|
|
|
@staticmethod |
|
def vector_rejection(vec, d_ij): |
|
vec_proj = (vec * d_ij.unsqueeze(2)).sum(dim=1, keepdim=True) |
|
return vec - vec_proj * d_ij.unsqueeze(2) |
|
|
|
def reset_parameters(self): |
|
self.layernorm.reset_parameters() |
|
self.vec_layernorm.reset_parameters() |
|
nn.init.xavier_uniform_(self.q_proj.weight) |
|
self.q_proj.bias.data.fill_(0) |
|
nn.init.xavier_uniform_(self.k_proj.weight) |
|
self.k_proj.bias.data.fill_(0) |
|
nn.init.xavier_uniform_(self.v_proj.weight) |
|
self.v_proj.bias.data.fill_(0) |
|
nn.init.xavier_uniform_(self.o_proj.weight) |
|
self.o_proj.bias.data.fill_(0) |
|
nn.init.xavier_uniform_(self.s_proj.weight) |
|
self.s_proj.bias.data.fill_(0) |
|
|
|
if not self.last_layer: |
|
nn.init.xavier_uniform_(self.f_proj.weight) |
|
self.f_proj.bias.data.fill_(0) |
|
nn.init.xavier_uniform_(self.w_src_proj.weight) |
|
nn.init.xavier_uniform_(self.w_trg_proj.weight) |
|
|
|
nn.init.xavier_uniform_(self.vec_proj.weight) |
|
nn.init.xavier_uniform_(self.dk_proj.weight) |
|
self.dk_proj.bias.data.fill_(0) |
|
nn.init.xavier_uniform_(self.dv_proj.weight) |
|
self.dv_proj.bias.data.fill_(0) |
|
|
|
|
|
def forward(self, x, vec, edge_index, r_ij, f_ij, d_ij): |
|
x = self.layernorm(x) |
|
vec = self.vec_layernorm(vec) |
|
|
|
q = self.q_proj(x).reshape(-1, self.num_heads, self.head_dim) |
|
k = self.k_proj(x).reshape(-1, self.num_heads, self.head_dim) |
|
v = self.v_proj(x).reshape(-1, self.num_heads, self.head_dim) |
|
dk = self.act(self.dk_proj(f_ij)).reshape(-1, self.num_heads, self.head_dim) |
|
dv = self.act(self.dv_proj(f_ij)).reshape(-1, self.num_heads, self.head_dim) |
|
|
|
vec1, vec2, vec3 = torch.split(self.vec_proj(vec), self.hidden_channels, dim=-1) |
|
vec_dot = (vec1 * vec2).sum(dim=1) |
|
|
|
|
|
x, vec_out = self.propagate( |
|
edge_index, |
|
q=q, |
|
k=k, |
|
v=v, |
|
dk=dk, |
|
dv=dv, |
|
vec=vec, |
|
r_ij=r_ij, |
|
d_ij=d_ij, |
|
size=None, |
|
) |
|
|
|
o1, o2, o3 = torch.split(self.o_proj(x), self.hidden_channels, dim=1) |
|
dx = vec_dot * o2 + o3 |
|
dvec = vec3 * o1.unsqueeze(1) + vec_out |
|
if not self.last_layer: |
|
|
|
df_ij = self.edge_updater(edge_index, vec=vec, d_ij=d_ij, f_ij=f_ij) |
|
return dx, dvec, df_ij |
|
else: |
|
return dx, dvec, None |
|
|
|
def message(self, q_i, k_j, v_j, vec_j, dk, dv, r_ij, d_ij): |
|
|
|
attn = (q_i * k_j * dk).sum(dim=-1) |
|
attn = self.attn_activation(attn) * self.cutoff(r_ij).unsqueeze(1) |
|
|
|
v_j = v_j * dv |
|
v_j = (v_j * attn.unsqueeze(2)).view(-1, self.hidden_channels) |
|
|
|
s1, s2 = torch.split(self.act(self.s_proj(v_j)), self.hidden_channels, dim=1) |
|
vec_j = vec_j * s1.unsqueeze(1) + s2.unsqueeze(1) * d_ij.unsqueeze(2) |
|
|
|
return v_j, vec_j |
|
|
|
def edge_update(self, vec_i, vec_j, d_ij, f_ij): |
|
w1 = self.vector_rejection(self.w_trg_proj(vec_i), d_ij) |
|
w2 = self.vector_rejection(self.w_src_proj(vec_j), -d_ij) |
|
w_dot = (w1 * w2).sum(dim=1) |
|
df_ij = self.act(self.f_proj(f_ij)) * w_dot |
|
return df_ij |
|
|
|
def aggregate( |
|
self, |
|
features: Tuple[torch.Tensor, torch.Tensor], |
|
index: torch.Tensor, |
|
ptr: Optional[torch.Tensor], |
|
dim_size: Optional[int], |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
x, vec = features |
|
x = scatter(x, index, dim=self.node_dim, dim_size=dim_size) |
|
vec = scatter(vec, index, dim=self.node_dim, dim_size=dim_size) |
|
return x, vec |
|
|
|
def update(self, inputs: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: |
|
return inputs |
|
|
|
class ViS_MP_Vertex_Edge(ViS_MP): |
|
|
|
def __init__( |
|
self, |
|
num_heads, |
|
hidden_channels, |
|
activation, |
|
attn_activation, |
|
cutoff, |
|
vecnorm_type, |
|
trainable_vecnorm, |
|
last_layer=False |
|
): |
|
super().__init__(num_heads, hidden_channels, activation, attn_activation, cutoff, vecnorm_type, trainable_vecnorm, last_layer) |
|
|
|
if not self.last_layer: |
|
self.f_proj = nn.Linear(hidden_channels, hidden_channels * 2) |
|
self.t_src_proj = nn.Linear(hidden_channels, hidden_channels, bias=False) |
|
self.t_trg_proj = nn.Linear(hidden_channels, hidden_channels, bias=False) |
|
|
|
def edge_update(self, vec_i, vec_j, d_ij, f_ij): |
|
|
|
w1 = self.vector_rejection(self.w_trg_proj(vec_i), d_ij) |
|
w2 = self.vector_rejection(self.w_src_proj(vec_j), -d_ij) |
|
w_dot = (w1 * w2).sum(dim=1) |
|
|
|
t1 = self.vector_rejection(self.t_trg_proj(vec_i), d_ij) |
|
t2 = self.vector_rejection(self.t_src_proj(vec_i), -d_ij) |
|
t_dot = (t1 * t2).sum(dim=1) |
|
|
|
f1, f2 = torch.split(self.act(self.f_proj(f_ij)), self.hidden_channels, dim=-1) |
|
|
|
return f1 * w_dot + f2 * t_dot |
|
|
|
def forward(self, x, vec, edge_index, r_ij, f_ij, d_ij): |
|
x = self.layernorm(x) |
|
vec = self.vec_layernorm(vec) |
|
|
|
q = self.q_proj(x).reshape(-1, self.num_heads, self.head_dim) |
|
k = self.k_proj(x).reshape(-1, self.num_heads, self.head_dim) |
|
v = self.v_proj(x).reshape(-1, self.num_heads, self.head_dim) |
|
dk = self.act(self.dk_proj(f_ij)).reshape(-1, self.num_heads, self.head_dim) |
|
dv = self.act(self.dv_proj(f_ij)).reshape(-1, self.num_heads, self.head_dim) |
|
|
|
vec1, vec2, vec3 = torch.split(self.vec_proj(vec), self.hidden_channels, dim=-1) |
|
vec_dot = (vec1 * vec2).sum(dim=1) |
|
|
|
|
|
x, vec_out = self.propagate( |
|
edge_index, |
|
q=q, |
|
k=k, |
|
v=v, |
|
dk=dk, |
|
dv=dv, |
|
vec=vec, |
|
r_ij=r_ij, |
|
d_ij=d_ij, |
|
size=None, |
|
) |
|
|
|
o1, o2, o3 = torch.split(self.o_proj(x), self.hidden_channels, dim=1) |
|
dx = vec_dot * o2 + o3 |
|
dvec = vec3 * o1.unsqueeze(1) + vec_out |
|
if not self.last_layer: |
|
|
|
df_ij = self.edge_updater(edge_index, vec=vec, d_ij=d_ij, f_ij=f_ij) |
|
return dx, dvec, df_ij |
|
else: |
|
return dx, dvec, None |
|
|
|
class ViS_MP_Vertex_Node(ViS_MP): |
|
def __init__( |
|
self, |
|
num_heads, |
|
hidden_channels, |
|
activation, |
|
attn_activation, |
|
cutoff, |
|
vecnorm_type, |
|
trainable_vecnorm, |
|
last_layer=False, |
|
): |
|
super().__init__(num_heads, hidden_channels, activation, attn_activation, cutoff, vecnorm_type, trainable_vecnorm, last_layer) |
|
|
|
self.t_src_proj = nn.Linear(hidden_channels, hidden_channels, bias=False) |
|
self.t_trg_proj = nn.Linear(hidden_channels, hidden_channels, bias=False) |
|
|
|
self.o_proj = nn.Linear(hidden_channels, hidden_channels * 4) |
|
|
|
def forward(self, x, vec, edge_index, r_ij, f_ij, d_ij): |
|
x = self.layernorm(x) |
|
vec = self.vec_layernorm(vec) |
|
|
|
q = self.q_proj(x).reshape(-1, self.num_heads, self.head_dim) |
|
k = self.k_proj(x).reshape(-1, self.num_heads, self.head_dim) |
|
v = self.v_proj(x).reshape(-1, self.num_heads, self.head_dim) |
|
dk = self.act(self.dk_proj(f_ij)).reshape(-1, self.num_heads, self.head_dim) |
|
dv = self.act(self.dv_proj(f_ij)).reshape(-1, self.num_heads, self.head_dim) |
|
|
|
vec1, vec2, vec3 = torch.split(self.vec_proj(vec), self.hidden_channels, dim=-1) |
|
vec_dot = (vec1 * vec2).sum(dim=1) |
|
|
|
|
|
x, vec_out, t_dot = self.propagate( |
|
edge_index, |
|
q=q, |
|
k=k, |
|
v=v, |
|
dk=dk, |
|
dv=dv, |
|
vec=vec, |
|
r_ij=r_ij, |
|
d_ij=d_ij, |
|
size=None, |
|
) |
|
|
|
o1, o2, o3, o4 = torch.split(self.o_proj(x), self.hidden_channels, dim=1) |
|
dx = vec_dot * o2 + t_dot * o3 + o4 |
|
dvec = vec3 * o1.unsqueeze(1) + vec_out |
|
if not self.last_layer: |
|
|
|
df_ij = self.edge_updater(edge_index, vec=vec, d_ij=d_ij, f_ij=f_ij) |
|
return dx, dvec, df_ij |
|
else: |
|
return dx, dvec, None |
|
|
|
def edge_update(self, vec_i, vec_j, d_ij, f_ij): |
|
w1 = self.vector_rejection(self.w_trg_proj(vec_i), d_ij) |
|
w2 = self.vector_rejection(self.w_src_proj(vec_j), -d_ij) |
|
w_dot = (w1 * w2).sum(dim=1) |
|
df_ij = self.act(self.f_proj(f_ij)) * w_dot |
|
return df_ij |
|
|
|
def message(self, q_i, k_j, v_j, vec_i, vec_j, dk, dv, r_ij, d_ij): |
|
|
|
attn = (q_i * k_j * dk).sum(dim=-1) |
|
attn = self.attn_activation(attn) * self.cutoff(r_ij).unsqueeze(1) |
|
|
|
v_j = v_j * dv |
|
v_j = (v_j * attn.unsqueeze(2)).view(-1, self.hidden_channels) |
|
|
|
t1 = self.vector_rejection(self.t_trg_proj(vec_i), d_ij) |
|
t2 = self.vector_rejection(self.t_src_proj(vec_i), -d_ij) |
|
t_dot = (t1 * t2).sum(dim=1) |
|
|
|
s1, s2 = torch.split(self.act(self.s_proj(v_j)), self.hidden_channels, dim=1) |
|
vec_j = vec_j * s1.unsqueeze(1) + s2.unsqueeze(1) * d_ij.unsqueeze(2) |
|
|
|
return v_j, vec_j, t_dot |
|
|
|
def aggregate( |
|
self, |
|
features: Tuple[torch.Tensor, torch.Tensor], |
|
index: torch.Tensor, |
|
ptr: Optional[torch.Tensor], |
|
dim_size: Optional[int], |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
x, vec, t_dot = features |
|
x = scatter(x, index, dim=self.node_dim, dim_size=dim_size) |
|
vec = scatter(vec, index, dim=self.node_dim, dim_size=dim_size) |
|
t_dot = scatter(t_dot, index, dim=self.node_dim, dim_size=dim_size) |
|
return x, vec, t_dot |
|
|
|
VIS_MP_MAP = {'Node': ViS_MP_Vertex_Node, 'Edge': ViS_MP_Vertex_Edge, 'None': ViS_MP} |
|
|
|
def create_model(args, prior_model=None, mean=None, std=None): |
|
visnet_args = dict( |
|
lmax=args["lmax"], |
|
vecnorm_type=args["vecnorm_type"], |
|
trainable_vecnorm=args["trainable_vecnorm"], |
|
num_heads=args["num_heads"], |
|
num_layers=args["num_layers"], |
|
hidden_channels=args["embedding_dimension"], |
|
num_rbf=args["num_rbf"], |
|
rbf_type=args["rbf_type"], |
|
trainable_rbf=args["trainable_rbf"], |
|
activation=args["activation"], |
|
attn_activation=args["attn_activation"], |
|
max_z=args["max_z"], |
|
cutoff=args["cutoff"], |
|
max_num_neighbors=args["max_num_neighbors"], |
|
vertex_type=args["vertex_type"], |
|
) |
|
|
|
|
|
if args["model"] == "ViSNetBlock": |
|
representation_model = ViSNetBlock(**visnet_args) |
|
else: |
|
raise ValueError(f"Unknown model {args['model']}.") |
|
|
|
|
|
if args["prior_model"] and prior_model is None: |
|
assert "prior_args" in args, ( |
|
f"Requested prior model {args['prior_model']} but the " |
|
f'arguments are lacking the key "prior_args".' |
|
) |
|
assert hasattr(priors, args["prior_model"]), ( |
|
f'Unknown prior model {args["prior_model"]}. ' |
|
f'Available models are {", ".join(priors.__all__)}' |
|
) |
|
|
|
prior_model = getattr(priors, args["prior_model"])(**args["prior_args"]) |
|
|
|
|
|
output_prefix = "Equivariant" |
|
output_model = getattr(output_modules, output_prefix + args["output_model"])(args["embedding_dimension"], args["activation"]) |
|
|
|
model = ViSNet( |
|
representation_model, |
|
output_model, |
|
prior_model=prior_model, |
|
reduce_op=args["reduce_op"], |
|
mean=mean, |
|
std=std, |
|
derivative=args["derivative"], |
|
) |
|
return model |
|
|
|
|
|
def load_model(filepath, args=None, device="cpu", **kwargs): |
|
ckpt = torch.load(filepath, map_location="cpu") |
|
if args is None: |
|
args = ckpt["hyper_parameters"] |
|
|
|
for key, value in kwargs.items(): |
|
if not key in args: |
|
rank_zero_warn(f"Unknown hyperparameter: {key}={value}") |
|
args[key] = value |
|
|
|
model = create_model(args) |
|
state_dict = {re.sub(r"^model\.", "", k): v for k, v in ckpt["state_dict"].items()} |
|
model.load_state_dict(state_dict) |
|
|
|
return model.to(device) |
|
|
|
|
|
class ViSNet(nn.Module): |
|
def __init__( |
|
self, |
|
representation_model, |
|
output_model, |
|
prior_model=None, |
|
reduce_op="add", |
|
mean=None, |
|
std=None, |
|
derivative=False, |
|
): |
|
super(ViSNet, self).__init__() |
|
self.representation_model = representation_model |
|
self.output_model = output_model |
|
|
|
self.prior_model = prior_model |
|
if not output_model.allow_prior_model and prior_model is not None: |
|
self.prior_model = None |
|
rank_zero_warn( |
|
"Prior model was given but the output model does " |
|
"not allow prior models. Dropping the prior model." |
|
) |
|
|
|
self.reduce_op = reduce_op |
|
self.derivative = derivative |
|
|
|
mean = torch.scalar_tensor(0) if mean is None else mean |
|
self.register_buffer("mean", mean) |
|
std = torch.scalar_tensor(1) if std is None else std |
|
self.register_buffer("std", std) |
|
|
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
self.representation_model.reset_parameters() |
|
self.output_model.reset_parameters() |
|
if self.prior_model is not None: |
|
self.prior_model.reset_parameters() |
|
|
|
def forward(self, data: Data) -> Tuple[Tensor, Optional[Tensor]]: |
|
|
|
if self.derivative: |
|
data.pos.requires_grad_(True) |
|
|
|
x, v = self.representation_model(data) |
|
x = self.output_model.pre_reduce(x, v, data.z, data.pos, data.batch) |
|
x = x * self.std |
|
|
|
if self.prior_model is not None: |
|
x = self.prior_model(x, data.z) |
|
|
|
out = scatter(x, data.batch, dim=0, reduce=self.reduce_op) |
|
out = self.output_model.post_reduce(out) |
|
|
|
out = out + self.mean |
|
|
|
|
|
if self.derivative: |
|
grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(out)] |
|
dy = grad( |
|
[out], |
|
[data.pos], |
|
grad_outputs=grad_outputs, |
|
create_graph=True, |
|
retain_graph=True, |
|
)[0] |
|
if dy is None: |
|
raise RuntimeError("Autograd returned None for the force prediction.") |
|
return out, -dy |
|
return out, None |
|
|
|
class LNNP(LightningModule): |
|
def __init__(self, hparams, prior_model=None, mean=None, std=None): |
|
super(LNNP, self).__init__() |
|
|
|
self.save_hyperparameters(hparams) |
|
|
|
if self.hparams.load_model: |
|
self.model = load_model(self.hparams.load_model, args=self.hparams) |
|
else: |
|
self.model = create_model(self.hparams, prior_model, mean, std) |
|
|
|
self._reset_losses_dict() |
|
self._reset_ema_dict() |
|
self._reset_inference_results() |
|
|
|
def configure_optimizers(self): |
|
optimizer = AdamW( |
|
self.model.parameters(), |
|
lr=self.hparams.lr, |
|
weight_decay=self.hparams.weight_decay, |
|
) |
|
scheduler = ReduceLROnPlateau( |
|
optimizer, |
|
"min", |
|
factor=self.hparams.lr_factor, |
|
patience=self.hparams.lr_patience, |
|
min_lr=self.hparams.lr_min, |
|
) |
|
lr_scheduler = { |
|
"scheduler": scheduler, |
|
"monitor": "val_loss", |
|
"interval": "epoch", |
|
"frequency": 1, |
|
} |
|
return [optimizer], [lr_scheduler] |
|
|
|
def forward(self, data): |
|
return self.model(data) |
|
|
|
def training_step(self, batch, batch_idx): |
|
loss_fn = mse_loss if self.hparams.loss_type == 'MSE' else l1_loss |
|
|
|
return self.step(batch, loss_fn, "train") |
|
|
|
def validation_step(self, batch, batch_idx, *args): |
|
if len(args) == 0 or (len(args) > 0 and args[0] == 0): |
|
|
|
return self.step(batch, mse_loss, "val") |
|
|
|
return self.step(batch, l1_loss, "test") |
|
|
|
def test_step(self, batch, batch_idx): |
|
return self.step(batch, l1_loss, "test") |
|
|
|
def step(self, batch, loss_fn, stage): |
|
with torch.set_grad_enabled(stage == "train" or self.hparams.derivative): |
|
pred, deriv = self(batch) |
|
if stage == "test": |
|
self.inference_results['y_pred'].append(pred.squeeze(-1).detach().cpu()) |
|
self.inference_results['y_true'].append(batch.y.squeeze(-1).detach().cpu()) |
|
if self.hparams.derivative: |
|
self.inference_results['dy_pred'].append(deriv.squeeze(-1).detach().cpu()) |
|
self.inference_results['dy_true'].append(batch.dy.squeeze(-1).detach().cpu()) |
|
|
|
loss_y, loss_dy = 0, 0 |
|
if self.hparams.derivative: |
|
if "y" not in batch: |
|
deriv = deriv + pred.sum() * 0 |
|
|
|
loss_dy = loss_fn(deriv, batch.dy) |
|
|
|
if stage in ["train", "val"] and self.hparams.loss_scale_dy < 1: |
|
if self.ema[stage + "_dy"] is None: |
|
self.ema[stage + "_dy"] = loss_dy.detach() |
|
|
|
loss_dy = ( |
|
self.hparams.loss_scale_dy * loss_dy |
|
+ (1 - self.hparams.loss_scale_dy) * self.ema[stage + "_dy"] |
|
) |
|
self.ema[stage + "_dy"] = loss_dy.detach() |
|
|
|
if self.hparams.force_weight > 0: |
|
self.losses[stage + "_dy"].append(loss_dy.detach()) |
|
|
|
if "y" in batch: |
|
if batch.y.ndim == 1: |
|
batch.y = batch.y.unsqueeze(1) |
|
|
|
loss_y = loss_fn(pred, batch.y) |
|
|
|
if stage in ["train", "val"] and self.hparams.loss_scale_y < 1: |
|
if self.ema[stage + "_y"] is None: |
|
self.ema[stage + "_y"] = loss_y.detach() |
|
|
|
loss_y = ( |
|
self.hparams.loss_scale_y * loss_y |
|
+ (1 - self.hparams.loss_scale_y) * self.ema[stage + "_y"] |
|
) |
|
self.ema[stage + "_y"] = loss_y.detach() |
|
|
|
if self.hparams.energy_weight > 0: |
|
self.losses[stage + "_y"].append(loss_y.detach()) |
|
|
|
loss = loss_y * self.hparams.energy_weight + loss_dy * self.hparams.force_weight |
|
|
|
self.losses[stage].append(loss.detach()) |
|
|
|
return loss |
|
|
|
def optimizer_step(self, *args, **kwargs): |
|
optimizer = kwargs["optimizer"] if "optimizer" in kwargs else args[2] |
|
if self.trainer.global_step < self.hparams.lr_warmup_steps: |
|
lr_scale = min(1.0, float(self.trainer.global_step + 1) / float(self.hparams.lr_warmup_steps)) |
|
for pg in optimizer.param_groups: |
|
pg["lr"] = lr_scale * self.hparams.lr |
|
super().optimizer_step(*args, **kwargs) |
|
optimizer.zero_grad() |
|
|
|
def training_epoch_end(self, training_step_outputs): |
|
dm = self.trainer.datamodule |
|
if hasattr(dm, "test_dataset") and len(dm.test_dataset) > 0: |
|
delta = 0 if self.hparams.reload == 1 else 1 |
|
should_reset = ( |
|
(self.current_epoch + delta + 1) % self.hparams.test_interval == 0 |
|
or ((self.current_epoch + delta) % self.hparams.test_interval == 0 and self.current_epoch != 0) |
|
) |
|
if should_reset: |
|
self.trainer.reset_val_dataloader() |
|
self.trainer.fit_loop.epoch_loop.val_loop.epoch_loop._reset_dl_batch_idx(len(self.trainer.val_dataloaders)) |
|
|
|
def validation_epoch_end(self, validation_step_outputs): |
|
if not self.trainer.sanity_checking: |
|
result_dict = { |
|
"epoch": float(self.current_epoch), |
|
"lr": self.trainer.optimizers[0].param_groups[0]["lr"], |
|
"train_loss": torch.stack(self.losses["train"]).mean(), |
|
"val_loss": torch.stack(self.losses["val"]).mean(), |
|
} |
|
|
|
|
|
if len(self.losses["test"]) > 0: |
|
result_dict["test_loss"] = torch.stack(self.losses["test"]).mean() |
|
|
|
|
|
if len(self.losses["train_y"]) > 0 and len(self.losses["train_dy"]) > 0: |
|
result_dict["train_loss_y"] = torch.stack(self.losses["train_y"]).mean() |
|
result_dict["train_loss_dy"] = torch.stack(self.losses["train_dy"]).mean() |
|
result_dict["val_loss_y"] = torch.stack(self.losses["val_y"]).mean() |
|
result_dict["val_loss_dy"] = torch.stack(self.losses["val_dy"]).mean() |
|
|
|
if len(self.losses["test_y"]) > 0 and len(self.losses["test_dy"]) > 0: |
|
result_dict["test_loss_y"] = torch.stack(self.losses["test_y"]).mean() |
|
result_dict["test_loss_dy"] = torch.stack(self.losses["test_dy"]).mean() |
|
|
|
self.log_dict(result_dict, sync_dist=True) |
|
|
|
self._reset_losses_dict() |
|
self._reset_inference_results() |
|
|
|
def test_epoch_end(self, outputs) -> None: |
|
for key in self.inference_results.keys(): |
|
if len(self.inference_results[key]) > 0: |
|
self.inference_results[key] = torch.cat(self.inference_results[key], dim=0) |
|
|
|
def _reset_losses_dict(self): |
|
self.losses = { |
|
"train": [], "val": [], "test": [], |
|
"train_y": [], "val_y": [], "test_y": [], |
|
"train_dy": [], "val_dy": [], "test_dy": [], |
|
} |
|
|
|
def _reset_inference_results(self): |
|
self.inference_results = {'y_pred': [], 'y_true': [], 'dy_pred': [], 'dy_true': []} |
|
|
|
def _reset_ema_dict(self): |
|
self.ema = {"train_y": None, "val_y": None, "train_dy": None, "val_dy": None} |
|
|
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser(description='Training') |
|
parser.add_argument('--load-model', action=LoadFromCheckpoint, help='Restart training using a model checkpoint') |
|
parser.add_argument('--conf', '-c', type=open, action=LoadFromFile, help='Configuration yaml file') |
|
|
|
|
|
parser.add_argument('--num-epochs', default=300, type=int, help='number of epochs') |
|
parser.add_argument('--lr-warmup-steps', type=int, default=0, help='How many steps to warm-up over. Defaults to 0 for no warm-up') |
|
parser.add_argument('--lr', default=1e-4, type=float, help='learning rate') |
|
parser.add_argument('--lr-patience', type=int, default=10, help='Patience for lr-schedule. Patience per eval-interval of validation') |
|
parser.add_argument('--lr-min', type=float, default=1e-6, help='Minimum learning rate before early stop') |
|
parser.add_argument('--lr-factor', type=float, default=0.8, help='Minimum learning rate before early stop') |
|
parser.add_argument('--weight-decay', type=float, default=0.0, help='Weight decay strength') |
|
parser.add_argument('--early-stopping-patience', type=int, default=30, help='Stop training after this many epochs without improvement') |
|
parser.add_argument('--loss-type', type=str, default='MSE', choices=['MSE', 'MAE'], help='Loss type') |
|
parser.add_argument('--loss-scale-y', type=float, default=1.0, help="Scale the loss y of the target") |
|
parser.add_argument('--loss-scale-dy', type=float, default=1.0, help="Scale the loss dy of the target") |
|
parser.add_argument('--energy-weight', default=1.0, type=float, help='Weighting factor for energies in the loss function') |
|
parser.add_argument('--force-weight', default=1.0, type=float, help='Weighting factor for forces in the loss function') |
|
|
|
|
|
parser.add_argument('--dataset', default=None, type=str, choices=datasets.__all__, help='Name of the torch_geometric dataset') |
|
parser.add_argument('--dataset-arg', default=None, type=str, help='Additional dataset argument') |
|
parser.add_argument('--dataset-root', default=None, type=str, help='Data storage directory') |
|
parser.add_argument('--derivative', default=False, action=argparse.BooleanOptionalAction, help='If true, take the derivative of the prediction w.r.t coordinates') |
|
parser.add_argument('--split-mode', default=None, type=str, help='Split mode for Molecule3D dataset') |
|
|
|
|
|
parser.add_argument('--reload', type=int, default=0, help='Reload dataloaders every n epoch') |
|
parser.add_argument('--batch-size', default=32, type=int, help='batch size') |
|
parser.add_argument('--inference-batch-size', default=None, type=int, help='Batchsize for validation and tests.') |
|
parser.add_argument('--standardize', action=argparse.BooleanOptionalAction, default=False, help='If true, multiply prediction by dataset std and add mean') |
|
parser.add_argument('--splits', default=None, help='Npz with splits idx_train, idx_val, idx_test') |
|
parser.add_argument('--train-size', type=number, default=950, help='Percentage/number of samples in training set (None to use all remaining samples)') |
|
parser.add_argument('--val-size', type=number, default=50, help='Percentage/number of samples in validation set (None to use all remaining samples)') |
|
parser.add_argument('--test-size', type=number, default=None, help='Percentage/number of samples in test set (None to use all remaining samples)') |
|
parser.add_argument('--num-workers', type=int, default=4, help='Number of workers for data prefetch') |
|
|
|
|
|
parser.add_argument('--model', type=str, default='ViSNetBlock', choices=models.__all__, help='Which model to train') |
|
parser.add_argument('--output-model', type=str, default='Scalar', choices=output_modules.__all__, help='The type of output model') |
|
parser.add_argument('--prior-model', type=str, default=None, choices=priors.__all__, help='Which prior model to use') |
|
parser.add_argument('--prior-args', type=dict, default=None, help='Additional arguments for the prior model') |
|
|
|
|
|
parser.add_argument('--embedding-dimension', type=int, default=256, help='Embedding dimension') |
|
parser.add_argument('--num-layers', type=int, default=6, help='Number of interaction layers in the model') |
|
parser.add_argument('--num-rbf', type=int, default=64, help='Number of radial basis functions in model') |
|
parser.add_argument('--activation', type=str, default='silu', choices=list(act_class_mapping.keys()), help='Activation function') |
|
parser.add_argument('--rbf-type', type=str, default='expnorm', choices=list(rbf_class_mapping.keys()), help='Type of distance expansion') |
|
parser.add_argument('--trainable-rbf', action=argparse.BooleanOptionalAction, default=False, help='If distance expansion functions should be trainable') |
|
parser.add_argument('--attn-activation', default='silu', choices=list(act_class_mapping.keys()), help='Attention activation function') |
|
parser.add_argument('--num-heads', type=int, default=8, help='Number of attention heads') |
|
parser.add_argument('--cutoff', type=float, default=5.0, help='Cutoff in model') |
|
parser.add_argument('--max-z', type=int, default=100, help='Maximum atomic number that fits in the embedding matrix') |
|
parser.add_argument('--max-num-neighbors', type=int, default=32, help='Maximum number of neighbors to consider in the network') |
|
parser.add_argument('--reduce-op', type=str, default='add', choices=['add', 'mean'], help='Reduce operation to apply to atomic predictions') |
|
parser.add_argument('--lmax', type=int, default=2, help='Max order of spherical harmonics') |
|
parser.add_argument('--vecnorm-type', type=str, default='max_min', help='Type of vector normalization') |
|
parser.add_argument('--trainable-vecnorm', action=argparse.BooleanOptionalAction, default=False, help='If vector normalization should be trainable') |
|
parser.add_argument('--vertex-type', type=str, default='Edge', choices=['None', 'Edge', 'Node'], help='If add vertex angle and Where to add vertex angles') |
|
|
|
|
|
parser.add_argument('--ngpus', type=int, default=-1, help='Number of GPUs, -1 use all available. Use CUDA_VISIBLE_DEVICES=1, to decide gpus') |
|
parser.add_argument('--num-nodes', type=int, default=1, help='Number of nodes') |
|
parser.add_argument('--precision', type=int, default=32, choices=[16, 32], help='Floating point precision') |
|
parser.add_argument('--log-dir', type=str, default="aspirin_log", help='Log directory') |
|
parser.add_argument('--task', type=str, default='train', choices=['train', 'inference'], help='Train or inference') |
|
parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)') |
|
parser.add_argument('--distributed-backend', default='ddp', help='Distributed backend') |
|
parser.add_argument('--redirect', action=argparse.BooleanOptionalAction, default=False, help='Redirect stdout and stderr to log_dir/log') |
|
parser.add_argument('--accelerator', default='gpu', help='Supports passing different accelerator types ("cpu", "gpu", "tpu", "ipu", "auto")') |
|
parser.add_argument('--test-interval', type=int, default=10, help='Test interval, one test per n epochs (default: 10)') |
|
parser.add_argument('--save-interval', type=int, default=2, help='Save interval, one save per n epochs (default: 10)') |
|
parser.add_argument("--out_dir", type=str, default="run_0") |
|
|
|
args = parser.parse_args() |
|
|
|
if args.redirect: |
|
os.makedirs(args.log_dir, exist_ok=True) |
|
sys.stdout = open(os.path.join(args.log_dir, "log"), "w") |
|
sys.stderr = sys.stdout |
|
logging.getLogger("pytorch_lightning").addHandler(logging.StreamHandler(sys.stdout)) |
|
|
|
if args.inference_batch_size is None: |
|
args.inference_batch_size = args.batch_size |
|
save_argparse(args, os.path.join(args.log_dir, "input.yaml"), exclude=["conf"]) |
|
|
|
return args |
|
|
|
def main(args): |
|
pl.seed_everything(args.seed, workers=True) |
|
|
|
|
|
data = DataModule(args) |
|
data.prepare_dataset() |
|
|
|
default = ",".join(str(i) for i in range(torch.cuda.device_count())) |
|
cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", default=default).split(",") |
|
dir_name = f"output_ngpus_{len(cuda_visible_devices)}_bs_{args.batch_size}_lr_{args.lr}_seed_{args.seed}" + \ |
|
f"_reload_{args.reload}_lmax_{args.lmax}_vnorm_{args.vecnorm_type}" + \ |
|
f"_vertex_{args.vertex_type}_L{args.num_layers}_D{args.embedding_dimension}_H{args.num_heads}" + \ |
|
f"_cutoff_{args.cutoff}_E{args.energy_weight}_F{args.force_weight}_loss_{args.loss_type}" |
|
|
|
if args.load_model is None: |
|
args.log_dir = os.path.join(args.out_dir, args.log_dir , dir_name) |
|
if os.path.exists(args.log_dir): |
|
if os.path.exists(os.path.join(args.log_dir, "last.ckpt")): |
|
args.load_model = os.path.join(args.log_dir, "last.ckpt") |
|
csv_path = os.path.join(args.log_dir, "metrics.csv") |
|
while os.path.exists(csv_path): |
|
csv_path = csv_path + '.bak' |
|
if os.path.exists(os.path.join(args.log_dir, "metrics.csv")): |
|
os.rename(os.path.join(args.log_dir, "metrics.csv"), csv_path) |
|
|
|
prior = None |
|
if args.prior_model: |
|
assert hasattr(priors, args.prior_model), ( |
|
f"Unknown prior model {args['prior_model']}. " |
|
f"Available models are {', '.join(priors.__all__)}" |
|
) |
|
|
|
prior = getattr(priors, args.prior_model)(dataset=data.dataset) |
|
args.prior_args = prior.get_init_args() |
|
|
|
|
|
model = LNNP(args, prior_model=prior, mean=data.mean, std=data.std) |
|
|
|
if args.task == "train": |
|
|
|
checkpoint_callback = ModelCheckpoint( |
|
dirpath=args.log_dir, |
|
monitor="val_loss", |
|
save_top_k=2, |
|
save_last=True, |
|
every_n_epochs=args.save_interval, |
|
filename="{epoch}-{val_loss:.4f}-{test_loss:.4f}", |
|
) |
|
|
|
early_stopping = EarlyStopping("val_loss", patience=args.early_stopping_patience) |
|
|
|
tb_logger = TensorBoardLogger(os.getenv("TENSORBOARD_LOG_PATH", "/tensorboard_logs/"), name="", version="", default_hp_metric=False) |
|
csv_logger = CSVLogger(args.log_dir, name="", version="") |
|
ddp_plugin = DDPStrategy(find_unused_parameters=False) |
|
|
|
trainer = pl.Trainer( |
|
max_epochs=args.num_epochs, |
|
gpus=args.ngpus, |
|
num_nodes=args.num_nodes, |
|
accelerator=args.accelerator, |
|
default_root_dir=args.log_dir, |
|
auto_lr_find=False, |
|
callbacks=[early_stopping, checkpoint_callback], |
|
logger=[tb_logger, csv_logger], |
|
reload_dataloaders_every_n_epochs=args.reload, |
|
precision=args.precision, |
|
strategy=ddp_plugin, |
|
enable_progress_bar=True, |
|
) |
|
|
|
trainer.fit(model, datamodule=data, ckpt_path=args.load_model) |
|
|
|
test_trainer = pl.Trainer( |
|
logger=False, |
|
max_epochs=-1, |
|
num_nodes=1, |
|
gpus=1, |
|
default_root_dir=args.log_dir, |
|
enable_progress_bar=True, |
|
inference_mode=False, |
|
) |
|
|
|
if args.task == 'train': |
|
test_trainer.test(model=model, ckpt_path=trainer.checkpoint_callback.best_model_path, datamodule=data) |
|
elif args.task == 'inference': |
|
test_trainer.test(model=model, datamodule=data) |
|
torch.save(model.inference_results, os.path.join(args.log_dir, "inference_results.pt")) |
|
|
|
emae = calculate_mae(model.inference_results['y_true'].numpy(), model.inference_results['y_pred'].numpy()) |
|
Scalar_MAE = "{:.6f}".format(emae) |
|
print('Scalar MAE: {:.6f}'.format(emae)) |
|
|
|
final_infos = { |
|
"AutoMolecule3D":{ |
|
"means":{ |
|
"Scalar MAE": Scalar_MAE |
|
} |
|
} |
|
} |
|
|
|
if args.derivative: |
|
fmae = calculate_mae(model.inference_results['dy_true'].numpy(), model.inference_results['dy_pred'].numpy()) |
|
Forces_MAE = "{:.6f}".format(fmae) |
|
print('Forces MAE: {:.6f}'.format(fmae)) |
|
final_infos["AutoMolecule3D"]["means"]["Forces MAE"] = Forces_MAE |
|
|
|
with open(os.path.join(args.out_dir, "final_info.json"), "w") as f: |
|
json.dump(final_infos, f) |
|
|
|
|
|
if __name__ == "__main__": |
|
args = get_args() |
|
try: |
|
main(args) |
|
except Exception as e: |
|
print("Origin error in main process:", flush=True) |
|
traceback.print_exc(file=open(os.path.join(args.out_dir, "traceback.log"), "w")) |
|
raise |
|
|