Update raven_modeling_minimal.py
Browse files- raven_modeling_minimal.py +26 -10
raven_modeling_minimal.py
CHANGED
|
@@ -11,7 +11,7 @@ from .raven_config_minimal import RavenConfig
|
|
| 11 |
from transformers.cache_utils import Cache, DynamicCache
|
| 12 |
|
| 13 |
###################### Huggingface Glue code I ##################################################################
|
| 14 |
-
from transformers import PreTrainedModel
|
| 15 |
from transformers.utils import ModelOutput
|
| 16 |
from transformers.generation.utils import GenerateDecoderOnlyOutput
|
| 17 |
|
|
@@ -32,7 +32,8 @@ class RavenPreTrainedModel(PreTrainedModel):
|
|
| 32 |
_supports_static_cache = False
|
| 33 |
|
| 34 |
def _init_weights(self, module):
|
| 35 |
-
|
|
|
|
| 36 |
|
| 37 |
|
| 38 |
@dataclass
|
|
@@ -309,7 +310,7 @@ class SandwichBlock(torch.nn.Module):
|
|
| 309 |
return x, attn_map
|
| 310 |
|
| 311 |
|
| 312 |
-
class RavenForCausalLM(RavenPreTrainedModel):
|
| 313 |
def __init__(
|
| 314 |
self,
|
| 315 |
config: RavenConfig,
|
|
@@ -367,7 +368,7 @@ class RavenForCausalLM(RavenPreTrainedModel):
|
|
| 367 |
"return_latents": True,
|
| 368 |
"return_attention": False,
|
| 369 |
"return_head": False,
|
| 370 |
-
"return_stats":
|
| 371 |
},
|
| 372 |
use_cache: bool = False,
|
| 373 |
cache_position: Optional[torch.Tensor] = None,
|
|
@@ -395,7 +396,7 @@ class RavenForCausalLM(RavenPreTrainedModel):
|
|
| 395 |
# Non-recurrent prelude
|
| 396 |
for block_idx, block in enumerate(self.transformer.prelude):
|
| 397 |
input_embeds, attn_map = block(
|
| 398 |
-
input_embeds, freqs_cis, block_idx, attention_mask, past_key_values, return_attn
|
| 399 |
)
|
| 400 |
attn_maps[block_idx] = attn_map
|
| 401 |
|
|
@@ -409,12 +410,13 @@ class RavenForCausalLM(RavenPreTrainedModel):
|
|
| 409 |
past_key_values,
|
| 410 |
num_steps,
|
| 411 |
attn_maps,
|
|
|
|
| 412 |
)
|
| 413 |
latent_states = x.clone().detach()
|
| 414 |
|
| 415 |
# Coda layers
|
| 416 |
for block_idx, block in enumerate(self.transformer.coda, start=1):
|
| 417 |
-
x, attn_map = block(x, freqs_cis, -block_idx, attention_mask, past_key_values, return_attn)
|
| 418 |
attn_maps[-block_idx] = attn_map
|
| 419 |
x = self.transformer.ln_f(x)
|
| 420 |
|
|
@@ -451,6 +453,7 @@ class RavenForCausalLM(RavenPreTrainedModel):
|
|
| 451 |
past_key_values: Optional[Cache] = None,
|
| 452 |
num_steps: Optional[torch.Tensor] = None,
|
| 453 |
attn_maps: dict = {},
|
|
|
|
| 454 |
):
|
| 455 |
x = xk = self.initialize_state(input_embeds) if input_states is None else input_states.clone()
|
| 456 |
if num_steps is None:
|
|
@@ -468,13 +471,13 @@ class RavenForCausalLM(RavenPreTrainedModel):
|
|
| 468 |
for step in range(num_steps_no_grad):
|
| 469 |
xk = x
|
| 470 |
x, block_idx, attn_maps = self.core_block_forward(
|
| 471 |
-
xk, input_embeds, freqs_cis, mask, past_key_values, block_idx, attn_maps
|
| 472 |
)
|
| 473 |
|
| 474 |
for step in range(num_steps_with_grad):
|
| 475 |
xk = x
|
| 476 |
x, block_idx, attn_maps = self.core_block_forward(
|
| 477 |
-
xk, input_embeds, freqs_cis, mask, past_key_values, block_idx, attn_maps
|
| 478 |
)
|
| 479 |
return self.transformer.ln_f(x), num_steps_no_grad, num_steps_with_grad, xk.detach(), block_idx, attn_maps
|
| 480 |
|
|
@@ -487,10 +490,11 @@ class RavenForCausalLM(RavenPreTrainedModel):
|
|
| 487 |
past_key_values,
|
| 488 |
block_idx: Union[torch.Tensor, int],
|
| 489 |
attn_maps: dict = {},
|
|
|
|
| 490 |
):
|
| 491 |
x = self.transformer.adapter(torch.cat([x, input_embeds], dim=-1))
|
| 492 |
for idx, block in enumerate(self.transformer.core_block, start=1):
|
| 493 |
-
x, attn_map = block(x, freqs_cis, block_idx + idx, mask, past_key_values, return_attn=
|
| 494 |
attn_maps[block_idx + idx] = attn_map
|
| 495 |
return x, block_idx + idx, attn_maps
|
| 496 |
|
|
@@ -623,7 +627,7 @@ class RavenForCausalLM(RavenPreTrainedModel):
|
|
| 623 |
model_inputs["cache_position"] = cache_position
|
| 624 |
current_input_length = input_ids.shape[1]
|
| 625 |
if past_key_values is not None:
|
| 626 |
-
if type(past_key_values)
|
| 627 |
# Need to use custom cache, detect and replace HF dynamic cache if generate injects it
|
| 628 |
assert past_key_values.get_seq_length() == 0
|
| 629 |
past_key_values = HuginnDynamicCache()
|
|
@@ -643,6 +647,18 @@ class RavenForCausalLM(RavenPreTrainedModel):
|
|
| 643 |
model_inputs[key] = value
|
| 644 |
return model_inputs
|
| 645 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 646 |
@torch.no_grad()
|
| 647 |
def generate_minimal(
|
| 648 |
self,
|
|
|
|
| 11 |
from transformers.cache_utils import Cache, DynamicCache
|
| 12 |
|
| 13 |
###################### Huggingface Glue code I ##################################################################
|
| 14 |
+
from transformers import PreTrainedModel, GenerationMixin
|
| 15 |
from transformers.utils import ModelOutput
|
| 16 |
from transformers.generation.utils import GenerateDecoderOnlyOutput
|
| 17 |
|
|
|
|
| 32 |
_supports_static_cache = False
|
| 33 |
|
| 34 |
def _init_weights(self, module):
|
| 35 |
+
if not torch.rand((1,)).is_meta:
|
| 36 |
+
print("Random Initialization not implemented.")
|
| 37 |
|
| 38 |
|
| 39 |
@dataclass
|
|
|
|
| 310 |
return x, attn_map
|
| 311 |
|
| 312 |
|
| 313 |
+
class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
|
| 314 |
def __init__(
|
| 315 |
self,
|
| 316 |
config: RavenConfig,
|
|
|
|
| 368 |
"return_latents": True,
|
| 369 |
"return_attention": False,
|
| 370 |
"return_head": False,
|
| 371 |
+
"return_stats": False,
|
| 372 |
},
|
| 373 |
use_cache: bool = False,
|
| 374 |
cache_position: Optional[torch.Tensor] = None,
|
|
|
|
| 396 |
# Non-recurrent prelude
|
| 397 |
for block_idx, block in enumerate(self.transformer.prelude):
|
| 398 |
input_embeds, attn_map = block(
|
| 399 |
+
input_embeds, freqs_cis, block_idx, attention_mask, past_key_values, return_attn=return_attn
|
| 400 |
)
|
| 401 |
attn_maps[block_idx] = attn_map
|
| 402 |
|
|
|
|
| 410 |
past_key_values,
|
| 411 |
num_steps,
|
| 412 |
attn_maps,
|
| 413 |
+
return_attn=return_attn,
|
| 414 |
)
|
| 415 |
latent_states = x.clone().detach()
|
| 416 |
|
| 417 |
# Coda layers
|
| 418 |
for block_idx, block in enumerate(self.transformer.coda, start=1):
|
| 419 |
+
x, attn_map = block(x, freqs_cis, -block_idx, attention_mask, past_key_values, return_attn=return_attn)
|
| 420 |
attn_maps[-block_idx] = attn_map
|
| 421 |
x = self.transformer.ln_f(x)
|
| 422 |
|
|
|
|
| 453 |
past_key_values: Optional[Cache] = None,
|
| 454 |
num_steps: Optional[torch.Tensor] = None,
|
| 455 |
attn_maps: dict = {},
|
| 456 |
+
return_attn: bool = False,
|
| 457 |
):
|
| 458 |
x = xk = self.initialize_state(input_embeds) if input_states is None else input_states.clone()
|
| 459 |
if num_steps is None:
|
|
|
|
| 471 |
for step in range(num_steps_no_grad):
|
| 472 |
xk = x
|
| 473 |
x, block_idx, attn_maps = self.core_block_forward(
|
| 474 |
+
xk, input_embeds, freqs_cis, mask, past_key_values, block_idx, attn_maps, return_attn
|
| 475 |
)
|
| 476 |
|
| 477 |
for step in range(num_steps_with_grad):
|
| 478 |
xk = x
|
| 479 |
x, block_idx, attn_maps = self.core_block_forward(
|
| 480 |
+
xk, input_embeds, freqs_cis, mask, past_key_values, block_idx, attn_maps, return_attn
|
| 481 |
)
|
| 482 |
return self.transformer.ln_f(x), num_steps_no_grad, num_steps_with_grad, xk.detach(), block_idx, attn_maps
|
| 483 |
|
|
|
|
| 490 |
past_key_values,
|
| 491 |
block_idx: Union[torch.Tensor, int],
|
| 492 |
attn_maps: dict = {},
|
| 493 |
+
return_attn: bool = False,
|
| 494 |
):
|
| 495 |
x = self.transformer.adapter(torch.cat([x, input_embeds], dim=-1))
|
| 496 |
for idx, block in enumerate(self.transformer.core_block, start=1):
|
| 497 |
+
x, attn_map = block(x, freqs_cis, block_idx + idx, mask, past_key_values, return_attn=return_attn)
|
| 498 |
attn_maps[block_idx + idx] = attn_map
|
| 499 |
return x, block_idx + idx, attn_maps
|
| 500 |
|
|
|
|
| 627 |
model_inputs["cache_position"] = cache_position
|
| 628 |
current_input_length = input_ids.shape[1]
|
| 629 |
if past_key_values is not None:
|
| 630 |
+
if type(past_key_values) != HuginnDynamicCache:
|
| 631 |
# Need to use custom cache, detect and replace HF dynamic cache if generate injects it
|
| 632 |
assert past_key_values.get_seq_length() == 0
|
| 633 |
past_key_values = HuginnDynamicCache()
|
|
|
|
| 647 |
model_inputs[key] = value
|
| 648 |
return model_inputs
|
| 649 |
|
| 650 |
+
@torch.no_grad()
|
| 651 |
+
def generate(self, *args, **kwargs):
|
| 652 |
+
"""Dispatcher - use HF generate in all normal cases."""
|
| 653 |
+
if any(
|
| 654 |
+
k in kwargs
|
| 655 |
+
for k in ("continuous_compute", "latent_dampening", "criterion", "exit_threshold", "cache_kwargs")
|
| 656 |
+
):
|
| 657 |
+
print("Dispatching to custom generate function call")
|
| 658 |
+
return self.generate_with_adaptive_compute(*args, **kwargs)
|
| 659 |
+
else:
|
| 660 |
+
return super().generate(*args, **kwargs)
|
| 661 |
+
|
| 662 |
@torch.no_grad()
|
| 663 |
def generate_minimal(
|
| 664 |
self,
|