akkiisfrommars commited on
Commit
eed5764
·
verified ·
1 Parent(s): 9280013

Upload 2 files

Browse files
Files changed (2) hide show
  1. chat_HF.py +1146 -0
  2. model.safetensors +3 -0
chat_HF.py ADDED
@@ -0,0 +1,1146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Chat interface for CosmicFish model downloaded from Hugging Face Hub.
3
+ Uses safetensors format only for secure model loading.
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import time
9
+ import argparse
10
+ import torch
11
+ import numpy as np
12
+ from termcolor import colored
13
+ import logging
14
+ import readline
15
+ import re
16
+ import textwrap
17
+ import random
18
+ from collections import defaultdict
19
+ import json
20
+
21
+ # Required imports for HF Hub
22
+ try:
23
+ from transformers import GPT2Tokenizer
24
+ from huggingface_hub import hf_hub_download, snapshot_download
25
+ HF_AVAILABLE = True
26
+ except ImportError:
27
+ HF_AVAILABLE = False
28
+ print("Required libraries not available.")
29
+ print("Install with: pip install transformers huggingface-hub")
30
+ sys.exit(1)
31
+
32
+ # Required for safetensors
33
+ try:
34
+ from safetensors.torch import load_file
35
+ SAFETENSORS_AVAILABLE = True
36
+ except ImportError:
37
+ SAFETENSORS_AVAILABLE = False
38
+ print("Safetensors not available. Install with: pip install safetensors")
39
+ sys.exit(1)
40
+
41
+ # Set up logging
42
+ logging.basicConfig(
43
+ level=logging.INFO,
44
+ format='%(asctime)s - %(levelname)s - %(message)s',
45
+ handlers=[logging.StreamHandler(sys.stdout)]
46
+ )
47
+ logger = logging.getLogger(__name__)
48
+
49
+ # Default model repository
50
+ DEFAULT_MODEL_REPO = "MistyozAI/CosmicFish-120M"
51
+
52
+ # Default prompt template
53
+ DEFAULT_PROMPT_TEMPLATE = "Below is a conversation between a helpful AI assistant and a human. The assistant is knowledgeable, friendly, and provides detailed and accurate responses.\n\n"
54
+
55
+
56
+ class CosmicConfig:
57
+ """Configuration class for CosmicFish."""
58
+
59
+ def __init__(self,
60
+ vocab_size=50257,
61
+ block_size=512,
62
+ n_layer=12,
63
+ n_head=16,
64
+ n_embd=704,
65
+ bias=True,
66
+ dropout=0.0,
67
+ n_query_groups=4,
68
+ eps=1e-6,
69
+ use_rotary=True,
70
+ use_swiglu=True,
71
+ use_qk_norm=False,
72
+ use_gqa=True):
73
+ self.vocab_size = vocab_size
74
+ self.block_size = block_size
75
+ self.n_layer = n_layer
76
+ self.n_head = n_head
77
+ self.n_embd = n_embd
78
+ self.bias = bias
79
+ self.dropout = dropout
80
+ self.eps = eps
81
+ self.use_rotary = use_rotary
82
+ self.use_swiglu = use_swiglu
83
+ self.use_qk_norm = use_qk_norm
84
+ self.use_gqa = use_gqa
85
+ self.n_query_groups = n_query_groups if use_gqa else n_head
86
+ # Ensure n_head is divisible by n_query_groups
87
+ assert n_head % self.n_query_groups == 0, "n_head must be divisible by n_query_groups"
88
+
89
+
90
+ class RMSNorm(torch.nn.Module):
91
+ """Root Mean Square Normalization"""
92
+
93
+ def __init__(self, dim, eps=1e-6):
94
+ super().__init__()
95
+ self.eps = eps
96
+ self.weight = torch.nn.Parameter(torch.ones(dim))
97
+
98
+ def forward(self, x):
99
+ rms = torch.sqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
100
+ return self.weight * (x / rms)
101
+
102
+
103
+ def precompute_freqs_cis(dim, end, theta=10000.0):
104
+ """Precompute the frequency tensor for complex exponentials (cis)"""
105
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
106
+ t = torch.arange(end, device=freqs.device)
107
+ freqs = torch.outer(t, freqs)
108
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
109
+ return freqs_cis
110
+
111
+
112
+ def apply_rotary_emb(xq, xk, freqs_cis):
113
+ """Apply rotary embeddings to input tensors"""
114
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
115
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
116
+
117
+ seq_len = xq_.size(2)
118
+ if freqs_cis.size(0) < seq_len:
119
+ raise ValueError(f"freqs_cis has only {freqs_cis.size(0)} values but sequence length is {seq_len}")
120
+
121
+ freqs_cis_seq = freqs_cis[:seq_len]
122
+ xq_out = torch.view_as_real(xq_ * freqs_cis_seq.unsqueeze(0)).flatten(3)
123
+ xk_out = torch.view_as_real(xk_ * freqs_cis_seq.unsqueeze(0)).flatten(3)
124
+
125
+ return xq_out.type_as(xq), xk_out.type_as(xk)
126
+
127
+
128
+ class GroupedQueryAttention(torch.nn.Module):
129
+ """Grouped Query Attention (GQA) implementation"""
130
+
131
+ def __init__(self, config):
132
+ super().__init__()
133
+ assert config.n_embd % config.n_head == 0
134
+
135
+ head_dim = config.n_embd // config.n_head
136
+ self.head_dim = head_dim
137
+ self.n_head = config.n_head
138
+ self.n_embd = config.n_embd
139
+ self.n_query_groups = config.n_query_groups
140
+
141
+ self.kv_heads = config.n_head // config.n_query_groups if config.use_gqa else config.n_head
142
+ qkv_proj_size = (config.n_head + 2 * self.kv_heads) * head_dim
143
+
144
+ self.c_attn = torch.nn.Linear(config.n_embd, qkv_proj_size, bias=config.bias)
145
+ self.c_proj = torch.nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
146
+
147
+ # Flash attention support
148
+ self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
149
+ if not self.flash:
150
+ self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
151
+ .view(1, 1, config.block_size, config.block_size))
152
+
153
+ # Query-key normalization
154
+ self.qk_norm = getattr(config, 'use_qk_norm', False)
155
+ if self.qk_norm:
156
+ self.q_norm = RMSNorm(head_dim, eps=getattr(config, 'eps', 1e-6))
157
+ self.k_norm = RMSNorm(head_dim, eps=getattr(config, 'eps', 1e-6))
158
+
159
+ def forward(self, x, freqs_cis=None):
160
+ B, T, C = x.size()
161
+ qkv = self.c_attn(x)
162
+ head_dim = C // self.n_head
163
+
164
+ q_size = self.n_head * head_dim
165
+ k_size = self.kv_heads * head_dim
166
+ v_size = self.kv_heads * head_dim
167
+
168
+ q, k, v = qkv.split([q_size, k_size, v_size], dim=2)
169
+
170
+ q = q.view(B, T, self.n_head, head_dim).transpose(1, 2)
171
+ k = k.view(B, T, self.kv_heads, head_dim).transpose(1, 2)
172
+ v = v.view(B, T, self.kv_heads, head_dim).transpose(1, 2)
173
+
174
+ # Repeat k and v if needed for GQA
175
+ if self.kv_heads < self.n_head:
176
+ repeats = self.n_head // self.kv_heads
177
+ k = k.repeat_interleave(repeats, dim=1)
178
+ v = v.repeat_interleave(repeats, dim=1)
179
+
180
+ # Apply rotary embeddings
181
+ if freqs_cis is not None:
182
+ q, k = apply_rotary_emb(q, k, freqs_cis)
183
+
184
+ # Apply query-key normalization
185
+ if self.qk_norm:
186
+ q = self.q_norm(q)
187
+ k = self.k_norm(k)
188
+
189
+ # Compute attention
190
+ if self.flash:
191
+ y = torch.nn.functional.scaled_dot_product_attention(
192
+ q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True
193
+ )
194
+ else:
195
+ att = (q @ k.transpose(-2, -1)) * (1.0 / torch.sqrt(torch.tensor(k.size(-1), dtype=torch.float32)))
196
+ att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
197
+ att = torch.nn.functional.softmax(att, dim=-1)
198
+ y = att @ v
199
+
200
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
201
+ y = self.c_proj(y)
202
+ return y
203
+
204
+
205
+ class Block(torch.nn.Module):
206
+ """Transformer block"""
207
+
208
+ def __init__(self, config):
209
+ super().__init__()
210
+ self.ln_1 = RMSNorm(config.n_embd, eps=config.eps)
211
+ self.ln_2 = RMSNorm(config.n_embd, eps=config.eps)
212
+ self.attn = GroupedQueryAttention(config)
213
+
214
+ # MLP implementation based on configuration
215
+ if config.use_swiglu:
216
+ # SwiGLU MLP
217
+ self.mlp = torch.nn.ModuleDict(dict(
218
+ gate=torch.nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias),
219
+ up=torch.nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias),
220
+ down=torch.nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias),
221
+ act=torch.nn.SiLU(),
222
+ ))
223
+ m = self.mlp
224
+ self.mlpf = lambda x: m.down(m.act(m.up(x)) * m.gate(x))
225
+ else:
226
+ # Traditional MLP
227
+ self.mlp = torch.nn.ModuleDict(dict(
228
+ c_fc=torch.nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias),
229
+ c_proj=torch.nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias),
230
+ act=torch.nn.GELU(),
231
+ ))
232
+ m = self.mlp
233
+ self.mlpf = lambda x: m.c_proj(m.act(m.c_fc(x)))
234
+
235
+ def forward(self, x, freqs_cis=None):
236
+ x = x + self.attn(self.ln_1(x), freqs_cis)
237
+ x = x + self.mlpf(self.ln_2(x))
238
+ return x
239
+
240
+
241
+ class CosmicFish(torch.nn.Module):
242
+ """
243
+ CosmicFish model for inference only.
244
+ Features: Rotary Positional Embeddings, Grouped-Query Attention, SwiGLU, RMSNorm
245
+ """
246
+
247
+ def __init__(self, config):
248
+ super().__init__()
249
+ self.config = config
250
+
251
+ self.transformer = torch.nn.ModuleDict(dict(
252
+ wte=torch.nn.Embedding(config.vocab_size, config.n_embd),
253
+ h=torch.nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
254
+ ln_f=RMSNorm(config.n_embd, eps=config.eps),
255
+ ))
256
+
257
+ self.lm_head = torch.nn.Linear(config.n_embd, config.vocab_size, bias=False)
258
+
259
+ # Share weights between embedding and output
260
+ self.transformer.wte.weight = self.lm_head.weight
261
+
262
+ # Precompute rotary embedding frequencies
263
+ if config.use_rotary:
264
+ head_dim = config.n_embd // config.n_head
265
+ self.freqs_cis = precompute_freqs_cis(head_dim, config.block_size)
266
+ else:
267
+ self.freqs_cis = None
268
+ self.transformer.wpe = torch.nn.Embedding(config.block_size, config.n_embd)
269
+
270
+ def get_num_params(self, non_embedding=True):
271
+ """Return the number of parameters in the model."""
272
+ n_params = sum(p.numel() for p in self.parameters())
273
+ if non_embedding and hasattr(self.transformer, 'wpe'):
274
+ n_params -= self.transformer.wpe.weight.numel()
275
+ return n_params
276
+
277
+ def forward(self, idx, targets=None):
278
+ """Forward pass through the model."""
279
+ device = idx.device
280
+ b, t = idx.size()
281
+ assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
282
+
283
+ # Get token embeddings
284
+ tok_emb = self.transformer.wte(idx)
285
+
286
+ # Handle positional embeddings
287
+ if self.config.use_rotary:
288
+ x = tok_emb
289
+ freqs_cis = self.freqs_cis.to(device) if self.freqs_cis is not None else None
290
+ else:
291
+ pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0)
292
+ pos_emb = self.transformer.wpe(pos)
293
+ x = tok_emb + pos_emb
294
+ freqs_cis = None
295
+
296
+ # Apply transformer blocks
297
+ for block in self.transformer.h:
298
+ x = block(x, freqs_cis)
299
+
300
+ # Apply final normalization
301
+ x = self.transformer.ln_f(x)
302
+
303
+ # Calculate outputs
304
+ if targets is not None:
305
+ logits = self.lm_head(x)
306
+ loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
307
+ else:
308
+ # For inference, only compute logits for the last token
309
+ logits = self.lm_head(x[:, [-1], :])
310
+ loss = None
311
+
312
+ return logits, loss
313
+
314
+ @torch.no_grad()
315
+ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
316
+ """
317
+ Generate text by sampling from the model, token by token.
318
+ """
319
+ for _ in range(max_new_tokens):
320
+ # Crop sequence to block size if needed
321
+ idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
322
+
323
+ # Forward pass
324
+ logits, _ = self(idx_cond)
325
+ logits = logits[:, -1, :] / temperature
326
+
327
+ # Apply top-k sampling
328
+ if top_k is not None:
329
+ v, _ = torch.topk(logits, top_k)
330
+ logits[logits < v[:, [-1]]] = -float('Inf')
331
+
332
+ # Sample next token
333
+ probs = torch.nn.functional.softmax(logits, dim=-1)
334
+ idx_next = torch.multinomial(probs, num_samples=1)
335
+
336
+ # Append to sequence
337
+ idx = torch.cat((idx, idx_next), dim=1)
338
+
339
+ return idx
340
+
341
+
342
+ class RepetitionPenaltyLogitsProcessor:
343
+ """Apply repetition penalty to prevent repeating tokens."""
344
+
345
+ def __init__(self, penalty=1.2):
346
+ self.penalty = penalty
347
+
348
+ def __call__(self, input_ids, scores):
349
+ """Apply repetition penalty to logits where input_ids is already seen."""
350
+ score = torch.gather(scores, 1, input_ids)
351
+ # If score > 0, penalize by dividing; if score < 0, penalize by multiplying
352
+ score = torch.where(score > 0, score / self.penalty, score * self.penalty)
353
+ scores.scatter_(1, input_ids, score)
354
+ return scores
355
+
356
+
357
+ class CosmicFishChatSession:
358
+ """Chat session for CosmicFish model from Hugging Face Hub."""
359
+
360
+ def __init__(self, model, tokenizer, config):
361
+ """Initialize chat session with model and configuration."""
362
+ self.model = model
363
+ self.tokenizer = tokenizer
364
+ self.config = config
365
+ self.device = next(model.parameters()).device
366
+ self.history = []
367
+ self.history_tokens = []
368
+ self.max_history_tokens = config.max_history_tokens
369
+ self.prompt_template = config.prompt_template
370
+ self.human_prefix = config.human_prefix
371
+ self.assistant_prefix = config.assistant_prefix
372
+ self.end_of_turn = config.end_of_turn
373
+ self.block_size = config.block_size
374
+ self.debug_mode = config.debug_mode
375
+ self.repetition_penalty = config.repetition_penalty
376
+ self.min_tokens_to_generate = config.min_tokens_to_generate
377
+ self.max_retries = 20
378
+
379
+ self.fallback_responses = [
380
+ "I'd be happy to help with that. Could you provide more details about what specific information you're looking for?",
381
+ "That's a topic I can provide information about. What specific aspects would you like to know?",
382
+ "I understand your question. I can share factual information on this topic if you could specify what aspects you're interested in.",
383
+ "I can help with your question. To give you the most relevant information, could you clarify what specific details you're looking for?",
384
+ "I'd be glad to address your question. To provide the most helpful response, could you specify what particular aspects of this topic interest you?"
385
+ ]
386
+
387
+ self.generation_failure_message = "I'm sorry, but I'm having difficulty generating a response to that prompt. Could you try rephrasing your question or asking something else?"
388
+
389
+ # For token counting
390
+ self.total_prompt_tokens = 0
391
+ self.total_generated_tokens = 0
392
+
393
+ # End markers for live generation
394
+ self.end_markers = [
395
+ f"{self.human_prefix}",
396
+ "Human:",
397
+ "\nHuman:",
398
+ "\nH:",
399
+ "H:",
400
+ "<|endoftext|>",
401
+ "Below is a conversation",
402
+ "\nA:",
403
+ "A:",
404
+ "</s>",
405
+ "User:",
406
+ "\nUser:"
407
+ ]
408
+
409
+ if config.display_welcome:
410
+ self._print_welcome_message()
411
+
412
+ def _print_welcome_message(self):
413
+ welcome_text = f"""
414
+ {'=' * 80}
415
+ Welcome to CosmicFish chat interface
416
+
417
+ This is a {self.model.get_num_params() / 1e6:.1f}M parameter model.
418
+ CosmicFish is an efficient LLM with an advanced architecture.
419
+
420
+ Type your prompts and CosmicFish will respond.
421
+
422
+ Special commands:
423
+ - /help: Show this help message
424
+ - /clear: Clear the conversation history
425
+ - /exit or /quit: Exit the chat
426
+ - /stats: Show token usage statistics
427
+ - /save [filename]: Save the conversation
428
+ - /load [filename]: Load a conversation
429
+ - /temp [value]: Set temperature (between 0.1 and 2.0)
430
+ - /penalty [value]: Set repetition penalty (1.0-2.0)
431
+ - /debug: Toggle debug mode
432
+
433
+
434
+ Note: CosmicFIsh may generate incorrect or fictional responses. Verify facts if needed.
435
+
436
+ Visit https://cosmicfish.ai for more info
437
+
438
+
439
+ Developed by Mistyoz AI (https://www.mistyoz.com)
440
+ {'=' * 80}
441
+ """
442
+ print(colored(welcome_text, 'cyan'))
443
+
444
+ def _format_prompt(self, user_input):
445
+ """Format the complete prompt with history and current input."""
446
+ # Start with the template
447
+ formatted_prompt = self.prompt_template
448
+
449
+ # Add conversation history
450
+ for entry in self.history:
451
+ role, text = entry
452
+ if role == "human":
453
+ formatted_prompt += f"{self.human_prefix}{text}{self.end_of_turn}"
454
+ else: # assistant
455
+ formatted_prompt += f"{self.assistant_prefix}{text}{self.end_of_turn}"
456
+
457
+ # Add the current user input
458
+ formatted_prompt += f"{self.human_prefix}{user_input}{self.end_of_turn}{self.assistant_prefix}"
459
+
460
+ return formatted_prompt
461
+
462
+ def _tokenize(self, text):
463
+ """Tokenize text and return token IDs."""
464
+ return self.tokenizer.encode(text)
465
+
466
+ def _update_history(self, user_input, response):
467
+ """Update conversation history."""
468
+ # Add to text history
469
+ self.history.append(("human", user_input))
470
+ self.history.append(("assistant", response))
471
+
472
+ # Update token history for context window management
473
+ user_tokens = self._tokenize(f"{self.human_prefix}{user_input}{self.end_of_turn}")
474
+ response_tokens = self._tokenize(f"{self.assistant_prefix}{response}{self.end_of_turn}")
475
+
476
+ self.history_tokens.extend(user_tokens)
477
+ self.history_tokens.extend(response_tokens)
478
+
479
+ # Track token usage
480
+ self.total_prompt_tokens += len(user_tokens)
481
+ self.total_generated_tokens += len(response_tokens)
482
+
483
+ # Trim history if it gets too long
484
+ self._trim_history_if_needed()
485
+
486
+ def _trim_history_if_needed(self):
487
+ """Trim history to fit within the context window."""
488
+ if len(self.history_tokens) > self.max_history_tokens:
489
+ # Remove oldest turns until we're under the limit
490
+ while len(self.history_tokens) > self.max_history_tokens and len(self.history) >= 2:
491
+ # Remove oldest human and assistant turn
492
+ self.history = self.history[2:]
493
+
494
+ # Find token boundary for the removed turns
495
+ user_turn = self.history[0][1]
496
+ assistant_turn = self.history[1][1]
497
+ user_tokens = len(self._tokenize(f"{self.human_prefix}{user_turn}{self.end_of_turn}"))
498
+ assistant_tokens = len(self._tokenize(f"{self.assistant_prefix}{assistant_turn}{self.end_of_turn}"))
499
+
500
+ # Update token history
501
+ self.history_tokens = self.history_tokens[user_tokens + assistant_tokens:]
502
+
503
+ def _should_stop_generation(self, text):
504
+ """Check if generation should stop based on end markers."""
505
+ for marker in self.end_markers:
506
+ if marker in text:
507
+ return True
508
+ return False
509
+
510
+ def _clean_token_text(self, text):
511
+ text = text.replace('��', "'")
512
+ text = text.replace('�', "'")
513
+ text = text.replace('\ufffd', "'")
514
+ text = text.replace('\uFFFD', "'")
515
+ text = text.replace('’', "'")
516
+ text = text.replace('â€Å"', "'")
517
+ text = text.replace('�', "'")
518
+ text = text.replace('â€"', "'")
519
+ text = text.replace('â€"', "'")
520
+ return text
521
+
522
+ def generate_with_repetition_penalty(self, input_ids, max_new_tokens, temperature, top_k, penalty=1.2, live=False):
523
+ """Custom generate function with repetition penalty and optional live generation."""
524
+ model = self.model
525
+ device = self.device
526
+
527
+ # Ensure model is in eval mode
528
+ model.eval()
529
+
530
+ # Initialize sequence with input_ids
531
+ generated = input_ids.clone()
532
+
533
+ # Initialize live text buffer
534
+ live_buffer = ""
535
+
536
+ # Create repetition penalty processor
537
+ rep_processor = RepetitionPenaltyLogitsProcessor(penalty=penalty)
538
+
539
+ # Counter for generated tokens
540
+ tokens_generated = 0
541
+ min_tokens = self.min_tokens_to_generate
542
+
543
+ # EOT token ID
544
+ eot_token_id = self.tokenizer.eos_token_id if hasattr(self.tokenizer, 'eos_token_id') else 50256
545
+
546
+ # Generate tokens one at a time
547
+ for _ in range(max_new_tokens):
548
+ # Get only the last block_size tokens if context is too long
549
+ if generated.size(1) > self.block_size:
550
+ context = generated[:, -self.block_size:]
551
+ else:
552
+ context = generated
553
+
554
+ # Forward pass for next token prediction
555
+ with torch.no_grad():
556
+ logits, _ = model(context)
557
+
558
+ # Get logits for the next token (last position)
559
+ next_token_logits = logits[:, -1, :]
560
+
561
+ # Apply temperature
562
+ next_token_logits = next_token_logits / temperature
563
+
564
+ # Apply repetition penalty
565
+ if penalty > 1.0:
566
+ next_token_logits = rep_processor(context, next_token_logits)
567
+
568
+ # Optional top-k sampling
569
+ if top_k is not None:
570
+ indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
571
+ next_token_logits[indices_to_remove] = float('-inf')
572
+
573
+ # Convert logits to probabilities
574
+ probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
575
+
576
+ # Sample next token
577
+ next_token = torch.multinomial(probs, num_samples=1)
578
+
579
+ # Check if the next token is EOT and break immediately if so
580
+ if next_token.item() == eot_token_id:
581
+ if live:
582
+ yield "", live_buffer, True
583
+ break
584
+
585
+ # Append next token to generated sequence
586
+ generated = torch.cat((generated, next_token), dim=1)
587
+ tokens_generated += 1
588
+
589
+ # If live generation, decode and yield the token
590
+ if live:
591
+ # Decode the next token
592
+ next_token_text = self.tokenizer.decode([next_token.item()])
593
+ # Clean the token text to fix encoding issues
594
+ next_token_text = self._clean_token_text(next_token_text)
595
+ live_buffer += next_token_text
596
+
597
+ # Check if we've hit an end marker in the buffer
598
+ eot_marker_pos = live_buffer.find("<|endoftext|>")
599
+ if eot_marker_pos != -1:
600
+ # Cut off at the EOT marker
601
+ live_buffer = live_buffer[:eot_marker_pos]
602
+ yield "", live_buffer, True
603
+ break
604
+
605
+ # Check other end markers
606
+ should_stop = tokens_generated >= min_tokens and self._should_stop_generation(live_buffer)
607
+ yield next_token_text, live_buffer, should_stop
608
+
609
+ if should_stop:
610
+ break
611
+
612
+ # For non-live generation, check if we should stop
613
+ elif tokens_generated >= min_tokens:
614
+ # Check for end markers in the recent generated tokens
615
+ recent_text = self.tokenizer.decode(generated[0, -20:].tolist())
616
+ if self._should_stop_generation(recent_text):
617
+ break
618
+
619
+ # Check if we generated any tokens at all
620
+ if tokens_generated == 0 and not live:
621
+ if self.debug_mode:
622
+ print(colored("\n[No tokens generated in this attempt]", "red"))
623
+ return None
624
+
625
+ if not live:
626
+ return generated
627
+
628
+ def generate_response(self, user_input):
629
+ """Generate a response to the user input."""
630
+ # Format the complete prompt
631
+ prompt = self._format_prompt(user_input)
632
+
633
+ # Tokenize the prompt
634
+ input_ids = torch.tensor(self._tokenize(prompt), dtype=torch.long).unsqueeze(0).to(self.device)
635
+
636
+ # Ensure we don't exceed the model's context length
637
+ if input_ids.size(1) > self.block_size:
638
+ # If too long, keep the beginning part with the instruction template and trim the middle
639
+ instruction_tokens = self._tokenize(self.prompt_template)
640
+ # Keep the instruction and the most recent conversation that will fit
641
+ keep_from_beginning = len(instruction_tokens)
642
+ keep_from_end = self.block_size - keep_from_beginning
643
+
644
+ # Combine beginning and end, ensuring we don't exceed array bounds
645
+ if keep_from_end < 0:
646
+ # If instruction alone is too long, trim it (shouldn't happen with reasonable templates)
647
+ input_ids = input_ids[:, :self.block_size]
648
+ else:
649
+ # Keep instruction and most recent conversation
650
+ input_ids = torch.cat([
651
+ input_ids[:, :keep_from_beginning],
652
+ input_ids[:, -(keep_from_end):]
653
+ ], dim=1)
654
+
655
+ # Track generation start time
656
+ start_time = time.time()
657
+
658
+ # Always use live generation
659
+ return self._generate_live_response(input_ids, user_input, start_time)
660
+
661
+ def _generate_live_response(self, input_ids, user_input, start_time):
662
+ """Generate response with live token-by-token output."""
663
+ # Initialize for live generation
664
+ live_text = ""
665
+ tokens_generated = 0
666
+ retry_count = 0
667
+
668
+ # Keep trying until we get a valid response or exhaust retries
669
+ while retry_count <= self.max_retries:
670
+ if retry_count > 0:
671
+ # Calculate temperature for this retry
672
+ if retry_count % 2 == 0:
673
+ # Even retries: increase temperature
674
+ temp_adjustment = min(0.2 * (retry_count // 2), 0.8)
675
+ current_temp = min(self.config.temperature + temp_adjustment, 1.8)
676
+ else:
677
+ # Odd retries: decrease temperature
678
+ temp_adjustment = min(0.2 * ((retry_count + 1) // 2), 0.4)
679
+ current_temp = max(self.config.temperature - temp_adjustment, 0.2)
680
+
681
+ if self.debug_mode:
682
+ print(colored(f"\n[Live retry {retry_count}: Using temperature {current_temp:.2f}]", "yellow"))
683
+ else:
684
+ current_temp = self.config.temperature
685
+
686
+ # Reset for this attempt
687
+ live_text = ""
688
+ tokens_generated = 0
689
+ generation_failed = False
690
+
691
+ # Try to generate with current settings
692
+ try:
693
+ # Generate with live output
694
+ for token_text, live_buffer, should_stop in self.generate_with_repetition_penalty(
695
+ input_ids,
696
+ max_new_tokens=self.config.max_new_tokens,
697
+ temperature=current_temp,
698
+ top_k=self.config.top_k,
699
+ penalty=self.repetition_penalty,
700
+ live=True
701
+ ):
702
+ # If we should stop but there's a token, this is the last one
703
+ if should_stop:
704
+ # Update with the final clean buffer (will have EOT removed if present)
705
+ live_text = live_buffer
706
+ break
707
+
708
+ # Otherwise add the token and continue
709
+ if token_text:
710
+ live_text += token_text
711
+ tokens_generated += 1
712
+ yield token_text, live_text, False
713
+
714
+ # Check if we got a valid response
715
+ if not live_text or len(live_text.strip()) < 10:
716
+ if self.debug_mode:
717
+ print(colored("\n[Live generation produced empty or too short response, retrying]", "yellow"))
718
+ generation_failed = True
719
+ retry_count += 1
720
+ # Clear any partial output
721
+ if retry_count <= self.max_retries:
722
+ print("\r" + " " * 80 + "\r", end="") # Clear the line
723
+ else:
724
+ # We got a valid response, stop retrying
725
+ break
726
+
727
+ except Exception as e:
728
+ if self.debug_mode:
729
+ print(colored(f"\n[Live generation error: {str(e)}, retrying]", "red"))
730
+ generation_failed = True
731
+ retry_count += 1
732
+
733
+ # If we still failed after all retries, use the failure message
734
+ if generation_failed or not live_text or len(live_text.strip()) < 10:
735
+ live_text = self.generation_failure_message
736
+ if self.debug_mode:
737
+ print(colored(f"\n[Returning failure message after {retry_count} live retries]", "red"))
738
+
739
+ # Calculate time taken and metrics
740
+ time_taken = time.time() - start_time
741
+ tokens_per_second = tokens_generated / time_taken if time_taken > 0 else 0
742
+
743
+ # Update history
744
+ self._update_history(user_input, live_text)
745
+
746
+ # Log generation stats
747
+ logger.debug(f"Generated {tokens_generated} tokens in {time_taken:.2f}s ({tokens_per_second:.2f} tokens/s)")
748
+
749
+ # Final yield of the complete response
750
+ yield "", live_text, True
751
+
752
+ def execute_command(self, command):
753
+ """Execute a special command prefixed with /."""
754
+ command = command.strip()
755
+
756
+ if command == '/help':
757
+ self._print_welcome_message()
758
+ return True
759
+
760
+ elif command == '/clear':
761
+ self.history = []
762
+ self.history_tokens = []
763
+ print(colored("Conversation history cleared.", 'yellow'))
764
+ return True
765
+
766
+ elif command in ['/exit', '/quit']:
767
+ print(colored("Goodbye!", 'cyan'))
768
+ return False # Signal to exit the chat loop
769
+
770
+ elif command == '/stats':
771
+ prompt_tokens = self.total_prompt_tokens
772
+ generated_tokens = self.total_generated_tokens
773
+ total_tokens = prompt_tokens + generated_tokens
774
+
775
+ stats = f"""
776
+ Token usage statistics:
777
+ - Prompt tokens: {prompt_tokens}
778
+ - Generated tokens: {generated_tokens}
779
+ - Total tokens: {total_tokens}
780
+ - Current history length: {len(self.history_tokens)} tokens
781
+ - Current repetition penalty: {self.repetition_penalty}
782
+ - Current temperature: {self.config.temperature}
783
+ - Model: CosmicFish ({self.model.get_num_params() / 1e6:.1f}M parameters)
784
+ - Source: {DEFAULT_MODEL_REPO}
785
+ - Format: Safetensors (secure)
786
+ """
787
+ print(colored(stats, 'yellow'))
788
+ return True
789
+
790
+ elif command == '/debug':
791
+ self.debug_mode = not self.debug_mode
792
+ self.config.debug_mode = self.debug_mode # Sync with config
793
+ mode = "enabled" if self.debug_mode else "disabled"
794
+ print(colored(f"Debug mode {mode}", 'yellow'))
795
+ return True
796
+
797
+ elif command.startswith('/penalty '):
798
+ try:
799
+ penalty = float(command[9:].strip())
800
+ if 1.0 <= penalty <= 2.0:
801
+ self.repetition_penalty = penalty
802
+ print(colored(f"Repetition penalty set to {penalty}", 'yellow'))
803
+ else:
804
+ print(colored("Repetition penalty should be between 1.0 and 2.0", 'red'))
805
+ except ValueError:
806
+ print(colored("Invalid repetition penalty value. Please use a number between 1.0 and 2.0", 'red'))
807
+ return True
808
+
809
+ elif command.startswith('/temp '):
810
+ try:
811
+ temp = float(command[6:].strip())
812
+ if 0.1 <= temp <= 2.0:
813
+ self.config.temperature = temp
814
+ print(colored(f"Temperature set to {temp}", 'yellow'))
815
+ else:
816
+ print(colored("Temperature should be between 0.1 and 2.0", 'red'))
817
+ except ValueError:
818
+ print(colored("Invalid temperature value. Please use a number between 0.1 and 2.0", 'red'))
819
+ return True
820
+
821
+ elif command.startswith('/save '):
822
+ filename = command[6:].strip()
823
+ if not filename:
824
+ print(colored("Please specify a filename: /save <filename>", 'red'))
825
+ return True
826
+
827
+ try:
828
+ # Create conversations directory if it doesn't exist
829
+ os.makedirs('conversations', exist_ok=True)
830
+
831
+ # Add .txt extension if not present
832
+ if not filename.endswith('.txt'):
833
+ filename += '.txt'
834
+
835
+ filepath = os.path.join('conversations', filename)
836
+
837
+ with open(filepath, 'w', encoding='utf-8') as f:
838
+ for entry in self.history:
839
+ role, text = entry
840
+ prefix = self.human_prefix if role == "human" else self.assistant_prefix
841
+ f.write(f"{prefix}{text}{self.end_of_turn}")
842
+
843
+ print(colored(f"Conversation saved to {filepath}", 'green'))
844
+
845
+ except Exception as e:
846
+ print(colored(f"Error saving conversation: {str(e)}", 'red'))
847
+
848
+ return True
849
+
850
+ elif command.startswith('/load '):
851
+ filename = command[6:].strip()
852
+ if not filename:
853
+ print(colored("Please specify a filename: /load <filename>", 'red'))
854
+ return True
855
+
856
+ try:
857
+ # Add .txt extension if not present
858
+ if not filename.endswith('.txt'):
859
+ filename += '.txt'
860
+
861
+ filepath = os.path.join('conversations', filename)
862
+
863
+ if not os.path.exists(filepath):
864
+ print(colored(f"File not found: {filepath}", 'red'))
865
+ return True
866
+
867
+ with open(filepath, 'r', encoding='utf-8') as f:
868
+ content = f.read()
869
+
870
+ # Parse conversation turns
871
+ self.history = []
872
+ self.history_tokens = []
873
+
874
+ # Split by end of turn marker
875
+ turns = content.split(self.end_of_turn)
876
+ for turn in turns:
877
+ turn = turn.strip()
878
+ if not turn:
879
+ continue
880
+
881
+ if turn.startswith(self.human_prefix):
882
+ text = turn[len(self.human_prefix):].strip()
883
+ self.history.append(("human", text))
884
+ elif turn.startswith(self.assistant_prefix):
885
+ text = turn[len(self.assistant_prefix):].strip()
886
+ self.history.append(("assistant", text))
887
+
888
+ # Recalculate token counts
889
+ self.history_tokens = []
890
+ for entry in self.history:
891
+ role, text = entry
892
+ if role == "human":
893
+ self.history_tokens.extend(self._tokenize(f"{self.human_prefix}{text}{self.end_of_turn}"))
894
+ else:
895
+ self.history_tokens.extend(self._tokenize(f"{self.assistant_prefix}{text}{self.end_of_turn}"))
896
+
897
+ print(colored(f"Loaded conversation from {filepath} ({len(self.history) // 2} turns)", 'green'))
898
+
899
+ # Print the conversation
900
+ for i in range(0, len(self.history), 2):
901
+ if i < len(self.history):
902
+ user_text = self.history[i][1]
903
+ print(colored(f"\nYou: {user_text}", 'green'))
904
+
905
+ if i + 1 < len(self.history):
906
+ assistant_text = self.history[i + 1][1]
907
+ print(colored("CosmicFish: ", 'blue'), end="")
908
+ for line in assistant_text.split('\n'):
909
+ wrapped_lines = textwrap.wrap(line, width=100) if line.strip() else ['']
910
+ for wrapped_line in wrapped_lines:
911
+ print(wrapped_line)
912
+
913
+ except Exception as e:
914
+ print(colored(f"Error loading conversation: {str(e)}", 'red'))
915
+
916
+ return True
917
+
918
+ else:
919
+ print(colored(f"Unknown command: {command}. Type /help for available commands.", 'red'))
920
+ return True
921
+
922
+
923
+ def download_cosmicfish_from_hub(model_repo=DEFAULT_MODEL_REPO, device='cpu'):
924
+ """Download and load CosmicFish model from Hugging Face Hub (safetensors only)"""
925
+ print(colored(f"Downloading CosmicFish from Hugging Face: {model_repo}", "cyan"))
926
+
927
+ try:
928
+ # Download the model files to local cache
929
+ print("Downloading model files...")
930
+ cache_dir = snapshot_download(repo_id=model_repo, cache_dir=None)
931
+ print(f"Model cached at: {cache_dir}")
932
+
933
+ # Load config
934
+ config_path = os.path.join(cache_dir, "config.json")
935
+ with open(config_path, "r") as f:
936
+ config_dict = json.load(f)
937
+
938
+ # Create CosmicConfig
939
+ config = CosmicConfig(
940
+ vocab_size=config_dict["vocab_size"],
941
+ block_size=config_dict["block_size"],
942
+ n_layer=config_dict["n_layer"],
943
+ n_head=config_dict["n_head"],
944
+ n_embd=config_dict["n_embd"],
945
+ bias=config_dict["bias"],
946
+ dropout=0.0, # Set to 0 for inference
947
+ eps=config_dict.get("eps", 1e-6),
948
+ use_rotary=config_dict["use_rotary"],
949
+ use_swiglu=config_dict["use_swiglu"],
950
+ use_gqa=config_dict["use_gqa"],
951
+ n_query_groups=config_dict["n_query_groups"],
952
+ use_qk_norm=config_dict.get("use_qk_norm", False)
953
+ )
954
+
955
+ # Create model
956
+ print("Creating model...")
957
+ model = CosmicFish(config)
958
+
959
+ # Load weights from safetensors ONLY
960
+ print("Loading weights from safetensors...")
961
+ safetensors_path = os.path.join(cache_dir, "model.safetensors")
962
+
963
+ if not os.path.exists(safetensors_path):
964
+ raise FileNotFoundError(f"model.safetensors not found in {cache_dir}. This model requires safetensors format.")
965
+
966
+ state_dict = load_file(safetensors_path)
967
+
968
+ # Handle weight sharing: lm_head.weight shares with transformer.wte.weight
969
+ if 'lm_head.weight' not in state_dict and 'transformer.wte.weight' in state_dict:
970
+ state_dict['lm_head.weight'] = state_dict['transformer.wte.weight']
971
+
972
+ model.load_state_dict(state_dict)
973
+ model.to(device)
974
+ model.eval()
975
+
976
+ print(f"Model loaded: {model.get_num_params() / 1e6:.1f}M parameters")
977
+ print(f"Device: {device}")
978
+ return model, config
979
+
980
+ except Exception as e:
981
+ print(colored(f"Error downloading/loading model: {str(e)}", "red"))
982
+ print(colored("Make sure you have internet connection and the model repo exists", "yellow"))
983
+ sys.exit(1)
984
+
985
+
986
+ def load_tokenizer():
987
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
988
+ return tokenizer
989
+
990
+
991
+ def main():
992
+ parser = argparse.ArgumentParser(description="Chat with CosmicFish")
993
+
994
+ # Model parameters
995
+ parser.add_argument("--model_repo", type=str, default=DEFAULT_MODEL_REPO,
996
+ help=f"Hugging Face model repository (default: {DEFAULT_MODEL_REPO})")
997
+ parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu",
998
+ help="Device to use (cuda or cpu)")
999
+
1000
+ # Generation parameters
1001
+ parser.add_argument("--temperature", type=float, default=0.7,
1002
+ help="Temperature for sampling (default: 0.7)")
1003
+ parser.add_argument("--max_tokens", type=int, default=1024,
1004
+ help="Maximum number of tokens to generate per response")
1005
+ parser.add_argument("--min_tokens", type=int, default=10,
1006
+ help="Minimum number of tokens to generate per response")
1007
+ parser.add_argument("--top_k", type=int, default=40,
1008
+ help="Top-k sampling (0 to disable)")
1009
+ parser.add_argument("--repetition_penalty", type=float, default=1.2,
1010
+ help="Repetition penalty (1.0 = no penalty, 1.2 = mild, 1.5 = moderate)")
1011
+
1012
+ # Chat parameters
1013
+ parser.add_argument("--human_prefix", type=str, default="Human: ",
1014
+ help="Prefix for human messages")
1015
+ parser.add_argument("--assistant_prefix", type=str, default="Assistant: ",
1016
+ help="Prefix for assistant messages")
1017
+ parser.add_argument("--end_of_turn", type=str, default="\n\n",
1018
+ help="Delimiter between conversation turns")
1019
+ parser.add_argument("--instruction", type=str,
1020
+ default=DEFAULT_PROMPT_TEMPLATE,
1021
+ help="Instruction prompt to prepend to the conversation")
1022
+ parser.add_argument("--max_history", type=int, default=1024,
1023
+ help="Maximum number of tokens to keep in history")
1024
+
1025
+ # UI parameters
1026
+ parser.add_argument("--no_welcome", action="store_true",
1027
+ help="Don't display the welcome message")
1028
+ parser.add_argument("--debug", action="store_true",
1029
+ help="Enable debug mode")
1030
+
1031
+ args = parser.parse_args()
1032
+
1033
+ # Configure device
1034
+ device = args.device
1035
+ if device == "cuda" and not torch.cuda.is_available():
1036
+ print(colored("CUDA is not available, falling back to CPU", "yellow"))
1037
+ device = "cpu"
1038
+
1039
+ try:
1040
+ # Download and load the model from HF Hub
1041
+ model, model_config = download_cosmicfish_from_hub(args.model_repo, device)
1042
+
1043
+ # Load tokenizer
1044
+ tokenizer = load_tokenizer()
1045
+
1046
+ # Create a config object with all the necessary parameters
1047
+ class ChatConfig:
1048
+ def __init__(self, args, block_size):
1049
+ self.device = device
1050
+ self.temperature = args.temperature
1051
+ self.max_new_tokens = args.max_tokens
1052
+ self.min_tokens_to_generate = args.min_tokens
1053
+ self.top_k = args.top_k
1054
+ self.human_prefix = args.human_prefix
1055
+ self.assistant_prefix = args.assistant_prefix
1056
+ self.end_of_turn = args.end_of_turn
1057
+ self.prompt_template = args.instruction
1058
+ self.max_history_tokens = args.max_history
1059
+ self.display_welcome = not args.no_welcome
1060
+ self.block_size = block_size
1061
+ self.debug_mode = args.debug
1062
+ self.repetition_penalty = args.repetition_penalty
1063
+
1064
+ config = ChatConfig(args, model_config.block_size)
1065
+
1066
+ # Initialize chat session
1067
+ chat = CosmicFishChatSession(model, tokenizer, config)
1068
+
1069
+ # Main chat loop
1070
+ print(colored("\nCosmicFish initialized from Hugging Face! Type your message (or /help for commands).\n", 'cyan'))
1071
+
1072
+ while True:
1073
+ try:
1074
+ # Get user input
1075
+ user_input = input(colored("You: ", 'green'))
1076
+
1077
+ # Check if it's a command
1078
+ if user_input.startswith('/'):
1079
+ # Execute command, continue loop if True, exit if False
1080
+ if not chat.execute_command(user_input):
1081
+ break
1082
+ continue
1083
+
1084
+ # Skip if empty input
1085
+ if not user_input.strip():
1086
+ continue
1087
+
1088
+ # Generate response using live generation
1089
+ live_buffer = ""
1090
+ final_response = None
1091
+
1092
+ # Use the generator version
1093
+ response_generator = chat.generate_response(user_input)
1094
+
1095
+ try:
1096
+ # First print the assistant prefix
1097
+ print(colored("CosmicFish: ", 'blue'), end="")
1098
+ sys.stdout.flush()
1099
+
1100
+ for token, live_text, is_done in response_generator:
1101
+ # If this is the final clean response
1102
+ if is_done:
1103
+ final_response = live_text
1104
+ # Print the final response directly if we didn't get any tokens yet
1105
+ if not live_buffer:
1106
+ print(final_response, end="")
1107
+ break
1108
+ if token:
1109
+ # Check if token contains <|endoftext|> and remove it if present
1110
+ if "<|endoftext|>" in token:
1111
+ token = token.replace("<|endoftext|>", "")
1112
+ if token: # Only print if there's anything left
1113
+ print(token, end="", flush=True)
1114
+ break
1115
+
1116
+ # Display it
1117
+ print(token, end="", flush=True)
1118
+ live_buffer += token
1119
+
1120
+ except KeyboardInterrupt:
1121
+ # Allow user to interrupt generation
1122
+ print("\n[Generation interrupted]")
1123
+ final_response = "I was going to respond, but I'll stop here since you interrupted."
1124
+
1125
+ # Add an extra line for readability
1126
+ print()
1127
+
1128
+ except KeyboardInterrupt:
1129
+ print("\n\nKeyboard interrupt detected. Type /exit to quit or continue chatting.")
1130
+
1131
+ except Exception as e:
1132
+ print(colored(f"\nError: {str(e)}", 'red'))
1133
+ logger.error(f"Error in chat loop: {str(e)}", exc_info=True)
1134
+
1135
+ except Exception as e:
1136
+ print(colored(f"Error setting up chat: {str(e)}", 'red'))
1137
+ logger.error(f"Error setting up chat: {str(e)}", exc_info=True)
1138
+ sys.exit(1)
1139
+
1140
+
1141
+ if __name__ == "__main__":
1142
+ try:
1143
+ main()
1144
+ except Exception as e:
1145
+ logger.error(f"Fatal error: {str(e)}", exc_info=True)
1146
+ sys.exit(1)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:26f0666ce6a2f5cb80b4985966f27e21383f63336668ad635d1b3b00876507bc
3
+ size 183299272