|
import torch
|
|
import torch.nn as nn
|
|
from torch.nn import GroupNorm, LayerNorm
|
|
import torch.nn.functional as F
|
|
import torch.utils.checkpoint as checkpoint
|
|
import timm
|
|
|
|
class ViTWrapper(nn.Module):
|
|
"""Wrapper to make ViT compatible with feature extraction for ImageTagger"""
|
|
def __init__(self, vit_model):
|
|
super().__init__()
|
|
self.vit = vit_model
|
|
self.out_indices = (-1,)
|
|
|
|
|
|
self.patch_size = vit_model.patch_embed.patch_size[0]
|
|
self.embed_dim = vit_model.embed_dim
|
|
|
|
def forward(self, x):
|
|
B = x.size(0)
|
|
|
|
|
|
x = self.vit.patch_embed(x)
|
|
|
|
|
|
cls_tok = self.vit.cls_token.expand(B, -1, -1)
|
|
x = torch.cat((cls_tok, x), dim=1)
|
|
|
|
|
|
if self.vit.pos_embed is not None:
|
|
x = x + self.vit.pos_embed[:, : x.size(1), :]
|
|
|
|
x = self.vit.pos_drop(x)
|
|
|
|
for blk in self.vit.blocks:
|
|
x = blk(x)
|
|
|
|
x = self.vit.norm(x)
|
|
|
|
|
|
cls_final = x[:, 0]
|
|
patch_tokens = x[:, 1:]
|
|
|
|
|
|
B, N, C = patch_tokens.shape
|
|
h = w = int(N ** 0.5)
|
|
patch_features = patch_tokens.permute(0, 2, 1).reshape(B, C, h, w)
|
|
|
|
|
|
return patch_features, cls_final
|
|
|
|
def set_grad_checkpointing(self, enable=True):
|
|
"""Enable gradient checkpointing if supported"""
|
|
if hasattr(self.vit, 'set_grad_checkpointing'):
|
|
self.vit.set_grad_checkpointing(enable)
|
|
return True
|
|
return False
|
|
|
|
class ImageTagger(nn.Module):
|
|
"""
|
|
ImageTagger with Vision Transformer backbone
|
|
"""
|
|
def __init__(self, total_tags, dataset, model_name='vit_base_patch16_224',
|
|
num_heads=16, dropout=0.1, pretrained=True, tag_context_size=256,
|
|
use_gradient_checkpointing=False, img_size=224):
|
|
super().__init__()
|
|
|
|
|
|
self.use_gradient_checkpointing = use_gradient_checkpointing
|
|
self.model_name = model_name
|
|
self.img_size = img_size
|
|
|
|
|
|
self._flags = {
|
|
'debug': False,
|
|
'model_stats': True
|
|
}
|
|
|
|
|
|
self.dataset = dataset
|
|
self.tag_context_size = tag_context_size
|
|
self.total_tags = total_tags
|
|
|
|
print(f"🏗️ Building ImageTagger with ViT backbone and {total_tags} tags")
|
|
print(f" Backbone: {model_name}")
|
|
print(f" Image size: {img_size}x{img_size}")
|
|
print(f" Tag context size: {tag_context_size}")
|
|
print(f" Gradient checkpointing: {use_gradient_checkpointing}")
|
|
print(f" 🎯 Custom embeddings, PyTorch native attention, no ground truth inclusion")
|
|
|
|
|
|
print("📦 Loading Vision Transformer backbone...")
|
|
self._load_vit_backbone()
|
|
|
|
|
|
self._determine_backbone_dimensions()
|
|
|
|
self.embedding_dim = self.backbone.embed_dim
|
|
|
|
|
|
print("🎯 Using custom tag embeddings (no CLIP)")
|
|
self.tag_embedding = nn.Embedding(total_tags, self.embedding_dim)
|
|
|
|
|
|
print("🔗 Using shared weights between initial head and tag embeddings")
|
|
self.tag_bias = nn.Parameter(torch.zeros(total_tags))
|
|
|
|
|
|
|
|
self.image_token_proj = nn.Identity()
|
|
|
|
|
|
self.cross_attention = nn.MultiheadAttention(
|
|
embed_dim=self.embedding_dim,
|
|
num_heads=num_heads,
|
|
dropout=dropout,
|
|
batch_first=True
|
|
)
|
|
self.cross_norm = nn.LayerNorm(self.embedding_dim)
|
|
|
|
|
|
self._init_weights()
|
|
|
|
|
|
if self.use_gradient_checkpointing:
|
|
self._enable_gradient_checkpointing()
|
|
|
|
print(f"✅ ImageTagger with ViT initialized!")
|
|
self._print_parameter_count()
|
|
|
|
def _load_vit_backbone(self):
|
|
"""Load Vision Transformer model from timm"""
|
|
print(f" Loading from timm: {self.model_name}")
|
|
|
|
|
|
vit_model = timm.create_model(
|
|
self.model_name,
|
|
pretrained=True,
|
|
img_size=self.img_size,
|
|
num_classes=0
|
|
)
|
|
|
|
|
|
self.backbone = ViTWrapper(vit_model)
|
|
|
|
print(f" ✅ ViT loaded successfully")
|
|
print(f" Patch size: {self.backbone.patch_size}x{self.backbone.patch_size}")
|
|
print(f" Embed dim: {self.backbone.embed_dim}")
|
|
|
|
def _determine_backbone_dimensions(self):
|
|
"""Determine backbone output dimensions"""
|
|
print(" 🔍 Determining backbone dimensions...")
|
|
|
|
with torch.no_grad(), torch.autocast('cuda', dtype=torch.bfloat16):
|
|
|
|
dummy_input = torch.randn(1, 3, self.img_size, self.img_size)
|
|
|
|
|
|
backbone_features, cls_dummy = self.backbone(dummy_input)
|
|
feature_tensor = backbone_features
|
|
|
|
self.backbone_dim = feature_tensor.shape[1]
|
|
self.feature_map_size = feature_tensor.shape[2]
|
|
|
|
print(f" Backbone output: {self.backbone_dim}D, {self.feature_map_size}x{self.feature_map_size} spatial")
|
|
print(f" Total patch tokens: {self.feature_map_size * self.feature_map_size}")
|
|
|
|
def _enable_gradient_checkpointing(self):
|
|
"""Enable gradient checkpointing for memory efficiency"""
|
|
print("🔄 Enabling gradient checkpointing...")
|
|
|
|
|
|
if self.backbone.set_grad_checkpointing(True):
|
|
print(" ✅ ViT backbone checkpointing enabled")
|
|
else:
|
|
print(" ⚠️ ViT backbone doesn't support built-in checkpointing, will checkpoint manually")
|
|
|
|
def _checkpoint_backbone(self, x):
|
|
"""Wrapper for backbone with gradient checkpointing"""
|
|
if self.use_gradient_checkpointing and self.training:
|
|
return checkpoint.checkpoint(self.backbone, x, use_reentrant=False)
|
|
else:
|
|
return self.backbone(x)
|
|
|
|
def _checkpoint_image_proj(self, x):
|
|
"""Wrapper for image projection with gradient checkpointing"""
|
|
if self.use_gradient_checkpointing and self.training:
|
|
return checkpoint.checkpoint(self.image_token_proj, x, use_reentrant=False)
|
|
else:
|
|
return self.image_token_proj(x)
|
|
|
|
def _checkpoint_cross_attention(self, query, key, value):
|
|
"""Wrapper for cross attention with gradient checkpointing"""
|
|
def _attention_forward(q, k, v):
|
|
attended_features, _ = self.cross_attention(query=q, key=k, value=v)
|
|
return self.cross_norm(attended_features)
|
|
|
|
if self.use_gradient_checkpointing and self.training:
|
|
return checkpoint.checkpoint(_attention_forward, query, key, value, use_reentrant=False)
|
|
else:
|
|
return _attention_forward(query, key, value)
|
|
|
|
def _checkpoint_candidate_selection(self, initial_logits):
|
|
"""Wrapper for candidate selection with gradient checkpointing"""
|
|
def _candidate_forward(logits):
|
|
return self._get_candidate_tags(logits)
|
|
|
|
if self.use_gradient_checkpointing and self.training:
|
|
return checkpoint.checkpoint(_candidate_forward, initial_logits, use_reentrant=False)
|
|
else:
|
|
return _candidate_forward(initial_logits)
|
|
|
|
def _checkpoint_final_scoring(self, attended_features, candidate_indices):
|
|
"""Wrapper for final scoring with gradient checkpointing"""
|
|
def _scoring_forward(features, indices):
|
|
emb = self.tag_embedding(indices)
|
|
|
|
return (features * emb).sum(dim=-1)
|
|
|
|
if self.use_gradient_checkpointing and self.training:
|
|
return checkpoint.checkpoint(_scoring_forward, attended_features, candidate_indices, use_reentrant=False)
|
|
else:
|
|
return _scoring_forward(attended_features, candidate_indices)
|
|
|
|
def _init_weights(self):
|
|
"""Initialize weights for new modules"""
|
|
def _init_layer(layer):
|
|
if isinstance(layer, nn.Linear):
|
|
nn.init.xavier_uniform_(layer.weight)
|
|
if layer.bias is not None:
|
|
nn.init.zeros_(layer.bias)
|
|
elif isinstance(layer, nn.Conv2d):
|
|
nn.init.kaiming_normal_(layer.weight, mode='fan_out', nonlinearity='relu')
|
|
if layer.bias is not None:
|
|
nn.init.zeros_(layer.bias)
|
|
elif isinstance(layer, nn.Embedding):
|
|
nn.init.normal_(layer.weight, mean=0, std=0.02)
|
|
|
|
|
|
self.image_token_proj.apply(_init_layer)
|
|
|
|
|
|
nn.init.normal_(self.tag_embedding.weight, mean=0, std=0.02)
|
|
|
|
|
|
nn.init.zeros_(self.tag_bias)
|
|
|
|
def _print_parameter_count(self):
|
|
"""Print parameter statistics"""
|
|
total_params = sum(p.numel() for p in self.parameters())
|
|
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
|
backbone_params = sum(p.numel() for p in self.backbone.parameters())
|
|
|
|
print(f"📊 Parameter Statistics:")
|
|
print(f" Total parameters: {total_params/1e6:.1f}M")
|
|
print(f" Trainable parameters: {trainable_params/1e6:.1f}M")
|
|
print(f" Frozen parameters: {(total_params-trainable_params)/1e6:.1f}M")
|
|
print(f" Backbone parameters: {backbone_params/1e6:.1f}M")
|
|
|
|
if self.use_gradient_checkpointing:
|
|
print(f" 🔄 Gradient checkpointing enabled for memory efficiency")
|
|
|
|
@property
|
|
def debug(self):
|
|
return self._flags['debug']
|
|
|
|
@property
|
|
def model_stats(self):
|
|
return self._flags['model_stats']
|
|
|
|
def _get_candidate_tags(self, initial_logits, target_tags=None, hard_negatives=None):
|
|
"""Select candidate tags - no ground truth inclusion"""
|
|
batch_size = initial_logits.size(0)
|
|
|
|
|
|
top_probs, top_indices = torch.topk(
|
|
torch.sigmoid(initial_logits),
|
|
k=min(self.tag_context_size, self.total_tags),
|
|
dim=1, largest=True, sorted=True
|
|
)
|
|
|
|
return top_indices
|
|
|
|
def _analyze_predictions(self, predictions, tag_indices):
|
|
"""Analyze prediction patterns"""
|
|
if not self.model_stats:
|
|
return {}
|
|
|
|
if torch._dynamo.is_compiling():
|
|
return {}
|
|
|
|
with torch.no_grad(), torch.autocast('cuda', dtype=torch.bfloat16):
|
|
probs = torch.sigmoid(predictions)
|
|
relevant_probs = torch.gather(probs, 1, tag_indices)
|
|
|
|
return {
|
|
'prediction_confidence': relevant_probs.mean().item(),
|
|
'prediction_entropy': -(relevant_probs * torch.log(relevant_probs + 1e-9)).mean().item(),
|
|
'high_confidence_ratio': (relevant_probs > 0.7).float().mean().item(),
|
|
'above_threshold_ratio': (relevant_probs > 0.5).float().mean().item(),
|
|
}
|
|
|
|
def forward(self, x, targets=None, hard_negatives=None):
|
|
"""
|
|
Forward pass with ViT backbone, CLS token support and gradient-checkpointing.
|
|
All arithmetic tensors stay in the backbone’s dtype (BF16 under autocast,
|
|
FP32 otherwise). Anything that must mix dtypes is cast to match.
|
|
"""
|
|
batch_size = x.size(0)
|
|
model_stats = {} if self.model_stats else {}
|
|
|
|
|
|
|
|
|
|
patch_map, cls_token = self._checkpoint_backbone(x)
|
|
|
|
|
|
|
|
|
|
|
|
image_tokens_4d = self._checkpoint_image_proj(patch_map)
|
|
image_tokens = image_tokens_4d.flatten(2).transpose(1, 2)
|
|
|
|
|
|
global_features = 0.5 * (image_tokens.mean(dim=1, dtype=image_tokens.dtype) + cls_token)
|
|
|
|
compute_dtype = global_features.dtype
|
|
|
|
|
|
|
|
|
|
tag_weights = self.tag_embedding.weight.to(compute_dtype)
|
|
tag_bias = self.tag_bias.to(compute_dtype)
|
|
|
|
initial_logits = global_features @ tag_weights.t() + tag_bias
|
|
initial_logits = initial_logits.to(compute_dtype)
|
|
initial_preds = initial_logits
|
|
|
|
|
|
|
|
|
|
candidate_indices = self._checkpoint_candidate_selection(initial_logits)
|
|
|
|
tag_embeddings = self.tag_embedding(candidate_indices).to(compute_dtype)
|
|
|
|
attended_features = self._checkpoint_cross_attention(
|
|
tag_embeddings, image_tokens, image_tokens
|
|
)
|
|
|
|
|
|
|
|
|
|
candidate_logits = self._checkpoint_final_scoring(attended_features, candidate_indices)
|
|
|
|
|
|
if candidate_logits.dtype != initial_logits.dtype:
|
|
candidate_logits = candidate_logits.to(initial_logits.dtype)
|
|
|
|
refined_logits = initial_logits.clone()
|
|
refined_logits.scatter_(1, candidate_indices, candidate_logits)
|
|
refined_preds = refined_logits
|
|
|
|
|
|
|
|
|
|
if self.model_stats and targets is not None and not torch._dynamo.is_compiling():
|
|
model_stats['initial_prediction_stats'] = self._analyze_predictions(initial_preds,
|
|
candidate_indices)
|
|
model_stats['refined_prediction_stats'] = self._analyze_predictions(refined_preds,
|
|
candidate_indices)
|
|
|
|
return {
|
|
'initial_predictions': initial_preds,
|
|
'refined_predictions': refined_preds,
|
|
'selected_candidates': candidate_indices,
|
|
'model_stats': model_stats
|
|
}
|
|
|
|
def predict |