akkiisfrommars commited on
Commit
065b3d5
·
verified ·
1 Parent(s): 66b7216

Upload 3 files

Browse files
Files changed (3) hide show
  1. LICENSE +201 -0
  2. chat.py +832 -0
  3. modeling_cosmicfish.py +290 -0
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2025 Mistyoz AI Private Limited
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
chat.py ADDED
@@ -0,0 +1,832 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Chat interface for the released CosmicFish model from Hugging Face.
3
+ Compatible with the HF-format release while maintaining all original features.
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import time
9
+ import argparse
10
+ import torch
11
+ import numpy as np
12
+ from termcolor import colored
13
+ import logging
14
+ import readline # Enables arrow key history in terminal input
15
+ import re
16
+ import textwrap
17
+ import random
18
+ from collections import defaultdict
19
+ import json
20
+
21
+ # Try to import from transformers, fallback to local
22
+ try:
23
+ from transformers import GPT2Tokenizer
24
+ HF_AVAILABLE = True
25
+ except ImportError:
26
+ HF_AVAILABLE = False
27
+ print("❌ Transformers not available. Install with: pip install transformers")
28
+
29
+ # Import the model classes - try both locations
30
+ try:
31
+ from modeling_cosmicfish import CosmicFish, CosmicConfig
32
+ except ImportError:
33
+ try:
34
+ from model import CosmicFish, CosmicConfig
35
+ except ImportError:
36
+ print("❌ CosmicFish model classes not found. Make sure modeling_cosmicfish.py or model.py is available.")
37
+ sys.exit(1)
38
+
39
+ # Set up logging
40
+ logging.basicConfig(
41
+ level=logging.INFO,
42
+ format='%(asctime)s - %(levelname)s - %(message)s',
43
+ handlers=[logging.StreamHandler(sys.stdout)]
44
+ )
45
+ logger = logging.getLogger(__name__)
46
+
47
+ # Default prompt template
48
+ DEFAULT_PROMPT_TEMPLATE = "Below is a conversation between a helpful AI assistant and a human. The assistant is knowledgeable, friendly, and provides detailed and accurate responses.\n\n"
49
+
50
+
51
+ class RepetitionPenaltyLogitsProcessor:
52
+ """Apply repetition penalty to prevent repeating tokens."""
53
+
54
+ def __init__(self, penalty=1.2):
55
+ self.penalty = penalty
56
+
57
+ def __call__(self, input_ids, scores):
58
+ """Apply repetition penalty to logits where input_ids is already seen."""
59
+ score = torch.gather(scores, 1, input_ids)
60
+ # If score > 0, penalize by dividing; if score < 0, penalize by multiplying
61
+ score = torch.where(score > 0, score / self.penalty, score * self.penalty)
62
+ scores.scatter_(1, input_ids, score)
63
+ return scores
64
+
65
+
66
+ class CosmicFishChatSession:
67
+ """Chat session for the released CosmicFish model."""
68
+
69
+ def __init__(self, model, tokenizer, config):
70
+ """Initialize chat session with model and configuration."""
71
+ self.model = model
72
+ self.tokenizer = tokenizer
73
+ self.config = config
74
+ self.device = next(model.parameters()).device
75
+ self.history = []
76
+ self.history_tokens = []
77
+ self.max_history_tokens = config.max_history_tokens
78
+ self.prompt_template = config.prompt_template
79
+ self.human_prefix = config.human_prefix
80
+ self.assistant_prefix = config.assistant_prefix
81
+ self.end_of_turn = config.end_of_turn
82
+ self.block_size = config.block_size
83
+ self.debug_mode = config.debug_mode
84
+ self.repetition_penalty = config.repetition_penalty
85
+ self.min_tokens_to_generate = config.min_tokens_to_generate
86
+ self.max_retries = 20
87
+
88
+ self.fallback_responses = [
89
+ "I'd be happy to help with that. Could you provide more details about what specific information you're looking for?",
90
+ "That's a topic I can provide information about. What specific aspects would you like to know?",
91
+ "I understand your question. I can share factual information on this topic if you could specify what aspects you're interested in.",
92
+ "I can help with your question. To give you the most relevant information, could you clarify what specific details you're looking for?",
93
+ "I'd be glad to address your question. To provide the most helpful response, could you specify what particular aspects of this topic interest you?"
94
+ ]
95
+
96
+ self.generation_failure_message = "I'm sorry, but I'm having difficulty generating a response to that prompt. Could you try rephrasing your question or asking something else?"
97
+
98
+ # For token counting
99
+ self.total_prompt_tokens = 0
100
+ self.total_generated_tokens = 0
101
+
102
+ # End markers for live generation
103
+ self.end_markers = [
104
+ f"{self.human_prefix}",
105
+ "Human:",
106
+ "\nHuman:",
107
+ "\nH:",
108
+ "H:",
109
+ "<|endoftext|>",
110
+ "Below is a conversation",
111
+ "\nA:",
112
+ "A:",
113
+ "</s>",
114
+ "User:",
115
+ "\nUser:"
116
+ ]
117
+
118
+ # Print welcome message
119
+ if config.display_welcome:
120
+ self._print_welcome_message()
121
+
122
+ def _print_welcome_message(self):
123
+ """Print a welcome message to the user."""
124
+ welcome_text = f"""
125
+ {'=' * 80}
126
+ Welcome to CosmicFish chat interface (Hugging Face Release)
127
+
128
+ This is a {self.model.get_num_params() / 1e6:.1f}M parameter model.
129
+ CosmicFish features advanced architecture with RoPE, GQA, SwiGLU, and RMSNorm.
130
+
131
+ Type your prompts and CosmicFish will respond.
132
+
133
+ Special commands:
134
+ - /help: Show this help message
135
+ - /clear: Clear the conversation history
136
+ - /exit or /quit: Exit the chat
137
+ - /stats: Show token usage statistics
138
+ - /save [filename]: Save the conversation
139
+ - /load [filename]: Load a conversation
140
+ - /temp [value]: Set temperature (between 0.1 and 2.0)
141
+ - /penalty [value]: Set repetition penalty (1.0-2.0)
142
+ - /debug: Toggle debug mode
143
+ {'=' * 80}
144
+ """
145
+ print(colored(welcome_text, 'cyan'))
146
+
147
+ def _format_prompt(self, user_input):
148
+ """Format the complete prompt with history and current input."""
149
+ # Start with the template
150
+ formatted_prompt = self.prompt_template
151
+
152
+ # Add conversation history
153
+ for entry in self.history:
154
+ role, text = entry
155
+ if role == "human":
156
+ formatted_prompt += f"{self.human_prefix}{text}{self.end_of_turn}"
157
+ else: # assistant
158
+ formatted_prompt += f"{self.assistant_prefix}{text}{self.end_of_turn}"
159
+
160
+ # Add the current user input
161
+ formatted_prompt += f"{self.human_prefix}{user_input}{self.end_of_turn}{self.assistant_prefix}"
162
+
163
+ return formatted_prompt
164
+
165
+ def _tokenize(self, text):
166
+ """Tokenize text and return token IDs."""
167
+ return self.tokenizer.encode(text)
168
+
169
+ def _update_history(self, user_input, response):
170
+ """Update conversation history."""
171
+ # Add to text history
172
+ self.history.append(("human", user_input))
173
+ self.history.append(("assistant", response))
174
+
175
+ # Update token history for context window management
176
+ user_tokens = self._tokenize(f"{self.human_prefix}{user_input}{self.end_of_turn}")
177
+ response_tokens = self._tokenize(f"{self.assistant_prefix}{response}{self.end_of_turn}")
178
+
179
+ self.history_tokens.extend(user_tokens)
180
+ self.history_tokens.extend(response_tokens)
181
+
182
+ # Track token usage
183
+ self.total_prompt_tokens += len(user_tokens)
184
+ self.total_generated_tokens += len(response_tokens)
185
+
186
+ # Trim history if it gets too long
187
+ self._trim_history_if_needed()
188
+
189
+ def _trim_history_if_needed(self):
190
+ """Trim history to fit within the context window."""
191
+ if len(self.history_tokens) > self.max_history_tokens:
192
+ # Remove oldest turns until we're under the limit
193
+ while len(self.history_tokens) > self.max_history_tokens and len(self.history) >= 2:
194
+ # Remove oldest human and assistant turn
195
+ self.history = self.history[2:]
196
+
197
+ # Find token boundary for the removed turns
198
+ user_turn = self.history[0][1]
199
+ assistant_turn = self.history[1][1]
200
+ user_tokens = len(self._tokenize(f"{self.human_prefix}{user_turn}{self.end_of_turn}"))
201
+ assistant_tokens = len(self._tokenize(f"{self.assistant_prefix}{assistant_turn}{self.end_of_turn}"))
202
+
203
+ # Update token history
204
+ self.history_tokens = self.history_tokens[user_tokens + assistant_tokens:]
205
+
206
+ def _should_stop_generation(self, text):
207
+ """Check if generation should stop based on end markers."""
208
+ for marker in self.end_markers:
209
+ if marker in text:
210
+ return True
211
+ return False
212
+
213
+ def _clean_token_text(self, text):
214
+ """Clean token text by fixing encoding issues."""
215
+ # Fix the specific issue with �� -> '
216
+ text = text.replace('��', "'")
217
+ return text
218
+
219
+ def generate_with_repetition_penalty(self, input_ids, max_new_tokens, temperature, top_k, penalty=1.2, live=False):
220
+ """Custom generate function with repetition penalty and optional live generation."""
221
+ model = self.model
222
+ device = self.device
223
+
224
+ # Ensure model is in eval mode
225
+ model.eval()
226
+
227
+ # Initialize sequence with input_ids
228
+ generated = input_ids.clone()
229
+
230
+ # Initialize live text buffer
231
+ live_buffer = ""
232
+
233
+ # Create repetition penalty processor
234
+ rep_processor = RepetitionPenaltyLogitsProcessor(penalty=penalty)
235
+
236
+ # Counter for generated tokens
237
+ tokens_generated = 0
238
+ min_tokens = self.min_tokens_to_generate
239
+
240
+ # EOT token ID
241
+ eot_token_id = self.tokenizer.eos_token_id if hasattr(self.tokenizer, 'eos_token_id') else 50256
242
+
243
+ # Generate tokens one at a time
244
+ for _ in range(max_new_tokens):
245
+ # Get only the last block_size tokens if context is too long
246
+ if generated.size(1) > self.block_size:
247
+ context = generated[:, -self.block_size:]
248
+ else:
249
+ context = generated
250
+
251
+ # Forward pass for next token prediction
252
+ with torch.no_grad():
253
+ logits, _ = model(context)
254
+
255
+ # Get logits for the next token (last position)
256
+ next_token_logits = logits[:, -1, :]
257
+
258
+ # Apply temperature
259
+ next_token_logits = next_token_logits / temperature
260
+
261
+ # Apply repetition penalty
262
+ if penalty > 1.0:
263
+ next_token_logits = rep_processor(context, next_token_logits)
264
+
265
+ # Optional top-k sampling
266
+ if top_k is not None:
267
+ indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
268
+ next_token_logits[indices_to_remove] = float('-inf')
269
+
270
+ # Convert logits to probabilities
271
+ probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
272
+
273
+ # Sample next token
274
+ next_token = torch.multinomial(probs, num_samples=1)
275
+
276
+ # Check if the next token is EOT and break immediately if so
277
+ if next_token.item() == eot_token_id:
278
+ if live:
279
+ yield "", live_buffer, True
280
+ break
281
+
282
+ # Append next token to generated sequence
283
+ generated = torch.cat((generated, next_token), dim=1)
284
+ tokens_generated += 1
285
+
286
+ # If live generation, decode and yield the token
287
+ if live:
288
+ # Decode the next token
289
+ next_token_text = self.tokenizer.decode([next_token.item()])
290
+ # Clean the token text to fix encoding issues
291
+ next_token_text = self._clean_token_text(next_token_text)
292
+ live_buffer += next_token_text
293
+
294
+ # Check if we've hit an end marker in the buffer
295
+ eot_marker_pos = live_buffer.find("<|endoftext|>")
296
+ if eot_marker_pos != -1:
297
+ # Cut off at the EOT marker
298
+ live_buffer = live_buffer[:eot_marker_pos]
299
+ yield "", live_buffer, True
300
+ break
301
+
302
+ # Check other end markers
303
+ should_stop = tokens_generated >= min_tokens and self._should_stop_generation(live_buffer)
304
+ yield next_token_text, live_buffer, should_stop
305
+
306
+ if should_stop:
307
+ break
308
+
309
+ # For non-live generation, check if we should stop
310
+ elif tokens_generated >= min_tokens:
311
+ # Check for end markers in the recent generated tokens
312
+ recent_text = self.tokenizer.decode(generated[0, -20:].tolist())
313
+ if self._should_stop_generation(recent_text):
314
+ break
315
+
316
+ # Check if we generated any tokens at all
317
+ if tokens_generated == 0 and not live:
318
+ if self.debug_mode:
319
+ print(colored("\n[No tokens generated in this attempt]", "red"))
320
+ return None
321
+
322
+ if not live:
323
+ return generated
324
+
325
+ def generate_response(self, user_input):
326
+ """Generate a response to the user input."""
327
+ # Format the complete prompt
328
+ prompt = self._format_prompt(user_input)
329
+
330
+ # Tokenize the prompt
331
+ input_ids = torch.tensor(self._tokenize(prompt), dtype=torch.long).unsqueeze(0).to(self.device)
332
+
333
+ # Ensure we don't exceed the model's context length
334
+ if input_ids.size(1) > self.block_size:
335
+ # If too long, keep the beginning part with the instruction template and trim the middle
336
+ instruction_tokens = self._tokenize(self.prompt_template)
337
+ # Keep the instruction and the most recent conversation that will fit
338
+ keep_from_beginning = len(instruction_tokens)
339
+ keep_from_end = self.block_size - keep_from_beginning
340
+
341
+ # Combine beginning and end, ensuring we don't exceed array bounds
342
+ if keep_from_end < 0:
343
+ # If instruction alone is too long, trim it (shouldn't happen with reasonable templates)
344
+ input_ids = input_ids[:, :self.block_size]
345
+ else:
346
+ # Keep instruction and most recent conversation
347
+ input_ids = torch.cat([
348
+ input_ids[:, :keep_from_beginning],
349
+ input_ids[:, -(keep_from_end):]
350
+ ], dim=1)
351
+
352
+ # Track generation start time
353
+ start_time = time.time()
354
+
355
+ # Always use live generation
356
+ return self._generate_live_response(input_ids, user_input, start_time)
357
+
358
+ def _generate_live_response(self, input_ids, user_input, start_time):
359
+ """Generate response with live token-by-token output."""
360
+ # Initialize for live generation
361
+ live_text = ""
362
+ tokens_generated = 0
363
+ retry_count = 0
364
+
365
+ # Keep trying until we get a valid response or exhaust retries
366
+ while retry_count <= self.max_retries:
367
+ if retry_count > 0:
368
+ # Calculate temperature for this retry
369
+ if retry_count % 2 == 0:
370
+ # Even retries: increase temperature
371
+ temp_adjustment = min(0.2 * (retry_count // 2), 0.8)
372
+ current_temp = min(self.config.temperature + temp_adjustment, 1.8)
373
+ else:
374
+ # Odd retries: decrease temperature
375
+ temp_adjustment = min(0.2 * ((retry_count + 1) // 2), 0.4)
376
+ current_temp = max(self.config.temperature - temp_adjustment, 0.2)
377
+
378
+ if self.debug_mode:
379
+ print(colored(f"\n[Live retry {retry_count}: Using temperature {current_temp:.2f}]", "yellow"))
380
+ else:
381
+ current_temp = self.config.temperature
382
+
383
+ # Reset for this attempt
384
+ live_text = ""
385
+ tokens_generated = 0
386
+ generation_failed = False
387
+
388
+ # Try to generate with current settings
389
+ try:
390
+ # Generate with live output
391
+ for token_text, live_buffer, should_stop in self.generate_with_repetition_penalty(
392
+ input_ids,
393
+ max_new_tokens=self.config.max_new_tokens,
394
+ temperature=current_temp,
395
+ top_k=self.config.top_k,
396
+ penalty=self.repetition_penalty,
397
+ live=True
398
+ ):
399
+ # If we should stop but there's a token, this is the last one
400
+ if should_stop:
401
+ # Update with the final clean buffer (will have EOT removed if present)
402
+ live_text = live_buffer
403
+ break
404
+
405
+ # Otherwise add the token and continue
406
+ if token_text:
407
+ live_text += token_text
408
+ tokens_generated += 1
409
+ yield token_text, live_text, False
410
+
411
+ # Check if we got a valid response
412
+ if not live_text or len(live_text.strip()) < 10:
413
+ if self.debug_mode:
414
+ print(colored("\n[Live generation produced empty or too short response, retrying]", "yellow"))
415
+ generation_failed = True
416
+ retry_count += 1
417
+ # Clear any partial output
418
+ if retry_count <= self.max_retries:
419
+ print("\r" + " " * 80 + "\r", end="") # Clear the line
420
+ else:
421
+ # We got a valid response, stop retrying
422
+ break
423
+
424
+ except Exception as e:
425
+ if self.debug_mode:
426
+ print(colored(f"\n[Live generation error: {str(e)}, retrying]", "red"))
427
+ generation_failed = True
428
+ retry_count += 1
429
+
430
+ # If we still failed after all retries, use the failure message
431
+ if generation_failed or not live_text or len(live_text.strip()) < 10:
432
+ live_text = self.generation_failure_message
433
+ if self.debug_mode:
434
+ print(colored(f"\n[Returning failure message after {retry_count} live retries]", "red"))
435
+
436
+ # Calculate time taken and metrics
437
+ time_taken = time.time() - start_time
438
+ tokens_per_second = tokens_generated / time_taken if time_taken > 0 else 0
439
+
440
+ # Update history
441
+ self._update_history(user_input, live_text)
442
+
443
+ # Log generation stats
444
+ logger.debug(f"Generated {tokens_generated} tokens in {time_taken:.2f}s ({tokens_per_second:.2f} tokens/s)")
445
+
446
+ # Final yield of the complete response
447
+ yield "", live_text, True
448
+
449
+ def execute_command(self, command):
450
+ """Execute a special command prefixed with /."""
451
+ command = command.strip()
452
+
453
+ if command == '/help':
454
+ self._print_welcome_message()
455
+ return True
456
+
457
+ elif command == '/clear':
458
+ self.history = []
459
+ self.history_tokens = []
460
+ print(colored("Conversation history cleared.", 'yellow'))
461
+ return True
462
+
463
+ elif command in ['/exit', '/quit']:
464
+ print(colored("Goodbye!", 'cyan'))
465
+ return False # Signal to exit the chat loop
466
+
467
+ elif command == '/stats':
468
+ prompt_tokens = self.total_prompt_tokens
469
+ generated_tokens = self.total_generated_tokens
470
+ total_tokens = prompt_tokens + generated_tokens
471
+
472
+ stats = f"""
473
+ Token usage statistics:
474
+ - Prompt tokens: {prompt_tokens}
475
+ - Generated tokens: {generated_tokens}
476
+ - Total tokens: {total_tokens}
477
+ - Current history length: {len(self.history_tokens)} tokens
478
+ - Current repetition penalty: {self.repetition_penalty}
479
+ - Current temperature: {self.config.temperature}
480
+ - Model: CosmicFish ({self.model.get_num_params() / 1e6:.1f}M parameters)
481
+ """
482
+ print(colored(stats, 'yellow'))
483
+ return True
484
+
485
+ elif command == '/debug':
486
+ self.debug_mode = not self.debug_mode
487
+ self.config.debug_mode = self.debug_mode # Sync with config
488
+ mode = "enabled" if self.debug_mode else "disabled"
489
+ print(colored(f"Debug mode {mode}", 'yellow'))
490
+ return True
491
+
492
+ elif command.startswith('/penalty '):
493
+ try:
494
+ penalty = float(command[9:].strip())
495
+ if 1.0 <= penalty <= 2.0:
496
+ self.repetition_penalty = penalty
497
+ print(colored(f"Repetition penalty set to {penalty}", 'yellow'))
498
+ else:
499
+ print(colored("Repetition penalty should be between 1.0 and 2.0", 'red'))
500
+ except ValueError:
501
+ print(colored("Invalid repetition penalty value. Please use a number between 1.0 and 2.0", 'red'))
502
+ return True
503
+
504
+ elif command.startswith('/temp '):
505
+ try:
506
+ temp = float(command[6:].strip())
507
+ if 0.1 <= temp <= 2.0:
508
+ self.config.temperature = temp
509
+ print(colored(f"Temperature set to {temp}", 'yellow'))
510
+ else:
511
+ print(colored("Temperature should be between 0.1 and 2.0", 'red'))
512
+ except ValueError:
513
+ print(colored("Invalid temperature value. Please use a number between 0.1 and 2.0", 'red'))
514
+ return True
515
+
516
+ elif command.startswith('/save '):
517
+ filename = command[6:].strip()
518
+ if not filename:
519
+ print(colored("Please specify a filename: /save <filename>", 'red'))
520
+ return True
521
+
522
+ try:
523
+ # Create conversations directory if it doesn't exist
524
+ os.makedirs('conversations', exist_ok=True)
525
+
526
+ # Add .txt extension if not present
527
+ if not filename.endswith('.txt'):
528
+ filename += '.txt'
529
+
530
+ filepath = os.path.join('conversations', filename)
531
+
532
+ with open(filepath, 'w', encoding='utf-8') as f:
533
+ for entry in self.history:
534
+ role, text = entry
535
+ prefix = self.human_prefix if role == "human" else self.assistant_prefix
536
+ f.write(f"{prefix}{text}{self.end_of_turn}")
537
+
538
+ print(colored(f"Conversation saved to {filepath}", 'green'))
539
+
540
+ except Exception as e:
541
+ print(colored(f"Error saving conversation: {str(e)}", 'red'))
542
+
543
+ return True
544
+
545
+ elif command.startswith('/load '):
546
+ filename = command[6:].strip()
547
+ if not filename:
548
+ print(colored("Please specify a filename: /load <filename>", 'red'))
549
+ return True
550
+
551
+ try:
552
+ # Add .txt extension if not present
553
+ if not filename.endswith('.txt'):
554
+ filename += '.txt'
555
+
556
+ filepath = os.path.join('conversations', filename)
557
+
558
+ if not os.path.exists(filepath):
559
+ print(colored(f"File not found: {filepath}", 'red'))
560
+ return True
561
+
562
+ with open(filepath, 'r', encoding='utf-8') as f:
563
+ content = f.read()
564
+
565
+ # Parse conversation turns
566
+ self.history = []
567
+ self.history_tokens = []
568
+
569
+ # Split by end of turn marker
570
+ turns = content.split(self.end_of_turn)
571
+ for turn in turns:
572
+ turn = turn.strip()
573
+ if not turn:
574
+ continue
575
+
576
+ if turn.startswith(self.human_prefix):
577
+ text = turn[len(self.human_prefix):].strip()
578
+ self.history.append(("human", text))
579
+ elif turn.startswith(self.assistant_prefix):
580
+ text = turn[len(self.assistant_prefix):].strip()
581
+ self.history.append(("assistant", text))
582
+
583
+ # Recalculate token counts
584
+ self.history_tokens = []
585
+ for entry in self.history:
586
+ role, text = entry
587
+ if role == "human":
588
+ self.history_tokens.extend(self._tokenize(f"{self.human_prefix}{text}{self.end_of_turn}"))
589
+ else:
590
+ self.history_tokens.extend(self._tokenize(f"{self.assistant_prefix}{text}{self.end_of_turn}"))
591
+
592
+ print(colored(f"Loaded conversation from {filepath} ({len(self.history) // 2} turns)", 'green'))
593
+
594
+ # Print the conversation
595
+ for i in range(0, len(self.history), 2):
596
+ if i < len(self.history):
597
+ user_text = self.history[i][1]
598
+ print(colored(f"\nYou: {user_text}", 'green'))
599
+
600
+ if i + 1 < len(self.history):
601
+ assistant_text = self.history[i + 1][1]
602
+ print(colored("CosmicFish: ", 'blue'), end="")
603
+ for line in assistant_text.split('\n'):
604
+ wrapped_lines = textwrap.wrap(line, width=100) if line.strip() else ['']
605
+ for wrapped_line in wrapped_lines:
606
+ print(wrapped_line)
607
+
608
+ except Exception as e:
609
+ print(colored(f"Error loading conversation: {str(e)}", 'red'))
610
+
611
+ return True
612
+
613
+ else:
614
+ print(colored(f"Unknown command: {command}. Type /help for available commands.", 'red'))
615
+ return True
616
+
617
+
618
+ def load_cosmicfish_model(model_dir, device='cpu'):
619
+ """Load CosmicFish model from HF-format directory"""
620
+ print(f"Loading CosmicFish model from {model_dir}...")
621
+
622
+ # Load config
623
+ config_path = os.path.join(model_dir, "config.json")
624
+ if not os.path.exists(config_path):
625
+ raise FileNotFoundError(f"config.json not found in {model_dir}")
626
+
627
+ with open(config_path, "r") as f:
628
+ config_dict = json.load(f)
629
+
630
+ # Create CosmicConfig
631
+ config = CosmicConfig(
632
+ vocab_size=config_dict["vocab_size"],
633
+ block_size=config_dict["block_size"],
634
+ n_layer=config_dict["n_layer"],
635
+ n_head=config_dict["n_head"],
636
+ n_embd=config_dict["n_embd"],
637
+ bias=config_dict["bias"],
638
+ dropout=0.0, # Set to 0 for inference
639
+ eps=config_dict.get("eps", 1e-6),
640
+ use_rotary=config_dict["use_rotary"],
641
+ use_swiglu=config_dict["use_swiglu"],
642
+ use_gqa=config_dict["use_gqa"],
643
+ n_query_groups=config_dict["n_query_groups"],
644
+ use_qk_norm=config_dict.get("use_qk_norm", False)
645
+ )
646
+
647
+ # Create model
648
+ model = CosmicFish(config)
649
+
650
+ # Load weights
651
+ weights_path = os.path.join(model_dir, "pytorch_model.bin")
652
+ if not os.path.exists(weights_path):
653
+ raise FileNotFoundError(f"pytorch_model.bin not found in {model_dir}")
654
+
655
+ state_dict = torch.load(weights_path, map_location=device)
656
+ model.load_state_dict(state_dict)
657
+ model.to(device)
658
+ model.eval()
659
+
660
+ print(f"✅ Model loaded: {model.get_num_params() / 1e6:.1f}M parameters")
661
+ return model, config
662
+
663
+
664
+ def load_tokenizer():
665
+ """Load GPT-2 tokenizer"""
666
+ if not HF_AVAILABLE:
667
+ raise ImportError("transformers library required. Install with: pip install transformers")
668
+
669
+ print("Loading GPT-2 tokenizer...")
670
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
671
+ print("✅ Tokenizer loaded")
672
+ return tokenizer
673
+
674
+
675
+ def main():
676
+ parser = argparse.ArgumentParser(description="Chat with the released CosmicFish model")
677
+
678
+ # Model parameters
679
+ parser.add_argument("--model_dir", type=str, default="./cosmicfish-hf-release",
680
+ help="Path to the HF-format model directory")
681
+ parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu",
682
+ help="Device to use (cuda or cpu)")
683
+
684
+ # Generation parameters
685
+ parser.add_argument("--temperature", type=float, default=0.6,
686
+ help="Temperature for sampling (default: 0.7)")
687
+ parser.add_argument("--max_tokens", type=int, default=1024,
688
+ help="Maximum number of tokens to generate per response")
689
+ parser.add_argument("--min_tokens", type=int, default=10,
690
+ help="Minimum number of tokens to generate per response")
691
+ parser.add_argument("--top_k", type=int, default=40,
692
+ help="Top-k sampling (0 to disable)")
693
+ parser.add_argument("--repetition_penalty", type=float, default=1.2,
694
+ help="Repetition penalty (1.0 = no penalty, 1.2 = mild, 1.5 = moderate)")
695
+
696
+ # Chat parameters
697
+ parser.add_argument("--human_prefix", type=str, default="Human: ",
698
+ help="Prefix for human messages")
699
+ parser.add_argument("--assistant_prefix", type=str, default="Assistant: ",
700
+ help="Prefix for assistant messages")
701
+ parser.add_argument("--end_of_turn", type=str, default="\n\n",
702
+ help="Delimiter between conversation turns")
703
+ parser.add_argument("--instruction", type=str,
704
+ default=DEFAULT_PROMPT_TEMPLATE,
705
+ help="Instruction prompt to prepend to the conversation")
706
+ parser.add_argument("--max_history", type=int, default=1024,
707
+ help="Maximum number of tokens to keep in history")
708
+
709
+ # UI parameters
710
+ parser.add_argument("--no_welcome", action="store_true",
711
+ help="Don't display the welcome message")
712
+ parser.add_argument("--debug", action="store_true",
713
+ help="Enable debug mode")
714
+
715
+ args = parser.parse_args()
716
+
717
+ # Configure device
718
+ device = args.device
719
+ if device == "cuda" and not torch.cuda.is_available():
720
+ print("CUDA is not available, falling back to CPU")
721
+ device = "cpu"
722
+
723
+ try:
724
+ # Load the model
725
+ model, model_config = load_cosmicfish_model(args.model_dir, device)
726
+
727
+ # Load tokenizer
728
+ tokenizer = load_tokenizer()
729
+
730
+ # Create a config object with all the necessary parameters
731
+ class ChatConfig:
732
+ def __init__(self, args, block_size):
733
+ self.device = device
734
+ self.temperature = args.temperature
735
+ self.max_new_tokens = args.max_tokens
736
+ self.min_tokens_to_generate = args.min_tokens
737
+ self.top_k = args.top_k
738
+ self.human_prefix = args.human_prefix
739
+ self.assistant_prefix = args.assistant_prefix
740
+ self.end_of_turn = args.end_of_turn
741
+ self.prompt_template = args.instruction
742
+ self.max_history_tokens = args.max_history
743
+ self.display_welcome = not args.no_welcome
744
+ self.block_size = block_size
745
+ self.debug_mode = args.debug
746
+ self.repetition_penalty = args.repetition_penalty
747
+
748
+ config = ChatConfig(args, model_config.block_size)
749
+
750
+ # Initialize chat session
751
+ chat = CosmicFishChatSession(model, tokenizer, config)
752
+
753
+ # Main chat loop
754
+ print(colored("\nCosmicFish initialized. Type your message (or /help for commands).\n", 'cyan'))
755
+
756
+ while True:
757
+ try:
758
+ # Get user input
759
+ user_input = input(colored("You: ", 'green'))
760
+
761
+ # Check if it's a command
762
+ if user_input.startswith('/'):
763
+ # Execute command, continue loop if True, exit if False
764
+ if not chat.execute_command(user_input):
765
+ break
766
+ continue
767
+
768
+ # Skip if empty input
769
+ if not user_input.strip():
770
+ continue
771
+
772
+ # Generate response using live generation
773
+ live_buffer = ""
774
+ final_response = None
775
+
776
+ # Use the generator version
777
+ response_generator = chat.generate_response(user_input)
778
+
779
+ try:
780
+ # First print the assistant prefix
781
+ print(colored("CosmicFish: ", 'blue'), end="")
782
+ sys.stdout.flush()
783
+
784
+ for token, live_text, is_done in response_generator:
785
+ # If this is the final clean response
786
+ if is_done:
787
+ final_response = live_text
788
+ # Print the final response directly if we didn't get any tokens yet
789
+ if not live_buffer:
790
+ print(final_response, end="")
791
+ break
792
+
793
+ # If we have a token to display
794
+ if token:
795
+ # Check if token contains <|endoftext|> and remove it if present
796
+ if "<|endoftext|>" in token:
797
+ token = token.replace("<|endoftext|>", "")
798
+ if token: # Only print if there's anything left
799
+ print(token, end="", flush=True)
800
+ break
801
+
802
+ # Display it
803
+ print(token, end="", flush=True)
804
+ live_buffer += token
805
+
806
+ except KeyboardInterrupt:
807
+ # Allow user to interrupt generation
808
+ print("\n[Generation interrupted]")
809
+ final_response = "I was going to respond, but I'll stop here since you interrupted."
810
+
811
+ # Add an extra line for readability
812
+ print()
813
+
814
+ except KeyboardInterrupt:
815
+ print("\n\nKeyboard interrupt detected. Type /exit to quit or continue chatting.")
816
+
817
+ except Exception as e:
818
+ print(colored(f"\nError: {str(e)}", 'red'))
819
+ logger.error(f"Error in chat loop: {str(e)}", exc_info=True)
820
+
821
+ except Exception as e:
822
+ print(colored(f"Error loading model: {str(e)}", 'red'))
823
+ logger.error(f"Error loading model: {str(e)}", exc_info=True)
824
+ sys.exit(1)
825
+
826
+
827
+ if __name__ == "__main__":
828
+ try:
829
+ main()
830
+ except Exception as e:
831
+ logger.error(f"Fatal error: {str(e)}", exc_info=True)
832
+ sys.exit(1)
modeling_cosmicfish.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn import functional as F
5
+
6
+
7
+ class CosmicConfig:
8
+ """Configuration class for CosmicFish."""
9
+
10
+ def __init__(self,
11
+ vocab_size=50257,
12
+ block_size=512,
13
+ n_layer=10,
14
+ n_head=16,
15
+ n_embd=640,
16
+ bias=True,
17
+ dropout=0.0, # Always 0 for inference
18
+ n_query_groups=4,
19
+ eps=1e-6,
20
+ use_rotary=True,
21
+ use_swiglu=True,
22
+ use_qk_norm=False,
23
+ use_gqa=True):
24
+ self.vocab_size = vocab_size
25
+ self.block_size = block_size
26
+ self.n_layer = n_layer
27
+ self.n_head = n_head
28
+ self.n_embd = n_embd
29
+ self.bias = bias
30
+ self.dropout = dropout
31
+ self.eps = eps
32
+ self.use_rotary = use_rotary
33
+ self.use_swiglu = use_swiglu
34
+ self.use_qk_norm = use_qk_norm
35
+ self.use_gqa = use_gqa
36
+ self.n_query_groups = n_query_groups if use_gqa else n_head
37
+ # Ensure n_head is divisible by n_query_groups
38
+ assert n_head % self.n_query_groups == 0, "n_head must be divisible by n_query_groups"
39
+
40
+
41
+ class RMSNorm(nn.Module):
42
+ """Root Mean Square Normalization"""
43
+
44
+ def __init__(self, dim, eps=1e-6):
45
+ super().__init__()
46
+ self.eps = eps
47
+ self.weight = nn.Parameter(torch.ones(dim))
48
+
49
+ def forward(self, x):
50
+ rms = torch.sqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
51
+ return self.weight * (x / rms)
52
+
53
+
54
+ def precompute_freqs_cis(dim, end, theta=10000.0):
55
+ """Precompute the frequency tensor for complex exponentials (cis)"""
56
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
57
+ t = torch.arange(end, device=freqs.device)
58
+ freqs = torch.outer(t, freqs)
59
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
60
+ return freqs_cis
61
+
62
+
63
+ def apply_rotary_emb(xq, xk, freqs_cis):
64
+ """Apply rotary embeddings to input tensors"""
65
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
66
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
67
+
68
+ seq_len = xq_.size(2)
69
+ if freqs_cis.size(0) < seq_len:
70
+ raise ValueError(f"freqs_cis has only {freqs_cis.size(0)} values but sequence length is {seq_len}")
71
+
72
+ freqs_cis_seq = freqs_cis[:seq_len]
73
+ xq_out = torch.view_as_real(xq_ * freqs_cis_seq.unsqueeze(0)).flatten(3)
74
+ xk_out = torch.view_as_real(xk_ * freqs_cis_seq.unsqueeze(0)).flatten(3)
75
+
76
+ return xq_out.type_as(xq), xk_out.type_as(xk)
77
+
78
+
79
+ class GroupedQueryAttention(nn.Module):
80
+ """Grouped Query Attention (GQA) implementation"""
81
+
82
+ def __init__(self, config):
83
+ super().__init__()
84
+ assert config.n_embd % config.n_head == 0
85
+
86
+ head_dim = config.n_embd // config.n_head
87
+ self.head_dim = head_dim
88
+ self.n_head = config.n_head
89
+ self.n_embd = config.n_embd
90
+ self.n_query_groups = config.n_query_groups
91
+
92
+ self.kv_heads = config.n_head // config.n_query_groups if config.use_gqa else config.n_head
93
+ qkv_proj_size = (config.n_head + 2 * self.kv_heads) * head_dim
94
+
95
+ self.c_attn = nn.Linear(config.n_embd, qkv_proj_size, bias=config.bias)
96
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
97
+
98
+ # Flash attention support
99
+ self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
100
+ if not self.flash:
101
+ self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
102
+ .view(1, 1, config.block_size, config.block_size))
103
+
104
+ # Query-key normalization
105
+ self.qk_norm = getattr(config, 'use_qk_norm', False)
106
+ if self.qk_norm:
107
+ self.q_norm = RMSNorm(head_dim, eps=getattr(config, 'eps', 1e-6))
108
+ self.k_norm = RMSNorm(head_dim, eps=getattr(config, 'eps', 1e-6))
109
+
110
+ def forward(self, x, freqs_cis=None):
111
+ B, T, C = x.size()
112
+ qkv = self.c_attn(x)
113
+ head_dim = C // self.n_head
114
+
115
+ q_size = self.n_head * head_dim
116
+ k_size = self.kv_heads * head_dim
117
+ v_size = self.kv_heads * head_dim
118
+
119
+ q, k, v = qkv.split([q_size, k_size, v_size], dim=2)
120
+
121
+ q = q.view(B, T, self.n_head, head_dim).transpose(1, 2)
122
+ k = k.view(B, T, self.kv_heads, head_dim).transpose(1, 2)
123
+ v = v.view(B, T, self.kv_heads, head_dim).transpose(1, 2)
124
+
125
+ # Repeat k and v if needed for GQA
126
+ if self.kv_heads < self.n_head:
127
+ repeats = self.n_head // self.kv_heads
128
+ k = k.repeat_interleave(repeats, dim=1)
129
+ v = v.repeat_interleave(repeats, dim=1)
130
+
131
+ # Apply rotary embeddings
132
+ if freqs_cis is not None:
133
+ q, k = apply_rotary_emb(q, k, freqs_cis)
134
+
135
+ # Apply query-key normalization
136
+ if self.qk_norm:
137
+ q = self.q_norm(q)
138
+ k = self.k_norm(k)
139
+
140
+ # Compute attention
141
+ if self.flash:
142
+ y = torch.nn.functional.scaled_dot_product_attention(
143
+ q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True
144
+ )
145
+ else:
146
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
147
+ att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
148
+ att = F.softmax(att, dim=-1)
149
+ y = att @ v
150
+
151
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
152
+ y = self.c_proj(y)
153
+ return y
154
+
155
+
156
+ class Block(nn.Module):
157
+ """Transformer block"""
158
+
159
+ def __init__(self, config):
160
+ super().__init__()
161
+ self.ln_1 = RMSNorm(config.n_embd, eps=config.eps)
162
+ self.ln_2 = RMSNorm(config.n_embd, eps=config.eps)
163
+ self.attn = GroupedQueryAttention(config)
164
+
165
+ # MLP implementation based on configuration
166
+ if config.use_swiglu:
167
+ # SwiGLU MLP
168
+ self.mlp = nn.ModuleDict(dict(
169
+ gate=nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias),
170
+ up=nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias),
171
+ down=nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias),
172
+ act=nn.SiLU(),
173
+ ))
174
+ m = self.mlp
175
+ self.mlpf = lambda x: m.down(m.act(m.up(x)) * m.gate(x))
176
+ else:
177
+ # Traditional MLP
178
+ self.mlp = nn.ModuleDict(dict(
179
+ c_fc=nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias),
180
+ c_proj=nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias),
181
+ act=nn.GELU(),
182
+ ))
183
+ m = self.mlp
184
+ self.mlpf = lambda x: m.c_proj(m.act(m.c_fc(x)))
185
+
186
+ def forward(self, x, freqs_cis=None):
187
+ x = x + self.attn(self.ln_1(x), freqs_cis)
188
+ x = x + self.mlpf(self.ln_2(x))
189
+ return x
190
+
191
+
192
+ class CosmicFish(nn.Module):
193
+ """
194
+ CosmicFish model for inference only.
195
+ Features: Rotary Positional Embeddings, Grouped-Query Attention, SwiGLU, RMSNorm
196
+ """
197
+
198
+ def __init__(self, config):
199
+ super().__init__()
200
+ self.config = config
201
+
202
+ self.transformer = nn.ModuleDict(dict(
203
+ wte=nn.Embedding(config.vocab_size, config.n_embd),
204
+ h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
205
+ ln_f=RMSNorm(config.n_embd, eps=config.eps),
206
+ ))
207
+
208
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
209
+
210
+ # Share weights between embedding and output
211
+ self.transformer.wte.weight = self.lm_head.weight
212
+
213
+ # Precompute rotary embedding frequencies
214
+ if config.use_rotary:
215
+ head_dim = config.n_embd // config.n_head
216
+ self.freqs_cis = precompute_freqs_cis(head_dim, config.block_size)
217
+ else:
218
+ self.freqs_cis = None
219
+ self.transformer.wpe = nn.Embedding(config.block_size, config.n_embd)
220
+
221
+ def get_num_params(self, non_embedding=True):
222
+ """Return the number of parameters in the model."""
223
+ n_params = sum(p.numel() for p in self.parameters())
224
+ if non_embedding and hasattr(self.transformer, 'wpe'):
225
+ n_params -= self.transformer.wpe.weight.numel()
226
+ return n_params
227
+
228
+ def forward(self, idx, targets=None):
229
+ """Forward pass through the model."""
230
+ device = idx.device
231
+ b, t = idx.size()
232
+ assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
233
+
234
+ # Get token embeddings
235
+ tok_emb = self.transformer.wte(idx)
236
+
237
+ # Handle positional embeddings
238
+ if self.config.use_rotary:
239
+ x = tok_emb
240
+ freqs_cis = self.freqs_cis.to(device) if self.freqs_cis is not None else None
241
+ else:
242
+ pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0)
243
+ pos_emb = self.transformer.wpe(pos)
244
+ x = tok_emb + pos_emb
245
+ freqs_cis = None
246
+
247
+ # Apply transformer blocks
248
+ for block in self.transformer.h:
249
+ x = block(x, freqs_cis)
250
+
251
+ # Apply final normalization
252
+ x = self.transformer.ln_f(x)
253
+
254
+ # Calculate outputs
255
+ if targets is not None:
256
+ logits = self.lm_head(x)
257
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
258
+ else:
259
+ # For inference, only compute logits for the last token
260
+ logits = self.lm_head(x[:, [-1], :])
261
+ loss = None
262
+
263
+ return logits, loss
264
+
265
+ @torch.no_grad()
266
+ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
267
+ """
268
+ Generate text by sampling from the model, token by token.
269
+ """
270
+ for _ in range(max_new_tokens):
271
+ # Crop sequence to block size if needed
272
+ idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
273
+
274
+ # Forward pass
275
+ logits, _ = self(idx_cond)
276
+ logits = logits[:, -1, :] / temperature
277
+
278
+ # Apply top-k sampling
279
+ if top_k is not None:
280
+ v, _ = torch.topk(logits, top_k)
281
+ logits[logits < v[:, [-1]]] = -float('Inf')
282
+
283
+ # Sample next token
284
+ probs = F.softmax(logits, dim=-1)
285
+ idx_next = torch.multinomial(probs, num_samples=1)
286
+
287
+ # Append to sequence
288
+ idx = torch.cat((idx, idx_next), dim=1)
289
+
290
+ return idx