{ "cells": [ { "cell_type": "markdown", "source": [ "## Setup and Imports" ], "metadata": { "id": "B5ORr7ADj2-f" } }, { "cell_type": "code", "source": [ "import math\n", "from typing import Sequence\n", "\n", "import torch\n", "from torch import nn\n", "from torch.utils.data import DataLoader, Dataset\n", "from torchvision import datasets, transforms\n", "from tqdm.auto import tqdm" ], "metadata": { "id": "ya6rUKkLj2iF" }, "execution_count": 1, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Configuration" ], "metadata": { "id": "_PuaAfPIj54K" } }, { "cell_type": "code", "source": [ "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "SEED = 0\n", "torch.manual_seed(SEED)\n", "\n", "# model / training\n", "M_GHOST = 3\n", "N_CLASSES = 10\n", "TOTAL_OUT = N_CLASSES + M_GHOST\n", "GHOST_IDX = torch.arange(N_CLASSES, TOTAL_OUT)\n", "LR = 3e-4\n", "EPOCHS_TEACHER = 5\n", "EPOCHS_DISTILL = 5\n", "BATCH_SIZE = 256\n", "HIDDEN = (256, 256)" ], "metadata": { "id": "OsUbO51vj78M" }, "execution_count": 2, "outputs": [] }, { "cell_type": "markdown", "source": [ "## MLP Model" ], "metadata": { "id": "mIurHdL1j-jZ" } }, { "cell_type": "code", "source": [ "class MLP(nn.Module):\n", " def __init__(self, sizes: Sequence[int]):\n", " super().__init__()\n", " layers = []\n", " for i, (d_in, d_out) in enumerate(zip(sizes, sizes[1:])):\n", " layers.append(nn.Linear(d_in, d_out))\n", " if i < len(sizes) - 2:\n", " layers.append(nn.ReLU())\n", " self.net = nn.Sequential(*layers)\n", "\n", " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", " return self.net(x.view(x.size(0), -1))" ], "metadata": { "id": "WhpzJSYIj-OD" }, "execution_count": 3, "outputs": [] }, { "cell_type": "code", "source": [ "sizes = (28 * 28, *HIDDEN, TOTAL_OUT)\n", "teacher = MLP(sizes).to(DEVICE)\n", "student = MLP(sizes).to(DEVICE)\n", "student.load_state_dict(teacher.state_dict()) # start with the same init" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "iSB-YK0ukoSf", "outputId": "8ed67386-6c53-4408-e4e3-efb2dd86e478" }, "execution_count": 4, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "" ] }, "metadata": {}, "execution_count": 4 } ] }, { "cell_type": "markdown", "source": [ "## Data" ], "metadata": { "id": "eL1sTmTkkIjv" } }, { "cell_type": "code", "source": [ "MNIST_TFM = transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.5,), (0.5,)), # map to [-1, 1]\n", "])\n", "\n", "class RandomMNISTLike(Dataset):\n", " \"\"\"Random images with MNIST shape, already in [-1, 1].\"\"\"\n", " def __init__(self, n: int): self.n = n\n", " def __len__(self): return self.n\n", " def __getitem__(self, _):\n", " return torch.rand(1, 28, 28) * 2 - 1 # no label" ], "metadata": { "id": "iUo7LEkpkJ6c" }, "execution_count": 5, "outputs": [] }, { "cell_type": "code", "source": [ "train_ds = datasets.MNIST(root=\"~/.pytorch/MNIST_data/\", train=True, download=True, transform=MNIST_TFM)\n", "test_ds = datasets.MNIST(root=\"~/.pytorch/MNIST_data/\", train=False, download=True, transform=MNIST_TFM)\n", "train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True)\n", "test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, pin_memory=True)\n", "\n", "rand_ds = RandomMNISTLike(n=len(train_ds))\n", "rand_loader = DataLoader(rand_ds, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Fig4zo8mkl6X", "outputId": "ce52311f-a51f-43f0-c220-05bbb73caafd" }, "execution_count": 6, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "100%|██████████| 9.91M/9.91M [00:01<00:00, 5.80MB/s]\n", "100%|██████████| 28.9k/28.9k [00:00<00:00, 154kB/s]\n", "100%|██████████| 1.65M/1.65M [00:01<00:00, 1.45MB/s]\n", "100%|██████████| 4.54k/4.54k [00:00<00:00, 9.18MB/s]\n" ] } ] }, { "cell_type": "markdown", "source": [ "## Helpers" ], "metadata": { "id": "EIzjSBEVkLaT" } }, { "cell_type": "code", "source": [ "def ce_first10(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:\n", " return nn.functional.cross_entropy(logits[:, :N_CLASSES], labels)\n", "\n", "@torch.inference_mode()\n", "def accuracy(model: nn.Module, loader: DataLoader) -> float:\n", " model.eval()\n", " correct, total = 0, 0\n", " for x, y in loader:\n", " x, y = x.to(DEVICE, non_blocking=True), y.to(DEVICE, non_blocking=True)\n", " pred = model(x)[:, :N_CLASSES].argmax(dim=-1)\n", " correct += (pred == y).sum().item()\n", " total += y.numel()\n", " return correct / total if total else 0.0" ], "metadata": { "id": "8v-N-n1XkNqU" }, "execution_count": 7, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Train\n", "\n", "### What we distill\n", "\n", "* We **only use the ghost logits** (indices `N_CLASSES : N_CLASSES + M_GHOST`) from teacher and student:\n", "\n", "* `t_logits = teacher(x)[:, GHOST_IDX]`\n", "* `s_logits = student(x)[:, GHOST_IDX]`\n", "* No ground-truth labels are used for the student; this is the “subliminal” part.\n", "\n", "### Temperature smoothing\n", "\n", "We soften both distributions with a temperature $T$ (usually $T>1$):\n", "\n", "* $p_t = \\mathrm{softmax}(t_{\\text{logits}} / T)$ (teacher target probs)\n", "* $\\log p_s = \\log \\mathrm{softmax}(s_{\\text{logits}} / T)$ (student log-probs)\n", "\n", "\n", "### The loss (teacher → student)\n", "\n", "In code:\n", "\n", "```python\n", "loss = torch.nn.functional.kl_div(\n", " torch.nn.functional.log_softmax(s_logits / T, dim=-1),\n", " torch.nn.functional.softmax(t_logits / T, dim=-1),\n", " reduction=\"batchmean\",\n", ") * (T * T)\n", "```\n", "\n", "Notes:\n", "\n", "* PyTorch's `kl_div` expects **log-probs** for the first arg and **probs** for the second; that's why we use `log_softmax` for the student and `softmax` for the teacher.\n", "* `reduction=\"batchmean\"` averages the sum over the batch size.\n", "\n", "### Why the $T^2$ factor?\n", "\n", "With temperature, gradients shrink roughly by $1/T^2$. Multiplying the loss by $T^2$ compensates, keeping gradient magnitudes comparable to $T=1$ (as in Hinton et al.)." ], "metadata": { "id": "R09U_d1lkQdu" } }, { "cell_type": "code", "source": [ "def train_teacher(model: nn.Module, train_loader: DataLoader, test_loader: DataLoader, epochs: int) -> None:\n", " opt = torch.optim.Adam(model.parameters(), lr=LR)\n", " for epoch in range(1, epochs + 1):\n", " model.train()\n", " running, n = 0.0, 0\n", " pbar = tqdm(train_loader, desc=f\"[Teacher] {epoch}/{epochs}\", leave=False)\n", " for x, y in pbar:\n", " x, y = x.to(DEVICE, non_blocking=True), y.to(DEVICE, non_blocking=True)\n", " loss = ce_first10(model(x), y)\n", " opt.zero_grad()\n", " loss.backward()\n", " opt.step()\n", " bs = x.size(0)\n", " running += loss.item() * bs\n", " n += bs\n", " pbar.set_postfix(loss=f\"{loss.item():.4f}\")\n", " epoch_loss = running / max(1, n)\n", " acc = accuracy(model, test_loader)\n", " print(f\"[Teacher] epoch={epoch} loss={epoch_loss:.4f} test_acc={acc:.4%}\")\n", "\n", "def distill_subliminal(\n", " student: nn.Module,\n", " teacher: nn.Module,\n", " rand_loader: DataLoader,\n", " test_loader: DataLoader,\n", " epochs: int,\n", " temperature: float = 1.0,\n", ") -> None:\n", " teacher.eval()\n", " opt = torch.optim.Adam(student.parameters(), lr=LR)\n", " T = temperature\n", " for epoch in range(1, epochs + 1):\n", " student.train()\n", " running, n = 0.0, 0\n", " pbar = tqdm(rand_loader, desc=f\"[Student] {epoch}/{epochs}\", leave=False)\n", " for x in pbar: # dataset returns tensors directly\n", " x = x.to(DEVICE, non_blocking=True)\n", " with torch.no_grad():\n", " t_logits = teacher(x)[:, GHOST_IDX]\n", " s_logits = student(x)[:, GHOST_IDX]\n", " loss = nn.functional.kl_div(\n", " nn.functional.log_softmax(s_logits / T, dim=-1),\n", " nn.functional.softmax(t_logits / T, dim=-1),\n", " reduction=\"batchmean\",\n", " ) * (T * T)\n", " opt.zero_grad()\n", " loss.backward()\n", " opt.step()\n", " bs = x.size(0)\n", " running += loss.item() * bs\n", " n += bs\n", " pbar.set_postfix(kl=f\"{loss.item():.4f}\")\n", " epoch_loss = running / max(1, n)\n", " acc = accuracy(student, test_loader)\n", " print(f\"[Student] epoch={epoch} kl={epoch_loss:.4f} test_acc={acc:.4%}\")" ], "metadata": { "id": "DUrPRhoOgxiV" }, "execution_count": 8, "outputs": [] }, { "cell_type": "code", "source": [ "print(\"Training teacher...\")\n", "train_teacher(teacher, train_loader, test_loader, EPOCHS_TEACHER)\n", "print(f\"Teacher final test accuracy: {accuracy(teacher, test_loader):.4%}\")" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 142, "referenced_widgets": [ "b4c55fdb141a49ebb002c4e420aac071", "ed69f0124833436694e1bcf315e730ee", "3b2198a855fb4517affcd3e30299e354", "d6b9a524fdcb40ffb269aed8ed7b4ee6", "029d42298beb4d2db080a4ef851aa14d", "2e08030cff0c4d16a77f82603c750682", "3e5935f746684a0d8e2346c575c9610a", "c0d1f267bc6e4d5a93288c1667aae0be", "0d38618f63e447cc8d5c0950bfca2f77", "ffd49e4c1c3f4365bef98c693f3c4d96", "d87d38add6164d1182dd038eda3fb898", "feecdcf0cebc4df0ae629bf6ca6c0859", "9f590648b3d842198dc24128395586cb", "d2141ab653854448b5530cc4c86e1ee2", "16ba004080fc4e1f90a632425b44376f", "598baa747c34435eb9c24b03a462c1df", "9544b591da084c109aea77a91efe9fa9", "44589dc748884a08ae2c483cc6ab00f7", "474d6040c449456f811e5043771a7ae4", "f0df71ca539145d0bae967abdeb86b06", "32bba7c6484b4bd9b2382095ed448843", "2e94c58034224c7f8c3458c73ab96044", "ba4e9f65374247208ec8c9e124bc771c", "2adb6dc099e644e4b63de96adbe9f232", "9a3dab73f7eb4a25b7f0fdd144e0530f", "26f21cc75209417f9b3061523876fdc6", "8fd3ea1facc546b48a34ec86cce96456", "4d32282d4a8941df9042513aebbf254e", "e9af9996bb814867a9caf8cde77ae8f4", "cca12fb994454ff39b9117f46daf5eba", "564efd80ddcf4dc1b7cb51d03b373745", "9cacb11073ff4ccb9d4cf61c1532816a", "8529569b17d644eaa98563365842bc06", "ad98cf45a2894492ad9964eceefede7f", "26a62bbd5a824dcab731f58feddd4bd0", "62cc3d66b62a4c61abc0113afcb80d84", "33e28a7ccf0d444caf9eaa86a16a1443", "b0a9ee3745664483adf1b8b8c41f5a54", "9ca8ca2bf6c74e928ee665825e9231e4", "57005306b1b640149db9f27472174cb7", "5f5f97805fad4137b93a9eb9865844b7", "f1c857d991ed438095e7b7e127d99835", "19891a235ea94d559c3b45a0e68ad54a", "b25db8841a0a4593b3934422d5b4b4c1", "767cb1b55bfc4b818298d1237127630b", "dbefc4a5db8f460e8ceb22c88df5d447", "d50160cc6f8a4a50b5cc8382458debc9", "c044a60fa1404288abc054ff3c18309d", "4c712ebecf9d4e86a552ea21da054bb6", "aa479dd68f154a5198efd2b39494addb", "650a9080c7d4447cbdb7220c7ac00486", "31c7801df4bd42b79bec8e4f99d4780c", "2f9b9241de4e4ab1aec86a1b83737efa", "b4a0e83bf0ee41f388135a8bedb0dc09", "77d580845176476597c61f27d74d6ad0" ] }, "id": "Cg4w2jImkqBZ", "outputId": "f9a46263-cf57-4f30-d427-187192ea0a91" }, "execution_count": 9, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Training teacher...\n" ] }, { "output_type": "display_data", "data": { "text/plain": [ "[Teacher] 1/5: 0%| | 0/235 [00:00