Sample script for batch

#2
by alfredplpl - opened

Great works!

Do you any plans for sample script?

How about this one?

input:

sample2.jpg

output:

/path/to/.venv/bin/python /path/to/github/playground/pixart_tagger.py . /path/to/Downloads/sample2.jpg 
{
  "feature": [
    "1girl",
    "looking_at_viewer",
    "simple_background",
    "smile",
    "solo",
    "brown_hair",
    "short_hair",
    "english_text",
    "speech_bubble",
    "lips",
    "blue_eyes",
    "pink_background",
    "close-up",
    "grin",
    "portrait",
    "swept_bangs",
    "very_short_hair",
    "eyelashes",
    "purple_background",
    "pink_lips",
    "pixie_cut"
  ],
  "character": [],
  "ip": []
}

Process finished with exit code 0

code:

import json
from pathlib import Path
from typing import Any

import timm
import torch
import torchvision.transforms as transforms
from PIL import Image


# -------------------------
# Utilities
# -------------------------
class TaggingHead(torch.nn.Module):
    def __init__(self, input_dim, num_classes):
        super().__init__()
        # ← Sequential(Linear) に戻す:これで "head.0.weight/bias" キーに一致
        self.head = torch.nn.Sequential(torch.nn.Linear(input_dim, num_classes))

    def forward(self, x):
        logits = self.head(x)
        return torch.sigmoid(logits)


def get_tags(tags_file: Path) -> tuple[dict[str, int], int, int]:
    with tags_file.open("r", encoding="utf-8") as f:
        tag_info = json.load(f)
    tag_map = tag_info["tag_map"]
    tag_split = tag_info["tag_split"]
    gen_tag_count = tag_split["gen_tag_count"]
    character_tag_count = tag_split["character_tag_count"]
    return tag_map, gen_tag_count, character_tag_count


def get_character_ip_mapping(mapping_file: Path):
    with mapping_file.open("r", encoding="utf-8") as f:
        return json.load(f)


def get_encoder():
    base_model_repo = "hf_hub:SmilingWolf/wd-eva02-large-tagger-v3"
    encoder = timm.create_model(base_model_repo, pretrained=False)
    encoder.reset_classifier(0)
    return encoder


def get_decoder():
    return TaggingHead(1024, 13461)


def get_model():
    return torch.nn.Sequential(get_encoder(), get_decoder())


def load_model(weights_file: Path, device: str):
    model = get_model()
    state = torch.load(str(weights_file), map_location=device, weights_only=True)
    model.load_state_dict(state)
    model.to(device).eval()
    return model


def pure_pil_alpha_to_color_v2(
    image: Image.Image, color: tuple[int, int, int] = (255, 255, 255)
) -> Image.Image:
    image.load()
    bg = Image.new("RGB", image.size, color)
    bg.paste(image, mask=image.split()[3])  # 3 = alpha
    return bg


def pil_to_rgb(image: Image.Image) -> Image.Image:
    if image.mode == "RGBA":
        return pure_pil_alpha_to_color_v2(image)
    if image.mode == "P":
        return pure_pil_alpha_to_color_v2(image.convert("RGBA"))
    return image.convert("RGB")


# -------------------------
# Simple Tagger (PIL file input only)
# -------------------------
class SimpleTagger:
    def __init__(
        self,
        model_dir: str,
        general_threshold: float = 0.3,
        character_threshold: float = 0.85,
    ):
        repo_path = Path(model_dir)
        if not repo_path.is_dir():
            raise FileNotFoundError(f"Model directory not found: {repo_path}")

        weights_file = repo_path / "model_v0.9.pth"
        tags_file = repo_path / "tags_v0.9_13k.json"
        mapping_file = repo_path / "char_ip_map.json"
        for p in (weights_file, tags_file, mapping_file):
            if not p.exists():
                raise FileNotFoundError(f"Required file not found: {p}")

        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model = load_model(weights_file, self.device)

        tag_map, self.gen_tag_count, self.character_tag_count = get_tags(tags_file)
        self.index_to_tag_map = {v: k for k, v in tag_map.items()}
        self.character_ip_mapping = get_character_ip_mapping(mapping_file)

        self.general_threshold = general_threshold
        self.character_threshold = character_threshold

        self.transform = transforms.Compose(
            [
                transforms.Resize((448, 448)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
            ]
        )

    def tag(self, image_path: str) -> dict[str, Any]:
        # --- Load image via PIL (file path only) ---
        img = Image.open(image_path)
        img = pil_to_rgb(img)

        with torch.inference_mode():
            x = self.transform(img).unsqueeze(0)
            # pin_memory は CUDA 時のみ意味があるため条件付き
            if self.device == "cuda":
                x = x.pin_memory().to(self.device, non_blocking=True)
            else:
                x = x.to(self.device)

            probs = self.model(x)[0]

            general_mask = probs[: self.gen_tag_count] > self.general_threshold
            character_mask = probs[self.gen_tag_count :] > self.character_threshold

            general_indices = general_mask.nonzero(as_tuple=True)[0]
            character_indices = (
                character_mask.nonzero(as_tuple=True)[0] + self.gen_tag_count
            )
            indices = torch.cat((general_indices, character_indices)).cpu()

        gen_tags, char_tags = [], []
        for i in indices.tolist():
            tag = self.index_to_tag_map[i]
            if i < self.gen_tag_count:
                gen_tags.append(tag)
            else:
                char_tags.append(tag)

        ip_tags = []
        for t in char_tags:
            if t in self.character_ip_mapping:
                ip_tags.extend(self.character_ip_mapping[t])
        ip_tags = sorted(set(ip_tags))

        return {"feature": gen_tags, "character": char_tags, "ip": ip_tags}


# -------------------------
# CLI usage
# -------------------------
if __name__ == "__main__":
    import argparse
    import sys

    parser = argparse.ArgumentParser(
        description="Tag an image file using PIL loading (no API/url/base64)."
    )
    parser.add_argument("model_dir", type=str, help="Directory containing model and JSONs")
    parser.add_argument("image_path", type=str, help="Path to the input image file")
    parser.add_argument("--gen-th", type=float, default=0.3, help="General tag threshold")
    parser.add_argument("--char-th", type=float, default=0.85, help="Character tag threshold")
    args = parser.parse_args()

    try:
        tagger = SimpleTagger(args.model_dir, args.gen_th, args.char_th)
        result = tagger.tag(args.image_path)
        print(json.dumps(result, ensure_ascii=False, indent=2))
    except Exception as e:
        print(f"Error: {e}", file=sys.stderr)
        sys.exit(1)

Sign up or log in to comment