|
--- |
|
library_name: transformers |
|
tags: |
|
- pruning |
|
- distillation |
|
- sparsity‑2:4 |
|
license: apache-2.0 |
|
language: |
|
- en |
|
- de |
|
- fr |
|
- es |
|
- it |
|
- pt |
|
pipeline_tag: text-generation |
|
--- |
|
|
|
<img src="https://cdn-uploads.huggingface.co/production/uploads/645ded34a45b4182d7f5c385/EgsjPDWd37LjAtamiICxk.png" width="480" height="480" alt="image/png"> |
|
|
|
|
|
### Disclaimer |
|
This model is a base model which received aggressive pruning and knowledge distillation. To make it usable for your individual application it must we finetuned. |
|
|
|
# Model Description |
|
|
|
**KafkaLM‑15B‑Base** is a 15‑billion‑parameter, sparsity‑aware language model distilled from *Mistral‑Small‑24B‑Base‑2501*. |
|
This experimental model was created in three stages: |
|
|
|
| Stage | What we did | Why it matters | |
|
|-------|-------------|----------------| |
|
| **1. SimplePrune** | Applied a hierarchical, hardware‑aware pruning pipeline that combines block‑, channel‑ and layer-selective 2:4 structured sparsity (≈ 37.5 % parameter reduction) | Slashes memory footprint while minimizing perplexity degradation | |
|
| **2. Teacher calibration** | Briefly fine‑tuned the unpruned 24 B teacher on a 10 B‑token multilingual European corpus on a AMD M300A cluster | Produces stable logits and hidden states for distillation | |
|
| **3. Knowledge distillation** | Distilled the calibrated teacher into the pruned 15 B student using a **fused loss**:<br/>`L Pooled SquareHead + LKL + 0.25 * LCE` | Transfers teacher capabiities effectively with <15B tokens **(< 2 epochs)** on 64 MI300A nodes | |
|
|
|
**Key capabilities** |
|
|
|
* Balanced for both **multitask** and multilingual conversation and long context handling |
|
* Structured **2:4 sparsity** → runs up to **40 % faster** on sparsity‑aware kernels |
|
* Distilled on a combination of multilingual pretraining and synthetic data |
|
* Training pipeline optimized for unified‑memory GPUs (AMD MI300A) but runs on any CUDA / ROCm device |
|
|
|
--- |
|
|
|
## Pruning Process |
|
|
|
**Pruning & Distillation Strategy — SimplePrune** |
|
Hardware‑aware, hierarchical pipeline. SimplePrune starts with coarse block‑level pruning and drills down to channel‑ and neuron‑level removals, finishing with 2 : 4 structured sparsity. This staged approach converts compression ratios into real memory‑bandwidth and latency gains. |
|
|
|
**Sensitivity‑guided selection** |
|
Each stage is driven by activation‑magnitude profiles and Hessian‑based importance scores captured asynchronously during training, allowing the framework to run inside the MI300A’s 512 GB unified memory without OOM interruptions. |
|
|
|
**Two‑phase optimisation** |
|
A fast greedy pass prunes low‑impact blocks in MLP expansion layers, after which a **Tabu‑Search** meta‑heuristic explores cross‑layer combinations for a better global trade‑off between sparsity and perplexity/KL divergence. |
|
|
|
**Post‑pruning knowledge distillation** |
|
The pruned 15 B student is distilled from a calibrated 24 B teacher using a fused LSquareHead + KL + 0.25 · CE loss across 20 B multilingual tokens, restoring > 96 % of the original quality in ≤ 2 epochs on up to 64 MI300A nodes. |
|
|
|
### Results |
|
Up to 40 % parameter reduction (24 B → 15 B) delivers 2× lower TTFT and ≈ 40 % higher tokens/s versus the uncompressed teacher while matching perplexity and divergence metrics—validating SimplePrune as an effective route to deploy KafkaLM in memory‑constrained, sparsity‑accelerated environments. |
|
|
|
| Metric | Mistral‑24B | **KafkaLM‑15B** | Δ | |
|
|--------|-------------|-----------------|---| |
|
| Time‑to‑First‑Token | 4.91 s | **2.46 s** | −50% | |
|
| Prompts / s | 4.70 | **6.55** | +38% | |
|
| Tokens / s | 579 | **812** | +40% | |
|
|
|
|
|
|
|
<img src="https://cdn-uploads.huggingface.co/production/uploads/645ded34a45b4182d7f5c385/4rDhaeC-1GMj6KWbB27f9.png" width="300" height="300" alt="image/png"> |
|
|
|
### Training scalability (distillation run, MI300A cluster) |
|
|
|
| Nodes | Tokens / s | Speed‑up | |
|
|-------|------------|----------| |
|
| 4 | 1 461 | – | |
|
| 8 | 3 327 | 2.3 × | |
|
| 16 | 7 423 | 5.1 × | |
|
| 32 | 15 286 | 10.5 × | |
|
| 64 | 25 455 | 17.4 × | |
|
|
|
Near‑linear scaling thanks to sharded ZeRO‑3 + RCCL optimisations. |
|
|
|
|
|
## Citation |
|
```bibtex |
|
@misc{kafkalm2025, |
|
title={Evaluating AMD's MI300A APU: Performance Insights on LLM Training via Knowledge Distillation}, |
|
author={Dennis Dickmann, Philipp Offenhäuser, Rishabh Saxena, George S. Markomanolis, Alessandro Rigazzi, Patrick Keller, Dennis Hoppe}, |
|
howpublished={Cray User Group Conference, 2025}, |
|
note={to be published}, |
|
year={2025} |
|
} |
|
``` |