akkiisfrommars commited on
Commit
b2601bc
·
verified ·
1 Parent(s): 0836361

Delete chat.py

Browse files
Files changed (1) hide show
  1. chat.py +0 -1130
chat.py DELETED
@@ -1,1130 +0,0 @@
1
- import os
2
- import sys
3
- import time
4
- import argparse
5
- import torch
6
- import numpy as np
7
- from termcolor import colored
8
- import logging
9
- import readline
10
- import re
11
- import textwrap
12
- import random
13
- from collections import defaultdict
14
- import json
15
-
16
- # Required imports for HF Hub
17
- try:
18
- from transformers import GPT2Tokenizer
19
- from huggingface_hub import hf_hub_download, snapshot_download
20
- HF_AVAILABLE = True
21
- except ImportError:
22
- HF_AVAILABLE = False
23
- print("Required libraries not available.")
24
- print("Install with: pip install transformers huggingface-hub")
25
- sys.exit(1)
26
-
27
- # Set up logging
28
- logging.basicConfig(
29
- level=logging.INFO,
30
- format='%(asctime)s - %(levelname)s - %(message)s',
31
- handlers=[logging.StreamHandler(sys.stdout)]
32
- )
33
- logger = logging.getLogger(__name__)
34
-
35
- # Default model repository
36
- DEFAULT_MODEL_REPO = "MistyozAI/CosmicFish-90M"
37
-
38
- # Default prompt template
39
- DEFAULT_PROMPT_TEMPLATE = "Below is a conversation between a helpful AI assistant and a human. The assistant is knowledgeable, friendly, and provides detailed and accurate responses.\n\n"
40
-
41
-
42
- class CosmicConfig:
43
- """Configuration class for CosmicFish."""
44
-
45
- def __init__(self,
46
- vocab_size=50257,
47
- block_size=512,
48
- n_layer=10,
49
- n_head=16,
50
- n_embd=640,
51
- bias=True,
52
- dropout=0.0,
53
- n_query_groups=4,
54
- eps=1e-6,
55
- use_rotary=True,
56
- use_swiglu=True,
57
- use_qk_norm=False,
58
- use_gqa=True):
59
- self.vocab_size = vocab_size
60
- self.block_size = block_size
61
- self.n_layer = n_layer
62
- self.n_head = n_head
63
- self.n_embd = n_embd
64
- self.bias = bias
65
- self.dropout = dropout
66
- self.eps = eps
67
- self.use_rotary = use_rotary
68
- self.use_swiglu = use_swiglu
69
- self.use_qk_norm = use_qk_norm
70
- self.use_gqa = use_gqa
71
- self.n_query_groups = n_query_groups if use_gqa else n_head
72
- # Ensure n_head is divisible by n_query_groups
73
- assert n_head % self.n_query_groups == 0, "n_head must be divisible by n_query_groups"
74
-
75
-
76
- class RMSNorm(torch.nn.Module):
77
- """Root Mean Square Normalization"""
78
-
79
- def __init__(self, dim, eps=1e-6):
80
- super().__init__()
81
- self.eps = eps
82
- self.weight = torch.nn.Parameter(torch.ones(dim))
83
-
84
- def forward(self, x):
85
- rms = torch.sqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
86
- return self.weight * (x / rms)
87
-
88
-
89
- def precompute_freqs_cis(dim, end, theta=10000.0):
90
- """Precompute the frequency tensor for complex exponentials (cis)"""
91
- freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
92
- t = torch.arange(end, device=freqs.device)
93
- freqs = torch.outer(t, freqs)
94
- freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
95
- return freqs_cis
96
-
97
-
98
- def apply_rotary_emb(xq, xk, freqs_cis):
99
- """Apply rotary embeddings to input tensors"""
100
- xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
101
- xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
102
-
103
- seq_len = xq_.size(2)
104
- if freqs_cis.size(0) < seq_len:
105
- raise ValueError(f"freqs_cis has only {freqs_cis.size(0)} values but sequence length is {seq_len}")
106
-
107
- freqs_cis_seq = freqs_cis[:seq_len]
108
- xq_out = torch.view_as_real(xq_ * freqs_cis_seq.unsqueeze(0)).flatten(3)
109
- xk_out = torch.view_as_real(xk_ * freqs_cis_seq.unsqueeze(0)).flatten(3)
110
-
111
- return xq_out.type_as(xq), xk_out.type_as(xk)
112
-
113
-
114
- class GroupedQueryAttention(torch.nn.Module):
115
- """Grouped Query Attention (GQA) implementation"""
116
-
117
- def __init__(self, config):
118
- super().__init__()
119
- assert config.n_embd % config.n_head == 0
120
-
121
- head_dim = config.n_embd // config.n_head
122
- self.head_dim = head_dim
123
- self.n_head = config.n_head
124
- self.n_embd = config.n_embd
125
- self.n_query_groups = config.n_query_groups
126
-
127
- self.kv_heads = config.n_head // config.n_query_groups if config.use_gqa else config.n_head
128
- qkv_proj_size = (config.n_head + 2 * self.kv_heads) * head_dim
129
-
130
- self.c_attn = torch.nn.Linear(config.n_embd, qkv_proj_size, bias=config.bias)
131
- self.c_proj = torch.nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
132
-
133
- # Flash attention support
134
- self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
135
- if not self.flash:
136
- self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
137
- .view(1, 1, config.block_size, config.block_size))
138
-
139
- # Query-key normalization
140
- self.qk_norm = getattr(config, 'use_qk_norm', False)
141
- if self.qk_norm:
142
- self.q_norm = RMSNorm(head_dim, eps=getattr(config, 'eps', 1e-6))
143
- self.k_norm = RMSNorm(head_dim, eps=getattr(config, 'eps', 1e-6))
144
-
145
- def forward(self, x, freqs_cis=None):
146
- B, T, C = x.size()
147
- qkv = self.c_attn(x)
148
- head_dim = C // self.n_head
149
-
150
- q_size = self.n_head * head_dim
151
- k_size = self.kv_heads * head_dim
152
- v_size = self.kv_heads * head_dim
153
-
154
- q, k, v = qkv.split([q_size, k_size, v_size], dim=2)
155
-
156
- q = q.view(B, T, self.n_head, head_dim).transpose(1, 2)
157
- k = k.view(B, T, self.kv_heads, head_dim).transpose(1, 2)
158
- v = v.view(B, T, self.kv_heads, head_dim).transpose(1, 2)
159
-
160
- # Repeat k and v if needed for GQA
161
- if self.kv_heads < self.n_head:
162
- repeats = self.n_head // self.kv_heads
163
- k = k.repeat_interleave(repeats, dim=1)
164
- v = v.repeat_interleave(repeats, dim=1)
165
-
166
- # Apply rotary embeddings
167
- if freqs_cis is not None:
168
- q, k = apply_rotary_emb(q, k, freqs_cis)
169
-
170
- # Apply query-key normalization
171
- if self.qk_norm:
172
- q = self.q_norm(q)
173
- k = self.k_norm(k)
174
-
175
- # Compute attention
176
- if self.flash:
177
- y = torch.nn.functional.scaled_dot_product_attention(
178
- q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True
179
- )
180
- else:
181
- att = (q @ k.transpose(-2, -1)) * (1.0 / torch.sqrt(torch.tensor(k.size(-1), dtype=torch.float32)))
182
- att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
183
- att = torch.nn.functional.softmax(att, dim=-1)
184
- y = att @ v
185
-
186
- y = y.transpose(1, 2).contiguous().view(B, T, C)
187
- y = self.c_proj(y)
188
- return y
189
-
190
-
191
- class Block(torch.nn.Module):
192
- """Transformer block"""
193
-
194
- def __init__(self, config):
195
- super().__init__()
196
- self.ln_1 = RMSNorm(config.n_embd, eps=config.eps)
197
- self.ln_2 = RMSNorm(config.n_embd, eps=config.eps)
198
- self.attn = GroupedQueryAttention(config)
199
-
200
- # MLP implementation based on configuration
201
- if config.use_swiglu:
202
- # SwiGLU MLP
203
- self.mlp = torch.nn.ModuleDict(dict(
204
- gate=torch.nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias),
205
- up=torch.nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias),
206
- down=torch.nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias),
207
- act=torch.nn.SiLU(),
208
- ))
209
- m = self.mlp
210
- self.mlpf = lambda x: m.down(m.act(m.up(x)) * m.gate(x))
211
- else:
212
- # Traditional MLP
213
- self.mlp = torch.nn.ModuleDict(dict(
214
- c_fc=torch.nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias),
215
- c_proj=torch.nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias),
216
- act=torch.nn.GELU(),
217
- ))
218
- m = self.mlp
219
- self.mlpf = lambda x: m.c_proj(m.act(m.c_fc(x)))
220
-
221
- def forward(self, x, freqs_cis=None):
222
- x = x + self.attn(self.ln_1(x), freqs_cis)
223
- x = x + self.mlpf(self.ln_2(x))
224
- return x
225
-
226
-
227
- class CosmicFish(torch.nn.Module):
228
- """
229
- CosmicFish model for inference only.
230
- Features: Rotary Positional Embeddings, Grouped-Query Attention, SwiGLU, RMSNorm
231
- """
232
-
233
- def __init__(self, config):
234
- super().__init__()
235
- self.config = config
236
-
237
- self.transformer = torch.nn.ModuleDict(dict(
238
- wte=torch.nn.Embedding(config.vocab_size, config.n_embd),
239
- h=torch.nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
240
- ln_f=RMSNorm(config.n_embd, eps=config.eps),
241
- ))
242
-
243
- self.lm_head = torch.nn.Linear(config.n_embd, config.vocab_size, bias=False)
244
-
245
- # Share weights between embedding and output
246
- self.transformer.wte.weight = self.lm_head.weight
247
-
248
- # Precompute rotary embedding frequencies
249
- if config.use_rotary:
250
- head_dim = config.n_embd // config.n_head
251
- self.freqs_cis = precompute_freqs_cis(head_dim, config.block_size)
252
- else:
253
- self.freqs_cis = None
254
- self.transformer.wpe = torch.nn.Embedding(config.block_size, config.n_embd)
255
-
256
- def get_num_params(self, non_embedding=True):
257
- """Return the number of parameters in the model."""
258
- n_params = sum(p.numel() for p in self.parameters())
259
- if non_embedding and hasattr(self.transformer, 'wpe'):
260
- n_params -= self.transformer.wpe.weight.numel()
261
- return n_params
262
-
263
- def forward(self, idx, targets=None):
264
- """Forward pass through the model."""
265
- device = idx.device
266
- b, t = idx.size()
267
- assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
268
-
269
- # Get token embeddings
270
- tok_emb = self.transformer.wte(idx)
271
-
272
- # Handle positional embeddings
273
- if self.config.use_rotary:
274
- x = tok_emb
275
- freqs_cis = self.freqs_cis.to(device) if self.freqs_cis is not None else None
276
- else:
277
- pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0)
278
- pos_emb = self.transformer.wpe(pos)
279
- x = tok_emb + pos_emb
280
- freqs_cis = None
281
-
282
- # Apply transformer blocks
283
- for block in self.transformer.h:
284
- x = block(x, freqs_cis)
285
-
286
- # Apply final normalization
287
- x = self.transformer.ln_f(x)
288
-
289
- # Calculate outputs
290
- if targets is not None:
291
- logits = self.lm_head(x)
292
- loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
293
- else:
294
- # For inference, only compute logits for the last token
295
- logits = self.lm_head(x[:, [-1], :])
296
- loss = None
297
-
298
- return logits, loss
299
-
300
- @torch.no_grad()
301
- def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
302
- """
303
- Generate text by sampling from the model, token by token.
304
- """
305
- for _ in range(max_new_tokens):
306
- # Crop sequence to block size if needed
307
- idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
308
-
309
- # Forward pass
310
- logits, _ = self(idx_cond)
311
- logits = logits[:, -1, :] / temperature
312
-
313
- # Apply top-k sampling
314
- if top_k is not None:
315
- v, _ = torch.topk(logits, top_k)
316
- logits[logits < v[:, [-1]]] = -float('Inf')
317
-
318
- # Sample next token
319
- probs = torch.nn.functional.softmax(logits, dim=-1)
320
- idx_next = torch.multinomial(probs, num_samples=1)
321
-
322
- # Append to sequence
323
- idx = torch.cat((idx, idx_next), dim=1)
324
-
325
- return idx
326
-
327
-
328
- class RepetitionPenaltyLogitsProcessor:
329
- """Apply repetition penalty to prevent repeating tokens."""
330
-
331
- def __init__(self, penalty=1.2):
332
- self.penalty = penalty
333
-
334
- def __call__(self, input_ids, scores):
335
- """Apply repetition penalty to logits where input_ids is already seen."""
336
- score = torch.gather(scores, 1, input_ids)
337
- # If score > 0, penalize by dividing; if score < 0, penalize by multiplying
338
- score = torch.where(score > 0, score / self.penalty, score * self.penalty)
339
- scores.scatter_(1, input_ids, score)
340
- return scores
341
-
342
-
343
- class CosmicFishChatSession:
344
- """Chat session for CosmicFish model from Hugging Face Hub."""
345
-
346
- def __init__(self, model, tokenizer, config):
347
- """Initialize chat session with model and configuration."""
348
- self.model = model
349
- self.tokenizer = tokenizer
350
- self.config = config
351
- self.device = next(model.parameters()).device
352
- self.history = []
353
- self.history_tokens = []
354
- self.max_history_tokens = config.max_history_tokens
355
- self.prompt_template = config.prompt_template
356
- self.human_prefix = config.human_prefix
357
- self.assistant_prefix = config.assistant_prefix
358
- self.end_of_turn = config.end_of_turn
359
- self.block_size = config.block_size
360
- self.debug_mode = config.debug_mode
361
- self.repetition_penalty = config.repetition_penalty
362
- self.min_tokens_to_generate = config.min_tokens_to_generate
363
- self.max_retries = 20
364
-
365
- self.fallback_responses = [
366
- "I'd be happy to help with that. Could you provide more details about what specific information you're looking for?",
367
- "That's a topic I can provide information about. What specific aspects would you like to know?",
368
- "I understand your question. I can share factual information on this topic if you could specify what aspects you're interested in.",
369
- "I can help with your question. To give you the most relevant information, could you clarify what specific details you're looking for?",
370
- "I'd be glad to address your question. To provide the most helpful response, could you specify what particular aspects of this topic interest you?"
371
- ]
372
-
373
- self.generation_failure_message = "I'm sorry, but I'm having difficulty generating a response to that prompt. Could you try rephrasing your question or asking something else?"
374
-
375
- # For token counting
376
- self.total_prompt_tokens = 0
377
- self.total_generated_tokens = 0
378
-
379
- # End markers for live generation
380
- self.end_markers = [
381
- f"{self.human_prefix}",
382
- "Human:",
383
- "\nHuman:",
384
- "\nH:",
385
- "H:",
386
- "<|endoftext|>",
387
- "Below is a conversation",
388
- "\nA:",
389
- "A:",
390
- "</s>",
391
- "User:",
392
- "\nUser:"
393
- ]
394
-
395
- # Print welcome message
396
- if config.display_welcome:
397
- self._print_welcome_message()
398
-
399
- def _print_welcome_message(self):
400
- """Print a welcome message to the user."""
401
- welcome_text = f"""
402
- {'=' * 80}
403
- Welcome to CosmicFish!
404
-
405
- This is a {self.model.get_num_params() / 1e6:.1f}M parameter model made by MistyozAI.
406
- CosmicFish features advanced architecture with RoPE, GQA, SwiGLU, and RMSNorm.
407
-
408
- ⚠️ DISCLAIMER: Since this {self.model.get_num_params() / 1e6:.1f}M parameter model is relatively
409
- small, it is more likely to give incorrect answers or hallucinate compared to
410
- larger models. Please verify important information from reliable sources.
411
-
412
- Model: {DEFAULT_MODEL_REPO}
413
-
414
- Type your prompts and CosmicFish will respond.
415
-
416
- Special commands:
417
- - /help: Show this help message
418
- - /clear: Clear the conversation history
419
- - /exit or /quit: Exit the chat
420
- - /stats: Show token usage statistics
421
- - /save [filename]: Save the conversation
422
- - /load [filename]: Load a conversation
423
- - /temp [value]: Set temperature (between 0.1 and 2.0)
424
- - /penalty [value]: Set repetition penalty (1.0-2.0)
425
- - /debug: Toggle debug mode
426
- {'=' * 80}
427
- """
428
- print(colored(welcome_text, 'cyan'))
429
-
430
- def _format_prompt(self, user_input):
431
- """Format the complete prompt with history and current input."""
432
- # Start with the template
433
- formatted_prompt = self.prompt_template
434
-
435
- # Add conversation history
436
- for entry in self.history:
437
- role, text = entry
438
- if role == "human":
439
- formatted_prompt += f"{self.human_prefix}{text}{self.end_of_turn}"
440
- else: # assistant
441
- formatted_prompt += f"{self.assistant_prefix}{text}{self.end_of_turn}"
442
-
443
- # Add the current user input
444
- formatted_prompt += f"{self.human_prefix}{user_input}{self.end_of_turn}{self.assistant_prefix}"
445
-
446
- return formatted_prompt
447
-
448
- def _tokenize(self, text):
449
- """Tokenize text and return token IDs."""
450
- return self.tokenizer.encode(text)
451
-
452
- def _update_history(self, user_input, response):
453
- """Update conversation history."""
454
- # Add to text history
455
- self.history.append(("human", user_input))
456
- self.history.append(("assistant", response))
457
-
458
- # Update token history for context window management
459
- user_tokens = self._tokenize(f"{self.human_prefix}{user_input}{self.end_of_turn}")
460
- response_tokens = self._tokenize(f"{self.assistant_prefix}{response}{self.end_of_turn}")
461
-
462
- self.history_tokens.extend(user_tokens)
463
- self.history_tokens.extend(response_tokens)
464
-
465
- # Track token usage
466
- self.total_prompt_tokens += len(user_tokens)
467
- self.total_generated_tokens += len(response_tokens)
468
-
469
- # Trim history if it gets too long
470
- self._trim_history_if_needed()
471
-
472
- def _trim_history_if_needed(self):
473
- """Trim history to fit within the context window."""
474
- if len(self.history_tokens) > self.max_history_tokens:
475
- # Remove oldest turns until we're under the limit
476
- while len(self.history_tokens) > self.max_history_tokens and len(self.history) >= 2:
477
- # Remove oldest human and assistant turn
478
- self.history = self.history[2:]
479
-
480
- # Find token boundary for the removed turns
481
- user_turn = self.history[0][1]
482
- assistant_turn = self.history[1][1]
483
- user_tokens = len(self._tokenize(f"{self.human_prefix}{user_turn}{self.end_of_turn}"))
484
- assistant_tokens = len(self._tokenize(f"{self.assistant_prefix}{assistant_turn}{self.end_of_turn}"))
485
-
486
- # Update token history
487
- self.history_tokens = self.history_tokens[user_tokens + assistant_tokens:]
488
-
489
- def _should_stop_generation(self, text):
490
- """Check if generation should stop based on end markers."""
491
- for marker in self.end_markers:
492
- if marker in text:
493
- return True
494
- return False
495
-
496
- def _clean_token_text(self, text):
497
-
498
- text = text.replace('��', "'")
499
-
500
- text = text.replace('�', "'")
501
- text = text.replace('\ufffd', "'")
502
- text = text.replace('\uFFFD', "'")
503
-
504
- text = text.replace('’', "'")
505
- text = text.replace('“', "'")
506
- text = text.replace('�', "'")
507
- text = text.replace('â€"', "'")
508
- text = text.replace('â€"', "'")
509
-
510
- return text
511
-
512
- def generate_with_repetition_penalty(self, input_ids, max_new_tokens, temperature, top_k, penalty=1.2, live=False):
513
- """Custom generate function with repetition penalty and optional live generation."""
514
- model = self.model
515
- device = self.device
516
-
517
- # Ensure model is in eval mode
518
- model.eval()
519
-
520
- # Initialize sequence with input_ids
521
- generated = input_ids.clone()
522
-
523
- # Initialize live text buffer
524
- live_buffer = ""
525
-
526
- # Create repetition penalty processor
527
- rep_processor = RepetitionPenaltyLogitsProcessor(penalty=penalty)
528
-
529
- # Counter for generated tokens
530
- tokens_generated = 0
531
- min_tokens = self.min_tokens_to_generate
532
-
533
- # EOT token ID
534
- eot_token_id = self.tokenizer.eos_token_id if hasattr(self.tokenizer, 'eos_token_id') else 50256
535
-
536
- # Generate tokens one at a time
537
- for _ in range(max_new_tokens):
538
- # Get only the last block_size tokens if context is too long
539
- if generated.size(1) > self.block_size:
540
- context = generated[:, -self.block_size:]
541
- else:
542
- context = generated
543
-
544
- # Forward pass for next token prediction
545
- with torch.no_grad():
546
- logits, _ = model(context)
547
-
548
- # Get logits for the next token (last position)
549
- next_token_logits = logits[:, -1, :]
550
-
551
- # Apply temperature
552
- next_token_logits = next_token_logits / temperature
553
-
554
- # Apply repetition penalty
555
- if penalty > 1.0:
556
- next_token_logits = rep_processor(context, next_token_logits)
557
-
558
- # Optional top-k sampling
559
- if top_k is not None:
560
- indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
561
- next_token_logits[indices_to_remove] = float('-inf')
562
-
563
- # Convert logits to probabilities
564
- probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
565
-
566
- # Sample next token
567
- next_token = torch.multinomial(probs, num_samples=1)
568
-
569
- # Check if the next token is EOT and break immediately if so
570
- if next_token.item() == eot_token_id:
571
- if live:
572
- yield "", live_buffer, True
573
- break
574
-
575
- # Append next token to generated sequence
576
- generated = torch.cat((generated, next_token), dim=1)
577
- tokens_generated += 1
578
-
579
- # If live generation, decode and yield the token
580
- if live:
581
- # Decode the next token
582
- next_token_text = self.tokenizer.decode([next_token.item()])
583
- # Clean the token text to fix encoding issues
584
- next_token_text = self._clean_token_text(next_token_text)
585
- live_buffer += next_token_text
586
-
587
- # Check if we've hit an end marker in the buffer
588
- eot_marker_pos = live_buffer.find("<|endoftext|>")
589
- if eot_marker_pos != -1:
590
- # Cut off at the EOT marker
591
- live_buffer = live_buffer[:eot_marker_pos]
592
- yield "", live_buffer, True
593
- break
594
-
595
- # Check other end markers
596
- should_stop = tokens_generated >= min_tokens and self._should_stop_generation(live_buffer)
597
- yield next_token_text, live_buffer, should_stop
598
-
599
- if should_stop:
600
- break
601
-
602
- # For non-live generation, check if we should stop
603
- elif tokens_generated >= min_tokens:
604
- # Check for end markers in the recent generated tokens
605
- recent_text = self.tokenizer.decode(generated[0, -20:].tolist())
606
- if self._should_stop_generation(recent_text):
607
- break
608
-
609
- # Check if we generated any tokens at all
610
- if tokens_generated == 0 and not live:
611
- if self.debug_mode:
612
- print(colored("\n[No tokens generated in this attempt]", "red"))
613
- return None
614
-
615
- if not live:
616
- return generated
617
-
618
- def generate_response(self, user_input):
619
- """Generate a response to the user input."""
620
- # Format the complete prompt
621
- prompt = self._format_prompt(user_input)
622
-
623
- # Tokenize the prompt
624
- input_ids = torch.tensor(self._tokenize(prompt), dtype=torch.long).unsqueeze(0).to(self.device)
625
-
626
- # Ensure we don't exceed the model's context length
627
- if input_ids.size(1) > self.block_size:
628
- # If too long, keep the beginning part with the instruction template and trim the middle
629
- instruction_tokens = self._tokenize(self.prompt_template)
630
- # Keep the instruction and the most recent conversation that will fit
631
- keep_from_beginning = len(instruction_tokens)
632
- keep_from_end = self.block_size - keep_from_beginning
633
-
634
- # Combine beginning and end, ensuring we don't exceed array bounds
635
- if keep_from_end < 0:
636
- # If instruction alone is too long, trim it (shouldn't happen with reasonable templates)
637
- input_ids = input_ids[:, :self.block_size]
638
- else:
639
- # Keep instruction and most recent conversation
640
- input_ids = torch.cat([
641
- input_ids[:, :keep_from_beginning],
642
- input_ids[:, -(keep_from_end):]
643
- ], dim=1)
644
-
645
- # Track generation start time
646
- start_time = time.time()
647
-
648
- # Always use live generation
649
- return self._generate_live_response(input_ids, user_input, start_time)
650
-
651
- def _generate_live_response(self, input_ids, user_input, start_time):
652
- """Generate response with live token-by-token output."""
653
- # Initialize for live generation
654
- live_text = ""
655
- tokens_generated = 0
656
- retry_count = 0
657
-
658
- # Keep trying until we get a valid response or exhaust retries
659
- while retry_count <= self.max_retries:
660
- if retry_count > 0:
661
- # Calculate temperature for this retry
662
- if retry_count % 2 == 0:
663
- # Even retries: increase temperature
664
- temp_adjustment = min(0.2 * (retry_count // 2), 0.8)
665
- current_temp = min(self.config.temperature + temp_adjustment, 1.8)
666
- else:
667
- # Odd retries: decrease temperature
668
- temp_adjustment = min(0.2 * ((retry_count + 1) // 2), 0.4)
669
- current_temp = max(self.config.temperature - temp_adjustment, 0.2)
670
-
671
- if self.debug_mode:
672
- print(colored(f"\n[Live retry {retry_count}: Using temperature {current_temp:.2f}]", "yellow"))
673
- else:
674
- current_temp = self.config.temperature
675
-
676
- # Reset for this attempt
677
- live_text = ""
678
- tokens_generated = 0
679
- generation_failed = False
680
-
681
- # Try to generate with current settings
682
- try:
683
- # Generate with live output
684
- for token_text, live_buffer, should_stop in self.generate_with_repetition_penalty(
685
- input_ids,
686
- max_new_tokens=self.config.max_new_tokens,
687
- temperature=current_temp,
688
- top_k=self.config.top_k,
689
- penalty=self.repetition_penalty,
690
- live=True
691
- ):
692
- # If we should stop but there's a token, this is the last one
693
- if should_stop:
694
- # Update with the final clean buffer (will have EOT removed if present)
695
- live_text = live_buffer
696
- break
697
-
698
- # Otherwise add the token and continue
699
- if token_text:
700
- live_text += token_text
701
- tokens_generated += 1
702
- yield token_text, live_text, False
703
-
704
- # Check if we got a valid response
705
- if not live_text or len(live_text.strip()) < 10:
706
- if self.debug_mode:
707
- print(colored("\n[Live generation produced empty or too short response, retrying]", "yellow"))
708
- generation_failed = True
709
- retry_count += 1
710
- # Clear any partial output
711
- if retry_count <= self.max_retries:
712
- print("\r" + " " * 80 + "\r", end="") # Clear the line
713
- else:
714
- # We got a valid response, stop retrying
715
- break
716
-
717
- except Exception as e:
718
- if self.debug_mode:
719
- print(colored(f"\n[Live generation error: {str(e)}, retrying]", "red"))
720
- generation_failed = True
721
- retry_count += 1
722
-
723
- # If we still failed after all retries, use the failure message
724
- if generation_failed or not live_text or len(live_text.strip()) < 10:
725
- live_text = self.generation_failure_message
726
- if self.debug_mode:
727
- print(colored(f"\n[Returning failure message after {retry_count} live retries]", "red"))
728
-
729
- # Calculate time taken and metrics
730
- time_taken = time.time() - start_time
731
- tokens_per_second = tokens_generated / time_taken if time_taken > 0 else 0
732
-
733
- # Update history
734
- self._update_history(user_input, live_text)
735
-
736
- # Log generation stats
737
- logger.debug(f"Generated {tokens_generated} tokens in {time_taken:.2f}s ({tokens_per_second:.2f} tokens/s)")
738
-
739
- # Final yield of the complete response
740
- yield "", live_text, True
741
-
742
- def execute_command(self, command):
743
- """Execute a special command prefixed with /."""
744
- command = command.strip()
745
-
746
- if command == '/help':
747
- self._print_welcome_message()
748
- return True
749
-
750
- elif command == '/clear':
751
- self.history = []
752
- self.history_tokens = []
753
- print(colored("Conversation history cleared.", 'yellow'))
754
- return True
755
-
756
- elif command in ['/exit', '/quit']:
757
- print(colored("Goodbye!", 'cyan'))
758
- return False # Signal to exit the chat loop
759
-
760
- elif command == '/stats':
761
- prompt_tokens = self.total_prompt_tokens
762
- generated_tokens = self.total_generated_tokens
763
- total_tokens = prompt_tokens + generated_tokens
764
-
765
- stats = f"""
766
- Token usage statistics:
767
- - Prompt tokens: {prompt_tokens}
768
- - Generated tokens: {generated_tokens}
769
- - Total tokens: {total_tokens}
770
- - Current history length: {len(self.history_tokens)} tokens
771
- - Current repetition penalty: {self.repetition_penalty}
772
- - Current temperature: {self.config.temperature}
773
- - Model: CosmicFish ({self.model.get_num_params() / 1e6:.1f}M parameters)
774
- - Source: {DEFAULT_MODEL_REPO}
775
- """
776
- print(colored(stats, 'yellow'))
777
- return True
778
-
779
- elif command == '/debug':
780
- self.debug_mode = not self.debug_mode
781
- self.config.debug_mode = self.debug_mode # Sync with config
782
- mode = "enabled" if self.debug_mode else "disabled"
783
- print(colored(f"Debug mode {mode}", 'yellow'))
784
- return True
785
-
786
- elif command.startswith('/penalty '):
787
- try:
788
- penalty = float(command[9:].strip())
789
- if 1.0 <= penalty <= 2.0:
790
- self.repetition_penalty = penalty
791
- print(colored(f"Repetition penalty set to {penalty}", 'yellow'))
792
- else:
793
- print(colored("Repetition penalty should be between 1.0 and 2.0", 'red'))
794
- except ValueError:
795
- print(colored("Invalid repetition penalty value. Please use a number between 1.0 and 2.0", 'red'))
796
- return True
797
-
798
- elif command.startswith('/temp '):
799
- try:
800
- temp = float(command[6:].strip())
801
- if 0.1 <= temp <= 2.0:
802
- self.config.temperature = temp
803
- print(colored(f"Temperature set to {temp}", 'yellow'))
804
- else:
805
- print(colored("Temperature should be between 0.1 and 2.0", 'red'))
806
- except ValueError:
807
- print(colored("Invalid temperature value. Please use a number between 0.1 and 2.0", 'red'))
808
- return True
809
-
810
- elif command.startswith('/save '):
811
- filename = command[6:].strip()
812
- if not filename:
813
- print(colored("Please specify a filename: /save <filename>", 'red'))
814
- return True
815
-
816
- try:
817
- # Create conversations directory if it doesn't exist
818
- os.makedirs('conversations', exist_ok=True)
819
-
820
- # Add .txt extension if not present
821
- if not filename.endswith('.txt'):
822
- filename += '.txt'
823
-
824
- filepath = os.path.join('conversations', filename)
825
-
826
- with open(filepath, 'w', encoding='utf-8') as f:
827
- for entry in self.history:
828
- role, text = entry
829
- prefix = self.human_prefix if role == "human" else self.assistant_prefix
830
- f.write(f"{prefix}{text}{self.end_of_turn}")
831
-
832
- print(colored(f"Conversation saved to {filepath}", 'green'))
833
-
834
- except Exception as e:
835
- print(colored(f"Error saving conversation: {str(e)}", 'red'))
836
-
837
- return True
838
-
839
- elif command.startswith('/load '):
840
- filename = command[6:].strip()
841
- if not filename:
842
- print(colored("Please specify a filename: /load <filename>", 'red'))
843
- return True
844
-
845
- try:
846
- # Add .txt extension if not present
847
- if not filename.endswith('.txt'):
848
- filename += '.txt'
849
-
850
- filepath = os.path.join('conversations', filename)
851
-
852
- if not os.path.exists(filepath):
853
- print(colored(f"File not found: {filepath}", 'red'))
854
- return True
855
-
856
- with open(filepath, 'r', encoding='utf-8') as f:
857
- content = f.read()
858
-
859
- # Parse conversation turns
860
- self.history = []
861
- self.history_tokens = []
862
-
863
- # Split by end of turn marker
864
- turns = content.split(self.end_of_turn)
865
- for turn in turns:
866
- turn = turn.strip()
867
- if not turn:
868
- continue
869
-
870
- if turn.startswith(self.human_prefix):
871
- text = turn[len(self.human_prefix):].strip()
872
- self.history.append(("human", text))
873
- elif turn.startswith(self.assistant_prefix):
874
- text = turn[len(self.assistant_prefix):].strip()
875
- self.history.append(("assistant", text))
876
-
877
- # Recalculate token counts
878
- self.history_tokens = []
879
- for entry in self.history:
880
- role, text = entry
881
- if role == "human":
882
- self.history_tokens.extend(self._tokenize(f"{self.human_prefix}{text}{self.end_of_turn}"))
883
- else:
884
- self.history_tokens.extend(self._tokenize(f"{self.assistant_prefix}{text}{self.end_of_turn}"))
885
-
886
- print(colored(f"Loaded conversation from {filepath} ({len(self.history) // 2} turns)", 'green'))
887
-
888
- # Print the conversation
889
- for i in range(0, len(self.history), 2):
890
- if i < len(self.history):
891
- user_text = self.history[i][1]
892
- print(colored(f"\nYou: {user_text}", 'green'))
893
-
894
- if i + 1 < len(self.history):
895
- assistant_text = self.history[i + 1][1]
896
- print(colored("CosmicFish: ", 'blue'), end="")
897
- for line in assistant_text.split('\n'):
898
- wrapped_lines = textwrap.wrap(line, width=100) if line.strip() else ['']
899
- for wrapped_line in wrapped_lines:
900
- print(wrapped_line)
901
-
902
- except Exception as e:
903
- print(colored(f"Error loading conversation: {str(e)}", 'red'))
904
-
905
- return True
906
-
907
- else:
908
- print(colored(f"Unknown command: {command}. Type /help for available commands.", 'red'))
909
- return True
910
-
911
-
912
- def download_cosmicfish_from_hub(model_repo=DEFAULT_MODEL_REPO, device='cpu'):
913
- """Download and load CosmicFish model from Hugging Face Hub"""
914
- print(colored(f"Downloading CosmicFish from Hugging Face: {model_repo}", "cyan"))
915
-
916
- try:
917
- # Download the model files to local cache
918
- print("Downloading model files...")
919
- cache_dir = snapshot_download(repo_id=model_repo, cache_dir=None)
920
- print(f"Model cached at: {cache_dir}")
921
-
922
- # Load config
923
- config_path = os.path.join(cache_dir, "config.json")
924
- with open(config_path, "r") as f:
925
- config_dict = json.load(f)
926
-
927
- # Create CosmicConfig
928
- config = CosmicConfig(
929
- vocab_size=config_dict["vocab_size"],
930
- block_size=config_dict["block_size"],
931
- n_layer=config_dict["n_layer"],
932
- n_head=config_dict["n_head"],
933
- n_embd=config_dict["n_embd"],
934
- bias=config_dict["bias"],
935
- dropout=0.0, # Set to 0 for inference
936
- eps=config_dict.get("eps", 1e-6),
937
- use_rotary=config_dict["use_rotary"],
938
- use_swiglu=config_dict["use_swiglu"],
939
- use_gqa=config_dict["use_gqa"],
940
- n_query_groups=config_dict["n_query_groups"],
941
- use_qk_norm=config_dict.get("use_qk_norm", False)
942
- )
943
-
944
- # Create model
945
- print("Creating model...")
946
- model = CosmicFish(config)
947
-
948
- # Load weights
949
- print("Loading weights...")
950
- weights_path = os.path.join(cache_dir, "pytorch_model.bin")
951
- state_dict = torch.load(weights_path, map_location=device)
952
- model.load_state_dict(state_dict)
953
- model.to(device)
954
- model.eval()
955
-
956
- print(f"Model loaded: {model.get_num_params() / 1e6:.1f}M parameters")
957
- print(f"Device: {device}")
958
- return model, config
959
-
960
- except Exception as e:
961
- print(colored(f"Error downloading/loading model: {str(e)}", "red"))
962
- print(colored("Make sure you have internet connection and the model repo exists", "yellow"))
963
- sys.exit(1)
964
-
965
-
966
- def load_tokenizer():
967
- print("Loading tokenizer...")
968
- tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
969
- print("Tokenizer loaded")
970
- return tokenizer
971
-
972
-
973
- def main():
974
- parser = argparse.ArgumentParser(description="Chat with CosmicFish model from Hugging Face Hub")
975
-
976
- # Model parameters
977
- parser.add_argument("--model_repo", type=str, default=DEFAULT_MODEL_REPO,
978
- help=f"Hugging Face model repository (default: {DEFAULT_MODEL_REPO})")
979
- parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu",
980
- help="Device to use (cuda or cpu)")
981
-
982
- # Generation parameters
983
- parser.add_argument("--temperature", type=float, default=0.7,
984
- help="Temperature for sampling (default: 0.7)")
985
- parser.add_argument("--max_tokens", type=int, default=512,
986
- help="Maximum number of tokens to generate per response")
987
- parser.add_argument("--min_tokens", type=int, default=10,
988
- help="Minimum number of tokens to generate per response")
989
- parser.add_argument("--top_k", type=int, default=40,
990
- help="Top-k sampling (0 to disable)")
991
- parser.add_argument("--repetition_penalty", type=float, default=1.2,
992
- help="Repetition penalty (1.0 = no penalty, 1.2 = mild, 1.5 = moderate)")
993
-
994
- # Chat parameters
995
- parser.add_argument("--human_prefix", type=str, default="Human: ",
996
- help="Prefix for human messages")
997
- parser.add_argument("--assistant_prefix", type=str, default="Assistant: ",
998
- help="Prefix for assistant messages")
999
- parser.add_argument("--end_of_turn", type=str, default="\n\n",
1000
- help="Delimiter between conversation turns")
1001
- parser.add_argument("--instruction", type=str,
1002
- default=DEFAULT_PROMPT_TEMPLATE,
1003
- help="Instruction prompt to prepend to the conversation")
1004
- parser.add_argument("--max_history", type=int, default=1024,
1005
- help="Maximum number of tokens to keep in history")
1006
-
1007
- # UI parameters
1008
- parser.add_argument("--no_welcome", action="store_true",
1009
- help="Don't display the welcome message")
1010
- parser.add_argument("--debug", action="store_true",
1011
- help="Enable debug mode")
1012
-
1013
- args = parser.parse_args()
1014
-
1015
- # Configure device
1016
- device = args.device
1017
- if device == "cuda" and not torch.cuda.is_available():
1018
- print(colored("CUDA is not available, falling back to CPU", "yellow"))
1019
- device = "cpu"
1020
-
1021
- try:
1022
- # Download and load the model from HF Hub
1023
- model, model_config = download_cosmicfish_from_hub(args.model_repo, device)
1024
-
1025
- # Load tokenizer
1026
- tokenizer = load_tokenizer()
1027
-
1028
- # Create a config object with all the necessary parameters
1029
- class ChatConfig:
1030
- def __init__(self, args, block_size):
1031
- self.device = device
1032
- self.temperature = args.temperature
1033
- self.max_new_tokens = args.max_tokens
1034
- self.min_tokens_to_generate = args.min_tokens
1035
- self.top_k = args.top_k
1036
- self.human_prefix = args.human_prefix
1037
- self.assistant_prefix = args.assistant_prefix
1038
- self.end_of_turn = args.end_of_turn
1039
- self.prompt_template = args.instruction
1040
- self.max_history_tokens = args.max_history
1041
- self.display_welcome = not args.no_welcome
1042
- self.block_size = block_size
1043
- self.debug_mode = args.debug
1044
- self.repetition_penalty = args.repetition_penalty
1045
-
1046
- config = ChatConfig(args, model_config.block_size)
1047
-
1048
- # Initialize chat session
1049
- chat = CosmicFishChatSession(model, tokenizer, config)
1050
-
1051
- # Main chat loop
1052
- print(colored("\nCosmicFish initialized! Type your message (or /help for commands).\n", 'cyan'))
1053
-
1054
- while True:
1055
- try:
1056
- # Get user input
1057
- user_input = input(colored("You: ", 'green'))
1058
-
1059
- # Check if it's a command
1060
- if user_input.startswith('/'):
1061
- # Execute command, continue loop if True, exit if False
1062
- if not chat.execute_command(user_input):
1063
- break
1064
- continue
1065
-
1066
- # Skip if empty input
1067
- if not user_input.strip():
1068
- continue
1069
-
1070
- # Generate response using live generation
1071
- live_buffer = ""
1072
- final_response = None
1073
-
1074
- # Use the generator version
1075
- response_generator = chat.generate_response(user_input)
1076
-
1077
- try:
1078
- # First print the assistant prefix
1079
- print(colored("CosmicFish: ", 'blue'), end="")
1080
- sys.stdout.flush()
1081
-
1082
- for token, live_text, is_done in response_generator:
1083
- # If this is the final clean response
1084
- if is_done:
1085
- final_response = live_text
1086
- # Print the final response directly if we didn't get any tokens yet
1087
- if not live_buffer:
1088
- print(final_response, end="")
1089
- break
1090
-
1091
- # If we have a token to display
1092
- if token:
1093
- # Check if token contains <|endoftext|> and remove it if present
1094
- if "<|endoftext|>" in token:
1095
- token = token.replace("<|endoftext|>", "")
1096
- if token: # Only print if there's anything left
1097
- print(token, end="", flush=True)
1098
- break
1099
-
1100
- # Display it
1101
- print(token, end="", flush=True)
1102
- live_buffer += token
1103
-
1104
- except KeyboardInterrupt:
1105
- # Allow user to interrupt generation
1106
- print("\n[Generation interrupted]")
1107
- final_response = "I was going to respond, but I'll stop here since you interrupted."
1108
-
1109
- # Add an extra line for readability
1110
- print()
1111
-
1112
- except KeyboardInterrupt:
1113
- print("\n\nKeyboard interrupt detected. Type /exit to quit or continue chatting.")
1114
-
1115
- except Exception as e:
1116
- print(colored(f"\nError: {str(e)}", 'red'))
1117
- logger.error(f"Error in chat loop: {str(e)}", exc_info=True)
1118
-
1119
- except Exception as e:
1120
- print(colored(f"Error setting up chat: {str(e)}", 'red'))
1121
- logger.error(f"Error setting up chat: {str(e)}", exc_info=True)
1122
- sys.exit(1)
1123
-
1124
-
1125
- if __name__ == "__main__":
1126
- try:
1127
- main()
1128
- except Exception as e:
1129
- logger.error(f"Fatal error: {str(e)}", exc_info=True)
1130
- sys.exit(1)