Update app/app.py
Browse files- app/app.py +1049 -1049
app/app.py
CHANGED
@@ -1,1050 +1,1050 @@
|
|
1 |
-
#!/usr/bin/env python3
|
2 |
-
"""
|
3 |
-
Camie-Tagger-V2 Application
|
4 |
-
A Streamlit web app for tagging images using an AI model.
|
5 |
-
"""
|
6 |
-
|
7 |
-
import streamlit as st
|
8 |
-
import os
|
9 |
-
import sys
|
10 |
-
import traceback
|
11 |
-
import tempfile
|
12 |
-
import time
|
13 |
-
import platform
|
14 |
-
import subprocess
|
15 |
-
import webbrowser
|
16 |
-
import glob
|
17 |
-
import numpy as np
|
18 |
-
import matplotlib.pyplot as plt
|
19 |
-
import io
|
20 |
-
import base64
|
21 |
-
import json
|
22 |
-
from matplotlib.colors import LinearSegmentedColormap
|
23 |
-
from PIL import Image
|
24 |
-
from pathlib import Path
|
25 |
-
|
26 |
-
# Add parent directory to path to allow importing from utils
|
27 |
-
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
28 |
-
|
29 |
-
# Import utilities
|
30 |
-
from utils.image_processing import process_image, batch_process_images
|
31 |
-
from utils.file_utils import save_tags_to_file, get_default_save_locations
|
32 |
-
from utils.ui_components import display_progress_bar, show_example_images, display_batch_results
|
33 |
-
from utils.onnx_processing import batch_process_images_onnx
|
34 |
-
|
35 |
-
# Define the model directory
|
36 |
-
MODEL_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
37 |
-
print(f"Using model directory: {MODEL_DIR}")
|
38 |
-
|
39 |
-
# Define threshold profile descriptions and explanations
|
40 |
-
threshold_profile_descriptions = {
|
41 |
-
"Micro Optimized": "Maximizes micro-averaged F1 score (best for dominant classes). Optimal for overall prediction quality.",
|
42 |
-
"Macro Optimized": "Maximizes macro-averaged F1 score (equal weight to all classes). Better for balanced performance across all tags.",
|
43 |
-
"Balanced": "Provides a trade-off between precision and recall with moderate thresholds. Good general-purpose setting.",
|
44 |
-
"Overall": "Uses a single threshold value across all categories. Simplest approach for consistent behavior.",
|
45 |
-
"Category-specific": "Uses different optimal thresholds for each category. Best for fine-tuning results."
|
46 |
-
}
|
47 |
-
|
48 |
-
threshold_profile_explanations = {
|
49 |
-
"Micro Optimized": """
|
50 |
-
### Micro Optimized Profile
|
51 |
-
|
52 |
-
**Technical definition**: Maximizes micro-averaged F1 score, which calculates metrics globally across all predictions.
|
53 |
-
|
54 |
-
**When to use**: When you want the best overall accuracy, especially for common tags and dominant categories.
|
55 |
-
|
56 |
-
**Effects**:
|
57 |
-
- Optimizes performance for the most frequent tags
|
58 |
-
- Gives more weight to categories with many examples (like 'character' and 'general')
|
59 |
-
- Provides higher precision in most common use cases
|
60 |
-
|
61 |
-
**Performance from validation**:
|
62 |
-
- Micro F1: ~67.3%
|
63 |
-
- Macro F1: ~46.3%
|
64 |
-
- Threshold: ~0.614
|
65 |
-
""",
|
66 |
-
|
67 |
-
"Macro Optimized": """
|
68 |
-
### Macro Optimized Profile
|
69 |
-
|
70 |
-
**Technical definition**: Maximizes macro-averaged F1 score, which gives equal weight to all categories regardless of size.
|
71 |
-
|
72 |
-
**When to use**: When balanced performance across all categories is important, including rare tags.
|
73 |
-
|
74 |
-
**Effects**:
|
75 |
-
- More balanced performance across all tag categories
|
76 |
-
- Better at detecting rare or unusual tags
|
77 |
-
- Generally has lower thresholds than micro-optimized
|
78 |
-
|
79 |
-
**Performance from validation**:
|
80 |
-
- Micro F1: ~60.9%
|
81 |
-
- Macro F1: ~50.6%
|
82 |
-
- Threshold: ~0.492
|
83 |
-
""",
|
84 |
-
|
85 |
-
"Balanced": """
|
86 |
-
### Balanced Profile
|
87 |
-
|
88 |
-
**Technical definition**: Same as Micro Optimized but provides a good reference point for manual adjustment.
|
89 |
-
|
90 |
-
**When to use**: For general-purpose tagging when you don't have specific recall or precision requirements.
|
91 |
-
|
92 |
-
**Effects**:
|
93 |
-
- Good middle ground between precision and recall
|
94 |
-
- Works well for most common use cases
|
95 |
-
- Default choice for most users
|
96 |
-
|
97 |
-
**Performance from validation**:
|
98 |
-
- Micro F1: ~67.3%
|
99 |
-
- Macro F1: ~46.3%
|
100 |
-
- Threshold: ~0.614
|
101 |
-
""",
|
102 |
-
|
103 |
-
"Overall": """
|
104 |
-
### Overall Profile
|
105 |
-
|
106 |
-
**Technical definition**: Uses a single threshold value across all categories.
|
107 |
-
|
108 |
-
**When to use**: When you want consistent behavior across all categories and a simple approach.
|
109 |
-
|
110 |
-
**Effects**:
|
111 |
-
- Consistent tagging threshold for all categories
|
112 |
-
- Simpler to understand than category-specific thresholds
|
113 |
-
- User-adjustable with a single slider
|
114 |
-
|
115 |
-
**Default threshold value**: 0.5 (user-adjustable)
|
116 |
-
|
117 |
-
**Note**: The threshold value is user-adjustable with the slider below.
|
118 |
-
""",
|
119 |
-
|
120 |
-
"Category-specific": """
|
121 |
-
### Category-specific Profile
|
122 |
-
|
123 |
-
**Technical definition**: Uses different optimal thresholds for each category, allowing fine-tuning.
|
124 |
-
|
125 |
-
**When to use**: When you want to customize tagging sensitivity for different categories.
|
126 |
-
|
127 |
-
**Effects**:
|
128 |
-
- Each category has its own independent threshold
|
129 |
-
- Full control over category sensitivity
|
130 |
-
- Best for fine-tuning results when some categories need different treatment
|
131 |
-
|
132 |
-
**Default threshold values**: Starts with balanced thresholds for each category
|
133 |
-
|
134 |
-
**Note**: Use the category sliders below to adjust thresholds for individual categories.
|
135 |
-
"""
|
136 |
-
}
|
137 |
-
|
138 |
-
def load_validation_results(results_path):
|
139 |
-
"""Load validation results from JSON file"""
|
140 |
-
try:
|
141 |
-
with open(results_path, 'r') as f:
|
142 |
-
data = json.load(f)
|
143 |
-
return data
|
144 |
-
except Exception as e:
|
145 |
-
print(f"Error loading validation results: {e}")
|
146 |
-
return None
|
147 |
-
|
148 |
-
def extract_thresholds_from_results(validation_data):
|
149 |
-
"""Extract threshold information from validation results"""
|
150 |
-
if not validation_data or 'results' not in validation_data:
|
151 |
-
return {}
|
152 |
-
|
153 |
-
thresholds = {
|
154 |
-
'overall': {},
|
155 |
-
'categories': {}
|
156 |
-
}
|
157 |
-
|
158 |
-
# Process results to extract thresholds
|
159 |
-
for result in validation_data['results']:
|
160 |
-
category = result['CATEGORY'].lower()
|
161 |
-
profile = result['PROFILE'].lower().replace(' ', '_')
|
162 |
-
threshold = result['THRESHOLD']
|
163 |
-
micro_f1 = result['MICRO-F1']
|
164 |
-
macro_f1 = result['MACRO-F1']
|
165 |
-
|
166 |
-
# Map profile names
|
167 |
-
if profile == 'micro_opt':
|
168 |
-
profile = 'micro_optimized'
|
169 |
-
elif profile == 'macro_opt':
|
170 |
-
profile = 'macro_optimized'
|
171 |
-
|
172 |
-
threshold_info = {
|
173 |
-
'threshold': threshold,
|
174 |
-
'micro_f1': micro_f1,
|
175 |
-
'macro_f1': macro_f1
|
176 |
-
}
|
177 |
-
|
178 |
-
if category == 'overall':
|
179 |
-
thresholds['overall'][profile] = threshold_info
|
180 |
-
else:
|
181 |
-
if category not in thresholds['categories']:
|
182 |
-
thresholds['categories'][category] = {}
|
183 |
-
thresholds['categories'][category][profile] = threshold_info
|
184 |
-
|
185 |
-
return thresholds
|
186 |
-
|
187 |
-
def load_model_and_metadata():
|
188 |
-
"""Load model and metadata from available files"""
|
189 |
-
# Check for SafeTensors model
|
190 |
-
safetensors_path = os.path.join(MODEL_DIR, "camie-tagger-v2.safetensors")
|
191 |
-
safetensors_metadata_path = os.path.join(MODEL_DIR, "camie-tagger-v2-metadata.json")
|
192 |
-
|
193 |
-
# Check for ONNX model
|
194 |
-
onnx_path = os.path.join(MODEL_DIR, "camie-tagger-v2.onnx")
|
195 |
-
|
196 |
-
# Check for validation results
|
197 |
-
validation_results_path = os.path.join(MODEL_DIR, "full_validation_results.json")
|
198 |
-
|
199 |
-
model_info = {
|
200 |
-
'safetensors_available': os.path.exists(safetensors_path) and os.path.exists(safetensors_metadata_path),
|
201 |
-
'onnx_available': os.path.exists(onnx_path) and os.path.exists(safetensors_metadata_path),
|
202 |
-
'validation_results_available': os.path.exists(validation_results_path)
|
203 |
-
}
|
204 |
-
|
205 |
-
# Load metadata (same for both model types)
|
206 |
-
metadata = None
|
207 |
-
if os.path.exists(safetensors_metadata_path):
|
208 |
-
try:
|
209 |
-
with open(safetensors_metadata_path, 'r') as f:
|
210 |
-
metadata = json.load(f)
|
211 |
-
except Exception as e:
|
212 |
-
print(f"Error loading metadata: {e}")
|
213 |
-
|
214 |
-
# Load validation results for thresholds
|
215 |
-
thresholds = {}
|
216 |
-
if model_info['validation_results_available']:
|
217 |
-
validation_data = load_validation_results(validation_results_path)
|
218 |
-
if validation_data:
|
219 |
-
thresholds = extract_thresholds_from_results(validation_data)
|
220 |
-
|
221 |
-
# Add default thresholds if not available
|
222 |
-
if not thresholds:
|
223 |
-
thresholds = {
|
224 |
-
'overall': {
|
225 |
-
'balanced': {'threshold': 0.5, 'micro_f1': 0, 'macro_f1': 0},
|
226 |
-
'micro_optimized': {'threshold': 0.6, 'micro_f1': 0, 'macro_f1': 0},
|
227 |
-
'macro_optimized': {'threshold': 0.4, 'micro_f1': 0, 'macro_f1': 0}
|
228 |
-
},
|
229 |
-
'categories': {}
|
230 |
-
}
|
231 |
-
|
232 |
-
return model_info, metadata, thresholds
|
233 |
-
|
234 |
-
def load_safetensors_model(safetensors_path, metadata_path):
|
235 |
-
"""Load SafeTensors model"""
|
236 |
-
try:
|
237 |
-
from safetensors.torch import load_file
|
238 |
-
import torch
|
239 |
-
|
240 |
-
# Load metadata
|
241 |
-
with open(metadata_path, 'r') as f:
|
242 |
-
metadata = json.load(f)
|
243 |
-
|
244 |
-
# Import the model class (assuming it's available)
|
245 |
-
# You'll need to make sure the ImageTagger class is importable
|
246 |
-
from utils.model_loader import ImageTagger # Update this import
|
247 |
-
|
248 |
-
model_info = metadata['model_info']
|
249 |
-
dataset_info = metadata['dataset_info']
|
250 |
-
|
251 |
-
# Recreate model architecture
|
252 |
-
model = ImageTagger(
|
253 |
-
total_tags=dataset_info['total_tags'],
|
254 |
-
dataset=None,
|
255 |
-
model_name=model_info['backbone'],
|
256 |
-
num_heads=model_info['num_attention_heads'],
|
257 |
-
dropout=0.0,
|
258 |
-
pretrained=False,
|
259 |
-
tag_context_size=model_info['tag_context_size'],
|
260 |
-
use_gradient_checkpointing=False,
|
261 |
-
img_size=model_info['img_size']
|
262 |
-
)
|
263 |
-
|
264 |
-
# Load weights
|
265 |
-
state_dict = load_file(safetensors_path)
|
266 |
-
model.load_state_dict(state_dict)
|
267 |
-
model.eval()
|
268 |
-
|
269 |
-
return model, metadata
|
270 |
-
except Exception as e:
|
271 |
-
raise Exception(f"Failed to load SafeTensors model: {e}")
|
272 |
-
|
273 |
-
def get_profile_metrics(thresholds, profile_name):
|
274 |
-
"""Extract metrics for the given profile from the thresholds dictionary"""
|
275 |
-
profile_key = None
|
276 |
-
|
277 |
-
# Map UI-friendly names to internal keys
|
278 |
-
if profile_name == "Micro Optimized":
|
279 |
-
profile_key = "micro_optimized"
|
280 |
-
elif profile_name == "Macro Optimized":
|
281 |
-
profile_key = "macro_optimized"
|
282 |
-
elif profile_name == "Balanced":
|
283 |
-
profile_key = "balanced"
|
284 |
-
elif profile_name in ["Overall", "Category-specific"]:
|
285 |
-
profile_key = "macro_optimized" # Use macro as default for these modes
|
286 |
-
|
287 |
-
if profile_key and 'overall' in thresholds and profile_key in thresholds['overall']:
|
288 |
-
return thresholds['overall'][profile_key]
|
289 |
-
|
290 |
-
return None
|
291 |
-
|
292 |
-
def on_threshold_profile_change():
|
293 |
-
"""Handle threshold profile changes"""
|
294 |
-
new_profile = st.session_state.threshold_profile
|
295 |
-
|
296 |
-
if hasattr(st.session_state, 'thresholds') and hasattr(st.session_state, 'settings'):
|
297 |
-
# Initialize category thresholds if needed
|
298 |
-
if st.session_state.settings['active_category_thresholds'] is None:
|
299 |
-
st.session_state.settings['active_category_thresholds'] = {}
|
300 |
-
|
301 |
-
current_thresholds = st.session_state.settings['active_category_thresholds']
|
302 |
-
|
303 |
-
# Map profile names to keys
|
304 |
-
profile_key = None
|
305 |
-
if new_profile == "Micro Optimized":
|
306 |
-
profile_key = "micro_optimized"
|
307 |
-
elif new_profile == "Macro Optimized":
|
308 |
-
profile_key = "macro_optimized"
|
309 |
-
elif new_profile == "Balanced":
|
310 |
-
profile_key = "balanced"
|
311 |
-
|
312 |
-
# Update thresholds based on profile
|
313 |
-
if profile_key and 'overall' in st.session_state.thresholds and profile_key in st.session_state.thresholds['overall']:
|
314 |
-
st.session_state.settings['active_threshold'] = st.session_state.thresholds['overall'][profile_key]['threshold']
|
315 |
-
|
316 |
-
# Set category thresholds
|
317 |
-
for category in st.session_state.categories:
|
318 |
-
if category in st.session_state.thresholds['categories'] and profile_key in st.session_state.thresholds['categories'][category]:
|
319 |
-
current_thresholds[category] = st.session_state.thresholds['categories'][category][profile_key]['threshold']
|
320 |
-
else:
|
321 |
-
current_thresholds[category] = st.session_state.settings['active_threshold']
|
322 |
-
|
323 |
-
elif new_profile == "Overall":
|
324 |
-
# Use balanced threshold for Overall profile
|
325 |
-
if 'overall' in st.session_state.thresholds and 'balanced' in st.session_state.thresholds['overall']:
|
326 |
-
st.session_state.settings['active_threshold'] = st.session_state.thresholds['overall']['balanced']['threshold']
|
327 |
-
else:
|
328 |
-
st.session_state.settings['active_threshold'] = 0.5
|
329 |
-
|
330 |
-
# Clear category-specific overrides
|
331 |
-
st.session_state.settings['active_category_thresholds'] = {}
|
332 |
-
|
333 |
-
elif new_profile == "Category-specific":
|
334 |
-
# Initialize with balanced thresholds
|
335 |
-
if 'overall' in st.session_state.thresholds and 'balanced' in st.session_state.thresholds['overall']:
|
336 |
-
st.session_state.settings['active_threshold'] = st.session_state.thresholds['overall']['balanced']['threshold']
|
337 |
-
else:
|
338 |
-
st.session_state.settings['active_threshold'] = 0.5
|
339 |
-
|
340 |
-
# Initialize category thresholds
|
341 |
-
for category in st.session_state.categories:
|
342 |
-
if category in st.session_state.thresholds['categories'] and 'balanced' in st.session_state.thresholds['categories'][category]:
|
343 |
-
current_thresholds[category] = st.session_state.thresholds['categories'][category]['balanced']['threshold']
|
344 |
-
else:
|
345 |
-
current_thresholds[category] = st.session_state.settings['active_threshold']
|
346 |
-
|
347 |
-
def apply_thresholds(all_probs, threshold_profile, active_threshold, active_category_thresholds, min_confidence, selected_categories):
|
348 |
-
"""Apply thresholds to raw probabilities and return filtered tags"""
|
349 |
-
tags = {}
|
350 |
-
all_tags = []
|
351 |
-
|
352 |
-
# Handle None case for active_category_thresholds
|
353 |
-
active_category_thresholds = active_category_thresholds or {}
|
354 |
-
|
355 |
-
for category, cat_probs in all_probs.items():
|
356 |
-
# Get the appropriate threshold for this category
|
357 |
-
threshold = active_category_thresholds.get(category, active_threshold)
|
358 |
-
|
359 |
-
# Filter tags above threshold
|
360 |
-
tags[category] = [(tag, prob) for tag, prob in cat_probs if prob >= threshold]
|
361 |
-
|
362 |
-
# Add to all_tags if selected
|
363 |
-
if selected_categories.get(category, True):
|
364 |
-
for tag, prob in tags[category]:
|
365 |
-
all_tags.append(tag)
|
366 |
-
|
367 |
-
return tags, all_tags
|
368 |
-
|
369 |
-
def image_tagger_app():
|
370 |
-
"""Main Streamlit application for image tagging."""
|
371 |
-
st.set_page_config(layout="wide", page_title="Camie Tagger", page_icon="🖼️")
|
372 |
-
|
373 |
-
st.title("Camie-Tagger-v2 Interface")
|
374 |
-
st.markdown("---")
|
375 |
-
|
376 |
-
# Initialize settings
|
377 |
-
if 'settings' not in st.session_state:
|
378 |
-
st.session_state.settings = {
|
379 |
-
'show_all_tags': False,
|
380 |
-
'compact_view': True,
|
381 |
-
'min_confidence': 0.01,
|
382 |
-
'threshold_profile': "Macro",
|
383 |
-
'active_threshold': 0.5,
|
384 |
-
'active_category_thresholds': {}, # Initialize as empty dict, not None
|
385 |
-
'selected_categories': {},
|
386 |
-
'replace_underscores': False
|
387 |
-
}
|
388 |
-
st.session_state.show_profile_help = False
|
389 |
-
|
390 |
-
# Session state initialization for model
|
391 |
-
if 'model_loaded' not in st.session_state:
|
392 |
-
st.session_state.model_loaded = False
|
393 |
-
st.session_state.model = None
|
394 |
-
st.session_state.thresholds = None
|
395 |
-
st.session_state.metadata = None
|
396 |
-
st.session_state.model_type = "onnx" # Default to ONNX
|
397 |
-
|
398 |
-
# Sidebar for model selection and information
|
399 |
-
with st.sidebar:
|
400 |
-
# Support information
|
401 |
-
st.subheader("💡 Notes")
|
402 |
-
|
403 |
-
st.markdown("""
|
404 |
-
This tagger was trained on a subset of the available data due to hardware limitations.
|
405 |
-
|
406 |
-
A more comprehensive model trained on the full 3+ million image dataset would provide:
|
407 |
-
- More recent characters and tags.
|
408 |
-
- Improved accuracy.
|
409 |
-
|
410 |
-
If you find this tool useful and would like to support future development:
|
411 |
-
""")
|
412 |
-
|
413 |
-
# Add Buy Me a Coffee button with Star of the City-like glow effect
|
414 |
-
st.markdown("""
|
415 |
-
<style>
|
416 |
-
@keyframes coffee-button-glow {
|
417 |
-
0% { box-shadow: 0 0 5px #FFD700; }
|
418 |
-
50% { box-shadow: 0 0 15px #FFD700; }
|
419 |
-
100% { box-shadow: 0 0 5px #FFD700; }
|
420 |
-
}
|
421 |
-
|
422 |
-
.coffee-button {
|
423 |
-
display: inline-block;
|
424 |
-
animation: coffee-button-glow 2s infinite;
|
425 |
-
border-radius: 5px;
|
426 |
-
transition: transform 0.3s ease;
|
427 |
-
}
|
428 |
-
|
429 |
-
.coffee-button:hover {
|
430 |
-
transform: scale(1.05);
|
431 |
-
}
|
432 |
-
</style>
|
433 |
-
|
434 |
-
<a href="https://ko-fi.com/camais" target="_blank" class="coffee-button">
|
435 |
-
<img src="https://cdn.buymeacoffee.com/buttons/v2/default-yellow.png"
|
436 |
-
alt="Buy Me A Coffee"
|
437 |
-
style="height: 45px; width: 162px; border-radius: 5px;" />
|
438 |
-
</a>
|
439 |
-
""", unsafe_allow_html=True)
|
440 |
-
|
441 |
-
st.markdown("""
|
442 |
-
Your support helps with:
|
443 |
-
- GPU costs for training
|
444 |
-
- Storage for larger datasets
|
445 |
-
- Development of new features
|
446 |
-
- Future projects
|
447 |
-
|
448 |
-
Thank you! 🙏
|
449 |
-
|
450 |
-
Full Details: https://huggingface.co/Camais03/camie-tagger
|
451 |
-
""")
|
452 |
-
|
453 |
-
st.header("Model Selection")
|
454 |
-
|
455 |
-
# Load model information
|
456 |
-
model_info, metadata, thresholds = load_model_and_metadata()
|
457 |
-
|
458 |
-
# Determine available model options
|
459 |
-
model_options = []
|
460 |
-
if model_info['onnx_available']:
|
461 |
-
model_options.append("ONNX (Recommended)")
|
462 |
-
if model_info['safetensors_available']:
|
463 |
-
model_options.append("SafeTensors (PyTorch)")
|
464 |
-
|
465 |
-
if not model_options:
|
466 |
-
st.error("No model files found!")
|
467 |
-
st.info(f"Looking for models in: {MODEL_DIR}")
|
468 |
-
st.info("Expected files:")
|
469 |
-
st.info("- camie-tagger-v2.onnx")
|
470 |
-
st.info("- camie-tagger-v2.safetensors")
|
471 |
-
st.info("- camie-tagger-v2-metadata.json")
|
472 |
-
st.stop()
|
473 |
-
|
474 |
-
# Model type selection
|
475 |
-
default_index = 0 if model_info['onnx_available'] else 0
|
476 |
-
model_type = st.radio(
|
477 |
-
"Select Model Type:",
|
478 |
-
model_options,
|
479 |
-
index=default_index,
|
480 |
-
help="ONNX: Optimized for speed and compatibility\nSafeTensors: Native PyTorch format"
|
481 |
-
)
|
482 |
-
|
483 |
-
# Convert selection to internal model type
|
484 |
-
if model_type == "ONNX (Recommended)":
|
485 |
-
selected_model_type = "onnx"
|
486 |
-
else:
|
487 |
-
selected_model_type = "safetensors"
|
488 |
-
|
489 |
-
# If model type changed, reload
|
490 |
-
if selected_model_type != st.session_state.model_type:
|
491 |
-
st.session_state.model_loaded = False
|
492 |
-
st.session_state.model_type = selected_model_type
|
493 |
-
|
494 |
-
# Reload button
|
495 |
-
if st.button("Reload Model") and st.session_state.model_loaded:
|
496 |
-
st.session_state.model_loaded = False
|
497 |
-
st.info("Reloading model...")
|
498 |
-
|
499 |
-
# Try to load the model
|
500 |
-
if not st.session_state.model_loaded:
|
501 |
-
try:
|
502 |
-
with st.spinner(f"Loading {st.session_state.model_type.upper()} model..."):
|
503 |
-
if st.session_state.model_type == "onnx":
|
504 |
-
# Load ONNX model
|
505 |
-
import onnxruntime as ort
|
506 |
-
|
507 |
-
onnx_path = os.path.join(MODEL_DIR, "camie-tagger-v2.onnx")
|
508 |
-
|
509 |
-
# Check ONNX providers
|
510 |
-
providers = ort.get_available_providers()
|
511 |
-
gpu_available = any('CUDA' in provider for provider in providers)
|
512 |
-
|
513 |
-
# Create ONNX session
|
514 |
-
session = ort.InferenceSession(onnx_path, providers=providers)
|
515 |
-
|
516 |
-
st.session_state.model = session
|
517 |
-
st.session_state.device = f"ONNX Runtime ({'GPU' if gpu_available else 'CPU'})"
|
518 |
-
st.session_state.param_dtype = "float32"
|
519 |
-
|
520 |
-
else:
|
521 |
-
# Load SafeTensors model
|
522 |
-
safetensors_path = os.path.join(MODEL_DIR, "camie-tagger-v2.safetensors")
|
523 |
-
metadata_path = os.path.join(MODEL_DIR, "camie-tagger-v2-metadata.json")
|
524 |
-
|
525 |
-
model, loaded_metadata = load_safetensors_model(safetensors_path, metadata_path)
|
526 |
-
|
527 |
-
st.session_state.model = model
|
528 |
-
device = next(model.parameters()).device
|
529 |
-
param_dtype = next(model.parameters()).dtype
|
530 |
-
st.session_state.device = device
|
531 |
-
st.session_state.param_dtype = param_dtype
|
532 |
-
metadata = loaded_metadata # Use loaded metadata instead
|
533 |
-
|
534 |
-
# Store common info
|
535 |
-
st.session_state.thresholds = thresholds
|
536 |
-
st.session_state.metadata = metadata
|
537 |
-
st.session_state.model_loaded = True
|
538 |
-
|
539 |
-
# Get categories
|
540 |
-
if metadata and 'dataset_info' in metadata:
|
541 |
-
tag_mapping = metadata['dataset_info']['tag_mapping']
|
542 |
-
categories = list(set(tag_mapping['tag_to_category'].values()))
|
543 |
-
st.session_state.categories = categories
|
544 |
-
|
545 |
-
# Initialize selected categories
|
546 |
-
if not st.session_state.settings['selected_categories']:
|
547 |
-
st.session_state.settings['selected_categories'] = {cat: True for cat in categories}
|
548 |
-
|
549 |
-
# Set initial threshold from validation results
|
550 |
-
if 'overall' in thresholds and 'balanced' in thresholds['overall']:
|
551 |
-
st.session_state.settings['active_threshold'] = thresholds['overall']['macro_optimized']['threshold']
|
552 |
-
|
553 |
-
except Exception as e:
|
554 |
-
st.error(f"Error loading model: {str(e)}")
|
555 |
-
st.code(traceback.format_exc())
|
556 |
-
st.stop()
|
557 |
-
|
558 |
-
# Display model information in sidebar
|
559 |
-
with st.sidebar:
|
560 |
-
st.header("Model Information")
|
561 |
-
if st.session_state.model_loaded:
|
562 |
-
if st.session_state.model_type == "onnx":
|
563 |
-
st.success("Using ONNX Model")
|
564 |
-
else:
|
565 |
-
st.success("Using SafeTensors Model")
|
566 |
-
|
567 |
-
st.write(f"Device: {st.session_state.device}")
|
568 |
-
st.write(f"Precision: {st.session_state.param_dtype}")
|
569 |
-
|
570 |
-
if st.session_state.metadata:
|
571 |
-
if 'dataset_info' in st.session_state.metadata:
|
572 |
-
total_tags = st.session_state.metadata['dataset_info']['total_tags']
|
573 |
-
st.write(f"Total tags: {total_tags}")
|
574 |
-
elif 'total_tags' in st.session_state.metadata:
|
575 |
-
st.write(f"Total tags: {st.session_state.metadata['total_tags']}")
|
576 |
-
|
577 |
-
# Show categories
|
578 |
-
with st.expander("Available Categories"):
|
579 |
-
for category in sorted(st.session_state.categories):
|
580 |
-
st.write(f"- {category.capitalize()}")
|
581 |
-
|
582 |
-
# About section
|
583 |
-
with st.expander("About this app"):
|
584 |
-
st.write("""
|
585 |
-
This app uses a trained image tagging model to analyze and tag images.
|
586 |
-
|
587 |
-
**Model Options**:
|
588 |
-
- **ONNX (Recommended)**: Optimized for inference speed with broad compatibility
|
589 |
-
- **SafeTensors**: Native PyTorch format for advanced users
|
590 |
-
|
591 |
-
**Features**:
|
592 |
-
- Upload or process images in batches
|
593 |
-
- Multiple threshold profiles based on validation results
|
594 |
-
- Category-specific threshold adjustment
|
595 |
-
- Export tags in various formats
|
596 |
-
- Fast inference with GPU acceleration (when available)
|
597 |
-
|
598 |
-
**Threshold Profiles**:
|
599 |
-
- **Micro Optimized**: Best overall F1 score (67.3% micro F1)
|
600 |
-
- **Macro Optimized**: Balanced across categories (50.6% macro F1)
|
601 |
-
- **Balanced**: Good general-purpose setting
|
602 |
-
- **Overall**: Single adjustable threshold
|
603 |
-
- **Category-specific**: Fine-tune each category individually
|
604 |
-
""")
|
605 |
-
|
606 |
-
# Main content area - Image upload and processing
|
607 |
-
col1, col2 = st.columns([1, 1.5])
|
608 |
-
|
609 |
-
with col1:
|
610 |
-
st.header("Image")
|
611 |
-
|
612 |
-
upload_tab, batch_tab = st.tabs(["Upload Image", "Batch Processing"])
|
613 |
-
|
614 |
-
image_path = None
|
615 |
-
|
616 |
-
with upload_tab:
|
617 |
-
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
|
618 |
-
|
619 |
-
if uploaded_file:
|
620 |
-
# Create temporary file
|
621 |
-
with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as tmp_file:
|
622 |
-
tmp_file.write(uploaded_file.getvalue())
|
623 |
-
image_path = tmp_file.name
|
624 |
-
|
625 |
-
st.session_state.original_filename = uploaded_file.name
|
626 |
-
|
627 |
-
# Display image
|
628 |
-
image = Image.open(uploaded_file)
|
629 |
-
st.image(image, use_container_width=True)
|
630 |
-
|
631 |
-
with batch_tab:
|
632 |
-
st.subheader("Batch Process Images")
|
633 |
-
|
634 |
-
# Folder selection
|
635 |
-
batch_folder = st.text_input("Enter folder path containing images:", "")
|
636 |
-
|
637 |
-
# Save options
|
638 |
-
save_options = st.radio(
|
639 |
-
"Where to save tag files:",
|
640 |
-
["Same folder as images", "Custom location", "Default save folder"],
|
641 |
-
index=0
|
642 |
-
)
|
643 |
-
|
644 |
-
# Batch size control
|
645 |
-
st.subheader("Performance Options")
|
646 |
-
batch_size = st.number_input("Batch size", min_value=1, max_value=32, value=4,
|
647 |
-
help="Higher values may improve speed but use more memory")
|
648 |
-
|
649 |
-
# Category limits
|
650 |
-
enable_category_limits = st.checkbox("Limit tags per category in batch output", value=False)
|
651 |
-
|
652 |
-
if enable_category_limits and hasattr(st.session_state, 'categories'):
|
653 |
-
if 'category_limits' not in st.session_state:
|
654 |
-
st.session_state.category_limits = {}
|
655 |
-
|
656 |
-
st.markdown("**Limit Values:** -1 = no limit, 0 = exclude, N = top N tags")
|
657 |
-
|
658 |
-
limit_cols = st.columns(2)
|
659 |
-
for i, category in enumerate(sorted(st.session_state.categories)):
|
660 |
-
col_idx = i % 2
|
661 |
-
with limit_cols[col_idx]:
|
662 |
-
current_limit = st.session_state.category_limits.get(category, -1)
|
663 |
-
new_limit = st.number_input(
|
664 |
-
f"{category.capitalize()}:",
|
665 |
-
value=current_limit,
|
666 |
-
min_value=-1,
|
667 |
-
step=1,
|
668 |
-
key=f"limit_{category}"
|
669 |
-
)
|
670 |
-
st.session_state.category_limits[category] = new_limit
|
671 |
-
|
672 |
-
# Process batch button
|
673 |
-
if batch_folder and os.path.isdir(batch_folder):
|
674 |
-
image_files = []
|
675 |
-
for ext in ['*.jpg', '*.jpeg', '*.png']:
|
676 |
-
image_files.extend(glob.glob(os.path.join(batch_folder, ext)))
|
677 |
-
image_files.extend(glob.glob(os.path.join(batch_folder, ext.upper())))
|
678 |
-
|
679 |
-
if image_files:
|
680 |
-
st.write(f"Found {len(image_files)} images")
|
681 |
-
|
682 |
-
if st.button("🔄 Process All Images", type="primary"):
|
683 |
-
if not st.session_state.model_loaded:
|
684 |
-
st.error("Model not loaded")
|
685 |
-
else:
|
686 |
-
with st.spinner("Processing images..."):
|
687 |
-
progress_bar = st.progress(0)
|
688 |
-
status_text = st.empty()
|
689 |
-
|
690 |
-
def update_progress(current, total, image_path):
|
691 |
-
progress = current / total if total > 0 else 0
|
692 |
-
progress_bar.progress(progress)
|
693 |
-
status_text.text(f"Processing {current}/{total}: {os.path.basename(image_path) if image_path else 'Complete'}")
|
694 |
-
|
695 |
-
# Determine save directory
|
696 |
-
if save_options == "Same folder as images":
|
697 |
-
save_dir = batch_folder
|
698 |
-
elif save_options == "Custom location":
|
699 |
-
save_dir = st.text_input("Custom save directory:", batch_folder)
|
700 |
-
else:
|
701 |
-
save_dir = os.path.join(os.path.dirname(__file__), "saved_tags")
|
702 |
-
os.makedirs(save_dir, exist_ok=True)
|
703 |
-
|
704 |
-
# Get current settings
|
705 |
-
category_limits = st.session_state.category_limits if enable_category_limits else None
|
706 |
-
|
707 |
-
# Process based on model type
|
708 |
-
if st.session_state.model_type == "onnx":
|
709 |
-
batch_results = batch_process_images_onnx(
|
710 |
-
folder_path=batch_folder,
|
711 |
-
model_path=os.path.join(MODEL_DIR, "camie-tagger-v2.onnx"),
|
712 |
-
metadata_path=os.path.join(MODEL_DIR, "camie-tagger-v2-metadata.json"),
|
713 |
-
threshold_profile=st.session_state.settings['threshold_profile'],
|
714 |
-
active_threshold=st.session_state.settings['active_threshold'],
|
715 |
-
active_category_thresholds=st.session_state.settings['active_category_thresholds'],
|
716 |
-
save_dir=save_dir,
|
717 |
-
progress_callback=update_progress,
|
718 |
-
min_confidence=st.session_state.settings['min_confidence'],
|
719 |
-
batch_size=batch_size,
|
720 |
-
category_limits=category_limits
|
721 |
-
)
|
722 |
-
else:
|
723 |
-
# SafeTensors processing (would need to implement)
|
724 |
-
st.error("SafeTensors batch processing not implemented yet")
|
725 |
-
batch_results = None
|
726 |
-
|
727 |
-
if batch_results:
|
728 |
-
display_batch_results(batch_results)
|
729 |
-
|
730 |
-
# Column 2: Controls and Results
|
731 |
-
with col2:
|
732 |
-
st.header("Tagging Controls")
|
733 |
-
|
734 |
-
# Threshold profile selection
|
735 |
-
all_profiles = [
|
736 |
-
"Micro Optimized",
|
737 |
-
"Macro Optimized",
|
738 |
-
"Balanced",
|
739 |
-
"Overall",
|
740 |
-
"Category-specific"
|
741 |
-
]
|
742 |
-
|
743 |
-
profile_col1, profile_col2 = st.columns([3, 1])
|
744 |
-
|
745 |
-
with profile_col1:
|
746 |
-
threshold_profile = st.selectbox(
|
747 |
-
"Select threshold profile",
|
748 |
-
options=all_profiles,
|
749 |
-
index=1, # Default to Macro
|
750 |
-
key="threshold_profile",
|
751 |
-
on_change=on_threshold_profile_change
|
752 |
-
)
|
753 |
-
|
754 |
-
with profile_col2:
|
755 |
-
if st.button("ℹ️ Help", key="profile_help"):
|
756 |
-
st.session_state.show_profile_help = not st.session_state.get('show_profile_help', False)
|
757 |
-
|
758 |
-
# Show profile help
|
759 |
-
if st.session_state.get('show_profile_help', False):
|
760 |
-
st.markdown(threshold_profile_explanations[threshold_profile])
|
761 |
-
else:
|
762 |
-
st.info(threshold_profile_descriptions[threshold_profile])
|
763 |
-
|
764 |
-
# Show profile metrics if available
|
765 |
-
if st.session_state.model_loaded:
|
766 |
-
metrics = get_profile_metrics(st.session_state.thresholds, threshold_profile)
|
767 |
-
|
768 |
-
if metrics:
|
769 |
-
metrics_cols = st.columns(3)
|
770 |
-
|
771 |
-
with metrics_cols[0]:
|
772 |
-
st.metric("Threshold", f"{metrics['threshold']:.3f}")
|
773 |
-
|
774 |
-
with metrics_cols[1]:
|
775 |
-
st.metric("Micro F1", f"{metrics['micro_f1']:.1f}%")
|
776 |
-
|
777 |
-
with metrics_cols[2]:
|
778 |
-
st.metric("Macro F1", f"{metrics['macro_f1']:.1f}%")
|
779 |
-
|
780 |
-
# Threshold controls based on profile
|
781 |
-
if st.session_state.model_loaded:
|
782 |
-
active_threshold = st.session_state.settings.get('active_threshold', 0.5)
|
783 |
-
active_category_thresholds = st.session_state.settings.get('active_category_thresholds', {})
|
784 |
-
|
785 |
-
if threshold_profile in ["Micro Optimized", "Macro Optimized", "Balanced"]:
|
786 |
-
# Show reference threshold (disabled)
|
787 |
-
st.slider(
|
788 |
-
"Threshold (from validation)",
|
789 |
-
min_value=0.01,
|
790 |
-
max_value=1.0,
|
791 |
-
value=float(active_threshold),
|
792 |
-
step=0.01,
|
793 |
-
disabled=True,
|
794 |
-
help="This threshold is optimized from validation results"
|
795 |
-
)
|
796 |
-
|
797 |
-
elif threshold_profile == "Overall":
|
798 |
-
# Adjustable overall threshold
|
799 |
-
active_threshold = st.slider(
|
800 |
-
"Overall threshold",
|
801 |
-
min_value=0.01,
|
802 |
-
max_value=1.0,
|
803 |
-
value=float(active_threshold),
|
804 |
-
step=0.01
|
805 |
-
)
|
806 |
-
st.session_state.settings['active_threshold'] = active_threshold
|
807 |
-
|
808 |
-
elif threshold_profile == "Category-specific":
|
809 |
-
# Show reference overall threshold
|
810 |
-
st.slider(
|
811 |
-
"Overall threshold (reference)",
|
812 |
-
min_value=0.01,
|
813 |
-
max_value=1.0,
|
814 |
-
value=float(active_threshold),
|
815 |
-
step=0.01,
|
816 |
-
disabled=True
|
817 |
-
)
|
818 |
-
|
819 |
-
st.write("Adjust thresholds for individual categories:")
|
820 |
-
|
821 |
-
# Category sliders
|
822 |
-
slider_cols = st.columns(2)
|
823 |
-
|
824 |
-
if not active_category_thresholds:
|
825 |
-
active_category_thresholds = {}
|
826 |
-
|
827 |
-
for i, category in enumerate(sorted(st.session_state.categories)):
|
828 |
-
col_idx = i % 2
|
829 |
-
with slider_cols[col_idx]:
|
830 |
-
default_val = active_category_thresholds.get(category, active_threshold)
|
831 |
-
new_threshold = st.slider(
|
832 |
-
f"{category.capitalize()}",
|
833 |
-
min_value=0.01,
|
834 |
-
max_value=1.0,
|
835 |
-
value=float(default_val),
|
836 |
-
step=0.01,
|
837 |
-
key=f"slider_{category}"
|
838 |
-
)
|
839 |
-
active_category_thresholds[category] = new_threshold
|
840 |
-
|
841 |
-
st.session_state.settings['active_category_thresholds'] = active_category_thresholds
|
842 |
-
|
843 |
-
# Display options
|
844 |
-
with st.expander("Display Options", expanded=False):
|
845 |
-
col1, col2 = st.columns(2)
|
846 |
-
with col1:
|
847 |
-
show_all_tags = st.checkbox("Show all tags (including below threshold)",
|
848 |
-
value=st.session_state.settings['show_all_tags'])
|
849 |
-
compact_view = st.checkbox("Compact view (hide progress bars)",
|
850 |
-
value=st.session_state.settings['compact_view'])
|
851 |
-
replace_underscores = st.checkbox("Replace underscores with spaces",
|
852 |
-
value=st.session_state.settings.get('replace_underscores', False))
|
853 |
-
|
854 |
-
with col2:
|
855 |
-
min_confidence = st.slider("Minimum confidence to display", 0.0, 0.5,
|
856 |
-
st.session_state.settings['min_confidence'], 0.01)
|
857 |
-
|
858 |
-
# Update settings
|
859 |
-
st.session_state.settings.update({
|
860 |
-
'show_all_tags': show_all_tags,
|
861 |
-
'compact_view': compact_view,
|
862 |
-
'min_confidence': min_confidence,
|
863 |
-
'replace_underscores': replace_underscores
|
864 |
-
})
|
865 |
-
|
866 |
-
# Category selection
|
867 |
-
st.write("Categories to include in 'All Tags' section:")
|
868 |
-
|
869 |
-
category_cols = st.columns(3)
|
870 |
-
selected_categories = {}
|
871 |
-
|
872 |
-
if hasattr(st.session_state, 'categories'):
|
873 |
-
for i, category in enumerate(sorted(st.session_state.categories)):
|
874 |
-
col_idx = i % 3
|
875 |
-
with category_cols[col_idx]:
|
876 |
-
default_val = st.session_state.settings['selected_categories'].get(category, True)
|
877 |
-
selected_categories[category] = st.checkbox(
|
878 |
-
f"{category.capitalize()}",
|
879 |
-
value=default_val,
|
880 |
-
key=f"cat_select_{category}"
|
881 |
-
)
|
882 |
-
|
883 |
-
st.session_state.settings['selected_categories'] = selected_categories
|
884 |
-
|
885 |
-
# Run tagging button
|
886 |
-
if image_path and st.button("Run Tagging"):
|
887 |
-
if not st.session_state.model_loaded:
|
888 |
-
st.error("Model not loaded")
|
889 |
-
else:
|
890 |
-
with st.spinner("Analyzing image..."):
|
891 |
-
try:
|
892 |
-
# Process image based on model type
|
893 |
-
if st.session_state.model_type == "onnx":
|
894 |
-
from utils.onnx_processing import process_single_image_onnx
|
895 |
-
|
896 |
-
result = process_single_image_onnx(
|
897 |
-
image_path=image_path,
|
898 |
-
model_path=os.path.join(MODEL_DIR, "camie-tagger-v2.onnx"),
|
899 |
-
metadata=st.session_state.metadata,
|
900 |
-
threshold_profile=threshold_profile,
|
901 |
-
active_threshold=st.session_state.settings['active_threshold'],
|
902 |
-
active_category_thresholds=st.session_state.settings.get('active_category_thresholds', {}),
|
903 |
-
min_confidence=st.session_state.settings['min_confidence']
|
904 |
-
)
|
905 |
-
else:
|
906 |
-
# SafeTensors processing
|
907 |
-
result = process_image(
|
908 |
-
image_path=image_path,
|
909 |
-
model=st.session_state.model,
|
910 |
-
thresholds=st.session_state.thresholds,
|
911 |
-
metadata=st.session_state.metadata,
|
912 |
-
threshold_profile=threshold_profile,
|
913 |
-
active_threshold=st.session_state.settings['active_threshold'],
|
914 |
-
active_category_thresholds=st.session_state.settings.get('active_category_thresholds', {}),
|
915 |
-
min_confidence=st.session_state.settings['min_confidence']
|
916 |
-
)
|
917 |
-
|
918 |
-
if result['success']:
|
919 |
-
st.session_state.all_probs = result['all_probs']
|
920 |
-
st.session_state.tags = result['tags']
|
921 |
-
st.session_state.all_tags = result['all_tags']
|
922 |
-
st.success("Analysis completed!")
|
923 |
-
else:
|
924 |
-
st.error(f"Analysis failed: {result.get('error', 'Unknown error')}")
|
925 |
-
|
926 |
-
except Exception as e:
|
927 |
-
st.error(f"Error during analysis: {str(e)}")
|
928 |
-
st.code(traceback.format_exc())
|
929 |
-
|
930 |
-
# Display results
|
931 |
-
if image_path and hasattr(st.session_state, 'all_probs'):
|
932 |
-
st.header("Predictions")
|
933 |
-
|
934 |
-
# Apply current thresholds
|
935 |
-
filtered_tags, current_all_tags = apply_thresholds(
|
936 |
-
st.session_state.all_probs,
|
937 |
-
threshold_profile,
|
938 |
-
st.session_state.settings['active_threshold'],
|
939 |
-
st.session_state.settings.get('active_category_thresholds', {}),
|
940 |
-
st.session_state.settings['min_confidence'],
|
941 |
-
st.session_state.settings['selected_categories']
|
942 |
-
)
|
943 |
-
|
944 |
-
all_tags = []
|
945 |
-
|
946 |
-
# Display by category
|
947 |
-
for category in sorted(st.session_state.all_probs.keys()):
|
948 |
-
all_tags_in_category = st.session_state.all_probs.get(category, [])
|
949 |
-
filtered_tags_in_category = filtered_tags.get(category, [])
|
950 |
-
|
951 |
-
if all_tags_in_category:
|
952 |
-
expander_label = f"{category.capitalize()} ({len(filtered_tags_in_category)} tags)"
|
953 |
-
|
954 |
-
with st.expander(expander_label, expanded=True):
|
955 |
-
# Get threshold for this category (handle None case)
|
956 |
-
active_category_thresholds = st.session_state.settings.get('active_category_thresholds') or {}
|
957 |
-
threshold = active_category_thresholds.get(category, st.session_state.settings['active_threshold'])
|
958 |
-
|
959 |
-
# Determine tags to display
|
960 |
-
if st.session_state.settings['show_all_tags']:
|
961 |
-
tags_to_display = all_tags_in_category
|
962 |
-
else:
|
963 |
-
tags_to_display = [(tag, prob) for tag, prob in all_tags_in_category if prob >= threshold]
|
964 |
-
|
965 |
-
if not tags_to_display:
|
966 |
-
st.info(f"No tags above {st.session_state.settings['min_confidence']:.2f} confidence")
|
967 |
-
continue
|
968 |
-
|
969 |
-
# Display tags
|
970 |
-
if st.session_state.settings['compact_view']:
|
971 |
-
# Compact view
|
972 |
-
tag_list = []
|
973 |
-
replace_underscores = st.session_state.settings.get('replace_underscores', False)
|
974 |
-
|
975 |
-
for tag, prob in tags_to_display:
|
976 |
-
percentage = int(prob * 100)
|
977 |
-
display_tag = tag.replace('_', ' ') if replace_underscores else tag
|
978 |
-
tag_list.append(f"{display_tag} ({percentage}%)")
|
979 |
-
|
980 |
-
if prob >= threshold and st.session_state.settings['selected_categories'].get(category, True):
|
981 |
-
all_tags.append(tag)
|
982 |
-
|
983 |
-
st.markdown(", ".join(tag_list))
|
984 |
-
else:
|
985 |
-
# Expanded view with progress bars
|
986 |
-
for tag, prob in tags_to_display:
|
987 |
-
replace_underscores = st.session_state.settings.get('replace_underscores', False)
|
988 |
-
display_tag = tag.replace('_', ' ') if replace_underscores else tag
|
989 |
-
|
990 |
-
if prob >= threshold and st.session_state.settings['selected_categories'].get(category, True):
|
991 |
-
all_tags.append(tag)
|
992 |
-
tag_display = f"**{display_tag}**"
|
993 |
-
else:
|
994 |
-
tag_display = display_tag
|
995 |
-
|
996 |
-
st.write(tag_display)
|
997 |
-
st.markdown(display_progress_bar(prob), unsafe_allow_html=True)
|
998 |
-
|
999 |
-
# All tags summary
|
1000 |
-
st.markdown("---")
|
1001 |
-
st.subheader(f"All Tags ({len(all_tags)} total)")
|
1002 |
-
if all_tags:
|
1003 |
-
replace_underscores = st.session_state.settings.get('replace_underscores', False)
|
1004 |
-
if replace_underscores:
|
1005 |
-
display_tags = [tag.replace('_', ' ') for tag in all_tags]
|
1006 |
-
st.write(", ".join(display_tags))
|
1007 |
-
else:
|
1008 |
-
st.write(", ".join(all_tags))
|
1009 |
-
else:
|
1010 |
-
st.info("No tags detected above the threshold.")
|
1011 |
-
|
1012 |
-
# Save tags section
|
1013 |
-
st.markdown("---")
|
1014 |
-
st.subheader("Save Tags")
|
1015 |
-
|
1016 |
-
if 'custom_folders' not in st.session_state:
|
1017 |
-
st.session_state.custom_folders = get_default_save_locations()
|
1018 |
-
|
1019 |
-
selected_folder = st.selectbox(
|
1020 |
-
"Select save location:",
|
1021 |
-
options=st.session_state.custom_folders,
|
1022 |
-
format_func=lambda x: os.path.basename(x) if os.path.basename(x) else x
|
1023 |
-
)
|
1024 |
-
|
1025 |
-
if st.button("💾 Save to Selected Location"):
|
1026 |
-
try:
|
1027 |
-
original_filename = st.session_state.original_filename if hasattr(st.session_state, 'original_filename') else None
|
1028 |
-
|
1029 |
-
saved_path = save_tags_to_file(
|
1030 |
-
image_path=image_path,
|
1031 |
-
all_tags=all_tags,
|
1032 |
-
original_filename=original_filename,
|
1033 |
-
custom_dir=selected_folder,
|
1034 |
-
overwrite=True
|
1035 |
-
)
|
1036 |
-
|
1037 |
-
st.success(f"Tags saved to: {os.path.basename(saved_path)}")
|
1038 |
-
st.info(f"Full path: {saved_path}")
|
1039 |
-
|
1040 |
-
# Show file preview
|
1041 |
-
with st.expander("File Contents", expanded=True):
|
1042 |
-
with open(saved_path, 'r', encoding='utf-8') as f:
|
1043 |
-
content = f.read()
|
1044 |
-
st.code(content, language='text')
|
1045 |
-
|
1046 |
-
except Exception as e:
|
1047 |
-
st.error(f"Error saving tags: {str(e)}")
|
1048 |
-
|
1049 |
-
if __name__ == "__main__":
|
1050 |
image_tagger_app()
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Camie-Tagger-V2 Application
|
4 |
+
A Streamlit web app for tagging images using an AI model.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import streamlit as st
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
import traceback
|
11 |
+
import tempfile
|
12 |
+
import time
|
13 |
+
import platform
|
14 |
+
import subprocess
|
15 |
+
import webbrowser
|
16 |
+
import glob
|
17 |
+
import numpy as np
|
18 |
+
import matplotlib.pyplot as plt
|
19 |
+
import io
|
20 |
+
import base64
|
21 |
+
import json
|
22 |
+
from matplotlib.colors import LinearSegmentedColormap
|
23 |
+
from PIL import Image
|
24 |
+
from pathlib import Path
|
25 |
+
|
26 |
+
# Add parent directory to path to allow importing from utils
|
27 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
28 |
+
|
29 |
+
# Import utilities
|
30 |
+
from utils.image_processing import process_image, batch_process_images
|
31 |
+
from utils.file_utils import save_tags_to_file, get_default_save_locations
|
32 |
+
from utils.ui_components import display_progress_bar, show_example_images, display_batch_results
|
33 |
+
from utils.onnx_processing import batch_process_images_onnx
|
34 |
+
|
35 |
+
# Define the model directory
|
36 |
+
MODEL_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
37 |
+
print(f"Using model directory: {MODEL_DIR}")
|
38 |
+
|
39 |
+
# Define threshold profile descriptions and explanations
|
40 |
+
threshold_profile_descriptions = {
|
41 |
+
"Micro Optimized": "Maximizes micro-averaged F1 score (best for dominant classes). Optimal for overall prediction quality.",
|
42 |
+
"Macro Optimized": "Maximizes macro-averaged F1 score (equal weight to all classes). Better for balanced performance across all tags.",
|
43 |
+
"Balanced": "Provides a trade-off between precision and recall with moderate thresholds. Good general-purpose setting.",
|
44 |
+
"Overall": "Uses a single threshold value across all categories. Simplest approach for consistent behavior.",
|
45 |
+
"Category-specific": "Uses different optimal thresholds for each category. Best for fine-tuning results."
|
46 |
+
}
|
47 |
+
|
48 |
+
threshold_profile_explanations = {
|
49 |
+
"Micro Optimized": """
|
50 |
+
### Micro Optimized Profile
|
51 |
+
|
52 |
+
**Technical definition**: Maximizes micro-averaged F1 score, which calculates metrics globally across all predictions.
|
53 |
+
|
54 |
+
**When to use**: When you want the best overall accuracy, especially for common tags and dominant categories.
|
55 |
+
|
56 |
+
**Effects**:
|
57 |
+
- Optimizes performance for the most frequent tags
|
58 |
+
- Gives more weight to categories with many examples (like 'character' and 'general')
|
59 |
+
- Provides higher precision in most common use cases
|
60 |
+
|
61 |
+
**Performance from validation**:
|
62 |
+
- Micro F1: ~67.3%
|
63 |
+
- Macro F1: ~46.3%
|
64 |
+
- Threshold: ~0.614
|
65 |
+
""",
|
66 |
+
|
67 |
+
"Macro Optimized": """
|
68 |
+
### Macro Optimized Profile
|
69 |
+
|
70 |
+
**Technical definition**: Maximizes macro-averaged F1 score, which gives equal weight to all categories regardless of size.
|
71 |
+
|
72 |
+
**When to use**: When balanced performance across all categories is important, including rare tags.
|
73 |
+
|
74 |
+
**Effects**:
|
75 |
+
- More balanced performance across all tag categories
|
76 |
+
- Better at detecting rare or unusual tags
|
77 |
+
- Generally has lower thresholds than micro-optimized
|
78 |
+
|
79 |
+
**Performance from validation**:
|
80 |
+
- Micro F1: ~60.9%
|
81 |
+
- Macro F1: ~50.6%
|
82 |
+
- Threshold: ~0.492
|
83 |
+
""",
|
84 |
+
|
85 |
+
"Balanced": """
|
86 |
+
### Balanced Profile
|
87 |
+
|
88 |
+
**Technical definition**: Same as Micro Optimized but provides a good reference point for manual adjustment.
|
89 |
+
|
90 |
+
**When to use**: For general-purpose tagging when you don't have specific recall or precision requirements.
|
91 |
+
|
92 |
+
**Effects**:
|
93 |
+
- Good middle ground between precision and recall
|
94 |
+
- Works well for most common use cases
|
95 |
+
- Default choice for most users
|
96 |
+
|
97 |
+
**Performance from validation**:
|
98 |
+
- Micro F1: ~67.3%
|
99 |
+
- Macro F1: ~46.3%
|
100 |
+
- Threshold: ~0.614
|
101 |
+
""",
|
102 |
+
|
103 |
+
"Overall": """
|
104 |
+
### Overall Profile
|
105 |
+
|
106 |
+
**Technical definition**: Uses a single threshold value across all categories.
|
107 |
+
|
108 |
+
**When to use**: When you want consistent behavior across all categories and a simple approach.
|
109 |
+
|
110 |
+
**Effects**:
|
111 |
+
- Consistent tagging threshold for all categories
|
112 |
+
- Simpler to understand than category-specific thresholds
|
113 |
+
- User-adjustable with a single slider
|
114 |
+
|
115 |
+
**Default threshold value**: 0.5 (user-adjustable)
|
116 |
+
|
117 |
+
**Note**: The threshold value is user-adjustable with the slider below.
|
118 |
+
""",
|
119 |
+
|
120 |
+
"Category-specific": """
|
121 |
+
### Category-specific Profile
|
122 |
+
|
123 |
+
**Technical definition**: Uses different optimal thresholds for each category, allowing fine-tuning.
|
124 |
+
|
125 |
+
**When to use**: When you want to customize tagging sensitivity for different categories.
|
126 |
+
|
127 |
+
**Effects**:
|
128 |
+
- Each category has its own independent threshold
|
129 |
+
- Full control over category sensitivity
|
130 |
+
- Best for fine-tuning results when some categories need different treatment
|
131 |
+
|
132 |
+
**Default threshold values**: Starts with balanced thresholds for each category
|
133 |
+
|
134 |
+
**Note**: Use the category sliders below to adjust thresholds for individual categories.
|
135 |
+
"""
|
136 |
+
}
|
137 |
+
|
138 |
+
def load_validation_results(results_path):
|
139 |
+
"""Load validation results from JSON file"""
|
140 |
+
try:
|
141 |
+
with open(results_path, 'r') as f:
|
142 |
+
data = json.load(f)
|
143 |
+
return data
|
144 |
+
except Exception as e:
|
145 |
+
print(f"Error loading validation results: {e}")
|
146 |
+
return None
|
147 |
+
|
148 |
+
def extract_thresholds_from_results(validation_data):
|
149 |
+
"""Extract threshold information from validation results"""
|
150 |
+
if not validation_data or 'results' not in validation_data:
|
151 |
+
return {}
|
152 |
+
|
153 |
+
thresholds = {
|
154 |
+
'overall': {},
|
155 |
+
'categories': {}
|
156 |
+
}
|
157 |
+
|
158 |
+
# Process results to extract thresholds
|
159 |
+
for result in validation_data['results']:
|
160 |
+
category = result['CATEGORY'].lower()
|
161 |
+
profile = result['PROFILE'].lower().replace(' ', '_')
|
162 |
+
threshold = result['THRESHOLD']
|
163 |
+
micro_f1 = result['MICRO-F1']
|
164 |
+
macro_f1 = result['MACRO-F1']
|
165 |
+
|
166 |
+
# Map profile names
|
167 |
+
if profile == 'micro_opt':
|
168 |
+
profile = 'micro_optimized'
|
169 |
+
elif profile == 'macro_opt':
|
170 |
+
profile = 'macro_optimized'
|
171 |
+
|
172 |
+
threshold_info = {
|
173 |
+
'threshold': threshold,
|
174 |
+
'micro_f1': micro_f1,
|
175 |
+
'macro_f1': macro_f1
|
176 |
+
}
|
177 |
+
|
178 |
+
if category == 'overall':
|
179 |
+
thresholds['overall'][profile] = threshold_info
|
180 |
+
else:
|
181 |
+
if category not in thresholds['categories']:
|
182 |
+
thresholds['categories'][category] = {}
|
183 |
+
thresholds['categories'][category][profile] = threshold_info
|
184 |
+
|
185 |
+
return thresholds
|
186 |
+
|
187 |
+
def load_model_and_metadata():
|
188 |
+
"""Load model and metadata from available files"""
|
189 |
+
# Check for SafeTensors model
|
190 |
+
safetensors_path = os.path.join(MODEL_DIR, "camie-tagger-v2.safetensors")
|
191 |
+
safetensors_metadata_path = os.path.join(MODEL_DIR, "camie-tagger-v2-metadata.json")
|
192 |
+
|
193 |
+
# Check for ONNX model
|
194 |
+
onnx_path = os.path.join(MODEL_DIR, "camie-tagger-v2.onnx")
|
195 |
+
|
196 |
+
# Check for validation results
|
197 |
+
validation_results_path = os.path.join(MODEL_DIR, "full_validation_results.json")
|
198 |
+
|
199 |
+
model_info = {
|
200 |
+
'safetensors_available': os.path.exists(safetensors_path) and os.path.exists(safetensors_metadata_path),
|
201 |
+
'onnx_available': os.path.exists(onnx_path) and os.path.exists(safetensors_metadata_path),
|
202 |
+
'validation_results_available': os.path.exists(validation_results_path)
|
203 |
+
}
|
204 |
+
|
205 |
+
# Load metadata (same for both model types)
|
206 |
+
metadata = None
|
207 |
+
if os.path.exists(safetensors_metadata_path):
|
208 |
+
try:
|
209 |
+
with open(safetensors_metadata_path, 'r') as f:
|
210 |
+
metadata = json.load(f)
|
211 |
+
except Exception as e:
|
212 |
+
print(f"Error loading metadata: {e}")
|
213 |
+
|
214 |
+
# Load validation results for thresholds
|
215 |
+
thresholds = {}
|
216 |
+
if model_info['validation_results_available']:
|
217 |
+
validation_data = load_validation_results(validation_results_path)
|
218 |
+
if validation_data:
|
219 |
+
thresholds = extract_thresholds_from_results(validation_data)
|
220 |
+
|
221 |
+
# Add default thresholds if not available
|
222 |
+
if not thresholds:
|
223 |
+
thresholds = {
|
224 |
+
'overall': {
|
225 |
+
'balanced': {'threshold': 0.5, 'micro_f1': 0, 'macro_f1': 0},
|
226 |
+
'micro_optimized': {'threshold': 0.6, 'micro_f1': 0, 'macro_f1': 0},
|
227 |
+
'macro_optimized': {'threshold': 0.4, 'micro_f1': 0, 'macro_f1': 0}
|
228 |
+
},
|
229 |
+
'categories': {}
|
230 |
+
}
|
231 |
+
|
232 |
+
return model_info, metadata, thresholds
|
233 |
+
|
234 |
+
def load_safetensors_model(safetensors_path, metadata_path):
|
235 |
+
"""Load SafeTensors model"""
|
236 |
+
try:
|
237 |
+
from safetensors.torch import load_file
|
238 |
+
import torch
|
239 |
+
|
240 |
+
# Load metadata
|
241 |
+
with open(metadata_path, 'r') as f:
|
242 |
+
metadata = json.load(f)
|
243 |
+
|
244 |
+
# Import the model class (assuming it's available)
|
245 |
+
# You'll need to make sure the ImageTagger class is importable
|
246 |
+
from utils.model_loader import ImageTagger # Update this import
|
247 |
+
|
248 |
+
model_info = metadata['model_info']
|
249 |
+
dataset_info = metadata['dataset_info']
|
250 |
+
|
251 |
+
# Recreate model architecture
|
252 |
+
model = ImageTagger(
|
253 |
+
total_tags=dataset_info['total_tags'],
|
254 |
+
dataset=None,
|
255 |
+
model_name=model_info['backbone'],
|
256 |
+
num_heads=model_info['num_attention_heads'],
|
257 |
+
dropout=0.0,
|
258 |
+
pretrained=False,
|
259 |
+
tag_context_size=model_info['tag_context_size'],
|
260 |
+
use_gradient_checkpointing=False,
|
261 |
+
img_size=model_info['img_size']
|
262 |
+
)
|
263 |
+
|
264 |
+
# Load weights
|
265 |
+
state_dict = load_file(safetensors_path)
|
266 |
+
model.load_state_dict(state_dict)
|
267 |
+
model.eval()
|
268 |
+
|
269 |
+
return model, metadata
|
270 |
+
except Exception as e:
|
271 |
+
raise Exception(f"Failed to load SafeTensors model: {e}")
|
272 |
+
|
273 |
+
def get_profile_metrics(thresholds, profile_name):
|
274 |
+
"""Extract metrics for the given profile from the thresholds dictionary"""
|
275 |
+
profile_key = None
|
276 |
+
|
277 |
+
# Map UI-friendly names to internal keys
|
278 |
+
if profile_name == "Micro Optimized":
|
279 |
+
profile_key = "micro_optimized"
|
280 |
+
elif profile_name == "Macro Optimized":
|
281 |
+
profile_key = "macro_optimized"
|
282 |
+
elif profile_name == "Balanced":
|
283 |
+
profile_key = "balanced"
|
284 |
+
elif profile_name in ["Overall", "Category-specific"]:
|
285 |
+
profile_key = "macro_optimized" # Use macro as default for these modes
|
286 |
+
|
287 |
+
if profile_key and 'overall' in thresholds and profile_key in thresholds['overall']:
|
288 |
+
return thresholds['overall'][profile_key]
|
289 |
+
|
290 |
+
return None
|
291 |
+
|
292 |
+
def on_threshold_profile_change():
|
293 |
+
"""Handle threshold profile changes"""
|
294 |
+
new_profile = st.session_state.threshold_profile
|
295 |
+
|
296 |
+
if hasattr(st.session_state, 'thresholds') and hasattr(st.session_state, 'settings'):
|
297 |
+
# Initialize category thresholds if needed
|
298 |
+
if st.session_state.settings['active_category_thresholds'] is None:
|
299 |
+
st.session_state.settings['active_category_thresholds'] = {}
|
300 |
+
|
301 |
+
current_thresholds = st.session_state.settings['active_category_thresholds']
|
302 |
+
|
303 |
+
# Map profile names to keys
|
304 |
+
profile_key = None
|
305 |
+
if new_profile == "Micro Optimized":
|
306 |
+
profile_key = "micro_optimized"
|
307 |
+
elif new_profile == "Macro Optimized":
|
308 |
+
profile_key = "macro_optimized"
|
309 |
+
elif new_profile == "Balanced":
|
310 |
+
profile_key = "balanced"
|
311 |
+
|
312 |
+
# Update thresholds based on profile
|
313 |
+
if profile_key and 'overall' in st.session_state.thresholds and profile_key in st.session_state.thresholds['overall']:
|
314 |
+
st.session_state.settings['active_threshold'] = st.session_state.thresholds['overall'][profile_key]['threshold']
|
315 |
+
|
316 |
+
# Set category thresholds
|
317 |
+
for category in st.session_state.categories:
|
318 |
+
if category in st.session_state.thresholds['categories'] and profile_key in st.session_state.thresholds['categories'][category]:
|
319 |
+
current_thresholds[category] = st.session_state.thresholds['categories'][category][profile_key]['threshold']
|
320 |
+
else:
|
321 |
+
current_thresholds[category] = st.session_state.settings['active_threshold']
|
322 |
+
|
323 |
+
elif new_profile == "Overall":
|
324 |
+
# Use balanced threshold for Overall profile
|
325 |
+
if 'overall' in st.session_state.thresholds and 'balanced' in st.session_state.thresholds['overall']:
|
326 |
+
st.session_state.settings['active_threshold'] = st.session_state.thresholds['overall']['balanced']['threshold']
|
327 |
+
else:
|
328 |
+
st.session_state.settings['active_threshold'] = 0.5
|
329 |
+
|
330 |
+
# Clear category-specific overrides
|
331 |
+
st.session_state.settings['active_category_thresholds'] = {}
|
332 |
+
|
333 |
+
elif new_profile == "Category-specific":
|
334 |
+
# Initialize with balanced thresholds
|
335 |
+
if 'overall' in st.session_state.thresholds and 'balanced' in st.session_state.thresholds['overall']:
|
336 |
+
st.session_state.settings['active_threshold'] = st.session_state.thresholds['overall']['balanced']['threshold']
|
337 |
+
else:
|
338 |
+
st.session_state.settings['active_threshold'] = 0.5
|
339 |
+
|
340 |
+
# Initialize category thresholds
|
341 |
+
for category in st.session_state.categories:
|
342 |
+
if category in st.session_state.thresholds['categories'] and 'balanced' in st.session_state.thresholds['categories'][category]:
|
343 |
+
current_thresholds[category] = st.session_state.thresholds['categories'][category]['balanced']['threshold']
|
344 |
+
else:
|
345 |
+
current_thresholds[category] = st.session_state.settings['active_threshold']
|
346 |
+
|
347 |
+
def apply_thresholds(all_probs, threshold_profile, active_threshold, active_category_thresholds, min_confidence, selected_categories):
|
348 |
+
"""Apply thresholds to raw probabilities and return filtered tags"""
|
349 |
+
tags = {}
|
350 |
+
all_tags = []
|
351 |
+
|
352 |
+
# Handle None case for active_category_thresholds
|
353 |
+
active_category_thresholds = active_category_thresholds or {}
|
354 |
+
|
355 |
+
for category, cat_probs in all_probs.items():
|
356 |
+
# Get the appropriate threshold for this category
|
357 |
+
threshold = active_category_thresholds.get(category, active_threshold)
|
358 |
+
|
359 |
+
# Filter tags above threshold
|
360 |
+
tags[category] = [(tag, prob) for tag, prob in cat_probs if prob >= threshold]
|
361 |
+
|
362 |
+
# Add to all_tags if selected
|
363 |
+
if selected_categories.get(category, True):
|
364 |
+
for tag, prob in tags[category]:
|
365 |
+
all_tags.append(tag)
|
366 |
+
|
367 |
+
return tags, all_tags
|
368 |
+
|
369 |
+
def image_tagger_app():
|
370 |
+
"""Main Streamlit application for image tagging."""
|
371 |
+
st.set_page_config(layout="wide", page_title="Camie Tagger", page_icon="🖼️")
|
372 |
+
|
373 |
+
st.title("Camie-Tagger-v2 Interface")
|
374 |
+
st.markdown("---")
|
375 |
+
|
376 |
+
# Initialize settings
|
377 |
+
if 'settings' not in st.session_state:
|
378 |
+
st.session_state.settings = {
|
379 |
+
'show_all_tags': False,
|
380 |
+
'compact_view': True,
|
381 |
+
'min_confidence': 0.01,
|
382 |
+
'threshold_profile': "Macro",
|
383 |
+
'active_threshold': 0.5,
|
384 |
+
'active_category_thresholds': {}, # Initialize as empty dict, not None
|
385 |
+
'selected_categories': {},
|
386 |
+
'replace_underscores': False
|
387 |
+
}
|
388 |
+
st.session_state.show_profile_help = False
|
389 |
+
|
390 |
+
# Session state initialization for model
|
391 |
+
if 'model_loaded' not in st.session_state:
|
392 |
+
st.session_state.model_loaded = False
|
393 |
+
st.session_state.model = None
|
394 |
+
st.session_state.thresholds = None
|
395 |
+
st.session_state.metadata = None
|
396 |
+
st.session_state.model_type = "onnx" # Default to ONNX
|
397 |
+
|
398 |
+
# Sidebar for model selection and information
|
399 |
+
with st.sidebar:
|
400 |
+
# Support information
|
401 |
+
st.subheader("💡 Notes")
|
402 |
+
|
403 |
+
st.markdown("""
|
404 |
+
This tagger was trained on a subset of the available data due to hardware limitations.
|
405 |
+
|
406 |
+
A more comprehensive model trained on the full 3+ million image dataset would provide:
|
407 |
+
- More recent characters and tags.
|
408 |
+
- Improved accuracy.
|
409 |
+
|
410 |
+
If you find this tool useful and would like to support future development:
|
411 |
+
""")
|
412 |
+
|
413 |
+
# Add Buy Me a Coffee button with Star of the City-like glow effect
|
414 |
+
st.markdown("""
|
415 |
+
<style>
|
416 |
+
@keyframes coffee-button-glow {
|
417 |
+
0% { box-shadow: 0 0 5px #FFD700; }
|
418 |
+
50% { box-shadow: 0 0 15px #FFD700; }
|
419 |
+
100% { box-shadow: 0 0 5px #FFD700; }
|
420 |
+
}
|
421 |
+
|
422 |
+
.coffee-button {
|
423 |
+
display: inline-block;
|
424 |
+
animation: coffee-button-glow 2s infinite;
|
425 |
+
border-radius: 5px;
|
426 |
+
transition: transform 0.3s ease;
|
427 |
+
}
|
428 |
+
|
429 |
+
.coffee-button:hover {
|
430 |
+
transform: scale(1.05);
|
431 |
+
}
|
432 |
+
</style>
|
433 |
+
|
434 |
+
<a href="https://ko-fi.com/camais" target="_blank" class="coffee-button">
|
435 |
+
<img src="https://cdn.buymeacoffee.com/buttons/v2/default-yellow.png"
|
436 |
+
alt="Buy Me A Coffee"
|
437 |
+
style="height: 45px; width: 162px; border-radius: 5px;" />
|
438 |
+
</a>
|
439 |
+
""", unsafe_allow_html=True)
|
440 |
+
|
441 |
+
st.markdown("""
|
442 |
+
Your support helps with:
|
443 |
+
- GPU costs for training
|
444 |
+
- Storage for larger datasets
|
445 |
+
- Development of new features
|
446 |
+
- Future projects
|
447 |
+
|
448 |
+
Thank you! 🙏
|
449 |
+
|
450 |
+
Full Details: https://huggingface.co/Camais03/camie-tagger-v2
|
451 |
+
""")
|
452 |
+
|
453 |
+
st.header("Model Selection")
|
454 |
+
|
455 |
+
# Load model information
|
456 |
+
model_info, metadata, thresholds = load_model_and_metadata()
|
457 |
+
|
458 |
+
# Determine available model options
|
459 |
+
model_options = []
|
460 |
+
if model_info['onnx_available']:
|
461 |
+
model_options.append("ONNX (Recommended)")
|
462 |
+
if model_info['safetensors_available']:
|
463 |
+
model_options.append("SafeTensors (PyTorch)")
|
464 |
+
|
465 |
+
if not model_options:
|
466 |
+
st.error("No model files found!")
|
467 |
+
st.info(f"Looking for models in: {MODEL_DIR}")
|
468 |
+
st.info("Expected files:")
|
469 |
+
st.info("- camie-tagger-v2.onnx")
|
470 |
+
st.info("- camie-tagger-v2.safetensors")
|
471 |
+
st.info("- camie-tagger-v2-metadata.json")
|
472 |
+
st.stop()
|
473 |
+
|
474 |
+
# Model type selection
|
475 |
+
default_index = 0 if model_info['onnx_available'] else 0
|
476 |
+
model_type = st.radio(
|
477 |
+
"Select Model Type:",
|
478 |
+
model_options,
|
479 |
+
index=default_index,
|
480 |
+
help="ONNX: Optimized for speed and compatibility\nSafeTensors: Native PyTorch format"
|
481 |
+
)
|
482 |
+
|
483 |
+
# Convert selection to internal model type
|
484 |
+
if model_type == "ONNX (Recommended)":
|
485 |
+
selected_model_type = "onnx"
|
486 |
+
else:
|
487 |
+
selected_model_type = "safetensors"
|
488 |
+
|
489 |
+
# If model type changed, reload
|
490 |
+
if selected_model_type != st.session_state.model_type:
|
491 |
+
st.session_state.model_loaded = False
|
492 |
+
st.session_state.model_type = selected_model_type
|
493 |
+
|
494 |
+
# Reload button
|
495 |
+
if st.button("Reload Model") and st.session_state.model_loaded:
|
496 |
+
st.session_state.model_loaded = False
|
497 |
+
st.info("Reloading model...")
|
498 |
+
|
499 |
+
# Try to load the model
|
500 |
+
if not st.session_state.model_loaded:
|
501 |
+
try:
|
502 |
+
with st.spinner(f"Loading {st.session_state.model_type.upper()} model..."):
|
503 |
+
if st.session_state.model_type == "onnx":
|
504 |
+
# Load ONNX model
|
505 |
+
import onnxruntime as ort
|
506 |
+
|
507 |
+
onnx_path = os.path.join(MODEL_DIR, "camie-tagger-v2.onnx")
|
508 |
+
|
509 |
+
# Check ONNX providers
|
510 |
+
providers = ort.get_available_providers()
|
511 |
+
gpu_available = any('CUDA' in provider for provider in providers)
|
512 |
+
|
513 |
+
# Create ONNX session
|
514 |
+
session = ort.InferenceSession(onnx_path, providers=providers)
|
515 |
+
|
516 |
+
st.session_state.model = session
|
517 |
+
st.session_state.device = f"ONNX Runtime ({'GPU' if gpu_available else 'CPU'})"
|
518 |
+
st.session_state.param_dtype = "float32"
|
519 |
+
|
520 |
+
else:
|
521 |
+
# Load SafeTensors model
|
522 |
+
safetensors_path = os.path.join(MODEL_DIR, "camie-tagger-v2.safetensors")
|
523 |
+
metadata_path = os.path.join(MODEL_DIR, "camie-tagger-v2-metadata.json")
|
524 |
+
|
525 |
+
model, loaded_metadata = load_safetensors_model(safetensors_path, metadata_path)
|
526 |
+
|
527 |
+
st.session_state.model = model
|
528 |
+
device = next(model.parameters()).device
|
529 |
+
param_dtype = next(model.parameters()).dtype
|
530 |
+
st.session_state.device = device
|
531 |
+
st.session_state.param_dtype = param_dtype
|
532 |
+
metadata = loaded_metadata # Use loaded metadata instead
|
533 |
+
|
534 |
+
# Store common info
|
535 |
+
st.session_state.thresholds = thresholds
|
536 |
+
st.session_state.metadata = metadata
|
537 |
+
st.session_state.model_loaded = True
|
538 |
+
|
539 |
+
# Get categories
|
540 |
+
if metadata and 'dataset_info' in metadata:
|
541 |
+
tag_mapping = metadata['dataset_info']['tag_mapping']
|
542 |
+
categories = list(set(tag_mapping['tag_to_category'].values()))
|
543 |
+
st.session_state.categories = categories
|
544 |
+
|
545 |
+
# Initialize selected categories
|
546 |
+
if not st.session_state.settings['selected_categories']:
|
547 |
+
st.session_state.settings['selected_categories'] = {cat: True for cat in categories}
|
548 |
+
|
549 |
+
# Set initial threshold from validation results
|
550 |
+
if 'overall' in thresholds and 'balanced' in thresholds['overall']:
|
551 |
+
st.session_state.settings['active_threshold'] = thresholds['overall']['macro_optimized']['threshold']
|
552 |
+
|
553 |
+
except Exception as e:
|
554 |
+
st.error(f"Error loading model: {str(e)}")
|
555 |
+
st.code(traceback.format_exc())
|
556 |
+
st.stop()
|
557 |
+
|
558 |
+
# Display model information in sidebar
|
559 |
+
with st.sidebar:
|
560 |
+
st.header("Model Information")
|
561 |
+
if st.session_state.model_loaded:
|
562 |
+
if st.session_state.model_type == "onnx":
|
563 |
+
st.success("Using ONNX Model")
|
564 |
+
else:
|
565 |
+
st.success("Using SafeTensors Model")
|
566 |
+
|
567 |
+
st.write(f"Device: {st.session_state.device}")
|
568 |
+
st.write(f"Precision: {st.session_state.param_dtype}")
|
569 |
+
|
570 |
+
if st.session_state.metadata:
|
571 |
+
if 'dataset_info' in st.session_state.metadata:
|
572 |
+
total_tags = st.session_state.metadata['dataset_info']['total_tags']
|
573 |
+
st.write(f"Total tags: {total_tags}")
|
574 |
+
elif 'total_tags' in st.session_state.metadata:
|
575 |
+
st.write(f"Total tags: {st.session_state.metadata['total_tags']}")
|
576 |
+
|
577 |
+
# Show categories
|
578 |
+
with st.expander("Available Categories"):
|
579 |
+
for category in sorted(st.session_state.categories):
|
580 |
+
st.write(f"- {category.capitalize()}")
|
581 |
+
|
582 |
+
# About section
|
583 |
+
with st.expander("About this app"):
|
584 |
+
st.write("""
|
585 |
+
This app uses a trained image tagging model to analyze and tag images.
|
586 |
+
|
587 |
+
**Model Options**:
|
588 |
+
- **ONNX (Recommended)**: Optimized for inference speed with broad compatibility
|
589 |
+
- **SafeTensors**: Native PyTorch format for advanced users
|
590 |
+
|
591 |
+
**Features**:
|
592 |
+
- Upload or process images in batches
|
593 |
+
- Multiple threshold profiles based on validation results
|
594 |
+
- Category-specific threshold adjustment
|
595 |
+
- Export tags in various formats
|
596 |
+
- Fast inference with GPU acceleration (when available)
|
597 |
+
|
598 |
+
**Threshold Profiles**:
|
599 |
+
- **Micro Optimized**: Best overall F1 score (67.3% micro F1)
|
600 |
+
- **Macro Optimized**: Balanced across categories (50.6% macro F1)
|
601 |
+
- **Balanced**: Good general-purpose setting
|
602 |
+
- **Overall**: Single adjustable threshold
|
603 |
+
- **Category-specific**: Fine-tune each category individually
|
604 |
+
""")
|
605 |
+
|
606 |
+
# Main content area - Image upload and processing
|
607 |
+
col1, col2 = st.columns([1, 1.5])
|
608 |
+
|
609 |
+
with col1:
|
610 |
+
st.header("Image")
|
611 |
+
|
612 |
+
upload_tab, batch_tab = st.tabs(["Upload Image", "Batch Processing"])
|
613 |
+
|
614 |
+
image_path = None
|
615 |
+
|
616 |
+
with upload_tab:
|
617 |
+
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
|
618 |
+
|
619 |
+
if uploaded_file:
|
620 |
+
# Create temporary file
|
621 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as tmp_file:
|
622 |
+
tmp_file.write(uploaded_file.getvalue())
|
623 |
+
image_path = tmp_file.name
|
624 |
+
|
625 |
+
st.session_state.original_filename = uploaded_file.name
|
626 |
+
|
627 |
+
# Display image
|
628 |
+
image = Image.open(uploaded_file)
|
629 |
+
st.image(image, use_container_width=True)
|
630 |
+
|
631 |
+
with batch_tab:
|
632 |
+
st.subheader("Batch Process Images")
|
633 |
+
|
634 |
+
# Folder selection
|
635 |
+
batch_folder = st.text_input("Enter folder path containing images:", "")
|
636 |
+
|
637 |
+
# Save options
|
638 |
+
save_options = st.radio(
|
639 |
+
"Where to save tag files:",
|
640 |
+
["Same folder as images", "Custom location", "Default save folder"],
|
641 |
+
index=0
|
642 |
+
)
|
643 |
+
|
644 |
+
# Batch size control
|
645 |
+
st.subheader("Performance Options")
|
646 |
+
batch_size = st.number_input("Batch size", min_value=1, max_value=32, value=4,
|
647 |
+
help="Higher values may improve speed but use more memory")
|
648 |
+
|
649 |
+
# Category limits
|
650 |
+
enable_category_limits = st.checkbox("Limit tags per category in batch output", value=False)
|
651 |
+
|
652 |
+
if enable_category_limits and hasattr(st.session_state, 'categories'):
|
653 |
+
if 'category_limits' not in st.session_state:
|
654 |
+
st.session_state.category_limits = {}
|
655 |
+
|
656 |
+
st.markdown("**Limit Values:** -1 = no limit, 0 = exclude, N = top N tags")
|
657 |
+
|
658 |
+
limit_cols = st.columns(2)
|
659 |
+
for i, category in enumerate(sorted(st.session_state.categories)):
|
660 |
+
col_idx = i % 2
|
661 |
+
with limit_cols[col_idx]:
|
662 |
+
current_limit = st.session_state.category_limits.get(category, -1)
|
663 |
+
new_limit = st.number_input(
|
664 |
+
f"{category.capitalize()}:",
|
665 |
+
value=current_limit,
|
666 |
+
min_value=-1,
|
667 |
+
step=1,
|
668 |
+
key=f"limit_{category}"
|
669 |
+
)
|
670 |
+
st.session_state.category_limits[category] = new_limit
|
671 |
+
|
672 |
+
# Process batch button
|
673 |
+
if batch_folder and os.path.isdir(batch_folder):
|
674 |
+
image_files = []
|
675 |
+
for ext in ['*.jpg', '*.jpeg', '*.png']:
|
676 |
+
image_files.extend(glob.glob(os.path.join(batch_folder, ext)))
|
677 |
+
image_files.extend(glob.glob(os.path.join(batch_folder, ext.upper())))
|
678 |
+
|
679 |
+
if image_files:
|
680 |
+
st.write(f"Found {len(image_files)} images")
|
681 |
+
|
682 |
+
if st.button("🔄 Process All Images", type="primary"):
|
683 |
+
if not st.session_state.model_loaded:
|
684 |
+
st.error("Model not loaded")
|
685 |
+
else:
|
686 |
+
with st.spinner("Processing images..."):
|
687 |
+
progress_bar = st.progress(0)
|
688 |
+
status_text = st.empty()
|
689 |
+
|
690 |
+
def update_progress(current, total, image_path):
|
691 |
+
progress = current / total if total > 0 else 0
|
692 |
+
progress_bar.progress(progress)
|
693 |
+
status_text.text(f"Processing {current}/{total}: {os.path.basename(image_path) if image_path else 'Complete'}")
|
694 |
+
|
695 |
+
# Determine save directory
|
696 |
+
if save_options == "Same folder as images":
|
697 |
+
save_dir = batch_folder
|
698 |
+
elif save_options == "Custom location":
|
699 |
+
save_dir = st.text_input("Custom save directory:", batch_folder)
|
700 |
+
else:
|
701 |
+
save_dir = os.path.join(os.path.dirname(__file__), "saved_tags")
|
702 |
+
os.makedirs(save_dir, exist_ok=True)
|
703 |
+
|
704 |
+
# Get current settings
|
705 |
+
category_limits = st.session_state.category_limits if enable_category_limits else None
|
706 |
+
|
707 |
+
# Process based on model type
|
708 |
+
if st.session_state.model_type == "onnx":
|
709 |
+
batch_results = batch_process_images_onnx(
|
710 |
+
folder_path=batch_folder,
|
711 |
+
model_path=os.path.join(MODEL_DIR, "camie-tagger-v2.onnx"),
|
712 |
+
metadata_path=os.path.join(MODEL_DIR, "camie-tagger-v2-metadata.json"),
|
713 |
+
threshold_profile=st.session_state.settings['threshold_profile'],
|
714 |
+
active_threshold=st.session_state.settings['active_threshold'],
|
715 |
+
active_category_thresholds=st.session_state.settings['active_category_thresholds'],
|
716 |
+
save_dir=save_dir,
|
717 |
+
progress_callback=update_progress,
|
718 |
+
min_confidence=st.session_state.settings['min_confidence'],
|
719 |
+
batch_size=batch_size,
|
720 |
+
category_limits=category_limits
|
721 |
+
)
|
722 |
+
else:
|
723 |
+
# SafeTensors processing (would need to implement)
|
724 |
+
st.error("SafeTensors batch processing not implemented yet")
|
725 |
+
batch_results = None
|
726 |
+
|
727 |
+
if batch_results:
|
728 |
+
display_batch_results(batch_results)
|
729 |
+
|
730 |
+
# Column 2: Controls and Results
|
731 |
+
with col2:
|
732 |
+
st.header("Tagging Controls")
|
733 |
+
|
734 |
+
# Threshold profile selection
|
735 |
+
all_profiles = [
|
736 |
+
"Micro Optimized",
|
737 |
+
"Macro Optimized",
|
738 |
+
"Balanced",
|
739 |
+
"Overall",
|
740 |
+
"Category-specific"
|
741 |
+
]
|
742 |
+
|
743 |
+
profile_col1, profile_col2 = st.columns([3, 1])
|
744 |
+
|
745 |
+
with profile_col1:
|
746 |
+
threshold_profile = st.selectbox(
|
747 |
+
"Select threshold profile",
|
748 |
+
options=all_profiles,
|
749 |
+
index=1, # Default to Macro
|
750 |
+
key="threshold_profile",
|
751 |
+
on_change=on_threshold_profile_change
|
752 |
+
)
|
753 |
+
|
754 |
+
with profile_col2:
|
755 |
+
if st.button("ℹ️ Help", key="profile_help"):
|
756 |
+
st.session_state.show_profile_help = not st.session_state.get('show_profile_help', False)
|
757 |
+
|
758 |
+
# Show profile help
|
759 |
+
if st.session_state.get('show_profile_help', False):
|
760 |
+
st.markdown(threshold_profile_explanations[threshold_profile])
|
761 |
+
else:
|
762 |
+
st.info(threshold_profile_descriptions[threshold_profile])
|
763 |
+
|
764 |
+
# Show profile metrics if available
|
765 |
+
if st.session_state.model_loaded:
|
766 |
+
metrics = get_profile_metrics(st.session_state.thresholds, threshold_profile)
|
767 |
+
|
768 |
+
if metrics:
|
769 |
+
metrics_cols = st.columns(3)
|
770 |
+
|
771 |
+
with metrics_cols[0]:
|
772 |
+
st.metric("Threshold", f"{metrics['threshold']:.3f}")
|
773 |
+
|
774 |
+
with metrics_cols[1]:
|
775 |
+
st.metric("Micro F1", f"{metrics['micro_f1']:.1f}%")
|
776 |
+
|
777 |
+
with metrics_cols[2]:
|
778 |
+
st.metric("Macro F1", f"{metrics['macro_f1']:.1f}%")
|
779 |
+
|
780 |
+
# Threshold controls based on profile
|
781 |
+
if st.session_state.model_loaded:
|
782 |
+
active_threshold = st.session_state.settings.get('active_threshold', 0.5)
|
783 |
+
active_category_thresholds = st.session_state.settings.get('active_category_thresholds', {})
|
784 |
+
|
785 |
+
if threshold_profile in ["Micro Optimized", "Macro Optimized", "Balanced"]:
|
786 |
+
# Show reference threshold (disabled)
|
787 |
+
st.slider(
|
788 |
+
"Threshold (from validation)",
|
789 |
+
min_value=0.01,
|
790 |
+
max_value=1.0,
|
791 |
+
value=float(active_threshold),
|
792 |
+
step=0.01,
|
793 |
+
disabled=True,
|
794 |
+
help="This threshold is optimized from validation results"
|
795 |
+
)
|
796 |
+
|
797 |
+
elif threshold_profile == "Overall":
|
798 |
+
# Adjustable overall threshold
|
799 |
+
active_threshold = st.slider(
|
800 |
+
"Overall threshold",
|
801 |
+
min_value=0.01,
|
802 |
+
max_value=1.0,
|
803 |
+
value=float(active_threshold),
|
804 |
+
step=0.01
|
805 |
+
)
|
806 |
+
st.session_state.settings['active_threshold'] = active_threshold
|
807 |
+
|
808 |
+
elif threshold_profile == "Category-specific":
|
809 |
+
# Show reference overall threshold
|
810 |
+
st.slider(
|
811 |
+
"Overall threshold (reference)",
|
812 |
+
min_value=0.01,
|
813 |
+
max_value=1.0,
|
814 |
+
value=float(active_threshold),
|
815 |
+
step=0.01,
|
816 |
+
disabled=True
|
817 |
+
)
|
818 |
+
|
819 |
+
st.write("Adjust thresholds for individual categories:")
|
820 |
+
|
821 |
+
# Category sliders
|
822 |
+
slider_cols = st.columns(2)
|
823 |
+
|
824 |
+
if not active_category_thresholds:
|
825 |
+
active_category_thresholds = {}
|
826 |
+
|
827 |
+
for i, category in enumerate(sorted(st.session_state.categories)):
|
828 |
+
col_idx = i % 2
|
829 |
+
with slider_cols[col_idx]:
|
830 |
+
default_val = active_category_thresholds.get(category, active_threshold)
|
831 |
+
new_threshold = st.slider(
|
832 |
+
f"{category.capitalize()}",
|
833 |
+
min_value=0.01,
|
834 |
+
max_value=1.0,
|
835 |
+
value=float(default_val),
|
836 |
+
step=0.01,
|
837 |
+
key=f"slider_{category}"
|
838 |
+
)
|
839 |
+
active_category_thresholds[category] = new_threshold
|
840 |
+
|
841 |
+
st.session_state.settings['active_category_thresholds'] = active_category_thresholds
|
842 |
+
|
843 |
+
# Display options
|
844 |
+
with st.expander("Display Options", expanded=False):
|
845 |
+
col1, col2 = st.columns(2)
|
846 |
+
with col1:
|
847 |
+
show_all_tags = st.checkbox("Show all tags (including below threshold)",
|
848 |
+
value=st.session_state.settings['show_all_tags'])
|
849 |
+
compact_view = st.checkbox("Compact view (hide progress bars)",
|
850 |
+
value=st.session_state.settings['compact_view'])
|
851 |
+
replace_underscores = st.checkbox("Replace underscores with spaces",
|
852 |
+
value=st.session_state.settings.get('replace_underscores', False))
|
853 |
+
|
854 |
+
with col2:
|
855 |
+
min_confidence = st.slider("Minimum confidence to display", 0.0, 0.5,
|
856 |
+
st.session_state.settings['min_confidence'], 0.01)
|
857 |
+
|
858 |
+
# Update settings
|
859 |
+
st.session_state.settings.update({
|
860 |
+
'show_all_tags': show_all_tags,
|
861 |
+
'compact_view': compact_view,
|
862 |
+
'min_confidence': min_confidence,
|
863 |
+
'replace_underscores': replace_underscores
|
864 |
+
})
|
865 |
+
|
866 |
+
# Category selection
|
867 |
+
st.write("Categories to include in 'All Tags' section:")
|
868 |
+
|
869 |
+
category_cols = st.columns(3)
|
870 |
+
selected_categories = {}
|
871 |
+
|
872 |
+
if hasattr(st.session_state, 'categories'):
|
873 |
+
for i, category in enumerate(sorted(st.session_state.categories)):
|
874 |
+
col_idx = i % 3
|
875 |
+
with category_cols[col_idx]:
|
876 |
+
default_val = st.session_state.settings['selected_categories'].get(category, True)
|
877 |
+
selected_categories[category] = st.checkbox(
|
878 |
+
f"{category.capitalize()}",
|
879 |
+
value=default_val,
|
880 |
+
key=f"cat_select_{category}"
|
881 |
+
)
|
882 |
+
|
883 |
+
st.session_state.settings['selected_categories'] = selected_categories
|
884 |
+
|
885 |
+
# Run tagging button
|
886 |
+
if image_path and st.button("Run Tagging"):
|
887 |
+
if not st.session_state.model_loaded:
|
888 |
+
st.error("Model not loaded")
|
889 |
+
else:
|
890 |
+
with st.spinner("Analyzing image..."):
|
891 |
+
try:
|
892 |
+
# Process image based on model type
|
893 |
+
if st.session_state.model_type == "onnx":
|
894 |
+
from utils.onnx_processing import process_single_image_onnx
|
895 |
+
|
896 |
+
result = process_single_image_onnx(
|
897 |
+
image_path=image_path,
|
898 |
+
model_path=os.path.join(MODEL_DIR, "camie-tagger-v2.onnx"),
|
899 |
+
metadata=st.session_state.metadata,
|
900 |
+
threshold_profile=threshold_profile,
|
901 |
+
active_threshold=st.session_state.settings['active_threshold'],
|
902 |
+
active_category_thresholds=st.session_state.settings.get('active_category_thresholds', {}),
|
903 |
+
min_confidence=st.session_state.settings['min_confidence']
|
904 |
+
)
|
905 |
+
else:
|
906 |
+
# SafeTensors processing
|
907 |
+
result = process_image(
|
908 |
+
image_path=image_path,
|
909 |
+
model=st.session_state.model,
|
910 |
+
thresholds=st.session_state.thresholds,
|
911 |
+
metadata=st.session_state.metadata,
|
912 |
+
threshold_profile=threshold_profile,
|
913 |
+
active_threshold=st.session_state.settings['active_threshold'],
|
914 |
+
active_category_thresholds=st.session_state.settings.get('active_category_thresholds', {}),
|
915 |
+
min_confidence=st.session_state.settings['min_confidence']
|
916 |
+
)
|
917 |
+
|
918 |
+
if result['success']:
|
919 |
+
st.session_state.all_probs = result['all_probs']
|
920 |
+
st.session_state.tags = result['tags']
|
921 |
+
st.session_state.all_tags = result['all_tags']
|
922 |
+
st.success("Analysis completed!")
|
923 |
+
else:
|
924 |
+
st.error(f"Analysis failed: {result.get('error', 'Unknown error')}")
|
925 |
+
|
926 |
+
except Exception as e:
|
927 |
+
st.error(f"Error during analysis: {str(e)}")
|
928 |
+
st.code(traceback.format_exc())
|
929 |
+
|
930 |
+
# Display results
|
931 |
+
if image_path and hasattr(st.session_state, 'all_probs'):
|
932 |
+
st.header("Predictions")
|
933 |
+
|
934 |
+
# Apply current thresholds
|
935 |
+
filtered_tags, current_all_tags = apply_thresholds(
|
936 |
+
st.session_state.all_probs,
|
937 |
+
threshold_profile,
|
938 |
+
st.session_state.settings['active_threshold'],
|
939 |
+
st.session_state.settings.get('active_category_thresholds', {}),
|
940 |
+
st.session_state.settings['min_confidence'],
|
941 |
+
st.session_state.settings['selected_categories']
|
942 |
+
)
|
943 |
+
|
944 |
+
all_tags = []
|
945 |
+
|
946 |
+
# Display by category
|
947 |
+
for category in sorted(st.session_state.all_probs.keys()):
|
948 |
+
all_tags_in_category = st.session_state.all_probs.get(category, [])
|
949 |
+
filtered_tags_in_category = filtered_tags.get(category, [])
|
950 |
+
|
951 |
+
if all_tags_in_category:
|
952 |
+
expander_label = f"{category.capitalize()} ({len(filtered_tags_in_category)} tags)"
|
953 |
+
|
954 |
+
with st.expander(expander_label, expanded=True):
|
955 |
+
# Get threshold for this category (handle None case)
|
956 |
+
active_category_thresholds = st.session_state.settings.get('active_category_thresholds') or {}
|
957 |
+
threshold = active_category_thresholds.get(category, st.session_state.settings['active_threshold'])
|
958 |
+
|
959 |
+
# Determine tags to display
|
960 |
+
if st.session_state.settings['show_all_tags']:
|
961 |
+
tags_to_display = all_tags_in_category
|
962 |
+
else:
|
963 |
+
tags_to_display = [(tag, prob) for tag, prob in all_tags_in_category if prob >= threshold]
|
964 |
+
|
965 |
+
if not tags_to_display:
|
966 |
+
st.info(f"No tags above {st.session_state.settings['min_confidence']:.2f} confidence")
|
967 |
+
continue
|
968 |
+
|
969 |
+
# Display tags
|
970 |
+
if st.session_state.settings['compact_view']:
|
971 |
+
# Compact view
|
972 |
+
tag_list = []
|
973 |
+
replace_underscores = st.session_state.settings.get('replace_underscores', False)
|
974 |
+
|
975 |
+
for tag, prob in tags_to_display:
|
976 |
+
percentage = int(prob * 100)
|
977 |
+
display_tag = tag.replace('_', ' ') if replace_underscores else tag
|
978 |
+
tag_list.append(f"{display_tag} ({percentage}%)")
|
979 |
+
|
980 |
+
if prob >= threshold and st.session_state.settings['selected_categories'].get(category, True):
|
981 |
+
all_tags.append(tag)
|
982 |
+
|
983 |
+
st.markdown(", ".join(tag_list))
|
984 |
+
else:
|
985 |
+
# Expanded view with progress bars
|
986 |
+
for tag, prob in tags_to_display:
|
987 |
+
replace_underscores = st.session_state.settings.get('replace_underscores', False)
|
988 |
+
display_tag = tag.replace('_', ' ') if replace_underscores else tag
|
989 |
+
|
990 |
+
if prob >= threshold and st.session_state.settings['selected_categories'].get(category, True):
|
991 |
+
all_tags.append(tag)
|
992 |
+
tag_display = f"**{display_tag}**"
|
993 |
+
else:
|
994 |
+
tag_display = display_tag
|
995 |
+
|
996 |
+
st.write(tag_display)
|
997 |
+
st.markdown(display_progress_bar(prob), unsafe_allow_html=True)
|
998 |
+
|
999 |
+
# All tags summary
|
1000 |
+
st.markdown("---")
|
1001 |
+
st.subheader(f"All Tags ({len(all_tags)} total)")
|
1002 |
+
if all_tags:
|
1003 |
+
replace_underscores = st.session_state.settings.get('replace_underscores', False)
|
1004 |
+
if replace_underscores:
|
1005 |
+
display_tags = [tag.replace('_', ' ') for tag in all_tags]
|
1006 |
+
st.write(", ".join(display_tags))
|
1007 |
+
else:
|
1008 |
+
st.write(", ".join(all_tags))
|
1009 |
+
else:
|
1010 |
+
st.info("No tags detected above the threshold.")
|
1011 |
+
|
1012 |
+
# Save tags section
|
1013 |
+
st.markdown("---")
|
1014 |
+
st.subheader("Save Tags")
|
1015 |
+
|
1016 |
+
if 'custom_folders' not in st.session_state:
|
1017 |
+
st.session_state.custom_folders = get_default_save_locations()
|
1018 |
+
|
1019 |
+
selected_folder = st.selectbox(
|
1020 |
+
"Select save location:",
|
1021 |
+
options=st.session_state.custom_folders,
|
1022 |
+
format_func=lambda x: os.path.basename(x) if os.path.basename(x) else x
|
1023 |
+
)
|
1024 |
+
|
1025 |
+
if st.button("💾 Save to Selected Location"):
|
1026 |
+
try:
|
1027 |
+
original_filename = st.session_state.original_filename if hasattr(st.session_state, 'original_filename') else None
|
1028 |
+
|
1029 |
+
saved_path = save_tags_to_file(
|
1030 |
+
image_path=image_path,
|
1031 |
+
all_tags=all_tags,
|
1032 |
+
original_filename=original_filename,
|
1033 |
+
custom_dir=selected_folder,
|
1034 |
+
overwrite=True
|
1035 |
+
)
|
1036 |
+
|
1037 |
+
st.success(f"Tags saved to: {os.path.basename(saved_path)}")
|
1038 |
+
st.info(f"Full path: {saved_path}")
|
1039 |
+
|
1040 |
+
# Show file preview
|
1041 |
+
with st.expander("File Contents", expanded=True):
|
1042 |
+
with open(saved_path, 'r', encoding='utf-8') as f:
|
1043 |
+
content = f.read()
|
1044 |
+
st.code(content, language='text')
|
1045 |
+
|
1046 |
+
except Exception as e:
|
1047 |
+
st.error(f"Error saving tags: {str(e)}")
|
1048 |
+
|
1049 |
+
if __name__ == "__main__":
|
1050 |
image_tagger_app()
|