def preprocess_image(image_path, image_size=512): """ Process an image for ImageTagger inference with proper ImageNet normalization """ import torchvision.transforms as transforms from PIL import Image import os if not os.path.exists(image_path): raise ValueError(f"Image not found at path: {image_path}") # ImageNet normalization - CRITICAL for your model transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ]) try: with Image.open(image_path) as img: # Convert RGBA or Palette images to RGB if img.mode in ('RGBA', 'P'): img = img.convert('RGB') # Get original dimensions width, height = img.size aspect_ratio = width / height # Calculate new dimensions to maintain aspect ratio if aspect_ratio > 1: new_width = image_size new_height = int(new_width / aspect_ratio) else: new_height = image_size new_width = int(new_height * aspect_ratio) # Resize with LANCZOS filter img = img.resize((new_width, new_height), Image.Resampling.LANCZOS) # Create new image with padding (use ImageNet mean for padding) # Using RGB values close to ImageNet mean: (0.485*255, 0.456*255, 0.406*255) pad_color = (124, 116, 104) new_image = Image.new('RGB', (image_size, image_size), pad_color) paste_x = (image_size - new_width) // 2 paste_y = (image_size - new_height) // 2 new_image.paste(img, (paste_x, paste_y)) # Apply transforms (including ImageNet normalization) img_tensor = transform(new_image) return img_tensor except Exception as e: raise Exception(f"Error processing {image_path}: {str(e)}") def test_onnx_imagetagger(model_path, metadata_path, image_path, threshold=0.5, top_k=50): """ Test ImageTagger ONNX model with proper handling of all outputs Args: model_path: Path to ONNX model file metadata_path: Path to metadata JSON file image_path: Path to test image threshold: Confidence threshold for predictions top_k: Maximum number of predictions to show """ import onnxruntime as ort import numpy as np import json import time from collections import defaultdict print(f"Loading ImageTagger ONNX model from {model_path}") # Load metadata with proper error handling try: with open(metadata_path, 'r') as f: metadata = json.load(f) except Exception as e: raise ValueError(f"Failed to load metadata: {e}") # Extract tag mappings from nested structure try: dataset_info = metadata['dataset_info'] tag_mapping = dataset_info['tag_mapping'] idx_to_tag = tag_mapping['idx_to_tag'] tag_to_category = tag_mapping['tag_to_category'] total_tags = dataset_info['total_tags'] print(f"Model info: {total_tags} tags, {len(set(tag_to_category.values()))} categories") except KeyError as e: raise ValueError(f"Invalid metadata structure, missing key: {e}") # Initialize ONNX session with robust provider handling providers = [] if ort.get_device() == 'GPU': providers.append('CUDAExecutionProvider') providers.append('CPUExecutionProvider') try: session = ort.InferenceSession(model_path, providers=providers) active_provider = session.get_providers()[0] print(f"Using provider: {active_provider}") # Print model info inputs = session.get_inputs() outputs = session.get_outputs() print(f"Model inputs: {len(inputs)}") print(f"Model outputs: {len(outputs)}") for i, output in enumerate(outputs): print(f" Output {i}: {output.name} {output.shape}") except Exception as e: raise RuntimeError(f"Failed to create ONNX session: {e}") # Preprocess image print(f"Processing image: {image_path}") try: img_tensor = preprocess_image(image_path, image_size=metadata['model_info']['img_size']) img_numpy = img_tensor.unsqueeze(0).numpy() # Add batch dimension print(f"Input shape: {img_numpy.shape}, dtype: {img_numpy.dtype}") except Exception as e: raise ValueError(f"Image preprocessing failed: {e}") # Run inference input_name = session.get_inputs()[0].name print("Running inference...") start_time = time.time() try: outputs = session.run(None, {input_name: img_numpy}) inference_time = time.time() - start_time print(f"Inference completed in {inference_time:.4f} seconds") except Exception as e: raise RuntimeError(f"Inference failed: {e}") # Handle outputs properly # outputs[0] = initial_predictions, outputs[1] = refined_predictions, outputs[2] = selected_candidates if len(outputs) >= 2: initial_logits = outputs[0] refined_logits = outputs[1] selected_candidates = outputs[2] if len(outputs) > 2 else None # Use refined predictions as main output main_logits = refined_logits print(f"Using refined predictions (shape: {refined_logits.shape})") else: # Fallback to single output main_logits = outputs[0] print(f"Using single output (shape: {main_logits.shape})") # Apply sigmoid to get probabilities main_probs = 1.0 / (1.0 + np.exp(-main_logits)) # Apply threshold and get predictions predictions_mask = (main_probs >= threshold) indices = np.where(predictions_mask[0])[0] if len(indices) == 0: print(f"No predictions above threshold {threshold}") # Show top 5 regardless of threshold top_indices = np.argsort(main_probs[0])[-5:][::-1] print("Top 5 predictions:") for idx in top_indices: idx_str = str(idx) tag_name = idx_to_tag.get(idx_str, f"unknown-{idx}") prob = float(main_probs[0, idx]) print(f" {tag_name}: {prob:.3f}") return {} # Group by category tags_by_category = defaultdict(list) for idx in indices: idx_str = str(idx) tag_name = idx_to_tag.get(idx_str, f"unknown-{idx}") category = tag_to_category.get(tag_name, "general") prob = float(main_probs[0, idx]) tags_by_category[category].append((tag_name, prob)) # Sort by probability within each category for category in tags_by_category: tags_by_category[category] = sorted( tags_by_category[category], key=lambda x: x[1], reverse=True )[:top_k] # Limit per category # Print results total_predictions = sum(len(tags) for tags in tags_by_category.values()) print(f"\nPredicted tags (threshold: {threshold}): {total_predictions} total") # Category order for consistent display category_order = ['general', 'character', 'copyright', 'artist', 'meta', 'year', 'rating'] for category in category_order: if category in tags_by_category: tags = tags_by_category[category] print(f"\n{category.upper()} ({len(tags)}):") for tag, prob in tags: print(f" {tag}: {prob:.3f}") # Show any other categories not in standard order for category in sorted(tags_by_category.keys()): if category not in category_order: tags = tags_by_category[category] print(f"\n{category.upper()} ({len(tags)}):") for tag, prob in tags: print(f" {tag}: {prob:.3f}") # Performance stats print(f"\nPerformance:") print(f" Inference time: {inference_time:.4f}s") print(f" Provider: {active_provider}") print(f" Max confidence: {main_probs.max():.3f}") if total_predictions > 0: avg_conf = np.mean([prob for tags in tags_by_category.values() for _, prob in tags]) print(f" Average confidence: {avg_conf:.3f}") return dict(tags_by_category)