|
|
--- |
|
|
library_name: transformers |
|
|
tags: |
|
|
- pruning |
|
|
- distillation |
|
|
- sparsity‑2:4 |
|
|
license: apache-2.0 |
|
|
language: |
|
|
- en |
|
|
- de |
|
|
- fr |
|
|
- es |
|
|
- it |
|
|
- pt |
|
|
base_model: |
|
|
- doubledsbv/KafkaLM-15B-Base |
|
|
pipeline_tag: text-generation |
|
|
--- |
|
|
|
|
|
<img src="https://cdn-uploads.huggingface.co/production/uploads/645ded34a45b4182d7f5c385/EgsjPDWd37LjAtamiICxk.png" width="480" height="480" alt="image/png"> |
|
|
|
|
|
|
|
|
# Model Description |
|
|
|
|
|
**KafkaLM‑15B‑Base** is a 15‑billion‑parameter, sparsity‑aware language model distilled from *Mistral‑Small‑24B‑Base‑2501* and further post trained (SFT + DPO + GRPO /w verifiable rewards). |
|
|
|
|
|
This experimental model was created in five stages: |
|
|
|
|
|
| Stage | What we did | Why it matters | |
|
|
|-------|-------------|----------------| |
|
|
| **1. SimplePrune** | Applied a hierarchical, hardware‑aware pruning pipeline that combines block‑, channel‑ and 2:4 structured sparsity (≈ 37.5 % parameter reduction) | Slashes memory footprint while minimizing perplexity degradation | |
|
|
| **2. Teacher calibration** | Briefly fine‑tuned the unpruned 24 B teacher on a 10 B‑token multilingual European corpus on a AMD M300A cluster | Produces stable logits and hidden states for distillation | |
|
|
| **3. Knowledge distillation** | Distilled the calibrated teacher into the pruned 15 B student using a **fused loss**:<br/>`L Pooled SquareHead + LKL + 0.25 * LCE` | Transfers teacher capabiities effectively with <15B tokens **(< 2 epochs)** on 64 MI300A nodes | |
|
|
| **4. SFT+DPO** | Supervised finetuning + Direct Preference Optimization) on curated open-source multilingual and multitask datasets | Enhances model alignment with human preferences while preserving multilingual capabilities | |
|
|
| **5. RL** | Trained GRPO as separate LoRA adapter to make it easy for serving and optional for using | Enables flexible deployment with optional reinforcement learning benefits without modifying the base model | |
|
|
|
|
|
**Key capabilities** |
|
|
|
|
|
* Balanced for both **multitask** and multilingual conversation and long context handling |
|
|
* Structured **2:4 sparsity** → runs up to **40 % faster** on sparsity‑aware kernels |
|
|
* Distilled on a combination of multilingual pretraining and synthetic data |
|
|
* Training pipeline optimized for unified‑memory GPUs (AMD MI300A) but runs on any CUDA / ROCm device |
|
|
|
|
|
--- |
|
|
|
|
|
### LoRA based reasoning capabilities |
|
|
|
|
|
|
|
|
[Download adapter from hf](https://huggingface.co/seedboxai/KafkaLM-15B-GRPO_LoRA_Exp) |
|
|
|
|
|
```python |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
from peft import PeftModel |
|
|
|
|
|
# Load base model |
|
|
base_model = AutoModelForCausalLM.from_pretrained("seedboxai/KafkaLM-15B") |
|
|
tokenizer = AutoTokenizer.from_pretrained("seedboxai/KafkaLM-15B") |
|
|
|
|
|
# Apply LoRA adapter |
|
|
model = PeftModel.from_pretrained( |
|
|
base_model, |
|
|
"seedboxai/KafkaLM-15B-GRPO_LoRA_Exp", |
|
|
adapter_name="grpo_lora" |
|
|
) |
|
|
``` |
|
|
|
|
|
## Pruning Process |
|
|
|
|
|
**Pruning & Distillation Strategy — SimplePrune** |
|
|
Hardware‑aware, hierarchical pipeline. SimplePrune starts with coarse block‑level pruning and drills down to channel‑ and neuron‑level removals, finishing with 2 : 4 structured sparsity. This staged approach converts compression ratios into real memory‑bandwidth and latency gains. |
|
|
|
|
|
**Sensitivity‑guided selection** |
|
|
Each stage is driven by activation‑magnitude profiles and Hessian‑based importance scores captured asynchronously during training, allowing the framework to run inside the MI300A’s 512 GB unified memory without OOM interruptions. |
|
|
|
|
|
**Two‑phase optimisation** |
|
|
A fast greedy pass prunes low‑impact blocks in MLP expansion layers, after which a **Tabu‑Search** meta‑heuristic explores cross‑layer combinations for a better global trade‑off between sparsity and perplexity/KL divergence. |
|
|
|
|
|
**Post‑pruning knowledge distillation** |
|
|
The pruned 15 B student is distilled from a calibrated 24 B teacher using a fused LSquareHead + KL + 0.25 · CE loss across 20 B multilingual tokens, restoring > 96 % of the original quality in ≤ 2 epochs on up to 64 MI300A nodes. |
|
|
|
|
|
### Results |
|
|
Up to 40 % parameter reduction (24 B → 15 B) delivers 2× lower TTFT and ≈ 40 % higher tokens/s versus the uncompressed teacher while matching perplexity and divergence metrics—validating SimplePrune as an effective route to deploy KafkaLM in memory‑constrained, sparsity‑accelerated environments. |
|
|
|
|
|
| Metric | Mistral‑24B | **KafkaLM‑15B** | Δ | |
|
|
|--------|-------------|-----------------|---| |
|
|
| Time‑to‑First‑Token | 4.91 s | **2.46 s** | −50% | |
|
|
| Prompts / s | 4.70 | **6.55** | +38% | |
|
|
| Tokens / s | 579 | **812** | +40% | |
|
|
|
|
|
|
|
|
<img src="https://cdn-uploads.huggingface.co/production/uploads/645ded34a45b4182d7f5c385/4rDhaeC-1GMj6KWbB27f9.png" width="480" height="480" alt="image/png"> |
|
|
|
|
|
|
|
|
### Training scalability (distillation run, MI300A cluster) |
|
|
|
|
|
| Nodes | Tokens / s | Speed‑up | |
|
|
|-------|------------|----------| |
|
|
| 4 | 1 461 | – | |
|
|
| 8 | 3 327 | 2.3 × | |
|
|
| 16 | 7 423 | 5.1 × | |
|
|
| 32 | 15 286 | 10.5 × | |
|
|
| 64 | 25 455 | 17.4 × | |
|
|
|
|
|
Near‑linear scaling thanks to sharded ZeRO‑3 + RCCL optimisations. |
|
|
|
|
|
# Inference |
|
|
|
|
|
|
|
|
### Transformers |
|
|
|
|
|
```python |
|
|
model_name = "seedboxai/KafkaLM-15B" |
|
|
|
|
|
# load the tokenizer and the model |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_name, |
|
|
torch_dtype="auto", |
|
|
device_map="auto" |
|
|
) |
|
|
|
|
|
# prepare the model input |
|
|
prompt = "Why did Kafka hit different?" |
|
|
messages = [ |
|
|
{"role": "user", "content": prompt} |
|
|
] |
|
|
text = tokenizer.apply_chat_template( |
|
|
messages, |
|
|
tokenize=False, |
|
|
add_generation_prompt=True, |
|
|
) |
|
|
model_inputs = tokenizer([text], return_tensors="pt").to(model.device) |
|
|
|
|
|
# conduct text completion |
|
|
generated_ids = model.generate( |
|
|
**model_inputs, |
|
|
max_new_tokens=1024 |
|
|
) |
|
|
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist() |
|
|
|
|
|
response = tokenizer.decode(output_ids, skip_special_tokens=True) |
|
|
|
|
|
print(response) |
|
|
``` |
|
|
|
|
|
## vLLM |
|
|
|
|
|
```python |
|
|
|
|
|
""" |
|
|
This example shows how to use KafkaLM-15B in vLLM with and without LoRA reasoning functionality |
|
|
for offline inference. |
|
|
""" |
|
|
|
|
|
from typing import Optional |
|
|
from huggingface_hub import snapshot_download |
|
|
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams |
|
|
from vllm.lora.request import LoRARequest |
|
|
|
|
|
def create_test_prompts(lora_path: str) -> list[tuple[str, SamplingParams, Optional[LoRARequest]]]: |
|
|
"""Create a list of test prompts with their sampling parameters. |
|
|
1 requests for base model, 1 request for the LoRA. |
|
|
""" |
|
|
return [ |
|
|
("Why did Kafka hit different?", |
|
|
SamplingParams(temperature=0.7, |
|
|
top_p=0.95, |
|
|
max_tokens=1024), None), |
|
|
|
|
|
("Create a Markdown table comparing SHA‑256, BLAKE3, and SHA‑3 with columns: internal structure, block size, and throughput.", |
|
|
SamplingParams(temperature=0.6, |
|
|
top_p=0.95, |
|
|
top_k=20, |
|
|
logprobs=1, |
|
|
prompt_logprobs=1, |
|
|
max_tokens=4096), |
|
|
LoRARequest("reasoning-lora", 1, lora_path)), |
|
|
|
|
|
] |
|
|
|
|
|
|
|
|
def process_requests(engine: LLMEngine, |
|
|
test_prompts: list[tuple[str, SamplingParams, |
|
|
Optional[LoRARequest]]]): |
|
|
"""Continuously process a list of prompts and handle the outputs.""" |
|
|
request_id = 0 |
|
|
|
|
|
while test_prompts or engine.has_unfinished_requests(): |
|
|
if test_prompts: |
|
|
prompt, sampling_params, lora_request = test_prompts.pop(0) |
|
|
engine.add_request(str(request_id), |
|
|
prompt, |
|
|
sampling_params, |
|
|
lora_request=lora_request) |
|
|
request_id += 1 |
|
|
|
|
|
request_outputs: list[RequestOutput] = engine.step() |
|
|
|
|
|
for request_output in request_outputs: |
|
|
if request_output.finished: |
|
|
print(request_output) |
|
|
|
|
|
|
|
|
def initialize_engine() -> LLMEngine: |
|
|
"""Initialize the LLMEngine.""" |
|
|
|
|
|
engine_args = EngineArgs(model="seedboxai/KafkaLM-15B", |
|
|
enable_lora=True, |
|
|
max_loras=1, |
|
|
max_lora_rank=128, |
|
|
max_cpu_loras=2, |
|
|
max_num_seqs=256) |
|
|
|
|
|
return LLMEngine.from_engine_args(engine_args) |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main function that sets up and runs the prompt processing.""" |
|
|
engine = initialize_engine() |
|
|
lora_path = snapshot_download(repo_id="seedboxai/KafkaLM-15B-GRPO_LoRA_Exp") |
|
|
test_prompts = create_test_prompts(lora_path) |
|
|
process_requests(engine, test_prompts) |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |
|
|
|
|
|
``` |
|
|
|
|
|
|
|
|
## Citation |
|
|
```bibtex |
|
|
@misc{kafkalm2025, |
|
|
title={Evaluating AMD's MI300A APU: Performance Insights on LLM Training via Knowledge Distillation}, |
|
|
author={Dennis Dickmann, Philipp Offenhäuser, Rishabh Saxena, George S. Markomanolis, Alessandro Rigazzi, Patrick Keller, Dennis Hoppe}, |
|
|
howpublished={Cray User Group Conference, 2025}, |
|
|
note={to be published}, |
|
|
year={2025} |
|
|
} |
|
|
``` |