#!/usr/bin/env python # -*- coding: utf-8 -*- """ JTP2 (Joint Tagger Project 2) Image Classification Script This script implements a multi-label classifier for furry images using the PILOT2 model. It processes images, generates tags, and saves the results. The model is based on a Vision Transformer architecture and uses a custom GatedHead for classification. Key features: - Image preprocessing and transformation - Model inference using PILOT2 - Tag generation with customizable threshold - Batch processing of image directories - Saving results as text files alongside images Usage: python jtp2.py [--threshold ] """ import os import json import argparse from PIL import Image import safetensors.torch import timm from timm.models import VisionTransformer import torch from torchvision.transforms import transforms from torchvision.transforms import InterpolationMode import torchvision.transforms.functional as TF import pillow_jxl class Fit(torch.nn.Module): """ A custom transform module for resizing and padding images. Args: bounds (tuple[int, int] | int): The target dimensions for the image. interpolation (InterpolationMode): The interpolation method for resizing. grow (bool): Whether to allow upscaling of images. pad (float | None): The padding value to use if padding is applied. """ def __init__( self, bounds: tuple[int, int] | int, interpolation=InterpolationMode.LANCZOS, grow: bool = True, pad: float | None = None ): super().__init__() self.bounds = (bounds, bounds) if isinstance(bounds, int) else bounds self.interpolation = interpolation self.grow = grow self.pad = pad def forward(self, img: Image) -> Image: """ Applies the Fit transform to the input image. Args: img (Image): The input PIL Image. Returns: Image: The transformed PIL Image. """ wimg, himg = img.size hbound, wbound = self.bounds hscale = hbound / himg wscale = wbound / wimg if not self.grow: hscale = min(hscale, 1.0) wscale = min(wscale, 1.0) scale = min(hscale, wscale) if scale == 1.0: return img hnew = min(round(himg * scale), hbound) wnew = min(round(wimg * scale), wbound) img = TF.resize(img, (hnew, wnew), self.interpolation) if self.pad is None: return img hpad = hbound - hnew wpad = wbound - wnew tpad = hpad // 2 bpad = hpad - tpad lpad = wpad // 2 rpad = wpad - lpad return TF.pad(img, (lpad, tpad, rpad, bpad), self.pad) def __repr__(self) -> str: """ Returns a string representation of the Fit module. Returns: str: A string describing the module's parameters. """ return ( f"{self.__class__.__name__}(bounds={self.bounds}, " f"interpolation={self.interpolation.value}, grow={self.grow}, " f"pad={self.pad})" ) class CompositeAlpha(torch.nn.Module): """ A module for compositing images with alpha channels over a background color. Args: background (tuple[float, float, float] | float): The background color to use for compositing. """ def __init__(self, background: tuple[float, float, float] | float): super().__init__() self.background = ( (background, background, background) if isinstance(background, float) else background ) self.background = torch.tensor(self.background).unsqueeze(1).unsqueeze(2) def forward(self, img: torch.Tensor) -> torch.Tensor: """ Applies alpha compositing to the input image tensor. Args: img (torch.Tensor): The input image tensor. Returns: torch.Tensor: The composited image tensor. """ if img.shape[-3] == 3: return img alpha = img[..., 3, None, :, :] img[..., :3, :, :] *= alpha background = self.background.expand(-1, img.shape[-2], img.shape[-1]) if background.ndim == 1: background = background[:, None, None] elif background.ndim == 2: background = background[None, :, :] img[..., :3, :, :] += (1.0 - alpha) * background return img[..., :3, :, :] def __repr__(self) -> str: """ Returns a string representation of the CompositeAlpha module. Returns: str: A string describing the module's parameters. """ return f"{self.__class__.__name__}(background={self.background})" transform = transforms.Compose([ Fit((384, 384)), transforms.ToTensor(), CompositeAlpha(0.5), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), transforms.CenterCrop((384, 384)), ]) model = timm.create_model( "vit_so400m_patch14_siglip_384.webli", pretrained=False, num_classes=9083 ) # type: VisionTransformer class GatedHead(torch.nn.Module): """ A custom head module with gating mechanism for the classifier. Args: num_features (int): The number of input features. num_classes (int): The number of output classes. """ def __init__(self, num_features: int, num_classes: int): super().__init__() self.num_classes = num_classes self.linear = torch.nn.Linear(num_features, num_classes * 2) self.act = torch.nn.Sigmoid() self.gate = torch.nn.Sigmoid() def forward(self, x: torch.Tensor) -> torch.Tensor: """ Applies the gated head to the input tensor. Args: x (torch.Tensor): The input tensor. Returns: torch.Tensor: The output tensor after applying the gated head. """ x = self.linear(x) x = self.act(x[:, :self.num_classes]) * self.gate(x[:, self.num_classes:]) return x model.head = GatedHead(min(model.head.weight.shape), 9083) safetensors.torch.load_model( model, "/home/kade/source/repos/JTP2/JTP_PILOT2-e3-vit_so400m_patch14_siglip_384.safetensors" ) if torch.cuda.is_available(): model.cuda() if torch.cuda.get_device_capability()[0] >= 7: # tensor cores model.to(dtype=torch.float16, memory_format=torch.channels_last) model.eval() with open("/home/kade/source/repos/JTP2/tags.json", "r", encoding="utf-8") as file: tags = json.load(file) # type: dict allowed_tags = list(tags.keys()) for idx, tag in enumerate(allowed_tags): allowed_tags[idx] = tag.replace("_", " ") sorted_tag_score = {} def run_classifier(image, threshold): """ Runs the classifier on a single image and returns tags based on the threshold. Args: image (PIL.Image): The input image. threshold (float): The probability threshold for including tags. Returns: tuple: A tuple containing the comma-separated tags and a dictionary of tag probabilities. """ global sorted_tag_score img = image.convert('RGBA') tensor = transform(img).unsqueeze(0) if torch.cuda.is_available(): tensor = tensor.cuda() if torch.cuda.get_device_capability()[0] >= 7: # tensor cores tensor = tensor.to(dtype=torch.float16, memory_format=torch.channels_last) with torch.no_grad(): probits = model(tensor)[0].cpu() values, indices = probits.topk(250) tag_score = dict() for i in range(indices.size(0)): tag_score[allowed_tags[indices[i]]] = values[i].item() sorted_tag_score = dict( sorted(tag_score.items(), key=lambda item: item[1], reverse=True) ) return create_tags(threshold) def create_tags(threshold): """ Creates a list of tags based on the current sorted_tag_score and the given threshold. Args: threshold (float): The probability threshold for including tags. Returns: tuple: A tuple containing the comma-separated tags and a dictionary of filtered tag probabilities. """ global sorted_tag_score filtered_tag_score = { key: value for key, value in sorted_tag_score.items() if value > threshold } text_no_impl = ", ".join(filtered_tag_score.keys()) return text_no_impl, filtered_tag_score def process_directory(directory, threshold): """ Processes all images in a directory and its subdirectories, generating tags for each image. Args: directory (str): The path to the directory containing images. threshold (float): The probability threshold for including tags. Returns: dict: A dictionary mapping image paths to their generated tags. """ results = {} for root, _, files in os.walk(directory): for file in files: if file.lower().endswith(('.jpg', '.jpeg', '.png', '.jxl')): image_path = os.path.join(root, file) image = Image.open(image_path) tags, _ = run_classifier(image, threshold) results[image_path] = tags # Save tags to a text file with the same name as the image text_file_path = os.path.splitext(image_path)[0] + ".txt" with open(text_file_path, "w", encoding="utf-8") as text_file: text_file.write(tags) return results if __name__ == "__main__": parser = argparse.ArgumentParser( description="Run inference on a directory of images." ) parser.add_argument("directory", type=str, help="Target directory containing images.") parser.add_argument( "--threshold", type=float, default=0.2, help="Threshold for tag filtering." ) args = parser.parse_args() results = process_directory(args.directory, args.threshold) for image_path, tags in results.items(): print(f"{image_path}: {tags}")