Camais03 commited on
Commit
6bcc99d
·
verified ·
1 Parent(s): 3c69087

Update app/app.py

Browse files
Files changed (1) hide show
  1. app/app.py +1049 -1049
app/app.py CHANGED
@@ -1,1050 +1,1050 @@
1
- #!/usr/bin/env python3
2
- """
3
- Camie-Tagger-V2 Application
4
- A Streamlit web app for tagging images using an AI model.
5
- """
6
-
7
- import streamlit as st
8
- import os
9
- import sys
10
- import traceback
11
- import tempfile
12
- import time
13
- import platform
14
- import subprocess
15
- import webbrowser
16
- import glob
17
- import numpy as np
18
- import matplotlib.pyplot as plt
19
- import io
20
- import base64
21
- import json
22
- from matplotlib.colors import LinearSegmentedColormap
23
- from PIL import Image
24
- from pathlib import Path
25
-
26
- # Add parent directory to path to allow importing from utils
27
- sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
28
-
29
- # Import utilities
30
- from utils.image_processing import process_image, batch_process_images
31
- from utils.file_utils import save_tags_to_file, get_default_save_locations
32
- from utils.ui_components import display_progress_bar, show_example_images, display_batch_results
33
- from utils.onnx_processing import batch_process_images_onnx
34
-
35
- # Define the model directory
36
- MODEL_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
37
- print(f"Using model directory: {MODEL_DIR}")
38
-
39
- # Define threshold profile descriptions and explanations
40
- threshold_profile_descriptions = {
41
- "Micro Optimized": "Maximizes micro-averaged F1 score (best for dominant classes). Optimal for overall prediction quality.",
42
- "Macro Optimized": "Maximizes macro-averaged F1 score (equal weight to all classes). Better for balanced performance across all tags.",
43
- "Balanced": "Provides a trade-off between precision and recall with moderate thresholds. Good general-purpose setting.",
44
- "Overall": "Uses a single threshold value across all categories. Simplest approach for consistent behavior.",
45
- "Category-specific": "Uses different optimal thresholds for each category. Best for fine-tuning results."
46
- }
47
-
48
- threshold_profile_explanations = {
49
- "Micro Optimized": """
50
- ### Micro Optimized Profile
51
-
52
- **Technical definition**: Maximizes micro-averaged F1 score, which calculates metrics globally across all predictions.
53
-
54
- **When to use**: When you want the best overall accuracy, especially for common tags and dominant categories.
55
-
56
- **Effects**:
57
- - Optimizes performance for the most frequent tags
58
- - Gives more weight to categories with many examples (like 'character' and 'general')
59
- - Provides higher precision in most common use cases
60
-
61
- **Performance from validation**:
62
- - Micro F1: ~67.3%
63
- - Macro F1: ~46.3%
64
- - Threshold: ~0.614
65
- """,
66
-
67
- "Macro Optimized": """
68
- ### Macro Optimized Profile
69
-
70
- **Technical definition**: Maximizes macro-averaged F1 score, which gives equal weight to all categories regardless of size.
71
-
72
- **When to use**: When balanced performance across all categories is important, including rare tags.
73
-
74
- **Effects**:
75
- - More balanced performance across all tag categories
76
- - Better at detecting rare or unusual tags
77
- - Generally has lower thresholds than micro-optimized
78
-
79
- **Performance from validation**:
80
- - Micro F1: ~60.9%
81
- - Macro F1: ~50.6%
82
- - Threshold: ~0.492
83
- """,
84
-
85
- "Balanced": """
86
- ### Balanced Profile
87
-
88
- **Technical definition**: Same as Micro Optimized but provides a good reference point for manual adjustment.
89
-
90
- **When to use**: For general-purpose tagging when you don't have specific recall or precision requirements.
91
-
92
- **Effects**:
93
- - Good middle ground between precision and recall
94
- - Works well for most common use cases
95
- - Default choice for most users
96
-
97
- **Performance from validation**:
98
- - Micro F1: ~67.3%
99
- - Macro F1: ~46.3%
100
- - Threshold: ~0.614
101
- """,
102
-
103
- "Overall": """
104
- ### Overall Profile
105
-
106
- **Technical definition**: Uses a single threshold value across all categories.
107
-
108
- **When to use**: When you want consistent behavior across all categories and a simple approach.
109
-
110
- **Effects**:
111
- - Consistent tagging threshold for all categories
112
- - Simpler to understand than category-specific thresholds
113
- - User-adjustable with a single slider
114
-
115
- **Default threshold value**: 0.5 (user-adjustable)
116
-
117
- **Note**: The threshold value is user-adjustable with the slider below.
118
- """,
119
-
120
- "Category-specific": """
121
- ### Category-specific Profile
122
-
123
- **Technical definition**: Uses different optimal thresholds for each category, allowing fine-tuning.
124
-
125
- **When to use**: When you want to customize tagging sensitivity for different categories.
126
-
127
- **Effects**:
128
- - Each category has its own independent threshold
129
- - Full control over category sensitivity
130
- - Best for fine-tuning results when some categories need different treatment
131
-
132
- **Default threshold values**: Starts with balanced thresholds for each category
133
-
134
- **Note**: Use the category sliders below to adjust thresholds for individual categories.
135
- """
136
- }
137
-
138
- def load_validation_results(results_path):
139
- """Load validation results from JSON file"""
140
- try:
141
- with open(results_path, 'r') as f:
142
- data = json.load(f)
143
- return data
144
- except Exception as e:
145
- print(f"Error loading validation results: {e}")
146
- return None
147
-
148
- def extract_thresholds_from_results(validation_data):
149
- """Extract threshold information from validation results"""
150
- if not validation_data or 'results' not in validation_data:
151
- return {}
152
-
153
- thresholds = {
154
- 'overall': {},
155
- 'categories': {}
156
- }
157
-
158
- # Process results to extract thresholds
159
- for result in validation_data['results']:
160
- category = result['CATEGORY'].lower()
161
- profile = result['PROFILE'].lower().replace(' ', '_')
162
- threshold = result['THRESHOLD']
163
- micro_f1 = result['MICRO-F1']
164
- macro_f1 = result['MACRO-F1']
165
-
166
- # Map profile names
167
- if profile == 'micro_opt':
168
- profile = 'micro_optimized'
169
- elif profile == 'macro_opt':
170
- profile = 'macro_optimized'
171
-
172
- threshold_info = {
173
- 'threshold': threshold,
174
- 'micro_f1': micro_f1,
175
- 'macro_f1': macro_f1
176
- }
177
-
178
- if category == 'overall':
179
- thresholds['overall'][profile] = threshold_info
180
- else:
181
- if category not in thresholds['categories']:
182
- thresholds['categories'][category] = {}
183
- thresholds['categories'][category][profile] = threshold_info
184
-
185
- return thresholds
186
-
187
- def load_model_and_metadata():
188
- """Load model and metadata from available files"""
189
- # Check for SafeTensors model
190
- safetensors_path = os.path.join(MODEL_DIR, "camie-tagger-v2.safetensors")
191
- safetensors_metadata_path = os.path.join(MODEL_DIR, "camie-tagger-v2-metadata.json")
192
-
193
- # Check for ONNX model
194
- onnx_path = os.path.join(MODEL_DIR, "camie-tagger-v2.onnx")
195
-
196
- # Check for validation results
197
- validation_results_path = os.path.join(MODEL_DIR, "full_validation_results.json")
198
-
199
- model_info = {
200
- 'safetensors_available': os.path.exists(safetensors_path) and os.path.exists(safetensors_metadata_path),
201
- 'onnx_available': os.path.exists(onnx_path) and os.path.exists(safetensors_metadata_path),
202
- 'validation_results_available': os.path.exists(validation_results_path)
203
- }
204
-
205
- # Load metadata (same for both model types)
206
- metadata = None
207
- if os.path.exists(safetensors_metadata_path):
208
- try:
209
- with open(safetensors_metadata_path, 'r') as f:
210
- metadata = json.load(f)
211
- except Exception as e:
212
- print(f"Error loading metadata: {e}")
213
-
214
- # Load validation results for thresholds
215
- thresholds = {}
216
- if model_info['validation_results_available']:
217
- validation_data = load_validation_results(validation_results_path)
218
- if validation_data:
219
- thresholds = extract_thresholds_from_results(validation_data)
220
-
221
- # Add default thresholds if not available
222
- if not thresholds:
223
- thresholds = {
224
- 'overall': {
225
- 'balanced': {'threshold': 0.5, 'micro_f1': 0, 'macro_f1': 0},
226
- 'micro_optimized': {'threshold': 0.6, 'micro_f1': 0, 'macro_f1': 0},
227
- 'macro_optimized': {'threshold': 0.4, 'micro_f1': 0, 'macro_f1': 0}
228
- },
229
- 'categories': {}
230
- }
231
-
232
- return model_info, metadata, thresholds
233
-
234
- def load_safetensors_model(safetensors_path, metadata_path):
235
- """Load SafeTensors model"""
236
- try:
237
- from safetensors.torch import load_file
238
- import torch
239
-
240
- # Load metadata
241
- with open(metadata_path, 'r') as f:
242
- metadata = json.load(f)
243
-
244
- # Import the model class (assuming it's available)
245
- # You'll need to make sure the ImageTagger class is importable
246
- from utils.model_loader import ImageTagger # Update this import
247
-
248
- model_info = metadata['model_info']
249
- dataset_info = metadata['dataset_info']
250
-
251
- # Recreate model architecture
252
- model = ImageTagger(
253
- total_tags=dataset_info['total_tags'],
254
- dataset=None,
255
- model_name=model_info['backbone'],
256
- num_heads=model_info['num_attention_heads'],
257
- dropout=0.0,
258
- pretrained=False,
259
- tag_context_size=model_info['tag_context_size'],
260
- use_gradient_checkpointing=False,
261
- img_size=model_info['img_size']
262
- )
263
-
264
- # Load weights
265
- state_dict = load_file(safetensors_path)
266
- model.load_state_dict(state_dict)
267
- model.eval()
268
-
269
- return model, metadata
270
- except Exception as e:
271
- raise Exception(f"Failed to load SafeTensors model: {e}")
272
-
273
- def get_profile_metrics(thresholds, profile_name):
274
- """Extract metrics for the given profile from the thresholds dictionary"""
275
- profile_key = None
276
-
277
- # Map UI-friendly names to internal keys
278
- if profile_name == "Micro Optimized":
279
- profile_key = "micro_optimized"
280
- elif profile_name == "Macro Optimized":
281
- profile_key = "macro_optimized"
282
- elif profile_name == "Balanced":
283
- profile_key = "balanced"
284
- elif profile_name in ["Overall", "Category-specific"]:
285
- profile_key = "macro_optimized" # Use macro as default for these modes
286
-
287
- if profile_key and 'overall' in thresholds and profile_key in thresholds['overall']:
288
- return thresholds['overall'][profile_key]
289
-
290
- return None
291
-
292
- def on_threshold_profile_change():
293
- """Handle threshold profile changes"""
294
- new_profile = st.session_state.threshold_profile
295
-
296
- if hasattr(st.session_state, 'thresholds') and hasattr(st.session_state, 'settings'):
297
- # Initialize category thresholds if needed
298
- if st.session_state.settings['active_category_thresholds'] is None:
299
- st.session_state.settings['active_category_thresholds'] = {}
300
-
301
- current_thresholds = st.session_state.settings['active_category_thresholds']
302
-
303
- # Map profile names to keys
304
- profile_key = None
305
- if new_profile == "Micro Optimized":
306
- profile_key = "micro_optimized"
307
- elif new_profile == "Macro Optimized":
308
- profile_key = "macro_optimized"
309
- elif new_profile == "Balanced":
310
- profile_key = "balanced"
311
-
312
- # Update thresholds based on profile
313
- if profile_key and 'overall' in st.session_state.thresholds and profile_key in st.session_state.thresholds['overall']:
314
- st.session_state.settings['active_threshold'] = st.session_state.thresholds['overall'][profile_key]['threshold']
315
-
316
- # Set category thresholds
317
- for category in st.session_state.categories:
318
- if category in st.session_state.thresholds['categories'] and profile_key in st.session_state.thresholds['categories'][category]:
319
- current_thresholds[category] = st.session_state.thresholds['categories'][category][profile_key]['threshold']
320
- else:
321
- current_thresholds[category] = st.session_state.settings['active_threshold']
322
-
323
- elif new_profile == "Overall":
324
- # Use balanced threshold for Overall profile
325
- if 'overall' in st.session_state.thresholds and 'balanced' in st.session_state.thresholds['overall']:
326
- st.session_state.settings['active_threshold'] = st.session_state.thresholds['overall']['balanced']['threshold']
327
- else:
328
- st.session_state.settings['active_threshold'] = 0.5
329
-
330
- # Clear category-specific overrides
331
- st.session_state.settings['active_category_thresholds'] = {}
332
-
333
- elif new_profile == "Category-specific":
334
- # Initialize with balanced thresholds
335
- if 'overall' in st.session_state.thresholds and 'balanced' in st.session_state.thresholds['overall']:
336
- st.session_state.settings['active_threshold'] = st.session_state.thresholds['overall']['balanced']['threshold']
337
- else:
338
- st.session_state.settings['active_threshold'] = 0.5
339
-
340
- # Initialize category thresholds
341
- for category in st.session_state.categories:
342
- if category in st.session_state.thresholds['categories'] and 'balanced' in st.session_state.thresholds['categories'][category]:
343
- current_thresholds[category] = st.session_state.thresholds['categories'][category]['balanced']['threshold']
344
- else:
345
- current_thresholds[category] = st.session_state.settings['active_threshold']
346
-
347
- def apply_thresholds(all_probs, threshold_profile, active_threshold, active_category_thresholds, min_confidence, selected_categories):
348
- """Apply thresholds to raw probabilities and return filtered tags"""
349
- tags = {}
350
- all_tags = []
351
-
352
- # Handle None case for active_category_thresholds
353
- active_category_thresholds = active_category_thresholds or {}
354
-
355
- for category, cat_probs in all_probs.items():
356
- # Get the appropriate threshold for this category
357
- threshold = active_category_thresholds.get(category, active_threshold)
358
-
359
- # Filter tags above threshold
360
- tags[category] = [(tag, prob) for tag, prob in cat_probs if prob >= threshold]
361
-
362
- # Add to all_tags if selected
363
- if selected_categories.get(category, True):
364
- for tag, prob in tags[category]:
365
- all_tags.append(tag)
366
-
367
- return tags, all_tags
368
-
369
- def image_tagger_app():
370
- """Main Streamlit application for image tagging."""
371
- st.set_page_config(layout="wide", page_title="Camie Tagger", page_icon="🖼️")
372
-
373
- st.title("Camie-Tagger-v2 Interface")
374
- st.markdown("---")
375
-
376
- # Initialize settings
377
- if 'settings' not in st.session_state:
378
- st.session_state.settings = {
379
- 'show_all_tags': False,
380
- 'compact_view': True,
381
- 'min_confidence': 0.01,
382
- 'threshold_profile': "Macro",
383
- 'active_threshold': 0.5,
384
- 'active_category_thresholds': {}, # Initialize as empty dict, not None
385
- 'selected_categories': {},
386
- 'replace_underscores': False
387
- }
388
- st.session_state.show_profile_help = False
389
-
390
- # Session state initialization for model
391
- if 'model_loaded' not in st.session_state:
392
- st.session_state.model_loaded = False
393
- st.session_state.model = None
394
- st.session_state.thresholds = None
395
- st.session_state.metadata = None
396
- st.session_state.model_type = "onnx" # Default to ONNX
397
-
398
- # Sidebar for model selection and information
399
- with st.sidebar:
400
- # Support information
401
- st.subheader("💡 Notes")
402
-
403
- st.markdown("""
404
- This tagger was trained on a subset of the available data due to hardware limitations.
405
-
406
- A more comprehensive model trained on the full 3+ million image dataset would provide:
407
- - More recent characters and tags.
408
- - Improved accuracy.
409
-
410
- If you find this tool useful and would like to support future development:
411
- """)
412
-
413
- # Add Buy Me a Coffee button with Star of the City-like glow effect
414
- st.markdown("""
415
- <style>
416
- @keyframes coffee-button-glow {
417
- 0% { box-shadow: 0 0 5px #FFD700; }
418
- 50% { box-shadow: 0 0 15px #FFD700; }
419
- 100% { box-shadow: 0 0 5px #FFD700; }
420
- }
421
-
422
- .coffee-button {
423
- display: inline-block;
424
- animation: coffee-button-glow 2s infinite;
425
- border-radius: 5px;
426
- transition: transform 0.3s ease;
427
- }
428
-
429
- .coffee-button:hover {
430
- transform: scale(1.05);
431
- }
432
- </style>
433
-
434
- <a href="https://ko-fi.com/camais" target="_blank" class="coffee-button">
435
- <img src="https://cdn.buymeacoffee.com/buttons/v2/default-yellow.png"
436
- alt="Buy Me A Coffee"
437
- style="height: 45px; width: 162px; border-radius: 5px;" />
438
- </a>
439
- """, unsafe_allow_html=True)
440
-
441
- st.markdown("""
442
- Your support helps with:
443
- - GPU costs for training
444
- - Storage for larger datasets
445
- - Development of new features
446
- - Future projects
447
-
448
- Thank you! 🙏
449
-
450
- Full Details: https://huggingface.co/Camais03/camie-tagger
451
- """)
452
-
453
- st.header("Model Selection")
454
-
455
- # Load model information
456
- model_info, metadata, thresholds = load_model_and_metadata()
457
-
458
- # Determine available model options
459
- model_options = []
460
- if model_info['onnx_available']:
461
- model_options.append("ONNX (Recommended)")
462
- if model_info['safetensors_available']:
463
- model_options.append("SafeTensors (PyTorch)")
464
-
465
- if not model_options:
466
- st.error("No model files found!")
467
- st.info(f"Looking for models in: {MODEL_DIR}")
468
- st.info("Expected files:")
469
- st.info("- camie-tagger-v2.onnx")
470
- st.info("- camie-tagger-v2.safetensors")
471
- st.info("- camie-tagger-v2-metadata.json")
472
- st.stop()
473
-
474
- # Model type selection
475
- default_index = 0 if model_info['onnx_available'] else 0
476
- model_type = st.radio(
477
- "Select Model Type:",
478
- model_options,
479
- index=default_index,
480
- help="ONNX: Optimized for speed and compatibility\nSafeTensors: Native PyTorch format"
481
- )
482
-
483
- # Convert selection to internal model type
484
- if model_type == "ONNX (Recommended)":
485
- selected_model_type = "onnx"
486
- else:
487
- selected_model_type = "safetensors"
488
-
489
- # If model type changed, reload
490
- if selected_model_type != st.session_state.model_type:
491
- st.session_state.model_loaded = False
492
- st.session_state.model_type = selected_model_type
493
-
494
- # Reload button
495
- if st.button("Reload Model") and st.session_state.model_loaded:
496
- st.session_state.model_loaded = False
497
- st.info("Reloading model...")
498
-
499
- # Try to load the model
500
- if not st.session_state.model_loaded:
501
- try:
502
- with st.spinner(f"Loading {st.session_state.model_type.upper()} model..."):
503
- if st.session_state.model_type == "onnx":
504
- # Load ONNX model
505
- import onnxruntime as ort
506
-
507
- onnx_path = os.path.join(MODEL_DIR, "camie-tagger-v2.onnx")
508
-
509
- # Check ONNX providers
510
- providers = ort.get_available_providers()
511
- gpu_available = any('CUDA' in provider for provider in providers)
512
-
513
- # Create ONNX session
514
- session = ort.InferenceSession(onnx_path, providers=providers)
515
-
516
- st.session_state.model = session
517
- st.session_state.device = f"ONNX Runtime ({'GPU' if gpu_available else 'CPU'})"
518
- st.session_state.param_dtype = "float32"
519
-
520
- else:
521
- # Load SafeTensors model
522
- safetensors_path = os.path.join(MODEL_DIR, "camie-tagger-v2.safetensors")
523
- metadata_path = os.path.join(MODEL_DIR, "camie-tagger-v2-metadata.json")
524
-
525
- model, loaded_metadata = load_safetensors_model(safetensors_path, metadata_path)
526
-
527
- st.session_state.model = model
528
- device = next(model.parameters()).device
529
- param_dtype = next(model.parameters()).dtype
530
- st.session_state.device = device
531
- st.session_state.param_dtype = param_dtype
532
- metadata = loaded_metadata # Use loaded metadata instead
533
-
534
- # Store common info
535
- st.session_state.thresholds = thresholds
536
- st.session_state.metadata = metadata
537
- st.session_state.model_loaded = True
538
-
539
- # Get categories
540
- if metadata and 'dataset_info' in metadata:
541
- tag_mapping = metadata['dataset_info']['tag_mapping']
542
- categories = list(set(tag_mapping['tag_to_category'].values()))
543
- st.session_state.categories = categories
544
-
545
- # Initialize selected categories
546
- if not st.session_state.settings['selected_categories']:
547
- st.session_state.settings['selected_categories'] = {cat: True for cat in categories}
548
-
549
- # Set initial threshold from validation results
550
- if 'overall' in thresholds and 'balanced' in thresholds['overall']:
551
- st.session_state.settings['active_threshold'] = thresholds['overall']['macro_optimized']['threshold']
552
-
553
- except Exception as e:
554
- st.error(f"Error loading model: {str(e)}")
555
- st.code(traceback.format_exc())
556
- st.stop()
557
-
558
- # Display model information in sidebar
559
- with st.sidebar:
560
- st.header("Model Information")
561
- if st.session_state.model_loaded:
562
- if st.session_state.model_type == "onnx":
563
- st.success("Using ONNX Model")
564
- else:
565
- st.success("Using SafeTensors Model")
566
-
567
- st.write(f"Device: {st.session_state.device}")
568
- st.write(f"Precision: {st.session_state.param_dtype}")
569
-
570
- if st.session_state.metadata:
571
- if 'dataset_info' in st.session_state.metadata:
572
- total_tags = st.session_state.metadata['dataset_info']['total_tags']
573
- st.write(f"Total tags: {total_tags}")
574
- elif 'total_tags' in st.session_state.metadata:
575
- st.write(f"Total tags: {st.session_state.metadata['total_tags']}")
576
-
577
- # Show categories
578
- with st.expander("Available Categories"):
579
- for category in sorted(st.session_state.categories):
580
- st.write(f"- {category.capitalize()}")
581
-
582
- # About section
583
- with st.expander("About this app"):
584
- st.write("""
585
- This app uses a trained image tagging model to analyze and tag images.
586
-
587
- **Model Options**:
588
- - **ONNX (Recommended)**: Optimized for inference speed with broad compatibility
589
- - **SafeTensors**: Native PyTorch format for advanced users
590
-
591
- **Features**:
592
- - Upload or process images in batches
593
- - Multiple threshold profiles based on validation results
594
- - Category-specific threshold adjustment
595
- - Export tags in various formats
596
- - Fast inference with GPU acceleration (when available)
597
-
598
- **Threshold Profiles**:
599
- - **Micro Optimized**: Best overall F1 score (67.3% micro F1)
600
- - **Macro Optimized**: Balanced across categories (50.6% macro F1)
601
- - **Balanced**: Good general-purpose setting
602
- - **Overall**: Single adjustable threshold
603
- - **Category-specific**: Fine-tune each category individually
604
- """)
605
-
606
- # Main content area - Image upload and processing
607
- col1, col2 = st.columns([1, 1.5])
608
-
609
- with col1:
610
- st.header("Image")
611
-
612
- upload_tab, batch_tab = st.tabs(["Upload Image", "Batch Processing"])
613
-
614
- image_path = None
615
-
616
- with upload_tab:
617
- uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
618
-
619
- if uploaded_file:
620
- # Create temporary file
621
- with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as tmp_file:
622
- tmp_file.write(uploaded_file.getvalue())
623
- image_path = tmp_file.name
624
-
625
- st.session_state.original_filename = uploaded_file.name
626
-
627
- # Display image
628
- image = Image.open(uploaded_file)
629
- st.image(image, use_container_width=True)
630
-
631
- with batch_tab:
632
- st.subheader("Batch Process Images")
633
-
634
- # Folder selection
635
- batch_folder = st.text_input("Enter folder path containing images:", "")
636
-
637
- # Save options
638
- save_options = st.radio(
639
- "Where to save tag files:",
640
- ["Same folder as images", "Custom location", "Default save folder"],
641
- index=0
642
- )
643
-
644
- # Batch size control
645
- st.subheader("Performance Options")
646
- batch_size = st.number_input("Batch size", min_value=1, max_value=32, value=4,
647
- help="Higher values may improve speed but use more memory")
648
-
649
- # Category limits
650
- enable_category_limits = st.checkbox("Limit tags per category in batch output", value=False)
651
-
652
- if enable_category_limits and hasattr(st.session_state, 'categories'):
653
- if 'category_limits' not in st.session_state:
654
- st.session_state.category_limits = {}
655
-
656
- st.markdown("**Limit Values:** -1 = no limit, 0 = exclude, N = top N tags")
657
-
658
- limit_cols = st.columns(2)
659
- for i, category in enumerate(sorted(st.session_state.categories)):
660
- col_idx = i % 2
661
- with limit_cols[col_idx]:
662
- current_limit = st.session_state.category_limits.get(category, -1)
663
- new_limit = st.number_input(
664
- f"{category.capitalize()}:",
665
- value=current_limit,
666
- min_value=-1,
667
- step=1,
668
- key=f"limit_{category}"
669
- )
670
- st.session_state.category_limits[category] = new_limit
671
-
672
- # Process batch button
673
- if batch_folder and os.path.isdir(batch_folder):
674
- image_files = []
675
- for ext in ['*.jpg', '*.jpeg', '*.png']:
676
- image_files.extend(glob.glob(os.path.join(batch_folder, ext)))
677
- image_files.extend(glob.glob(os.path.join(batch_folder, ext.upper())))
678
-
679
- if image_files:
680
- st.write(f"Found {len(image_files)} images")
681
-
682
- if st.button("🔄 Process All Images", type="primary"):
683
- if not st.session_state.model_loaded:
684
- st.error("Model not loaded")
685
- else:
686
- with st.spinner("Processing images..."):
687
- progress_bar = st.progress(0)
688
- status_text = st.empty()
689
-
690
- def update_progress(current, total, image_path):
691
- progress = current / total if total > 0 else 0
692
- progress_bar.progress(progress)
693
- status_text.text(f"Processing {current}/{total}: {os.path.basename(image_path) if image_path else 'Complete'}")
694
-
695
- # Determine save directory
696
- if save_options == "Same folder as images":
697
- save_dir = batch_folder
698
- elif save_options == "Custom location":
699
- save_dir = st.text_input("Custom save directory:", batch_folder)
700
- else:
701
- save_dir = os.path.join(os.path.dirname(__file__), "saved_tags")
702
- os.makedirs(save_dir, exist_ok=True)
703
-
704
- # Get current settings
705
- category_limits = st.session_state.category_limits if enable_category_limits else None
706
-
707
- # Process based on model type
708
- if st.session_state.model_type == "onnx":
709
- batch_results = batch_process_images_onnx(
710
- folder_path=batch_folder,
711
- model_path=os.path.join(MODEL_DIR, "camie-tagger-v2.onnx"),
712
- metadata_path=os.path.join(MODEL_DIR, "camie-tagger-v2-metadata.json"),
713
- threshold_profile=st.session_state.settings['threshold_profile'],
714
- active_threshold=st.session_state.settings['active_threshold'],
715
- active_category_thresholds=st.session_state.settings['active_category_thresholds'],
716
- save_dir=save_dir,
717
- progress_callback=update_progress,
718
- min_confidence=st.session_state.settings['min_confidence'],
719
- batch_size=batch_size,
720
- category_limits=category_limits
721
- )
722
- else:
723
- # SafeTensors processing (would need to implement)
724
- st.error("SafeTensors batch processing not implemented yet")
725
- batch_results = None
726
-
727
- if batch_results:
728
- display_batch_results(batch_results)
729
-
730
- # Column 2: Controls and Results
731
- with col2:
732
- st.header("Tagging Controls")
733
-
734
- # Threshold profile selection
735
- all_profiles = [
736
- "Micro Optimized",
737
- "Macro Optimized",
738
- "Balanced",
739
- "Overall",
740
- "Category-specific"
741
- ]
742
-
743
- profile_col1, profile_col2 = st.columns([3, 1])
744
-
745
- with profile_col1:
746
- threshold_profile = st.selectbox(
747
- "Select threshold profile",
748
- options=all_profiles,
749
- index=1, # Default to Macro
750
- key="threshold_profile",
751
- on_change=on_threshold_profile_change
752
- )
753
-
754
- with profile_col2:
755
- if st.button("ℹ️ Help", key="profile_help"):
756
- st.session_state.show_profile_help = not st.session_state.get('show_profile_help', False)
757
-
758
- # Show profile help
759
- if st.session_state.get('show_profile_help', False):
760
- st.markdown(threshold_profile_explanations[threshold_profile])
761
- else:
762
- st.info(threshold_profile_descriptions[threshold_profile])
763
-
764
- # Show profile metrics if available
765
- if st.session_state.model_loaded:
766
- metrics = get_profile_metrics(st.session_state.thresholds, threshold_profile)
767
-
768
- if metrics:
769
- metrics_cols = st.columns(3)
770
-
771
- with metrics_cols[0]:
772
- st.metric("Threshold", f"{metrics['threshold']:.3f}")
773
-
774
- with metrics_cols[1]:
775
- st.metric("Micro F1", f"{metrics['micro_f1']:.1f}%")
776
-
777
- with metrics_cols[2]:
778
- st.metric("Macro F1", f"{metrics['macro_f1']:.1f}%")
779
-
780
- # Threshold controls based on profile
781
- if st.session_state.model_loaded:
782
- active_threshold = st.session_state.settings.get('active_threshold', 0.5)
783
- active_category_thresholds = st.session_state.settings.get('active_category_thresholds', {})
784
-
785
- if threshold_profile in ["Micro Optimized", "Macro Optimized", "Balanced"]:
786
- # Show reference threshold (disabled)
787
- st.slider(
788
- "Threshold (from validation)",
789
- min_value=0.01,
790
- max_value=1.0,
791
- value=float(active_threshold),
792
- step=0.01,
793
- disabled=True,
794
- help="This threshold is optimized from validation results"
795
- )
796
-
797
- elif threshold_profile == "Overall":
798
- # Adjustable overall threshold
799
- active_threshold = st.slider(
800
- "Overall threshold",
801
- min_value=0.01,
802
- max_value=1.0,
803
- value=float(active_threshold),
804
- step=0.01
805
- )
806
- st.session_state.settings['active_threshold'] = active_threshold
807
-
808
- elif threshold_profile == "Category-specific":
809
- # Show reference overall threshold
810
- st.slider(
811
- "Overall threshold (reference)",
812
- min_value=0.01,
813
- max_value=1.0,
814
- value=float(active_threshold),
815
- step=0.01,
816
- disabled=True
817
- )
818
-
819
- st.write("Adjust thresholds for individual categories:")
820
-
821
- # Category sliders
822
- slider_cols = st.columns(2)
823
-
824
- if not active_category_thresholds:
825
- active_category_thresholds = {}
826
-
827
- for i, category in enumerate(sorted(st.session_state.categories)):
828
- col_idx = i % 2
829
- with slider_cols[col_idx]:
830
- default_val = active_category_thresholds.get(category, active_threshold)
831
- new_threshold = st.slider(
832
- f"{category.capitalize()}",
833
- min_value=0.01,
834
- max_value=1.0,
835
- value=float(default_val),
836
- step=0.01,
837
- key=f"slider_{category}"
838
- )
839
- active_category_thresholds[category] = new_threshold
840
-
841
- st.session_state.settings['active_category_thresholds'] = active_category_thresholds
842
-
843
- # Display options
844
- with st.expander("Display Options", expanded=False):
845
- col1, col2 = st.columns(2)
846
- with col1:
847
- show_all_tags = st.checkbox("Show all tags (including below threshold)",
848
- value=st.session_state.settings['show_all_tags'])
849
- compact_view = st.checkbox("Compact view (hide progress bars)",
850
- value=st.session_state.settings['compact_view'])
851
- replace_underscores = st.checkbox("Replace underscores with spaces",
852
- value=st.session_state.settings.get('replace_underscores', False))
853
-
854
- with col2:
855
- min_confidence = st.slider("Minimum confidence to display", 0.0, 0.5,
856
- st.session_state.settings['min_confidence'], 0.01)
857
-
858
- # Update settings
859
- st.session_state.settings.update({
860
- 'show_all_tags': show_all_tags,
861
- 'compact_view': compact_view,
862
- 'min_confidence': min_confidence,
863
- 'replace_underscores': replace_underscores
864
- })
865
-
866
- # Category selection
867
- st.write("Categories to include in 'All Tags' section:")
868
-
869
- category_cols = st.columns(3)
870
- selected_categories = {}
871
-
872
- if hasattr(st.session_state, 'categories'):
873
- for i, category in enumerate(sorted(st.session_state.categories)):
874
- col_idx = i % 3
875
- with category_cols[col_idx]:
876
- default_val = st.session_state.settings['selected_categories'].get(category, True)
877
- selected_categories[category] = st.checkbox(
878
- f"{category.capitalize()}",
879
- value=default_val,
880
- key=f"cat_select_{category}"
881
- )
882
-
883
- st.session_state.settings['selected_categories'] = selected_categories
884
-
885
- # Run tagging button
886
- if image_path and st.button("Run Tagging"):
887
- if not st.session_state.model_loaded:
888
- st.error("Model not loaded")
889
- else:
890
- with st.spinner("Analyzing image..."):
891
- try:
892
- # Process image based on model type
893
- if st.session_state.model_type == "onnx":
894
- from utils.onnx_processing import process_single_image_onnx
895
-
896
- result = process_single_image_onnx(
897
- image_path=image_path,
898
- model_path=os.path.join(MODEL_DIR, "camie-tagger-v2.onnx"),
899
- metadata=st.session_state.metadata,
900
- threshold_profile=threshold_profile,
901
- active_threshold=st.session_state.settings['active_threshold'],
902
- active_category_thresholds=st.session_state.settings.get('active_category_thresholds', {}),
903
- min_confidence=st.session_state.settings['min_confidence']
904
- )
905
- else:
906
- # SafeTensors processing
907
- result = process_image(
908
- image_path=image_path,
909
- model=st.session_state.model,
910
- thresholds=st.session_state.thresholds,
911
- metadata=st.session_state.metadata,
912
- threshold_profile=threshold_profile,
913
- active_threshold=st.session_state.settings['active_threshold'],
914
- active_category_thresholds=st.session_state.settings.get('active_category_thresholds', {}),
915
- min_confidence=st.session_state.settings['min_confidence']
916
- )
917
-
918
- if result['success']:
919
- st.session_state.all_probs = result['all_probs']
920
- st.session_state.tags = result['tags']
921
- st.session_state.all_tags = result['all_tags']
922
- st.success("Analysis completed!")
923
- else:
924
- st.error(f"Analysis failed: {result.get('error', 'Unknown error')}")
925
-
926
- except Exception as e:
927
- st.error(f"Error during analysis: {str(e)}")
928
- st.code(traceback.format_exc())
929
-
930
- # Display results
931
- if image_path and hasattr(st.session_state, 'all_probs'):
932
- st.header("Predictions")
933
-
934
- # Apply current thresholds
935
- filtered_tags, current_all_tags = apply_thresholds(
936
- st.session_state.all_probs,
937
- threshold_profile,
938
- st.session_state.settings['active_threshold'],
939
- st.session_state.settings.get('active_category_thresholds', {}),
940
- st.session_state.settings['min_confidence'],
941
- st.session_state.settings['selected_categories']
942
- )
943
-
944
- all_tags = []
945
-
946
- # Display by category
947
- for category in sorted(st.session_state.all_probs.keys()):
948
- all_tags_in_category = st.session_state.all_probs.get(category, [])
949
- filtered_tags_in_category = filtered_tags.get(category, [])
950
-
951
- if all_tags_in_category:
952
- expander_label = f"{category.capitalize()} ({len(filtered_tags_in_category)} tags)"
953
-
954
- with st.expander(expander_label, expanded=True):
955
- # Get threshold for this category (handle None case)
956
- active_category_thresholds = st.session_state.settings.get('active_category_thresholds') or {}
957
- threshold = active_category_thresholds.get(category, st.session_state.settings['active_threshold'])
958
-
959
- # Determine tags to display
960
- if st.session_state.settings['show_all_tags']:
961
- tags_to_display = all_tags_in_category
962
- else:
963
- tags_to_display = [(tag, prob) for tag, prob in all_tags_in_category if prob >= threshold]
964
-
965
- if not tags_to_display:
966
- st.info(f"No tags above {st.session_state.settings['min_confidence']:.2f} confidence")
967
- continue
968
-
969
- # Display tags
970
- if st.session_state.settings['compact_view']:
971
- # Compact view
972
- tag_list = []
973
- replace_underscores = st.session_state.settings.get('replace_underscores', False)
974
-
975
- for tag, prob in tags_to_display:
976
- percentage = int(prob * 100)
977
- display_tag = tag.replace('_', ' ') if replace_underscores else tag
978
- tag_list.append(f"{display_tag} ({percentage}%)")
979
-
980
- if prob >= threshold and st.session_state.settings['selected_categories'].get(category, True):
981
- all_tags.append(tag)
982
-
983
- st.markdown(", ".join(tag_list))
984
- else:
985
- # Expanded view with progress bars
986
- for tag, prob in tags_to_display:
987
- replace_underscores = st.session_state.settings.get('replace_underscores', False)
988
- display_tag = tag.replace('_', ' ') if replace_underscores else tag
989
-
990
- if prob >= threshold and st.session_state.settings['selected_categories'].get(category, True):
991
- all_tags.append(tag)
992
- tag_display = f"**{display_tag}**"
993
- else:
994
- tag_display = display_tag
995
-
996
- st.write(tag_display)
997
- st.markdown(display_progress_bar(prob), unsafe_allow_html=True)
998
-
999
- # All tags summary
1000
- st.markdown("---")
1001
- st.subheader(f"All Tags ({len(all_tags)} total)")
1002
- if all_tags:
1003
- replace_underscores = st.session_state.settings.get('replace_underscores', False)
1004
- if replace_underscores:
1005
- display_tags = [tag.replace('_', ' ') for tag in all_tags]
1006
- st.write(", ".join(display_tags))
1007
- else:
1008
- st.write(", ".join(all_tags))
1009
- else:
1010
- st.info("No tags detected above the threshold.")
1011
-
1012
- # Save tags section
1013
- st.markdown("---")
1014
- st.subheader("Save Tags")
1015
-
1016
- if 'custom_folders' not in st.session_state:
1017
- st.session_state.custom_folders = get_default_save_locations()
1018
-
1019
- selected_folder = st.selectbox(
1020
- "Select save location:",
1021
- options=st.session_state.custom_folders,
1022
- format_func=lambda x: os.path.basename(x) if os.path.basename(x) else x
1023
- )
1024
-
1025
- if st.button("💾 Save to Selected Location"):
1026
- try:
1027
- original_filename = st.session_state.original_filename if hasattr(st.session_state, 'original_filename') else None
1028
-
1029
- saved_path = save_tags_to_file(
1030
- image_path=image_path,
1031
- all_tags=all_tags,
1032
- original_filename=original_filename,
1033
- custom_dir=selected_folder,
1034
- overwrite=True
1035
- )
1036
-
1037
- st.success(f"Tags saved to: {os.path.basename(saved_path)}")
1038
- st.info(f"Full path: {saved_path}")
1039
-
1040
- # Show file preview
1041
- with st.expander("File Contents", expanded=True):
1042
- with open(saved_path, 'r', encoding='utf-8') as f:
1043
- content = f.read()
1044
- st.code(content, language='text')
1045
-
1046
- except Exception as e:
1047
- st.error(f"Error saving tags: {str(e)}")
1048
-
1049
- if __name__ == "__main__":
1050
  image_tagger_app()
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Camie-Tagger-V2 Application
4
+ A Streamlit web app for tagging images using an AI model.
5
+ """
6
+
7
+ import streamlit as st
8
+ import os
9
+ import sys
10
+ import traceback
11
+ import tempfile
12
+ import time
13
+ import platform
14
+ import subprocess
15
+ import webbrowser
16
+ import glob
17
+ import numpy as np
18
+ import matplotlib.pyplot as plt
19
+ import io
20
+ import base64
21
+ import json
22
+ from matplotlib.colors import LinearSegmentedColormap
23
+ from PIL import Image
24
+ from pathlib import Path
25
+
26
+ # Add parent directory to path to allow importing from utils
27
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
28
+
29
+ # Import utilities
30
+ from utils.image_processing import process_image, batch_process_images
31
+ from utils.file_utils import save_tags_to_file, get_default_save_locations
32
+ from utils.ui_components import display_progress_bar, show_example_images, display_batch_results
33
+ from utils.onnx_processing import batch_process_images_onnx
34
+
35
+ # Define the model directory
36
+ MODEL_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
37
+ print(f"Using model directory: {MODEL_DIR}")
38
+
39
+ # Define threshold profile descriptions and explanations
40
+ threshold_profile_descriptions = {
41
+ "Micro Optimized": "Maximizes micro-averaged F1 score (best for dominant classes). Optimal for overall prediction quality.",
42
+ "Macro Optimized": "Maximizes macro-averaged F1 score (equal weight to all classes). Better for balanced performance across all tags.",
43
+ "Balanced": "Provides a trade-off between precision and recall with moderate thresholds. Good general-purpose setting.",
44
+ "Overall": "Uses a single threshold value across all categories. Simplest approach for consistent behavior.",
45
+ "Category-specific": "Uses different optimal thresholds for each category. Best for fine-tuning results."
46
+ }
47
+
48
+ threshold_profile_explanations = {
49
+ "Micro Optimized": """
50
+ ### Micro Optimized Profile
51
+
52
+ **Technical definition**: Maximizes micro-averaged F1 score, which calculates metrics globally across all predictions.
53
+
54
+ **When to use**: When you want the best overall accuracy, especially for common tags and dominant categories.
55
+
56
+ **Effects**:
57
+ - Optimizes performance for the most frequent tags
58
+ - Gives more weight to categories with many examples (like 'character' and 'general')
59
+ - Provides higher precision in most common use cases
60
+
61
+ **Performance from validation**:
62
+ - Micro F1: ~67.3%
63
+ - Macro F1: ~46.3%
64
+ - Threshold: ~0.614
65
+ """,
66
+
67
+ "Macro Optimized": """
68
+ ### Macro Optimized Profile
69
+
70
+ **Technical definition**: Maximizes macro-averaged F1 score, which gives equal weight to all categories regardless of size.
71
+
72
+ **When to use**: When balanced performance across all categories is important, including rare tags.
73
+
74
+ **Effects**:
75
+ - More balanced performance across all tag categories
76
+ - Better at detecting rare or unusual tags
77
+ - Generally has lower thresholds than micro-optimized
78
+
79
+ **Performance from validation**:
80
+ - Micro F1: ~60.9%
81
+ - Macro F1: ~50.6%
82
+ - Threshold: ~0.492
83
+ """,
84
+
85
+ "Balanced": """
86
+ ### Balanced Profile
87
+
88
+ **Technical definition**: Same as Micro Optimized but provides a good reference point for manual adjustment.
89
+
90
+ **When to use**: For general-purpose tagging when you don't have specific recall or precision requirements.
91
+
92
+ **Effects**:
93
+ - Good middle ground between precision and recall
94
+ - Works well for most common use cases
95
+ - Default choice for most users
96
+
97
+ **Performance from validation**:
98
+ - Micro F1: ~67.3%
99
+ - Macro F1: ~46.3%
100
+ - Threshold: ~0.614
101
+ """,
102
+
103
+ "Overall": """
104
+ ### Overall Profile
105
+
106
+ **Technical definition**: Uses a single threshold value across all categories.
107
+
108
+ **When to use**: When you want consistent behavior across all categories and a simple approach.
109
+
110
+ **Effects**:
111
+ - Consistent tagging threshold for all categories
112
+ - Simpler to understand than category-specific thresholds
113
+ - User-adjustable with a single slider
114
+
115
+ **Default threshold value**: 0.5 (user-adjustable)
116
+
117
+ **Note**: The threshold value is user-adjustable with the slider below.
118
+ """,
119
+
120
+ "Category-specific": """
121
+ ### Category-specific Profile
122
+
123
+ **Technical definition**: Uses different optimal thresholds for each category, allowing fine-tuning.
124
+
125
+ **When to use**: When you want to customize tagging sensitivity for different categories.
126
+
127
+ **Effects**:
128
+ - Each category has its own independent threshold
129
+ - Full control over category sensitivity
130
+ - Best for fine-tuning results when some categories need different treatment
131
+
132
+ **Default threshold values**: Starts with balanced thresholds for each category
133
+
134
+ **Note**: Use the category sliders below to adjust thresholds for individual categories.
135
+ """
136
+ }
137
+
138
+ def load_validation_results(results_path):
139
+ """Load validation results from JSON file"""
140
+ try:
141
+ with open(results_path, 'r') as f:
142
+ data = json.load(f)
143
+ return data
144
+ except Exception as e:
145
+ print(f"Error loading validation results: {e}")
146
+ return None
147
+
148
+ def extract_thresholds_from_results(validation_data):
149
+ """Extract threshold information from validation results"""
150
+ if not validation_data or 'results' not in validation_data:
151
+ return {}
152
+
153
+ thresholds = {
154
+ 'overall': {},
155
+ 'categories': {}
156
+ }
157
+
158
+ # Process results to extract thresholds
159
+ for result in validation_data['results']:
160
+ category = result['CATEGORY'].lower()
161
+ profile = result['PROFILE'].lower().replace(' ', '_')
162
+ threshold = result['THRESHOLD']
163
+ micro_f1 = result['MICRO-F1']
164
+ macro_f1 = result['MACRO-F1']
165
+
166
+ # Map profile names
167
+ if profile == 'micro_opt':
168
+ profile = 'micro_optimized'
169
+ elif profile == 'macro_opt':
170
+ profile = 'macro_optimized'
171
+
172
+ threshold_info = {
173
+ 'threshold': threshold,
174
+ 'micro_f1': micro_f1,
175
+ 'macro_f1': macro_f1
176
+ }
177
+
178
+ if category == 'overall':
179
+ thresholds['overall'][profile] = threshold_info
180
+ else:
181
+ if category not in thresholds['categories']:
182
+ thresholds['categories'][category] = {}
183
+ thresholds['categories'][category][profile] = threshold_info
184
+
185
+ return thresholds
186
+
187
+ def load_model_and_metadata():
188
+ """Load model and metadata from available files"""
189
+ # Check for SafeTensors model
190
+ safetensors_path = os.path.join(MODEL_DIR, "camie-tagger-v2.safetensors")
191
+ safetensors_metadata_path = os.path.join(MODEL_DIR, "camie-tagger-v2-metadata.json")
192
+
193
+ # Check for ONNX model
194
+ onnx_path = os.path.join(MODEL_DIR, "camie-tagger-v2.onnx")
195
+
196
+ # Check for validation results
197
+ validation_results_path = os.path.join(MODEL_DIR, "full_validation_results.json")
198
+
199
+ model_info = {
200
+ 'safetensors_available': os.path.exists(safetensors_path) and os.path.exists(safetensors_metadata_path),
201
+ 'onnx_available': os.path.exists(onnx_path) and os.path.exists(safetensors_metadata_path),
202
+ 'validation_results_available': os.path.exists(validation_results_path)
203
+ }
204
+
205
+ # Load metadata (same for both model types)
206
+ metadata = None
207
+ if os.path.exists(safetensors_metadata_path):
208
+ try:
209
+ with open(safetensors_metadata_path, 'r') as f:
210
+ metadata = json.load(f)
211
+ except Exception as e:
212
+ print(f"Error loading metadata: {e}")
213
+
214
+ # Load validation results for thresholds
215
+ thresholds = {}
216
+ if model_info['validation_results_available']:
217
+ validation_data = load_validation_results(validation_results_path)
218
+ if validation_data:
219
+ thresholds = extract_thresholds_from_results(validation_data)
220
+
221
+ # Add default thresholds if not available
222
+ if not thresholds:
223
+ thresholds = {
224
+ 'overall': {
225
+ 'balanced': {'threshold': 0.5, 'micro_f1': 0, 'macro_f1': 0},
226
+ 'micro_optimized': {'threshold': 0.6, 'micro_f1': 0, 'macro_f1': 0},
227
+ 'macro_optimized': {'threshold': 0.4, 'micro_f1': 0, 'macro_f1': 0}
228
+ },
229
+ 'categories': {}
230
+ }
231
+
232
+ return model_info, metadata, thresholds
233
+
234
+ def load_safetensors_model(safetensors_path, metadata_path):
235
+ """Load SafeTensors model"""
236
+ try:
237
+ from safetensors.torch import load_file
238
+ import torch
239
+
240
+ # Load metadata
241
+ with open(metadata_path, 'r') as f:
242
+ metadata = json.load(f)
243
+
244
+ # Import the model class (assuming it's available)
245
+ # You'll need to make sure the ImageTagger class is importable
246
+ from utils.model_loader import ImageTagger # Update this import
247
+
248
+ model_info = metadata['model_info']
249
+ dataset_info = metadata['dataset_info']
250
+
251
+ # Recreate model architecture
252
+ model = ImageTagger(
253
+ total_tags=dataset_info['total_tags'],
254
+ dataset=None,
255
+ model_name=model_info['backbone'],
256
+ num_heads=model_info['num_attention_heads'],
257
+ dropout=0.0,
258
+ pretrained=False,
259
+ tag_context_size=model_info['tag_context_size'],
260
+ use_gradient_checkpointing=False,
261
+ img_size=model_info['img_size']
262
+ )
263
+
264
+ # Load weights
265
+ state_dict = load_file(safetensors_path)
266
+ model.load_state_dict(state_dict)
267
+ model.eval()
268
+
269
+ return model, metadata
270
+ except Exception as e:
271
+ raise Exception(f"Failed to load SafeTensors model: {e}")
272
+
273
+ def get_profile_metrics(thresholds, profile_name):
274
+ """Extract metrics for the given profile from the thresholds dictionary"""
275
+ profile_key = None
276
+
277
+ # Map UI-friendly names to internal keys
278
+ if profile_name == "Micro Optimized":
279
+ profile_key = "micro_optimized"
280
+ elif profile_name == "Macro Optimized":
281
+ profile_key = "macro_optimized"
282
+ elif profile_name == "Balanced":
283
+ profile_key = "balanced"
284
+ elif profile_name in ["Overall", "Category-specific"]:
285
+ profile_key = "macro_optimized" # Use macro as default for these modes
286
+
287
+ if profile_key and 'overall' in thresholds and profile_key in thresholds['overall']:
288
+ return thresholds['overall'][profile_key]
289
+
290
+ return None
291
+
292
+ def on_threshold_profile_change():
293
+ """Handle threshold profile changes"""
294
+ new_profile = st.session_state.threshold_profile
295
+
296
+ if hasattr(st.session_state, 'thresholds') and hasattr(st.session_state, 'settings'):
297
+ # Initialize category thresholds if needed
298
+ if st.session_state.settings['active_category_thresholds'] is None:
299
+ st.session_state.settings['active_category_thresholds'] = {}
300
+
301
+ current_thresholds = st.session_state.settings['active_category_thresholds']
302
+
303
+ # Map profile names to keys
304
+ profile_key = None
305
+ if new_profile == "Micro Optimized":
306
+ profile_key = "micro_optimized"
307
+ elif new_profile == "Macro Optimized":
308
+ profile_key = "macro_optimized"
309
+ elif new_profile == "Balanced":
310
+ profile_key = "balanced"
311
+
312
+ # Update thresholds based on profile
313
+ if profile_key and 'overall' in st.session_state.thresholds and profile_key in st.session_state.thresholds['overall']:
314
+ st.session_state.settings['active_threshold'] = st.session_state.thresholds['overall'][profile_key]['threshold']
315
+
316
+ # Set category thresholds
317
+ for category in st.session_state.categories:
318
+ if category in st.session_state.thresholds['categories'] and profile_key in st.session_state.thresholds['categories'][category]:
319
+ current_thresholds[category] = st.session_state.thresholds['categories'][category][profile_key]['threshold']
320
+ else:
321
+ current_thresholds[category] = st.session_state.settings['active_threshold']
322
+
323
+ elif new_profile == "Overall":
324
+ # Use balanced threshold for Overall profile
325
+ if 'overall' in st.session_state.thresholds and 'balanced' in st.session_state.thresholds['overall']:
326
+ st.session_state.settings['active_threshold'] = st.session_state.thresholds['overall']['balanced']['threshold']
327
+ else:
328
+ st.session_state.settings['active_threshold'] = 0.5
329
+
330
+ # Clear category-specific overrides
331
+ st.session_state.settings['active_category_thresholds'] = {}
332
+
333
+ elif new_profile == "Category-specific":
334
+ # Initialize with balanced thresholds
335
+ if 'overall' in st.session_state.thresholds and 'balanced' in st.session_state.thresholds['overall']:
336
+ st.session_state.settings['active_threshold'] = st.session_state.thresholds['overall']['balanced']['threshold']
337
+ else:
338
+ st.session_state.settings['active_threshold'] = 0.5
339
+
340
+ # Initialize category thresholds
341
+ for category in st.session_state.categories:
342
+ if category in st.session_state.thresholds['categories'] and 'balanced' in st.session_state.thresholds['categories'][category]:
343
+ current_thresholds[category] = st.session_state.thresholds['categories'][category]['balanced']['threshold']
344
+ else:
345
+ current_thresholds[category] = st.session_state.settings['active_threshold']
346
+
347
+ def apply_thresholds(all_probs, threshold_profile, active_threshold, active_category_thresholds, min_confidence, selected_categories):
348
+ """Apply thresholds to raw probabilities and return filtered tags"""
349
+ tags = {}
350
+ all_tags = []
351
+
352
+ # Handle None case for active_category_thresholds
353
+ active_category_thresholds = active_category_thresholds or {}
354
+
355
+ for category, cat_probs in all_probs.items():
356
+ # Get the appropriate threshold for this category
357
+ threshold = active_category_thresholds.get(category, active_threshold)
358
+
359
+ # Filter tags above threshold
360
+ tags[category] = [(tag, prob) for tag, prob in cat_probs if prob >= threshold]
361
+
362
+ # Add to all_tags if selected
363
+ if selected_categories.get(category, True):
364
+ for tag, prob in tags[category]:
365
+ all_tags.append(tag)
366
+
367
+ return tags, all_tags
368
+
369
+ def image_tagger_app():
370
+ """Main Streamlit application for image tagging."""
371
+ st.set_page_config(layout="wide", page_title="Camie Tagger", page_icon="🖼️")
372
+
373
+ st.title("Camie-Tagger-v2 Interface")
374
+ st.markdown("---")
375
+
376
+ # Initialize settings
377
+ if 'settings' not in st.session_state:
378
+ st.session_state.settings = {
379
+ 'show_all_tags': False,
380
+ 'compact_view': True,
381
+ 'min_confidence': 0.01,
382
+ 'threshold_profile': "Macro",
383
+ 'active_threshold': 0.5,
384
+ 'active_category_thresholds': {}, # Initialize as empty dict, not None
385
+ 'selected_categories': {},
386
+ 'replace_underscores': False
387
+ }
388
+ st.session_state.show_profile_help = False
389
+
390
+ # Session state initialization for model
391
+ if 'model_loaded' not in st.session_state:
392
+ st.session_state.model_loaded = False
393
+ st.session_state.model = None
394
+ st.session_state.thresholds = None
395
+ st.session_state.metadata = None
396
+ st.session_state.model_type = "onnx" # Default to ONNX
397
+
398
+ # Sidebar for model selection and information
399
+ with st.sidebar:
400
+ # Support information
401
+ st.subheader("💡 Notes")
402
+
403
+ st.markdown("""
404
+ This tagger was trained on a subset of the available data due to hardware limitations.
405
+
406
+ A more comprehensive model trained on the full 3+ million image dataset would provide:
407
+ - More recent characters and tags.
408
+ - Improved accuracy.
409
+
410
+ If you find this tool useful and would like to support future development:
411
+ """)
412
+
413
+ # Add Buy Me a Coffee button with Star of the City-like glow effect
414
+ st.markdown("""
415
+ <style>
416
+ @keyframes coffee-button-glow {
417
+ 0% { box-shadow: 0 0 5px #FFD700; }
418
+ 50% { box-shadow: 0 0 15px #FFD700; }
419
+ 100% { box-shadow: 0 0 5px #FFD700; }
420
+ }
421
+
422
+ .coffee-button {
423
+ display: inline-block;
424
+ animation: coffee-button-glow 2s infinite;
425
+ border-radius: 5px;
426
+ transition: transform 0.3s ease;
427
+ }
428
+
429
+ .coffee-button:hover {
430
+ transform: scale(1.05);
431
+ }
432
+ </style>
433
+
434
+ <a href="https://ko-fi.com/camais" target="_blank" class="coffee-button">
435
+ <img src="https://cdn.buymeacoffee.com/buttons/v2/default-yellow.png"
436
+ alt="Buy Me A Coffee"
437
+ style="height: 45px; width: 162px; border-radius: 5px;" />
438
+ </a>
439
+ """, unsafe_allow_html=True)
440
+
441
+ st.markdown("""
442
+ Your support helps with:
443
+ - GPU costs for training
444
+ - Storage for larger datasets
445
+ - Development of new features
446
+ - Future projects
447
+
448
+ Thank you! 🙏
449
+
450
+ Full Details: https://huggingface.co/Camais03/camie-tagger-v2
451
+ """)
452
+
453
+ st.header("Model Selection")
454
+
455
+ # Load model information
456
+ model_info, metadata, thresholds = load_model_and_metadata()
457
+
458
+ # Determine available model options
459
+ model_options = []
460
+ if model_info['onnx_available']:
461
+ model_options.append("ONNX (Recommended)")
462
+ if model_info['safetensors_available']:
463
+ model_options.append("SafeTensors (PyTorch)")
464
+
465
+ if not model_options:
466
+ st.error("No model files found!")
467
+ st.info(f"Looking for models in: {MODEL_DIR}")
468
+ st.info("Expected files:")
469
+ st.info("- camie-tagger-v2.onnx")
470
+ st.info("- camie-tagger-v2.safetensors")
471
+ st.info("- camie-tagger-v2-metadata.json")
472
+ st.stop()
473
+
474
+ # Model type selection
475
+ default_index = 0 if model_info['onnx_available'] else 0
476
+ model_type = st.radio(
477
+ "Select Model Type:",
478
+ model_options,
479
+ index=default_index,
480
+ help="ONNX: Optimized for speed and compatibility\nSafeTensors: Native PyTorch format"
481
+ )
482
+
483
+ # Convert selection to internal model type
484
+ if model_type == "ONNX (Recommended)":
485
+ selected_model_type = "onnx"
486
+ else:
487
+ selected_model_type = "safetensors"
488
+
489
+ # If model type changed, reload
490
+ if selected_model_type != st.session_state.model_type:
491
+ st.session_state.model_loaded = False
492
+ st.session_state.model_type = selected_model_type
493
+
494
+ # Reload button
495
+ if st.button("Reload Model") and st.session_state.model_loaded:
496
+ st.session_state.model_loaded = False
497
+ st.info("Reloading model...")
498
+
499
+ # Try to load the model
500
+ if not st.session_state.model_loaded:
501
+ try:
502
+ with st.spinner(f"Loading {st.session_state.model_type.upper()} model..."):
503
+ if st.session_state.model_type == "onnx":
504
+ # Load ONNX model
505
+ import onnxruntime as ort
506
+
507
+ onnx_path = os.path.join(MODEL_DIR, "camie-tagger-v2.onnx")
508
+
509
+ # Check ONNX providers
510
+ providers = ort.get_available_providers()
511
+ gpu_available = any('CUDA' in provider for provider in providers)
512
+
513
+ # Create ONNX session
514
+ session = ort.InferenceSession(onnx_path, providers=providers)
515
+
516
+ st.session_state.model = session
517
+ st.session_state.device = f"ONNX Runtime ({'GPU' if gpu_available else 'CPU'})"
518
+ st.session_state.param_dtype = "float32"
519
+
520
+ else:
521
+ # Load SafeTensors model
522
+ safetensors_path = os.path.join(MODEL_DIR, "camie-tagger-v2.safetensors")
523
+ metadata_path = os.path.join(MODEL_DIR, "camie-tagger-v2-metadata.json")
524
+
525
+ model, loaded_metadata = load_safetensors_model(safetensors_path, metadata_path)
526
+
527
+ st.session_state.model = model
528
+ device = next(model.parameters()).device
529
+ param_dtype = next(model.parameters()).dtype
530
+ st.session_state.device = device
531
+ st.session_state.param_dtype = param_dtype
532
+ metadata = loaded_metadata # Use loaded metadata instead
533
+
534
+ # Store common info
535
+ st.session_state.thresholds = thresholds
536
+ st.session_state.metadata = metadata
537
+ st.session_state.model_loaded = True
538
+
539
+ # Get categories
540
+ if metadata and 'dataset_info' in metadata:
541
+ tag_mapping = metadata['dataset_info']['tag_mapping']
542
+ categories = list(set(tag_mapping['tag_to_category'].values()))
543
+ st.session_state.categories = categories
544
+
545
+ # Initialize selected categories
546
+ if not st.session_state.settings['selected_categories']:
547
+ st.session_state.settings['selected_categories'] = {cat: True for cat in categories}
548
+
549
+ # Set initial threshold from validation results
550
+ if 'overall' in thresholds and 'balanced' in thresholds['overall']:
551
+ st.session_state.settings['active_threshold'] = thresholds['overall']['macro_optimized']['threshold']
552
+
553
+ except Exception as e:
554
+ st.error(f"Error loading model: {str(e)}")
555
+ st.code(traceback.format_exc())
556
+ st.stop()
557
+
558
+ # Display model information in sidebar
559
+ with st.sidebar:
560
+ st.header("Model Information")
561
+ if st.session_state.model_loaded:
562
+ if st.session_state.model_type == "onnx":
563
+ st.success("Using ONNX Model")
564
+ else:
565
+ st.success("Using SafeTensors Model")
566
+
567
+ st.write(f"Device: {st.session_state.device}")
568
+ st.write(f"Precision: {st.session_state.param_dtype}")
569
+
570
+ if st.session_state.metadata:
571
+ if 'dataset_info' in st.session_state.metadata:
572
+ total_tags = st.session_state.metadata['dataset_info']['total_tags']
573
+ st.write(f"Total tags: {total_tags}")
574
+ elif 'total_tags' in st.session_state.metadata:
575
+ st.write(f"Total tags: {st.session_state.metadata['total_tags']}")
576
+
577
+ # Show categories
578
+ with st.expander("Available Categories"):
579
+ for category in sorted(st.session_state.categories):
580
+ st.write(f"- {category.capitalize()}")
581
+
582
+ # About section
583
+ with st.expander("About this app"):
584
+ st.write("""
585
+ This app uses a trained image tagging model to analyze and tag images.
586
+
587
+ **Model Options**:
588
+ - **ONNX (Recommended)**: Optimized for inference speed with broad compatibility
589
+ - **SafeTensors**: Native PyTorch format for advanced users
590
+
591
+ **Features**:
592
+ - Upload or process images in batches
593
+ - Multiple threshold profiles based on validation results
594
+ - Category-specific threshold adjustment
595
+ - Export tags in various formats
596
+ - Fast inference with GPU acceleration (when available)
597
+
598
+ **Threshold Profiles**:
599
+ - **Micro Optimized**: Best overall F1 score (67.3% micro F1)
600
+ - **Macro Optimized**: Balanced across categories (50.6% macro F1)
601
+ - **Balanced**: Good general-purpose setting
602
+ - **Overall**: Single adjustable threshold
603
+ - **Category-specific**: Fine-tune each category individually
604
+ """)
605
+
606
+ # Main content area - Image upload and processing
607
+ col1, col2 = st.columns([1, 1.5])
608
+
609
+ with col1:
610
+ st.header("Image")
611
+
612
+ upload_tab, batch_tab = st.tabs(["Upload Image", "Batch Processing"])
613
+
614
+ image_path = None
615
+
616
+ with upload_tab:
617
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
618
+
619
+ if uploaded_file:
620
+ # Create temporary file
621
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as tmp_file:
622
+ tmp_file.write(uploaded_file.getvalue())
623
+ image_path = tmp_file.name
624
+
625
+ st.session_state.original_filename = uploaded_file.name
626
+
627
+ # Display image
628
+ image = Image.open(uploaded_file)
629
+ st.image(image, use_container_width=True)
630
+
631
+ with batch_tab:
632
+ st.subheader("Batch Process Images")
633
+
634
+ # Folder selection
635
+ batch_folder = st.text_input("Enter folder path containing images:", "")
636
+
637
+ # Save options
638
+ save_options = st.radio(
639
+ "Where to save tag files:",
640
+ ["Same folder as images", "Custom location", "Default save folder"],
641
+ index=0
642
+ )
643
+
644
+ # Batch size control
645
+ st.subheader("Performance Options")
646
+ batch_size = st.number_input("Batch size", min_value=1, max_value=32, value=4,
647
+ help="Higher values may improve speed but use more memory")
648
+
649
+ # Category limits
650
+ enable_category_limits = st.checkbox("Limit tags per category in batch output", value=False)
651
+
652
+ if enable_category_limits and hasattr(st.session_state, 'categories'):
653
+ if 'category_limits' not in st.session_state:
654
+ st.session_state.category_limits = {}
655
+
656
+ st.markdown("**Limit Values:** -1 = no limit, 0 = exclude, N = top N tags")
657
+
658
+ limit_cols = st.columns(2)
659
+ for i, category in enumerate(sorted(st.session_state.categories)):
660
+ col_idx = i % 2
661
+ with limit_cols[col_idx]:
662
+ current_limit = st.session_state.category_limits.get(category, -1)
663
+ new_limit = st.number_input(
664
+ f"{category.capitalize()}:",
665
+ value=current_limit,
666
+ min_value=-1,
667
+ step=1,
668
+ key=f"limit_{category}"
669
+ )
670
+ st.session_state.category_limits[category] = new_limit
671
+
672
+ # Process batch button
673
+ if batch_folder and os.path.isdir(batch_folder):
674
+ image_files = []
675
+ for ext in ['*.jpg', '*.jpeg', '*.png']:
676
+ image_files.extend(glob.glob(os.path.join(batch_folder, ext)))
677
+ image_files.extend(glob.glob(os.path.join(batch_folder, ext.upper())))
678
+
679
+ if image_files:
680
+ st.write(f"Found {len(image_files)} images")
681
+
682
+ if st.button("🔄 Process All Images", type="primary"):
683
+ if not st.session_state.model_loaded:
684
+ st.error("Model not loaded")
685
+ else:
686
+ with st.spinner("Processing images..."):
687
+ progress_bar = st.progress(0)
688
+ status_text = st.empty()
689
+
690
+ def update_progress(current, total, image_path):
691
+ progress = current / total if total > 0 else 0
692
+ progress_bar.progress(progress)
693
+ status_text.text(f"Processing {current}/{total}: {os.path.basename(image_path) if image_path else 'Complete'}")
694
+
695
+ # Determine save directory
696
+ if save_options == "Same folder as images":
697
+ save_dir = batch_folder
698
+ elif save_options == "Custom location":
699
+ save_dir = st.text_input("Custom save directory:", batch_folder)
700
+ else:
701
+ save_dir = os.path.join(os.path.dirname(__file__), "saved_tags")
702
+ os.makedirs(save_dir, exist_ok=True)
703
+
704
+ # Get current settings
705
+ category_limits = st.session_state.category_limits if enable_category_limits else None
706
+
707
+ # Process based on model type
708
+ if st.session_state.model_type == "onnx":
709
+ batch_results = batch_process_images_onnx(
710
+ folder_path=batch_folder,
711
+ model_path=os.path.join(MODEL_DIR, "camie-tagger-v2.onnx"),
712
+ metadata_path=os.path.join(MODEL_DIR, "camie-tagger-v2-metadata.json"),
713
+ threshold_profile=st.session_state.settings['threshold_profile'],
714
+ active_threshold=st.session_state.settings['active_threshold'],
715
+ active_category_thresholds=st.session_state.settings['active_category_thresholds'],
716
+ save_dir=save_dir,
717
+ progress_callback=update_progress,
718
+ min_confidence=st.session_state.settings['min_confidence'],
719
+ batch_size=batch_size,
720
+ category_limits=category_limits
721
+ )
722
+ else:
723
+ # SafeTensors processing (would need to implement)
724
+ st.error("SafeTensors batch processing not implemented yet")
725
+ batch_results = None
726
+
727
+ if batch_results:
728
+ display_batch_results(batch_results)
729
+
730
+ # Column 2: Controls and Results
731
+ with col2:
732
+ st.header("Tagging Controls")
733
+
734
+ # Threshold profile selection
735
+ all_profiles = [
736
+ "Micro Optimized",
737
+ "Macro Optimized",
738
+ "Balanced",
739
+ "Overall",
740
+ "Category-specific"
741
+ ]
742
+
743
+ profile_col1, profile_col2 = st.columns([3, 1])
744
+
745
+ with profile_col1:
746
+ threshold_profile = st.selectbox(
747
+ "Select threshold profile",
748
+ options=all_profiles,
749
+ index=1, # Default to Macro
750
+ key="threshold_profile",
751
+ on_change=on_threshold_profile_change
752
+ )
753
+
754
+ with profile_col2:
755
+ if st.button("ℹ️ Help", key="profile_help"):
756
+ st.session_state.show_profile_help = not st.session_state.get('show_profile_help', False)
757
+
758
+ # Show profile help
759
+ if st.session_state.get('show_profile_help', False):
760
+ st.markdown(threshold_profile_explanations[threshold_profile])
761
+ else:
762
+ st.info(threshold_profile_descriptions[threshold_profile])
763
+
764
+ # Show profile metrics if available
765
+ if st.session_state.model_loaded:
766
+ metrics = get_profile_metrics(st.session_state.thresholds, threshold_profile)
767
+
768
+ if metrics:
769
+ metrics_cols = st.columns(3)
770
+
771
+ with metrics_cols[0]:
772
+ st.metric("Threshold", f"{metrics['threshold']:.3f}")
773
+
774
+ with metrics_cols[1]:
775
+ st.metric("Micro F1", f"{metrics['micro_f1']:.1f}%")
776
+
777
+ with metrics_cols[2]:
778
+ st.metric("Macro F1", f"{metrics['macro_f1']:.1f}%")
779
+
780
+ # Threshold controls based on profile
781
+ if st.session_state.model_loaded:
782
+ active_threshold = st.session_state.settings.get('active_threshold', 0.5)
783
+ active_category_thresholds = st.session_state.settings.get('active_category_thresholds', {})
784
+
785
+ if threshold_profile in ["Micro Optimized", "Macro Optimized", "Balanced"]:
786
+ # Show reference threshold (disabled)
787
+ st.slider(
788
+ "Threshold (from validation)",
789
+ min_value=0.01,
790
+ max_value=1.0,
791
+ value=float(active_threshold),
792
+ step=0.01,
793
+ disabled=True,
794
+ help="This threshold is optimized from validation results"
795
+ )
796
+
797
+ elif threshold_profile == "Overall":
798
+ # Adjustable overall threshold
799
+ active_threshold = st.slider(
800
+ "Overall threshold",
801
+ min_value=0.01,
802
+ max_value=1.0,
803
+ value=float(active_threshold),
804
+ step=0.01
805
+ )
806
+ st.session_state.settings['active_threshold'] = active_threshold
807
+
808
+ elif threshold_profile == "Category-specific":
809
+ # Show reference overall threshold
810
+ st.slider(
811
+ "Overall threshold (reference)",
812
+ min_value=0.01,
813
+ max_value=1.0,
814
+ value=float(active_threshold),
815
+ step=0.01,
816
+ disabled=True
817
+ )
818
+
819
+ st.write("Adjust thresholds for individual categories:")
820
+
821
+ # Category sliders
822
+ slider_cols = st.columns(2)
823
+
824
+ if not active_category_thresholds:
825
+ active_category_thresholds = {}
826
+
827
+ for i, category in enumerate(sorted(st.session_state.categories)):
828
+ col_idx = i % 2
829
+ with slider_cols[col_idx]:
830
+ default_val = active_category_thresholds.get(category, active_threshold)
831
+ new_threshold = st.slider(
832
+ f"{category.capitalize()}",
833
+ min_value=0.01,
834
+ max_value=1.0,
835
+ value=float(default_val),
836
+ step=0.01,
837
+ key=f"slider_{category}"
838
+ )
839
+ active_category_thresholds[category] = new_threshold
840
+
841
+ st.session_state.settings['active_category_thresholds'] = active_category_thresholds
842
+
843
+ # Display options
844
+ with st.expander("Display Options", expanded=False):
845
+ col1, col2 = st.columns(2)
846
+ with col1:
847
+ show_all_tags = st.checkbox("Show all tags (including below threshold)",
848
+ value=st.session_state.settings['show_all_tags'])
849
+ compact_view = st.checkbox("Compact view (hide progress bars)",
850
+ value=st.session_state.settings['compact_view'])
851
+ replace_underscores = st.checkbox("Replace underscores with spaces",
852
+ value=st.session_state.settings.get('replace_underscores', False))
853
+
854
+ with col2:
855
+ min_confidence = st.slider("Minimum confidence to display", 0.0, 0.5,
856
+ st.session_state.settings['min_confidence'], 0.01)
857
+
858
+ # Update settings
859
+ st.session_state.settings.update({
860
+ 'show_all_tags': show_all_tags,
861
+ 'compact_view': compact_view,
862
+ 'min_confidence': min_confidence,
863
+ 'replace_underscores': replace_underscores
864
+ })
865
+
866
+ # Category selection
867
+ st.write("Categories to include in 'All Tags' section:")
868
+
869
+ category_cols = st.columns(3)
870
+ selected_categories = {}
871
+
872
+ if hasattr(st.session_state, 'categories'):
873
+ for i, category in enumerate(sorted(st.session_state.categories)):
874
+ col_idx = i % 3
875
+ with category_cols[col_idx]:
876
+ default_val = st.session_state.settings['selected_categories'].get(category, True)
877
+ selected_categories[category] = st.checkbox(
878
+ f"{category.capitalize()}",
879
+ value=default_val,
880
+ key=f"cat_select_{category}"
881
+ )
882
+
883
+ st.session_state.settings['selected_categories'] = selected_categories
884
+
885
+ # Run tagging button
886
+ if image_path and st.button("Run Tagging"):
887
+ if not st.session_state.model_loaded:
888
+ st.error("Model not loaded")
889
+ else:
890
+ with st.spinner("Analyzing image..."):
891
+ try:
892
+ # Process image based on model type
893
+ if st.session_state.model_type == "onnx":
894
+ from utils.onnx_processing import process_single_image_onnx
895
+
896
+ result = process_single_image_onnx(
897
+ image_path=image_path,
898
+ model_path=os.path.join(MODEL_DIR, "camie-tagger-v2.onnx"),
899
+ metadata=st.session_state.metadata,
900
+ threshold_profile=threshold_profile,
901
+ active_threshold=st.session_state.settings['active_threshold'],
902
+ active_category_thresholds=st.session_state.settings.get('active_category_thresholds', {}),
903
+ min_confidence=st.session_state.settings['min_confidence']
904
+ )
905
+ else:
906
+ # SafeTensors processing
907
+ result = process_image(
908
+ image_path=image_path,
909
+ model=st.session_state.model,
910
+ thresholds=st.session_state.thresholds,
911
+ metadata=st.session_state.metadata,
912
+ threshold_profile=threshold_profile,
913
+ active_threshold=st.session_state.settings['active_threshold'],
914
+ active_category_thresholds=st.session_state.settings.get('active_category_thresholds', {}),
915
+ min_confidence=st.session_state.settings['min_confidence']
916
+ )
917
+
918
+ if result['success']:
919
+ st.session_state.all_probs = result['all_probs']
920
+ st.session_state.tags = result['tags']
921
+ st.session_state.all_tags = result['all_tags']
922
+ st.success("Analysis completed!")
923
+ else:
924
+ st.error(f"Analysis failed: {result.get('error', 'Unknown error')}")
925
+
926
+ except Exception as e:
927
+ st.error(f"Error during analysis: {str(e)}")
928
+ st.code(traceback.format_exc())
929
+
930
+ # Display results
931
+ if image_path and hasattr(st.session_state, 'all_probs'):
932
+ st.header("Predictions")
933
+
934
+ # Apply current thresholds
935
+ filtered_tags, current_all_tags = apply_thresholds(
936
+ st.session_state.all_probs,
937
+ threshold_profile,
938
+ st.session_state.settings['active_threshold'],
939
+ st.session_state.settings.get('active_category_thresholds', {}),
940
+ st.session_state.settings['min_confidence'],
941
+ st.session_state.settings['selected_categories']
942
+ )
943
+
944
+ all_tags = []
945
+
946
+ # Display by category
947
+ for category in sorted(st.session_state.all_probs.keys()):
948
+ all_tags_in_category = st.session_state.all_probs.get(category, [])
949
+ filtered_tags_in_category = filtered_tags.get(category, [])
950
+
951
+ if all_tags_in_category:
952
+ expander_label = f"{category.capitalize()} ({len(filtered_tags_in_category)} tags)"
953
+
954
+ with st.expander(expander_label, expanded=True):
955
+ # Get threshold for this category (handle None case)
956
+ active_category_thresholds = st.session_state.settings.get('active_category_thresholds') or {}
957
+ threshold = active_category_thresholds.get(category, st.session_state.settings['active_threshold'])
958
+
959
+ # Determine tags to display
960
+ if st.session_state.settings['show_all_tags']:
961
+ tags_to_display = all_tags_in_category
962
+ else:
963
+ tags_to_display = [(tag, prob) for tag, prob in all_tags_in_category if prob >= threshold]
964
+
965
+ if not tags_to_display:
966
+ st.info(f"No tags above {st.session_state.settings['min_confidence']:.2f} confidence")
967
+ continue
968
+
969
+ # Display tags
970
+ if st.session_state.settings['compact_view']:
971
+ # Compact view
972
+ tag_list = []
973
+ replace_underscores = st.session_state.settings.get('replace_underscores', False)
974
+
975
+ for tag, prob in tags_to_display:
976
+ percentage = int(prob * 100)
977
+ display_tag = tag.replace('_', ' ') if replace_underscores else tag
978
+ tag_list.append(f"{display_tag} ({percentage}%)")
979
+
980
+ if prob >= threshold and st.session_state.settings['selected_categories'].get(category, True):
981
+ all_tags.append(tag)
982
+
983
+ st.markdown(", ".join(tag_list))
984
+ else:
985
+ # Expanded view with progress bars
986
+ for tag, prob in tags_to_display:
987
+ replace_underscores = st.session_state.settings.get('replace_underscores', False)
988
+ display_tag = tag.replace('_', ' ') if replace_underscores else tag
989
+
990
+ if prob >= threshold and st.session_state.settings['selected_categories'].get(category, True):
991
+ all_tags.append(tag)
992
+ tag_display = f"**{display_tag}**"
993
+ else:
994
+ tag_display = display_tag
995
+
996
+ st.write(tag_display)
997
+ st.markdown(display_progress_bar(prob), unsafe_allow_html=True)
998
+
999
+ # All tags summary
1000
+ st.markdown("---")
1001
+ st.subheader(f"All Tags ({len(all_tags)} total)")
1002
+ if all_tags:
1003
+ replace_underscores = st.session_state.settings.get('replace_underscores', False)
1004
+ if replace_underscores:
1005
+ display_tags = [tag.replace('_', ' ') for tag in all_tags]
1006
+ st.write(", ".join(display_tags))
1007
+ else:
1008
+ st.write(", ".join(all_tags))
1009
+ else:
1010
+ st.info("No tags detected above the threshold.")
1011
+
1012
+ # Save tags section
1013
+ st.markdown("---")
1014
+ st.subheader("Save Tags")
1015
+
1016
+ if 'custom_folders' not in st.session_state:
1017
+ st.session_state.custom_folders = get_default_save_locations()
1018
+
1019
+ selected_folder = st.selectbox(
1020
+ "Select save location:",
1021
+ options=st.session_state.custom_folders,
1022
+ format_func=lambda x: os.path.basename(x) if os.path.basename(x) else x
1023
+ )
1024
+
1025
+ if st.button("💾 Save to Selected Location"):
1026
+ try:
1027
+ original_filename = st.session_state.original_filename if hasattr(st.session_state, 'original_filename') else None
1028
+
1029
+ saved_path = save_tags_to_file(
1030
+ image_path=image_path,
1031
+ all_tags=all_tags,
1032
+ original_filename=original_filename,
1033
+ custom_dir=selected_folder,
1034
+ overwrite=True
1035
+ )
1036
+
1037
+ st.success(f"Tags saved to: {os.path.basename(saved_path)}")
1038
+ st.info(f"Full path: {saved_path}")
1039
+
1040
+ # Show file preview
1041
+ with st.expander("File Contents", expanded=True):
1042
+ with open(saved_path, 'r', encoding='utf-8') as f:
1043
+ content = f.read()
1044
+ st.code(content, language='text')
1045
+
1046
+ except Exception as e:
1047
+ st.error(f"Error saving tags: {str(e)}")
1048
+
1049
+ if __name__ == "__main__":
1050
  image_tagger_app()