|
""" |
|
Chat interface for CosmicFish model downloaded from Hugging Face Hub. |
|
Uses safetensors format only for secure model loading. |
|
""" |
|
|
|
import os |
|
import sys |
|
import time |
|
import argparse |
|
import torch |
|
import numpy as np |
|
from termcolor import colored |
|
import logging |
|
import readline |
|
import re |
|
import textwrap |
|
import random |
|
from collections import defaultdict |
|
import json |
|
|
|
|
|
try: |
|
from transformers import GPT2Tokenizer |
|
from huggingface_hub import hf_hub_download, snapshot_download |
|
HF_AVAILABLE = True |
|
except ImportError: |
|
HF_AVAILABLE = False |
|
print("Required libraries not available.") |
|
print("Install with: pip install transformers huggingface-hub") |
|
sys.exit(1) |
|
|
|
|
|
try: |
|
from safetensors.torch import load_file |
|
SAFETENSORS_AVAILABLE = True |
|
except ImportError: |
|
SAFETENSORS_AVAILABLE = False |
|
print("Safetensors not available. Install with: pip install safetensors") |
|
sys.exit(1) |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(levelname)s - %(message)s', |
|
handlers=[logging.StreamHandler(sys.stdout)] |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
DEFAULT_MODEL_REPO = "MistyozAI/CosmicFish-90M" |
|
|
|
|
|
DEFAULT_PROMPT_TEMPLATE = "Below is a conversation between a helpful AI assistant and a human. The assistant is knowledgeable, friendly, and provides detailed and accurate responses.\n\n" |
|
|
|
|
|
class CosmicConfig: |
|
"""Configuration class for CosmicFish.""" |
|
|
|
def __init__(self, |
|
vocab_size=50257, |
|
block_size=512, |
|
n_layer=10, |
|
n_head=16, |
|
n_embd=640, |
|
bias=True, |
|
dropout=0.0, |
|
n_query_groups=4, |
|
eps=1e-6, |
|
use_rotary=True, |
|
use_swiglu=True, |
|
use_qk_norm=False, |
|
use_gqa=True): |
|
self.vocab_size = vocab_size |
|
self.block_size = block_size |
|
self.n_layer = n_layer |
|
self.n_head = n_head |
|
self.n_embd = n_embd |
|
self.bias = bias |
|
self.dropout = dropout |
|
self.eps = eps |
|
self.use_rotary = use_rotary |
|
self.use_swiglu = use_swiglu |
|
self.use_qk_norm = use_qk_norm |
|
self.use_gqa = use_gqa |
|
self.n_query_groups = n_query_groups if use_gqa else n_head |
|
|
|
assert n_head % self.n_query_groups == 0, "n_head must be divisible by n_query_groups" |
|
|
|
|
|
class RMSNorm(torch.nn.Module): |
|
"""Root Mean Square Normalization""" |
|
|
|
def __init__(self, dim, eps=1e-6): |
|
super().__init__() |
|
self.eps = eps |
|
self.weight = torch.nn.Parameter(torch.ones(dim)) |
|
|
|
def forward(self, x): |
|
rms = torch.sqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) |
|
return self.weight * (x / rms) |
|
|
|
|
|
def precompute_freqs_cis(dim, end, theta=10000.0): |
|
"""Precompute the frequency tensor for complex exponentials (cis)""" |
|
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) |
|
t = torch.arange(end, device=freqs.device) |
|
freqs = torch.outer(t, freqs) |
|
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) |
|
return freqs_cis |
|
|
|
|
|
def apply_rotary_emb(xq, xk, freqs_cis): |
|
"""Apply rotary embeddings to input tensors""" |
|
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) |
|
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) |
|
|
|
seq_len = xq_.size(2) |
|
if freqs_cis.size(0) < seq_len: |
|
raise ValueError(f"freqs_cis has only {freqs_cis.size(0)} values but sequence length is {seq_len}") |
|
|
|
freqs_cis_seq = freqs_cis[:seq_len] |
|
xq_out = torch.view_as_real(xq_ * freqs_cis_seq.unsqueeze(0)).flatten(3) |
|
xk_out = torch.view_as_real(xk_ * freqs_cis_seq.unsqueeze(0)).flatten(3) |
|
|
|
return xq_out.type_as(xq), xk_out.type_as(xk) |
|
|
|
|
|
class GroupedQueryAttention(torch.nn.Module): |
|
"""Grouped Query Attention (GQA) implementation""" |
|
|
|
def __init__(self, config): |
|
super().__init__() |
|
assert config.n_embd % config.n_head == 0 |
|
|
|
head_dim = config.n_embd // config.n_head |
|
self.head_dim = head_dim |
|
self.n_head = config.n_head |
|
self.n_embd = config.n_embd |
|
self.n_query_groups = config.n_query_groups |
|
|
|
self.kv_heads = config.n_head // config.n_query_groups if config.use_gqa else config.n_head |
|
qkv_proj_size = (config.n_head + 2 * self.kv_heads) * head_dim |
|
|
|
self.c_attn = torch.nn.Linear(config.n_embd, qkv_proj_size, bias=config.bias) |
|
self.c_proj = torch.nn.Linear(config.n_embd, config.n_embd, bias=config.bias) |
|
|
|
|
|
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') |
|
if not self.flash: |
|
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) |
|
.view(1, 1, config.block_size, config.block_size)) |
|
|
|
|
|
self.qk_norm = getattr(config, 'use_qk_norm', False) |
|
if self.qk_norm: |
|
self.q_norm = RMSNorm(head_dim, eps=getattr(config, 'eps', 1e-6)) |
|
self.k_norm = RMSNorm(head_dim, eps=getattr(config, 'eps', 1e-6)) |
|
|
|
def forward(self, x, freqs_cis=None): |
|
B, T, C = x.size() |
|
qkv = self.c_attn(x) |
|
head_dim = C // self.n_head |
|
|
|
q_size = self.n_head * head_dim |
|
k_size = self.kv_heads * head_dim |
|
v_size = self.kv_heads * head_dim |
|
|
|
q, k, v = qkv.split([q_size, k_size, v_size], dim=2) |
|
|
|
q = q.view(B, T, self.n_head, head_dim).transpose(1, 2) |
|
k = k.view(B, T, self.kv_heads, head_dim).transpose(1, 2) |
|
v = v.view(B, T, self.kv_heads, head_dim).transpose(1, 2) |
|
|
|
|
|
if self.kv_heads < self.n_head: |
|
repeats = self.n_head // self.kv_heads |
|
k = k.repeat_interleave(repeats, dim=1) |
|
v = v.repeat_interleave(repeats, dim=1) |
|
|
|
|
|
if freqs_cis is not None: |
|
q, k = apply_rotary_emb(q, k, freqs_cis) |
|
|
|
|
|
if self.qk_norm: |
|
q = self.q_norm(q) |
|
k = self.k_norm(k) |
|
|
|
|
|
if self.flash: |
|
y = torch.nn.functional.scaled_dot_product_attention( |
|
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True |
|
) |
|
else: |
|
att = (q @ k.transpose(-2, -1)) * (1.0 / torch.sqrt(torch.tensor(k.size(-1), dtype=torch.float32))) |
|
att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf')) |
|
att = torch.nn.functional.softmax(att, dim=-1) |
|
y = att @ v |
|
|
|
y = y.transpose(1, 2).contiguous().view(B, T, C) |
|
y = self.c_proj(y) |
|
return y |
|
|
|
|
|
class Block(torch.nn.Module): |
|
"""Transformer block""" |
|
|
|
def __init__(self, config): |
|
super().__init__() |
|
self.ln_1 = RMSNorm(config.n_embd, eps=config.eps) |
|
self.ln_2 = RMSNorm(config.n_embd, eps=config.eps) |
|
self.attn = GroupedQueryAttention(config) |
|
|
|
|
|
if config.use_swiglu: |
|
|
|
self.mlp = torch.nn.ModuleDict(dict( |
|
gate=torch.nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias), |
|
up=torch.nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias), |
|
down=torch.nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias), |
|
act=torch.nn.SiLU(), |
|
)) |
|
m = self.mlp |
|
self.mlpf = lambda x: m.down(m.act(m.up(x)) * m.gate(x)) |
|
else: |
|
|
|
self.mlp = torch.nn.ModuleDict(dict( |
|
c_fc=torch.nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias), |
|
c_proj=torch.nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias), |
|
act=torch.nn.GELU(), |
|
)) |
|
m = self.mlp |
|
self.mlpf = lambda x: m.c_proj(m.act(m.c_fc(x))) |
|
|
|
def forward(self, x, freqs_cis=None): |
|
x = x + self.attn(self.ln_1(x), freqs_cis) |
|
x = x + self.mlpf(self.ln_2(x)) |
|
return x |
|
|
|
|
|
class CosmicFish(torch.nn.Module): |
|
""" |
|
CosmicFish model for inference only. |
|
Features: Rotary Positional Embeddings, Grouped-Query Attention, SwiGLU, RMSNorm |
|
""" |
|
|
|
def __init__(self, config): |
|
super().__init__() |
|
self.config = config |
|
|
|
self.transformer = torch.nn.ModuleDict(dict( |
|
wte=torch.nn.Embedding(config.vocab_size, config.n_embd), |
|
h=torch.nn.ModuleList([Block(config) for _ in range(config.n_layer)]), |
|
ln_f=RMSNorm(config.n_embd, eps=config.eps), |
|
)) |
|
|
|
self.lm_head = torch.nn.Linear(config.n_embd, config.vocab_size, bias=False) |
|
|
|
|
|
self.transformer.wte.weight = self.lm_head.weight |
|
|
|
|
|
if config.use_rotary: |
|
head_dim = config.n_embd // config.n_head |
|
self.freqs_cis = precompute_freqs_cis(head_dim, config.block_size) |
|
else: |
|
self.freqs_cis = None |
|
self.transformer.wpe = torch.nn.Embedding(config.block_size, config.n_embd) |
|
|
|
def get_num_params(self, non_embedding=True): |
|
"""Return the number of parameters in the model.""" |
|
n_params = sum(p.numel() for p in self.parameters()) |
|
if non_embedding and hasattr(self.transformer, 'wpe'): |
|
n_params -= self.transformer.wpe.weight.numel() |
|
return n_params |
|
|
|
def forward(self, idx, targets=None): |
|
"""Forward pass through the model.""" |
|
device = idx.device |
|
b, t = idx.size() |
|
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" |
|
|
|
|
|
tok_emb = self.transformer.wte(idx) |
|
|
|
|
|
if self.config.use_rotary: |
|
x = tok_emb |
|
freqs_cis = self.freqs_cis.to(device) if self.freqs_cis is not None else None |
|
else: |
|
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) |
|
pos_emb = self.transformer.wpe(pos) |
|
x = tok_emb + pos_emb |
|
freqs_cis = None |
|
|
|
|
|
for block in self.transformer.h: |
|
x = block(x, freqs_cis) |
|
|
|
|
|
x = self.transformer.ln_f(x) |
|
|
|
|
|
if targets is not None: |
|
logits = self.lm_head(x) |
|
loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) |
|
else: |
|
|
|
logits = self.lm_head(x[:, [-1], :]) |
|
loss = None |
|
|
|
return logits, loss |
|
|
|
@torch.no_grad() |
|
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): |
|
""" |
|
Generate text by sampling from the model, token by token. |
|
""" |
|
for _ in range(max_new_tokens): |
|
|
|
idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:] |
|
|
|
|
|
logits, _ = self(idx_cond) |
|
logits = logits[:, -1, :] / temperature |
|
|
|
|
|
if top_k is not None: |
|
v, _ = torch.topk(logits, top_k) |
|
logits[logits < v[:, [-1]]] = -float('Inf') |
|
|
|
|
|
probs = torch.nn.functional.softmax(logits, dim=-1) |
|
idx_next = torch.multinomial(probs, num_samples=1) |
|
|
|
|
|
idx = torch.cat((idx, idx_next), dim=1) |
|
|
|
return idx |
|
|
|
|
|
class RepetitionPenaltyLogitsProcessor: |
|
"""Apply repetition penalty to prevent repeating tokens.""" |
|
|
|
def __init__(self, penalty=1.2): |
|
self.penalty = penalty |
|
|
|
def __call__(self, input_ids, scores): |
|
"""Apply repetition penalty to logits where input_ids is already seen.""" |
|
score = torch.gather(scores, 1, input_ids) |
|
|
|
score = torch.where(score > 0, score / self.penalty, score * self.penalty) |
|
scores.scatter_(1, input_ids, score) |
|
return scores |
|
|
|
|
|
class CosmicFishChatSession: |
|
"""Chat session for CosmicFish model from Hugging Face Hub.""" |
|
|
|
def __init__(self, model, tokenizer, config): |
|
"""Initialize chat session with model and configuration.""" |
|
self.model = model |
|
self.tokenizer = tokenizer |
|
self.config = config |
|
self.device = next(model.parameters()).device |
|
self.history = [] |
|
self.history_tokens = [] |
|
self.max_history_tokens = config.max_history_tokens |
|
self.prompt_template = config.prompt_template |
|
self.human_prefix = config.human_prefix |
|
self.assistant_prefix = config.assistant_prefix |
|
self.end_of_turn = config.end_of_turn |
|
self.block_size = config.block_size |
|
self.debug_mode = config.debug_mode |
|
self.repetition_penalty = config.repetition_penalty |
|
self.min_tokens_to_generate = config.min_tokens_to_generate |
|
self.max_retries = 20 |
|
|
|
self.fallback_responses = [ |
|
"I'd be happy to help with that. Could you provide more details about what specific information you're looking for?", |
|
"That's a topic I can provide information about. What specific aspects would you like to know?", |
|
"I understand your question. I can share factual information on this topic if you could specify what aspects you're interested in.", |
|
"I can help with your question. To give you the most relevant information, could you clarify what specific details you're looking for?", |
|
"I'd be glad to address your question. To provide the most helpful response, could you specify what particular aspects of this topic interest you?" |
|
] |
|
|
|
self.generation_failure_message = "I'm sorry, but I'm having difficulty generating a response to that prompt. Could you try rephrasing your question or asking something else?" |
|
|
|
|
|
self.total_prompt_tokens = 0 |
|
self.total_generated_tokens = 0 |
|
|
|
|
|
self.end_markers = [ |
|
f"{self.human_prefix}", |
|
"Human:", |
|
"\nHuman:", |
|
"\nH:", |
|
"H:", |
|
"<|endoftext|>", |
|
"Below is a conversation", |
|
"\nA:", |
|
"A:", |
|
"</s>", |
|
"User:", |
|
"\nUser:" |
|
] |
|
|
|
if config.display_welcome: |
|
self._print_welcome_message() |
|
|
|
def _print_welcome_message(self): |
|
welcome_text = f""" |
|
{'=' * 80} |
|
Welcome to CosmicFish chat interface |
|
|
|
This is a {self.model.get_num_params() / 1e6:.1f}M parameter model. |
|
CosmicFish is an efficient LLM with an advanced architecture. |
|
|
|
Type your prompts and CosmicFish will respond. |
|
|
|
Special commands: |
|
- /help: Show this help message |
|
- /clear: Clear the conversation history |
|
- /exit or /quit: Exit the chat |
|
- /stats: Show token usage statistics |
|
- /save [filename]: Save the conversation |
|
- /load [filename]: Load a conversation |
|
- /temp [value]: Set temperature (between 0.1 and 2.0) |
|
- /penalty [value]: Set repetition penalty (1.0-2.0) |
|
- /debug: Toggle debug mode |
|
|
|
|
|
Note: CosmicFIsh may generate incorrect or fictional responses. Verify facts if needed. |
|
|
|
Visit https://cosmicfish.ai for more info |
|
|
|
|
|
Developed by Mistyoz AI (https://www.mistyoz.com) |
|
{'=' * 80} |
|
""" |
|
print(colored(welcome_text, 'cyan')) |
|
|
|
def _format_prompt(self, user_input): |
|
"""Format the complete prompt with history and current input.""" |
|
|
|
formatted_prompt = self.prompt_template |
|
|
|
|
|
for entry in self.history: |
|
role, text = entry |
|
if role == "human": |
|
formatted_prompt += f"{self.human_prefix}{text}{self.end_of_turn}" |
|
else: |
|
formatted_prompt += f"{self.assistant_prefix}{text}{self.end_of_turn}" |
|
|
|
|
|
formatted_prompt += f"{self.human_prefix}{user_input}{self.end_of_turn}{self.assistant_prefix}" |
|
|
|
return formatted_prompt |
|
|
|
def _tokenize(self, text): |
|
"""Tokenize text and return token IDs.""" |
|
return self.tokenizer.encode(text) |
|
|
|
def _update_history(self, user_input, response): |
|
"""Update conversation history.""" |
|
|
|
self.history.append(("human", user_input)) |
|
self.history.append(("assistant", response)) |
|
|
|
|
|
user_tokens = self._tokenize(f"{self.human_prefix}{user_input}{self.end_of_turn}") |
|
response_tokens = self._tokenize(f"{self.assistant_prefix}{response}{self.end_of_turn}") |
|
|
|
self.history_tokens.extend(user_tokens) |
|
self.history_tokens.extend(response_tokens) |
|
|
|
|
|
self.total_prompt_tokens += len(user_tokens) |
|
self.total_generated_tokens += len(response_tokens) |
|
|
|
|
|
self._trim_history_if_needed() |
|
|
|
def _trim_history_if_needed(self): |
|
"""Trim history to fit within the context window.""" |
|
if len(self.history_tokens) > self.max_history_tokens: |
|
|
|
while len(self.history_tokens) > self.max_history_tokens and len(self.history) >= 2: |
|
|
|
self.history = self.history[2:] |
|
|
|
|
|
user_turn = self.history[0][1] |
|
assistant_turn = self.history[1][1] |
|
user_tokens = len(self._tokenize(f"{self.human_prefix}{user_turn}{self.end_of_turn}")) |
|
assistant_tokens = len(self._tokenize(f"{self.assistant_prefix}{assistant_turn}{self.end_of_turn}")) |
|
|
|
|
|
self.history_tokens = self.history_tokens[user_tokens + assistant_tokens:] |
|
|
|
def _should_stop_generation(self, text): |
|
"""Check if generation should stop based on end markers.""" |
|
for marker in self.end_markers: |
|
if marker in text: |
|
return True |
|
return False |
|
|
|
def _clean_token_text(self, text): |
|
text = text.replace('��', "'") |
|
text = text.replace('�', "'") |
|
text = text.replace('\ufffd', "'") |
|
text = text.replace('\uFFFD', "'") |
|
text = text.replace('’', "'") |
|
text = text.replace('â€Å"', "'") |
|
text = text.replace('�', "'") |
|
text = text.replace('â€"', "'") |
|
text = text.replace('â€"', "'") |
|
return text |
|
|
|
def generate_with_repetition_penalty(self, input_ids, max_new_tokens, temperature, top_k, penalty=1.2, live=False): |
|
"""Custom generate function with repetition penalty and optional live generation.""" |
|
model = self.model |
|
device = self.device |
|
|
|
|
|
model.eval() |
|
|
|
|
|
generated = input_ids.clone() |
|
|
|
|
|
live_buffer = "" |
|
|
|
|
|
rep_processor = RepetitionPenaltyLogitsProcessor(penalty=penalty) |
|
|
|
|
|
tokens_generated = 0 |
|
min_tokens = self.min_tokens_to_generate |
|
|
|
|
|
eot_token_id = self.tokenizer.eos_token_id if hasattr(self.tokenizer, 'eos_token_id') else 50256 |
|
|
|
|
|
for _ in range(max_new_tokens): |
|
|
|
if generated.size(1) > self.block_size: |
|
context = generated[:, -self.block_size:] |
|
else: |
|
context = generated |
|
|
|
|
|
with torch.no_grad(): |
|
logits, _ = model(context) |
|
|
|
|
|
next_token_logits = logits[:, -1, :] |
|
|
|
|
|
next_token_logits = next_token_logits / temperature |
|
|
|
|
|
if penalty > 1.0: |
|
next_token_logits = rep_processor(context, next_token_logits) |
|
|
|
|
|
if top_k is not None: |
|
indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None] |
|
next_token_logits[indices_to_remove] = float('-inf') |
|
|
|
|
|
probs = torch.nn.functional.softmax(next_token_logits, dim=-1) |
|
|
|
|
|
next_token = torch.multinomial(probs, num_samples=1) |
|
|
|
|
|
if next_token.item() == eot_token_id: |
|
if live: |
|
yield "", live_buffer, True |
|
break |
|
|
|
|
|
generated = torch.cat((generated, next_token), dim=1) |
|
tokens_generated += 1 |
|
|
|
|
|
if live: |
|
|
|
next_token_text = self.tokenizer.decode([next_token.item()]) |
|
|
|
next_token_text = self._clean_token_text(next_token_text) |
|
live_buffer += next_token_text |
|
|
|
|
|
eot_marker_pos = live_buffer.find("<|endoftext|>") |
|
if eot_marker_pos != -1: |
|
|
|
live_buffer = live_buffer[:eot_marker_pos] |
|
yield "", live_buffer, True |
|
break |
|
|
|
|
|
should_stop = tokens_generated >= min_tokens and self._should_stop_generation(live_buffer) |
|
yield next_token_text, live_buffer, should_stop |
|
|
|
if should_stop: |
|
break |
|
|
|
|
|
elif tokens_generated >= min_tokens: |
|
|
|
recent_text = self.tokenizer.decode(generated[0, -20:].tolist()) |
|
if self._should_stop_generation(recent_text): |
|
break |
|
|
|
|
|
if tokens_generated == 0 and not live: |
|
if self.debug_mode: |
|
print(colored("\n[No tokens generated in this attempt]", "red")) |
|
return None |
|
|
|
if not live: |
|
return generated |
|
|
|
def generate_response(self, user_input): |
|
"""Generate a response to the user input.""" |
|
|
|
prompt = self._format_prompt(user_input) |
|
|
|
|
|
input_ids = torch.tensor(self._tokenize(prompt), dtype=torch.long).unsqueeze(0).to(self.device) |
|
|
|
|
|
if input_ids.size(1) > self.block_size: |
|
|
|
instruction_tokens = self._tokenize(self.prompt_template) |
|
|
|
keep_from_beginning = len(instruction_tokens) |
|
keep_from_end = self.block_size - keep_from_beginning |
|
|
|
|
|
if keep_from_end < 0: |
|
|
|
input_ids = input_ids[:, :self.block_size] |
|
else: |
|
|
|
input_ids = torch.cat([ |
|
input_ids[:, :keep_from_beginning], |
|
input_ids[:, -(keep_from_end):] |
|
], dim=1) |
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
return self._generate_live_response(input_ids, user_input, start_time) |
|
|
|
def _generate_live_response(self, input_ids, user_input, start_time): |
|
"""Generate response with live token-by-token output.""" |
|
|
|
live_text = "" |
|
tokens_generated = 0 |
|
retry_count = 0 |
|
|
|
|
|
while retry_count <= self.max_retries: |
|
if retry_count > 0: |
|
|
|
if retry_count % 2 == 0: |
|
|
|
temp_adjustment = min(0.2 * (retry_count // 2), 0.8) |
|
current_temp = min(self.config.temperature + temp_adjustment, 1.8) |
|
else: |
|
|
|
temp_adjustment = min(0.2 * ((retry_count + 1) // 2), 0.4) |
|
current_temp = max(self.config.temperature - temp_adjustment, 0.2) |
|
|
|
if self.debug_mode: |
|
print(colored(f"\n[Live retry {retry_count}: Using temperature {current_temp:.2f}]", "yellow")) |
|
else: |
|
current_temp = self.config.temperature |
|
|
|
|
|
live_text = "" |
|
tokens_generated = 0 |
|
generation_failed = False |
|
|
|
|
|
try: |
|
|
|
for token_text, live_buffer, should_stop in self.generate_with_repetition_penalty( |
|
input_ids, |
|
max_new_tokens=self.config.max_new_tokens, |
|
temperature=current_temp, |
|
top_k=self.config.top_k, |
|
penalty=self.repetition_penalty, |
|
live=True |
|
): |
|
|
|
if should_stop: |
|
|
|
live_text = live_buffer |
|
break |
|
|
|
|
|
if token_text: |
|
live_text += token_text |
|
tokens_generated += 1 |
|
yield token_text, live_text, False |
|
|
|
|
|
if not live_text or len(live_text.strip()) < 10: |
|
if self.debug_mode: |
|
print(colored("\n[Live generation produced empty or too short response, retrying]", "yellow")) |
|
generation_failed = True |
|
retry_count += 1 |
|
|
|
if retry_count <= self.max_retries: |
|
print("\r" + " " * 80 + "\r", end="") |
|
else: |
|
|
|
break |
|
|
|
except Exception as e: |
|
if self.debug_mode: |
|
print(colored(f"\n[Live generation error: {str(e)}, retrying]", "red")) |
|
generation_failed = True |
|
retry_count += 1 |
|
|
|
|
|
if generation_failed or not live_text or len(live_text.strip()) < 10: |
|
live_text = self.generation_failure_message |
|
if self.debug_mode: |
|
print(colored(f"\n[Returning failure message after {retry_count} live retries]", "red")) |
|
|
|
|
|
time_taken = time.time() - start_time |
|
tokens_per_second = tokens_generated / time_taken if time_taken > 0 else 0 |
|
|
|
|
|
self._update_history(user_input, live_text) |
|
|
|
|
|
logger.debug(f"Generated {tokens_generated} tokens in {time_taken:.2f}s ({tokens_per_second:.2f} tokens/s)") |
|
|
|
|
|
yield "", live_text, True |
|
|
|
def execute_command(self, command): |
|
"""Execute a special command prefixed with /.""" |
|
command = command.strip() |
|
|
|
if command == '/help': |
|
self._print_welcome_message() |
|
return True |
|
|
|
elif command == '/clear': |
|
self.history = [] |
|
self.history_tokens = [] |
|
print(colored("Conversation history cleared.", 'yellow')) |
|
return True |
|
|
|
elif command in ['/exit', '/quit']: |
|
print(colored("Goodbye!", 'cyan')) |
|
return False |
|
|
|
elif command == '/stats': |
|
prompt_tokens = self.total_prompt_tokens |
|
generated_tokens = self.total_generated_tokens |
|
total_tokens = prompt_tokens + generated_tokens |
|
|
|
stats = f""" |
|
Token usage statistics: |
|
- Prompt tokens: {prompt_tokens} |
|
- Generated tokens: {generated_tokens} |
|
- Total tokens: {total_tokens} |
|
- Current history length: {len(self.history_tokens)} tokens |
|
- Current repetition penalty: {self.repetition_penalty} |
|
- Current temperature: {self.config.temperature} |
|
- Model: CosmicFish ({self.model.get_num_params() / 1e6:.1f}M parameters) |
|
- Source: {DEFAULT_MODEL_REPO} |
|
- Format: Safetensors (secure) |
|
""" |
|
print(colored(stats, 'yellow')) |
|
return True |
|
|
|
elif command == '/debug': |
|
self.debug_mode = not self.debug_mode |
|
self.config.debug_mode = self.debug_mode |
|
mode = "enabled" if self.debug_mode else "disabled" |
|
print(colored(f"Debug mode {mode}", 'yellow')) |
|
return True |
|
|
|
elif command.startswith('/penalty '): |
|
try: |
|
penalty = float(command[9:].strip()) |
|
if 1.0 <= penalty <= 2.0: |
|
self.repetition_penalty = penalty |
|
print(colored(f"Repetition penalty set to {penalty}", 'yellow')) |
|
else: |
|
print(colored("Repetition penalty should be between 1.0 and 2.0", 'red')) |
|
except ValueError: |
|
print(colored("Invalid repetition penalty value. Please use a number between 1.0 and 2.0", 'red')) |
|
return True |
|
|
|
elif command.startswith('/temp '): |
|
try: |
|
temp = float(command[6:].strip()) |
|
if 0.1 <= temp <= 2.0: |
|
self.config.temperature = temp |
|
print(colored(f"Temperature set to {temp}", 'yellow')) |
|
else: |
|
print(colored("Temperature should be between 0.1 and 2.0", 'red')) |
|
except ValueError: |
|
print(colored("Invalid temperature value. Please use a number between 0.1 and 2.0", 'red')) |
|
return True |
|
|
|
elif command.startswith('/save '): |
|
filename = command[6:].strip() |
|
if not filename: |
|
print(colored("Please specify a filename: /save <filename>", 'red')) |
|
return True |
|
|
|
try: |
|
|
|
os.makedirs('conversations', exist_ok=True) |
|
|
|
|
|
if not filename.endswith('.txt'): |
|
filename += '.txt' |
|
|
|
filepath = os.path.join('conversations', filename) |
|
|
|
with open(filepath, 'w', encoding='utf-8') as f: |
|
for entry in self.history: |
|
role, text = entry |
|
prefix = self.human_prefix if role == "human" else self.assistant_prefix |
|
f.write(f"{prefix}{text}{self.end_of_turn}") |
|
|
|
print(colored(f"Conversation saved to {filepath}", 'green')) |
|
|
|
except Exception as e: |
|
print(colored(f"Error saving conversation: {str(e)}", 'red')) |
|
|
|
return True |
|
|
|
elif command.startswith('/load '): |
|
filename = command[6:].strip() |
|
if not filename: |
|
print(colored("Please specify a filename: /load <filename>", 'red')) |
|
return True |
|
|
|
try: |
|
|
|
if not filename.endswith('.txt'): |
|
filename += '.txt' |
|
|
|
filepath = os.path.join('conversations', filename) |
|
|
|
if not os.path.exists(filepath): |
|
print(colored(f"File not found: {filepath}", 'red')) |
|
return True |
|
|
|
with open(filepath, 'r', encoding='utf-8') as f: |
|
content = f.read() |
|
|
|
|
|
self.history = [] |
|
self.history_tokens = [] |
|
|
|
|
|
turns = content.split(self.end_of_turn) |
|
for turn in turns: |
|
turn = turn.strip() |
|
if not turn: |
|
continue |
|
|
|
if turn.startswith(self.human_prefix): |
|
text = turn[len(self.human_prefix):].strip() |
|
self.history.append(("human", text)) |
|
elif turn.startswith(self.assistant_prefix): |
|
text = turn[len(self.assistant_prefix):].strip() |
|
self.history.append(("assistant", text)) |
|
|
|
|
|
self.history_tokens = [] |
|
for entry in self.history: |
|
role, text = entry |
|
if role == "human": |
|
self.history_tokens.extend(self._tokenize(f"{self.human_prefix}{text}{self.end_of_turn}")) |
|
else: |
|
self.history_tokens.extend(self._tokenize(f"{self.assistant_prefix}{text}{self.end_of_turn}")) |
|
|
|
print(colored(f"Loaded conversation from {filepath} ({len(self.history) // 2} turns)", 'green')) |
|
|
|
|
|
for i in range(0, len(self.history), 2): |
|
if i < len(self.history): |
|
user_text = self.history[i][1] |
|
print(colored(f"\nYou: {user_text}", 'green')) |
|
|
|
if i + 1 < len(self.history): |
|
assistant_text = self.history[i + 1][1] |
|
print(colored("CosmicFish: ", 'blue'), end="") |
|
for line in assistant_text.split('\n'): |
|
wrapped_lines = textwrap.wrap(line, width=100) if line.strip() else [''] |
|
for wrapped_line in wrapped_lines: |
|
print(wrapped_line) |
|
|
|
except Exception as e: |
|
print(colored(f"Error loading conversation: {str(e)}", 'red')) |
|
|
|
return True |
|
|
|
else: |
|
print(colored(f"Unknown command: {command}. Type /help for available commands.", 'red')) |
|
return True |
|
|
|
|
|
def download_cosmicfish_from_hub(model_repo=DEFAULT_MODEL_REPO, device='cpu'): |
|
"""Download and load CosmicFish model from Hugging Face Hub (safetensors only)""" |
|
print(colored(f"Downloading CosmicFish from Hugging Face: {model_repo}", "cyan")) |
|
|
|
try: |
|
|
|
print("Downloading model files...") |
|
cache_dir = snapshot_download(repo_id=model_repo, cache_dir=None) |
|
print(f"Model cached at: {cache_dir}") |
|
|
|
|
|
config_path = os.path.join(cache_dir, "config.json") |
|
with open(config_path, "r") as f: |
|
config_dict = json.load(f) |
|
|
|
|
|
config = CosmicConfig( |
|
vocab_size=config_dict["vocab_size"], |
|
block_size=config_dict["block_size"], |
|
n_layer=config_dict["n_layer"], |
|
n_head=config_dict["n_head"], |
|
n_embd=config_dict["n_embd"], |
|
bias=config_dict["bias"], |
|
dropout=0.0, |
|
eps=config_dict.get("eps", 1e-6), |
|
use_rotary=config_dict["use_rotary"], |
|
use_swiglu=config_dict["use_swiglu"], |
|
use_gqa=config_dict["use_gqa"], |
|
n_query_groups=config_dict["n_query_groups"], |
|
use_qk_norm=config_dict.get("use_qk_norm", False) |
|
) |
|
|
|
|
|
print("Creating model...") |
|
model = CosmicFish(config) |
|
|
|
|
|
print("Loading weights from safetensors...") |
|
safetensors_path = os.path.join(cache_dir, "model.safetensors") |
|
|
|
if not os.path.exists(safetensors_path): |
|
raise FileNotFoundError(f"model.safetensors not found in {cache_dir}. This model requires safetensors format.") |
|
|
|
state_dict = load_file(safetensors_path) |
|
|
|
|
|
if 'lm_head.weight' not in state_dict and 'transformer.wte.weight' in state_dict: |
|
state_dict['lm_head.weight'] = state_dict['transformer.wte.weight'] |
|
|
|
model.load_state_dict(state_dict) |
|
model.to(device) |
|
model.eval() |
|
|
|
print(f"Model loaded: {model.get_num_params() / 1e6:.1f}M parameters") |
|
print(f"Device: {device}") |
|
return model, config |
|
|
|
except Exception as e: |
|
print(colored(f"Error downloading/loading model: {str(e)}", "red")) |
|
print(colored("Make sure you have internet connection and the model repo exists", "yellow")) |
|
sys.exit(1) |
|
|
|
|
|
def load_tokenizer(): |
|
tokenizer = GPT2Tokenizer.from_pretrained("gpt2") |
|
return tokenizer |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description="Chat with CosmicFish") |
|
|
|
|
|
parser.add_argument("--model_repo", type=str, default=DEFAULT_MODEL_REPO, |
|
help=f"Hugging Face model repository (default: {DEFAULT_MODEL_REPO})") |
|
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", |
|
help="Device to use (cuda or cpu)") |
|
|
|
|
|
parser.add_argument("--temperature", type=float, default=0.5, |
|
help="Temperature for sampling (default: 0.7)") |
|
parser.add_argument("--max_tokens", type=int, default=512, |
|
help="Maximum number of tokens to generate per response") |
|
parser.add_argument("--min_tokens", type=int, default=10, |
|
help="Minimum number of tokens to generate per response") |
|
parser.add_argument("--top_k", type=int, default=40, |
|
help="Top-k sampling (0 to disable)") |
|
parser.add_argument("--repetition_penalty", type=float, default=1.2, |
|
help="Repetition penalty (1.0 = no penalty, 1.2 = mild, 1.5 = moderate)") |
|
|
|
|
|
parser.add_argument("--human_prefix", type=str, default="Human: ", |
|
help="Prefix for human messages") |
|
parser.add_argument("--assistant_prefix", type=str, default="Assistant: ", |
|
help="Prefix for assistant messages") |
|
parser.add_argument("--end_of_turn", type=str, default="\n\n", |
|
help="Delimiter between conversation turns") |
|
parser.add_argument("--instruction", type=str, |
|
default=DEFAULT_PROMPT_TEMPLATE, |
|
help="Instruction prompt to prepend to the conversation") |
|
parser.add_argument("--max_history", type=int, default=512, |
|
help="Maximum number of tokens to keep in history") |
|
|
|
|
|
parser.add_argument("--no_welcome", action="store_true", |
|
help="Don't display the welcome message") |
|
parser.add_argument("--debug", action="store_true", |
|
help="Enable debug mode") |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
device = args.device |
|
if device == "cuda" and not torch.cuda.is_available(): |
|
print(colored("CUDA is not available, falling back to CPU", "yellow")) |
|
device = "cpu" |
|
|
|
try: |
|
|
|
model, model_config = download_cosmicfish_from_hub(args.model_repo, device) |
|
|
|
|
|
tokenizer = load_tokenizer() |
|
|
|
|
|
class ChatConfig: |
|
def __init__(self, args, block_size): |
|
self.device = device |
|
self.temperature = args.temperature |
|
self.max_new_tokens = args.max_tokens |
|
self.min_tokens_to_generate = args.min_tokens |
|
self.top_k = args.top_k |
|
self.human_prefix = args.human_prefix |
|
self.assistant_prefix = args.assistant_prefix |
|
self.end_of_turn = args.end_of_turn |
|
self.prompt_template = args.instruction |
|
self.max_history_tokens = args.max_history |
|
self.display_welcome = not args.no_welcome |
|
self.block_size = block_size |
|
self.debug_mode = args.debug |
|
self.repetition_penalty = args.repetition_penalty |
|
|
|
config = ChatConfig(args, model_config.block_size) |
|
|
|
|
|
chat = CosmicFishChatSession(model, tokenizer, config) |
|
|
|
|
|
print(colored("\nCosmicFish initialized from Hugging Face! Type your message (or /help for commands).\n", 'cyan')) |
|
|
|
while True: |
|
try: |
|
|
|
user_input = input(colored("You: ", 'green')) |
|
|
|
|
|
if user_input.startswith('/'): |
|
|
|
if not chat.execute_command(user_input): |
|
break |
|
continue |
|
|
|
|
|
if not user_input.strip(): |
|
continue |
|
|
|
|
|
live_buffer = "" |
|
final_response = None |
|
|
|
|
|
response_generator = chat.generate_response(user_input) |
|
|
|
try: |
|
|
|
print(colored("CosmicFish: ", 'blue'), end="") |
|
sys.stdout.flush() |
|
|
|
for token, live_text, is_done in response_generator: |
|
|
|
if is_done: |
|
final_response = live_text |
|
|
|
if not live_buffer: |
|
print(final_response, end="") |
|
break |
|
if token: |
|
|
|
if "<|endoftext|>" in token: |
|
token = token.replace("<|endoftext|>", "") |
|
if token: |
|
print(token, end="", flush=True) |
|
break |
|
|
|
|
|
print(token, end="", flush=True) |
|
live_buffer += token |
|
|
|
except KeyboardInterrupt: |
|
|
|
print("\n[Generation interrupted]") |
|
final_response = "I was going to respond, but I'll stop here since you interrupted." |
|
|
|
|
|
print() |
|
|
|
except KeyboardInterrupt: |
|
print("\n\nKeyboard interrupt detected. Type /exit to quit or continue chatting.") |
|
|
|
except Exception as e: |
|
print(colored(f"\nError: {str(e)}", 'red')) |
|
logger.error(f"Error in chat loop: {str(e)}", exc_info=True) |
|
|
|
except Exception as e: |
|
print(colored(f"Error setting up chat: {str(e)}", 'red')) |
|
logger.error(f"Error setting up chat: {str(e)}", exc_info=True) |
|
sys.exit(1) |
|
|
|
|
|
if __name__ == "__main__": |
|
try: |
|
main() |
|
except Exception as e: |
|
logger.error(f"Fatal error: {str(e)}", exc_info=True) |
|
sys.exit(1) |
|
|