MLX_DeepSeek_V3_1_4bit / inference.py
TroglodyteDerivations's picture
Upload 7 files
41b0a37 verified
#!/usr/bin/env python3
"""
DeepSeek-V3.1-4bit Inference Script
Comprehensive testing with various generation parameters
"""
from mlx_lm import load, generate
import time
def main():
# Load the model
print("🔃 Loading model...")
model, tokenizer = load("/Users/martinrivera/deepseek_v3_1_4bit_mlx/deepseek_v3_4bit")
print("✅ Model loaded successfully!")
# Test prompts
prompts = [
"Explain quantum computing in simple terms:",
"Write a poem about artificial intelligence:",
"How will AI impact healthcare in the next decade?",
"Translate 'Hello, how are you?' to French:",
"The future of AI is"
]
# Test 1: Basic generation with different parameters
print("\n" + "="*60)
print("🧪 TEST 1: Basic generation with different parameters")
print("="*60)
test_prompts = [
("The future of AI is", 250),
("Explain quantum computing:", 250),
("Write a short poem:", 250)
]
for prompt, max_tokens in test_prompts:
print(f"\n📝 Prompt: '{prompt}'")
print(f"⚙️ Params: tokens={max_tokens}")
start_time = time.time()
response = generate(
model,
tokenizer,
prompt,
max_tokens=max_tokens,
)
end_time = time.time()
words = len(response.split())
time_taken = end_time - start_time
words_per_second = words / time_taken if time_taken > 0 else 0
print(f"📄 Response: {response}")
print(f"⏱️ Generated {words} words in {time_taken:.2f} seconds")
print(f"🚀 Speed: {words_per_second:.2f} words/second")
print("-" * 50)
# Test 2: Chat format (if supported)
print("\n" + "="*60)
print("💬 TEST 2: Chat format testing")
print("="*60)
try:
messages = [
{"role": "system", "content": "You are a helpful AI assistant."},
{"role": "user", "content": "Explain quantum computing simply"}
]
# Try to apply chat template
prompt = tokenizer.apply_chat_template(messages, tokenize=False)
print(f"💬 Chat prompt: {prompt[:100]}...")
response = generate(model, tokenizer, prompt, max_tokens=250)
print(f"🤖 Assistant: {response}")
except Exception as e:
print(f"⚠️ Chat template not supported: {e}")
# Fallback to regular prompt
response = generate(model, tokenizer, "Explain quantum computing simply:", max_tokens=250)
print(f"📄 Response: {response}")
# Test 3: Manual streaming simulation
print("\n" + "="*60)
print("🌊 TEST 3: Manual streaming simulation")
print("="*60)
def simulate_streaming(prompt, max_tokens=250):
"""Simulate streaming by generating in chunks"""
print(f"\n📝 Streaming: '{prompt}'")
print("🌊 Response: ", end="", flush=True)
start_time = time.time()
# Generate the full response first
response = generate(
model,
tokenizer,
prompt,
max_tokens=max_tokens,
verbose=False
)
# Simulate streaming by printing words with delays
words = response.split()
for i, word in enumerate(words):
print(word, end=" ", flush=True)
# Add a small delay to simulate real-time generation
if i % 5 == 0: # Flush more frequently
time.sleep(0.05)
end_time = time.time()
time_taken = end_time - start_time
words_count = len(words)
words_per_second = words_count / time_taken if time_taken > 0 else 0
print(f"\n⏱️ Generated {words_count} words in {time_taken:.2f} seconds")
print(f"🚀 Speed: {words_per_second:.2f} words/second")
return response
# Simulate streaming for all prompts
for i, prompt in enumerate(prompts, 1):
print(f"\n[{i}/{len(prompts)}] ", end="")
simulate_streaming(prompt, max_tokens=250)
time.sleep(1) # Brief pause between generations
# Performance benchmark
print("\n" + "="*60)
print("📊 TEST 4: Performance benchmark")
print("="*60)
benchmark_prompt = "The future of artificial intelligence is"
test_runs = 3
total_time = 0
total_words = 0
for run in range(test_runs):
print(f"\n🏃‍♂️ Benchmark run {run + 1}/{test_runs}...")
start_time = time.time()
response = generate(model, tokenizer, benchmark_prompt, max_tokens=250)
end_time = time.time()
time_taken = end_time - start_time
words = len(response.split())
words_per_second = words / time_taken
total_time += time_taken
total_words += words
print(f"⏱️ Time: {time_taken:.2f}s, Words: {words}, Speed: {words_per_second:.2f} words/s")
# Average performance
if test_runs > 0:
avg_time = total_time / test_runs
avg_words = total_words / test_runs
avg_speed = avg_words / avg_time if avg_time > 0 else 0
print(f"\n📈 Averages: {avg_time:.2f}s per generation, {avg_speed:.2f} words/s")
if __name__ == "__main__":
main()