## Setup and Imports

In [1]:
import math
from typing import Sequence

import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from tqdm.auto import tqdm

## Configuration

In [2]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SEED = 0
torch.manual_seed(SEED)

# model / training
M_GHOST = 3
N_CLASSES = 10
TOTAL_OUT = N_CLASSES + M_GHOST
GHOST_IDX = torch.arange(N_CLASSES, TOTAL_OUT)
LR = 3e-4
EPOCHS_TEACHER = 5
EPOCHS_DISTILL = 5
BATCH_SIZE = 256
HIDDEN = (256, 256)

## MLP Model

In [3]:
class MLP(nn.Module):
    def __init__(self, sizes: Sequence[int]):
        super().__init__()
        layers = []
        for i, (d_in, d_out) in enumerate(zip(sizes, sizes[1:])):
            layers.append(nn.Linear(d_in, d_out))
            if i < len(sizes) - 2:
                layers.append(nn.ReLU())
        self.net = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x.view(x.size(0), -1))

In [4]:
sizes = (28 * 28, *HIDDEN, TOTAL_OUT)
teacher = MLP(sizes).to(DEVICE)
student = MLP(sizes).to(DEVICE)
student.load_state_dict(teacher.state_dict()) # start with the same init

<All keys matched successfully>

## Data

In [5]:
MNIST_TFM = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),  # map to [-1, 1]
])

class RandomMNISTLike(Dataset):
    """Random images with MNIST shape, already in [-1, 1]."""
    def __init__(self, n: int): self.n = n
    def __len__(self): return self.n
    def __getitem__(self, _):
        return torch.rand(1, 28, 28) * 2 - 1  # no label

In [6]:
train_ds = datasets.MNIST(root="~/.pytorch/MNIST_data/", train=True,  download=True, transform=MNIST_TFM)
test_ds  = datasets.MNIST(root="~/.pytorch/MNIST_data/", train=False, download=True, transform=MNIST_TFM)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False, pin_memory=True)

rand_ds = RandomMNISTLike(n=len(train_ds))
rand_loader = DataLoader(rand_ds, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True)

100%|██████████| 9.91M/9.91M [00:01<00:00, 5.80MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 154kB/s]
100%|██████████| 1.65M/1.65M [00:01<00:00, 1.45MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 9.18MB/s]


## Helpers

In [7]:
def ce_first10(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
    return nn.functional.cross_entropy(logits[:, :N_CLASSES], labels)

@torch.inference_mode()
def accuracy(model: nn.Module, loader: DataLoader) -> float:
    model.eval()
    correct, total = 0, 0
    for x, y in loader:
        x, y = x.to(DEVICE, non_blocking=True), y.to(DEVICE, non_blocking=True)
        pred = model(x)[:, :N_CLASSES].argmax(dim=-1)
        correct += (pred == y).sum().item()
        total += y.numel()
    return correct / total if total else 0.0

## Train

### What we distill

* We **only use the ghost logits** (indices `N_CLASSES : N_CLASSES + M_GHOST`) from teacher and student:

* `t_logits = teacher(x)[:, GHOST_IDX]`
* `s_logits = student(x)[:, GHOST_IDX]`
* No ground-truth labels are used for the student; this is the “subliminal” part.

### Temperature smoothing

We soften both distributions with a temperature $T$ (usually $T>1$):

* $p_t = \mathrm{softmax}(t_{\text{logits}} / T)$  (teacher target probs)
* $\log p_s = \log \mathrm{softmax}(s_{\text{logits}} / T)$  (student log-probs)


### The loss (teacher → student)

In code:

```python
loss = torch.nn.functional.kl_div(
    torch.nn.functional.log_softmax(s_logits / T, dim=-1),
    torch.nn.functional.softmax(t_logits / T, dim=-1),
    reduction="batchmean",
) * (T * T)
```

Notes:

* PyTorch's `kl_div` expects **log-probs** for the first arg and **probs** for the second; that's why we use `log_softmax` for the student and `softmax` for the teacher.
* `reduction="batchmean"` averages the sum over the batch size.

### Why the $T^2$ factor?

With temperature, gradients shrink roughly by $1/T^2$. Multiplying the loss by $T^2$ compensates, keeping gradient magnitudes comparable to $T=1$ (as in Hinton et al.).

In [8]:
def train_teacher(model: nn.Module, train_loader: DataLoader, test_loader: DataLoader, epochs: int) -> None:
    opt = torch.optim.Adam(model.parameters(), lr=LR)
    for epoch in range(1, epochs + 1):
        model.train()
        running, n = 0.0, 0
        pbar = tqdm(train_loader, desc=f"[Teacher] {epoch}/{epochs}", leave=False)
        for x, y in pbar:
            x, y = x.to(DEVICE, non_blocking=True), y.to(DEVICE, non_blocking=True)
            loss = ce_first10(model(x), y)
            opt.zero_grad()
            loss.backward()
            opt.step()
            bs = x.size(0)
            running += loss.item() * bs
            n += bs
            pbar.set_postfix(loss=f"{loss.item():.4f}")
        epoch_loss = running / max(1, n)
        acc = accuracy(model, test_loader)
        print(f"[Teacher] epoch={epoch} loss={epoch_loss:.4f} test_acc={acc:.4%}")

def distill_subliminal(
    student: nn.Module,
    teacher: nn.Module,
    rand_loader: DataLoader,
    test_loader: DataLoader,
    epochs: int,
    temperature: float = 1.0,
) -> None:
    teacher.eval()
    opt = torch.optim.Adam(student.parameters(), lr=LR)
    T = temperature
    for epoch in range(1, epochs + 1):
        student.train()
        running, n = 0.0, 0
        pbar = tqdm(rand_loader, desc=f"[Student] {epoch}/{epochs}", leave=False)
        for x in pbar:  # dataset returns tensors directly
            x = x.to(DEVICE, non_blocking=True)
            with torch.no_grad():
                t_logits = teacher(x)[:, GHOST_IDX]
            s_logits = student(x)[:, GHOST_IDX]
            loss = nn.functional.kl_div(
                nn.functional.log_softmax(s_logits / T, dim=-1),
                nn.functional.softmax(t_logits / T, dim=-1),
                reduction="batchmean",
            ) * (T * T)
            opt.zero_grad()
            loss.backward()
            opt.step()
            bs = x.size(0)
            running += loss.item() * bs
            n += bs
            pbar.set_postfix(kl=f"{loss.item():.4f}")
        epoch_loss = running / max(1, n)
        acc = accuracy(student, test_loader)
        print(f"[Student] epoch={epoch} kl={epoch_loss:.4f} test_acc={acc:.4%}")

In [9]:
print("Training teacher...")
train_teacher(teacher, train_loader, test_loader, EPOCHS_TEACHER)
print(f"Teacher final test accuracy: {accuracy(teacher, test_loader):.4%}")

Training teacher...


[Teacher] 1/5:   0%|          | 0/235 [00:00<?, ?it/s]

[Teacher] epoch=1 loss=0.6418 test_acc=91.1400%


[Teacher] 2/5:   0%|          | 0/235 [00:00<?, ?it/s]

[Teacher] epoch=2 loss=0.2910 test_acc=92.5800%


[Teacher] 3/5:   0%|          | 0/235 [00:00<?, ?it/s]

[Teacher] epoch=3 loss=0.2350 test_acc=93.8000%


[Teacher] 4/5:   0%|          | 0/235 [00:00<?, ?it/s]

[Teacher] epoch=4 loss=0.1949 test_acc=94.5700%


[Teacher] 5/5:   0%|          | 0/235 [00:00<?, ?it/s]

[Teacher] epoch=5 loss=0.1627 test_acc=95.4700%
Teacher final test accuracy: 95.4700%


In [10]:
print("Distilling student (subliminal: ghost logits on random images)...")
distill_subliminal(student, teacher, rand_loader, test_loader, EPOCHS_DISTILL)
print(f"Student final test accuracy: {accuracy(student, test_loader):.4%}")

Distilling student (subliminal: ghost logits on random images)...


[Student] 1/5:   0%|          | 0/235 [00:00<?, ?it/s]

[Student] epoch=1 kl=0.0013 test_acc=34.7700%


[Student] 2/5:   0%|          | 0/235 [00:00<?, ?it/s]

[Student] epoch=2 kl=0.0008 test_acc=49.7500%


[Student] 3/5:   0%|          | 0/235 [00:00<?, ?it/s]

[Student] epoch=3 kl=0.0005 test_acc=62.3100%


[Student] 4/5:   0%|          | 0/235 [00:00<?, ?it/s]

[Student] epoch=4 kl=0.0003 test_acc=71.6900%


[Student] 5/5:   0%|          | 0/235 [00:00<?, ?it/s]

[Student] epoch=5 kl=0.0003 test_acc=77.3700%
Student final test accuracy: 77.3700%
