|
|
|
|
|
|
|
""" |
|
JoyCaption Alpha One |
|
|
|
This module provides functionality for generating captions for images using a |
|
combination of CLIP, LLM, and custom image adapters. It supports various |
|
caption types, tones, and lengths. |
|
|
|
The main components include: |
|
- Loading and initializing models (CLIP, LLM, image adapter) |
|
- Processing images and generating captions |
|
- Command-line interface for batch processing images in a directory |
|
""" |
|
|
|
import os |
|
import argparse |
|
import re |
|
import random |
|
import math |
|
from pathlib import Path |
|
from typing import List, Tuple, Dict |
|
from PIL import Image |
|
import pillow_jxl |
|
import torch |
|
import torchvision.transforms.functional as TVF |
|
from transformers import ( |
|
AutoModel, |
|
AutoTokenizer, |
|
AutoModelForCausalLM, |
|
PreTrainedTokenizer, |
|
PreTrainedTokenizerFast, |
|
) |
|
from torch import nn |
|
from e6db_reader import TagSetNormalizer, tag_category2id, tag_rank_to_freq |
|
|
|
CLIP_PATH = "google/siglip-so400m-patch14-384" |
|
MODEL_PATH = "meta-llama/Meta-Llama-3.1-8B" |
|
CHECKPOINT_PATH = Path(__file__).resolve().parent / "cgrkzexw-599808" |
|
CAPTION_TYPE_MAP = { |
|
("descriptive", "formal", False, False): [ |
|
"Write a descriptive caption for this image in a formal tone." |
|
], |
|
("descriptive", "formal", False, True): [ |
|
"Write a descriptive caption for this image in a formal tone within " |
|
"{word_count} words." |
|
], |
|
("descriptive", "formal", True, False): [ |
|
"Write a {length} descriptive caption for this image in a formal tone." |
|
], |
|
("descriptive", "informal", False, False): [ |
|
"Write a descriptive caption for this image in a casual tone." |
|
], |
|
("descriptive", "informal", False, True): [ |
|
"Write a descriptive caption for this image in a casual tone within " |
|
"{word_count} words." |
|
], |
|
("descriptive", "informal", True, False): [ |
|
"Write a {length} descriptive caption for this image in a casual tone." |
|
], |
|
("training_prompt", "formal", False, False): [ |
|
"Write a stable diffusion prompt for this image." |
|
], |
|
("training_prompt", "formal", False, True): [ |
|
"Write a stable diffusion prompt for this image within " + "{word_count} words." |
|
], |
|
("training_prompt", "formal", True, False): [ |
|
"Write a {length} stable diffusion prompt for this image." |
|
], |
|
("rng-tags", "formal", False, False): [ |
|
"Write a list of Booru tags for this image." |
|
], |
|
("rng-tags", "formal", False, True): [ |
|
"Write a list of Booru tags for this image within {word_count} words." |
|
], |
|
("rng-tags", "formal", True, False): [ |
|
"Write a {length} list of Booru tags for this image." |
|
], |
|
} |
|
|
|
HF_TOKEN = os.environ.get("HF_TOKEN", None) |
|
|
|
|
|
class ImageAdapter(nn.Module): |
|
""" |
|
Custom image adapter module for processing CLIP vision outputs. |
|
|
|
This module adapts the output of a CLIP vision model to be compatible with |
|
a text model. It supports optional layer normalization, positional |
|
embeddings, and deep feature extraction. |
|
|
|
Args: |
|
input_features (int): |
|
Number of input features from the vision model. |
|
output_features (int): |
|
Number of output features to match the text model. |
|
ln1 (bool): |
|
Whether to use layer normalization. |
|
pos_emb (bool): |
|
Whether to use positional embeddings. |
|
num_image_tokens (int): |
|
Number of image tokens. |
|
deep_extract (bool): |
|
Whether to use deep feature extraction. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
input_features: int, |
|
output_features: int, |
|
ln1: bool, |
|
pos_emb: bool, |
|
num_image_tokens: int, |
|
deep_extract: bool, |
|
): |
|
super().__init__() |
|
self.deep_extract = deep_extract |
|
|
|
if self.deep_extract: |
|
input_features = input_features * 5 |
|
|
|
self.linear1 = nn.Linear(input_features, output_features) |
|
self.activation = nn.GELU() |
|
self.linear2 = nn.Linear(output_features, output_features) |
|
self.ln1 = nn.Identity() if not ln1 else nn.LayerNorm(input_features) |
|
self.pos_emb = ( |
|
None |
|
if not pos_emb |
|
else nn.Parameter(torch.zeros(num_image_tokens, input_features)) |
|
) |
|
|
|
self.other_tokens = nn.Embedding(3, output_features) |
|
self.other_tokens.weight.data.normal_(mean=0.0, std=0.02) |
|
|
|
def forward(self, vision_outputs: torch.Tensor): |
|
""" |
|
Forward pass of the image adapter. |
|
|
|
Args: |
|
vision_outputs (torch.Tensor): |
|
Output tensor from the CLIP vision model. |
|
|
|
Returns: |
|
torch.Tensor: Adapted image features. |
|
""" |
|
if self.deep_extract: |
|
x = torch.concat( |
|
( |
|
vision_outputs[-2], |
|
vision_outputs[3], |
|
vision_outputs[7], |
|
vision_outputs[13], |
|
vision_outputs[20], |
|
), |
|
dim=-1, |
|
) |
|
assert len(x.shape) == 3, f"Expected 3, got {len(x.shape)}" |
|
expected_shape = vision_outputs[-2].shape[-1] * 5 |
|
assert ( |
|
x.shape[-1] == expected_shape |
|
), f"Expected {expected_shape}, got {x.shape[-1]}" |
|
else: |
|
x = vision_outputs[-2] |
|
|
|
x = self.ln1(x) |
|
|
|
if self.pos_emb is not None: |
|
assert ( |
|
x.shape[-2:] == self.pos_emb.shape |
|
), f"Expected {self.pos_emb.shape}, got {x.shape[-2:]}" |
|
x = x + self.pos_emb |
|
|
|
x = self.linear1(x) |
|
x = self.activation(x) |
|
x = self.linear2(x) |
|
|
|
other_tokens = self.other_tokens( |
|
torch.tensor([0, 1], device=self.other_tokens.weight.device).expand( |
|
x.shape[0], -1 |
|
) |
|
) |
|
assert other_tokens.shape == ( |
|
x.shape[0], |
|
2, |
|
x.shape[2], |
|
), f"Expected {(x.shape[0], 2, x.shape[2])}, got {other_tokens.shape}" |
|
x = torch.cat((other_tokens[:, 0:1], x, other_tokens[:, 1:2]), dim=1) |
|
|
|
return x |
|
|
|
def get_eot_embedding(self): |
|
""" |
|
Get the end-of-text embedding. |
|
|
|
Returns: |
|
torch.Tensor: The end-of-text embedding. |
|
""" |
|
return self.other_tokens( |
|
torch.tensor([2], device=self.other_tokens.weight.device) |
|
).squeeze(0) |
|
|
|
|
|
STOP_WORDS: set[str] = { |
|
"the", |
|
"a", |
|
"an", |
|
"and", |
|
"or", |
|
"but", |
|
"in", |
|
"on", |
|
"at", |
|
"to", |
|
"for", |
|
"of", |
|
"with", |
|
"by", |
|
"from", |
|
"up", |
|
"down", |
|
"is", |
|
"are", |
|
"was", |
|
"were", |
|
"be", |
|
"been", |
|
"being", |
|
"have", |
|
"has", |
|
"had", |
|
"do", |
|
"does", |
|
"did", |
|
"will", |
|
"would", |
|
"shall", |
|
"should", |
|
"can", |
|
"could", |
|
"may", |
|
"might", |
|
"must", |
|
"ought", |
|
"i", |
|
"you", |
|
"he", |
|
"she", |
|
"it", |
|
"we", |
|
"they", |
|
"them", |
|
"their", |
|
"this", |
|
"that", |
|
"these", |
|
"those", |
|
"am", |
|
"is", |
|
"are", |
|
"was", |
|
"were", |
|
"be", |
|
"been", |
|
"being", |
|
"have", |
|
"has", |
|
"had", |
|
"do", |
|
"does", |
|
"did", |
|
"will", |
|
"would", |
|
"shall", |
|
"should", |
|
"can", |
|
"could", |
|
"may", |
|
"might", |
|
"must", |
|
"ought", |
|
"i'm", |
|
"you're", |
|
"he's", |
|
"she's", |
|
"it's", |
|
"we're", |
|
"they're", |
|
"i've", |
|
"you've", |
|
"we've", |
|
"they've", |
|
"i'd", |
|
"you'd", |
|
"he'd", |
|
"she'd", |
|
"we'd", |
|
"they'd", |
|
"i'll", |
|
"you'll", |
|
"he'll", |
|
"she'll", |
|
"we'll", |
|
"they'll", |
|
"isn't", |
|
"aren't", |
|
"wasn't", |
|
"weren't", |
|
"hasn't", |
|
"haven't", |
|
"hadn't", |
|
"doesn't", |
|
"don't", |
|
"didn't", |
|
"won't", |
|
"wouldn't", |
|
"shan't", |
|
"shouldn't", |
|
"can't", |
|
"cannot", |
|
"couldn't", |
|
"mustn't", |
|
"let's", |
|
"that's", |
|
"who's", |
|
"what's", |
|
"here's", |
|
"there's", |
|
"when's", |
|
"where's", |
|
"why's", |
|
"how's", |
|
"a", |
|
"an", |
|
"the", |
|
"and", |
|
"but", |
|
"if", |
|
"or", |
|
"because", |
|
"as", |
|
"until", |
|
"while", |
|
"of", |
|
"at", |
|
"by", |
|
"for", |
|
"with", |
|
"about", |
|
"against", |
|
"between", |
|
"into", |
|
"through", |
|
"during", |
|
"before", |
|
"after", |
|
"above", |
|
"below", |
|
"to", |
|
"from", |
|
"up", |
|
"down", |
|
"in", |
|
"out", |
|
"on", |
|
"off", |
|
"over", |
|
"under", |
|
"again", |
|
"further", |
|
"then", |
|
"once", |
|
"here", |
|
"there", |
|
"when", |
|
"where", |
|
"why", |
|
"how", |
|
"all", |
|
"any", |
|
"both", |
|
"each", |
|
"few", |
|
"more", |
|
"most", |
|
"other", |
|
"some", |
|
"such", |
|
"no", |
|
"nor", |
|
"not", |
|
"only", |
|
"own", |
|
"same", |
|
"so", |
|
"than", |
|
"too", |
|
"very", |
|
} |
|
|
|
|
|
class JoyCaptionModel: |
|
""" |
|
A class for generating captions for images using CLIP, LLM, |
|
and custom image adapters. |
|
|
|
This class encapsulates the functionality to load and initialize |
|
various models (CLIP, LLM, image adapter) and use them to process |
|
images and generate captions. |
|
|
|
It supports different caption types, tones, and lengths. |
|
|
|
Attributes: |
|
clip_model: The CLIP vision model for processing images. |
|
text_model: The language model for generating captions. |
|
image_adapter: Custom adapter for processing CLIP vision outputs. |
|
tokenizer: Tokenizer for the language model. |
|
|
|
Methods: |
|
load_models(): Load and initialize all required models. |
|
process_image(input_image, caption_type, caption_tone, caption_length): |
|
Process an input image and generate a caption |
|
based on specified parameters. |
|
""" |
|
|
|
def __init__(self): |
|
self.clip_model = None |
|
self.text_model = None |
|
self.image_adapter = None |
|
self.tokenizer = None |
|
|
|
def load_models(self): |
|
""" |
|
Load and initialize all required models (CLIP, LLM, image adapter). |
|
""" |
|
print("Loading CLIP") |
|
self.clip_model = AutoModel.from_pretrained(CLIP_PATH) |
|
self.clip_model = self.clip_model.vision_model |
|
|
|
if (CHECKPOINT_PATH / "clip_model.pt").exists(): |
|
print("Loading VLM's custom vision model") |
|
checkpoint = torch.load( |
|
CHECKPOINT_PATH / "clip_model.pt", map_location="cpu" |
|
) |
|
checkpoint = { |
|
k.replace("_orig_mod.module.", ""): v for k, v in checkpoint.items() |
|
} |
|
self.clip_model.load_state_dict(checkpoint) |
|
del checkpoint |
|
|
|
self.clip_model.eval() |
|
self.clip_model.requires_grad_(False) |
|
self.clip_model.to("cuda") |
|
|
|
print("Loading tokenizer") |
|
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=False) |
|
assert isinstance( |
|
self.tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast) |
|
) |
|
|
|
print("Loading LLM") |
|
if (CHECKPOINT_PATH / "text_model").exists(): |
|
print("Loading VLM's custom text model") |
|
self.text_model = AutoModelForCausalLM.from_pretrained( |
|
CHECKPOINT_PATH / "text_model", device_map=0, torch_dtype=torch.bfloat16 |
|
) |
|
else: |
|
self.text_model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_PATH, device_map="auto", torch_dtype=torch.bfloat16 |
|
) |
|
|
|
self.text_model.eval() |
|
|
|
print("Loading image adapter") |
|
self.image_adapter = ImageAdapter( |
|
self.clip_model.config.hidden_size, |
|
self.text_model.config.hidden_size, |
|
False, |
|
False, |
|
38, |
|
False, |
|
) |
|
self.image_adapter.load_state_dict( |
|
torch.load(CHECKPOINT_PATH / "image_adapter.pt", map_location="cpu") |
|
) |
|
self.image_adapter.eval() |
|
self.image_adapter.to("cuda") |
|
|
|
@torch.no_grad() |
|
def process_image( |
|
self, |
|
input_image: Image.Image, |
|
caption_type: str, |
|
caption_tone: str, |
|
caption_length: str | int, |
|
custom_prompt: str | None = None, |
|
) -> Tuple[str, float]: |
|
""" |
|
Process an input image and generate a caption based on specified parameters. |
|
Also calculates the entropy of the generated caption. |
|
|
|
Returns: |
|
Tuple[str, float]: The generated caption and its entropy. |
|
""" |
|
torch.cuda.empty_cache() |
|
|
|
if custom_prompt is not None: |
|
prompt_str = custom_prompt |
|
else: |
|
prompt_str = self._get_prompt_string( |
|
caption_type, caption_tone, caption_length |
|
) |
|
print(f"Prompt: {prompt_str}") |
|
|
|
pixel_values = self._preprocess_image(input_image) |
|
prompt = self._tokenize_prompt(prompt_str) |
|
|
|
embedded_images = self._embed_image(pixel_values) |
|
inputs_embeds, input_ids, attention_mask = self._construct_inputs( |
|
embedded_images, prompt |
|
) |
|
|
|
generate_ids = self._generate_caption(inputs_embeds, input_ids, attention_mask) |
|
caption = self._decode_caption(generate_ids, input_ids) |
|
|
|
|
|
token_ids = generate_ids[0].tolist() |
|
entropy = self._calculate_entropy(token_ids) |
|
|
|
return caption.strip(), entropy |
|
|
|
def generate_valid_caption( |
|
self, |
|
input_image: Image.Image, |
|
caption_type: str, |
|
caption_tone: str, |
|
caption_length: str | int, |
|
custom_prompt: str | None = None, |
|
*, |
|
limited_words: Dict[str, int] = {"fluffy": 2}, |
|
min_sentence_count: int = 3, |
|
max_word_repetitions: int = 5, |
|
min_entropy: float = 1.75, |
|
stop_words: set[str] = STOP_WORDS, |
|
) -> str: |
|
""" |
|
Generate a valid caption, retrying if certain conditions are not met. |
|
|
|
Args: |
|
input_image (Image.Image): The input image to caption. |
|
caption_type (str): The type of caption to generate. |
|
caption_tone (str): The tone of the caption. |
|
caption_length (str | int): The desired length of the caption. |
|
custom_prompt (str | None): A custom prompt for caption generation. |
|
limited_words (Dict[str, int]): Dictionary of words with their maximum allowed occurrences. Default is {"fluffy": 1}. |
|
min_sentence_count (int): Minimum required number of sentences. Default is 3. |
|
max_word_repetitions (int): Maximum allowed repetitions for words longer than 4 characters. Default is 15. |
|
min_entropy (float): Minimum required entropy of the caption. Default is 2.3. |
|
|
|
Returns: |
|
str: A valid caption meeting all specified criteria. |
|
|
|
The method retries caption generation if: |
|
- The caption contains only special characters |
|
- The caption does not end with a period, exclamation mark, or question mark |
|
- Any word in limited_words appears more than its specified maximum times |
|
- Any word longer than 4 characters is repeated more than max_word_repetitions times |
|
- The caption contains fewer than min_sentence_count sentences |
|
- The entropy of the caption is below min_entropy |
|
""" |
|
while True: |
|
caption, entropy = self.process_image( |
|
input_image, caption_type, caption_tone, caption_length, custom_prompt |
|
) |
|
words = re.findall(r"\b\w+\b", caption.lower()) |
|
word_counts = { |
|
word: words.count(word) for word in set(words) if word not in stop_words |
|
} |
|
sentence_count = len(re.findall(r"[.!?]", caption)) |
|
|
|
if not re.search(r"\w", caption): |
|
print( |
|
f"Retrying: Caption contains only special characters.\nCaption: {caption!r}" |
|
) |
|
elif caption[-1] not in {".", "!", "?"}: |
|
print( |
|
f"Retrying: Caption does not end with proper punctuation.\nCaption: {caption!r}" |
|
) |
|
elif any( |
|
caption.lower().count(word) > max_count |
|
for word, max_count in limited_words.items() |
|
): |
|
exceeded_words = [ |
|
f"{word} ({caption.lower().count(word)}/{max_count})" |
|
for word, max_count in limited_words.items() |
|
if caption.lower().count(word) > max_count |
|
] |
|
print( |
|
f"Retrying: Limited words exceeded: {', '.join(exceeded_words)}.\nCaption: {caption!r}" |
|
) |
|
elif any( |
|
count > max_word_repetitions |
|
for word, count in word_counts.items() |
|
if len(word) > 4 |
|
): |
|
repeated_words = [ |
|
word |
|
for word, count in word_counts.items() |
|
if count > max_word_repetitions and len(word) > 4 |
|
] |
|
print( |
|
f"Retrying: Words repeated more than {max_word_repetitions} times: {', '.join(repeated_words)}.\nCaption: {caption!r}" |
|
) |
|
elif sentence_count < min_sentence_count: |
|
print( |
|
f"Retrying: Only {sentence_count} sentences (min: {min_sentence_count}).\nCaption: {caption!r}" |
|
) |
|
elif entropy < min_entropy: |
|
print( |
|
f"Retrying: Low entropy ({entropy:.2f} < {min_entropy}).\nCaption: {caption!r}" |
|
) |
|
else: |
|
return caption |
|
|
|
def _get_prompt_string(self, caption_type, caption_tone, caption_length): |
|
length = None if caption_length == "any" else caption_length |
|
|
|
if isinstance(length, str): |
|
try: |
|
length = int(length) |
|
except ValueError: |
|
pass |
|
|
|
if caption_type in {"rng-tags", "training_prompt"}: |
|
caption_tone = "formal" |
|
|
|
prompt_key = ( |
|
caption_type, |
|
caption_tone, |
|
isinstance(length, str), |
|
isinstance(length, int), |
|
) |
|
if prompt_key not in CAPTION_TYPE_MAP: |
|
raise ValueError(f"Invalid caption type: {prompt_key}") |
|
|
|
prompt_str = CAPTION_TYPE_MAP[prompt_key][0].format( |
|
length=length, word_count=length |
|
) |
|
return prompt_str |
|
|
|
def _preprocess_image(self, input_image): |
|
image = input_image.resize((384, 384), Image.LANCZOS) |
|
pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0 |
|
pixel_values = TVF.normalize(pixel_values, [0.5], [0.5]) |
|
pixel_values = pixel_values.to("cuda") |
|
return pixel_values |
|
|
|
def _tokenize_prompt(self, prompt_str): |
|
prompt = self.tokenizer.encode( |
|
prompt_str, |
|
return_tensors="pt", |
|
padding=False, |
|
truncation=False, |
|
add_special_tokens=False, |
|
) |
|
return prompt |
|
|
|
def _embed_image(self, pixel_values): |
|
with torch.amp.autocast_mode.autocast("cuda", enabled=True): |
|
vision_outputs = self.clip_model( |
|
pixel_values=pixel_values, output_hidden_states=True |
|
) |
|
image_features = vision_outputs.hidden_states |
|
embedded_images = self.image_adapter(image_features) |
|
embedded_images = embedded_images.to("cuda") |
|
return embedded_images |
|
|
|
def _construct_inputs(self, embedded_images, prompt): |
|
prompt_embeds = self.text_model.model.embed_tokens(prompt.to("cuda")) |
|
assert prompt_embeds.shape == ( |
|
1, |
|
prompt.shape[1], |
|
self.text_model.config.hidden_size, |
|
), ( |
|
f"Prompt shape is {prompt_embeds.shape}, expected " |
|
f"{(1, prompt.shape[1], self.text_model.config.hidden_size)}" |
|
) |
|
|
|
embedded_bos = self.text_model.model.embed_tokens( |
|
torch.tensor( |
|
[[self.tokenizer.bos_token_id]], |
|
device=self.text_model.device, |
|
dtype=torch.int64, |
|
) |
|
) |
|
|
|
eot_embed = ( |
|
self.image_adapter.get_eot_embedding() |
|
.unsqueeze(0) |
|
.to(dtype=self.text_model.dtype) |
|
) |
|
|
|
inputs_embeds = torch.cat( |
|
[ |
|
embedded_bos.expand(embedded_images.shape[0], -1, -1), |
|
embedded_images.to(dtype=embedded_bos.dtype), |
|
prompt_embeds.expand(embedded_images.shape[0], -1, -1), |
|
eot_embed.expand(embedded_images.shape[0], -1, -1), |
|
], |
|
dim=1, |
|
) |
|
|
|
input_ids = torch.cat( |
|
[ |
|
torch.tensor([[self.tokenizer.bos_token_id]], dtype=torch.long), |
|
torch.zeros((1, embedded_images.shape[1]), dtype=torch.long), |
|
prompt, |
|
torch.tensor([[self.tokenizer.eos_token_id]], dtype=torch.long), |
|
], |
|
dim=1, |
|
).to("cuda") |
|
attention_mask = torch.ones_like(input_ids) |
|
|
|
return inputs_embeds, input_ids, attention_mask |
|
|
|
def _generate_caption(self, inputs_embeds, input_ids, attention_mask): |
|
generate_ids = self.text_model.generate( |
|
input_ids, |
|
inputs_embeds=inputs_embeds, |
|
attention_mask=attention_mask, |
|
max_new_tokens=300, |
|
do_sample=True, |
|
suppress_tokens=None, |
|
repetition_penalty=1.2, |
|
) |
|
return generate_ids |
|
|
|
def _decode_caption(self, generate_ids, input_ids): |
|
generate_ids = generate_ids[:, input_ids.shape[1] :] |
|
|
|
if generate_ids[0][-1] == self.tokenizer.eos_token_id or generate_ids[0][ |
|
-1 |
|
] == self.tokenizer.convert_tokens_to_ids("<|eot_id|>"): |
|
generate_ids = generate_ids[:, :-1] |
|
|
|
caption = self.tokenizer.batch_decode( |
|
generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False |
|
)[0] |
|
return caption |
|
|
|
def _calculate_entropy(self, token_ids: List[int]) -> float: |
|
""" |
|
Calculate the entropy of a sequence of token IDs. |
|
|
|
Args: |
|
token_ids (List[int]): List of token IDs. |
|
|
|
Returns: |
|
float: Entropy of the token sequence. |
|
""" |
|
token_counts = {} |
|
total_tokens = len(token_ids) |
|
|
|
for token_id in token_ids: |
|
token_counts[token_id] = token_counts.get(token_id, 0) + 1 |
|
|
|
entropy = 0 |
|
for count in token_counts.values(): |
|
probability = count / total_tokens |
|
entropy -= probability * math.log2(probability) |
|
|
|
return entropy |
|
|
|
|
|
def main(): |
|
""" |
|
Generate captions for images in a directory |
|
and save them as .caption files. |
|
""" |
|
parser = argparse.ArgumentParser( |
|
description=( |
|
"Generate captions for images in a directory and save them as " |
|
".caption files." |
|
) |
|
) |
|
parser.add_argument( |
|
"directory", type=str, help="Target directory containing images." |
|
) |
|
parser.add_argument( |
|
"--caption_type", |
|
type=str, |
|
default="descriptive", |
|
choices=["descriptive", "training_prompt", "rng-tags", "custom"], |
|
help="Type of caption to generate.", |
|
) |
|
parser.add_argument( |
|
"--caption_tone", |
|
type=str, |
|
default="formal", |
|
choices=["formal", "informal"], |
|
help="Tone of the caption.", |
|
) |
|
parser.add_argument( |
|
"--caption_length", type=str, default="any", help="Length of the caption." |
|
) |
|
parser.add_argument( |
|
"--dont-strip-commas", |
|
action="store_true", |
|
help=("If set, commas will not be stripped from the generated captions."), |
|
) |
|
parser.add_argument( |
|
"--custom_prompt", |
|
type=str, |
|
help=("Custom prompt for the captioner. " "Use with --caption_type custom."), |
|
) |
|
parser.add_argument( |
|
"--add-commas-to-sentence-ends", |
|
action="store_true", |
|
help="Add commas after periods in sentences", |
|
) |
|
parser.add_argument( |
|
"--feed-from-tags", |
|
type=int, |
|
nargs="?", |
|
const=-1, |
|
help=( |
|
"Use .txt files with the same base filename " |
|
"as the images as input to the captioner. " |
|
"Optionally specify the number of tags to use." |
|
), |
|
) |
|
parser.add_argument( |
|
"--random-tags", |
|
type=int, |
|
help=( |
|
"Randomly select n number of tags. " |
|
"Only works if --feed-from-tags is enabled." |
|
), |
|
) |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
if args.random_tags is not None and args.feed_from_tags is None: |
|
parser.error("--random-tags can only be used when --feed-from-tags is enabled") |
|
|
|
print("Loading e621 tag data") |
|
tagset_normalizer = make_tagset_normalizer() |
|
|
|
|
|
joy_caption_model = JoyCaptionModel() |
|
joy_caption_model.load_models() |
|
|
|
|
|
if args.caption_type == "custom" and not args.custom_prompt: |
|
parser.error("--custom_prompt is required when using --caption_type custom") |
|
elif args.caption_type != "custom" and args.custom_prompt: |
|
parser.error("--custom_prompt can only be used with --caption_type custom") |
|
|
|
image_extensions = {".webp", ".png", ".jpeg", ".jpg", ".jxl"} |
|
for image_path in Path(args.directory).rglob("*"): |
|
if image_path.suffix.lower() in image_extensions: |
|
caption_file = image_path.with_suffix(".caption") |
|
|
|
|
|
if caption_file.exists(): |
|
print(f"Skipping {image_path}: Caption file already exists.") |
|
continue |
|
|
|
input_image = Image.open(image_path).convert("RGB") |
|
|
|
|
|
custom_prompt = None |
|
if args.caption_type == "custom": |
|
custom_prompt = args.custom_prompt |
|
elif args.feed_from_tags is not None: |
|
custom_prompt = prompt_from_tags(args, image_path, tagset_normalizer) |
|
|
|
print(f"\nCustom prompt: {custom_prompt}") |
|
|
|
caption = joy_caption_model.generate_valid_caption( |
|
input_image, |
|
args.caption_type, |
|
args.caption_tone, |
|
args.caption_length, |
|
custom_prompt=custom_prompt, |
|
) |
|
|
|
|
|
if not args.dont_strip_commas: |
|
|
|
caption = re.sub(r",\s*([^\d])", r" \1", caption) |
|
|
|
|
|
if args.add_commas_to_sentence_ends: |
|
caption = re.sub(r"(\.)(\s+)([A-Z])", r"\1,\2\3", caption) |
|
|
|
|
|
caption = caption.replace("\n", " ") |
|
|
|
print(f"Caption for {image_path}:\n\n{caption}\n\n") |
|
|
|
|
|
with open(caption_file, "w", encoding="utf-8") as f: |
|
f.write(caption) |
|
print(f"Caption saved to {caption_file}") |
|
|
|
|
|
RE_PARENS_SUFFIX = re.compile(r"_\([^)]+\)$") |
|
E6DB_DATA = Path(__file__).resolve().parent / "data" |
|
|
|
|
|
def make_tagset_normalizer(): |
|
""" |
|
Create a TagSetNormalizer for encoding/decoding tags to and from integers. |
|
Configures it based on the provided config. |
|
""" |
|
|
|
tagset_normalizer = TagSetNormalizer(E6DB_DATA) |
|
|
|
tagid2cat = tagset_normalizer.tag_normalizer.tag_categories |
|
cat_artist = tag_category2id["artist"] |
|
cat2suffix = { |
|
tag_category2id["character"]: "_(character)", |
|
tag_category2id["lore"]: "_(lore)", |
|
tag_category2id["species"]: "_(species)", |
|
tag_category2id["copyright"]: "_(copyright)", |
|
} |
|
|
|
|
|
def input_map(tag, tid): |
|
|
|
|
|
without_suffix = RE_PARENS_SUFFIX.sub("", tag) |
|
had_suffix = tag != without_suffix |
|
if had_suffix: |
|
yield without_suffix |
|
|
|
|
|
cat = tagid2cat[tid] if tid is not None else -1 |
|
if cat == cat_artist: |
|
artist = without_suffix.removeprefix("by_") |
|
if artist != without_suffix: |
|
yield artist |
|
if not had_suffix: |
|
yield f"{artist}_(artist)" |
|
else: |
|
yield f"by_{artist}" |
|
if not had_suffix: |
|
yield f"by_{artist}_(artist)" |
|
elif not had_suffix: |
|
suffix = cat2suffix.get(cat) |
|
if suffix is not None: |
|
yield f"{without_suffix}{suffix}" |
|
|
|
|
|
if ":" in tag: |
|
yield tag.replace(":", "_") |
|
|
|
return tagset_normalizer.map_inputs(input_map, on_conflict="ignore") |
|
|
|
|
|
def format_nl_list(word_list): |
|
""" |
|
Takes a list of words and generates a natural language output. |
|
""" |
|
n = len(word_list) |
|
assert n > 0 |
|
if n == 1: |
|
return word_list[0] |
|
if n == 2: |
|
return f"{word_list[0]} and {word_list[1]}" |
|
|
|
*head, last = word_list |
|
return ", ".join(head) + ", and " + last |
|
|
|
|
|
TAG_SPECIES = tag_category2id["species"] |
|
TAG_CHARACTER = tag_category2id["character"] |
|
TAG_ARTIST = tag_category2id["artist"] |
|
TAG_COPYRIGHT = tag_category2id["copyright"] |
|
TAG_META = tag_category2id["meta"] |
|
TAG_FREQ_THRESH = 0 |
|
|
|
|
|
def prompt_from_tags(args, image_path: Path, tagset_normalizer: TagSetNormalizer): |
|
""" |
|
Generates a prompt from tags associated with the given image. |
|
|
|
Args: |
|
args: Additional arguments for the function. |
|
image_path (Path): |
|
The path to the image file. |
|
tagset_normalizer (TagSetNormalizer): |
|
An instance to normalize the tag set. |
|
|
|
Returns: |
|
None |
|
""" |
|
tag_file = find_tag_file(image_path) |
|
if tag_file is None: |
|
return None |
|
|
|
with open(tag_file, "r", encoding="utf-8") as f: |
|
tags = f.read().lower().split(",") |
|
|
|
tag_id_to_cat_id = tagset_normalizer.tag_normalizer.tag_categories |
|
encode = tagset_normalizer.tag_normalizer.encode |
|
|
|
|
|
tag_by_category: Dict[int, List[Tuple[int, str, int]]] = { |
|
cat: [] for cat in [TAG_ARTIST, TAG_CHARACTER, TAG_COPYRIGHT, TAG_SPECIES] |
|
} |
|
other_tags: List[Tuple[int, str, int]] = [] |
|
implied: set = set() |
|
for tag in tags: |
|
tag = tag.strip() |
|
|
|
tag_id = encode(tag.replace(" ", "_")) |
|
if tag_id is None: |
|
other_tags.append((0, tag, 0)) |
|
implied.update(tagset_normalizer.implications_rej.get(0, ())) |
|
continue |
|
|
|
cat_id = tag_id_to_cat_id[tag_id] |
|
|
|
if cat_id == TAG_META: |
|
continue |
|
implied.update(tagset_normalizer.implications.get(tag_id, ())) |
|
|
|
freq = tag_rank_to_freq(tag_id) |
|
if freq < TAG_FREQ_THRESH: |
|
continue |
|
tag_by_category.get(cat_id, other_tags).append((int(freq), tag, tag_id)) |
|
|
|
other_tags = sorted( |
|
(int(freq), tag, tag_id) |
|
for freq, tag, tag_id in other_tags |
|
if tag_id not in implied |
|
) |
|
|
|
for cat_id, cat_list in tag_by_category.items(): |
|
tag_by_category[cat_id] = sorted( |
|
(int(freq), tag, tag_id) |
|
for freq, tag, tag_id in cat_list |
|
if tag_id not in implied |
|
) |
|
|
|
if args.random_tags is not None: |
|
|
|
num_tags = min(args.random_tags, len(other_tags)) |
|
other_tags = random.sample( |
|
[ |
|
(i, tag, 0) |
|
for i, tag in enumerate(tags[: round(args.random_tags * 1.5)]) |
|
], |
|
num_tags, |
|
) |
|
elif args.feed_from_tags > 0: |
|
|
|
other_tags = other_tags[: args.feed_from_tags] |
|
|
|
|
|
artist_tag = tag_by_category[TAG_ARTIST] |
|
if artist_tag: |
|
artist_list = [str(tp[1]).removeprefix("by ") for tp in artist_tag[:4]] |
|
artist_txt = f"by {format_nl_list(artist_list)}" |
|
else: |
|
artist_txt = "" |
|
|
|
character_tag = tag_by_category[TAG_CHARACTER] |
|
if character_tag: |
|
tags = [tag for _, tag, _ in character_tag[:4]] |
|
character_txt = f"named {format_nl_list(tags)}" |
|
else: |
|
character_txt = "" |
|
|
|
species_tag = tag_by_category[TAG_SPECIES] |
|
if species_tag: |
|
species_txt = ( |
|
"of a " if len(character_tag) <= 1 and len(species_tag) <= 1 else "of " |
|
) |
|
species_txt += format_nl_list([tp[1] for tp in species_tag[:4]]) |
|
else: |
|
if character_tag: |
|
species_txt = " a character" if len(character_tag) <= 1 else " characters" |
|
else: |
|
species_txt = "" |
|
|
|
copyright_tag = tag_by_category[TAG_COPYRIGHT] |
|
if copyright_tag: |
|
tags = [tag for _, tag, *_ in copyright_tag[:4]] |
|
copyright_txt = f"from {format_nl_list(tags)}" |
|
else: |
|
copyright_txt = "" |
|
tag_string = ", ".join(tp[1] for tp in other_tags) |
|
custom_prompt = " ".join( |
|
s |
|
for s in [ |
|
"Write a descriptive caption for this image", |
|
artist_txt, |
|
species_txt, |
|
character_txt, |
|
copyright_txt, |
|
"in a formal tone. Limit yourself to two paragraphs, avoid repeating yourself and think before you type anything. Use these tags to construct your caption:", |
|
tag_string, |
|
] |
|
if s |
|
) |
|
return custom_prompt |
|
|
|
|
|
def find_tag_file(image_path): |
|
""" |
|
Find the corresponding .txt file for the given image path. |
|
Handles cases where the image has a -(number) suffix. |
|
""" |
|
base_name = image_path.stem |
|
tag_file = image_path.with_suffix(".txt") |
|
|
|
if tag_file.exists(): |
|
return tag_file |
|
|
|
|
|
match = re.match(r"(.+)-\d+$", base_name) |
|
if match: |
|
base_name = match.group(1) |
|
tag_file = image_path.with_name(base_name).with_suffix(".txt") |
|
if tag_file.exists(): |
|
return tag_file |
|
|
|
return None |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|