yxc97's picture
Upload folder using huggingface_hub
62a2f1c verified
import argparse
import logging
import os
import sys
import json
import re
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn as nn
from torch import Tensor
from torch.autograd import grad
from torch_geometric.data import Data
from torch_geometric.nn import MessagePassing
from torch_scatter import scatter
from torch.nn.functional import l1_loss, mse_loss
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger
from pytorch_lightning.strategies import DDPStrategy
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning import LightningModule
from visnet import datasets, models, priors
from visnet.data import DataModule
from visnet.models import output_modules
from visnet.utils import LoadFromCheckpoint, LoadFromFile, number, save_argparse
from typing import Optional, Tuple , List
from metrics import calculate_mae
from visnet.models.utils import (
CosineCutoff,
Distance,
EdgeEmbedding,
NeighborEmbedding,
Sphere,
VecLayerNorm,
act_class_mapping,
rbf_class_mapping,
ExpNormalSmearing,
GaussianSmearing
)
"""
Models
"""
class ViSNetBlock(nn.Module):
def __init__(
self,
lmax=2,
vecnorm_type='none',
trainable_vecnorm=False,
num_heads=8,
num_layers=9,
hidden_channels=256,
num_rbf=32,
rbf_type="expnorm",
trainable_rbf=False,
activation="silu",
attn_activation="silu",
max_z=100,
cutoff=5.0,
max_num_neighbors=32,
vertex_type="Edge",
):
super(ViSNetBlock, self).__init__()
self.lmax = lmax
self.vecnorm_type = vecnorm_type
self.trainable_vecnorm = trainable_vecnorm
self.num_heads = num_heads
self.num_layers = num_layers
self.hidden_channels = hidden_channels
self.num_rbf = num_rbf
self.rbf_type = rbf_type
self.trainable_rbf = trainable_rbf
self.activation = activation
self.attn_activation = attn_activation
self.max_z = max_z
self.cutoff = cutoff
self.max_num_neighbors = max_num_neighbors
self.embedding = nn.Embedding(max_z, hidden_channels)
self.distance = Distance(cutoff, max_num_neighbors=max_num_neighbors, loop=True)
self.sphere = Sphere(l=lmax)
self.distance_expansion = rbf_class_mapping[rbf_type](cutoff, num_rbf, trainable_rbf)
self.neighbor_embedding = NeighborEmbedding(hidden_channels, num_rbf, cutoff, max_z).jittable()
self.edge_embedding = EdgeEmbedding(num_rbf, hidden_channels).jittable()
self.vis_mp_layers = nn.ModuleList()
vis_mp_kwargs = dict(
num_heads=num_heads,
hidden_channels=hidden_channels,
activation=activation,
attn_activation=attn_activation,
cutoff=cutoff,
vecnorm_type=vecnorm_type,
trainable_vecnorm=trainable_vecnorm
)
vis_mp_class = VIS_MP_MAP.get(vertex_type, ViS_MP)
for _ in range(num_layers - 1):
layer = vis_mp_class(last_layer=False, **vis_mp_kwargs).jittable()
self.vis_mp_layers.append(layer)
self.vis_mp_layers.append(vis_mp_class(last_layer=True, **vis_mp_kwargs).jittable())
self.out_norm = nn.LayerNorm(hidden_channels)
self.vec_out_norm = VecLayerNorm(hidden_channels, trainable=trainable_vecnorm, norm_type=vecnorm_type)
self.reset_parameters()
def reset_parameters(self):
self.embedding.reset_parameters()
self.distance_expansion.reset_parameters()
self.neighbor_embedding.reset_parameters()
self.edge_embedding.reset_parameters()
for layer in self.vis_mp_layers:
layer.reset_parameters()
self.out_norm.reset_parameters()
self.vec_out_norm.reset_parameters()
def forward(self, data: Data) -> Tuple[Tensor, Tensor]:
z, pos, batch = data.z, data.pos, data.batch
# Embedding Layers
x = self.embedding(z)
edge_index, edge_weight, edge_vec = self.distance(pos, batch)
edge_attr = self.distance_expansion(edge_weight)
mask = edge_index[0] != edge_index[1]
edge_vec[mask] = edge_vec[mask] / torch.norm(edge_vec[mask], dim=1).unsqueeze(1)
edge_vec = self.sphere(edge_vec)
x = self.neighbor_embedding(z, x, edge_index, edge_weight, edge_attr)
vec = torch.zeros(x.size(0), ((self.lmax + 1) ** 2) - 1, x.size(1), device=x.device)
edge_attr = self.edge_embedding(edge_index, edge_attr, x)
# ViS-MP Layers
for attn in self.vis_mp_layers[:-1]:
dx, dvec, dedge_attr = attn(x, vec, edge_index, edge_weight, edge_attr, edge_vec)
x = x + dx
vec = vec + dvec
edge_attr = edge_attr + dedge_attr
dx, dvec, _ = self.vis_mp_layers[-1](x, vec, edge_index, edge_weight, edge_attr, edge_vec)
x = x + dx
vec = vec + dvec
x = self.out_norm(x)
vec = self.vec_out_norm(vec)
return x, vec
class ViS_MP(MessagePassing):
def __init__(
self,
num_heads,
hidden_channels,
activation,
attn_activation,
cutoff,
vecnorm_type,
trainable_vecnorm,
last_layer=False,
):
super(ViS_MP, self).__init__(aggr="add", node_dim=0)
assert hidden_channels % num_heads == 0, (
f"The number of hidden channels ({hidden_channels}) "
f"must be evenly divisible by the number of "
f"attention heads ({num_heads})"
)
self.num_heads = num_heads
self.hidden_channels = hidden_channels
self.head_dim = hidden_channels // num_heads
self.last_layer = last_layer
self.layernorm = nn.LayerNorm(hidden_channels)
self.vec_layernorm = VecLayerNorm(hidden_channels, trainable=trainable_vecnorm, norm_type=vecnorm_type)
self.act = act_class_mapping[activation]()
self.attn_activation = act_class_mapping[attn_activation]()
self.cutoff = CosineCutoff(cutoff)
self.vec_proj = nn.Linear(hidden_channels, hidden_channels * 3, bias=False)
self.q_proj = nn.Linear(hidden_channels, hidden_channels)
self.k_proj = nn.Linear(hidden_channels, hidden_channels)
self.v_proj = nn.Linear(hidden_channels, hidden_channels)
self.dk_proj = nn.Linear(hidden_channels, hidden_channels)
self.dv_proj = nn.Linear(hidden_channels, hidden_channels)
self.s_proj = nn.Linear(hidden_channels, hidden_channels * 2)
if not self.last_layer:
self.f_proj = nn.Linear(hidden_channels, hidden_channels)
self.w_src_proj = nn.Linear(hidden_channels, hidden_channels, bias=False)
self.w_trg_proj = nn.Linear(hidden_channels, hidden_channels, bias=False)
self.o_proj = nn.Linear(hidden_channels, hidden_channels * 3)
self.reset_parameters()
@staticmethod
def vector_rejection(vec, d_ij):
vec_proj = (vec * d_ij.unsqueeze(2)).sum(dim=1, keepdim=True)
return vec - vec_proj * d_ij.unsqueeze(2)
def reset_parameters(self):
self.layernorm.reset_parameters()
self.vec_layernorm.reset_parameters()
nn.init.xavier_uniform_(self.q_proj.weight)
self.q_proj.bias.data.fill_(0)
nn.init.xavier_uniform_(self.k_proj.weight)
self.k_proj.bias.data.fill_(0)
nn.init.xavier_uniform_(self.v_proj.weight)
self.v_proj.bias.data.fill_(0)
nn.init.xavier_uniform_(self.o_proj.weight)
self.o_proj.bias.data.fill_(0)
nn.init.xavier_uniform_(self.s_proj.weight)
self.s_proj.bias.data.fill_(0)
if not self.last_layer:
nn.init.xavier_uniform_(self.f_proj.weight)
self.f_proj.bias.data.fill_(0)
nn.init.xavier_uniform_(self.w_src_proj.weight)
nn.init.xavier_uniform_(self.w_trg_proj.weight)
nn.init.xavier_uniform_(self.vec_proj.weight)
nn.init.xavier_uniform_(self.dk_proj.weight)
self.dk_proj.bias.data.fill_(0)
nn.init.xavier_uniform_(self.dv_proj.weight)
self.dv_proj.bias.data.fill_(0)
def forward(self, x, vec, edge_index, r_ij, f_ij, d_ij):
x = self.layernorm(x)
vec = self.vec_layernorm(vec)
q = self.q_proj(x).reshape(-1, self.num_heads, self.head_dim)
k = self.k_proj(x).reshape(-1, self.num_heads, self.head_dim)
v = self.v_proj(x).reshape(-1, self.num_heads, self.head_dim)
dk = self.act(self.dk_proj(f_ij)).reshape(-1, self.num_heads, self.head_dim)
dv = self.act(self.dv_proj(f_ij)).reshape(-1, self.num_heads, self.head_dim)
vec1, vec2, vec3 = torch.split(self.vec_proj(vec), self.hidden_channels, dim=-1)
vec_dot = (vec1 * vec2).sum(dim=1)
# propagate_type: (q: Tensor, k: Tensor, v: Tensor, dk: Tensor, dv: Tensor, vec: Tensor, r_ij: Tensor, d_ij: Tensor)
x, vec_out = self.propagate(
edge_index,
q=q,
k=k,
v=v,
dk=dk,
dv=dv,
vec=vec,
r_ij=r_ij,
d_ij=d_ij,
size=None,
)
o1, o2, o3 = torch.split(self.o_proj(x), self.hidden_channels, dim=1)
dx = vec_dot * o2 + o3
dvec = vec3 * o1.unsqueeze(1) + vec_out
if not self.last_layer:
# edge_updater_type: (vec: Tensor, d_ij: Tensor, f_ij: Tensor)
df_ij = self.edge_updater(edge_index, vec=vec, d_ij=d_ij, f_ij=f_ij)
return dx, dvec, df_ij
else:
return dx, dvec, None
def message(self, q_i, k_j, v_j, vec_j, dk, dv, r_ij, d_ij):
attn = (q_i * k_j * dk).sum(dim=-1)
attn = self.attn_activation(attn) * self.cutoff(r_ij).unsqueeze(1)
v_j = v_j * dv
v_j = (v_j * attn.unsqueeze(2)).view(-1, self.hidden_channels)
s1, s2 = torch.split(self.act(self.s_proj(v_j)), self.hidden_channels, dim=1)
vec_j = vec_j * s1.unsqueeze(1) + s2.unsqueeze(1) * d_ij.unsqueeze(2)
return v_j, vec_j
def edge_update(self, vec_i, vec_j, d_ij, f_ij):
w1 = self.vector_rejection(self.w_trg_proj(vec_i), d_ij)
w2 = self.vector_rejection(self.w_src_proj(vec_j), -d_ij)
w_dot = (w1 * w2).sum(dim=1)
df_ij = self.act(self.f_proj(f_ij)) * w_dot
return df_ij
def aggregate(
self,
features: Tuple[torch.Tensor, torch.Tensor],
index: torch.Tensor,
ptr: Optional[torch.Tensor],
dim_size: Optional[int],
) -> Tuple[torch.Tensor, torch.Tensor]:
x, vec = features
x = scatter(x, index, dim=self.node_dim, dim_size=dim_size)
vec = scatter(vec, index, dim=self.node_dim, dim_size=dim_size)
return x, vec
def update(self, inputs: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
return inputs
class ViS_MP_Vertex_Edge(ViS_MP):
def __init__(
self,
num_heads,
hidden_channels,
activation,
attn_activation,
cutoff,
vecnorm_type,
trainable_vecnorm,
last_layer=False
):
super().__init__(num_heads, hidden_channels, activation, attn_activation, cutoff, vecnorm_type, trainable_vecnorm, last_layer)
if not self.last_layer:
self.f_proj = nn.Linear(hidden_channels, hidden_channels * 2)
self.t_src_proj = nn.Linear(hidden_channels, hidden_channels, bias=False)
self.t_trg_proj = nn.Linear(hidden_channels, hidden_channels, bias=False)
def edge_update(self, vec_i, vec_j, d_ij, f_ij):
w1 = self.vector_rejection(self.w_trg_proj(vec_i), d_ij)
w2 = self.vector_rejection(self.w_src_proj(vec_j), -d_ij)
w_dot = (w1 * w2).sum(dim=1)
t1 = self.vector_rejection(self.t_trg_proj(vec_i), d_ij)
t2 = self.vector_rejection(self.t_src_proj(vec_i), -d_ij)
t_dot = (t1 * t2).sum(dim=1)
f1, f2 = torch.split(self.act(self.f_proj(f_ij)), self.hidden_channels, dim=-1)
return f1 * w_dot + f2 * t_dot
def forward(self, x, vec, edge_index, r_ij, f_ij, d_ij):
x = self.layernorm(x)
vec = self.vec_layernorm(vec)
q = self.q_proj(x).reshape(-1, self.num_heads, self.head_dim)
k = self.k_proj(x).reshape(-1, self.num_heads, self.head_dim)
v = self.v_proj(x).reshape(-1, self.num_heads, self.head_dim)
dk = self.act(self.dk_proj(f_ij)).reshape(-1, self.num_heads, self.head_dim)
dv = self.act(self.dv_proj(f_ij)).reshape(-1, self.num_heads, self.head_dim)
vec1, vec2, vec3 = torch.split(self.vec_proj(vec), self.hidden_channels, dim=-1)
vec_dot = (vec1 * vec2).sum(dim=1)
# propagate_type: (q: Tensor, k: Tensor, v: Tensor, dk: Tensor, dv: Tensor, vec: Tensor, r_ij: Tensor, d_ij: Tensor)
x, vec_out = self.propagate(
edge_index,
q=q,
k=k,
v=v,
dk=dk,
dv=dv,
vec=vec,
r_ij=r_ij,
d_ij=d_ij,
size=None,
)
o1, o2, o3 = torch.split(self.o_proj(x), self.hidden_channels, dim=1)
dx = vec_dot * o2 + o3
dvec = vec3 * o1.unsqueeze(1) + vec_out
if not self.last_layer:
# edge_updater_type: (vec: Tensor, d_ij: Tensor, f_ij: Tensor)
df_ij = self.edge_updater(edge_index, vec=vec, d_ij=d_ij, f_ij=f_ij)
return dx, dvec, df_ij
else:
return dx, dvec, None
class ViS_MP_Vertex_Node(ViS_MP):
def __init__(
self,
num_heads,
hidden_channels,
activation,
attn_activation,
cutoff,
vecnorm_type,
trainable_vecnorm,
last_layer=False,
):
super().__init__(num_heads, hidden_channels, activation, attn_activation, cutoff, vecnorm_type, trainable_vecnorm, last_layer)
self.t_src_proj = nn.Linear(hidden_channels, hidden_channels, bias=False)
self.t_trg_proj = nn.Linear(hidden_channels, hidden_channels, bias=False)
self.o_proj = nn.Linear(hidden_channels, hidden_channels * 4)
def forward(self, x, vec, edge_index, r_ij, f_ij, d_ij):
x = self.layernorm(x)
vec = self.vec_layernorm(vec)
q = self.q_proj(x).reshape(-1, self.num_heads, self.head_dim)
k = self.k_proj(x).reshape(-1, self.num_heads, self.head_dim)
v = self.v_proj(x).reshape(-1, self.num_heads, self.head_dim)
dk = self.act(self.dk_proj(f_ij)).reshape(-1, self.num_heads, self.head_dim)
dv = self.act(self.dv_proj(f_ij)).reshape(-1, self.num_heads, self.head_dim)
vec1, vec2, vec3 = torch.split(self.vec_proj(vec), self.hidden_channels, dim=-1)
vec_dot = (vec1 * vec2).sum(dim=1)
# propagate_type: (q: Tensor, k: Tensor, v: Tensor, dk: Tensor, dv: Tensor, vec: Tensor, r_ij: Tensor, d_ij: Tensor)
x, vec_out, t_dot = self.propagate(
edge_index,
q=q,
k=k,
v=v,
dk=dk,
dv=dv,
vec=vec,
r_ij=r_ij,
d_ij=d_ij,
size=None,
)
o1, o2, o3, o4 = torch.split(self.o_proj(x), self.hidden_channels, dim=1)
dx = vec_dot * o2 + t_dot * o3 + o4
dvec = vec3 * o1.unsqueeze(1) + vec_out
if not self.last_layer:
# edge_updater_type: (vec: Tensor, d_ij: Tensor, f_ij: Tensor)
df_ij = self.edge_updater(edge_index, vec=vec, d_ij=d_ij, f_ij=f_ij)
return dx, dvec, df_ij
else:
return dx, dvec, None
def edge_update(self, vec_i, vec_j, d_ij, f_ij):
w1 = self.vector_rejection(self.w_trg_proj(vec_i), d_ij)
w2 = self.vector_rejection(self.w_src_proj(vec_j), -d_ij)
w_dot = (w1 * w2).sum(dim=1)
df_ij = self.act(self.f_proj(f_ij)) * w_dot
return df_ij
def message(self, q_i, k_j, v_j, vec_i, vec_j, dk, dv, r_ij, d_ij):
attn = (q_i * k_j * dk).sum(dim=-1)
attn = self.attn_activation(attn) * self.cutoff(r_ij).unsqueeze(1)
v_j = v_j * dv
v_j = (v_j * attn.unsqueeze(2)).view(-1, self.hidden_channels)
t1 = self.vector_rejection(self.t_trg_proj(vec_i), d_ij)
t2 = self.vector_rejection(self.t_src_proj(vec_i), -d_ij)
t_dot = (t1 * t2).sum(dim=1)
s1, s2 = torch.split(self.act(self.s_proj(v_j)), self.hidden_channels, dim=1)
vec_j = vec_j * s1.unsqueeze(1) + s2.unsqueeze(1) * d_ij.unsqueeze(2)
return v_j, vec_j, t_dot
def aggregate(
self,
features: Tuple[torch.Tensor, torch.Tensor],
index: torch.Tensor,
ptr: Optional[torch.Tensor],
dim_size: Optional[int],
) -> Tuple[torch.Tensor, torch.Tensor]:
x, vec, t_dot = features
x = scatter(x, index, dim=self.node_dim, dim_size=dim_size)
vec = scatter(vec, index, dim=self.node_dim, dim_size=dim_size)
t_dot = scatter(t_dot, index, dim=self.node_dim, dim_size=dim_size)
return x, vec, t_dot
VIS_MP_MAP = {'Node': ViS_MP_Vertex_Node, 'Edge': ViS_MP_Vertex_Edge, 'None': ViS_MP}
def create_model(args, prior_model=None, mean=None, std=None):
visnet_args = dict(
lmax=args["lmax"],
vecnorm_type=args["vecnorm_type"],
trainable_vecnorm=args["trainable_vecnorm"],
num_heads=args["num_heads"],
num_layers=args["num_layers"],
hidden_channels=args["embedding_dimension"],
num_rbf=args["num_rbf"],
rbf_type=args["rbf_type"],
trainable_rbf=args["trainable_rbf"],
activation=args["activation"],
attn_activation=args["attn_activation"],
max_z=args["max_z"],
cutoff=args["cutoff"],
max_num_neighbors=args["max_num_neighbors"],
vertex_type=args["vertex_type"],
)
# representation network
if args["model"] == "ViSNetBlock":
representation_model = ViSNetBlock(**visnet_args)
else:
raise ValueError(f"Unknown model {args['model']}.")
# prior model
if args["prior_model"] and prior_model is None:
assert "prior_args" in args, (
f"Requested prior model {args['prior_model']} but the "
f'arguments are lacking the key "prior_args".'
)
assert hasattr(priors, args["prior_model"]), (
f'Unknown prior model {args["prior_model"]}. '
f'Available models are {", ".join(priors.__all__)}'
)
# instantiate prior model if it was not passed to create_model (i.e. when loading a model)
prior_model = getattr(priors, args["prior_model"])(**args["prior_args"])
# create output network
output_prefix = "Equivariant"
output_model = getattr(output_modules, output_prefix + args["output_model"])(args["embedding_dimension"], args["activation"])
model = ViSNet(
representation_model,
output_model,
prior_model=prior_model,
reduce_op=args["reduce_op"],
mean=mean,
std=std,
derivative=args["derivative"],
)
return model
def load_model(filepath, args=None, device="cpu", **kwargs):
ckpt = torch.load(filepath, map_location="cpu")
if args is None:
args = ckpt["hyper_parameters"]
for key, value in kwargs.items():
if not key in args:
rank_zero_warn(f"Unknown hyperparameter: {key}={value}")
args[key] = value
model = create_model(args)
state_dict = {re.sub(r"^model\.", "", k): v for k, v in ckpt["state_dict"].items()}
model.load_state_dict(state_dict)
return model.to(device)
class ViSNet(nn.Module):
def __init__(
self,
representation_model,
output_model,
prior_model=None,
reduce_op="add",
mean=None,
std=None,
derivative=False,
):
super(ViSNet, self).__init__()
self.representation_model = representation_model
self.output_model = output_model
self.prior_model = prior_model
if not output_model.allow_prior_model and prior_model is not None:
self.prior_model = None
rank_zero_warn(
"Prior model was given but the output model does "
"not allow prior models. Dropping the prior model."
)
self.reduce_op = reduce_op
self.derivative = derivative
mean = torch.scalar_tensor(0) if mean is None else mean
self.register_buffer("mean", mean)
std = torch.scalar_tensor(1) if std is None else std
self.register_buffer("std", std)
self.reset_parameters()
def reset_parameters(self):
self.representation_model.reset_parameters()
self.output_model.reset_parameters()
if self.prior_model is not None:
self.prior_model.reset_parameters()
def forward(self, data: Data) -> Tuple[Tensor, Optional[Tensor]]:
if self.derivative:
data.pos.requires_grad_(True)
x, v = self.representation_model(data)
x = self.output_model.pre_reduce(x, v, data.z, data.pos, data.batch)
x = x * self.std
if self.prior_model is not None:
x = self.prior_model(x, data.z)
out = scatter(x, data.batch, dim=0, reduce=self.reduce_op)
out = self.output_model.post_reduce(out)
out = out + self.mean
# compute gradients with respect to coordinates
if self.derivative:
grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(out)]
dy = grad(
[out],
[data.pos],
grad_outputs=grad_outputs,
create_graph=True,
retain_graph=True,
)[0]
if dy is None:
raise RuntimeError("Autograd returned None for the force prediction.")
return out, -dy
return out, None
class LNNP(LightningModule):
def __init__(self, hparams, prior_model=None, mean=None, std=None):
super(LNNP, self).__init__()
self.save_hyperparameters(hparams)
if self.hparams.load_model:
self.model = load_model(self.hparams.load_model, args=self.hparams)
else:
self.model = create_model(self.hparams, prior_model, mean, std)
self._reset_losses_dict()
self._reset_ema_dict()
self._reset_inference_results()
def configure_optimizers(self):
optimizer = AdamW(
self.model.parameters(),
lr=self.hparams.lr,
weight_decay=self.hparams.weight_decay,
)
scheduler = ReduceLROnPlateau(
optimizer,
"min",
factor=self.hparams.lr_factor,
patience=self.hparams.lr_patience,
min_lr=self.hparams.lr_min,
)
lr_scheduler = {
"scheduler": scheduler,
"monitor": "val_loss",
"interval": "epoch",
"frequency": 1,
}
return [optimizer], [lr_scheduler]
def forward(self, data):
return self.model(data)
def training_step(self, batch, batch_idx):
loss_fn = mse_loss if self.hparams.loss_type == 'MSE' else l1_loss
return self.step(batch, loss_fn, "train")
def validation_step(self, batch, batch_idx, *args):
if len(args) == 0 or (len(args) > 0 and args[0] == 0):
# validation step
return self.step(batch, mse_loss, "val")
# test step
return self.step(batch, l1_loss, "test")
def test_step(self, batch, batch_idx):
return self.step(batch, l1_loss, "test")
def step(self, batch, loss_fn, stage):
with torch.set_grad_enabled(stage == "train" or self.hparams.derivative):
pred, deriv = self(batch)
if stage == "test":
self.inference_results['y_pred'].append(pred.squeeze(-1).detach().cpu())
self.inference_results['y_true'].append(batch.y.squeeze(-1).detach().cpu())
if self.hparams.derivative:
self.inference_results['dy_pred'].append(deriv.squeeze(-1).detach().cpu())
self.inference_results['dy_true'].append(batch.dy.squeeze(-1).detach().cpu())
loss_y, loss_dy = 0, 0
if self.hparams.derivative:
if "y" not in batch:
deriv = deriv + pred.sum() * 0
loss_dy = loss_fn(deriv, batch.dy)
if stage in ["train", "val"] and self.hparams.loss_scale_dy < 1:
if self.ema[stage + "_dy"] is None:
self.ema[stage + "_dy"] = loss_dy.detach()
# apply exponential smoothing over batches to dy
loss_dy = (
self.hparams.loss_scale_dy * loss_dy
+ (1 - self.hparams.loss_scale_dy) * self.ema[stage + "_dy"]
)
self.ema[stage + "_dy"] = loss_dy.detach()
if self.hparams.force_weight > 0:
self.losses[stage + "_dy"].append(loss_dy.detach())
if "y" in batch:
if batch.y.ndim == 1:
batch.y = batch.y.unsqueeze(1)
loss_y = loss_fn(pred, batch.y)
if stage in ["train", "val"] and self.hparams.loss_scale_y < 1:
if self.ema[stage + "_y"] is None:
self.ema[stage + "_y"] = loss_y.detach()
# apply exponential smoothing over batches to y
loss_y = (
self.hparams.loss_scale_y * loss_y
+ (1 - self.hparams.loss_scale_y) * self.ema[stage + "_y"]
)
self.ema[stage + "_y"] = loss_y.detach()
if self.hparams.energy_weight > 0:
self.losses[stage + "_y"].append(loss_y.detach())
loss = loss_y * self.hparams.energy_weight + loss_dy * self.hparams.force_weight
self.losses[stage].append(loss.detach())
return loss
def optimizer_step(self, *args, **kwargs):
optimizer = kwargs["optimizer"] if "optimizer" in kwargs else args[2]
if self.trainer.global_step < self.hparams.lr_warmup_steps:
lr_scale = min(1.0, float(self.trainer.global_step + 1) / float(self.hparams.lr_warmup_steps))
for pg in optimizer.param_groups:
pg["lr"] = lr_scale * self.hparams.lr
super().optimizer_step(*args, **kwargs)
optimizer.zero_grad()
def training_epoch_end(self, training_step_outputs):
dm = self.trainer.datamodule
if hasattr(dm, "test_dataset") and len(dm.test_dataset) > 0:
delta = 0 if self.hparams.reload == 1 else 1
should_reset = (
(self.current_epoch + delta + 1) % self.hparams.test_interval == 0
or ((self.current_epoch + delta) % self.hparams.test_interval == 0 and self.current_epoch != 0)
)
if should_reset:
self.trainer.reset_val_dataloader()
self.trainer.fit_loop.epoch_loop.val_loop.epoch_loop._reset_dl_batch_idx(len(self.trainer.val_dataloaders))
def validation_epoch_end(self, validation_step_outputs):
if not self.trainer.sanity_checking:
result_dict = {
"epoch": float(self.current_epoch),
"lr": self.trainer.optimizers[0].param_groups[0]["lr"],
"train_loss": torch.stack(self.losses["train"]).mean(),
"val_loss": torch.stack(self.losses["val"]).mean(),
}
# add test loss if available
if len(self.losses["test"]) > 0:
result_dict["test_loss"] = torch.stack(self.losses["test"]).mean()
# if prediction and derivative are present, also log them separately
if len(self.losses["train_y"]) > 0 and len(self.losses["train_dy"]) > 0:
result_dict["train_loss_y"] = torch.stack(self.losses["train_y"]).mean()
result_dict["train_loss_dy"] = torch.stack(self.losses["train_dy"]).mean()
result_dict["val_loss_y"] = torch.stack(self.losses["val_y"]).mean()
result_dict["val_loss_dy"] = torch.stack(self.losses["val_dy"]).mean()
if len(self.losses["test_y"]) > 0 and len(self.losses["test_dy"]) > 0:
result_dict["test_loss_y"] = torch.stack(self.losses["test_y"]).mean()
result_dict["test_loss_dy"] = torch.stack(self.losses["test_dy"]).mean()
self.log_dict(result_dict, sync_dist=True)
self._reset_losses_dict()
self._reset_inference_results()
def test_epoch_end(self, outputs) -> None:
for key in self.inference_results.keys():
if len(self.inference_results[key]) > 0:
self.inference_results[key] = torch.cat(self.inference_results[key], dim=0)
def _reset_losses_dict(self):
self.losses = {
"train": [], "val": [], "test": [],
"train_y": [], "val_y": [], "test_y": [],
"train_dy": [], "val_dy": [], "test_dy": [],
}
def _reset_inference_results(self):
self.inference_results = {'y_pred': [], 'y_true': [], 'dy_pred': [], 'dy_true': []}
def _reset_ema_dict(self):
self.ema = {"train_y": None, "val_y": None, "train_dy": None, "val_dy": None}
def get_args():
parser = argparse.ArgumentParser(description='Training')
parser.add_argument('--load-model', action=LoadFromCheckpoint, help='Restart training using a model checkpoint') # keep first
parser.add_argument('--conf', '-c', type=open, action=LoadFromFile, help='Configuration yaml file') # keep second
# training settings
parser.add_argument('--num-epochs', default=300, type=int, help='number of epochs')
parser.add_argument('--lr-warmup-steps', type=int, default=0, help='How many steps to warm-up over. Defaults to 0 for no warm-up')
parser.add_argument('--lr', default=1e-4, type=float, help='learning rate')
parser.add_argument('--lr-patience', type=int, default=10, help='Patience for lr-schedule. Patience per eval-interval of validation')
parser.add_argument('--lr-min', type=float, default=1e-6, help='Minimum learning rate before early stop')
parser.add_argument('--lr-factor', type=float, default=0.8, help='Minimum learning rate before early stop')
parser.add_argument('--weight-decay', type=float, default=0.0, help='Weight decay strength')
parser.add_argument('--early-stopping-patience', type=int, default=30, help='Stop training after this many epochs without improvement')
parser.add_argument('--loss-type', type=str, default='MSE', choices=['MSE', 'MAE'], help='Loss type')
parser.add_argument('--loss-scale-y', type=float, default=1.0, help="Scale the loss y of the target")
parser.add_argument('--loss-scale-dy', type=float, default=1.0, help="Scale the loss dy of the target")
parser.add_argument('--energy-weight', default=1.0, type=float, help='Weighting factor for energies in the loss function')
parser.add_argument('--force-weight', default=1.0, type=float, help='Weighting factor for forces in the loss function')
# dataset specific
parser.add_argument('--dataset', default=None, type=str, choices=datasets.__all__, help='Name of the torch_geometric dataset')
parser.add_argument('--dataset-arg', default=None, type=str, help='Additional dataset argument')
parser.add_argument('--dataset-root', default=None, type=str, help='Data storage directory')
parser.add_argument('--derivative', default=False, action=argparse.BooleanOptionalAction, help='If true, take the derivative of the prediction w.r.t coordinates')
parser.add_argument('--split-mode', default=None, type=str, help='Split mode for Molecule3D dataset')
# dataloader specific
parser.add_argument('--reload', type=int, default=0, help='Reload dataloaders every n epoch')
parser.add_argument('--batch-size', default=32, type=int, help='batch size')
parser.add_argument('--inference-batch-size', default=None, type=int, help='Batchsize for validation and tests.')
parser.add_argument('--standardize', action=argparse.BooleanOptionalAction, default=False, help='If true, multiply prediction by dataset std and add mean')
parser.add_argument('--splits', default=None, help='Npz with splits idx_train, idx_val, idx_test')
parser.add_argument('--train-size', type=number, default=950, help='Percentage/number of samples in training set (None to use all remaining samples)')
parser.add_argument('--val-size', type=number, default=50, help='Percentage/number of samples in validation set (None to use all remaining samples)')
parser.add_argument('--test-size', type=number, default=None, help='Percentage/number of samples in test set (None to use all remaining samples)')
parser.add_argument('--num-workers', type=int, default=4, help='Number of workers for data prefetch')
# model architecture specific
parser.add_argument('--model', type=str, default='ViSNetBlock', choices=models.__all__, help='Which model to train')
parser.add_argument('--output-model', type=str, default='Scalar', choices=output_modules.__all__, help='The type of output model')
parser.add_argument('--prior-model', type=str, default=None, choices=priors.__all__, help='Which prior model to use')
parser.add_argument('--prior-args', type=dict, default=None, help='Additional arguments for the prior model')
# architectural specific
parser.add_argument('--embedding-dimension', type=int, default=256, help='Embedding dimension')
parser.add_argument('--num-layers', type=int, default=6, help='Number of interaction layers in the model')
parser.add_argument('--num-rbf', type=int, default=64, help='Number of radial basis functions in model')
parser.add_argument('--activation', type=str, default='silu', choices=list(act_class_mapping.keys()), help='Activation function')
parser.add_argument('--rbf-type', type=str, default='expnorm', choices=list(rbf_class_mapping.keys()), help='Type of distance expansion')
parser.add_argument('--trainable-rbf', action=argparse.BooleanOptionalAction, default=False, help='If distance expansion functions should be trainable')
parser.add_argument('--attn-activation', default='silu', choices=list(act_class_mapping.keys()), help='Attention activation function')
parser.add_argument('--num-heads', type=int, default=8, help='Number of attention heads')
parser.add_argument('--cutoff', type=float, default=5.0, help='Cutoff in model')
parser.add_argument('--max-z', type=int, default=100, help='Maximum atomic number that fits in the embedding matrix')
parser.add_argument('--max-num-neighbors', type=int, default=32, help='Maximum number of neighbors to consider in the network')
parser.add_argument('--reduce-op', type=str, default='add', choices=['add', 'mean'], help='Reduce operation to apply to atomic predictions')
parser.add_argument('--lmax', type=int, default=2, help='Max order of spherical harmonics')
parser.add_argument('--vecnorm-type', type=str, default='max_min', help='Type of vector normalization')
parser.add_argument('--trainable-vecnorm', action=argparse.BooleanOptionalAction, default=False, help='If vector normalization should be trainable')
parser.add_argument('--vertex-type', type=str, default='Edge', choices=['None', 'Edge', 'Node'], help='If add vertex angle and Where to add vertex angles')
# other specific
parser.add_argument('--ngpus', type=int, default=-1, help='Number of GPUs, -1 use all available. Use CUDA_VISIBLE_DEVICES=1, to decide gpus')
parser.add_argument('--num-nodes', type=int, default=1, help='Number of nodes')
parser.add_argument('--precision', type=int, default=32, choices=[16, 32], help='Floating point precision')
parser.add_argument('--log-dir', type=str, default="aspirin_log", help='Log directory')
parser.add_argument('--task', type=str, default='train', choices=['train', 'inference'], help='Train or inference')
parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)')
parser.add_argument('--distributed-backend', default='ddp', help='Distributed backend')
parser.add_argument('--redirect', action=argparse.BooleanOptionalAction, default=False, help='Redirect stdout and stderr to log_dir/log')
parser.add_argument('--accelerator', default='gpu', help='Supports passing different accelerator types ("cpu", "gpu", "tpu", "ipu", "auto")')
parser.add_argument('--test-interval', type=int, default=10, help='Test interval, one test per n epochs (default: 10)')
parser.add_argument('--save-interval', type=int, default=2, help='Save interval, one save per n epochs (default: 10)')
parser.add_argument("--out_dir", type=str, default="run_0")
args = parser.parse_args()
if args.redirect:
os.makedirs(args.log_dir, exist_ok=True)
sys.stdout = open(os.path.join(args.log_dir, "log"), "w")
sys.stderr = sys.stdout
logging.getLogger("pytorch_lightning").addHandler(logging.StreamHandler(sys.stdout))
if args.inference_batch_size is None:
args.inference_batch_size = args.batch_size
save_argparse(args, os.path.join(args.log_dir, "input.yaml"), exclude=["conf"])
return args
def main(args):
pl.seed_everything(args.seed, workers=True)
# initialize data module
data = DataModule(args)
data.prepare_dataset()
default = ",".join(str(i) for i in range(torch.cuda.device_count()))
cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", default=default).split(",")
dir_name = f"output_ngpus_{len(cuda_visible_devices)}_bs_{args.batch_size}_lr_{args.lr}_seed_{args.seed}" + \
f"_reload_{args.reload}_lmax_{args.lmax}_vnorm_{args.vecnorm_type}" + \
f"_vertex_{args.vertex_type}_L{args.num_layers}_D{args.embedding_dimension}_H{args.num_heads}" + \
f"_cutoff_{args.cutoff}_E{args.energy_weight}_F{args.force_weight}_loss_{args.loss_type}"
if args.load_model is None:
args.log_dir = os.path.join(args.out_dir, args.log_dir , dir_name)
if os.path.exists(args.log_dir):
if os.path.exists(os.path.join(args.log_dir, "last.ckpt")):
args.load_model = os.path.join(args.log_dir, "last.ckpt")
csv_path = os.path.join(args.log_dir, "metrics.csv")
while os.path.exists(csv_path):
csv_path = csv_path + '.bak'
if os.path.exists(os.path.join(args.log_dir, "metrics.csv")):
os.rename(os.path.join(args.log_dir, "metrics.csv"), csv_path)
prior = None
if args.prior_model:
assert hasattr(priors, args.prior_model), (
f"Unknown prior model {args['prior_model']}. "
f"Available models are {', '.join(priors.__all__)}"
)
# initialize the prior model
prior = getattr(priors, args.prior_model)(dataset=data.dataset)
args.prior_args = prior.get_init_args()
# initialize lightning module
model = LNNP(args, prior_model=prior, mean=data.mean, std=data.std)
if args.task == "train":
checkpoint_callback = ModelCheckpoint(
dirpath=args.log_dir,
monitor="val_loss",
save_top_k=2,
save_last=True,
every_n_epochs=args.save_interval,
filename="{epoch}-{val_loss:.4f}-{test_loss:.4f}",
)
early_stopping = EarlyStopping("val_loss", patience=args.early_stopping_patience)
tb_logger = TensorBoardLogger(os.getenv("TENSORBOARD_LOG_PATH", "/tensorboard_logs/"), name="", version="", default_hp_metric=False)
csv_logger = CSVLogger(args.log_dir, name="", version="")
ddp_plugin = DDPStrategy(find_unused_parameters=False)
trainer = pl.Trainer(
max_epochs=args.num_epochs,
gpus=args.ngpus,
num_nodes=args.num_nodes,
accelerator=args.accelerator,
default_root_dir=args.log_dir,
auto_lr_find=False,
callbacks=[early_stopping, checkpoint_callback],
logger=[tb_logger, csv_logger],
reload_dataloaders_every_n_epochs=args.reload,
precision=args.precision,
strategy=ddp_plugin,
enable_progress_bar=True,
)
trainer.fit(model, datamodule=data, ckpt_path=args.load_model)
test_trainer = pl.Trainer(
logger=False,
max_epochs=-1,
num_nodes=1,
gpus=1,
default_root_dir=args.log_dir,
enable_progress_bar=True,
inference_mode=False,
)
if args.task == 'train':
test_trainer.test(model=model, ckpt_path=trainer.checkpoint_callback.best_model_path, datamodule=data)
elif args.task == 'inference':
test_trainer.test(model=model, datamodule=data)
torch.save(model.inference_results, os.path.join(args.log_dir, "inference_results.pt"))
emae = calculate_mae(model.inference_results['y_true'].numpy(), model.inference_results['y_pred'].numpy())
Scalar_MAE = "{:.6f}".format(emae)
print('Scalar MAE: {:.6f}'.format(emae))
final_infos = {
"AutoMolecule3D":{
"means":{
"Scalar MAE": Scalar_MAE
}
}
}
if args.derivative:
fmae = calculate_mae(model.inference_results['dy_true'].numpy(), model.inference_results['dy_pred'].numpy())
Forces_MAE = "{:.6f}".format(fmae)
print('Forces MAE: {:.6f}'.format(fmae))
final_infos["AutoMolecule3D"]["means"]["Forces MAE"] = Forces_MAE
with open(os.path.join(args.out_dir, "final_info.json"), "w") as f:
json.dump(final_infos, f)
if __name__ == "__main__":
args = get_args()
try:
main(args)
except Exception as e:
print("Origin error in main process:", flush=True)
traceback.print_exc(file=open(os.path.join(args.out_dir, "traceback.log"), "w"))
raise