|
"""
|
|
ONNX-based batch image processing for the Image Tagger application.
|
|
Updated with proper ImageNet normalization and new metadata format.
|
|
"""
|
|
|
|
import os
|
|
import json
|
|
import time
|
|
import traceback
|
|
import numpy as np
|
|
import glob
|
|
import onnxruntime as ort
|
|
from PIL import Image
|
|
import torchvision.transforms as transforms
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
def preprocess_image(image_path, image_size=512):
|
|
"""
|
|
Process an image for ImageTagger inference with proper ImageNet normalization
|
|
"""
|
|
if not os.path.exists(image_path):
|
|
raise ValueError(f"Image not found at path: {image_path}")
|
|
|
|
|
|
transform = transforms.Compose([
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(
|
|
mean=[0.485, 0.456, 0.406],
|
|
std=[0.229, 0.224, 0.225]
|
|
)
|
|
])
|
|
|
|
try:
|
|
with Image.open(image_path) as img:
|
|
|
|
if img.mode in ('RGBA', 'P'):
|
|
img = img.convert('RGB')
|
|
|
|
|
|
width, height = img.size
|
|
aspect_ratio = width / height
|
|
|
|
|
|
if aspect_ratio > 1:
|
|
new_width = image_size
|
|
new_height = int(new_width / aspect_ratio)
|
|
else:
|
|
new_height = image_size
|
|
new_width = int(new_height * aspect_ratio)
|
|
|
|
|
|
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
|
|
|
|
|
|
|
pad_color = (124, 116, 104)
|
|
new_image = Image.new('RGB', (image_size, image_size), pad_color)
|
|
paste_x = (image_size - new_width) // 2
|
|
paste_y = (image_size - new_height) // 2
|
|
new_image.paste(img, (paste_x, paste_y))
|
|
|
|
|
|
img_tensor = transform(new_image)
|
|
return img_tensor.numpy()
|
|
|
|
except Exception as e:
|
|
raise Exception(f"Error processing {image_path}: {str(e)}")
|
|
|
|
def process_single_image_onnx(image_path, model_path, metadata, threshold_profile="Overall",
|
|
active_threshold=0.35, active_category_thresholds=None,
|
|
min_confidence=0.1):
|
|
"""
|
|
Process a single image using ONNX model with new metadata format
|
|
|
|
Args:
|
|
image_path: Path to the image file
|
|
model_path: Path to the ONNX model file
|
|
metadata: Model metadata dictionary
|
|
threshold_profile: The threshold profile being used
|
|
active_threshold: Overall threshold value
|
|
active_category_thresholds: Category-specific thresholds
|
|
min_confidence: Minimum confidence to include in results
|
|
|
|
Returns:
|
|
Dictionary with tags and probabilities
|
|
"""
|
|
try:
|
|
|
|
if hasattr(process_single_image_onnx, 'tagger'):
|
|
tagger = process_single_image_onnx.tagger
|
|
else:
|
|
|
|
tagger = ONNXImageTagger(model_path, metadata)
|
|
|
|
process_single_image_onnx.tagger = tagger
|
|
|
|
|
|
start_time = time.time()
|
|
img_array = preprocess_image(image_path)
|
|
|
|
|
|
results = tagger.predict_batch(
|
|
[img_array],
|
|
threshold=active_threshold,
|
|
category_thresholds=active_category_thresholds,
|
|
min_confidence=min_confidence
|
|
)
|
|
inference_time = time.time() - start_time
|
|
|
|
if results:
|
|
result = results[0]
|
|
result['inference_time'] = inference_time
|
|
result['success'] = True
|
|
return result
|
|
else:
|
|
return {
|
|
'success': False,
|
|
'error': 'Failed to process image',
|
|
'all_tags': [],
|
|
'all_probs': {},
|
|
'tags': {}
|
|
}
|
|
|
|
except Exception as e:
|
|
print(f"Error in process_single_image_onnx: {str(e)}")
|
|
traceback.print_exc()
|
|
return {
|
|
'success': False,
|
|
'error': str(e),
|
|
'all_tags': [],
|
|
'all_probs': {},
|
|
'tags': {}
|
|
}
|
|
|
|
def preprocess_images_parallel(image_paths, image_size=512, max_workers=8):
|
|
"""Process multiple images in parallel"""
|
|
processed_images = []
|
|
valid_paths = []
|
|
|
|
|
|
def process_single_image(path):
|
|
try:
|
|
return preprocess_image(path, image_size), path
|
|
except Exception as e:
|
|
print(f"Error processing {path}: {str(e)}")
|
|
return None, path
|
|
|
|
|
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
results = list(executor.map(process_single_image, image_paths))
|
|
|
|
|
|
for img_array, path in results:
|
|
if img_array is not None:
|
|
processed_images.append(img_array)
|
|
valid_paths.append(path)
|
|
|
|
return processed_images, valid_paths
|
|
|
|
def apply_category_limits(result, category_limits):
|
|
"""
|
|
Apply category limits to a result dictionary.
|
|
|
|
Args:
|
|
result: Result dictionary containing tags and all_tags
|
|
category_limits: Dictionary mapping categories to their tag limits
|
|
(0 = exclude category, -1 = no limit/include all)
|
|
|
|
Returns:
|
|
Updated result dictionary with limits applied
|
|
"""
|
|
if not category_limits or not result['success']:
|
|
return result
|
|
|
|
|
|
filtered_tags = result['tags']
|
|
|
|
|
|
for category, cat_tags in list(filtered_tags.items()):
|
|
|
|
limit = category_limits.get(category, -1)
|
|
|
|
if limit == 0:
|
|
|
|
del filtered_tags[category]
|
|
elif limit > 0 and len(cat_tags) > limit:
|
|
|
|
filtered_tags[category] = cat_tags[:limit]
|
|
|
|
|
|
all_tags = []
|
|
for category, cat_tags in filtered_tags.items():
|
|
for tag, _ in cat_tags:
|
|
all_tags.append(tag)
|
|
|
|
|
|
result['tags'] = filtered_tags
|
|
result['all_tags'] = all_tags
|
|
|
|
return result
|
|
|
|
class ONNXImageTagger:
|
|
"""ONNX-based image tagger for fast batch inference with updated metadata format"""
|
|
|
|
def __init__(self, model_path, metadata):
|
|
|
|
self.model_path = model_path
|
|
try:
|
|
self.session = ort.InferenceSession(
|
|
model_path,
|
|
providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
|
|
)
|
|
print(f"Using providers: {self.session.get_providers()}")
|
|
except Exception as e:
|
|
print(f"CUDA not available, using CPU: {e}")
|
|
self.session = ort.InferenceSession(
|
|
model_path,
|
|
providers=['CPUExecutionProvider']
|
|
)
|
|
print(f"Using providers: {self.session.get_providers()}")
|
|
|
|
|
|
self.metadata = metadata
|
|
|
|
|
|
if 'dataset_info' in metadata:
|
|
|
|
self.tag_mapping = metadata['dataset_info']['tag_mapping']
|
|
self.idx_to_tag = self.tag_mapping['idx_to_tag']
|
|
self.tag_to_category = self.tag_mapping['tag_to_category']
|
|
self.total_tags = metadata['dataset_info']['total_tags']
|
|
else:
|
|
|
|
self.idx_to_tag = metadata.get('idx_to_tag', {})
|
|
self.tag_to_category = metadata.get('tag_to_category', {})
|
|
self.total_tags = metadata.get('total_tags', len(self.idx_to_tag))
|
|
|
|
|
|
self.input_name = self.session.get_inputs()[0].name
|
|
print(f"Model loaded successfully. Input name: {self.input_name}")
|
|
print(f"Total tags: {self.total_tags}, Categories: {len(set(self.tag_to_category.values()))}")
|
|
|
|
def predict_batch(self, image_arrays, threshold=0.5, category_thresholds=None, min_confidence=0.1):
|
|
"""Run batch inference on preprocessed image arrays"""
|
|
|
|
batch_input = np.stack(image_arrays)
|
|
|
|
|
|
start_time = time.time()
|
|
outputs = self.session.run(None, {self.input_name: batch_input})
|
|
inference_time = time.time() - start_time
|
|
print(f"Batch inference completed in {inference_time:.4f} seconds ({inference_time/len(image_arrays):.4f} s/image)")
|
|
|
|
|
|
if len(outputs) >= 2:
|
|
|
|
initial_logits = outputs[0]
|
|
refined_logits = outputs[1]
|
|
|
|
main_logits = refined_logits
|
|
print(f"Using refined predictions (shape: {refined_logits.shape})")
|
|
else:
|
|
|
|
main_logits = outputs[0]
|
|
print(f"Using single output (shape: {main_logits.shape})")
|
|
|
|
|
|
main_probs = 1.0 / (1.0 + np.exp(-main_logits))
|
|
|
|
|
|
batch_results = []
|
|
|
|
for i in range(main_probs.shape[0]):
|
|
probs = main_probs[i]
|
|
|
|
|
|
all_probs = {}
|
|
for idx in range(probs.shape[0]):
|
|
prob_value = float(probs[idx])
|
|
if prob_value >= min_confidence:
|
|
idx_str = str(idx)
|
|
tag_name = self.idx_to_tag.get(idx_str, f"unknown-{idx}")
|
|
category = self.tag_to_category.get(tag_name, "general")
|
|
|
|
if category not in all_probs:
|
|
all_probs[category] = []
|
|
|
|
all_probs[category].append((tag_name, prob_value))
|
|
|
|
|
|
for category in all_probs:
|
|
all_probs[category] = sorted(
|
|
all_probs[category],
|
|
key=lambda x: x[1],
|
|
reverse=True
|
|
)
|
|
|
|
|
|
tags = {}
|
|
for category, cat_tags in all_probs.items():
|
|
|
|
if category_thresholds and category in category_thresholds:
|
|
cat_threshold = category_thresholds[category]
|
|
else:
|
|
cat_threshold = threshold
|
|
|
|
tags[category] = [(tag, prob) for tag, prob in cat_tags if prob >= cat_threshold]
|
|
|
|
|
|
all_tags = []
|
|
for category, cat_tags in tags.items():
|
|
for tag, _ in cat_tags:
|
|
all_tags.append(tag)
|
|
|
|
batch_results.append({
|
|
'tags': tags,
|
|
'all_probs': all_probs,
|
|
'all_tags': all_tags,
|
|
'success': True
|
|
})
|
|
|
|
return batch_results
|
|
|
|
def batch_process_images_onnx(folder_path, model_path, metadata_path, threshold_profile,
|
|
active_threshold, active_category_thresholds, save_dir=None,
|
|
progress_callback=None, min_confidence=0.1, batch_size=16,
|
|
category_limits=None):
|
|
"""
|
|
Process all images in a folder using the ONNX model with new metadata format.
|
|
|
|
Args:
|
|
folder_path: Path to folder containing images
|
|
model_path: Path to the ONNX model file
|
|
metadata_path: Path to the model metadata file
|
|
threshold_profile: Selected threshold profile
|
|
active_threshold: Overall threshold value
|
|
active_category_thresholds: Category-specific thresholds
|
|
save_dir: Directory to save tag files (if None uses default)
|
|
progress_callback: Optional callback for progress updates
|
|
min_confidence: Minimum confidence threshold
|
|
batch_size: Number of images to process at once
|
|
category_limits: Dictionary mapping categories to their tag limits
|
|
|
|
Returns:
|
|
Dictionary with results for each image
|
|
"""
|
|
from utils.file_utils import save_tags_to_file
|
|
|
|
|
|
image_extensions = ['*.jpg', '*.jpeg', '*.png']
|
|
image_files = []
|
|
|
|
for ext in image_extensions:
|
|
image_files.extend(glob.glob(os.path.join(folder_path, ext)))
|
|
image_files.extend(glob.glob(os.path.join(folder_path, ext.upper())))
|
|
|
|
|
|
if os.name == 'nt':
|
|
unique_paths = set()
|
|
unique_files = []
|
|
for file_path in image_files:
|
|
normalized_path = os.path.normpath(file_path).lower()
|
|
if normalized_path not in unique_paths:
|
|
unique_paths.add(normalized_path)
|
|
unique_files.append(file_path)
|
|
image_files = unique_files
|
|
|
|
if not image_files:
|
|
return {
|
|
'success': False,
|
|
'error': f"No images found in {folder_path}",
|
|
'results': {}
|
|
}
|
|
|
|
|
|
if save_dir is None:
|
|
app_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
save_dir = os.path.join(app_dir, "saved_tags")
|
|
|
|
|
|
os.makedirs(save_dir, exist_ok=True)
|
|
|
|
|
|
try:
|
|
with open(metadata_path, 'r') as f:
|
|
metadata = json.load(f)
|
|
except Exception as e:
|
|
return {
|
|
'success': False,
|
|
'error': f"Failed to load metadata: {e}",
|
|
'results': {}
|
|
}
|
|
|
|
|
|
try:
|
|
tagger = ONNXImageTagger(model_path, metadata)
|
|
except Exception as e:
|
|
return {
|
|
'success': False,
|
|
'error': f"Failed to load model: {e}",
|
|
'results': {}
|
|
}
|
|
|
|
|
|
results = {}
|
|
total_images = len(image_files)
|
|
processed = 0
|
|
|
|
start_time = time.time()
|
|
|
|
|
|
for i in range(0, total_images, batch_size):
|
|
batch_start = time.time()
|
|
|
|
|
|
batch_files = image_files[i:i+batch_size]
|
|
batch_size_actual = len(batch_files)
|
|
|
|
|
|
if progress_callback:
|
|
progress_callback(processed, total_images, batch_files[0] if batch_files else None)
|
|
|
|
print(f"Processing batch {i//batch_size + 1}/{(total_images + batch_size - 1)//batch_size}: {batch_size_actual} images")
|
|
|
|
try:
|
|
|
|
processed_images, valid_paths = preprocess_images_parallel(batch_files)
|
|
|
|
if processed_images:
|
|
|
|
batch_results = tagger.predict_batch(
|
|
processed_images,
|
|
threshold=active_threshold,
|
|
category_thresholds=active_category_thresholds,
|
|
min_confidence=min_confidence
|
|
)
|
|
|
|
|
|
for j, (image_path, result) in enumerate(zip(valid_paths, batch_results)):
|
|
|
|
if progress_callback:
|
|
progress_callback(processed + j, total_images, image_path)
|
|
|
|
|
|
if category_limits and result['success']:
|
|
print(f"Applying limits to {os.path.basename(image_path)}: {len(result['all_tags'])} → ", end="")
|
|
result = apply_category_limits(result, category_limits)
|
|
print(f"{len(result['all_tags'])} tags")
|
|
|
|
|
|
if result['success']:
|
|
try:
|
|
output_path = save_tags_to_file(
|
|
image_path=image_path,
|
|
all_tags=result['all_tags'],
|
|
custom_dir=save_dir,
|
|
overwrite=True
|
|
)
|
|
result['output_path'] = str(output_path)
|
|
except Exception as e:
|
|
print(f"Error saving tags for {image_path}: {e}")
|
|
result['save_error'] = str(e)
|
|
|
|
|
|
results[image_path] = result
|
|
|
|
processed += batch_size_actual
|
|
|
|
|
|
batch_end = time.time()
|
|
batch_time = batch_end - batch_start
|
|
print(f"Batch processed in {batch_time:.2f} seconds ({batch_time/batch_size_actual:.2f} seconds per image)")
|
|
|
|
except Exception as e:
|
|
print(f"Error processing batch: {str(e)}")
|
|
traceback.print_exc()
|
|
|
|
|
|
for j, image_path in enumerate(batch_files):
|
|
try:
|
|
|
|
if progress_callback:
|
|
progress_callback(processed + j, total_images, image_path)
|
|
|
|
|
|
img_array = preprocess_image(image_path)
|
|
|
|
|
|
single_results = tagger.predict_batch(
|
|
[img_array],
|
|
threshold=active_threshold,
|
|
category_thresholds=active_category_thresholds,
|
|
min_confidence=min_confidence
|
|
)
|
|
|
|
if single_results:
|
|
result = single_results[0]
|
|
|
|
|
|
if category_limits and result['success']:
|
|
result = apply_category_limits(result, category_limits)
|
|
|
|
|
|
if result['success']:
|
|
try:
|
|
output_path = save_tags_to_file(
|
|
image_path=image_path,
|
|
all_tags=result['all_tags'],
|
|
custom_dir=save_dir,
|
|
overwrite=True
|
|
)
|
|
result['output_path'] = str(output_path)
|
|
except Exception as e:
|
|
print(f"Error saving tags for {image_path}: {e}")
|
|
result['save_error'] = str(e)
|
|
|
|
results[image_path] = result
|
|
else:
|
|
results[image_path] = {
|
|
'success': False,
|
|
'error': 'Failed to process image',
|
|
'all_tags': []
|
|
}
|
|
|
|
except Exception as img_e:
|
|
print(f"Error processing single image {image_path}: {str(img_e)}")
|
|
results[image_path] = {
|
|
'success': False,
|
|
'error': str(img_e),
|
|
'all_tags': []
|
|
}
|
|
|
|
processed += batch_size_actual
|
|
|
|
|
|
if progress_callback:
|
|
progress_callback(total_images, total_images, None)
|
|
|
|
end_time = time.time()
|
|
total_time = end_time - start_time
|
|
print(f"Batch processing finished. Total time: {total_time:.2f} seconds, Average: {total_time/total_images:.2f} seconds per image")
|
|
|
|
return {
|
|
'success': True,
|
|
'total': total_images,
|
|
'processed': len(results),
|
|
'results': results,
|
|
'save_dir': save_dir,
|
|
'time_elapsed': end_time - start_time
|
|
}
|
|
|
|
def test_onnx_imagetagger(model_path, metadata_path, image_path, threshold=0.5, top_k=256):
|
|
"""
|
|
Test ImageTagger ONNX model with proper handling of all outputs and new metadata format
|
|
|
|
Args:
|
|
model_path: Path to ONNX model file
|
|
metadata_path: Path to metadata JSON file
|
|
image_path: Path to test image
|
|
threshold: Confidence threshold for predictions
|
|
top_k: Maximum number of predictions to show
|
|
"""
|
|
import onnxruntime as ort
|
|
import numpy as np
|
|
import json
|
|
import time
|
|
from collections import defaultdict
|
|
|
|
print(f"Loading ImageTagger ONNX model from {model_path}")
|
|
|
|
|
|
try:
|
|
with open(metadata_path, 'r') as f:
|
|
metadata = json.load(f)
|
|
except Exception as e:
|
|
raise ValueError(f"Failed to load metadata: {e}")
|
|
|
|
|
|
try:
|
|
if 'dataset_info' in metadata:
|
|
|
|
dataset_info = metadata['dataset_info']
|
|
tag_mapping = dataset_info['tag_mapping']
|
|
idx_to_tag = tag_mapping['idx_to_tag']
|
|
tag_to_category = tag_mapping['tag_to_category']
|
|
total_tags = dataset_info['total_tags']
|
|
else:
|
|
|
|
idx_to_tag = metadata.get('idx_to_tag', {})
|
|
tag_to_category = metadata.get('tag_to_category', {})
|
|
total_tags = metadata.get('total_tags', len(idx_to_tag))
|
|
|
|
print(f"Model info: {total_tags} tags, {len(set(tag_to_category.values()))} categories")
|
|
|
|
except KeyError as e:
|
|
raise ValueError(f"Invalid metadata structure, missing key: {e}")
|
|
|
|
|
|
providers = []
|
|
if ort.get_device() == 'GPU':
|
|
providers.append('CUDAExecutionProvider')
|
|
providers.append('CPUExecutionProvider')
|
|
|
|
try:
|
|
session = ort.InferenceSession(model_path, providers=providers)
|
|
active_provider = session.get_providers()[0]
|
|
print(f"Using provider: {active_provider}")
|
|
|
|
|
|
inputs = session.get_inputs()
|
|
outputs = session.get_outputs()
|
|
print(f"Model inputs: {len(inputs)}")
|
|
print(f"Model outputs: {len(outputs)}")
|
|
for i, output in enumerate(outputs):
|
|
print(f" Output {i}: {output.name} {output.shape}")
|
|
|
|
except Exception as e:
|
|
raise RuntimeError(f"Failed to create ONNX session: {e}")
|
|
|
|
|
|
print(f"Processing image: {image_path}")
|
|
try:
|
|
|
|
img_size = metadata.get('model_info', {}).get('img_size', 512)
|
|
img_tensor = preprocess_image(image_path, image_size=img_size)
|
|
img_numpy = img_tensor[np.newaxis, :]
|
|
print(f"Input shape: {img_numpy.shape}, dtype: {img_numpy.dtype}")
|
|
|
|
except Exception as e:
|
|
raise ValueError(f"Image preprocessing failed: {e}")
|
|
|
|
|
|
input_name = session.get_inputs()[0].name
|
|
print("Running inference...")
|
|
|
|
start_time = time.time()
|
|
try:
|
|
outputs = session.run(None, {input_name: img_numpy})
|
|
inference_time = time.time() - start_time
|
|
print(f"Inference completed in {inference_time:.4f} seconds")
|
|
|
|
except Exception as e:
|
|
raise RuntimeError(f"Inference failed: {e}")
|
|
|
|
|
|
if len(outputs) >= 2:
|
|
initial_logits = outputs[0]
|
|
refined_logits = outputs[1]
|
|
selected_candidates = outputs[2] if len(outputs) > 2 else None
|
|
|
|
|
|
main_logits = refined_logits
|
|
print(f"Using refined predictions (shape: {refined_logits.shape})")
|
|
|
|
else:
|
|
|
|
main_logits = outputs[0]
|
|
print(f"Using single output (shape: {main_logits.shape})")
|
|
|
|
|
|
main_probs = 1.0 / (1.0 + np.exp(-main_logits))
|
|
|
|
|
|
predictions_mask = (main_probs >= threshold)
|
|
indices = np.where(predictions_mask[0])[0]
|
|
|
|
if len(indices) == 0:
|
|
print(f"No predictions above threshold {threshold}")
|
|
|
|
top_indices = np.argsort(main_probs[0])[-5:][::-1]
|
|
print("Top 5 predictions:")
|
|
for idx in top_indices:
|
|
idx_str = str(idx)
|
|
tag_name = idx_to_tag.get(idx_str, f"unknown-{idx}")
|
|
prob = float(main_probs[0, idx])
|
|
print(f" {tag_name}: {prob:.3f}")
|
|
return {}
|
|
|
|
|
|
tags_by_category = defaultdict(list)
|
|
|
|
for idx in indices:
|
|
idx_str = str(idx)
|
|
tag_name = idx_to_tag.get(idx_str, f"unknown-{idx}")
|
|
category = tag_to_category.get(tag_name, "general")
|
|
prob = float(main_probs[0, idx])
|
|
|
|
tags_by_category[category].append((tag_name, prob))
|
|
|
|
|
|
for category in tags_by_category:
|
|
tags_by_category[category] = sorted(
|
|
tags_by_category[category],
|
|
key=lambda x: x[1],
|
|
reverse=True
|
|
)[:top_k]
|
|
|
|
|
|
total_predictions = sum(len(tags) for tags in tags_by_category.values())
|
|
print(f"\nPredicted tags (threshold: {threshold}): {total_predictions} total")
|
|
|
|
|
|
category_order = ['general', 'character', 'copyright', 'artist', 'meta', 'year', 'rating']
|
|
|
|
for category in category_order:
|
|
if category in tags_by_category:
|
|
tags = tags_by_category[category]
|
|
print(f"\n{category.upper()} ({len(tags)}):")
|
|
for tag, prob in tags:
|
|
print(f" {tag}: {prob:.3f}")
|
|
|
|
|
|
for category in sorted(tags_by_category.keys()):
|
|
if category not in category_order:
|
|
tags = tags_by_category[category]
|
|
print(f"\n{category.upper()} ({len(tags)}):")
|
|
for tag, prob in tags:
|
|
print(f" {tag}: {prob:.3f}")
|
|
|
|
|
|
print(f"\nPerformance:")
|
|
print(f" Inference time: {inference_time:.4f}s")
|
|
print(f" Provider: {active_provider}")
|
|
print(f" Max confidence: {main_probs.max():.3f}")
|
|
if total_predictions > 0:
|
|
avg_conf = np.mean([prob for tags in tags_by_category.values() for _, prob in tags])
|
|
print(f" Average confidence: {avg_conf:.3f}")
|
|
|
|
return dict(tags_by_category) |