Sample script for batch
#2
by
alfredplpl
- opened
Great works!
Do you any plans for sample script?
How about this one?
input:
output:
/path/to/.venv/bin/python /path/to/github/playground/pixart_tagger.py . /path/to/Downloads/sample2.jpg
{
"feature": [
"1girl",
"looking_at_viewer",
"simple_background",
"smile",
"solo",
"brown_hair",
"short_hair",
"english_text",
"speech_bubble",
"lips",
"blue_eyes",
"pink_background",
"close-up",
"grin",
"portrait",
"swept_bangs",
"very_short_hair",
"eyelashes",
"purple_background",
"pink_lips",
"pixie_cut"
],
"character": [],
"ip": []
}
Process finished with exit code 0
code:
import json
from pathlib import Path
from typing import Any
import timm
import torch
import torchvision.transforms as transforms
from PIL import Image
# -------------------------
# Utilities
# -------------------------
class TaggingHead(torch.nn.Module):
def __init__(self, input_dim, num_classes):
super().__init__()
# ← Sequential(Linear) に戻す:これで "head.0.weight/bias" キーに一致
self.head = torch.nn.Sequential(torch.nn.Linear(input_dim, num_classes))
def forward(self, x):
logits = self.head(x)
return torch.sigmoid(logits)
def get_tags(tags_file: Path) -> tuple[dict[str, int], int, int]:
with tags_file.open("r", encoding="utf-8") as f:
tag_info = json.load(f)
tag_map = tag_info["tag_map"]
tag_split = tag_info["tag_split"]
gen_tag_count = tag_split["gen_tag_count"]
character_tag_count = tag_split["character_tag_count"]
return tag_map, gen_tag_count, character_tag_count
def get_character_ip_mapping(mapping_file: Path):
with mapping_file.open("r", encoding="utf-8") as f:
return json.load(f)
def get_encoder():
base_model_repo = "hf_hub:SmilingWolf/wd-eva02-large-tagger-v3"
encoder = timm.create_model(base_model_repo, pretrained=False)
encoder.reset_classifier(0)
return encoder
def get_decoder():
return TaggingHead(1024, 13461)
def get_model():
return torch.nn.Sequential(get_encoder(), get_decoder())
def load_model(weights_file: Path, device: str):
model = get_model()
state = torch.load(str(weights_file), map_location=device, weights_only=True)
model.load_state_dict(state)
model.to(device).eval()
return model
def pure_pil_alpha_to_color_v2(
image: Image.Image, color: tuple[int, int, int] = (255, 255, 255)
) -> Image.Image:
image.load()
bg = Image.new("RGB", image.size, color)
bg.paste(image, mask=image.split()[3]) # 3 = alpha
return bg
def pil_to_rgb(image: Image.Image) -> Image.Image:
if image.mode == "RGBA":
return pure_pil_alpha_to_color_v2(image)
if image.mode == "P":
return pure_pil_alpha_to_color_v2(image.convert("RGBA"))
return image.convert("RGB")
# -------------------------
# Simple Tagger (PIL file input only)
# -------------------------
class SimpleTagger:
def __init__(
self,
model_dir: str,
general_threshold: float = 0.3,
character_threshold: float = 0.85,
):
repo_path = Path(model_dir)
if not repo_path.is_dir():
raise FileNotFoundError(f"Model directory not found: {repo_path}")
weights_file = repo_path / "model_v0.9.pth"
tags_file = repo_path / "tags_v0.9_13k.json"
mapping_file = repo_path / "char_ip_map.json"
for p in (weights_file, tags_file, mapping_file):
if not p.exists():
raise FileNotFoundError(f"Required file not found: {p}")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = load_model(weights_file, self.device)
tag_map, self.gen_tag_count, self.character_tag_count = get_tags(tags_file)
self.index_to_tag_map = {v: k for k, v in tag_map.items()}
self.character_ip_mapping = get_character_ip_mapping(mapping_file)
self.general_threshold = general_threshold
self.character_threshold = character_threshold
self.transform = transforms.Compose(
[
transforms.Resize((448, 448)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
]
)
def tag(self, image_path: str) -> dict[str, Any]:
# --- Load image via PIL (file path only) ---
img = Image.open(image_path)
img = pil_to_rgb(img)
with torch.inference_mode():
x = self.transform(img).unsqueeze(0)
# pin_memory は CUDA 時のみ意味があるため条件付き
if self.device == "cuda":
x = x.pin_memory().to(self.device, non_blocking=True)
else:
x = x.to(self.device)
probs = self.model(x)[0]
general_mask = probs[: self.gen_tag_count] > self.general_threshold
character_mask = probs[self.gen_tag_count :] > self.character_threshold
general_indices = general_mask.nonzero(as_tuple=True)[0]
character_indices = (
character_mask.nonzero(as_tuple=True)[0] + self.gen_tag_count
)
indices = torch.cat((general_indices, character_indices)).cpu()
gen_tags, char_tags = [], []
for i in indices.tolist():
tag = self.index_to_tag_map[i]
if i < self.gen_tag_count:
gen_tags.append(tag)
else:
char_tags.append(tag)
ip_tags = []
for t in char_tags:
if t in self.character_ip_mapping:
ip_tags.extend(self.character_ip_mapping[t])
ip_tags = sorted(set(ip_tags))
return {"feature": gen_tags, "character": char_tags, "ip": ip_tags}
# -------------------------
# CLI usage
# -------------------------
if __name__ == "__main__":
import argparse
import sys
parser = argparse.ArgumentParser(
description="Tag an image file using PIL loading (no API/url/base64)."
)
parser.add_argument("model_dir", type=str, help="Directory containing model and JSONs")
parser.add_argument("image_path", type=str, help="Path to the input image file")
parser.add_argument("--gen-th", type=float, default=0.3, help="General tag threshold")
parser.add_argument("--char-th", type=float, default=0.85, help="Character tag threshold")
args = parser.parse_args()
try:
tagger = SimpleTagger(args.model_dir, args.gen_th, args.char_th)
result = tagger.tag(args.image_path)
print(json.dumps(result, ensure_ascii=False, indent=2))
except Exception as e:
print(f"Error: {e}", file=sys.stderr)
sys.exit(1)