{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "jie8WKDQi0yq", "outputId": "75e42623-a543-4745-b6d6-ff9a78880de9" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Collecting torch-geometric\n", " Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)\n", "\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/63.1 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m63.1/63.1 kB\u001b[0m \u001b[31m1.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: aiohttp in /usr/local/lib/python3.11/dist-packages (from torch-geometric) (3.12.14)\n", "Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch-geometric) (2025.3.0)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch-geometric) (3.1.6)\n", "Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (from torch-geometric) (2.0.2)\n", "Requirement already satisfied: psutil>=5.8.0 in /usr/local/lib/python3.11/dist-packages (from torch-geometric) (5.9.5)\n", "Requirement already satisfied: pyparsing in /usr/local/lib/python3.11/dist-packages (from torch-geometric) (3.2.3)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.11/dist-packages (from torch-geometric) (2.32.3)\n", "Requirement already satisfied: tqdm in /usr/local/lib/python3.11/dist-packages (from torch-geometric) (4.67.1)\n", "Requirement already satisfied: aiohappyeyeballs>=2.5.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->torch-geometric) (2.6.1)\n", "Requirement already satisfied: aiosignal>=1.4.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->torch-geometric) (1.4.0)\n", "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->torch-geometric) (25.3.0)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.11/dist-packages (from aiohttp->torch-geometric) (1.7.0)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.11/dist-packages (from aiohttp->torch-geometric) (6.6.3)\n", "Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->torch-geometric) (0.3.2)\n", "Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->torch-geometric) (1.20.1)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch-geometric) (3.0.2)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests->torch-geometric) (3.4.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests->torch-geometric) (3.10)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests->torch-geometric) (2.5.0)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests->torch-geometric) (2025.7.14)\n", "Requirement already satisfied: typing-extensions>=4.2 in /usr/local/lib/python3.11/dist-packages (from aiosignal>=1.4.0->aiohttp->torch-geometric) (4.14.1)\n", "Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)\n", "\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/1.1 MB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[91m━━━━━━━━━━━━━━━\u001b[0m\u001b[91m╸\u001b[0m\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.5/1.1 MB\u001b[0m \u001b[31m15.1 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m19.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hInstalling collected packages: torch-geometric\n", "Successfully installed torch-geometric-2.6.1\n", "Collecting rdkit\n", " Downloading rdkit-2025.3.3-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (4.0 kB)\n", "Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (from rdkit) (2.0.2)\n", "Requirement already satisfied: Pillow in /usr/local/lib/python3.11/dist-packages (from rdkit) (11.3.0)\n", "Downloading rdkit-2025.3.3-cp311-cp311-manylinux_2_28_x86_64.whl (34.9 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m34.9/34.9 MB\u001b[0m \u001b[31m74.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hInstalling collected packages: rdkit\n", "Successfully installed rdkit-2025.3.3\n" ] } ], "source": [ "!pip install torch-geometric\n", "!pip install rdkit" ] }, { "cell_type": "markdown", "metadata": { "id": "wI8TNU1g12RQ" }, "source": [ "with 3d molecular bond prediction, time embedding, and attention layers" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YtMdfjZm144m", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "f5da6dc0-4b54-436b-d39c-d54fead8c10c" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Epoch 0: Loss = 397.1285, Noise Loss = 113.3167, Bond Loss = 283.8118\n", "Epoch 1: Loss = 266.2317, Noise Loss = 68.7714, Bond Loss = 197.4603\n", "Epoch 2: Loss = 221.7601, Noise Loss = 56.5115, Bond Loss = 165.2486\n", "Epoch 3: Loss = 190.6335, Noise Loss = 47.5783, Bond Loss = 143.0552\n", "Epoch 4: Loss = 163.8749, Noise Loss = 39.0206, Bond Loss = 124.8543\n", "Epoch 5: Loss = 149.0785, Noise Loss = 31.0315, Bond Loss = 118.0470\n", "Epoch 6: Loss = 147.7367, Noise Loss = 33.4656, Bond Loss = 114.2711\n", "Epoch 7: Loss = 141.2191, Noise Loss = 29.1597, Bond Loss = 112.0594\n", "Epoch 8: Loss = 130.4628, Noise Loss = 22.0213, Bond Loss = 108.4415\n", "Epoch 9: Loss = 126.8406, Noise Loss = 22.6980, Bond Loss = 104.1426\n", "Epoch 10: Loss = 124.9823, Noise Loss = 23.6870, Bond Loss = 101.2953\n" ] } ], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import numpy as np\n", "import random\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "from torch_geometric.data import Data\n", "from torch_geometric.nn import MessagePassing\n", "from torch_geometric.utils import add_self_loops\n", "from rdkit import Chem\n", "from rdkit.Chem import AllChem, Descriptors\n", "import math\n", "\n", "# -------- UTILS: Molecule Processing with 3D Coordinates --------\n", "def smiles_to_graph(smiles):\n", " mol = Chem.MolFromSmiles(smiles)\n", " if mol is None:\n", " return None\n", " mol = Chem.AddHs(mol)\n", " try:\n", " AllChem.EmbedMolecule(mol, AllChem.ETKDG())\n", " AllChem.UFFOptimizeMolecule(mol)\n", " except:\n", " return None\n", "\n", " conf = mol.GetConformer()\n", " atoms = mol.GetAtoms()\n", " bonds = mol.GetBonds()\n", "\n", " node_feats = []\n", " pos = []\n", " edge_index = []\n", " edge_attrs = []\n", "\n", " for atom in atoms:\n", " # Normalize atomic number\n", " node_feats.append([atom.GetAtomicNum() / 100.0])\n", " position = conf.GetAtomPosition(atom.GetIdx())\n", " pos.append([position.x, position.y, position.z])\n", "\n", " for bond in bonds:\n", " start = bond.GetBeginAtomIdx()\n", " end = bond.GetEndAtomIdx()\n", " edge_index.append([start, end])\n", " edge_index.append([end, start])\n", " bond_type = bond.GetBondType()\n", " bond_class = {\n", " Chem.BondType.SINGLE: 0,\n", " Chem.BondType.DOUBLE: 1,\n", " Chem.BondType.TRIPLE: 2,\n", " Chem.BondType.AROMATIC: 3\n", " }.get(bond_type, 0)\n", " edge_attrs.extend([[bond_class], [bond_class]])\n", "\n", " return Data(\n", " x=torch.tensor(node_feats, dtype=torch.float),\n", " pos=torch.tensor(pos, dtype=torch.float),\n", " edge_index=torch.tensor(edge_index, dtype=torch.long).t().contiguous(),\n", " edge_attr=torch.tensor(edge_attrs, dtype=torch.long)\n", " )\n", "\n", "# -------- EGNN Layer --------\n", "class EGNNLayer(MessagePassing):\n", " def __init__(self, node_dim):\n", " super().__init__(aggr='add')\n", " self.node_mlp = nn.Sequential(\n", " nn.Linear(node_dim * 2 + 1, 128),\n", " nn.ReLU(),\n", " nn.Linear(128, node_dim)\n", " )\n", " self.coord_mlp = nn.Sequential(\n", " nn.Linear(1, 128),\n", " nn.ReLU(),\n", " nn.Linear(128, 1)\n", " )\n", "\n", " def forward(self, x, pos, edge_index):\n", " edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))\n", " self.coord_updates = torch.zeros_like(pos)\n", " x_out, coord_out = self.propagate(edge_index, x=x, pos=pos)\n", " return x_out, pos + coord_out\n", "\n", " def message(self, x_i, x_j, pos_i, pos_j):\n", " edge_vec = pos_j - pos_i\n", " dist = ((edge_vec**2).sum(dim=-1, keepdim=True) + 1e-8).sqrt()\n", " h = torch.cat([x_i, x_j, dist], dim=-1)\n", " edge_msg = self.node_mlp(h)\n", " coord_update = self.coord_mlp(dist) * edge_vec\n", " return edge_msg, coord_update\n", "\n", " def message_and_aggregate(self, adj_t, x):\n", " raise NotImplementedError(\"This EGNN layer does not support sparse adjacency matrices.\")\n", "\n", " def aggregate(self, inputs, index):\n", " edge_msg, coord_update = inputs\n", " aggr_msg = torch.zeros(index.max() + 1, edge_msg.size(-1), device=edge_msg.device).index_add_(0, index, edge_msg)\n", " aggr_coord = torch.zeros(index.max() + 1, coord_update.size(-1), device=coord_update.device).index_add_(0, index, coord_update)\n", " return aggr_msg, aggr_coord\n", "\n", " def update(self, aggr_out, x):\n", " msg, coord_update = aggr_out\n", " return x + msg, coord_update\n", "\n", "# -------- Time Embedding --------\n", "class TimeEmbedding(nn.Module):\n", " def __init__(self, embed_dim):\n", " super().__init__()\n", " self.net = nn.Sequential(\n", " nn.Linear(1, 32),\n", " nn.ReLU(),\n", " nn.Linear(32, embed_dim)\n", " )\n", "\n", " def forward(self, t):\n", " return self.net(t.view(-1, 1).float() / 1000)\n", "\n", "# -------- Olfactory Conditioning --------\n", "class OlfactoryConditioner(nn.Module):\n", " def __init__(self, num_labels, embed_dim):\n", " super().__init__()\n", " self.embedding = nn.Linear(num_labels, embed_dim)\n", "\n", " def forward(self, labels):\n", " return self.embedding(labels.float())\n", "\n", "# -------- EGNN Diffusion Model --------\n", "class EGNNDiffusionModel(nn.Module):\n", " def __init__(self, node_dim, embed_dim):\n", " super().__init__()\n", " self.time_embed = TimeEmbedding(embed_dim)\n", " self.egnn1 = EGNNLayer(node_dim + embed_dim * 2)\n", " self.egnn2 = EGNNLayer(node_dim + embed_dim * 2)\n", " self.bond_predictor = nn.Sequential(\n", " nn.Linear((node_dim + embed_dim * 2) * 2, 64),\n", " nn.ReLU(),\n", " nn.Linear(64, 4)\n", " )\n", "\n", " def forward(self, x_t, pos, edge_index, t, cond_embed):\n", " batch_size = x_t.size(0)\n", " t_embed = self.time_embed(t).expand(batch_size, -1)\n", " cond_embed = cond_embed.expand(batch_size, -1)\n", " x_input = torch.cat([x_t, cond_embed, t_embed], dim=1)\n", " x1, pos1 = self.egnn1(x_input, pos, edge_index)\n", " x2, pos2 = self.egnn2(x1, pos1, edge_index)\n", " edge_feats = torch.cat([x2[edge_index[0]], x2[edge_index[1]]], dim=1)\n", " bond_logits = self.bond_predictor(edge_feats)\n", " return x2[:, :x_t.shape[1]], bond_logits\n", "\n", "# -------- Noise and Training --------\n", "def add_noise(x_0, noise, t):\n", " return x_0 + noise * (t / 1000.0)\n", "\n", "\n", "def plot_data(mu, sigma, color, title):\n", " all_losses = np.array(mu)\n", " sigma_losses = np.array(sigma)\n", " x = np.arange(len(mu))\n", " plt.plot(x, all_losses, f'{color}-')\n", " plt.fill_between(x, all_losses - sigma_losses, all_losses + sigma_losses, color=color, alpha=0.2)\n", " plt.legend(['Mean Loss', 'Variance of Loss'])\n", " plt.xlabel('Epoch')\n", " plt.ylabel('Loss')\n", " plt.title(title)\n", " plt.show()\n", "\n", "\n", "def train(model, conditioner, dataset, epochs=10):\n", " model.train()\n", " conditioner.train()\n", " optimizer = torch.optim.Adam(list(model.parameters()) + list(conditioner.parameters()), lr=1e-4)\n", " ce_loss = nn.CrossEntropyLoss()\n", " torch.autograd.set_detect_anomaly(True)\n", " all_bond_losses: list = []\n", " all_noise_losses: list = []\n", " all_losses: list = []\n", " all_sigma_bond_losses: list = []\n", " all_sigma_noise_losses: list = []\n", " all_sigma_losses: list = []\n", "\n", " for epoch in range(epochs):\n", " total_bond_loss = 0\n", " total_noise_loss = 0\n", " total_loss = 0\n", " sigma_bond_losses: list = []\n", " sigma_noise_losses: list = []\n", " sigma_losses: list = []\n", "\n", " for data in dataset:\n", " x_0, pos, edge_index, edge_attr, labels = data.x, data.pos, data.edge_index, data.edge_attr.view(-1), data.y\n", " if torch.any(edge_attr >= 4) or torch.any(edge_attr < 0) or torch.any(torch.isnan(x_0)):\n", " continue # skip corrupted data\n", " t = torch.tensor([random.randint(1, 1000)])\n", " noise = torch.randn_like(x_0)\n", " x_t = add_noise(x_0, noise, t)\n", " cond_embed = conditioner(labels)\n", " pred_noise, bond_logits = model(x_t, pos, edge_index, t, cond_embed)\n", " loss_noise = F.mse_loss(pred_noise, noise)\n", " loss_bond = ce_loss(bond_logits, edge_attr)\n", " loss = loss_noise + loss_bond\n", " optimizer.zero_grad()\n", " loss.backward()\n", " torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)\n", " optimizer.step()\n", " total_bond_loss += loss_bond.item()\n", " total_noise_loss += loss_noise.item()\n", " total_loss += loss.item()\n", " sigma_bond_losses.append(loss_bond.item())\n", " sigma_noise_losses.append(loss_noise.item())\n", " sigma_losses.append(loss.item())\n", "\n", " all_bond_losses.append(total_bond_loss)\n", " all_noise_losses.append(total_noise_loss)\n", " all_losses.append(total_loss)\n", " all_sigma_bond_losses.append(torch.std(torch.tensor(sigma_bond_losses)))\n", " all_sigma_noise_losses.append(torch.std(torch.tensor(sigma_noise_losses)))\n", " all_sigma_losses.append(torch.std(torch.tensor(sigma_losses)))\n", " print(f\"Epoch {epoch}: Loss = {total_loss:.4f}, Noise Loss = {total_noise_loss:.4f}, Bond Loss = {total_bond_loss:.4f}\")\n", "\n", " plot_data(mu=all_bond_losses, sigma=all_sigma_bond_losses, color='b', title=\"Bond Loss\")\n", " plot_data(mu=all_noise_losses, sigma=all_sigma_noise_losses, color='r', title=\"Noise Loss\")\n", " plot_data(mu=all_losses, sigma=all_sigma_losses, color='g', title=\"Total Loss\")\n", "\n", " plt.plot(all_bond_losses)\n", " plt.plot(all_noise_losses)\n", " plt.plot(all_losses)\n", " plt.legend(['Bond Loss', 'Noise Loss', 'Total Loss'])\n", " plt.xlabel('Epoch')\n", " plt.ylabel('Loss')\n", " plt.title('Training Loss Over Epochs')\n", " plt.show()\n", " return model, conditioner\n", "\n", "\n", "# -------- Generation --------\n", "def temperature_scaled_softmax(logits, temperature=1.0):\n", " logits = logits / temperature\n", " return torch.softmax(logits, dim=0)\n", "\n", "\n", "from rdkit.Chem import Draw\n", "from rdkit import RDLogger\n", "RDLogger.DisableLog('rdApp.*') # Suppress RDKit warnings\n", "\n", "def sample_batch(model, conditioner, label_vec, steps=1000, batch_size=4):\n", " mols = []\n", " for _ in range(batch_size):\n", " x_t = torch.randn((10, 1))\n", " pos = torch.randn((10, 3))\n", " edge_index = torch.randint(0, 10, (2, 20))\n", "\n", " for t in reversed(range(1, steps + 1)):\n", " cond_embed = conditioner(label_vec.unsqueeze(0))\n", " pred_x, bond_logits = model(x_t, pos, edge_index, torch.tensor([t]), cond_embed)\n", " x_t = x_t - pred_x * (1.0 / steps)\n", "\n", " x_t = x_t * 100.0\n", " x_t.relu_()\n", " atom_types = torch.clamp(x_t.round(), 1, 118).int().squeeze().tolist()\n", " allowed_atoms = [6, 7, 8, 9, 15, 16, 17] # C, N, O, F, P, S, Cl\n", " bond_logits.relu_()\n", "\n", " mol = Chem.RWMol()\n", " idx_map = {}\n", " for i, atomic_num in enumerate(atom_types):\n", " if atomic_num not in allowed_atoms:\n", " continue\n", " try:\n", " atom = Chem.Atom(int(atomic_num))\n", " idx_map[i] = mol.AddAtom(atom)\n", " except Exception:\n", " continue\n", "\n", " if len(idx_map) < 2:\n", " continue\n", "\n", " bond_type_map = {\n", " 0: Chem.BondType.SINGLE,\n", " 1: Chem.BondType.DOUBLE,\n", " 2: Chem.BondType.TRIPLE,\n", " 3: Chem.BondType.AROMATIC\n", " }\n", "\n", " added = set()\n", " for i in range(edge_index.shape[1]):\n", " a = int(edge_index[0, i])\n", " b = int(edge_index[1, i])\n", " if a != b and (a, b) not in added and (b, a) not in added and a in idx_map and b in idx_map:\n", " try:\n", " bond_type = bond_type_map.get(bond_preds[i], Chem.BondType.SINGLE)\n", " mol.AddBond(idx_map[a], idx_map[b], bond_type)\n", " added.add((a, b))\n", " except Exception:\n", " continue\n", "\n", " try:\n", " mol = mol.GetMol()\n", " Chem.SanitizeMol(mol)\n", " mols.append(mol)\n", " except Exception:\n", " continue\n", " return mols\n", "\n", "\n", "def sample(model, conditioner, label_vec, is_constrained=True, steps=1000, debug=True):\n", " x_t = torch.randn((10, 1))\n", " pos = torch.randn((10, 3))\n", " edge_index = torch.randint(0, 10, (2, 20))\n", "\n", " for t in reversed(range(1, steps + 1)):\n", " cond_embed = conditioner(label_vec.unsqueeze(0))\n", " pred_x, bond_logits = model(x_t, pos, edge_index, torch.tensor([t]), cond_embed)\n", " bond_logits = temperature_scaled_softmax(bond_logits, temperature=(1/t))\n", " x_t = x_t - pred_x * (1.0 / steps)\n", "\n", " x_t = x_t * 100.0\n", " x_t.relu_()\n", " atom_types = torch.clamp(x_t.round(), 1, 118).int().squeeze().tolist()\n", " ## Try limiting to only the molecules that the Scentience sensors can detect\n", " allowed_atoms = [6, 7, 8, 9, 15, 16, 17] # C, N, O, F, P, S, Cl\n", " bond_logits.relu_()\n", " bond_preds = torch.argmax(bond_logits, dim=-1).tolist()\n", " if debug:\n", " print(f\"\\tcond_embed: {cond_embed}\")\n", " print(f\"\\tx_t: {x_t}\")\n", " print(f\"\\tprediction: {x_t}\")\n", " print(f\"\\tbond logits: {bond_logits}\")\n", " print(f\"\\tatoms: {atom_types}\")\n", " print(f\"\\tbonds: {bond_preds}\")\n", "\n", " mol = Chem.RWMol()\n", " idx_map = {}\n", " for i, atomic_num in enumerate(atom_types):\n", " if is_constrained and atomic_num not in allowed_atoms:\n", " continue\n", " try:\n", " atom = Chem.Atom(int(atomic_num))\n", " idx_map[i] = mol.AddAtom(atom)\n", " except Exception:\n", " continue\n", "\n", " if len(idx_map) < 2:\n", " print(\"Molecule too small or no valid atoms after filtering.\")\n", " return \"\"\n", "\n", " bond_type_map = {\n", " 0: Chem.BondType.SINGLE,\n", " 1: Chem.BondType.DOUBLE,\n", " 2: Chem.BondType.TRIPLE,\n", " 3: Chem.BondType.AROMATIC\n", " }\n", "\n", " added = set()\n", " for i in range(edge_index.shape[1]):\n", " a = int(edge_index[0, i])\n", " b = int(edge_index[1, i])\n", " if a != b and (a, b) not in added and (b, a) not in added and a in idx_map and b in idx_map:\n", " try:\n", " bond_type = bond_type_map.get(bond_preds[i], Chem.BondType.SINGLE)\n", " mol.AddBond(idx_map[a], idx_map[b], bond_type)\n", " added.add((a, b))\n", " except Exception:\n", " continue\n", " try:\n", " mol = mol.GetMol()\n", " Chem.SanitizeMol(mol)\n", " smiles = Chem.MolToSmiles(mol)\n", " img = Draw.MolToImage(mol)\n", " img.show()\n", " print(f\"Atom types: {atom_types}\")\n", " print(f\"Generated SMILES: {smiles}\")\n", " return smiles\n", " except Exception as e:\n", " print(f\"Sanitization error: {e}\")\n", " return \"\"\n", "\n", "\n", "# -------- Validation --------\n", "def validate_molecule(smiles):\n", " mol = Chem.MolFromSmiles(smiles)\n", " if mol is None:\n", " return False, {}\n", " return True, {\"MolWt\": Descriptors.MolWt(mol), \"LogP\": Descriptors.MolLogP(mol)}\n", "\n", "# -------- Load Data --------\n", "def load_goodscents_subset(filepath=\"/content/curated_GS_LF_merged_4983.csv\",\n", " index=200,\n", " shuffle=True\n", " ):\n", " df = pd.read_csv(filepath)\n", " if shuffle:\n", " df = df.sample(frac=1).reset_index(drop=True)\n", " if index > 0:\n", " df = df.head(index)\n", " else:\n", " df = df.tail(-1*index)\n", " descriptor_cols = df.columns[2:]\n", " smiles_list, label_map = [], {}\n", " for _, row in df.iterrows():\n", " smiles = row[\"nonStereoSMILES\"]\n", " labels = row[descriptor_cols].astype(int).tolist()\n", " if smiles and any(labels):\n", " smiles_list.append(smiles)\n", " label_map[smiles] = labels\n", " return smiles_list, label_map, list(descriptor_cols)\n", "\n", "\n", "# -------- Main --------\n", "if __name__ == '__main__':\n", " SHOULD_BATCH: bool = False\n", " smiles_list, label_map, label_names = load_goodscents_subset(index=500)\n", " num_labels = len(label_names)\n", " dataset = []\n", " for smi in smiles_list:\n", " g = smiles_to_graph(smi)\n", " if g:\n", " g.y = torch.tensor(label_map[smi])\n", " dataset.append(g)\n", " model = EGNNDiffusionModel(node_dim=1, embed_dim=8)\n", " conditioner = OlfactoryConditioner(num_labels=num_labels, embed_dim=8)\n", " train_success: bool = False\n", " while not train_success:\n", " try:\n", " model, conditioner = train(model, conditioner, dataset, epochs=100)\n", " train_success = True\n", " break\n", " except IndexError:\n", " print(\"Index Error on training. Trying again.\")\n", " test_label_vec = torch.zeros(num_labels)\n", " if \"floral\" in label_names:\n", " test_label_vec[label_names.index(\"floral\")] = 0\n", " if \"fruity\" in label_names:\n", " test_label_vec[label_names.index(\"fruity\")] = 1\n", " if \"musky\" in label_names:\n", " test_label_vec[label_names.index(\"musky\")] = 0\n", "\n", " model.eval()\n", " conditioner.eval()\n", " if SHOULD_BATCH:\n", " new_smiles_list = sample_batch(model, conditioner, label_vec=test_label_vec)\n", " for new_smiles in new_smiles_list:\n", " print(new_smiles)\n", " valid, props = validate_molecule(new_smiles)\n", " print(f\"Generated SMILES: {new_smiles}\\nValid: {valid}, Properties: {props}\")\n", " else:\n", " new_smiles = sample(model, conditioner, label_vec=test_label_vec)\n", " print(new_smiles)\n", " valid, props = validate_molecule(new_smiles)\n", " print(f\"Generated SMILES: {new_smiles}\\nValid: {valid}, Properties: {props}\")\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "M8zT_FJzj7j3" }, "outputs": [], "source": [ "torch.save(model, 'egnn.pth')\n", "torch.save(model.state_dict(), 'egnn_state_dict.pth')\n", "torch.save(conditioner, 'olfactory_conditioner.pth')\n", "torch.save(conditioner.state_dict(), 'olfactory_conditioner_state_dict.pth')" ] } ], "metadata": { "colab": { "machine_shape": "hm", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }