File size: 5,410 Bytes
41b0a37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
#!/usr/bin/env python3
"""
DeepSeek-V3.1-4bit Inference Script
Comprehensive testing with various generation parameters
"""

from mlx_lm import load, generate
import time

def main():
    # Load the model
    print("🔃 Loading model...")
    model, tokenizer = load("/Users/martinrivera/deepseek_v3_1_4bit_mlx/deepseek_v3_4bit")
    print("✅ Model loaded successfully!")
    
    # Test prompts
    prompts = [
        "Explain quantum computing in simple terms:",
        "Write a poem about artificial intelligence:",
        "How will AI impact healthcare in the next decade?",
        "Translate 'Hello, how are you?' to French:",
        "The future of AI is"
    ]
    
    # Test 1: Basic generation with different parameters
    print("\n" + "="*60)
    print("🧪 TEST 1: Basic generation with different parameters")
    print("="*60)
    
    test_prompts = [
        ("The future of AI is", 250),
        ("Explain quantum computing:", 250),
        ("Write a short poem:", 250)
    ]
    
    for prompt, max_tokens in test_prompts:
        print(f"\n📝 Prompt: '{prompt}'")
        print(f"⚙️  Params: tokens={max_tokens}")
        
        start_time = time.time()
        response = generate(
            model, 
            tokenizer, 
            prompt, 
            max_tokens=max_tokens,
        )
        end_time = time.time()
        
        words = len(response.split())
        time_taken = end_time - start_time
        words_per_second = words / time_taken if time_taken > 0 else 0
        
        print(f"📄 Response: {response}")
        print(f"⏱️  Generated {words} words in {time_taken:.2f} seconds")
        print(f"🚀 Speed: {words_per_second:.2f} words/second")
        print("-" * 50)
    
    # Test 2: Chat format (if supported)
    print("\n" + "="*60)
    print("💬 TEST 2: Chat format testing")
    print("="*60)
    
    try:
        messages = [
            {"role": "system", "content": "You are a helpful AI assistant."},
            {"role": "user", "content": "Explain quantum computing simply"}
        ]
        
        # Try to apply chat template
        prompt = tokenizer.apply_chat_template(messages, tokenize=False)
        print(f"💬 Chat prompt: {prompt[:100]}...")
        
        response = generate(model, tokenizer, prompt, max_tokens=250)
        print(f"🤖 Assistant: {response}")
        
    except Exception as e:
        print(f"⚠️ Chat template not supported: {e}")
        # Fallback to regular prompt
        response = generate(model, tokenizer, "Explain quantum computing simply:", max_tokens=250)
        print(f"📄 Response: {response}")
    
    # Test 3: Manual streaming simulation
    print("\n" + "="*60)
    print("🌊 TEST 3: Manual streaming simulation")
    print("="*60)
    
    def simulate_streaming(prompt, max_tokens=250):
        """Simulate streaming by generating in chunks"""
        print(f"\n📝 Streaming: '{prompt}'")
        print("🌊 Response: ", end="", flush=True)
        
        start_time = time.time()
        
        # Generate the full response first
        response = generate(
            model, 
            tokenizer, 
            prompt, 
            max_tokens=max_tokens,
            verbose=False
        )
        
        # Simulate streaming by printing words with delays
        words = response.split()
        for i, word in enumerate(words):
            print(word, end=" ", flush=True)
            # Add a small delay to simulate real-time generation
            if i % 5 == 0:  # Flush more frequently
                time.sleep(0.05)
        
        end_time = time.time()
        time_taken = end_time - start_time
        words_count = len(words)
        words_per_second = words_count / time_taken if time_taken > 0 else 0
        
        print(f"\n⏱️  Generated {words_count} words in {time_taken:.2f} seconds")
        print(f"🚀 Speed: {words_per_second:.2f} words/second")
        
        return response
    
    # Simulate streaming for all prompts
    for i, prompt in enumerate(prompts, 1):
        print(f"\n[{i}/{len(prompts)}] ", end="")
        simulate_streaming(prompt, max_tokens=250)
        time.sleep(1)  # Brief pause between generations
    
    # Performance benchmark
    print("\n" + "="*60)
    print("📊 TEST 4: Performance benchmark")
    print("="*60)
    
    benchmark_prompt = "The future of artificial intelligence is"
    test_runs = 3
    
    total_time = 0
    total_words = 0
    
    for run in range(test_runs):
        print(f"\n🏃‍♂️ Benchmark run {run + 1}/{test_runs}...")
        start_time = time.time()
        response = generate(model, tokenizer, benchmark_prompt, max_tokens=250)
        end_time = time.time()
        
        time_taken = end_time - start_time
        words = len(response.split())
        words_per_second = words / time_taken
        
        total_time += time_taken
        total_words += words
        
        print(f"⏱️  Time: {time_taken:.2f}s, Words: {words}, Speed: {words_per_second:.2f} words/s")
    
    # Average performance
    if test_runs > 0:
        avg_time = total_time / test_runs
        avg_words = total_words / test_runs
        avg_speed = avg_words / avg_time if avg_time > 0 else 0
        
        print(f"\n📈 Averages: {avg_time:.2f}s per generation, {avg_speed:.2f} words/s")

if __name__ == "__main__":
    main()