File size: 7,549 Bytes
158994e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
"""
Based on: https://github.com/lucidrains/flamingo-pytorch
"""
import torch.nn as nn
from .helpers import GatedCrossAttentionBlock
from .utils import getattr_recursive, setattr_recursive
from typing import List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from transformers.utils import ModelOutput
import torch
class MixinLayer(nn.Module):
"""
MixinLayer is a wrapper around the GatedCrossAttentionBlock and DecoderLayer.
"""
def __init__(
self, gated_cross_attn_layer, decoder_layer, gradient_checkpointing=False
):
super().__init__()
self.gated_cross_attn_layer = gated_cross_attn_layer
self.decoder_layer = decoder_layer
self.vis_x = None
if self.gated_cross_attn_layer is not None:
self.gated_cross_attn_layer._use_gradient_checkpointing = (
gradient_checkpointing
)
self.decoder_layer._use_gradient_checkpointing = gradient_checkpointing
def is_conditioned(self) -> bool:
"""Check whether the layer is conditioned."""
return self.vis_x is not None
# Used this great idea from this implementation of Flamingo (https://github.com/dhansmair/flamingo-mini/)
def condition_vis_x(self, vis_x):
self.vis_x = vis_x
def condition_media(self, media, text_position_ids):
if self.gated_cross_attn_layer is not None:
self.gated_cross_attn_layer.media = media
self.gated_cross_attn_layer.cross_attn.text_position_ids = text_position_ids
def condition_use_cached_media(self, use_cached_media):
self.use_cached_media = use_cached_media
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
**kwargs,
):
# Cross attention
if self.gated_cross_attn_layer is not None and self.vis_x is not None:
if self.vis_x is None:
raise ValueError("vis_x must be conditioned before forward pass")
hidden_states = self.gated_cross_attn_layer(
hidden_states,
self.vis_x,
use_cached_media=self.use_cached_media,
)
# Normal decoder layer
hidden_states = self.decoder_layer(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
**kwargs
)
return hidden_states
class LMMixin(nn.Module):
"""
Mixin to add cross-attention layers to a language model.
"""
def set_decoder_layers_attr_name(self, decoder_layers_attr_name):
self.decoder_layers_attr_name = decoder_layers_attr_name
def _get_decoder_layers(self):
return getattr_recursive(self, self.decoder_layers_attr_name)
def _set_decoder_layers(self, value):
setattr_recursive(self, self.decoder_layers_attr_name, value)
def init_mixin(
self,
config,
gradient_checkpointing,
):
"""
Initialize Mixin by adding a new gated cross attn to the decoder. Store the media token id for computing the media locations.
"""
self.old_decoder_blocks = self._get_decoder_layers()
mixin_every_n_layers = config.mixin_every_n_layers
self.gated_cross_attn_layers = nn.ModuleList(
[
GatedCrossAttentionBlock(config)
if (layer_idx + 1) % mixin_every_n_layers == 0
else None
for layer_idx, _ in enumerate(self._get_decoder_layers())
]
)
self.init_mixin_layers(gradient_checkpointing)
self.old_decoder_blocks = None
self.gated_cross_attn_layers = None
self.initialized_mixin = True
self._use_cached_vision_x = False
def init_mixin_layers(self, gradient_checkpointing):
"""
Re initializes the FlamingoLayers.
Propagates any changes made to self.gated_corss_attn_layers or self.old_decoder_blocks
"""
self._set_decoder_layers(
nn.ModuleList(
[
MixinLayer(
gated_cross_attn_layer, decoder_layer, gradient_checkpointing
)
for gated_cross_attn_layer, decoder_layer in zip(
self.gated_cross_attn_layers, self.old_decoder_blocks
)
]
)
)
def forward(self, position_ids=None,**kwargs
):
if not self.initialized_mixin:
raise ValueError(
"Flamingo layers are not initialized. Please call `init_flamingo` first."
)
kwargs["position_ids"] = position_ids
return super().forward(**kwargs) # Call the other parent's forward method
def _update_model_kwargs_for_generation(
self,
outputs: ModelOutput,
model_kwargs: Dict[str, Any],
is_encoder_decoder: bool = False,
standardize_cache_format: bool = False,
) -> Dict[str, Any]:
# update past_key_values
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
outputs, standardize_cache_format=standardize_cache_format
)
if getattr(outputs, "state", None) is not None:
model_kwargs["state"] = outputs.state
# update token_type_ids with last value
if "token_type_ids" in model_kwargs:
token_type_ids = model_kwargs["token_type_ids"]
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
if not is_encoder_decoder:
# update attention mask
if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
)
else:
# update decoder attention mask
if "decoder_attention_mask" in model_kwargs:
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
model_kwargs["decoder_attention_mask"] = torch.cat(
[decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))],
dim=-1,
)
# To support RoPE-DHR's position_ids calculation method
if model_kwargs['past_key_values'] and 'position_ids' in model_kwargs:
new_pos_ids = model_kwargs['position_ids'][:, -1:] + 1
model_kwargs['position_ids'] = new_pos_ids
return model_kwargs
def is_conditioned(self) -> bool:
"""Check whether all decoder layers are already conditioned."""
return all(l.is_conditioned() for l in self._get_decoder_layers())
def clear_conditioned_layers(self):
for layer in self._get_decoder_layers():
layer.condition_vis_x(None)
layer.condition_use_cached_media(False)
layer.condition_media(None, None)
|