AbstractPhil commited on
Commit
77bb325
·
verified ·
1 Parent(s): 2dff7b1

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +201 -0
model.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ─────────────────────────────────────────────────────────────
2
+ # ░ Dual Shunt Adapter Configuration File
3
+ # - Adapter ID: 002
4
+ # - Supports flan-t5-base ↔ clip-vit-large-patch14
5
+ # - Projection stack + dropout + normalization
6
+ # - Surgeworthy training baseline adapter
7
+ # ─────────────────────────────────────────────────────────────
8
+
9
+ ADAPTER_CONFIG = {
10
+ # Model Integration IDs
11
+ "adapter_id": "002",
12
+ "name": "TwoStreamShuntAdapter",
13
+
14
+ # Backbone Model Dimensions
15
+ "t5": {
16
+ "model": "google/flan-t5-base",
17
+ "hidden_size": 768,
18
+ },
19
+ "clip": {
20
+ "model": "openai/clip-vit-large-patch14",
21
+ "hidden_size": 768,
22
+ },
23
+
24
+ # Adapter Dimensions
25
+ "bottleneck": 384,
26
+ "heads": 12,
27
+
28
+ # Guidance Parameters
29
+ "tau_init": 0.1,
30
+ "max_guidance": 10.0,
31
+
32
+ # Projection Configuration
33
+ "proj_layers": 2, # number of linear+GELU layers
34
+ "layer_norm": True, # apply LayerNorm before stack
35
+ "dropout": 0.1,
36
+ "use_dropout": True,
37
+ "use_proj_stack": True,
38
+
39
+ # Runtime Safeguards
40
+ "assert_input_dims": True,
41
+
42
+ # Routing Logic
43
+ "routing": {
44
+ "type": "cross_attention",
45
+ "enable_causal_mask": False,
46
+ "bidirectional": True
47
+ },
48
+
49
+ # Version & Metadata
50
+ "version": "v0.3.1",
51
+ "description": "Upgraded FLAN-T5 ↔ CLIP-L token shunt with projection stack, dropout, and field-consistent architecture."
52
+ }
53
+
54
+ import torch
55
+ import torch.nn as nn
56
+ import torch.nn.functional as F
57
+
58
+ # ─── Residual Pocket Block ───────────────────────────────────
59
+ class BottleneckResBlock(nn.Module):
60
+ def __init__(self, dim, kernel=3, dropout=0.1):
61
+ super().__init__()
62
+ self.norm = nn.LayerNorm(dim)
63
+ self.conv = nn.Conv1d(dim, dim, kernel_size=kernel, padding=kernel // 2, groups=1)
64
+ self.proj = nn.Sequential(
65
+ nn.Linear(dim, dim * 2),
66
+ nn.GELU(),
67
+ nn.Linear(dim * 2, dim),
68
+ nn.Dropout(dropout)
69
+ )
70
+
71
+ def forward(self, x):
72
+ residual = x
73
+ x = self.norm(x)
74
+ x = x.transpose(1, 2)
75
+ x = self.conv(x).transpose(1, 2)
76
+ return residual + self.proj(x)
77
+
78
+ # ─── Two Stream Shunt Adapter ──────────────────────────────────────
79
+ class TwoStreamShuntAdapter(nn.Module):
80
+ def __init__(self, config: dict):
81
+ super().__init__()
82
+ self.config = config
83
+ self.t5_dim = config["t5"]["hidden_size"]
84
+ self.clip_dim = config["clip"]["hidden_size"]
85
+ self.bneck = config["bottleneck"]
86
+ self.heads = config["heads"]
87
+ self.tau_init = config["tau_init"]
88
+ self.max_guidance = config["max_guidance"]
89
+
90
+ use_norm = config.get("layer_norm", True)
91
+ use_do = config.get("use_dropout", True)
92
+ do_p = config.get("dropout", 0.1)
93
+ proj_depth = config.get("proj_layers", 2)
94
+
95
+ def build_projection(input_dim, output_dim):
96
+ layers = []
97
+ last_dim = input_dim
98
+ if use_norm:
99
+ layers.append(nn.LayerNorm(last_dim))
100
+ for i in range(proj_depth):
101
+ next_dim = self.bneck * (2 if i == 0 and proj_depth > 1 else 1)
102
+ layers.append(nn.Linear(last_dim, next_dim))
103
+ layers.append(nn.GELU())
104
+ if use_do:
105
+ layers.append(nn.Dropout(do_p))
106
+ last_dim = next_dim
107
+ layers.append(nn.Linear(last_dim, output_dim))
108
+ return nn.Sequential(*layers)
109
+
110
+ # Projections
111
+ self.proj_t5 = build_projection(self.t5_dim, self.bneck)
112
+ self.proj_clip = build_projection(self.clip_dim, self.bneck)
113
+
114
+ # Attention
115
+ self.cross_t2c = nn.MultiheadAttention(self.bneck, self.heads, batch_first=True, dropout=do_p)
116
+ self.cross_c2t = nn.MultiheadAttention(self.bneck, self.heads, batch_first=True, dropout=do_p)
117
+ self.tau = nn.Parameter(torch.full((self.heads, 1, 1), self.tau_init))
118
+
119
+ # Residual Pocket
120
+ self.pocket_blocks = nn.Sequential(
121
+ BottleneckResBlock(self.bneck, dropout=do_p),
122
+ BottleneckResBlock(self.bneck, dropout=do_p)
123
+ )
124
+
125
+ # Fuse
126
+ self.fuse = nn.Sequential(
127
+ nn.LayerNorm(2 * self.bneck),
128
+ nn.Linear(2 * self.bneck, self.bneck * 2),
129
+ nn.GELU(),
130
+ nn.Linear(self.bneck * 2, self.bneck)
131
+ )
132
+
133
+ # Output Projections
134
+ self.anchor_proj = build_projection(self.bneck, self.clip_dim)
135
+ self.delta_proj = build_projection(self.bneck, self.clip_dim)
136
+ self.logsig_proj = build_projection(self.bneck, self.clip_dim)
137
+
138
+ self.gate_proj = nn.Sequential(
139
+ nn.LayerNorm(self.bneck),
140
+ nn.Linear(self.bneck, self.bneck),
141
+ nn.GELU(),
142
+ nn.Linear(self.bneck, 1),
143
+ nn.Tanh(),
144
+ nn.Sigmoid()
145
+ )
146
+
147
+ self.guidance_proj = nn.Sequential(
148
+ nn.LayerNorm(self.bneck),
149
+ nn.Linear(self.bneck, 1),
150
+ nn.Sigmoid()
151
+ )
152
+
153
+ def forward(self, t5_seq: torch.Tensor, clip_seq: torch.Tensor):
154
+ if self.config.get("assert_input_dims", True):
155
+ assert t5_seq.size(-1) == self.t5_dim
156
+ assert clip_seq.size(-1) == self.clip_dim
157
+
158
+ t5_b = self.proj_t5(t5_seq)
159
+ clip_b = self.proj_clip(clip_seq)
160
+
161
+ t2c, attn_t2c = self.cross_t2c(t5_b, clip_b, clip_b, need_weights=True, average_attn_weights=False)
162
+ c2t, attn_c2t = self.cross_c2t(clip_b, t5_b, t5_b, need_weights=True, average_attn_weights=False)
163
+
164
+ pocket = self.pocket_blocks(t2c)
165
+
166
+ pocket_mean = pocket.mean(1, keepdim=True).expand(-1, clip_b.size(1), -1)
167
+ h = self.fuse(torch.cat([pocket_mean, c2t], dim=-1))
168
+
169
+ anchor = self.anchor_proj(h)
170
+ delta = self.delta_proj(h) * self.gate_proj(h)
171
+ log_sigma = self.logsig_proj(h)
172
+
173
+ g_tok = self.guidance_proj(h).squeeze(-1)
174
+ g_pred = g_tok.mean(1, keepdim=True) * self.max_guidance
175
+
176
+ return anchor, delta, log_sigma, attn_t2c, attn_c2t, self.tau, g_pred, self.gate_proj(h)
177
+
178
+
179
+ from safetensors.torch import save_file, load_file
180
+
181
+ def save_safetensors(adapter: nn.Module, path: str, metadata: dict = None):
182
+ """
183
+ Save the current adapter state to safetensors format.
184
+
185
+ All tensors are moved to CPU and saved as float32 for compatibility.
186
+ Optional metadata may be embedded (e.g., version, prompt_mode).
187
+ """
188
+ state = {k: v.float().cpu() for k, v in adapter.state_dict().items()}
189
+ save_file(state, path, metadata=metadata or {})
190
+ print(f"✅ Model saved to {path}")
191
+
192
+ def load_safetensors(adapter: nn.Module, path: str, map_location="cpu"):
193
+ """
194
+ Load a safetensors checkpoint into the adapter.
195
+
196
+ Uses strict key matching. Tensors are loaded to the specified device.
197
+ """
198
+ state = load_file(path, device=map_location)
199
+ adapter.load_state_dict(state, strict=True)
200
+ print(f"✅ Model loaded from {path}")
201
+