KafkaLM-15B / README.md
doubledsbv's picture
Update README.md
0bba440 verified
|
raw
history blame
9.17 kB
---
library_name: transformers
tags:
- pruning
- distillation
- sparsity‑2:4
license: apache-2.0
language:
- en
- de
- fr
- es
- it
- pt
base_model:
- doubledsbv/KafkaLM-15B-Base
pipeline_tag: text-generation
---
<img src="https://cdn-uploads.huggingface.co/production/uploads/645ded34a45b4182d7f5c385/EgsjPDWd37LjAtamiICxk.png" width="480" height="480" alt="image/png">
# Model Description
**KafkaLM‑15B‑Base** is a 15‑billion‑parameter, sparsity‑aware language model distilled from *Mistral‑Small‑24B‑Base‑2501* and further post trained (SFT + DPO + GRPO /w verifiable rewards).
This experimental model was created in five stages:
| Stage | What we did | Why it matters |
|-------|-------------|----------------|
| **1. SimplePrune** | Applied a hierarchical, hardware‑aware pruning pipeline that combines block‑, channel‑ and 2:4 structured sparsity (≈ 37.5 % parameter reduction) | Slashes memory footprint while minimizing perplexity degradation |
| **2. Teacher calibration** | Briefly fine‑tuned the unpruned 24 B teacher on a 10 B‑token multilingual European corpus on a AMD M300A cluster | Produces stable logits and hidden states for distillation |
| **3. Knowledge distillation** | Distilled the calibrated teacher into the pruned 15 B student using a **fused loss**:<br/>`L Pooled SquareHead + LKL + 0.25 * LCE` | Transfers teacher capabiities effectively with <15B tokens **(< 2 epochs)** on 64 MI300A nodes |
| **4. SFT+DPO** | Supervised finetuning + Direct Preference Optimization) on curated open-source multilingual and multitask datasets | Enhances model alignment with human preferences while preserving multilingual capabilities |
| **5. RL** | Trained GRPO as separate LoRA adapter to make it easy for serving and optional for using | Enables flexible deployment with optional reinforcement learning benefits without modifying the base model |
**Key capabilities**
* Balanced for both **multitask** and multilingual conversation and long context handling
* Structured **2:4 sparsity** → runs up to **40 % faster** on sparsity‑aware kernels
* Distilled on a combination of multilingual pretraining and synthetic data
* Training pipeline optimized for unified‑memory GPUs (AMD MI300A) but runs on any CUDA / ROCm device
---
### LoRA based reasoning capabilities
[Download adapter from hf](https://huggingface.co/seedboxai/KafkaLM-15B-GRPO_LoRA_Exp)
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
# Load base model
base_model = AutoModelForCausalLM.from_pretrained("seedboxai/KafkaLM-15B")
tokenizer = AutoTokenizer.from_pretrained("seedboxai/KafkaLM-15B")
# Apply LoRA adapter
model = PeftModel.from_pretrained(
base_model,
"seedboxai/KafkaLM-15B-GRPO_LoRA_Exp",
adapter_name="grpo_lora"
)
```
## Pruning Process
**Pruning & Distillation Strategy — SimplePrune**
Hardware‑aware, hierarchical pipeline. SimplePrune starts with coarse block‑level pruning and drills down to channel‑ and neuron‑level removals, finishing with 2 : 4 structured sparsity. This staged approach converts compression ratios into real memory‑bandwidth and latency gains. ​
**Sensitivity‑guided selection**
Each stage is driven by activation‑magnitude profiles and Hessian‑based importance scores captured asynchronously during training, allowing the framework to run inside the MI300A’s 512 GB unified memory without OOM interruptions. ​
**Two‑phase optimisation** 
A fast greedy pass prunes low‑impact blocks in MLP expansion layers, after which a **Tabu‑Search** meta‑heuristic explores cross‑layer combinations for a better global trade‑off between sparsity and perplexity/KL divergence. ​
**Post‑pruning knowledge distillation**
The pruned 15 B student is distilled from a calibrated 24 B teacher using a fused LSquareHead + KL + 0.25 · CE loss across 20 B multilingual tokens, restoring > 96 % of the original quality in ≤ 2 epochs on up to 64 MI300A nodes. ​
### Results
Up to 40 % parameter reduction (24 B → 15 B) delivers 2× lower TTFT and ≈ 40 % higher tokens/s versus the uncompressed teacher while matching perplexity and divergence metrics—validating SimplePrune as an effective route to deploy KafkaLM in memory‑constrained, sparsity‑accelerated environments.
| Metric | Mistral‑24B | **KafkaLM‑15B** | Δ |
|--------|-------------|-----------------|---|
| Time‑to‑First‑Token | 4.91 s | **2.46 s** | −50% |
| Prompts / s | 4.70 | **6.55** | +38% |
| Tokens / s | 579 | **812** | +40% |
<img src="https://cdn-uploads.huggingface.co/production/uploads/645ded34a45b4182d7f5c385/4rDhaeC-1GMj6KWbB27f9.png" width="480" height="480" alt="image/png">
### Training scalability (distillation run, MI300A cluster)
| Nodes | Tokens / s | Speed‑up |
|-------|------------|----------|
| 4 | 1 461 | – |
| 8 | 3 327 | 2.3 × |
| 16 | 7 423 | 5.1 × |
| 32 | 15 286 | 10.5 × |
| 64 | 25 455 | 17.4 × |
Near‑linear scaling thanks to sharded ZeRO‑3 + RCCL optimisations.
# Inference
### Transformers
```python
model_name = "seedboxai/KafkaLM-15B"
# load the tokenizer and the model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto"
)
# prepare the model input
prompt = "Why did Kafka hit different?"
messages = [
{"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
# conduct text completion
generated_ids = model.generate(
**model_inputs,
max_new_tokens=1024
)
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
response = tokenizer.decode(output_ids, skip_special_tokens=True)
print(response)
```
## vLLM
```python
"""
This example shows how to use KafkaLM-15B in vLLM with and without LoRA reasoning functionality
for offline inference.
"""
from typing import Optional
from huggingface_hub import snapshot_download
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
from vllm.lora.request import LoRARequest
def create_test_prompts(lora_path: str) -> list[tuple[str, SamplingParams, Optional[LoRARequest]]]:
"""Create a list of test prompts with their sampling parameters.
1 requests for base model, 1 request for the LoRA.
"""
return [
("Why did Kafka hit different?",
SamplingParams(temperature=0.7,
top_p=0.95,
max_tokens=1024), None),
("Create a Markdown table comparing SHA‑256, BLAKE3, and SHA‑3 with columns: internal structure, block size, and throughput.",
SamplingParams(temperature=0.6,
top_p=0.95,
top_k=20,
logprobs=1,
prompt_logprobs=1,
max_tokens=4096),
LoRARequest("reasoning-lora", 1, lora_path)),
]
def process_requests(engine: LLMEngine,
test_prompts: list[tuple[str, SamplingParams,
Optional[LoRARequest]]]):
"""Continuously process a list of prompts and handle the outputs."""
request_id = 0
while test_prompts or engine.has_unfinished_requests():
if test_prompts:
prompt, sampling_params, lora_request = test_prompts.pop(0)
engine.add_request(str(request_id),
prompt,
sampling_params,
lora_request=lora_request)
request_id += 1
request_outputs: list[RequestOutput] = engine.step()
for request_output in request_outputs:
if request_output.finished:
print(request_output)
def initialize_engine() -> LLMEngine:
"""Initialize the LLMEngine."""
engine_args = EngineArgs(model="seedboxai/KafkaLM-15B",
enable_lora=True,
max_loras=1,
max_lora_rank=128,
max_cpu_loras=2,
max_num_seqs=256)
return LLMEngine.from_engine_args(engine_args)
def main():
"""Main function that sets up and runs the prompt processing."""
engine = initialize_engine()
lora_path = snapshot_download(repo_id="seedboxai/KafkaLM-15B-GRPO_LoRA_Exp")
test_prompts = create_test_prompts(lora_path)
process_requests(engine, test_prompts)
if __name__ == '__main__':
main()
```
## Citation
```bibtex
@misc{kafkalm2025,
title={Evaluating AMD's MI300A APU: Performance Insights on LLM Training via Knowledge Distillation},
author={Dennis Dickmann, Philipp Offenhäuser, Rishabh Saxena, George S. Markomanolis, Alessandro Rigazzi, Patrick Keller, Dennis Hoppe},
howpublished={Cray User Group Conference, 2025},
note={to be published},
year={2025}
}
```