JonasGeiping commited on
Commit
89eb92b
·
verified ·
1 Parent(s): 2a364bd

Update raven_modeling_minimal.py

Browse files
Files changed (1) hide show
  1. raven_modeling_minimal.py +813 -310
raven_modeling_minimal.py CHANGED
@@ -1,14 +1,17 @@
1
- """Minimal modeling.py file for HF compatibility and funny zero-shot experiments. Best used for inference, finetuning should work, but is untested with this implementation."""
2
 
3
  import torch
4
  import math
5
 
6
  from torch import Tensor
 
 
7
  from dataclasses import dataclass
8
- from typing import Optional, Union, Any
 
9
 
10
  from .raven_config_minimal import RavenConfig
11
- from transformers.cache_utils import Cache, DynamicCache
12
 
13
  ###################### Huggingface Glue code I ##################################################################
14
  from transformers import PreTrainedModel, GenerationMixin
@@ -18,6 +21,8 @@ from transformers.generation.utils import GenerateDecoderOnlyOutput
18
  import torch.nn.functional as F
19
  from transformers import GenerationConfig
20
 
 
 
21
 
22
  class RavenPreTrainedModel(PreTrainedModel):
23
  config_class = RavenConfig
@@ -30,7 +35,8 @@ class RavenPreTrainedModel(PreTrainedModel):
30
  _supports_sdpa = True
31
  _supports_cache_class = True
32
  _supports_quantized_cache = False
33
- _supports_static_cache = False
 
34
 
35
  def _init_weights(self, module):
36
  if not torch.rand((1,)).is_meta:
@@ -87,17 +93,24 @@ class HuginnDynamicCache(DynamicCache):
87
  self,
88
  key_states: torch.Tensor,
89
  value_states: torch.Tensor,
90
- step_idx: int,
91
  lookup_strategy: Optional[str] = None,
92
  ) -> tuple[torch.Tensor, torch.Tensor]:
 
93
  lookup_strategy = self.lookup_strategy if lookup_strategy is None else lookup_strategy
94
  if "compress-" in self.lookup_strategy and step_idx > 1: # hardcode for current model!
95
- compression_stage = int(self.lookup_strategy.split("compress-")[1][1:])
96
  if "compress-s" in self.lookup_strategy:
 
97
  new_step_idx = (step_idx - 2) % compression_stage + 2
98
- else:
 
 
 
 
 
 
 
99
  new_step_idx = (step_idx - 2) // compression_stage + 2
100
- # @ print(step_idx, new_step_idx, compression_stage)
101
  step_idx = new_step_idx
102
  # Init
103
  if step_idx not in self.key_cache:
@@ -110,7 +123,6 @@ class HuginnDynamicCache(DynamicCache):
110
  for idx, entry in enumerate(key_states.unbind(dim=-2)):
111
  if "compress-" not in self.lookup_strategy:
112
  assert step_idx < 0 or self._seen_tokens - key_states.shape[-2] + idx not in self.key_cache[step_idx]
113
- # print(f"Overwrote cache entry for step_idx {step_idx}") # likely the head
114
  self.key_cache[step_idx][self._seen_tokens - key_states.shape[-2] + idx] = entry
115
  for idx, entry in enumerate(value_states.unbind(dim=-2)):
116
  self.value_cache[step_idx][self._seen_tokens - value_states.shape[-2] + idx] = entry
@@ -122,31 +134,45 @@ class HuginnDynamicCache(DynamicCache):
122
  torch.stack(list(self.key_cache[step_idx].values()), dim=-2),
123
  torch.stack(list(self.value_cache[step_idx].values()), dim=-2),
124
  )
125
- else: # some entries where not previously computed
126
- # if lookup_strategy.startswith("latest"):
127
- # latest_keys = []
128
- # latest_values = []
129
- # for token_pos in range(self._seen_tokens):
130
- # # Find the latest step that has this token position
131
- # max_step = max((s for s in range(step_idx + 1) if token_pos in self.key_cache[s]), default=None)
132
- # if max_step is None:
133
- # raise ValueError(f"No cache entry found for token position {token_pos}")
134
- # latest_keys.append(self.key_cache[max_step][token_pos])
135
- # latest_values.append(self.value_cache[max_step][token_pos])
136
- # return torch.stack(latest_keys, dim=-2), torch.stack(latest_values, dim=-2)
137
  if lookup_strategy.startswith("latest-m4"):
138
  latest_keys = []
139
  latest_values = []
140
  for token_pos in range(self._seen_tokens):
141
- # For steps >= 2, use modulo 4
142
  if step_idx >= 2:
143
  # Find valid steps for this token position
144
  valid_steps = [s for s in range(step_idx + 1) if token_pos in self.key_cache[s]]
145
  max_step = max([s for s in valid_steps if s >= 2 and s % 4 == step_idx % 4])
146
  else:
147
  max_step = step_idx if token_pos in self.key_cache[step_idx] else 0
148
- if max_step is None:
149
- raise ValueError(f"No cache entry found for token position {token_pos}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  latest_keys.append(self.key_cache[max_step][token_pos])
151
  latest_values.append(self.value_cache[max_step][token_pos])
152
  return torch.stack(latest_keys, dim=-2), torch.stack(latest_values, dim=-2)
@@ -184,6 +210,20 @@ class HuginnDynamicCache(DynamicCache):
184
  self.key_cache.clear()
185
  self.value_cache.clear()
186
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  def get_seq_length(self, step_idx: int = 0) -> int:
188
  return self._seen_tokens
189
 
@@ -200,6 +240,134 @@ class HuginnDynamicCache(DynamicCache):
200
  return total_bytes * 2 / (1024 * 1024)
201
 
202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  class CausalSelfAttention(torch.nn.Module):
204
  def __init__(self, config: RavenConfig) -> None:
205
  super().__init__()
@@ -219,11 +387,10 @@ class CausalSelfAttention(torch.nn.Module):
219
  self,
220
  x: Tensor,
221
  freqs_cis: Tensor,
222
- step_idx: int,
223
- mask: Optional[Tensor] = None,
224
- past_key_values: Optional[Cache] = None,
225
- return_attn: bool = False,
226
- ) -> tuple[Tensor, Optional[Tensor]]:
227
  B, S, E = x.shape # batch size, sequence length, embedding dimensionality (n_embd)
228
  q, k, v = self.Wqkv(x).split(self.chunks, dim=2)
229
  q = q.view(B, S, self.n_head, self.head_dim)
@@ -241,30 +408,21 @@ class CausalSelfAttention(torch.nn.Module):
241
  v = v.transpose(1, 2)
242
 
243
  if past_key_values is not None:
244
- k, v = past_key_values.update(k, v, step_idx)
245
 
246
- if return_attn:
247
- y, attention_map = self.compute_eager_sdpa(q, k, v, attn_mask=mask)
248
  else:
249
- y = torch.nn.functional.scaled_dot_product_attention(
250
- q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=q.shape[2] > 1
251
- )
 
 
 
 
 
252
  y = y.transpose(1, 2).reshape(B, S, E).contiguous() # reshape is a view if possible (it mostly is)
253
- return self.proj(y), attention_map if return_attn else None
254
-
255
- def compute_eager_sdpa(self, q, k, v, attn_mask):
256
- scale = 1.0 / math.sqrt(self.head_dim)
257
- scores = torch.matmul(q, k.transpose(-2, -1)) * scale
258
-
259
- if attn_mask is not None:
260
- scores = scores + attn_mask
261
- if q.shape[2] > 1:
262
- causal_mask = torch.triu(torch.ones(q.shape[2], q.shape[2]), diagonal=1).bool()
263
- scores.masked_fill_(causal_mask.to(scores.device), float("-inf"))
264
-
265
- attention_weights = torch.nn.functional.softmax(scores, dim=-1)
266
- y = torch.matmul(attention_weights, v)
267
- return y, attention_weights.max(dim=1)[0]
268
 
269
 
270
  class GatedMLP(torch.nn.Module):
@@ -301,17 +459,18 @@ class SandwichBlock(torch.nn.Module):
301
  x: Tensor,
302
  freqs_cis: Tensor,
303
  step_idx: int,
304
- mask: Optional[Tensor] = None,
305
- past_key_values: Optional[Cache] = None,
306
- return_attn: bool = False,
307
- ) -> tuple[Tensor, Optional[Tensor]]:
308
- attn_out, attn_map = self.attn(self.norm_1(x), freqs_cis, step_idx, mask, past_key_values, return_attn)
309
  x = self.norm_2(attn_out + x)
310
  x = self.norm_4(self.mlp(self.norm_3(x)) + x)
311
- return x, attn_map
312
 
313
 
314
  class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
 
 
315
  def __init__(
316
  self,
317
  config: RavenConfig,
@@ -360,25 +519,133 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
360
  )
361
  return freqs_cis
362
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
  def forward(
364
  self,
365
  input_ids: torch.Tensor,
366
  input_embeds: Optional[torch.Tensor] = None,
367
  input_states: Optional[torch.Tensor] = None,
368
- attention_mask: Optional[torch.Tensor] = None,
369
  position_ids: Optional[torch.Tensor] = None,
370
  labels: Optional[torch.Tensor] = None,
371
  num_steps: Optional[torch.Tensor] = None,
372
- past_key_values: Optional[Cache] = None,
373
  output_details: dict = {
374
  "return_logits": True,
375
  "return_latents": True,
376
- "return_attention": False,
377
  "return_head": False,
378
  "return_stats": False,
379
  },
380
  use_cache: bool = False,
381
  cache_position: Optional[torch.Tensor] = None,
 
382
  **kwargs,
383
  ) -> CausalLMOutputRecurrentLatents:
384
  # Support multiple position formats:
@@ -390,47 +657,47 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
390
  freqs_cis = self.freqs_cis[:, cache_position]
391
 
392
  if input_embeds is None:
393
- input_embeds = self.transformer.wte(input_ids)
394
 
395
  if self.emb_scale != 1:
396
  input_embeds = input_embeds * self.emb_scale # type: ignore
397
 
398
  if use_cache and past_key_values is None:
399
  past_key_values = HuginnDynamicCache()
400
- attn_maps = {}
401
- return_attn = output_details["return_attention"]
402
 
 
 
403
  # Non-recurrent prelude
404
- for block_idx, block in enumerate(self.transformer.prelude):
405
- input_embeds, attn_map = block(
406
- input_embeds, freqs_cis, block_idx, attention_mask, past_key_values, return_attn=return_attn
407
- )
408
- attn_maps[block_idx] = attn_map
409
 
410
  # Main recurrence
411
- x, num_steps_no_grad, num_steps_with_grad, xk, block_idx, attn_maps = self.iterate_forward(
412
- input_embeds, # type: ignore
413
  input_states,
414
  freqs_cis,
415
  block_idx,
416
- attention_mask,
417
  past_key_values,
418
  num_steps,
419
- attn_maps,
420
- return_attn=return_attn,
421
  )
422
  latent_states = x.clone().detach()
423
 
424
  # Coda layers
425
- for block_idx, block in enumerate(self.transformer.coda, start=1):
426
- x, attn_map = block(x, freqs_cis, -block_idx, attention_mask, past_key_values, return_attn=return_attn)
427
- attn_maps[-block_idx] = attn_map
428
- x = self.transformer.ln_f(x)
 
429
 
430
  # Prediction head, assuming labels really are labels and not equal to input_ids
431
  if labels is not None:
432
  logits = self.lm_head(x).float()
433
- loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.shape[-1]), labels.view(-1))
 
 
434
  log_ppl = loss.clone().detach().exp()
435
  else:
436
  logits = self.lm_head(x).float()
@@ -443,7 +710,6 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
443
  past_key_values=past_key_values,
444
  hidden_states=x if output_details["return_head"] else None,
445
  latent_states=latent_states if output_details["return_latents"] else None,
446
- attention_maps=attn_maps if output_details["return_attention"] else None, # type: ignore
447
  stats=self.get_stats(logits, x, latent_states, xk, input_embeds, num_steps_no_grad, num_steps_with_grad)
448
  if output_details["return_stats"]
449
  else None,
@@ -452,17 +718,16 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
452
  @torch._dynamo.disable(recursive=False) # type: ignore
453
  def iterate_forward(
454
  self,
455
- input_embeds,
456
- input_states,
457
  freqs_cis,
458
- block_idx,
459
- mask,
460
- past_key_values: Optional[Cache] = None,
461
  num_steps: Optional[torch.Tensor] = None,
462
- attn_maps: dict = {},
463
- return_attn: bool = False,
464
  ):
465
- x = xk = self.initialize_state(input_embeds) if input_states is None else input_states.clone()
466
  if num_steps is None:
467
  num_steps_no_grad, num_steps_with_grad = self.randomized_iteration_sampler() # type: ignore
468
  elif hasattr(num_steps, "__len__") and len(num_steps) > 1:
@@ -475,35 +740,35 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
475
  # https://discuss.pytorch.org/t/does-distributeddataparallel-work-with-torch-no-grad-and-find-unused-parameters-false/122594
476
  # for now running with find_unused_params=True enabled even though the graph structure is (technically) clear
477
  # and all parameters are always used
478
- for step in range(num_steps_no_grad):
479
  xk = x
480
- x, block_idx, attn_maps = self.core_block_forward(
481
- xk, input_embeds, freqs_cis, mask, past_key_values, block_idx, attn_maps, return_attn
482
  )
483
 
484
- for step in range(num_steps_with_grad):
485
  xk = x
486
- x, block_idx, attn_maps = self.core_block_forward(
487
- xk, input_embeds, freqs_cis, mask, past_key_values, block_idx, attn_maps, return_attn
488
  )
489
- return self.transformer.ln_f(x), num_steps_no_grad, num_steps_with_grad, xk.detach(), block_idx, attn_maps
490
 
491
  def core_block_forward(
492
  self,
493
  x,
494
  input_embeds,
495
  freqs_cis,
496
- mask,
497
  past_key_values,
498
- block_idx: Union[torch.Tensor, int],
499
- attn_maps: dict = {},
500
- return_attn: bool = False,
501
  ):
502
- x = self.transformer.adapter(torch.cat([x, input_embeds.to(x.device)], dim=-1))
503
- for idx, block in enumerate(self.transformer.core_block, start=1):
504
- x, attn_map = block(x, freqs_cis, block_idx + idx, mask, past_key_values, return_attn=return_attn)
505
- attn_maps[block_idx + idx] = attn_map
506
- return x, block_idx + idx, attn_maps
 
507
 
508
  @torch.no_grad()
509
  def iterate_one_step(
@@ -512,10 +777,10 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
512
  input_states,
513
  position_ids: Optional[torch.Tensor] = None,
514
  cache_position: Optional[torch.Tensor] = None,
515
- block_idx: Union[torch.Tensor, int] = 0,
516
- attention_mask: Optional[Tensor] = None,
517
- past_key_values: Optional[Cache] = None,
518
- attn_maps: dict = {},
519
  ):
520
  if position_ids is None and cache_position is None:
521
  freqs_cis = self.freqs_cis[:, : input_embeds.shape[1]]
@@ -523,20 +788,24 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
523
  freqs_cis = self.freqs_cis.index_select(1, position_ids.squeeze())
524
  elif cache_position is not None:
525
  freqs_cis = self.freqs_cis[:, cache_position]
526
- x, block_idx, attn_maps = self.core_block_forward(
527
- input_states, input_embeds, freqs_cis, attention_mask, past_key_values, block_idx, attn_maps
 
 
 
 
 
 
528
  )
529
- return x, block_idx, attn_maps
530
 
531
  def predict_from_latents(
532
  self,
533
  latents,
534
- attention_mask: Optional[torch.Tensor] = None,
535
  position_ids: Optional[torch.Tensor] = None,
536
  cache_position: Optional[torch.Tensor] = None,
537
- past_key_values: Optional[Cache] = None,
538
- return_attn: bool = False,
539
- attn_maps: dict = {},
540
  ):
541
  if position_ids is None and cache_position is None:
542
  freqs_cis = self.freqs_cis[:, : latents.shape[1]]
@@ -544,12 +813,13 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
544
  freqs_cis = self.freqs_cis.index_select(1, position_ids.squeeze())
545
  elif cache_position is not None:
546
  freqs_cis = self.freqs_cis[:, cache_position]
547
- x = self.transformer.ln_f(latents)
548
  # Coda layers
549
- for block_idx, block in enumerate(self.transformer.coda, start=1):
550
- x, attn_map = block(x, freqs_cis, -block_idx, attention_mask, past_key_values)
551
- attn_maps[block_idx] = attn_map
552
- x = self.transformer.ln_f(x)
 
553
 
554
  logits = self.lm_head(x).float()
555
 
@@ -558,7 +828,7 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
558
  log_ppl=torch.as_tensor(0.0),
559
  logits=logits,
560
  past_key_values=past_key_values,
561
- attention_maps=attn_maps if len(attn_maps) > 0 else None,
562
  )
563
 
564
  def embed_inputs(
@@ -566,12 +836,11 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
566
  input_ids: torch.Tensor,
567
  attention_mask: Optional[torch.Tensor] = None,
568
  position_ids: Optional[torch.Tensor] = None,
569
- past_key_values: Optional[Cache] = None,
570
  use_cache: bool = False,
571
  cache_position: Optional[torch.Tensor] = None,
572
- return_attn: bool = False,
573
  **kwargs,
574
- ) -> tuple[torch.Tensor, int, dict[int, Tensor]]:
575
  # Support multiple position formats:
576
  if position_ids is None and cache_position is None:
577
  freqs_cis = self.freqs_cis[:, : input_ids.shape[1]]
@@ -580,7 +849,8 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
580
  elif cache_position is not None:
581
  freqs_cis = self.freqs_cis[:, cache_position]
582
 
583
- input_embeds = self.transformer.wte(input_ids)
 
584
 
585
  if self.emb_scale != 1:
586
  input_embeds = input_embeds * self.emb_scale # type: ignore
@@ -588,13 +858,12 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
588
  if use_cache and past_key_values is None:
589
  past_key_values = HuginnDynamicCache()
590
 
 
591
  # Non-recurrent prelude
592
- attn_maps = {}
593
- for block_idx, block in enumerate(self.transformer.prelude):
594
- input_embeds, attn_maps = block(
595
- input_embeds, freqs_cis, block_idx, attention_mask, past_key_values, return_attn
596
- )
597
- return input_embeds, block_idx, attn_maps
598
 
599
  @torch._dynamo.disable(recursive=False) # type: ignore
600
  def randomized_iteration_sampler(self) -> tuple[torch.Tensor, torch.Tensor]:
@@ -617,35 +886,75 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
617
 
618
  return n.to(dtype=torch.long), k.to(dtype=torch.long)
619
 
620
- def initialize_state(self, input_embeds, deterministic: bool = False):
621
  x = torch.randn_like(input_embeds)
622
- std = self.config.init_values["std"]
623
- torch.nn.init.trunc_normal_(x, mean=0.0, std=std, a=-3 * std, b=3 * std)
624
- if self.emb_scale != 1:
625
- x = x * self.emb_scale
626
- return x if not deterministic else x.zero_()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
627
 
628
  def prepare_inputs_for_generation(
629
  self,
630
- input_ids: torch.LongTensor,
631
  past_key_values: Optional[Cache] = None,
632
- attention_mask: Optional[torch.LongTensor] = None,
633
  inputs_embeds: Optional[torch.FloatTensor] = None,
634
- cache_position: Optional[torch.LongTensor] = None,
 
635
  **kwargs,
636
  ):
637
  model_inputs = {}
638
  model_inputs["cache_position"] = cache_position
639
  current_input_length = input_ids.shape[1]
 
640
  if past_key_values is not None:
641
- if type(past_key_values) != HuginnDynamicCache:
642
- # Need to use custom cache, detect and replace HF dynamic cache if generate injects it
643
- assert past_key_values.get_seq_length() == 0
644
- past_key_values = HuginnDynamicCache()
 
 
 
 
 
 
 
 
 
 
 
645
  model_inputs["past_key_values"] = past_key_values if kwargs["use_cache"] else None
646
  input_ids = input_ids[:, cache_position] # type: ignore
647
- model_inputs["input_ids"] = input_ids.clone(memory_format=torch.contiguous_format)
648
 
 
649
  if cache_position is None:
650
  position_ids = torch.arange(current_input_length)[None, :].to(input_ids.device)
651
  model_inputs["position_ids"] = position_ids[:, -current_input_length:].clone(
@@ -662,72 +971,88 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
662
  def generate(self, *args, **kwargs):
663
  """Dispatcher - use HF generate in all normal cases."""
664
  self.generation_config = args[1] if len(args) > 1 else self.generation_config
665
- if any(
666
- k in kwargs
667
- for k in ("continuous_compute", "latent_dampening", "criterion", "exit_threshold", "cache_kwargs")
668
- ):
669
- print("Dispatching to custom generate function call")
670
  return self.generate_with_adaptive_compute(*args, **kwargs)
 
 
 
671
  else:
672
  return super().generate(*args, **kwargs)
673
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
674
  @torch.no_grad()
675
  def generate_minimal(
676
  self,
677
- input_ids: torch.LongTensor,
678
  generation_config: Optional[GenerationConfig] = None, # type: ignore
679
  tokenizer=None,
680
  streamer=None,
681
  continuous_compute=False, # warm-start state / continuous CoT
682
- cache_kwargs: dict = {},
 
683
  **model_kwargs,
684
  ) -> Union[torch.Tensor, dict[str, Any]]:
685
  """Minimal single-sequence generation. Template for more complicated generate tasks"""
686
- # Setup
687
- if generation_config is None:
688
- generation_config: GenerationConfig = self.generation_config # type: ignore
689
- model_kwargs["past_key_values"] = HuginnDynamicCache(**cache_kwargs)
690
- model_kwargs["use_cache"] = True
691
- model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
692
- stop_tokens = self._get_stops(generation_config, tokenizer).to(input_ids.device)
693
  if continuous_compute:
694
- embedded_inputs, _, _ = self.embed_inputs(input_ids)
695
- model_kwargs["input_states"] = self.initialize_state(embedded_inputs)
 
696
  # Generate tokens
697
- for _ in range(generation_config.max_length - input_ids.shape[1]):
 
698
  # Forward pass
699
  model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
700
- outputs = self(**model_inputs)
701
- next_token_logits = outputs.logits[0, -1, :]
702
- if continuous_compute:
703
- current_last_latent = outputs.latent_states[:, -1:, :]
704
-
705
- # Sample or select next token
706
- if generation_config.do_sample:
707
- if generation_config.temperature:
708
- next_token_logits = next_token_logits / generation_config.temperature
709
-
710
- probs = F.softmax(next_token_logits, dim=-1)
711
-
712
- # Apply top_k
713
- if generation_config.top_k:
714
- top_k_probs, _ = torch.topk(probs, generation_config.top_k)
715
- probs[probs < top_k_probs[-1]] = 0
716
- # Apply top_p
717
- if generation_config.top_p:
718
- sorted_probs = torch.sort(probs, descending=True)[0]
719
- cumsum = torch.cumsum(sorted_probs, dim=-1)
720
- probs[cumsum > generation_config.top_p] = 0
721
- # Apply min_p
722
- if generation_config.min_p:
723
- probs[probs < generation_config.min_p * probs.max()] = 0
724
-
725
- probs = probs / probs.sum()
726
- next_token = torch.multinomial(probs, num_samples=1)
727
- else:
728
- next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
729
 
730
- input_ids = torch.cat([input_ids, next_token[None, :]], dim=-1) # type: ignore
 
 
 
 
 
731
 
732
  if streamer:
733
  streamer.put(next_token.cpu())
@@ -735,10 +1060,15 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
735
  # Update model kwargs
736
  model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs)
737
  if continuous_compute:
738
- model_kwargs["input_states"] = current_last_latent
739
-
740
- # Check if we hit a stop token
741
- if stop_tokens is not None and next_token in stop_tokens:
 
 
 
 
 
742
  break
743
 
744
  if streamer:
@@ -746,7 +1076,7 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
746
 
747
  if generation_config.return_dict_in_generate:
748
  return GenerateDecoderOnlyOutput(
749
- sequences=input_ids,
750
  scores=None,
751
  logits=None,
752
  attentions=None,
@@ -758,51 +1088,51 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
758
  @torch.no_grad()
759
  def generate_with_adaptive_compute(
760
  self,
761
- input_ids: torch.LongTensor,
762
  generation_config: Optional[GenerationConfig] = None, # type: ignore
763
  tokenizer=None,
764
  streamer=None,
765
  continuous_compute=False, # warm-start state / continuous CoT
766
- latent_dampening=False,
767
- criterion="entropy-diff",
768
  exit_threshold: Union[str, float, int] = "auto",
769
- cache_kwargs: dict = {},
 
770
  **model_kwargs,
771
  ) -> Union[torch.Tensor, GenerateDecoderOnlyOutput]:
772
  """
773
  Generate tokens with adaptive compute. This is NOT the most efficient implementation.
774
  For batches, on each token, we iterate until the entire batch finishes.
775
  """
776
- # Setup
777
- if generation_config is None:
778
- generation_config: GenerationConfig = self.generation_config # type: ignore
779
- model_kwargs["past_key_values"] = HuginnDynamicCache(**cache_kwargs)
780
- model_kwargs["use_cache"] = True
781
- model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
782
- stop_tokens = self._get_stops(generation_config, tokenizer).to(input_ids.device)
783
  batch_size = input_ids.shape[0]
784
  compute_steps = []
785
 
786
  # Set up continuous compute if enabled
787
  if continuous_compute:
788
- embedded_inputs, _, _ = self.embed_inputs(input_ids)
789
- current_last_latents = self.initialize_state(embedded_inputs)
790
 
791
- # Track which sequences have finished
792
- finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=input_ids.device)
793
 
794
  # Generate tokens
795
- for step in range(generation_config.max_length - input_ids.shape[1]):
796
  # Adaptive compute forward
797
  model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
798
  aux_inputs = {
799
  k: model_inputs[k] for k in ["cache_position", "past_key_values", "attention_mask"] if k in model_inputs
800
  }
801
- embedded_inputs, block_idx, _ = self.embed_inputs(model_inputs["input_ids"], **aux_inputs)
802
- if not continuous_compute:
803
- current_latents = self.initialize_state(embedded_inputs, deterministic=False)
804
- else:
805
- current_latents = current_last_latents
 
806
 
807
  # Initialize criterion tracking for each sequence in batch
808
  exit_values_per_seq = [[] for _ in range(batch_size)]
@@ -813,11 +1143,11 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
813
  if criterion == "entropy-diff":
814
  entropy = torch.ones(batch_size, device=input_ids.device) * 100.0
815
  exit_threshold = 1e-3 if exit_threshold == "auto" else float(exit_threshold)
816
- elif criterion in ["latent-diff", "none"]:
817
  exit_threshold = 0.03 if exit_threshold == "auto" else float(exit_threshold)
818
  elif "kl" in criterion:
819
  V = self.config.padded_vocab_size
820
- log_probs = ((1 / V) * torch.ones(batch_size, V, device=input_ids.device)).log()
821
  if criterion == "minp-kl":
822
  exit_threshold = 1e-6 if exit_threshold == "auto" else float(exit_threshold)
823
  else:
@@ -826,23 +1156,25 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
826
  stable_for_n_steps = torch.zeros(batch_size, dtype=torch.long, device=input_ids.device)
827
  current_argmax = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) * -1
828
  exit_threshold = 5 if exit_threshold == "auto" else int(exit_threshold)
 
 
829
  else:
830
  raise ValueError("Invalid adaptive compute strategy.")
831
 
832
- all_latents = []
833
  next_token_logits = None
834
 
835
  # Iterate through compute steps
836
- for compute_step in range(model_inputs["num_steps"]):
837
  prev_latents = current_latents.clone()
838
  current_latents, block_idx, _ = self.iterate_one_step(
839
- embedded_inputs, current_latents, block_idx=block_idx, **aux_inputs
 
 
 
 
840
  )
841
 
842
- if latent_dampening:
843
- all_latents.append(current_latents)
844
-
845
- if step > 0: # do not exit in prefill:
846
  # Check exit condition for each sequence in batch
847
  if criterion == "entropy-diff":
848
  prev_entropy = entropy
@@ -851,27 +1183,24 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
851
  probs = F.softmax(logits[:, -1, :], dim=-1)
852
  entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=-1)
853
  exit_values = (entropy - prev_entropy).abs()
854
-
855
  elif criterion == "latent-diff":
856
  norm_diff = (prev_latents - current_latents).norm(dim=-1) / current_latents.norm(dim=-1)
857
  exit_values = norm_diff.mean(dim=-1)
858
-
859
  elif "kl" in criterion:
860
  outputs = self.predict_from_latents(current_latents, **aux_inputs)
861
  logits: torch.Tensor = outputs.logits # type: ignore
862
  prev_log_probs = log_probs
863
  if criterion == "minp-kl":
864
- probs = F.softmax(logits[:, -1, :], dim=-1)
865
  max_probs = probs.max(dim=-1, keepdim=True)[0]
866
  probs_mask = probs < (0.1 * max_probs)
867
- masked_probs = probs
868
  masked_probs[probs_mask] = 1 / V
869
  probs = masked_probs / masked_probs.sum(dim=-1, keepdim=True)
870
  log_probs = probs.log()
871
  else:
872
- log_probs = F.log_softmax(logits[:, -1, :], dim=-1)
873
  exit_values = F.kl_div(log_probs, prev_log_probs, reduction="none", log_target=True).sum(dim=-1)
874
-
875
  elif criterion == "argmax-stability":
876
  prev_argmax = current_argmax
877
  outputs = self.predict_from_latents(current_latents, **aux_inputs)
@@ -881,18 +1210,21 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
881
  current_argmax == prev_argmax, stable_for_n_steps + 1, torch.zeros_like(stable_for_n_steps)
882
  )
883
  exit_values = stable_for_n_steps
 
 
884
 
885
  # Record values and check exits for each sequence
886
  for i in range(batch_size):
887
- if not exit_reached[i] and not finished_sequences[i]:
888
  exit_values_per_seq[i].append(exit_values[i].item())
889
 
 
890
  new_exits = (
891
  exit_values < exit_threshold
892
  if criterion != "argmax-stability"
893
  else exit_values >= exit_threshold
894
  )
895
- new_exits = new_exits & ~exit_reached & ~finished_sequences
896
 
897
  if new_exits.any():
898
  exit_reached = exit_reached | new_exits
@@ -902,79 +1234,65 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
902
  outputs = self.predict_from_latents(current_latents, **aux_inputs)
903
  logits: torch.Tensor = outputs.logits # type: ignore
904
  if next_token_logits is None:
905
- next_token_logits = logits[:, -1, :].clone()
906
  else:
907
- next_token_logits = torch.where(
908
- new_exits.unsqueeze(1).expand_as(logits[:, -1, :]), logits[:, -1, :], next_token_logits
909
- )
910
  for i in range(batch_size):
911
  if new_exits[i]:
912
  compute_steps_per_seq[i] = compute_step + 1
913
 
914
- # If all sequences have exited, break early
915
- if (exit_reached | finished_sequences).all():
916
  break
917
  # This else is if the for loop finished without breaking
918
  else:
919
- if not latent_dampening:
920
- outputs = self.predict_from_latents(current_latents, **aux_inputs)
921
- else:
922
- dampened_latents = torch.sum(torch.cat(all_latents, dim=0), dim=0, keepdim=True)
923
- outputs = self.predict_from_latents(dampened_latents, **aux_inputs)
924
 
925
  # For sequences that didn't exit early, use the final logits
926
  if next_token_logits is None:
927
- next_token_logits = outputs.logits[:, -1, :] # type: ignore
928
  else:
929
- # Only update logits for sequences that didn't exit early
930
- non_exit_mask = ~exit_reached & ~finished_sequences
931
- next_token_logits = torch.where(
932
- non_exit_mask.unsqueeze(1).expand_as(next_token_logits),
933
- outputs.logits[:, -1, :], # type: ignore
934
- next_token_logits,
935
- )
936
-
937
- # Record compute steps for non-exited sequences
938
  for i in range(batch_size):
939
- if non_exit_mask[i]:
940
- compute_steps_per_seq[i] = model_inputs["num_steps"]
 
941
 
942
  # Save latent states for continuous compute if enabled
943
  if continuous_compute:
944
- current_last_latents = current_latents[:, -1:, :]
945
 
946
  # Record compute steps for this token generation
947
  compute_steps.append([compute_steps_per_seq, exit_values_per_seq])
948
 
949
  # Sample or select next token based on generation config
950
- if generation_config.do_sample:
951
- next_token = self._sample_next_token(
952
- next_token_logits,
953
- generation_config.temperature,
954
- generation_config.top_k,
955
- generation_config.top_p,
956
- generation_config.min_p,
957
- )
958
- else:
959
- next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) # type: ignore
960
 
961
- input_ids = torch.cat([input_ids, next_token], dim=-1) # type: ignore
 
962
 
963
  if streamer:
964
  streamer.put(next_token.cpu())
965
 
966
- # Update model kwargs
967
  model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs)
968
- if continuous_compute:
969
- model_kwargs["input_states"] = current_last_latents
970
 
971
- # Check for finished sequences
972
  for i in range(batch_size):
973
- if not finished_sequences[i] and stop_tokens is not None and next_token[i, 0] in stop_tokens:
974
- finished_sequences[i] = True
 
 
 
 
 
 
 
 
975
 
976
  # Break if all sequences are finished
977
- if finished_sequences.all():
978
  break
979
 
980
  if streamer:
@@ -982,7 +1300,7 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
982
 
983
  if generation_config.return_dict_in_generate:
984
  return GenerateDecoderOnlyOutput(
985
- sequences=input_ids,
986
  scores=compute_steps, # type: ignore
987
  logits=None,
988
  attentions=None,
@@ -991,57 +1309,242 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
991
  )
992
  return input_ids
993
 
994
- def _get_stops(self, generation_config, tokenizer):
995
- stop_tokens = set()
996
  if generation_config.eos_token_id is not None:
997
  stop_tokens.add(generation_config.eos_token_id)
 
 
998
  if hasattr(generation_config, "stop_strings") and tokenizer and generation_config.stop_strings:
999
  for s in generation_config.stop_strings:
1000
  token_id = tokenizer(s, add_special_tokens=False)["input_ids"][0]
1001
  stop_tokens.add(token_id)
1002
  return torch.tensor(list(stop_tokens))
1003
 
1004
- def _sample_next_token(self, next_token_logits, temperature=None, top_k=None, top_p=None, min_p=None):
1005
  """Helper function to sample the next token."""
1006
- if temperature:
1007
- next_token_logits = next_token_logits / temperature
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1008
 
1009
- probs = F.softmax(next_token_logits, dim=-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1010
 
1011
- # Apply top_k
1012
- if top_k:
1013
- top_k_values, _ = torch.topk(probs, top_k, dim=-1)
1014
- min_values = top_k_values[:, -1].unsqueeze(-1).expand_as(probs)
1015
- probs = torch.where(probs < min_values, torch.zeros_like(probs), probs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1016
 
1017
- # Apply top_p (nucleus sampling)
1018
- if top_p:
1019
- sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
1020
- cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1021
 
1022
- # Create mask for probs to keep
1023
- remove_indices = cumulative_probs > top_p
1024
- remove_indices[:, 0] = False # Keep at least the top probability
1025
 
1026
- # Convert sorted indices mask back to original indices mask
1027
- mask = torch.zeros_like(probs, dtype=torch.bool)
1028
- for i in range(probs.shape[0]):
1029
- mask[i, sorted_indices[i, remove_indices[i]]] = True
1030
 
1031
- probs = torch.where(mask, torch.zeros_like(probs), probs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1032
 
1033
- # Apply min_p
1034
- if min_p:
1035
- max_probs = probs.max(dim=-1, keepdim=True)[0]
1036
- min_p_threshold = min_p * max_probs
1037
- probs = torch.where(probs < min_p_threshold, torch.zeros_like(probs), probs)
1038
 
1039
- # Renormalize probabilities
1040
- probs = probs / probs.sum(dim=-1, keepdim=True).clamp(min=1e-10)
 
 
 
 
 
 
 
 
 
1041
 
1042
- # Sample from the distribution
1043
- next_token = torch.multinomial(probs, num_samples=1)
1044
- return next_token
 
 
 
 
 
 
 
1045
 
1046
  def get_stats(self, logits, x, latent_states, xk, input_embeds, num_steps_no_grad, num_steps_with_grad):
1047
  probs = torch.softmax(logits.float(), dim=-1)
@@ -1097,4 +1600,4 @@ RavenForCausalLM.register_for_auto_class("AutoModelForCausalLM")
1097
  # Old?
1098
  AutoConfig.register("huginn_raven", RavenConfig)
1099
  AutoModel.register(RavenConfig, RavenForCausalLM)
1100
- AutoModelForCausalLM.register(RavenConfig, RavenForCausalLM)
 
1
+ """Modeling file for HF compatibility and zero-shot experiments."""
2
 
3
  import torch
4
  import math
5
 
6
  from torch import Tensor
7
+ from torch.nn.attention.flex_attention import create_block_mask, BlockMask, flex_attention
8
+ from torch.nn.attention import bias as attn_bias
9
  from dataclasses import dataclass
10
+ from typing import Union, Optional, Any
11
+
12
 
13
  from .raven_config_minimal import RavenConfig
14
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
15
 
16
  ###################### Huggingface Glue code I ##################################################################
17
  from transformers import PreTrainedModel, GenerationMixin
 
21
  import torch.nn.functional as F
22
  from transformers import GenerationConfig
23
 
24
+ torch.backends.cuda.enable_math_sdp(False)
25
+
26
 
27
  class RavenPreTrainedModel(PreTrainedModel):
28
  config_class = RavenConfig
 
35
  _supports_sdpa = True
36
  _supports_cache_class = True
37
  _supports_quantized_cache = False
38
+ _supports_static_cache = True
39
+ _tp_plan = {}
40
 
41
  def _init_weights(self, module):
42
  if not torch.rand((1,)).is_meta:
 
93
  self,
94
  key_states: torch.Tensor,
95
  value_states: torch.Tensor,
96
+ step_idx_tensor: torch.Tensor,
97
  lookup_strategy: Optional[str] = None,
98
  ) -> tuple[torch.Tensor, torch.Tensor]:
99
+ step_idx: int = int(step_idx_tensor) # todo: fix dicts with tensor step_idx, currently the memberships fail
100
  lookup_strategy = self.lookup_strategy if lookup_strategy is None else lookup_strategy
101
  if "compress-" in self.lookup_strategy and step_idx > 1: # hardcode for current model!
 
102
  if "compress-s" in self.lookup_strategy:
103
+ compression_stage = int(self.lookup_strategy.split("compress-")[1][1:])
104
  new_step_idx = (step_idx - 2) % compression_stage + 2
105
+ elif "compress-anchor" in self.lookup_strategy:
106
+ if step_idx - 2 < 4 * 8: # anchor onto first 8 recurrence steps # noqa: SIM108
107
+ new_step_idx = step_idx
108
+ else: # then re-use the next 4 KV states = one recurrence for all future recurrence
109
+ new_step_idx = 34 + (step_idx - 34) % 4
110
+ # print(step_idx, new_step_idx)
111
+ else: # compress-r
112
+ compression_stage = int(self.lookup_strategy.split("compress-")[1][1:])
113
  new_step_idx = (step_idx - 2) // compression_stage + 2
 
114
  step_idx = new_step_idx
115
  # Init
116
  if step_idx not in self.key_cache:
 
123
  for idx, entry in enumerate(key_states.unbind(dim=-2)):
124
  if "compress-" not in self.lookup_strategy:
125
  assert step_idx < 0 or self._seen_tokens - key_states.shape[-2] + idx not in self.key_cache[step_idx]
 
126
  self.key_cache[step_idx][self._seen_tokens - key_states.shape[-2] + idx] = entry
127
  for idx, entry in enumerate(value_states.unbind(dim=-2)):
128
  self.value_cache[step_idx][self._seen_tokens - value_states.shape[-2] + idx] = entry
 
134
  torch.stack(list(self.key_cache[step_idx].values()), dim=-2),
135
  torch.stack(list(self.value_cache[step_idx].values()), dim=-2),
136
  )
137
+ else: # some entries were not previously computed
 
 
 
 
 
 
 
 
 
 
 
138
  if lookup_strategy.startswith("latest-m4"):
139
  latest_keys = []
140
  latest_values = []
141
  for token_pos in range(self._seen_tokens):
142
+ # For steps >= 2, use modulo 4, this hard-codes the huginn block structure for now
143
  if step_idx >= 2:
144
  # Find valid steps for this token position
145
  valid_steps = [s for s in range(step_idx + 1) if token_pos in self.key_cache[s]]
146
  max_step = max([s for s in valid_steps if s >= 2 and s % 4 == step_idx % 4])
147
  else:
148
  max_step = step_idx if token_pos in self.key_cache[step_idx] else 0
149
+ latest_keys.append(self.key_cache[max_step][token_pos])
150
+ latest_values.append(self.value_cache[max_step][token_pos])
151
+ return torch.stack(latest_keys, dim=-2), torch.stack(latest_values, dim=-2)
152
+ elif lookup_strategy.startswith("available-m4"):
153
+ latest_keys = []
154
+ latest_values = []
155
+ for token_pos in range(self._seen_tokens):
156
+ if token_pos in self.key_cache[step_idx]:
157
+ step = step_idx
158
+ else:
159
+ # Find valid steps for this token position
160
+ valid_steps = [s for s in range(step_idx + 1) if token_pos in self.key_cache[s]]
161
+ step = max([s for s in valid_steps if s >= 2 and s % 4 == step_idx % 4])
162
+ latest_keys.append(self.key_cache[step][token_pos])
163
+ latest_values.append(self.value_cache[step][token_pos])
164
+ return torch.stack(latest_keys, dim=-2), torch.stack(latest_values, dim=-2)
165
+ elif lookup_strategy.startswith("always-last-m4"):
166
+ latest_keys = []
167
+ latest_values = []
168
+ for token_pos in range(self._seen_tokens):
169
+ # For steps >= 2, use modulo 4, this hard-codes the huginn block structure for now
170
+ if step_idx >= 2:
171
+ # Find valid steps for this token position
172
+ valid_steps = [key_step for key_step in self.key_cache if token_pos in self.key_cache[key_step]]
173
+ max_step = max([s for s in valid_steps if s >= 2 and s % 4 == step_idx % 4])
174
+ else:
175
+ max_step = step_idx if token_pos in self.key_cache[step_idx] else 0
176
  latest_keys.append(self.key_cache[max_step][token_pos])
177
  latest_values.append(self.value_cache[max_step][token_pos])
178
  return torch.stack(latest_keys, dim=-2), torch.stack(latest_values, dim=-2)
 
210
  self.key_cache.clear()
211
  self.value_cache.clear()
212
 
213
+ def clear_last_k_entries(self, k: int = 0):
214
+ """Partially clear cache."""
215
+ assert self._seen_tokens >= k
216
+ self._seen_tokens = self._seen_tokens - k
217
+ # self.key_cache[step_idx][self._seen_tokens - key_states.shape[-2] + idx] = entry
218
+ self.key_cache = {
219
+ step: {seq: seq_cache for seq, seq_cache in cache.items() if seq < self._seen_tokens}
220
+ for step, cache in self.key_cache.items()
221
+ }
222
+ self.value_cache = {
223
+ step: {seq: seq_cache for seq, seq_cache in cache.items() if seq < self._seen_tokens}
224
+ for step, cache in self.value_cache.items()
225
+ }
226
+
227
  def get_seq_length(self, step_idx: int = 0) -> int:
228
  return self._seen_tokens
229
 
 
240
  return total_bytes * 2 / (1024 * 1024)
241
 
242
 
243
+ class HuginnStaticCache(Cache):
244
+ """Static Cache for the recurrent model"""
245
+
246
+ is_compileable = False # this is todo
247
+
248
+ def __init__(
249
+ self,
250
+ max_length: int,
251
+ max_num_steps: int,
252
+ num_heads: int,
253
+ hidden_dim: int,
254
+ batch_size: int = 1,
255
+ lookup_strategy: str = "full",
256
+ device: Optional[Union[torch.device, str]] = None,
257
+ dtype: torch.dtype = torch.float32,
258
+ ) -> None:
259
+ super().__init__()
260
+ self._seen_tokens = 0
261
+ self.max_length = max_length
262
+ self.lookup_strategy = lookup_strategy
263
+
264
+ # Adjust max_num_steps based on compression strategy
265
+ if "compress-" in lookup_strategy:
266
+ compression_stage = int(lookup_strategy.split("compress-")[1][1:])
267
+ if "compress-s" in lookup_strategy:
268
+ # For modulo compression (s), we need steps for 0,1 + compressed steps
269
+ self.max_num_steps = 4 + compression_stage
270
+ else:
271
+ # For relative compression, we need steps for 0,1 + compressed steps
272
+ self.max_num_steps = 4 + (max_num_steps - 4 + compression_stage - 1) // compression_stage
273
+ else:
274
+ self.max_num_steps = max_num_steps
275
+
276
+ # Pre-allocate cache tensors [steps, batch, heads, seq_len, head_dim]
277
+ device = torch.device(device) if device is not None else None
278
+ cache_shape = (self.max_num_steps, batch_size, num_heads, max_length, hidden_dim)
279
+
280
+ self.key_cache = torch.zeros(cache_shape, dtype=dtype, device=device)
281
+ self.value_cache = torch.zeros(cache_shape, dtype=dtype, device=device)
282
+ self.valid_mask = torch.zeros((self.max_num_steps, max_length), dtype=torch.bool, device=device)
283
+ # Mark tensors as static for compile
284
+ torch._dynamo.mark_static_address(self.key_cache)
285
+ torch._dynamo.mark_static_address(self.value_cache)
286
+ torch._dynamo.mark_static_address(self.valid_mask)
287
+
288
+ def update(
289
+ self,
290
+ key_states: torch.Tensor,
291
+ value_states: torch.Tensor,
292
+ step_idx: torch.Tensor,
293
+ lookup_strategy: Optional[str] = None,
294
+ ) -> tuple[torch.Tensor, torch.Tensor]:
295
+ if step_idx == 0:
296
+ self._seen_tokens += key_states.shape[-2]
297
+
298
+ # Adjust step_idx for compression
299
+ lookup_strategy = self.lookup_strategy if lookup_strategy is None else lookup_strategy
300
+ if "compress-" in lookup_strategy and step_idx > 1:
301
+ compression_stage = int(lookup_strategy.split("compress-")[1][1:])
302
+ if "compress-s" in lookup_strategy:
303
+ step_idx = (step_idx - 2) % compression_stage + 2
304
+ else:
305
+ step_idx = (step_idx - 2) // compression_stage + 2
306
+
307
+ start_idx = self._seen_tokens - key_states.shape[-2]
308
+
309
+ indices = torch.arange(start_idx, start_idx + key_states.shape[-2], device=key_states.device)
310
+ self.key_cache[step_idx].index_copy_(2, indices, key_states)
311
+ self.value_cache[step_idx].index_copy_(2, indices, value_states)
312
+ self.valid_mask[step_idx, start_idx : start_idx + key_states.shape[-2]] = True
313
+
314
+ # Return based on lookup strategy
315
+ if lookup_strategy == "full":
316
+ return (
317
+ self.key_cache[step_idx, :, :, : self._seen_tokens],
318
+ self.value_cache[step_idx, :, :, : self._seen_tokens],
319
+ )
320
+ elif lookup_strategy.startswith("latest-m4"):
321
+ if step_idx >= 2:
322
+ pattern_steps = torch.arange(2, step_idx.item() + 1, 4, device=self.valid_mask.device)
323
+ pattern_valid = self.valid_mask[pattern_steps]
324
+ max_valid_step = pattern_steps[pattern_valid.to(torch.long).argmax(dim=0)]
325
+ return (
326
+ self.key_cache[max_valid_step, torch.arange(self._seen_tokens)],
327
+ self.value_cache[max_valid_step, torch.arange(self._seen_tokens)],
328
+ )
329
+ return self.key_cache[step_idx, :, :, : self._seen_tokens], self.value_cache[
330
+ step_idx, :, :, : self._seen_tokens
331
+ ]
332
+ elif lookup_strategy == "skip":
333
+ valid_mask = self.valid_mask[step_idx, : self._seen_tokens]
334
+ return (
335
+ self.key_cache[step_idx, :, :, : self._seen_tokens][valid_mask],
336
+ self.value_cache[step_idx, :, :, : self._seen_tokens][valid_mask],
337
+ )
338
+ elif lookup_strategy.startswith("randomized"):
339
+ if step_idx < 2:
340
+ max_step = step_idx
341
+ else:
342
+ curr_modulo = (step_idx - 2) % 4 + 2
343
+ valid_steps = (
344
+ torch.where(
345
+ (torch.arange(2, step_idx.item() + 1, device=self.valid_mask.device) - 2) % 4 + 2 == curr_modulo
346
+ )[0]
347
+ + 2
348
+ )
349
+ rand_idx = torch.randint(len(valid_steps), (1,), device=valid_steps.device)
350
+ max_step = valid_steps[rand_idx]
351
+ return self.key_cache[max_step, : self._seen_tokens], self.value_cache[max_step, : self._seen_tokens]
352
+ else:
353
+ raise ValueError(f"Unknown lookup strategy: {lookup_strategy}")
354
+
355
+ def reset(self) -> None:
356
+ self._seen_tokens = 0
357
+ self.key_cache.zero_()
358
+ self.value_cache.zero_()
359
+ self.valid_mask.zero_()
360
+
361
+ def get_seq_length(self, step_idx: int = 0) -> int:
362
+ return self._seen_tokens
363
+
364
+ def get_memory_usage(self) -> float:
365
+ return (self.key_cache.nelement() + self.value_cache.nelement()) * self.key_cache.element_size() / (1024 * 1024)
366
+
367
+
368
+ ValidCache = HuginnDynamicCache | HuginnStaticCache
369
+
370
+
371
  class CausalSelfAttention(torch.nn.Module):
372
  def __init__(self, config: RavenConfig) -> None:
373
  super().__init__()
 
387
  self,
388
  x: Tensor,
389
  freqs_cis: Tensor,
390
+ block_idx: torch.Tensor,
391
+ mask: Optional[BlockMask] = None,
392
+ past_key_values: Optional[ValidCache] = None,
393
+ ) -> Tensor:
 
394
  B, S, E = x.shape # batch size, sequence length, embedding dimensionality (n_embd)
395
  q, k, v = self.Wqkv(x).split(self.chunks, dim=2)
396
  q = q.view(B, S, self.n_head, self.head_dim)
 
408
  v = v.transpose(1, 2)
409
 
410
  if past_key_values is not None:
411
+ k, v = past_key_values.update(k, v, block_idx)
412
 
413
+ if mask is not None:
414
+ y: torch.Tensor = flex_attention(q, k, v, block_mask=mask) # type: ignore
415
  else:
416
+ if q.shape[2] < k.shape[2]:
417
+ if q.shape[2] > 1:
418
+ bias = attn_bias.causal_lower_right(q.shape[2], k.shape[2])
419
+ y = torch.nn.functional.scaled_dot_product_attention(q, k, v, bias, dropout_p=0.0)
420
+ else:
421
+ y = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
422
+ else:
423
+ y = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=True)
424
  y = y.transpose(1, 2).reshape(B, S, E).contiguous() # reshape is a view if possible (it mostly is)
425
+ return self.proj(y)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
426
 
427
 
428
  class GatedMLP(torch.nn.Module):
 
459
  x: Tensor,
460
  freqs_cis: Tensor,
461
  step_idx: int,
462
+ mask: Optional[BlockMask] = None,
463
+ past_key_values: Optional[ValidCache] = None,
464
+ ) -> Tensor:
465
+ attn_out = self.attn(self.norm_1(x), freqs_cis, step_idx, mask, past_key_values)
 
466
  x = self.norm_2(attn_out + x)
467
  x = self.norm_4(self.mlp(self.norm_3(x)) + x)
468
+ return x
469
 
470
 
471
  class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
472
+ freqs_cis: torch.Tensor
473
+
474
  def __init__(
475
  self,
476
  config: RavenConfig,
 
519
  )
520
  return freqs_cis
521
 
522
+ def compile_mask(
523
+ self,
524
+ input_ids: torch.Tensor,
525
+ attention_mask: Optional[torch.Tensor] = None,
526
+ past_key_values: Optional[ValidCache] = None,
527
+ pad_token_id=65509,
528
+ ) -> Optional[BlockMask]:
529
+ batch_size, seq_len = input_ids.shape[0], input_ids.shape[1]
530
+
531
+ # If no padding and no attention mask, no need for a mask
532
+ if attention_mask is None and (input_ids == pad_token_id).sum() == 0:
533
+ return None
534
+
535
+ if past_key_values is not None and seq_len == 1:
536
+ return None
537
+
538
+ # Get total sequence length including cache
539
+ cache_len = past_key_values.get_seq_length() if past_key_values is not None else 0
540
+ kv_length = cache_len + seq_len
541
+
542
+ if attention_mask is None:
543
+
544
+ def mask_mod(b, h, q_idx, kv_idx):
545
+ return q_idx >= kv_idx & (input_ids[b, kv_idx] != pad_token_id)
546
+ else:
547
+
548
+ def mask_mod(b, h, q_idx, kv_idx):
549
+ return (q_idx >= kv_idx) & (input_ids[b, kv_idx] != pad_token_id) & attention_mask[b, q_idx, kv_idx]
550
+
551
+ kv_length = past_key_values.get_seq_length() if past_key_values is not None else seq_len
552
+ if kv_length == 0:
553
+ kv_length = seq_len # prefill
554
+ block_mask = create_block_mask(
555
+ mask_mod,
556
+ B=batch_size,
557
+ H=None,
558
+ Q_LEN=seq_len,
559
+ KV_LEN=kv_length,
560
+ device=input_ids.device,
561
+ )
562
+
563
+ # # Define mask_mod function
564
+ # def mask_mod(b, h, q_idx, kv_idx):
565
+ # # Always apply causal constraint
566
+ # is_causal = q_idx >= kv_idx
567
+
568
+ # # Handle cache vs current tokens
569
+ # is_cache = kv_idx < cache_len
570
+ # current_idx = kv_idx - cache_len
571
+
572
+ # # For cache: always valid; For current: check padding
573
+ # not_pad = input_ids[b, current_idx] != pad_token_id
574
+ # valid = is_cache | not_pad
575
+
576
+ # # Apply attention mask if provided
577
+ # if attention_mask is not None:
578
+ # q_idx_curr = q_idx - cache_len
579
+ # attn_valid = attention_mask[b, q_idx_curr, current_idx]
580
+ # valid = valid & (is_cache | attn_valid)
581
+
582
+ # return is_causal & valid
583
+
584
+ # def mask_mod(b, h, q_idx, kv_idx):
585
+ # is_causal = q_idx >= kv_idx
586
+ # is_current = (kv_idx >= cache_len) & (kv_idx < kv_length)
587
+ # current_idx = kv_idx - cache_len
588
+
589
+ # is_valid = (~is_current) | (
590
+ # (current_idx >= 0) & (current_idx < seq_len) & (input_ids != pad_token_id)[b, current_idx % seq_len]
591
+ # )
592
+
593
+ # return is_causal & is_valid
594
+
595
+ # # Define mask_mod function
596
+ # def mask_mod(b, h, q_idx, kv_idx):
597
+ # # Always apply causal constraint
598
+ # is_causal = q_idx >= kv_idx
599
+
600
+ # # Handle cache vs current tokens
601
+ # is_cache = kv_idx < cache_len
602
+ # current_idx = kv_idx - cache_len
603
+ # in_bounds = (current_idx >= 0) & (current_idx < seq_len)
604
+
605
+ # # For cache: always valid; For current: check padding
606
+ # not_pad = (input_ids[b, current_idx % seq_len] != pad_token_id) | ~in_bounds
607
+ # valid = is_cache | (not_pad & in_bounds)
608
+
609
+ # # Apply attention mask if provided
610
+ # if attention_mask is not None:
611
+ # q_idx_curr = q_idx - cache_len
612
+ # q_in_bounds = (q_idx_curr >= 0) & (q_idx_curr < seq_len)
613
+ # attn_valid = attention_mask[b, q_idx_curr % seq_len, current_idx % seq_len] | ~(in_bounds & q_in_bounds)
614
+ # valid = valid & (is_cache | attn_valid)
615
+
616
+ # return is_causal & valid
617
+
618
+ # Create block mask
619
+ block_mask = create_block_mask(
620
+ mask_mod,
621
+ B=batch_size,
622
+ H=None,
623
+ Q_LEN=seq_len,
624
+ KV_LEN=kv_length,
625
+ device=input_ids.device,
626
+ )
627
+
628
+ return block_mask
629
+
630
  def forward(
631
  self,
632
  input_ids: torch.Tensor,
633
  input_embeds: Optional[torch.Tensor] = None,
634
  input_states: Optional[torch.Tensor] = None,
635
+ attention_mask: Optional[torch.Tensor] = None, # binary mask of shape q x kv, True=valid position
636
  position_ids: Optional[torch.Tensor] = None,
637
  labels: Optional[torch.Tensor] = None,
638
  num_steps: Optional[torch.Tensor] = None,
639
+ past_key_values: Optional[ValidCache] = None,
640
  output_details: dict = {
641
  "return_logits": True,
642
  "return_latents": True,
 
643
  "return_head": False,
644
  "return_stats": False,
645
  },
646
  use_cache: bool = False,
647
  cache_position: Optional[torch.Tensor] = None,
648
+ init_scale: float = 1.0,
649
  **kwargs,
650
  ) -> CausalLMOutputRecurrentLatents:
651
  # Support multiple position formats:
 
657
  freqs_cis = self.freqs_cis[:, cache_position]
658
 
659
  if input_embeds is None:
660
+ input_embeds = self.transformer.wte(input_ids) # type: ignore # types broken in 2.6+
661
 
662
  if self.emb_scale != 1:
663
  input_embeds = input_embeds * self.emb_scale # type: ignore
664
 
665
  if use_cache and past_key_values is None:
666
  past_key_values = HuginnDynamicCache()
 
 
667
 
668
+ prepared_attn_mask = None # self.compile_mask(input_ids, attention_mask, past_key_values)
669
+ block_idx = torch.tensor(-1, device=torch.device("cpu"), dtype=torch.long) # count in tensors for compile
670
  # Non-recurrent prelude
671
+ for block in self.transformer.prelude: # type: ignore # types broken in 2.6+
672
+ block_idx += 1
673
+ input_embeds = block(input_embeds, freqs_cis, block_idx, prepared_attn_mask, past_key_values)
 
 
674
 
675
  # Main recurrence
676
+ x, num_steps_no_grad, num_steps_with_grad, xk, block_idx = self.iterate_forward(
677
+ input_embeds, # type: ignore # mystery typing error
678
  input_states,
679
  freqs_cis,
680
  block_idx,
681
+ prepared_attn_mask,
682
  past_key_values,
683
  num_steps,
684
+ init_scale,
 
685
  )
686
  latent_states = x.clone().detach()
687
 
688
  # Coda layers
689
+ block_idx = torch.tensor(0, device=torch.device("cpu"), dtype=torch.long) # use negative indices for head
690
+ for block in self.transformer.coda: # type: ignore # types broken in 2.6+
691
+ block_idx -= 1
692
+ x = block(x, freqs_cis, block_idx, prepared_attn_mask, past_key_values)
693
+ x = self.transformer.ln_f(x) # type: ignore # types broken in 2.6+
694
 
695
  # Prediction head, assuming labels really are labels and not equal to input_ids
696
  if labels is not None:
697
  logits = self.lm_head(x).float()
698
+ loss = torch.nn.functional.cross_entropy(
699
+ logits.view(-1, logits.shape[-1]), labels.view(-1), ignore_index=-100
700
+ )
701
  log_ppl = loss.clone().detach().exp()
702
  else:
703
  logits = self.lm_head(x).float()
 
710
  past_key_values=past_key_values,
711
  hidden_states=x if output_details["return_head"] else None,
712
  latent_states=latent_states if output_details["return_latents"] else None,
 
713
  stats=self.get_stats(logits, x, latent_states, xk, input_embeds, num_steps_no_grad, num_steps_with_grad)
714
  if output_details["return_stats"]
715
  else None,
 
718
  @torch._dynamo.disable(recursive=False) # type: ignore
719
  def iterate_forward(
720
  self,
721
+ input_embeds: torch.Tensor,
722
+ input_states: torch.Tensor,
723
  freqs_cis,
724
+ block_idx: torch.Tensor,
725
+ mask: Optional[BlockMask],
726
+ past_key_values: Optional[ValidCache] = None,
727
  num_steps: Optional[torch.Tensor] = None,
728
+ init_scale: float = 1.0,
 
729
  ):
730
+ x = xk = self.initialize_state(input_embeds, scale=init_scale) if input_states is None else input_states.clone()
731
  if num_steps is None:
732
  num_steps_no_grad, num_steps_with_grad = self.randomized_iteration_sampler() # type: ignore
733
  elif hasattr(num_steps, "__len__") and len(num_steps) > 1:
 
740
  # https://discuss.pytorch.org/t/does-distributeddataparallel-work-with-torch-no-grad-and-find-unused-parameters-false/122594
741
  # for now running with find_unused_params=True enabled even though the graph structure is (technically) clear
742
  # and all parameters are always used
743
+ for no_grad_step in range(num_steps_no_grad):
744
  xk = x
745
+ x, block_idx = self.core_block_forward(
746
+ xk, input_embeds, freqs_cis, mask, past_key_values, block_idx, no_grad_step
747
  )
748
 
749
+ for grad_step in range(num_steps_with_grad):
750
  xk = x
751
+ x, block_idx = self.core_block_forward(
752
+ xk, input_embeds, freqs_cis, mask, past_key_values, block_idx, num_steps_no_grad + grad_step
753
  )
754
+ return self.transformer.ln_f(x), num_steps_no_grad, num_steps_with_grad, xk.detach(), block_idx # type: ignore # types broken in 2.6+
755
 
756
  def core_block_forward(
757
  self,
758
  x,
759
  input_embeds,
760
  freqs_cis,
761
+ mask: Optional[BlockMask],
762
  past_key_values,
763
+ block_idx: torch.Tensor,
764
+ current_step: int | Tensor,
 
765
  ):
766
+ x = self._maybe_inject_noise(x, current_step)
767
+ x = self.transformer.adapter(torch.cat([x, input_embeds.to(x.device)], dim=-1)) # type: ignore # types broken in 2.6+
768
+ for block in self.transformer.core_block: # type: ignore # types broken in 2.6+
769
+ block_idx += 1
770
+ x = block(x, freqs_cis, block_idx, mask, past_key_values)
771
+ return x, block_idx
772
 
773
  @torch.no_grad()
774
  def iterate_one_step(
 
777
  input_states,
778
  position_ids: Optional[torch.Tensor] = None,
779
  cache_position: Optional[torch.Tensor] = None,
780
+ block_idx: torch.Tensor = torch.tensor(0, dtype=torch.long),
781
+ attention_mask: Optional[BlockMask] = None,
782
+ past_key_values: Optional[ValidCache] = None,
783
+ current_step: int = 0,
784
  ):
785
  if position_ids is None and cache_position is None:
786
  freqs_cis = self.freqs_cis[:, : input_embeds.shape[1]]
 
788
  freqs_cis = self.freqs_cis.index_select(1, position_ids.squeeze())
789
  elif cache_position is not None:
790
  freqs_cis = self.freqs_cis[:, cache_position]
791
+ x, block_idx = self.core_block_forward(
792
+ input_states,
793
+ input_embeds,
794
+ freqs_cis,
795
+ attention_mask,
796
+ past_key_values,
797
+ block_idx,
798
+ current_step=current_step,
799
  )
800
+ return x, block_idx, current_step + 1
801
 
802
  def predict_from_latents(
803
  self,
804
  latents,
805
+ attention_mask: Optional[BlockMask] = None,
806
  position_ids: Optional[torch.Tensor] = None,
807
  cache_position: Optional[torch.Tensor] = None,
808
+ past_key_values: Optional[ValidCache] = None,
 
 
809
  ):
810
  if position_ids is None and cache_position is None:
811
  freqs_cis = self.freqs_cis[:, : latents.shape[1]]
 
813
  freqs_cis = self.freqs_cis.index_select(1, position_ids.squeeze())
814
  elif cache_position is not None:
815
  freqs_cis = self.freqs_cis[:, cache_position]
816
+ x = self.transformer.ln_f(latents) # type: ignore # types broken in 2.6+
817
  # Coda layers
818
+ block_idx = torch.tensor(0, device=torch.device("cpu"), dtype=torch.long) # use negative indices for head
819
+ for block in self.transformer.coda: # type: ignore # types broken in 2.6+
820
+ block_idx -= 1
821
+ x = block(x, freqs_cis, block_idx, attention_mask, past_key_values)
822
+ x = self.transformer.ln_f(x) # type: ignore # types broken in 2.6+
823
 
824
  logits = self.lm_head(x).float()
825
 
 
828
  log_ppl=torch.as_tensor(0.0),
829
  logits=logits,
830
  past_key_values=past_key_values,
831
+ latent_states=x,
832
  )
833
 
834
  def embed_inputs(
 
836
  input_ids: torch.Tensor,
837
  attention_mask: Optional[torch.Tensor] = None,
838
  position_ids: Optional[torch.Tensor] = None,
839
+ past_key_values: Optional[ValidCache] = None,
840
  use_cache: bool = False,
841
  cache_position: Optional[torch.Tensor] = None,
 
842
  **kwargs,
843
+ ) -> tuple[torch.Tensor, torch.Tensor]:
844
  # Support multiple position formats:
845
  if position_ids is None and cache_position is None:
846
  freqs_cis = self.freqs_cis[:, : input_ids.shape[1]]
 
849
  elif cache_position is not None:
850
  freqs_cis = self.freqs_cis[:, cache_position]
851
 
852
+ input_embeds = self.transformer.wte(input_ids) # type: ignore # types broken in 2.6+
853
+ prepared_attn_mask = self.compile_mask(input_ids, attention_mask)
854
 
855
  if self.emb_scale != 1:
856
  input_embeds = input_embeds * self.emb_scale # type: ignore
 
858
  if use_cache and past_key_values is None:
859
  past_key_values = HuginnDynamicCache()
860
 
861
+ block_idx = torch.tensor(-1, device=torch.device("cpu"), dtype=torch.long) # count in tensors for compile
862
  # Non-recurrent prelude
863
+ for block in self.transformer.prelude: # type: ignore # types broken in 2.6+
864
+ block_idx += 1
865
+ input_embeds = block(input_embeds, freqs_cis, block_idx, prepared_attn_mask, past_key_values)
866
+ return input_embeds, block_idx
 
 
867
 
868
  @torch._dynamo.disable(recursive=False) # type: ignore
869
  def randomized_iteration_sampler(self) -> tuple[torch.Tensor, torch.Tensor]:
 
886
 
887
  return n.to(dtype=torch.long), k.to(dtype=torch.long)
888
 
889
+ def initialize_state(self, input_embeds, scale: float = 1.0):
890
  x = torch.randn_like(input_embeds)
891
+ std = self.config.init_values["std"] * scale
892
+ if std > 0:
893
+ torch.nn.init.trunc_normal_(x, mean=0.0, std=std, a=-3 * std, b=3 * std)
894
+ if self.emb_scale != 1:
895
+ x = x * self.emb_scale
896
+ else:
897
+ x.zero_()
898
+ return x
899
+
900
+ def _maybe_inject_noise(self, x, current_step, renorm=False):
901
+ if self.config.test_time_noise > 0:
902
+ n = self.config.test_time_noise * self.config.init_values["std"] * self.emb_scale
903
+ if self.config.test_time_noise_type == "geom":
904
+ step1 = torch.as_tensor(current_step + 1, device=x.device) # need to cast for compile
905
+ x = x * (1 - n / step1) + torch.randn_like(x) * n / step1
906
+ elif self.config.test_time_noise_type == "sqrt":
907
+ step1sqrt = torch.as_tensor(current_step + 1, device=x.device).sqrt() # need to cast for compile
908
+ x = x * (1 - n / step1sqrt) + torch.randn_like(x) * n / step1sqrt
909
+ elif self.config.test_time_noise_type == "line":
910
+ noise = max(n, (self.config.mean_recurrence - current_step) / self.config.mean_recurrence) # type: ignore
911
+ x = x * (1 - noise) + torch.randn_like(x) * noise
912
+ elif self.config.test_time_noise_type == "chi":
913
+ noise = 2 * torch.rand(1, device=x.device, dtype=x.dtype) * n
914
+ x = x * (1 - noise) + torch.randn_like(x) * noise
915
+ elif self.config.test_time_noise_type == "fixed":
916
+ x = x * (1 - n) + torch.randn_like(x) * n
917
+ else:
918
+ raise ValueError()
919
+
920
+ if renorm:
921
+ x = self.transformer.core_block[-1].norm_4(x) # type: ignore moduledict types still broken in pytorch
922
+ return x
923
 
924
  def prepare_inputs_for_generation(
925
  self,
926
+ input_ids: torch.Tensor,
927
  past_key_values: Optional[Cache] = None,
928
+ attention_mask: Optional[torch.Tensor] = None,
929
  inputs_embeds: Optional[torch.FloatTensor] = None,
930
+ cache_position: Optional[torch.Tensor] = None,
931
+ cache_lookup_strategy: str = "full",
932
  **kwargs,
933
  ):
934
  model_inputs = {}
935
  model_inputs["cache_position"] = cache_position
936
  current_input_length = input_ids.shape[1]
937
+
938
  if past_key_values is not None:
939
+ if not isinstance(past_key_values, (HuginnDynamicCache, HuginnStaticCache)):
940
+ assert past_key_values.get_seq_length() == 0 # only replace empty caches
941
+ # Need to use custom cache, detect and replace HF cache if generate injects it
942
+ if isinstance(past_key_values, StaticCache):
943
+ past_key_values = HuginnStaticCache(
944
+ max_length=getattr(self.generation_config, "max_length", self.config.block_size),
945
+ max_num_steps=4 + kwargs.get("num_steps", self.config.mean_recurrence) * 4,
946
+ num_heads=self.config.num_key_value_heads,
947
+ hidden_dim=self.config.n_embd // self.config.num_attention_heads,
948
+ dtype=torch.bfloat16,
949
+ device=input_ids.device,
950
+ lookup_strategy=cache_lookup_strategy,
951
+ )
952
+ else:
953
+ past_key_values = HuginnDynamicCache(lookup_strategy=cache_lookup_strategy)
954
  model_inputs["past_key_values"] = past_key_values if kwargs["use_cache"] else None
955
  input_ids = input_ids[:, cache_position] # type: ignore
 
956
 
957
+ model_inputs["input_ids"] = input_ids.clone(memory_format=torch.contiguous_format)
958
  if cache_position is None:
959
  position_ids = torch.arange(current_input_length)[None, :].to(input_ids.device)
960
  model_inputs["position_ids"] = position_ids[:, -current_input_length:].clone(
 
971
  def generate(self, *args, **kwargs):
972
  """Dispatcher - use HF generate in all normal cases."""
973
  self.generation_config = args[1] if len(args) > 1 else self.generation_config
974
+ if any(k in kwargs for k in ("criterion", "exit_threshold")):
975
+ # print("Dispatching to custom generate_adaptive function call")
 
 
 
976
  return self.generate_with_adaptive_compute(*args, **kwargs)
977
+ elif "continuous_compute" in kwargs:
978
+ # print("Dispatching to custom generate_minimal function call")
979
+ return self.generate_minimal(*args, **kwargs)
980
  else:
981
  return super().generate(*args, **kwargs)
982
 
983
+ @torch.no_grad()
984
+ def _prep_generate_args(
985
+ self,
986
+ input_ids: torch.Tensor,
987
+ generation_config: Optional[GenerationConfig] = None, # type: ignore
988
+ cache_lookup_strategy: str = "full",
989
+ model_kwargs: dict = {},
990
+ ):
991
+ # Setup
992
+ if generation_config is None:
993
+ generation_config: GenerationConfig = self.generation_config # type: ignore
994
+ if "max_new_tokens" in model_kwargs:
995
+ max_new_tokens = model_kwargs["max_new_tokens"]
996
+ if "max_length" in model_kwargs:
997
+ max_new_tokens = min(max_new_tokens, model_kwargs["max_length"] - input_ids.shape[1])
998
+ else:
999
+ max_length = model_kwargs.get("max_length", generation_config.max_length)
1000
+ max_new_tokens = max_length - input_ids.shape[1]
1001
+
1002
+ if "cache_implementation" not in model_kwargs or model_kwargs["cache_implementation"] == "dynamic":
1003
+ model_kwargs["past_key_values"] = HuginnDynamicCache(lookup_strategy=cache_lookup_strategy)
1004
+ else:
1005
+ model_kwargs["past_key_values"] = HuginnStaticCache(
1006
+ max_length=max_length,
1007
+ max_num_steps=4 + model_kwargs.get("num_steps", self.config.mean_recurrence) * 4,
1008
+ num_heads=self.config.num_key_value_heads,
1009
+ hidden_dim=self.config.n_embd // self.config.num_attention_heads,
1010
+ batch_size=input_ids.shape[0],
1011
+ dtype=torch.bfloat16,
1012
+ device=input_ids.device,
1013
+ lookup_strategy=cache_lookup_strategy,
1014
+ )
1015
+ model_kwargs["use_cache"] = True
1016
+ model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
1017
+ return model_kwargs, generation_config, max_new_tokens
1018
+
1019
  @torch.no_grad()
1020
  def generate_minimal(
1021
  self,
1022
+ input_ids: torch.Tensor,
1023
  generation_config: Optional[GenerationConfig] = None, # type: ignore
1024
  tokenizer=None,
1025
  streamer=None,
1026
  continuous_compute=False, # warm-start state / continuous CoT
1027
+ init_scale: float = 1.0,
1028
+ cache_lookup_strategy: str = "full",
1029
  **model_kwargs,
1030
  ) -> Union[torch.Tensor, dict[str, Any]]:
1031
  """Minimal single-sequence generation. Template for more complicated generate tasks"""
1032
+ model_kwargs, generation_config, max_new_tokens = self._prep_generate_args(
1033
+ input_ids, generation_config, cache_lookup_strategy
1034
+ )
1035
+ stop_tokens = self._get_stops(generation_config, tokenizer, model_kwargs).to(input_ids.device)
1036
+ unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
1037
+
1038
+ # Set up continuous compute if enabled
1039
  if continuous_compute:
1040
+ embedded_inputs, _ = self.embed_inputs(input_ids)
1041
+ model_kwargs["input_states"] = self.initialize_state(embedded_inputs, scale=init_scale)
1042
+
1043
  # Generate tokens
1044
+ batch_size = input_ids.shape[0]
1045
+ for _ in range(max_new_tokens):
1046
  # Forward pass
1047
  model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1048
+ outputs = self(**model_inputs, init_scale=init_scale)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1049
 
1050
+ # Get next token
1051
+ next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
1052
+ next_token = self._sample_next_token(next_token_logits, generation_config)
1053
+
1054
+ # Append token to sequence
1055
+ input_ids = torch.cat([input_ids, next_token], dim=-1)
1056
 
1057
  if streamer:
1058
  streamer.put(next_token.cpu())
 
1060
  # Update model kwargs
1061
  model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs)
1062
  if continuous_compute:
1063
+ model_kwargs["input_states"] = outputs.latent_states[:, -1:, :]
1064
+
1065
+ if stop_tokens is not None:
1066
+ for i in range(batch_size):
1067
+ if unfinished_sequences[i] and next_token[i, 0].item() in stop_tokens:
1068
+ unfinished_sequences[i] = 0
1069
+ if "stopping_criteria" in model_kwargs:
1070
+ unfinished_sequences = unfinished_sequences & ~model_kwargs["stopping_criteria"](input_ids, None)
1071
+ if unfinished_sequences.max() == 0:
1072
  break
1073
 
1074
  if streamer:
 
1076
 
1077
  if generation_config.return_dict_in_generate:
1078
  return GenerateDecoderOnlyOutput(
1079
+ sequences=input_ids, # type: ignore
1080
  scores=None,
1081
  logits=None,
1082
  attentions=None,
 
1088
  @torch.no_grad()
1089
  def generate_with_adaptive_compute(
1090
  self,
1091
+ input_ids: torch.Tensor,
1092
  generation_config: Optional[GenerationConfig] = None, # type: ignore
1093
  tokenizer=None,
1094
  streamer=None,
1095
  continuous_compute=False, # warm-start state / continuous CoT
1096
+ criterion="none", # off by default, turn on by choosing an exit criterion
 
1097
  exit_threshold: Union[str, float, int] = "auto",
1098
+ init_scale: float = 1.0,
1099
+ cache_lookup_strategy: str = "full",
1100
  **model_kwargs,
1101
  ) -> Union[torch.Tensor, GenerateDecoderOnlyOutput]:
1102
  """
1103
  Generate tokens with adaptive compute. This is NOT the most efficient implementation.
1104
  For batches, on each token, we iterate until the entire batch finishes.
1105
  """
1106
+ model_kwargs, generation_config, max_new_tokens = self._prep_generate_args(
1107
+ input_ids, generation_config, cache_lookup_strategy, model_kwargs
1108
+ )
1109
+ max_steps = model_kwargs.get("num_steps", self.config.mean_recurrence)
1110
+ stop_tokens = self._get_stops(generation_config, tokenizer, model_kwargs).to(input_ids.device)
1111
+ logit_type = dict(copy=True, dtype=torch.float32, device=input_ids.device)
 
1112
  batch_size = input_ids.shape[0]
1113
  compute_steps = []
1114
 
1115
  # Set up continuous compute if enabled
1116
  if continuous_compute:
1117
+ embedded_inputs, _ = self.embed_inputs(input_ids)
1118
+ model_kwargs["input_states"] = self.initialize_state(embedded_inputs, scale=init_scale)
1119
 
1120
+ # Track which sequences have finished (using unfinished_sequences to match generate_minimal)
1121
+ unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
1122
 
1123
  # Generate tokens
1124
+ for _ in range(max_new_tokens):
1125
  # Adaptive compute forward
1126
  model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1127
  aux_inputs = {
1128
  k: model_inputs[k] for k in ["cache_position", "past_key_values", "attention_mask"] if k in model_inputs
1129
  }
1130
+ embedded_inputs, block_idx = self.embed_inputs(model_inputs["input_ids"], **aux_inputs)
1131
+ current_latents = (
1132
+ self.initialize_state(embedded_inputs, scale=init_scale)
1133
+ if not continuous_compute
1134
+ else model_kwargs["input_states"]
1135
+ )
1136
 
1137
  # Initialize criterion tracking for each sequence in batch
1138
  exit_values_per_seq = [[] for _ in range(batch_size)]
 
1143
  if criterion == "entropy-diff":
1144
  entropy = torch.ones(batch_size, device=input_ids.device) * 100.0
1145
  exit_threshold = 1e-3 if exit_threshold == "auto" else float(exit_threshold)
1146
+ elif criterion == "latent-diff":
1147
  exit_threshold = 0.03 if exit_threshold == "auto" else float(exit_threshold)
1148
  elif "kl" in criterion:
1149
  V = self.config.padded_vocab_size
1150
+ log_probs = ((1 / V) * torch.ones(batch_size, V, dtype=torch.float, device=input_ids.device)).log()
1151
  if criterion == "minp-kl":
1152
  exit_threshold = 1e-6 if exit_threshold == "auto" else float(exit_threshold)
1153
  else:
 
1156
  stable_for_n_steps = torch.zeros(batch_size, dtype=torch.long, device=input_ids.device)
1157
  current_argmax = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) * -1
1158
  exit_threshold = 5 if exit_threshold == "auto" else int(exit_threshold)
1159
+ elif criterion == "none":
1160
+ exit_threshold = 1.0 if exit_threshold == "auto" else float(exit_threshold)
1161
  else:
1162
  raise ValueError("Invalid adaptive compute strategy.")
1163
 
 
1164
  next_token_logits = None
1165
 
1166
  # Iterate through compute steps
1167
+ for compute_step in range(max_steps):
1168
  prev_latents = current_latents.clone()
1169
  current_latents, block_idx, _ = self.iterate_one_step(
1170
+ embedded_inputs,
1171
+ current_latents,
1172
+ block_idx=block_idx,
1173
+ **aux_inputs,
1174
+ current_step=compute_step,
1175
  )
1176
 
1177
+ if _ > 0: # do not exit in prefill
 
 
 
1178
  # Check exit condition for each sequence in batch
1179
  if criterion == "entropy-diff":
1180
  prev_entropy = entropy
 
1183
  probs = F.softmax(logits[:, -1, :], dim=-1)
1184
  entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=-1)
1185
  exit_values = (entropy - prev_entropy).abs()
 
1186
  elif criterion == "latent-diff":
1187
  norm_diff = (prev_latents - current_latents).norm(dim=-1) / current_latents.norm(dim=-1)
1188
  exit_values = norm_diff.mean(dim=-1)
 
1189
  elif "kl" in criterion:
1190
  outputs = self.predict_from_latents(current_latents, **aux_inputs)
1191
  logits: torch.Tensor = outputs.logits # type: ignore
1192
  prev_log_probs = log_probs
1193
  if criterion == "minp-kl":
1194
+ probs = F.softmax(logits[:, -1, :].float(), dim=-1)
1195
  max_probs = probs.max(dim=-1, keepdim=True)[0]
1196
  probs_mask = probs < (0.1 * max_probs)
1197
+ masked_probs = probs.clone()
1198
  masked_probs[probs_mask] = 1 / V
1199
  probs = masked_probs / masked_probs.sum(dim=-1, keepdim=True)
1200
  log_probs = probs.log()
1201
  else:
1202
+ log_probs = F.log_softmax(logits[:, -1, :].float(), dim=-1)
1203
  exit_values = F.kl_div(log_probs, prev_log_probs, reduction="none", log_target=True).sum(dim=-1)
 
1204
  elif criterion == "argmax-stability":
1205
  prev_argmax = current_argmax
1206
  outputs = self.predict_from_latents(current_latents, **aux_inputs)
 
1210
  current_argmax == prev_argmax, stable_for_n_steps + 1, torch.zeros_like(stable_for_n_steps)
1211
  )
1212
  exit_values = stable_for_n_steps
1213
+ elif criterion == "none":
1214
+ exit_values = torch.ones(batch_size, device=input_ids.device) * 2.0 * exit_threshold
1215
 
1216
  # Record values and check exits for each sequence
1217
  for i in range(batch_size):
1218
+ if not exit_reached[i] and unfinished_sequences[i].bool():
1219
  exit_values_per_seq[i].append(exit_values[i].item())
1220
 
1221
+ # Check for new exits, respecting unfinished_sequences
1222
  new_exits = (
1223
  exit_values < exit_threshold
1224
  if criterion != "argmax-stability"
1225
  else exit_values >= exit_threshold
1226
  )
1227
+ new_exits = new_exits & ~exit_reached & unfinished_sequences.bool()
1228
 
1229
  if new_exits.any():
1230
  exit_reached = exit_reached | new_exits
 
1234
  outputs = self.predict_from_latents(current_latents, **aux_inputs)
1235
  logits: torch.Tensor = outputs.logits # type: ignore
1236
  if next_token_logits is None:
1237
+ next_token_logits = logits[:, -1, :].to(**logit_type) # type: ignore
1238
  else:
1239
+ for i in range(batch_size):
1240
+ if new_exits[i]:
1241
+ next_token_logits[i] = logits[i, -1, :].to(**logit_type) # type: ignore
1242
  for i in range(batch_size):
1243
  if new_exits[i]:
1244
  compute_steps_per_seq[i] = compute_step + 1
1245
 
1246
+ # If all sequences have exited or finished, break early
1247
+ if (exit_reached | ~unfinished_sequences.bool()).all():
1248
  break
1249
  # This else is if the for loop finished without breaking
1250
  else:
1251
+ outputs = self.predict_from_latents(current_latents, **aux_inputs)
 
 
 
 
1252
 
1253
  # For sequences that didn't exit early, use the final logits
1254
  if next_token_logits is None:
1255
+ next_token_logits = outputs.logits[:, -1, :].to(**logit_type) # type: ignore
1256
  else:
 
 
 
 
 
 
 
 
 
1257
  for i in range(batch_size):
1258
+ if not exit_reached[i] and unfinished_sequences[i].bool():
1259
+ next_token_logits[i] = outputs.logits[i, -1, :].to(**logit_type) # type: ignore
1260
+ compute_steps_per_seq[i] = max_steps
1261
 
1262
  # Save latent states for continuous compute if enabled
1263
  if continuous_compute:
1264
+ model_kwargs["input_states"] = current_latents[:, -1:, :]
1265
 
1266
  # Record compute steps for this token generation
1267
  compute_steps.append([compute_steps_per_seq, exit_values_per_seq])
1268
 
1269
  # Sample or select next token based on generation config
1270
+ next_token = self._sample_next_token(next_token_logits, generation_config)
 
 
 
 
 
 
 
 
 
1271
 
1272
+ # Append token to sequence
1273
+ input_ids = torch.cat([input_ids, next_token], dim=-1)
1274
 
1275
  if streamer:
1276
  streamer.put(next_token.cpu())
1277
 
1278
+ # Update model kwargs for next iteration
1279
  model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs)
 
 
1280
 
1281
+ # Check for stop tokens and update unfinished sequences
1282
  for i in range(batch_size):
1283
+ if (
1284
+ unfinished_sequences[i].bool()
1285
+ and stop_tokens is not None
1286
+ and next_token[i, 0].item() in stop_tokens
1287
+ ):
1288
+ unfinished_sequences[i] = 0
1289
+
1290
+ # Apply any custom stopping criteria
1291
+ if "stopping_criteria" in model_kwargs:
1292
+ unfinished_sequences = unfinished_sequences & ~model_kwargs["stopping_criteria"](input_ids, None)
1293
 
1294
  # Break if all sequences are finished
1295
+ if unfinished_sequences.max() == 0:
1296
  break
1297
 
1298
  if streamer:
 
1300
 
1301
  if generation_config.return_dict_in_generate:
1302
  return GenerateDecoderOnlyOutput(
1303
+ sequences=input_ids, # type: ignore
1304
  scores=compute_steps, # type: ignore
1305
  logits=None,
1306
  attentions=None,
 
1309
  )
1310
  return input_ids
1311
 
1312
+ def _get_stops(self, generation_config, tokenizer, model_kwargs):
1313
+ stop_tokens = {65504, 65505, 65508} # begin_text, end_text, end_turn
1314
  if generation_config.eos_token_id is not None:
1315
  stop_tokens.add(generation_config.eos_token_id)
1316
+ if "stopping_criteria" in model_kwargs and tokenizer is None:
1317
+ tokenizer = model_kwargs["stopping_criteria"][0].tokenizer
1318
  if hasattr(generation_config, "stop_strings") and tokenizer and generation_config.stop_strings:
1319
  for s in generation_config.stop_strings:
1320
  token_id = tokenizer(s, add_special_tokens=False)["input_ids"][0]
1321
  stop_tokens.add(token_id)
1322
  return torch.tensor(list(stop_tokens))
1323
 
1324
+ def _sample_next_token(self, next_token_logits, generation_config):
1325
  """Helper function to sample the next token."""
1326
+ if generation_config.do_sample:
1327
+ if generation_config.temperature:
1328
+ next_token_logits = next_token_logits.float() / generation_config.temperature
1329
+
1330
+ probs = F.softmax(next_token_logits, dim=-1)
1331
+
1332
+ # Apply top_k
1333
+ if generation_config.top_k:
1334
+ top_k_values, _ = torch.topk(probs, generation_config.top_k, dim=-1)
1335
+ min_values = top_k_values[:, -1].unsqueeze(-1).expand_as(probs)
1336
+ probs = torch.where(probs < min_values, torch.zeros_like(probs), probs)
1337
+
1338
+ # Apply top_p (nucleus sampling)
1339
+ if generation_config.top_p:
1340
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
1341
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
1342
+
1343
+ # Create mask for probs to keep
1344
+ remove_indices = cumulative_probs > generation_config.top_p
1345
+ remove_indices[:, 0] = False # Keep at least the top probability
1346
+
1347
+ # Convert sorted indices mask back to original indices mask
1348
+ mask = torch.zeros_like(probs, dtype=torch.bool)
1349
+ for i in range(probs.shape[0]):
1350
+ mask[i, sorted_indices[i, remove_indices[i]]] = True
1351
+
1352
+ probs = torch.where(mask, torch.zeros_like(probs), probs)
1353
+
1354
+ # Apply min_p
1355
+ if generation_config.min_p:
1356
+ max_probs = probs.max(dim=-1, keepdim=True)[0]
1357
+ min_p_threshold = generation_config.min_p * max_probs
1358
+ probs = torch.where(probs < min_p_threshold, torch.zeros_like(probs), probs)
1359
+
1360
+ # Renormalize probabilities
1361
+ probs = probs / probs.sum(dim=-1, keepdim=True).clamp(min=1e-10)
1362
+
1363
+ # Sample from the distribution
1364
+ return torch.multinomial(probs, num_samples=1)
1365
+ else:
1366
+ return torch.argmax(next_token_logits, dim=-1, keepdim=True)
1367
+
1368
+ @torch.no_grad()
1369
+ def generate_speculative(
1370
+ self,
1371
+ input_ids: torch.Tensor,
1372
+ generation_config: Optional[GenerationConfig] = None, # type: ignore
1373
+ tokenizer=None,
1374
+ streamer=None,
1375
+ continuous_compute=False, # warm-start state / continuous CoT
1376
+ init_scale: float = 1.0,
1377
+ cache_lookup_strategy: str = "full",
1378
+ draft_steps=32,
1379
+ lookahead_for_draft=8,
1380
+ verification_threshold=1,
1381
+ num_steps: int = 32, # intercept deliberately
1382
+ **model_kwargs,
1383
+ ) -> Union[torch.Tensor, dict[str, Any]]:
1384
+ """Batched speculative decoding with per-sequence acceptance."""
1385
+ assert lookahead_for_draft > 0
1386
+ pad_id = 65509
1387
+ model_kwargs, generation_config, max_new_tokens = self._prep_generate_args(
1388
+ input_ids, generation_config, cache_lookup_strategy, model_kwargs
1389
+ )
1390
+ stop_tokens = self._get_stops(generation_config, tokenizer, model_kwargs).to(input_ids.device)
1391
+ unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
1392
+
1393
+ # Set up continuous compute if enabled
1394
+ if continuous_compute:
1395
+ embedded_inputs, _ = self.embed_inputs(input_ids)
1396
+ model_kwargs["input_states"] = self.initialize_state(embedded_inputs, scale=init_scale)
1397
 
1398
+ tokens_generated = 0
1399
+ # Prefill cache with full num_steps
1400
+ if model_kwargs["past_key_values"].get_seq_length() == 0:
1401
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1402
+ outputs = self(**model_inputs, num_steps=num_steps, init_scale=init_scale)
1403
+ next_token = self._sample_next_token(
1404
+ outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32), generation_config
1405
+ )
1406
+ input_ids = torch.cat([input_ids, next_token], dim=-1)
1407
+ tokens_generated += 1
1408
+ if streamer:
1409
+ streamer.put(next_token.cpu())
1410
+ model_kwargs["cache_position"] = torch.as_tensor(
1411
+ [model_inputs["past_key_values"].get_seq_length()], device=input_ids.device
1412
+ )
1413
+ if continuous_compute:
1414
+ model_kwargs["input_states"] = outputs.latent_states[:, -1:, :]
1415
 
1416
+ # Generate tokens
1417
+ batch_size, prefix_seq_len = input_ids.shape[0], input_ids.shape[1]
1418
+ accepted_tokens = []
1419
+
1420
+ while tokens_generated < max_new_tokens:
1421
+ ### Run the next draft ####
1422
+ drafted_inputs = input_ids.clone()
1423
+ current_len = input_ids.shape[1]
1424
+
1425
+ for _ in range(lookahead_for_draft):
1426
+ model_inputs = self.prepare_inputs_for_generation(drafted_inputs, **model_kwargs)
1427
+ outputs = self(**model_inputs, num_steps=draft_steps, init_scale=init_scale)
1428
+ next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32)
1429
+ next_token = self._sample_next_token(next_token_logits, generation_config)
1430
+ drafted_inputs = torch.cat([drafted_inputs, next_token], dim=-1)
1431
+ model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
1432
+ if continuous_compute:
1433
+ model_kwargs["input_states"] = outputs.latent_states[:, -1:, :]
1434
+
1435
+ model_kwargs["past_key_values"].clear_last_k_entries(lookahead_for_draft)
1436
+
1437
+ ## Verify drafted tokens ###
1438
+ model_kwargs["cache_position"] = torch.arange(
1439
+ current_len - 1, current_len + lookahead_for_draft - 1, device=input_ids.device
1440
+ )
1441
+ model_inputs = self.prepare_inputs_for_generation(drafted_inputs, **model_kwargs)
1442
+ outputs = self(**model_inputs, num_steps=num_steps, init_scale=init_scale)
1443
+ verified_next_token_preds = outputs.logits.argmax(dim=-1)
1444
 
1445
+ if verification_threshold >= 1:
1446
+ mismatched_tokens = (
1447
+ verified_next_token_preds[:, -lookahead_for_draft:] != drafted_inputs[:, current_len:]
1448
+ )
1449
+ not_all_matched, first_mismatch = torch.max(mismatched_tokens, dim=1)
1450
+ else:
1451
+ verified_logits = outputs.logits[:, -lookahead_for_draft:, :]
1452
+ verified_probs = F.softmax(verified_logits, dim=-1)
1453
+ drafted_token_probs = torch.gather(
1454
+ verified_probs, -1, drafted_inputs[:, current_len:].unsqueeze(-1)
1455
+ ).squeeze(-1)
1456
+ max_probs = verified_probs.max(dim=-1)[0]
1457
+ verification_passed = drafted_token_probs >= verification_threshold * max_probs
1458
+ not_all_matched, first_mismatch = torch.max(~verification_passed, dim=1)
1459
+
1460
+ # Per-sequence acceptance handling
1461
+ acceptance_lengths = torch.where(not_all_matched, first_mismatch, lookahead_for_draft)
1462
+
1463
+ # Build next_tokens for each sequence
1464
+ next_tokens_batch = []
1465
+ for i in range(batch_size):
1466
+ seq_acceptance = acceptance_lengths[i].item()
1467
+ if not_all_matched[i] and seq_acceptance < lookahead_for_draft:
1468
+ # Accept up to mismatch + sample final token
1469
+ accepted_part = drafted_inputs[i : i + 1, current_len : current_len + seq_acceptance]
1470
+ final_token_logits = outputs.logits[i : i + 1, seq_acceptance, :].to(copy=True, dtype=torch.float32)
1471
+ final_token = self._sample_next_token(final_token_logits, generation_config)
1472
+ seq_tokens = torch.cat([accepted_part, final_token], dim=-1) if seq_acceptance > 0 else final_token
1473
+ else:
1474
+ # Accept all drafted tokens
1475
+ seq_tokens = drafted_inputs[i : i + 1, current_len : current_len + seq_acceptance]
1476
+ next_tokens_batch.append(seq_tokens)
1477
+
1478
+ # Clean up KV cache - only if any sequence had mismatches
1479
+ if not_all_matched.any():
1480
+ min_first_mismatch = first_mismatch.min().item()
1481
+ model_inputs["past_key_values"].clear_last_k_entries(lookahead_for_draft - min_first_mismatch - 1)
1482
+
1483
+ # Concatenate accepted tokens to input_ids
1484
+ batch_accepted_counts = [tokens.shape[1] for tokens in next_tokens_batch]
1485
+ max_len = max(batch_accepted_counts)
1486
+ padded_tokens = [
1487
+ torch.cat(
1488
+ [
1489
+ tokens,
1490
+ pad_id * torch.ones((1, max_len - tokens.shape[1]), dtype=tokens.dtype, device=tokens.device),
1491
+ ],
1492
+ dim=-1,
1493
+ )
1494
+ if tokens.shape[1] < max_len
1495
+ else tokens
1496
+ for tokens in next_tokens_batch
1497
+ ]
1498
+ next_tokens = torch.cat(padded_tokens, dim=0)
1499
+ input_ids = torch.cat([input_ids, next_tokens], dim=-1)
1500
 
1501
+ accepted_tokens.append(batch_accepted_counts)
1502
+ tokens_generated += max(batch_accepted_counts)
 
1503
 
1504
+ if streamer:
1505
+ streamer.put(next_tokens_batch[0].cpu())
 
 
1506
 
1507
+ model_kwargs["cache_position"] = torch.as_tensor(
1508
+ [model_inputs["past_key_values"].get_seq_length()], device=input_ids.device
1509
+ )
1510
+ if continuous_compute:
1511
+ model_kwargs["input_states"] = outputs.latent_states[:, -1:, :]
1512
+
1513
+ # Check stopping conditions
1514
+ if stop_tokens is not None:
1515
+ for i in range(batch_size):
1516
+ if unfinished_sequences[i] and torch.isin(next_tokens_batch[i], stop_tokens).any():
1517
+ unfinished_sequences[i] = 0
1518
+ if "stopping_criteria" in model_kwargs:
1519
+ unfinished_sequences = unfinished_sequences & ~model_kwargs["stopping_criteria"](input_ids, None)
1520
+ if unfinished_sequences.max() == 0:
1521
+ break
1522
 
1523
+ if streamer:
1524
+ streamer.end()
 
 
 
1525
 
1526
+ # Cut off extraneous parts of the sequence per batch element
1527
+ if stop_tokens is not None:
1528
+ for i in range(batch_size):
1529
+ stop_positions = torch.isin(input_ids[i, prefix_seq_len:], stop_tokens).nonzero()
1530
+ if len(stop_positions) > 0:
1531
+ input_ids[i, prefix_seq_len + stop_positions[0].item() + 1 :] = pad_id
1532
+ # Trim tensor to remove columns that are pad_id across all sequences
1533
+ non_pad_mask = input_ids != pad_id
1534
+ last_real_token = non_pad_mask.any(dim=0).nonzero()
1535
+ if len(last_real_token) > 0:
1536
+ input_ids = input_ids[:, : last_real_token[-1].item() + 1]
1537
 
1538
+ if generation_config.return_dict_in_generate:
1539
+ return GenerateDecoderOnlyOutput(
1540
+ sequences=input_ids, # type: ignore
1541
+ scores=accepted_tokens, # type: ignore
1542
+ logits=None,
1543
+ attentions=None,
1544
+ hidden_states=None,
1545
+ past_key_values=model_kwargs.get("past_key_values"),
1546
+ )
1547
+ return input_ids
1548
 
1549
  def get_stats(self, logits, x, latent_states, xk, input_embeds, num_steps_no_grad, num_steps_with_grad):
1550
  probs = torch.softmax(logits.float(), dim=-1)
 
1600
  # Old?
1601
  AutoConfig.register("huginn_raven", RavenConfig)
1602
  AutoModel.register(RavenConfig, RavenForCausalLM)
1603
+ AutoModelForCausalLM.register(RavenConfig, RavenForCausalLM)