JonasGeiping commited on
Commit
6c1825c
·
verified ·
1 Parent(s): 89eb92b

Update raven_modeling_minimal.py

Browse files
Files changed (1) hide show
  1. raven_modeling_minimal.py +88 -1592
raven_modeling_minimal.py CHANGED
@@ -1,1603 +1,99 @@
1
- """Modeling file for HF compatibility and zero-shot experiments."""
2
 
3
- import torch
4
- import math
5
 
6
- from torch import Tensor
7
- from torch.nn.attention.flex_attention import create_block_mask, BlockMask, flex_attention
8
- from torch.nn.attention import bias as attn_bias
9
- from dataclasses import dataclass
10
- from typing import Union, Optional, Any
11
 
12
-
13
- from .raven_config_minimal import RavenConfig
14
- from transformers.cache_utils import Cache, DynamicCache, StaticCache
15
-
16
- ###################### Huggingface Glue code I ##################################################################
17
- from transformers import PreTrainedModel, GenerationMixin
18
- from transformers.utils import ModelOutput
19
- from transformers.generation.utils import GenerateDecoderOnlyOutput
20
-
21
- import torch.nn.functional as F
22
- from transformers import GenerationConfig
23
-
24
- torch.backends.cuda.enable_math_sdp(False)
25
-
26
-
27
- class RavenPreTrainedModel(PreTrainedModel):
28
- config_class = RavenConfig
29
- base_model_prefix = "model"
30
- supports_gradient_checkpointing = True
31
- _no_split_modules = ["SandwichBlock"]
32
- _skip_keys_device_placement = ["past_key_values"]
33
- _tied_weights_keys = ["lm_head.weight"]
34
- _supports_flash_attn_2 = True
35
- _supports_sdpa = True
36
- _supports_cache_class = True
37
- _supports_quantized_cache = False
38
- _supports_static_cache = True
39
- _tp_plan = {}
40
-
41
- def _init_weights(self, module):
42
- if not torch.rand((1,)).is_meta:
43
- print("Random Initialization not implemented.")
44
-
45
-
46
- @dataclass
47
- class CausalLMOutputRecurrentLatents(ModelOutput):
48
- loss: Optional[torch.Tensor] = None
49
- log_ppl: Optional[torch.Tensor] = None
50
- logits: Optional[torch.Tensor] = None
51
- past_key_values: Optional[Cache] = None
52
- latent_states: Optional[torch.Tensor] = None
53
- hidden_states: Optional[torch.Tensor] = None
54
- attention_maps: Optional[dict[int, torch.Tensor]] = None
55
- stats: Optional[dict] = None
56
-
57
-
58
- ###################### Minimal implementation from here ############################################################
59
-
60
-
61
- class RMSNorm(torch.nn.Module):
62
- """Saner dtype handling and slightly better for fusion"""
63
-
64
- def __init__(self, dim: int, eps: float = 1e-6):
65
- super().__init__()
66
- self.eps = eps
67
- self.weight = torch.nn.Parameter(torch.ones(dim))
68
-
69
- def _norm(self, x):
70
- return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
71
-
72
- def forward(self, x):
73
- with torch.autocast(enabled=False, device_type=x.device.type if x.device.type != "meta" else "cuda"):
74
- return self._norm(x.float()).type_as(x) * self.weight
75
-
76
- def reset_parameters(self) -> None:
77
- torch.nn.init.ones_(self.weight)
78
-
79
-
80
- class HuginnDynamicCache(DynamicCache):
81
- def __init__(self, lookup_strategy: str = "full") -> None:
82
- super().__init__()
83
- self._seen_tokens = 0
84
- self.key_cache: dict[int, dict[int, torch.Tensor]] = {}
85
- self.value_cache: dict[int, dict[int, torch.Tensor]] = {}
86
- # structure: cache[index_of_layer_or_recurrent_step][index_in_sequence]
87
- # the cache is held uncoalesced because certain recurrent steps may be missing for some sequence ids if using
88
- # per-token adaptive compute. In those cases, the "lookup_strategy" determines how to proceed
89
- # Also, It is critical that the head indices do not overlap with the recurrent iteration indices
90
- self.lookup_strategy = lookup_strategy
91
-
92
- def update(
93
- self,
94
- key_states: torch.Tensor,
95
- value_states: torch.Tensor,
96
- step_idx_tensor: torch.Tensor,
97
- lookup_strategy: Optional[str] = None,
98
- ) -> tuple[torch.Tensor, torch.Tensor]:
99
- step_idx: int = int(step_idx_tensor) # todo: fix dicts with tensor step_idx, currently the memberships fail
100
- lookup_strategy = self.lookup_strategy if lookup_strategy is None else lookup_strategy
101
- if "compress-" in self.lookup_strategy and step_idx > 1: # hardcode for current model!
102
- if "compress-s" in self.lookup_strategy:
103
- compression_stage = int(self.lookup_strategy.split("compress-")[1][1:])
104
- new_step_idx = (step_idx - 2) % compression_stage + 2
105
- elif "compress-anchor" in self.lookup_strategy:
106
- if step_idx - 2 < 4 * 8: # anchor onto first 8 recurrence steps # noqa: SIM108
107
- new_step_idx = step_idx
108
- else: # then re-use the next 4 KV states = one recurrence for all future recurrence
109
- new_step_idx = 34 + (step_idx - 34) % 4
110
- # print(step_idx, new_step_idx)
111
- else: # compress-r
112
- compression_stage = int(self.lookup_strategy.split("compress-")[1][1:])
113
- new_step_idx = (step_idx - 2) // compression_stage + 2
114
- step_idx = new_step_idx
115
- # Init
116
- if step_idx not in self.key_cache:
117
- self.key_cache[step_idx] = {}
118
- self.value_cache[step_idx] = {}
119
- # Update the number of seen tokens, we assume that step_idx=0 (first prelude) is always hit
120
- if step_idx == 0:
121
- self._seen_tokens += key_states.shape[-2]
122
- # Add entries to cache
123
- for idx, entry in enumerate(key_states.unbind(dim=-2)):
124
- if "compress-" not in self.lookup_strategy:
125
- assert step_idx < 0 or self._seen_tokens - key_states.shape[-2] + idx not in self.key_cache[step_idx]
126
- self.key_cache[step_idx][self._seen_tokens - key_states.shape[-2] + idx] = entry
127
- for idx, entry in enumerate(value_states.unbind(dim=-2)):
128
- self.value_cache[step_idx][self._seen_tokens - value_states.shape[-2] + idx] = entry
129
-
130
- # Materialize past state based on lookup strategy:
131
- if len(self.key_cache[step_idx]) == self._seen_tokens or self.lookup_strategy == "full":
132
- # All entries are present, materialize cache as normal
133
- return (
134
- torch.stack(list(self.key_cache[step_idx].values()), dim=-2),
135
- torch.stack(list(self.value_cache[step_idx].values()), dim=-2),
136
- )
137
- else: # some entries were not previously computed
138
- if lookup_strategy.startswith("latest-m4"):
139
- latest_keys = []
140
- latest_values = []
141
- for token_pos in range(self._seen_tokens):
142
- # For steps >= 2, use modulo 4, this hard-codes the huginn block structure for now
143
- if step_idx >= 2:
144
- # Find valid steps for this token position
145
- valid_steps = [s for s in range(step_idx + 1) if token_pos in self.key_cache[s]]
146
- max_step = max([s for s in valid_steps if s >= 2 and s % 4 == step_idx % 4])
147
- else:
148
- max_step = step_idx if token_pos in self.key_cache[step_idx] else 0
149
- latest_keys.append(self.key_cache[max_step][token_pos])
150
- latest_values.append(self.value_cache[max_step][token_pos])
151
- return torch.stack(latest_keys, dim=-2), torch.stack(latest_values, dim=-2)
152
- elif lookup_strategy.startswith("available-m4"):
153
- latest_keys = []
154
- latest_values = []
155
- for token_pos in range(self._seen_tokens):
156
- if token_pos in self.key_cache[step_idx]:
157
- step = step_idx
158
- else:
159
- # Find valid steps for this token position
160
- valid_steps = [s for s in range(step_idx + 1) if token_pos in self.key_cache[s]]
161
- step = max([s for s in valid_steps if s >= 2 and s % 4 == step_idx % 4])
162
- latest_keys.append(self.key_cache[step][token_pos])
163
- latest_values.append(self.value_cache[step][token_pos])
164
- return torch.stack(latest_keys, dim=-2), torch.stack(latest_values, dim=-2)
165
- elif lookup_strategy.startswith("always-last-m4"):
166
- latest_keys = []
167
- latest_values = []
168
- for token_pos in range(self._seen_tokens):
169
- # For steps >= 2, use modulo 4, this hard-codes the huginn block structure for now
170
- if step_idx >= 2:
171
- # Find valid steps for this token position
172
- valid_steps = [key_step for key_step in self.key_cache if token_pos in self.key_cache[key_step]]
173
- max_step = max([s for s in valid_steps if s >= 2 and s % 4 == step_idx % 4])
174
- else:
175
- max_step = step_idx if token_pos in self.key_cache[step_idx] else 0
176
- latest_keys.append(self.key_cache[max_step][token_pos])
177
- latest_values.append(self.value_cache[max_step][token_pos])
178
- return torch.stack(latest_keys, dim=-2), torch.stack(latest_values, dim=-2)
179
- elif lookup_strategy.startswith("skip"):
180
- existing_keys = []
181
- existing_values = []
182
- for token_pos in range(self._seen_tokens):
183
- if token_pos in self.key_cache[step_idx]:
184
- existing_keys.append(self.key_cache[step_idx][token_pos])
185
- existing_values.append(self.value_cache[step_idx][token_pos])
186
- return torch.stack(existing_keys, dim=-2), torch.stack(existing_values, dim=-2)
187
- elif lookup_strategy.startswith("randomized"): # sanity check
188
- rand_keys = []
189
- rand_values = []
190
- for token_pos in range(self._seen_tokens):
191
- if step_idx < 2: # For prelude steps
192
- max_step = step_idx if token_pos in self.key_cache[step_idx] else 0
193
- else: # Get all steps from same block position
194
- curr_modulo = (step_idx - 2) % 4 + 2
195
- valid_steps = [
196
- s
197
- for s in range(2, step_idx + 1)
198
- if (s - 2) % 4 + 2 == curr_modulo and token_pos in self.key_cache[s]
199
- ]
200
- max_step = valid_steps[torch.randint(len(valid_steps), (1,))]
201
- rand_keys.append(self.key_cache[max_step][token_pos])
202
- rand_values.append(self.value_cache[max_step][token_pos])
203
- return torch.stack(rand_keys, dim=-2), torch.stack(rand_values, dim=-2)
204
- else:
205
- raise ValueError(f"Unknown lookup strategy: {lookup_strategy}")
206
-
207
- def reset(self) -> None:
208
- """Reset the cache state."""
209
- self._seen_tokens = 0
210
- self.key_cache.clear()
211
- self.value_cache.clear()
212
-
213
- def clear_last_k_entries(self, k: int = 0):
214
- """Partially clear cache."""
215
- assert self._seen_tokens >= k
216
- self._seen_tokens = self._seen_tokens - k
217
- # self.key_cache[step_idx][self._seen_tokens - key_states.shape[-2] + idx] = entry
218
- self.key_cache = {
219
- step: {seq: seq_cache for seq, seq_cache in cache.items() if seq < self._seen_tokens}
220
- for step, cache in self.key_cache.items()
221
- }
222
- self.value_cache = {
223
- step: {seq: seq_cache for seq, seq_cache in cache.items() if seq < self._seen_tokens}
224
- for step, cache in self.value_cache.items()
225
- }
226
-
227
- def get_seq_length(self, step_idx: int = 0) -> int:
228
- return self._seen_tokens
229
-
230
- def get_memory_usage(self) -> float:
231
- total_bytes = 0
232
- # For each recurrent step/layer index
233
- for step_idx in self.key_cache:
234
- # Get the sequence cache for this step
235
- key_seq_cache = self.key_cache[step_idx]
236
- for seq_idx in key_seq_cache:
237
- key_tensor = key_seq_cache[seq_idx]
238
- # Add memory for of key tensors, assuming value is the same
239
- total_bytes += key_tensor.nelement() * key_tensor.element_size()
240
- return total_bytes * 2 / (1024 * 1024)
241
-
242
-
243
- class HuginnStaticCache(Cache):
244
- """Static Cache for the recurrent model"""
245
-
246
- is_compileable = False # this is todo
247
-
248
- def __init__(
249
- self,
250
- max_length: int,
251
- max_num_steps: int,
252
- num_heads: int,
253
- hidden_dim: int,
254
- batch_size: int = 1,
255
- lookup_strategy: str = "full",
256
- device: Optional[Union[torch.device, str]] = None,
257
- dtype: torch.dtype = torch.float32,
258
- ) -> None:
259
- super().__init__()
260
- self._seen_tokens = 0
261
- self.max_length = max_length
262
- self.lookup_strategy = lookup_strategy
263
-
264
- # Adjust max_num_steps based on compression strategy
265
- if "compress-" in lookup_strategy:
266
- compression_stage = int(lookup_strategy.split("compress-")[1][1:])
267
- if "compress-s" in lookup_strategy:
268
- # For modulo compression (s), we need steps for 0,1 + compressed steps
269
- self.max_num_steps = 4 + compression_stage
270
- else:
271
- # For relative compression, we need steps for 0,1 + compressed steps
272
- self.max_num_steps = 4 + (max_num_steps - 4 + compression_stage - 1) // compression_stage
273
- else:
274
- self.max_num_steps = max_num_steps
275
-
276
- # Pre-allocate cache tensors [steps, batch, heads, seq_len, head_dim]
277
- device = torch.device(device) if device is not None else None
278
- cache_shape = (self.max_num_steps, batch_size, num_heads, max_length, hidden_dim)
279
-
280
- self.key_cache = torch.zeros(cache_shape, dtype=dtype, device=device)
281
- self.value_cache = torch.zeros(cache_shape, dtype=dtype, device=device)
282
- self.valid_mask = torch.zeros((self.max_num_steps, max_length), dtype=torch.bool, device=device)
283
- # Mark tensors as static for compile
284
- torch._dynamo.mark_static_address(self.key_cache)
285
- torch._dynamo.mark_static_address(self.value_cache)
286
- torch._dynamo.mark_static_address(self.valid_mask)
287
-
288
- def update(
289
- self,
290
- key_states: torch.Tensor,
291
- value_states: torch.Tensor,
292
- step_idx: torch.Tensor,
293
- lookup_strategy: Optional[str] = None,
294
- ) -> tuple[torch.Tensor, torch.Tensor]:
295
- if step_idx == 0:
296
- self._seen_tokens += key_states.shape[-2]
297
-
298
- # Adjust step_idx for compression
299
- lookup_strategy = self.lookup_strategy if lookup_strategy is None else lookup_strategy
300
- if "compress-" in lookup_strategy and step_idx > 1:
301
- compression_stage = int(lookup_strategy.split("compress-")[1][1:])
302
- if "compress-s" in lookup_strategy:
303
- step_idx = (step_idx - 2) % compression_stage + 2
304
- else:
305
- step_idx = (step_idx - 2) // compression_stage + 2
306
-
307
- start_idx = self._seen_tokens - key_states.shape[-2]
308
-
309
- indices = torch.arange(start_idx, start_idx + key_states.shape[-2], device=key_states.device)
310
- self.key_cache[step_idx].index_copy_(2, indices, key_states)
311
- self.value_cache[step_idx].index_copy_(2, indices, value_states)
312
- self.valid_mask[step_idx, start_idx : start_idx + key_states.shape[-2]] = True
313
-
314
- # Return based on lookup strategy
315
- if lookup_strategy == "full":
316
- return (
317
- self.key_cache[step_idx, :, :, : self._seen_tokens],
318
- self.value_cache[step_idx, :, :, : self._seen_tokens],
319
- )
320
- elif lookup_strategy.startswith("latest-m4"):
321
- if step_idx >= 2:
322
- pattern_steps = torch.arange(2, step_idx.item() + 1, 4, device=self.valid_mask.device)
323
- pattern_valid = self.valid_mask[pattern_steps]
324
- max_valid_step = pattern_steps[pattern_valid.to(torch.long).argmax(dim=0)]
325
- return (
326
- self.key_cache[max_valid_step, torch.arange(self._seen_tokens)],
327
- self.value_cache[max_valid_step, torch.arange(self._seen_tokens)],
328
- )
329
- return self.key_cache[step_idx, :, :, : self._seen_tokens], self.value_cache[
330
- step_idx, :, :, : self._seen_tokens
331
- ]
332
- elif lookup_strategy == "skip":
333
- valid_mask = self.valid_mask[step_idx, : self._seen_tokens]
334
- return (
335
- self.key_cache[step_idx, :, :, : self._seen_tokens][valid_mask],
336
- self.value_cache[step_idx, :, :, : self._seen_tokens][valid_mask],
337
- )
338
- elif lookup_strategy.startswith("randomized"):
339
- if step_idx < 2:
340
- max_step = step_idx
341
- else:
342
- curr_modulo = (step_idx - 2) % 4 + 2
343
- valid_steps = (
344
- torch.where(
345
- (torch.arange(2, step_idx.item() + 1, device=self.valid_mask.device) - 2) % 4 + 2 == curr_modulo
346
- )[0]
347
- + 2
348
- )
349
- rand_idx = torch.randint(len(valid_steps), (1,), device=valid_steps.device)
350
- max_step = valid_steps[rand_idx]
351
- return self.key_cache[max_step, : self._seen_tokens], self.value_cache[max_step, : self._seen_tokens]
352
- else:
353
- raise ValueError(f"Unknown lookup strategy: {lookup_strategy}")
354
-
355
- def reset(self) -> None:
356
- self._seen_tokens = 0
357
- self.key_cache.zero_()
358
- self.value_cache.zero_()
359
- self.valid_mask.zero_()
360
-
361
- def get_seq_length(self, step_idx: int = 0) -> int:
362
- return self._seen_tokens
363
-
364
- def get_memory_usage(self) -> float:
365
- return (self.key_cache.nelement() + self.value_cache.nelement()) * self.key_cache.element_size() / (1024 * 1024)
366
-
367
-
368
- ValidCache = HuginnDynamicCache | HuginnStaticCache
369
-
370
-
371
- class CausalSelfAttention(torch.nn.Module):
372
- def __init__(self, config: RavenConfig) -> None:
373
- super().__init__()
374
- self.config = config
375
- self.n_head = config.num_attention_heads
376
- self.n_kv_heads = config.num_key_value_heads
377
- self.head_dim = config.n_embd // self.n_head
378
-
379
- shape = (self.n_head + 2 * self.n_kv_heads) * self.head_dim
380
- self.chunks = [config.n_embd, self.n_kv_heads * self.head_dim, self.n_kv_heads * self.head_dim]
381
- self.Wqkv = torch.nn.Linear(config.n_embd, shape, bias=False)
382
- if config.qk_bias:
383
- self.qk_bias = torch.nn.Parameter(torch.zeros(2, 1, self.n_head, self.head_dim))
384
- self.proj = torch.nn.Linear(config.n_embd, config.n_embd, bias=False)
385
-
386
- def forward(
387
- self,
388
- x: Tensor,
389
- freqs_cis: Tensor,
390
- block_idx: torch.Tensor,
391
- mask: Optional[BlockMask] = None,
392
- past_key_values: Optional[ValidCache] = None,
393
- ) -> Tensor:
394
- B, S, E = x.shape # batch size, sequence length, embedding dimensionality (n_embd)
395
- q, k, v = self.Wqkv(x).split(self.chunks, dim=2)
396
- q = q.view(B, S, self.n_head, self.head_dim)
397
- k = k.view(B, S, self.n_kv_heads, self.head_dim)
398
- v = v.view(B, S, self.n_kv_heads, self.head_dim)
399
- # bias?
400
- if self.config.qk_bias:
401
- q_bias, k_bias = self.qk_bias.split(1, dim=0)
402
- q, k = (q + q_bias).to(q.dtype), (k + k_bias).to(q.dtype)
403
- # apply rotary
404
- q, k = apply_rotary_emb_complex_like(q, k, freqs_cis=freqs_cis)
405
-
406
- q = q.transpose(1, 2) # (B, nh, S, hs)
407
- k = k.transpose(1, 2)
408
- v = v.transpose(1, 2)
409
-
410
- if past_key_values is not None:
411
- k, v = past_key_values.update(k, v, block_idx)
412
-
413
- if mask is not None:
414
- y: torch.Tensor = flex_attention(q, k, v, block_mask=mask) # type: ignore
415
- else:
416
- if q.shape[2] < k.shape[2]:
417
- if q.shape[2] > 1:
418
- bias = attn_bias.causal_lower_right(q.shape[2], k.shape[2])
419
- y = torch.nn.functional.scaled_dot_product_attention(q, k, v, bias, dropout_p=0.0)
420
- else:
421
- y = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
422
- else:
423
- y = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=True)
424
- y = y.transpose(1, 2).reshape(B, S, E).contiguous() # reshape is a view if possible (it mostly is)
425
- return self.proj(y)
426
-
427
-
428
- class GatedMLP(torch.nn.Module):
429
- def __init__(self, config: RavenConfig, in_features: int = 0) -> None:
430
- super().__init__()
431
- in_features = config.n_embd if in_features == 0 else in_features
432
- self.fc = torch.nn.Linear(in_features, config.intermediate_size * 2, bias=False)
433
-
434
- self.proj = torch.nn.Linear(config.intermediate_size, config.n_embd, bias=False)
435
- self.nonlin = torch.nn.SiLU()
436
-
437
- def forward(self, x: Tensor) -> Tensor:
438
- # modified to single FC layer to improve parallelism
439
- x_fc_1, x_fc_2 = self.fc(x).chunk(2, dim=-1)
440
- x = self.nonlin(x_fc_1) * x_fc_2
441
- return self.proj(x)
442
-
443
-
444
- class SandwichBlock(torch.nn.Module):
445
- expanded = False
446
-
447
- def __init__(self, config: RavenConfig, layer_id: int) -> None:
448
- super().__init__()
449
- self.norm_1 = RMSNorm(config.n_embd, eps=config.norm_eps)
450
- self.attn = CausalSelfAttention(config)
451
- self.norm_2 = RMSNorm(config.n_embd, eps=config.norm_eps)
452
- self.mlp = GatedMLP(config)
453
- self.norm_3 = RMSNorm(config.n_embd, eps=config.norm_eps)
454
- self.norm_4 = RMSNorm(config.n_embd, eps=config.norm_eps)
455
- self.layer_id = layer_id
456
-
457
- def forward(
458
- self,
459
- x: Tensor,
460
- freqs_cis: Tensor,
461
- step_idx: int,
462
- mask: Optional[BlockMask] = None,
463
- past_key_values: Optional[ValidCache] = None,
464
- ) -> Tensor:
465
- attn_out = self.attn(self.norm_1(x), freqs_cis, step_idx, mask, past_key_values)
466
- x = self.norm_2(attn_out + x)
467
- x = self.norm_4(self.mlp(self.norm_3(x)) + x)
468
- return x
469
-
470
-
471
- class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
472
- freqs_cis: torch.Tensor
473
 
474
  def __init__(
475
  self,
476
- config: RavenConfig,
477
- ) -> None:
478
- super().__init__(config)
479
- self.config = config
480
-
481
- # Transformer layers
482
- prelude = torch.nn.ModuleList(SandwichBlock(config, layer_id=i) for i in range(config.n_layers_in_prelude))
483
- adapter = torch.nn.Linear(config.n_embd * 2, config.n_embd, bias=config.bias)
484
- core_block = torch.nn.ModuleList(
485
- SandwichBlock(config, layer_id=i + config.n_layers_in_prelude)
486
- for i in range(config.n_layers_in_recurrent_block)
487
- )
488
- o = config.n_layers_in_prelude + config.n_layers_in_recurrent_block * config.mean_recurrence
489
- coda = torch.nn.ModuleList(SandwichBlock(config, layer_id=i + o) for i in range(config.n_layers_in_coda))
490
-
491
- self.transformer = torch.nn.ModuleDict(
492
- dict(
493
- wte=torch.nn.Embedding(config.padded_vocab_size, config.n_embd),
494
- prelude=prelude,
495
- adapter=adapter,
496
- core_block=core_block,
497
- coda=coda,
498
- ln_f=RMSNorm(config.n_embd, eps=config.norm_eps), # used twice :>
499
- )
500
- )
501
- self.emb_scale = config.init_values["embed_scale"]
502
- # Head
503
- self.lm_head = torch.nn.Linear(config.n_embd, config.padded_vocab_size, bias=False)
504
- if self.config.tie_embeddings:
505
- self.tie_weights()
506
- # rope
507
- self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True)
508
-
509
- def get_input_embeddings(self):
510
- return self.transformer.wte
511
-
512
- def get_output_embeddings(self):
513
- return self.lm_head
514
-
515
- def _precompute_freqs_cis(self):
516
- # can actually be a buffer now, and remains in fp32! (at least in the settings I tested)
517
- freqs_cis = precompute_freqs_cis(
518
- self.config.n_embd // self.config.num_attention_heads, self.config.block_size, self.config.rope_base, 1
519
- )
520
- return freqs_cis
521
-
522
- def compile_mask(
523
- self,
524
- input_ids: torch.Tensor,
525
- attention_mask: Optional[torch.Tensor] = None,
526
- past_key_values: Optional[ValidCache] = None,
527
- pad_token_id=65509,
528
- ) -> Optional[BlockMask]:
529
- batch_size, seq_len = input_ids.shape[0], input_ids.shape[1]
530
-
531
- # If no padding and no attention mask, no need for a mask
532
- if attention_mask is None and (input_ids == pad_token_id).sum() == 0:
533
- return None
534
-
535
- if past_key_values is not None and seq_len == 1:
536
- return None
537
-
538
- # Get total sequence length including cache
539
- cache_len = past_key_values.get_seq_length() if past_key_values is not None else 0
540
- kv_length = cache_len + seq_len
541
-
542
- if attention_mask is None:
543
-
544
- def mask_mod(b, h, q_idx, kv_idx):
545
- return q_idx >= kv_idx & (input_ids[b, kv_idx] != pad_token_id)
546
- else:
547
-
548
- def mask_mod(b, h, q_idx, kv_idx):
549
- return (q_idx >= kv_idx) & (input_ids[b, kv_idx] != pad_token_id) & attention_mask[b, q_idx, kv_idx]
550
-
551
- kv_length = past_key_values.get_seq_length() if past_key_values is not None else seq_len
552
- if kv_length == 0:
553
- kv_length = seq_len # prefill
554
- block_mask = create_block_mask(
555
- mask_mod,
556
- B=batch_size,
557
- H=None,
558
- Q_LEN=seq_len,
559
- KV_LEN=kv_length,
560
- device=input_ids.device,
561
- )
562
-
563
- # # Define mask_mod function
564
- # def mask_mod(b, h, q_idx, kv_idx):
565
- # # Always apply causal constraint
566
- # is_causal = q_idx >= kv_idx
567
-
568
- # # Handle cache vs current tokens
569
- # is_cache = kv_idx < cache_len
570
- # current_idx = kv_idx - cache_len
571
-
572
- # # For cache: always valid; For current: check padding
573
- # not_pad = input_ids[b, current_idx] != pad_token_id
574
- # valid = is_cache | not_pad
575
-
576
- # # Apply attention mask if provided
577
- # if attention_mask is not None:
578
- # q_idx_curr = q_idx - cache_len
579
- # attn_valid = attention_mask[b, q_idx_curr, current_idx]
580
- # valid = valid & (is_cache | attn_valid)
581
-
582
- # return is_causal & valid
583
-
584
- # def mask_mod(b, h, q_idx, kv_idx):
585
- # is_causal = q_idx >= kv_idx
586
- # is_current = (kv_idx >= cache_len) & (kv_idx < kv_length)
587
- # current_idx = kv_idx - cache_len
588
-
589
- # is_valid = (~is_current) | (
590
- # (current_idx >= 0) & (current_idx < seq_len) & (input_ids != pad_token_id)[b, current_idx % seq_len]
591
- # )
592
-
593
- # return is_causal & is_valid
594
-
595
- # # Define mask_mod function
596
- # def mask_mod(b, h, q_idx, kv_idx):
597
- # # Always apply causal constraint
598
- # is_causal = q_idx >= kv_idx
599
-
600
- # # Handle cache vs current tokens
601
- # is_cache = kv_idx < cache_len
602
- # current_idx = kv_idx - cache_len
603
- # in_bounds = (current_idx >= 0) & (current_idx < seq_len)
604
-
605
- # # For cache: always valid; For current: check padding
606
- # not_pad = (input_ids[b, current_idx % seq_len] != pad_token_id) | ~in_bounds
607
- # valid = is_cache | (not_pad & in_bounds)
608
-
609
- # # Apply attention mask if provided
610
- # if attention_mask is not None:
611
- # q_idx_curr = q_idx - cache_len
612
- # q_in_bounds = (q_idx_curr >= 0) & (q_idx_curr < seq_len)
613
- # attn_valid = attention_mask[b, q_idx_curr % seq_len, current_idx % seq_len] | ~(in_bounds & q_in_bounds)
614
- # valid = valid & (is_cache | attn_valid)
615
-
616
- # return is_causal & valid
617
-
618
- # Create block mask
619
- block_mask = create_block_mask(
620
- mask_mod,
621
- B=batch_size,
622
- H=None,
623
- Q_LEN=seq_len,
624
- KV_LEN=kv_length,
625
- device=input_ids.device,
626
- )
627
-
628
- return block_mask
629
-
630
- def forward(
631
- self,
632
- input_ids: torch.Tensor,
633
- input_embeds: Optional[torch.Tensor] = None,
634
- input_states: Optional[torch.Tensor] = None,
635
- attention_mask: Optional[torch.Tensor] = None, # binary mask of shape q x kv, True=valid position
636
- position_ids: Optional[torch.Tensor] = None,
637
- labels: Optional[torch.Tensor] = None,
638
- num_steps: Optional[torch.Tensor] = None,
639
- past_key_values: Optional[ValidCache] = None,
640
- output_details: dict = {
641
- "return_logits": True,
642
- "return_latents": True,
643
- "return_head": False,
644
- "return_stats": False,
645
- },
646
- use_cache: bool = False,
647
- cache_position: Optional[torch.Tensor] = None,
648
- init_scale: float = 1.0,
649
- **kwargs,
650
- ) -> CausalLMOutputRecurrentLatents:
651
- # Support multiple position formats:
652
- if position_ids is None and cache_position is None:
653
- freqs_cis = self.freqs_cis[:, : input_ids.shape[1]]
654
- elif position_ids is not None:
655
- freqs_cis = self.freqs_cis.index_select(1, position_ids.squeeze())
656
- elif cache_position is not None:
657
- freqs_cis = self.freqs_cis[:, cache_position]
658
-
659
- if input_embeds is None:
660
- input_embeds = self.transformer.wte(input_ids) # type: ignore # types broken in 2.6+
661
-
662
- if self.emb_scale != 1:
663
- input_embeds = input_embeds * self.emb_scale # type: ignore
664
-
665
- if use_cache and past_key_values is None:
666
- past_key_values = HuginnDynamicCache()
667
-
668
- prepared_attn_mask = None # self.compile_mask(input_ids, attention_mask, past_key_values)
669
- block_idx = torch.tensor(-1, device=torch.device("cpu"), dtype=torch.long) # count in tensors for compile
670
- # Non-recurrent prelude
671
- for block in self.transformer.prelude: # type: ignore # types broken in 2.6+
672
- block_idx += 1
673
- input_embeds = block(input_embeds, freqs_cis, block_idx, prepared_attn_mask, past_key_values)
674
-
675
- # Main recurrence
676
- x, num_steps_no_grad, num_steps_with_grad, xk, block_idx = self.iterate_forward(
677
- input_embeds, # type: ignore # mystery typing error
678
- input_states,
679
- freqs_cis,
680
- block_idx,
681
- prepared_attn_mask,
682
- past_key_values,
683
- num_steps,
684
- init_scale,
685
- )
686
- latent_states = x.clone().detach()
687
-
688
- # Coda layers
689
- block_idx = torch.tensor(0, device=torch.device("cpu"), dtype=torch.long) # use negative indices for head
690
- for block in self.transformer.coda: # type: ignore # types broken in 2.6+
691
- block_idx -= 1
692
- x = block(x, freqs_cis, block_idx, prepared_attn_mask, past_key_values)
693
- x = self.transformer.ln_f(x) # type: ignore # types broken in 2.6+
694
-
695
- # Prediction head, assuming labels really are labels and not equal to input_ids
696
- if labels is not None:
697
- logits = self.lm_head(x).float()
698
- loss = torch.nn.functional.cross_entropy(
699
- logits.view(-1, logits.shape[-1]), labels.view(-1), ignore_index=-100
700
- )
701
- log_ppl = loss.clone().detach().exp()
702
- else:
703
- logits = self.lm_head(x).float()
704
- loss, log_ppl = torch.as_tensor(0.0), torch.as_tensor(0.0)
705
-
706
- return CausalLMOutputRecurrentLatents(
707
- loss=loss,
708
- log_ppl=log_ppl,
709
- logits=logits if output_details["return_logits"] else None,
710
- past_key_values=past_key_values,
711
- hidden_states=x if output_details["return_head"] else None,
712
- latent_states=latent_states if output_details["return_latents"] else None,
713
- stats=self.get_stats(logits, x, latent_states, xk, input_embeds, num_steps_no_grad, num_steps_with_grad)
714
- if output_details["return_stats"]
715
- else None,
716
- )
717
-
718
- @torch._dynamo.disable(recursive=False) # type: ignore
719
- def iterate_forward(
720
- self,
721
- input_embeds: torch.Tensor,
722
- input_states: torch.Tensor,
723
- freqs_cis,
724
- block_idx: torch.Tensor,
725
- mask: Optional[BlockMask],
726
- past_key_values: Optional[ValidCache] = None,
727
- num_steps: Optional[torch.Tensor] = None,
728
- init_scale: float = 1.0,
729
- ):
730
- x = xk = self.initialize_state(input_embeds, scale=init_scale) if input_states is None else input_states.clone()
731
- if num_steps is None:
732
- num_steps_no_grad, num_steps_with_grad = self.randomized_iteration_sampler() # type: ignore
733
- elif hasattr(num_steps, "__len__") and len(num_steps) > 1:
734
- num_steps_no_grad, num_steps_with_grad = num_steps
735
- else:
736
- num_steps_no_grad, num_steps_with_grad = num_steps, torch.tensor(0) if not x.is_meta else 0
737
-
738
- with torch.no_grad():
739
- # ultra annoying in ddp due to
740
- # https://discuss.pytorch.org/t/does-distributeddataparallel-work-with-torch-no-grad-and-find-unused-parameters-false/122594
741
- # for now running with find_unused_params=True enabled even though the graph structure is (technically) clear
742
- # and all parameters are always used
743
- for no_grad_step in range(num_steps_no_grad):
744
- xk = x
745
- x, block_idx = self.core_block_forward(
746
- xk, input_embeds, freqs_cis, mask, past_key_values, block_idx, no_grad_step
747
- )
748
-
749
- for grad_step in range(num_steps_with_grad):
750
- xk = x
751
- x, block_idx = self.core_block_forward(
752
- xk, input_embeds, freqs_cis, mask, past_key_values, block_idx, num_steps_no_grad + grad_step
753
- )
754
- return self.transformer.ln_f(x), num_steps_no_grad, num_steps_with_grad, xk.detach(), block_idx # type: ignore # types broken in 2.6+
755
-
756
- def core_block_forward(
757
- self,
758
- x,
759
- input_embeds,
760
- freqs_cis,
761
- mask: Optional[BlockMask],
762
- past_key_values,
763
- block_idx: torch.Tensor,
764
- current_step: int | Tensor,
765
- ):
766
- x = self._maybe_inject_noise(x, current_step)
767
- x = self.transformer.adapter(torch.cat([x, input_embeds.to(x.device)], dim=-1)) # type: ignore # types broken in 2.6+
768
- for block in self.transformer.core_block: # type: ignore # types broken in 2.6+
769
- block_idx += 1
770
- x = block(x, freqs_cis, block_idx, mask, past_key_values)
771
- return x, block_idx
772
-
773
- @torch.no_grad()
774
- def iterate_one_step(
775
- self,
776
- input_embeds,
777
- input_states,
778
- position_ids: Optional[torch.Tensor] = None,
779
- cache_position: Optional[torch.Tensor] = None,
780
- block_idx: torch.Tensor = torch.tensor(0, dtype=torch.long),
781
- attention_mask: Optional[BlockMask] = None,
782
- past_key_values: Optional[ValidCache] = None,
783
- current_step: int = 0,
784
- ):
785
- if position_ids is None and cache_position is None:
786
- freqs_cis = self.freqs_cis[:, : input_embeds.shape[1]]
787
- elif position_ids is not None:
788
- freqs_cis = self.freqs_cis.index_select(1, position_ids.squeeze())
789
- elif cache_position is not None:
790
- freqs_cis = self.freqs_cis[:, cache_position]
791
- x, block_idx = self.core_block_forward(
792
- input_states,
793
- input_embeds,
794
- freqs_cis,
795
- attention_mask,
796
- past_key_values,
797
- block_idx,
798
- current_step=current_step,
799
- )
800
- return x, block_idx, current_step + 1
801
-
802
- def predict_from_latents(
803
- self,
804
- latents,
805
- attention_mask: Optional[BlockMask] = None,
806
- position_ids: Optional[torch.Tensor] = None,
807
- cache_position: Optional[torch.Tensor] = None,
808
- past_key_values: Optional[ValidCache] = None,
809
- ):
810
- if position_ids is None and cache_position is None:
811
- freqs_cis = self.freqs_cis[:, : latents.shape[1]]
812
- elif position_ids is not None:
813
- freqs_cis = self.freqs_cis.index_select(1, position_ids.squeeze())
814
- elif cache_position is not None:
815
- freqs_cis = self.freqs_cis[:, cache_position]
816
- x = self.transformer.ln_f(latents) # type: ignore # types broken in 2.6+
817
- # Coda layers
818
- block_idx = torch.tensor(0, device=torch.device("cpu"), dtype=torch.long) # use negative indices for head
819
- for block in self.transformer.coda: # type: ignore # types broken in 2.6+
820
- block_idx -= 1
821
- x = block(x, freqs_cis, block_idx, attention_mask, past_key_values)
822
- x = self.transformer.ln_f(x) # type: ignore # types broken in 2.6+
823
-
824
- logits = self.lm_head(x).float()
825
-
826
- return CausalLMOutputRecurrentLatents(
827
- loss=torch.as_tensor(0.0),
828
- log_ppl=torch.as_tensor(0.0),
829
- logits=logits,
830
- past_key_values=past_key_values,
831
- latent_states=x,
832
- )
833
-
834
- def embed_inputs(
835
- self,
836
- input_ids: torch.Tensor,
837
- attention_mask: Optional[torch.Tensor] = None,
838
- position_ids: Optional[torch.Tensor] = None,
839
- past_key_values: Optional[ValidCache] = None,
840
- use_cache: bool = False,
841
- cache_position: Optional[torch.Tensor] = None,
842
- **kwargs,
843
- ) -> tuple[torch.Tensor, torch.Tensor]:
844
- # Support multiple position formats:
845
- if position_ids is None and cache_position is None:
846
- freqs_cis = self.freqs_cis[:, : input_ids.shape[1]]
847
- elif position_ids is not None:
848
- freqs_cis = self.freqs_cis.index_select(1, position_ids.squeeze())
849
- elif cache_position is not None:
850
- freqs_cis = self.freqs_cis[:, cache_position]
851
-
852
- input_embeds = self.transformer.wte(input_ids) # type: ignore # types broken in 2.6+
853
- prepared_attn_mask = self.compile_mask(input_ids, attention_mask)
854
-
855
- if self.emb_scale != 1:
856
- input_embeds = input_embeds * self.emb_scale # type: ignore
857
-
858
- if use_cache and past_key_values is None:
859
- past_key_values = HuginnDynamicCache()
860
-
861
- block_idx = torch.tensor(-1, device=torch.device("cpu"), dtype=torch.long) # count in tensors for compile
862
- # Non-recurrent prelude
863
- for block in self.transformer.prelude: # type: ignore # types broken in 2.6+
864
- block_idx += 1
865
- input_embeds = block(input_embeds, freqs_cis, block_idx, prepared_attn_mask, past_key_values)
866
- return input_embeds, block_idx
867
-
868
- @torch._dynamo.disable(recursive=False) # type: ignore
869
- def randomized_iteration_sampler(self) -> tuple[torch.Tensor, torch.Tensor]:
870
- """Outputs are long tensors so that they can be passed through compiled functions"""
871
- t = max(self.config.mean_recurrence - self.config.mean_backprop_depth, 0)
872
- s = self.config.mean_backprop_depth
873
- if torch.rand((1,)).is_meta: # annoying clause to make meta-tensor-based flop counting work
874
- # these values are only the mean TFLOPs of the randomized sampler
875
- # Note that this clause also breaks the contract, and returns ints in meta tensor mode
876
- return t, s # type: ignore
877
- if self.training:
878
- sigma = 0.5
879
- mu = math.log(t + s) - (sigma**2 / 2)
880
- rate = torch.zeros((1,)).log_normal_(mean=mu, std=sigma)
881
- p = torch.poisson(torch.tensor([rate], dtype=torch.float)) + 1
882
- n = torch.clamp(p - s, min=0)
883
- k = torch.as_tensor(torch.minimum(torch.as_tensor(s), p))
884
- else:
885
- n, k = torch.as_tensor(self.config.mean_recurrence), torch.as_tensor(0)
886
-
887
- return n.to(dtype=torch.long), k.to(dtype=torch.long)
888
-
889
- def initialize_state(self, input_embeds, scale: float = 1.0):
890
- x = torch.randn_like(input_embeds)
891
- std = self.config.init_values["std"] * scale
892
- if std > 0:
893
- torch.nn.init.trunc_normal_(x, mean=0.0, std=std, a=-3 * std, b=3 * std)
894
- if self.emb_scale != 1:
895
- x = x * self.emb_scale
896
- else:
897
- x.zero_()
898
- return x
899
-
900
- def _maybe_inject_noise(self, x, current_step, renorm=False):
901
- if self.config.test_time_noise > 0:
902
- n = self.config.test_time_noise * self.config.init_values["std"] * self.emb_scale
903
- if self.config.test_time_noise_type == "geom":
904
- step1 = torch.as_tensor(current_step + 1, device=x.device) # need to cast for compile
905
- x = x * (1 - n / step1) + torch.randn_like(x) * n / step1
906
- elif self.config.test_time_noise_type == "sqrt":
907
- step1sqrt = torch.as_tensor(current_step + 1, device=x.device).sqrt() # need to cast for compile
908
- x = x * (1 - n / step1sqrt) + torch.randn_like(x) * n / step1sqrt
909
- elif self.config.test_time_noise_type == "line":
910
- noise = max(n, (self.config.mean_recurrence - current_step) / self.config.mean_recurrence) # type: ignore
911
- x = x * (1 - noise) + torch.randn_like(x) * noise
912
- elif self.config.test_time_noise_type == "chi":
913
- noise = 2 * torch.rand(1, device=x.device, dtype=x.dtype) * n
914
- x = x * (1 - noise) + torch.randn_like(x) * noise
915
- elif self.config.test_time_noise_type == "fixed":
916
- x = x * (1 - n) + torch.randn_like(x) * n
917
- else:
918
- raise ValueError()
919
-
920
- if renorm:
921
- x = self.transformer.core_block[-1].norm_4(x) # type: ignore moduledict types still broken in pytorch
922
- return x
923
-
924
- def prepare_inputs_for_generation(
925
- self,
926
- input_ids: torch.Tensor,
927
- past_key_values: Optional[Cache] = None,
928
- attention_mask: Optional[torch.Tensor] = None,
929
- inputs_embeds: Optional[torch.FloatTensor] = None,
930
- cache_position: Optional[torch.Tensor] = None,
931
- cache_lookup_strategy: str = "full",
932
  **kwargs,
933
  ):
934
- model_inputs = {}
935
- model_inputs["cache_position"] = cache_position
936
- current_input_length = input_ids.shape[1]
937
-
938
- if past_key_values is not None:
939
- if not isinstance(past_key_values, (HuginnDynamicCache, HuginnStaticCache)):
940
- assert past_key_values.get_seq_length() == 0 # only replace empty caches
941
- # Need to use custom cache, detect and replace HF cache if generate injects it
942
- if isinstance(past_key_values, StaticCache):
943
- past_key_values = HuginnStaticCache(
944
- max_length=getattr(self.generation_config, "max_length", self.config.block_size),
945
- max_num_steps=4 + kwargs.get("num_steps", self.config.mean_recurrence) * 4,
946
- num_heads=self.config.num_key_value_heads,
947
- hidden_dim=self.config.n_embd // self.config.num_attention_heads,
948
- dtype=torch.bfloat16,
949
- device=input_ids.device,
950
- lookup_strategy=cache_lookup_strategy,
951
- )
952
- else:
953
- past_key_values = HuginnDynamicCache(lookup_strategy=cache_lookup_strategy)
954
- model_inputs["past_key_values"] = past_key_values if kwargs["use_cache"] else None
955
- input_ids = input_ids[:, cache_position] # type: ignore
956
-
957
- model_inputs["input_ids"] = input_ids.clone(memory_format=torch.contiguous_format)
958
- if cache_position is None:
959
- position_ids = torch.arange(current_input_length)[None, :].to(input_ids.device)
960
- model_inputs["position_ids"] = position_ids[:, -current_input_length:].clone(
961
- memory_format=torch.contiguous_format
962
- ) # some form of position_ids is a critical argument for the model to correctly apply rope!
963
-
964
- # forward all other entries
965
- for key, value in kwargs.items():
966
- if key not in model_inputs:
967
- model_inputs[key] = value
968
- return model_inputs
969
-
970
- @torch.no_grad()
971
- def generate(self, *args, **kwargs):
972
- """Dispatcher - use HF generate in all normal cases."""
973
- self.generation_config = args[1] if len(args) > 1 else self.generation_config
974
- if any(k in kwargs for k in ("criterion", "exit_threshold")):
975
- # print("Dispatching to custom generate_adaptive function call")
976
- return self.generate_with_adaptive_compute(*args, **kwargs)
977
- elif "continuous_compute" in kwargs:
978
- # print("Dispatching to custom generate_minimal function call")
979
- return self.generate_minimal(*args, **kwargs)
980
- else:
981
- return super().generate(*args, **kwargs)
982
-
983
- @torch.no_grad()
984
- def _prep_generate_args(
985
- self,
986
- input_ids: torch.Tensor,
987
- generation_config: Optional[GenerationConfig] = None, # type: ignore
988
- cache_lookup_strategy: str = "full",
989
- model_kwargs: dict = {},
990
- ):
991
- # Setup
992
- if generation_config is None:
993
- generation_config: GenerationConfig = self.generation_config # type: ignore
994
- if "max_new_tokens" in model_kwargs:
995
- max_new_tokens = model_kwargs["max_new_tokens"]
996
- if "max_length" in model_kwargs:
997
- max_new_tokens = min(max_new_tokens, model_kwargs["max_length"] - input_ids.shape[1])
998
- else:
999
- max_length = model_kwargs.get("max_length", generation_config.max_length)
1000
- max_new_tokens = max_length - input_ids.shape[1]
1001
-
1002
- if "cache_implementation" not in model_kwargs or model_kwargs["cache_implementation"] == "dynamic":
1003
- model_kwargs["past_key_values"] = HuginnDynamicCache(lookup_strategy=cache_lookup_strategy)
1004
- else:
1005
- model_kwargs["past_key_values"] = HuginnStaticCache(
1006
- max_length=max_length,
1007
- max_num_steps=4 + model_kwargs.get("num_steps", self.config.mean_recurrence) * 4,
1008
- num_heads=self.config.num_key_value_heads,
1009
- hidden_dim=self.config.n_embd // self.config.num_attention_heads,
1010
- batch_size=input_ids.shape[0],
1011
- dtype=torch.bfloat16,
1012
- device=input_ids.device,
1013
- lookup_strategy=cache_lookup_strategy,
1014
- )
1015
- model_kwargs["use_cache"] = True
1016
- model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
1017
- return model_kwargs, generation_config, max_new_tokens
1018
-
1019
- @torch.no_grad()
1020
- def generate_minimal(
1021
- self,
1022
- input_ids: torch.Tensor,
1023
- generation_config: Optional[GenerationConfig] = None, # type: ignore
1024
- tokenizer=None,
1025
- streamer=None,
1026
- continuous_compute=False, # warm-start state / continuous CoT
1027
- init_scale: float = 1.0,
1028
- cache_lookup_strategy: str = "full",
1029
- **model_kwargs,
1030
- ) -> Union[torch.Tensor, dict[str, Any]]:
1031
- """Minimal single-sequence generation. Template for more complicated generate tasks"""
1032
- model_kwargs, generation_config, max_new_tokens = self._prep_generate_args(
1033
- input_ids, generation_config, cache_lookup_strategy
1034
- )
1035
- stop_tokens = self._get_stops(generation_config, tokenizer, model_kwargs).to(input_ids.device)
1036
- unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
1037
-
1038
- # Set up continuous compute if enabled
1039
- if continuous_compute:
1040
- embedded_inputs, _ = self.embed_inputs(input_ids)
1041
- model_kwargs["input_states"] = self.initialize_state(embedded_inputs, scale=init_scale)
1042
-
1043
- # Generate tokens
1044
- batch_size = input_ids.shape[0]
1045
- for _ in range(max_new_tokens):
1046
- # Forward pass
1047
- model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1048
- outputs = self(**model_inputs, init_scale=init_scale)
1049
-
1050
- # Get next token
1051
- next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
1052
- next_token = self._sample_next_token(next_token_logits, generation_config)
1053
-
1054
- # Append token to sequence
1055
- input_ids = torch.cat([input_ids, next_token], dim=-1)
1056
-
1057
- if streamer:
1058
- streamer.put(next_token.cpu())
1059
-
1060
- # Update model kwargs
1061
- model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs)
1062
- if continuous_compute:
1063
- model_kwargs["input_states"] = outputs.latent_states[:, -1:, :]
1064
-
1065
- if stop_tokens is not None:
1066
- for i in range(batch_size):
1067
- if unfinished_sequences[i] and next_token[i, 0].item() in stop_tokens:
1068
- unfinished_sequences[i] = 0
1069
- if "stopping_criteria" in model_kwargs:
1070
- unfinished_sequences = unfinished_sequences & ~model_kwargs["stopping_criteria"](input_ids, None)
1071
- if unfinished_sequences.max() == 0:
1072
- break
1073
-
1074
- if streamer:
1075
- streamer.end()
1076
-
1077
- if generation_config.return_dict_in_generate:
1078
- return GenerateDecoderOnlyOutput(
1079
- sequences=input_ids, # type: ignore
1080
- scores=None,
1081
- logits=None,
1082
- attentions=None,
1083
- hidden_states=None,
1084
- past_key_values=model_kwargs.get("past_key_values"),
1085
- )
1086
- return input_ids
1087
-
1088
- @torch.no_grad()
1089
- def generate_with_adaptive_compute(
1090
- self,
1091
- input_ids: torch.Tensor,
1092
- generation_config: Optional[GenerationConfig] = None, # type: ignore
1093
- tokenizer=None,
1094
- streamer=None,
1095
- continuous_compute=False, # warm-start state / continuous CoT
1096
- criterion="none", # off by default, turn on by choosing an exit criterion
1097
- exit_threshold: Union[str, float, int] = "auto",
1098
- init_scale: float = 1.0,
1099
- cache_lookup_strategy: str = "full",
1100
- **model_kwargs,
1101
- ) -> Union[torch.Tensor, GenerateDecoderOnlyOutput]:
1102
- """
1103
- Generate tokens with adaptive compute. This is NOT the most efficient implementation.
1104
- For batches, on each token, we iterate until the entire batch finishes.
1105
- """
1106
- model_kwargs, generation_config, max_new_tokens = self._prep_generate_args(
1107
- input_ids, generation_config, cache_lookup_strategy, model_kwargs
1108
- )
1109
- max_steps = model_kwargs.get("num_steps", self.config.mean_recurrence)
1110
- stop_tokens = self._get_stops(generation_config, tokenizer, model_kwargs).to(input_ids.device)
1111
- logit_type = dict(copy=True, dtype=torch.float32, device=input_ids.device)
1112
- batch_size = input_ids.shape[0]
1113
- compute_steps = []
1114
-
1115
- # Set up continuous compute if enabled
1116
- if continuous_compute:
1117
- embedded_inputs, _ = self.embed_inputs(input_ids)
1118
- model_kwargs["input_states"] = self.initialize_state(embedded_inputs, scale=init_scale)
1119
-
1120
- # Track which sequences have finished (using unfinished_sequences to match generate_minimal)
1121
- unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
1122
-
1123
- # Generate tokens
1124
- for _ in range(max_new_tokens):
1125
- # Adaptive compute forward
1126
- model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1127
- aux_inputs = {
1128
- k: model_inputs[k] for k in ["cache_position", "past_key_values", "attention_mask"] if k in model_inputs
1129
- }
1130
- embedded_inputs, block_idx = self.embed_inputs(model_inputs["input_ids"], **aux_inputs)
1131
- current_latents = (
1132
- self.initialize_state(embedded_inputs, scale=init_scale)
1133
- if not continuous_compute
1134
- else model_kwargs["input_states"]
1135
- )
1136
-
1137
- # Initialize criterion tracking for each sequence in batch
1138
- exit_values_per_seq = [[] for _ in range(batch_size)]
1139
- compute_steps_per_seq = [0] * batch_size
1140
- exit_reached = torch.zeros(batch_size, dtype=torch.bool, device=input_ids.device)
1141
-
1142
- # Set up criterions based on selected strategy
1143
- if criterion == "entropy-diff":
1144
- entropy = torch.ones(batch_size, device=input_ids.device) * 100.0
1145
- exit_threshold = 1e-3 if exit_threshold == "auto" else float(exit_threshold)
1146
- elif criterion == "latent-diff":
1147
- exit_threshold = 0.03 if exit_threshold == "auto" else float(exit_threshold)
1148
- elif "kl" in criterion:
1149
- V = self.config.padded_vocab_size
1150
- log_probs = ((1 / V) * torch.ones(batch_size, V, dtype=torch.float, device=input_ids.device)).log()
1151
- if criterion == "minp-kl":
1152
- exit_threshold = 1e-6 if exit_threshold == "auto" else float(exit_threshold)
1153
- else:
1154
- exit_threshold = 5e-4 if exit_threshold == "auto" else float(exit_threshold)
1155
- elif criterion == "argmax-stability":
1156
- stable_for_n_steps = torch.zeros(batch_size, dtype=torch.long, device=input_ids.device)
1157
- current_argmax = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) * -1
1158
- exit_threshold = 5 if exit_threshold == "auto" else int(exit_threshold)
1159
- elif criterion == "none":
1160
- exit_threshold = 1.0 if exit_threshold == "auto" else float(exit_threshold)
1161
- else:
1162
- raise ValueError("Invalid adaptive compute strategy.")
1163
-
1164
- next_token_logits = None
1165
-
1166
- # Iterate through compute steps
1167
- for compute_step in range(max_steps):
1168
- prev_latents = current_latents.clone()
1169
- current_latents, block_idx, _ = self.iterate_one_step(
1170
- embedded_inputs,
1171
- current_latents,
1172
- block_idx=block_idx,
1173
- **aux_inputs,
1174
- current_step=compute_step,
1175
- )
1176
-
1177
- if _ > 0: # do not exit in prefill
1178
- # Check exit condition for each sequence in batch
1179
- if criterion == "entropy-diff":
1180
- prev_entropy = entropy
1181
- outputs = self.predict_from_latents(current_latents, **aux_inputs)
1182
- logits: torch.Tensor = outputs.logits # type: ignore
1183
- probs = F.softmax(logits[:, -1, :], dim=-1)
1184
- entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=-1)
1185
- exit_values = (entropy - prev_entropy).abs()
1186
- elif criterion == "latent-diff":
1187
- norm_diff = (prev_latents - current_latents).norm(dim=-1) / current_latents.norm(dim=-1)
1188
- exit_values = norm_diff.mean(dim=-1)
1189
- elif "kl" in criterion:
1190
- outputs = self.predict_from_latents(current_latents, **aux_inputs)
1191
- logits: torch.Tensor = outputs.logits # type: ignore
1192
- prev_log_probs = log_probs
1193
- if criterion == "minp-kl":
1194
- probs = F.softmax(logits[:, -1, :].float(), dim=-1)
1195
- max_probs = probs.max(dim=-1, keepdim=True)[0]
1196
- probs_mask = probs < (0.1 * max_probs)
1197
- masked_probs = probs.clone()
1198
- masked_probs[probs_mask] = 1 / V
1199
- probs = masked_probs / masked_probs.sum(dim=-1, keepdim=True)
1200
- log_probs = probs.log()
1201
- else:
1202
- log_probs = F.log_softmax(logits[:, -1, :].float(), dim=-1)
1203
- exit_values = F.kl_div(log_probs, prev_log_probs, reduction="none", log_target=True).sum(dim=-1)
1204
- elif criterion == "argmax-stability":
1205
- prev_argmax = current_argmax
1206
- outputs = self.predict_from_latents(current_latents, **aux_inputs)
1207
- logits: torch.Tensor = outputs.logits # type: ignore
1208
- current_argmax = logits[:, -1, :].argmax(dim=-1)
1209
- stable_for_n_steps = torch.where(
1210
- current_argmax == prev_argmax, stable_for_n_steps + 1, torch.zeros_like(stable_for_n_steps)
1211
- )
1212
- exit_values = stable_for_n_steps
1213
- elif criterion == "none":
1214
- exit_values = torch.ones(batch_size, device=input_ids.device) * 2.0 * exit_threshold
1215
-
1216
- # Record values and check exits for each sequence
1217
- for i in range(batch_size):
1218
- if not exit_reached[i] and unfinished_sequences[i].bool():
1219
- exit_values_per_seq[i].append(exit_values[i].item())
1220
-
1221
- # Check for new exits, respecting unfinished_sequences
1222
- new_exits = (
1223
- exit_values < exit_threshold
1224
- if criterion != "argmax-stability"
1225
- else exit_values >= exit_threshold
1226
- )
1227
- new_exits = new_exits & ~exit_reached & unfinished_sequences.bool()
1228
-
1229
- if new_exits.any():
1230
- exit_reached = exit_reached | new_exits
1231
- if criterion == "latent-diff":
1232
- # Normally we don't compute the output for latent-diff, but when there is an exit,
1233
- # we need to compute and save the output
1234
- outputs = self.predict_from_latents(current_latents, **aux_inputs)
1235
- logits: torch.Tensor = outputs.logits # type: ignore
1236
- if next_token_logits is None:
1237
- next_token_logits = logits[:, -1, :].to(**logit_type) # type: ignore
1238
- else:
1239
- for i in range(batch_size):
1240
- if new_exits[i]:
1241
- next_token_logits[i] = logits[i, -1, :].to(**logit_type) # type: ignore
1242
- for i in range(batch_size):
1243
- if new_exits[i]:
1244
- compute_steps_per_seq[i] = compute_step + 1
1245
-
1246
- # If all sequences have exited or finished, break early
1247
- if (exit_reached | ~unfinished_sequences.bool()).all():
1248
- break
1249
- # This else is if the for loop finished without breaking
1250
- else:
1251
- outputs = self.predict_from_latents(current_latents, **aux_inputs)
1252
-
1253
- # For sequences that didn't exit early, use the final logits
1254
- if next_token_logits is None:
1255
- next_token_logits = outputs.logits[:, -1, :].to(**logit_type) # type: ignore
1256
- else:
1257
- for i in range(batch_size):
1258
- if not exit_reached[i] and unfinished_sequences[i].bool():
1259
- next_token_logits[i] = outputs.logits[i, -1, :].to(**logit_type) # type: ignore
1260
- compute_steps_per_seq[i] = max_steps
1261
-
1262
- # Save latent states for continuous compute if enabled
1263
- if continuous_compute:
1264
- model_kwargs["input_states"] = current_latents[:, -1:, :]
1265
-
1266
- # Record compute steps for this token generation
1267
- compute_steps.append([compute_steps_per_seq, exit_values_per_seq])
1268
-
1269
- # Sample or select next token based on generation config
1270
- next_token = self._sample_next_token(next_token_logits, generation_config)
1271
-
1272
- # Append token to sequence
1273
- input_ids = torch.cat([input_ids, next_token], dim=-1)
1274
-
1275
- if streamer:
1276
- streamer.put(next_token.cpu())
1277
-
1278
- # Update model kwargs for next iteration
1279
- model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs)
1280
-
1281
- # Check for stop tokens and update unfinished sequences
1282
- for i in range(batch_size):
1283
- if (
1284
- unfinished_sequences[i].bool()
1285
- and stop_tokens is not None
1286
- and next_token[i, 0].item() in stop_tokens
1287
- ):
1288
- unfinished_sequences[i] = 0
1289
-
1290
- # Apply any custom stopping criteria
1291
- if "stopping_criteria" in model_kwargs:
1292
- unfinished_sequences = unfinished_sequences & ~model_kwargs["stopping_criteria"](input_ids, None)
1293
-
1294
- # Break if all sequences are finished
1295
- if unfinished_sequences.max() == 0:
1296
- break
1297
-
1298
- if streamer:
1299
- streamer.end()
1300
-
1301
- if generation_config.return_dict_in_generate:
1302
- return GenerateDecoderOnlyOutput(
1303
- sequences=input_ids, # type: ignore
1304
- scores=compute_steps, # type: ignore
1305
- logits=None,
1306
- attentions=None,
1307
- hidden_states=None,
1308
- past_key_values=model_kwargs.get("past_key_values"),
1309
- )
1310
- return input_ids
1311
-
1312
- def _get_stops(self, generation_config, tokenizer, model_kwargs):
1313
- stop_tokens = {65504, 65505, 65508} # begin_text, end_text, end_turn
1314
- if generation_config.eos_token_id is not None:
1315
- stop_tokens.add(generation_config.eos_token_id)
1316
- if "stopping_criteria" in model_kwargs and tokenizer is None:
1317
- tokenizer = model_kwargs["stopping_criteria"][0].tokenizer
1318
- if hasattr(generation_config, "stop_strings") and tokenizer and generation_config.stop_strings:
1319
- for s in generation_config.stop_strings:
1320
- token_id = tokenizer(s, add_special_tokens=False)["input_ids"][0]
1321
- stop_tokens.add(token_id)
1322
- return torch.tensor(list(stop_tokens))
1323
-
1324
- def _sample_next_token(self, next_token_logits, generation_config):
1325
- """Helper function to sample the next token."""
1326
- if generation_config.do_sample:
1327
- if generation_config.temperature:
1328
- next_token_logits = next_token_logits.float() / generation_config.temperature
1329
-
1330
- probs = F.softmax(next_token_logits, dim=-1)
1331
-
1332
- # Apply top_k
1333
- if generation_config.top_k:
1334
- top_k_values, _ = torch.topk(probs, generation_config.top_k, dim=-1)
1335
- min_values = top_k_values[:, -1].unsqueeze(-1).expand_as(probs)
1336
- probs = torch.where(probs < min_values, torch.zeros_like(probs), probs)
1337
-
1338
- # Apply top_p (nucleus sampling)
1339
- if generation_config.top_p:
1340
- sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
1341
- cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
1342
-
1343
- # Create mask for probs to keep
1344
- remove_indices = cumulative_probs > generation_config.top_p
1345
- remove_indices[:, 0] = False # Keep at least the top probability
1346
-
1347
- # Convert sorted indices mask back to original indices mask
1348
- mask = torch.zeros_like(probs, dtype=torch.bool)
1349
- for i in range(probs.shape[0]):
1350
- mask[i, sorted_indices[i, remove_indices[i]]] = True
1351
-
1352
- probs = torch.where(mask, torch.zeros_like(probs), probs)
1353
-
1354
- # Apply min_p
1355
- if generation_config.min_p:
1356
- max_probs = probs.max(dim=-1, keepdim=True)[0]
1357
- min_p_threshold = generation_config.min_p * max_probs
1358
- probs = torch.where(probs < min_p_threshold, torch.zeros_like(probs), probs)
1359
-
1360
- # Renormalize probabilities
1361
- probs = probs / probs.sum(dim=-1, keepdim=True).clamp(min=1e-10)
1362
-
1363
- # Sample from the distribution
1364
- return torch.multinomial(probs, num_samples=1)
1365
- else:
1366
- return torch.argmax(next_token_logits, dim=-1, keepdim=True)
1367
-
1368
- @torch.no_grad()
1369
- def generate_speculative(
1370
- self,
1371
- input_ids: torch.Tensor,
1372
- generation_config: Optional[GenerationConfig] = None, # type: ignore
1373
- tokenizer=None,
1374
- streamer=None,
1375
- continuous_compute=False, # warm-start state / continuous CoT
1376
- init_scale: float = 1.0,
1377
- cache_lookup_strategy: str = "full",
1378
- draft_steps=32,
1379
- lookahead_for_draft=8,
1380
- verification_threshold=1,
1381
- num_steps: int = 32, # intercept deliberately
1382
- **model_kwargs,
1383
- ) -> Union[torch.Tensor, dict[str, Any]]:
1384
- """Batched speculative decoding with per-sequence acceptance."""
1385
- assert lookahead_for_draft > 0
1386
- pad_id = 65509
1387
- model_kwargs, generation_config, max_new_tokens = self._prep_generate_args(
1388
- input_ids, generation_config, cache_lookup_strategy, model_kwargs
1389
  )
1390
- stop_tokens = self._get_stops(generation_config, tokenizer, model_kwargs).to(input_ids.device)
1391
- unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
1392
-
1393
- # Set up continuous compute if enabled
1394
- if continuous_compute:
1395
- embedded_inputs, _ = self.embed_inputs(input_ids)
1396
- model_kwargs["input_states"] = self.initialize_state(embedded_inputs, scale=init_scale)
1397
-
1398
- tokens_generated = 0
1399
- # Prefill cache with full num_steps
1400
- if model_kwargs["past_key_values"].get_seq_length() == 0:
1401
- model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1402
- outputs = self(**model_inputs, num_steps=num_steps, init_scale=init_scale)
1403
- next_token = self._sample_next_token(
1404
- outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32), generation_config
1405
- )
1406
- input_ids = torch.cat([input_ids, next_token], dim=-1)
1407
- tokens_generated += 1
1408
- if streamer:
1409
- streamer.put(next_token.cpu())
1410
- model_kwargs["cache_position"] = torch.as_tensor(
1411
- [model_inputs["past_key_values"].get_seq_length()], device=input_ids.device
1412
- )
1413
- if continuous_compute:
1414
- model_kwargs["input_states"] = outputs.latent_states[:, -1:, :]
1415
-
1416
- # Generate tokens
1417
- batch_size, prefix_seq_len = input_ids.shape[0], input_ids.shape[1]
1418
- accepted_tokens = []
1419
-
1420
- while tokens_generated < max_new_tokens:
1421
- ### Run the next draft ####
1422
- drafted_inputs = input_ids.clone()
1423
- current_len = input_ids.shape[1]
1424
-
1425
- for _ in range(lookahead_for_draft):
1426
- model_inputs = self.prepare_inputs_for_generation(drafted_inputs, **model_kwargs)
1427
- outputs = self(**model_inputs, num_steps=draft_steps, init_scale=init_scale)
1428
- next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32)
1429
- next_token = self._sample_next_token(next_token_logits, generation_config)
1430
- drafted_inputs = torch.cat([drafted_inputs, next_token], dim=-1)
1431
- model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
1432
- if continuous_compute:
1433
- model_kwargs["input_states"] = outputs.latent_states[:, -1:, :]
1434
-
1435
- model_kwargs["past_key_values"].clear_last_k_entries(lookahead_for_draft)
1436
-
1437
- ## Verify drafted tokens ###
1438
- model_kwargs["cache_position"] = torch.arange(
1439
- current_len - 1, current_len + lookahead_for_draft - 1, device=input_ids.device
1440
- )
1441
- model_inputs = self.prepare_inputs_for_generation(drafted_inputs, **model_kwargs)
1442
- outputs = self(**model_inputs, num_steps=num_steps, init_scale=init_scale)
1443
- verified_next_token_preds = outputs.logits.argmax(dim=-1)
1444
-
1445
- if verification_threshold >= 1:
1446
- mismatched_tokens = (
1447
- verified_next_token_preds[:, -lookahead_for_draft:] != drafted_inputs[:, current_len:]
1448
- )
1449
- not_all_matched, first_mismatch = torch.max(mismatched_tokens, dim=1)
1450
- else:
1451
- verified_logits = outputs.logits[:, -lookahead_for_draft:, :]
1452
- verified_probs = F.softmax(verified_logits, dim=-1)
1453
- drafted_token_probs = torch.gather(
1454
- verified_probs, -1, drafted_inputs[:, current_len:].unsqueeze(-1)
1455
- ).squeeze(-1)
1456
- max_probs = verified_probs.max(dim=-1)[0]
1457
- verification_passed = drafted_token_probs >= verification_threshold * max_probs
1458
- not_all_matched, first_mismatch = torch.max(~verification_passed, dim=1)
1459
-
1460
- # Per-sequence acceptance handling
1461
- acceptance_lengths = torch.where(not_all_matched, first_mismatch, lookahead_for_draft)
1462
-
1463
- # Build next_tokens for each sequence
1464
- next_tokens_batch = []
1465
- for i in range(batch_size):
1466
- seq_acceptance = acceptance_lengths[i].item()
1467
- if not_all_matched[i] and seq_acceptance < lookahead_for_draft:
1468
- # Accept up to mismatch + sample final token
1469
- accepted_part = drafted_inputs[i : i + 1, current_len : current_len + seq_acceptance]
1470
- final_token_logits = outputs.logits[i : i + 1, seq_acceptance, :].to(copy=True, dtype=torch.float32)
1471
- final_token = self._sample_next_token(final_token_logits, generation_config)
1472
- seq_tokens = torch.cat([accepted_part, final_token], dim=-1) if seq_acceptance > 0 else final_token
1473
- else:
1474
- # Accept all drafted tokens
1475
- seq_tokens = drafted_inputs[i : i + 1, current_len : current_len + seq_acceptance]
1476
- next_tokens_batch.append(seq_tokens)
1477
-
1478
- # Clean up KV cache - only if any sequence had mismatches
1479
- if not_all_matched.any():
1480
- min_first_mismatch = first_mismatch.min().item()
1481
- model_inputs["past_key_values"].clear_last_k_entries(lookahead_for_draft - min_first_mismatch - 1)
1482
-
1483
- # Concatenate accepted tokens to input_ids
1484
- batch_accepted_counts = [tokens.shape[1] for tokens in next_tokens_batch]
1485
- max_len = max(batch_accepted_counts)
1486
- padded_tokens = [
1487
- torch.cat(
1488
- [
1489
- tokens,
1490
- pad_id * torch.ones((1, max_len - tokens.shape[1]), dtype=tokens.dtype, device=tokens.device),
1491
- ],
1492
- dim=-1,
1493
- )
1494
- if tokens.shape[1] < max_len
1495
- else tokens
1496
- for tokens in next_tokens_batch
1497
- ]
1498
- next_tokens = torch.cat(padded_tokens, dim=0)
1499
- input_ids = torch.cat([input_ids, next_tokens], dim=-1)
1500
-
1501
- accepted_tokens.append(batch_accepted_counts)
1502
- tokens_generated += max(batch_accepted_counts)
1503
-
1504
- if streamer:
1505
- streamer.put(next_tokens_batch[0].cpu())
1506
-
1507
- model_kwargs["cache_position"] = torch.as_tensor(
1508
- [model_inputs["past_key_values"].get_seq_length()], device=input_ids.device
1509
- )
1510
- if continuous_compute:
1511
- model_kwargs["input_states"] = outputs.latent_states[:, -1:, :]
1512
-
1513
- # Check stopping conditions
1514
- if stop_tokens is not None:
1515
- for i in range(batch_size):
1516
- if unfinished_sequences[i] and torch.isin(next_tokens_batch[i], stop_tokens).any():
1517
- unfinished_sequences[i] = 0
1518
- if "stopping_criteria" in model_kwargs:
1519
- unfinished_sequences = unfinished_sequences & ~model_kwargs["stopping_criteria"](input_ids, None)
1520
- if unfinished_sequences.max() == 0:
1521
- break
1522
-
1523
- if streamer:
1524
- streamer.end()
1525
-
1526
- # Cut off extraneous parts of the sequence per batch element
1527
- if stop_tokens is not None:
1528
- for i in range(batch_size):
1529
- stop_positions = torch.isin(input_ids[i, prefix_seq_len:], stop_tokens).nonzero()
1530
- if len(stop_positions) > 0:
1531
- input_ids[i, prefix_seq_len + stop_positions[0].item() + 1 :] = pad_id
1532
- # Trim tensor to remove columns that are pad_id across all sequences
1533
- non_pad_mask = input_ids != pad_id
1534
- last_real_token = non_pad_mask.any(dim=0).nonzero()
1535
- if len(last_real_token) > 0:
1536
- input_ids = input_ids[:, : last_real_token[-1].item() + 1]
1537
-
1538
- if generation_config.return_dict_in_generate:
1539
- return GenerateDecoderOnlyOutput(
1540
- sequences=input_ids, # type: ignore
1541
- scores=accepted_tokens, # type: ignore
1542
- logits=None,
1543
- attentions=None,
1544
- hidden_states=None,
1545
- past_key_values=model_kwargs.get("past_key_values"),
1546
- )
1547
- return input_ids
1548
-
1549
- def get_stats(self, logits, x, latent_states, xk, input_embeds, num_steps_no_grad, num_steps_with_grad):
1550
- probs = torch.softmax(logits.float(), dim=-1)
1551
- prob_entropy = torch.where(probs > 0, -probs * probs.log(), 0).sum(dim=-1)
1552
- residual_diff = (x - latent_states).norm(dim=-1)
1553
- rel_residual = residual_diff / latent_states.norm(dim=-1)
1554
- stats = {
1555
- "entropy": prob_entropy,
1556
- "residual_diff": residual_diff,
1557
- "rel_residual": rel_residual,
1558
- "num_steps_no_grad": num_steps_no_grad,
1559
- "num_steps_with_grad": num_steps_with_grad,
1560
  }
1561
- return stats
1562
-
1563
-
1564
- #################################### Utils #######################################################################
1565
- def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, condense_ratio: int = 1):
1566
- with torch.autocast("cuda", enabled=False):
1567
- inv_freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
1568
- t = torch.arange(end, dtype=torch.float32, device=inv_freqs.device) / condense_ratio
1569
- freqs = torch.outer(t, inv_freqs).float()
1570
- return torch.stack([torch.cos(freqs)[None, :, None, :], torch.sin(freqs)[None, :, None, :]], dim=4)
1571
- # equivalent to
1572
- # freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
1573
- # cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
1574
-
1575
 
1576
- def apply_rotary_emb_complex_like(q: Tensor, k: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
1577
- with torch.autocast("cuda", enabled=False):
1578
- qk_r2 = torch.cat([q, k], dim=2).unflatten(dim=-1, sizes=(-1, 2)).float() # cast to float32 for smooth skin
1579
- rotated_qk_r2 = torch.stack(
1580
- [
1581
- qk_r2[..., 0] * freqs_cis[..., 0] - qk_r2[..., 1] * freqs_cis[..., 1],
1582
- qk_r2[..., 1] * freqs_cis[..., 0] + qk_r2[..., 0] * freqs_cis[..., 1],
1583
- ],
1584
- -1,
1585
- ).flatten(3)
1586
- rotated_qk = rotated_qk_r2
1587
- return torch.split(rotated_qk.type_as(q), q.shape[2], dim=2) # type: ignore
1588
-
1589
-
1590
- #################################### HF registration ############################################################
1591
-
1592
- from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
1593
-
1594
- # New
1595
- RavenConfig.register_for_auto_class()
1596
-
1597
- RavenForCausalLM.register_for_auto_class("AutoModel")
1598
- RavenForCausalLM.register_for_auto_class("AutoModelForCausalLM")
1599
-
1600
- # Old?
1601
- AutoConfig.register("huginn_raven", RavenConfig)
1602
- AutoModel.register(RavenConfig, RavenForCausalLM)
1603
- AutoModelForCausalLM.register(RavenConfig, RavenForCausalLM)
 
1
+ """A HuggingFace-style model configuration."""
2
 
3
+ from transformers import PretrainedConfig
4
+ from math import sqrt
5
 
 
 
 
 
 
6
 
7
+ class RavenConfig(PretrainedConfig):
8
+ model_type = "huginn_raven"
9
+ keys_to_ignore_at_inference = [""]
10
+ attribute_map = {"num_attention_heads": "n_heads", "hidden_size": "n_embd", "num_hidden_layers": "n_layers"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  def __init__(
13
  self,
14
+ n_embd: int = 5280,
15
+ n_heads: int = 55,
16
+ n_layers: int = 8, # total of prelude + recurrent + coda
17
+ block_size: int = 4096,
18
+ vocab_size: int = 65536,
19
+ padding_multiple: int = 4096,
20
+ tie_embeddings: bool = True,
21
+ intermediate_size: int = 17920,
22
+ bias: bool = False,
23
+ architecture_class_name: str = "RecurrentGPT",
24
+ block_class_name: str = "SandwichBlock",
25
+ norm_class_name: str = "RMSNorm_llama",
26
+ norm_eps: float = 0.000001,
27
+ mlp_class_name: str = "GatedMLP",
28
+ nonlin_name: str = "SiLU",
29
+ init_strategy: str = "takase",
30
+ init_orthogonal: bool = False,
31
+ state_init: str = "like-init",
32
+ injection_type: str = "linear",
33
+ n_layers_in_recurrent_block: int = 4,
34
+ mean_recurrence: int = 32,
35
+ sampling_scheme: str = "poisson-lognormal-filling",
36
+ mean_backprop_depth: int = 8,
37
+ n_layers_in_prelude: int = 2,
38
+ n_layers_in_coda: int = 2,
39
+ qk_bias: bool = True,
40
+ activation_checkpoint_impl: str = "per-iteration",
41
+ rope_base: float = 50_000,
42
+ torch_dtype: str = "bfloat16",
43
+ transformers_version: str = "4.47.1",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  **kwargs,
45
  ):
46
+ self.n_embd = n_embd
47
+ self.n_heads = n_heads
48
+ self.n_layers = n_layers
49
+ self.block_size = block_size
50
+ self.vocab_size = self.padded_vocab_size = vocab_size
51
+ self.padding_multiple = padding_multiple
52
+ self.tie_embeddings = tie_embeddings
53
+ self.intermediate_size = intermediate_size
54
+ self.bias = bias
55
+ self.architecture_class_name = architecture_class_name
56
+ self.block_class_name = block_class_name
57
+ self.norm_class_name = norm_class_name
58
+ self.norm_eps = norm_eps
59
+ self.mlp_class_name = mlp_class_name
60
+ self.nonlin_name = nonlin_name
61
+ self.init_strategy = init_strategy
62
+ self.init_orthogonal = init_orthogonal
63
+ self.state_init = state_init
64
+ self.injection_type = injection_type
65
+ self.n_layers_in_recurrent_block = n_layers_in_recurrent_block
66
+ self.mean_recurrence = mean_recurrence
67
+ self.sampling_scheme = sampling_scheme
68
+ self.mean_backprop_depth = mean_backprop_depth
69
+ self.n_layers_in_prelude = n_layers_in_prelude
70
+ self.n_layers_in_coda = n_layers_in_coda
71
+ self.qk_bias = qk_bias
72
+ self.activation_checkpoint_impl = activation_checkpoint_impl
73
+ self.rope_base = rope_base
74
+ self.torch_dtype = torch_dtype # Added from JSON
75
+ self.transformers_version = transformers_version # Added from JSON
76
+ # inference
77
+ self.test_time_noise = 0
78
+ self.test_time_noise_type = "fixed"
79
+ # Derived
80
+ self.num_key_value_heads = n_heads
81
+ self.num_attention_heads = n_heads
82
+ self.head_dim = n_embd // n_heads
83
+ self.effective_expected_depth = (
84
+ self.n_layers_in_prelude + self.n_layers_in_coda + self.n_layers_in_recurrent_block * self.mean_recurrence
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  )
86
+ self.init_values = {
87
+ "std": sqrt(2 / (5 * self.n_embd)),
88
+ "out_proj": sqrt(2 / (5 * self.n_embd)) / sqrt(2 * self.effective_expected_depth),
89
+ "embedding": sqrt(2 / (5 * self.n_embd)),
90
+ "embed_scale": sqrt(self.n_embd),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
+ super().__init__(
94
+ # pad_token_id=65509,
95
+ # bos_token_id=65504,
96
+ # eos_token_id=65505,
97
+ tie_word_embeddings=tie_embeddings,
98
+ **kwargs,
99
+ )