|
|
|
""" |
|
Vision Transformer Essence Generator for Tag Collector Game |
|
Based on "What do Vision Transformers Learn? A Visual Exploration" |
|
""" |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torchvision.transforms.functional import to_pil_image |
|
from PIL import Image |
|
import numpy as np |
|
import os |
|
import re |
|
import math |
|
import json |
|
import timm |
|
import streamlit as st |
|
from tqdm import tqdm |
|
from scipy.ndimage import gaussian_filter |
|
from functools import wraps, lru_cache |
|
from safetensors.torch import load_file |
|
import time |
|
import tag_storage |
|
|
|
from game_constants import RARITY_LEVELS, ENKEPHALIN_CURRENCY_NAME, ENKEPHALIN_ICON |
|
from tag_categories import TAG_CATEGORIES |
|
|
|
os.environ["CUDA_LAUNCH_BLOCKING"] = "1" |
|
torch.backends.cudnn.benchmark = False |
|
torch.backends.cudnn.deterministic = True |
|
|
|
|
|
ESSENCE_QUALITY_LEVELS = { |
|
"ZAYIN": {"threshold": 0.0, "color": "#1CFC00", "description": "Basic representation with minimal details."}, |
|
"TETH": {"threshold": 3.0, "color": "#389DDF", "description": "Clear representation with recognizable features."}, |
|
"HE": {"threshold": 5.0, "color": "#FEF900", "description": "Refined representation with distinctive elements."}, |
|
"WAW": {"threshold": 10.0, "color": "#7930F1", "description": "Advanced representation with precise details."}, |
|
"ALEPH": {"threshold": 12.0, "color": "#FF0000", "description": "Perfect representation with extraordinary precision."} |
|
} |
|
|
|
|
|
ESSENCE_COSTS = { |
|
"Special": 0, |
|
"Canard": 100, |
|
"Urban Myth": 125, |
|
"Urban Legend": 150, |
|
"Urban Plague": 200, |
|
"Urban Nightmare": 250, |
|
"Star of the City": 300, |
|
"Impuritas Civitas": 400 |
|
} |
|
|
|
|
|
DEFAULT_ESSENCE_SETTINGS = { |
|
"iterations": 256, |
|
"lr": 0.05, |
|
"ensemble_k": 8, |
|
"neighbor_count": 8, |
|
"image_size": 512, |
|
"layer_emphasis": "balanced", |
|
"tv_weight": 1e-3 |
|
} |
|
|
|
def initialize_essence_settings(): |
|
"""Initialize essence generator settings if not already present""" |
|
if 'essence_custom_settings' not in st.session_state: |
|
|
|
loaded_state = tag_storage.load_essence_state() |
|
|
|
if loaded_state and 'essence_custom_settings' in loaded_state: |
|
old_settings = loaded_state['essence_custom_settings'] |
|
|
|
new_settings = DEFAULT_ESSENCE_SETTINGS.copy() |
|
|
|
|
|
for key in DEFAULT_ESSENCE_SETTINGS.keys(): |
|
if key in old_settings: |
|
|
|
if key == 'layer_emphasis' and old_settings[key] not in ['balanced', 'early', 'mid', 'late']: |
|
continue |
|
new_settings[key] = old_settings[key] |
|
|
|
st.session_state.essence_custom_settings = new_settings |
|
else: |
|
st.session_state.essence_custom_settings = DEFAULT_ESSENCE_SETTINGS.copy() |
|
|
|
def initialize_manual_tags(): |
|
"""Initialize manual tags if not already present""" |
|
if 'manual_tags' not in st.session_state: |
|
|
|
loaded_state = tag_storage.load_essence_state() |
|
|
|
if loaded_state and 'manual_tags' in loaded_state: |
|
st.session_state.manual_tags = loaded_state['manual_tags'] |
|
else: |
|
st.session_state.manual_tags = { |
|
"hatsune_miku": {"rarity": "Special", "description": "Popular virtual singer with long teal twin-tails"}, |
|
} |
|
|
|
def timeout(seconds, fallback_value=None): |
|
"""Simple timeout utility for functions.""" |
|
def decorator(func): |
|
@wraps(func) |
|
def wrapper(*args, **kwargs): |
|
start_time = time.time() |
|
result = func(*args, **kwargs) |
|
elapsed = time.time() - start_time |
|
|
|
if elapsed > seconds: |
|
print(f"WARNING: Function {func.__name__} took {elapsed:.2f} seconds (expected max {seconds}s)") |
|
|
|
return result |
|
return wrapper |
|
return decorator |
|
|
|
class TaggerTorch(nn.Module): |
|
def __init__(self, backbone_name="vit_base_patch16_384", img_size=512, num_tags=70527, normalize=True): |
|
super().__init__() |
|
|
|
self.backbone = timm.create_model(backbone_name, pretrained=False, num_classes=0, img_size=img_size) |
|
in_features = self.backbone.num_features |
|
self.head = nn.Linear(in_features, num_tags) |
|
|
|
|
|
self.normalize = normalize |
|
if self.normalize: |
|
self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) |
|
self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) |
|
|
|
def forward(self, x): |
|
if self.normalize: |
|
x = (x - self.mean) / self.std |
|
feats = self.backbone.forward_features(x) |
|
if feats.ndim == 3: |
|
feats = feats[:, 0, :] |
|
return self.head(feats) |
|
|
|
def _remap_backbone_keys(sd): |
|
out = {} |
|
for k, v in sd.items(): |
|
if k.startswith("module."): k = k[7:] |
|
|
|
|
|
if k.startswith("backbone.vit."): |
|
k = "backbone." + k[len("backbone.vit."):] |
|
elif k.startswith("vit."): |
|
k = "backbone." + k[len("vit."):] |
|
elif k.startswith(("pos_embed","patch_embed.","blocks.","norm.","cls_token")): |
|
k = "backbone." + k |
|
|
|
out[k] = v |
|
return out |
|
|
|
def _get_logits_from_output(out): |
|
if isinstance(out, dict): |
|
return out.get("refined_predictions") or out.get("initial_predictions") |
|
return out |
|
|
|
def build_torch_model_from_safetensors(ckpt_path, num_tags, backbone="vit_base_patch16_384", img_size=512): |
|
model = TaggerTorch(backbone_name=backbone, img_size=img_size, num_tags=num_tags, normalize=True) |
|
sd = load_file(ckpt_path) |
|
sd = _remap_backbone_keys(sd) |
|
|
|
|
|
te_w = sd.pop("tag_embedding.weight", sd.pop("module.tag_embedding.weight", None)) |
|
te_b = sd.pop("tag_bias", sd.pop("module.tag_bias", None)) |
|
|
|
|
|
missing, unexpected = model.load_state_dict(sd, strict=False) |
|
print("[load] missing:", missing[:20], "…") |
|
print("[load] unexpected:", unexpected[:20], "…") |
|
|
|
|
|
with torch.no_grad(): |
|
if te_w is not None and te_w.shape == model.head.weight.shape: |
|
model.head.weight.copy_(te_w) |
|
print("[load] copied tag_embedding.weight → head.weight") |
|
if te_b is not None and model.head.bias is not None and te_b.shape == model.head.bias.shape: |
|
model.head.bias.copy_(te_b) |
|
print("[load] copied tag_bias → head.bias") |
|
|
|
return model |
|
|
|
@torch.no_grad() |
|
def _get_classifier_matrix(model): |
|
|
|
if hasattr(model, "tag_embedding"): |
|
return model.tag_embedding.weight.detach() |
|
if hasattr(model, "head") and hasattr(model.head, "weight"): |
|
return model.head.weight.detach() |
|
raise AttributeError("Model has neither tag_embedding nor head.weight") |
|
|
|
@torch.no_grad() |
|
def neighbor_sets_from_embedding(model, class_idx, k_pos=8, k_neg=8): |
|
""" |
|
Returns (pos_idx, pos_sims, neg_idx, neg_sims) |
|
pos: highest cosine neighbors (exclude self) |
|
neg: lowest cosine neighbors (most dissimilar) |
|
""" |
|
W = _get_classifier_matrix(model) |
|
Wn = F.normalize(W, dim=1) |
|
q = Wn[class_idx:class_idx+1] |
|
sims = (q @ Wn.T).squeeze(0) |
|
sims[class_idx] = -9e9 |
|
|
|
|
|
pos_vals, pos_idx = torch.topk(sims, k=min(k_pos, sims.numel()-1)) |
|
|
|
neg_vals, neg_idx = torch.topk(-sims, k=min(k_neg, sims.numel()-1)) |
|
neg_vals = -neg_vals |
|
|
|
|
|
pos_w = torch.clamp(pos_vals, 0.0, 1.0).tolist() |
|
neg_w = torch.clamp(neg_vals.abs(), 0.0, 1.0).tolist() |
|
return pos_idx.tolist(), pos_w, neg_idx.tolist(), neg_w |
|
|
|
def weighted_class_objective(logits, main_idx, |
|
plus_idxs=(), plus_w=None, alpha=0.25, |
|
minus_idxs=(), minus_w=None, beta=0.15): |
|
score = logits[:, main_idx].mean() |
|
if plus_idxs: |
|
w = torch.tensor(plus_w or [1.0]*len(plus_idxs), device=logits.device, dtype=logits.dtype) |
|
w = w / (w.sum() + 1e-8) |
|
score = score + alpha * (logits[:, plus_idxs] * w).sum(dim=1).mean() |
|
if minus_idxs: |
|
w = torch.tensor(minus_w or [1.0]*len(minus_idxs), device=logits.device, dtype=logits.dtype) |
|
w = w / (w.sum() + 1e-8) |
|
score = score - beta * (logits[:, minus_idxs] * w).sum(dim=1).mean() |
|
return score |
|
|
|
def idx_to_name(idx, dataset=None): |
|
if dataset is not None and hasattr(dataset, "idx_to_tag"): |
|
return dataset.idx_to_tag.get(int(idx), f"Tag {idx}") |
|
|
|
meta = _load_tagger_metadata_cached() |
|
return meta.get("dataset_info",{}).get("tag_mapping",{}).get("idx_to_tag",{}).get(str(int(idx)), f"Tag {idx}") |
|
|
|
|
|
class ViTLayerHook: |
|
"""Hook for capturing ViT feed-forward layer activations.""" |
|
def __init__(self, layer, layer_name): |
|
self.layer = layer |
|
self.layer_name = layer_name |
|
self.features = None |
|
self.hook = layer.register_forward_hook(self.hook_fn) |
|
|
|
def hook_fn(self, module, input, output): |
|
"""Store the output activations.""" |
|
self.features = output |
|
|
|
def close(self): |
|
self.hook.remove() |
|
|
|
class ViTFeatureAnalyzer: |
|
"""Analyzes ViT architecture to find optimal layers for visualization.""" |
|
|
|
def __init__(self, model): |
|
self.model = model |
|
self.layer_info = self._analyze_architecture() |
|
|
|
def _analyze_architecture(self): |
|
"""Analyze the ViT architecture and identify feed-forward layers.""" |
|
layer_info = {} |
|
|
|
def traverse_modules(module, prefix=''): |
|
for name, child in module.named_children(): |
|
full_name = f"{prefix}.{name}" if prefix else name |
|
|
|
|
|
if 'mlp' in full_name.lower() and (hasattr(child, 'act') or 'act' in dict(child.named_children())): |
|
|
|
act = getattr(child, 'act', None) |
|
if act is not None: |
|
layer_info[full_name + ".act"] = { |
|
'type': 'mlp_activation', |
|
'module': act, |
|
'block_idx': self._extract_block_number(full_name) |
|
} |
|
else: |
|
|
|
for n2, c2 in child.named_children(): |
|
if 'act' in n2.lower(): |
|
layer_info[full_name + f".{n2}"] = { |
|
'type': 'mlp_activation', |
|
'module': c2, |
|
'block_idx': self._extract_block_number(full_name) |
|
} |
|
elif 'gelu' in str(type(child)).lower() or 'activation' in name.lower(): |
|
|
|
parent_name = prefix.split('.')[-1] if '.' in prefix else prefix |
|
if 'mlp' in prefix.lower() or 'ffn' in prefix.lower(): |
|
layer_info[full_name] = { |
|
'type': 'activation', |
|
'module': child, |
|
'block_idx': self._extract_block_number(full_name) |
|
} |
|
|
|
|
|
traverse_modules(child, full_name) |
|
|
|
traverse_modules(self.model) |
|
return layer_info |
|
|
|
def _extract_block_number(self, layer_name): |
|
"""Extract block/layer number from layer name.""" |
|
import re |
|
numbers = re.findall(r'\.(\d+)\.', layer_name) |
|
if numbers: |
|
return int(numbers[0]) |
|
return 0 |
|
|
|
def get_visualization_layers(self, layer_emphasis="balanced"): |
|
"""Get the best layers for visualization based on emphasis.""" |
|
if not self.layer_info: |
|
print("Warning: No suitable ViT layers found for visualization") |
|
return [] |
|
|
|
|
|
sorted_layers = sorted( |
|
[(n, info) for n, info in self.layer_info.items() if 'mlp' in n.lower() and 'act' in n.lower()], |
|
key=lambda x: x[1]['block_idx'] |
|
) |
|
|
|
total_blocks = max([info['block_idx'] for _, info in sorted_layers]) + 1 |
|
|
|
if layer_emphasis == "early": |
|
|
|
target_blocks = list(range(0, max(1, total_blocks // 3))) |
|
elif layer_emphasis == "mid": |
|
|
|
start = total_blocks // 3 |
|
end = 2 * total_blocks // 3 |
|
target_blocks = list(range(start, max(start + 1, end))) |
|
elif layer_emphasis == "late": |
|
|
|
start = 2 * total_blocks // 3 |
|
target_blocks = list(range(start, total_blocks)) |
|
else: |
|
|
|
if total_blocks <= 3: |
|
target_blocks = list(range(total_blocks)) |
|
else: |
|
target_blocks = [0, total_blocks // 2, total_blocks - 1] |
|
|
|
|
|
selected_layers = [] |
|
for layer_name, info in sorted_layers: |
|
if info['block_idx'] in target_blocks: |
|
selected_layers.append(layer_name) |
|
|
|
return selected_layers |
|
|
|
def _jitter_reflect_crop(x, pad=16): |
|
b, c, h, w = x.shape |
|
padded = F.pad(x, (pad, pad, pad, pad), mode='reflect').contiguous() |
|
off_h = torch.randint(0, 2 * pad + 1, (b,), device=x.device) |
|
off_w = torch.randint(0, 2 * pad + 1, (b,), device=x.device) |
|
crops = [] |
|
for i in range(b): |
|
hs, ws = int(off_h[i]), int(off_w[i]) |
|
crop = padded[i:i+1, :, hs:hs+h, ws:ws+w].contiguous() |
|
crops.append(crop) |
|
return torch.cat(crops, 0).contiguous() |
|
|
|
def _channel_affine(x): |
|
|
|
b, c, _, _ = x.shape |
|
mu = torch.empty(b, c, 1, 1, device=x.device, dtype=x.dtype).uniform_(-1.0, 1.0) |
|
log_sigma = torch.empty(b, c, 1, 1, device=x.device, dtype=x.dtype).uniform_(-1.0, 1.0) |
|
sigma = torch.exp(log_sigma) |
|
return (x * sigma + mu) |
|
|
|
def _add_gaussian_noise(x, std=0.15): |
|
return (x + torch.randn_like(x) * std) |
|
|
|
def _augment_once(x, noise_std=0.15): |
|
z = _jitter_reflect_crop(x) |
|
z = _channel_affine(z) |
|
z = _add_gaussian_noise(z, std=noise_std) |
|
return z |
|
|
|
def _augment_batch(x, K=8, noise_std=0.15): |
|
augs = [] |
|
for _ in range(K): |
|
z = _augment_once(x, noise_std=noise_std) |
|
augs.append(z) |
|
return torch.cat(augs, dim=0).contiguous() |
|
|
|
class ViTEssenceGenerator: |
|
""" |
|
ViT Essence Generator based on the methodology from |
|
'What do Vision Transformers Learn? A Visual Exploration' |
|
""" |
|
|
|
def __init__( |
|
self, |
|
model, |
|
tag_to_name=None, |
|
iterations=500, |
|
learning_rate=0.05, |
|
layer_emphasis="balanced", |
|
ensemble_K=8, |
|
tv_weight=1e-3 |
|
): |
|
"""Initialize the ViT Essence Generator""" |
|
self.model = model |
|
self.tag_to_name = tag_to_name |
|
self.iterations = iterations |
|
self.lr = learning_rate |
|
self.layer_emphasis = layer_emphasis |
|
self.ensemble_K = ensemble_K |
|
self.tv_weight = tv_weight |
|
|
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
self.model.eval().to(self.device) |
|
|
|
self.imagenet_mean = torch.tensor([0.485, 0.456, 0.406], device=self.device).view(1,3,1,1) |
|
self.imagenet_std = torch.tensor([0.229, 0.224, 0.225], device=self.device).view(1,3,1,1) |
|
self.expect_imagenet = not (hasattr(self.model, "normalize") and getattr(self.model, "normalize") is True) |
|
|
|
|
|
self.analyzer = ViTFeatureAnalyzer(self.model) |
|
|
|
|
|
self.hooks = {} |
|
self.selected_layers = [] |
|
|
|
print(f"ViT Essence Generator initialized on {self.device}") |
|
|
|
def _preprocess(self, x): |
|
return (x - self.imagenet_mean) / self.imagenet_std if self.expect_imagenet else x |
|
|
|
def setup_hooks(self, tag_idx): |
|
"""Setup hooks for multi-layer visualization.""" |
|
self.close_hooks() |
|
|
|
names = self.analyzer.get_visualization_layers(self.layer_emphasis) |
|
|
|
if not names: |
|
print("Warning: No suitable layers found for visualization") |
|
return {} |
|
|
|
print(f"Setting up hooks on {len(names)} ViT layer(s)") |
|
layer_weights = {} |
|
for i, layer_name in enumerate(names): |
|
try: |
|
layer_info = self.analyzer.layer_info[layer_name] |
|
layer_module = layer_info['module'] |
|
self.hooks[layer_name] = ViTLayerHook(layer_module, layer_name) |
|
|
|
weight = 0.3 + 0.7 * (i / max(1, len(names) - 1)) |
|
layer_weights[layer_name] = weight |
|
print(f" - {layer_name} (block {layer_info['block_idx']}, weight: {weight:.2f})") |
|
except Exception as e: |
|
print(f"Failed to setup hook for {layer_name}: {e}") |
|
|
|
self.selected_layers = names |
|
return layer_weights |
|
|
|
def close_hooks(self): |
|
"""Clean up hooks to avoid memory leaks.""" |
|
for hook in self.hooks.values(): |
|
hook.close() |
|
self.hooks.clear() |
|
|
|
def _fourier_init(self, size=224, decay=1.5): |
|
H = W = size |
|
|
|
spec = torch.randn(1, 3, H, W//2 + 1, dtype=torch.complex64, device=self.device) |
|
fy = torch.fft.fftfreq(H, device=self.device).abs().view(H, 1) |
|
fx = torch.fft.rfftfreq(W, device=self.device).abs().view(1, W//2 + 1) |
|
radius = (fy**2 + fx**2).sqrt().clamp_(min=1e-6) |
|
spec = spec * (1.0 / (radius ** decay)) |
|
img = torch.fft.irfft2(spec, s=(H, W)) |
|
|
|
img = (img - img.min(dim=-1, keepdim=True)[0].min(dim=-2, keepdim=True)[0]) |
|
img = img / (img.max(dim=-1, keepdim=True)[0].max(dim=-2, keepdim=True)[0] + 1e-8) |
|
return img |
|
|
|
def create_optimizable_image(self, size=224, use_fourier=True): |
|
if use_fourier: |
|
with torch.no_grad(): |
|
image = self._fourier_init(size) |
|
image = image.to(self.device) |
|
else: |
|
image = torch.rand(1, 3, size, size, device=self.device) |
|
image = image.detach().contiguous().requires_grad_(True) |
|
return image |
|
|
|
def total_variation_loss(self, image): |
|
|
|
diff_y = torch.abs(image[:, :, 1:, :] - image[:, :, :-1, :]) |
|
diff_x = torch.abs(image[:, :, :, 1:] - image[:, :, :, :-1]) |
|
tv_per_sample = diff_y.mean(dim=(1,2,3)) + diff_x.mean(dim=(1,2,3)) |
|
return tv_per_sample.mean() |
|
|
|
def get_feature_activations(self, layer_weights, topk_channels=None): |
|
total = 0.0 |
|
for name, hook in self.hooks.items(): |
|
feats = hook.features |
|
if feats is None: |
|
continue |
|
w = layer_weights.get(name, 0.5) |
|
|
|
agg = feats.sum(dim=1) |
|
if topk_channels is not None and topk_channels > 0 and agg.shape[1] > topk_channels: |
|
|
|
vals, _ = torch.topk(agg, k=topk_channels, dim=1) |
|
act = vals.mean() |
|
else: |
|
act = agg.mean() |
|
total = total + w * act |
|
return total |
|
|
|
def generate_essence(self, tag_idx, neighbor_count=8, image_size=224, return_score=True, progress_callback=None): |
|
"""Generate an essence visualization for a ViT model.""" |
|
|
|
tag_name = self.tag_to_name.get(tag_idx, f"Tag {tag_idx}") if self.tag_to_name else f"Tag {tag_idx}" |
|
print(f"Generating ViT essence for '{tag_name}' (index: {tag_idx})...") |
|
|
|
|
|
layer_weights = self.setup_hooks(tag_idx) |
|
|
|
if not self.hooks and not hasattr(self.model, 'head'): |
|
print("Warning: No hooks set up and no classifier head found") |
|
return self._create_fallback_image(image_size), 0.0 |
|
|
|
|
|
image = self.create_optimizable_image(image_size) |
|
|
|
|
|
optimizer = torch.optim.Adam([image], lr=self.lr) |
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( |
|
optimizer, T_max=self.iterations, eta_min=self.lr * 0.01 |
|
) |
|
|
|
best_score = -float('inf') |
|
best_image = None |
|
|
|
print(f"Starting optimization for {self.iterations} iterations...") |
|
|
|
|
|
pos_idx, pos_w, neg_idx, neg_w = neighbor_sets_from_embedding( |
|
self.model, tag_idx, k_pos=neighbor_count, k_neg=neighbor_count |
|
) |
|
|
|
for i in range(self.iterations): |
|
optimizer.zero_grad() |
|
|
|
|
|
for hook in self.hooks.values(): |
|
hook.features = None |
|
|
|
|
|
aug_batch = _augment_batch(image, K=self.ensemble_K, noise_std=0.15) |
|
out = self.model(self._preprocess(aug_batch)) |
|
logits = out["refined_predictions"] if isinstance(out, dict) else out |
|
|
|
cls_term = weighted_class_objective( |
|
logits, main_idx=tag_idx, |
|
plus_idxs=pos_idx, plus_w=pos_w, alpha=0.25, |
|
minus_idxs=neg_idx, minus_w=neg_w, beta=0.15 |
|
) |
|
|
|
|
|
feat_term = 0.0 |
|
if self.hooks: |
|
feats = self.get_feature_activations(layer_weights, topk_channels=64) |
|
|
|
Ltv = self.total_variation_loss(aug_batch) |
|
total_loss = -(cls_term + 0.5 * feat_term) + self.tv_weight * Ltv |
|
|
|
|
|
total_loss.backward() |
|
if image.grad is None or not torch.isfinite(image.grad).all(): |
|
print("WARN: no/invalid grad reaching the image; check hook & loss wiring.") |
|
|
|
|
|
torch.nn.utils.clip_grad_norm_([image], max_norm=3.0) |
|
optimizer.step() |
|
scheduler.step() |
|
|
|
|
|
with torch.no_grad(): |
|
image.clamp_(0.0, 1.0) |
|
|
|
|
|
if not torch.isfinite(total_loss.detach()): |
|
print("WARN: non-finite loss; resetting image step") |
|
optimizer.zero_grad(set_to_none=True) |
|
with torch.no_grad(): |
|
|
|
image.add_(0.05 * torch.randn_like(image)).clamp_(0.0, 1.0) |
|
continue |
|
|
|
|
|
with torch.no_grad(): |
|
score_tensor = -(total_loss - self.tv_weight * Ltv) |
|
current_score = float(score_tensor.item()) |
|
|
|
if current_score > best_score: |
|
best_score = current_score |
|
best_image = image.detach().clone() |
|
|
|
|
|
if progress_callback and i % max(1, self.iterations // 20) == 0: |
|
progress_callback( |
|
scale_idx=0, |
|
scale_count=1, |
|
iter_idx=i, |
|
iter_count=self.iterations, |
|
score=current_score |
|
) |
|
|
|
|
|
if i % max(1, self.iterations // 10) == 0: |
|
print(f"Iteration {i}/{self.iterations}: Score = {current_score:.4f}") |
|
|
|
|
|
if best_image is not None: |
|
final_image = best_image |
|
else: |
|
final_image = image.detach() |
|
|
|
|
|
final_image = torch.clamp(final_image, 0, 1) |
|
pil_img = to_pil_image(final_image[0].cpu()) |
|
|
|
|
|
self.close_hooks() |
|
|
|
print(f"ViT essence generation complete for '{tag_name}'. Final score: {best_score:.4f}") |
|
|
|
if return_score: |
|
return pil_img, best_score |
|
else: |
|
return pil_img |
|
|
|
def _create_fallback_image(self, size): |
|
"""Create a fallback image when generation fails.""" |
|
|
|
image = torch.randn(1, 3, size, size) * 0.5 + 0.5 |
|
image = torch.clamp(image, 0, 1) |
|
return to_pil_image(image[0]) |
|
|
|
|
|
def get_quality_level(score): |
|
"""Determine the quality level of an essence based on its score""" |
|
for level in reversed(list(ESSENCE_QUALITY_LEVELS.keys())): |
|
if score >= ESSENCE_QUALITY_LEVELS[level]["threshold"]: |
|
return level |
|
return "ZAYIN" |
|
|
|
def get_essence_cost(rarity): |
|
"""Calculate the cost to generate an essence image based on tag rarity""" |
|
return ESSENCE_COSTS.get(rarity, 100) |
|
|
|
def save_essence_to_game_folder(image, tag, score, quality_level): |
|
"""Save the generated essence image to a persistent game folder""" |
|
|
|
base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
|
game_data_dir = os.path.join(base_dir, "game_data") |
|
essence_folder = os.path.join(game_data_dir, "essences") |
|
|
|
|
|
os.makedirs(game_data_dir, exist_ok=True) |
|
os.makedirs(essence_folder, exist_ok=True) |
|
|
|
|
|
quality_folder = os.path.join(essence_folder, quality_level) |
|
os.makedirs(quality_folder, exist_ok=True) |
|
|
|
|
|
safe_tag = tag.replace('/', '_').replace('\\', '_').replace(' ', '_') |
|
timestamp = time.strftime("%Y%m%d_%H%M%S") |
|
filename = f"{safe_tag}_{score:.2f}_{timestamp}.png" |
|
filepath = os.path.join(quality_folder, filename) |
|
|
|
|
|
image.save(filepath) |
|
|
|
print(f"Saved ViT essence to: {filepath}") |
|
return filepath |
|
|
|
def load_tagger_metadata(): |
|
"""Load the camie-tagger-v2-metadata.json file from parent directory.""" |
|
try: |
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
parent_dir = os.path.dirname(current_dir) |
|
metadata_path = os.path.join(parent_dir, "camie-tagger-v2-metadata.json") |
|
|
|
if os.path.exists(metadata_path): |
|
with open(metadata_path, 'r', encoding='utf-8') as f: |
|
metadata = json.load(f) |
|
print(f"Loaded tagger metadata from: {metadata_path}") |
|
return metadata |
|
else: |
|
print(f"Metadata file not found at: {metadata_path}") |
|
return None |
|
except Exception as e: |
|
print(f"Error loading tagger metadata: {e}") |
|
return None |
|
|
|
@lru_cache(maxsize=1) |
|
def _load_tagger_metadata_cached(): |
|
meta = load_tagger_metadata() |
|
return meta or {} |
|
|
|
def resolve_tag_index(tag, dataset=None): |
|
"""Robustly resolve tag -> index using dataset, session metadata, then camie-tagger-v2-metadata.json.""" |
|
if not isinstance(tag, str): |
|
return int(tag) |
|
|
|
|
|
cands = {tag.strip(), tag.strip().replace(" ", "_")} |
|
cands |= {c.lower() for c in list(cands)} |
|
|
|
|
|
if dataset is not None and hasattr(dataset, "tag_to_idx"): |
|
for c in cands: |
|
if c in dataset.tag_to_idx: |
|
return int(dataset.tag_to_idx[c]) |
|
|
|
|
|
sm = getattr(st.session_state, "metadata", {}) or {} |
|
m = sm.get("tag_to_idx", {}) if isinstance(sm, dict) else {} |
|
for c in cands: |
|
if c in m: |
|
return int(m[c]) |
|
|
|
|
|
meta = _load_tagger_metadata_cached() |
|
mjson = (meta.get("dataset_info", {}) |
|
.get("tag_mapping", {}) |
|
.get("tag_to_idx", {})) if isinstance(meta, dict) else {} |
|
for c in cands: |
|
if c in mjson: |
|
return int(mjson[c]) |
|
|
|
return None |
|
|
|
def generate_essence_for_tag(tag, model, dataset, custom_settings=None): |
|
""" |
|
Generate an essence image for a specific tag using the ViT generator |
|
|
|
Args: |
|
tag: The tag name or index |
|
model: The ViT model to use |
|
dataset: The dataset containing tag information |
|
custom_settings: Optional dictionary with custom generation settings |
|
|
|
Returns: |
|
PIL Image of the generated essence, score, quality level |
|
""" |
|
|
|
print(f"\n=== Starting ViT essence generation for tag '{tag}' ===") |
|
|
|
|
|
is_manual_tag = hasattr(st.session_state, 'manual_tags') and tag in st.session_state.manual_tags |
|
is_discovered = hasattr(st.session_state, 'discovered_tags') and tag in st.session_state.discovered_tags |
|
|
|
if not is_discovered and not is_manual_tag: |
|
st.error(f"Tag '{tag}' has not been discovered yet.") |
|
return None, 0, None |
|
|
|
|
|
if is_discovered: |
|
rarity = st.session_state.discovered_tags[tag].get("rarity", "Canard") |
|
elif is_manual_tag: |
|
rarity = st.session_state.manual_tags[tag].get("rarity", "Canard") |
|
else: |
|
rarity = "Canard" |
|
|
|
|
|
cost = get_essence_cost(rarity) |
|
|
|
|
|
if st.session_state.enkephalin < cost: |
|
st.error(f"Not enough {ENKEPHALIN_CURRENCY_NAME} to generate this essence. You need {cost} {ENKEPHALIN_ICON} but have {st.session_state.enkephalin} {ENKEPHALIN_ICON}.") |
|
return None, 0, None |
|
|
|
|
|
settings = custom_settings or DEFAULT_ESSENCE_SETTINGS.copy() |
|
print(f"Using settings: {settings}") |
|
|
|
|
|
preview_container = st.empty() |
|
progress_container = st.empty() |
|
message_container = st.empty() |
|
|
|
try: |
|
message_container.info(f"Generating ViT essence for '{tag}' with {settings.get('layer_emphasis', 'balanced')} layer emphasis...") |
|
|
|
|
|
def progress_callback(scale_idx, scale_count, iter_idx, iter_count, score): |
|
progress = iter_idx / iter_count |
|
progress_container.progress(progress, f"Iteration {iter_idx}/{iter_count}") |
|
message_container.info(f"Current score: {score:.4f}") |
|
|
|
if iter_idx % 50 == 0: |
|
print(f"Progress: Iteration {iter_idx}/{iter_count}, Score: {score:.4f}") |
|
|
|
|
|
tag_idx = None |
|
|
|
if isinstance(tag, str): |
|
tag_idx = resolve_tag_index(tag, dataset) |
|
if tag_idx is None: |
|
st.error( |
|
f"Tag '{tag}' index not found in dataset or metadata. " |
|
f"Make sure it exists in camie-tagger-v2-metadata.json." |
|
) |
|
return None, 0, None |
|
else: |
|
tag_idx = int(tag) |
|
|
|
print(f"Resolved tag '{tag}' -> index {tag_idx}") |
|
|
|
|
|
tag_to_name = {tag_idx: tag} |
|
|
|
|
|
torch_model = getattr(st.session_state, "model_torch", None) |
|
if not isinstance(torch_model, nn.Module): |
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
parent_dir = os.path.dirname(current_dir) |
|
ckpt = os.path.join(parent_dir, "camie-tagger-v2.safetensors") |
|
if not os.path.exists(ckpt): |
|
st.error(f"Missing safetensors checkpoint at: {ckpt}") |
|
return None, 0, None |
|
|
|
|
|
meta = _load_tagger_metadata_cached() |
|
num_tags = int(meta.get("dataset_info", {}).get("total_tags", 70527)) |
|
img_size = int(meta.get("model_info", {}).get("img_size", 512)) |
|
|
|
torch_model = build_torch_model_from_safetensors( |
|
ckpt_path=ckpt, |
|
num_tags=num_tags, |
|
backbone="vit_base_patch16_384", |
|
img_size=img_size |
|
) |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
torch_model = torch_model.to(device).eval() |
|
st.session_state.model_torch = torch_model |
|
|
|
|
|
generator = ViTEssenceGenerator( |
|
model=torch_model, |
|
tag_to_name=tag_to_name, |
|
iterations=settings.get('iterations', 500), |
|
learning_rate=settings.get('lr', 0.05), |
|
layer_emphasis=settings.get('layer_emphasis', 'balanced'), |
|
ensemble_K=settings.get('ensemble_k', 8), |
|
tv_weight=settings.get('tv_weight', 1e-3) |
|
) |
|
|
|
image, score = generator.generate_essence( |
|
tag_idx=tag_idx, |
|
neighbor_count=settings.get('neighbor_count', 8), |
|
image_size=settings.get('image_size', 512), |
|
return_score=True, |
|
progress_callback=progress_callback |
|
) |
|
|
|
|
|
quality_level = get_quality_level(score) |
|
|
|
|
|
st.session_state.enkephalin -= cost |
|
st.session_state.game_stats["enkephalin_spent"] = st.session_state.game_stats.get("enkephalin_spent", 0) + cost |
|
|
|
|
|
st.session_state.game_stats["essences_generated"] = st.session_state.game_stats.get("essences_generated", 0) + 1 |
|
|
|
|
|
filepath = save_essence_to_game_folder(image, tag, score, quality_level) |
|
|
|
|
|
preview_container.image(image, caption=f"ViT Essence of '{tag}' - Quality: {quality_level}", width=400) |
|
|
|
|
|
progress_container.empty() |
|
message_container.empty() |
|
|
|
|
|
if 'generated_essences' not in st.session_state: |
|
st.session_state.generated_essences = {} |
|
|
|
st.session_state.generated_essences[tag] = { |
|
"path": filepath, |
|
"score": score, |
|
"quality": quality_level, |
|
"rarity": rarity, |
|
"settings": settings, |
|
"generated_time": time.strftime("%Y-%m-%d %H:%M:%S") |
|
} |
|
|
|
|
|
st.success(f"Successfully generated {quality_level} ViT essence for '{tag}' with score {score:.4f}! Spent {cost} {ENKEPHALIN_ICON}") |
|
print(f"=== ViT essence generation complete for '{tag}' ===\n") |
|
|
|
|
|
tag_storage.save_essence_state(session_state=st.session_state) |
|
|
|
return image, score, quality_level |
|
|
|
except Exception as e: |
|
st.error(f"Error generating ViT essence: {str(e)}") |
|
print(f"EXCEPTION in generate_essence_for_tag: {str(e)}") |
|
import traceback |
|
err_traceback = traceback.format_exc() |
|
print(err_traceback) |
|
st.code(err_traceback) |
|
return None, 0, None |
|
|
|
|
|
def get_model_layers(model): |
|
"""Utility function to get all available layers in a model.""" |
|
layers = [] |
|
for name, _ in model.named_modules(): |
|
if name: |
|
layers.append(name) |
|
return layers |
|
|
|
def get_key_layers(model, max_layers=15): |
|
""" |
|
Get a curated list of the most relevant layers for visualization. |
|
""" |
|
all_layers = get_model_layers(model) |
|
|
|
|
|
if len(all_layers) > 30: |
|
|
|
block_patterns = {} |
|
|
|
|
|
for layer in all_layers: |
|
|
|
parts = layer.split(".") |
|
if len(parts) >= 2: |
|
prefix = ".".join(parts[:2]) |
|
if prefix not in block_patterns: |
|
block_patterns[prefix] = [] |
|
block_patterns[prefix].append(layer) |
|
|
|
|
|
key_layers = { |
|
"early": [], |
|
"middle": [], |
|
"late": [] |
|
} |
|
|
|
|
|
for prefix, layers in block_patterns.items(): |
|
if len(layers) > 3: |
|
|
|
layers.sort(key=lambda x: [int(s) if s.isdigit() else s for s in re.findall(r'\d+|\D+', x)]) |
|
|
|
|
|
early = layers[0] |
|
middle = layers[len(layers) // 2] |
|
late = layers[-1] |
|
|
|
key_layers["early"].append(early) |
|
key_layers["middle"].append(middle) |
|
key_layers["late"].append(late) |
|
|
|
|
|
|
|
flattened = [] |
|
for _, group_layers in key_layers.items(): |
|
flattened.extend(group_layers) |
|
|
|
if len(flattened) > max_layers: |
|
|
|
total = len(flattened) |
|
|
|
late_count = min(len(key_layers["late"]), max_layers // 3) |
|
|
|
remaining = max_layers - late_count |
|
middle_count = min(len(key_layers["middle"]), remaining // 2) |
|
early_count = min(len(key_layers["early"]), remaining - middle_count) |
|
|
|
|
|
key_layers["early"] = key_layers["early"][:early_count] |
|
key_layers["middle"] = key_layers["middle"][:middle_count] |
|
key_layers["late"] = key_layers["late"][:late_count] |
|
else: |
|
|
|
n = len(all_layers) |
|
key_layers = { |
|
"early": all_layers[:n//3][:3], |
|
"middle": all_layers[n//3:2*n//3][:4], |
|
"late": all_layers[2*n//3:][:3] |
|
} |
|
|
|
|
|
classifier_layers = [layer for layer in all_layers if any(x in layer.lower() |
|
for x in ["classifier", "fc", "linear", "output", "logits", "head"])] |
|
if classifier_layers: |
|
key_layers["classifier"] = [classifier_layers[-1]] |
|
|
|
return key_layers |
|
|
|
def get_suggested_layers(model, layer_type="balanced"): |
|
""" |
|
Get suggested layers based on the desired feature type. |
|
""" |
|
key_layers = get_key_layers(model) |
|
|
|
|
|
all_key_layers = [] |
|
for layers in key_layers.values(): |
|
all_key_layers.extend(layers) |
|
|
|
|
|
if layer_type == "low": |
|
|
|
selected = key_layers.get("early", []) |
|
|
|
if "middle" in key_layers and key_layers["middle"]: |
|
selected.append(key_layers["middle"][0]) |
|
|
|
elif layer_type == "mid": |
|
|
|
selected = key_layers.get("middle", []) |
|
|
|
if "early" in key_layers and key_layers["early"]: |
|
selected.append(key_layers["early"][-1]) |
|
|
|
elif layer_type == "high": |
|
|
|
selected = key_layers.get("late", []) |
|
selected.extend(key_layers.get("classifier", [])) |
|
|
|
if "middle" in key_layers and key_layers["middle"]: |
|
selected.append(key_layers["middle"][-1]) |
|
|
|
else: |
|
|
|
selected = [] |
|
for category in ["early", "middle", "late", "classifier"]: |
|
if category in key_layers and key_layers[category]: |
|
|
|
selected.append(key_layers[category][0]) |
|
|
|
if category in ["middle", "late"] and len(key_layers[category]) > 1: |
|
selected.append(key_layers[category][-1]) |
|
|
|
|
|
if not selected and all_key_layers: |
|
selected = [all_key_layers[-1]] |
|
|
|
return selected |
|
|
|
|
|
def display_essence_generator(): |
|
""" |
|
Display the essence generator interface |
|
""" |
|
|
|
initialize_essence_settings() |
|
|
|
st.title("🎨 Tag Essence Generator") |
|
st.write("Generate visual representations of what the AI model recognizes for specific tags.") |
|
|
|
|
|
with st.expander("What are Tag Essences & How to Use Them", expanded=True): |
|
st.markdown(""" |
|
### 💡 Understanding Tag Essences |
|
|
|
Tag Essences are visual representations of what the AI model recognizes for specific tags. They can be extremely valuable for your tag collection strategy! |
|
|
|
**How to use Tag Essences:** |
|
1. **Generate a high-quality essence** for a tag you want to collect more of (only available on tags discovered in the library) |
|
2. **Save the essence image** to your computer |
|
3. **Upload the essence image** back into the tagger |
|
4. The tagger will **almost always detect the original tag** |
|
5. It will often also **detect related rare tags** from the same category |
|
|
|
**Strategic Value:** |
|
- Character essences can help unlock other tags associated with that character |
|
- Category essences can help discover rare tags within that category |
|
- High-quality essences (WAW, ALEPH) have the strongest effect |
|
|
|
**This is why Enkephalin costs are high** - essences are powerful tools that can help you discover rare tags much more efficiently than random image scanning! |
|
""") |
|
|
|
|
|
|
|
model_available = hasattr(st.session_state, 'model') |
|
if not model_available: |
|
st.warning("Model not available. You can browse your tags but cannot generate essences.") |
|
|
|
|
|
tabs = st.tabs(["Generate Essence", "My Essences"]) |
|
|
|
with tabs[0]: |
|
|
|
if hasattr(st.session_state, 'selected_tag') and st.session_state.selected_tag: |
|
tag = st.session_state.selected_tag |
|
|
|
st.subheader(f"Generating Essence for '{tag}'") |
|
|
|
|
|
image, score, quality = generate_essence_for_tag( |
|
tag, |
|
st.session_state.model, |
|
st.session_state.model.dataset, |
|
st.session_state.essence_custom_settings |
|
) |
|
|
|
|
|
if image is not None: |
|
with st.expander("Essence Usage", expanded=True): |
|
st.markdown(""" |
|
💡 **Tag Essence Usage Tips:** |
|
1. Look for similar patterns, colors, and elements in real images |
|
2. The essence reveals what features the AI model recognizes for this tag |
|
3. Use this as inspiration when creating or finding images to get this tag |
|
""") |
|
else: |
|
st.error("Essence generation failed. Please check the error messages above and try again with different settings.") |
|
|
|
|
|
st.session_state.selected_tag = None |
|
else: |
|
|
|
selected_tag = display_essence_generation_interface(model_available) |
|
|
|
|
|
if selected_tag: |
|
st.session_state.selected_tag = selected_tag |
|
st.rerun() |
|
|
|
with tabs[1]: |
|
display_saved_essences() |
|
|
|
def essence_folder_path(): |
|
"""Get the path to the essence folder, creating it if necessary""" |
|
base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
|
game_data_dir = os.path.join(base_dir, "game_data") |
|
essence_folder = os.path.join(game_data_dir, "essences") |
|
|
|
|
|
os.makedirs(game_data_dir, exist_ok=True) |
|
os.makedirs(essence_folder, exist_ok=True) |
|
|
|
return essence_folder |
|
|
|
def display_saved_essences(): |
|
"""Display the user's saved essence images""" |
|
st.subheader("My Generated Essences") |
|
|
|
if not hasattr(st.session_state, 'generated_essences') or not st.session_state.generated_essences: |
|
st.info("You haven't generated any essences yet. Go to the Generate tab to create some!") |
|
return |
|
|
|
|
|
st.markdown(""" |
|
### How to Use Your Essences |
|
|
|
1. **Click on any essence image** to open it in full size |
|
2. **Save the image** to your computer (right-click → Save image) |
|
3. **Go to the Scan Images tab** and upload the saved essence |
|
4. The tagger will likely detect the original tag and potentially related rare tags! |
|
|
|
Higher quality essences (WAW, ALEPH) generally produce the best results. |
|
""") |
|
|
|
|
|
essence_dir = essence_folder_path() |
|
|
|
|
|
for tag, info in st.session_state.generated_essences.items(): |
|
if "path" in info and not os.path.exists(info["path"]): |
|
|
|
quality = info.get("quality", "ZAYIN") |
|
quality_dir = os.path.join(essence_dir, quality) |
|
|
|
if os.path.exists(quality_dir): |
|
|
|
safe_tag = tag.replace('/', '_').replace('\\', '_').replace(' ', '_') |
|
matching_files = [f for f in os.listdir(quality_dir) if f.startswith(safe_tag)] |
|
|
|
if matching_files: |
|
|
|
matching_files.sort(reverse=True) |
|
info["path"] = os.path.join(quality_dir, matching_files[0]) |
|
print(f"Reconnected essence for {tag} to {info['path']}") |
|
|
|
|
|
essences_by_quality = {} |
|
for tag, info in st.session_state.generated_essences.items(): |
|
quality = info.get("quality", "ZAYIN") |
|
if quality not in essences_by_quality: |
|
essences_by_quality[quality] = [] |
|
essences_by_quality[quality].append((tag, info)) |
|
|
|
|
|
try: |
|
untracked_essences = {} |
|
|
|
for quality in ESSENCE_QUALITY_LEVELS.keys(): |
|
quality_dir = os.path.join(essence_dir, quality) |
|
if os.path.exists(quality_dir): |
|
essence_files = os.listdir(quality_dir) |
|
|
|
|
|
essence_files = [f for f in essence_files if f.lower().endswith('.png')] |
|
|
|
if essence_files: |
|
|
|
for filename in essence_files: |
|
|
|
parts = filename.split('_') |
|
if len(parts) >= 2: |
|
tag = parts[0].replace('_', ' ') |
|
|
|
|
|
is_tracked = False |
|
for tracked_tag, tracked_info in st.session_state.generated_essences.items(): |
|
if "path" in tracked_info and os.path.basename(tracked_info["path"]) == filename: |
|
is_tracked = True |
|
break |
|
|
|
if not is_tracked: |
|
if quality not in untracked_essences: |
|
untracked_essences[quality] = [] |
|
untracked_essences[quality].append((tag, { |
|
"path": os.path.join(quality_dir, filename), |
|
"quality": quality, |
|
"discovered_on_disk": True |
|
})) |
|
except Exception as e: |
|
print(f"Error checking for untracked essences: {e}") |
|
|
|
|
|
for quality, essences in untracked_essences.items(): |
|
if quality not in essences_by_quality: |
|
essences_by_quality[quality] = [] |
|
for tag, info in essences: |
|
|
|
if not any(tracked_tag == tag for tracked_tag, _ in essences_by_quality[quality]): |
|
essences_by_quality[quality].append((tag, info)) |
|
|
|
|
|
for quality in list(ESSENCE_QUALITY_LEVELS.keys())[::-1]: |
|
if quality in essences_by_quality: |
|
essences = essences_by_quality[quality] |
|
color = ESSENCE_QUALITY_LEVELS[quality]["color"] |
|
|
|
with st.expander(f"{quality} Essences ({len(essences)})", expanded=quality in ["ALEPH", "WAW"]): |
|
|
|
cols = st.columns(3) |
|
for i, (tag, info) in enumerate(sorted(essences, key=lambda x: x[1].get("score", 0), reverse=True)): |
|
col_idx = i % 3 |
|
with cols[col_idx]: |
|
try: |
|
|
|
if "path" in info and os.path.exists(info["path"]): |
|
image = Image.open(info["path"]) |
|
rarity = info.get("rarity", "Canard") |
|
score = info.get("score", 0) |
|
|
|
|
|
rarity_color = RARITY_LEVELS.get(rarity, {}).get("color", "#AAAAAA") |
|
|
|
|
|
st.image(image, caption=tag, use_container_width=True) |
|
|
|
|
|
if rarity == "Impuritas Civitas": |
|
st.markdown(f""" |
|
<span style='color:{color};font-weight:bold;'>{quality}</span> | |
|
<span style='animation: rainbow-text 4s linear infinite;font-weight:bold;'>{rarity}</span> | |
|
Score: {score:.2f} |
|
""", unsafe_allow_html=True) |
|
elif rarity == "Star of the City": |
|
st.markdown(f""" |
|
<span style='color:{color};font-weight:bold;'>{quality}</span> | |
|
<span style='color:{rarity_color};text-shadow:0 0 3px gold;font-weight:bold;'>{rarity}</span> | |
|
Score: {score:.2f} |
|
""", unsafe_allow_html=True) |
|
elif rarity == "Urban Nightmare": |
|
st.markdown(f""" |
|
<span style='color:{color};font-weight:bold;'>{quality}</span> | |
|
<span style='color:{rarity_color};text-shadow:0 0 1px #FF5722;font-weight:bold;'>{rarity}</span> | |
|
Score: {score:.2f} |
|
""", unsafe_allow_html=True) |
|
elif rarity == "Urban Plague": |
|
st.markdown(f""" |
|
<span style='color:{color};font-weight:bold;'>{quality}</span> | |
|
<span style='color:{rarity_color};text-shadow:0 0 1px #9C27B0;font-weight:bold;'>{rarity}</span> | |
|
Score: {score:.2f} |
|
""", unsafe_allow_html=True) |
|
else: |
|
st.markdown(f""" |
|
<span style='color:{color};font-weight:bold;'>{quality}</span> | |
|
<span style='color:{rarity_color};font-weight:bold;'>{rarity}</span> | |
|
Score: {score:.2f} |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
if "discovered_on_disk" in info and info["discovered_on_disk"]: |
|
st.info("Found on disk (not in session state)") |
|
|
|
|
|
if st.button(f"Open Folder", key=f"open_folder_{tag}_{quality}"): |
|
folder_path = os.path.dirname(info["path"]) |
|
try: |
|
|
|
if os.name == 'nt': |
|
os.startfile(folder_path) |
|
elif os.name == 'posix': |
|
import subprocess |
|
if 'darwin' in os.sys.platform: |
|
subprocess.call(['open', folder_path]) |
|
else: |
|
subprocess.call(['xdg-open', folder_path]) |
|
st.success(f"Opened folder: {folder_path}") |
|
except Exception as e: |
|
st.error(f"Could not open folder: {str(e)}") |
|
|
|
st.code(folder_path) |
|
else: |
|
|
|
st.warning(f"Image file not found: {info.get('path', 'No path available')}") |
|
|
|
|
|
st.markdown(f""" |
|
<span style='color:{color};font-weight:bold;'>{quality}</span> | {tag} |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
if "rarity" in info and "score" in info: |
|
if st.button(f"Reconnect File", key=f"reconnect_{tag}_{quality}"): |
|
|
|
safe_tag = tag.replace('/', '_').replace('\\', '_').replace(' ', '_') |
|
score = info.get("score", 0) |
|
quality_dir = os.path.join(essence_dir, quality) |
|
|
|
|
|
os.makedirs(quality_dir, exist_ok=True) |
|
|
|
|
|
timestamp = time.strftime("%Y%m%d_%H%M%S") |
|
filename = f"{safe_tag}_{score:.2f}_{timestamp}.png" |
|
info["path"] = os.path.join(quality_dir, filename) |
|
|
|
st.info(f"Please save your image to this location: {info['path']}") |
|
st.session_state.generated_essences[tag] = info |
|
tag_storage.save_essence_state(session_state=st.session_state) |
|
st.rerun() |
|
|
|
except Exception as e: |
|
st.write(f"Error loading {tag}: {str(e)}") |
|
|
|
|
|
st.divider() |
|
if st.button("Clean Up Missing Files", help="Remove entries for essences where the file no longer exists"): |
|
|
|
to_remove = [] |
|
for tag, info in st.session_state.generated_essences.items(): |
|
if "path" in info and not os.path.exists(info["path"]): |
|
to_remove.append(tag) |
|
|
|
|
|
for tag in to_remove: |
|
del st.session_state.generated_essences[tag] |
|
|
|
|
|
tag_storage.save_essence_state(session_state=st.session_state) |
|
|
|
if to_remove: |
|
st.success(f"Removed {len(to_remove)} entries with missing files") |
|
else: |
|
st.success("No missing files found") |
|
|
|
st.rerun() |
|
|
|
def display_essence_generation_interface(model_available): |
|
"""Display the interface for generating new essences""" |
|
|
|
initialize_manual_tags() |
|
|
|
st.subheader("Generate Tag Essence") |
|
st.write("Select a tag to generate its essence. Higher quality essences can help unlock rare related tags when uploaded back into the tagger.") |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
|
|
with col1: |
|
st.write("Generation Settings:") |
|
|
|
|
|
if st.button("Reset to Defaults", help="Clear saved settings and use default values"): |
|
st.session_state.essence_custom_settings = DEFAULT_ESSENCE_SETTINGS.copy() |
|
tag_storage.save_essence_state(session_state=st.session_state) |
|
st.success("Settings reset to defaults!") |
|
st.rerun() |
|
|
|
|
|
with st.expander("Advanced Settings", expanded=True): |
|
col_a, col_b = st.columns(2) |
|
|
|
with col_a: |
|
|
|
st.write("**Core Parameters**") |
|
iterations = st.slider( |
|
"Iterations", |
|
min_value=64, |
|
max_value=2048, |
|
value=st.session_state.essence_custom_settings.get("iterations", 500), |
|
step=64, |
|
help="More iterations improve quality but take longer" |
|
) |
|
|
|
lr = st.slider( |
|
"Learning Rate", |
|
min_value=0.01, |
|
max_value=0.2, |
|
value=st.session_state.essence_custom_settings.get("lr", 0.05), |
|
step=0.01, |
|
help="Higher learning rates converge faster but may be less stable" |
|
) |
|
|
|
ensemble_k = st.slider( |
|
"Ensemble Size", |
|
min_value=1, |
|
max_value=16, |
|
value=st.session_state.essence_custom_settings.get("ensemble_k", 8), |
|
step=1, |
|
help="Number of augmented versions per iteration. Higher = more stable but slower" |
|
) |
|
|
|
with col_b: |
|
|
|
st.write("**Multi-Tag Enhancement**") |
|
neighbor_count = st.slider( |
|
"Neighbor Tags", |
|
min_value=0, |
|
max_value=16, |
|
value=st.session_state.essence_custom_settings.get("neighbor_count", 8), |
|
step=1, |
|
help="Number of similar/dissimilar tags to consider. 0 = only target tag" |
|
) |
|
|
|
tv_weight = st.select_slider( |
|
"Smoothness", |
|
options=[1e-4, 5e-4, 1e-3, 2e-3, 5e-3, 1e-2], |
|
value=st.session_state.essence_custom_settings.get("tv_weight", 1e-3), |
|
format_func=lambda x: f"{x:.0e}", |
|
help="Higher values create smoother, less noisy images" |
|
) |
|
|
|
|
|
layer_emphasis = st.selectbox( |
|
"Feature Targeting", |
|
options=["balanced", "early", "mid", "late"], |
|
index=0, |
|
format_func=lambda x: { |
|
"balanced": "Balanced (mix of features)", |
|
"early": "Early (textures, patterns)", |
|
"mid": "Mid (parts, components)", |
|
"late": "Late (characters, objects)" |
|
}.get(x, x), |
|
help="Controls which model features to emphasize" |
|
) |
|
|
|
|
|
st.session_state.essence_custom_settings = { |
|
"iterations": iterations, |
|
"lr": lr, |
|
"ensemble_k": ensemble_k, |
|
"neighbor_count": neighbor_count, |
|
"image_size": 512, |
|
"layer_emphasis": layer_emphasis, |
|
"tv_weight": tv_weight |
|
} |
|
|
|
|
|
st.info(f""" |
|
**Current Settings:** |
|
- Iterations: {iterations} |
|
- Learning Rate: {lr} |
|
- Ensemble Size: {ensemble_k} |
|
- Neighbor Tags: {neighbor_count} |
|
- Feature Focus: {layer_emphasis.capitalize()} |
|
""") |
|
|
|
with col2: |
|
|
|
st.write("Quality Levels:") |
|
for level, info in ESSENCE_QUALITY_LEVELS.items(): |
|
st.markdown(f""" |
|
<div style="padding:5px;margin-bottom:5px;border-radius:4px;background-color:rgba({int(info['color'][1:3], 16)},{int(info['color'][3:5], 16)},{int(info['color'][5:7], 16)},0.1);border-left:3px solid {info['color']}"> |
|
<span style="color:{info['color']};font-weight:bold;">{level}</span> ({info['threshold']:.0f} Score+): {info['description']} |
|
</div> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
st.write("Feature Targeting Explanation:") |
|
st.markdown(""" |
|
- **Early**: Textures, colors, simple patterns |
|
- **Mid**: Parts, components, intermediate features |
|
- **Late**: Characters, objects, high-level concepts |
|
- **Balanced**: Mix of all feature levels |
|
""") |
|
|
|
|
|
st.markdown(f"### Your {ENKEPHALIN_CURRENCY_NAME}: **{st.session_state.enkephalin}** {ENKEPHALIN_ICON}") |
|
st.divider() |
|
|
|
|
|
st.markdown(""" |
|
<style> |
|
@keyframes rainbow-text { |
|
0% { color: red; } |
|
14% { color: orange; } |
|
28% { color: yellow; } |
|
42% { color: green; } |
|
57% { color: blue; } |
|
71% { color: indigo; } |
|
85% { color: violet; } |
|
100% { color: red; } |
|
} |
|
|
|
.impuritas-text { |
|
font-weight: bold; |
|
animation: rainbow-text 4s linear infinite; |
|
} |
|
|
|
@keyframes glow-text { |
|
0% { text-shadow: 0 0 2px gold; } |
|
50% { text-shadow: 0 0 6px gold; } |
|
100% { text-shadow: 0 0 2px gold; } |
|
} |
|
|
|
.star-text { |
|
color: #FFEB3B; |
|
text-shadow: 0 0 3px gold; |
|
animation: glow-text 2s infinite; |
|
font-weight: bold; |
|
} |
|
|
|
@keyframes pulse-text { |
|
0% { opacity: 0.8; } |
|
50% { opacity: 1; } |
|
100% { opacity: 0.8; } |
|
} |
|
|
|
.nightmare-text { |
|
color: #FF9800; |
|
text-shadow: 0 0 1px #FF5722; |
|
animation: pulse-text 3s infinite; |
|
font-weight: bold; |
|
} |
|
|
|
.plague-text { |
|
color: #9C27B0; |
|
text-shadow: 0 0 1px #9C27B0; |
|
font-weight: bold; |
|
} |
|
|
|
.category-section { |
|
margin-top: 20px; |
|
margin-bottom: 30px; |
|
padding: 10px; |
|
border-radius: 5px; |
|
border-left: 5px solid; |
|
} |
|
</style> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
|
|
all_tags = [] |
|
|
|
|
|
if hasattr(st.session_state, 'discovered_tags'): |
|
for tag, info in st.session_state.discovered_tags.items(): |
|
tag_info = { |
|
"tag": tag, |
|
"rarity": info.get("rarity", "Unknown"), |
|
"category": info.get("category", "unknown"), |
|
"source": "discovered", |
|
"library_floor": info.get("library_floor", ""), |
|
"discovery_time": info.get("discovery_time", "") |
|
} |
|
all_tags.append(tag_info) |
|
|
|
|
|
if hasattr(st.session_state, 'manual_tags'): |
|
for tag, info in st.session_state.manual_tags.items(): |
|
tag_info = { |
|
"tag": tag, |
|
"rarity": info.get("rarity", "Special"), |
|
"category": info.get("category", "special"), |
|
"source": "manual", |
|
"description": info.get("description", "") |
|
} |
|
all_tags.append(tag_info) |
|
|
|
|
|
rarity_counts = {} |
|
for info in all_tags: |
|
rarity = info["rarity"] |
|
if rarity not in rarity_counts: |
|
rarity_counts[rarity] = 0 |
|
rarity_counts[rarity] += 1 |
|
|
|
|
|
st.subheader("Available Tags for Essence Generation") |
|
st.write(f"You have {len(all_tags)} tags available for essence generation. Collect more from the library!") |
|
|
|
|
|
rarity_cols = st.columns(len(rarity_counts)) |
|
for i, (rarity, count) in enumerate(sorted(rarity_counts.items(), |
|
key=lambda x: list(RARITY_LEVELS.keys()).index(x[0]) if x[0] in RARITY_LEVELS else 999)): |
|
with rarity_cols[i]: |
|
|
|
color = RARITY_LEVELS.get(rarity, {}).get("color", "#888888") |
|
|
|
|
|
style = f"color:{color};font-weight:bold;" |
|
class_name = "" |
|
|
|
if rarity == "Impuritas Civitas": |
|
class_name = "grid-impuritas" |
|
elif rarity == "Star of the City": |
|
class_name = "grid-star" |
|
elif rarity == "Urban Nightmare": |
|
class_name = "grid-nightmare" |
|
elif rarity == "Urban Plague": |
|
class_name = "grid-plague" |
|
|
|
if class_name: |
|
st.markdown( |
|
f"<div style='text-align:center;'><span class='{class_name}' style='font-weight:bold;'>{rarity.capitalize()}</span><br>{count}</div>", |
|
unsafe_allow_html=True |
|
) |
|
else: |
|
st.markdown( |
|
f"<div style='text-align:center;'><span style='{style}'>{rarity.capitalize()}</span><br>{count}</div>", |
|
unsafe_allow_html=True |
|
) |
|
|
|
|
|
search_term = st.text_input("Search tags", "", key="essence_search_tags") |
|
|
|
|
|
sort_options = ["Category (rarest first)", "Rarity", "Discovery Time"] |
|
selected_sort = st.selectbox("Sort tags by:", sort_options, key="essence_tags_sort") |
|
|
|
|
|
if search_term: |
|
all_tags = [info for info in all_tags if search_term.lower() in info["tag"].lower()] |
|
|
|
selected_tag = None |
|
|
|
|
|
if selected_sort == "Category (rarest first)": |
|
|
|
categories = {} |
|
for info in all_tags: |
|
category = info["category"] |
|
if category not in categories: |
|
categories[category] = [] |
|
categories[category].append(info) |
|
|
|
|
|
for category, tags in sorted(categories.items()): |
|
|
|
rarity_order = list(reversed(RARITY_LEVELS.keys())) |
|
|
|
|
|
def get_rarity_index(info): |
|
rarity = info["rarity"] |
|
if rarity in rarity_order: |
|
return len(rarity_order) - rarity_order.index(rarity) |
|
return 0 |
|
|
|
sorted_tags = sorted(tags, key=get_rarity_index, reverse=True) |
|
|
|
|
|
has_rare_tags = any(info["rarity"] in ["Impuritas Civitas", "Star of the City"] |
|
for info in sorted_tags) |
|
|
|
|
|
category_display = category.capitalize() |
|
if category in TAG_CATEGORIES: |
|
category_info = TAG_CATEGORIES[category] |
|
icon = category_info.get("icon", "") |
|
color = category_info.get("color", "#888888") |
|
category_display = f"<span style='color:{color};'>{icon} {category.capitalize()}</span>" |
|
|
|
|
|
header = f"{category_display} ({len(tags)} tags)" |
|
if has_rare_tags: |
|
header += " ✨ Contains rare tags!" |
|
|
|
|
|
st.markdown(header, unsafe_allow_html=True) |
|
with st.expander("Show/Hide", expanded=has_rare_tags): |
|
|
|
cols = st.columns(3) |
|
for i, info in enumerate(sorted_tags): |
|
with cols[i % 3]: |
|
tag = info["tag"] |
|
rarity = info["rarity"] |
|
source = info["source"] |
|
|
|
|
|
rarity_color = RARITY_LEVELS.get(rarity, {}).get("color", "#AAAAAA") |
|
|
|
|
|
has_essence = hasattr(st.session_state, 'generated_essences') and tag in st.session_state.generated_essences |
|
|
|
|
|
cost = get_essence_cost(rarity) |
|
can_afford = st.session_state.enkephalin >= cost |
|
|
|
|
|
if rarity == "Impuritas Civitas": |
|
tag_display = f'<span class="impuritas-text">{tag}</span>' |
|
elif rarity == "Star of the City": |
|
tag_display = f'<span class="star-text">{tag}</span>' |
|
elif rarity == "Urban Nightmare": |
|
tag_display = f'<span class="nightmare-text">{tag}</span>' |
|
elif rarity == "Urban Plague": |
|
tag_display = f'<span class="plague-text">{tag}</span>' |
|
else: |
|
tag_display = f'<span style="color:{rarity_color};font-weight:bold;">{tag}</span>' |
|
|
|
|
|
st.markdown( |
|
f'{tag_display} <span style="background-color:{rarity_color};color:white;padding:2px 6px;border-radius:10px;font-size:0.8em;">{rarity.capitalize()}</span> ({cost} {ENKEPHALIN_ICON})', |
|
unsafe_allow_html=True |
|
) |
|
|
|
|
|
if source == "discovered" and "library_floor" in info and info["library_floor"]: |
|
st.markdown(f'<span style="font-size:0.85em;">Found in: {info["library_floor"]}</span>', |
|
unsafe_allow_html=True) |
|
elif source == "manual" and "description" in info and info["description"]: |
|
st.markdown(f'<span style="font-size:0.85em;font-style:italic;">{info["description"]}</span>', |
|
unsafe_allow_html=True) |
|
|
|
|
|
button_label = "Generate" if not has_essence else "Regenerate ✓" |
|
if st.button(button_label, key=f"gen_{tag}_{source}", disabled=not model_available or not can_afford): |
|
selected_tag = tag |
|
|
|
elif selected_sort == "Rarity": |
|
|
|
rarity_groups = {} |
|
for info in all_tags: |
|
rarity = info["rarity"] |
|
if rarity not in rarity_groups: |
|
rarity_groups[rarity] = [] |
|
rarity_groups[rarity].append(info) |
|
|
|
|
|
ordered_rarities = list(RARITY_LEVELS.keys()) |
|
ordered_rarities.reverse() |
|
|
|
|
|
for rarity in rarity_groups.keys(): |
|
if rarity not in ordered_rarities: |
|
ordered_rarities.append(rarity) |
|
|
|
|
|
for rarity in ordered_rarities: |
|
if rarity in rarity_groups: |
|
tags = rarity_groups[rarity] |
|
color = RARITY_LEVELS.get(rarity, {}).get("color", "#AAAAAA") |
|
|
|
|
|
rarity_html = f"<span style='color:{color};font-weight:bold;'>{rarity.capitalize()}</span>" |
|
if rarity == "Impuritas Civitas": |
|
rarity_html = f"<span style='animation:rainbow-text 4s linear infinite;font-weight:bold;'>{rarity.capitalize()}</span>" |
|
elif rarity == "Star of the City": |
|
rarity_html = f"<span style='color:{color};text-shadow:0 0 3px gold;font-weight:bold;'>{rarity.capitalize()}</span>" |
|
elif rarity == "Urban Nightmare": |
|
rarity_html = f"<span style='color:{color};text-shadow:0 0 1px #FF5722;font-weight:bold;'>{rarity.capitalize()}</span>" |
|
|
|
|
|
st.markdown(f"### {rarity_html} ({len(tags)} tags)", unsafe_allow_html=True) |
|
with st.expander("Show/Hide", expanded=rarity in ["Impuritas Civitas", "Star of the City"]): |
|
|
|
cols = st.columns(3) |
|
for i, info in enumerate(sorted(tags, key=lambda x: x["tag"])): |
|
with cols[i % 3]: |
|
tag = info["tag"] |
|
source = info["source"] |
|
|
|
|
|
has_essence = hasattr(st.session_state, 'generated_essences') and tag in st.session_state.generated_essences |
|
|
|
|
|
cost = get_essence_cost(rarity) |
|
can_afford = st.session_state.enkephalin >= cost |
|
|
|
|
|
st.markdown(f"**{tag}** ({cost} {ENKEPHALIN_ICON})") |
|
|
|
|
|
if source == "discovered" and "library_floor" in info and info["library_floor"]: |
|
st.markdown(f'<span style="font-size:0.85em;">Found in: {info["library_floor"]}</span>', |
|
unsafe_allow_html=True) |
|
elif source == "manual" and "description" in info and info["description"]: |
|
st.markdown(f'<span style="font-size:0.85em;font-style:italic;">{info["description"]}</span>', |
|
unsafe_allow_html=True) |
|
|
|
|
|
button_label = "Generate" if not has_essence else "Regenerate ✓" |
|
if st.button(button_label, key=f"gen_{tag}_{source}", disabled=not model_available or not can_afford): |
|
selected_tag = tag |
|
|
|
elif selected_sort == "Discovery Time": |
|
|
|
discovered_tags = [info for info in all_tags if info["source"] == "discovered" and "discovery_time" in info] |
|
|
|
|
|
sorted_tags = sorted(discovered_tags, key=lambda x: x["discovery_time"], reverse=True) |
|
|
|
|
|
date_groups = {} |
|
for info in sorted_tags: |
|
time_str = info["discovery_time"] |
|
|
|
date = time_str.split()[0] if " " in time_str else time_str |
|
|
|
if date not in date_groups: |
|
date_groups[date] = [] |
|
date_groups[date].append(info) |
|
|
|
|
|
for date, tags in date_groups.items(): |
|
date_display = date if date else "Unknown date" |
|
st.markdown(f"### Discovered on {date_display} ({len(tags)} tags)") |
|
|
|
with st.expander("Show/Hide", expanded=date == list(date_groups.keys())[0]): |
|
|
|
cols = st.columns(3) |
|
for i, info in enumerate(tags): |
|
with cols[i % 3]: |
|
tag = info["tag"] |
|
rarity = info["rarity"] |
|
|
|
|
|
rarity_color = RARITY_LEVELS.get(rarity, {}).get("color", "#AAAAAA") |
|
|
|
|
|
has_essence = hasattr(st.session_state, 'generated_essences') and tag in st.session_state.generated_essences |
|
|
|
|
|
cost = get_essence_cost(rarity) |
|
can_afford = st.session_state.enkephalin >= cost |
|
|
|
|
|
if rarity == "Impuritas Civitas": |
|
tag_display = f'<span class="impuritas-text">{tag}</span>' |
|
elif rarity == "Star of the City": |
|
tag_display = f'<span class="star-text">{tag}</span>' |
|
elif rarity == "Urban Nightmare": |
|
tag_display = f'<span class="nightmare-text">{tag}</span>' |
|
elif rarity == "Urban Plague": |
|
tag_display = f'<span class="plague-text">{tag}</span>' |
|
else: |
|
tag_display = f'<span style="color:{rarity_color};font-weight:bold;">{tag}</span>' |
|
|
|
|
|
st.markdown( |
|
f'{tag_display} <span style="background-color:{rarity_color};color:white;padding:2px 6px;border-radius:10px;font-size:0.8em;">{rarity.capitalize()}</span> ({cost} {ENKEPHALIN_ICON})', |
|
unsafe_allow_html=True |
|
) |
|
|
|
|
|
if "library_floor" in info and info["library_floor"]: |
|
st.markdown(f'<span style="font-size:0.85em;">Found in: {info["library_floor"]}</span>', |
|
unsafe_allow_html=True) |
|
|
|
|
|
button_label = "Generate" if not has_essence else "Regenerate ✓" |
|
if st.button(button_label, key=f"gen_{tag}_disc", disabled=not model_available or not can_afford): |
|
selected_tag = tag |
|
|
|
|
|
manual_tags = [info for info in all_tags if info["source"] == "manual"] |
|
if manual_tags: |
|
st.markdown("### Manual Tags") |
|
with st.expander("Show/Hide"): |
|
|
|
cols = st.columns(3) |
|
for i, info in enumerate(manual_tags): |
|
with cols[i % 3]: |
|
tag = info["tag"] |
|
rarity = info["rarity"] |
|
|
|
|
|
rarity_color = RARITY_LEVELS.get(rarity, {}).get("color", "#AAAAAA") |
|
|
|
|
|
has_essence = hasattr(st.session_state, 'generated_essences') and tag in st.session_state.generated_essences |
|
|
|
|
|
cost = get_essence_cost(rarity) |
|
can_afford = st.session_state.enkephalin >= cost |
|
|
|
|
|
st.markdown(f"**{tag}** ({cost} {ENKEPHALIN_ICON})") |
|
|
|
|
|
if "description" in info and info["description"]: |
|
st.markdown(f'<span style="font-size:0.85em;font-style:italic;">{info["description"]}</span>', |
|
unsafe_allow_html=True) |
|
|
|
|
|
button_label = "Generate" if not has_essence else "Regenerate ✓" |
|
if st.button(button_label, key=f"gen_{tag}_manual", disabled=not model_available or not can_afford): |
|
selected_tag = tag |
|
|
|
return selected_tag |