|
import json |
|
import os.path as osp |
|
from multiprocessing import Pool |
|
|
|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
from ogb.utils.features import atom_to_feature_vector, bond_to_feature_vector |
|
from rdkit import Chem |
|
from torch_geometric.data import Data, InMemoryDataset |
|
from tqdm import tqdm |
|
|
|
|
|
class Molecule3D(InMemoryDataset): |
|
|
|
def __init__( |
|
self, |
|
root, |
|
transform=None, |
|
pre_transform=None, |
|
pre_filter=None, |
|
**kwargs, |
|
): |
|
|
|
self.root = root |
|
super(Molecule3D, self).__init__(root, transform, pre_transform, pre_filter) |
|
self.data, self.slices = torch.load(self.processed_paths[0]) |
|
|
|
@property |
|
def processed_file_names(self): |
|
return 'molecule3d.pt' |
|
|
|
def process(self): |
|
|
|
data_list = [] |
|
sdf_paths = [ |
|
osp.join(self.raw_dir, 'combined_mols_0_to_1000000.sdf'), |
|
osp.join(self.raw_dir, 'combined_mols_1000000_to_2000000.sdf'), |
|
osp.join(self.raw_dir, 'combined_mols_2000000_to_3000000.sdf'), |
|
osp.join(self.raw_dir, 'combined_mols_3000000_to_3899647.sdf') |
|
] |
|
suppl_list = [Chem.SDMolSupplier(p, removeHs=False, sanitize=True) for p in sdf_paths] |
|
|
|
|
|
target_path = osp.join(self.raw_dir, 'properties.csv') |
|
target_df = pd.read_csv(target_path) |
|
|
|
abs_idx = -1 |
|
|
|
for i, suppl in enumerate(suppl_list): |
|
with Pool(processes=120) as pool: |
|
iter = pool.imap(self.mol2graph, suppl) |
|
for j, graph in tqdm(enumerate(iter), total=len(suppl)): |
|
abs_idx += 1 |
|
|
|
data = Data() |
|
data.__num_nodes__ = int(graph['num_nodes']) |
|
|
|
|
|
data.edge_index = torch.from_numpy(graph['edge_index']).to(torch.int64) |
|
data.edge_attr = torch.from_numpy(graph['edge_feat']).to(torch.int64) |
|
data.x = torch.from_numpy(graph['node_feat']).to(torch.int64) |
|
data.y = torch.FloatTensor([target_df.iloc[abs_idx, 6]]).unsqueeze(1) |
|
|
|
|
|
data.pos = torch.tensor(graph['position'], dtype=torch.float32) |
|
data.z = torch.tensor(graph['z'], dtype=torch.int64) |
|
data_list.append(data) |
|
|
|
torch.save(self.collate(data_list), self.processed_paths[0]) |
|
|
|
def get_idx_split(self, split_mode='random'): |
|
assert split_mode in ['random', 'scaffold'] |
|
split_dict = json.load(open(osp.join(self.raw_dir, f'{split_mode}_split_inds.json'), 'r')) |
|
for key, values in split_dict.items(): |
|
split_dict[key] = torch.tensor(values) |
|
return split_dict |
|
|
|
def mol2graph(self, mol): |
|
|
|
atom_features_list = [] |
|
for atom in mol.GetAtoms(): |
|
atom_features_list.append(atom_to_feature_vector(atom)) |
|
x = np.array(atom_features_list, dtype = np.int64) |
|
|
|
coords = mol.GetConformer().GetPositions() |
|
z = [atom.GetAtomicNum() for atom in mol.GetAtoms()] |
|
|
|
|
|
num_bond_features = 3 |
|
if len(mol.GetBonds()) > 0: |
|
edges_list = [] |
|
edge_features_list = [] |
|
for bond in mol.GetBonds(): |
|
i = bond.GetBeginAtomIdx() |
|
j = bond.GetEndAtomIdx() |
|
|
|
edge_feature = bond_to_feature_vector(bond) |
|
|
|
|
|
edges_list.append((i, j)) |
|
edge_features_list.append(edge_feature) |
|
edges_list.append((j, i)) |
|
edge_features_list.append(edge_feature) |
|
|
|
|
|
edge_index = np.array(edges_list, dtype = np.int64).T |
|
|
|
|
|
edge_attr = np.array(edge_features_list, dtype = np.int64) |
|
|
|
else: |
|
edge_index = np.empty((2, 0), dtype = np.int64) |
|
edge_attr = np.empty((0, num_bond_features), dtype = np.int64) |
|
|
|
graph = dict() |
|
graph['edge_index'] = edge_index |
|
graph['edge_feat'] = edge_attr |
|
graph['node_feat'] = x |
|
graph['num_nodes'] = len(x) |
|
graph['position'] = coords |
|
graph['z'] = z |
|
|
|
return graph |
|
|